Skip to content

Commit e54a3d8

Browse files
Support strided Load/Store in SVE2
Structured load relies on shuffle_vectors for scalable vector, which is lowered to llvm.vector.deinterleave
1 parent b993d90 commit e54a3d8

File tree

2 files changed

+42
-115
lines changed

2 files changed

+42
-115
lines changed

src/CodeGen_ARM.cpp

Lines changed: 9 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,17 +1452,13 @@ void CodeGen_ARM::visit(const Store *op) {
14521452
is_float16_and_has_feature(elt) ||
14531453
elt == Int(8) || elt == Int(16) || elt == Int(32) || elt == Int(64) ||
14541454
elt == UInt(8) || elt == UInt(16) || elt == UInt(32) || elt == UInt(64)) {
1455-
// TODO(zvookin): Handle vector_bits_*.
1455+
const int target_vector_bits = native_vector_bits();
14561456
if (vec_bits % 128 == 0) {
14571457
type_ok_for_vst = true;
1458-
int target_vector_bits = native_vector_bits();
1459-
if (target_vector_bits == 0) {
1460-
target_vector_bits = 128;
1461-
}
14621458
intrin_type = intrin_type.with_lanes(target_vector_bits / t.bits());
14631459
} else if (vec_bits % 64 == 0) {
14641460
type_ok_for_vst = true;
1465-
auto intrin_bits = (vec_bits % 128 == 0 || target.has_feature(Target::SVE2)) ? 128 : 64;
1461+
auto intrin_bits = (vec_bits % 128 == 0 || target.has_feature(Target::SVE2)) ? target_vector_bits : 64;
14661462
intrin_type = intrin_type.with_lanes(intrin_bits / t.bits());
14671463
}
14681464
}
@@ -1490,7 +1486,6 @@ void CodeGen_ARM::visit(const Store *op) {
14901486
for (int i = 0; i < num_vecs; ++i) {
14911487
args[i] = codegen(shuffle->vectors[i]);
14921488
}
1493-
Value *store_pred_val = codegen(op->predicate);
14941489

14951490
bool is_sve = target.has_feature(Target::SVE2);
14961491

@@ -1536,8 +1531,8 @@ void CodeGen_ARM::visit(const Store *op) {
15361531
llvm::FunctionCallee fn = module->getOrInsertFunction(instr.str(), fn_type);
15371532
internal_assert(fn);
15381533

1539-
// SVE2 supports predication for smaller than whole vector size.
1540-
internal_assert(target.has_feature(Target::SVE2) || (t.lanes() >= intrin_type.lanes()));
1534+
// Scalable vector supports predication for smaller than whole vector size.
1535+
internal_assert(target_vscale() > 0 || (t.lanes() >= intrin_type.lanes()));
15411536

15421537
for (int i = 0; i < t.lanes(); i += intrin_type.lanes()) {
15431538
Expr slice_base = simplify(ramp->base + i * num_vecs);
@@ -1557,16 +1552,12 @@ void CodeGen_ARM::visit(const Store *op) {
15571552
// Set the alignment argument
15581553
slice_args.push_back(ConstantInt::get(i32_t, alignment));
15591554
} else {
1560-
if (is_sve) {
1561-
// Set the predicate argument
1555+
// TODO: we could handle is_predicated_store==true once shuffle_vector gets robust for scalable vectors
1556+
if (is_sve && !is_predicated_store) {
1557+
// Set the predicate argument to mask active lanes
15621558
auto active_lanes = std::min(t.lanes() - i, intrin_type.lanes());
1563-
Value *vpred_val;
1564-
if (is_predicated_store) {
1565-
vpred_val = slice_vector(store_pred_val, i, intrin_type.lanes());
1566-
} else {
1567-
Expr vpred = make_vector_predicate_1s_0s(active_lanes, intrin_type.lanes() - active_lanes);
1568-
vpred_val = codegen(vpred);
1569-
}
1559+
Expr vpred = make_vector_predicate_1s_0s(active_lanes, intrin_type.lanes() - active_lanes);
1560+
Value *vpred_val = codegen(vpred);
15701561
slice_args.push_back(vpred_val);
15711562
}
15721563
// Set the pointer argument
@@ -1787,74 +1778,6 @@ void CodeGen_ARM::visit(const Load *op) {
17871778
CodeGen_Posix::visit(op);
17881779
return;
17891780
}
1790-
} else if (stride && (2 <= stride->value && stride->value <= 4)) {
1791-
// Structured load ST2/ST3/ST4 of SVE
1792-
1793-
Expr base = ramp->base;
1794-
ModulusRemainder align = op->alignment;
1795-
1796-
int aligned_stride = gcd(stride->value, align.modulus);
1797-
int offset = 0;
1798-
if (aligned_stride == stride->value) {
1799-
offset = mod_imp((int)align.remainder, aligned_stride);
1800-
} else {
1801-
const Add *add = base.as<Add>();
1802-
if (const IntImm *add_c = add ? add->b.as<IntImm>() : base.as<IntImm>()) {
1803-
offset = mod_imp(add_c->value, stride->value);
1804-
}
1805-
}
1806-
1807-
if (offset) {
1808-
base = simplify(base - offset);
1809-
}
1810-
1811-
Value *load_pred_val = codegen(op->predicate);
1812-
1813-
// We need to slice the result in to native vector lanes to use sve intrin.
1814-
// LLVM will optimize redundant ld instructions afterwards
1815-
const int slice_lanes = target.natural_vector_size(op->type);
1816-
vector<Value *> results;
1817-
for (int i = 0; i < op->type.lanes(); i += slice_lanes) {
1818-
int load_base_i = i * stride->value;
1819-
Expr slice_base = simplify(base + load_base_i);
1820-
Expr slice_index = Ramp::make(slice_base, stride, slice_lanes);
1821-
std::ostringstream instr;
1822-
instr << "llvm.aarch64.sve.ld"
1823-
<< stride->value
1824-
<< ".sret.nxv"
1825-
<< slice_lanes
1826-
<< (op->type.is_float() ? 'f' : 'i')
1827-
<< op->type.bits();
1828-
llvm::Type *elt = llvm_type_of(op->type.element_of());
1829-
llvm::Type *slice_type = get_vector_type(elt, slice_lanes);
1830-
StructType *sret_type = StructType::get(module->getContext(), std::vector(stride->value, slice_type));
1831-
std::vector<llvm::Type *> arg_types{get_vector_type(i1_t, slice_lanes), ptr_t};
1832-
llvm::FunctionType *fn_type = FunctionType::get(sret_type, arg_types, false);
1833-
FunctionCallee fn = module->getOrInsertFunction(instr.str(), fn_type);
1834-
1835-
// Set the predicate argument
1836-
int active_lanes = std::min(op->type.lanes() - i, slice_lanes);
1837-
1838-
Expr vpred = make_vector_predicate_1s_0s(active_lanes, slice_lanes - active_lanes);
1839-
Value *vpred_val = codegen(vpred);
1840-
vpred_val = convert_fixed_or_scalable_vector_type(vpred_val, get_vector_type(vpred_val->getType()->getScalarType(), slice_lanes));
1841-
if (is_predicated_load) {
1842-
Value *sliced_load_vpred_val = slice_vector(load_pred_val, i, slice_lanes);
1843-
vpred_val = builder->CreateAnd(vpred_val, sliced_load_vpred_val);
1844-
}
1845-
1846-
Value *elt_ptr = codegen_buffer_pointer(op->name, op->type.element_of(), slice_base);
1847-
CallInst *load_i = builder->CreateCall(fn, {vpred_val, elt_ptr});
1848-
add_tbaa_metadata(load_i, op->name, slice_index);
1849-
// extract one element out of returned struct
1850-
Value *extracted = builder->CreateExtractValue(load_i, offset);
1851-
results.push_back(extracted);
1852-
}
1853-
1854-
// Retrieve original lanes
1855-
value = concat_vectors(results);
1856-
value = slice_vector(value, 0, op->type.lanes());
1857-
return;
18581781
} else if (op->index.type().is_vector()) {
18591782
// General Gather Load
18601783

test/correctness/simd_op_check_sve2.cpp

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,8 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
677677
vector<tuple<Type, CastFuncTy>> test_params = {
678678
{Int(8), in_i8}, {Int(16), in_i16}, {Int(32), in_i32}, {Int(64), in_i64}, {UInt(8), in_u8}, {UInt(16), in_u16}, {UInt(32), in_u32}, {UInt(64), in_u64}, {Float(16), in_f16}, {Float(32), in_f32}, {Float(64), in_f64}};
679679

680+
const int base_vec_bits = has_sve() ? target.vector_bits : 128;
681+
680682
for (const auto &[elt, in_im] : test_params) {
681683
const int bits = elt.bits();
682684
if ((elt == Float(16) && !is_float16_supported()) ||
@@ -712,40 +714,55 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
712714
}
713715
}
714716

715-
// LD2/ST2 - Load/Store two-element structures
716-
int base_vec_bits = has_sve() ? target.vector_bits : 128;
717-
for (int width = base_vec_bits; width <= base_vec_bits * 4; width *= 2) {
717+
// LDn - Structured Load strided elements
718+
for (int stride = 2; stride <= 4; ++stride) {
719+
720+
for (int factor = 1; factor <= 4; factor *= 2) {
721+
const int vector_lanes = base_vec_bits * factor / bits;
722+
723+
// In StageStridedLoads.cp (stride < r->lanes) is the condition for staging to happen
724+
// See https://github.com/halide/Halide/issues/8819
725+
if (vector_lanes <= stride) continue;
726+
727+
AddTestFunctor add_ldn(*this, bits, vector_lanes);
728+
729+
Expr load_n = in_im(x * stride) + in_im(x * stride + stride - 1);
730+
731+
const string ldn_str = "ld" + to_string(stride);
732+
if (has_sve()) {
733+
add_ldn({get_sve_ls_instr(ldn_str, bits)}, vector_lanes, load_n);
734+
} else {
735+
add_ldn(sel_op("v" + ldn_str + ".", ldn_str), load_n);
736+
}
737+
}
738+
}
739+
740+
// ST2 - Store two-element structures
741+
for (int width = base_vec_bits * 2; width <= base_vec_bits * 8; width *= 2) {
718742
const int total_lanes = width / bits;
719743
const int vector_lanes = total_lanes / 2;
720744
const int instr_lanes = min(vector_lanes, base_vec_bits / bits);
721745
if (instr_lanes < 2) continue; // bail out scalar op
722746

723-
AddTestFunctor add_ldn(*this, bits, vector_lanes);
724747
AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);
725748

726749
Func tmp1, tmp2;
727750
tmp1(x) = cast(elt, x);
728751
tmp1.compute_root();
729752
tmp2(x, y) = select(x % 2 == 0, tmp1(x / 2), tmp1(x / 2 + 16));
730753
tmp2.compute_root().vectorize(x, total_lanes);
731-
Expr load_2 = in_im(x * 2) + in_im(x * 2 + 1);
732754
Expr store_2 = tmp2(0, 0) + tmp2(0, 127);
733755

734756
if (has_sve()) {
735-
// TODO(inssue needed): Added strided load support.
736-
#if 0
737-
add_ldn({get_sve_ls_instr("ld2", bits)}, vector_lanes, load_2);
738-
#endif
739757
add_stn({get_sve_ls_instr("st2", bits)}, total_lanes, store_2);
740758
} else {
741-
add_ldn(sel_op("vld2.", "ld2"), load_2);
742759
add_stn(sel_op("vst2.", "st2"), store_2);
743760
}
744761
}
745762

746763
// Also check when the two expressions interleaved have a common
747764
// subexpression, which results in a vector var being lifted out.
748-
for (int width = base_vec_bits; width <= base_vec_bits * 4; width *= 2) {
765+
for (int width = base_vec_bits * 2; width <= base_vec_bits * 4; width *= 2) {
749766
const int total_lanes = width / bits;
750767
const int vector_lanes = total_lanes / 2;
751768
const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target);
@@ -768,14 +785,13 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
768785
}
769786
}
770787

771-
// LD3/ST3 - Store three-element structures
772-
for (int width = 192; width <= 192 * 4; width *= 2) {
788+
// ST3 - Store three-element structures
789+
for (int width = base_vec_bits * 3; width <= base_vec_bits * 3 * 2; width *= 2) {
773790
const int total_lanes = width / bits;
774791
const int vector_lanes = total_lanes / 3;
775792
const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target);
776793
if (instr_lanes < 2) continue; // bail out scalar op
777794

778-
AddTestFunctor add_ldn(*this, bits, vector_lanes);
779795
AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);
780796

781797
Func tmp1, tmp2;
@@ -785,29 +801,22 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
785801
x % 3 == 1, tmp1(x / 3 + 16),
786802
tmp1(x / 3 + 32));
787803
tmp2.compute_root().vectorize(x, total_lanes);
788-
Expr load_3 = in_im(x * 3) + in_im(x * 3 + 1) + in_im(x * 3 + 2);
789804
Expr store_3 = tmp2(0, 0) + tmp2(0, 127);
790805

791806
if (has_sve()) {
792-
// TODO(issue needed): Added strided load support.
793-
#if 0
794-
add_ldn({get_sve_ls_instr("ld3", bits)}, vector_lanes, load_3);
795807
add_stn({get_sve_ls_instr("st3", bits)}, total_lanes, store_3);
796-
#endif
797808
} else {
798-
add_ldn(sel_op("vld3.", "ld3"), load_3);
799809
add_stn(sel_op("vst3.", "st3"), store_3);
800810
}
801811
}
802812

803-
// LD4/ST4 - Store four-element structures
804-
for (int width = 256; width <= 256 * 4; width *= 2) {
813+
// ST4 - Store four-element structures
814+
for (int width = base_vec_bits * 4; width <= base_vec_bits * 4 * 2; width *= 2) {
805815
const int total_lanes = width / bits;
806816
const int vector_lanes = total_lanes / 4;
807817
const int instr_lanes = Instruction::get_instr_lanes(bits, vector_lanes, target);
808818
if (instr_lanes < 2) continue; // bail out scalar op
809819

810-
AddTestFunctor add_ldn(*this, bits, vector_lanes);
811820
AddTestFunctor add_stn(*this, bits, instr_lanes, total_lanes);
812821

813822
Func tmp1, tmp2;
@@ -818,17 +827,11 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
818827
x % 4 == 2, tmp1(x / 4 + 32),
819828
tmp1(x / 4 + 48));
820829
tmp2.compute_root().vectorize(x, total_lanes);
821-
Expr load_4 = in_im(x * 4) + in_im(x * 4 + 1) + in_im(x * 4 + 2) + in_im(x * 4 + 3);
822830
Expr store_4 = tmp2(0, 0) + tmp2(0, 127);
823831

824832
if (has_sve()) {
825-
// TODO(issue needed): Added strided load support.
826-
#if 0
827-
add_ldn({get_sve_ls_instr("ld4", bits)}, vector_lanes, load_4);
828833
add_stn({get_sve_ls_instr("st4", bits)}, total_lanes, store_4);
829-
#endif
830834
} else {
831-
add_ldn(sel_op("vld4.", "ld4"), load_4);
832835
add_stn(sel_op("vst4.", "st4"), store_4);
833836
}
834837
}
@@ -1295,6 +1298,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
12951298

12961299
auto ext = Internal::get_output_info(target);
12971300
std::map<OutputFileType, std::string> outputs = {
1301+
{OutputFileType::stmt, file_name + ext.at(OutputFileType::stmt).extension},
12981302
{OutputFileType::llvm_assembly, file_name + ext.at(OutputFileType::llvm_assembly).extension},
12991303
{OutputFileType::c_header, file_name + ext.at(OutputFileType::c_header).extension},
13001304
{OutputFileType::object, file_name + ext.at(OutputFileType::object).extension},

0 commit comments

Comments
 (0)