File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
src/maxdiffusion/models/wan Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -154,7 +154,7 @@ def __call__(
154154
155155 if cache_x is not None :
156156 x_concat = jnp .concatenate ([cache_x , x ], axis = 1 )
157- new_cache = x_concat [:, - CACHE_T :, ...]
157+ new_cache = x_concat [:, - CACHE_T :, ...]. astype ( cache_x . dtype )
158158
159159 padding_needed = self ._depth_padding_before - cache_x .shape [1 ]
160160 if padding_needed < 0 :
@@ -415,7 +415,7 @@ def __call__(
415415 prev_cache = cache .get ("time_conv" )
416416 x_combined = jnp .concatenate ([prev_cache , x ], axis = 1 )
417417 x , _ = self .time_conv (x_combined , cache_x = None )
418- new_cache ["time_conv" ] = x_combined [:, - CACHE_T :, ...]
418+ new_cache ["time_conv" ] = x_combined [:, - CACHE_T :, ...]. astype ( prev_cache . dtype )
419419
420420 else :
421421 if hasattr (self , "resample" ):
You can’t perform that action at this time.
0 commit comments