diff --git a/neuracore-dictionary.txt b/neuracore-dictionary.txt index df233361..a1a3dbab 100644 --- a/neuracore-dictionary.txt +++ b/neuracore-dictionary.txt @@ -201,6 +201,9 @@ UNITREE unitreeh1 URDF usefixtures +chonk +dinov +dinov2 Vaswani vertadr vertnum @@ -218,6 +221,28 @@ ylabel znear bigym secho +adarms +ADARMS +meanpooling +colwise +rowwise +autocast +broadcastable +seqlen +layernorm +attns +torchdynamo +llava +Llava +triu +gptj +GPTJ +erfinv +lecun +CLIPMLP +altclip +loglik +logsigmoid xyzw wxyz nans diff --git a/neuracore/ml/algorithms/diffusion_policy/diffusion_policy.py b/neuracore/ml/algorithms/diffusion_policy/diffusion_policy.py index 94c5d97d..f48dc511 100644 --- a/neuracore/ml/algorithms/diffusion_policy/diffusion_policy.py +++ b/neuracore/ml/algorithms/diffusion_policy/diffusion_policy.py @@ -596,7 +596,6 @@ def training_step(self, batch: BatchedTrainingSamples) -> BatchedTrainingOutputs action_data = torch.cat(action_targets, dim=-1) # (B, T, total_action_dim) target_actions = self.action_normalizer.normalize(action_data) - target_actions = target_actions # Sample noise to add to the trajectory. eps = torch.randn(target_actions.shape, device=target_actions.device) diff --git a/neuracore/ml/algorithms/pi0/__init__.py b/neuracore/ml/algorithms/pi0/__init__.py index 14e8999c..cbf80c4d 100644 --- a/neuracore/ml/algorithms/pi0/__init__.py +++ b/neuracore/ml/algorithms/pi0/__init__.py @@ -1 +1,118 @@ -"""Init.""" +"""PI0 algorithm with transformers patching. + +Automatically patches the installed transformers library with custom modifications +required by PI0. This eliminates the need to manually copy files into the transformers +installation directory. + +The patching includes: +- Gemma model with Adaptive RMSNorm support +- Gated residual connections for Gemma modeling +- Custom PaliGemma and SigLIP modifications +- Python 3.10 UnionType annotation support for transformers docs +""" + +# cspell:ignore adarms +import logging +import shutil +from pathlib import Path + +logger = logging.getLogger(__name__) + + +def check_whether_transformers_replace_is_installed_correctly() -> bool: + """Check whether transformers has been patched with PI0 modifications. + + Verifies that the installed `transformers` library has been patched by checking + for custom attributes and functions that are not present in upstream. + + Returns: + True if patches are detected, False otherwise. + """ + try: + from transformers.models.gemma import modeling_gemma + from transformers.models.gemma.configuration_gemma import GemmaConfig + + cfg = GemmaConfig() + if not hasattr(cfg, "use_adarms"): + return False + if not hasattr(modeling_gemma, "_gated_residual"): + return False + return True + except Exception: + return False + + +def _patch_transformers_args_doc() -> None: + """Patch transformers args_doc to handle Python 3.10 UnionType annotations. + + Fixes documentation generation errors caused by UnionType syntax + (e.g., `int | str`). The patch is applied once and marked to prevent + re-patching. + """ + try: + import inspect + import re + import types + from collections.abc import Callable + from typing import Any, get_args + + from transformers.utils import args_doc + + if getattr(args_doc, "_UNIONTYPE_PATCHED", False): + return + + original = args_doc._process_parameter_type + + def _process_parameter_type( + param: inspect.Parameter, param_name: str, func: Callable[..., Any] + ) -> tuple[str, bool]: + if param.annotation != inspect.Parameter.empty and isinstance( + param.annotation, types.UnionType + ): + param_type = str(param.annotation).replace("transformers.", "~") + optional = any(arg is type(None) for arg in get_args(param.annotation)) + if "ForwardRef" in param_type: + param_type = re.sub(r"ForwardRef\('([\w.]+)'\)", r"\1", param_type) + if "Optional" in param_type: + param_type = re.sub(r"Optional\[(.*?)\]", r"\1", param_type) + optional = True + return param_type, optional + return original(param, param_name, func) + + args_doc._process_parameter_type = _process_parameter_type + args_doc._UNIONTYPE_PATCHED = True + except Exception: + return + + +def _patch_transformers() -> None: + """Automatically patch transformers with custom modifications. + + Checks if patching is needed, then copies files from transformers_replace/ + to the installed transformers library. The process is idempotent and works + across different installation methods. + + Raises: + ValueError: If patching fails due to permission issues. + """ + if check_whether_transformers_replace_is_installed_correctly(): + return # Already patched + else: + logger.info("Transformers not patched; attempting to patch now.") + + try: + import transformers + + src = Path(__file__).parent / "transformers_replace" + dst = Path(transformers.__file__).parent + if src.exists(): + for f in src.rglob("*.py"): + target = dst / f.relative_to(src) + target.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(f, target) + except Exception: + raise ValueError("Failed to patch transformers because of permission issues") + + +_patch_transformers() +_patch_transformers_args_doc() diff --git a/neuracore/ml/algorithms/pi0/gemma_pytorch.py b/neuracore/ml/algorithms/pi0/gemma_pytorch.py new file mode 100644 index 00000000..456432a7 --- /dev/null +++ b/neuracore/ml/algorithms/pi0/gemma_pytorch.py @@ -0,0 +1,505 @@ +"""Minimal Gemma/PaliGemma helpers for PI0.""" + +# cspell:ignore adarms layernorm +from dataclasses import dataclass +from typing import Literal + +import torch +import torch.nn as nn +from transformers.models.auto import CONFIG_MAPPING +from transformers.models.gemma import modeling_gemma +from transformers.models.gemma.modeling_gemma import GemmaForCausalLM +from transformers.models.paligemma.modeling_paligemma import ( + PaliGemmaForConditionalGeneration, +) + + +def compute_shared_attention_layer( + layer_idx: int, + inputs_embeds: list[torch.Tensor], + attention_mask: torch.Tensor, + position_ids: torch.Tensor, + adarms_cond: list[torch.Tensor | None], + paligemma: PaliGemmaForConditionalGeneration, + gemma_expert: GemmaForCausalLM, +) -> list[torch.Tensor]: + """Run a single transformer layer jointly across prefix/suffix branches. + + This function performs shared attention computation between the PaliGemma + vision-language model and the Gemma action expert model. It concatenates + the query, key, and value states from both models, applies rotary positional + embeddings, computes attention, and then separates the outputs back to each + model's branch. + + Args: + layer_idx: Index of the transformer layer to process. + inputs_embeds: List of two tensors containing hidden states for + [prefix (PaliGemma), suffix (Gemma expert)] branches. + attention_mask: Attention mask tensor of shape (batch_size, 1, seq_len, seq_len) + for masking padded or future tokens. + position_ids: Position indices tensor of shape (batch_size, seq_len) for + rotary positional embeddings. + adarms_cond: List of two optional conditioning tensors for adaptive RMS + normalization, one per branch. None values indicate no conditioning. + paligemma: The PaliGemma vision-language model instance. + gemma_expert: The Gemma action expert model instance. + + Returns: + List of two tensors containing the output hidden states for each branch + after attention and feed-forward processing. + """ + models = [paligemma.language_model, gemma_expert.model] + query_states = [] + key_states = [] + value_states = [] + gates = [] + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + if adarms_cond[i] is None: + hidden_states = layer.input_layernorm(hidden_states)[0] # noqa: PLW2901 + gate = None + else: + hidden_states, gate = layer.input_layernorm( + hidden_states, cond=adarms_cond[i] + ) # noqa: PLW2901 + gates.append(gate) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + query_state = ( + layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + ) + key_state = ( + layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + ) + value_state = ( + layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + ) + query_states.append(query_state) + key_states.append(key_state) + value_states.append(value_state) + query_states = torch.cat(query_states, dim=2) + key_states = torch.cat(key_states, dim=2) + value_states = torch.cat(value_states, dim=2) + dummy_tensor = torch.zeros( + query_states.shape[0], + query_states.shape[2], + query_states.shape[-1], + device=query_states.device, + dtype=query_states.dtype, + ) + cos, sin = paligemma.model.language_model.rotary_emb(dummy_tensor, position_ids) + query_states, key_states = modeling_gemma.apply_rotary_pos_emb( + query_states, key_states, cos, sin, unsqueeze_dim=1 + ) + batch_size = query_states.shape[0] + scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling + att_output, _ = modeling_gemma.eager_attention_forward( + paligemma.language_model.layers[layer_idx].self_attn, + query_states, + key_states, + value_states, + attention_mask, + scaling, + ) + head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim + num_heads = paligemma.language_model.layers[ + layer_idx + ].self_attn.config.num_attention_heads + att_output = att_output.reshape(batch_size, -1, num_heads * head_dim) + outputs_embeds = [] + start_pos = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = models[i].layers[layer_idx] + end_pos = start_pos + hidden_states.shape[1] + if att_output.dtype != layer.self_attn.o_proj.weight.dtype: + att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) + out_emb = layer.self_attn.o_proj(att_output[:, start_pos:end_pos]) + out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) + after_first_residual = out_emb.clone() + if adarms_cond[i] is None: + out_emb = layer.post_attention_layernorm(out_emb)[0] + gate = None + else: + out_emb, gate = layer.post_attention_layernorm(out_emb, cond=adarms_cond[i]) + if layer.mlp.up_proj.weight.dtype == torch.bfloat16: + out_emb = out_emb.to(dtype=torch.bfloat16) + out_emb = layer.mlp(out_emb) + out_emb = modeling_gemma._gated_residual(after_first_residual, out_emb, gate) + outputs_embeds.append(out_emb) + start_pos = end_pos + return outputs_embeds + + +@dataclass(slots=True) +class GemmaConfig: + """Configuration for Gemma model variants. + + Attributes: + variant: Model variant name (e.g., "gemma_300m", "gemma_2b", "gemma_tiny"). + width: Hidden size of the model (embedding dimension). + depth: Number of transformer decoder layers. + mlp_dim: Intermediate size of the feed-forward network. + num_heads: Number of attention heads. + num_kv_heads: Number of key-value heads for grouped query attention. + head_dim: Dimension of each attention head. + """ + + variant: str + width: int + depth: int + mlp_dim: int + num_heads: int + num_kv_heads: int + head_dim: int + + +def get_gemma_config(variant: str) -> GemmaConfig: + """Return the GemmaConfig for a known Gemma model variant. + + Provides predefined configurations for different Gemma model sizes used + in the PI0 architecture. + + Args: + variant: The model variant name. Supported values are: + - "gemma_300m": Small 300M parameter model with width=1024, depth=18. + - "gemma_2b": Larger 2B parameter model with width=2048, depth=18. + - "gemma_tiny": Tiny model for testing with width=16, depth=4. + + Returns: + A GemmaConfig dataclass containing the model hyperparameters. + + Raises: + ValueError: If the variant name is not recognized. + """ + if variant == "gemma_300m": + return GemmaConfig( + variant="gemma_300m", + width=1024, + depth=18, + mlp_dim=4096, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + if variant == "gemma_2b": + return GemmaConfig( + variant="gemma_2b", + width=2048, + depth=18, + mlp_dim=16_384, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + if variant == "gemma_tiny": + return GemmaConfig( + variant="gemma_tiny", + width=16, + depth=4, + mlp_dim=64, + num_heads=2, + num_kv_heads=1, + head_dim=16, + ) + raise ValueError(f"Unknown variant: {variant}") + + +class PaliGemmaWithExpertModel(nn.Module): + """PaliGemma model with action expert for PI0.""" + + def __init__( + self, + vlm_config: GemmaConfig, + action_expert_config: GemmaConfig, + use_adarms: tuple[bool, bool] | None = None, + precision: Literal["bfloat16", "float32"] = "bfloat16", + ) -> None: + """Initialize the joint vision-language and action expert model. + + Creates a PaliGemma vision-language model alongside a separate Gemma + action expert model. These models share attention computation during + forward passes for efficient multi-modal reasoning in the PI0 architecture. + + Args: + vlm_config: Configuration for the vision-language model (PaliGemma), + specifying dimensions, depth, and attention parameters. + action_expert_config: Configuration for the action expert model (Gemma), + which processes action-related tokens. + use_adarms: Optional tuple of two booleans indicating whether to use + adaptive RMS normalization for (VLM, action expert) respectively. + Defaults to (False, False) if not provided. + precision: Model precision, either "bfloat16" for mixed precision + training/inference or "float32" for full precision. Defaults to + "bfloat16". + """ + if use_adarms is None: + use_adarms = (False, False) + super().__init__() + + paligemma_config = CONFIG_MAPPING["paligemma"]() + paligemma_config._vocab_size = 257152 # noqa: SLF001 + paligemma_config.image_token_index = 257152 + paligemma_config.text_config.hidden_size = vlm_config.width + paligemma_config.text_config.intermediate_size = vlm_config.mlp_dim + paligemma_config.text_config.num_attention_heads = vlm_config.num_heads + paligemma_config.text_config.head_dim = vlm_config.head_dim + paligemma_config.text_config.num_hidden_layers = vlm_config.depth + paligemma_config.text_config.num_key_value_heads = vlm_config.num_kv_heads + paligemma_config.text_config.hidden_activation = "gelu_pytorch_tanh" + paligemma_config.text_config.torch_dtype = "float32" + paligemma_config.text_config.vocab_size = 257152 + paligemma_config.text_config.use_adarms = use_adarms[0] + paligemma_config.text_config.adarms_cond_dim = ( + vlm_config.width if use_adarms[0] else None + ) + paligemma_config.vision_config.intermediate_size = 4304 + # Keep pretrained-compatible projection size for real models, but make + # tiny variant align vision/text dims to avoid concat shape mismatch. + if vlm_config.variant == "gemma_tiny": + paligemma_config.vision_config.projection_dim = vlm_config.width + else: + paligemma_config.vision_config.projection_dim = 2048 + paligemma_config.vision_config.projector_hidden_act = "gelu_fast" + paligemma_config.vision_config.torch_dtype = "float32" + + action_expert_config_gemma = CONFIG_MAPPING["gemma"]( + head_dim=action_expert_config.head_dim, + hidden_size=action_expert_config.width, + intermediate_size=action_expert_config.mlp_dim, + num_attention_heads=action_expert_config.num_heads, + num_hidden_layers=action_expert_config.depth, + num_key_value_heads=action_expert_config.num_kv_heads, + vocab_size=257152, + hidden_activation="gelu_pytorch_tanh", + torch_dtype="float32", + use_adarms=use_adarms[1], + adarms_cond_dim=action_expert_config.width if use_adarms[1] else None, + ) + + self.paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config) + self.gemma_expert = GemmaForCausalLM(config=action_expert_config_gemma) + self.gemma_expert.model.embed_tokens = None + self.to_bfloat16_for_selected_params(precision) + + def to_bfloat16_for_selected_params( + self, precision: Literal["bfloat16", "float32"] = "bfloat16" + ) -> None: + """Move parameters to bfloat16, keeping sensitive ones in float32. + + Converts most model parameters to bfloat16 for memory efficiency while + preserving numerical stability by keeping certain sensitive parameters + in float32. Parameters kept in float32 include vision embeddings + (patch and position) and layer normalization weights. + + Args: + precision: Target precision for the model. If "bfloat16", converts + most parameters while keeping sensitive ones in float32. If + "float32", keeps all parameters in full precision. + + Raises: + ValueError: If precision is not "bfloat16" or "float32". + """ + if precision == "bfloat16": + self.to(dtype=torch.bfloat16) + elif precision == "float32": + self.to(dtype=torch.float32) + return + else: + raise ValueError(f"Invalid precision: {precision}") + + params_to_keep_float32 = [ + "vision_tower.vision_model.embeddings.patch_embedding.weight", + "vision_tower.vision_model.embeddings.patch_embedding.bias", + "vision_tower.vision_model.embeddings.position_embedding.weight", + "input_layernorm", + "post_attention_layernorm", + "model.norm", + ] + + for name, param in self.named_parameters(): + if any(selector in name for selector in params_to_keep_float32): + param.data = param.data.to(dtype=torch.float32) + + def embed_image(self, image: torch.Tensor) -> torch.Tensor: + """Embed an input image using the vision tower. + + Processes raw image pixels through the PaliGemma vision encoder to + produce visual embeddings that can be concatenated with language tokens. + + Args: + image: Input image tensor of shape (batch_size, channels, height, width), + typically normalized to the expected range for the vision model. + + Returns: + Image embeddings tensor of shape (batch_size, num_patches, embed_dim) + representing the encoded visual features. + """ + return self.paligemma.model.get_image_features(image) + + def embed_language_tokens(self, tokens: torch.Tensor) -> torch.Tensor: + """Embed language tokens with the language model embedding layer. + + Converts discrete token IDs into continuous embeddings using the + PaliGemma language model's embedding table. + + Args: + tokens: Token IDs tensor of shape (batch_size, seq_len) containing + integer indices into the vocabulary. + + Returns: + Token embeddings tensor of shape (batch_size, seq_len, embed_dim) + representing the embedded language tokens. + """ + return self.paligemma.language_model.embed_tokens(tokens) + + def forward( + self, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: list[torch.FloatTensor] | None = None, + use_cache: bool | None = None, + adarms_cond: list[torch.Tensor | None] | None = None, + ) -> tuple[list[torch.Tensor | None], list[torch.FloatTensor] | None]: + """Forward pass for prefix (vision/lang) and suffix (action) branches. + + Processes input embeddings through the joint PaliGemma and Gemma expert + architecture. Supports three modes of operation: + 1. Prefix-only: When suffix embeddings are None, runs only PaliGemma. + 2. Suffix-only: When prefix embeddings are None, runs only Gemma expert. + 3. Joint: When both are provided, runs shared attention across both models. + + Args: + attention_mask: Optional mask tensor of shape + (batch_size, 1, seq_len, seq_len) to prevent attention to + certain positions (e.g., padding, future tokens). + position_ids: Optional position indices of shape (batch_size, seq_len) for + rotary positional embeddings. + past_key_values: Optional cached key-value states from previous forward + passes for efficient autoregressive generation. + inputs_embeds: List of two tensors [prefix_embeds, suffix_embeds] where + each is of shape (batch_size, seq_len, embed_dim). One can be None + to run only the other branch. + use_cache: Whether to return updated key-value cache for generation. + adarms_cond: Optional list of two conditioning tensors for adaptive RMS + normalization, one per branch [VLM_cond, expert_cond]. + + Returns: + A tuple containing: + - List of two output tensors [prefix_output, suffix_output], where each + is the final hidden states of shape (batch_size, seq_len, embed_dim). + One may be None if that branch was not computed. + - Updated past_key_values cache if use_cache=True and running prefix-only, + otherwise None. + + Raises: + ValueError: If inputs_embeds is None. + """ + adarms: list[torch.Tensor | None] + if adarms_cond is None: + adarms = [None, None] + else: + adarms = adarms_cond + if inputs_embeds is None: + raise ValueError("inputs_embeds must be provided") + if inputs_embeds[1] is None: + prefix_output = self.paligemma.language_model.forward( + inputs_embeds=inputs_embeds[0], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms[0], + ) + prefix_past_key_values = prefix_output.past_key_values + prefix_output = prefix_output.last_hidden_state + suffix_output = None + elif inputs_embeds[0] is None: + suffix_output = self.gemma_expert.model.forward( + inputs_embeds=inputs_embeds[1], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms[1], + ) + suffix_output = suffix_output.last_hidden_state + prefix_output = None + prefix_past_key_values = None + else: + models = [self.paligemma.language_model, self.gemma_expert.model] + num_layers = self.paligemma.config.text_config.num_hidden_layers + use_gradient_checkpointing = ( + hasattr(self.gemma_expert.model, "gradient_checkpointing") + and self.gemma_expert.model.gradient_checkpointing + and self.training + ) or ( + hasattr(self, "gradient_checkpointing") + and self.gradient_checkpointing + and self.training + ) + for layer_idx in range(num_layers): + if use_gradient_checkpointing: + inputs_embeds = torch.utils.checkpoint.checkpoint( + compute_shared_attention_layer, + layer_idx, + inputs_embeds, + attention_mask, + position_ids, + adarms, + use_reentrant=False, + preserve_rng_state=False, + paligemma=self.paligemma, + gemma_expert=self.gemma_expert, + ) + else: + inputs_embeds = compute_shared_attention_layer( + layer_idx, + inputs_embeds, + attention_mask, + position_ids, + adarms, + paligemma=self.paligemma, + gemma_expert=self.gemma_expert, + ) + + def compute_final_norms( + inputs_embeds: list[torch.Tensor], + adarms_cond: list[torch.Tensor | None], + ) -> list[torch.Tensor]: + """Apply final layer normalization to both model branches. + + Performs the final RMS normalization step after all transformer + layers have been processed. Supports optional adaptive conditioning. + + Args: + inputs_embeds: List of hidden state tensors for each branch, + each of shape (batch_size, seq_len, embed_dim). + adarms_cond: List of optional conditioning tensors for adaptive + RMS normalization, one per branch. + + Returns: + List of normalized output tensors, one per branch. + """ + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i]) + outputs_embeds.append(out_emb) + return outputs_embeds + + if use_gradient_checkpointing: + outputs_embeds = torch.utils.checkpoint.checkpoint( + compute_final_norms, + inputs_embeds, + adarms, + use_reentrant=False, + preserve_rng_state=False, + ) + else: + outputs_embeds = compute_final_norms(inputs_embeds, adarms) + + prefix_output = outputs_embeds[0] + suffix_output = outputs_embeds[1] + prefix_past_key_values = None + + return [prefix_output, suffix_output], prefix_past_key_values diff --git a/neuracore/ml/algorithms/pi0/modules.py b/neuracore/ml/algorithms/pi0/modules.py index 83d5a8e9..5f71b1d1 100644 --- a/neuracore/ml/algorithms/pi0/modules.py +++ b/neuracore/ml/algorithms/pi0/modules.py @@ -1,474 +1,732 @@ -"""Gemma MoE model with custom attention.""" +"""Core PyTorch modules for the PI0 algorithm. +This module implements the PI0Policy model that combines a PaliGemma +vision-language model with a Gemma action expert for robot manipulation. +The model uses flow matching to denoise action sequences conditioned on +visual observations and proprioceptive state. +""" + +# cspell:ignore OPENPI adarms layernorm silu huggingface openpi denoised + +from __future__ import annotations + +import logging import math from collections.abc import Callable -from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, TypeVar import torch -import torch.nn as nn -from transformers.cache_utils import DynamicCache -from transformers.modeling_flash_attention_utils import FlashAttentionKwargs -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS -from transformers.models.gemma.modeling_gemma import ( - Cache, - GemmaAttention, - GemmaConfig, - GemmaDecoderLayer, - GemmaRotaryEmbedding, - apply_rotary_pos_emb, - eager_attention_forward, +import torch.nn.functional as F +from safetensors.torch import load_file +from torch import Tensor, nn +from transformers.models.paligemma.modeling_paligemma import ( + PaliGemmaForConditionalGeneration, ) -from transformers.processing_utils import Unpack -from transformers.utils import logging +from transformers.utils import cached_file -logger = logging.get_logger(__name__) +from neuracore.ml.algorithms.pi0.gemma_pytorch import ( + PaliGemmaWithExpertModel, + get_gemma_config, +) +from neuracore.ml.algorithms.pi0.utils import ( + OPENPI_ATTENTION_MASK_VALUE, + PI0Config, + _align_mask_length, + _create_sinusoidal_pos_embedding, + _make_att_2d_masks, + _sample_beta, +) +T = TypeVar("T") -@dataclass -class MoeExpertConfig: - """Configuration for the MoE model.""" +logger = logging.getLogger(__name__) - hidden_size: int # aka width - intermediate_size: int - head_dim: int - num_attention_heads: int - num_key_value_heads: int - use_cache: bool = False - hidden_activation: str = "gelu_pytorch_tanh" +class PI0Policy(nn.Module): + """Core PI0 model combining PaliGemma VLM with Gemma action expert. -class CustomGemmaAttention(GemmaAttention): - """Multi-headed attention from 'Attention Is All You Need' paper. + This model processes visual observations and language through PaliGemma, + then uses a separate Gemma model as the action expert to predict + denoised action sequences via flow matching. - Note this is a replica of the GemmaAttention module from the Hugging Face. - We have to replicate it here to be able to modify the forward pass, - and expose the query, key, and value states for the mixed attention. + The architecture supports gradient checkpointing and torch.compile + optimization for efficient training and inference. """ - def forward( - self, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: torch.Tensor | None, - past_key_value: Cache | None = None, - cache_position: torch.LongTensor | None = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, torch.Tensor | None]: - """Forward pass for the CustomGemmaAttention module. + def __init__(self, config: PI0Config): + """Initialize the PI0 model. Args: - hidden_states: Input hidden states. - position_embeddings: Position embeddings. - attention_mask: Attention mask. - past_key_value: Past key-value cache. - cache_position: Cache position. - **kwargs: Additional keyword arguments. - - Returns: - Output hidden states, attention weights, and past key-value cache. + config: Model configuration specifying architecture and hyperparameters """ - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, self.head_dim) + super().__init__() + self.config = config + + paligemma_config = get_gemma_config(config.paligemma_variant) + action_expert_config = get_gemma_config(config.action_expert_variant) - self.query_states = query_states = ( - self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + self.paligemma_with_expert = PaliGemmaWithExpertModel( + paligemma_config, + action_expert_config, + use_adarms=config.use_adarms, + precision=config.dtype, ) - self.key_states = key_states = ( - self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + self.action_in_proj = nn.Linear( + config.max_action_dim, action_expert_config.width ) - self.value_states = value_states = ( - self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + self.action_out_proj = nn.Linear( + action_expert_config.width, config.max_action_dim ) - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin + self.state_proj = nn.Linear(config.max_state_dim, action_expert_config.width) + self.action_time_mlp_in = nn.Linear( + 2 * action_expert_config.width, action_expert_config.width + ) + self.action_time_mlp_out = nn.Linear( + action_expert_config.width, action_expert_config.width ) - if past_key_value is not None: - # sin and cos are specific to RoPE models; - # cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) + self.gradient_checkpointing_enabled = False - attention_interface: Callable = eager_attention_forward - if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get( - "output_attentions", False - ): - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` " - "does not support `output_attentions=True`. Falling back to " - "eager attention. This warning can be removed using the argument " - '`attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[ - self.config._attn_implementation - ] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - **kwargs, - ) + if config.gradient_checkpointing: + self.gradient_checkpointing_enable() + if config.device is not None: + self.to(config.device) - attn_output = attn_output.reshape(*input_shape, -1).contiguous() - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights + def gradient_checkpointing_enable(self) -> None: + """Enable gradient checkpointing on all submodules.""" + self.gradient_checkpointing_enabled = True + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = ( + True + ) + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = True + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = True + logging.info("Enabled gradient checkpointing for PI0Pytorch model") + + def gradient_checkpointing_disable(self) -> None: + """Disable gradient checkpointing on all submodules.""" + self.gradient_checkpointing_enabled = False + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing = ( + False + ) + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing = False + self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False + logging.info("Disabled gradient checkpointing for PI0Pytorch model") + + def compile_model_enable(self) -> None: + """Enable model compilation.""" + torch.set_float32_matmul_precision("high") + self.sample_actions = torch.compile( # type: ignore[method-assign] + self.sample_actions, mode=self.config.compile_mode + ) + self.forward = torch.compile( # type: ignore[method-assign] + self.forward, mode=self.config.compile_mode + ) + logging.info("Enabled model compilation for PI0Pytorch model") - def o_project(self, hidden_states: torch.Tensor) -> torch.Tensor: - """Output projection for the attention module. + def _apply_checkpoint(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T: + """Apply gradient checkpointing to a function if enabled. Args: - hidden_states: Input hidden states. + func: Function to potentially checkpoint + *args: Positional arguments for the function + **kwargs: Keyword arguments for the function Returns: - torch.Tensor: Output hidden states. + Function output, computed with or without checkpointing. """ - return self.o_proj(hidden_states) + if self.gradient_checkpointing_enabled and self.training: + return torch.utils.checkpoint.checkpoint( + func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs + ) + return func(*args, **kwargs) + def _prepare_attention_masks_4d(self, att_2d_masks: Tensor) -> Tensor: + """Expand 2D attention masks to 4D format for transformer layers. -class GemmaMoELayer(nn.Module): - """A layer that combines individual Gemma experts with cross-expert attention.""" + Args: + att_2d_masks: 2D attention mask [B, seq_len, seq_len] - def __init__(self, expert_configs: dict[str, MoeExpertConfig], layer_idx: int): - """Initialize the GemmaMoELayer. + Returns: + 4D attention mask [B, 1, seq_len, seq_len] with fill values applied. + """ + att_2d_masks_4d = att_2d_masks[:, None, :, :] + return torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE) + + def _sample_noise( + self, shape: torch.Size | tuple[int, ...], device: torch.device + ) -> Tensor: + """Sample standard normal noise for flow matching. Args: - expert_configs: Configuration for the experts. - layer_idx: Index of the layer. + shape: Shape of the noise tensor + device: Target device + + Returns: + Tensor of standard normal noise. """ - super().__init__() - self.expert_configs = expert_configs - self.layer_idx = layer_idx - - self.experts = nn.ModuleDict() - self.rotary_embs = nn.ModuleDict() - for name, config in expert_configs.items(): - # Create Gemma config for this expert - gemma_config = GemmaConfig(**asdict(config)) - # Ensure attention implementation is set to eager to avoid None lookups - # in ALL_ATTENTION_FUNCTIONS during CustomGemmaAttention.forward - setattr(gemma_config, "_attn_implementation", "eager") - setattr(gemma_config, "attn_implementation", "eager") - self.experts[name] = GemmaDecoderLayer(gemma_config, layer_idx) - self.experts[name].self_attn = CustomGemmaAttention( - config=gemma_config, layer_idx=layer_idx - ) - self.rotary_embs[name] = GemmaRotaryEmbedding(config=gemma_config) + return torch.normal( + mean=0.0, std=1.0, size=shape, dtype=torch.float32, device=device + ) - def mix_attention( - self, - queries: torch.Tensor, - keys: torch.Tensor, - values: torch.Tensor, - attention_mask: torch.Tensor, - dropout_p: float = 0.0, - ) -> torch.Tensor: - """Compute mixed attention across experts. + def _sample_time(self, bsize: int, device: torch.device) -> Tensor: + """Sample diffusion time steps from beta distribution. Args: - queries: Query tensor. - keys: Key tensor. - values: Value tensor. - attention_mask: Attention mask. - dropout_p: Dropout probability. + bsize: Batch size + device: Target device Returns: - torch.Tensor: Mixed attention output. + Tensor of time values [bsize] in range [offset, offset + scale]. """ - # Compute attention scores - attn_weights = torch.matmul(queries, keys.transpose(-1, -2)) - attn_weights = attn_weights / math.sqrt(queries.size(-1)) - - # Apply attention mask - if attention_mask is not None: - attn_weights = attn_weights + attention_mask - - # Apply softmax and dropout - attn_weights = torch.softmax(attn_weights, dim=-1, dtype=attn_weights.dtype) - attn_weights = nn.functional.dropout( - attn_weights, p=dropout_p, training=self.training + time_beta = _sample_beta( + self.config.time_sampling_beta_alpha, + self.config.time_sampling_beta_beta, + bsize, + device, ) + time = ( + time_beta * self.config.time_sampling_scale + + self.config.time_sampling_offset + ) + return time.to(dtype=torch.float32, device=device) - # Compute mixed attention output - mixed_output = torch.matmul(attn_weights, values) - return mixed_output - - def forward( + def _embed_prefix( self, - hidden_states: dict[str, torch.FloatTensor], - expert_attention_masks: dict[str, torch.Tensor] | None = None, - mix_attention_mask: torch.Tensor | None = None, - position_ids: dict[str, torch.LongTensor] | None = None, - past_key_values: dict[str, DynamicCache] | None = None, - use_cache: bool = False, - ) -> dict[str, torch.FloatTensor]: - """Forward pass for the GemmaMoELayer. + images: list[Tensor], + img_masks: list[Tensor], + lang_tokens: Tensor, + lang_masks: Tensor, + ) -> tuple[Tensor, Tensor, Tensor]: + """Embed image and language inputs for the prefix sequence. Args: - hidden_states: Input hidden states. - expert_attention_masks: Attention masks for the experts. - mix_attention_mask: Mixed attention mask. - position_ids: Position IDs. - past_key_values: Past key-value caches. - use_cache: Whether to use caching. + images: List of image tensors [B, C, H, W] per camera + img_masks: List of image masks [B] per camera + lang_tokens: Language token IDs [B, L] + lang_masks: Language attention mask [B, L] Returns: - Dict[str, torch.FloatTensor]: Output hidden states. + Tuple of (embeddings, padding_masks, attention_masks). """ - expert_outputs = {} # Store the expert outputs - query_states_all, key_states_all, value_states_all = {}, {}, {} - for name, states in hidden_states.items(): - pos_ids = position_ids.get(name) if position_ids else None - past_kv = past_key_values.get(name) if past_key_values else None - - # Get pos embeddings and run through expert - position_embeddings = self.rotary_embs[name](states, pos_ids) - expert_output = self.experts[name]( - hidden_states=states, - attention_mask=( - expert_attention_masks[name] if expert_attention_masks else None - ), - position_ids=pos_ids, - past_key_value=past_kv, - use_cache=use_cache, - position_embeddings=position_embeddings, - ) - expert_outputs[name] = expert_output[0] # Store the output + embs = [] + pad_masks = [] + att_masks = [] - # Store attention states - query_states_all[name] = self.experts[name].self_attn.query_states - key_states_all[name] = self.experts[name].self_attn.key_states - value_states_all[name] = self.experts[name].self_attn.value_states + for img, img_mask in zip(images, img_masks, strict=True): - # Concatenate for mixed attention - queries = torch.cat(tuple(query_states_all.values()), dim=2) - keys = torch.cat(tuple(key_states_all.values()), dim=2) - values = torch.cat(tuple(value_states_all.values()), dim=2) + def image_embed_func(img: Tensor) -> Tensor: + return self.paligemma_with_expert.embed_image(img) - # Run mixed attention - mixed_output = self.mix_attention(queries, keys, values, mix_attention_mask) + img_emb = self._apply_checkpoint(image_embed_func, img) + bsize, num_img_embs = img_emb.shape[:2] - attn_output = mixed_output.transpose(1, 2).contiguous() - batch_size = queries.size(0) - q_lens = [hidden_states.size(1) for hidden_states in hidden_states.values()] - attn_output = attn_output.view(batch_size, sum(q_lens), -1) + embs.append(img_emb) + pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) + att_masks += [0] * num_img_embs - # Split back per expert - attn_outputs = torch.split(attn_output, q_lens, dim=1) + def lang_embed_func(lang_tokens: Tensor) -> Tensor: + lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) + lang_emb_dim = lang_emb.shape[-1] + return lang_emb * math.sqrt(lang_emb_dim) - # Combine with expert outputs - outputs = {} - for name, states in zip(hidden_states.keys(), attn_outputs): - proj_mixed = self.experts[name].self_attn.o_project(states) - # Add expert output as residual - outputs[name] = expert_outputs[name] + proj_mixed + lang_emb = self._apply_checkpoint(lang_embed_func, lang_tokens) + embs.append(lang_emb) + pad_masks.append(lang_masks) - return outputs + num_lang_embs = lang_emb.shape[1] + att_masks += [0] * num_lang_embs + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1).to(dtype=torch.bool) + att_masks_t = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) + att_masks_t = _align_mask_length(att_masks_t, pad_masks.shape[1]) + bsize = pad_masks.shape[0] + att_masks_t = att_masks_t[None, :].expand(bsize, att_masks_t.shape[0]) + return embs, pad_masks, att_masks_t -class GemmaMoE(nn.Module): - """Main MoE model that uses Gemma experts.""" - - def __init__( - self, - depth: int, - expert_configs: dict[str, MoeExpertConfig], - ): - """Initialize the GemmaMoE model. + def _embed_suffix( + self, state: Tensor, noisy_actions: Tensor, timestep: Tensor + ) -> tuple[Tensor, Tensor, Tensor, None]: + """Embed state, noisy actions, and timestep for the action expert. Args: - depth: Depth of the MoE model. - expert_configs: Configuration for the experts. + state: Proprioceptive state [B, state_dim] + noisy_actions: Noisy action sequence [B, chunk_size, action_dim] + timestep: Diffusion timestep [B] + + Returns: + Tuple of (embeddings, padding_masks, attention_masks, adarms_cond). """ - super().__init__() - self.expert_names = list(expert_configs.keys()) - self.expert_configs = expert_configs + embs = [] + pad_masks = [] + att_masks = [] + + if self.state_proj.weight.dtype == torch.float32: + state = state.to(torch.float32) + + def state_proj_func(state: Tensor) -> Tensor: + return self.state_proj(state) + + state_emb = self._apply_checkpoint(state_proj_func, state) + embs.append(state_emb[:, None, :]) + bsize = state_emb.shape[0] + device = state_emb.device + + state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device) + pad_masks.append(state_mask) + att_masks += [1] + + time_emb = _create_sinusoidal_pos_embedding( + timestep, + self.action_in_proj.out_features, + min_period=self.config.min_period, + max_period=self.config.max_period, + device=timestep.device, + ) + time_emb = time_emb.type(dtype=timestep.dtype) + + def action_proj_func(noisy_actions: Tensor) -> Tensor: + return self.action_in_proj(noisy_actions) + + action_emb = self._apply_checkpoint(action_proj_func, noisy_actions) + time_emb = time_emb[:, None, :].expand_as(action_emb) + action_time_emb = torch.cat([action_emb, time_emb], dim=2) + + def mlp_func(action_time_emb: Tensor) -> Tensor: + x = self.action_time_mlp_in(action_time_emb) + x = F.silu(x) + return self.action_time_mlp_out(x) - # Create layers with Gemma experts - self.layers = nn.ModuleList( - [GemmaMoELayer(expert_configs, i) for i in range(depth)] + action_time_emb = self._apply_checkpoint(mlp_func, action_time_emb) + adarms_cond = None + + embs.append(action_time_emb) + bsize, action_time_dim = action_time_emb.shape[:2] + action_time_mask = torch.ones( + bsize, action_time_dim, dtype=torch.bool, device=timestep.device ) + pad_masks.append(action_time_mask) + att_masks += [1] + ([0] * (self.config.chunk_size - 1)) + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks_t = torch.tensor(att_masks, dtype=torch.bool, device=embs.device) + att_masks_t = _align_mask_length(att_masks_t, pad_masks.shape[1]) + att_masks_t = att_masks_t[None, :].expand(bsize, att_masks_t.shape[0]) - # Create final layer norms for each expert - self.final_norms = nn.ModuleDict() - for name, config in expert_configs.items(): - self.final_norms[name] = nn.LayerNorm(config.hidden_size) + return embs, pad_masks, att_masks_t, adarms_cond - # Track which experts use caching - self.cache_names = [ - name for name, config in expert_configs.items() if config.use_cache - ] + def forward( + self, + images: list[Tensor], + img_masks: list[Tensor], + lang_tokens: Tensor, + lang_masks: Tensor, + state: Tensor, + actions: Tensor, + noise: Tensor | None = None, + time: Tensor | None = None, + ) -> Tensor: + """Compute flow matching loss for training. - def _init_caches(self) -> dict[str, DynamicCache]: - """Initialize caches for the experts. + Args: + images: List of image tensors [B, C, H, W] per camera + img_masks: List of image masks [B] per camera + lang_tokens: Language token IDs [B, L] + lang_masks: Language attention mask [B, L] + state: Proprioceptive state [B, state_dim] + actions: Target action sequence [B, chunk_size, action_dim] + noise: Optional pre-sampled noise + time: Optional pre-sampled diffusion time Returns: - Dict[str, DynamicCache]: Initialized caches. + Per-element MSE loss [B, chunk_size, action_dim]. """ - return {name: DynamicCache() for name in self.cache_names} + if noise is None: + noise = self._sample_noise(actions.shape, actions.device) + if time is None: + time = self._sample_time(actions.shape[0], actions.device) + + time_expanded = time[:, None, None] + x_t = time_expanded * noise + (1 - time_expanded) * actions + u_t = noise - actions + + prefix_embs, prefix_pad_masks, prefix_att_masks = self._embed_prefix( + images, img_masks, lang_tokens, lang_masks + ) + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = ( + self._embed_suffix(state, x_t, time) + ) + + if ( + self.paligemma_with_expert.paligemma.language_model.layers[ + 0 + ].self_attn.q_proj.weight.dtype + == torch.bfloat16 + ): + suffix_embs = suffix_embs.to(dtype=torch.bfloat16) + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) + + pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) + att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) + + att_2d_masks = _make_att_2d_masks(pad_masks, att_masks) + position_ids = torch.cumsum(pad_masks, dim=1) - 1 + att_2d_masks_4d = self._prepare_attention_masks_4d(att_2d_masks) + + def forward_func( + prefix_embs: Tensor, + suffix_embs: Tensor, + att_2d_masks_4d: Tensor, + position_ids: Tensor, + adarms_cond: Tensor | None, + ) -> Tensor: + (_, suffix_out), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_masks_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + return suffix_out + + suffix_out = self._apply_checkpoint( + forward_func, + prefix_embs, + suffix_embs, + att_2d_masks_4d, + position_ids, + adarms_cond, + ) + suffix_out = suffix_out[:, -self.config.chunk_size :] + suffix_out = suffix_out.to(dtype=torch.float32) + + def action_out_proj_func(suffix_out: Tensor) -> Tensor: + return self.action_out_proj(suffix_out) - def _normalize_inputs( - self, hidden_states: dict[str, torch.FloatTensor] - ) -> dict[str, torch.FloatTensor]: - """Normalize input hidden states. + v_t = self._apply_checkpoint(action_out_proj_func, suffix_out) + return F.mse_loss(u_t, v_t, reduction="none") + + @torch.no_grad() + def sample_actions( + self, + images: list[Tensor], + img_masks: list[Tensor], + lang_tokens: Tensor, + lang_masks: Tensor, + state: Tensor, + noise: Tensor | None = None, + num_steps: int | None = None, + ) -> Tensor: + """Sample action sequence via Euler integration. + + From pure noise to actions using the flow matching ODE. Args: - hidden_states: Input hidden states. + images: List of image tensors [B, C, H, W] per camera + img_masks: List of image masks [B] per camera + lang_tokens: Language token IDs [B, L] + lang_masks: Language attention mask [B, L] + state: Proprioceptive state [B, state_dim] + noise: Optional initial noise + num_steps: Number of Euler steps (default: config.num_inference_steps) Returns: - Dict[str, torch.FloatTensor]: Normalized hidden states. + Sampled action sequence [B, chunk_size, action_dim]. """ - normalized = {} - for name, states in hidden_states.items(): - hidden_size = states.shape[-1] - normalizer = torch.sqrt( - torch.tensor(hidden_size, dtype=states.dtype, device=states.device) + if num_steps is None: + num_steps = self.config.num_inference_steps + + bsize = state.shape[0] + device = state.device + + if noise is None: + actions_shape = ( + bsize, + self.config.chunk_size, + self.config.max_action_dim, ) - normalized[name] = states * normalizer - return normalized + noise = self._sample_noise(actions_shape, device) + + prefix_embs, prefix_pad_masks, prefix_att_masks = self._embed_prefix( + images, img_masks, lang_tokens, lang_masks + ) + prefix_att_2d_masks = _make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + prefix_att_2d_masks_4d = self._prepare_attention_masks_4d(prefix_att_2d_masks) + paligemma_lm_config = self.paligemma_with_expert.paligemma.language_model.config + paligemma_lm_config._attn_implementation = "eager" + + _, past_key_values = self.paligemma_with_expert.forward( + attention_mask=prefix_att_2d_masks_4d, + position_ids=prefix_position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=True, + ) + + dt = -1.0 / num_steps + dt = torch.tensor(dt, dtype=torch.float32, device=device) + + x_t = noise + time = torch.tensor(1.0, dtype=torch.float32, device=device) + while time >= -dt / 2: + expanded_time = time.expand(bsize) + v_t = self._denoise_step( + state, + prefix_pad_masks, + past_key_values, + x_t, + expanded_time, + ) + x_t = x_t + dt * v_t + time += dt + + return x_t - def get_parameters(self, mixture_name: str) -> list: - """Get the parameters for a specific mixture. + def _denoise_step( + self, + state: Tensor, + prefix_pad_masks: Tensor, + past_key_values: list[torch.FloatTensor] | None, + x_t: Tensor, + timestep: Tensor, + ) -> Tensor: + """Compute velocity field for a single Euler denoising step. Args: - mixture_name: Name of the mixture. + state: Proprioceptive state [B, state_dim] + prefix_pad_masks: Padding masks from prefix embedding + past_key_values: Cached key-values from prefix forward pass + x_t: Current noisy actions [B, chunk_size, action_dim] + timestep: Current diffusion time [B] Returns: - list: List of parameters. + Predicted velocity [B, chunk_size, action_dim]. """ - params = [] - for layer in self.layers: - for name, expert in layer.experts.items(): - if name == mixture_name: - params.extend([p for p in expert.parameters()]) - return params + suffix_embs, suffix_pad_masks, suffix_att_masks, adarms_cond = ( + self._embed_suffix(state, x_t, timestep) + ) - def forward( - self, - hidden_states: dict[str, torch.FloatTensor], - expert_attention_masks: dict[str, torch.Tensor] | None = None, - mix_attention_mask: torch.Tensor | None = None, - position_ids: dict[str, torch.LongTensor] | None = None, - past_key_values: dict[str, DynamicCache] | None = None, - use_cache: bool = False, - ) -> torch.Tensor: - """Forward pass for the GemmaMoE model. + suffix_len = suffix_pad_masks.shape[1] + batch_size = prefix_pad_masks.shape[0] + prefix_len = prefix_pad_masks.shape[1] + + prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand( + batch_size, suffix_len, prefix_len + ) + suffix_att_2d_masks = _make_att_2d_masks(suffix_pad_masks, suffix_att_masks) + full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) + + prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] + position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 + + full_att_2d_masks_4d = self._prepare_attention_masks_4d(full_att_2d_masks) + gemma_config = self.paligemma_with_expert.gemma_expert.model.config + gemma_config._attn_implementation = "eager" + + outputs_embeds, _ = self.paligemma_with_expert.forward( + attention_mask=full_att_2d_masks_4d, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=[None, suffix_embs], + use_cache=False, + adarms_cond=[None, adarms_cond], + ) + + suffix_out = outputs_embeds[1] + assert suffix_out is not None + suffix_out = suffix_out[:, -self.config.chunk_size :] + suffix_out = suffix_out.to(dtype=torch.float32) + return self.action_out_proj(suffix_out) + + @classmethod + def from_pretrained( + cls, + pretrained_name_or_path: str | Path | None = None, + *, + config: PI0Config | None = None, + strict: bool = True, + **kwargs: Any, + ) -> PI0Policy: + """Load a pretrained PI0 model from HuggingFace Hub or local path. Args: - hidden_states: Input hidden states. - expert_attention_masks: Attention masks for the experts. - mix_attention_mask: Mixed attention mask. - position_ids: Position IDs. - past_key_values: Past key-value caches. - use_cache: Whether to use caching. + pretrained_name_or_path: HuggingFace repo id or local path + config: Model configuration (default: PI0Config()) + strict: Whether to strictly enforce state dict loading + **kwargs: Additional arguments (cache_dir, force_download, etc.) Returns: - hidden_states: Output hidden states. + PI0Policy model with loaded weights. """ - # Initialize caches if needed - if past_key_values is None and use_cache: - past_key_values = self._init_caches() - - # Normalize inputs - hidden_states = self._normalize_inputs(hidden_states) - - # Process through layers - for layer in self.layers: - hidden_states = layer( - hidden_states, - expert_attention_masks=expert_attention_masks, - mix_attention_mask=mix_attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - use_cache=use_cache, + if pretrained_name_or_path is None: + pretrained_name_or_path = "lerobot/pi0_base" + logging.warning( + "No pretrained model path provided; using default pi0_base model" ) + if config is None: + config = PI0Config() - # Apply final layer norms - hidden_states = { - name: self.final_norms[name](states) - for name, states in hidden_states.items() - } - return hidden_states + model = cls(config, **kwargs) + if cached_file is None or load_file is None: + logging.warning( + "transformers/safetensors not available; loading weights skipped" + ) + return model + + try: + resolved_file = cached_file( + pretrained_name_or_path, + "model.safetensors", + cache_dir=kwargs.get("cache_dir"), + force_download=kwargs.get("force_download", False), + resume_download=kwargs.get("resume_download"), + proxies=kwargs.get("proxies"), + token=kwargs.get("token") or kwargs.get("use_auth_token"), + revision=kwargs.get("revision"), + local_files_only=kwargs.get("local_files_only", False), + ) + original_state_dict = load_file(resolved_file) + logging.info("Loaded state dict from %s", resolved_file) + except Exception as exc: + logging.warning( + "Could not load state dict from %s: %s", pretrained_name_or_path, exc + ) + return model -class SinusoidalPosEmb(nn.Module): - """1D sinusoidal positional embedding module used for time embedding. - - This module implements sinusoidal positional embeddings for time steps, - commonly used in diffusion models and transformers. - """ + fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict) - def __init__(self, dim: int): - """Initialize the SinusoidalPosEmb module. + missing_keys, unexpected_keys = model.load_state_dict( + fixed_state_dict, strict=False + ) + if missing_keys: + logging.warning("Missing keys when loading state dict: %s", missing_keys) + if unexpected_keys: + logging.warning( + "Unexpected keys when loading state dict: %s", unexpected_keys + ) - Args: - dim: Dimension of the positional embedding. - """ - super().__init__() - self.dim = dim + tie_key = ( + "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" + ) + if tie_key in missing_keys: + paligemma = model.paligemma_with_expert.paligemma + if model._tie_or_copy_language_embeddings(paligemma): + logging.info("Tied language embeddings to lm_head weight") + missing_keys = [key for key in missing_keys if key != tie_key] + logging.warning( + "Missing keys after tying language embeddings: %s", missing_keys + ) + logging.info( + "Successfully loaded pretrained PI0 weights from %s", + pretrained_name_or_path, + ) + return model - def forward(self, t: torch.Tensor, max_period: float = 10000.0) -> torch.Tensor: - """Forward pass for the SinusoidalPosEmb module. + def _tie_or_copy_language_embeddings( + self, paligemma: PaliGemmaForConditionalGeneration + ) -> bool: + """Tie or copy language embeddings to lm_head weight. Args: - t: Input tensor. - max_period: Maximum period for the sinusoidal embedding. + paligemma: PaliGemma model instance Returns: - torch.Tensor: Positional embeddings. + True if embeddings were successfully tied, False otherwise. """ - half_dim = self.dim // 2 - emb = math.log(max_period) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb).to(t.dtype) - emb = t[:, None] * emb[None, :] - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb + language_model = getattr( + getattr(paligemma, "model", None), "language_model", None + ) + lm_head = getattr(paligemma, "lm_head", None) + if language_model is None or lm_head is None: + return False + embed_tokens = getattr(language_model, "embed_tokens", None) + lm_head_weight = getattr(lm_head, "weight", None) + if embed_tokens is None or lm_head_weight is None: + return False -class ActionEncoder(nn.Module): - """Action encoder for the Pi0 model.""" + embed_weight = getattr(embed_tokens, "weight", None) + if embed_weight is None or embed_weight.shape != lm_head_weight.shape: + return False - def __init__(self, action_dim: int, width: int): - """Initialize the ActionEncoder module. + with torch.no_grad(): + embed_weight.copy_(lm_head_weight) - Args: - action_dim: Dimension of the action space. - width: Width of the encoder. - """ - super().__init__() - self.linear_1 = nn.Linear(action_dim, width) - self.linear_2 = nn.Linear(2 * width, width) - self.nonlinearity = nn.SiLU() - self.linear_3 = nn.Linear(width, width) + if hasattr(paligemma, "tie_weights"): + paligemma.tie_weights() - def forward( - self, action: torch.Tensor, time_emb: torch.Tensor | None = None - ) -> torch.Tensor: - """Forward pass for the ActionEncoder module. + tied_embed = getattr(language_model.embed_tokens, "weight", None) + return ( + tied_embed is not None + and tied_embed.data_ptr() == lm_head_weight.data_ptr() + ) + + def _fix_pytorch_state_dict_keys( + self, state_dict: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """Fix state dict keys to match current model architecture. + + Handles key remapping and filtering for compatibility with + different checkpoint formats (e.g., OpenPI vs current). Args: - action: Input action tensor. - time_emb: Time embedding tensor. + state_dict: Original state dict from checkpoint Returns: - torch.Tensor: Encoded action tensor. + Fixed state dict with compatible keys. """ - emb = self.linear_1(action) # [B, H, W] - if time_emb is not None: - time_emb_full = time_emb.unsqueeze(1).expand(-1, action.size(1), -1) - else: - time_emb_full = torch.zeros_like(emb) - emb = torch.cat([time_emb_full, emb], dim=-1) # [B, H, W * 2] - emb = self.nonlinearity(self.linear_2(emb)) # [B, H, W] - emb = self.linear_3(emb) # [B, H, W] - return emb # [B, H, W] + import re + + fixed_state_dict: dict[str, torch.Tensor] = {} + + for key, value in state_dict.items(): + new_key = key + + if re.match( + ( + r"paligemma_with_expert\.gemma_expert\.model\.layers\.\d+\." + r"(input_layernorm|post_attention_layernorm)\.weight" + ), + key, + ): + expert_uses_adarms = getattr( + self.paligemma_with_expert.gemma_expert.config, + "use_adarms", + False, + ) + if expert_uses_adarms: + logging.warning( + "Skipping layer norm key (adaRMS mismatch): %s", key + ) + continue + + if re.match( + r"paligemma_with_expert\.gemma_expert\.model\.norm\.weight", key + ): + expert_uses_adarms = getattr( + self.paligemma_with_expert.gemma_expert.config, + "use_adarms", + False, + ) + if expert_uses_adarms: + logging.warning("Skipping norm key (adaRMS mismatch): %s", key) + continue + + if key.startswith("time_mlp_in."): + new_key = key.replace("time_mlp_in.", "action_time_mlp_in.") + elif key.startswith("time_mlp_out."): + new_key = key.replace("time_mlp_out.", "action_time_mlp_out.") + + if "patch_embedding" in key: + logging.warning("Vision embedding key might need handling: %s", key) + + fixed_state_dict[new_key] = value + + return fixed_state_dict diff --git a/neuracore/ml/algorithms/pi0/pi0.py b/neuracore/ml/algorithms/pi0/pi0.py index 0236e9e6..cf14ff3c 100644 --- a/neuracore/ml/algorithms/pi0/pi0.py +++ b/neuracore/ml/algorithms/pi0/pi0.py @@ -9,14 +9,12 @@ for General Robot Control." arXiv preprint https://arxiv.org/abs/2410.24164. """ +from __future__ import annotations + import logging -import os -from typing import cast +from typing import Any, Literal, cast import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision.transforms as T from neuracore_types import ( BatchedJointData, BatchedLanguageData, @@ -30,7 +28,7 @@ ModelInitDescription, ParallelGripperOpenAmountDataStats, ) -from transformers import AutoProcessor, PaliGemmaForConditionalGeneration +from torch.optim.lr_scheduler import LambdaLR from neuracore.ml import ( BatchedInferenceInputs, @@ -40,90 +38,119 @@ ) from neuracore.ml.algorithm_utils.normalizer import MeanStdNormalizer -from .modules import ActionEncoder, GemmaMoE, MoeExpertConfig, SinusoidalPosEmb - -logging.getLogger("transformers").setLevel(logging.CRITICAL) +from .modules import PI0Policy +from .utils import PI0Config, build_lr_lambda, pad_vector, resize_with_pad_torch logger = logging.getLogger(__name__) PROPRIO_NORMALIZER = MeanStdNormalizer # or MinMaxNormalizer ACTION_NORMALIZER = MeanStdNormalizer # or MinMaxNormalizer - -VLM_BACKBONE = "google/paligemma-3b-pt-224" -VLM_EXPERT_WIDTH = 2048 # Width of the VLM expert, matches PaliGemma's hidden size +IMAGE_RESIZE_SHAPE = (224, 224) class Pi0(NeuracoreModel): - """Implementation of Pi0 model from the Physical Intelligence paper. + """Vision-language-action flow model for robot manipulation. + + Implements the π0 model from Physical Intelligence that combines a + PaliGemma vision-language model with a Gemma action expert. The model + uses flow matching to predict action sequences from visual observations, + proprioceptive state, and optional language instructions. - Currently only supports finetuning the action expert. The model combines - vision-language understanding with action prediction through a mixture of - experts architecture. + The architecture supports flexible finetuning strategies including + action-expert-only, vision+action, or full model training. """ def __init__( self, model_init_description: ModelInitDescription, - vlm_expert_intermediate_size: int = 16384, - vlm_expert_num_heads: int = 8, - vlm_expert_num_kv_heads: int = 1, - vlm_expert_head_dim: int = 256, - vlm_max_text_tokens: int = 128, - action_expert_width: int = 1024, - action_expert_intermediate_size: int = 4096, - action_expert_num_heads: int = 8, - action_expert_num_kv_heads: int = 1, - action_expert_head_dim: int = 256, - moe_depth: int = 18, + vlm_max_text_tokens: int = 48, num_inference_steps: int = 10, - flow_sig_min: float = 0.001, - flow_alpha: float = 1.5, - flow_beta: float = 1.0, - lr: float = 5e-5, - weight_decay: float = 0.0, - dtype: torch.dtype = torch.float32, + dtype: Literal["bfloat16", "float32"] = "float32", + paligemma_variant: str = "gemma_2b", + action_expert_variant: str = "gemma_300m", + use_pretrained_weights: bool = True, + pretrained_name_or_path: str | None = "lerobot/pi0_base", + time_sampling_beta_alpha: float = 1.5, + time_sampling_beta_beta: float = 1.0, + time_sampling_scale: float = 0.999, + time_sampling_offset: float = 0.001, + min_period: float = 4e-3, + max_period: float = 4.0, + gradient_checkpointing: bool = False, + compile_model: bool = False, + compile_mode: str = "max-autotune", + optimizer_lr: float = 2.5e-5, + optimizer_betas: tuple[float, float] = (0.9, 0.95), + optimizer_eps: float = 1e-8, + optimizer_weight_decay: float = 0.01, + clip_grad_norm: float = 1.0, + lr_scheduler_warmup_steps: int = 1000, + lr_scheduler_num_decay_steps: int = 30000, + lr_scheduler_decay_lr: float = 2.5e-6, + finetune_action_expert_only: bool = False, + freeze_language_model_only: bool = False, ): """Initialize the Pi0 model. Args: - model_init_description: Model initialization configuration. - vlm_expert_intermediate_size: Intermediate size of the VLM expert. - vlm_expert_num_heads: Number of attention heads in the VLM expert. - vlm_expert_num_kv_heads: Number of key-value heads in the VLM expert. - vlm_expert_head_dim: Dimension of each attention head in the VLM expert. - vlm_max_text_tokens: Maximum number of text tokens for the VLM. - action_expert_width: Width of the action expert. - action_expert_intermediate_size: Intermediate size of the action expert. - action_expert_num_heads: Number of attention heads in the action expert. - action_expert_num_kv_heads: Number of key-value heads in the action expert. - action_expert_head_dim: Dimension of each attention head in action expert. - moe_depth: Depth of the mixture of experts. - num_inference_steps: Number of inference steps. - flow_sig_min: Minimum value for the flow sigma. - flow_alpha: Alpha parameter for the flow beta distribution. - flow_beta: Beta parameter for the flow beta distribution. - lr: Learning rate for the model. - weight_decay: Weight decay for the model. - dtype: Data type for model parameters and computations. + model_init_description: Model initialization parameters + vlm_max_text_tokens: Maximum number of language tokens + num_inference_steps: Number of Euler denoising steps + dtype: Model precision ("bfloat16" or "float32") + paligemma_variant: VLM size ("gemma_300m" or "gemma_2b") + action_expert_variant: Action expert size ("gemma_300m" or "gemma_2b") + use_pretrained_weights: Whether to load pretrained weights + pretrained_name_or_path: HuggingFace repo id or local path + time_sampling_beta_alpha: Alpha for beta distribution time sampling + time_sampling_beta_beta: Beta for beta distribution time sampling + time_sampling_scale: Scale factor for sampled time values + time_sampling_offset: Offset added to sampled time values + min_period: Minimum period for sinusoidal time embeddings + max_period: Maximum period for sinusoidal time embeddings + gradient_checkpointing: Enable gradient checkpointing + compile_model: Enable torch.compile optimization + compile_mode: Compilation mode for torch.compile + optimizer_lr: Learning rate + optimizer_betas: Adam beta parameters + optimizer_eps: Adam epsilon + optimizer_weight_decay: Weight decay + clip_grad_norm: Gradient clipping norm (unused, for config compatibility) + lr_scheduler_warmup_steps: Linear warmup steps + lr_scheduler_num_decay_steps: Cosine decay steps + lr_scheduler_decay_lr: Final learning rate after decay + finetune_action_expert_only: Only train action expert parameters + freeze_language_model_only: Freeze language model, train vision+action """ super().__init__(model_init_description) - if not os.environ.get("HF_TOKEN"): - raise ValueError( - "Hugging Face token not found. " - "Please set the HF_TOKEN environment variable." - ) - - self.action_expert_width = action_expert_width + self.max_state_dim = self.max_action_dim = 32 self.vlm_max_text_tokens = vlm_max_text_tokens self.num_inference_steps = num_inference_steps - self.flow_sig_min = flow_sig_min - self.flow_beta_dist = torch.distributions.Beta(flow_alpha, flow_beta) - self.lr = lr - self.weight_decay = weight_decay self.dtype = dtype + self.time_sampling_beta_alpha = time_sampling_beta_alpha + self.time_sampling_beta_beta = time_sampling_beta_beta + self.time_sampling_scale = time_sampling_scale + self.time_sampling_offset = time_sampling_offset + self.min_period = min_period + self.max_period = max_period + self.gradient_checkpointing = gradient_checkpointing + self.compile_model = compile_model + self.compile_mode = compile_mode + self.optimizer_lr = optimizer_lr + self.optimizer_betas = optimizer_betas + self.optimizer_eps = optimizer_eps + self.optimizer_weight_decay = optimizer_weight_decay + self.lr_scheduler_warmup_steps = lr_scheduler_warmup_steps + self.lr_scheduler_num_decay_steps = lr_scheduler_num_decay_steps + self.lr_scheduler_decay_lr = lr_scheduler_decay_lr + self.use_pretrained_weights = use_pretrained_weights + self.pretrained_name_or_path = pretrained_name_or_path + self.finetune_action_expert_only = finetune_action_expert_only + self.freeze_language_model_only = freeze_language_model_only data_stats: dict[DataType, DataItemStats] = {} + # Track per-data-type feature sizes to preserve ordering when splitting + self.output_slices: dict[DataType, list[int]] = {} # Setup proprioceptive data self.proprio_dims: dict[DataType, tuple[int, int]] = {} @@ -160,8 +187,6 @@ def __init__( self.proprio_dims[data_type] = (current_dim, current_dim + dim) current_dim += dim - proprio_dim = current_dim - # Setup output data self.max_output_size = 0 output_stats = [] @@ -205,7 +230,6 @@ def __init__( self.max_output_size += dim self.action_dim = self.max_output_size - self.action_horizon = self.output_prediction_horizon # Setup normalizers # Only create proprio_normalizer if there are proprioception stats @@ -220,97 +244,110 @@ def __init__( ) # Setup RGB cameras - num_rgbs = 0 if DataType.RGB_IMAGES in self.input_data_types: stats = cast( list[CameraDataStats], self.dataset_statistics[DataType.RGB_IMAGES] ) - num_rgbs = len(stats) - - self.vlm_max_tokens = num_rgbs * 256 + self.vlm_max_text_tokens - - self.vlm = PaliGemmaForConditionalGeneration.from_pretrained( - VLM_BACKBONE, dtype=self.dtype, attn_implementation="eager" - ) - self.vlm_processor = AutoProcessor.from_pretrained( - VLM_BACKBONE, padding_side="right" + len(stats) + + # Build PI0 config + self.config = PI0Config( + paligemma_variant=paligemma_variant, + action_expert_variant=action_expert_variant, + dtype=dtype, + chunk_size=self.output_prediction_horizon, + max_state_dim=self.max_state_dim, + max_action_dim=self.max_action_dim, + num_inference_steps=self.num_inference_steps, + time_sampling_beta_alpha=self.time_sampling_beta_alpha, + time_sampling_beta_beta=self.time_sampling_beta_beta, + time_sampling_scale=self.time_sampling_scale, + time_sampling_offset=self.time_sampling_offset, + min_period=self.min_period, + max_period=self.max_period, + gradient_checkpointing=self.gradient_checkpointing, + compile_model=self.compile_model, + compile_mode=self.compile_mode, + device=self.device, ) - self.vlm_embedding_module = self.vlm.get_input_embeddings() - assert self.vlm_processor.tokenizer.padding_side == "right" - - # Disable finetuning of the VLM - for param in self.vlm.parameters(): - param.requires_grad = False - - # Create a mixture of experts (MoE) model consisting of 2 experts: - # 1. VLM expert - # 2. Action expert - expert_configs = { - "vlm": MoeExpertConfig( - hidden_size=VLM_EXPERT_WIDTH, - intermediate_size=vlm_expert_intermediate_size, - head_dim=vlm_expert_head_dim, - num_attention_heads=vlm_expert_num_heads, - num_key_value_heads=vlm_expert_num_kv_heads, - ), - "action": MoeExpertConfig( - hidden_size=action_expert_width, - intermediate_size=action_expert_intermediate_size, - head_dim=action_expert_head_dim, - num_attention_heads=action_expert_num_heads, - num_key_value_heads=action_expert_num_kv_heads, - ), - } - self.moe = GemmaMoE(moe_depth, expert_configs) - self.action_encoder = ActionEncoder(self.action_dim, action_expert_width) - self.time_embedding = SinusoidalPosEmb(action_expert_width) - # Only create proprio_encoder if there's proprioception data - # This allows the algorithm to work without proprioception (visual-only) - if proprio_dim > 0: - self.proprio_encoder = nn.Linear(proprio_dim, action_expert_width) + + # Core model from the reference implementation + if self.use_pretrained_weights and self.pretrained_name_or_path: + self.model = PI0Policy.from_pretrained( + self.pretrained_name_or_path, config=self.config + ) else: - # Create a dummy encoder that outputs zeros (will be replaced in forward) - self.proprio_encoder = None - self.action_decoder = nn.Linear( - action_expert_width, - self.action_dim, - ) + self.model = PI0Policy(self.config) - gemma_config = self.vlm.config.text_config - self.using_pretrained_paligemma = ( - gemma_config.intermediate_size == vlm_expert_intermediate_size - and gemma_config.hidden_size == VLM_EXPERT_WIDTH - ) + if self.config.gradient_checkpointing: + self.model.gradient_checkpointing_enable() - # Load PaliGemma weights into VLM expert - if self.using_pretrained_paligemma: - self._load_pretrained_vlm_weights() - else: - logger.warning("Using custom VLM weights, not pretrained PaliGemma") + self._setup_optimizer_param_groups() - # disable grads for VLM part of MoE if using pretrained - if self.using_pretrained_paligemma: - for param in self.moe.get_parameters("vlm"): - param.requires_grad = False + def gradient_checkpointing_enable(self) -> None: + """Enable gradient checkpointing on the underlying PI0 model.""" + self.model.gradient_checkpointing_enable() - # Delete the language model to save memory (keep only embeddings) - # Note: We delete model.language_model (the actual module), not - # language_model (the property) - del self.vlm.model.language_model + def gradient_checkpointing_disable(self) -> None: + """Disable gradient checkpointing on the underlying PI0 model.""" + self.model.gradient_checkpointing_disable() - # Resize the images to 224x224 - self.image_normalizer = torch.nn.Sequential( - T.Resize((224, 224)), - ) + def _setup_optimizer_param_groups(self) -> None: + """Setup optimizer parameter groups for the underlying PI0 model. + + There are two logical groups: the VLM model and the action expert model. + You can either finetune everything or just the action expert while + freezing the VLM model. + """ + # Define parameter name patterns + ACTION_EXPERT_PARAM_NAMES = [ + "gemma_expert", + "action_in_proj", + "action_out_proj", + "state_proj", + "action_time_mlp_in", + "action_time_mlp_out", + ] + VISION_ENCODER_PARAM_NAMES = ["vision_tower", "multi_modal"] + + # Determine which parameters to include + if self.finetune_action_expert_only: + params = [ + param + for name, param in self.model.named_parameters() + if any(param_name in name for param_name in ACTION_EXPERT_PARAM_NAMES) + ] + self.param_groups = [{"params": params, "lr": self.optimizer_lr}] + elif self.freeze_language_model_only: + params = [ + param + for name, param in self.model.named_parameters() + if any( + param_name in name + for param_name in ACTION_EXPERT_PARAM_NAMES + + VISION_ENCODER_PARAM_NAMES + ) + ] + self.param_groups = [{"params": params, "lr": self.optimizer_lr}] + else: + # Train all parameters + self.param_groups = [{ + "params": list(self.model.parameters()), + "lr": self.optimizer_lr, + }] def _combine_proprio(self, batch: BatchedInferenceInputs) -> torch.FloatTensor: - """Combine different types of joint state data. + """Combine and normalize proprioceptive state data. + + Concatenates joint positions, velocities, torques, and gripper states + into a single normalized state vector padded to max_state_dim. Args: batch: Input batch containing joint state data Returns: - torch.FloatTensor: Combined and normalized joint state features + Combined and normalized state tensor [B, max_state_dim], or None + if no proprioceptive data is available. """ proprio_list = [] for data_type in [ @@ -357,22 +394,28 @@ def _combine_proprio(self, batch: BatchedInferenceInputs) -> torch.FloatTensor: "Proprioception inputs were provided but no normalizer was available." ) normalized_proprio = self.proprio_normalizer.normalize(all_proprio) + # Pad proprio to max state dim since PI0 expects fixed-size input. + # Pad after normalization to avoid padding artifacts. + normalized_proprio = pad_vector(normalized_proprio, self.max_state_dim).to( + self.device + ) return normalized_proprio def _prepare_rgb_images( self, batch: BatchedInferenceInputs ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: - """Prepare the RGB images and masks. + """Prepare RGB images for the vision encoder. - First resize to 224x224 and then normalize values to [-1,1]. And transform - the image dimension to (num_cams, B, C, H, W). + Resizes images to 224x224 and normalizes pixel values to [-1, 1] + as expected by the SigLIP vision encoder. Args: - batch: Batch of inference samples. + batch: Batch of inference samples Returns: - tuple[list[torch.Tensor], list[torch.Tensor]]: List of images and masks. + Tuple of (images, masks) where images is a list of tensors + [B, C, H, W] per camera and masks is a list of [B] tensors. """ if DataType.RGB_IMAGES not in batch.inputs: raise ValueError("RGB images are required but not provided") @@ -384,7 +427,7 @@ def _prepare_rgb_images( image_masks = [] for cam_id, input_rgb in enumerate(batched_rgb_data): last_frame = input_rgb.frame[:, -1, :, :, :] # (B, 3, H, W) - image = self.image_normalizer(last_frame) + image = resize_with_pad_torch(last_frame, *IMAGE_RESIZE_SHAPE) # Normalize from range [0,1] to [-1,1] as expected by siglip image = image * 2.0 - 1.0 images.append(image) @@ -396,14 +439,14 @@ def _process_language_tokens( self, batch: BatchedInferenceInputs, ) -> tuple[torch.Tensor, torch.Tensor]: - """Process the language tokens. + """Extract language tokens and attention masks from batch. Args: - batch: Batch of inference samples. + batch: Batch of inference samples Returns: - torch.Tensor: Language tokens tensor. - torch.Tensor: Language mask tensor. + Tuple of (tokens, mask) where tokens is [B, L] token IDs + and mask is [B, L] attention mask. """ batch_size = len(batch) if DataType.LANGUAGE not in batch.inputs: @@ -429,328 +472,87 @@ def _process_language_tokens( return language_tokens, language_mask - def _load_pretrained_vlm_weights(self) -> None: - """Load pretrained PaliGemma weights into the VLM expert of the MoE.""" - logger.info("Loading pretrained PaliGemma weights into VLM expert...") - vlm_state_dict = self.vlm.model.language_model.state_dict() - moe_state_dict = self.moe.state_dict() - new_state_dict = {} - for moe_key, moe_param in moe_state_dict.items(): - # Check if this is a VLM expert parameter - if "experts.vlm" in moe_key: - # Convert MoE key format to PaliGemma key format - vlm_key = moe_key.replace("experts.vlm.", "") - - # If this key exists in the VLM state dict, copy it - if vlm_key not in vlm_state_dict: - raise ValueError(f"VLM key not found: {vlm_key}") - new_state_dict[moe_key] = vlm_state_dict[vlm_key] - else: - # Keep non-VLM parameters as is - new_state_dict[moe_key] = moe_param - - # Load the combined state dict - missing_keys, unexpected_keys = self.moe.load_state_dict( - new_state_dict, strict=True - ) - - # Log any mismatches for debugging - if missing_keys: - raise ValueError(f"Missing keys when loading VLM weights: {missing_keys}") - if unexpected_keys: - raise ValueError( - f"Unexpected keys when loading VLM weights: {unexpected_keys}" - ) - - logger.info("Successfully loaded pretrained PaliGemma weights into VLM expert.") - - def _create_expert_attention_masks( - self, batch_size: int, pad_masks: torch.Tensor = None - ) -> dict[str, torch.Tensor]: - """Create attention masks for the experts. + def _build_inputs_from_batch( + self, batch: BatchedInferenceInputs + ) -> tuple[ + list[torch.Tensor], list[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor + ]: + """Build model inputs from a batch of inference samples. Args: - batch_size: Size of the batch. - pad_masks: Padding masks for the merged text and images tensor. + batch: Batch of inference samples Returns: - dict[str, torch.Tensor]: Attention masks for the experts. + Tuple of (images, image_masks, lang_tokens, lang_masks, proprios). """ - # generate 2d padding mask from 1d padding mask - # pad_masks has shape [batch_size, seq_len] - # Create attention mask: [batch_size, seq_len, seq_len] - vlm_mask = pad_masks.unsqueeze(1) * pad_masks.unsqueeze(2) - # Convert to attention mask format (0 for attended positions, -inf for masked) - vlm_mask = torch.where(vlm_mask == 1, 0.0, torch.finfo(self.dtype).min).to( - self.dtype - ) - state_len = 1 - action_len = self.action_horizon - - stat_act_len = state_len + action_len # proprio + actions - state_action_mask = torch.zeros( - (stat_act_len, stat_act_len), device=self.device, dtype=self.dtype - ) - - # Proprio can only attend to itself - state_action_mask[0, 0] = 1 - - # Each action can attend to proprio and previous actions - for i in range(1, stat_act_len): # i starts at 1 (first action) - # Can attend to proprio - state_action_mask[i, 0] = 1 - # Can attend to self and previous actions - state_action_mask[i, 1 : i + 1] = 1 - - # Convert to attention mask format (0 for attended positions, -inf for masked) - state_action_mask = torch.where( - state_action_mask == 1, 0.0, torch.finfo(self.dtype).min - ).to(self.dtype) - - # Add head dimension: [batch_size, 1, seq_len, seq_len] - vlm_mask = vlm_mask.unsqueeze(1) - state_action_mask = ( - state_action_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, 1, -1, -1) - ) - - return {"vlm": vlm_mask, "action": state_action_mask} + images, image_masks = self._prepare_rgb_images(batch) + lang_tokens, lang_masks = self._process_language_tokens(batch) + proprios = self._combine_proprio(batch) + return images, image_masks, lang_tokens, lang_masks, proprios - def _create_pi0_mix_attention_mask( - self, batch_size: int, vlm_seq_len: int | None = None - ) -> torch.Tensor: - """Create the mixed attention mask for the Pi0 model. + def _predict_action(self, batch: BatchedInferenceInputs) -> torch.Tensor: + """Predict action sequence for the given batch. Args: - batch_size: Size of the batch. - vlm_seq_len: Actual VLM sequence length. + batch: Input batch with observations Returns: - torch.Tensor: Mixed attention mask. + Predicted action tensor [B, chunk_size, action_dim] """ - # Calculate sequence lengths for each block - vlm_len = vlm_seq_len if vlm_seq_len is not None else self.vlm_max_tokens - state_len = 1 - action_len = self.action_horizon - total_seq_len = vlm_len + state_len + action_len - - # Create base mask allowing full attention within each block - mask = torch.zeros( - (total_seq_len, total_seq_len), device=self.device, dtype=self.dtype + images, image_masks, lang_tokens, lang_masks, proprios = ( + self._build_inputs_from_batch(batch) ) - - # (VLM): Can only attend to itself - mask[:vlm_len, :vlm_len] = 1 - - # (State / Action): Can attend to VLM - mask[vlm_len:, :vlm_len] = 1 - - # Proprio can attend to itself and vl - mask[vlm_len : vlm_len + state_len, : vlm_len + state_len] = 1 - - action_start = vlm_len + state_len - # Actions follow causal pattern - for i in range(0, action_len): - # Can attend to proprio and previous actions - mask[action_start + i, : action_start + i + 1] = 1 - - # Add batch dimension and head dimension - mask = mask.unsqueeze(0).unsqueeze(1) - mask = mask.expand(batch_size, 1, -1, -1) - # Convert to attention mask format (0 for attended positions, -inf for masked) - attention_mask = torch.where(mask == 1, 0.0, torch.finfo(self.dtype).min).to( - self.dtype + actions = self.model.sample_actions( + images, image_masks, lang_tokens, lang_masks, proprios ) - return attention_mask - - def _create_pi0_position_ids( - self, batch_size: int, vlm_seq_len: int | None = None - ) -> dict[str, torch.Tensor]: - """Create position IDs for the Pi0 model. - - Args: - batch_size: Size of the batch. - vlm_seq_len: Actual VLM sequence length. - - Returns: - dict[str, torch.Tensor]: Position IDs for VLM and action blocks. - """ - # VLM positions: Use actual sequence length - vlm_len = vlm_seq_len if vlm_seq_len is not None else self.vlm_max_tokens - vlm_pos = torch.arange(1, vlm_len + 1, device=self.device).type(self.dtype) - vlm_pos = vlm_pos.unsqueeze(0).expand(batch_size, -1) - - # State and Action positions: Sequential positions for state and action sequence - state_action_pos = torch.arange( - 1, 1 + self.action_horizon + 1, device=self.device - ).type(self.dtype) - state_action_pos = state_action_pos.unsqueeze(0).expand(batch_size, -1) - - position_ids = {"vlm": vlm_pos, "action": state_action_pos} + actions = actions[:, :, : self.action_dim] # output pad to max action dim + return actions - return position_ids - - def _forward_vlm_merged_text_images( - self, - images: list[torch.Tensor], - image_masks: list[torch.Tensor], - language_tokens: torch.Tensor, - language_masks: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Forward pass for merging text and images in the VLM. - - Generates the mixed image-language embeddings and padding masks. - - Args: - images: Input images tensor. - image_masks: Input image masks tensor. - language_tokens: Input language tokens tensor. - language_masks: Input language masks tensor. - - Returns: - tuple[torch.Tensor, torch.Tensor]: Merged text and images - tensor, mixed padding mask. - """ - embs = [] - pad_masks = [] - - # iterate over num_cam images - for img, img_mask in zip(images, image_masks): - img_emb = self.vlm.model.get_image_features(img) - img_emb = img_emb.to(dtype=self.dtype, device=self.device) - - bsize, num_img_embs = img_emb.shape[:2] - img_mask = ( - img_mask[:, None].expand(bsize, num_img_embs).to(device=self.device) - ) - - embs.append(img_emb) - pad_masks.append(img_mask) - - language_embeddings = self.vlm_embedding_module(language_tokens) - embs.append(language_embeddings) - pad_masks.append(language_masks) - - embs = torch.cat(embs, dim=1) - pad_masks = torch.cat(pad_masks, dim=1) - return embs, pad_masks - - def _sample_fm_time(self, batch_size: int) -> torch.Tensor: - """Sample flow matching timesteps. - - Args: - batch_size: Size of the batch. - - Returns: - torch.Tensor: Sampled timesteps. - """ - z = self.flow_beta_dist.sample((batch_size,)) - t = (1 - self.flow_sig_min) * (1 - z) - return t.to(self.device).to(self.dtype) + @classmethod + def from_pretrained( + cls, + model_init_description: ModelInitDescription, + pretrained_name_or_path: str | None = None, + **kwargs: Any, + ) -> Pi0: + """Load a pretrained PI0 model while keeping the Neuracore model interface. - def _predict_action( - self, - merged_text_images: torch.Tensor, - proprio_embeds: torch.Tensor, - action: torch.Tensor, - t: torch.Tensor, - vlm_seq_len: int | None = None, - pad_masks: torch.Tensor | None = None, - ) -> torch.Tensor: - """Predict action sequence from observations. + By default, downloads weights from https://huggingface.co/lerobot/pi0_base + which contains the π₀ base model from Physical Intelligence. Args: - merged_text_images: Merged text and images tensor. - proprio_embeds: Proprioceptive embeddings tensor. - action: Action tensor. - t: Time tensor. - vlm_seq_len: Actual VLM Embeddings sequence length. - pad_masks: Padding masks for the merged text and images tensor. + model_init_description: Neuracore model initialization config. + pretrained_name_or_path: HuggingFace repo id (e.g. "lerobot/pi0_base") + or local path. Defaults to "lerobot/pi0_base". + **kwargs: Additional arguments passed to PI0Policy.from_pretrained + (e.g. cache_dir, force_download, token, revision). Returns: - torch.Tensor: Predicted action tensor. + Pi0 model with loaded pretrained weights. """ - batch_size = proprio_embeds.size(0) - time_cond = self.time_embedding(t) - # [B, H, E] - action_embeds = self.action_encoder(action, time_cond) - # [B, 1 + H, E] - proprio_embeds = proprio_embeds.unsqueeze(1) # [B, 1, E] - proprio_action_tokens = torch.cat([proprio_embeds, action_embeds], dim=1) - # [B, 1 + H, E] - proprio_action_embeds = self.moe( - hidden_states={ - "vlm": merged_text_images, - "action": proprio_action_tokens, - }, - expert_attention_masks=self._create_expert_attention_masks( - batch_size, pad_masks - ), - mix_attention_mask=self._create_pi0_mix_attention_mask( - batch_size, vlm_seq_len - ), - position_ids=self._create_pi0_position_ids(batch_size, vlm_seq_len), - )["action"] - # [B, H, E] - action_embeds = proprio_action_embeds[:, 1:] - return self.action_decoder(action_embeds) + model = PI0Policy.from_pretrained(pretrained_name_or_path, **kwargs) + obj = cls(model_init_description) + obj.model = model + obj.config = model.config + return obj def forward( self, batch: BatchedInferenceInputs ) -> dict[DataType, list[BatchedNCData]]: - """Forward pass for generating actions. + """Perform inference to predict action sequence. Args: - batch: Batch of inference samples. + batch: Input batch with observations Returns: - dict[DataType, list[BatchedNCData]]: Model predictions with action sequences + Dictionary mapping output data types to lists of batched predictions. """ - batch_size = len(batch) - - if DataType.RGB_IMAGES not in batch.inputs: - raise ValueError("No RGB images available") - - images, image_masks = self._prepare_rgb_images(batch) - language_tokens, language_masks = self._process_language_tokens(batch) - merged_text_images, pad_masks = self._forward_vlm_merged_text_images( - images, image_masks, language_tokens, language_masks - ) - proprio_states = self._combine_proprio(batch) - # If no proprioception, create zero tensor with appropriate dimensions - # This allows the algorithm to work with visual-only inputs - if proprio_states is None or self.proprio_encoder is None: - # Create zero tensor with shape (B, action_expert_width) - proprio_embeds = torch.zeros( - batch_size, - self.action_expert_width, - device=self.device, - dtype=merged_text_images.dtype, - ) - else: - proprio_embeds = self.proprio_encoder(proprio_states) # (B, E) - - delta_t = 1.0 / self.num_inference_steps - t = torch.zeros( - batch_size, device=self.device, dtype=proprio_embeds.dtype - ) # (B,) - action = torch.randn( - (batch_size, self.action_horizon, self.action_dim), - device=self.device, - dtype=proprio_embeds.dtype, - ) # (B, H, A) - # Get the actual sequence length from the merged embeddings - actual_seq_len = merged_text_images.shape[1] - - for _ in range(self.num_inference_steps): - action_vel = self._predict_action( - merged_text_images, proprio_embeds, action, t, actual_seq_len, pad_masks - ) - action += delta_t * action_vel - t += delta_t - - # (B, T, action_dim) - predictions = self.action_normalizer.unnormalize(action) + self.model.eval() + self.model.gradient_checkpointing_disable() + self.model.compile_model_enable() + actions = self._predict_action(batch) + predictions = self.action_normalizer.unnormalize(actions) output_tensors: dict[DataType, list[BatchedNCData]] = {} for data_type in self.output_data_types: @@ -794,15 +596,9 @@ def training_step(self, batch: BatchedTrainingSamples) -> BatchedTrainingOutputs batch_size=batch.batch_size, ) - proprios = self._combine_proprio(inference_sample) - # If no proprioception, create zero tensor with appropriate dimensions - if proprios is None or self.proprio_encoder is None: - proprios = torch.zeros( - len(batch), - self.action_expert_width, - device=self.device, - dtype=torch.float32, - ) + images, image_masks, lang_tokens, lang_masks, proprios = ( + self._build_inputs_from_batch(inference_sample) + ) if set(batch.outputs.keys()) != set(self.output_data_types): raise ValueError( @@ -829,33 +625,15 @@ def training_step(self, batch: BatchedTrainingSamples) -> BatchedTrainingOutputs action_data = torch.cat(action_targets, dim=-1) # (B, T, total_action_dim) - target_actions = self.action_normalizer.normalize(action_data) - target_actions = target_actions - - t = self._sample_fm_time(len(batch)) - x0 = torch.randn_like(target_actions) - x1 = target_actions - # Calculate conditional flow - _t = t.view(-1, 1, 1) - psi_t = (1 - (1 - self.flow_sig_min) * _t) * x0 + _t * x1 + target_actions = self.action_normalizer.normalize(data=action_data) + # Pad to the max action dim after normalization to avoid padding artifacts + target_actions = pad_vector(target_actions, self.max_action_dim).to(self.device) - if DataType.RGB_IMAGES not in batch.inputs: - raise ValueError("RGB images are required for training") - - images, image_masks = self._prepare_rgb_images(inference_sample) - lang_tokens, lang_masks = self._process_language_tokens(inference_sample) - merged_text_images, pad_masks = self._forward_vlm_merged_text_images( - images, image_masks, lang_tokens, lang_masks - ) - proprio_embeds = self.proprio_encoder(proprios) # (B, E) - # Get the actual sequence length from the merged embeddings - actual_seq_len = merged_text_images.shape[1] - v_psi = self._predict_action( - merged_text_images, proprio_embeds, psi_t, t, actual_seq_len, pad_masks + mse_losses = self.model.forward( + images, image_masks, lang_tokens, lang_masks, proprios, target_actions ) - d_psi = x1 - (1 - self.flow_sig_min) * x0 - loss = F.mse_loss(v_psi, d_psi, reduction="none") - loss = loss.mean() + # Mask to the real action dims + loss = mse_losses[:, :, : self.action_dim].mean() losses = { "mse_loss": loss, @@ -868,41 +646,63 @@ def training_step(self, batch: BatchedTrainingSamples) -> BatchedTrainingOutputs metrics=metrics, ) - def _get_action_expert_parameters(self) -> list[torch.nn.Parameter]: - """Get parameters of the action expert. + def configure_optimizers(self) -> list[torch.optim.Optimizer]: + """Configure optimizer for training. Returns: - list: List of action expert parameters. + List containing a single AdamW optimizer. """ - return ( - list(self.action_encoder.parameters()) - + list(self.action_decoder.parameters()) - + list(self.proprio_encoder.parameters()) - + list(self.moe.get_parameters("action")) - ) + return [ + torch.optim.AdamW( + self.param_groups, + weight_decay=self.optimizer_weight_decay, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + ) + ] - def configure_optimizers( - self, - ) -> list[torch.optim.Optimizer]: - """Configure optimizer with different learning rates. + def configure_schedulers( + self, optimizers: list[torch.optim.Optimizer], num_training_steps: int + ) -> list[LambdaLR]: + """Configure learning rate schedulers. - Uses separate learning rates for image encoder backbone and other - model parameters. + Creates schedulers with linear warmup and cosine decay. Automatically + scales warmup and decay periods if training steps are fewer than + configured decay steps. + + Args: + optimizers: List of optimizers to create schedulers for + num_training_steps: Total number of training steps Returns: - list[torch.optim.Optimizer]: List of optimizers for model parameters + List of LambdaLR schedulers, one per optimizer. """ - if self.using_pretrained_paligemma: - # Only train action expert parameters when using pretrained VLM - trainable_params = self._get_action_expert_parameters() - else: - # Train all parameters when not using pretrained weights - trainable_params = [p for p in self.parameters() if p.requires_grad] - param_groups = [ - {"params": trainable_params, "lr": self.lr}, - ] + actual_warmup_steps = self.lr_scheduler_warmup_steps + actual_decay_steps = self.lr_scheduler_num_decay_steps + + # Auto-scale warmup and decay steps if training steps are fewer than + # configured decay steps + if num_training_steps < self.lr_scheduler_num_decay_steps: + scale = num_training_steps / self.lr_scheduler_num_decay_steps + actual_warmup_steps = int(self.lr_scheduler_warmup_steps * scale) + actual_decay_steps = num_training_steps + logger.info( + "Auto-scaling LR scheduler: warmup %s->%s, decay %s->%s (scale %.3f)", + self.lr_scheduler_warmup_steps, + actual_warmup_steps, + self.lr_scheduler_num_decay_steps, + actual_decay_steps, + scale, + ) + + lr_lambda = build_lr_lambda( + actual_warmup_steps=actual_warmup_steps, + actual_decay_steps=actual_decay_steps, + decay_lr=self.lr_scheduler_decay_lr, + optimizer_lr=self.optimizer_lr, + ) - return [torch.optim.AdamW(param_groups, weight_decay=self.weight_decay)] + return [LambdaLR(optimizer, lr_lambda, -1) for optimizer in optimizers] @staticmethod def get_supported_input_data_types() -> set[DataType]: diff --git a/neuracore/ml/algorithms/pi0/requirements.txt b/neuracore/ml/algorithms/pi0/requirements.txt new file mode 100644 index 00000000..9049df36 --- /dev/null +++ b/neuracore/ml/algorithms/pi0/requirements.txt @@ -0,0 +1,6 @@ +transformers==4.53.2 +scipy>=1.14.0 +scikit-learn>=1.5.0 +pandas>=2.2.0 +pyarrow>=16.0.0 +huggingface-hub==0.36.0 \ No newline at end of file diff --git a/neuracore/ml/algorithms/pi0/transformers_replace/models/gemma/configuration_gemma.py b/neuracore/ml/algorithms/pi0/transformers_replace/models/gemma/configuration_gemma.py new file mode 100644 index 00000000..8a9cc794 --- /dev/null +++ b/neuracore/ml/algorithms/pi0/transformers_replace/models/gemma/configuration_gemma.py @@ -0,0 +1,176 @@ +"""Gemma configuration for the PI0 transformers library modification. + +This file started life as auto-generated code from the upstream transformers +Gemma configuration and is now maintained here and adapted for the Neuracore PI0 +implementation. +""" + +from collections.abc import Callable +from typing import Any + +from ...configuration_utils import PretrainedConfig + + +class GemmaConfig(PretrainedConfig): + """Configuration for [`GemmaModel`]. + + This stores the configuration of a [`GemmaModel`] and is used to + instantiate a model according to the specified arguments, defining the + model architecture. Instantiating a configuration with the defaults yields + a similar configuration to Gemma-7B, e.g. + [google/gemma-7b](https://huggingface.co/google/gemma-7b). + Configuration objects inherit from [`PretrainedConfig`] and can be used to + control the model outputs. Read the documentation from + [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the Gemma model. Defines the number of different + tokens that can be represented by the `inputs_ids` passed when + calling [`GemmaModel`]. + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 24576): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the + Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + Number of key/value heads used to implement Grouped Query Attention. + If `num_key_value_heads=num_attention_heads`, the model will use + Multi Head Attention (MHA). If `num_key_value_heads=1` the model will + use Multi Query Attention (MQA), otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group + key and value head should be constructed by meanpooling all the + original heads within that group. For more details, see + [this paper](https://huggingface.co/papers/2305.13245). If it is not + specified, defaults to `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The legacy activation function. It is overwritten by the + `hidden_activation`. + hidden_activation (`str` or `function`, *optional*): + The non-linear activation function (function or string) in the + decoder. Defaults to `"gelu_pytorch_tanh"` if not specified. + `"gelu_pytorch_tanh"` uses an approximation of the `"gelu"` + activation function. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/value attentions + (not used by all models). Only relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection + layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + use_adarms (`bool`, *optional*, defaults to `False`): + Whether to use ADARMS. + adarms_cond_dim (`int`, *optional*, defaults to `None`): + The dimension of the ADARMS condition. + ```python + >>> from transformers import GemmaModel, GemmaConfig + >>> # Initializing a Gemma gemma-7b style configuration + >>> configuration = GemmaConfig() + >>> # Initializing a model from the gemma-7b style configuration + >>> model = GemmaModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "gemma" + keys_to_ignore_at_inference = ["past_key_values"] + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + "layers.*.self_attn.k_proj": "colwise", + "layers.*.self_attn.v_proj": "colwise", + "layers.*.self_attn.o_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + def __init__( + self, + vocab_size: int = 256000, + hidden_size: int = 3072, + intermediate_size: int = 24576, + num_hidden_layers: int = 28, + num_attention_heads: int = 16, + num_key_value_heads: int = 16, + head_dim: int = 256, + hidden_act: str | Callable = "gelu_pytorch_tanh", + hidden_activation: str | Callable | None = None, + max_position_embeddings: int = 8192, + initializer_range: float = 0.02, + rms_norm_eps: float = 1e-6, + use_cache: bool = True, + pad_token_id: int = 0, + eos_token_id: int = 1, + bos_token_id: int = 2, + tie_word_embeddings: bool = True, + rope_theta: float = 10000.0, + attention_bias: bool = False, + attention_dropout: float = 0.0, + use_adarms: bool = False, + adarms_cond_dim: int | None = None, + **kwargs: Any, + ) -> None: + """Initialize the configuration.""" + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.hidden_activation = hidden_activation + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.use_adarms = use_adarms + self.adarms_cond_dim = adarms_cond_dim + + # Set default for adarms_cond_dim if use_adarms is True + if self.use_adarms and self.adarms_cond_dim is None: + self.adarms_cond_dim = self.hidden_size + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +__all__ = ["GemmaConfig"] diff --git a/neuracore/ml/algorithms/pi0/transformers_replace/models/gemma/modeling_gemma.py b/neuracore/ml/algorithms/pi0/transformers_replace/models/gemma/modeling_gemma.py new file mode 100644 index 00000000..024725c4 --- /dev/null +++ b/neuracore/ml/algorithms/pi0/transformers_replace/models/gemma/modeling_gemma.py @@ -0,0 +1,1131 @@ +"""Gemma model implementation for the PI0 transformers library modification. + +This module implements the Gemma transformer architecture adapted for the +Neuracore PI0 algorithm. The model supports adaptive RMS normalization (ADARMS) +for conditional generation, rotary position embeddings (RoPE), and various +attention backends including flash attention. + +The implementation includes: +- GemmaRMSNorm: Root mean square layer normalization with optional adaptive + conditioning for ADARMS +- GemmaMLP: Multi-layer perceptron with gated activation +- GemmaRotaryEmbedding: Rotary position embeddings for relative position encoding +- GemmaAttention: Multi-headed self-attention with support for grouped query + attention +- GemmaDecoderLayer: Transformer decoder layer with self-attention and MLP +- GemmaModel: Full decoder-only transformer model +- GemmaForCausalLM: Causal language modeling head on top of GemmaModel + +This file started life as auto-generated code from the upstream transformers +Gemma model and is now maintained here and adapted for the Neuracore PI0 +implementation. +""" + +from collections.abc import Callable +from typing import Any + +import torch +from torch import nn + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import LossKwargs, auto_docstring, can_return_tuple, logging +from .configuration_gemma import GemmaConfig + +logger = logging.get_logger(__name__) + + +class GemmaRMSNorm(nn.Module): + """Root mean square layer normalization with optional adaptive conditioning. + + This module implements RMS normalization as used in the Gemma architecture. + When a condition dimension is provided, it supports adaptive RMS normalization + (ADARMS) which modulates the normalization using a learned dense layer that + produces scale, shift, and gate parameters from a conditioning vector. + + Args: + dim: Hidden dimension size + eps: Small epsilon value for numerical stability (default: 1e-6) + cond_dim: Optional condition dimension for adaptive normalization + """ + + def __init__( + self, dim: int, eps: float = 1e-6, cond_dim: int | None = None + ) -> None: + super().__init__() + self.eps = eps + self.dim = dim + self.cond_dim = cond_dim + + # Dense layer for adaptive normalization (if cond_dim is provided) + if cond_dim is not None: + # self.dense = nn.Linear(cond_dim, dim * 3, bias=True, dtype=torch.bfloat16) + self.dense = nn.Linear(cond_dim, dim * 3, bias=True) + # Initialize with zeros (matches source implementation) + nn.init.zeros_(self.dense.weight) + else: + self.weight = nn.Parameter(torch.zeros(dim, dtype=torch.bfloat16)) + self.dense = None + + def _norm(self, x: torch.Tensor) -> torch.Tensor: + """Compute RMS normalization of input tensor. + + Args: + x: Input tensor of shape [..., dim] + + Returns: + Normalized tensor with same shape as input. + """ + # Compute variance in float32 (like the source implementation) + var = torch.mean(torch.square(x.float()), dim=-1, keepdim=True) + # Compute normalization in float32 + normed_inputs = x * torch.rsqrt(var + self.eps) + return normed_inputs + + def forward( + self, x: torch.Tensor, cond: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Apply RMS normalization with optional adaptive conditioning. + + Args: + x: Input tensor of shape [batch, seq_len, dim] or [batch, dim] + cond: Optional condition tensor of shape [batch, cond_dim] for + adaptive normalization. If provided, must match cond_dim. + + Returns: + Tuple of (normalized_tensor, gate_tensor). The gate tensor is None + for standard RMSNorm and contains the gate values for ADARMS. + """ + dtype = x.dtype # original dtype, could be half-precision + normed_inputs = self._norm(x) + + if cond is None or self.dense is None: + # regular RMSNorm + # scale by learned parameter in float32 (matches source implementation) + normed_inputs = normed_inputs * (1.0 + self.weight.float()) + return ( + normed_inputs.to(dtype), + None, + ) # return in original dtype with None gate + + # adaptive RMSNorm (if cond is provided and dense layer exists) + if cond.shape[-1] != self.cond_dim: + raise ValueError( + f"Expected cond dimension {self.cond_dim}, got {cond.shape[-1]}" + ) + + # self.dense.to(dtype=torch.bfloat16).to(dtype=torch.float32) + modulation = self.dense(cond) + # Reshape modulation to broadcast properly: + # [batch, 1, features] for [batch, seq, features] + if len(x.shape) == 3: # [batch, seq, features] + modulation = modulation.unsqueeze(1) + + scale, shift, gate = torch.chunk(modulation, 3, dim=-1) + + # Apply adaptive normalization: use model weight dtype to ensure compatibility + # model_dtype = self.dense.weight.dtype # Use the model's dtype (bfloat16) + # scale = scale.to(model_dtype) + # shift = shift.to(model_dtype) + # gate = gate.to(model_dtype) + # normed_inputs = normed_inputs.to(model_dtype) + # Convert normed_inputs to model dtype + + normed_inputs = normed_inputs * (1 + scale.to(torch.float32)) + shift.to( + torch.float32 + ) + + return normed_inputs.to(dtype), gate.to(dtype) + + def extra_repr(self) -> str: + """Return a string representation of the module configuration. + + Returns: + String describing the module parameters. + """ + repr_str = f"{tuple(self.weight.shape)}, eps={self.eps}" + if self.dense is not None: + repr_str += f", adaptive=True, cond_dim={self.cond_dim}" + return repr_str + + +class GemmaMLP(nn.Module): + """Multi-layer perceptron with gated activation for Gemma. + + This MLP uses a gated activation pattern where the input is projected + through both a gate projection and an up projection, then combined via + element-wise multiplication before being projected down. + + Args: + config: Gemma configuration containing hidden_size, intermediate_size, + and hidden_act activation function. + """ + + def __init__(self, config: GemmaConfig) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the MLP transformation with gated activation. + + Args: + x: Input tensor of shape [batch, seq_len, hidden_size] + + Returns: + Output tensor of shape [batch, seq_len, hidden_size] + """ + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class GemmaRotaryEmbedding(nn.Module): + """Rotary position embedding (RoPE) for relative position encoding. + + This module implements rotary position embeddings that encode relative + position information directly into the attention mechanism. Supports + various RoPE scaling types for handling longer sequences. + + Args: + config: Gemma configuration containing max_position_embeddings and + optional rope_scaling configuration + device: Target device for buffer registration + """ + + def __init__( + self, config: GemmaConfig, device: torch.device | str | None = None + ) -> None: + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get( + "rope_type", config.rope_scaling.get("type") + ) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + # Power user: used with advanced RoPE types (e.g. dynamic rope). + @dynamic_rope_update + def forward( + self, x: torch.Tensor, position_ids: torch.LongTensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute cosine and sine embeddings for rotary position encoding. + + Args: + x: Input tensor used to determine device and dtype + position_ids: Position indices of shape [batch_size, seq_len] + + Returns: + Tuple of (cos_emb, sin_emb) tensors of shape + [batch_size, seq_len, head_dim] matching the dtype of x. + """ + inv_freq_expanded = ( + self.inv_freq[None, :, None] + .float() + .expand(position_ids.shape[0], -1, 1) + .to(x.device) + ) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = ( + x.device.type + if isinstance(x.device.type, str) and x.device.type != "mps" + else "cpu" + ) + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = ( + inv_freq_expanded.float() @ position_ids_expanded.float() + ).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotate half the hidden dimensions of the input for RoPE. + + This function splits the last dimension in half and rotates the two halves, + which is used in the rotary position embedding computation. + + Args: + x: Input tensor with shape [..., hidden_dim] + + Returns: + Rotated tensor with same shape as input. + """ + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: torch.Tensor | None = None, + unsqueeze_dim: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + """Apply rotary position embedding to query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be broadcast to q and k. For + example, cos[position_ids] and sin[position_ids] have shape + [batch_size, seq_len, head_dim]. If q and k have shape + [batch_size, heads, seq_len, head_dim], then setting + unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] + broadcastable to q and k. If q and k have shape + [batch_size, seq_len, heads, head_dim], set unsqueeze_dim=2. + + Returns: + `tuple(torch.Tensor)` comprising the rotated query and key tensors. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """Repeat key/value heads across attention groups for grouped query attention. + + This function implements grouped query attention (GQA) by repeating key/value + heads to match the number of query heads. This is more efficient than + multi-head attention when num_key_value_heads < num_attention_heads. + + Args: + hidden_states: Key or value tensor of shape + [batch, num_key_value_heads, seq_len, head_dim] + n_rep: Number of repetitions (num_attention_heads // num_key_value_heads) + + Returns: + Repeated tensor of shape [batch, num_attention_heads, seq_len, head_dim] + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def _gated_residual( + x: torch.Tensor | None, + y: torch.Tensor | None, + gate: torch.Tensor | None, +) -> torch.Tensor | None: + """Apply a gated residual connection. + + Args: + x: Input tensor (residual). + y: Output tensor to be added. + gate: Optional gate tensor to modulate the addition. + + Returns: + The gated residual sum. + """ + if x is None and y is None: + return None + if x is None or y is None: + return x if x is not None else y + if gate is None: + return x + y + return x + y * gate + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Any, +) -> tuple[torch.Tensor, torch.Tensor]: + """Eager implementation of scaled dot-product attention. + + This function computes attention using standard matrix operations without + optimized kernels. Used as a fallback when flash attention is not available. + + Args: + module: Attention module containing num_key_value_groups attribute + query: Query tensor of shape [batch, num_heads, seq_len, head_dim] + key: Key tensor of shape [batch, num_kv_heads, seq_len, head_dim] + value: Value tensor of shape [batch, num_kv_heads, seq_len, head_dim] + attention_mask: Optional attention mask of shape + [batch, 1, seq_len, seq_len] + scaling: Attention scaling factor (typically 1/sqrt(head_dim)) + dropout: Dropout probability for attention weights + **kwargs: Additional keyword arguments (unused) + + Returns: + Tuple of (attention_output, attention_weights) where: + - attention_output: [batch, seq_len, num_heads, head_dim] + - attention_weights: [batch, num_heads, seq_len, seq_len] + """ + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query.dtype + ) + attn_weights = nn.functional.dropout( + attn_weights, p=dropout, training=module.training + ) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class GemmaAttention(nn.Module): + """Multi-headed self-attention with rotary position embeddings. + + This module implements scaled dot-product attention with support for: + - Grouped query attention (GQA) for efficient key/value caching + - Rotary position embeddings (RoPE) for relative position encoding + - Multiple attention backends (eager, flash attention, SDPA) + - Causal masking for autoregressive generation + + Args: + config: Gemma configuration containing attention parameters + layer_idx: Layer index for cache management + """ + + def __init__(self, config: GemmaConfig, layer_idx: int) -> None: + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = nn.Linear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: torch.Tensor | None, + past_key_value: Cache | None = None, + cache_position: torch.LongTensor | None = None, + use_cache: bool = False, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Compute self-attention with rotary position embeddings. + + Args: + hidden_states: Input tensor of shape [batch, seq_len, hidden_size] + position_embeddings: Tuple of (cos, sin) tensors for RoPE + attention_mask: Optional attention mask of shape + [batch, 1, seq_len, seq_len] + past_key_value: Optional cache for key/value states + cache_position: Optional position indices for cache updates + use_cache: Whether to update and return the cache + **kwargs: Additional attention backend arguments + + Returns: + Tuple of (attention_output, attention_weights) where: + - attention_output: [batch, seq_len, hidden_size] + - attention_weights: [batch, num_heads, seq_len, seq_len] or None + """ + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + # Use cache if provided + if past_key_value is not None: + if use_cache: + # sin and cos are specific to RoPE models; cache_position is + # needed for the static cache. + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + else: + key_states = torch.cat( + [past_key_value[self.layer_idx][0], key_states], dim=2 + ) + value_states = torch.cat( + [past_key_value[self.layer_idx][1], value_states], dim=2 + ) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class GemmaDecoderLayer(GradientCheckpointingLayer): + """Transformer decoder layer with self-attention and MLP. + + This layer implements a standard transformer decoder block with: + - Pre-attention layer normalization (with optional ADARMS) + - Multi-headed self-attention with RoPE + - Gated residual connection + - Post-attention layer normalization (with optional ADARMS) + - Feed-forward MLP with gated activation + - Gated residual connection + + Supports gradient checkpointing for memory-efficient training. + + Args: + config: Gemma configuration containing layer parameters + layer_idx: Layer index for attention cache management + """ + + def __init__(self, config: GemmaConfig, layer_idx: int) -> None: + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = GemmaAttention(config=config, layer_idx=layer_idx) + + self.mlp = GemmaMLP(config) + cond_dim = ( + getattr(config, "adarms_cond_dim", None) + if getattr(config, "use_adarms", False) + else None + ) + self.input_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim + ) + self.post_attention_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool | None = False, + use_cache: bool | None = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: None | ( + tuple[torch.Tensor, torch.Tensor] + ) = None, # necessary, but kept here for BC + adarms_cond: torch.Tensor | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.FloatTensor] | tuple[torch.FloatTensor, torch.FloatTensor | None]: + """Apply the decoder layer transformation. + + Args: + hidden_states: Input tensor of shape [batch, seq_len, hidden_size] + attention_mask: Optional attention mask + position_ids: Optional position indices + past_key_value: Optional cache for key/value states + output_attentions: Whether to return attention weights + use_cache: Whether to update and return the cache + cache_position: Optional position indices for cache updates + position_embeddings: Optional precomputed RoPE embeddings + adarms_cond: Optional condition tensor for ADARMS + **kwargs: Additional attention backend arguments + + Returns: + Tuple containing: + - hidden_states: [batch, seq_len, hidden_size] + - attention_weights: Optional [batch, num_heads, seq_len, seq_len] + """ + residual = hidden_states + hidden_states, gate = self.input_layernorm(hidden_states, adarms_cond) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = _gated_residual(residual, hidden_states, gate) + + # Fully Connected + residual = hidden_states + hidden_states, gate = self.post_attention_layernorm(hidden_states, adarms_cond) + hidden_states = self.mlp(hidden_states) + hidden_states = _gated_residual(residual, hidden_states, gate) + + if output_attentions: + return (hidden_states, self_attn_weights) + return (hidden_states,) + + +@auto_docstring +class GemmaPreTrainedModel(PreTrainedModel): + """Base class for Gemma models. + + This class provides common functionality for all Gemma model variants, + including weight initialization and support for various attention backends + and optimization features. + """ + + config_class = GemmaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["GemmaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_3 = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_attention_backend = True + + def _init_weights(self, module: nn.Module) -> None: + """Initialize weights for different module types. + + Args: + module: PyTorch module to initialize + """ + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, GemmaRMSNorm): + if hasattr(module, "weight"): + module.weight.data.fill_(1.0) + + +@auto_docstring +class GemmaModel(GemmaPreTrainedModel): + """Gemma decoder-only transformer model. + + This model implements a standard decoder-only transformer architecture + with token embeddings, multiple decoder layers, and final layer normalization. + Supports adaptive RMS normalization (ADARMS) for conditional generation. + + Args: + config: Gemma configuration specifying model architecture + """ + + def __init__(self, config: GemmaConfig) -> None: + """Initialize the Gemma model. + + Args: + config: Gemma configuration containing model hyperparameters + """ + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList([ + GemmaDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + + cond_dim = ( + getattr(config, "adarms_cond_dim", None) + if getattr(config, "use_adarms", False) + else None + ) + self.norm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps, cond_dim=cond_dim + ) + self.rotary_emb = GemmaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + """Return the input token embeddings. + + Returns: + Embedding layer for input tokens + """ + return self.embed_tokens + + def set_input_embeddings(self, value: nn.Embedding) -> None: + """Set the input token embeddings. + + Args: + value: New embedding layer to use for input tokens + """ + self.embed_tokens = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + cache_position: torch.LongTensor | None = None, + adarms_cond: torch.Tensor | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> BaseModelOutputWithPast: + """Run the forward pass for the decoder. + + Args: + input_ids (`torch.LongTensor`, *optional*): + Input token IDs. + attention_mask (`torch.Tensor`, *optional*): + Attention mask for the input tokens. + position_ids (`torch.LongTensor`, *optional*): + Position indices for the input tokens. + past_key_values (`Cache`, *optional*): + Cached key/value states for faster decoding. + inputs_embeds (`torch.FloatTensor`, *optional*): + Precomputed input embeddings. + use_cache (`bool`, *optional*): + Whether to return key/value cache. + output_attentions (`bool`, *optional*): + Whether to return attention weights. + output_hidden_states (`bool`, *optional*): + Whether to return hidden states. + cache_position (`torch.LongTensor`, *optional*): + Positions used for cache updates. + adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): + Condition for ADARMS. + **kwargs: + Additional attention-related keyword arguments. + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. " + "Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + # embed positions + hidden_states = inputs_embeds + # Convert to bfloat16 if the first layer uses bfloat16 + if ( + len(self.layers) > 0 + and self.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 + ): + hidden_states = hidden_states.to(torch.bfloat16) + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # normalized + # Gemma downcasts the below to float16, causing sqrt(3072)=55.4256 to + # become 55.5. + # See https://github.com/huggingface/transformers/pull/29402 + torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) + # hidden_states = hidden_states * normalizer + + # decoder layers + all_hidden_states: tuple[torch.Tensor, ...] | None = ( + () if output_hidden_states else None + ) + all_self_attns: tuple[torch.Tensor, ...] | None = ( + () if output_attentions else None + ) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + if output_hidden_states: + assert all_hidden_states is not None + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + adarms_cond=adarms_cond, + **kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + assert all_self_attns is not None + all_self_attns += (layer_outputs[1],) + + hidden_states, _ = self.norm(hidden_states, adarms_cond) + + # add hidden states from the last decoder layer + if output_hidden_states: + assert all_hidden_states is not None + all_hidden_states += (hidden_states,) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): + """Type alias for keyword arguments accepted by GemmaForCausalLM. + + Combines flash attention kwargs and loss kwargs for type checking. + """ + + +@auto_docstring +class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin): + """Gemma model with a causal language modeling head. + + This model adds a language modeling head on top of the GemmaModel decoder, + enabling next-token prediction for autoregressive text generation. Supports + efficient logit computation via logits_to_keep parameter. + + Args: + config: Gemma configuration specifying model architecture + """ + + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + def __init__(self, config: GemmaConfig) -> None: + """Initialize the causal language model. + + Args: + config: Gemma configuration containing model hyperparameters + """ + super().__init__(config) + self.model = GemmaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + """Return the input token embeddings. + + Returns: + Embedding layer for input tokens + """ + return self.model.embed_tokens + + def set_input_embeddings(self, value: nn.Embedding) -> None: + """Set the input token embeddings. + + Args: + value: New embedding layer to use for input tokens + """ + self.model.embed_tokens = value + + def get_output_embeddings(self) -> nn.Linear: + """Return the output token embeddings. + + Returns: + Linear layer mapping hidden states to vocabulary logits + """ + return self.lm_head + + def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: + """Set the output token embeddings. + + Args: + new_embeddings: New linear layer for output embeddings + """ + self.lm_head = new_embeddings + + def set_decoder(self, decoder: GemmaModel) -> None: + """Set the decoder module. + + Args: + decoder: GemmaModel instance to use as decoder + """ + self.model = decoder + + def get_decoder(self) -> GemmaModel: + """Return the decoder module. + + Returns: + The underlying GemmaModel decoder + """ + return self.model + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + cache_position: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + adarms_cond: torch.Tensor | None = None, + **kwargs: Unpack[KwargsForCausalLM], + ) -> CausalLMOutputWithPast: + """Run the forward pass for causal language modeling. + + Args: + input_ids (`torch.LongTensor`, *optional*): + Input token IDs. + attention_mask (`torch.Tensor`, *optional*): + Attention mask for the input tokens. + position_ids (`torch.LongTensor`, *optional*): + Position indices for the input tokens. + past_key_values (`Cache`, *optional*): + Cached key/value states for faster decoding. + inputs_embeds (`torch.FloatTensor`, *optional*): + Precomputed input embeddings. + labels (`torch.LongTensor`, *optional*): + Labels for computing the masked language modeling loss with + shape `(batch_size, sequence_length)`. Indices should either be + in `[0, ..., config.vocab_size]` or -100 (see `input_ids` + docstring). Tokens with indices set to `-100` are ignored + (masked), so the loss is computed only for tokens with labels in + `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + Whether to return key/value cache. + output_attentions (`bool`, *optional*): + Whether to return attention weights. + output_hidden_states (`bool`, *optional*): + Whether to return hidden states. + cache_position (`torch.LongTensor`, *optional*): + Positions used for cache updates. + logits_to_keep (`int` or `torch.Tensor`, *optional*): + Number of logits to compute or indices to keep. + adarms_cond (`torch.Tensor` of shape `(batch_size, cond_dim)`, *optional*): + Condition for ADARMS. + **kwargs: + Additional attention or loss-related keyword arguments. + + Example: + ```python + >>> from transformers import AutoTokenizer, GemmaForCausalLM + + >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode( + ... generate_ids, + ... skip_special_tokens=True, + ... clean_up_tokenization_spaces=False, + ... )[0] + "What is your favorite condiment?" + ``` + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + adarms_cond=adarms_cond, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we + # are not computing the loss. + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +__all__ = [ + "GemmaModel", + "GemmaForCausalLM", + "GemmaPreTrainedModel", +] diff --git a/neuracore/ml/algorithms/pi0/transformers_replace/models/paligemma/modeling_paligemma.py b/neuracore/ml/algorithms/pi0/transformers_replace/models/paligemma/modeling_paligemma.py new file mode 100644 index 00000000..814c3bbb --- /dev/null +++ b/neuracore/ml/algorithms/pi0/transformers_replace/models/paligemma/modeling_paligemma.py @@ -0,0 +1,1024 @@ +"""PyTorch PaliGemma model implementation. + +This module implements the PaliGemma vision-language model adapted for the +Neuracore PI0 algorithm. PaliGemma combines a vision encoder (SigLIP) with a +language model (Gemma) to enable vision-language understanding and generation. + +The implementation includes: +- PaliGemmaMultiModalProjector: Projects vision features to language model space +- PaliGemmaModel: Base model with vision and language backbones +- PaliGemmaForConditionalGeneration: Full model with language modeling head +- Support for prefix attention, causal masking, and efficient caching + +This file started life as auto-generated code from the upstream transformers +PaliGemma model and is now maintained here and adapted for the Neuracore PI0 +implementation. +""" + +from dataclasses import dataclass +from typing import Any, cast + +import torch +import torch.utils.checkpoint +from torch import nn + +from ...cache_utils import Cache, HybridCache, StaticCache +from ...generation import GenerationMixin +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + ModelOutput, + auto_docstring, + can_return_tuple, + is_torchdynamo_compiling, + logging, +) +from ..auto import AutoModel +from .configuration_paligemma import PaliGemmaConfig + +logger = logging.get_logger(__name__) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for Paligemma outputs, with hidden states and attentions. + """ +) +class PaligemmaModelOutputWithPast(BaseModelOutputWithPast): + """Base class for Paligemma outputs with past key values. + + Args: + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned + when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, + with each tuple having 2 tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`. Contains pre-computed + hidden-states (key and values in the self-attention blocks) that can + be used (see `past_key_values` input) to speed up sequential + decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, + sequence_length, hidden_size)`. Image hidden states of the model + produced by the vision encoder and after projecting the last hidden + state. + """ + + image_hidden_states: torch.FloatTensor | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for PaliGemma causal language model (or autoregressive) outputs. + """ +) +class PaliGemmaCausalLMOutputWithPast(ModelOutput): + """Outputs for Paligemma causal language modeling. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when + `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, + config.text_config.vocab_size)`): + Prediction scores of the language modeling head (scores for each + vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned + when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, + with each tuple having 2 tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`. Contains pre-computed + hidden-states (key and values in the self-attention blocks) that can + be used (see `past_key_values` input) to speed up sequential + decoding. + image_hidden_states (`torch.FloatTensor`, *optional*): + A `torch.FloatTensor` of size `(batch_size, num_images, + sequence_length, hidden_size)`. Image hidden states of the model + produced by the vision encoder after projecting last hidden state. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: list[torch.FloatTensor] | Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + image_hidden_states: torch.FloatTensor | None = None + + +class PaliGemmaMultiModalProjector(nn.Module): + """Multi-modal projector for aligning vision and language features. + + This module projects vision encoder features to the language model's + embedding space, enabling the language model to process visual information. + + Args: + config: PaliGemma configuration containing vision and projection parameters + """ + + def __init__(self, config: PaliGemmaConfig) -> None: + super().__init__() + self.linear = nn.Linear( + config.vision_config.hidden_size, + config.vision_config.projection_dim, + bias=True, + ) + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + """Project image features to language model embedding space. + + Args: + image_features: Vision encoder features of shape + [batch_size, num_patches, vision_hidden_size] + + Returns: + Projected features of shape [batch_size, num_patches, projection_dim] + """ + hidden_states = self.linear(image_features) + + return hidden_states + + +@auto_docstring +class PaliGemmaPreTrainedModel(PreTrainedModel): + """Base class for PaliGemma models. + + This class provides common functionality for all PaliGemma model variants, + including weight initialization and support for various attention backends + and optimization features. Note that this ported version is intended for + inference and fine-tuning, not training from scratch. + """ + + config_class = PaliGemmaConfig + base_model_prefix = "" + supports_gradient_checkpointing = True + _no_split_modules = ["PaliGemmaMultiModalProjector"] + _skip_keys_device_placement = "past_key_values" + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_attention_backend = True + + def _init_weights(self, module: nn.Module) -> None: + """Initialize weights for different module types. + + Note: This ported version of PaliGemma isn't meant for training from + scratch - only inference and fine-tuning. + + Args: + module: PyTorch module to initialize + """ + # important: this ported version of PaliGemma isn't meant for training + # from scratch - only inference and fine-tuning + std = getattr( + self.config, + "initializer_range", + self.config.get_text_config().initializer_range, + ) + + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + + +@auto_docstring( + custom_intro=""" + The base Paligemma model which consists of a vision backbone and a language + model without a language modeling head. + """ +) +class PaliGemmaModel(PaliGemmaPreTrainedModel): + """PaliGemma model with vision and language backbones.""" + + _checkpoint_conversion_mapping = {"language_model.model": "language_model"} + # We are filtering the logits/labels so we shouldn't divide the loss based + # on num_items_in_batch. + accepts_loss_kwargs = False + + def __init__(self, config: PaliGemmaConfig) -> None: + """Initialize the PaliGemma model. + + Args: + config: PaliGemma configuration containing vision and text model + hyperparameters + """ + super().__init__(config) + self.vision_tower = AutoModel.from_config(config=config.vision_config) + self.multi_modal_projector = PaliGemmaMultiModalProjector(config) + self.vocab_size = config.text_config.vocab_size + + language_model = AutoModel.from_config(config=config.text_config) + self.language_model = language_model + + self.pad_token_id = ( + self.config.pad_token_id if self.config.pad_token_id is not None else -1 + ) + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + """Return the language model input embeddings. + + Returns: + Embedding layer for input tokens + """ + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value: nn.Module) -> None: + """Set the language model input embeddings. + + Args: + value: New embedding layer to use for input tokens + """ + self.language_model.set_input_embeddings(value) + + def set_decoder(self, decoder: nn.Module) -> None: + """Set the language model decoder. + + Args: + decoder: Language model decoder module + """ + self.language_model = decoder + + def get_decoder(self) -> nn.Module: + """Return the language model decoder. + + Returns: + The underlying language model decoder + """ + return self.language_model + + def _update_causal_mask( + self, + attention_mask: torch.Tensor | None, + token_type_ids: torch.Tensor | None = None, + past_key_values: Cache | None = None, + cache_position: torch.Tensor | None = None, + input_tensor: torch.Tensor | None = None, + is_training: bool | None = None, + ) -> torch.Tensor | None: + """Update causal attention mask for prefix attention. + + This method creates or updates a causal attention mask that supports: + - Prefix attention: allows attending to image tokens during training + - Causal masking: prevents attending to future tokens + - Padding masking: masks padding tokens + - Cache-aware masking: handles static and dynamic caches + + Args: + attention_mask: Optional 2D or 4D attention mask + token_type_ids: Optional token type IDs for prefix attention + (required during training) + past_key_values: Optional cache for key/value states + cache_position: Optional position indices for cache updates + input_tensor: Optional input tensor to determine sequence length + is_training: Optional flag indicating training mode + + Returns: + Updated 4D causal attention mask of shape + [batch_size, 1, seq_len, target_len] or None if flash attention + is used and mask is not needed + """ + if self.config.text_config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + is_training = is_training if is_training is not None else self.training + using_static_cache = isinstance(past_key_values, StaticCache) + if using_static_cache: + assert past_key_values is not None + min_dtype = torch.finfo(self.dtype).min + if input_tensor is None: + input_tensor = attention_mask + assert input_tensor is not None + + inputs_lead_dim, sequence_length = input_tensor.shape[:2] + if using_static_cache: + if past_key_values is None: + raise ValueError("past_key_values must be provided for static cache.") + static_cache = cast(StaticCache, past_key_values) + target_length = static_cache.get_max_cache_shape() + elif isinstance(past_key_values, HybridCache): + hybrid_cache = cast(HybridCache, past_key_values) + target_length = hybrid_cache.get_max_cache_shape() + else: + if isinstance(attention_mask, torch.Tensor): + target_length = attention_mask.shape[-1] + else: + if cache_position is None: + raise ValueError( + "cache_position must be provided without attention_mask." + ) + target_length = cache_position[0] + sequence_length + 1 + + if attention_mask is not None and attention_mask.dim() == 4: + # The mask comes already in inverted form and requires no inversion + # or slicing. + return attention_mask + + if cache_position is None: + raise ValueError("cache_position must be provided to build a causal mask.") + cache_position_tensor = cache_position + causal_mask: torch.Tensor + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=self.dtype, + device=cache_position_tensor.device, + ) + # Causal diagonal mask only if training, otherwise attend to the whole + # prefix. Training-specific attention for prefix is handled below. + if sequence_length != 1: + if is_training: + causal_mask = torch.triu(causal_mask, diagonal=1) + else: + causal_mask[:, :sequence_length] = 0.0 + + causal_mask *= torch.arange( + target_length, device=cache_position_tensor.device + ) > cache_position_tensor.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) + if attention_mask is not None: + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + + # First unmask prefix tokens during training + if is_training: + if token_type_ids is None: + raise ValueError("Token type ids must be provided during training") + causal_mask[:, :, :, :mask_length] = causal_mask[ + :, :, :, :mask_length + ].masked_fill( + token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0 + ) + + # Then apply padding mask (will mask pad tokens) + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ + :, None, None, : + ].to(causal_mask.device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[ + :, :, :, :mask_length + ].masked_fill(padding_mask, min_dtype) + + return causal_mask + + def get_image_features(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + """Obtain image features from the vision tower. + + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, channels, + height, width)`): + The tensors corresponding to the input images. + + Returns: + image_features (`torch.Tensor`): + Image feature tensor of shape `(num_images, image_length, + embed_dim)`. + """ + image_outputs = self.vision_tower(pixel_values) + selected_image_feature = image_outputs.last_hidden_state + image_features = self.multi_modal_projector(selected_image_feature) + return image_features + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + token_type_ids: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple | PaligemmaModelOutputWithPast: + """Run the forward pass for the multi-modal model. + + Args: + input_ids (`torch.LongTensor`, *optional*): + Input token IDs. + pixel_values (`torch.FloatTensor`, *optional*): + Image pixel values. + attention_mask (`torch.Tensor`, *optional*): + Attention mask for input tokens. + position_ids (`torch.LongTensor`, *optional*): + Position indices for the input tokens. + past_key_values (`Cache`, *optional*): + Cached key/value states for faster decoding. + token_type_ids (`torch.LongTensor`, *optional*): + Token type ids for prefix attention. + cache_position (`torch.LongTensor`, *optional*): + Cache positions for decoding. + inputs_embeds (`torch.FloatTensor`, *optional*): + Precomputed input embeddings. + labels (`torch.LongTensor`, *optional*): + Labels for computing the masked language modeling loss. Indices + should either be in `[0, ..., config.text_config.vocab_size]` or + -100 (see `input_ids` docstring). Tokens with indices set to + `-100` are ignored (masked), so the loss is computed only for + tokens with labels in `[0, ..., config.text_config.vocab_size]`. + use_cache (`bool`, *optional*): + Whether to return key/value cache. + output_attentions (`bool`, *optional*): + Whether to return attention weights. + output_hidden_states (`bool`, *optional*): + Whether to return hidden states. + return_dict (`bool`, *optional*): + Whether to return a ModelOutput dict. + **kwargs: + Additional attention-related keyword arguments. + + Example: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor + >>> from transformers import PaliGemmaForConditionalGeneration + + >>> model = PaliGemmaForConditionalGeneration.from_pretrained( + ... "google/paligemma2-3b-mix-224" + ... ) + >>> processor = AutoProcessor.from_pretrained( + ... "google/paligemma2-3b-mix-224" + ... ) + + >>> prompt = "Where is the cat standing?" + >>> url = ( + ... "https://huggingface.co/datasets/huggingface/" + ... "documentation-images/resolve/main/pipeline-cat-chonk.jpeg" + ... ) + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs) + >>> processor.batch_decode( + ... generate_ids, + ... skip_special_tokens=True, + ... clean_up_tokenization_spaces=False, + ... )[0] + "Where is the cat standing? snow" + ``` + """ + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + is_training = token_type_ids is not None and labels is not None + + # Replace image id with PAD if the image token if OOV, to avoid index-errors + if input_ids is not None and self.config.image_token_id >= self.vocab_size: + special_image_mask = input_ids == self.config.image_token_id + llm_input_ids = input_ids.clone() + llm_input_ids[special_image_mask] = 0 + else: + llm_input_ids = input_ids + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(llm_input_ids) + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + + if position_ids is None: + position_ids = ( + cache_position.unsqueeze(0) + 1 + ) # Paligemma positions are 1-indexed + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor( + self.config.image_token_id, + dtype=torch.long, + device=inputs_embeds.device, + ) + ) + else: + special_image_mask = ( + input_ids == self.config.image_token_id + ).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds).to( + inputs_embeds.device + ) + + if ( + not is_torchdynamo_compiling() + and inputs_embeds[special_image_mask].numel() != image_features.numel() + ): + image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] + raise ValueError( + "Number of images does not match number of special image " + "tokens in the input text. " + f"Got {image_tokens_in_text} image tokens in the text but " + f"{image_features.shape[0] * image_features.shape[1]} " + "tokens from image embeddings." + ) + image_features = image_features.to( + inputs_embeds.device, inputs_embeds.dtype + ) + inputs_embeds = inputs_embeds.masked_scatter( + special_image_mask, image_features + ) + + causal_mask = self._update_causal_mask( + attention_mask, + token_type_ids, + past_key_values, + cache_position, + inputs_embeds, + is_training, + ) + outputs = self.language_model( + attention_mask=causal_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + return PaligemmaModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): + """Type alias for keyword arguments accepted by PaliGemmaForConditionalGeneration. + + Combines flash attention kwargs and loss kwargs for type checking. + """ + + +@auto_docstring( + custom_intro=""" + The base Paligemma model which consists of a vision backbone and a language + model without language modeling head. + """ +) +class PaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin): + """PaliGemma model with a conditional generation head.""" + + _checkpoint_conversion_mapping = { + "^language_model.model": "model.language_model", + "^vision_tower": "model.vision_tower", + "^multi_modal_projector": "model.multi_modal_projector", + "^language_model.lm_head": "lm_head", + } + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: PaliGemmaConfig) -> None: + """Initialize the conditional generation model. + + Args: + config: PaliGemma configuration containing model hyperparameters + """ + super().__init__(config) + self.model = PaliGemmaModel(config) + self.lm_head = nn.Linear( + config.text_config.hidden_size, config.text_config.vocab_size, bias=False + ) + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + """Return the input token embeddings. + + Returns: + Embedding layer for input tokens + """ + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value: nn.Module) -> None: + """Set the input token embeddings. + + Args: + value: New embedding layer to use for input tokens + """ + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Linear: + """Return the output token embeddings. + + Returns: + Linear layer mapping hidden states to vocabulary logits + """ + return self.lm_head + + def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: + """Set the output token embeddings. + + Args: + new_embeddings: New linear layer for output embeddings + """ + self.lm_head = new_embeddings + + def set_decoder(self, decoder: nn.Module) -> None: + """Set the language model decoder. + + Args: + decoder: Language model decoder module + """ + self.model.set_decoder(decoder) + + def get_decoder(self) -> nn.Module: + """Return the language model decoder. + + Returns: + The underlying language model decoder + """ + return self.model.get_decoder() + + def get_image_features(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + """Return image features from the vision tower. + + Args: + pixel_values: Image pixel values of shape + [batch_size, channels, height, width] + + Returns: + Image feature tensor of shape [num_images, image_length, embed_dim] + """ + return self.model.get_image_features(pixel_values) + + # Make modules available through conditional class for BC + @property + def language_model(self) -> nn.Module: + """Return the language model module. + + Returns: + The underlying language model (Gemma) decoder + """ + return self.model.language_model + + @property + def vision_tower(self) -> nn.Module: + """Return the vision tower module. + + Returns: + The vision encoder (SigLIP) model + """ + return self.model.vision_tower + + @property + def multi_modal_projector(self) -> nn.Module: + """Return the multimodal projector module. + + Returns: + The projector that maps vision features to language model space + """ + return self.model.multi_modal_projector + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + token_type_ids: torch.LongTensor | None = None, + cache_position: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + return_dict: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[KwargsForCausalLM], + ) -> tuple | PaliGemmaCausalLMOutputWithPast: + """Run the forward pass for conditional generation. + + Args: + input_ids (`torch.LongTensor`, *optional*): + Input token IDs. + pixel_values (`torch.FloatTensor`, *optional*): + Image pixel values. + attention_mask (`torch.Tensor`, *optional*): + Attention mask for input tokens. + position_ids (`torch.LongTensor`, *optional*): + Position indices for the input tokens. + past_key_values (`Cache`, *optional*): + Cached key/value states for faster decoding. + token_type_ids (`torch.LongTensor`, *optional*): + Token type ids for prefix attention. + cache_position (`torch.LongTensor`, *optional*): + Cache positions for decoding. + inputs_embeds (`torch.FloatTensor`, *optional*): + Precomputed input embeddings. + labels (`torch.LongTensor`, *optional*): + Labels for computing the masked language modeling loss. Indices + should either be in `[0, ..., config.text_config.vocab_size]` or + -100 (see `input_ids` docstring). Tokens with indices set to + `-100` are ignored (masked), so the loss is computed only for + tokens with labels in `[0, ..., config.text_config.vocab_size]`. + use_cache (`bool`, *optional*): + Whether to return key/value cache. + output_attentions (`bool`, *optional*): + Whether to return attention weights. + output_hidden_states (`bool`, *optional*): + Whether to return hidden states. + return_dict (`bool`, *optional*): + Whether to return a ModelOutput dict. + logits_to_keep (`int` or `torch.Tensor`, *optional*): + Number of logits to compute or indices to keep. + **kwargs: + Additional attention or loss-related keyword arguments. + + Example: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor + >>> from transformers import PaliGemmaForConditionalGeneration + + >>> model = PaliGemmaForConditionalGeneration.from_pretrained( + ... "google/paligemma2-3b-mix-224" + ... ) + >>> processor = AutoProcessor.from_pretrained( + ... "google/paligemma2-3b-mix-224" + ... ) + + >>> prompt = "Where is the cat standing?" + >>> url = ( + ... "https://huggingface.co/datasets/huggingface/" + ... "documentation-images/resolve/main/pipeline-cat-chonk.jpeg" + ... ) + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, text=prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs) + >>> processor.batch_decode( + ... generate_ids, + ... skip_special_tokens=True, + ... clean_up_tokenization_spaces=False, + ... )[0] + "Where is the cat standing? snow" + ``` + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we + # are not computing the loss. + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.text_config.vocab_size, + **kwargs, + ) + + return PaliGemmaCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ) + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Cache | None = None, + inputs_embeds: torch.Tensor | None = None, + cache_position: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + pixel_values: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + token_type_ids: torch.Tensor | None = None, + use_cache: bool = True, + logits_to_keep: int | torch.Tensor | None = None, + labels: torch.Tensor | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Prepare inputs for generation with image-aware defaults. + + This method handles special requirements for PaliGemma generation: + - Position IDs are 1-indexed (unlike standard transformers) + - Pixel values are only needed at the first cache position + - Causal mask is updated for hybrid cache scenarios + + Args: + input_ids: Input token IDs + past_key_values: Optional cache for key/value states + inputs_embeds: Optional precomputed input embeddings + cache_position: Optional position indices for cache updates + position_ids: Optional position indices (will be adjusted to 1-indexed) + pixel_values: Optional image pixel values (only used at cache start) + attention_mask: Optional attention mask + token_type_ids: Optional token type IDs for prefix attention + use_cache: Whether to use caching + logits_to_keep: Optional number of logits to compute + labels: Optional labels for training + **kwargs: Additional generation arguments + + Returns: + Dictionary of prepared inputs for the model forward pass + """ + # Overwritten -- custom `position_ids` and `pixel_values` handling + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + cache_position=cache_position, + use_cache=use_cache, + logits_to_keep=logits_to_keep, + token_type_ids=token_type_ids, + **kwargs, + ) + + # position_ids in Paligemma are 1-indexed + if model_inputs.get("position_ids") is not None: + model_inputs["position_ids"] += 1 + if cache_position is None: + return model_inputs + # If we're in cached decoding stage, pixel values should be None because + # input ids do not contain special image token anymore. + # Otherwise we need pixel values to be passed to model. NOTE: + # use_cache=False needs pixel_values always. + if cache_position[0] == 0: + model_inputs["pixel_values"] = pixel_values + is_training = token_type_ids is not None and labels is not None + if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): + input_tensor = inputs_embeds if inputs_embeds is not None else input_ids + causal_mask = self.model._update_causal_mask( + attention_mask, + token_type_ids, + past_key_values, + cache_position, + input_tensor, + is_training, + ) + model_inputs["attention_mask"] = causal_mask + + return model_inputs + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs: Any, + ) -> torch.Tensor: + """Create a 4D causal attention mask. + + Creates a causal 4D mask of shape `(batch_size, 1, query_length, + key_value_length)` from a 2D mask of shape `(batch_size, + key_value_length)`, or if the input `attention_mask` is already 4D, does + nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or + a 4D attention mask of shape `(batch_size, 1, query_length, + key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length. When generating with static cache, the mask + should be as long as the static cache to account for the 0 + padding and the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in + the sequence. + batch_size (`int`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # The mask comes already in inverted form and requires no inversion + # or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), + fill_value=min_dtype, + dtype=dtype, + device=cache_position.device, + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange( + target_length, device=cache_position.device + ) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = ( + causal_mask.clone() + ) # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[ + :, None, None, : + ].to(causal_mask.device) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[ + :, :, :, :mask_length + ].masked_fill(padding_mask, min_dtype) + + return causal_mask + + +__all__ = [ + "PaliGemmaForConditionalGeneration", + "PaliGemmaPreTrainedModel", + "PaliGemmaModel", +] diff --git a/neuracore/ml/algorithms/pi0/transformers_replace/models/siglip/modeling_siglip.py b/neuracore/ml/algorithms/pi0/transformers_replace/models/siglip/modeling_siglip.py new file mode 100644 index 00000000..5ee234db --- /dev/null +++ b/neuracore/ml/algorithms/pi0/transformers_replace/models/siglip/modeling_siglip.py @@ -0,0 +1,1600 @@ +"""PyTorch SigLIP model implementation.""" + +import math +import warnings +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn.init import _calculate_fan_in_and_fan_out + +from ...activations import ACT2FN +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...utils import ModelOutput, auto_docstring, can_return_tuple, logging, torch_int +from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig + +logger = logging.get_logger(__name__) + + +def _trunc_normal_( + tensor: torch.Tensor, + mean: float, + std: float, + a: float, + b: float, +) -> torch.Tensor: + """Fill tensor with values from a truncated normal distribution (in-place). + + Uses the inverse CDF method to sample from a truncated normal distribution. + Values outside [a, b] are redrawn until within bounds. + + Based on: https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + + Args: + tensor: The tensor to fill with random values. + mean: Mean of the normal distribution. + std: Standard deviation of the normal distribution. + a: Minimum cutoff value. + b: Maximum cutoff value. + + Returns: + The input tensor filled with truncated normal values. + """ + + def norm_cdf(x: float) -> float: + """Compute the standard normal cumulative distribution function.""" + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + lower_cdf = norm_cdf((a - mean) / std) + upper_cdf = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * lower_cdf - 1, 2 * upper_cdf - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_tf_( + tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +) -> torch.Tensor: + r"""Fill the input tensor with values drawn from a truncated normal distribution. + + The values are effectively drawn from the normal distribution + :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` with values outside + :math:`[a, b]` redrawn until they are within the bounds. The method used + for generating the random values works best when + :math:`a \\leq \text{mean} \\leq b`. + + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with + mean=0, std=1.0 and the result is subsequently scaled and shifted by the + mean and std args. + + Args: + tensor: An n-dimensional `torch.Tensor`. + mean: The mean of the normal distribution. + std: The standard deviation of the normal distribution. + a: The minimum cutoff value. + b: The maximum cutoff value. + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + return tensor + + +def variance_scaling_( + tensor: torch.Tensor, + scale: float = 1.0, + mode: str = "fan_in", + distribution: str = "normal", +) -> torch.Tensor: + """Initialize tensor using variance scaling (in-place). + + Scales the variance of the initialization based on the number of input + and/or output units, following the approach from various initialization + schemes (Xavier, He, LeCun). + + Args: + tensor: The tensor to initialize. + scale: Scaling factor for the variance. Defaults to 1.0. + mode: Fan mode for computing variance denominator. One of: + - "fan_in": Use number of input units (default). + - "fan_out": Use number of output units. + - "fan_avg": Use average of fan_in and fan_out. + distribution: Distribution to sample from. One of: + - "truncated_normal": Truncated normal distribution. + - "normal": Standard normal distribution (default). + - "uniform": Uniform distribution. + + Returns: + The initialized tensor. + + Raises: + ValueError: If distribution is not recognized. + """ + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + return tensor + + +def lecun_normal_(tensor: torch.Tensor) -> torch.Tensor: + """Initialize tensor using LeCun normal initialization (in-place). + + Uses variance scaling with fan_in mode and truncated normal distribution, + as described in the Self-Normalizing Neural Networks paper. + + Args: + tensor: The tensor to initialize. + + Returns: + The initialized tensor. + """ + return variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor: torch.Tensor) -> torch.Tensor: + """Initialize embedding tensor using default Flax initialization (in-place). + + Uses variance scaling with fan_in mode and normal distribution, + matching the default embedding initialization in Flax/JAX. + + Args: + tensor: The tensor to initialize. + + Returns: + The initialized tensor. + """ + return variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for vision model outputs that also contains pooled image + embeddings from the last hidden states. + """ +) +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with +# CLIP->Siglip. +class SiglipVisionModelOutput(ModelOutput): + """Outputs for SigLIP vision models. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, + *optional*, returned when model is initialized with + `with_projection=True`): + The image embeddings obtained by applying the projection layer to + the pooler_output. + """ + + image_embeds: torch.FloatTensor | None = None + last_hidden_state: torch.FloatTensor | None = None + hidden_states: tuple[torch.FloatTensor, ...] | None = None + attentions: tuple[torch.FloatTensor, ...] | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for text model outputs that also contains pooled text embeddings + from the last hidden states. + """ +) +# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with +# CLIP->Siglip. +class SiglipTextModelOutput(ModelOutput): + """Outputs for SigLIP text models. + + Args: + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`, + *optional*, returned when model is initialized with + `with_projection=True`): + The text embeddings obtained by applying the projection layer to + the pooler_output. + """ + + text_embeds: torch.FloatTensor | None = None + last_hidden_state: torch.FloatTensor | None = None + hidden_states: tuple[torch.FloatTensor, ...] | None = None + attentions: tuple[torch.FloatTensor, ...] | None = None + + +@dataclass +@auto_docstring +# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip +class SiglipOutput(ModelOutput): + """Outputs for SigLIP image-text similarity models. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when + `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, + text_batch_size)`): + The scaled dot product scores between `image_embeds` and + `text_embeds`. This represents the image-text similarity scores. + logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, + image_batch_size)`): + The scaled dot product scores between `text_embeds` and + `image_embeds`. This represents the text-image similarity scores. + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to + the pooled output of [`SiglipTextModel`]. + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to + the pooled output of [`SiglipVisionModel`]. + text_model_output (`BaseModelOutputWithPooling`): + The output of the [`SiglipTextModel`]. + vision_model_output (`BaseModelOutputWithPooling`): + The output of the [`SiglipVisionModel`]. + """ + + loss: torch.FloatTensor | None = None + logits_per_image: torch.FloatTensor | None = None + logits_per_text: torch.FloatTensor | None = None + text_embeds: torch.FloatTensor | None = None + image_embeds: torch.FloatTensor | None = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> tuple[Any, ...]: + """Convert the output to a tuple representation. + + Recursively converts nested model outputs (text_model_output and + vision_model_output) to tuples as well. + + Returns: + A tuple containing all output values, with nested outputs also + converted to tuples. + """ + return tuple( + ( + self[k] + if k not in ["text_model_output", "vision_model_output"] + else getattr(self, k).to_tuple() + ) + for k in self.keys() + ) + + +class SiglipVisionEmbeddings(nn.Module): + """Vision embeddings for SigLIP, converting images to patch embeddings.""" + + def __init__(self, config: SiglipVisionConfig) -> None: + """Initialize the vision embeddings module. + + Args: + config: Vision configuration containing hidden_size, image_size, + patch_size, and num_channels. + """ + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size) ** 2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions).expand((1, -1)), + persistent=False, + ) + + def interpolate_pos_encoding( + self, embeddings: torch.Tensor, height: int, width: int + ) -> torch.Tensor: + """Interpolate pre-trained position encodings for higher resolutions. + + This method is adapted to support torch.jit tracing and no class + embeddings. + + Adapted from: + - https://github.com/facebookresearch/dino/blob/ + de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py + #L174-L194 + - https://github.com/facebookresearch/dinov2/blob/ + e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py + #L179-L211 + """ + num_patches = embeddings.shape[1] + num_positions = self.position_embedding.weight.shape[0] + + # Always interpolate when tracing to ensure the exported model works + # for dynamic input shapes. + if ( + not torch.jit.is_tracing() + and num_patches == num_positions + and height == width + ): + return self.position_embedding(self.position_ids) + + patch_pos_embed = self.position_embedding.weight.unsqueeze(0) + + dim = embeddings.shape[-1] + + new_height = height // self.patch_size + new_width = width // self.patch_size + + sqrt_num_positions = torch_int(num_positions**0.5) + patch_pos_embed = patch_pos_embed.reshape( + 1, sqrt_num_positions, sqrt_num_positions, dim + ) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + size=(new_height, new_width), + mode="bicubic", + align_corners=False, + ) + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward( + self, + pixel_values: torch.FloatTensor, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + """Convert pixel values to patch embeddings with positional encoding. + + Args: + pixel_values: Input images of shape (batch_size, channels, height, width). + interpolate_pos_encoding: Whether to interpolate positional encodings + for images with different resolutions than training. Defaults to False. + + Returns: + Patch embeddings of shape (batch_size, num_patches, embed_dim) with + positional encodings added. + """ + _, _, height, width = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype) + ) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding( + embeddings, height, width + ) + else: + embeddings = embeddings + self.position_embedding(self.position_ids) + return embeddings + + +class SiglipTextEmbeddings(nn.Module): + """Text embeddings for SigLIP, combining token and positional embeddings.""" + + def __init__(self, config: SiglipTextConfig) -> None: + """Initialize the text embeddings module. + + Args: + config: Text configuration containing vocab_size, hidden_size, + and max_position_embeddings. + """ + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding( + config.max_position_embeddings, embed_dim + ) + + # position_ids (1, len position emb) is contiguous in memory and + # exported when serialized. + self.register_buffer( + "position_ids", + torch.arange(config.max_position_embeddings).expand((1, -1)), + persistent=False, + ) + + def forward( + self, + input_ids: torch.LongTensor | None = None, + position_ids: torch.LongTensor | None = None, + inputs_embeds: torch.FloatTensor | None = None, + ) -> torch.Tensor: + """Compute text embeddings from token IDs or pre-computed embeddings. + + Args: + input_ids: Token IDs of shape (batch_size, seq_length). Either this + or inputs_embeds must be provided. + position_ids: Position IDs of shape (batch_size, seq_length). If None, + sequential positions starting from 0 are used. + inputs_embeds: Pre-computed token embeddings of shape + (batch_size, seq_length, embed_dim). Either this or input_ids + must be provided. + + Returns: + Combined token and position embeddings of shape + (batch_size, seq_length, embed_dim). + + Raises: + ValueError: If neither input_ids nor inputs_embeds is provided. + ValueError: If sequence length exceeds max_position_embeddings. + """ + if input_ids is None and inputs_embeds is None: + raise ValueError("You must specify either input_ids or inputs_embeds.") + if input_ids is not None: + seq_length = input_ids.shape[-1] + else: + assert inputs_embeds is not None + seq_length = inputs_embeds.shape[-2] + max_position_embedding = self.position_embedding.weight.shape[0] + + if seq_length > max_position_embedding: + raise ValueError( + "Sequence length must be less than max_position_embeddings (got " + f"`sequence length`: {seq_length} and max_position_embeddings: " + f"{max_position_embedding}" + ) + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + assert input_ids is not None + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor | None, + scaling: float, + dropout: float = 0.0, + **kwargs: Any, +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute scaled dot-product attention using eager (non-fused) implementation. + + Args: + module: The attention module (used for training state). + query: Query tensor of shape (batch_size, num_heads, seq_len, head_dim). + key: Key tensor of shape (batch_size, num_heads, seq_len, head_dim). + value: Value tensor of shape (batch_size, num_heads, seq_len, head_dim). + attention_mask: Optional mask of shape (batch_size, 1, seq_len, seq_len) + with large negative values for masked positions. + scaling: Scaling factor (typically 1/sqrt(head_dim)). + dropout: Dropout probability for attention weights. Defaults to 0.0. + **kwargs: Additional arguments (ignored for compatibility). + + Returns: + A tuple of: + - Attention output of shape (batch_size, seq_len, num_heads, head_dim). + - Attention weights of shape (batch_size, num_heads, seq_len, seq_len). + """ + attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query.dtype + ) + attn_weights = nn.functional.dropout( + attn_weights, p=dropout, training=module.training + ) + + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class SiglipAttention(nn.Module): + """Multi-headed attention from "Attention Is All You Need" paper. + + Implements scaled dot-product attention with multiple heads, supporting + various attention backends (eager, SDPA, Flash Attention). + """ + + def __init__(self, config: SiglipVisionConfig | SiglipTextConfig) -> None: + """Initialize the multi-head attention module. + + Args: + config: Model configuration containing hidden_size, num_attention_heads, + and attention_dropout. + + Raises: + ValueError: If embed_dim is not divisible by num_heads. + """ + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads (got `embed_dim`: " + f"{self.embed_dim} and `num_heads`: {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + output_attentions: bool | None = False, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + """Compute multi-head self-attention. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, embed_dim). + attention_mask: Optional attention mask of shape + (batch_size, 1, seq_length, seq_length). Defaults to None. + output_attentions: Whether to return attention weights. Defaults to False. + + Returns: + A tuple of: + - Attention output of shape (batch_size, seq_length, embed_dim). + - Attention weights of shape (batch_size, num_heads, seq_length, seq_length) + if output_attentions=True, else None. + """ + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view( + batch_size, seq_length, self.num_heads, self.head_dim + ).transpose(1, 2) + keys = keys.view( + batch_size, seq_length, self.num_heads, self.head_dim + ).transpose(1, 2) + values = values.view( + batch_size, seq_length, self.num_heads, self.head_dim + ).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and output_attentions: + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does " + "not support `output_attentions=True`. Falling back to " + "eager attention. This warning can be removed using the " + 'argument `attn_implementation="eager"` when loading the ' + "model." + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape( + batch_size, seq_length, embed_dim + ).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +class SiglipMLP(nn.Module): + """Feed-forward MLP block for SigLIP transformer layers.""" + + def __init__(self, config: SiglipVisionConfig | SiglipTextConfig) -> None: + """Initialize the MLP module. + + Args: + config: Model configuration containing hidden_size, intermediate_size, + and hidden_act activation function name. + """ + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Apply the feed-forward transformation. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + + Returns: + Output tensor of shape (batch_size, seq_length, hidden_size). + """ + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class SiglipEncoderLayer(GradientCheckpointingLayer): + """Single transformer encoder layer for SigLIP. + + Consists of self-attention followed by a feed-forward MLP, each with + layer normalization and residual connections (pre-norm architecture). + """ + + def __init__(self, config: SiglipVisionConfig | SiglipTextConfig) -> None: + """Initialize the encoder layer. + + Args: + config: Model configuration containing hidden_size, layer_norm_eps, + and attention/MLP parameters. + """ + super().__init__() + self.embed_dim = config.hidden_size + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.self_attn = SiglipAttention(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: bool | None = False, + ) -> tuple[torch.FloatTensor] | tuple[torch.FloatTensor, torch.Tensor | None]: + """Run a transformer encoder layer. + + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where + padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention + layers. See `attentions` under returned tensors for more + detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + if output_attentions: + return (hidden_states, attn_weights) + return (hidden_states,) + + +@auto_docstring +class SiglipPreTrainedModel(PreTrainedModel): + """Base class for SigLIP models providing weight initialization and loading. + + Handles common functionality like weight initialization, gradient checkpointing, + and compatibility with various attention implementations. + """ + + config_class = SiglipConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + _no_split_modules = [ + "SiglipTextEmbeddings", + "SiglipEncoderLayer", + "SiglipVisionEmbeddings", + "SiglipEncoderLayer", + "SiglipMultiheadAttentionPoolingHead", + ] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + _supports_attention_backend = True + + def _init_weights(self, module: nn.Module) -> None: + """Initialize weights for a given module based on its type. + + Uses different initialization strategies depending on the module type: + - Vision embeddings: Normal distribution scaled by 1/sqrt(width) + - Embeddings: Default Flax embedding initialization + - Attention layers: Xavier uniform for projections + - MLPs: Xavier uniform with small bias noise + - Pooling head: Xavier uniform for probe and attention weights + - SiglipModel: Initialize logit_scale and logit_bias + - Linear/Conv2d: LeCun normal initialization + - LayerNorm: Zero bias, ones for weight + + Args: + module: The module to initialize. + """ + if isinstance(module, SiglipVisionEmbeddings): + width = ( + self.config.vision_config.hidden_size + if isinstance(self.config, SiglipConfig) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.xavier_uniform_(module.q_proj.weight) + nn.init.xavier_uniform_(module.k_proj.weight) + nn.init.xavier_uniform_(module.v_proj.weight) + nn.init.xavier_uniform_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.xavier_uniform_(module.fc1.weight) + nn.init.xavier_uniform_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, SiglipMultiheadAttentionPoolingHead): + nn.init.xavier_uniform_(module.probe.data) + nn.init.xavier_uniform_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, SiglipModel): + logit_scale_init = torch.log(torch.tensor(1.0)) + module.logit_scale.data.fill_(logit_scale_init) + module.logit_bias.data.zero_() + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class SiglipEncoder(nn.Module): + """Transformer encoder consisting of multiple self-attention layers. + + Stacks `config.num_hidden_layers` transformer encoder layers, with + optional gradient checkpointing for memory efficiency during training. + """ + + def __init__(self, config: SiglipConfig) -> None: + """Initialize the encoder with stacked transformer layers. + + Args: + config: Model configuration containing num_hidden_layers and + layer-specific parameters. + """ + super().__init__() + self.config = config + self.layers = nn.ModuleList( + [SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + self.gradient_checkpointing = False + + # Ignore copy + @can_return_tuple + def forward( + self, + inputs_embeds: torch.FloatTensor, + attention_mask: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + ) -> BaseModelOutput: + """Run the transformer encoder. + + Args: + inputs_embeds (`torch.FloatTensor`): + Embedded inputs of shape `(batch_size, sequence_length, + hidden_size)`. Optionally, instead of passing `input_ids` you + can choose to directly pass an embedded representation. This is + useful if you want more control over how to convert `input_ids` + indices into associated vectors than the model's internal + embedding lookup matrix. + attention_mask (`torch.Tensor`, *optional*): + Mask of shape `(batch_size, sequence_length)` to avoid + performing attention on padding token indices. Mask values + selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention + layers. See `attentions` under returned tensors for more + detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See + `hidden_states` under returned tensors for more detail. + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + encoder_states: tuple[torch.Tensor, ...] | None = ( + () if output_hidden_states else None + ) + all_attentions: tuple[torch.Tensor, ...] | None = ( + () if output_attentions else None + ) + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + assert encoder_states is not None + encoder_states = encoder_states + (hidden_states,) + + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + assert all_attentions is not None + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + assert encoder_states is not None + encoder_states = encoder_states + (hidden_states,) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + ) + + +class SiglipTextTransformer(nn.Module): + """Text transformer for SigLIP, encoding text sequences. + + Combines text embeddings, transformer encoder, final layer normalization, + and a projection head for the text representation. + """ + + def __init__(self, config: SiglipTextConfig) -> None: + """Initialize the text transformer. + + Args: + config: Text model configuration containing hidden_size, + projection_size, and encoder parameters. + """ + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = SiglipTextEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.head = nn.Linear(embed_dim, config.projection_size) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + ) -> BaseModelOutputWithPooling: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # Note: SigLIP's text model does not use a causal mask, unlike the + # original CLIP model. + # expand attention_mask + if attention_mask is not None and not self._use_flash_attention_2: + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask( + attention_mask, hidden_states.dtype + ) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # Assuming "sticky" EOS tokenization, last token is always EOS. + pooled_output = last_hidden_state[:, -1, :] + pooled_output = self.head(pooled_output) + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The text model from SigLIP without any head or projection on top. + """ +) +class SiglipTextModel(SiglipPreTrainedModel): + """SigLIP text encoder without a projection head. + + Encodes text sequences into token-level and pooled representations + using a Transformer architecture. + """ + + config_class = SiglipTextConfig + + def __init__(self, config: SiglipTextConfig) -> None: + """Initialize the SigLIP text model. + + Args: + config: Text configuration containing model architecture parameters. + """ + super().__init__(config) + self.text_model = SiglipTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Embedding: + """Return the token embedding layer. + + Returns: + The nn.Embedding module used for token embeddings. + """ + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value: nn.Embedding) -> None: + """Set the token embedding layer. + + Args: + value: New embedding module to use for token embeddings. + """ + self.text_model.embeddings.token_embedding = value + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + ) -> BaseModelOutputWithPooling: + """Run the forward pass for text inputs. + + Examples: + ```python + >>> from transformers import AutoTokenizer, SiglipTextModel + + >>> model = SiglipTextModel.from_pretrained( + ... "google/siglip-base-patch16-224" + ... ) + >>> tokenizer = AutoTokenizer.from_pretrained( + ... "google/siglip-base-patch16-224" + ... ) + + >>> # Important: make sure to set padding="max_length". + >>> inputs = tokenizer( + ... ["a photo of a cat", "a photo of a dog"], + ... padding="max_length", + ... return_tensors="pt", + ... ) + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ``` + """ + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + +class SiglipVisionTransformer(nn.Module): + """Vision transformer for SigLIP, encoding image patches. + + Combines patch embeddings, transformer encoder, post-layer normalization, + and an optional multi-head attention pooling head. + """ + + def __init__(self, config: SiglipVisionConfig) -> None: + """Initialize the vision transformer. + + Args: + config: Vision model configuration containing hidden_size, + layer_norm_eps, vision_use_head, and encoder parameters. + """ + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.use_head = ( + True if not hasattr(config, "vision_use_head") else config.vision_use_head + ) + if self.use_head: + self.head = SiglipMultiheadAttentionPoolingHead(config) + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + interpolate_pos_encoding: bool | None = False, + ) -> BaseModelOutputWithPooling: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + hidden_states = self.embeddings( + pixel_values, interpolate_pos_encoding=interpolate_pos_encoding + ) + # Convert to bfloat16 if the encoder uses bfloat16 + if ( + len(self.encoder.layers) > 0 + and self.encoder.layers[0].self_attn.q_proj.weight.dtype == torch.bfloat16 + ): + hidden_states = hidden_states.to(torch.bfloat16) + + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + last_hidden_state = encoder_outputs.last_hidden_state + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooler_output = self.head(last_hidden_state) if self.use_head else None + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooler_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead attention pooling for aggregating sequence representations. + + Uses a learnable probe token to attend over the sequence and produce + a single pooled representation, followed by layer norm and MLP. + """ + + def __init__(self, config: SiglipVisionConfig) -> None: + """Initialize the attention pooling head. + + Args: + config: Vision configuration containing hidden_size, + num_attention_heads, and layer_norm_eps. + """ + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention( + config.hidden_size, config.num_attention_heads, batch_first=True + ) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + """Pool the sequence using attention with a learnable probe. + + Args: + hidden_state: Input tensor of shape (batch_size, seq_length, hidden_size). + + Returns: + Pooled representation of shape (batch_size, hidden_size). + """ + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +@auto_docstring( + custom_intro=""" + The vision model from SigLIP without any head or projection on top. + """ +) +class SiglipVisionModel(SiglipPreTrainedModel): + """SigLIP vision encoder without a projection head. + + Encodes images into patch-level and pooled representations using a + Vision Transformer architecture. + """ + + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: SiglipVisionConfig) -> None: + """Initialize the SigLIP vision model. + + Args: + config: Vision configuration containing model architecture parameters. + """ + super().__init__(config) + + self.vision_model = SiglipVisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Conv2d: + """Return the patch embedding layer. + + Returns: + The nn.Conv2d module used for converting image patches to embeddings. + """ + return self.vision_model.embeddings.patch_embedding + + @can_return_tuple + @auto_docstring + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + interpolate_pos_encoding: bool = False, + ) -> BaseModelOutputWithPooling: + """Run the forward pass for vision inputs. + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, SiglipVisionModel + + >>> model = SiglipVisionModel.from_pretrained( + ... "google/siglip-base-patch16-224" + ... ) + >>> processor = AutoProcessor.from_pretrained( + ... "google/siglip-base-patch16-224" + ... ) + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ``` + """ + return self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + +@auto_docstring +class SiglipModel(SiglipPreTrainedModel): + """SigLIP model combining text and vision encoders for image-text matching. + + Implements the Sigmoid Loss for Language Image Pre-Training (SigLIP) approach, + which uses a sigmoid loss instead of softmax for contrastive learning. + """ + + config_class = SiglipConfig + + def __init__(self, config: SiglipConfig) -> None: + """Initialize the combined SigLIP model with text and vision encoders. + + Args: + config: Combined configuration containing text_config and vision_config. + + Raises: + TypeError: If text_config or vision_config are not of expected types. + """ + super().__init__(config) + + if not isinstance(config.text_config, SiglipTextConfig): + raise TypeError( + "config.text_config is expected to be of type SiglipTextConfig " + f"but is of type {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, SiglipVisionConfig): + raise TypeError( + "config.vision_config is expected to be of type SiglipVisionConfig " + f"but is of type {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + # First, initialize the text and vision models with proper attention + # implementation. + text_model = SiglipTextModel._from_config(text_config) + vision_model = SiglipVisionModel._from_config(vision_config) + + # Second, get the text and vision submodules (for backward compatibility) + self.text_model = text_model.text_model + self.vision_model = vision_model.vision_model + + self.logit_scale = nn.Parameter(torch.randn(1)) + self.logit_bias = nn.Parameter(torch.randn(1)) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def get_text_features( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + ) -> torch.FloatTensor: + """Return pooled text features. + + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to + the pooled output of [`SiglipTextModel`]. + + Examples: + ```python + >>> from transformers import AutoTokenizer, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + + >>> # important: make sure to set padding="max_length". + >>> inputs = tokenizer( + ... ["a photo of a cat", "a photo of a dog"], + ... padding="max_length", + ... return_tensors="pt", + ... ) + >>> with torch.no_grad(): + ... text_features = model.get_text_features(**inputs) + ``` + """ + # Use SigLIP model's config for some fields (if specified) instead of + # those of vision & text components. + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + pooled_output = text_outputs.pooler_output + + return pooled_output + + @auto_docstring + def get_image_features( + self, + pixel_values: torch.FloatTensor | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + interpolate_pos_encoding: bool = False, + ) -> torch.FloatTensor: + """Return pooled image features. + + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to + the pooled output of [`SiglipVisionModel`]. + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... image_features = model.get_image_features(**inputs) + ``` + """ + # Use SiglipModel's config for some fields (if specified) instead of + # those of vision & text components. + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + pooled_output = vision_outputs.pooler_output + + return pooled_output + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + pixel_values: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + return_loss: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + interpolate_pos_encoding: bool = False, + ) -> SiglipOutput: + """Run the forward pass for image-text similarity. + + Args: + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained( + ... "google/siglip-base-patch16-224" + ... ) + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] + >>> # Important: we pass padding="max_length". + >>> inputs = processor( + ... text=texts, + ... images=image, + ... padding="max_length", + ... return_tensors="pt", + ... ) + + >>> with torch.no_grad(): + ... outputs = model(**inputs) + + >>> logits_per_image = outputs.logits_per_image + >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities + >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") + 31.9% that image 0 is 'a photo of 2 cats' + ``` + """ + # Use SigLIP model's config for some fields (if specified) instead of + # those of vision & text components. + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + vision_outputs: BaseModelOutputWithPooling = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + text_outputs: BaseModelOutputWithPooling = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + image_embeds = vision_outputs.pooler_output + text_embeds = text_outputs.pooler_output + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_text = torch.matmul( + text_embeds, image_embeds.t().to(text_embeds.device) + ) + + logit_scale, logit_bias = self.logit_scale.to( + text_embeds.device + ), self.logit_bias.to(text_embeds.device) + logits_per_text = logits_per_text * logit_scale.exp() + logit_bias + + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + # Adapted from: + # https://github.com/google-research/big_vision/blob/ + # 01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/ + # proj/image_text/siglip.py#L287 + eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device) + m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye + loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text) + nll = -torch.sum(loglik, dim=-1) + loss = nll.mean() + + return SiglipOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +__all__ = [ + "SiglipModel", + "SiglipPreTrainedModel", + "SiglipTextModel", + "SiglipVisionModel", +] diff --git a/neuracore/ml/algorithms/pi0/utils.py b/neuracore/ml/algorithms/pi0/utils.py new file mode 100644 index 00000000..25f5f27f --- /dev/null +++ b/neuracore/ml/algorithms/pi0/utils.py @@ -0,0 +1,349 @@ +"""Utility functions and configuration for the PI0 algorithm. + +This module provides helper functions for flow matching, attention mask +construction, and image preprocessing used by the PI0 model. It also +defines the PI0Config dataclass for model configuration. +""" + +# cspell:ignore OPENPI adarms + +from __future__ import annotations + +import math +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Literal + +import torch +import torch.nn.functional as F # noqa: N812 +from torch import Tensor + +# Constant used for attention masking in the OPENPI implementation +OPENPI_ATTENTION_MASK_VALUE = -1e9 + + +@dataclass(slots=True) +class PI0Config: + """Configuration for the PI0 model and training hyperparameters. + + Attributes: + paligemma_variant: PaliGemma model size ("gemma_300m" or "gemma_2b"). + action_expert_variant: Action expert model size ("gemma_300m" or "gemma_2b"). + dtype: Model precision ("bfloat16" or "float32"). + chunk_size: Number of action steps predicted per inference. + max_state_dim: Maximum dimension for state input vectors. + max_action_dim: Maximum dimension for action output vectors. + num_inference_steps: Number of Euler steps for action denoising. + use_adarms: Whether to use adaptive RMSNorm for (VLM, action expert). + time_sampling_beta_alpha: Alpha parameter for beta distribution time sampling. + time_sampling_beta_beta: Beta parameter for beta distribution time sampling. + time_sampling_scale: Scale factor for sampled time values. + time_sampling_offset: Offset added to sampled time values. + min_period: Minimum period for sinusoidal time embeddings. + max_period: Maximum period for sinusoidal time embeddings. + gradient_checkpointing: Whether to enable gradient checkpointing. + compile_model: Whether to compile the model with torch.compile. + compile_mode: Compilation mode for torch.compile. + device: Device to place the model on. + input_features: Mapping of input feature names to dimensions. + output_features: Mapping of output feature names to dimensions. + image_features: List of image feature names used as input. + """ + + paligemma_variant: str = "gemma_2b" + action_expert_variant: str = "gemma_300m" + dtype: Literal["bfloat16", "float32"] = "float32" + chunk_size: int = 50 + max_state_dim: int = 32 + max_action_dim: int = 32 + num_inference_steps: int = 10 + use_adarms: tuple[bool, bool] = (False, False) + time_sampling_beta_alpha: float = 1.5 + time_sampling_beta_beta: float = 1.0 + time_sampling_scale: float = 0.999 + time_sampling_offset: float = 0.001 + min_period: float = 4e-3 + max_period: float = 4.0 + gradient_checkpointing: bool = False + compile_model: bool = False + compile_mode: str = "max-autotune" + device: str | None = None + input_features: dict = field(default_factory=dict) + output_features: dict = field(default_factory=dict) + image_features: list[str] = field(default_factory=list) + + def validate_features(self) -> None: + """Validate configuration values. + + Raises: + ValueError: If any configuration value is invalid. + """ + if self.device is None: + self.device = "cpu" + + if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]: + raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}") + + if self.action_expert_variant not in ["gemma_300m", "gemma_2b"]: + raise ValueError( + f"Invalid action_expert_variant: {self.action_expert_variant}" + ) + + if self.dtype not in ["bfloat16", "float32"]: + raise ValueError(f"Invalid dtype: {self.dtype}") + + +def _get_safe_dtype(target_dtype: torch.dtype, device_type: str) -> torch.dtype: + """Get a device-compatible dtype. + + Some devices don't support certain dtypes (e.g., MPS doesn't support + float64, CPU doesn't efficiently support bfloat16). This function + returns a safe fallback dtype. + + Args: + target_dtype: Desired dtype + device_type: Device type string ("cpu", "mps", "cuda") + + Returns: + Compatible dtype for the given device. + """ + if device_type == "mps" and target_dtype == torch.float64: + return torch.float32 + if device_type == "cpu": + if target_dtype == torch.bfloat16: + return torch.float32 + if target_dtype == torch.float64: + return torch.float64 + return target_dtype + + +def _create_sinusoidal_pos_embedding( + time: torch.Tensor, + dimension: int, + min_period: float, + max_period: float, + device: torch.device | str = "cpu", +) -> Tensor: + """Create sinusoidal positional embeddings for diffusion timesteps. + + Uses logarithmically-spaced frequencies between min_period and max_period + to create rich time representations for the flow matching model. + + Args: + time: Diffusion timesteps [batch_size] + dimension: Embedding dimension (must be even) + min_period: Minimum frequency period + max_period: Maximum frequency period + device: Target device + + Returns: + Sinusoidal embeddings [batch_size, dimension]. + + Raises: + ValueError: If dimension is odd or time tensor has wrong shape. + """ + device = torch.device(device) + if dimension % 2 != 0: + raise ValueError(f"dimension ({dimension}) must be divisible by 2") + if time.ndim != 1: + raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") + dtype = _get_safe_dtype(torch.float64, device.type) + fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) + period = min_period * (max_period / min_period) ** fraction + scaling_factor = 1.0 / period * 2 * math.pi + sin_input = scaling_factor[None, :] * time[:, None] + return torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + + +def _sample_beta( + alpha: float | torch.Tensor, + beta: float | torch.Tensor, + bsize: int, + device: torch.device | str, +) -> Tensor: + """Sample from beta distribution for time sampling. + + Args: + alpha: Beta distribution alpha parameter + beta: Beta distribution beta parameter + bsize: Number of samples to draw + device: Target device + + Returns: + Beta-distributed samples [bsize]. + """ + alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) + beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) + dist = torch.distributions.Beta(alpha_t, beta_t) + return dist.sample((bsize,)) + + +def _make_att_2d_masks( + pad_masks: torch.Tensor, att_masks: torch.Tensor +) -> torch.Tensor: + """Build causal 2D attention masks from padding and attention masks. + + Combines padding information with causal masking to create the final + attention mask used by transformer layers. + + Args: + pad_masks: Padding mask [batch_size, seq_len] + att_masks: Attention mask [batch_size, seq_len] + + Returns: + Combined causal mask [batch_size, seq_len, seq_len]. + + Raises: + ValueError: If input masks don't have 2 dimensions. + """ + if att_masks.ndim != 2: + raise ValueError(att_masks.ndim) + if pad_masks.ndim != 2: + raise ValueError(pad_masks.ndim) + cumsum = torch.cumsum(att_masks, dim=1) + att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + return att_2d_masks & pad_2d_masks + + +def pad_vector(vector: torch.Tensor, new_dim: int) -> torch.Tensor: + """Right-pad tensor's last dimension to target size. + + Args: + vector: Input tensor + new_dim: Target size for last dimension + + Returns: + Padded tensor, or original if already large enough. + """ + if vector.shape[-1] >= new_dim: + return vector + return F.pad(vector, (0, new_dim - vector.shape[-1])) + + +def build_lr_lambda( + actual_warmup_steps: int, + actual_decay_steps: int, + decay_lr: float, + optimizer_lr: float, +) -> Callable[[int], float]: + """Create a learning rate scheduler lambda with warmup and cosine decay. + + Args: + actual_warmup_steps: Warmup steps after any scaling. + actual_decay_steps: Cosine decay steps after any scaling. + decay_lr: Final learning rate after decay. + optimizer_lr: Base optimizer learning rate. + + Returns: + Callable that maps the current step to a LR multiplier. + """ + + def linear_warmup(step: int) -> float: + if step <= 0: + return 1 / (actual_warmup_steps + 1) + frac = 1 - step / actual_warmup_steps + return (1 / (actual_warmup_steps + 1) - 1) * frac + 1 + + def cosine_decay(step: int) -> float: + step = min(step, actual_decay_steps) + cosine = 0.5 * (1 + math.cos(math.pi * step / actual_decay_steps)) + alpha = decay_lr / optimizer_lr + return (1 - alpha) * cosine + alpha + + def lr_lambda(current_step: int) -> float: + if current_step < actual_warmup_steps: + return linear_warmup(current_step) + return cosine_decay(current_step) + + return lr_lambda + + +def _align_mask_length(mask_1d: torch.Tensor, target_len: int) -> torch.Tensor: + """Pad or trim a 1D mask to target length. + + Args: + mask_1d: Input mask tensor + target_len: Desired length + + Returns: + Mask tensor with exactly target_len elements. + """ + current_len = mask_1d.shape[0] + if current_len == target_len: + return mask_1d + if current_len < target_len: + pad = torch.zeros( + target_len - current_len, device=mask_1d.device, dtype=mask_1d.dtype + ) + return torch.cat([mask_1d, pad], dim=0) + return mask_1d[:target_len] + + +def resize_with_pad_torch( + images: torch.Tensor, + height: int, + width: int, + mode: str = "bilinear", +) -> torch.Tensor: + """Resize images to target size while preserving aspect ratio. + + Resizes the image to fit within the target dimensions while maintaining + aspect ratio, then pads to reach exact target size. Automatically + detects channels-first vs channels-last format. + + Args: + images: Input images (B, C, H, W) or (B, H, W, C) + height: Target height + width: Target width + mode: Interpolation mode + + Returns: + Resized and padded images in original format. + + Raises: + ValueError: If image dtype is not uint8 or float32. + """ + if images.shape[-1] <= 4: # assume channels-last + channels_last = True + if images.dim() == 3: + images = images.unsqueeze(0) + images = images.permute(0, 3, 1, 2) + else: + channels_last = False + if images.dim() == 3: + images = images.unsqueeze(0) + + _, _, cur_height, cur_width = images.shape + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + + resized_images = F.interpolate( + images, + size=(resized_height, resized_width), + mode=mode, + align_corners=False if mode == "bilinear" else None, + ) + if images.dtype == torch.uint8: + resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) + elif images.dtype == torch.float32: + resized_images = resized_images.clamp(-1.0, 1.0) + else: + raise ValueError(f"Unsupported image dtype: {images.dtype}") + + pad_h0, remainder_h = divmod(height - resized_height, 2) + pad_h1 = pad_h0 + remainder_h + pad_w0, remainder_w = divmod(width - resized_width, 2) + pad_w1 = pad_w0 + remainder_w + + constant_value = 0 if images.dtype == torch.uint8 else -1.0 + padded_images = F.pad( + resized_images, + (pad_w0, pad_w1, pad_h0, pad_h1), + mode="constant", + value=constant_value, + ) + if channels_last: + padded_images = padded_images.permute(0, 2, 3, 1) + return padded_images diff --git a/neuracore/ml/config/algorithm/pi0.yaml b/neuracore/ml/config/algorithm/pi0.yaml index eb8ede03..7e2ba29e 100644 --- a/neuracore/ml/config/algorithm/pi0.yaml +++ b/neuracore/ml/config/algorithm/pi0.yaml @@ -1,22 +1,30 @@ # @package _global_ +# cspell:ignore bfloat checkpointing finetune algorithm: _target_: neuracore.ml.algorithms.pi0.pi0.Pi0 - vlm_expert_intermediate_size: 16384 - vlm_expert_num_heads: 8 - vlm_expert_num_kv_heads: 1 - vlm_expert_head_dim: 256 - vlm_max_text_tokens: 128 - action_expert_width: 1024 - action_expert_intermediate_size: 4096 - action_expert_num_heads: 8 - action_expert_num_kv_heads: 1 - action_expert_head_dim: 256 - moe_depth: 18 + vlm_max_text_tokens: 48 num_inference_steps: 10 - flow_sig_min: 0.001 - flow_alpha: 1.5 - flow_beta: 1.0 - lr: 5e-5 - weight_decay: 0.0 + dtype: "bfloat16" + paligemma_variant: "gemma_2b" + action_expert_variant: "gemma_300m" + use_pretrained_weights: True + pretrained_name_or_path: "lerobot/pi0_base" + time_sampling_beta_alpha: 1.5 + time_sampling_beta_beta: 1.0 + time_sampling_scale: 0.999 + time_sampling_offset: 0.001 + min_period: 4e-3 + max_period: 4.0 + gradient_checkpointing: True + compile_model: False + compile_mode: "max-autotune" + optimizer_lr: 2.5e-5 + optimizer_betas: [0.9, 0.95] + optimizer_eps: 1e-8 + optimizer_weight_decay: 0.01 clip_grad_norm: 1.0 - dtype: torch.float32 \ No newline at end of file + lr_scheduler_warmup_steps: 1000 + lr_scheduler_num_decay_steps: 30000 + lr_scheduler_decay_lr: 2.5e-6 + finetune_action_expert_only: False + freeze_language_model_only: True \ No newline at end of file diff --git a/neuracore/ml/utils/algorithm_loader.py b/neuracore/ml/utils/algorithm_loader.py index e817c451..8e1d97e0 100644 --- a/neuracore/ml/utils/algorithm_loader.py +++ b/neuracore/ml/utils/algorithm_loader.py @@ -147,8 +147,7 @@ def install_requirements(self) -> bool: def get_all_files(self) -> list[Path]: """Get all Python files in the algorithm directory recursively. - Scans the algorithm directory and all subdirectories for Python files, - excluding __init__.py files which are handled separately. + Scans the algorithm directory and all subdirectories for Python files. Returns: List of Path objects representing all Python files found. @@ -156,7 +155,7 @@ def get_all_files(self) -> list[Path]: files = [] for root, _, filenames in os.walk(self.algorithm_dir): for filename in filenames: - if filename.endswith(".py") and filename != "__init__.py": + if filename.endswith(".py"): files.append(Path(root) / filename) return files diff --git a/tests/unit/ml/algorithms/test_pi0.py b/tests/unit/ml/algorithms/test_pi0.py index e058b59f..17121cb3 100644 --- a/tests/unit/ml/algorithms/test_pi0.py +++ b/tests/unit/ml/algorithms/test_pi0.py @@ -15,27 +15,24 @@ from neuracore.ml.algorithms.pi0.pi0 import Pi0 from neuracore.ml.core.ml_types import BatchedTrainingOutputs from neuracore.ml.datasets.pytorch_dummy_dataset import PytorchDummyDataset -from neuracore.ml.utils.device_utils import get_default_device from neuracore.ml.utils.validate import run_validation -BS = 2 -DEVICE = get_default_device() -OUTPUT_PREDICTION_HORIZON = 5 +BS = 1 +OUTPUT_PREDICTION_HORIZON = 1 # Use cpu because the model takes a lot of vram DEVICE = torch.device("cpu") SKIP_TEST = os.environ.get("CI", "false").lower() == "true" -PI_TINY_ARGS: dict[str, Any] = { - "vlm_expert_intermediate_size": 4, - "vlm_expert_num_heads": 1, - "vlm_expert_head_dim": 4, - "action_expert_width": 16, - "action_expert_intermediate_size": 4, - "action_expert_num_heads": 1, - "action_expert_head_dim": 4, - "moe_depth": 1, +PI0_TEST_ARGS: dict[str, Any] = { + "paligemma_variant": "gemma_tiny", + "action_expert_variant": "gemma_tiny", + "use_pretrained_weights": False, + "num_inference_steps": 1, + "vlm_max_text_tokens": 4, + "compile_model": False, + "gradient_checkpointing": False, } @@ -106,10 +103,8 @@ def sample_training_batch( @pytest.mark.skipif(SKIP_TEST, reason="Skipping test in CI environment") -def test_model_construction( - model_init_description: ModelInitDescription, model_config: dict -): - model = Pi0(model_init_description, **PI_TINY_ARGS) +def test_model_construction(model_init_description: ModelInitDescription): + model = Pi0(model_init_description, **PI0_TEST_ARGS) model = model.to(DEVICE) assert isinstance(model, nn.Module) @@ -119,7 +114,7 @@ def test_model_forward( model_init_description: ModelInitDescription, sample_inference_batch: BatchedInferenceInputs, ): - model = Pi0(model_init_description, **PI_TINY_ARGS) + model = Pi0(model_init_description, **PI0_TEST_ARGS) model = model.to(DEVICE) sample_inference_batch = sample_inference_batch.to(DEVICE) output: dict[DataType, list[BatchedNCData]] = model(sample_inference_batch) @@ -136,7 +131,7 @@ def test_model_backward( model_init_description: ModelInitDescription, sample_training_batch: BatchedTrainingSamples, ): - model = Pi0(model_init_description, **PI_TINY_ARGS) + model = Pi0(model_init_description, **PI0_TEST_ARGS) model = model.to(DEVICE) sample_training_batch = sample_training_batch.to(DEVICE) output: BatchedTrainingOutputs = model.training_step(sample_training_batch) @@ -152,7 +147,10 @@ def test_model_backward( if param.requires_grad: # VLM parameters may not get gradients if they're not used in the # forward pass - is_vlm_param = any(keyword in name.lower() for keyword in ["vlm", "vision"]) + is_vlm_param = any( + keyword in name.lower() + for keyword in ["vlm", "vision", "paligemma", "language_model"] + ) if not is_vlm_param: # Non-VLM parameters should definitely have gradients @@ -170,7 +168,12 @@ def test_model_backward( @pytest.mark.skipif(SKIP_TEST, reason="Skipping test in CI environment") -def test_run_validation(tmp_path: Path, mock_login): +def test_run_validation(tmp_path: Path, mock_login, monkeypatch): + from neuracore.ml.algorithms.pi0.pi0 import Pi0 + from neuracore.ml.utils import validate as validate_module + + monkeypatch.setattr(validate_module.AlgorithmLoader, "load_model", lambda self: Pi0) + # Long timeout due to larger model run on CPU os.environ["NEURACORE_ENDPOINT_TIMEOUT"] = "120" algorithm_dir = Path(inspect.getfile(Pi0)).parent @@ -179,7 +182,7 @@ def test_run_validation(tmp_path: Path, mock_login): algorithm_dir=algorithm_dir, port=random.randint(10000, 20000), skip_endpoint_check=False, - algorithm_config=PI_TINY_ARGS, + algorithm_config=PI0_TEST_ARGS, device=DEVICE, ) if len(error_msg) > 0: