Skip to content
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,9 @@ After installation completes, run the training script.
- In Wan2.1, the ici_fsdp_parallelism axis is used for sequence parallelism, the ici_tensor_parallelism axis is used for head parallelism.
- You can enable both, keeping in mind that Wan2.1 has 40 heads and 40 must be evenly divisible by ici_tensor_parallelism.
- For Sequence parallelism, the code pads the sequence length to evenly divide the sequence. Try out different ici_fsdp_parallelism numbers, but we find 2 and 4 to be the best right now.
- For use on GPU it is recommended to enable the cudnn_te_flash attention kernel for optimal performance.
- Best performance is achieved with the use of batch parallelism, which can be enabled by using the ici_fsdp_batch_parallelism axis. Note that this parallelism strategy does not support fractional batch sizes.
- ici_fsdp_batch_parallelism and ici_fsdp_parallelism can be combined to allow for fractional batch sizes. However, padding is not currently supported for the cudnn_te_flash attention kernel and it is therefore required that the sequence length is divisible by the number of devices in the ici_fsdp_parallelism axis.

You should eventually see a training run as:

Expand Down
16 changes: 9 additions & 7 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu'
skip_jax_distributed_system: False

# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp_batch', 'fsdp', 'tensor']

# batch : batch dimension of data and activations
# hidden :
Expand All @@ -163,30 +163,32 @@ mesh_axes: ['data', 'fsdp', 'tensor']
# conv_in : conv.shape[2] weight
# conv_out : conv.shape[-1] weight
logical_axis_rules: [
['batch', 'data'],
['activation_batch', 'data'],
['batch', ['data', 'fsdp_batch']],
['activation_batch', ['data', 'fsdp_batch']],
['activation_length', 'fsdp'],
['activation_self_attn_heads', ['fsdp', 'tensor']],
['activation_cross_attn_q_length', ['fsdp', 'tensor']],
['activation_length', 'fsdp'],
['activation_heads', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['embed', ['fsdp', 'fsdp_batch']],
['heads', 'tensor'],
['norm', 'tensor'],
['conv_batch', ['data','fsdp']],
['conv_batch', ['data', 'fsdp', 'fsdp_batch']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
]
data_sharding: [['data', 'fsdp', 'tensor']]
data_sharding: [['data', 'fsdp_batch', 'fsdp', 'tensor']]

# One axis for each parallelism type may hold a placeholder (-1)
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
dcn_fsdp_batch_parallelism: 1
dcn_fsdp_parallelism: -1
dcn_tensor_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_batch_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_tensor_parallelism: 1

Expand Down
5 changes: 4 additions & 1 deletion src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,10 @@ def run(config, pipeline=None, filename_prefix=""):

def main(argv: Sequence[str]) -> None:
pyconfig.initialize(argv)
flax.config.update("flax_always_shard_variable", False)
try:
flax.config.update("flax_always_shard_variable", False)
except:
pass
run(pyconfig.config)


Expand Down
35 changes: 24 additions & 11 deletions src/maxdiffusion/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,17 +268,30 @@ def create_device_mesh(config, devices=None, logging=True):
max_logging.log(f"Devices: {devices} (num_devices: {num_devices})")

multi_slice_env = num_slices > 1

dcn_parallelism = [
config.dcn_data_parallelism,
config.dcn_fsdp_parallelism,
config.dcn_tensor_parallelism,
]
ici_parallelism = [
config.ici_data_parallelism,
config.ici_fsdp_parallelism,
config.ici_tensor_parallelism,
]
if "dcn_fsdp_batch_parallelism" in config.get_keys():
dcn_parallelism = [
config.dcn_data_parallelism,
config.dcn_fsdp_batch_parallelism,
config.dcn_fsdp_parallelism,
config.dcn_tensor_parallelism,
]
ici_parallelism = [
config.ici_data_parallelism,
config.ici_fsdp_batch_parallelism,
config.ici_fsdp_parallelism,
config.ici_tensor_parallelism,
]
else:
dcn_parallelism = [
config.dcn_data_parallelism,
config.dcn_fsdp_parallelism,
config.dcn_tensor_parallelism,
]
ici_parallelism = [
config.ici_data_parallelism,
config.ici_fsdp_parallelism,
config.ici_tensor_parallelism,
]

# Find possible unspecified parallelisms
ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI")
Expand Down
74 changes: 43 additions & 31 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,11 @@ def _reshape_data_from_cudnn_flash(tensor):

def _reshape_data_for_cudnn_flash(tensor, heads):
# reshapes from [b, s, h * d] to [b, s, h, d] (input format to flash format)
batch, seq, heads_and_dim_head = tensor.shape
tensor = tensor.reshape(batch, seq, heads, heads_and_dim_head // heads)
if len(tensor.shape) == 3:
batch, seq, dim_head = tensor.shape
tensor = tensor.reshape(batch, seq, heads, dim_head // heads)
else:
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
return tensor


Expand All @@ -89,7 +92,8 @@ def _reshape_batch_dim_to_heads(tensor, heads):
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
reshaped_tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size)
return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor"))
axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD))
return jax.lax.with_sharding_constraint(reshaped_tensor, axis_names)


def _reshape_heads_to_batch_dim(tensor, heads):
Expand All @@ -102,8 +106,8 @@ def _reshape_heads_to_batch_dim(tensor, heads):
else:
batch_size, head_size, seq_len, head_dim = tensor.shape
reshaped_tensor = tensor.reshape(batch_size * head_size, seq_len, head_dim)

return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor"))
axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD))
return jax.lax.with_sharding_constraint(reshaped_tensor, axis_names)


