diff --git a/examples/community/README.md b/examples/community/README.md
index e314463077f0..4013248e84c5 100644
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -88,7 +88,7 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
 | FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://huggingface.co/papers/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
 | Stable Diffusion 3 InstructPix2Pix Pipeline | Implementation of Stable Diffusion 3 InstructPix2Pix Pipeline | [Stable Diffusion 3 InstructPix2Pix Pipeline](#stable-diffusion-3-instructpix2pix-pipeline) | [](https://huggingface.co/BleachNick/SD3_UltraEdit_freeform) [](https://huggingface.co/CaptainZZZ/sd3-instructpix2pix) | [Jiayu Zhang](https://github.com/xduzhangjiayu) and [Haozhe Zhao](https://github.com/HaozheZhao)|
 | Flux Kontext multiple images | A modified version of the `FluxKontextPipeline` that supports calling Flux Kontext with multiple reference images.| [Flux Kontext multiple input Pipeline](#flux-kontext-multiple-images) | - |  [Net-Mist](https://github.com/Net-Mist) |
-
+| Lumina-DiMOO Pipeline | Implementation of Lumina-DiMOO, an omni-foundational model for unified multimodal generation and understanding. | [Lumina-DiMOO Pipeline](#lumina-dimoo) | [](https://huggingface.co/Alpha-VLLM/Lumina-DiMOO) | [Yi Xin](https://synbol.github.io/) and [Qi Qin](https://github.com/ChinChyi)|
 
 To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
 
@@ -5527,3 +5527,160 @@ images = pipe(
 ).images
 images[0].save("pizzeria.png")
 ```
+
+
+# Lumina-DiMOO
+[Project](https://synbol.github.io/Lumina-DiMOO/) / [GitHub](https://github.com/Alpha-VLLM/Lumina-DiMOO/) / [Model](https://huggingface.co/Alpha-VLLM/Lumina-DiMOO)
+
+Lumina-DiMOO is a discrete-diffusion omni-modal foundation model unifying generation and understanding. This implementation integrates a Lumina-DiMOO switch for T2I, I2I editing, and MMU.
+
+#### Key features
+
+- **Unified Discrete Diffusion Architecture**: Employs a fully discrete diffusion framework to process inputs and outputs across diverse modalities.
+- **Versatile Multimodal Capabilities**: Supports a wide range of multimodal tasks, including text-to-image generation (arbitrary and high-resolution), image-to-image generation (e.g., image editing, subject-driven generation, inpainting), and advanced image understanding.
+- **Higher Sampling Efficiency**:  Outperforms previous autoregressive (AR) or hybrid AR-diffusion models with significantly faster sampling. A custom caching mechanism further boosts sampling speed by up to 2×.
+
+
+### Example Usage
+
+The Lumina-DiMOO pipeline provides three core functions — T2I, I2I, and MMU.
+For detailed implementation examples and creative applications, please visit the [GitHub](https://github.com/Alpha-VLLM/Lumina-DiMOO)
+
+
+#### Text-to-Image
+**prompt**             |  **image** 
+:-------------------------:|:-------------------------:
+| "A striking photograph of a glass of orange juice on a wooden kitchen table, capturing a playful moment. The orange juice splashes out of the glass and forms the word \"Smile\" in a whimsical, swirling script just above the glass. The background is softly blurred, revealing a cozy, homely kitchen with warm lighting and a sense of comfort." |  +
+```python
+import torch
+
+from diffusers import VQModel, DiffusionPipeline
+from transformers import AutoTokenizer
+
+vqvae = VQModel.from_pretrained("Alpha-VLLM/Lumina-DiMOO", subfolder="vqvae").to(device='cuda', dtype=torch.bfloat16)
+tokenizer = AutoTokenizer.from_pretrained("Alpha-VLLM/Lumina-DiMOO", trust_remote_code=True)
+
+pipe = DiffusionPipeline.from_pretrained(
+    "Alpha-VLLM/Lumina-DiMOO",
+    vqvae=vqvae,
+    tokenizer=tokenizer,
+    torch_dtype=torch.bfloat16,
+    custom_pipeline="lumina_dimoo",
+)
+pipe.to("cuda")
+
+prompt = '''A striking photograph of a glass of orange juice on a wooden kitchen table, capturing a playful moment. The orange juice splashes out of the glass and forms the word \"Smile\" in a whimsical, swirling script just above the glass. The background is softly blurred, revealing a cozy, homely kitchen with warm lighting and a sense of comfort.'''
+
+img = pipe(
+    prompt=prompt,
+    task="text_to_image",
+    height=768,
+    width=1536,
+    num_inference_steps=64,
+    cfg_scale=4.0,     
+    use_cache=True,
+    cache_ratio=0.9, 
+    warmup_ratio=0.3,
+    refresh_interval=5
+).images[0]
+
+img.save("t2i_test_output.png")
+```
+
+#### Image-to-Image
+**prompt**             |  **image_before**   |  **image_after**  
+:-------------------------:|:-------------------------:|:-------------------------:
+| "A functional wooden printer stand.Nestled next to a brick wall in a bustling city street, it stands firm as pedestrians hustle by, illuminated by the warm glow of vintage street lamps." |
+
+```python
+import torch
+
+from diffusers import VQModel, DiffusionPipeline
+from transformers import AutoTokenizer
+
+vqvae = VQModel.from_pretrained("Alpha-VLLM/Lumina-DiMOO", subfolder="vqvae").to(device='cuda', dtype=torch.bfloat16)
+tokenizer = AutoTokenizer.from_pretrained("Alpha-VLLM/Lumina-DiMOO", trust_remote_code=True)
+
+pipe = DiffusionPipeline.from_pretrained(
+    "Alpha-VLLM/Lumina-DiMOO",
+    vqvae=vqvae,
+    tokenizer=tokenizer,
+    torch_dtype=torch.bfloat16,
+    custom_pipeline="lumina_dimoo",
+)
+pipe.to("cuda")
+
+prompt = '''A striking photograph of a glass of orange juice on a wooden kitchen table, capturing a playful moment. The orange juice splashes out of the glass and forms the word \"Smile\" in a whimsical, swirling script just above the glass. The background is softly blurred, revealing a cozy, homely kitchen with warm lighting and a sense of comfort.'''
+
+img = pipe(
+    prompt=prompt,
+    task="text_to_image",
+    height=768,
+    width=1536,
+    num_inference_steps=64,
+    cfg_scale=4.0,     
+    use_cache=True,
+    cache_ratio=0.9, 
+    warmup_ratio=0.3,
+    refresh_interval=5
+).images[0]
+
+img.save("t2i_test_output.png")
+```
+
+#### Image-to-Image
+**prompt**             |  **image_before**   |  **image_after**  
+:-------------------------:|:-------------------------:|:-------------------------:
+| "A functional wooden printer stand.Nestled next to a brick wall in a bustling city street, it stands firm as pedestrians hustle by, illuminated by the warm glow of vintage street lamps." |  |
 |  |
+
+```python
+import torch
+
+from diffusers import VQModel, DiffusionPipeline
+from transformers import AutoTokenizer
+from diffusers.utils import load_image
+
+vqvae = VQModel.from_pretrained("Alpha-VLLM/Lumina-DiMOO", subfolder="vqvae").to(device='cuda', dtype=torch.bfloat16)
+tokenizer = AutoTokenizer.from_pretrained("Alpha-VLLM/Lumina-DiMOO", trust_remote_code=True)
+
+pipe = DiffusionPipeline.from_pretrained(
+    "Alpha-VLLM/Lumina-DiMOO",
+    vqvae=vqvae,
+    tokenizer=tokenizer,
+    torch_dtype=torch.bfloat16,
+    custom_pipeline="lumina_dimoo",
+)
+pipe.to("cuda")
+
+input_image = load_image(
+    "https://raw.githubusercontent.com/Alpha-VLLM/Lumina-DiMOO/main/examples/example_2.jpg"
+).convert("RGB")
+
+prompt = "A functional wooden printer stand.Nestled next to a brick wall in a bustling city street, it stands firm as pedestrians hustle by, illuminated by the warm glow of vintage street lamps."
+
+img = pipe(
+    prompt=prompt,
+    image=input_image,
+    edit_type="depth_control",
+    num_inference_steps=64,
+    temperature=1.0,
+    cfg_scale=2.5,
+    cfg_img=4.0,
+    task="image_to_image"
+).images[0]
+
+img.save("i2i_test_output.png")
+
+```
+
+
+#### Multimodal Understanding
+**question**       |   **image**      |   **answer** 
+:-------------------------:|:-------------------------:|:-------------------------:
+| "Please describe the image." |
 |
+
+```python
+import torch
+
+from diffusers import VQModel, DiffusionPipeline
+from transformers import AutoTokenizer
+from diffusers.utils import load_image
+
+vqvae = VQModel.from_pretrained("Alpha-VLLM/Lumina-DiMOO", subfolder="vqvae").to(device='cuda', dtype=torch.bfloat16)
+tokenizer = AutoTokenizer.from_pretrained("Alpha-VLLM/Lumina-DiMOO", trust_remote_code=True)
+
+pipe = DiffusionPipeline.from_pretrained(
+    "Alpha-VLLM/Lumina-DiMOO",
+    vqvae=vqvae,
+    tokenizer=tokenizer,
+    torch_dtype=torch.bfloat16,
+    custom_pipeline="lumina_dimoo",
+)
+pipe.to("cuda")
+
+input_image = load_image(
+    "https://raw.githubusercontent.com/Alpha-VLLM/Lumina-DiMOO/main/examples/example_2.jpg"
+).convert("RGB")
+
+prompt = "A functional wooden printer stand.Nestled next to a brick wall in a bustling city street, it stands firm as pedestrians hustle by, illuminated by the warm glow of vintage street lamps."
+
+img = pipe(
+    prompt=prompt,
+    image=input_image,
+    edit_type="depth_control",
+    num_inference_steps=64,
+    temperature=1.0,
+    cfg_scale=2.5,
+    cfg_img=4.0,
+    task="image_to_image"
+).images[0]
+
+img.save("i2i_test_output.png")
+
+```
+
+
+#### Multimodal Understanding
+**question**       |   **image**      |   **answer** 
+:-------------------------:|:-------------------------:|:-------------------------:
+| "Please describe the image." |  | "The image shows a vibrant orange sports car parked in a showroom. The car has a sleek, aerodynamic design with a prominent front grille and side vents. The body is adorned with black and orange racing stripes, creating a striking contrast against the orange paint. The car is equipped with black alloy wheels and a low-profile body style. The background features a white wall with a large emblem that reads "BREITZEN" and includes a silhouette of a horse and text. The floor is tiled with dark tiles, and the showroom is well-lit, highlighting the car. The overall setting suggests a high-end, possibly luxury, automotive environment."|
+
+
+```python
+import torch
+
+from diffusers import VQModel, DiffusionPipeline
+from transformers import AutoTokenizer
+from diffusers.utils import load_image
+
+vqvae = VQModel.from_pretrained("Alpha-VLLM/Lumina-DiMOO", subfolder="vqvae").to(device='cuda', dtype=torch.bfloat16)
+tokenizer = AutoTokenizer.from_pretrained("Alpha-VLLM/Lumina-DiMOO", trust_remote_code=True)
+
+pipe = DiffusionPipeline.from_pretrained(
+    "Alpha-VLLM/Lumina-DiMOO",
+    vqvae=vqvae,
+    tokenizer=tokenizer,
+    torch_dtype=torch.bfloat16,
+    custom_pipeline="lumina_dimoo",
+)
+pipe.to("cuda")
+
+question = "Please describe the image."
+
+input_image = load_image(
+    "https://raw.githubusercontent.com/Alpha-VLLM/Lumina-DiMOO/main/examples/example_8.png"
+).convert("RGB")
+
+out = pipe(
+    prompt=question,
+    image=input_image,
+    task="multimodal_understanding",
+    num_inference_steps=128,
+    gen_length=128,
+    block_length=32,
+    temperature=0.0,
+    cfg_scale=0.0,
+)
+
+text = getattr(out, "text", out)
+with open("mmu_answer.txt", "w", encoding="utf-8") as f:
+    f.write(text.strip() + "\n")
+```
+
+
+
+
diff --git a/examples/community/lumina_dimoo.py b/examples/community/lumina_dimoo.py
new file mode 100644
index 000000000000..6ed4b0d6398c
--- /dev/null
+++ b/examples/community/lumina_dimoo.py
@@ -0,0 +1,2676 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+import math
+import random
+import sys
+from abc import abstractmethod
+from dataclasses import dataclass, fields
+from enum import Enum
+from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Sequence, Tuple, Union, cast
+from accelerate import init_empty_weights
+
+import numpy as np
+import torch
+import torch.backends.cuda
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import einsum
+from PIL import Image
+from transformers import AutoConfig, AutoTokenizer, PretrainedConfig, PreTrainedModel
+from transformers.cache_utils import Cache
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from diffusers import DiffusionPipeline, VQModel
+from diffusers.utils import BaseOutput, logging, replace_example_docstring
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+
+from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
+
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+# --- Start of model definition copied from Lumina-DiMOO ---
+
+
+class StrEnum(str, Enum):
+    """
+    This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
+    We include this here for compatibility with older version of Python.
+    """
+
+    def __str__(self) -> str:
+        return self.value
+
+    def __repr__(self) -> str:
+        return f"'{str(self)}'"
+
+
+class LayerNormType(StrEnum):
+    default = "default"
+    low_precision = "low_precision"
+    rms = "rms"
+    gemma_rms = "gemma_rms"
+    amd_compatible = "amd_compatible"
+
+
+class ActivationType(StrEnum):
+    gelu = "gelu"
+    relu = "relu"
+    silu = "silu"
+    swiglu = "swiglu"
+
+
+class BlockType(StrEnum):
+    sequential = "sequential"
+    parallel = "parallel"
+    llama = "llama"
+
+
+class InitFnType(StrEnum):
+    mitchell = "mitchell"
+    normal = "normal"
+    kaiming_normal = "kaiming_normal"
+    fan_in = "fan_in"
+    full_megatron = "full_megatron"
+
+
+@dataclass
+class ModelConfig:
+    """
+    LLaDA (model) configuration.
+    """
+
+    # Note that the defaults for these attributes are equivalent to the base GPT2 model.
+
+    d_model: int = 768
+    n_heads: int = 12
+    n_kv_heads: Optional[int] = None
+    n_layers: int = 12
+    mlp_ratio: int = 4
+    mlp_hidden_size: Optional[int] = None
+    activation_type: ActivationType = ActivationType.swiglu
+    block_type: BlockType = BlockType.sequential
+    block_group_size: int = 1
+    alibi: bool = False
+    alibi_bias_max: float = 8.0
+    rope: bool = False
+    rope_full_precision: bool = True
+    flash_attention: bool = False
+    attention_dropout: float = 0.1
+    multi_query_attention: Optional[bool] = None
+    attention_layer_norm: bool = False
+    residual_dropout: float = 0.1
+    embedding_dropout: float = 0.1
+    input_emb_norm: bool = False
+    layer_norm_type: LayerNormType = LayerNormType.default
+    layer_norm_with_affine: bool = True
+    rms_norm_eps: float = 1e-05
+    attention_layer_norm_with_affine: bool = True
+    max_sequence_length: int = 1024
+    rope_theta: float = 10000.0
+    include_qkv_bias: Optional[bool] = False
+    include_bias: bool = False
+    bias_for_layer_norm: Optional[bool] = None
+    scale_logits: bool = False
+    vocab_size: int = 50257
+    embedding_size: Optional[int] = 50304
+    weight_tying: bool = True
+    eos_token_id: int = 50256
+    pad_token_id: int = 50256
+    mask_token_id: Optional[int] = 50256
+    init_device: Optional[str] = None
+    init_fn: InitFnType = InitFnType.normal
+    init_std: float = 0.02
+    init_cutoff_factor: Optional[float] = None
+    precision: Optional[str] = None
+
+    @property
+    def effective_n_kv_heads(self) -> int:
+        if self.n_kv_heads is None:
+            if self.multi_query_attention is True:
+                return 1
+            else:
+                return self.n_heads
+        else:
+            if self.multi_query_attention is None:
+                return self.n_kv_heads
+            if self.multi_query_attention:
+                n_kv_heads_should_be = 1
+            else:
+                n_kv_heads_should_be = self.n_heads
+            if self.n_kv_heads == n_kv_heads_should_be:
+                return n_kv_heads_should_be
+            else:
+                raise Exception("You can't set `multi_query_attention` and `n_kv_heads` at the same time.")
+
+
+class ActivationCheckpointingStrategy(StrEnum):
+    whole_layer = "whole_layer"
+    one_in_two = "one_in_two"
+    one_in_three = "one_in_three"
+    one_in_four = "one_in_four"
+    two_in_three = "two_in_three"
+    three_in_four = "three_in_four"
+    four_in_five = "four_in_five"
+    nine_in_ten = "nine_in_ten"
+    fine_grained = "fine_grained"
+
+
+class LLaDAConfig(PretrainedConfig):
+    model_type = "llada"
+    keys_to_ignore_at_inference = ["past_key_values"]  # TODO: confirm
+
+    def __init__(self, use_cache: bool = False, **kwargs):
+        model_config = ModelConfig()
+        all_kwargs = model_config.__dict__
+        all_kwargs.update(kwargs)
+        all_kwargs.update({"use_cache": use_cache})
+        all_kwargs.update({"architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])})
+        super().__init__(**all_kwargs)
+
+    @property
+    def num_attention_heads(self):
+        return self.n_heads
+
+    @property
+    def num_hidden_layers(self):
+        return self.n_layers
+
+    @property
+    def hidden_size(self):
+        return self.d_model
+
+
+if sys.version_info.minor > 8:
+    from collections.abc import MutableMapping
+elif sys.version_info.minor == 8:
+    from typing import MutableMapping
+else:
+    raise SystemExit("This script supports Python 3.8 or higher")
+
+
+class ModuleType(StrEnum):
+    in_module = "in"
+    out_module = "out"
+    emb = "emb"
+    final_out = "final_out"
+
+
+def init_weights(
+    config: ModelConfig,
+    module: Union[nn.Linear, nn.Embedding],
+    d: Optional[int] = None,
+    layer_id: Optional[int] = None,
+    std_factor: float = 1.0,
+    type_of_module: Optional[ModuleType] = None,
+) -> None:
+    d = d if d is not None else config.d_model
+    if config.init_fn == InitFnType.normal:
+        std = config.init_std * std_factor
+        if config.init_cutoff_factor is not None:
+            cutoff_value = config.init_cutoff_factor * std
+            nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
+        else:
+            nn.init.normal_(module.weight, mean=0.0, std=std)
+    elif config.init_fn == InitFnType.mitchell:
+        std = std_factor / math.sqrt(d)
+        if layer_id is not None:
+            std = std / math.sqrt(2 * (layer_id + 1))
+        nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
+    elif config.init_fn == InitFnType.kaiming_normal:
+        nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
+    elif config.init_fn == InitFnType.fan_in:
+        std = std_factor / math.sqrt(d)
+        nn.init.normal_(module.weight, mean=0.0, std=std)
+    elif config.init_fn == InitFnType.full_megatron:
+        if type_of_module is None:
+            raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
+
+        cutoff_factor = config.init_cutoff_factor
+        if cutoff_factor is None:
+            cutoff_factor = 3
+
+        if type_of_module == ModuleType.in_module:
+            # for att_proj (same as QKV), ff_proj
+            std = config.init_std
+        elif type_of_module == ModuleType.out_module:
+            # for attn_out, ff_out
+            std = config.init_std / math.sqrt(2.0 * config.n_layers)
+        elif type_of_module == ModuleType.emb:
+            # positional embeddings (wpe)
+            # token embeddings (wte)
+            std = config.init_std
+        elif type_of_module == ModuleType.final_out:
+            # final output (ff_out)
+            std = config.d_model**-0.5
+        else:
+            raise RuntimeError(f"Unknown module type '{type_of_module}'")
+        nn.init.trunc_normal_(
+            module.weight,
+            mean=0.0,
+            std=std,
+            a=-cutoff_factor * std,
+            b=cutoff_factor * std,
+        )
+    else:
+        raise NotImplementedError(config.init_fn)
+
+    if isinstance(module, nn.Linear):
+        if module.bias is not None:
+            nn.init.zeros_(module.bias)
+
+        if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False):
+            with torch.no_grad():
+                module.weight.div_(math.sqrt(2 * config.n_layers))
+
+
+def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
+    if check_neg_inf:
+        x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
+    if check_pos_inf:
+        x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
+
+
+def activation_checkpoint_function(cfg: ModelConfig):
+    preserve_rng_state = (
+        (cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0)
+    )
+    from torch.utils.checkpoint import checkpoint
+
+    return functools.partial(
+        checkpoint,
+        preserve_rng_state=preserve_rng_state,
+        use_reentrant=False,
+    )
+
+
+class BufferCache(dict, MutableMapping[str, torch.Tensor]):
+    """
+    Cache for attention biases and other things that would normally be stored as buffers.
+    We avoid using buffers because we've run into various issues doing so with FSDP.
+    In general it appears the way FSDP handles buffers is not well-defined.
+    It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid
+    since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into
+    NaNs when they're synchronized due to casting or some other issue.
+    """
+
+
+def _non_meta_init_device(config: ModelConfig) -> torch.device:
+    if config.init_device is not None and config.init_device != "meta":
+        return torch.device(config.init_device)
+    else:
+        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class Dropout(nn.Dropout):
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        if self.p == 0.0:
+            return input
+        else:
+            return F.dropout(input, self.p, self.training, self.inplace)
+
+
+class LayerNormBase(nn.Module):
+    def __init__(
+        self,
+        config: ModelConfig,
+        *,
+        size: Optional[int] = None,
+        elementwise_affine: Optional[bool] = True,
+        eps: float = 1e-05,
+    ):
+        super().__init__()
+        self.config = config
+        self.eps = eps
+        self.normalized_shape = (size or config.d_model,)
+        if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
+            self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
+            use_bias = self.config.bias_for_layer_norm
+            if use_bias is None:
+                use_bias = self.config.include_bias
+            if use_bias:
+                self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
+            else:
+                self.register_parameter("bias", None)
+        else:
+            self.register_parameter("bias", None)
+            self.register_parameter("weight", None)
+
+    @abstractmethod
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        raise NotImplementedError
+
+    @classmethod
+    def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> "LayerNormBase":
+        if config.layer_norm_type == LayerNormType.default:
+            return LayerNorm(config, size=size, low_precision=False, **kwargs)
+        elif config.layer_norm_type == LayerNormType.low_precision:
+            return LayerNorm(config, size=size, low_precision=True, **kwargs)
+        elif config.layer_norm_type == LayerNormType.rms:
+            return RMSLayerNorm(config, size=size, **kwargs)
+        elif config.layer_norm_type == LayerNormType.gemma_rms:
+            return GemmaRMSLayerNorm(config, size=size, **kwargs)
+        else:
+            raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")
+
+    def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
+        if tensor.device.type == "cuda" and torch.is_autocast_enabled():
+            return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype())
+        elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled():
+            return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
+        else:
+           return tensor
+
+    def reset_parameters(self):
+        if self.weight is not None:
+            torch.nn.init.ones_(self.weight)  # type: ignore
+        if self.bias is not None:
+            torch.nn.init.zeros_(self.bias)  # type: ignore
+
+
+class LayerNorm(LayerNormBase):
+    def __init__(
+        self,
+        config: ModelConfig,
+        size: Optional[int] = None,
+        low_precision: bool = False,
+        elementwise_affine: Optional[bool] = None,
+        eps: float = 1e-05,
+    ):
+        super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
+        self.low_precision = low_precision
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if self.low_precision:
+            module_device = x.device
+            downcast_x = self._cast_if_autocast_enabled(x)
+            downcast_weight = (
+                self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
+            )
+            downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
+            with torch.autocast(enabled=False, device_type=module_device.type):
+                return F.layer_norm(
+                    downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps
+                )
+        else:
+            return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
+
+
+class RMSLayerNorm(LayerNormBase):
+    def __init__(
+        self,
+        config: ModelConfig,
+        size: Optional[int] = None,
+        elementwise_affine: Optional[bool] = None,
+        eps: float = 1e-5,
+    ):
+        super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        with torch.autocast(enabled=False, device_type=x.device.type):
+            og_dtype = x.dtype
+            x = x.to(torch.float32)
+            variance = x.pow(2).mean(-1, keepdim=True)
+            x = x * torch.rsqrt(variance + self.eps)
+            x = x.to(og_dtype)
+
+        if self.weight is not None:
+            if self.bias is not None:
+                return self.weight * x + self.bias
+            else:
+                return self.weight * x
+        else:
+            return x
+
+
+class GemmaRMSLayerNorm(LayerNormBase):
+    def __init__(
+        self,
+        config: ModelConfig,
+        size: Optional[int] = None,
+        elementwise_affine: Optional[bool] = None,
+        eps: float = 1e-5,
+    ):
+        super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        with torch.autocast(enabled=False, device_type=x.device.type):
+            og_dtype = x.dtype
+            x = x.to(torch.float32)
+            variance = x.pow(2).mean(-1, keepdim=True)
+            x = x * torch.rsqrt(variance + self.eps)
+            x = x.to(og_dtype)
+
+        if self.weight is not None:
+            if self.bias is not None:
+                return x * (1 + self.weight) + self.bias
+            else:
+                return x * (1 + self.weight)
+        else:
+            return x
+
+
+class RotaryEmbedding(nn.Module):
+    def __init__(self, config: ModelConfig, cache: BufferCache):
+        super().__init__()
+        self.config = config
+        self.__cache = cache
+        # Warm up cache.
+        self.rope_theta = config.rope_theta
+        self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config))
+
+    def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
+        if (
+            (pos_sin := self.__cache.get("rope_pos_sin")) is not None
+            and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
+            and pos_sin.shape[-2] >= seq_len
+            and pos_cos.shape[-2] >= seq_len
+        ):
+            if pos_sin.device != device:
+                pos_sin = pos_sin.to(device)
+                self.__cache["rope_pos_sin"] = pos_sin
+            if pos_cos.device != device:
+                pos_cos = pos_cos.to(device)
+                self.__cache["rope_pos_cos"] = pos_cos
+            return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
+
+        with torch.autocast(device.type, enabled=False):
+            dim = self.config.d_model // self.config.n_heads
+            inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
+            seq = torch.arange(seq_len, device=device, dtype=torch.float)
+            freqs = einsum("i , j -> i j", seq, inv_freq)
+            positions = torch.cat((freqs, freqs), dim=-1)
+            pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
+        self.__cache["rope_pos_sin"] = pos_sin
+        self.__cache["rope_pos_cos"] = pos_cos
+        return pos_sin, pos_cos
+
+    def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
+        B, nh, T, hs = x.size()
+        x = x.view(B, nh, T, 2, hs // 2)
+        x1, x2 = x.unbind(dim=-2)
+        return torch.cat((-x2, x1), dim=-1)
+
+    def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
+        return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
+
+    def forward(self, q: torch.Tensor, k: torch.Tensor, q_mask=None) -> Tuple[torch.Tensor, torch.Tensor]:
+        if self.config.rope_full_precision:
+            q_, k_ = q.float(), k.float()
+        else:
+            q_, k_ = q, k
+
+        with torch.autocast(q.device.type, enabled=False):
+            query_len, key_len = q_.shape[-2], k_.shape[-2]  # could be different if layer_past not None
+            pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
+            pos_sin = pos_sin.type_as(q_)
+            pos_cos = pos_cos.type_as(q_)
+            if q_mask is None:
+                q_ = self.apply_rotary_pos_emb(
+                    pos_sin[:, :, key_len - query_len : key_len, :],
+                    pos_cos[:, :, key_len - query_len : key_len, :],
+                    q_,
+                )
+            else:
+                q_ = self.apply_rotary_pos_emb(
+                    pos_sin[:, :, q_mask, :],
+                    pos_cos[:, :, q_mask, :],
+                    q_,
+                )
+            k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
+        return q_.type_as(q), k_.type_as(k)
+
+
+class Activation(nn.Module):
+    def __init__(self, config: ModelConfig):
+        super().__init__()
+        self.config = config
+
+    @abstractmethod
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        raise NotImplementedError
+
+    @property
+    @abstractmethod
+    def output_multiplier(self) -> float:
+        raise NotImplementedError
+
+    @classmethod
+    def build(cls, config: ModelConfig) -> "Activation":
+        if config.activation_type == ActivationType.gelu:
+            return cast(Activation, GELU(approximate="none"))
+        elif config.activation_type == ActivationType.relu:
+            return cast(Activation, ReLU(inplace=False))
+        elif config.activation_type == ActivationType.silu:
+            return cast(Activation, SiLU(inplace=False))
+        elif config.activation_type == ActivationType.swiglu:
+            return SwiGLU(config)
+        else:
+            raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")
+
+
+class GELU(nn.GELU):
+    @property
+    def output_multiplier(self) -> float:
+        return 1.0
+
+
+class ReLU(nn.ReLU):
+    @property
+    def output_multiplier(self) -> float:
+        return 1.0
+
+
+class SiLU(nn.SiLU):
+    @property
+    def output_multiplier(self) -> float:
+        return 1.0
+
+
+class SwiGLU(Activation):
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x, gate = x.chunk(2, dim=-1)
+        return F.silu(gate) * x
+
+    @property
+    def output_multiplier(self) -> float:
+        return 0.5
+
+
+def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
+    att_bias = torch.triu(
+        torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
+        diagonal=1,
+    )
+    att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
+    return att_bias.view(1, 1, seq_len, seq_len)  # type: ignore
+
+
+def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
+    if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
+        if causal_bias.device != device:
+            causal_bias = causal_bias.to(device)
+            cache["causal_attention_bias"] = causal_bias
+        return causal_bias
+    with torch.autocast(device.type, enabled=False):
+        causal_bias = causal_attention_bias(seq_len, device)
+    cache["causal_attention_bias"] = causal_bias
+    return causal_bias
+
+
+def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
+    alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)
+
+    # shape: (1, 1, seq_len, seq_len)
+    alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
+    alibi_bias.abs_().mul_(-1)
+
+    # shape: (n_heads,)
+    m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
+    m.mul_(config.alibi_bias_max / config.n_heads)
+
+    # shape: (1, n_heads, seq_len, seq_len)
+    return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1)))  # type: ignore
+
+
+class LLaDABlock(nn.Module):
+    def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
+        super().__init__()
+        self.layer_id = layer_id
+        self.config = config
+        self.hidden_size = (
+            config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
+        )
+        self.__cache = cache
+        assert config.d_model % config.n_heads == 0
+
+        self._activation_checkpoint_fn = None
+
+        # Dropout.
+        self.dropout = Dropout(config.residual_dropout)
+
+        # Layer norms.
+        self.k_norm: Optional[LayerNormBase] = None
+        self.q_norm: Optional[LayerNormBase] = None
+        if config.attention_layer_norm:
+            self.k_norm = LayerNormBase.build(
+                config,
+                size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
+                elementwise_affine=config.attention_layer_norm_with_affine,
+            )
+            self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
+
+        # Activation function.
+        self.act = Activation.build(config)
+        assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
+
+        # Attention output projection.
+        self.attn_out = nn.Linear(
+            config.d_model, config.d_model, bias=config.include_bias, device=config.init_device
+        )
+
+        # Feed-forward output projection.
+        self.ff_out = nn.Linear(
+            int(self.act.output_multiplier * self.hidden_size),
+            config.d_model,
+            bias=config.include_bias,
+            device=config.init_device,
+        )
+        self.ff_out._is_residual = True  # type: ignore
+
+        # Rotary embeddings.
+        if self.config.rope:
+            self.rotary_emb = RotaryEmbedding(config, self.__cache)
+
+        self.flash_attn_func = None
+        if config.flash_attention:
+            try:
+                from flash_attn import flash_attn_func  # type: ignore
+
+                self.flash_attn_func = flash_attn_func
+            except ModuleNotFoundError:
+                pass
+
+        self.use_cache = False
+        self.init_cache()
+
+    def init_cache(self):
+        self.cache = {"k": {}, "v": {}, "out": {}}
+
+    def caching(self, enable: bool = True):
+        self.use_cache = enable
+        self.init_cache()
+
+    def reset_parameters(self):
+        if self.k_norm is not None:
+            self.k_norm.reset_parameters()
+        if self.q_norm is not None:
+            self.q_norm.reset_parameters()
+        init_weights(
+            self.config,
+            self.attn_out,
+            d=self.config.d_model,
+            layer_id=self.layer_id,
+            type_of_module=ModuleType.out_module,
+        )
+        init_weights(
+            self.config,
+            self.ff_out,
+            d=self.ff_out.in_features,
+            layer_id=self.layer_id,
+            type_of_module=ModuleType.out_module,
+        )
+
+    def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
+        if strategy == ActivationCheckpointingStrategy.fine_grained:
+            self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
+        else:
+            self._activation_checkpoint_fn = None
+
+    @classmethod
+    def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
+        target_dtype = input_dtype
+        if bias.device.type == "cuda" and torch.is_autocast_enabled():
+            target_dtype = torch.get_autocast_gpu_dtype()
+        elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
+            target_dtype = torch.get_autocast_cpu_dtype()
+        if bias.dtype != target_dtype:
+            bias = bias.to(target_dtype)
+            ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
+        return bias
+
+    def _scaled_dot_product_attention(
+        self,
+        q: torch.Tensor,
+        k: torch.Tensor,
+        v: torch.Tensor,
+        attn_mask: Optional[torch.Tensor] = None,
+        dropout_p: float = 0.0,
+        is_causal: bool = False,
+    ) -> torch.Tensor:
+        if self.flash_attn_func is not None and attn_mask is None:
+            r = self.flash_attn_func(
+                q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=False
+            )
+            return r.transpose(1, 2)
+        else:
+            assert k.size(1) == v.size(1)
+            num_kv_heads = k.size(1)
+            num_q_heads = q.size(1)
+            if num_q_heads != num_kv_heads:
+                assert num_q_heads % num_kv_heads == 0
+                k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
+                v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
+
+            return F.scaled_dot_product_attention(
+                q,
+                k,
+                v,
+                attn_mask=None,
+                dropout_p=dropout_p,
+                is_causal=False,
+            )
+
+    def attention(
+        self,
+        q: torch.Tensor,
+        k: torch.Tensor,
+        v: torch.Tensor,
+        attention_bias: Optional[torch.Tensor] = None,
+        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        to_compute_mask=None,
+    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
+        B, T, C = q.size()  # batch size, sequence length, d_model
+        dtype = k.dtype
+
+        if self.q_norm is not None and self.k_norm is not None:
+            q = self.q_norm(q).to(dtype=dtype)
+            k = self.k_norm(k).to(dtype=dtype)
+
+        q = q.view(B, -1, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
+        k = k.view(B, -1, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
+        v = v.view(B, -1, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
+
+        if layer_past is not None:
+            past_key, past_value = layer_past
+            k = torch.cat((past_key, k), dim=-2)
+            v = torch.cat((past_value, v), dim=-2)
+
+        if self.config.rope:
+            to_compute_index = (
+                to_compute_mask.nonzero(as_tuple=True)[1] if self.use_cache and to_compute_mask is not None else None
+            )
+            q, k = self.rotary_emb(q, k, q_mask=to_compute_index)
+
+        if attention_bias is not None:
+            attention_bias = self._cast_attn_bias(attention_bias, dtype)
+
+        att = self._scaled_dot_product_attention(
+            q,
+            k,
+            v,
+            attn_mask=None,
+            dropout_p=0.0 if not self.training else self.config.attention_dropout,
+            is_causal=False,
+        )
+
+        att = att.transpose(1, 2).contiguous().view(B, T, C)
+
+        return self.attn_out(att), None
+
+    @abstractmethod
+    def forward(
+        self,
+        x: torch.Tensor,
+        attention_bias: Optional[torch.FloatTensor] = None,
+        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        use_cache: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
+        raise NotImplementedError
+
+    @classmethod
+    def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> "LLaDABlock":
+        if config.block_type == BlockType.sequential:
+            return LLaDASequentialBlock(layer_id, config, cache)
+        elif config.block_type == BlockType.llama:
+            return LLaDALlamaBlock(layer_id, config, cache)
+        else:
+            raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
+
+
+class LLaDASequentialBlock(LLaDABlock):
+    def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
+        super().__init__(layer_id, config, cache)
+        self.attn_norm = LayerNorm.build(config)
+        self.ff_norm = LayerNorm.build(config)
+        head_dim = config.d_model // config.n_heads
+        self.fused_dims = (
+            config.d_model,
+            config.effective_n_kv_heads * head_dim,
+            config.effective_n_kv_heads * head_dim,
+        )
+        self.att_proj = nn.Linear(
+            config.d_model, sum(self.fused_dims), bias=config.include_bias | config.include_qkv_bias, device=config.init_device
+        )
+        self.ff_proj = nn.Linear(
+            config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
+        )
+
+    def reset_parameters(self):
+        super().reset_parameters()
+        self.attn_norm.reset_parameters()
+        self.ff_norm.reset_parameters()
+        init_weights(
+            self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
+        )
+        init_weights(
+            self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
+        )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        attention_bias: Optional[torch.Tensor] = None,
+        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        use_cache: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
+        if self._activation_checkpoint_fn is not None:
+            q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split(
+                self.fused_dims, dim=-1
+            )
+        else:
+            q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1)
+
+        if self._activation_checkpoint_fn is not None:
+            att, cache = self._activation_checkpoint_fn(  # type: ignore
+                self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
+            )
+        else:
+            att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
+
+        x = x + self.dropout(att)
+
+        og_x = x
+        if self._activation_checkpoint_fn is not None:
+            x = self._activation_checkpoint_fn(self.ff_norm, x)  # type: ignore
+        else:
+            x = self.ff_norm(x)
+        x = self.ff_proj(x)
+        if self._activation_checkpoint_fn is not None:
+            x = self._activation_checkpoint_fn(self.act, x)  # type: ignore
+        else:
+           x = self.act(x)
+        x = self.ff_out(x)
+        x = self.dropout(x)
+        x = og_x + x
+
+        return x, cache
+
+
+class LLaDALlamaBlock(LLaDABlock):
+    def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
+        super().__init__(layer_id, config, cache)
+        self.attn_norm = LayerNorm.build(config)
+        self.ff_norm = LayerNorm.build(config)
+        self.__cache = cache
+
+        head_dim = config.d_model // config.n_heads
+        q_proj_out_dim = config.d_model
+        k_proj_out_dim = config.effective_n_kv_heads * head_dim
+        v_proj_out_dim = config.effective_n_kv_heads * head_dim
+        self.q_proj = nn.Linear(
+            config.d_model, q_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
+        )
+        self.k_proj = nn.Linear(
+            config.d_model, k_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
+        )
+        self.v_proj = nn.Linear(
+            config.d_model, v_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
+        )
+
+        self.ff_proj = nn.Linear(
+            config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
+        )
+        self.up_proj = nn.Linear(
+            config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
+        )
+
+    def reset_parameters(self):
+        super().reset_parameters()
+        self.attn_norm.reset_parameters()
+        self.ff_norm.reset_parameters()
+        init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None)
+        init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None)
+        init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None)
+        init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None)
+        init_weights(self.config, self.up_proj, d=self.config.d_model, layer_id=None)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        attention_bias: Optional[torch.Tensor] = None,
+        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        use_cache: bool = False,
+        cat="cond",
+        to_compute_mask=None,
+    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
+        B, T, D = x.shape
+
+        x_normed = self.attn_norm(x)
+        q = self.q_proj(x_normed)
+        k = self.k_proj(x_normed)
+        v = self.v_proj(x_normed)
+
+        if use_cache:
+            if cat not in self.cache["k"]:
+                self.cache["k"][cat] = torch.zeros_like(x)
+                self.cache["v"][cat] = torch.zeros_like(x)
+            if to_compute_mask is not None:
+                self.cache["k"][cat][to_compute_mask] = k.view(-1, D)
+                self.cache["v"][cat][to_compute_mask] = v.view(-1, D)
+                k = self.cache["k"][cat]
+                v = self.cache["v"][cat]
+            else:
+                self.cache["k"][cat] = k
+                self.cache["v"][cat] = v
+
+        if self._activation_checkpoint_fn is not None:
+            att, cache = self._activation_checkpoint_fn(  # type: ignore
+                self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
+            )
+        else:
+            att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, to_compute_mask=to_compute_mask)
+
+        x = x + self.dropout(att)
+
+        og_x = x
+        if self._activation_checkpoint_fn is not None:
+            x = self._activation_checkpoint_fn(self.ff_norm, x)  # type: ignore
+        else:
+            x = self.ff_norm(x)
+        x, x_up = self.ff_proj(x), self.up_proj(x)
+        if self._activation_checkpoint_fn is not None:
+            x = self._activation_checkpoint_fn(self.act, x)  # type: ignore
+        else:
+            x = self.act(x)
+        x = x * x_up
+        x = self.ff_out(x)
+        x = self.dropout(x)
+        x = og_x + x
+
+        return x, cache
+
+
+class LLaDAOutput(NamedTuple):
+    logits: torch.FloatTensor
+    attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
+    hidden_states: Optional[Tuple[torch.Tensor]]
+
+
+class LLaDAGenerateOutput(NamedTuple):
+    token_ids: torch.LongTensor
+    scores: torch.FloatTensor
+
+
+class LLaDABlockGroup(nn.ModuleList):
+    def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None):
+        super().__init__(modules)
+        self.config = config
+        self.layer_offset = layer_offset
+        self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
+        self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        attention_bias: Optional[torch.FloatTensor] = None,
+        layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
+        use_cache: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
+        attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
+        for block_idx, block in enumerate(self):
+            layer_past = None if layers_past is None else layers_past[block_idx]
+            block_idx += self.layer_offset
+            if (
+                (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
+                or (
+                    self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
+                    and block_idx % 2 == 0
+                )
+                or (
+                    self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
+                    and block_idx % 3 == 0
+                )
+                or (
+                    self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
+                    and block_idx % 4 == 0
+                )
+            ):
+                x, cache = self._activation_checkpoint_fn(  # type: ignore
+                    block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
+                )
+            else:
+                x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
+            if attn_key_values is not None:
+                assert cache is not None
+                attn_key_values.append(cache)
+        return x, attn_key_values
+
+    def reset_parameters(self):
+        for block in self:
+            block.reset_parameters()
+
+    def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
+        self.activation_checkpointing_strategy = strategy
+        for block in self:
+            block.set_activation_checkpointing(strategy)
+
+
+class LLaDAModel(nn.Module):
+    def __init__(self, config: ModelConfig, init_params: bool = True):
+        super().__init__()
+        self.config = config
+        self.__cache = BufferCache()
+
+        if self.config.alibi and self.config.flash_attention:
+            raise Exception("ALiBi is currently not supported with FlashAttention")
+
+        if self.config.alibi and self.config.rope:
+            raise Exception("ALiBi and RoPE are mutually exclusive")
+
+        if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
+            if self.config.embedding_size < self.config.vocab_size:
+                raise Exception("embedding size should be at least as big as vocab size")
+            elif self.config.embedding_size % 128 != 0:
+                import warnings
+
+                warnings.warn(
+                    "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
+                )
+
+        self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
+        self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
+
+        if not (
+            0 < self.config.block_group_size <= self.config.n_layers
+            and self.config.n_layers % self.config.block_group_size == 0
+        ):
+            raise Exception("n layers must be divisible by block group size")
+
+        torch.backends.cuda.enable_flash_sdp(True)
+        torch.backends.cuda.enable_mem_efficient_sdp(False)  # this is super slow so make sure torch won't use it
+
+        self.transformer = nn.ModuleDict(
+            dict(
+                wte=nn.Embedding(
+                    config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
+                ),
+                emb_drop=Dropout(config.embedding_dropout),
+                ln_f=LayerNorm.build(config),
+            )
+        )
+
+        blocks = [LLaDABlock.build(i, config, self.__cache) for i in range(config.n_layers)]
+        if self.config.block_group_size > 1:
+            block_groups = [
+                LLaDABlockGroup(config, i, blocks[i : i + config.block_group_size])
+                for i in range(0, config.n_layers, config.block_group_size)
+            ]
+            self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
+        else:
+            self.transformer.update({"blocks": nn.ModuleList(blocks)})
+
+        if not (self.config.alibi or self.config.rope):
+            self.transformer.update(
+                {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
+            )
+        if not config.weight_tying:
+            self.transformer.update(
+                {
+                    "ff_out": nn.Linear(
+                        config.d_model,
+                        config.embedding_size or config.vocab_size,
+                        bias=config.include_bias,
+                        device=config.init_device,
+                    )
+                }
+            )
+        if init_params and self.config.init_device != "meta":
+            self.reset_parameters()
+        self.__num_fwd_flops: Optional[int] = None
+
+        if self.config.alibi:
+            get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
+            self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
+
+        self.logit_cache = {}
+
+    def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
+        self.activation_checkpointing_strategy = strategy
+        if self.config.block_group_size != 1:
+            for block_group in self.transformer.block_groups:
+                block_group.set_activation_checkpointing(strategy)
+        else:
+            for block in self.transformer.blocks:
+                block.set_activation_checkpointing(strategy)
+
+    @property
+    def device(self) -> torch.device:
+        device: torch.device = self.transformer.wte.weight.device  # type: ignore
+        if device.type == "meta":
+            return _non_meta_init_device(self.config)
+        else:
+            return device
+
+    def reset_parameters(self):
+        logger.info("Initializing model parameters...")
+        init_weights(
+            self.config,
+            self.transformer.wte,  # type: ignore
+            std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0,
+            type_of_module=ModuleType.emb,
+        )
+        if hasattr(self.transformer, "wpe"):
+            init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb)  # type: ignore
+
+        self.transformer.ln_f.reset_parameters()  # type: ignore
+
+        if hasattr(self.transformer, "ff_out"):
+            init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out)  # type: ignore
+
+        if self.config.block_group_size == 1:
+            for block in self.transformer.blocks:
+                block.reset_parameters()
+        else:
+            for block_group in self.transformer.block_groups:
+                block_group.reset_parameters()
+
+    def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
+        if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[
+            -1
+        ] >= seq_len:
+            if alibi_bias.device != device:
+                alibi_bias = alibi_bias.to(device)
+                self.__cache["alibi_attention_bias"] = alibi_bias
+            return alibi_bias
+        with torch.autocast(device.type, enabled=False):
+            alibi_bias = alibi_attention_bias(seq_len, self.config, device)
+        self.__cache["alibi_attention_bias"] = alibi_bias
+        return alibi_bias
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor,
+        input_embeddings: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        attention_bias: Optional[torch.Tensor] = None,
+        past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
+        last_logits_only: bool = False,
+        output_hidden_states: Optional[bool] = None,
+        use_cache=False,
+        to_compute_mask=None,
+        cat="",
+    ) -> LLaDAOutput:
+        if use_cache and to_compute_mask is not None:
+            input_ids = input_ids[to_compute_mask].view(input_ids.shape[0], -1)
+
+        assert not self.config.alibi, "Alibi length extrapolation is not supported for MDM."
+        assert self.config.rope, "Rope must be used in Llama-Encoder for MDM."
+
+        output_hidden_states = output_hidden_states if output_hidden_states is not None else False
+
+        if past_key_values:
+            assert len(past_key_values) == self.config.n_layers
+
+        batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
+        if past_key_values is None:
+            past_length = 0
+        else:
+            past_length = past_key_values[0][0].size(-2)
+
+        x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings  # type: ignore
+
+        if self.config.input_emb_norm:
+            x = x * (self.config.d_model**0.5)
+
+        if not (self.config.alibi or self.config.rope):
+            pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
+            pos_emb = self.transformer.wpe(pos)  # type: ignore
+            x = pos_emb + x
+
+        x = self.transformer.emb_drop(x)  # type: ignore
+
+        if attention_mask is not None and 0.0 in attention_mask:
+            attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
+            attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
+        else:
+            attention_mask = None
+
+        if (
+            attention_bias is not None
+            or attention_mask is not None
+            or self.config.alibi
+            or past_key_values is not None
+        ):
+            if attention_bias is None and self.config.alibi:
+                attention_bias = get_causal_attention_bias(
+                    self.__cache, past_length + seq_len, x.device
+                ) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
+            elif attention_bias is None:
+                attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
+            elif attention_bias.dtype in (torch.int8, torch.bool):
+                attention_bias = attention_bias.to(dtype=torch.float)
+                attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
+
+            mask_len = seq_len
+            if attention_mask is not None:
+                mask_len = attention_mask.shape[-1]
+            elif past_key_values is not None:
+                mask_len = past_key_values[0][0].shape[-2] + seq_len
+            attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
+
+            if attention_mask is not None:
+                attention_bias = attention_bias + attention_mask
+                ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
+
+        attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
+
+        all_hidden_states = []
+
+        if self.config.block_group_size == 1:
+            for block_idx, block in enumerate(self.transformer.blocks):
+                if output_hidden_states:
+                    all_hidden_states.append(x)
+
+                layer_past = None if past_key_values is None else past_key_values[block_idx]
+                if (
+                    (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
+                    or (
+                        self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
+                        and block_idx % 2 == 0
+                    )
+                    or (
+                        self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
+                        and block_idx % 3 == 0
+                    )
+                    or (
+                        self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
+                        and block_idx % 4 == 0
+                    )
+                ):
+                    x, _ = self._activation_checkpoint_fn(
+                        block, x, attention_bias=attention_bias, layer_past=layer_past, to_compute_mask=to_compute_mask, use_cache=use_cache, cat=cat
+                    )
+                else:
+                    LLaDALlamaBlock.forward
+                    x, _ = block(
+                        x, attention_bias=attention_bias, layer_past=layer_past, to_compute_mask=to_compute_mask, use_cache=use_cache, cat=cat
+                    )
+        else:
+            for group_idx, block_group in enumerate(self.transformer.block_groups):
+                if output_hidden_states:
+                    all_hidden_states.append(x)
+
+                layers_past = (
+                    None
+                    if past_key_values is None
+                    else past_key_values[
+                        group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
+                    ]
+                )
+                x, _ = block_group(
+                    x, attention_bias=attention_bias, layers_past=layers_past, to_compute_mask=to_compute_mask, use_cache=use_cache, cat=cat
+                )
+
+        if last_logits_only:
+            x = x[:, -1, :].unsqueeze(1)
+
+        x = self.transformer.ln_f(x)  # type: ignore
+        if output_hidden_states:
+            all_hidden_states.append(x)
+
+        if self.config.weight_tying:
+            logits = F.linear(x, self.transformer.wte.weight, None)  # type: ignore
+        else:
+            logits = self.transformer.ff_out(x)  # type: ignore
+        if self.config.scale_logits:
+            logits.mul_(1 / math.sqrt(self.config.d_model))
+
+        if use_cache:
+            if cat not in self.logit_cache:
+                self.logit_cache[cat] = torch.zeros_like(logits)
+            if to_compute_mask is not None:
+                self.logit_cache[cat][to_compute_mask] = logits.view(-1, logits.shape[-1])
+                logits = self.logit_cache[cat]
+            else:
+                self.logit_cache[cat] = logits
+
+        return LLaDAOutput(
+            logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None
+        )  # type: ignore[arg-type]
+
+    def caching(self, enable: bool = True):
+        LLaDABlock.caching
+        for block in self.transformer.blocks:
+            block.caching(enable)
+        self.logit_cache = {}
+
+    def empty_cache(self):
+        for block in self.transformer.blocks:
+            block.init_cache()
+        self.logit_cache = {}
+
+
+def create_model_config_from_pretrained_config(config: LLaDAConfig):
+    kwargs = {}
+    for field in fields(ModelConfig):
+        kwargs[field.name] = getattr(config, field.name)
+
+    model_config = ModelConfig(**kwargs)
+    return model_config
+
+
+class LLaDAModelLM(PreTrainedModel):
+    config_class = LLaDAConfig
+    base_model_prefix = "model"
+    _no_split_modules = ["LLaDABlock", "LLaDASequentialBlock", "LLaDALlamaBlock"]
+
+    def __init__(self, config: LLaDAConfig, model: Optional[LLaDAModel] = None, init_params: bool = False):
+        super().__init__(config)
+
+        if not model:
+            model_config = create_model_config_from_pretrained_config(config)
+            model_config.init_device = "cpu"
+            self.model = LLaDAModel(model_config, init_params=init_params)
+        else:
+            self.model = model
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        attention_bias: Optional[torch.Tensor] = None,
+        past_key_values: Optional[List[torch.FloatTensor]] = None,
+        labels: Optional[torch.LongTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        cache_position: Optional[Cache] = None,  # This is a hack mitigation of an issue in transformers `4.39.x`
+        use_cache=False,
+        to_compute_mask=None,
+        cat="",
+    ) -> Union[Tuple, CausalLMOutputWithPast]:
+        if output_attentions:
+            raise ValueError("output_attentions is not yet supported in LLaDA")
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.model.forward(
+            input_ids=input_ids,
+            input_embeddings=inputs_embeds,
+            attention_mask=attention_mask,
+            attention_bias=attention_bias,
+            past_key_values=past_key_values,
+            output_hidden_states=output_hidden_states,
+            use_cache=use_cache,
+            to_compute_mask=to_compute_mask,
+            cat=cat,
+        )
+
+        logits = outputs.logits
+        hidden_states = outputs.hidden_states
+
+        loss = None
+        if labels is not None:
+            import warnings
+
+            warnings.warn("Note that for LLaDA, you cannot calculate the loss here.", UserWarning)
+        if not return_dict:
+            output = (logits,) + outputs[1:]
+            return (loss,) + output if loss is not None else output
+
+        return CausalLMOutputWithPast(
+            logits=logits,
+            past_key_values=outputs.attn_key_values,
+            hidden_states=hidden_states,
+        )
+
+    def can_generate(self) -> bool:
+        return True
+
+    def prepare_inputs_for_generation(
+        self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
+    ):
+        if past_key_values:
+            input_ids = input_ids[:, -1:]
+        model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
+
+        model_inputs.update(kwargs)
+        model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
+        return model_inputs
+
+    def get_input_embeddings(self) -> torch.nn.Module:
+        return self.model.transformer.wte
+
+    def set_input_embeddings(self, value: torch.nn.Module):
+        self.model.transformer.wte = value
+
+    def get_output_embeddings(self):
+        if self.config.weight_tying:
+            return self.model.transformer.wte
+        else:
+            return self.model.transformer.ff_out
+
+    def set_output_embeddings(self, value: torch.nn.Module):
+        if self.config.weight_tying:
+            self.model.transformer.wte = value
+        else:
+            self.model.transformer.ff_out = value
+
+    def tie_weights(self):
+        if self.config.weight_tying:
+            self.model.transformer.ff_out = self.model.transformer.wte
+
+    def caching(self, enable: bool = True):
+        self.model.caching(enable)
+
+    def empty_cache(self):
+        self.model.empty_cache()
+
+
+
+
+def create_attention_mask(original_lengths, max_tokens, device):
+    batch_size = len(original_lengths)
+    attention_mask = torch.zeros(batch_size, max_tokens, dtype=torch.bool, device=device)
+    for i, length in enumerate(original_lengths):
+        attention_mask[i, :length] = 1
+    return attention_mask
+
+
+class LLaDAForMultiModalGeneration(LLaDAModelLM):
+    config_class = LLaDAConfig
+    base_model_prefix = "model"
+
+    def __init__(self, config: LLaDAConfig, *args, **kwargs):
+        logger.info(f"Initializing MMadaModelLM with config: {config}")
+        super().__init__(config, *args, **kwargs)
+
+    def forward(self, input_ids=None, labels=None, infer=False, use_cache=False, to_compute_mask=None, cat="", **kwargs):
+        input_ids = input_ids.tolist()
+        max_tokens = max([len(_) for _ in input_ids])
+        original_lengths = [len(example) for example in input_ids]
+        input_ids = [example + [0] * (max_tokens - len(example)) for example in input_ids]
+        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=self.device)
+        attention_mask = create_attention_mask(original_lengths, max_tokens, self.device)
+
+        output = LLaDAModelLM.forward(
+            self, input_ids=input_ids, attention_mask=attention_mask, use_cache=use_cache, to_compute_mask=to_compute_mask, cat=cat
+        )
+        if infer:
+            return output
+
+    def get_fsdp_wrap_module_list(self) -> List:
+        modules = [*list(self.model.transformer.blocks), self.model.transformer.ff_out]
+        return modules
+
+
+AutoConfig.register("llada", LLaDAConfig)
+
+# --- End of model definition ---
+
+EXAMPLE_DOC_STRING = """
+Examples:
+```py
+>>> import torch
+>>> from PIL import Image
+>>> from diffusers import VQModel, DiffusionPipeline
+>>> from transformers import AutoTokenizer
+>>> from diffusers.utils import load_image
+
+>>> CHECKPOINT = "Alpha-VLLM/Lumina-DiMOO"
+
+>>> # Load VQ-VAE and tokenizer
+>>> vqvae = VQModel.from_pretrained(CHECKPOINT, subfolder="vqvae").to(device=device, dtype=torch_dtype)
+>>> tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT, trust_remote_code=True)
+
+>>> # Initialize the Lumina-DiMOO pipeline
+>>> pipe = DiffusionPipeline.from_pretrained(
+...     CHECKPOINT,
+...     custom_pipeline="lumina_dimoo",
+...     vqvae=vqvae,
+...     tokenizer=tokenizer,
+...     torch_dtype=torch.bfloat16
+... )
+>>> pipe.to("cuda")
+
+>>> # Load input image
+>>> input_image  = Image.open("path/to/your/ref_image.png").convert("RGB")
+
+>>> prompt = (
+...     " your prompt. "
+... )
+
+>>> # Run image-to-image generation
+>>> out = pipe(
+...     prompt=prompt,
+...     image=input_image,
+...     edit_type="depth_control",
+...     num_inference_steps=64,
+...     task="image_to_image",
+... )
+
+>>> out.images[0].save("i2i_test_output.png")
+"""
+
+
+# --- Helper functions ---
+
+
+def cosine_schedule(t):
+    return torch.cos(t * math.pi / 2)
+
+
+def gumbel_noise(t: torch.Tensor, *, generator: Optional[torch.Generator] = None) -> torch.Tensor:
+    if generator is None:
+        u = torch.rand_like(t)
+    else:
+        u = torch.rand(t.shape, device=t.device, dtype=t.dtype, generator=generator)
+    return -torch.log(-torch.log(u + 1e-20) + 1e-20)
+
+
+def add_gumbel_noise(logits, temperature):
+    """
+    Gumbel noise addition function
+    According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality
+    Therefore using float64
+    """
+    if temperature == 0:
+        return logits
+    logits = logits.to(torch.float64)
+    noise = torch.rand_like(logits, dtype=torch.float64)
+    gumbel_noise = (- torch.log(noise)) ** temperature
+    return logits.exp() / gumbel_noise
+
+
+def gumbel_max_sample(logits, temperature=1.0, generator=None):
+    if temperature == 0.0:
+        return logits.argmax(dim=-1)
+    gumbel_noise_ = gumbel_noise(logits, generator=generator)
+    return torch.argmax(logits / temperature + gumbel_noise_, dim=-1)
+
+def get_num_transfer_tokens(mask_index, steps):
+    """
+    In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals
+    Since LLaDA employs a linear noise schedule (as defined in Eq.(8)),
+    the expected number of tokens transitioned at each step should be consistent
+    
+    This function is designed to precompute the number of tokens that need to be transitioned at each step
+    """
+    mask_num = mask_index.sum(dim=1, keepdim=True)
+
+    base = mask_num // steps
+    remainder = mask_num % steps
+
+    num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
+
+    for i in range(mask_num.size(0)):
+        num_transfer_tokens[i, :remainder[i]] += 1
+
+    return num_transfer_tokens
+
+
+def mask_by_random_topk(keep_n, probs, temperature=1.0, generator=None):
+    B, S = probs.shape
+    noise = gumbel_noise(probs, generator=generator)
+
+    conf = probs / temperature + noise
+
+    mask = torch.zeros_like(conf, dtype=torch.bool)
+    for i in range(B):
+        k = keep_n[i]
+        if k > 0:
+            top_k_indices = torch.topk(conf[i], k, largest=True).indices
+            mask[i, top_k_indices] = True
+    return mask
+
+
+def calculate_vq_params(height, width, vae_scale_factor=32):
+    token_grid_height = height // vae_scale_factor
+    token_grid_width = width // vae_scale_factor
+    seq_len = token_grid_height * token_grid_width
+    newline_every = token_grid_width
+    return seq_len, newline_every, token_grid_height, token_grid_width
+
+
+def add_break_line(tokens, token_grid_height, token_grid_width, new_number):
+    new_tokens = []
+    for i in range(token_grid_height):
+        start = i * token_grid_width
+        end = (i + 1) * token_grid_width
+        row = tokens[start:end]
+        new_tokens.extend(row)
+        if i < token_grid_height - 1:
+            new_tokens.append(new_number)
+    return new_tokens
+
+
+def generate_crop_size_list(num_patches, patch_size, max_ratio=4.0):
+    assert max_ratio >= 1.0
+    crop_size_list = []
+    wp, hp = num_patches, 1
+    while wp > 0:
+        if max(wp, hp) / min(wp, hp) <= max_ratio:
+            crop_size_list.append((wp * patch_size, hp * patch_size))
+        if (hp + 1) * wp <= num_patches:
+            hp += 1
+        else:
+            wp -= 1
+    return crop_size_list
+
+
+def center_crop(pil_image, crop_size):
+    while pil_image.size[0] >= 2 * crop_size[0] and pil_image.size[1] >= 2 * crop_size[1]:
+        pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
+
+    scale = max(crop_size[0] / pil_image.size[0], crop_size[1] / pil_image.size[1])
+    pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
+
+    crop_left = random.randint(0, pil_image.size[0] - crop_size[0])
+    crop_upper = random.randint(0, pil_image.size[1] - crop_size[1])
+    crop_right = crop_left + crop_size[0]
+    crop_lower = crop_upper + crop_size[1]
+    return pil_image.crop(box=(crop_left, crop_upper, crop_right, crop_lower))
+
+
+def var_center_crop(pil_image, crop_size_list, random_top_k=1):
+    w, h = pil_image.size
+    rem_percent = [min(cw / w, ch / h) / max(cw / w, ch / h) for cw, ch in crop_size_list]
+    crop_size = random.choice(
+        sorted(((x, y) for x, y in zip(rem_percent, crop_size_list)), reverse=True)[:random_top_k]
+    )[1]
+    return center_crop(pil_image, crop_size)
+
+
+def preprocess_image(image: Image.Image):
+    image = image.convert("RGB")
+    w, h = image.size
+    w, h = (x - x % 32 for x in (w, h))
+    image = image.resize((w, h), resample=Image.LANCZOS)
+    image = np.array(image).astype(np.float32) / 255.0
+    image = image[None].transpose(0, 3, 1, 2)
+    image = torch.from_numpy(image)
+    return 2.0 * image - 1.0
+
+
+def encode_img_with_breaks(image, vqvae, special_tokens, vae_scale_factor: int = 16):
+    """
+    Encode image, add VQ offset, add newlines, and wrap with BOI/EOI tokens.
+    This function mirrors the logic from the original inference script.
+    """
+    orig = image.convert("RGB")
+    image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor, do_normalize=False)
+    pixels = image_processor.preprocess(orig).to(vqvae.device, dtype=vqvae.dtype)
+    latents = vqvae.encode(pixels).latents
+
+    latents_bsz, _, lat_h, lat_w = latents.shape
+
+    quantized = vqvae.quantize(latents)[2][2] + special_tokens["image_token_offset"]
+    quantized_with_offset = quantized.reshape(latents_bsz, lat_h, lat_w).flatten().tolist()
+
+    tokens_with_breaks = add_break_line(
+        quantized_with_offset, lat_h, lat_w, special_tokens["newline_token"]
+    )
+    return [special_tokens["boi"]] + tokens_with_breaks + [special_tokens["eoi"]]
+
+
+def create_prompt_templates():
+    """Create prompt templates for various tasks based on prompt_utils.py"""
+    templates = {
+        "text_understanding": "You are a multimodal model that can process both text and images. Answer the following question based on the provided images. Analyze each image and combine relevant details to answer.",
+        "image_generation": "Generate an image according to the text prompt.",
+        "image_editing": "Generate an image applying the following editing instruction based on the original image.",
+        "dense_prediction": "Perform dense prediction on the given images.",
+        "control_generation": "Generate an image according to the text prompt and the given control image.",
+        "subject_generation": "Generate an image according to the text prompt and the given object image.",
+        "multi_view": "Generate a view-image based on the given image.",
+        "style_transfer": "Transform the current image into the style of the provided image.",
+    }
+    return templates
+
+
+def generate_image_to_image_prompt(prompt_text, edit_type, templates):
+    """
+    Generate prompt for image-to-image generation based on prompt_utils.py
+    """
+    if "dense" in edit_type or "canny_pred" in edit_type:
+        des = {
+            "canny": "canny edge map",
+            "hed": "hed edge map",
+            "normal": "normal map",
+            "sam2mask": "sam2 mask",
+            "depth": "depth map",
+            "openpose": "pose estimation map",
+        }
+        system_prompt = templates["dense_prediction"]
+        prompt_text_used = f"Generate a {des.get(edit_type.split('_')[0], 'dense map')} according to the image."
+
+    elif "control" in edit_type:
+        system_prompt = templates["control_generation"]
+        prompt_text_used = prompt_text
+
+    elif "subject" in edit_type:
+        system_prompt = templates["subject_generation"]
+        prompt_text_used = prompt_text
+
+    elif "edit" in edit_type:
+        system_prompt = templates["image_editing"]
+        prompt_text_used = prompt_text
+
+    elif "ref_transfer" in edit_type or "image_ref_transfer" in edit_type:
+        system_prompt = templates["style_transfer"]
+        prompt_text_used = "Transform the current image into the style of the provided image."
+
+    elif "multi_view" in edit_type:
+        system_prompt = templates["multi_view"]
+        prompt_text_used = f"Generate the {edit_type.split('_')[-1]} view based on the provided front view."
+
+    else:
+        system_prompt = "Generate an image according to the prompt and image."
+        prompt_text_used = prompt_text
+
+    input_prompt = "" + system_prompt + "" + "" + prompt_text_used + ""
+    uncon_prompt = "" + system_prompt + "" + "" + "" + ""
+
+    return input_prompt, uncon_prompt
+
+def generate_text_to_image_prompt(prompt_text: str, templates: Optional[Dict] = None) -> Tuple[str, str]:
+    """
+    Generate prompt for text-to-image generation
+    
+    Args:
+        prompt_text: User input text prompt
+        templates: Optional prompt templates dict
+        
+    Returns:
+        Tuple of (input_prompt, unconditional_prompt)
+    """
+    if templates is None:
+        templates = create_prompt_templates()
+    
+    system_prompt = templates["image_generation"]
+    input_prompt = "" + system_prompt + "" + "" + prompt_text + ""
+    uncon_prompt = "" + system_prompt + "" + "" + "" + ""
+    
+    return input_prompt, uncon_prompt
+
+
+def generate_multimodal_understanding_prompt(question: str, templates: Optional[Dict] = None) -> str:
+    """
+    Generate prompt for multimodal understanding (MMU)
+    
+    Args:
+        question: User question about the image
+        templates: Optional prompt templates dict
+        
+    Returns:
+        Formatted input prompt
+    """
+    if templates is None:
+        templates = create_prompt_templates()
+    
+    system_prompt = "You are a multimodal model that can process both text and images. Answer the following question based on the provided images. Analyze each image and combine relevant details to answer."
+    input_prompt = "" + system_prompt + "" + "" + question + ""
+    
+    return input_prompt
+
+
+@torch.no_grad()
+def encode_img_with_paint(
+    img: Image.Image,
+    vqvae: VQModel,
+    *,
+    mask_h_ratio: float = 1,   # Height ratio
+    mask_w_ratio: float = 0.2,    # Width ratio
+    gray_value: int = 127,        # Visualization gray value
+    downsample_mode: str = "area",# Pixel mask alignment to latent grid
+    dilate_latent_k: int = 0,     # Optional dilation on latent grid (grid count)
+    mask_mode: str = "inpainting",   # "inpainting" | "outpainting"
+    special_tokens
+):
+    """
+    Encode image with mask for inpainting/outpainting tasks
+    
+    Args:
+        img: Input PIL image
+        vqvae: VQ-VAE model for encoding
+        mask_h_ratio: Height ratio for mask region (default: 1.0)
+        mask_w_ratio: Width ratio for mask region (default: 0.2)
+        gray_value: Gray value for mask visualization (default: 127)
+        downsample_mode: Downsampling mode for mask alignment ("area", "nearest", "bilinear")
+        dilate_latent_k: Dilation kernel size for latent grid (default: 0)
+        mask_mode: Mask mode - "inpainting" (mask inside) or "outpainting" (mask outside)
+    
+    Returns:
+        img_token: List[int] - Token sequence with newlines (126084) inserted at row ends;
+                              masked positions = 126336, others = index + 126356
+        vis_img: PIL.Image - Gray mask visualization image (consistent with mask_mode)
+    
+    Note:
+        * Encoding uses original image strictly; mask only maps to latent grid to determine
+          which tokens are set to MASK_TOKEN_ID.
+        * mask_mode="inpainting": mask inside rectangle; "outpainting": mask outside rectangle (inverse).
+    """
+
+    assert mask_mode in ("inpainting", "outpainting"), "mask_mode must be 'inpainting' or 'outpainting'"
+
+    # --- 1) Calculate center rectangle and generate visualization ---
+    img = img.convert("RGB")
+    W, H = img.size
+    mh = int(round(H * mask_h_ratio))
+    mw = int(round(W * mask_w_ratio))
+    top = (H - mh) // 2
+    left = (W - mw) // 2
+    bottom = top + mh
+    right = left + mw
+
+    if mask_mode == "inpainting":
+        vis_img = img.copy()
+        draw = ImageDraw.Draw(vis_img)
+        draw.rectangle([left, top, right, bottom], fill=(gray_value, gray_value, gray_value))
+    elif mask_mode == "outpainting":  # outpainting
+        bg = Image.new("RGB", (W, H), (gray_value, gray_value, gray_value))
+        crop = img.crop((left, top, right, bottom))
+        bg.paste(crop, (left, top))
+        vis_img = bg
+
+    # --- 2) VQ encoding using original image ---
+    vae_scale_factor = 2 ** (len(vqvae.config.block_out_channels) - 1)
+    image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor, do_normalize=False)
+    x = image_processor.preprocess(img).to(vqvae.device)  # 1 x 3 x H' x W'
+    latents = vqvae.encode(x).latents                     # 1 x C x h x w
+    _, _, lat_h, lat_w = latents.shape
+
+    # Quantization indices
+    quant_pack = vqvae.quantize(latents)
+    indices = quant_pack[2][2].view(1, lat_h, lat_w)      # 1 x h x w, long
+
+    # --- 3) Pixel mask -> latent grid mask (aligned with encoding input size) ---
+    Hp, Wp = x.shape[-2:]
+    mask_px = torch.zeros((1, 1, Hp, Wp), dtype=torch.float32, device=vqvae.device)
+    # First generate mask where "rectangle inside=1, outside=0"
+    top_p  = int(round(top  * Hp / H))
+    left_p = int(round(left * Wp / W))
+    bh_p   = int(round(mh   * Hp / H))
+    bw_p   = int(round(mw   * Wp / W))
+    mask_px[:, :, top_p:top_p+bh_p, left_p:left_p+bw_p] = 1.0
+
+    # If outpainting, need to invert (outside=1, inside=0 is the masked region)
+    if mask_mode == "outpainting":
+        mask_px = 1.0 - mask_px
+
+    if downsample_mode not in ("nearest", "area", "bilinear"):
+        downsample_mode = "area"
+    mask_lat = F.interpolate(mask_px, size=(lat_h, lat_w), mode=downsample_mode)
+    mask_lat = (mask_lat > 0.5) if downsample_mode == "area" else (mask_lat >= 0.5)
+    mask_lat = mask_lat[0, 0]        # h x w (bool)
+
+    # Optional: latent grid dilation (after inversion is applied)
+    if dilate_latent_k > 0:
+        m = mask_lat.float().unsqueeze(0).unsqueeze(0)
+        ker = 2 * dilate_latent_k + 1
+        m = F.max_pool2d(m, kernel_size=ker, stride=1, padding=dilate_latent_k)
+        mask_lat = (m[0, 0] > 0.5)
+
+    # --- 4) Generate tokens: masked positions=MASK_TOKEN_ID, others=indices+VQ_OFFSET ---
+    idx_flat = indices.view(-1)
+    mask_flat = mask_lat.view(-1)
+    tokens = torch.empty_like(idx_flat)
+    tokens[mask_flat] = special_tokens['mask_token']
+    tokens[~mask_flat] = idx_flat[~mask_flat] + special_tokens['image_token_offset']
+    tokens_list = tokens.tolist()
+
+    # --- 5) Insert newlines (no longer wrapped in /, consistent with current return) ---
+
+    img_token = add_break_line(tokens_list, lat_h, lat_w, special_tokens['newline_token'])
+    return img_token, vis_img
+
+
+class LuminaDiMOOPipelineOutput(BaseOutput):
+    """
+    Output class for the Lumina-DiMOO pipeline.
+
+    Args:
+        images (`List[PIL.Image.Image]` or `np.ndarray`, *optional*):
+            List of generated PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
+            num_channels)`.
+        text (`str`, *optional*):
+            Generated text from the multimodal understanding task.
+    """
+
+    images: Optional[Union[List[Image.Image], np.ndarray]] = None
+    text: Optional[str] = None
+
+
+class LuminaDiMOOPipeline(DiffusionPipeline):
+    """
+    A unified pipeline for Text-to-Image, Image-to-Image, and Multimodal Understanding
+    using the Lumina-DiMOO model.
+
+    This model was contributed by https://huggingface.co/Alpha-VLLM
+
+    Args:
+        llm ([`LLaDAForMultiModalGeneration`]):
+            The core LLM for multimodal generation, e.g., `LLaDAForMultiModalGeneration`.
+        vqvae ([`VQModel`]):
+            Vector Quantized Variational Auto-Encoder (VQ-VAE) model to encode and decode images to and from discrete
+            latent representations.
+        tokenizer ([`AutoTokenizer`):
+            An `AutoTokenizer` to tokenize text prompts.
+    """
+
+    def __init__(
+        self,
+        vqvae: VQModel,
+        tokenizer: AutoTokenizer,
+        checkpoint: Optional[str] = "Alpha-VLLM/Lumina-DiMOO",
+        torch_dtype: Optional[torch.dtype] = torch.bfloat16,
+        device_map: Optional[str] = "auto",         
+        low_cpu_mem_usage: bool = True, 
+    ):
+        super().__init__()
+        self.register_modules(
+            vqvae=vqvae,
+            tokenizer=tokenizer,
+        )
+        
+        self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
+        self.special_tokens = {
+            "mask_token": 126336,
+            "newline_token": 126084,
+            "boa": 126354,
+            "eoa": 126355,
+            "boi": 126349,
+            "eoi": 126350,
+            "image_token_offset": 126356,
+            "uncondition":126351
+        }
+        self.prompt_templates = create_prompt_templates()
+
+        # If checkpoint is not provided, reuse the model path from from_pretrained
+        if checkpoint is None:
+            checkpoint = self._name_or_path
+            raise ValueError("A `checkpoint` path must be provided to load the LLM, either directly or via `from_pretrained`.")
+
+        print("[Lumina] start loading LLaDA ...")
+        self.llm = LLaDAForMultiModalGeneration.from_pretrained(
+            checkpoint, torch_dtype=torch_dtype, trust_remote_code=True,
+            device_map=device_map,                   
+            low_cpu_mem_usage=low_cpu_mem_usage,    
+            use_safetensors=True,
+        )
+        print("   LlaDA Loaded Successfully.")
+
+    @staticmethod
+    @torch.no_grad()
+    def generate_i2i(
+        model: LLaDAForMultiModalGeneration,
+        prompt: torch.LongTensor,
+        *,
+        seq_len: int = 1024,
+        newline_every: int = 16,
+        timesteps: int = 18,
+        mask_token_id: int = 126336,
+        newline_id: int = 126084,
+        temperature: float = 1.0,
+        cfg_scale: float = 0.0,
+        cfg_img: float = 0.0,
+        uncon_text: torch.LongTensor,
+        uncon_image: torch.LongTensor,
+        code_start: Optional[int] = None,
+        codebook_size: int = 8192,
+        noise_schedule: Callable[[torch.Tensor], torch.Tensor] = cosine_schedule,
+        text_vocab_size: Optional[int] = None,
+        generator: Optional[torch.Generator] = None,
+    ) -> torch.LongTensor:
+        """
+        Image-to-image MaskGit generation (supports CFG for text and image)
+        
+        Args:
+            model: Model
+            prompt: Prompt tensor
+            seq_len: Sequence length
+            newline_every: Newline interval per row
+            timesteps: Number of timesteps
+            mask_token_id: Mask token id
+            newline_id: Newline token id
+            temperature: Temperature
+            cfg_scale: Text CFG scale
+            cfg_img: Image CFG scale
+            code_start: Prediction image token satrt index
+            uncon_text: Unconditional text input
+            uncon_image: Unconditional image input
+            codebook_size: Codebook size
+            noise_schedule: Noise schedule function
+            text_vocab_size: Text vocabulary size
+            generator: Random number generator
+        
+        Returns:
+            Final VQ codes (1, seq_len)
+        """
+        device = next(model.parameters()).device
+        prompt = prompt.to(device)
+        B, P = prompt.shape
+        assert B == 1, "batch>1 not supported – wrap in loop if needed"
+        
+        x = prompt
+
+        vq_mask = x == mask_token_id
+        unknown_cnt = vq_mask.sum(dim=1, keepdim=True)
+        vq_len = unknown_cnt
+
+        # Infer text vocabulary size
+        if text_vocab_size is None:
+            vocab_total = model(torch.zeros(1, 1, dtype=torch.long, device=device), infer=True).logits.size(-1)
+            text_vocab_size = vocab_total - codebook_size
+        vocab_offset = text_vocab_size
+
+        for step in range(timesteps):
+            if unknown_cnt.item() == 0:
+                break
+
+            # Calculate number of tokens to keep (continue masking) this round
+            if step < timesteps - 1:
+                frac = noise_schedule(torch.tensor([(step + 1) / timesteps], device=device))
+                keep_n = (vq_len.float() * frac).floor().clamp_min(1).long()
+            else:
+                keep_n = torch.zeros_like(unknown_cnt)
+
+            # Forward pass (with/without CFG)
+            if cfg_scale > 0 or cfg_img > 0:
+                # CFG text
+                uncond_text = torch.cat((uncon_text.to(x.device), x[:, code_start-2:]), dim=1)
+                uncond_text_vq_mask = torch.cat((torch.zeros((1, uncon_text.size(1)), dtype=torch.bool, device=x.device), vq_mask[:, code_start-2:]), dim=1)
+                # CFG image
+                uncond_img = torch.cat((uncon_image.to(x.device), x[:, code_start-2:]), dim=1)
+                uncond_img_vq_mask = torch.cat((torch.zeros((1, uncon_image.size(1)), dtype=torch.bool, device=x.device), vq_mask[:, code_start-2:]), dim=1)
+
+                cond_logits = model(x, infer=True).logits[:, vq_mask[0], vocab_offset : vocab_offset + codebook_size]
+                uncond_logits_text = model(uncond_text, infer=True).logits[:, uncond_text_vq_mask[0], vocab_offset : vocab_offset + codebook_size]
+                uncond_logits_img = model(uncond_img, infer=True).logits[:, uncond_img_vq_mask[0], vocab_offset : vocab_offset + codebook_size]
+                logits = cond_logits + cfg_scale * (cond_logits - uncond_logits_text) + cfg_img * (cond_logits - uncond_logits_img)
+            else:
+                logits = model(x, infer=True).logits[:, vq_mask[0], vocab_offset : vocab_offset + codebook_size]
+
+            sampled = gumbel_max_sample(logits, temperature, generator=generator)
+            sampled_full = sampled + vocab_offset
+            probs = torch.softmax(logits, dim=-1)
+            conf = probs.gather(-1, sampled.unsqueeze(-1)).squeeze(-1)
+
+            flat_idx = vq_mask.nonzero(as_tuple=False)[:, 1]
+            x.view(-1)[flat_idx] = sampled_full.view(-1)
+
+            conf_map = torch.full_like(x, -math.inf, dtype=probs.dtype)
+            conf_map.view(-1)[flat_idx] = conf.view(-1)
+
+            mask_sel = mask_by_random_topk(keep_n.squeeze(1), conf, temperature=temperature, generator=generator)
+            x.view(-1)[flat_idx[mask_sel.view(-1)]] = mask_token_id
+            vq_mask = x == mask_token_id
+            unknown_cnt = vq_mask.sum(dim=1, keepdim=True)
+
+        # Remove newline tokens
+        vq_ids = x[0, code_start:-2]
+        vq_ids = vq_ids[vq_ids != newline_id].view(1, seq_len)
+        return vq_ids
+
+
+    @staticmethod
+    @torch.no_grad()
+    def generate_image(
+        model: LLaDAForMultiModalGeneration,
+        prompt: torch.LongTensor,
+        *,
+        seq_len: int = 1024,
+        newline_every: int = 16,
+        timesteps: int = 18,
+        mask_token_id: int = 126336,
+        newline_id: int = 126084,
+        temperature: float = 1.0,
+        cfg_scale: float = 0.0,
+        uncon_ids: torch.LongTensor,
+        code_start: Optional[int] = None,
+        codebook_size: int = 8192,
+        noise_schedule: Callable[[torch.Tensor], torch.Tensor] = cosine_schedule,
+        text_vocab_size: Optional[int] = None,
+        generator: Optional[torch.Generator] = None,
+        use_cache=True,
+        cache_ratio=0.9,
+        refresh_interval=5,
+        warmup_ratio=0.3
+    ) -> torch.LongTensor:
+        """
+        MaskGit parallel decoding to generate VQ tokens
+        
+        Args:
+            model: Model
+            prompt: Prompt tensor
+            seq_len: Sequence length
+            newline_every: Newline interval per row
+            timesteps: Number of timesteps
+            mask_token_id: Mask token id
+            newline_id: Newline token id
+            temperature: Temperature
+            cfg_scale: CFG scale
+            uncon_ids: Unconditional input
+            code_start: Image token satrt index
+            codebook_size: Codebook size
+            noise_schedule: Noise schedule function
+            text_vocab_size: Text vocabulary size
+            generator: Random number generator
+        
+        Returns:
+            Final VQ codes (1, seq_len)
+        """
+
+
+        device = next(model.parameters()).device
+        prompt = prompt.to(device)
+        B, P = prompt.shape
+        assert B == 1, "batch>1 not supported – wrap in loop if needed"
+
+        x = prompt
+        
+        vq_mask = x == mask_token_id
+        unknown_cnt = vq_mask.sum(dim=1, keepdim=True)
+        vq_len = unknown_cnt
+
+        if isinstance(model, LLaDAForMultiModalGeneration):
+            model.caching(use_cache)
+        else:  # DDP
+            model.module.caching(use_cache)
+
+        warmup_step = int(timesteps * warmup_ratio)
+        refresh_steps = torch.zeros(timesteps, dtype=torch.bool)
+        for step in range(timesteps):
+            if not use_cache or step <= warmup_step or (step-warmup_step) % refresh_interval == 0:
+                refresh_steps[step] = True
+        compute_ratio = 1 - cache_ratio
+
+        # Infer text vocabulary size
+        if text_vocab_size is None:
+            vocab_total = model(torch.zeros(1, 1, dtype=torch.long, device=device), infer=True).logits.size(-1)
+            text_vocab_size = vocab_total - codebook_size
+        vocab_offset = text_vocab_size
+
+        for step in range(timesteps):
+            if unknown_cnt.item() == 0:
+                break
+
+            # Calculate number of tokens to keep (continue masking) this round
+            if step < timesteps - 1:
+                frac = noise_schedule(torch.tensor([(step + 1) / timesteps], device=device))
+                keep_n = (vq_len.float() * frac).floor().clamp_min(1).long()
+            else:
+                keep_n = torch.zeros_like(unknown_cnt)
+
+            if use_cache and step and refresh_steps[step]:
+                if isinstance(model, LLaDAForMultiModalGeneration):
+                    model.empty_cache()
+                else:  # DDP
+                    model.module.empty_cache()
+
+            # Forward pass (with/without CFG)
+            if cfg_scale > 0:
+                import time
+                t0 = time.time()
+                uncond = torch.cat((uncon_ids.to(x.device), x[:, code_start-2:]), axis=1)
+                uncond_vq_mask = torch.cat((torch.zeros((1, uncon_ids.size()[1]), dtype=torch.bool).to(x.device), vq_mask[:, code_start-2:]), axis=1)
+                cond_logits = model(x, infer=True,
+                        cat='cond', use_cache=use_cache, 
+                        to_compute_mask = cond_to_compute_mask if not refresh_steps[step] else None,
+                    ).logits[..., vocab_offset : vocab_offset + codebook_size]
+                cond_mask_logits = cond_logits[vq_mask].view(B, -1, codebook_size)
+                uncond_logits = model(uncond, infer=True,
+                        cat='uncond', use_cache=use_cache, 
+                        to_compute_mask = uncond_to_compute_mask if not refresh_steps[step] else None
+                    ).logits[..., vocab_offset : vocab_offset + codebook_size]
+                uncond_mask_logits = uncond_logits[uncond_vq_mask].view(B, -1, codebook_size)
+                logits = (1 + cfg_scale) * cond_mask_logits - cfg_scale * uncond_mask_logits
+            else:
+                logits = model(x, infer=True).logits[:, vq_mask[0], vocab_offset : vocab_offset + codebook_size]
+
+            sampled = gumbel_max_sample(logits, temperature, generator=generator)
+            sampled_full = sampled + vocab_offset
+            probs = torch.softmax(logits, dim=-1)
+            conf = probs.gather(-1, sampled.unsqueeze(-1)).squeeze(-1)
+
+            flat_idx = vq_mask.nonzero(as_tuple=False)[:, 1]
+            x.view(-1)[flat_idx] = sampled_full.view(-1)
+
+            conf_map = torch.full_like(x, -math.inf, dtype=probs.dtype)
+            conf_map.view(-1)[flat_idx] = conf.view(-1)
+
+            mask_sel = mask_by_random_topk(keep_n.squeeze(1), conf, temperature=temperature, generator=generator)
+            x.view(-1)[flat_idx[mask_sel.view(-1)]] = mask_token_id
+            vq_mask = x == mask_token_id
+            unknown_cnt = vq_mask.sum(dim=1, keepdim=True)
+
+            if use_cache and step < timesteps - 1 and not refresh_steps[step+1]:
+                cond_conf = cond_logits.max(dim=-1)[0]
+                cond_conf_threshold = torch.quantile(cond_conf.to(torch.float), compute_ratio, dim=-1, keepdim=True)
+                cond_to_compute_mask = cond_conf <= cond_conf_threshold
+
+                uncond_conf = uncond_logits.max(dim=-1)[0]
+                uncond_conf_threshold = torch.quantile(uncond_conf.to(torch.float), compute_ratio, dim=-1, keepdim=True)
+                uncond_to_compute_mask = uncond_conf <= uncond_conf_threshold
+                
+        # Remove newline tokens
+        vq_ids = x[0, code_start:-2]
+        vq_ids = vq_ids[vq_ids != newline_id].view(1, seq_len)
+        return vq_ids
+
+
+    @staticmethod
+    @torch.no_grad()
+    def generate_text_understanding(
+        model: LLaDAForMultiModalGeneration,
+        prompt, 
+        steps=128, 
+        gen_length=128, 
+        block_length=128, 
+        temperature=0.,
+        cfg_scale=0., 
+        remasking='low_confidence', 
+        mask_id=126336, 
+        code_start: Optional[int] = None,
+    ):
+        """
+        Text understanding generation function
+        
+        Args:
+            model: Mask predictor
+            prompt: Input prompt tensor (1, L)
+            steps: Sampling steps, less than or equal to gen_length
+            gen_length: Generated answer length
+            block_length: Block length, less than or equal to gen_length
+            temperature: Categorical distribution sampling temperature
+            cfg_scale: Unsupervised classifier-free guidance scale
+            remasking: Remasking strategy 'low_confidence' or 'random'
+            mask_id: The token id of [MASK] is 126336
+            code_start: Prediction text token satrt index
+        """
+        device = next(model.parameters()).device
+
+        x = prompt
+
+        prompt_index = (x != mask_id)
+
+        assert gen_length % block_length == 0
+        num_blocks = gen_length // block_length
+
+        assert steps % num_blocks == 0
+        steps = steps // num_blocks
+
+        for num_block in range(num_blocks):
+            block_mask_index = (x[:, code_start + num_block * block_length: code_start + (num_block + 1) * block_length:] == mask_id)
+            num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
+            
+            for i in range(steps):
+                mask_index = (x == mask_id)
+                if cfg_scale > 0.:
+                    un_x = x.clone()
+                    un_x[prompt_index] = mask_id
+                    x_ = torch.cat([x, un_x], dim=0)
+                    logits = model(x_, infer=True).logits
+                    logits, un_logits = torch.chunk(logits, 2, dim=0)
+                    logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
+                else:
+                    logits = model(x, infer=True).logits
+
+                logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
+                x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
+
+                if remasking == 'low_confidence':
+                    p = F.softmax(logits.to(torch.float64), dim=-1)
+                    x0_p = torch.squeeze(
+                        torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
+                elif remasking == 'random':
+                    x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
+                else:
+                    raise NotImplementedError(remasking)
+
+                x0_p[:, code_start + (num_block + 1) * block_length:] = -np.inf
+
+                x0 = torch.where(mask_index, x0, x)
+                confidence = torch.where(mask_index, x0_p, -np.inf)
+
+                transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
+                for j in range(confidence.shape[0]):
+                    _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
+                    transfer_index[j, select_index] = True
+                x[transfer_index] = x0[transfer_index]
+            
+        
+        return x
+
+
+
+    @torch.no_grad()
+    def _image_to_image(
+        self,
+        prompt: str,
+        image: Union[Image.Image, str],
+        ref_image: Optional[PipelineImageInput] = None,
+        edit_type: str = "canny_pred",
+        num_inference_steps: int = 64,
+        temperature: float = 1.0,
+        cfg_scale: float = 2.5,
+        cfg_img: float = 4.0,
+        output_type: Optional[str] = "pil",
+    ):
+        
+        if isinstance(prompt, list):
+            raise ValueError("Batching is not supported for this pipeline.")
+
+        if isinstance(image, str):
+            image = Image.open(image).convert("RGB")
+        if isinstance(ref_image, str):
+            ref_image = Image.open(ref_image).convert("RGB")
+
+        input_prompt, uncon_text = generate_image_to_image_prompt(prompt, edit_type, self.prompt_templates)
+
+        crop_size_list = generate_crop_size_list((512 // 32) ** 2, 32)
+        
+        # Correctly encode input images with newline tokens
+        if "image_ref_transfer" in edit_type:
+            if ref_image is None:
+                raise ValueError("`ref_image` must be provided for `image_ref_transfer` edit type.")
+            processed_img = var_center_crop(image, crop_size_list=crop_size_list)
+            input_img_token = encode_img_with_breaks(processed_img, self.vqvae, self.special_tokens)
+
+            referring_img = var_center_crop(ref_image, crop_size_list=crop_size_list)
+            referring_img_token = encode_img_with_breaks(referring_img, self.vqvae, self.special_tokens)
+
+            image_width, image_height = referring_img.size
+            seq_len, newline_every, token_grid_height, token_grid_width = calculate_vq_params(
+                referring_img.height, referring_img.width, self.vae_scale_factor
+            )
+        else:
+            processed_img = var_center_crop(image, crop_size_list=crop_size_list)
+            input_img_token = encode_img_with_breaks(processed_img, self.vqvae, self.special_tokens)
+            image_width, image_height = processed_img.size
+            seq_len, newline_every, token_grid_height, token_grid_width = calculate_vq_params(
+                processed_img.height, processed_img.width, self.vae_scale_factor
+            )
+
+        prompt_ids = self.tokenizer(input_prompt)["input_ids"]
+        uncon_text_ids = self.tokenizer(uncon_text)["input_ids"]
+
+        img_mask_token = add_break_line(
+            [self.special_tokens["mask_token"]] * seq_len,
+            token_grid_height,
+            token_grid_width,
+            new_number=self.special_tokens["newline_token"],
+        )
+        img_pred_token = (
+            [self.special_tokens["boa"]]
+            + [self.special_tokens["boi"]]
+            + img_mask_token
+            + [self.special_tokens["eoi"]]
+            + [self.special_tokens["eoa"]]
+        )
+
+        if "image_ref_transfer" in edit_type:
+            con_input = prompt_ids[:-1] + input_img_token + referring_img_token + prompt_ids[-1:]
+            uncon_input_text = uncon_text_ids[:-1] + input_img_token + referring_img_token + uncon_text_ids[-1:]
+        else:
+            con_input = prompt_ids[:-1] + input_img_token + prompt_ids[-1:]
+            uncon_input_text = uncon_text_ids[:-1] + input_img_token + uncon_text_ids[-1:]
+        uncon_input_image = prompt_ids
+
+        code_start = len(con_input) + 2
+
+        con_input = torch.tensor(con_input + img_pred_token, device=self.device).unsqueeze(0)
+        uncon_input_text = torch.tensor(uncon_input_text, device=self.device).unsqueeze(0)
+        uncon_input_image = torch.tensor(uncon_input_image, device=self.device).unsqueeze(0)
+
+        vq_tokens = self.generate_i2i(
+            self.llm,
+            con_input,
+            seq_len=seq_len,
+            newline_every=newline_every,
+            timesteps=num_inference_steps,
+            temperature=temperature,
+            cfg_scale=cfg_scale,
+            cfg_img=cfg_img,
+            uncon_text=uncon_input_text,
+            uncon_image=uncon_input_image,
+            code_start=code_start
+        )
+
+        if vq_tokens.shape[1] != token_grid_height * token_grid_width:
+            raise ValueError(
+                f"VQ codes length mismatch: {vq_tokens.shape[1]} != {token_grid_height * token_grid_width} "
+                f"for image size ({image_height},{image_width}) with scale {self.vae_scale_factor}"
+            )
+
+        latents = (
+            vq_tokens.view(1, token_grid_height, token_grid_width).to(self.vqvae.device) - self.special_tokens["image_token_offset"]
+        ).long()
+
+        shape = (1, token_grid_height, token_grid_width, self.vqvae.config.latent_channels)
+
+        recon = self.vqvae.decode(
+            latents,
+            force_not_quantize=True,
+            shape=shape,
+        ).sample.clip(0, 1)
+
+        img_proc = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
+        image = img_proc.postprocess(recon.detach(), output_type=output_type)
+
+        return image
+
+    @torch.no_grad()
+    def _text_to_image(
+        self,
+        prompt: str,
+        height: int,
+        width: int,
+        painting_mode: Optional[str] = None,
+        painting_image: Optional[PipelineImageInput] = None,
+        cfg_scale: float = 4.0,
+        use_cache: bool = True,
+        cache_ratio: float = 0.9,
+        refresh_interval: int = 5,
+        warmup_ratio: float = 0.3,
+        num_inference_steps: int = 64,
+        temperature: float = 1.0,
+        mask_h_ratio: float = 1.0,
+        mask_w_ratio: float = 0.2
+    ):
+        if isinstance(painting_image, str):
+            painting_image = Image.open(painting_image)
+
+        if painting_mode and painting_image:
+            width, height = painting_image.size
+
+        seq_len, newline_every, token_grid_height, token_grid_width = calculate_vq_params(height, width, self.vae_scale_factor)
+
+        input_prompt, uncon_prompt = generate_text_to_image_prompt(prompt, self.prompt_templates)
+
+        con_prompt_token = self.tokenizer(input_prompt)["input_ids"]
+        uncon_prompt_token = self.tokenizer(uncon_prompt)["input_ids"]
+
+        if painting_mode:
+            img_mask_token, img_vis = encode_img_with_paint(
+                painting_image,
+                vqvae=self.vqvae,
+                mask_h_ratio=mask_h_ratio,
+                mask_w_ratio=mask_w_ratio,
+                mask_mode=painting_mode,
+                special_tokens=self.special_tokens,
+            )
+        else:
+            img_mask_token = add_break_line(
+                [self.special_tokens["mask_token"]] * seq_len,
+                token_grid_height,
+                token_grid_width,
+                new_number=self.special_tokens["newline_token"],
+            )
+
+        img_pred_token = (
+            [self.special_tokens["boa"]]
+            + [self.special_tokens["boi"]]
+            + img_mask_token
+            + [self.special_tokens["eoi"]]
+            + [self.special_tokens["eoa"]]
+        )
+
+        prompt_ids = torch.tensor(con_prompt_token + img_pred_token, device=self.device).unsqueeze(0)
+        uncon_ids = torch.tensor(uncon_prompt_token, device=self.device).unsqueeze(0)
+
+        code_start = len(con_prompt_token) + 2
+
+        vq_tokens = self.generate_image(
+            model=self.llm,
+            prompt=prompt_ids,
+            seq_len=seq_len,
+            newline_every=newline_every,
+            timesteps=num_inference_steps,
+            temperature=temperature,
+            cfg_scale=cfg_scale,
+            uncon_ids=uncon_ids,
+            code_start=code_start,
+            use_cache=use_cache,
+            cache_ratio=cache_ratio,
+            refresh_interval=refresh_interval,
+            warmup_ratio=warmup_ratio
+        )
+
+        latents = (
+            vq_tokens.view(1, token_grid_height, token_grid_width).to(self.vqvae.device) - self.special_tokens["image_token_offset"]
+        ).long()
+
+        shape = (1, token_grid_height, token_grid_width, self.vqvae.config.latent_channels)
+        recon = self.vqvae.decode(latents, force_not_quantize=True, shape=shape).sample.clip(0, 1)
+
+        img_proc = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
+        image = img_proc.postprocess(recon.detach(), output_type="pil")
+
+        return image
+
+    @torch.no_grad()
+    def _multimodal_understanding(
+        self,
+        prompt: str,
+        image: PipelineImageInput,
+        num_inference_steps: int = 128,
+        gen_length: int = 1024,
+        block_length: int = 128,
+        temperature: float = 0.0,
+        cfg_scale: float = 0.0,
+        remasking: str = "low_confidence",
+    ):
+
+        if isinstance(image, str):
+            image = Image.open(image)
+
+        input_prompt = generate_multimodal_understanding_prompt(prompt)
+        input_ids = self.tokenizer(input_prompt)["input_ids"]
+
+        crop_size_list = generate_crop_size_list((1024 // 32) ** 2, 32)
+        processed_image = var_center_crop(image, crop_size_list=crop_size_list)
+        
+        image_width, image_height = processed_image.size
+        seq_len, newline_every, token_grid_height, token_grid_width = calculate_vq_params(
+            image_height, image_width, self.vae_scale_factor
+        )
+
+        input_img_token = encode_img_with_breaks(processed_image, self.vqvae, self.special_tokens)
+
+        input_token = input_ids[:-1] + input_img_token + input_ids[-1:]
+        code_start = len(input_token) + 1
+
+        input_token = input_token + [self.special_tokens["boa"]] + gen_length * [self.special_tokens["mask_token"]] + [self.special_tokens["eoa"]]
+        input_ids = torch.tensor(input_token, device=self.device).unsqueeze(0)
+
+        output_tokens = self.generate_text_understanding(
+            model=self.llm,
+            prompt=input_ids,
+            steps=num_inference_steps,
+            gen_length=gen_length,
+            block_length=block_length, 
+            cfg_scale=cfg_scale, 
+            temperature=temperature,
+            remasking=remasking,
+            code_start=code_start
+        )
+
+        generated_text = self.tokenizer.batch_decode(output_tokens[:, code_start:-1], skip_special_tokens=True)[0]
+        return generated_text
+
+    @torch.no_grad()
+    @replace_example_docstring(EXAMPLE_DOC_STRING)
+    def __call__(
+        self,
+        prompt: str,
+        image: Optional[PipelineImageInput] = None,
+        task: str = "auto",
+        **kwargs,
+    ) -> LuminaDiMOOPipelineOutput:
+        r"""
+        Unified entry for 'text_to_image' | 'image_to_image' | 'multimodal_understanding'.
+
+        Examples:
+        {EXAMPLE_DOC_STRING}
+        """
+        if task == "auto":
+            if image is None:
+                task = "text_to_image"
+            elif "edit_type" in kwargs:
+                task = "image_to_image"
+            else:
+                task = "multimodal_understanding"
+
+        if task == "text_to_image":
+            # Default values from inference_t2i.py
+            t2i_kwargs = {
+                "height": kwargs.pop("height", 1024),
+                "width": kwargs.pop("width", 1024),
+                "num_inference_steps": kwargs.pop("num_inference_steps", 64), 
+                "cfg_scale": kwargs.pop("cfg_scale", 4.0), 
+                "temperature": kwargs.pop("temperature", 1.0),
+                "painting_mode": kwargs.pop("painting_mode", None),
+                "painting_image": kwargs.pop("painting_image", None),
+                "mask_h_ratio": kwargs.pop("mask_h_ratio", 1.0),
+                "mask_w_ratio": kwargs.pop("mask_w_ratio", 0.2),
+                "use_cache": kwargs.pop("use_cache", True),
+                "cache_ratio": kwargs.pop("cache_ratio", 0.9),
+                "refresh_interval": kwargs.pop("refresh_interval", 5),
+                "warmup_ratio": kwargs.pop("warmup_ratio", 0.3),
+            }
+            images = self._text_to_image(prompt=prompt, **t2i_kwargs)
+            return LuminaDiMOOPipelineOutput(images=images, text=None)
+
+        elif task == "image_to_image":
+            if image is None:
+                raise ValueError("`image` must be provided for image_to_image task.")
+            i2i_kwargs = {
+                "ref_image": kwargs.pop("ref_image", None),
+                "edit_type": kwargs.pop("edit_type", "canny_pred"),
+                "num_inference_steps": kwargs.pop("num_inference_steps", 64), 
+                "temperature": kwargs.pop("temperature", 1.0),
+                "cfg_scale": kwargs.pop("cfg_scale", 2.5), 
+                "cfg_img": kwargs.pop("cfg_img", 4.0),
+            }
+            images = self._image_to_image(prompt=prompt, image=image, **i2i_kwargs)
+            return LuminaDiMOOPipelineOutput(images=images, text=None)
+
+        elif task == "multimodal_understanding":
+            if image is None:
+                raise ValueError("`image` must be provided for multimodal_understanding task.")
+            mmu_kwargs = {
+                "num_inference_steps": kwargs.pop("num_inference_steps", 128), 
+                "gen_length": kwargs.pop("gen_length", 1024),
+                "block_length": kwargs.pop("block_length", 256),
+                "temperature": kwargs.pop("temperature", 0.0),
+                "cfg_scale": kwargs.pop("cfg_scale", 0.0),
+                "remasking": kwargs.pop("remasking", "low_confidence"),
+            }
+            text = self._multimodal_understanding(prompt=prompt, image=image, **mmu_kwargs)
+            return LuminaDiMOOPipelineOutput(images=None, text=text)
+
+        else:
+            raise ValueError(f"Unknown task: {task}. Supported tasks are 'text_to_image', 'image_to_image', 'multimodal_understanding', and 'auto'.")
 | "The image shows a vibrant orange sports car parked in a showroom. The car has a sleek, aerodynamic design with a prominent front grille and side vents. The body is adorned with black and orange racing stripes, creating a striking contrast against the orange paint. The car is equipped with black alloy wheels and a low-profile body style. The background features a white wall with a large emblem that reads "BREITZEN" and includes a silhouette of a horse and text. The floor is tiled with dark tiles, and the showroom is well-lit, highlighting the car. The overall setting suggests a high-end, possibly luxury, automotive environment."|
+
+
+```python
+import torch
+
+from diffusers import VQModel, DiffusionPipeline
+from transformers import AutoTokenizer
+from diffusers.utils import load_image
+
+vqvae = VQModel.from_pretrained("Alpha-VLLM/Lumina-DiMOO", subfolder="vqvae").to(device='cuda', dtype=torch.bfloat16)
+tokenizer = AutoTokenizer.from_pretrained("Alpha-VLLM/Lumina-DiMOO", trust_remote_code=True)
+
+pipe = DiffusionPipeline.from_pretrained(
+    "Alpha-VLLM/Lumina-DiMOO",
+    vqvae=vqvae,
+    tokenizer=tokenizer,
+    torch_dtype=torch.bfloat16,
+    custom_pipeline="lumina_dimoo",
+)
+pipe.to("cuda")
+
+question = "Please describe the image."
+
+input_image = load_image(
+    "https://raw.githubusercontent.com/Alpha-VLLM/Lumina-DiMOO/main/examples/example_8.png"
+).convert("RGB")
+
+out = pipe(
+    prompt=question,
+    image=input_image,
+    task="multimodal_understanding",
+    num_inference_steps=128,
+    gen_length=128,
+    block_length=32,
+    temperature=0.0,
+    cfg_scale=0.0,
+)
+
+text = getattr(out, "text", out)
+with open("mmu_answer.txt", "w", encoding="utf-8") as f:
+    f.write(text.strip() + "\n")
+```
+
+
+
+
diff --git a/examples/community/lumina_dimoo.py b/examples/community/lumina_dimoo.py
new file mode 100644
index 000000000000..6ed4b0d6398c
--- /dev/null
+++ b/examples/community/lumina_dimoo.py
@@ -0,0 +1,2676 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functools
+import math
+import random
+import sys
+from abc import abstractmethod
+from dataclasses import dataclass, fields
+from enum import Enum
+from typing import Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Sequence, Tuple, Union, cast
+from accelerate import init_empty_weights
+
+import numpy as np
+import torch
+import torch.backends.cuda
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import einsum
+from PIL import Image
+from transformers import AutoConfig, AutoTokenizer, PretrainedConfig, PreTrainedModel
+from transformers.cache_utils import Cache
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from diffusers import DiffusionPipeline, VQModel
+from diffusers.utils import BaseOutput, logging, replace_example_docstring
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+
+from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
+
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+# --- Start of model definition copied from Lumina-DiMOO ---
+
+
+class StrEnum(str, Enum):
+    """
+    This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
+    We include this here for compatibility with older version of Python.
+    """
+
+    def __str__(self) -> str:
+        return self.value
+
+    def __repr__(self) -> str:
+        return f"'{str(self)}'"
+
+
+class LayerNormType(StrEnum):
+    default = "default"
+    low_precision = "low_precision"
+    rms = "rms"
+    gemma_rms = "gemma_rms"
+    amd_compatible = "amd_compatible"
+
+
+class ActivationType(StrEnum):
+    gelu = "gelu"
+    relu = "relu"
+    silu = "silu"
+    swiglu = "swiglu"
+
+
+class BlockType(StrEnum):
+    sequential = "sequential"
+    parallel = "parallel"
+    llama = "llama"
+
+
+class InitFnType(StrEnum):
+    mitchell = "mitchell"
+    normal = "normal"
+    kaiming_normal = "kaiming_normal"
+    fan_in = "fan_in"
+    full_megatron = "full_megatron"
+
+
+@dataclass
+class ModelConfig:
+    """
+    LLaDA (model) configuration.
+    """
+
+    # Note that the defaults for these attributes are equivalent to the base GPT2 model.
+
+    d_model: int = 768
+    n_heads: int = 12
+    n_kv_heads: Optional[int] = None
+    n_layers: int = 12
+    mlp_ratio: int = 4
+    mlp_hidden_size: Optional[int] = None
+    activation_type: ActivationType = ActivationType.swiglu
+    block_type: BlockType = BlockType.sequential
+    block_group_size: int = 1
+    alibi: bool = False
+    alibi_bias_max: float = 8.0
+    rope: bool = False
+    rope_full_precision: bool = True
+    flash_attention: bool = False
+    attention_dropout: float = 0.1
+    multi_query_attention: Optional[bool] = None
+    attention_layer_norm: bool = False
+    residual_dropout: float = 0.1
+    embedding_dropout: float = 0.1
+    input_emb_norm: bool = False
+    layer_norm_type: LayerNormType = LayerNormType.default
+    layer_norm_with_affine: bool = True
+    rms_norm_eps: float = 1e-05
+    attention_layer_norm_with_affine: bool = True
+    max_sequence_length: int = 1024
+    rope_theta: float = 10000.0
+    include_qkv_bias: Optional[bool] = False
+    include_bias: bool = False
+    bias_for_layer_norm: Optional[bool] = None
+    scale_logits: bool = False
+    vocab_size: int = 50257
+    embedding_size: Optional[int] = 50304
+    weight_tying: bool = True
+    eos_token_id: int = 50256
+    pad_token_id: int = 50256
+    mask_token_id: Optional[int] = 50256
+    init_device: Optional[str] = None
+    init_fn: InitFnType = InitFnType.normal
+    init_std: float = 0.02
+    init_cutoff_factor: Optional[float] = None
+    precision: Optional[str] = None
+
+    @property
+    def effective_n_kv_heads(self) -> int:
+        if self.n_kv_heads is None:
+            if self.multi_query_attention is True:
+                return 1
+            else:
+                return self.n_heads
+        else:
+            if self.multi_query_attention is None:
+                return self.n_kv_heads
+            if self.multi_query_attention:
+                n_kv_heads_should_be = 1
+            else:
+                n_kv_heads_should_be = self.n_heads
+            if self.n_kv_heads == n_kv_heads_should_be:
+                return n_kv_heads_should_be
+            else:
+                raise Exception("You can't set `multi_query_attention` and `n_kv_heads` at the same time.")
+
+
+class ActivationCheckpointingStrategy(StrEnum):
+    whole_layer = "whole_layer"
+    one_in_two = "one_in_two"
+    one_in_three = "one_in_three"
+    one_in_four = "one_in_four"
+    two_in_three = "two_in_three"
+    three_in_four = "three_in_four"
+    four_in_five = "four_in_five"
+    nine_in_ten = "nine_in_ten"
+    fine_grained = "fine_grained"
+
+
+class LLaDAConfig(PretrainedConfig):
+    model_type = "llada"
+    keys_to_ignore_at_inference = ["past_key_values"]  # TODO: confirm
+
+    def __init__(self, use_cache: bool = False, **kwargs):
+        model_config = ModelConfig()
+        all_kwargs = model_config.__dict__
+        all_kwargs.update(kwargs)
+        all_kwargs.update({"use_cache": use_cache})
+        all_kwargs.update({"architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])})
+        super().__init__(**all_kwargs)
+
+    @property
+    def num_attention_heads(self):
+        return self.n_heads
+
+    @property
+    def num_hidden_layers(self):
+        return self.n_layers
+
+    @property
+    def hidden_size(self):
+        return self.d_model
+
+
+if sys.version_info.minor > 8:
+    from collections.abc import MutableMapping
+elif sys.version_info.minor == 8:
+    from typing import MutableMapping
+else:
+    raise SystemExit("This script supports Python 3.8 or higher")
+
+
+class ModuleType(StrEnum):
+    in_module = "in"
+    out_module = "out"
+    emb = "emb"
+    final_out = "final_out"
+
+
+def init_weights(
+    config: ModelConfig,
+    module: Union[nn.Linear, nn.Embedding],
+    d: Optional[int] = None,
+    layer_id: Optional[int] = None,
+    std_factor: float = 1.0,
+    type_of_module: Optional[ModuleType] = None,
+) -> None:
+    d = d if d is not None else config.d_model
+    if config.init_fn == InitFnType.normal:
+        std = config.init_std * std_factor
+        if config.init_cutoff_factor is not None:
+            cutoff_value = config.init_cutoff_factor * std
+            nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
+        else:
+            nn.init.normal_(module.weight, mean=0.0, std=std)
+    elif config.init_fn == InitFnType.mitchell:
+        std = std_factor / math.sqrt(d)
+        if layer_id is not None:
+            std = std / math.sqrt(2 * (layer_id + 1))
+        nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
+    elif config.init_fn == InitFnType.kaiming_normal:
+        nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
+    elif config.init_fn == InitFnType.fan_in:
+        std = std_factor / math.sqrt(d)
+        nn.init.normal_(module.weight, mean=0.0, std=std)
+    elif config.init_fn == InitFnType.full_megatron:
+        if type_of_module is None:
+            raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
+
+        cutoff_factor = config.init_cutoff_factor
+        if cutoff_factor is None:
+            cutoff_factor = 3
+
+        if type_of_module == ModuleType.in_module:
+            # for att_proj (same as QKV), ff_proj
+            std = config.init_std
+        elif type_of_module == ModuleType.out_module:
+            # for attn_out, ff_out
+            std = config.init_std / math.sqrt(2.0 * config.n_layers)
+        elif type_of_module == ModuleType.emb:
+            # positional embeddings (wpe)
+            # token embeddings (wte)
+            std = config.init_std
+        elif type_of_module == ModuleType.final_out:
+            # final output (ff_out)
+            std = config.d_model**-0.5
+        else:
+            raise RuntimeError(f"Unknown module type '{type_of_module}'")
+        nn.init.trunc_normal_(
+            module.weight,
+            mean=0.0,
+            std=std,
+            a=-cutoff_factor * std,
+            b=cutoff_factor * std,
+        )
+    else:
+        raise NotImplementedError(config.init_fn)
+
+    if isinstance(module, nn.Linear):
+        if module.bias is not None:
+            nn.init.zeros_(module.bias)
+
+        if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False):
+            with torch.no_grad():
+                module.weight.div_(math.sqrt(2 * config.n_layers))
+
+
+def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
+    if check_neg_inf:
+        x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
+    if check_pos_inf:
+        x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
+
+
+def activation_checkpoint_function(cfg: ModelConfig):
+    preserve_rng_state = (
+        (cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0)
+    )
+    from torch.utils.checkpoint import checkpoint
+
+    return functools.partial(
+        checkpoint,
+        preserve_rng_state=preserve_rng_state,
+        use_reentrant=False,
+    )
+
+
+class BufferCache(dict, MutableMapping[str, torch.Tensor]):
+    """
+    Cache for attention biases and other things that would normally be stored as buffers.
+    We avoid using buffers because we've run into various issues doing so with FSDP.
+    In general it appears the way FSDP handles buffers is not well-defined.
+    It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid
+    since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into
+    NaNs when they're synchronized due to casting or some other issue.
+    """
+
+
+def _non_meta_init_device(config: ModelConfig) -> torch.device:
+    if config.init_device is not None and config.init_device != "meta":
+        return torch.device(config.init_device)
+    else:
+        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class Dropout(nn.Dropout):
+    def forward(self, input: torch.Tensor) -> torch.Tensor:
+        if self.p == 0.0:
+            return input
+        else:
+            return F.dropout(input, self.p, self.training, self.inplace)
+
+
+class LayerNormBase(nn.Module):
+    def __init__(
+        self,
+        config: ModelConfig,
+        *,
+        size: Optional[int] = None,
+        elementwise_affine: Optional[bool] = True,
+        eps: float = 1e-05,
+    ):
+        super().__init__()
+        self.config = config
+        self.eps = eps
+        self.normalized_shape = (size or config.d_model,)
+        if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
+            self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
+            use_bias = self.config.bias_for_layer_norm
+            if use_bias is None:
+                use_bias = self.config.include_bias
+            if use_bias:
+                self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
+            else:
+                self.register_parameter("bias", None)
+        else:
+            self.register_parameter("bias", None)
+            self.register_parameter("weight", None)
+
+    @abstractmethod
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        raise NotImplementedError
+
+    @classmethod
+    def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> "LayerNormBase":
+        if config.layer_norm_type == LayerNormType.default:
+            return LayerNorm(config, size=size, low_precision=False, **kwargs)
+        elif config.layer_norm_type == LayerNormType.low_precision:
+            return LayerNorm(config, size=size, low_precision=True, **kwargs)
+        elif config.layer_norm_type == LayerNormType.rms:
+            return RMSLayerNorm(config, size=size, **kwargs)
+        elif config.layer_norm_type == LayerNormType.gemma_rms:
+            return GemmaRMSLayerNorm(config, size=size, **kwargs)
+        else:
+            raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")
+
+    def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
+        if tensor.device.type == "cuda" and torch.is_autocast_enabled():
+            return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype())
+        elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled():
+            return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
+        else:
+           return tensor
+
+    def reset_parameters(self):
+        if self.weight is not None:
+            torch.nn.init.ones_(self.weight)  # type: ignore
+        if self.bias is not None:
+            torch.nn.init.zeros_(self.bias)  # type: ignore
+
+
+class LayerNorm(LayerNormBase):
+    def __init__(
+        self,
+        config: ModelConfig,
+        size: Optional[int] = None,
+        low_precision: bool = False,
+        elementwise_affine: Optional[bool] = None,
+        eps: float = 1e-05,
+    ):
+        super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
+        self.low_precision = low_precision
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if self.low_precision:
+            module_device = x.device
+            downcast_x = self._cast_if_autocast_enabled(x)
+            downcast_weight = (
+                self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
+            )
+            downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
+            with torch.autocast(enabled=False, device_type=module_device.type):
+                return F.layer_norm(
+                    downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps
+                )
+        else:
+            return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
+
+
+class RMSLayerNorm(LayerNormBase):
+    def __init__(
+        self,
+        config: ModelConfig,
+        size: Optional[int] = None,
+        elementwise_affine: Optional[bool] = None,
+        eps: float = 1e-5,
+    ):
+        super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        with torch.autocast(enabled=False, device_type=x.device.type):
+            og_dtype = x.dtype
+            x = x.to(torch.float32)
+            variance = x.pow(2).mean(-1, keepdim=True)
+            x = x * torch.rsqrt(variance + self.eps)
+            x = x.to(og_dtype)
+
+        if self.weight is not None:
+            if self.bias is not None:
+                return self.weight * x + self.bias
+            else:
+                return self.weight * x
+        else:
+            return x
+
+
+class GemmaRMSLayerNorm(LayerNormBase):
+    def __init__(
+        self,
+        config: ModelConfig,
+        size: Optional[int] = None,
+        elementwise_affine: Optional[bool] = None,
+        eps: float = 1e-5,
+    ):
+        super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        with torch.autocast(enabled=False, device_type=x.device.type):
+            og_dtype = x.dtype
+            x = x.to(torch.float32)
+            variance = x.pow(2).mean(-1, keepdim=True)
+            x = x * torch.rsqrt(variance + self.eps)
+            x = x.to(og_dtype)
+
+        if self.weight is not None:
+            if self.bias is not None:
+                return x * (1 + self.weight) + self.bias
+            else:
+                return x * (1 + self.weight)
+        else:
+            return x
+
+
+class RotaryEmbedding(nn.Module):
+    def __init__(self, config: ModelConfig, cache: BufferCache):
+        super().__init__()
+        self.config = config
+        self.__cache = cache
+        # Warm up cache.
+        self.rope_theta = config.rope_theta
+        self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config))
+
+    def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
+        if (
+            (pos_sin := self.__cache.get("rope_pos_sin")) is not None
+            and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
+            and pos_sin.shape[-2] >= seq_len
+            and pos_cos.shape[-2] >= seq_len
+        ):
+            if pos_sin.device != device:
+                pos_sin = pos_sin.to(device)
+                self.__cache["rope_pos_sin"] = pos_sin
+            if pos_cos.device != device:
+                pos_cos = pos_cos.to(device)
+                self.__cache["rope_pos_cos"] = pos_cos
+            return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
+
+        with torch.autocast(device.type, enabled=False):
+            dim = self.config.d_model // self.config.n_heads
+            inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
+            seq = torch.arange(seq_len, device=device, dtype=torch.float)
+            freqs = einsum("i , j -> i j", seq, inv_freq)
+            positions = torch.cat((freqs, freqs), dim=-1)
+            pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
+        self.__cache["rope_pos_sin"] = pos_sin
+        self.__cache["rope_pos_cos"] = pos_cos
+        return pos_sin, pos_cos
+
+    def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
+        B, nh, T, hs = x.size()
+        x = x.view(B, nh, T, 2, hs // 2)
+        x1, x2 = x.unbind(dim=-2)
+        return torch.cat((-x2, x1), dim=-1)
+
+    def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
+        return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
+
+    def forward(self, q: torch.Tensor, k: torch.Tensor, q_mask=None) -> Tuple[torch.Tensor, torch.Tensor]:
+        if self.config.rope_full_precision:
+            q_, k_ = q.float(), k.float()
+        else:
+            q_, k_ = q, k
+
+        with torch.autocast(q.device.type, enabled=False):
+            query_len, key_len = q_.shape[-2], k_.shape[-2]  # could be different if layer_past not None
+            pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
+            pos_sin = pos_sin.type_as(q_)
+            pos_cos = pos_cos.type_as(q_)
+            if q_mask is None:
+                q_ = self.apply_rotary_pos_emb(
+                    pos_sin[:, :, key_len - query_len : key_len, :],
+                    pos_cos[:, :, key_len - query_len : key_len, :],
+                    q_,
+                )
+            else:
+                q_ = self.apply_rotary_pos_emb(
+                    pos_sin[:, :, q_mask, :],
+                    pos_cos[:, :, q_mask, :],
+                    q_,
+                )
+            k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
+        return q_.type_as(q), k_.type_as(k)
+
+
+class Activation(nn.Module):
+    def __init__(self, config: ModelConfig):
+        super().__init__()
+        self.config = config
+
+    @abstractmethod
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        raise NotImplementedError
+
+    @property
+    @abstractmethod
+    def output_multiplier(self) -> float:
+        raise NotImplementedError
+
+    @classmethod
+    def build(cls, config: ModelConfig) -> "Activation":
+        if config.activation_type == ActivationType.gelu:
+            return cast(Activation, GELU(approximate="none"))
+        elif config.activation_type == ActivationType.relu:
+            return cast(Activation, ReLU(inplace=False))
+        elif config.activation_type == ActivationType.silu:
+            return cast(Activation, SiLU(inplace=False))
+        elif config.activation_type == ActivationType.swiglu:
+            return SwiGLU(config)
+        else:
+            raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")
+
+
+class GELU(nn.GELU):
+    @property
+    def output_multiplier(self) -> float:
+        return 1.0
+
+
+class ReLU(nn.ReLU):
+    @property
+    def output_multiplier(self) -> float:
+        return 1.0
+
+
+class SiLU(nn.SiLU):
+    @property
+    def output_multiplier(self) -> float:
+        return 1.0
+
+
+class SwiGLU(Activation):
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x, gate = x.chunk(2, dim=-1)
+        return F.silu(gate) * x
+
+    @property
+    def output_multiplier(self) -> float:
+        return 0.5
+
+
+def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
+    att_bias = torch.triu(
+        torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
+        diagonal=1,
+    )
+    att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
+    return att_bias.view(1, 1, seq_len, seq_len)  # type: ignore
+
+
+def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
+    if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
+        if causal_bias.device != device:
+            causal_bias = causal_bias.to(device)
+            cache["causal_attention_bias"] = causal_bias
+        return causal_bias
+    with torch.autocast(device.type, enabled=False):
+        causal_bias = causal_attention_bias(seq_len, device)
+    cache["causal_attention_bias"] = causal_bias
+    return causal_bias
+
+
+def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
+    alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)
+
+    # shape: (1, 1, seq_len, seq_len)
+    alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
+    alibi_bias.abs_().mul_(-1)
+
+    # shape: (n_heads,)
+    m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
+    m.mul_(config.alibi_bias_max / config.n_heads)
+
+    # shape: (1, n_heads, seq_len, seq_len)
+    return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1)))  # type: ignore
+
+
+class LLaDABlock(nn.Module):
+    def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
+        super().__init__()
+        self.layer_id = layer_id
+        self.config = config
+        self.hidden_size = (
+            config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
+        )
+        self.__cache = cache
+        assert config.d_model % config.n_heads == 0
+
+        self._activation_checkpoint_fn = None
+
+        # Dropout.
+        self.dropout = Dropout(config.residual_dropout)
+
+        # Layer norms.
+        self.k_norm: Optional[LayerNormBase] = None
+        self.q_norm: Optional[LayerNormBase] = None
+        if config.attention_layer_norm:
+            self.k_norm = LayerNormBase.build(
+                config,
+                size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
+                elementwise_affine=config.attention_layer_norm_with_affine,
+            )
+            self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
+
+        # Activation function.
+        self.act = Activation.build(config)
+        assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
+
+        # Attention output projection.
+        self.attn_out = nn.Linear(
+            config.d_model, config.d_model, bias=config.include_bias, device=config.init_device
+        )
+
+        # Feed-forward output projection.
+        self.ff_out = nn.Linear(
+            int(self.act.output_multiplier * self.hidden_size),
+            config.d_model,
+            bias=config.include_bias,
+            device=config.init_device,
+        )
+        self.ff_out._is_residual = True  # type: ignore
+
+        # Rotary embeddings.
+        if self.config.rope:
+            self.rotary_emb = RotaryEmbedding(config, self.__cache)
+
+        self.flash_attn_func = None
+        if config.flash_attention:
+            try:
+                from flash_attn import flash_attn_func  # type: ignore
+
+                self.flash_attn_func = flash_attn_func
+            except ModuleNotFoundError:
+                pass
+
+        self.use_cache = False
+        self.init_cache()
+
+    def init_cache(self):
+        self.cache = {"k": {}, "v": {}, "out": {}}
+
+    def caching(self, enable: bool = True):
+        self.use_cache = enable
+        self.init_cache()
+
+    def reset_parameters(self):
+        if self.k_norm is not None:
+            self.k_norm.reset_parameters()
+        if self.q_norm is not None:
+            self.q_norm.reset_parameters()
+        init_weights(
+            self.config,
+            self.attn_out,
+            d=self.config.d_model,
+            layer_id=self.layer_id,
+            type_of_module=ModuleType.out_module,
+        )
+        init_weights(
+            self.config,
+            self.ff_out,
+            d=self.ff_out.in_features,
+            layer_id=self.layer_id,
+            type_of_module=ModuleType.out_module,
+        )
+
+    def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
+        if strategy == ActivationCheckpointingStrategy.fine_grained:
+            self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
+        else:
+            self._activation_checkpoint_fn = None
+
+    @classmethod
+    def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
+        target_dtype = input_dtype
+        if bias.device.type == "cuda" and torch.is_autocast_enabled():
+            target_dtype = torch.get_autocast_gpu_dtype()
+        elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
+            target_dtype = torch.get_autocast_cpu_dtype()
+        if bias.dtype != target_dtype:
+            bias = bias.to(target_dtype)
+            ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
+        return bias
+
+    def _scaled_dot_product_attention(
+        self,
+        q: torch.Tensor,
+        k: torch.Tensor,
+        v: torch.Tensor,
+        attn_mask: Optional[torch.Tensor] = None,
+        dropout_p: float = 0.0,
+        is_causal: bool = False,
+    ) -> torch.Tensor:
+        if self.flash_attn_func is not None and attn_mask is None:
+            r = self.flash_attn_func(
+                q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=False
+            )
+            return r.transpose(1, 2)
+        else:
+            assert k.size(1) == v.size(1)
+            num_kv_heads = k.size(1)
+            num_q_heads = q.size(1)
+            if num_q_heads != num_kv_heads:
+                assert num_q_heads % num_kv_heads == 0
+                k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
+                v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
+
+            return F.scaled_dot_product_attention(
+                q,
+                k,
+                v,
+                attn_mask=None,
+                dropout_p=dropout_p,
+                is_causal=False,
+            )
+
+    def attention(
+        self,
+        q: torch.Tensor,
+        k: torch.Tensor,
+        v: torch.Tensor,
+        attention_bias: Optional[torch.Tensor] = None,
+        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        to_compute_mask=None,
+    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
+        B, T, C = q.size()  # batch size, sequence length, d_model
+        dtype = k.dtype
+
+        if self.q_norm is not None and self.k_norm is not None:
+            q = self.q_norm(q).to(dtype=dtype)
+            k = self.k_norm(k).to(dtype=dtype)
+
+        q = q.view(B, -1, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
+        k = k.view(B, -1, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
+        v = v.view(B, -1, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
+
+        if layer_past is not None:
+            past_key, past_value = layer_past
+            k = torch.cat((past_key, k), dim=-2)
+            v = torch.cat((past_value, v), dim=-2)
+
+        if self.config.rope:
+            to_compute_index = (
+                to_compute_mask.nonzero(as_tuple=True)[1] if self.use_cache and to_compute_mask is not None else None
+            )
+            q, k = self.rotary_emb(q, k, q_mask=to_compute_index)
+
+        if attention_bias is not None:
+            attention_bias = self._cast_attn_bias(attention_bias, dtype)
+
+        att = self._scaled_dot_product_attention(
+            q,
+            k,
+            v,
+            attn_mask=None,
+            dropout_p=0.0 if not self.training else self.config.attention_dropout,
+            is_causal=False,
+        )
+
+        att = att.transpose(1, 2).contiguous().view(B, T, C)
+
+        return self.attn_out(att), None
+
+    @abstractmethod
+    def forward(
+        self,
+        x: torch.Tensor,
+        attention_bias: Optional[torch.FloatTensor] = None,
+        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        use_cache: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
+        raise NotImplementedError
+
+    @classmethod
+    def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> "LLaDABlock":
+        if config.block_type == BlockType.sequential:
+            return LLaDASequentialBlock(layer_id, config, cache)
+        elif config.block_type == BlockType.llama:
+            return LLaDALlamaBlock(layer_id, config, cache)
+        else:
+            raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
+
+
+class LLaDASequentialBlock(LLaDABlock):
+    def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
+        super().__init__(layer_id, config, cache)
+        self.attn_norm = LayerNorm.build(config)
+        self.ff_norm = LayerNorm.build(config)
+        head_dim = config.d_model // config.n_heads
+        self.fused_dims = (
+            config.d_model,
+            config.effective_n_kv_heads * head_dim,
+            config.effective_n_kv_heads * head_dim,
+        )
+        self.att_proj = nn.Linear(
+            config.d_model, sum(self.fused_dims), bias=config.include_bias | config.include_qkv_bias, device=config.init_device
+        )
+        self.ff_proj = nn.Linear(
+            config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
+        )
+
+    def reset_parameters(self):
+        super().reset_parameters()
+        self.attn_norm.reset_parameters()
+        self.ff_norm.reset_parameters()
+        init_weights(
+            self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
+        )
+        init_weights(
+            self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
+        )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        attention_bias: Optional[torch.Tensor] = None,
+        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        use_cache: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
+        if self._activation_checkpoint_fn is not None:
+            q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split(
+                self.fused_dims, dim=-1
+            )
+        else:
+            q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1)
+
+        if self._activation_checkpoint_fn is not None:
+            att, cache = self._activation_checkpoint_fn(  # type: ignore
+                self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
+            )
+        else:
+            att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
+
+        x = x + self.dropout(att)
+
+        og_x = x
+        if self._activation_checkpoint_fn is not None:
+            x = self._activation_checkpoint_fn(self.ff_norm, x)  # type: ignore
+        else:
+            x = self.ff_norm(x)
+        x = self.ff_proj(x)
+        if self._activation_checkpoint_fn is not None:
+            x = self._activation_checkpoint_fn(self.act, x)  # type: ignore
+        else:
+           x = self.act(x)
+        x = self.ff_out(x)
+        x = self.dropout(x)
+        x = og_x + x
+
+        return x, cache
+
+
+class LLaDALlamaBlock(LLaDABlock):
+    def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
+        super().__init__(layer_id, config, cache)
+        self.attn_norm = LayerNorm.build(config)
+        self.ff_norm = LayerNorm.build(config)
+        self.__cache = cache
+
+        head_dim = config.d_model // config.n_heads
+        q_proj_out_dim = config.d_model
+        k_proj_out_dim = config.effective_n_kv_heads * head_dim
+        v_proj_out_dim = config.effective_n_kv_heads * head_dim
+        self.q_proj = nn.Linear(
+            config.d_model, q_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
+        )
+        self.k_proj = nn.Linear(
+            config.d_model, k_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
+        )
+        self.v_proj = nn.Linear(
+            config.d_model, v_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
+        )
+
+        self.ff_proj = nn.Linear(
+            config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
+        )
+        self.up_proj = nn.Linear(
+            config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
+        )
+
+    def reset_parameters(self):
+        super().reset_parameters()
+        self.attn_norm.reset_parameters()
+        self.ff_norm.reset_parameters()
+        init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None)
+        init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None)
+        init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None)
+        init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None)
+        init_weights(self.config, self.up_proj, d=self.config.d_model, layer_id=None)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        attention_bias: Optional[torch.Tensor] = None,
+        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        use_cache: bool = False,
+        cat="cond",
+        to_compute_mask=None,
+    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
+        B, T, D = x.shape
+
+        x_normed = self.attn_norm(x)
+        q = self.q_proj(x_normed)
+        k = self.k_proj(x_normed)
+        v = self.v_proj(x_normed)
+
+        if use_cache:
+            if cat not in self.cache["k"]:
+                self.cache["k"][cat] = torch.zeros_like(x)
+                self.cache["v"][cat] = torch.zeros_like(x)
+            if to_compute_mask is not None:
+                self.cache["k"][cat][to_compute_mask] = k.view(-1, D)
+                self.cache["v"][cat][to_compute_mask] = v.view(-1, D)
+                k = self.cache["k"][cat]
+                v = self.cache["v"][cat]
+            else:
+                self.cache["k"][cat] = k
+                self.cache["v"][cat] = v
+
+        if self._activation_checkpoint_fn is not None:
+            att, cache = self._activation_checkpoint_fn(  # type: ignore
+                self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
+            )
+        else:
+            att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, to_compute_mask=to_compute_mask)
+
+        x = x + self.dropout(att)
+
+        og_x = x
+        if self._activation_checkpoint_fn is not None:
+            x = self._activation_checkpoint_fn(self.ff_norm, x)  # type: ignore
+        else:
+            x = self.ff_norm(x)
+        x, x_up = self.ff_proj(x), self.up_proj(x)
+        if self._activation_checkpoint_fn is not None:
+            x = self._activation_checkpoint_fn(self.act, x)  # type: ignore
+        else:
+            x = self.act(x)
+        x = x * x_up
+        x = self.ff_out(x)
+        x = self.dropout(x)
+        x = og_x + x
+
+        return x, cache
+
+
+class LLaDAOutput(NamedTuple):
+    logits: torch.FloatTensor
+    attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
+    hidden_states: Optional[Tuple[torch.Tensor]]
+
+
+class LLaDAGenerateOutput(NamedTuple):
+    token_ids: torch.LongTensor
+    scores: torch.FloatTensor
+
+
+class LLaDABlockGroup(nn.ModuleList):
+    def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None):
+        super().__init__(modules)
+        self.config = config
+        self.layer_offset = layer_offset
+        self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
+        self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        attention_bias: Optional[torch.FloatTensor] = None,
+        layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
+        use_cache: bool = False,
+    ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
+        attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
+        for block_idx, block in enumerate(self):
+            layer_past = None if layers_past is None else layers_past[block_idx]
+            block_idx += self.layer_offset
+            if (
+                (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
+                or (
+                    self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
+                    and block_idx % 2 == 0
+                )
+                or (
+                    self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
+                    and block_idx % 3 == 0
+                )
+                or (
+                    self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
+                    and block_idx % 4 == 0
+                )
+            ):
+                x, cache = self._activation_checkpoint_fn(  # type: ignore
+                    block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
+                )
+            else:
+                x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
+            if attn_key_values is not None:
+                assert cache is not None
+                attn_key_values.append(cache)
+        return x, attn_key_values
+
+    def reset_parameters(self):
+        for block in self:
+            block.reset_parameters()
+
+    def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
+        self.activation_checkpointing_strategy = strategy
+        for block in self:
+            block.set_activation_checkpointing(strategy)
+
+
+class LLaDAModel(nn.Module):
+    def __init__(self, config: ModelConfig, init_params: bool = True):
+        super().__init__()
+        self.config = config
+        self.__cache = BufferCache()
+
+        if self.config.alibi and self.config.flash_attention:
+            raise Exception("ALiBi is currently not supported with FlashAttention")
+
+        if self.config.alibi and self.config.rope:
+            raise Exception("ALiBi and RoPE are mutually exclusive")
+
+        if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
+            if self.config.embedding_size < self.config.vocab_size:
+                raise Exception("embedding size should be at least as big as vocab size")
+            elif self.config.embedding_size % 128 != 0:
+                import warnings
+
+                warnings.warn(
+                    "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
+                )
+
+        self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
+        self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
+
+        if not (
+            0 < self.config.block_group_size <= self.config.n_layers
+            and self.config.n_layers % self.config.block_group_size == 0
+        ):
+            raise Exception("n layers must be divisible by block group size")
+
+        torch.backends.cuda.enable_flash_sdp(True)
+        torch.backends.cuda.enable_mem_efficient_sdp(False)  # this is super slow so make sure torch won't use it
+
+        self.transformer = nn.ModuleDict(
+            dict(
+                wte=nn.Embedding(
+                    config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
+                ),
+                emb_drop=Dropout(config.embedding_dropout),
+                ln_f=LayerNorm.build(config),
+            )
+        )
+
+        blocks = [LLaDABlock.build(i, config, self.__cache) for i in range(config.n_layers)]
+        if self.config.block_group_size > 1:
+            block_groups = [
+                LLaDABlockGroup(config, i, blocks[i : i + config.block_group_size])
+                for i in range(0, config.n_layers, config.block_group_size)
+            ]
+            self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
+        else:
+            self.transformer.update({"blocks": nn.ModuleList(blocks)})
+
+        if not (self.config.alibi or self.config.rope):
+            self.transformer.update(
+                {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
+            )
+        if not config.weight_tying:
+            self.transformer.update(
+                {
+                    "ff_out": nn.Linear(
+                        config.d_model,
+                        config.embedding_size or config.vocab_size,
+                        bias=config.include_bias,
+                        device=config.init_device,
+                    )
+                }
+            )
+        if init_params and self.config.init_device != "meta":
+            self.reset_parameters()
+        self.__num_fwd_flops: Optional[int] = None
+
+        if self.config.alibi:
+            get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
+            self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
+
+        self.logit_cache = {}
+
+    def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
+        self.activation_checkpointing_strategy = strategy
+        if self.config.block_group_size != 1:
+            for block_group in self.transformer.block_groups:
+                block_group.set_activation_checkpointing(strategy)
+        else:
+            for block in self.transformer.blocks:
+                block.set_activation_checkpointing(strategy)
+
+    @property
+    def device(self) -> torch.device:
+        device: torch.device = self.transformer.wte.weight.device  # type: ignore
+        if device.type == "meta":
+            return _non_meta_init_device(self.config)
+        else:
+            return device
+
+    def reset_parameters(self):
+        logger.info("Initializing model parameters...")
+        init_weights(
+            self.config,
+            self.transformer.wte,  # type: ignore
+            std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0,
+            type_of_module=ModuleType.emb,
+        )
+        if hasattr(self.transformer, "wpe"):
+            init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb)  # type: ignore
+
+        self.transformer.ln_f.reset_parameters()  # type: ignore
+
+        if hasattr(self.transformer, "ff_out"):
+            init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out)  # type: ignore
+
+        if self.config.block_group_size == 1:
+            for block in self.transformer.blocks:
+                block.reset_parameters()
+        else:
+            for block_group in self.transformer.block_groups:
+                block_group.reset_parameters()
+
+    def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
+        if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[
+            -1
+        ] >= seq_len:
+            if alibi_bias.device != device:
+                alibi_bias = alibi_bias.to(device)
+                self.__cache["alibi_attention_bias"] = alibi_bias
+            return alibi_bias
+        with torch.autocast(device.type, enabled=False):
+            alibi_bias = alibi_attention_bias(seq_len, self.config, device)
+        self.__cache["alibi_attention_bias"] = alibi_bias
+        return alibi_bias
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor,
+        input_embeddings: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        attention_bias: Optional[torch.Tensor] = None,
+        past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
+        last_logits_only: bool = False,
+        output_hidden_states: Optional[bool] = None,
+        use_cache=False,
+        to_compute_mask=None,
+        cat="",
+    ) -> LLaDAOutput:
+        if use_cache and to_compute_mask is not None:
+            input_ids = input_ids[to_compute_mask].view(input_ids.shape[0], -1)
+
+        assert not self.config.alibi, "Alibi length extrapolation is not supported for MDM."
+        assert self.config.rope, "Rope must be used in Llama-Encoder for MDM."
+
+        output_hidden_states = output_hidden_states if output_hidden_states is not None else False
+
+        if past_key_values:
+            assert len(past_key_values) == self.config.n_layers
+
+        batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
+        if past_key_values is None:
+            past_length = 0
+        else:
+            past_length = past_key_values[0][0].size(-2)
+
+        x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings  # type: ignore
+
+        if self.config.input_emb_norm:
+            x = x * (self.config.d_model**0.5)
+
+        if not (self.config.alibi or self.config.rope):
+            pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
+            pos_emb = self.transformer.wpe(pos)  # type: ignore
+            x = pos_emb + x
+
+        x = self.transformer.emb_drop(x)  # type: ignore
+
+        if attention_mask is not None and 0.0 in attention_mask:
+            attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
+            attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
+        else:
+            attention_mask = None
+
+        if (
+            attention_bias is not None
+            or attention_mask is not None
+            or self.config.alibi
+            or past_key_values is not None
+        ):
+            if attention_bias is None and self.config.alibi:
+                attention_bias = get_causal_attention_bias(
+                    self.__cache, past_length + seq_len, x.device
+                ) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
+            elif attention_bias is None:
+                attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
+            elif attention_bias.dtype in (torch.int8, torch.bool):
+                attention_bias = attention_bias.to(dtype=torch.float)
+                attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
+
+            mask_len = seq_len
+            if attention_mask is not None:
+                mask_len = attention_mask.shape[-1]
+            elif past_key_values is not None:
+                mask_len = past_key_values[0][0].shape[-2] + seq_len
+            attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
+
+            if attention_mask is not None:
+                attention_bias = attention_bias + attention_mask
+                ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
+
+        attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
+
+        all_hidden_states = []
+
+        if self.config.block_group_size == 1:
+            for block_idx, block in enumerate(self.transformer.blocks):
+                if output_hidden_states:
+                    all_hidden_states.append(x)
+
+                layer_past = None if past_key_values is None else past_key_values[block_idx]
+                if (
+                    (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
+                    or (
+                        self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
+                        and block_idx % 2 == 0
+                    )
+                    or (
+                        self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
+                        and block_idx % 3 == 0
+                    )
+                    or (
+                        self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
+                        and block_idx % 4 == 0
+                    )
+                ):
+                    x, _ = self._activation_checkpoint_fn(
+                        block, x, attention_bias=attention_bias, layer_past=layer_past, to_compute_mask=to_compute_mask, use_cache=use_cache, cat=cat
+                    )
+                else:
+                    LLaDALlamaBlock.forward
+                    x, _ = block(
+                        x, attention_bias=attention_bias, layer_past=layer_past, to_compute_mask=to_compute_mask, use_cache=use_cache, cat=cat
+                    )
+        else:
+            for group_idx, block_group in enumerate(self.transformer.block_groups):
+                if output_hidden_states:
+                    all_hidden_states.append(x)
+
+                layers_past = (
+                    None
+                    if past_key_values is None
+                    else past_key_values[
+                        group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
+                    ]
+                )
+                x, _ = block_group(
+                    x, attention_bias=attention_bias, layers_past=layers_past, to_compute_mask=to_compute_mask, use_cache=use_cache, cat=cat
+                )
+
+        if last_logits_only:
+            x = x[:, -1, :].unsqueeze(1)
+
+        x = self.transformer.ln_f(x)  # type: ignore
+        if output_hidden_states:
+            all_hidden_states.append(x)
+
+        if self.config.weight_tying:
+            logits = F.linear(x, self.transformer.wte.weight, None)  # type: ignore
+        else:
+            logits = self.transformer.ff_out(x)  # type: ignore
+        if self.config.scale_logits:
+            logits.mul_(1 / math.sqrt(self.config.d_model))
+
+        if use_cache:
+            if cat not in self.logit_cache:
+                self.logit_cache[cat] = torch.zeros_like(logits)
+            if to_compute_mask is not None:
+                self.logit_cache[cat][to_compute_mask] = logits.view(-1, logits.shape[-1])
+                logits = self.logit_cache[cat]
+            else:
+                self.logit_cache[cat] = logits
+
+        return LLaDAOutput(
+            logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None
+        )  # type: ignore[arg-type]
+
+    def caching(self, enable: bool = True):
+        LLaDABlock.caching
+        for block in self.transformer.blocks:
+            block.caching(enable)
+        self.logit_cache = {}
+
+    def empty_cache(self):
+        for block in self.transformer.blocks:
+            block.init_cache()
+        self.logit_cache = {}
+
+
+def create_model_config_from_pretrained_config(config: LLaDAConfig):
+    kwargs = {}
+    for field in fields(ModelConfig):
+        kwargs[field.name] = getattr(config, field.name)
+
+    model_config = ModelConfig(**kwargs)
+    return model_config
+
+
+class LLaDAModelLM(PreTrainedModel):
+    config_class = LLaDAConfig
+    base_model_prefix = "model"
+    _no_split_modules = ["LLaDABlock", "LLaDASequentialBlock", "LLaDALlamaBlock"]
+
+    def __init__(self, config: LLaDAConfig, model: Optional[LLaDAModel] = None, init_params: bool = False):
+        super().__init__(config)
+
+        if not model:
+            model_config = create_model_config_from_pretrained_config(config)
+            model_config.init_device = "cpu"
+            self.model = LLaDAModel(model_config, init_params=init_params)
+        else:
+            self.model = model
+
+    def forward(
+        self,
+        input_ids: torch.LongTensor = None,
+        inputs_embeds: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.Tensor] = None,
+        attention_bias: Optional[torch.Tensor] = None,
+        past_key_values: Optional[List[torch.FloatTensor]] = None,
+        labels: Optional[torch.LongTensor] = None,
+        output_attentions: Optional[bool] = None,
+        output_hidden_states: Optional[bool] = None,
+        return_dict: Optional[bool] = None,
+        cache_position: Optional[Cache] = None,  # This is a hack mitigation of an issue in transformers `4.39.x`
+        use_cache=False,
+        to_compute_mask=None,
+        cat="",
+    ) -> Union[Tuple, CausalLMOutputWithPast]:
+        if output_attentions:
+            raise ValueError("output_attentions is not yet supported in LLaDA")
+
+        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+        outputs = self.model.forward(
+            input_ids=input_ids,
+            input_embeddings=inputs_embeds,
+            attention_mask=attention_mask,
+            attention_bias=attention_bias,
+            past_key_values=past_key_values,
+            output_hidden_states=output_hidden_states,
+            use_cache=use_cache,
+            to_compute_mask=to_compute_mask,
+            cat=cat,
+        )
+
+        logits = outputs.logits
+        hidden_states = outputs.hidden_states
+
+        loss = None
+        if labels is not None:
+            import warnings
+
+            warnings.warn("Note that for LLaDA, you cannot calculate the loss here.", UserWarning)
+        if not return_dict:
+            output = (logits,) + outputs[1:]
+            return (loss,) + output if loss is not None else output
+
+        return CausalLMOutputWithPast(
+            logits=logits,
+            past_key_values=outputs.attn_key_values,
+            hidden_states=hidden_states,
+        )
+
+    def can_generate(self) -> bool:
+        return True
+
+    def prepare_inputs_for_generation(
+        self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
+    ):
+        if past_key_values:
+            input_ids = input_ids[:, -1:]
+        model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
+
+        model_inputs.update(kwargs)
+        model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
+        return model_inputs
+
+    def get_input_embeddings(self) -> torch.nn.Module:
+        return self.model.transformer.wte
+
+    def set_input_embeddings(self, value: torch.nn.Module):
+        self.model.transformer.wte = value
+
+    def get_output_embeddings(self):
+        if self.config.weight_tying:
+            return self.model.transformer.wte
+        else:
+            return self.model.transformer.ff_out
+
+    def set_output_embeddings(self, value: torch.nn.Module):
+        if self.config.weight_tying:
+            self.model.transformer.wte = value
+        else:
+            self.model.transformer.ff_out = value
+
+    def tie_weights(self):
+        if self.config.weight_tying:
+            self.model.transformer.ff_out = self.model.transformer.wte
+
+    def caching(self, enable: bool = True):
+        self.model.caching(enable)
+
+    def empty_cache(self):
+        self.model.empty_cache()
+
+
+
+
+def create_attention_mask(original_lengths, max_tokens, device):
+    batch_size = len(original_lengths)
+    attention_mask = torch.zeros(batch_size, max_tokens, dtype=torch.bool, device=device)
+    for i, length in enumerate(original_lengths):
+        attention_mask[i, :length] = 1
+    return attention_mask
+
+
+class LLaDAForMultiModalGeneration(LLaDAModelLM):
+    config_class = LLaDAConfig
+    base_model_prefix = "model"
+
+    def __init__(self, config: LLaDAConfig, *args, **kwargs):
+        logger.info(f"Initializing MMadaModelLM with config: {config}")
+        super().__init__(config, *args, **kwargs)
+
+    def forward(self, input_ids=None, labels=None, infer=False, use_cache=False, to_compute_mask=None, cat="", **kwargs):
+        input_ids = input_ids.tolist()
+        max_tokens = max([len(_) for _ in input_ids])
+        original_lengths = [len(example) for example in input_ids]
+        input_ids = [example + [0] * (max_tokens - len(example)) for example in input_ids]
+        input_ids = torch.tensor(input_ids, dtype=torch.int64, device=self.device)
+        attention_mask = create_attention_mask(original_lengths, max_tokens, self.device)
+
+        output = LLaDAModelLM.forward(
+            self, input_ids=input_ids, attention_mask=attention_mask, use_cache=use_cache, to_compute_mask=to_compute_mask, cat=cat
+        )
+        if infer:
+            return output
+
+    def get_fsdp_wrap_module_list(self) -> List:
+        modules = [*list(self.model.transformer.blocks), self.model.transformer.ff_out]
+        return modules
+
+
+AutoConfig.register("llada", LLaDAConfig)
+
+# --- End of model definition ---
+
+EXAMPLE_DOC_STRING = """
+Examples:
+```py
+>>> import torch
+>>> from PIL import Image
+>>> from diffusers import VQModel, DiffusionPipeline
+>>> from transformers import AutoTokenizer
+>>> from diffusers.utils import load_image
+
+>>> CHECKPOINT = "Alpha-VLLM/Lumina-DiMOO"
+
+>>> # Load VQ-VAE and tokenizer
+>>> vqvae = VQModel.from_pretrained(CHECKPOINT, subfolder="vqvae").to(device=device, dtype=torch_dtype)
+>>> tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT, trust_remote_code=True)
+
+>>> # Initialize the Lumina-DiMOO pipeline
+>>> pipe = DiffusionPipeline.from_pretrained(
+...     CHECKPOINT,
+...     custom_pipeline="lumina_dimoo",
+...     vqvae=vqvae,
+...     tokenizer=tokenizer,
+...     torch_dtype=torch.bfloat16
+... )
+>>> pipe.to("cuda")
+
+>>> # Load input image
+>>> input_image  = Image.open("path/to/your/ref_image.png").convert("RGB")
+
+>>> prompt = (
+...     " your prompt. "
+... )
+
+>>> # Run image-to-image generation
+>>> out = pipe(
+...     prompt=prompt,
+...     image=input_image,
+...     edit_type="depth_control",
+...     num_inference_steps=64,
+...     task="image_to_image",
+... )
+
+>>> out.images[0].save("i2i_test_output.png")
+"""
+
+
+# --- Helper functions ---
+
+
+def cosine_schedule(t):
+    return torch.cos(t * math.pi / 2)
+
+
+def gumbel_noise(t: torch.Tensor, *, generator: Optional[torch.Generator] = None) -> torch.Tensor:
+    if generator is None:
+        u = torch.rand_like(t)
+    else:
+        u = torch.rand(t.shape, device=t.device, dtype=t.dtype, generator=generator)
+    return -torch.log(-torch.log(u + 1e-20) + 1e-20)
+
+
+def add_gumbel_noise(logits, temperature):
+    """
+    Gumbel noise addition function
+    According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality
+    Therefore using float64
+    """
+    if temperature == 0:
+        return logits
+    logits = logits.to(torch.float64)
+    noise = torch.rand_like(logits, dtype=torch.float64)
+    gumbel_noise = (- torch.log(noise)) ** temperature
+    return logits.exp() / gumbel_noise
+
+
+def gumbel_max_sample(logits, temperature=1.0, generator=None):
+    if temperature == 0.0:
+        return logits.argmax(dim=-1)
+    gumbel_noise_ = gumbel_noise(logits, generator=generator)
+    return torch.argmax(logits / temperature + gumbel_noise_, dim=-1)
+
+def get_num_transfer_tokens(mask_index, steps):
+    """
+    In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals
+    Since LLaDA employs a linear noise schedule (as defined in Eq.(8)),
+    the expected number of tokens transitioned at each step should be consistent
+    
+    This function is designed to precompute the number of tokens that need to be transitioned at each step
+    """
+    mask_num = mask_index.sum(dim=1, keepdim=True)
+
+    base = mask_num // steps
+    remainder = mask_num % steps
+
+    num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
+
+    for i in range(mask_num.size(0)):
+        num_transfer_tokens[i, :remainder[i]] += 1
+
+    return num_transfer_tokens
+
+
+def mask_by_random_topk(keep_n, probs, temperature=1.0, generator=None):
+    B, S = probs.shape
+    noise = gumbel_noise(probs, generator=generator)
+
+    conf = probs / temperature + noise
+
+    mask = torch.zeros_like(conf, dtype=torch.bool)
+    for i in range(B):
+        k = keep_n[i]
+        if k > 0:
+            top_k_indices = torch.topk(conf[i], k, largest=True).indices
+            mask[i, top_k_indices] = True
+    return mask
+
+
+def calculate_vq_params(height, width, vae_scale_factor=32):
+    token_grid_height = height // vae_scale_factor
+    token_grid_width = width // vae_scale_factor
+    seq_len = token_grid_height * token_grid_width
+    newline_every = token_grid_width
+    return seq_len, newline_every, token_grid_height, token_grid_width
+
+
+def add_break_line(tokens, token_grid_height, token_grid_width, new_number):
+    new_tokens = []
+    for i in range(token_grid_height):
+        start = i * token_grid_width
+        end = (i + 1) * token_grid_width
+        row = tokens[start:end]
+        new_tokens.extend(row)
+        if i < token_grid_height - 1:
+            new_tokens.append(new_number)
+    return new_tokens
+
+
+def generate_crop_size_list(num_patches, patch_size, max_ratio=4.0):
+    assert max_ratio >= 1.0
+    crop_size_list = []
+    wp, hp = num_patches, 1
+    while wp > 0:
+        if max(wp, hp) / min(wp, hp) <= max_ratio:
+            crop_size_list.append((wp * patch_size, hp * patch_size))
+        if (hp + 1) * wp <= num_patches:
+            hp += 1
+        else:
+            wp -= 1
+    return crop_size_list
+
+
+def center_crop(pil_image, crop_size):
+    while pil_image.size[0] >= 2 * crop_size[0] and pil_image.size[1] >= 2 * crop_size[1]:
+        pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
+
+    scale = max(crop_size[0] / pil_image.size[0], crop_size[1] / pil_image.size[1])
+    pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
+
+    crop_left = random.randint(0, pil_image.size[0] - crop_size[0])
+    crop_upper = random.randint(0, pil_image.size[1] - crop_size[1])
+    crop_right = crop_left + crop_size[0]
+    crop_lower = crop_upper + crop_size[1]
+    return pil_image.crop(box=(crop_left, crop_upper, crop_right, crop_lower))
+
+
+def var_center_crop(pil_image, crop_size_list, random_top_k=1):
+    w, h = pil_image.size
+    rem_percent = [min(cw / w, ch / h) / max(cw / w, ch / h) for cw, ch in crop_size_list]
+    crop_size = random.choice(
+        sorted(((x, y) for x, y in zip(rem_percent, crop_size_list)), reverse=True)[:random_top_k]
+    )[1]
+    return center_crop(pil_image, crop_size)
+
+
+def preprocess_image(image: Image.Image):
+    image = image.convert("RGB")
+    w, h = image.size
+    w, h = (x - x % 32 for x in (w, h))
+    image = image.resize((w, h), resample=Image.LANCZOS)
+    image = np.array(image).astype(np.float32) / 255.0
+    image = image[None].transpose(0, 3, 1, 2)
+    image = torch.from_numpy(image)
+    return 2.0 * image - 1.0
+
+
+def encode_img_with_breaks(image, vqvae, special_tokens, vae_scale_factor: int = 16):
+    """
+    Encode image, add VQ offset, add newlines, and wrap with BOI/EOI tokens.
+    This function mirrors the logic from the original inference script.
+    """
+    orig = image.convert("RGB")
+    image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor, do_normalize=False)
+    pixels = image_processor.preprocess(orig).to(vqvae.device, dtype=vqvae.dtype)
+    latents = vqvae.encode(pixels).latents
+
+    latents_bsz, _, lat_h, lat_w = latents.shape
+
+    quantized = vqvae.quantize(latents)[2][2] + special_tokens["image_token_offset"]
+    quantized_with_offset = quantized.reshape(latents_bsz, lat_h, lat_w).flatten().tolist()
+
+    tokens_with_breaks = add_break_line(
+        quantized_with_offset, lat_h, lat_w, special_tokens["newline_token"]
+    )
+    return [special_tokens["boi"]] + tokens_with_breaks + [special_tokens["eoi"]]
+
+
+def create_prompt_templates():
+    """Create prompt templates for various tasks based on prompt_utils.py"""
+    templates = {
+        "text_understanding": "You are a multimodal model that can process both text and images. Answer the following question based on the provided images. Analyze each image and combine relevant details to answer.",
+        "image_generation": "Generate an image according to the text prompt.",
+        "image_editing": "Generate an image applying the following editing instruction based on the original image.",
+        "dense_prediction": "Perform dense prediction on the given images.",
+        "control_generation": "Generate an image according to the text prompt and the given control image.",
+        "subject_generation": "Generate an image according to the text prompt and the given object image.",
+        "multi_view": "Generate a view-image based on the given image.",
+        "style_transfer": "Transform the current image into the style of the provided image.",
+    }
+    return templates
+
+
+def generate_image_to_image_prompt(prompt_text, edit_type, templates):
+    """
+    Generate prompt for image-to-image generation based on prompt_utils.py
+    """
+    if "dense" in edit_type or "canny_pred" in edit_type:
+        des = {
+            "canny": "canny edge map",
+            "hed": "hed edge map",
+            "normal": "normal map",
+            "sam2mask": "sam2 mask",
+            "depth": "depth map",
+            "openpose": "pose estimation map",
+        }
+        system_prompt = templates["dense_prediction"]
+        prompt_text_used = f"Generate a {des.get(edit_type.split('_')[0], 'dense map')} according to the image."
+
+    elif "control" in edit_type:
+        system_prompt = templates["control_generation"]
+        prompt_text_used = prompt_text
+
+    elif "subject" in edit_type:
+        system_prompt = templates["subject_generation"]
+        prompt_text_used = prompt_text
+
+    elif "edit" in edit_type:
+        system_prompt = templates["image_editing"]
+        prompt_text_used = prompt_text
+
+    elif "ref_transfer" in edit_type or "image_ref_transfer" in edit_type:
+        system_prompt = templates["style_transfer"]
+        prompt_text_used = "Transform the current image into the style of the provided image."
+
+    elif "multi_view" in edit_type:
+        system_prompt = templates["multi_view"]
+        prompt_text_used = f"Generate the {edit_type.split('_')[-1]} view based on the provided front view."
+
+    else:
+        system_prompt = "Generate an image according to the prompt and image."
+        prompt_text_used = prompt_text
+
+    input_prompt = "" + system_prompt + "" + "" + prompt_text_used + ""
+    uncon_prompt = "" + system_prompt + "" + "" + "" + ""
+
+    return input_prompt, uncon_prompt
+
+def generate_text_to_image_prompt(prompt_text: str, templates: Optional[Dict] = None) -> Tuple[str, str]:
+    """
+    Generate prompt for text-to-image generation
+    
+    Args:
+        prompt_text: User input text prompt
+        templates: Optional prompt templates dict
+        
+    Returns:
+        Tuple of (input_prompt, unconditional_prompt)
+    """
+    if templates is None:
+        templates = create_prompt_templates()
+    
+    system_prompt = templates["image_generation"]
+    input_prompt = "" + system_prompt + "" + "" + prompt_text + ""
+    uncon_prompt = "" + system_prompt + "" + "" + "" + ""
+    
+    return input_prompt, uncon_prompt
+
+
+def generate_multimodal_understanding_prompt(question: str, templates: Optional[Dict] = None) -> str:
+    """
+    Generate prompt for multimodal understanding (MMU)
+    
+    Args:
+        question: User question about the image
+        templates: Optional prompt templates dict
+        
+    Returns:
+        Formatted input prompt
+    """
+    if templates is None:
+        templates = create_prompt_templates()
+    
+    system_prompt = "You are a multimodal model that can process both text and images. Answer the following question based on the provided images. Analyze each image and combine relevant details to answer."
+    input_prompt = "" + system_prompt + "" + "" + question + ""
+    
+    return input_prompt
+
+
+@torch.no_grad()
+def encode_img_with_paint(
+    img: Image.Image,
+    vqvae: VQModel,
+    *,
+    mask_h_ratio: float = 1,   # Height ratio
+    mask_w_ratio: float = 0.2,    # Width ratio
+    gray_value: int = 127,        # Visualization gray value
+    downsample_mode: str = "area",# Pixel mask alignment to latent grid
+    dilate_latent_k: int = 0,     # Optional dilation on latent grid (grid count)
+    mask_mode: str = "inpainting",   # "inpainting" | "outpainting"
+    special_tokens
+):
+    """
+    Encode image with mask for inpainting/outpainting tasks
+    
+    Args:
+        img: Input PIL image
+        vqvae: VQ-VAE model for encoding
+        mask_h_ratio: Height ratio for mask region (default: 1.0)
+        mask_w_ratio: Width ratio for mask region (default: 0.2)
+        gray_value: Gray value for mask visualization (default: 127)
+        downsample_mode: Downsampling mode for mask alignment ("area", "nearest", "bilinear")
+        dilate_latent_k: Dilation kernel size for latent grid (default: 0)
+        mask_mode: Mask mode - "inpainting" (mask inside) or "outpainting" (mask outside)
+    
+    Returns:
+        img_token: List[int] - Token sequence with newlines (126084) inserted at row ends;
+                              masked positions = 126336, others = index + 126356
+        vis_img: PIL.Image - Gray mask visualization image (consistent with mask_mode)
+    
+    Note:
+        * Encoding uses original image strictly; mask only maps to latent grid to determine
+          which tokens are set to MASK_TOKEN_ID.
+        * mask_mode="inpainting": mask inside rectangle; "outpainting": mask outside rectangle (inverse).
+    """
+
+    assert mask_mode in ("inpainting", "outpainting"), "mask_mode must be 'inpainting' or 'outpainting'"
+
+    # --- 1) Calculate center rectangle and generate visualization ---
+    img = img.convert("RGB")
+    W, H = img.size
+    mh = int(round(H * mask_h_ratio))
+    mw = int(round(W * mask_w_ratio))
+    top = (H - mh) // 2
+    left = (W - mw) // 2
+    bottom = top + mh
+    right = left + mw
+
+    if mask_mode == "inpainting":
+        vis_img = img.copy()
+        draw = ImageDraw.Draw(vis_img)
+        draw.rectangle([left, top, right, bottom], fill=(gray_value, gray_value, gray_value))
+    elif mask_mode == "outpainting":  # outpainting
+        bg = Image.new("RGB", (W, H), (gray_value, gray_value, gray_value))
+        crop = img.crop((left, top, right, bottom))
+        bg.paste(crop, (left, top))
+        vis_img = bg
+
+    # --- 2) VQ encoding using original image ---
+    vae_scale_factor = 2 ** (len(vqvae.config.block_out_channels) - 1)
+    image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor, do_normalize=False)
+    x = image_processor.preprocess(img).to(vqvae.device)  # 1 x 3 x H' x W'
+    latents = vqvae.encode(x).latents                     # 1 x C x h x w
+    _, _, lat_h, lat_w = latents.shape
+
+    # Quantization indices
+    quant_pack = vqvae.quantize(latents)
+    indices = quant_pack[2][2].view(1, lat_h, lat_w)      # 1 x h x w, long
+
+    # --- 3) Pixel mask -> latent grid mask (aligned with encoding input size) ---
+    Hp, Wp = x.shape[-2:]
+    mask_px = torch.zeros((1, 1, Hp, Wp), dtype=torch.float32, device=vqvae.device)
+    # First generate mask where "rectangle inside=1, outside=0"
+    top_p  = int(round(top  * Hp / H))
+    left_p = int(round(left * Wp / W))
+    bh_p   = int(round(mh   * Hp / H))
+    bw_p   = int(round(mw   * Wp / W))
+    mask_px[:, :, top_p:top_p+bh_p, left_p:left_p+bw_p] = 1.0
+
+    # If outpainting, need to invert (outside=1, inside=0 is the masked region)
+    if mask_mode == "outpainting":
+        mask_px = 1.0 - mask_px
+
+    if downsample_mode not in ("nearest", "area", "bilinear"):
+        downsample_mode = "area"
+    mask_lat = F.interpolate(mask_px, size=(lat_h, lat_w), mode=downsample_mode)
+    mask_lat = (mask_lat > 0.5) if downsample_mode == "area" else (mask_lat >= 0.5)
+    mask_lat = mask_lat[0, 0]        # h x w (bool)
+
+    # Optional: latent grid dilation (after inversion is applied)
+    if dilate_latent_k > 0:
+        m = mask_lat.float().unsqueeze(0).unsqueeze(0)
+        ker = 2 * dilate_latent_k + 1
+        m = F.max_pool2d(m, kernel_size=ker, stride=1, padding=dilate_latent_k)
+        mask_lat = (m[0, 0] > 0.5)
+
+    # --- 4) Generate tokens: masked positions=MASK_TOKEN_ID, others=indices+VQ_OFFSET ---
+    idx_flat = indices.view(-1)
+    mask_flat = mask_lat.view(-1)
+    tokens = torch.empty_like(idx_flat)
+    tokens[mask_flat] = special_tokens['mask_token']
+    tokens[~mask_flat] = idx_flat[~mask_flat] + special_tokens['image_token_offset']
+    tokens_list = tokens.tolist()
+
+    # --- 5) Insert newlines (no longer wrapped in /, consistent with current return) ---
+
+    img_token = add_break_line(tokens_list, lat_h, lat_w, special_tokens['newline_token'])
+    return img_token, vis_img
+
+
+class LuminaDiMOOPipelineOutput(BaseOutput):
+    """
+    Output class for the Lumina-DiMOO pipeline.
+
+    Args:
+        images (`List[PIL.Image.Image]` or `np.ndarray`, *optional*):
+            List of generated PIL images of length `batch_size` or NumPy array of shape `(batch_size, height, width,
+            num_channels)`.
+        text (`str`, *optional*):
+            Generated text from the multimodal understanding task.
+    """
+
+    images: Optional[Union[List[Image.Image], np.ndarray]] = None
+    text: Optional[str] = None
+
+
+class LuminaDiMOOPipeline(DiffusionPipeline):
+    """
+    A unified pipeline for Text-to-Image, Image-to-Image, and Multimodal Understanding
+    using the Lumina-DiMOO model.
+
+    This model was contributed by https://huggingface.co/Alpha-VLLM
+
+    Args:
+        llm ([`LLaDAForMultiModalGeneration`]):
+            The core LLM for multimodal generation, e.g., `LLaDAForMultiModalGeneration`.
+        vqvae ([`VQModel`]):
+            Vector Quantized Variational Auto-Encoder (VQ-VAE) model to encode and decode images to and from discrete
+            latent representations.
+        tokenizer ([`AutoTokenizer`):
+            An `AutoTokenizer` to tokenize text prompts.
+    """
+
+    def __init__(
+        self,
+        vqvae: VQModel,
+        tokenizer: AutoTokenizer,
+        checkpoint: Optional[str] = "Alpha-VLLM/Lumina-DiMOO",
+        torch_dtype: Optional[torch.dtype] = torch.bfloat16,
+        device_map: Optional[str] = "auto",         
+        low_cpu_mem_usage: bool = True, 
+    ):
+        super().__init__()
+        self.register_modules(
+            vqvae=vqvae,
+            tokenizer=tokenizer,
+        )
+        
+        self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1)
+        self.special_tokens = {
+            "mask_token": 126336,
+            "newline_token": 126084,
+            "boa": 126354,
+            "eoa": 126355,
+            "boi": 126349,
+            "eoi": 126350,
+            "image_token_offset": 126356,
+            "uncondition":126351
+        }
+        self.prompt_templates = create_prompt_templates()
+
+        # If checkpoint is not provided, reuse the model path from from_pretrained
+        if checkpoint is None:
+            checkpoint = self._name_or_path
+            raise ValueError("A `checkpoint` path must be provided to load the LLM, either directly or via `from_pretrained`.")
+
+        print("[Lumina] start loading LLaDA ...")
+        self.llm = LLaDAForMultiModalGeneration.from_pretrained(
+            checkpoint, torch_dtype=torch_dtype, trust_remote_code=True,
+            device_map=device_map,                   
+            low_cpu_mem_usage=low_cpu_mem_usage,    
+            use_safetensors=True,
+        )
+        print("   LlaDA Loaded Successfully.")
+
+    @staticmethod
+    @torch.no_grad()
+    def generate_i2i(
+        model: LLaDAForMultiModalGeneration,
+        prompt: torch.LongTensor,
+        *,
+        seq_len: int = 1024,
+        newline_every: int = 16,
+        timesteps: int = 18,
+        mask_token_id: int = 126336,
+        newline_id: int = 126084,
+        temperature: float = 1.0,
+        cfg_scale: float = 0.0,
+        cfg_img: float = 0.0,
+        uncon_text: torch.LongTensor,
+        uncon_image: torch.LongTensor,
+        code_start: Optional[int] = None,
+        codebook_size: int = 8192,
+        noise_schedule: Callable[[torch.Tensor], torch.Tensor] = cosine_schedule,
+        text_vocab_size: Optional[int] = None,
+        generator: Optional[torch.Generator] = None,
+    ) -> torch.LongTensor:
+        """
+        Image-to-image MaskGit generation (supports CFG for text and image)
+        
+        Args:
+            model: Model
+            prompt: Prompt tensor
+            seq_len: Sequence length
+            newline_every: Newline interval per row
+            timesteps: Number of timesteps
+            mask_token_id: Mask token id
+            newline_id: Newline token id
+            temperature: Temperature
+            cfg_scale: Text CFG scale
+            cfg_img: Image CFG scale
+            code_start: Prediction image token satrt index
+            uncon_text: Unconditional text input
+            uncon_image: Unconditional image input
+            codebook_size: Codebook size
+            noise_schedule: Noise schedule function
+            text_vocab_size: Text vocabulary size
+            generator: Random number generator
+        
+        Returns:
+            Final VQ codes (1, seq_len)
+        """
+        device = next(model.parameters()).device
+        prompt = prompt.to(device)
+        B, P = prompt.shape
+        assert B == 1, "batch>1 not supported – wrap in loop if needed"
+        
+        x = prompt
+
+        vq_mask = x == mask_token_id
+        unknown_cnt = vq_mask.sum(dim=1, keepdim=True)
+        vq_len = unknown_cnt
+
+        # Infer text vocabulary size
+        if text_vocab_size is None:
+            vocab_total = model(torch.zeros(1, 1, dtype=torch.long, device=device), infer=True).logits.size(-1)
+            text_vocab_size = vocab_total - codebook_size
+        vocab_offset = text_vocab_size
+
+        for step in range(timesteps):
+            if unknown_cnt.item() == 0:
+                break
+
+            # Calculate number of tokens to keep (continue masking) this round
+            if step < timesteps - 1:
+                frac = noise_schedule(torch.tensor([(step + 1) / timesteps], device=device))
+                keep_n = (vq_len.float() * frac).floor().clamp_min(1).long()
+            else:
+                keep_n = torch.zeros_like(unknown_cnt)
+
+            # Forward pass (with/without CFG)
+            if cfg_scale > 0 or cfg_img > 0:
+                # CFG text
+                uncond_text = torch.cat((uncon_text.to(x.device), x[:, code_start-2:]), dim=1)
+                uncond_text_vq_mask = torch.cat((torch.zeros((1, uncon_text.size(1)), dtype=torch.bool, device=x.device), vq_mask[:, code_start-2:]), dim=1)
+                # CFG image
+                uncond_img = torch.cat((uncon_image.to(x.device), x[:, code_start-2:]), dim=1)
+                uncond_img_vq_mask = torch.cat((torch.zeros((1, uncon_image.size(1)), dtype=torch.bool, device=x.device), vq_mask[:, code_start-2:]), dim=1)
+
+                cond_logits = model(x, infer=True).logits[:, vq_mask[0], vocab_offset : vocab_offset + codebook_size]
+                uncond_logits_text = model(uncond_text, infer=True).logits[:, uncond_text_vq_mask[0], vocab_offset : vocab_offset + codebook_size]
+                uncond_logits_img = model(uncond_img, infer=True).logits[:, uncond_img_vq_mask[0], vocab_offset : vocab_offset + codebook_size]
+                logits = cond_logits + cfg_scale * (cond_logits - uncond_logits_text) + cfg_img * (cond_logits - uncond_logits_img)
+            else:
+                logits = model(x, infer=True).logits[:, vq_mask[0], vocab_offset : vocab_offset + codebook_size]
+
+            sampled = gumbel_max_sample(logits, temperature, generator=generator)
+            sampled_full = sampled + vocab_offset
+            probs = torch.softmax(logits, dim=-1)
+            conf = probs.gather(-1, sampled.unsqueeze(-1)).squeeze(-1)
+
+            flat_idx = vq_mask.nonzero(as_tuple=False)[:, 1]
+            x.view(-1)[flat_idx] = sampled_full.view(-1)
+
+            conf_map = torch.full_like(x, -math.inf, dtype=probs.dtype)
+            conf_map.view(-1)[flat_idx] = conf.view(-1)
+
+            mask_sel = mask_by_random_topk(keep_n.squeeze(1), conf, temperature=temperature, generator=generator)
+            x.view(-1)[flat_idx[mask_sel.view(-1)]] = mask_token_id
+            vq_mask = x == mask_token_id
+            unknown_cnt = vq_mask.sum(dim=1, keepdim=True)
+
+        # Remove newline tokens
+        vq_ids = x[0, code_start:-2]
+        vq_ids = vq_ids[vq_ids != newline_id].view(1, seq_len)
+        return vq_ids
+
+
+    @staticmethod
+    @torch.no_grad()
+    def generate_image(
+        model: LLaDAForMultiModalGeneration,
+        prompt: torch.LongTensor,
+        *,
+        seq_len: int = 1024,
+        newline_every: int = 16,
+        timesteps: int = 18,
+        mask_token_id: int = 126336,
+        newline_id: int = 126084,
+        temperature: float = 1.0,
+        cfg_scale: float = 0.0,
+        uncon_ids: torch.LongTensor,
+        code_start: Optional[int] = None,
+        codebook_size: int = 8192,
+        noise_schedule: Callable[[torch.Tensor], torch.Tensor] = cosine_schedule,
+        text_vocab_size: Optional[int] = None,
+        generator: Optional[torch.Generator] = None,
+        use_cache=True,
+        cache_ratio=0.9,
+        refresh_interval=5,
+        warmup_ratio=0.3
+    ) -> torch.LongTensor:
+        """
+        MaskGit parallel decoding to generate VQ tokens
+        
+        Args:
+            model: Model
+            prompt: Prompt tensor
+            seq_len: Sequence length
+            newline_every: Newline interval per row
+            timesteps: Number of timesteps
+            mask_token_id: Mask token id
+            newline_id: Newline token id
+            temperature: Temperature
+            cfg_scale: CFG scale
+            uncon_ids: Unconditional input
+            code_start: Image token satrt index
+            codebook_size: Codebook size
+            noise_schedule: Noise schedule function
+            text_vocab_size: Text vocabulary size
+            generator: Random number generator
+        
+        Returns:
+            Final VQ codes (1, seq_len)
+        """
+
+
+        device = next(model.parameters()).device
+        prompt = prompt.to(device)
+        B, P = prompt.shape
+        assert B == 1, "batch>1 not supported – wrap in loop if needed"
+
+        x = prompt
+        
+        vq_mask = x == mask_token_id
+        unknown_cnt = vq_mask.sum(dim=1, keepdim=True)
+        vq_len = unknown_cnt
+
+        if isinstance(model, LLaDAForMultiModalGeneration):
+            model.caching(use_cache)
+        else:  # DDP
+            model.module.caching(use_cache)
+
+        warmup_step = int(timesteps * warmup_ratio)
+        refresh_steps = torch.zeros(timesteps, dtype=torch.bool)
+        for step in range(timesteps):
+            if not use_cache or step <= warmup_step or (step-warmup_step) % refresh_interval == 0:
+                refresh_steps[step] = True
+        compute_ratio = 1 - cache_ratio
+
+        # Infer text vocabulary size
+        if text_vocab_size is None:
+            vocab_total = model(torch.zeros(1, 1, dtype=torch.long, device=device), infer=True).logits.size(-1)
+            text_vocab_size = vocab_total - codebook_size
+        vocab_offset = text_vocab_size
+
+        for step in range(timesteps):
+            if unknown_cnt.item() == 0:
+                break
+
+            # Calculate number of tokens to keep (continue masking) this round
+            if step < timesteps - 1:
+                frac = noise_schedule(torch.tensor([(step + 1) / timesteps], device=device))
+                keep_n = (vq_len.float() * frac).floor().clamp_min(1).long()
+            else:
+                keep_n = torch.zeros_like(unknown_cnt)
+
+            if use_cache and step and refresh_steps[step]:
+                if isinstance(model, LLaDAForMultiModalGeneration):
+                    model.empty_cache()
+                else:  # DDP
+                    model.module.empty_cache()
+
+            # Forward pass (with/without CFG)
+            if cfg_scale > 0:
+                import time
+                t0 = time.time()
+                uncond = torch.cat((uncon_ids.to(x.device), x[:, code_start-2:]), axis=1)
+                uncond_vq_mask = torch.cat((torch.zeros((1, uncon_ids.size()[1]), dtype=torch.bool).to(x.device), vq_mask[:, code_start-2:]), axis=1)
+                cond_logits = model(x, infer=True,
+                        cat='cond', use_cache=use_cache, 
+                        to_compute_mask = cond_to_compute_mask if not refresh_steps[step] else None,
+                    ).logits[..., vocab_offset : vocab_offset + codebook_size]
+                cond_mask_logits = cond_logits[vq_mask].view(B, -1, codebook_size)
+                uncond_logits = model(uncond, infer=True,
+                        cat='uncond', use_cache=use_cache, 
+                        to_compute_mask = uncond_to_compute_mask if not refresh_steps[step] else None
+                    ).logits[..., vocab_offset : vocab_offset + codebook_size]
+                uncond_mask_logits = uncond_logits[uncond_vq_mask].view(B, -1, codebook_size)
+                logits = (1 + cfg_scale) * cond_mask_logits - cfg_scale * uncond_mask_logits
+            else:
+                logits = model(x, infer=True).logits[:, vq_mask[0], vocab_offset : vocab_offset + codebook_size]
+
+            sampled = gumbel_max_sample(logits, temperature, generator=generator)
+            sampled_full = sampled + vocab_offset
+            probs = torch.softmax(logits, dim=-1)
+            conf = probs.gather(-1, sampled.unsqueeze(-1)).squeeze(-1)
+
+            flat_idx = vq_mask.nonzero(as_tuple=False)[:, 1]
+            x.view(-1)[flat_idx] = sampled_full.view(-1)
+
+            conf_map = torch.full_like(x, -math.inf, dtype=probs.dtype)
+            conf_map.view(-1)[flat_idx] = conf.view(-1)
+
+            mask_sel = mask_by_random_topk(keep_n.squeeze(1), conf, temperature=temperature, generator=generator)
+            x.view(-1)[flat_idx[mask_sel.view(-1)]] = mask_token_id
+            vq_mask = x == mask_token_id
+            unknown_cnt = vq_mask.sum(dim=1, keepdim=True)
+
+            if use_cache and step < timesteps - 1 and not refresh_steps[step+1]:
+                cond_conf = cond_logits.max(dim=-1)[0]
+                cond_conf_threshold = torch.quantile(cond_conf.to(torch.float), compute_ratio, dim=-1, keepdim=True)
+                cond_to_compute_mask = cond_conf <= cond_conf_threshold
+
+                uncond_conf = uncond_logits.max(dim=-1)[0]
+                uncond_conf_threshold = torch.quantile(uncond_conf.to(torch.float), compute_ratio, dim=-1, keepdim=True)
+                uncond_to_compute_mask = uncond_conf <= uncond_conf_threshold
+                
+        # Remove newline tokens
+        vq_ids = x[0, code_start:-2]
+        vq_ids = vq_ids[vq_ids != newline_id].view(1, seq_len)
+        return vq_ids
+
+
+    @staticmethod
+    @torch.no_grad()
+    def generate_text_understanding(
+        model: LLaDAForMultiModalGeneration,
+        prompt, 
+        steps=128, 
+        gen_length=128, 
+        block_length=128, 
+        temperature=0.,
+        cfg_scale=0., 
+        remasking='low_confidence', 
+        mask_id=126336, 
+        code_start: Optional[int] = None,
+    ):
+        """
+        Text understanding generation function
+        
+        Args:
+            model: Mask predictor
+            prompt: Input prompt tensor (1, L)
+            steps: Sampling steps, less than or equal to gen_length
+            gen_length: Generated answer length
+            block_length: Block length, less than or equal to gen_length
+            temperature: Categorical distribution sampling temperature
+            cfg_scale: Unsupervised classifier-free guidance scale
+            remasking: Remasking strategy 'low_confidence' or 'random'
+            mask_id: The token id of [MASK] is 126336
+            code_start: Prediction text token satrt index
+        """
+        device = next(model.parameters()).device
+
+        x = prompt
+
+        prompt_index = (x != mask_id)
+
+        assert gen_length % block_length == 0
+        num_blocks = gen_length // block_length
+
+        assert steps % num_blocks == 0
+        steps = steps // num_blocks
+
+        for num_block in range(num_blocks):
+            block_mask_index = (x[:, code_start + num_block * block_length: code_start + (num_block + 1) * block_length:] == mask_id)
+            num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
+            
+            for i in range(steps):
+                mask_index = (x == mask_id)
+                if cfg_scale > 0.:
+                    un_x = x.clone()
+                    un_x[prompt_index] = mask_id
+                    x_ = torch.cat([x, un_x], dim=0)
+                    logits = model(x_, infer=True).logits
+                    logits, un_logits = torch.chunk(logits, 2, dim=0)
+                    logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
+                else:
+                    logits = model(x, infer=True).logits
+
+                logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
+                x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
+
+                if remasking == 'low_confidence':
+                    p = F.softmax(logits.to(torch.float64), dim=-1)
+                    x0_p = torch.squeeze(
+                        torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
+                elif remasking == 'random':
+                    x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
+                else:
+                    raise NotImplementedError(remasking)
+
+                x0_p[:, code_start + (num_block + 1) * block_length:] = -np.inf
+
+                x0 = torch.where(mask_index, x0, x)
+                confidence = torch.where(mask_index, x0_p, -np.inf)
+
+                transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
+                for j in range(confidence.shape[0]):
+                    _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
+                    transfer_index[j, select_index] = True
+                x[transfer_index] = x0[transfer_index]
+            
+        
+        return x
+
+
+
+    @torch.no_grad()
+    def _image_to_image(
+        self,
+        prompt: str,
+        image: Union[Image.Image, str],
+        ref_image: Optional[PipelineImageInput] = None,
+        edit_type: str = "canny_pred",
+        num_inference_steps: int = 64,
+        temperature: float = 1.0,
+        cfg_scale: float = 2.5,
+        cfg_img: float = 4.0,
+        output_type: Optional[str] = "pil",
+    ):
+        
+        if isinstance(prompt, list):
+            raise ValueError("Batching is not supported for this pipeline.")
+
+        if isinstance(image, str):
+            image = Image.open(image).convert("RGB")
+        if isinstance(ref_image, str):
+            ref_image = Image.open(ref_image).convert("RGB")
+
+        input_prompt, uncon_text = generate_image_to_image_prompt(prompt, edit_type, self.prompt_templates)
+
+        crop_size_list = generate_crop_size_list((512 // 32) ** 2, 32)
+        
+        # Correctly encode input images with newline tokens
+        if "image_ref_transfer" in edit_type:
+            if ref_image is None:
+                raise ValueError("`ref_image` must be provided for `image_ref_transfer` edit type.")
+            processed_img = var_center_crop(image, crop_size_list=crop_size_list)
+            input_img_token = encode_img_with_breaks(processed_img, self.vqvae, self.special_tokens)
+
+            referring_img = var_center_crop(ref_image, crop_size_list=crop_size_list)
+            referring_img_token = encode_img_with_breaks(referring_img, self.vqvae, self.special_tokens)
+
+            image_width, image_height = referring_img.size
+            seq_len, newline_every, token_grid_height, token_grid_width = calculate_vq_params(
+                referring_img.height, referring_img.width, self.vae_scale_factor
+            )
+        else:
+            processed_img = var_center_crop(image, crop_size_list=crop_size_list)
+            input_img_token = encode_img_with_breaks(processed_img, self.vqvae, self.special_tokens)
+            image_width, image_height = processed_img.size
+            seq_len, newline_every, token_grid_height, token_grid_width = calculate_vq_params(
+                processed_img.height, processed_img.width, self.vae_scale_factor
+            )
+
+        prompt_ids = self.tokenizer(input_prompt)["input_ids"]
+        uncon_text_ids = self.tokenizer(uncon_text)["input_ids"]
+
+        img_mask_token = add_break_line(
+            [self.special_tokens["mask_token"]] * seq_len,
+            token_grid_height,
+            token_grid_width,
+            new_number=self.special_tokens["newline_token"],
+        )
+        img_pred_token = (
+            [self.special_tokens["boa"]]
+            + [self.special_tokens["boi"]]
+            + img_mask_token
+            + [self.special_tokens["eoi"]]
+            + [self.special_tokens["eoa"]]
+        )
+
+        if "image_ref_transfer" in edit_type:
+            con_input = prompt_ids[:-1] + input_img_token + referring_img_token + prompt_ids[-1:]
+            uncon_input_text = uncon_text_ids[:-1] + input_img_token + referring_img_token + uncon_text_ids[-1:]
+        else:
+            con_input = prompt_ids[:-1] + input_img_token + prompt_ids[-1:]
+            uncon_input_text = uncon_text_ids[:-1] + input_img_token + uncon_text_ids[-1:]
+        uncon_input_image = prompt_ids
+
+        code_start = len(con_input) + 2
+
+        con_input = torch.tensor(con_input + img_pred_token, device=self.device).unsqueeze(0)
+        uncon_input_text = torch.tensor(uncon_input_text, device=self.device).unsqueeze(0)
+        uncon_input_image = torch.tensor(uncon_input_image, device=self.device).unsqueeze(0)
+
+        vq_tokens = self.generate_i2i(
+            self.llm,
+            con_input,
+            seq_len=seq_len,
+            newline_every=newline_every,
+            timesteps=num_inference_steps,
+            temperature=temperature,
+            cfg_scale=cfg_scale,
+            cfg_img=cfg_img,
+            uncon_text=uncon_input_text,
+            uncon_image=uncon_input_image,
+            code_start=code_start
+        )
+
+        if vq_tokens.shape[1] != token_grid_height * token_grid_width:
+            raise ValueError(
+                f"VQ codes length mismatch: {vq_tokens.shape[1]} != {token_grid_height * token_grid_width} "
+                f"for image size ({image_height},{image_width}) with scale {self.vae_scale_factor}"
+            )
+
+        latents = (
+            vq_tokens.view(1, token_grid_height, token_grid_width).to(self.vqvae.device) - self.special_tokens["image_token_offset"]
+        ).long()
+
+        shape = (1, token_grid_height, token_grid_width, self.vqvae.config.latent_channels)
+
+        recon = self.vqvae.decode(
+            latents,
+            force_not_quantize=True,
+            shape=shape,
+        ).sample.clip(0, 1)
+
+        img_proc = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
+        image = img_proc.postprocess(recon.detach(), output_type=output_type)
+
+        return image
+
+    @torch.no_grad()
+    def _text_to_image(
+        self,
+        prompt: str,
+        height: int,
+        width: int,
+        painting_mode: Optional[str] = None,
+        painting_image: Optional[PipelineImageInput] = None,
+        cfg_scale: float = 4.0,
+        use_cache: bool = True,
+        cache_ratio: float = 0.9,
+        refresh_interval: int = 5,
+        warmup_ratio: float = 0.3,
+        num_inference_steps: int = 64,
+        temperature: float = 1.0,
+        mask_h_ratio: float = 1.0,
+        mask_w_ratio: float = 0.2
+    ):
+        if isinstance(painting_image, str):
+            painting_image = Image.open(painting_image)
+
+        if painting_mode and painting_image:
+            width, height = painting_image.size
+
+        seq_len, newline_every, token_grid_height, token_grid_width = calculate_vq_params(height, width, self.vae_scale_factor)
+
+        input_prompt, uncon_prompt = generate_text_to_image_prompt(prompt, self.prompt_templates)
+
+        con_prompt_token = self.tokenizer(input_prompt)["input_ids"]
+        uncon_prompt_token = self.tokenizer(uncon_prompt)["input_ids"]
+
+        if painting_mode:
+            img_mask_token, img_vis = encode_img_with_paint(
+                painting_image,
+                vqvae=self.vqvae,
+                mask_h_ratio=mask_h_ratio,
+                mask_w_ratio=mask_w_ratio,
+                mask_mode=painting_mode,
+                special_tokens=self.special_tokens,
+            )
+        else:
+            img_mask_token = add_break_line(
+                [self.special_tokens["mask_token"]] * seq_len,
+                token_grid_height,
+                token_grid_width,
+                new_number=self.special_tokens["newline_token"],
+            )
+
+        img_pred_token = (
+            [self.special_tokens["boa"]]
+            + [self.special_tokens["boi"]]
+            + img_mask_token
+            + [self.special_tokens["eoi"]]
+            + [self.special_tokens["eoa"]]
+        )
+
+        prompt_ids = torch.tensor(con_prompt_token + img_pred_token, device=self.device).unsqueeze(0)
+        uncon_ids = torch.tensor(uncon_prompt_token, device=self.device).unsqueeze(0)
+
+        code_start = len(con_prompt_token) + 2
+
+        vq_tokens = self.generate_image(
+            model=self.llm,
+            prompt=prompt_ids,
+            seq_len=seq_len,
+            newline_every=newline_every,
+            timesteps=num_inference_steps,
+            temperature=temperature,
+            cfg_scale=cfg_scale,
+            uncon_ids=uncon_ids,
+            code_start=code_start,
+            use_cache=use_cache,
+            cache_ratio=cache_ratio,
+            refresh_interval=refresh_interval,
+            warmup_ratio=warmup_ratio
+        )
+
+        latents = (
+            vq_tokens.view(1, token_grid_height, token_grid_width).to(self.vqvae.device) - self.special_tokens["image_token_offset"]
+        ).long()
+
+        shape = (1, token_grid_height, token_grid_width, self.vqvae.config.latent_channels)
+        recon = self.vqvae.decode(latents, force_not_quantize=True, shape=shape).sample.clip(0, 1)
+
+        img_proc = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False)
+        image = img_proc.postprocess(recon.detach(), output_type="pil")
+
+        return image
+
+    @torch.no_grad()
+    def _multimodal_understanding(
+        self,
+        prompt: str,
+        image: PipelineImageInput,
+        num_inference_steps: int = 128,
+        gen_length: int = 1024,
+        block_length: int = 128,
+        temperature: float = 0.0,
+        cfg_scale: float = 0.0,
+        remasking: str = "low_confidence",
+    ):
+
+        if isinstance(image, str):
+            image = Image.open(image)
+
+        input_prompt = generate_multimodal_understanding_prompt(prompt)
+        input_ids = self.tokenizer(input_prompt)["input_ids"]
+
+        crop_size_list = generate_crop_size_list((1024 // 32) ** 2, 32)
+        processed_image = var_center_crop(image, crop_size_list=crop_size_list)
+        
+        image_width, image_height = processed_image.size
+        seq_len, newline_every, token_grid_height, token_grid_width = calculate_vq_params(
+            image_height, image_width, self.vae_scale_factor
+        )
+
+        input_img_token = encode_img_with_breaks(processed_image, self.vqvae, self.special_tokens)
+
+        input_token = input_ids[:-1] + input_img_token + input_ids[-1:]
+        code_start = len(input_token) + 1
+
+        input_token = input_token + [self.special_tokens["boa"]] + gen_length * [self.special_tokens["mask_token"]] + [self.special_tokens["eoa"]]
+        input_ids = torch.tensor(input_token, device=self.device).unsqueeze(0)
+
+        output_tokens = self.generate_text_understanding(
+            model=self.llm,
+            prompt=input_ids,
+            steps=num_inference_steps,
+            gen_length=gen_length,
+            block_length=block_length, 
+            cfg_scale=cfg_scale, 
+            temperature=temperature,
+            remasking=remasking,
+            code_start=code_start
+        )
+
+        generated_text = self.tokenizer.batch_decode(output_tokens[:, code_start:-1], skip_special_tokens=True)[0]
+        return generated_text
+
+    @torch.no_grad()
+    @replace_example_docstring(EXAMPLE_DOC_STRING)
+    def __call__(
+        self,
+        prompt: str,
+        image: Optional[PipelineImageInput] = None,
+        task: str = "auto",
+        **kwargs,
+    ) -> LuminaDiMOOPipelineOutput:
+        r"""
+        Unified entry for 'text_to_image' | 'image_to_image' | 'multimodal_understanding'.
+
+        Examples:
+        {EXAMPLE_DOC_STRING}
+        """
+        if task == "auto":
+            if image is None:
+                task = "text_to_image"
+            elif "edit_type" in kwargs:
+                task = "image_to_image"
+            else:
+                task = "multimodal_understanding"
+
+        if task == "text_to_image":
+            # Default values from inference_t2i.py
+            t2i_kwargs = {
+                "height": kwargs.pop("height", 1024),
+                "width": kwargs.pop("width", 1024),
+                "num_inference_steps": kwargs.pop("num_inference_steps", 64), 
+                "cfg_scale": kwargs.pop("cfg_scale", 4.0), 
+                "temperature": kwargs.pop("temperature", 1.0),
+                "painting_mode": kwargs.pop("painting_mode", None),
+                "painting_image": kwargs.pop("painting_image", None),
+                "mask_h_ratio": kwargs.pop("mask_h_ratio", 1.0),
+                "mask_w_ratio": kwargs.pop("mask_w_ratio", 0.2),
+                "use_cache": kwargs.pop("use_cache", True),
+                "cache_ratio": kwargs.pop("cache_ratio", 0.9),
+                "refresh_interval": kwargs.pop("refresh_interval", 5),
+                "warmup_ratio": kwargs.pop("warmup_ratio", 0.3),
+            }
+            images = self._text_to_image(prompt=prompt, **t2i_kwargs)
+            return LuminaDiMOOPipelineOutput(images=images, text=None)
+
+        elif task == "image_to_image":
+            if image is None:
+                raise ValueError("`image` must be provided for image_to_image task.")
+            i2i_kwargs = {
+                "ref_image": kwargs.pop("ref_image", None),
+                "edit_type": kwargs.pop("edit_type", "canny_pred"),
+                "num_inference_steps": kwargs.pop("num_inference_steps", 64), 
+                "temperature": kwargs.pop("temperature", 1.0),
+                "cfg_scale": kwargs.pop("cfg_scale", 2.5), 
+                "cfg_img": kwargs.pop("cfg_img", 4.0),
+            }
+            images = self._image_to_image(prompt=prompt, image=image, **i2i_kwargs)
+            return LuminaDiMOOPipelineOutput(images=images, text=None)
+
+        elif task == "multimodal_understanding":
+            if image is None:
+                raise ValueError("`image` must be provided for multimodal_understanding task.")
+            mmu_kwargs = {
+                "num_inference_steps": kwargs.pop("num_inference_steps", 128), 
+                "gen_length": kwargs.pop("gen_length", 1024),
+                "block_length": kwargs.pop("block_length", 256),
+                "temperature": kwargs.pop("temperature", 0.0),
+                "cfg_scale": kwargs.pop("cfg_scale", 0.0),
+                "remasking": kwargs.pop("remasking", "low_confidence"),
+            }
+            text = self._multimodal_understanding(prompt=prompt, image=image, **mmu_kwargs)
+            return LuminaDiMOOPipelineOutput(images=None, text=text)
+
+        else:
+            raise ValueError(f"Unknown task: {task}. Supported tasks are 'text_to_image', 'image_to_image', 'multimodal_understanding', and 'auto'.")