Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e085a00
Add audio model adapters and improve SSM partition specs
vishesh9131 Oct 31, 2025
6617a60
Fix return type annotation for default_output_partition_spec
vishesh9131 Nov 2, 2025
8b04152
Remove input_grain_csv_test.py - incomplete implementation
vishesh9131 Nov 2, 2025
fd94c7a
Fix audio adapter test failures
vishesh9131 Nov 2, 2025
bdead97
Fix ASRModelAdapter state passing to F() and parameter count
vishesh9131 Nov 2, 2025
ba55210
Remove fallback direct calls from ASRModelAdapter
vishesh9131 Nov 2, 2025
6e22d63
Use uv pip install in pre-commit workflow
vishesh9131 Nov 3, 2025
8f41b40
Fix uv pip install to use --system flag
vishesh9131 Nov 3, 2025
c569281
Add pylint disable for too-many-positional-arguments in jit_mamba_scan
vishesh9131 Nov 3, 2025
b2c673e
Add pylint disable for MambaConfig.__init__ too-many-positional-argum…
vishesh9131 Nov 3, 2025
3e94b99
Use uv pip install without --system flag for pre-commit
vishesh9131 Nov 3, 2025
be7466e
Revert to pip install matching upstream
vishesh9131 Nov 3, 2025
d13700d
Add uv back to workflow and debug venv detection
vishesh9131 Nov 3, 2025
b67e66c
Revert to pip install to match upstream
vishesh9131 Nov 3, 2025
5d6cddf
Use uv with explicit VIRTUAL_ENV for ml-dtypes override
vishesh9131 Nov 3, 2025
d7a9e1f
Fix VIRTUAL_ENV export in same step as uv pip install
vishesh9131 Nov 3, 2025
d87892e
Use pip with legacy resolver to handle ml-dtypes conflict
vishesh9131 Nov 3, 2025
d346600
Fix import order for isort 7.0 and add pylint disables
vishesh9131 Nov 3, 2025
e29f655
Reinstall google-cloud-aiplatform after legacy-resolver install
vishesh9131 Nov 3, 2025
81b08d4
Also reinstall transformers after legacy-resolver
vishesh9131 Nov 3, 2025
18bb746
CI: use dedicated venv + uv for dependency install
vishesh9131 Nov 3, 2025
a92685a
Pin isort==7.0.0 to align CI and local formatting
vishesh9131 Nov 3, 2025
ac29875
CI: pin isort==5.13.2 to satisfy pylint (<6) and allow uv resolution
vishesh9131 Nov 3, 2025
ebcb758
chore(ci): trigger CI on latest workflow/deps changes
vishesh9131 Nov 3, 2025
5f7e142
Apply isort 5.13.2 formatting to match CI
vishesh9131 Nov 3, 2025
f8d0ca7
CI: pin torch==2.1.1 and torchvision==0.16.1 for pytype compatibility…
vishesh9131 Nov 3, 2025
efe1397
CI: install PyTorch CPU wheels via index-url so pytype sees torch.nn …
vishesh9131 Nov 3, 2025
176d39e
pytype: disable module-attr in ssm_test to allow torch.nn usage under…
vishesh9131 Nov 3, 2025
0766664
Merge branch 'main' into main
vishesh9131 Nov 21, 2025
13115d4
Refactor with-statement formatting and minor style fixes
vishesh9131 Nov 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,17 @@ jobs:
with:
python-version: '3.10'
cache: 'pip'
- run: pip install --upgrade pip
- name: Create and activate venv for uv
run: |
python -m venv .venv
echo "VIRTUAL_ENV=$PWD/.venv" >> $GITHUB_ENV
echo "$PWD/.venv/bin" >> $GITHUB_PATH
- run: pip install --upgrade pip uv
# TODO(markblee): Remove gcp,vertexai_tensorboard from CI. (needed by pytype)
- run: pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]'
- run: |
uv pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]'
# Ensure PyTorch installs CPU wheels (with stubs) so pytype can resolve torch.nn
uv pip install --index-url https://download.pytorch.org/whl/cpu 'torch==2.1.1' 'torchvision==0.16.1'
# pylint uses approx 12GB of memory during this run, look into split to decrease?
- run: |
# Start memory monitor as a background process.
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,4 @@ bazel-*

# Emacs
*~
run_specific_test.sh
9 changes: 6 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@ repos:
- id: black
name: black
entry: black
language: system
language: python
additional_dependencies: ['black==23.12.0']
types: [python]
- id: isort
name: isort
entry: isort
language: system
language: python
additional_dependencies: ['isort==5.13.2']
types: [python]
- id: pylint
name: pylint
entry: pylint
args: ['--msg-template="{abspath}:{line}: [{msg_id}({symbol}), {obj}] {msg}"']
language: system
language: python
additional_dependencies: ['pylint==3.3.7']
types: [python]
246 changes: 246 additions & 0 deletions axlearn/audio/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# Copyright © 2023 Apple Inc.

"""Audio model adapter for efficient fine-tuning."""

from typing import Optional

import jax

from axlearn.common.base_layer import BaseLayer
from axlearn.common.config import REQUIRED, Required, config_class
from axlearn.common.layers import BatchNorm, LayerNorm, Linear
from axlearn.common.module import Module
from axlearn.common.module import functional as F
from axlearn.common.param_init import DefaultInitializer, WeightInitializer


class AudioModelAdapter(BaseLayer):
"""Adapter layer for efficient fine-tuning of audio models."""

@config_class
class Config(BaseLayer.Config):
"""Configures AudioModelAdapter."""

