@@ -677,6 +677,8 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
677677 vector<tuple<Type, CastFuncTy>> test_params = {
678678 {Int (8 ), in_i8}, {Int (16 ), in_i16}, {Int (32 ), in_i32}, {Int (64 ), in_i64}, {UInt (8 ), in_u8}, {UInt (16 ), in_u16}, {UInt (32 ), in_u32}, {UInt (64 ), in_u64}, {Float (16 ), in_f16}, {Float (32 ), in_f32}, {Float (64 ), in_f64}};
679679
680+ const int base_vec_bits = has_sve () ? target.vector_bits : 128 ;
681+
680682 for (const auto &[elt, in_im] : test_params) {
681683 const int bits = elt.bits ();
682684 if ((elt == Float (16 ) && !is_float16_supported ()) ||
@@ -712,40 +714,55 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
712714 }
713715 }
714716
715- // LD2/ST2 - Load/Store two-element structures
716- int base_vec_bits = has_sve () ? target.vector_bits : 128 ;
717- for (int width = base_vec_bits; width <= base_vec_bits * 4 ; width *= 2 ) {
717+ // LDn - Structured Load strided elements
718+ for (int stride = 2 ; stride <= 4 ; ++stride) {
719+
720+ for (int factor = 1 ; factor <= 4 ; factor *= 2 ) {
721+ const int vector_lanes = base_vec_bits * factor / bits;
722+
723+ // In StageStridedLoads.cp (stride < r->lanes) is the condition for staging to happen
724+ // See https://github.com/halide/Halide/issues/8819
725+ if (vector_lanes <= stride) continue ;
726+
727+ AddTestFunctor add_ldn (*this , bits, vector_lanes);
728+
729+ Expr load_n = in_im (x * stride) + in_im (x * stride + stride - 1 );
730+
731+ const string ldn_str = " ld" + to_string (stride);
732+ if (has_sve ()) {
733+ add_ldn ({get_sve_ls_instr (ldn_str, bits)}, vector_lanes, load_n);
734+ } else {
735+ add_ldn (sel_op (" v" + ldn_str + " ." , ldn_str), load_n);
736+ }
737+ }
738+ }
739+
740+ // ST2 - Store two-element structures
741+ for (int width = base_vec_bits * 2 ; width <= base_vec_bits * 8 ; width *= 2 ) {
718742 const int total_lanes = width / bits;
719743 const int vector_lanes = total_lanes / 2 ;
720744 const int instr_lanes = min (vector_lanes, base_vec_bits / bits);
721745 if (instr_lanes < 2 ) continue ; // bail out scalar op
722746
723- AddTestFunctor add_ldn (*this , bits, vector_lanes);
724747 AddTestFunctor add_stn (*this , bits, instr_lanes, total_lanes);
725748
726749 Func tmp1, tmp2;
727750 tmp1 (x) = cast (elt, x);
728751 tmp1.compute_root ();
729752 tmp2 (x, y) = select (x % 2 == 0 , tmp1 (x / 2 ), tmp1 (x / 2 + 16 ));
730753 tmp2.compute_root ().vectorize (x, total_lanes);
731- Expr load_2 = in_im (x * 2 ) + in_im (x * 2 + 1 );
732754 Expr store_2 = tmp2 (0 , 0 ) + tmp2 (0 , 127 );
733755
734756 if (has_sve ()) {
735- // TODO(inssue needed): Added strided load support.
736- #if 0
737- add_ldn({get_sve_ls_instr("ld2", bits)}, vector_lanes, load_2);
738- #endif
739757 add_stn ({get_sve_ls_instr (" st2" , bits)}, total_lanes, store_2);
740758 } else {
741- add_ldn (sel_op (" vld2." , " ld2" ), load_2);
742759 add_stn (sel_op (" vst2." , " st2" ), store_2);
743760 }
744761 }
745762
746763 // Also check when the two expressions interleaved have a common
747764 // subexpression, which results in a vector var being lifted out.
748- for (int width = base_vec_bits; width <= base_vec_bits * 4 ; width *= 2 ) {
765+ for (int width = base_vec_bits * 2 ; width <= base_vec_bits * 4 ; width *= 2 ) {
749766 const int total_lanes = width / bits;
750767 const int vector_lanes = total_lanes / 2 ;
751768 const int instr_lanes = Instruction::get_instr_lanes (bits, vector_lanes, target);
@@ -768,14 +785,13 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
768785 }
769786 }
770787
771- // LD3/ ST3 - Store three-element structures
772- for (int width = 192 ; width <= 192 * 4 ; width *= 2 ) {
788+ // ST3 - Store three-element structures
789+ for (int width = base_vec_bits * 3 ; width <= base_vec_bits * 3 * 2 ; width *= 2 ) {
773790 const int total_lanes = width / bits;
774791 const int vector_lanes = total_lanes / 3 ;
775792 const int instr_lanes = Instruction::get_instr_lanes (bits, vector_lanes, target);
776793 if (instr_lanes < 2 ) continue ; // bail out scalar op
777794
778- AddTestFunctor add_ldn (*this , bits, vector_lanes);
779795 AddTestFunctor add_stn (*this , bits, instr_lanes, total_lanes);
780796
781797 Func tmp1, tmp2;
@@ -785,29 +801,22 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
785801 x % 3 == 1 , tmp1 (x / 3 + 16 ),
786802 tmp1 (x / 3 + 32 ));
787803 tmp2.compute_root ().vectorize (x, total_lanes);
788- Expr load_3 = in_im (x * 3 ) + in_im (x * 3 + 1 ) + in_im (x * 3 + 2 );
789804 Expr store_3 = tmp2 (0 , 0 ) + tmp2 (0 , 127 );
790805
791806 if (has_sve ()) {
792- // TODO(issue needed): Added strided load support.
793- #if 0
794- add_ldn({get_sve_ls_instr("ld3", bits)}, vector_lanes, load_3);
795807 add_stn ({get_sve_ls_instr (" st3" , bits)}, total_lanes, store_3);
796- #endif
797808 } else {
798- add_ldn (sel_op (" vld3." , " ld3" ), load_3);
799809 add_stn (sel_op (" vst3." , " st3" ), store_3);
800810 }
801811 }
802812
803- // LD4/ ST4 - Store four-element structures
804- for (int width = 256 ; width <= 256 * 4 ; width *= 2 ) {
813+ // ST4 - Store four-element structures
814+ for (int width = base_vec_bits * 4 ; width <= base_vec_bits * 4 * 2 ; width *= 2 ) {
805815 const int total_lanes = width / bits;
806816 const int vector_lanes = total_lanes / 4 ;
807817 const int instr_lanes = Instruction::get_instr_lanes (bits, vector_lanes, target);
808818 if (instr_lanes < 2 ) continue ; // bail out scalar op
809819
810- AddTestFunctor add_ldn (*this , bits, vector_lanes);
811820 AddTestFunctor add_stn (*this , bits, instr_lanes, total_lanes);
812821
813822 Func tmp1, tmp2;
@@ -818,17 +827,11 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
818827 x % 4 == 2 , tmp1 (x / 4 + 32 ),
819828 tmp1 (x / 4 + 48 ));
820829 tmp2.compute_root ().vectorize (x, total_lanes);
821- Expr load_4 = in_im (x * 4 ) + in_im (x * 4 + 1 ) + in_im (x * 4 + 2 ) + in_im (x * 4 + 3 );
822830 Expr store_4 = tmp2 (0 , 0 ) + tmp2 (0 , 127 );
823831
824832 if (has_sve ()) {
825- // TODO(issue needed): Added strided load support.
826- #if 0
827- add_ldn({get_sve_ls_instr("ld4", bits)}, vector_lanes, load_4);
828833 add_stn ({get_sve_ls_instr (" st4" , bits)}, total_lanes, store_4);
829- #endif
830834 } else {
831- add_ldn (sel_op (" vld4." , " ld4" ), load_4);
832835 add_stn (sel_op (" vst4." , " st4" ), store_4);
833836 }
834837 }
@@ -1295,6 +1298,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
12951298
12961299 auto ext = Internal::get_output_info (target);
12971300 std::map<OutputFileType, std::string> outputs = {
1301+ {OutputFileType::stmt, file_name + ext.at (OutputFileType::stmt).extension },
12981302 {OutputFileType::llvm_assembly, file_name + ext.at (OutputFileType::llvm_assembly).extension },
12991303 {OutputFileType::c_header, file_name + ext.at (OutputFileType::c_header).extension },
13001304 {OutputFileType::object, file_name + ext.at (OutputFileType::object).extension },
0 commit comments