def _reshape_heads_to_head_dim(tensor):
Expand All @@ -112,7 +116,8 @@ def _reshape_heads_to_head_dim(tensor):
b, h, s, d = tensor.shape
tensor = jnp.transpose(tensor, axes=[0, 2, 1, 3])
reshaped_tensor = jnp.reshape(tensor, (b, -1, h * d))
return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor"))
axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD))
return jax.lax.with_sharding_constraint(reshaped_tensor, axis_names)


def _unflatten_heads(tensor, heads):
Expand Down Expand Up @@ -492,24 +497,12 @@ def _cudnn_flash_attention(query: Array, key: Array, value: Array, heads: int, m
key = _reshape_data_for_cudnn_flash(key, heads)
value = _reshape_data_for_cudnn_flash(value, heads)

cudnn_flash_axis_names = (BATCH, LENGTH, HEAD, D_KV)
axis_names = nn.logical_to_mesh_axes(cudnn_flash_axis_names)

query = nn.with_logical_constraint(query, axis_names)
key = nn.with_logical_constraint(key, axis_names)
value = nn.with_logical_constraint(value, axis_names)

@functools.partial(
shard_map.shard_map,
mesh=mesh,
in_specs=(axis_names, axis_names, axis_names),
out_specs=axis_names,
check_rep=False,
)
def wrap_flash_attention(query, key, value):
return jax.vmap(dpa_layer)(query, key, value, mask=None)

out = wrap_flash_attention(query, key, value)
axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD, D_KV))
query = jax.lax.with_sharding_constraint(query, axis_names)
key = jax.lax.with_sharding_constraint(key, axis_names)
value = jax.lax.with_sharding_constraint(value, axis_names)

out = dpa_layer(query, key, value, mask=None)
return _reshape_data_from_cudnn_flash(out)


Expand Down Expand Up @@ -706,7 +699,24 @@ def __init__(
):
self.dpa_layer = None
if attention_kernel == "cudnn_flash_te":
raise NotImplementedError(f"{self} has not been tested with {attention_kernel}")
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
jax.config.update("jax_use_shardy_partitioner", False)

dpa_layer = DotProductAttention(
head_dim=dim_head,
num_attention_heads=heads,
num_gqa_groups=heads,
attn_mask_type="no_mask", # 'no_mask', 'padding', 'causal', or 'padding_causal'
attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
# attention_dropout=self.dropout_rate,
dropout_rng_name="aqt",
dtype=dtype,
qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
scale_factor=scale,
transpose_batch_sequence=False,
)
variables = {}
self.dpa_layer = functools.partial(dpa_layer.apply, variables)

self.mesh = mesh
self.scale = scale
Expand Down Expand Up @@ -769,8 +779,9 @@ def setup(self):
self.dpa_layer = None
if self.attention_kernel == "cudnn_flash_te":
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
jax.config.update("jax_use_shardy_partitioner", False)

