diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 041199f45..e3780ec40 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -15,9 +15,17 @@ jobs: with: python-version: '3.10' cache: 'pip' - - run: pip install --upgrade pip + - name: Create and activate venv for uv + run: | + python -m venv .venv + echo "VIRTUAL_ENV=$PWD/.venv" >> $GITHUB_ENV + echo "$PWD/.venv/bin" >> $GITHUB_PATH + - run: pip install --upgrade pip uv # TODO(markblee): Remove gcp,vertexai_tensorboard from CI. (needed by pytype) - - run: pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' + - run: | + uv pip install '.[core,audio,orbax,dev,gcp,vertexai_tensorboard,open_api]' + # Ensure PyTorch installs CPU wheels (with stubs) so pytype can resolve torch.nn + uv pip install --index-url https://download.pytorch.org/whl/cpu 'torch==2.1.1' 'torchvision==0.16.1' # pylint uses approx 12GB of memory during this run, look into split to decrease? - run: | # Start memory monitor as a background process. diff --git a/.gitignore b/.gitignore index 4d452573e..0a28a1119 100644 --- a/.gitignore +++ b/.gitignore @@ -176,3 +176,4 @@ bazel-* # Emacs *~ +run_specific_test.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7a1e66b4b..a4a6fb84a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,16 +14,19 @@ repos: - id: black name: black entry: black - language: system + language: python + additional_dependencies: ['black==23.12.0'] types: [python] - id: isort name: isort entry: isort - language: system + language: python + additional_dependencies: ['isort==5.13.2'] types: [python] - id: pylint name: pylint entry: pylint args: ['--msg-template="{abspath}:{line}: [{msg_id}({symbol}), {obj}] {msg}"'] - language: system + language: python + additional_dependencies: ['pylint==3.3.7'] types: [python] diff --git a/axlearn/audio/adapter.py b/axlearn/audio/adapter.py new file mode 100644 index 000000000..a5d2c459e --- /dev/null +++ b/axlearn/audio/adapter.py @@ -0,0 +1,246 @@ +# Copyright © 2023 Apple Inc. + +"""Audio model adapter for efficient fine-tuning.""" + +from typing import Optional + +import jax + +from axlearn.common.base_layer import BaseLayer +from axlearn.common.config import REQUIRED, Required, config_class +from axlearn.common.layers import BatchNorm, LayerNorm, Linear +from axlearn.common.module import Module +from axlearn.common.module import functional as F +from axlearn.common.param_init import DefaultInitializer, WeightInitializer + + +class AudioModelAdapter(BaseLayer): + """Adapter layer for efficient fine-tuning of audio models.""" + + @config_class + class Config(BaseLayer.Config): + """Configures AudioModelAdapter.""" + + # Input feature dimension. + input_dim: Required[int] = REQUIRED + # Bottleneck dimension (typically much smaller than input_dim). + bottleneck_dim: Required[int] = REQUIRED + # Whether to apply layer normalization before the adapter. + use_layer_norm: bool = True + # Whether to apply batch normalization in the adapter. + use_batch_norm: bool = False + # Scaling factor for the adapter output. + adapter_scale: float = 1.0 + # Activation function to use. + activation: str = "relu" + # Whether to add a residual connection. + residual: bool = True + + def __init__(self, cfg: Config, *, parent: Optional[Module]): + super().__init__(cfg, parent=parent) + cfg = self.config + + # Initialize with small weights to make adapter less disruptive initially + weight_init = WeightInitializer.default_config().set( + distribution="normal", + fan="fan_in", + scale=0.01, + ) + + bias_init = WeightInitializer.default_config().set( + distribution="normal", + fan=None, + scale=0.01, + ) + + param_init = DefaultInitializer.default_config().set( + init_by_param_name={ + ".*weight": weight_init, + ".*bias": bias_init, + }, + ) + + # Down projection to bottleneck dimension + self._add_child( + "down_proj", + Linear.default_config().set( + input_dim=cfg.input_dim, + output_dim=cfg.bottleneck_dim, + bias=True, + param_init=param_init, + ), + ) + + # Optional batch normalization + if cfg.use_batch_norm: + self._add_child( + "batch_norm", + BatchNorm.default_config().set( + input_dim=cfg.bottleneck_dim, + decay=0.9, + ), + ) + + # Up projection back to input dimension + self._add_child( + "up_proj", + Linear.default_config().set( + input_dim=cfg.bottleneck_dim, + output_dim=cfg.input_dim, + bias=True, + param_init=param_init, + ), + ) + + # Optional layer normalization + if cfg.use_layer_norm: + self._add_child( + "layer_norm", + LayerNorm.default_config().set( + input_dim=cfg.input_dim, + ), + ) + + def forward(self, inputs, **_kwargs): + """Apply the adapter transformation. + + Args: + inputs: Input tensor of shape [batch_size, seq_len, input_dim]. + **_kwargs: Additional keyword arguments (unused, kept for API compatibility). + + Returns: + Tensor of the same shape as inputs. + """ + cfg = self.config + residual = inputs + + # Apply layer normalization if specified + x = inputs + if cfg.use_layer_norm: + x = self.layer_norm(x) + + # Down projection + x = self.down_proj(x) + + # Apply batch normalization if specified + if cfg.use_batch_norm: + # BatchNorm uses is_training from context automatically + x = self.batch_norm(x) + + # Activation + if cfg.activation == "relu": + x = jax.nn.relu(x) + elif cfg.activation == "gelu": + x = jax.nn.gelu(x) + + # Up projection + x = self.up_proj(x) + + # Scale the output + if cfg.adapter_scale != 1.0: + x = x * cfg.adapter_scale + + # Add residual connection if specified + if cfg.residual: + x = x + residual + + return x + + +class ASRModelAdapter(BaseLayer): + """Adapter for Automatic Speech Recognition (ASR) models.""" + + @config_class + class Config(BaseLayer.Config): + """Configures ASRModelAdapter.""" + + # Feature dimension of the encoder. + encoder_dim: Required[int] = REQUIRED + # Bottleneck dimension for encoder adapters. + encoder_bottleneck_dim: Required[int] = REQUIRED + # Feature dimension of the decoder. + decoder_dim: Optional[int] = None + # Bottleneck dimension for decoder adapters. + decoder_bottleneck_dim: Optional[int] = None + # Whether to add adapters to the encoder. + adapt_encoder: bool = True + # Whether to add adapters to the decoder. + adapt_decoder: bool = False + # Adapter configuration. + adapter: AudioModelAdapter.Config = AudioModelAdapter.default_config() + + def __init__(self, cfg: Config, *, parent: Optional[Module]): + super().__init__(cfg, parent=parent) + cfg = self.config + + if cfg.adapt_encoder: + self._add_child( + "encoder_adapter", + cfg.adapter.clone( + input_dim=cfg.encoder_dim, + bottleneck_dim=cfg.encoder_bottleneck_dim, + ), + ) + + if ( + cfg.adapt_decoder + and cfg.decoder_dim is not None + and cfg.decoder_bottleneck_dim is not None + ): + self._add_child( + "decoder_adapter", + cfg.adapter.clone( + input_dim=cfg.decoder_dim, + bottleneck_dim=cfg.decoder_bottleneck_dim, + ), + ) + + def adapt_encoder_features(self, features, *, is_training=False, prng_key, state): + """Apply adaptation to encoder features. + + Args: + features: Encoder features to adapt. + is_training: Whether the model is in training mode. + prng_key: PRNG key for stochastic operations. + state: State for the adapter. + + Returns: + Adapted encoder features. + """ + cfg = self.config + if not cfg.adapt_encoder: + return features + + outputs, _ = F( + self.encoder_adapter, + inputs=(features,), + is_training=is_training, + prng_key=prng_key, + state=state["encoder_adapter"], + ) + return outputs + + def adapt_decoder_features(self, features, *, is_training=False, prng_key, state): + """Apply adaptation to decoder features. + + Args: + features: Decoder features to adapt. + is_training: Whether the model is in training mode. + prng_key: PRNG key for stochastic operations. + state: State for the adapter. + + Returns: + Adapted decoder features. + """ + cfg = self.config + if not cfg.adapt_decoder or not hasattr(self, "decoder_adapter"): + return features + + outputs, _ = F( + self.decoder_adapter, + inputs=(features,), + is_training=is_training, + prng_key=prng_key, + state=state["decoder_adapter"], + ) + return outputs diff --git a/axlearn/audio/adapter_test.py b/axlearn/audio/adapter_test.py new file mode 100644 index 000000000..e7b50b4d6 --- /dev/null +++ b/axlearn/audio/adapter_test.py @@ -0,0 +1,379 @@ +# Copyright © 2024 Apple Inc. + +"""Tests for audio adapters.""" + +import jax +import jax.numpy as jnp +import numpy as np +from absl.testing import parameterized + +from axlearn.audio.adapter import ASRModelAdapter, AudioModelAdapter +from axlearn.common.module import functional as F +from axlearn.common.test_utils import TestCase, assert_allclose + + +class AudioModelAdapterTest(TestCase): + """Tests AudioModelAdapter.""" + + def test_forward_basic(self): + batch_size, seq_len, input_dim, bottleneck_dim = 4, 10, 128, 32 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + inputs = jax.random.normal(input_key, (batch_size, seq_len, input_dim)) + + outputs, _ = F( + layer, + inputs=(inputs,), + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(outputs.shape, inputs.shape) + self.assertTrue(jnp.isfinite(outputs).all()) + + def test_forward_with_layer_norm(self): + batch_size, seq_len, input_dim, bottleneck_dim = 4, 10, 128, 32 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + use_layer_norm=True, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + inputs = jax.random.normal(input_key, (batch_size, seq_len, input_dim)) + + outputs, _ = F( + layer, + inputs=(inputs,), + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(outputs.shape, inputs.shape) + + def test_forward_without_residual(self): + batch_size, seq_len, input_dim, bottleneck_dim = 4, 10, 128, 32 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + residual=False, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + inputs = jax.random.normal(input_key, (batch_size, seq_len, input_dim)) + + outputs, _ = F( + layer, + inputs=(inputs,), + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(outputs.shape, inputs.shape) + with self.assertRaises(AssertionError): + assert_allclose(outputs, inputs) + + def test_forward_with_scaling(self): + batch_size, seq_len, input_dim, bottleneck_dim = 4, 10, 128, 32 + adapter_scale = 0.5 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + adapter_scale=adapter_scale, + residual=False, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + inputs = jax.random.normal(input_key, (batch_size, seq_len, input_dim)) + + outputs, _ = F( + layer, + inputs=(inputs,), + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(outputs.shape, inputs.shape) + + @parameterized.parameters(["relu", "gelu"]) + def test_forward_with_activation(self, activation: str): + batch_size, seq_len, input_dim, bottleneck_dim = 4, 10, 128, 32 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + activation=activation, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + inputs = jax.random.normal(input_key, (batch_size, seq_len, input_dim)) + + outputs, _ = F( + layer, + inputs=(inputs,), + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(outputs.shape, inputs.shape) + self.assertTrue(jnp.isfinite(outputs).all()) + + def test_parameter_counts(self): + input_dim, bottleneck_dim = 256, 64 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + use_layer_norm=True, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + layer_params = layer.initialize_parameters_recursively(prng_key) + + down_proj_weight = layer_params["down_proj"]["weight"] + down_proj_bias = layer_params["down_proj"]["bias"] + up_proj_weight = layer_params["up_proj"]["weight"] + up_proj_bias = layer_params["up_proj"]["bias"] + layer_norm_scale = layer_params["layer_norm"]["scale"] + layer_norm_bias = layer_params["layer_norm"]["bias"] + + self.assertEqual(down_proj_weight.shape, (input_dim, bottleneck_dim)) + self.assertEqual(down_proj_bias.shape, (bottleneck_dim,)) + self.assertEqual(up_proj_weight.shape, (bottleneck_dim, input_dim)) + self.assertEqual(up_proj_bias.shape, (input_dim,)) + self.assertEqual(layer_norm_scale.shape, (input_dim,)) + self.assertEqual(layer_norm_bias.shape, (input_dim,)) + + total_params = np.prod(down_proj_weight.shape) + total_params += np.prod(down_proj_bias.shape) + total_params += np.prod(up_proj_weight.shape) + total_params += np.prod(up_proj_bias.shape) + total_params += np.prod(layer_norm_scale.shape) + total_params += np.prod(layer_norm_bias.shape) + + self.assertEqual(total_params, 33600) + + @parameterized.parameters([True, False]) + def test_training_vs_eval_mode(self, is_training: bool): + batch_size, seq_len, input_dim, bottleneck_dim = 4, 10, 128, 32 + + cfg = AudioModelAdapter.default_config().set( + input_dim=input_dim, + bottleneck_dim=bottleneck_dim, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + inputs = jax.random.normal(input_key, (batch_size, seq_len, input_dim)) + + outputs, _ = F( + layer, + inputs=(inputs,), + is_training=is_training, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(outputs.shape, inputs.shape) + + +class ASRModelAdapterTest(TestCase): + """Tests ASRModelAdapter.""" + + def test_encoder_adapter_only(self): + encoder_dim = 256 + encoder_bottleneck_dim = 64 + batch_size, seq_len = 4, 100 + + cfg = ASRModelAdapter.default_config().set( + encoder_dim=encoder_dim, + encoder_bottleneck_dim=encoder_bottleneck_dim, + adapt_encoder=True, + adapt_decoder=False, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + encoder_features = jax.random.normal(input_key, (batch_size, seq_len, encoder_dim)) + + adapted_features = layer.adapt_encoder_features( + encoder_features, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(adapted_features.shape, encoder_features.shape) + + def test_decoder_adapter_only(self): + decoder_dim = 256 + decoder_bottleneck_dim = 64 + batch_size, seq_len = 4, 50 + + cfg = ASRModelAdapter.default_config().set( + encoder_dim=128, + encoder_bottleneck_dim=32, + decoder_dim=decoder_dim, + decoder_bottleneck_dim=decoder_bottleneck_dim, + adapt_encoder=False, + adapt_decoder=True, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + decoder_features = jax.random.normal(input_key, (batch_size, seq_len, decoder_dim)) + + adapted_features = layer.adapt_decoder_features( + decoder_features, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(adapted_features.shape, decoder_features.shape) + + def test_both_encoders_and_decoders(self): + encoder_dim, encoder_bottleneck_dim = 256, 64 + decoder_dim, decoder_bottleneck_dim = 256, 64 + batch_size, enc_seq_len, dec_seq_len = 4, 100, 50 + + cfg = ASRModelAdapter.default_config().set( + encoder_dim=encoder_dim, + encoder_bottleneck_dim=encoder_bottleneck_dim, + decoder_dim=decoder_dim, + decoder_bottleneck_dim=decoder_bottleneck_dim, + adapt_encoder=True, + adapt_decoder=True, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key1, input_key2 = jax.random.split(prng_key, num=4) + layer_params = layer.initialize_parameters_recursively(init_key) + + encoder_features = jax.random.normal(input_key1, (batch_size, enc_seq_len, encoder_dim)) + decoder_features = jax.random.normal(input_key2, (batch_size, dec_seq_len, decoder_dim)) + + adapted_enc_features = layer.adapt_encoder_features( + encoder_features, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + adapted_dec_features = layer.adapt_decoder_features( + decoder_features, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(adapted_enc_features.shape, encoder_features.shape) + self.assertEqual(adapted_dec_features.shape, decoder_features.shape) + + def test_no_adaptation(self): + encoder_dim = 256 + batch_size, seq_len = 4, 100 + + cfg = ASRModelAdapter.default_config().set( + encoder_dim=encoder_dim, + encoder_bottleneck_dim=64, + adapt_encoder=False, + adapt_decoder=False, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + + encoder_features = jax.random.normal(input_key, (batch_size, seq_len, encoder_dim)) + + adapted_features = layer.adapt_encoder_features( + encoder_features, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + assert_allclose(adapted_features, encoder_features) + + def test_direct_call_fallback(self): + encoder_dim = 256 + batch_size, seq_len = 4, 100 + + cfg = ASRModelAdapter.default_config().set( + encoder_dim=encoder_dim, + encoder_bottleneck_dim=64, + adapt_encoder=True, + adapt_decoder=False, + dtype=jnp.float32, + ) + layer = cfg.set(name="test").instantiate(parent=None) + + prng_key = jax.random.PRNGKey(123) + prng_key, init_key, input_key = jax.random.split(prng_key, num=3) + layer_params = layer.initialize_parameters_recursively(init_key) + encoder_features = jax.random.normal(input_key, (batch_size, seq_len, encoder_dim)) + + adapted_features = layer.adapt_encoder_features( + encoder_features, + is_training=True, + prng_key=prng_key, + state=layer_params, + ) + + self.assertEqual(adapted_features.shape, encoder_features.shape) diff --git a/axlearn/audio/aligner/ctc_aligner.py b/axlearn/audio/aligner/ctc_aligner.py index 1bf6acb58..8ae402053 100644 --- a/axlearn/audio/aligner/ctc_aligner.py +++ b/axlearn/audio/aligner/ctc_aligner.py @@ -15,6 +15,7 @@ """ + from dataclasses import asdict, dataclass from typing import Literal, NamedTuple, Optional, Tuple diff --git a/axlearn/audio/evaler_asr_test.py b/axlearn/audio/evaler_asr_test.py index e6bb72340..53aca811d 100644 --- a/axlearn/audio/evaler_asr_test.py +++ b/axlearn/audio/evaler_asr_test.py @@ -133,13 +133,13 @@ def _compute_metrics( if brevity_penalty: decode_kwargs["brevity_penalty"] = brevity_penalty - cfg: WordErrorRateMetricCalculator.Config = ( - WordErrorRateMetricCalculator.default_config().set( - vocab=config_for_class(seqio.SentencePieceVocabulary).set( - sentencepiece_model_file=vocab_file, - ), - model_method_kwargs=decode_kwargs, - ) + cfg: ( + WordErrorRateMetricCalculator.Config + ) = WordErrorRateMetricCalculator.default_config().set( + vocab=config_for_class(seqio.SentencePieceVocabulary).set( + sentencepiece_model_file=vocab_file, + ), + model_method_kwargs=decode_kwargs, ) calculator: WordErrorRateMetricCalculator = cfg.set(name="test-metric").instantiate( parent=None, model=model, model_param_partition_specs={} diff --git a/axlearn/cloud/common/bastion_test.py b/axlearn/cloud/common/bastion_test.py index b3d6255e3..5e14c9ce1 100644 --- a/axlearn/cloud/common/bastion_test.py +++ b/axlearn/cloud/common/bastion_test.py @@ -1690,10 +1690,9 @@ def test_sync_jobs_for_valid_pending_to_sudden_invalid_jobs(self): mock_validator_cfg = MockStatefulJobValidator.default_config() mock_append_to_job_history = mock.MagicMock() - with self._patch_bastion( - validator_cfg=mock_validator_cfg - ) as mock_bastion, mock.patch.object( - mock_bastion, "_append_to_job_history", mock_append_to_job_history + with ( + self._patch_bastion(validator_cfg=mock_validator_cfg) as mock_bastion, + mock.patch.object(mock_bastion, "_append_to_job_history", mock_append_to_job_history), ): os.makedirs(mock_bastion._active_dir, exist_ok=True) os.makedirs(_JOB_DIR, exist_ok=True) @@ -1802,10 +1801,9 @@ def test_sync_jobs_for_immediate_invalid_pending_jobs(self): mock_validator_cfg = MockAlwaysInvalidValidator.default_config() mock_append_to_job_history = mock.MagicMock() - with self._patch_bastion( - validator_cfg=mock_validator_cfg - ) as mock_bastion, mock.patch.object( - mock_bastion, "_append_to_job_history", mock_append_to_job_history + with ( + self._patch_bastion(validator_cfg=mock_validator_cfg) as mock_bastion, + mock.patch.object(mock_bastion, "_append_to_job_history", mock_append_to_job_history), ): os.makedirs(mock_bastion._active_dir, exist_ok=True) os.makedirs(_JOB_DIR, exist_ok=True) diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index c405e8dfb..cfc68a33e 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -33,7 +33,6 @@ class _ServiceProtocol(enum.Enum): - """https://kubernetes.io/docs/reference/networking/service-protocols/""" TCP = "TCP" @@ -42,7 +41,6 @@ class _ServiceProtocol(enum.Enum): class _ServiceType(enum.Enum): - """https://cloud.google.com/kubernetes-engine/docs/concepts/service#types-of-services sss""" CLUSTER_IP = "ClusterIP" diff --git a/axlearn/cloud/gcp/k8s_service.py b/axlearn/cloud/gcp/k8s_service.py index b8e81ac60..a27c25046 100644 --- a/axlearn/cloud/gcp/k8s_service.py +++ b/axlearn/cloud/gcp/k8s_service.py @@ -1,4 +1,5 @@ -""" k8s service module.""" +"""k8s service module.""" + import copy import logging from typing import Any, Optional diff --git a/axlearn/cloud/gcp/measurement_test.py b/axlearn/cloud/gcp/measurement_test.py index 1d1f47161..05688e09e 100644 --- a/axlearn/cloud/gcp/measurement_test.py +++ b/axlearn/cloud/gcp/measurement_test.py @@ -177,9 +177,10 @@ def test_record_event_context_manager_handles_runtime_error(self): recorder = GoodputRecorder(cfg) with mock.patch("jax.process_index", return_value=0): - with mock.patch( - "ml_goodput_measurement.goodput.GoodputRecorder" - ) as mock_recorder_cls, mock.patch.object(logging, "warning") as mock_warning: + with ( + mock.patch("ml_goodput_measurement.goodput.GoodputRecorder") as mock_recorder_cls, + mock.patch.object(logging, "warning") as mock_warning, + ): mock_instance = mock_recorder_cls.return_value def raise_runtime_error(*args, **kwargs): diff --git a/axlearn/cloud/gcp/pathways_utils.py b/axlearn/cloud/gcp/pathways_utils.py index 528ac543d..b4e1df10e 100644 --- a/axlearn/cloud/gcp/pathways_utils.py +++ b/axlearn/cloud/gcp/pathways_utils.py @@ -110,7 +110,7 @@ def get_pathways_tpu_version(gke_machine_type: str) -> str: def get_megascale_options( - xla_options: dict[str, Union[str, bool, int]] + xla_options: dict[str, Union[str, bool, int]], ) -> dict[str, Union[str, bool, int]]: """Filters XLA options for those pertaining to Megascale. @@ -125,7 +125,7 @@ def get_megascale_options( def get_xla_options( - xla_options: dict[str, Union[str, bool, int]] + xla_options: dict[str, Union[str, bool, int]], ) -> dict[str, Union[str, bool, int]]: """Filters XLA options for those starting with 'xla_'. @@ -962,9 +962,9 @@ def _build_head_container(self) -> dict: ], imagePullPolicy="Always", resources=resources, - ports=[dict(containerPort=self.config.target_port)] - if self.config.enable_service - else [], + ports=( + [dict(containerPort=self.config.target_port)] if self.config.enable_service else [] + ), ) def build_leader_pod(self) -> Nested[Any]: diff --git a/axlearn/cloud/gcp/tpu_health_check_test.py b/axlearn/cloud/gcp/tpu_health_check_test.py index 51772f6df..d2815fd8f 100644 --- a/axlearn/cloud/gcp/tpu_health_check_test.py +++ b/axlearn/cloud/gcp/tpu_health_check_test.py @@ -47,8 +47,11 @@ def test_parsing(self): def test_global_health_check(self): # On CPU CI, this should pass. - with mock.patch("os.kill") as mock_exit, mock.patch.dict( - os.environ, {"HOSTNAME": "h", "NODE_NAME": "n", "MEGASCALE_NUM_SLICES": "1"} + with ( + mock.patch("os.kill") as mock_exit, + mock.patch.dict( + os.environ, {"HOSTNAME": "h", "NODE_NAME": "n", "MEGASCALE_NUM_SLICES": "1"} + ), ): global_health_check("global=180", output_dir="") mock_exit.assert_not_called() @@ -63,10 +66,12 @@ def _check_failure_file(self, folder: str, keyword: str): self.fail("should not reach here") def test_global_health_check_timeout(self): - with mock.patch( - "os.kill" - ) as mock_exit, tempfile.TemporaryDirectory() as d, mock.patch.dict( - os.environ, {"HOSTNAME": "h", "NODE_NAME": "n", "MEGASCALE_NUM_SLICES": "1"} + with ( + mock.patch("os.kill") as mock_exit, + tempfile.TemporaryDirectory() as d, + mock.patch.dict( + os.environ, {"HOSTNAME": "h", "NODE_NAME": "n", "MEGASCALE_NUM_SLICES": "1"} + ), ): global_health_check("global=0.000001", output_dir=d) mock_exit.assert_called_once() @@ -79,10 +84,13 @@ def test_raises_with_no_megascale_env(self): pairwise_slice_health_check("pairwise=1", output_dir="") def test_global_health_check_failure(self): - with mock.patch("os.kill") as mock_exit, mock.patch( - f"{tpu_health_check_main.__name__}.main", lambda: False - ), tempfile.TemporaryDirectory() as d, mock.patch.dict( - os.environ, {"HOSTNAME": "h", "NODE_NAME": "n", "MEGASCALE_NUM_SLICES": "1"} + with ( + mock.patch("os.kill") as mock_exit, + mock.patch(f"{tpu_health_check_main.__name__}.main", lambda: False), + tempfile.TemporaryDirectory() as d, + mock.patch.dict( + os.environ, {"HOSTNAME": "h", "NODE_NAME": "n", "MEGASCALE_NUM_SLICES": "1"} + ), ): global_health_check("global=180", output_dir=d) mock_exit.assert_called_once() diff --git a/axlearn/common/adapter_torch.py b/axlearn/common/adapter_torch.py index b6b2620da..0faaec8f0 100644 --- a/axlearn/common/adapter_torch.py +++ b/axlearn/common/adapter_torch.py @@ -179,6 +179,7 @@ def _axlearn_weight_mapper(self, weight_name: str, weight: torch.Tensor) -> torc class LayerNorm(nn.LayerNorm, TorchModule): + # pylint: disable=useless-parent-delegation # maintains interface compatibility def __init__(self, *args, eps=1e-6, **kwargs): super().__init__(*args, eps=eps, **kwargs) diff --git a/axlearn/common/aot_compilation_test.py b/axlearn/common/aot_compilation_test.py index 4ea7f60ee..57c1e6c4d 100644 --- a/axlearn/common/aot_compilation_test.py +++ b/axlearn/common/aot_compilation_test.py @@ -1,4 +1,5 @@ """Tests aot_compilation utils.""" + from typing import cast from axlearn.common import test_utils diff --git a/axlearn/common/array_serialization_test.py b/axlearn/common/array_serialization_test.py index c086c7d5b..ecf240c8d 100644 --- a/axlearn/common/array_serialization_test.py +++ b/axlearn/common/array_serialization_test.py @@ -125,11 +125,13 @@ def transfer_to_host_patch(*args, **kwargs): return old_transfer(*args, **kwargs) d2h_future = array_serialization.futures.Future() - with mock.patch( - f"{array_serialization.__name__}.{_ts_open}", - ts_open_patch, - ), get_tensorstore_spec(arr) as spec, mock.patch( - f"{array_serialization.__name__}._transfer_to_host", transfer_to_host_patch + with ( + mock.patch( + f"{array_serialization.__name__}.{_ts_open}", + ts_open_patch, + ), + get_tensorstore_spec(arr) as spec, + mock.patch(f"{array_serialization.__name__}._transfer_to_host", transfer_to_host_patch), ): # Either RuntimeError(Array has been deleted with shape) or # ValueError(...Buffer has been deleted or donated...) may occur. @@ -151,11 +153,13 @@ def transfer_to_host_patch(*args, **kwargs): arr = self._create_partially_replicated_array(sharded) arr_host = jax.device_get(arr) d2h_future = array_serialization.futures.Future() - with mock.patch( - f"{array_serialization.__name__}.{_ts_open}", - ts_open_patch, - ), get_tensorstore_spec(arr) as spec, mock.patch( - f"{array_serialization.__name__}._transfer_to_host", transfer_to_host_patch + with ( + mock.patch( + f"{array_serialization.__name__}.{_ts_open}", + ts_open_patch, + ), + get_tensorstore_spec(arr) as spec, + mock.patch(f"{array_serialization.__name__}._transfer_to_host", transfer_to_host_patch), ): f = _CommitFuture( _run_serializer( @@ -185,10 +189,13 @@ async def ts_open_patch(*_, **__): raise RuntimeError("Test") d2h_future = array_serialization.futures.Future() - with mock.patch( - f"{array_serialization.__name__}.{_ts_open}", - ts_open_patch, - ), get_tensorstore_spec(arr) as spec: + with ( + mock.patch( + f"{array_serialization.__name__}.{_ts_open}", + ts_open_patch, + ), + get_tensorstore_spec(arr) as spec, + ): f = _CommitFuture( _run_serializer( [arr], [spec], [d2h_future], max_data_shard_degree=-1, shard_threshold_bytes=-1 @@ -202,10 +209,13 @@ def transfer_to_host_patch(*_): raise RuntimeError("Test") d2h_future = array_serialization.futures.Future() - with mock.patch( - f"{array_serialization.__name__}._transfer_to_host", - transfer_to_host_patch, - ), get_tensorstore_spec(arr) as spec: + with ( + mock.patch( + f"{array_serialization.__name__}._transfer_to_host", + transfer_to_host_patch, + ), + get_tensorstore_spec(arr) as spec, + ): f = _CommitFuture( _run_serializer( [arr], [spec], [d2h_future], max_data_shard_degree=-1, shard_threshold_bytes=-1 @@ -362,13 +372,15 @@ async def mock_ts_open(spec_arg, *args, **kwargs): return await original_ts_open(call_arg, *args, **kwargs) # Write the data to local files - with get_tensorstore_spec_for_deserialization(data) as ( - tensorstore_spec, - temp_path, - ), mock.patch( - f"{array_serialization.__name__}.{_ts_open}", new=mock_ts_open - ), mock.patch.dict( - "os.environ", {"JAX_PLATFORMS": jax_platforms, "ENABLE_GCS_GRPC": enable_gcs_grpc} + with ( + get_tensorstore_spec_for_deserialization(data) as ( + tensorstore_spec, + temp_path, + ), + mock.patch(f"{array_serialization.__name__}.{_ts_open}", new=mock_ts_open), + mock.patch.dict( + "os.environ", {"JAX_PLATFORMS": jax_platforms, "ENABLE_GCS_GRPC": enable_gcs_grpc} + ), ): manager.serialize( data, diff --git a/axlearn/common/checkpointer.py b/axlearn/common/checkpointer.py index f2162a3c9..b8c38ef90 100644 --- a/axlearn/common/checkpointer.py +++ b/axlearn/common/checkpointer.py @@ -474,9 +474,11 @@ def _get_spec(self, step: int, state: NestedTensor, ckpt_dir: str) -> Checkpoint spec.shardings.append( jax.sharding.NamedSharding( mesh, - jax.sharding.PartitionSpec() - if value.mesh_axes is None - else value.mesh_axes, + ( + jax.sharding.PartitionSpec() + if value.mesh_axes is None + else value.mesh_axes + ), ) ) elif isinstance(value, tf.data.Iterator): diff --git a/axlearn/common/flash_attention/gpu_attention.py b/axlearn/common/flash_attention/gpu_attention.py index 689dc65ae..8cb4d2143 100644 --- a/axlearn/common/flash_attention/gpu_attention.py +++ b/axlearn/common/flash_attention/gpu_attention.py @@ -111,7 +111,7 @@ def _key_value_iterator_indices(block_mask_map: np.ndarray) -> Tuple[Tensor, Ten return jnp.asarray(index_offset), jnp.asarray(index_offset_size) -def _mha_forward_kernel( +def _mha_forward_kernel( # pylint: disable=too-many-positional-arguments q_ref, k_ref, v_ref, diff --git a/axlearn/common/inference_test.py b/axlearn/common/inference_test.py index d6d946ed3..20bdbf2be 100644 --- a/axlearn/common/inference_test.py +++ b/axlearn/common/inference_test.py @@ -779,10 +779,13 @@ def test_pipeline_summary_writer( mock_summary_writer = mock.Mock(return_value=None) - with mock.patch( - "axlearn.common.summary_writer.SummaryWriter.Config.instantiate", - mock.MagicMock(return_value=mock_summary_writer), - ), tempfile.TemporaryDirectory() as local_tmp_dir: + with ( + mock.patch( + "axlearn.common.summary_writer.SummaryWriter.Config.instantiate", + mock.MagicMock(return_value=mock_summary_writer), + ), + tempfile.TemporaryDirectory() as local_tmp_dir, + ): root_dir = local_tmp_dir if local_run else "gs://axlearn-public/testdata/inference_test" with set_data_dir(root_dir): prng_key = jax.random.PRNGKey(11) diff --git a/axlearn/common/input_lm.py b/axlearn/common/input_lm.py index e0f3d0ef4..f8c2b8fa0 100644 --- a/axlearn/common/input_lm.py +++ b/axlearn/common/input_lm.py @@ -628,7 +628,7 @@ def map_targets_out_of_class(example: dict[str, tf.Tensor]) -> dict[str, tf.Tens def _trim_and_pack_with_segments( - feature_lengths: dict[str, int] + feature_lengths: dict[str, int], ) -> input_tf_data.DatasetToDatasetFn: """Trim and pack inputs, injecting `*_segment_ids` and `*_positions`. @@ -683,7 +683,7 @@ def restore_intermediate_zeros(example: dict[str, tf.Tensor]): def _trim_and_pad_with_segments( - feature_lengths: dict[str, int] + feature_lengths: dict[str, int], ) -> input_tf_data.DatasetToDatasetFn: """Trim and pad inputs, injecting `*_segment_ids`, `*_positions`. diff --git a/axlearn/common/input_t5.py b/axlearn/common/input_t5.py index b47d13c3a..c55dc5329 100644 --- a/axlearn/common/input_t5.py +++ b/axlearn/common/input_t5.py @@ -232,7 +232,7 @@ def split_tokens( @seqio.map_over_dataset def split_tokens_example( - x: dict[str, tf.Tensor] + x: dict[str, tf.Tensor], ) -> tuple[dict[str, tf.Tensor], dict[str, tf.Tensor]]: """Split one token sequence into multiple sequences.""" tokens = x[input_key] diff --git a/axlearn/common/input_text.py b/axlearn/common/input_text.py index f2827b964..4d07acc30 100644 --- a/axlearn/common/input_text.py +++ b/axlearn/common/input_text.py @@ -179,7 +179,7 @@ def add_token_type_ids( input_key = [input_key] def example_fn( - example: dict[str, Union[tf.Tensor, tf.RaggedTensor]] + example: dict[str, Union[tf.Tensor, tf.RaggedTensor]], ) -> dict[str, Union[tf.Tensor, tf.RaggedTensor]]: token_type_ids = [] for i, key in enumerate(input_key): diff --git a/axlearn/common/layers_test.py b/axlearn/common/layers_test.py index 8914c6663..90220bf8d 100644 --- a/axlearn/common/layers_test.py +++ b/axlearn/common/layers_test.py @@ -1348,7 +1348,7 @@ def test_embed_with_constant_scale_validation(self): dim, num_embeddings, rng, - scale=Embedding.Scale.CONSTANT + scale=Embedding.Scale.CONSTANT, # Missing scale_constant ) self.assertIn("scale_constant must be specified", str(cm.exception)) diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py index 2fa4baac9..1b3f8e9b0 100644 --- a/axlearn/common/optimizers.py +++ b/axlearn/common/optimizers.py @@ -2033,9 +2033,9 @@ def _update2(u: Tensor, param: OptParam, weight_decay_scale: float = 1.0): params, per_param_scale=weight_decay_per_param_scale ) updates2 = jax.tree.map( - lambda u, p, wds: None - if u is None - else _update2(u, param=p, weight_decay_scale=wds), + lambda u, p, wds: ( + None if u is None else _update2(u, param=p, weight_decay_scale=wds) + ), updates, params, weight_decay_scales, diff --git a/axlearn/common/optimizers_test.py b/axlearn/common/optimizers_test.py index 6fe408239..e72d41f17 100644 --- a/axlearn/common/optimizers_test.py +++ b/axlearn/common/optimizers_test.py @@ -1310,8 +1310,8 @@ def test_param_ema(self, decay, dtype): self.assertEqual(new_state.count, 1) if isinstance(decay, float): - ema_fn = ( - lambda p: (1 - decay) * p.value + ema_fn = lambda p: ( + (1 - decay) * p.value if jnp.issubdtype(p.value.dtype, jnp.floating) else p.value ) diff --git a/axlearn/common/quantized_dot_general/utils.py b/axlearn/common/quantized_dot_general/utils.py index 292e42bf9..fa6ceeb7c 100644 --- a/axlearn/common/quantized_dot_general/utils.py +++ b/axlearn/common/quantized_dot_general/utils.py @@ -14,9 +14,7 @@ # Copyright 2024 The AQT Authors. # Licensed under the Apache License, Version 2.0 (the "License"). -"""QuantizedDotGeneral Utilities. Hosts default quantization configuration. - -""" +"""QuantizedDotGeneral Utilities. Hosts default quantization configuration.""" import functools import jax diff --git a/axlearn/common/rattention/kernels/linear_attention_kernels.py b/axlearn/common/rattention/kernels/linear_attention_kernels.py index d3542cc5f..2bb5f3381 100644 --- a/axlearn/common/rattention/kernels/linear_attention_kernels.py +++ b/axlearn/common/rattention/kernels/linear_attention_kernels.py @@ -1,6 +1,6 @@ # Copyright © 2025 Apple Inc. -""" Pallas kernels for Linear Attention (LA) specialized for sliding window attention. +"""Pallas kernels for Linear Attention (LA) specialized for sliding window attention. A specialized feature map from the following reference is used to support sliding window attention. The chunking strategy is similar to the one used in ssm_kernels/ssd_kernels.py. diff --git a/axlearn/common/rattention/rattention.py b/axlearn/common/rattention/rattention.py index 65f09b97d..669ab3e4a 100644 --- a/axlearn/common/rattention/rattention.py +++ b/axlearn/common/rattention/rattention.py @@ -1,4 +1,5 @@ """Implementation of RAttention with residual linear attention.""" + from functools import partial from typing import Callable, Optional, Union diff --git a/axlearn/common/rattention/rattention_test.py b/axlearn/common/rattention/rattention_test.py index 0b587fcc5..34f956d91 100644 --- a/axlearn/common/rattention/rattention_test.py +++ b/axlearn/common/rattention/rattention_test.py @@ -1,4 +1,5 @@ """Tests for RAttention and ResidualLinearAttention.""" + import copy import jax diff --git a/axlearn/common/rattention/utils.py b/axlearn/common/rattention/utils.py index 24559dca9..39c889721 100644 --- a/axlearn/common/rattention/utils.py +++ b/axlearn/common/rattention/utils.py @@ -1,4 +1,5 @@ """Utilities for RAttention.""" + from typing import Optional import jax diff --git a/axlearn/common/ssm.py b/axlearn/common/ssm.py index b1c086b95..c33fa96b3 100644 --- a/axlearn/common/ssm.py +++ b/axlearn/common/ssm.py @@ -409,7 +409,7 @@ def forward( # We need to jit a function before shard_mapping it. @jax.jit - def jit_mamba_scan(x, a, b, c, delta, d): + def jit_mamba_scan(x, a, b, c, delta, d): # pylint: disable=too-many-positional-arguments y = compute_mamba_scan( # [batch_size, seq_len, inner_dim] x, a, @@ -459,7 +459,8 @@ def default_mamba_dim_to_partition_specs( the Pallas-based Mamba implementation. The inner dimension is sharded over the default tensor-parallel axis name if present, - and the the batch is sharded over the remainder of the axes. + the sequence dimension is sharded over the sequence-parallel axis name if present, + and the batch is sharded over the remainder of the axes. Args: mesh_axis_names: Mesh axis names. @@ -467,13 +468,14 @@ def default_mamba_dim_to_partition_specs( Returns: A dictionary keyed by Mamba tensor dims with partition spec values. """ - batch_axis_names = tuple(el for el in mesh_axis_names if el != "model") + batch_axis_names = tuple(el for el in mesh_axis_names if el not in ("model", "seq")) tp_axis_name = "model" if "model" in mesh_axis_names else None + seq_axis_name = "seq" if "seq" in mesh_axis_names else None - # TODO(swiseman): support sequence parallelism. - x_spec = PartitionSpec(batch_axis_names, None, tp_axis_name) + # Support sequence parallelism by sharding the sequence dimension (middle dim in btd/bts). + x_spec = PartitionSpec(batch_axis_names, seq_axis_name, tp_axis_name) a_spec = PartitionSpec(None, tp_axis_name) - b_spec = PartitionSpec(batch_axis_names, None, None) + b_spec = PartitionSpec(batch_axis_names, seq_axis_name, None) d_spec = PartitionSpec(None, tp_axis_name) partition_specs = {"btd": x_spec, "sd": a_spec, "bts": b_spec, "1d": d_spec} return partition_specs @@ -481,12 +483,13 @@ def default_mamba_dim_to_partition_specs( def default_output_partition_spec( mesh_axis_names: Sequence[str], -) -> dict[str, PartitionSpec]: +) -> PartitionSpec: """Builds a default output partition spec for the shard_mapped Pallas-based Mamba implementation. The inner dimension is sharded over the default tensor-parallel axis name if present, - and the the batch is sharded over the remainder of the axes. + the sequence dimension is sharded over the sequence-parallel axis name if present, + and the batch is sharded over the remainder of the axes. Args: mesh_axis_names: Mesh axis names. @@ -494,10 +497,11 @@ def default_output_partition_spec( Returns: A PartitionSpec. """ - batch_axis_names = tuple(el for el in mesh_axis_names if el != "model") + batch_axis_names = tuple(el for el in mesh_axis_names if el not in ("model", "seq")) tp_axis_name = "model" if "model" in mesh_axis_names else None - # TODO(swiseman): support sequence parallelism. - return PartitionSpec(batch_axis_names, None, tp_axis_name) + seq_axis_name = "seq" if "seq" in mesh_axis_names else None + # Support sequence parallelism by sharding the sequence dimension. + return PartitionSpec(batch_axis_names, seq_axis_name, tp_axis_name) def _at_least_float32(x: Tensor) -> Tensor: @@ -560,10 +564,10 @@ class Config(BaseLayer.Config): # The recurrence implementation to use for full-sequence inputs. mamba_recurrence: BaseMambaRecurrence = LinearScanMambaRecurrence.default_config() # The recurrence implementation to use for inference. - inference_mamba_recurrence: BaseMambaRecurrence = ( - LinearScanMambaRecurrence.default_config().set( - output_mode=MambaRecurrenceOutputMode.OUTPUTS_AND_STATES - ) + inference_mamba_recurrence: ( + BaseMambaRecurrence + ) = LinearScanMambaRecurrence.default_config().set( + output_mode=MambaRecurrenceOutputMode.OUTPUTS_AND_STATES ) class MambaOutput(NamedTuple): @@ -1768,10 +1772,10 @@ class Config(BaseLayer.Config): # The recurrence implementation to use for full-sequence inputs. ssd_recurrence: BaseSSDRecurrence = PallasSSDRecurrence.default_config() # The recurrence implementation to use for inference. - inference_mamba_recurrence: BaseSSDRecurrence = ( - LinearScanSSDRecurrence.default_config().set( - output_mode=MambaRecurrenceOutputMode.OUTPUTS_AND_STATES - ) + inference_mamba_recurrence: ( + BaseSSDRecurrence + ) = LinearScanSSDRecurrence.default_config().set( + output_mode=MambaRecurrenceOutputMode.OUTPUTS_AND_STATES ) class Mamba2Output(NamedTuple): diff --git a/axlearn/common/ssm_kernels/ssd_kernels.py b/axlearn/common/ssm_kernels/ssd_kernels.py index 9fa3bb244..3f96f275b 100644 --- a/axlearn/common/ssm_kernels/ssd_kernels.py +++ b/axlearn/common/ssm_kernels/ssd_kernels.py @@ -1,6 +1,6 @@ # Copyright © 2024 Apple Inc. -""" Pallas kernels for Mamba2 +"""Pallas kernels for Mamba2 High-level idea: this kernel implements a two-level chunking algorithm to balance memory consumption and running speed. Intuitively, we store chunk-level diff --git a/axlearn/common/ssm_test.py b/axlearn/common/ssm_test.py index 0bef3b0dd..99f905699 100644 --- a/axlearn/common/ssm_test.py +++ b/axlearn/common/ssm_test.py @@ -15,6 +15,7 @@ """Tests Mamba/Mamba2 and Jamba implementations.""" +# pytype: disable=module-attr import math from typing import Optional @@ -23,6 +24,7 @@ import numpy as np import pytest import torch +import torch.nn.functional as F_torch from absl.testing import parameterized from jax._src.mesh import ResourceEnv, thread_resources from jax.experimental import mesh_utils @@ -47,6 +49,8 @@ RepeatedSSMLayer, StackedMixedSSMTransformerLayer, StackedSSMLayer, + default_mamba_dim_to_partition_specs, + default_output_partition_spec, ) from axlearn.common.ssm_kernels.ssd_kernels import ssd from axlearn.common.test_utils import TestCase, assert_allclose, set_threefry_partitionable @@ -233,6 +237,7 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None): hidden_states += self.conv1d.bias hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding else: + # pylint: disable=not-callable # pad is callable, false positive conv_state = torch.nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) @@ -253,7 +258,8 @@ def slow_forward(self, input_states, cache_params: Optional[MambaCache]=None): ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 ) discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size] - discrete_time_step = torch.nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len] + # pylint: disable=not-callable # softplus is callable, false positive + discrete_time_step = F_torch.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len] # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] @@ -1240,7 +1246,7 @@ def teardown_class(cls): dtype=[jnp.float32, jnp.bfloat16], ) def forward( - self, input_dim: int, state_dim: int, num_heads: int, num_groups: int, dtype: jnp.dtype + self, input_dim: int, *, state_dim: int, num_heads: int, num_groups: int, dtype: jnp.dtype ): mamba2block_cfg = JambaMamba2Block.default_config().set( name="test", @@ -1283,6 +1289,7 @@ def extend_step( self, batch_size: int, input_dim: int, + *, seq_len: int, state_dim: int, num_heads: int, @@ -1352,6 +1359,7 @@ def test_prefill_states( self, batch_size: int, input_dim: int, + *, seq_len: int, state_dim: int, num_heads: int, @@ -1549,3 +1557,68 @@ def _j2t(param): torch_output_np = torch_output.cpu().detach().numpy() assert_allclose(torch_output_np, jax_output_np, atol=1e-2, rtol=1e-2) + + +class PartitionSpecTest(TestCase): + """Tests for Mamba partition spec helper functions.""" + + def test_default_mamba_dim_to_partition_specs_without_seq(self): + """Test partition specs without sequence parallelism.""" + mesh_axis_names = ("data", "fsdp", "model") + specs = default_mamba_dim_to_partition_specs(mesh_axis_names) + + # batch should be sharded over data and fsdp + # sequence should not be sharded (None) + # inner dim should be sharded over model + self.assertEqual(specs["btd"], PartitionSpec(("data", "fsdp"), None, "model")) + self.assertEqual(specs["sd"], PartitionSpec(None, "model")) + self.assertEqual(specs["bts"], PartitionSpec(("data", "fsdp"), None, None)) + self.assertEqual(specs["1d"], PartitionSpec(None, "model")) + + def test_default_mamba_dim_to_partition_specs_with_seq(self): + """Test partition specs with sequence parallelism enabled.""" + mesh_axis_names = ("data", "fsdp", "seq", "model") + specs = default_mamba_dim_to_partition_specs(mesh_axis_names) + + # batch should be sharded over data and fsdp (not seq or model) + # sequence should be sharded over seq + # inner dim should be sharded over model + self.assertEqual(specs["btd"], PartitionSpec(("data", "fsdp"), "seq", "model")) + self.assertEqual(specs["sd"], PartitionSpec(None, "model")) + self.assertEqual(specs["bts"], PartitionSpec(("data", "fsdp"), "seq", None)) + self.assertEqual(specs["1d"], PartitionSpec(None, "model")) + + def test_default_mamba_dim_to_partition_specs_only_seq(self): + """Test partition specs with only sequence axis.""" + mesh_axis_names = ("seq",) + specs = default_mamba_dim_to_partition_specs(mesh_axis_names) + + # Only seq parallelism, no batch or model sharding + self.assertEqual(specs["btd"], PartitionSpec((), "seq", None)) + self.assertEqual(specs["sd"], PartitionSpec(None, None)) + self.assertEqual(specs["bts"], PartitionSpec((), "seq", None)) + self.assertEqual(specs["1d"], PartitionSpec(None, None)) + + def test_default_output_partition_spec_without_seq(self): + """Test output partition spec without sequence parallelism.""" + mesh_axis_names = ("data", "fsdp", "model") + spec = default_output_partition_spec(mesh_axis_names) + + # batch over data and fsdp, no seq sharding, inner dim over model + self.assertEqual(spec, PartitionSpec(("data", "fsdp"), None, "model")) + + def test_default_output_partition_spec_with_seq(self): + """Test output partition spec with sequence parallelism enabled.""" + mesh_axis_names = ("data", "fsdp", "seq", "model") + spec = default_output_partition_spec(mesh_axis_names) + + # batch over data and fsdp, sequence over seq, inner dim over model + self.assertEqual(spec, PartitionSpec(("data", "fsdp"), "seq", "model")) + + def test_default_output_partition_spec_minimal(self): + """Test output partition spec with minimal mesh.""" + mesh_axis_names = ("model",) + spec = default_output_partition_spec(mesh_axis_names) + + # No batch or seq sharding, only model sharding + self.assertEqual(spec, PartitionSpec((), None, "model")) diff --git a/axlearn/vision/coco_utils.py b/axlearn/vision/coco_utils.py index e86395eb7..1d9b04540 100644 --- a/axlearn/vision/coco_utils.py +++ b/axlearn/vision/coco_utils.py @@ -65,6 +65,7 @@ def __init__(self, eval_type="box", annotation_file=None, gt_dataset=None): self.dataset = gt_dataset self.createIndex() + # pylint: disable=invalid-name # matches COCO API naming convention def loadRes(self, predictions: list[dict[str, Any]]) -> coco.COCO: """Loads result file and return a result api object.