Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions neuracore-dictionary.txt
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ UNITREE
unitreeh1
URDF
usefixtures
chonk
dinov
dinov2
Vaswani
vertadr
vertnum
Expand All @@ -218,6 +221,28 @@ ylabel
znear
bigym
secho
adarms
ADARMS
meanpooling
colwise
rowwise
autocast
broadcastable
seqlen
layernorm
attns
torchdynamo
llava
Llava
triu
gptj
GPTJ
erfinv
lecun
CLIPMLP
altclip
loglik
logsigmoid
xyzw
wxyz
nans
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,6 @@ def training_step(self, batch: BatchedTrainingSamples) -> BatchedTrainingOutputs
action_data = torch.cat(action_targets, dim=-1) # (B, T, total_action_dim)

target_actions = self.action_normalizer.normalize(action_data)
target_actions = target_actions

# Sample noise to add to the trajectory.
eps = torch.randn(target_actions.shape, device=target_actions.device)
Expand Down
119 changes: 118 additions & 1 deletion neuracore/ml/algorithms/pi0/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,118 @@
"""Init."""
"""PI0 algorithm with transformers patching.

Automatically patches the installed transformers library with custom modifications
required by PI0. This eliminates the need to manually copy files into the transformers
installation directory.

The patching includes:
- Gemma model with Adaptive RMSNorm support
- Gated residual connections for Gemma modeling
- Custom PaliGemma and SigLIP modifications
- Python 3.10 UnionType annotation support for transformers docs
"""

# cspell:ignore adarms
import logging
import shutil
from pathlib import Path

logger = logging.getLogger(__name__)


def check_whether_transformers_replace_is_installed_correctly() -> bool:
"""Check whether transformers has been patched with PI0 modifications.

Verifies that the installed `transformers` library has been patched by checking
for custom attributes and functions that are not present in upstream.

Returns:
True if patches are detected, False otherwise.
"""
try:
from transformers.models.gemma import modeling_gemma
from transformers.models.gemma.configuration_gemma import GemmaConfig

cfg = GemmaConfig()
if not hasattr(cfg, "use_adarms"):
return False
if not hasattr(modeling_gemma, "_gated_residual"):
return False
return True
except Exception:
return False


def _patch_transformers_args_doc() -> None:
"""Patch transformers args_doc to handle Python 3.10 UnionType annotations.

Fixes documentation generation errors caused by UnionType syntax
(e.g., `int | str`). The patch is applied once and marked to prevent
re-patching.
"""
try:
import inspect
import re
import types
from collections.abc import Callable
from typing import Any, get_args

from transformers.utils import args_doc

if getattr(args_doc, "_UNIONTYPE_PATCHED", False):
return

original = args_doc._process_parameter_type

def _process_parameter_type(
param: inspect.Parameter, param_name: str, func: Callable[..., Any]
) -> tuple[str, bool]:
if param.annotation != inspect.Parameter.empty and isinstance(
param.annotation, types.UnionType
):
param_type = str(param.annotation).replace("transformers.", "~")
optional = any(arg is type(None) for arg in get_args(param.annotation))
if "ForwardRef" in param_type:
param_type = re.sub(r"ForwardRef\('([\w.]+)'\)", r"\1", param_type)
if "Optional" in param_type:
param_type = re.sub(r"Optional\[(.*?)\]", r"\1", param_type)
optional = True
return param_type, optional
return original(param, param_name, func)

args_doc._process_parameter_type = _process_parameter_type
args_doc._UNIONTYPE_PATCHED = True
except Exception:
return


def _patch_transformers() -> None:
"""Automatically patch transformers with custom modifications.

Checks if patching is needed, then copies files from transformers_replace/
to the installed transformers library. The process is idempotent and works
across different installation methods.

Raises:
ValueError: If patching fails due to permission issues.
"""
if check_whether_transformers_replace_is_installed_correctly():
return # Already patched
else:
logger.info("Transformers not patched; attempting to patch now.")

try:
import transformers

src = Path(__file__).parent / "transformers_replace"
dst = Path(transformers.__file__).parent
if src.exists():
for f in src.rglob("*.py"):
target = dst / f.relative_to(src)
target.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(f, target)
except Exception:
raise ValueError("Failed to patch transformers because of permission issues")


_patch_transformers()
_patch_transformers_args_doc()
Loading