@@ -571,7 +571,6 @@ def preprocess_multipole_nexprs(self,
571571 def loopy_preprocess_multipole (self ,
572572 tgt_expansion : LocalExpansionBase ,
573573 src_expansion : MultipoleExpansionBase ,
574- # FIXME: why is result_dtype unused here?
575574 result_dtype : DTypeLike ,
576575 ) -> tuple [lp .TranslationUnit , Sequence [OptimizationCallable ]]:
577576 assert isinstance (tgt_expansion , VolumeTaylorLocalExpansionBase )
@@ -659,7 +658,7 @@ def loopy_preprocess_multipole(self,
659658 knl = lp .make_function (domains , insns ,
660659 kernel_data = [
661660 lp .ValueArg ("src_rscale" , None ),
662- lp .GlobalArg ("output_coeffs" , None , shape = ncoeff_preprocessed ,
661+ lp .GlobalArg ("output_coeffs" , result_dtype , shape = ncoeff_preprocessed ,
663662 is_input = False , is_output = True ),
664663 lp .GlobalArg ("input_coeffs" , None , shape = ncoeff_src ),
665664 ...],
@@ -805,7 +804,7 @@ def result_func(x: ArithmeticExpression) -> ArithmeticExpression:
805804 kernel_data = [
806805 lp .ValueArg ("src_rscale" , None ),
807806 lp .ValueArg ("tgt_rscale" , None ),
808- lp .GlobalArg ("output_coeffs" , None ,
807+ lp .GlobalArg ("output_coeffs" , result_dtype ,
809808 shape = ncoeff_tgt , is_input = False ,
810809 is_output = True ),
811810 lp .GlobalArg ("input_coeffs" , None ,
0 commit comments