Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 264 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,270 @@ After installation completes, run the training script.

## Wan 2.1 Training

Coming soon.
in the first part, we'll run on a single host VM to get familiar with the workflow, then run on xpk for large scale training.

Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).

This workflow was tested using v5p-8 with a 500GB disk attached.

### Dataset Preparation

For this example, we'll be using the [PusaV1 dataset](https://huggingface.co/datasets/RaphaelLiu/PusaV1_training).

First, download the dataset.

```bash
export HF_DATASET_DIR=/mnt/disks/external_disk/PusaV1_training/
export TFRECORDS_DATASET_DIR=/mnt/disks/external_disk/wan_tfr_dataset_pusa_v1
huggingface-cli download RaphaelLiu/PusaV1_training --repo-type dataset --local-dir $HF_DATASET_DIR
```

Next run the TFRecords conversion script. This step prepares training and eval datasets. Validation is done as described in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/pdf/2403.03206). More details [here](https://github.com/mlcommons/training/tree/master/text_to_image#5-quality)

Training dataset.

```bash
python src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py src/maxdiffusion/configs/base_wan_14b.yml train_data_dir=$HF_DATASET_DIR tfrecords_dir=$TFRECORDS_DATASET_DIR/train no_records_per_shard=10 enable_eval_timesteps=False
```

The script will not have an output, but you can check the progress using:

```bash
ls -ll $TFRECORDS_DATASET_DIR/train
```

Evaluation dataset.

```bash
python src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py src/maxdiffusion/configs/base_wan_14b.yml train_data_dir=$HF_DATASET_DIR tfrecords_dir=$TFRECORDS_DATASET_DIR/eval no_records_per_shard=10 enable_eval_timesteps=True
```

The evaluation dataset creation takes the first 420 samples of the dataset and adds a timestep field. We then need to manually delete the first 420 samples from the `train` folder so they are not used in training.


```bash
printf "%s\n" $TFRECORDS_DATASET_DIR/train/file_*-*.tfrec | awk -F '[-.]' '$2+0 <= 420' | xargs -d '\n' rm
```

And verify that they do not exist.

```bash
printf "%s\n" $TFRECORDS_DATASET_DIR/train/file_*-*.tfrec | awk -F '[-.]' '$2+0 <= 420' | xargs -d '\n' echo
```

After the script is done running, you should see the following directory structure inside `$TFRECORDS_DATASET_DIR`

```
train
eval_timesteps
```

In some instances an empty file `file_42-430.tfrec` is created inside `eval_timesteps`, for sanity check, let's run a delete command.

```bash
rm $TFRECORDS_DATASET_DIR/eval_timesteps/file_42-430.tfrec
```

### Training on a Single VM

Loading the data is supported both locally from the disk created above, or from `gcs`. In this guide, we'll be using a gcs bucket to train. First copy the data to the GCS bucket.

```bash
BUCKET_NAME=my-bucket
gsutil -m cp -r $TFRECORDS_DATASET_DIR gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}
```

Now run the training command:

```bash
RUN_NAME=jfacevedo-wan-v5p-8-${RANDOM}
OUTPUT_DIR=gs://$BUCKET_NAME/wan/
DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/train/
EVAL_DATA_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/eval_timesteps/
SAVE_DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/save/
```

```bash
export LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
--xla_tpu_megacore_fusion_allow_ags=false \
--xla_enable_async_collective_permute=true \
--xla_tpu_enable_ag_backward_pipelining=true \
--xla_tpu_enable_data_parallel_all_reduce_opt=true \
--xla_tpu_data_parallel_opt_different_sized_ops=true \
--xla_tpu_enable_async_collective_fusion=true \
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
--xla_tpu_overlap_compute_collective_tc=true \
--xla_enable_async_all_gather=true \
--xla_tpu_scoped_vmem_limit_kib=65536 \
--xla_tpu_enable_async_all_to_all=true \
--xla_tpu_enable_all_experimental_scheduler_features=true \
--xla_tpu_enable_scheduler_memory_pressure_tracking=true \
--xla_tpu_host_transfer_overlap_limit=24 \
--xla_tpu_aggressive_opt_barrier_removal=ENABLED \
--xla_lhs_prioritize_async_depth_over_stall=ENABLED \
--xla_should_allow_loop_variant_parameter_in_chain=ENABLED \
--xla_should_add_loop_invariant_op_in_chain=ENABLED \
--xla_max_concurrent_host_send_recv=100 \
--xla_tpu_scheduler_percent_shared_memory_limit=100 \
--xla_latency_hiding_scheduler_rerun=2 \
--xla_tpu_use_minor_sharding_for_major_trivial_input=true \
--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 \
--xla_tpu_assign_all_reduce_scatter_layout=true'
```

```bash
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ python src/maxdiffusion/train_wan.py \
src/maxdiffusion/configs/base_wan_14b.yml \
attention='flash' \
weights_dtype=bfloat16 \
activations_dtype=bfloat16 \
guidance_scale=5.0 \
flow_shift=5.0 \
fps=16 \
skip_jax_distributed_system=False \
run_name=${RUN_NAME} \
output_dir=${OUTPUT_DIR} \
train_data_dir=${DATASET_DIR} \
load_tfrecord_cached=True \
height=1280 \
width=720 \
num_frames=81 \
num_inference_steps=50 \
jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \
max_train_steps=1000 \
enable_profiler=True \
dataset_save_location=${SAVE_DATASET_DIR} \
remat_policy='FULL' \
flash_min_seq_length=0 \
seed=$RANDOM \
skip_first_n_steps_for_profiler=3 \
profiler_steps=3 \
per_device_batch_size=0.25 \
ici_data_parallelism=1 \
ici_fsdp_parallelism=4 \
ici_tensor_parallelism=1
```

It is important to note a couple of things:
- per_device_batch_size can be a fractional, but must be a whole number when multiplied by number of devices. In this example, 0.25 * 4 (devices) = effective global batch size = 1.
- The step time in v5p-8 with global batch size = 1 is large due to using `FULL` remat. On larger number of chips we can run larger batch sizes greatly increasing MFU, as we will see in the next session of deploying with xpk.
- To enable eval during training set `eval_every` to a value > 0.
- In Wan2.1, the ici_fsdp_parallelism axis is used for sequence parallelism, the ici_tensor_parallelism axis is used for head parallelism.
- You can enable both, keeping in mind that Wan2.1 has 40 heads and 40 must be evenly divisible by ici_tensor_parallelism.
- For Sequence parallelism, the code pads the sequence length to evenly divide the sequence. Try out different ici_fsdp_parallelism numbers, but we find 2 and 4 to be the best right now.

You should eventually see a training run as:

```bash
***** Running training *****
Instantaneous batch size per device = 0.25
Total train batch size (w. parallel & distributed) = 1
Total optimization steps = 1000
Calculated TFLOPs per pass: 4893.2719
Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4
Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4
Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4
Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4
completed step: 0, seconds: 142.395, TFLOP/s/device: 34.364, loss: 0.270
To see full metrics 'tensorboard --logdir=gs://jfacevedo-maxdiffusion-v5p/wan/jfacevedo-wan-v5p-8-17263/tensorboard/'
completed step: 1, seconds: 137.207, TFLOP/s/device: 35.664, loss: 0.144
completed step: 2, seconds: 36.014, TFLOP/s/device: 135.871, loss: 0.210
completed step: 3, seconds: 36.016, TFLOP/s/device: 135.864, loss: 0.120
completed step: 4, seconds: 36.008, TFLOP/s/device: 135.894, loss: 0.107
completed step: 5, seconds: 36.008, TFLOP/s/device: 135.895, loss: 0.346
completed step: 6, seconds: 36.006, TFLOP/s/device: 135.900, loss: 0.169
```

### Deploying with XPK

This assummes the user has already created an xpk cluster, installed all dependencies and the also created the dataset from the step above. For getting started with MaxDiffusion and xpk see [this guide](docs/getting_started/run_maxdiffusion_via_xpk.md).

Using v5p-256 Then the command to run on xpk is as follows:

```bash
RUN_NAME=jfacevedo-wan-v5p-8-${RANDOM}
OUTPUT_DIR=gs://$BUCKET_NAME/wan/
DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/train/
EVAL_DATA_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/eval_timesteps/
SAVE_DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/save/
```

```bash
LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
--xla_tpu_megacore_fusion_allow_ags=false \
--xla_enable_async_collective_permute=true \
--xla_tpu_enable_ag_backward_pipelining=true \
--xla_tpu_enable_data_parallel_all_reduce_opt=true \
--xla_tpu_data_parallel_opt_different_sized_ops=true \
--xla_tpu_enable_async_collective_fusion=true \
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
--xla_tpu_overlap_compute_collective_tc=true \
--xla_enable_async_all_gather=true \
--xla_tpu_scoped_vmem_limit_kib=65536 \
--xla_tpu_enable_async_all_to_all=true \
--xla_tpu_enable_all_experimental_scheduler_features=true \
--xla_tpu_enable_scheduler_memory_pressure_tracking=true \
--xla_tpu_host_transfer_overlap_limit=24 \
--xla_tpu_aggressive_opt_barrier_removal=ENABLED \
--xla_lhs_prioritize_async_depth_over_stall=ENABLED \
--xla_should_allow_loop_variant_parameter_in_chain=ENABLED \
--xla_should_add_loop_invariant_op_in_chain=ENABLED \
--xla_max_concurrent_host_send_recv=100 \
--xla_tpu_scheduler_percent_shared_memory_limit=100 \
--xla_latency_hiding_scheduler_rerun=2 \
--xla_tpu_use_minor_sharding_for_major_trivial_input=true \
--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 \
--xla_tpu_assign_all_reduce_scatter_layout=true'
```

```bash
python3 ~/xpk/xpk.py workload create \
--cluster=$CLUSTER_NAME \
--project=$PROJECT \
--zone=$ZONE \
--device-type=$DEVICE_TYPE \
--num-slices=1 \
--command=" \
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ python src/maxdiffusion/train_wan.py \
src/maxdiffusion/configs/base_wan_14b.yml \
attention='flash' \
weights_dtype=bfloat16 \
activations_dtype=bfloat16 \
guidance_scale=5.0 \
flow_shift=5.0 \
fps=16 \
skip_jax_distributed_system=False \
run_name=${RUN_NAME} \
output_dir=${OUTPUT_DIR} \
train_data_dir=${DATASET_DIR} \
load_tfrecord_cached=True \
height=1280 \
width=720 \
num_frames=81 \
num_inference_steps=50 \
jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \
enable_profiler=True \
dataset_save_location=${SAVE_DATASET_DIR} \
remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \
flash_min_seq_length=0 \
seed=$RANDOM \
skip_first_n_steps_for_profiler=3 \
profiler_steps=3 \
per_device_batch_size=0.25 \
ici_data_parallelism=32 \
ici_fsdp_parallelism=4 \
ici_tensor_parallelism=1" \
max_train_steps=5000 \
eval_every=100 \
eval_data_dir=${EVAL_DATA_DIR} \
enable_generate_video_for_eval=True \
warmup_steps_fraction=0.025"
--base-docker-image=${IMAGE_DIR} \
--enable-debug-logs \
--workload=${RUN_NAME} \
--priority=medium \
--max-restarts=0
```

## Flux Training

Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ global_batch_size: 0
tfrecords_dir: ''
no_records_per_shard: 0
enable_eval_timesteps: False
considered_timesteps_list: [125, 250, 375, 500, 625, 750, 875]
timesteps_list: [125, 250, 375, 500, 625, 750, 875]
num_eval_samples: 420

warmup_steps_fraction: 0.1
Expand Down Expand Up @@ -321,6 +321,6 @@ qwix_module_path: ".*"
eval_every: -1
eval_data_dir: ""
enable_generate_video_for_eval: False # This will increase the used TPU memory.
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(considered_timesteps_list).
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list).

enable_ssim: False
Loading