Skip to content

Commit a996cce

Browse files
committed
fix for dtypes
1 parent 63bf3e9 commit a996cce

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def __call__(
154154

155155
if cache_x is not None:
156156
x_concat = jnp.concatenate([cache_x, x], axis=1)
157-
new_cache = x_concat[:, -CACHE_T:, ...]
157+
new_cache = x_concat[:, -CACHE_T:, ...].astype(cache_x.dtype)
158158

159159
padding_needed = self._depth_padding_before - cache_x.shape[1]
160160
if padding_needed < 0:
@@ -415,7 +415,7 @@ def __call__(
415415
prev_cache = cache.get("time_conv")
416416
x_combined = jnp.concatenate([prev_cache, x], axis=1)
417417
x, _ = self.time_conv(x_combined, cache_x=None)
418-
new_cache["time_conv"] = x_combined[:, -CACHE_T:, ...]
418+
new_cache["time_conv"] = x_combined[:, -CACHE_T:, ...].astype(prev_cache.dtype)
419419

420420
else:
421421
if hasattr(self, "resample"):

0 commit comments

Comments
 (0)