self.dpa_layer = DotProductAttention(
dpa_layer = DotProductAttention(
head_dim=self.dim_head,
num_attention_heads=self.heads,
num_gqa_groups=self.heads,
Expand All @@ -784,6 +795,9 @@ def setup(self):
scale_factor=self.scale,
transpose_batch_sequence=False,
)
variables = {}
self.dpa_layer = functools.partial(dpa_layer.apply, variables)


def apply_attention(self, query: Array, key: Array, value: Array):
return _apply_attention(
Expand Down Expand Up @@ -839,9 +853,6 @@ def __init__(
residual_checkpoint_name: str | None = None,
enable_jax_named_scopes: bool = False,
):
if attention_kernel == "cudnn_flash_te":
raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}")

if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None:
raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}")
self.dim_head = dim_head
Expand Down Expand Up @@ -998,8 +1009,9 @@ def __call__(
deterministic: bool = True,
rngs: nnx.Rngs = None,
) -> jax.Array:
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", "tensor"))
axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD))
hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names)
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names)
dtype = hidden_states.dtype
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
Expand Down
7 changes: 5 additions & 2 deletions src/maxdiffusion/models/wan/autoencoder_kl_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
BlockSizes = common_types.BlockSizes

CACHE_T = 2
flax.config.update('flax_always_shard_variable', False)
try:
flax.config.update('flax_always_shard_variable', False)
except:
pass

# Helper to ensure kernel_size, stride, padding are tuples of 3 integers
def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]:
Expand Down Expand Up @@ -73,7 +76,7 @@ def __init__(
self._depth_padding_before = self._causal_padding[1][0] # 2 * padding_tuple[0]

# Set sharding dynamically based on out_channels.
num_fsdp_axis_devices = mesh.device_ids.shape[1]
num_fsdp_axis_devices = mesh.shape["fsdp"]
kernel_sharding = (None, None, None, None, None)
if out_channels % num_fsdp_axis_devices == 0:
kernel_sharding = (None, None, None, None, "conv_out")
Expand Down
10 changes: 6 additions & 4 deletions src/maxdiffusion/models/wan/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,11 @@ def __call__(
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
)
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_heads"))
hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names)
hidden_states = checkpoint_name(hidden_states, "hidden_states")
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_kv"))
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names)

# 1. Self-attention
with self.conditional_named_scope("self_attn"):
Expand Down Expand Up @@ -515,7 +517,7 @@ def init_block(rngs):
if scan_layers:
self.blocks = init_block(rngs)
else:
blocks = nnx.List([])
blocks = []
for _ in range(num_layers):
block = WanTransformerBlock(
rngs=rngs,
Expand All @@ -535,7 +537,7 @@ def init_block(rngs):
enable_jax_named_scopes=enable_jax_named_scopes,
)
blocks.append(block)
self.blocks = blocks
self.blocks = nnx.data(blocks)

self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False)
self.proj_out = nnx.Linear(
Expand Down
9 changes: 8 additions & 1 deletion src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import List, Union, Optional
from ...pyconfig import HyperParameters
from functools import partial
from contextlib import nullcontext
from flax import nnx
from flax.linen import partitioning as nn_partitioning
import jax
Expand Down Expand Up @@ -113,8 +114,14 @@ def __call__(
scheduler=self.scheduler,
scheduler_state=scheduler_state,
)
# Set the TE shard_guard context_manager if using TE cudnn_flash attention
if self.config.attention == "cudnn_flash_te":
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error
shard_guard = global_shard_guard(MeshResource(cp_resource="fsdp"))
else:
shard_guard = nullcontext()

with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules), shard_guard:
latents = p_run_inference(
graphdef=graphdef,
sharded_state=state,
Expand Down
11 changes: 9 additions & 2 deletions src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import List, Union, Optional
from ...pyconfig import HyperParameters
from functools import partial
from contextlib import nullcontext
from flax import nnx
from flax.linen import partitioning as nn_partitioning
import jax
Expand Down Expand Up @@ -127,8 +128,14 @@ def __call__(
scheduler=self.scheduler,
scheduler_state=scheduler_state,
)

with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
# Set the TE shard_guard context_manager if using TE cudnn_flash attention
if self.config.attention == "cudnn_flash_te":
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error
shard_guard = global_shard_guard(MeshResource(cp_resource="fsdp"))
else:
shard_guard = nullcontext()

with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules), shard_guard:
latents = p_run_inference(
low_noise_graphdef=low_noise_graphdef,
low_noise_state=low_noise_state,
Expand Down
5 changes: 4 additions & 1 deletion src/maxdiffusion/train_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def main(argv: Sequence[str]) -> None:
config = pyconfig.config
validate_train_config(config)
max_logging.log(f"Found {jax.device_count()} devices.")
flax.config.update("flax_always_shard_variable", False)
try:
flax.config.update("flax_always_shard_variable", False)
except:
pass
train(config)


Expand Down
Loading