# Input feature dimension.
input_dim: Required[int] = REQUIRED
# Bottleneck dimension (typically much smaller than input_dim).
bottleneck_dim: Required[int] = REQUIRED
# Whether to apply layer normalization before the adapter.
use_layer_norm: bool = True
# Whether to apply batch normalization in the adapter.
use_batch_norm: bool = False
# Scaling factor for the adapter output.
adapter_scale: float = 1.0
# Activation function to use.
activation: str = "relu"
# Whether to add a residual connection.
residual: bool = True

def __init__(self, cfg: Config, *, parent: Optional[Module]):
super().__init__(cfg, parent=parent)
cfg = self.config

# Initialize with small weights to make adapter less disruptive initially
weight_init = WeightInitializer.default_config().set(
distribution="normal",
fan="fan_in",
scale=0.01,
)

bias_init = WeightInitializer.default_config().set(
distribution="normal",
fan=None,
scale=0.01,
)

param_init = DefaultInitializer.default_config().set(
init_by_param_name={
".*weight": weight_init,
".*bias": bias_init,
},
)

# Down projection to bottleneck dimension
self._add_child(
"down_proj",
Linear.default_config().set(
input_dim=cfg.input_dim,
output_dim=cfg.bottleneck_dim,
bias=True,
param_init=param_init,
),
)

# Optional batch normalization
if cfg.use_batch_norm:
self._add_child(
"batch_norm",
BatchNorm.default_config().set(
input_dim=cfg.bottleneck_dim,
decay=0.9,
),
)

# Up projection back to input dimension
self._add_child(
"up_proj",
Linear.default_config().set(
input_dim=cfg.bottleneck_dim,
output_dim=cfg.input_dim,
bias=True,
param_init=param_init,
),
)

# Optional layer normalization
if cfg.use_layer_norm:
self._add_child(
"layer_norm",
LayerNorm.default_config().set(
input_dim=cfg.input_dim,
),
)

def forward(self, inputs, **_kwargs):
"""Apply the adapter transformation.

Args:
inputs: Input tensor of shape [batch_size, seq_len, input_dim].
**_kwargs: Additional keyword arguments (unused, kept for API compatibility).

Returns:
Tensor of the same shape as inputs.
"""
cfg = self.config
residual = inputs

# Apply layer normalization if specified
x = inputs
if cfg.use_layer_norm:
x = self.layer_norm(x)

# Down projection
x = self.down_proj(x)

# Apply batch normalization if specified
if cfg.use_batch_norm:
# BatchNorm uses is_training from context automatically
x = self.batch_norm(x)

# Activation
if cfg.activation == "relu":
x = jax.nn.relu(x)
elif cfg.activation == "gelu":
x = jax.nn.gelu(x)

# Up projection
x = self.up_proj(x)

# Scale the output
if cfg.adapter_scale != 1.0:
x = x * cfg.adapter_scale

# Add residual connection if specified
if cfg.residual:
x = x + residual

return x


class ASRModelAdapter(BaseLayer):
"""Adapter for Automatic Speech Recognition (ASR) models."""

@config_class
class Config(BaseLayer.Config):
"""Configures ASRModelAdapter."""

# Feature dimension of the encoder.
encoder_dim: Required[int] = REQUIRED
# Bottleneck dimension for encoder adapters.
encoder_bottleneck_dim: Required[int] = REQUIRED
# Feature dimension of the decoder.
decoder_dim: Optional[int] = None
# Bottleneck dimension for decoder adapters.
decoder_bottleneck_dim: Optional[int] = None
# Whether to add adapters to the encoder.
adapt_encoder: bool = True
# Whether to add adapters to the decoder.
adapt_decoder: bool = False
# Adapter configuration.
adapter: AudioModelAdapter.Config = AudioModelAdapter.default_config()

def __init__(self, cfg: Config, *, parent: Optional[Module]):
super().__init__(cfg, parent=parent)
cfg = self.config

if cfg.adapt_encoder:
self._add_child(
"encoder_adapter",
cfg.adapter.clone(
input_dim=cfg.encoder_dim,
bottleneck_dim=cfg.encoder_bottleneck_dim,
),
)

if (
cfg.adapt_decoder
and cfg.decoder_dim is not None
and cfg.decoder_bottleneck_dim is not None
):
self._add_child(
"decoder_adapter",
cfg.adapter.clone(
input_dim=cfg.decoder_dim,
bottleneck_dim=cfg.decoder_bottleneck_dim,
),
)

def adapt_encoder_features(self, features, *, is_training=False, prng_key, state):
"""Apply adaptation to encoder features.

Args:
features: Encoder features to adapt.
is_training: Whether the model is in training mode.
prng_key: PRNG key for stochastic operations.
state: State for the adapter.

Returns:
Adapted encoder features.
"""
cfg = self.config
if not cfg.adapt_encoder:
return features

outputs, _ = F(
self.encoder_adapter,
inputs=(features,),
is_training=is_training,
prng_key=prng_key,
state=state["encoder_adapter"],
)
return outputs

def adapt_decoder_features(self, features, *, is_training=False, prng_key, state):
"""Apply adaptation to decoder features.

Args:
features: Decoder features to adapt.
is_training: Whether the model is in training mode.
prng_key: PRNG key for stochastic operations.
state: State for the adapter.

Returns:
Adapted decoder features.
"""
cfg = self.config
if not cfg.adapt_decoder or not hasattr(self, "decoder_adapter"):
return features

outputs, _ = F(
self.decoder_adapter,
inputs=(features,),
is_training=is_training,
prng_key=prng_key,
state=state["decoder_adapter"],
)
return outputs
Loading