Skip to content

Commit e4fa17a

Browse files
committed
M2L preprocess: respect result_dtype
1 parent bf6dd8e commit e4fa17a

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

sumpy/expansion/m2l.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,6 @@ def preprocess_multipole_nexprs(self,
571571
def loopy_preprocess_multipole(self,
572572
tgt_expansion: LocalExpansionBase,
573573
src_expansion: MultipoleExpansionBase,
574-
# FIXME: why is result_dtype unused here?
575574
result_dtype: DTypeLike,
576575
) -> tuple[lp.TranslationUnit, Sequence[OptimizationCallable]]:
577576
assert isinstance(tgt_expansion, VolumeTaylorLocalExpansionBase)
@@ -659,7 +658,7 @@ def loopy_preprocess_multipole(self,
659658
knl = lp.make_function(domains, insns,
660659
kernel_data=[
661660
lp.ValueArg("src_rscale", None),
662-
lp.GlobalArg("output_coeffs", None, shape=ncoeff_preprocessed,
661+
lp.GlobalArg("output_coeffs", result_dtype, shape=ncoeff_preprocessed,
663662
is_input=False, is_output=True),
664663
lp.GlobalArg("input_coeffs", None, shape=ncoeff_src),
665664
...],
@@ -805,7 +804,7 @@ def result_func(x: ArithmeticExpression) -> ArithmeticExpression:
805804
kernel_data=[
806805
lp.ValueArg("src_rscale", None),
807806
lp.ValueArg("tgt_rscale", None),
808-
lp.GlobalArg("output_coeffs", None,
807+
lp.GlobalArg("output_coeffs", result_dtype,
809808
shape=ncoeff_tgt, is_input=False,
810809
is_output=True),
811810
lp.GlobalArg("input_coeffs", None,

0 commit comments

Comments
 (0)