Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxdiffusion/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

WAN2_1 = "wan2.1"
WAN2_2 = "wan2.2"
LTX2_VIDEO = "ltx2_video"

WAN_MODEL = WAN2_1

Expand Down
87 changes: 87 additions & 0 deletions src/maxdiffusion/configs/ltx2_video.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#hardware
hardware: 'tpu'
skip_jax_distributed_system: False
attention: 'flash'
attention_sharding_uniform: True
precision: 'bf16'
data_sharding: ['data', 'fsdp', 'context', 'tensor']
remat_policy: "NONE"
names_which_can_be_saved: []
names_which_can_be_offloaded: []

jax_cache_dir: ''
weights_dtype: 'bfloat16'
activations_dtype: 'bfloat16'

run_name: ''
output_dir: ''
config_path: ''
save_config_to_gcs: False

frame_rate: 30
max_sequence_length: 1024
sampler: "from_checkpoint"

# Generation parameters
dataset_name: ''
dataset_save_location: ''
global_batch_size_to_train_on: 1
num_inference_steps: 40
guidance_scale: 3.0
fps: 24
prompt: "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
height: 512
width: 768
num_frames: 121
decode_timestep: 0.05
decode_noise_scale: 0.025
quantization: "int8"
seed: 10
#parallelism
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
logical_axis_rules: [
['batch', 'data'],
['activation_heads', 'fsdp'],
['activation_batch', 'data'],
['activation_kv', 'tensor'],
['mlp','tensor'],
['embed','fsdp'],
['heads', 'tensor'],
['norm', 'fsdp'],
['conv_batch', ['data','fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_in', 'fsdp']
]
dcn_data_parallelism: 1
dcn_fsdp_parallelism: -1
dcn_context_parallelism: 1
dcn_tensor_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: -1
ici_context_parallelism: 1
ici_tensor_parallelism: 1
enable_profiler: False

replicate_vae: False

allow_split_physical_axes: False
learning_rate_schedule_steps: -1
max_train_steps: 500
pretrained_model_name_or_path: 'Lightricks/LTX-2'
model_name: "ltx2_video"
model_type: "T2V"
unet_checkpoint: ''
checkpoint_dir: ""
cache_latents_text_encoder_outputs: True
per_device_batch_size: 1
compile_topology_num_slices: -1
quantization_local_shard_count: -1
use_qwix_quantization: False
weight_quantization_calibration_method: "absmax"
act_quantization_calibration_method: "absmax"
bwd_quantization_calibration_method: "absmax"
qwix_module_path: ".*"
jit_initializers: True
enable_single_replica_ckpt_restoring: False
75 changes: 75 additions & 0 deletions src/maxdiffusion/models/embeddings_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,3 +501,78 @@ def __call__(self, timestep, guidance, pooled_projection):
conditioning = time_guidance_emb + pooled_projections

return conditioning


class NNXTimesteps(nnx.Module):

def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale

def __call__(self, timesteps: jax.Array) -> jax.Array:
return get_sinusoidal_embeddings(
timesteps=timesteps,
embedding_dim=self.num_channels,
freq_shift=self.downscale_freq_shift,
flip_sin_to_cos=self.flip_sin_to_cos,
scale=self.scale,
)


class NNXPixArtAlphaCombinedTimestepSizeEmbeddings(nnx.Module):

def __init__(
self,
rngs: nnx.Rngs,
embedding_dim: int,
size_emb_dim: int,
use_additional_conditions: bool = False,
dtype: jnp.dtype = jnp.float32,
weights_dtype: jnp.dtype = jnp.float32,
):
self.outdim = size_emb_dim
self.use_additional_conditions = use_additional_conditions

self.time_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = NNXTimestepEmbedding(
rngs=rngs, in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, weights_dtype=weights_dtype
)

if use_additional_conditions:
self.additional_condition_proj = NNXTimesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.resolution_embedder = NNXTimestepEmbedding(
rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype
)
self.aspect_ratio_embedder = NNXTimestepEmbedding(
rngs=rngs, in_channels=256, time_embed_dim=size_emb_dim, dtype=dtype, weights_dtype=weights_dtype
)

def __call__(
self,
timestep: jax.Array,
resolution: Optional[jax.Array] = None,
aspect_ratio: Optional[jax.Array] = None,
hidden_dtype: jnp.dtype = jnp.float32,
) -> jax.Array:
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.astype(hidden_dtype))

if self.use_additional_conditions:
if resolution is None or aspect_ratio is None:
raise ValueError("resolution and aspect_ratio must be provided when use_additional_conditions is True")

resolution_emb = self.additional_condition_proj(resolution.flatten()).astype(hidden_dtype)
resolution_emb = self.resolution_embedder(resolution_emb)
resolution_emb = resolution_emb.reshape(timestep.shape[0], -1)

aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).astype(hidden_dtype)
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb)
aspect_ratio_emb = aspect_ratio_emb.reshape(timestep.shape[0], -1)

conditioning = timesteps_emb + jnp.concatenate([resolution_emb, aspect_ratio_emb], axis=1)
else:
conditioning = timesteps_emb

return conditioning
Loading
Loading