Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/maxdiffusion/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from absl import app
from maxdiffusion import (pyconfig, FlaxDDIMScheduler, max_utils)

from maxdiffusion.train_utils import transformer_engine_context
from maxdiffusion.maxdiffusion_utils import rescale_noise_cfg
from flax.linen import partitioning as nn_partitioning
from maxdiffusion.image_processor import VaeImageProcessor
Expand Down Expand Up @@ -261,4 +262,5 @@ def main(argv: Sequence[str]) -> None:


if __name__ == "__main__":
app.run(main)
with transformer_engine_context():
app.run(main)
4 changes: 3 additions & 1 deletion src/maxdiffusion/generate_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from maxdiffusion import FlaxAutoencoderKL, pyconfig, max_logging
from maxdiffusion.models.flux.transformers.transformer_flux_flax import FluxTransformer2DModel
from maxdiffusion.train_utils import transformer_engine_context
from maxdiffusion.max_utils import (
device_put_replicated,
get_memory_allocations,
Expand Down Expand Up @@ -492,4 +493,5 @@ def main(argv: Sequence[str]) -> None:


if __name__ == "__main__":
app.run(main)
with transformer_engine_context():
app.run(main)
4 changes: 3 additions & 1 deletion src/maxdiffusion/generate_flux_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from maxdiffusion import pyconfig, max_logging, max_utils

from maxdiffusion.checkpointing.checkpointing_utils import load_params_from_path
from maxdiffusion.train_utils import transformer_engine_context
from maxdiffusion.max_utils import setup_initial_state


Expand Down Expand Up @@ -123,4 +124,5 @@ def main(argv: Sequence[str]) -> None:


if __name__ == "__main__":
app.run(main)
with transformer_engine_context():
app.run(main)
4 changes: 3 additions & 1 deletion src/maxdiffusion/generate_ltx_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from maxdiffusion.pipelines.ltx_video.ltx_video_pipeline import LTXMultiScalePipeline, ConditioningItem
import maxdiffusion.pipelines.ltx_video.crf_compressor as crf_compressor
from maxdiffusion import pyconfig, max_logging
from maxdiffusion.train_utils import transformer_engine_context
import torchvision.transforms.functional as TVF
import imageio
from datetime import datetime
Expand Down Expand Up @@ -267,4 +268,5 @@ def main(argv: Sequence[str]) -> None:


if __name__ == "__main__":
app.run(main)
with transformer_engine_context():
app.run(main)
4 changes: 3 additions & 1 deletion src/maxdiffusion/generate_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from maxdiffusion import pyconfig, max_utils
from maxdiffusion.image_processor import VaeImageProcessor
from maxdiffusion.train_utils import transformer_engine_context
from maxdiffusion.maxdiffusion_utils import (
get_add_time_ids,
rescale_noise_cfg,
Expand Down Expand Up @@ -322,4 +323,5 @@ def main(argv: Sequence[str]) -> None:


if __name__ == "__main__":
app.run(main)
with transformer_engine_context():
app.run(main)
4 changes: 3 additions & 1 deletion src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from maxdiffusion.checkpointing.wan_checkpointer_i2v_2p2 import WanCheckpointerI2V_2_2
from maxdiffusion import pyconfig, max_logging, max_utils
from absl import app
from maxdiffusion.train_utils import transformer_engine_context
from maxdiffusion.utils import export_to_video
from maxdiffusion.utils.loading_utils import load_image
from google.cloud import storage
Expand Down Expand Up @@ -296,4 +297,5 @@ def main(argv: Sequence[str]) -> None:


if __name__ == "__main__":
app.run(main)
with transformer_engine_context():
app.run(main)
10 changes: 1 addition & 9 deletions src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import List, Union, Optional
from ...pyconfig import HyperParameters
from functools import partial
from contextlib import nullcontext
from flax import nnx
from flax.linen import partitioning as nn_partitioning
import jax
Expand Down Expand Up @@ -116,15 +115,8 @@ def __call__(
scheduler=self.scheduler,
scheduler_state=scheduler_state,
)
# Set the TE shard_guard context_manager if using TE cudnn_flash attention
if self.config.attention == "cudnn_flash_te":
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error

shard_guard = global_shard_guard(MeshResource(cp_resource="context"))
else:
shard_guard = nullcontext()

with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules), shard_guard:
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
latents = p_run_inference(
graphdef=graphdef,
sharded_state=state,
Expand Down
10 changes: 1 addition & 9 deletions src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from typing import List, Union, Optional
from ...pyconfig import HyperParameters
from functools import partial
from contextlib import nullcontext
from flax import nnx
from flax.linen import partitioning as nn_partitioning
import jax
Expand Down Expand Up @@ -140,15 +139,8 @@ def __call__(
scheduler=self.scheduler,
scheduler_state=scheduler_state,
)
# Set the TE shard_guard context_manager if using TE cudnn_flash attention
if self.config.attention == "cudnn_flash_te":
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error

shard_guard = global_shard_guard(MeshResource(cp_resource="context"))
else:
shard_guard = nullcontext()

with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules), shard_guard:
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
latents = p_run_inference(
low_noise_graphdef=low_noise_graphdef,
low_noise_state=low_noise_state,
Expand Down
4 changes: 3 additions & 1 deletion src/maxdiffusion/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
max_logging,
pyconfig,
)
from maxdiffusion.train_utils import transformer_engine_context

from maxdiffusion.train_utils import (
validate_train_config,
Expand All @@ -43,4 +44,5 @@ def main(argv: Sequence[str]) -> None:


if __name__ == "__main__":
app.run(main)
with transformer_engine_context():
app.run(main)
4 changes: 2 additions & 2 deletions src/maxdiffusion/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,11 @@ def transformer_engine_context():
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
# Inform TransformerEngine of MaxDiffusion's physical mesh resources.
mesh_resource = MeshResource(
dp_resource="data",
dp_resource=None,
tp_resource="tensor",
fsdp_resource="fsdp",
pp_resource=None,
cp_resource=None,
cp_resource="context",
)
with global_shard_guard(mesh_resource):
yield
Expand Down
8 changes: 6 additions & 2 deletions src/maxdiffusion/train_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
import jax
from absl import app
from maxdiffusion import max_logging, pyconfig
from maxdiffusion.train_utils import validate_train_config
from maxdiffusion.train_utils import (
validate_train_config,
transformer_engine_context,
)
import flax


Expand All @@ -43,4 +46,5 @@ def main(argv: Sequence[str]) -> None:


if __name__ == "__main__":
app.run(main)
with transformer_engine_context():
app.run(main)
11 changes: 1 addition & 10 deletions src/maxdiffusion/trainers/wan_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import pprint
import numpy as np
import threading
from contextlib import nullcontext
from concurrent.futures import ThreadPoolExecutor
import tensorflow as tf
import jax.numpy as jnp
Expand Down Expand Up @@ -392,18 +391,10 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
max_utils.activate_profiler(self.config)
start_step_time = datetime.datetime.now()

# Designate the context parallel axis for sharding
if self.config.attention == "cudnn_flash_te":
from transformer_engine.jax.sharding import global_shard_guard, MeshResource # pytype: disable=import-error

shard_guard = global_shard_guard(MeshResource(cp_resource="context"))
else:
shard_guard = nullcontext()

next_batch_future = executor.submit(load_next_batch, train_data_iterator, example_batch, self.config)
with jax.profiler.StepTraceAnnotation(
"train", step_num=step
), pipeline.mesh, shard_guard, nn_partitioning.axis_rules(self.config.logical_axis_rules):
), pipeline.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
state, scheduler_state, train_metric, rng = p_train_step(state, example_batch, rng, scheduler_state)
train_metric["scalar"]["learning/loss"].block_until_ready()
last_step_completion = datetime.datetime.now()
Expand Down
Loading