From e15f3ce647d9b5c2e9d924362e08dd5d19647c21 Mon Sep 17 00:00:00 2001 From: Carl Persson Date: Tue, 3 Mar 2026 08:25:03 +0000 Subject: [PATCH] Apply TE shard_guard to train/generate scripts --- src/maxdiffusion/generate.py | 4 +++- src/maxdiffusion/generate_flux.py | 4 +++- src/maxdiffusion/generate_flux_pipeline.py | 4 +++- src/maxdiffusion/generate_ltx_video.py | 4 +++- src/maxdiffusion/generate_sdxl.py | 4 +++- src/maxdiffusion/generate_wan.py | 4 +++- src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py | 10 +--------- src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py | 10 +--------- src/maxdiffusion/train.py | 4 +++- src/maxdiffusion/train_utils.py | 4 ++-- src/maxdiffusion/train_wan.py | 8 ++++++-- src/maxdiffusion/trainers/wan_trainer.py | 11 +---------- 12 files changed, 32 insertions(+), 39 deletions(-) diff --git a/src/maxdiffusion/generate.py b/src/maxdiffusion/generate.py index 7b1f1f626..942577687 100644 --- a/src/maxdiffusion/generate.py +++ b/src/maxdiffusion/generate.py @@ -26,6 +26,7 @@ from absl import app from maxdiffusion import (pyconfig, FlaxDDIMScheduler, max_utils) +from maxdiffusion.train_utils import transformer_engine_context from maxdiffusion.maxdiffusion_utils import rescale_noise_cfg from flax.linen import partitioning as nn_partitioning from maxdiffusion.image_processor import VaeImageProcessor @@ -261,4 +262,5 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": - app.run(main) + with transformer_engine_context(): + app.run(main) diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 0ba8a7a85..28c94f195 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -33,6 +33,7 @@ from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel +from maxdiffusion.train_utils import transformer_engine_context from maxdiffusion.max_utils import ( device_put_replicated, get_memory_allocations, @@ -492,4 +493,5 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": - app.run(main) + with transformer_engine_context(): + app.run(main) diff --git a/src/maxdiffusion/generate_flux_pipeline.py b/src/maxdiffusion/generate_flux_pipeline.py index c89f413a8..de6f72895 100644 --- a/src/maxdiffusion/generate_flux_pipeline.py +++ b/src/maxdiffusion/generate_flux_pipeline.py @@ -26,6 +26,7 @@ from maxdiffusion import pyconfig, max_logging, max_utils from maxdiffusion.checkpointing.checkpointing_utils import load_params_from_path +from maxdiffusion.train_utils import transformer_engine_context from maxdiffusion.max_utils import setup_initial_state @@ -123,4 +124,5 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": - app.run(main) + with transformer_engine_context(): + app.run(main) diff --git a/src/maxdiffusion/generate_ltx_video.py b/src/maxdiffusion/generate_ltx_video.py index 93753f0c8..5249a5081 100644 --- a/src/maxdiffusion/generate_ltx_video.py +++ b/src/maxdiffusion/generate_ltx_video.py @@ -21,6 +21,7 @@ from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline, ConditioningItem import maxdiffusion.pipelines.ltx_video.crf_compressor as crf_compressor from maxdiffusion import pyconfig, max_logging +from maxdiffusion.train_utils import transformer_engine_context import torchvision.transforms.functional as TVF import imageio from datetime import datetime @@ -267,4 +268,5 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": - app.run(main) + with transformer_engine_context(): + app.run(main) diff --git a/src/maxdiffusion/generate_sdxl.py b/src/maxdiffusion/generate_sdxl.py index 3ab703706..81055ad41 100644 --- a/src/maxdiffusion/generate_sdxl.py +++ b/src/maxdiffusion/generate_sdxl.py @@ -29,6 +29,7 @@ from maxdiffusion import pyconfig, max_utils from maxdiffusion.image_processor import VaeImageProcessor +from maxdiffusion.train_utils import transformer_engine_context from maxdiffusion.maxdiffusion_utils import ( get_add_time_ids, rescale_noise_cfg, @@ -322,4 +323,5 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": - app.run(main) + with transformer_engine_context(): + app.run(main) diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 7c03a21c0..f53cc59b6 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -23,6 +23,7 @@ from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2 import WanCheckpointerI2V_2_2 from maxdiffusion import pyconfig, max_logging, max_utils from absl import app +from maxdiffusion.train_utils import transformer_engine_context from maxdiffusion.utils import export_to_video from maxdiffusion.utils.loading_utils import load_image from google.cloud import storage @@ -296,4 +297,5 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": - app.run(main) + with transformer_engine_context(): + app.run(main) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index 62c1a34a5..c247facb5 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -17,7 +17,6 @@ from typing import List, Union, Optional from ...pyconfig import HyperParameters from functools import partial -from contextlib import nullcontext from flax import nnx from flax.linen import partitioning as nn_partitioning import jax @@ -116,15 +115,8 @@ def __call__( scheduler=self.scheduler, scheduler_state=scheduler_state, ) - # Set the TE shard_guard context_manager if using TE cudnn_flash attention - if self.config.attention == "cudnn_flash_te": - from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error - shard_guard = global_shard_guard(MeshResource(cp_resource="context")) - else: - shard_guard = nullcontext() - - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules), shard_guard: + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): latents = p_run_inference( graphdef=graphdef, sharded_state=state, diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index 82261edac..16b601bad 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -17,7 +17,6 @@ from typing import List, Union, Optional from ...pyconfig import HyperParameters from functools import partial -from contextlib import nullcontext from flax import nnx from flax.linen import partitioning as nn_partitioning import jax @@ -140,15 +139,8 @@ def __call__( scheduler=self.scheduler, scheduler_state=scheduler_state, ) - # Set the TE shard_guard context_manager if using TE cudnn_flash attention - if self.config.attention == "cudnn_flash_te": - from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error - shard_guard = global_shard_guard(MeshResource(cp_resource="context")) - else: - shard_guard = nullcontext() - - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules), shard_guard: + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): latents = p_run_inference( low_noise_graphdef=low_noise_graphdef, low_noise_state=low_noise_state, diff --git a/src/maxdiffusion/train.py b/src/maxdiffusion/train.py index b206a9a92..5949439ae 100644 --- a/src/maxdiffusion/train.py +++ b/src/maxdiffusion/train.py @@ -22,6 +22,7 @@ max_logging, pyconfig, ) +from maxdiffusion.train_utils import transformer_engine_context from maxdiffusion.train_utils import ( validate_train_config, @@ -43,4 +44,5 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": - app.run(main) + with transformer_engine_context(): + app.run(main) diff --git a/src/maxdiffusion/train_utils.py b/src/maxdiffusion/train_utils.py index 8db92a40c..bbf44dfe6 100644 --- a/src/maxdiffusion/train_utils.py +++ b/src/maxdiffusion/train_utils.py @@ -206,11 +206,11 @@ def transformer_engine_context(): from transformer_engine.jax.sharding import global_shard_guard, MeshResource # Inform TransformerEngine of MaxDiffusion's physical mesh resources. mesh_resource = MeshResource( - dp_resource="data", + dp_resource=None, tp_resource="tensor", fsdp_resource="fsdp", pp_resource=None, - cp_resource=None, + cp_resource="context", ) with global_shard_guard(mesh_resource): yield diff --git a/src/maxdiffusion/train_wan.py b/src/maxdiffusion/train_wan.py index d272ca237..f0920b70d 100644 --- a/src/maxdiffusion/train_wan.py +++ b/src/maxdiffusion/train_wan.py @@ -19,7 +19,10 @@ import jax from absl import app from maxdiffusion import max_logging, pyconfig -from maxdiffusion.train_utils import validate_train_config +from maxdiffusion.train_utils import ( + validate_train_config, + transformer_engine_context, +) import flax @@ -43,4 +46,5 @@ def main(argv: Sequence[str]) -> None: if __name__ == "__main__": - app.run(main) + with transformer_engine_context(): + app.run(main) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index be018a650..dc35f0256 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -20,7 +20,6 @@ import pprint import numpy as np import threading -from contextlib import nullcontext from concurrent.futures import ThreadPoolExecutor import tensorflow as tf import jax.numpy as jnp @@ -392,18 +391,10 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data max_utils.activate_profiler(self.config) start_step_time = datetime.datetime.now() - # Designate the context parallel axis for sharding - if self.config.attention == "cudnn_flash_te": - from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error - - shard_guard = global_shard_guard(MeshResource(cp_resource="context")) - else: - shard_guard = nullcontext() - next_batch_future = executor.submit(load_next_batch, train_data_iterator, example_batch, self.config) with jax.profiler.StepTraceAnnotation( "train", step_num=step - ), pipeline.mesh, shard_guard, nn_partitioning.axis_rules(self.config.logical_axis_rules): + ), pipeline.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): state, scheduler_state, train_metric, rng = p_train_step(state, example_batch, rng, scheduler_state) train_metric["scalar"]["learning/loss"].block_until_ready() last_step_completion = datetime.datetime.now()