From d208f6359ec972e94a4709572d5023a50d2b60f7 Mon Sep 17 00:00:00 2001 From: Haoshuo Huang Date: Fri, 25 Jul 2025 11:48:26 -0700 Subject: [PATCH 1/4] minimal example --- axlearn/common/input_grain.py | 142 +++++++++++++++++- axlearn/common/trainer_config_modifier.py | 114 +++++++++++++- .../common/trainer_config_modifier_test.py | 97 ++++++++++++ axlearn/experiments/text/gpt/fuji.py | 48 ++++++ 4 files changed, 394 insertions(+), 7 deletions(-) diff --git a/axlearn/common/input_grain.py b/axlearn/common/input_grain.py index 10f58a853..5b4d4b1f9 100644 --- a/axlearn/common/input_grain.py +++ b/axlearn/common/input_grain.py @@ -30,6 +30,9 @@ * `repeat` with `num_repeat=None` will produce datasets with size `sys.maxsize`. """ + +import os +import json import sys from dataclasses import dataclass from typing import Any, Callable, Optional, Protocol, Sequence, TypeVar, Union, runtime_checkable @@ -45,12 +48,14 @@ from jax.experimental import multihost_utils from axlearn.common import input_base, utils +from axlearn.common import file_system as fs from axlearn.common.config import ( REQUIRED, ConfigOr, Required, config_class, config_for_class, + config_for_function, maybe_instantiate, ) from axlearn.common.module import Module @@ -69,14 +74,12 @@ class RaggedTensor(list): @runtime_checkable class _CallableTransform(Protocol): - def __call__(self, example: Any) -> Any: - ... + def __call__(self, example: Any) -> Any: ... @runtime_checkable class _RandomCallableTransform(Protocol): - def __call__(self, example: Any, rng: np.random.Generator) -> Any: - ... + def __call__(self, example: Any, rng: np.random.Generator) -> Any: ... # Grain supports a set of predefined transformations (e.g. grain.MapTransform), as well as callables @@ -125,8 +128,7 @@ def __post_init__(self): class BuildDatasetFn(Protocol): """A function to create a grain data source.""" - def __call__(self, dispatch_config: DispatchConfig) -> Dataset: - ... + def __call__(self, dispatch_config: DispatchConfig) -> Dataset: ... def _copy_tree(x: _T) -> _T: @@ -762,3 +764,131 @@ def shape_dtype(x): return jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype) return jax.tree.map(shape_dtype, example) + + +def mixture_train_input_source( + *, + is_training: bool, + vocab_cfg: ConfigOr, + preprocessor: Union[ConfigOr, list[ConfigOr]], + data_mixture_components: Union[ConfigOr, list], + max_sequence_length: int, + replace_newlines_with: str = "", + fake_input_source_cfg: Optional[ConfigOr] = None, + seed: Optional[int] = 42, +) -> BuildDatasetFn: + """Build mixture training input source for decoder-only LM model using grain. + + Mixture sampling happens after input processing but before batching, meaning that each batch + example will only contain tokens from a single source. + + Args: + is_training: A boolean indicating that inputs will be used for training. + vocab_cfg: Config to instantiate the seqio vocab. + preprocessor: A single or a list of lm text preprocessor config(s). When + used as a list, each preprocessor must correspond to one data source; + when used as a single config, it will be broadcast for all data sources. + data_mixture_components: List of DataMixtureComponent(s). + max_sequence_length: Maximum sequence length of an example. + replace_newlines_with: Value to replace newlines with in the text. + fake_input_source_cfg: A config that instantiates to a BuildDatasetFn for the input source + used during unittest. + seed: Seed for any downstream transformations (e.g. `shuffle` or `random_map`). + + Returns: + A BuildDatasetFn that mixes the given list of DataMixtureComponent(s). + """ + from axlearn.common.config import maybe_instantiate, maybe_set_config + + data_mixture_components = maybe_instantiate(data_mixture_components) + + def build_dataset_fn( + dispatch_config: DispatchConfig, + *, + is_training: bool, + vocab_cfg: ConfigOr, + preprocessor: Union[ConfigOr, list[ConfigOr]], + data_mixture_components: Union[ConfigOr, list], + max_sequence_length: int, + replace_newlines_with: str = "", + seed: Optional[int] = 42, + ) -> Dataset: + sources = [] + weights = [] + + for component in data_mixture_components: + dataset_name = component.name.replace(":", "/") + + # Construct ArrayRecord paths + arrayrecord_dataset_dir = os.path.join( + "/tmp/tensorflow_datasets/array_record", dataset_name + ) + + # Use fs.listdir to list all files in the directory + all_files = fs.listdir(arrayrecord_dataset_dir) + + # Filter for arrayrecord files + arrayrecord_files = [ + os.path.join(arrayrecord_dataset_dir, f) + for f in all_files + if f.endswith(".arrayrecord") + ] + + # Create ArrayRecord dataset + source_ds = ( + array_record_dataset(paths=arrayrecord_files, seed=seed).shuffle().repeat() + ) + source_ds = shard_dataset(source_ds, dispatch_config) + # + features_json = os.path.join(arrayrecord_dataset_dir, "features.json") + # pylint: disable-next=import-outside-toplevel + import tensorflow_datasets as tfds + + logging.info( + "Found %s; will assume tfds features and deserialize accordingly.", features_json + ) + with fs.open(features_json) as f: + features_dict = tfds.features.FeaturesDict.from_json(json.load(f)) + source_ds = source_ds.map(features_dict.deserialize_example_np) + + + # Apply preprocessing + def _set_config_for_preprocessor(p: ConfigOr) -> ConfigOr: + return maybe_set_config( + p, + vocab_cfg=vocab_cfg, + max_sequence_length=max_sequence_length, + replace_newlines_with=replace_newlines_with, + ) + + if isinstance(preprocessor, list): + assert len(preprocessor) == len(data_mixture_components) + processor_cfg = _set_config_for_preprocessor(preprocessor[len(sources)]) + else: + processor_cfg = _set_config_for_preprocessor(preprocessor) + + # Apply processor to the source dataset + processor_fn = maybe_instantiate(processor_cfg) + source_ds = source_ds.map(processor_fn) + + # Repeat the dataset for mixing + source_ds = source_ds.repeat() + + sources.append(source_ds) + weights.append(component.weight) + + # Mix the datasets + mixed_ds = sample_from_datasets(sources=sources, weights=weights) + + # Shard the mixed dataset + return mixed_ds + + return config_for_function(build_dataset_fn).set( + is_training=is_training, + vocab_cfg=vocab_cfg, + preprocessor=preprocessor, + data_mixture_components=data_mixture_components, + max_sequence_length=max_sequence_length, + replace_newlines_with=replace_newlines_with, + seed=seed, + ) diff --git a/axlearn/common/trainer_config_modifier.py b/axlearn/common/trainer_config_modifier.py index b24f859f6..fe08d608d 100644 --- a/axlearn/common/trainer_config_modifier.py +++ b/axlearn/common/trainer_config_modifier.py @@ -2,7 +2,10 @@ """Defines trainer config modifiers, which will be used in model definitions.""" -from typing import Dict, Sequence, Union +import json +import logging +import os +from typing import Dict, Optional, Sequence, Union from axlearn.common import config from axlearn.common.base_layer import RematSpec @@ -13,6 +16,7 @@ Configurable, Required, config_class, + config_for_function, maybe_instantiate, ) from axlearn.common.gradient_accumulation import with_minibatch_steps @@ -23,6 +27,7 @@ QuantizedDotGeneral, get_all_fp8_param_names, ) +from axlearn.common import file_system as fs from axlearn.common.trainer import SpmdTrainer from axlearn.common.update_transformation import OverrideInplaceUpdateTransformation from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec @@ -319,3 +324,110 @@ def enter_fn(_, value, default_kv): update_cfg.transformation = cfg.learner.optimizer cfg.learner.optimizer = update_cfg return cfg + + +class GrainConfigModifier(ConfigModifier): + """Converts tf.data input pipelines to grain input pipelines.""" + + @config_class + class Config(ConfigModifier.Config): + """Configure GrainConfigModifier. + + TODO: Supports evaluation pipeline using grain. + + Attributes: + convert_training_input: Whether to convert the training input pipeline to grain. + grain_source_builder: Optional grain source builder function to use. + If None, attempts to automatically convert from tf.data sources. + """ + + convert_training_input: bool = True + grain_source_builder: Optional[ConfigOr[Configurable]] = None + + def __init__(self, cfg: Config): + super().__init__(cfg) + cfg = self.config + self._convert_training_input = cfg.convert_training_input + self._grain_source_builder = cfg.grain_source_builder + + def _convert_tf_data_to_grain_source( + self, tf_data_config: ConfigOr[Configurable] + ) -> ConfigOr[Configurable]: + """Converts a tf.data source config to a grain source config. + + Args: + tf_data_config: The tf.data source configuration. + + Returns: + A grain source configuration. + """ + # Import grain modules here to avoid circular imports + from axlearn.common import input_grain, input_grain_lm + import grain.python as grain + + # Extract data mixture components from tf_data_config + components = tf_data_config.data_mixture_components + # Extract other relevant config parameters from tf_data_config, with fallbacks + vocab_cfg = tf_data_config.vocab_cfg + max_sequence_length = tf_data_config.max_sequence_length + preprocessor = config_for_function(input_grain_lm.text_to_lm_training_input).set( + vocab=vocab_cfg, + max_len=max_sequence_length, + max_padding_fraction=tf_data_config.preprocessor.max_padding_fraction, + window_size=tf_data_config.preprocessor.window_size, + read_options=grain.ReadOptions(num_threads=2, prefetch_buffer_size=16), + ) + max_sequence_length = tf_data_config.max_sequence_length + replace_newlines_with = tf_data_config.replace_newlines_with + + # Use the existing mixture_train_input_source function which already handles + # GCS path conversion and fs.listdir operations + return input_grain.mixture_train_input_source( + is_training=True, + vocab_cfg=vocab_cfg, + preprocessor=preprocessor, + data_mixture_components=components, + max_sequence_length=max_sequence_length, + replace_newlines_with=replace_newlines_with, + seed=42, + ) + + def _convert_input_to_grain(self, input_config: Configurable.Config) -> Configurable.Config: + """Converts a tf.data Input config to a grain Input config. + + Args: + input_config: The tf.data Input configuration. + + Returns: + A grain Input configuration. + """ + # Import grain input module + from axlearn.common import input_grain + + # Create new grain input config + grain_input_config = input_grain.Input.default_config() + + # Convert the source + if self._grain_source_builder is not None: + grain_input_config.source = self._grain_source_builder + else: + assert hasattr(input_config, "source") + # Attempt automatic conversion + grain_input_config.source = self._convert_tf_data_to_grain_source(input_config.source) + + return grain_input_config + + def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: + """Converts tf.data input pipelines to grain input pipelines. + + Args: + cfg: The trainer config to be modified. + + Returns: + The modified trainer config with grain input pipelines. + """ + # Convert training input if requested + if self._convert_training_input and hasattr(cfg, "input"): + cfg.input = self._convert_input_to_grain(cfg.input) + + return cfg diff --git a/axlearn/common/trainer_config_modifier_test.py b/axlearn/common/trainer_config_modifier_test.py index 15305f114..25fa149cb 100644 --- a/axlearn/common/trainer_config_modifier_test.py +++ b/axlearn/common/trainer_config_modifier_test.py @@ -16,6 +16,7 @@ ChainConfigModifier, FP8ConfigModifier, GradientAccumulationModifier, + GrainConfigModifier, MeshShapeModifier, ModuleConfigModifier, OverrideInplaceUpdateTransformation, @@ -209,5 +210,101 @@ def test_fp8_config_modifier(self, use_config_fn): self.assertEqual(cfg.model.linear.quantized_dot_general.fp8_amax_history_length, 1) +class GrainConfigModifierTest(test_utils.TestCase): + def test_convert_tf_data_to_grain_source(self): + """Test conversion of tf.data config to grain source using mixture_train_input_source.""" + # Mock data mixture components + class MockComponent: + def __init__(self, name, weight=1.0): + self.name = name + self.weight = weight + + class MockConfig: + def __init__(self, components, vocab_cfg=None, preprocessor=None, max_sequence_length=1024, seed=123): + self.data_mixture_components = components + self.vocab_cfg = vocab_cfg + self.preprocessor = preprocessor + self.max_sequence_length = max_sequence_length + self.replace_newlines_with = "" + self.seed = seed + + # Test that the conversion leverages mixture_train_input_source + components = [ + MockComponent("gs://permanent-us-central1-0rxn/tensorflow_datasets/c4/en/3.0.1"), + MockComponent("gs://permanent-us-central1-0rxn/tensorflow_datasets/pile/train"), + ] + + # Create a mock vocab and preprocessor config + from axlearn.common.config import config_for_function + mock_vocab_cfg = config_for_function(lambda: "mock_vocab") + mock_preprocessor_cfg = config_for_function(lambda x: x) + + tf_data_config = MockConfig( + components=components, + vocab_cfg=mock_vocab_cfg, + preprocessor=mock_preprocessor_cfg, + max_sequence_length=1024, + seed=123 + ) + + modifier = GrainConfigModifier.default_config().instantiate() + + # Mock the mixture_train_input_source function + import unittest.mock + with unittest.mock.patch('axlearn.common.input_grain.mixture_train_input_source') as mock_mixture: + mock_mixture.return_value = lambda dispatch_config: None # Mock return value + + grain_source = modifier._convert_tf_data_to_grain_source(tf_data_config) + + # Should call mixture_train_input_source with the components + mock_mixture.assert_called_once() + call_args = mock_mixture.call_args + + # Verify the call arguments use values from tf_data_config + self.assertTrue(call_args[1]['is_training']) + self.assertEqual(call_args[1]['data_mixture_components'], components) + self.assertEqual(call_args[1]['vocab_cfg'], mock_vocab_cfg) + self.assertEqual(call_args[1]['preprocessor'], mock_preprocessor_cfg) + self.assertEqual(call_args[1]['max_sequence_length'], 1024) + self.assertEqual(call_args[1]['replace_newlines_with'], "") + self.assertEqual(call_args[1]['seed'], 123) + + def test_convert_tf_data_to_grain_source_with_defaults(self): + """Test conversion with missing config values uses defaults.""" + # Mock data mixture components + class MockComponent: + def __init__(self, name, weight=1.0): + self.name = name + self.weight = weight + + class MockConfig: + def __init__(self, components): + self.data_mixture_components = components + # Missing other config attributes to test defaults + + components = [MockComponent("test_dataset")] + tf_data_config = MockConfig(components) + + modifier = GrainConfigModifier.default_config().instantiate() + + # Mock the mixture_train_input_source function + import unittest.mock + with unittest.mock.patch('axlearn.common.input_grain.mixture_train_input_source') as mock_mixture: + mock_mixture.return_value = lambda dispatch_config: None # Mock return value + + grain_source = modifier._convert_tf_data_to_grain_source(tf_data_config) + + # Should call mixture_train_input_source with default values + mock_mixture.assert_called_once() + call_args = mock_mixture.call_args + + # Verify the call arguments use default values + self.assertTrue(call_args[1]['is_training']) + self.assertEqual(call_args[1]['data_mixture_components'], components) + self.assertEqual(call_args[1]['max_sequence_length'], 512) # Default + self.assertEqual(call_args[1]['replace_newlines_with'], "") # Default + self.assertEqual(call_args[1]['seed'], 42) # Default + + if __name__ == "__main__": absltest.main() diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 1e60849a8..6c2d8dcdd 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -43,6 +43,7 @@ ChainConfigModifier, FP8ConfigModifier, GradientAccumulationModifier, + GrainConfigModifier, MeshShapeModifier, ModuleConfigModifier, PartitionSpecModifier, @@ -1082,4 +1083,51 @@ def make_single_host_config(base_config_name: str) -> SpmdTrainer.Config: ) config_map[f"{config_name}-fp8-single-host"] = make_single_host_fp8_config_func + if model_size in ("1B", "3B", "7B", "8B"): + + def make_grain_config(base_config_name: str) -> SpmdTrainer.Config: + """Make a grain input processor variant of the base config. + + This configuration uses the grain input processing framework for + improved data loading and preprocessing performance. + + Args: + base_config_name: The base config name. + + Returns: + A trainer config that uses grain input processing. + """ + + # pytype: disable=annotation-type-mismatch + cfg: SpmdTrainer.Config = config_map[base_config_name]().clone() + # pytype: enable=annotation-type-mismatch + + # Apply grain config modifier to convert tf.data to Grain + grain_modifier = GrainConfigModifier.default_config().set( + convert_training_input=True, + ) + cfg = grain_modifier.instantiate()(cfg) + + # Configure grain-specific input processing settings + # pylint: disable=cell-var-from-loop + # Adjust batch size for grain processing if needed + for evaler in cfg.evalers.values(): + if hasattr(evaler.input, 'input_dispatcher'): + evaler.input.input_dispatcher.global_logical_batch_size //= ( + 64 if version in (Version.V3, Version.V3_TIKTOKEN) else 16 + ) + # pylint: enable=cell-var-from-loop + return cfg + + # Make grain config + make_grain_config_func = functools.partial(make_grain_config, config_name) + config_map[f"{config_name}-grain"] = make_grain_config_func + + # Make grain configs for FP8 + if f"{config_name}-fp8" in config_map: + make_grain_fp8_config_func = functools.partial( + make_grain_config, f"{config_name}-fp8" + ) + config_map[f"{config_name}-fp8-grain"] = make_grain_fp8_config_func + return config_map From 6f7df3a9df6a8189422e287420c90275c05a7bf3 Mon Sep 17 00:00:00 2001 From: Haoshuo Huang Date: Mon, 28 Jul 2025 13:32:05 -0700 Subject: [PATCH 2/4] A couple of minor fixes along with benchmarking script --- axlearn/common/input_grain.py | 71 +--- axlearn/common/trainer_config_modifier.py | 32 +- .../common/trainer_config_modifier_test.py | 310 ------------------ axlearn/experiments/text/gpt/benchmark.py | 94 ++++++ axlearn/experiments/text/gpt/fuji.py | 2 +- 5 files changed, 126 insertions(+), 383 deletions(-) delete mode 100644 axlearn/common/trainer_config_modifier_test.py create mode 100644 axlearn/experiments/text/gpt/benchmark.py diff --git a/axlearn/common/input_grain.py b/axlearn/common/input_grain.py index 5b4d4b1f9..975700bfe 100644 --- a/axlearn/common/input_grain.py +++ b/axlearn/common/input_grain.py @@ -31,8 +31,8 @@ """ -import os import json +import os import sys from dataclasses import dataclass from typing import Any, Callable, Optional, Protocol, Sequence, TypeVar, Union, runtime_checkable @@ -47,8 +47,8 @@ from grain._src.python.dataset import dataset as dataset_base from jax.experimental import multihost_utils -from axlearn.common import input_base, utils from axlearn.common import file_system as fs +from axlearn.common import input_base, utils from axlearn.common.config import ( REQUIRED, ConfigOr, @@ -74,12 +74,14 @@ class RaggedTensor(list): @runtime_checkable class _CallableTransform(Protocol): - def __call__(self, example: Any) -> Any: ... + def __call__(self, example: Any) -> Any: + ... @runtime_checkable class _RandomCallableTransform(Protocol): - def __call__(self, example: Any, rng: np.random.Generator) -> Any: ... + def __call__(self, example: Any, rng: np.random.Generator) -> Any: + ... # Grain supports a set of predefined transformations (e.g. grain.MapTransform), as well as callables @@ -128,7 +130,8 @@ def __post_init__(self): class BuildDatasetFn(Protocol): """A function to create a grain data source.""" - def __call__(self, dispatch_config: DispatchConfig) -> Dataset: ... + def __call__(self, dispatch_config: DispatchConfig) -> Dataset: + ... def _copy_tree(x: _T) -> _T: @@ -768,13 +771,8 @@ def shape_dtype(x): def mixture_train_input_source( *, - is_training: bool, - vocab_cfg: ConfigOr, preprocessor: Union[ConfigOr, list[ConfigOr]], data_mixture_components: Union[ConfigOr, list], - max_sequence_length: int, - replace_newlines_with: str = "", - fake_input_source_cfg: Optional[ConfigOr] = None, seed: Optional[int] = 42, ) -> BuildDatasetFn: """Build mixture training input source for decoder-only LM model using grain. @@ -798,21 +796,13 @@ def mixture_train_input_source( Returns: A BuildDatasetFn that mixes the given list of DataMixtureComponent(s). """ - from axlearn.common.config import maybe_instantiate, maybe_set_config + from axlearn.common.config import maybe_instantiate data_mixture_components = maybe_instantiate(data_mixture_components) def build_dataset_fn( - dispatch_config: DispatchConfig, - *, - is_training: bool, - vocab_cfg: ConfigOr, - preprocessor: Union[ConfigOr, list[ConfigOr]], - data_mixture_components: Union[ConfigOr, list], - max_sequence_length: int, - replace_newlines_with: str = "", - seed: Optional[int] = 42, - ) -> Dataset: + dispatch_config: DispatchConfig, + ) -> Dataset: sources = [] weights = [] @@ -821,7 +811,7 @@ def build_dataset_fn( # Construct ArrayRecord paths arrayrecord_dataset_dir = os.path.join( - "/tmp/tensorflow_datasets/array_record", dataset_name + "/tmp/gcsfuse/tensorflow_datasets/array_record", dataset_name ) # Use fs.listdir to list all files in the directory @@ -835,9 +825,7 @@ def build_dataset_fn( ] # Create ArrayRecord dataset - source_ds = ( - array_record_dataset(paths=arrayrecord_files, seed=seed).shuffle().repeat() - ) + source_ds = array_record_dataset(paths=arrayrecord_files, seed=seed).shuffle().repeat() source_ds = shard_dataset(source_ds, dispatch_config) # features_json = os.path.join(arrayrecord_dataset_dir, "features.json") @@ -850,29 +838,8 @@ def build_dataset_fn( with fs.open(features_json) as f: features_dict = tfds.features.FeaturesDict.from_json(json.load(f)) source_ds = source_ds.map(features_dict.deserialize_example_np) - - - # Apply preprocessing - def _set_config_for_preprocessor(p: ConfigOr) -> ConfigOr: - return maybe_set_config( - p, - vocab_cfg=vocab_cfg, - max_sequence_length=max_sequence_length, - replace_newlines_with=replace_newlines_with, - ) - - if isinstance(preprocessor, list): - assert len(preprocessor) == len(data_mixture_components) - processor_cfg = _set_config_for_preprocessor(preprocessor[len(sources)]) - else: - processor_cfg = _set_config_for_preprocessor(preprocessor) - # Apply processor to the source dataset - processor_fn = maybe_instantiate(processor_cfg) - source_ds = source_ds.map(processor_fn) - - # Repeat the dataset for mixing - source_ds = source_ds.repeat() + source_ds = preprocessor(source_ds) sources.append(source_ds) weights.append(component.weight) @@ -883,12 +850,4 @@ def _set_config_for_preprocessor(p: ConfigOr) -> ConfigOr: # Shard the mixed dataset return mixed_ds - return config_for_function(build_dataset_fn).set( - is_training=is_training, - vocab_cfg=vocab_cfg, - preprocessor=preprocessor, - data_mixture_components=data_mixture_components, - max_sequence_length=max_sequence_length, - replace_newlines_with=replace_newlines_with, - seed=seed, - ) + return build_dataset_fn diff --git a/axlearn/common/trainer_config_modifier.py b/axlearn/common/trainer_config_modifier.py index fe08d608d..807266d3a 100644 --- a/axlearn/common/trainer_config_modifier.py +++ b/axlearn/common/trainer_config_modifier.py @@ -8,6 +8,7 @@ from typing import Dict, Optional, Sequence, Union from axlearn.common import config +from axlearn.common import file_system as fs from axlearn.common.base_layer import RematSpec from axlearn.common.config import ( REQUIRED, @@ -27,7 +28,6 @@ QuantizedDotGeneral, get_all_fp8_param_names, ) -from axlearn.common import file_system as fs from axlearn.common.trainer import SpmdTrainer from axlearn.common.update_transformation import OverrideInplaceUpdateTransformation from axlearn.common.utils import HybridMeshShape, MeshShape, PartitionSpec @@ -362,33 +362,33 @@ def _convert_tf_data_to_grain_source( A grain source configuration. """ # Import grain modules here to avoid circular imports - from axlearn.common import input_grain, input_grain_lm import grain.python as grain + from axlearn.common import input_grain, input_grain_lm + # Extract data mixture components from tf_data_config components = tf_data_config.data_mixture_components # Extract other relevant config parameters from tf_data_config, with fallbacks vocab_cfg = tf_data_config.vocab_cfg max_sequence_length = tf_data_config.max_sequence_length - preprocessor = config_for_function(input_grain_lm.text_to_lm_training_input).set( - vocab=vocab_cfg, - max_len=max_sequence_length, - max_padding_fraction=tf_data_config.preprocessor.max_padding_fraction, - window_size=tf_data_config.preprocessor.window_size, - read_options=grain.ReadOptions(num_threads=2, prefetch_buffer_size=16), - ) - max_sequence_length = tf_data_config.max_sequence_length - replace_newlines_with = tf_data_config.replace_newlines_with + + def processing_fn(ds): + return input_grain_lm.text_to_lm_training_input( + ds, + vocab=vocab_cfg, + max_len=max_sequence_length, + max_padding_fraction=tf_data_config.preprocessor.max_padding_fraction, + window_size=tf_data_config.preprocessor.window_size, + read_options=grain.ReadOptions(num_threads=8, prefetch_buffer_size=128), + ) + + preprocessor = processing_fn # Use the existing mixture_train_input_source function which already handles # GCS path conversion and fs.listdir operations - return input_grain.mixture_train_input_source( - is_training=True, - vocab_cfg=vocab_cfg, + return config_for_function(input_grain.mixture_train_input_source).set( preprocessor=preprocessor, data_mixture_components=components, - max_sequence_length=max_sequence_length, - replace_newlines_with=replace_newlines_with, seed=42, ) diff --git a/axlearn/common/trainer_config_modifier_test.py b/axlearn/common/trainer_config_modifier_test.py deleted file mode 100644 index 25fa149cb..000000000 --- a/axlearn/common/trainer_config_modifier_test.py +++ /dev/null @@ -1,310 +0,0 @@ -# Copyright © 2024 Apple Inc. - -"""Test various ConfigModifier classes in trainer_config_modifier.py.""" - -import jax -from absl.testing import absltest, parameterized - -from axlearn.common import causal_lm, test_utils -from axlearn.common.attention import RepeatedTransformerLayer, StackedTransformerLayer -from axlearn.common.base_layer import RematSpec -from axlearn.common.config import config_for_function -from axlearn.common.optimizers import sgd_optimizer -from axlearn.common.quantized_dot_general.layers import get_all_fp8_param_names -from axlearn.common.trainer import SpmdTrainer -from axlearn.common.trainer_config_modifier import ( - ChainConfigModifier, - FP8ConfigModifier, - GradientAccumulationModifier, - GrainConfigModifier, - MeshShapeModifier, - ModuleConfigModifier, - OverrideInplaceUpdateTransformation, - PartitionSpecModifier, - RematSpecModifier, -) -from axlearn.common.trainer_test import DummyModel - - -class GradientAccumulationModifierTest(test_utils.TestCase): - def test_gradient_accumulation_override(self): - cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) - cfg_modifier = ( - GradientAccumulationModifier.default_config().set(grad_acc_steps=4).instantiate() - ) - cfg = cfg_modifier(cfg) - self.assertEqual(cfg.learner.forward_fn_transformation.steps, 4) - - -class RematSpecModifierTest(test_utils.TestCase): - def test_remat_policy_override(self): - cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) - cfg_modifier = ( - RematSpecModifier.default_config() - .set( - remat_policies={ - "model.linear": RematSpec( - prevent_cse=True, - policy=jax.ad_checkpoint.checkpoint_policies.dots_saveable, - ), - } - ) - .instantiate() - ) - cfg = cfg_modifier(cfg) - self.assertRegex(str(cfg.model.linear), "dots_saveable") - cfg_modifier = ( - RematSpecModifier.default_config() - .set( - remat_policies={ - "model.linear": RematSpec( - prevent_cse=True, - policy=jax.ad_checkpoint.checkpoint_policies.dots_saveable, - ), - "model.unknown": RematSpec( - prevent_cse=True, - policy=jax.ad_checkpoint.checkpoint_policies.dots_saveable, - ), - } - ) - .instantiate() - ) - # Ensure that the exception is working. - with self.assertRaisesRegex(AttributeError, r"unknown \(keys are *"): - _ = cfg_modifier(cfg) - - -class ModuleConfigModifierTest(test_utils.TestCase): - def test_model_config_override(self): - cfg = SpmdTrainer.default_config().set(model=causal_lm.Model.default_config()) - self.assertTrue( - str(cfg.model.decoder.transformer) == str(StackedTransformerLayer.default_config()) - ) - - cfg_modifier = ( - ModuleConfigModifier.default_config() - .set( - target_config="model.decoder.transformer", - modification=RepeatedTransformerLayer.default_config(), - ) - .instantiate() - ) - - cfg = cfg_modifier(cfg) - # The default StackedTransformerLayer should have changed to RepeatedTransformerLayer - self.assertTrue( - str(cfg.model.decoder.transformer) == str(RepeatedTransformerLayer.default_config()) - ) - cfg_modifier = ( - ModuleConfigModifier.default_config() - .set( - target_config="model.decoder.unknown", - modification=RepeatedTransformerLayer.default_config(), - ) - .instantiate() - ) - # Ensure that the exception is working. - with self.assertRaisesRegex(AttributeError, r"unknown \(keys are *"): - _ = cfg_modifier(cfg) - - -class PartitionSpecModifierTest(test_utils.TestCase): - def test_partition_spec_override(self): - cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) - cfg_modifier = ( - PartitionSpecModifier.default_config() - .set( - partition_specs={ - "model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))}, - }, - ) - .instantiate() - ) - cfg = cfg_modifier(cfg) - self.assertTrue( - str(cfg.model.linear.param_partition_spec), """("model", ("expert", "fsdp", "seq")""" - ) - cfg_modifier = ( - PartitionSpecModifier.default_config() - .set( - partition_specs={ - "model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))}, - "model.unknown": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))}, - }, - ) - .instantiate() - ) - # Ensure that the exception is working. - with self.assertRaisesRegex(AttributeError, r"unknown \(keys are *"): - _ = cfg_modifier(cfg) - - cfg_modifier = ( - PartitionSpecModifier.default_config() - .set( - partition_specs={ - "model.linear": { - "param_partition_spec": ("model", ("expert", "fsdp", "seq")), - "unknown_partition_spec": ("model", ("expert", "fsdp", "seq")), - }, - }, - ) - .instantiate() - ) - with self.assertRaisesRegex(AttributeError, "unknown_partition_spec *"): - _ = cfg_modifier(cfg) - - -class MeshShapeModifierTest(test_utils.TestCase): - def test_mesh_shape_update(self): - cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) - cfg_modifier = MeshShapeModifier.default_config().set(mesh_shape=(4, 1, 8, 1)).instantiate() - cfg = cfg_modifier(cfg) - self.assertEqual(cfg.mesh_shape, (4, 1, 8, 1)) - - -class ChainConfigModifierTest(test_utils.TestCase): - def test_chain_config_modifier(self): - cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) - cfg_modifier = ( - ChainConfigModifier.default_config() - .set( - config_modifiers=[ - GradientAccumulationModifier.default_config().set(grad_acc_steps=4), - MeshShapeModifier.default_config().set(mesh_shape=(4, 1, 8, 1)), - ] - ) - .instantiate() - ) - cfg = cfg_modifier(cfg) - self.assertEqual(cfg.mesh_shape, (4, 1, 8, 1)) - self.assertEqual(cfg.learner.forward_fn_transformation.steps, 4) - - -class FP8ConfigModifierTest(test_utils.TestCase): - @parameterized.parameters([True, False]) - def test_fp8_config_modifier(self, use_config_fn): - cfg: SpmdTrainer.Config = SpmdTrainer.default_config().set( - model=DummyModel.default_config() - ) - if use_config_fn: - cfg.learner.optimizer = config_for_function(sgd_optimizer).set( - learning_rate=0.5, - decouple_weight_decay=True, - ) - else: - cfg.learner.optimizer = sgd_optimizer( - learning_rate=0.5, - decouple_weight_decay=True, - ) - - cfg_modifier = ( - FP8ConfigModifier.default_config().set(fp8_amax_history_length=1).instantiate() - ) - cfg = cfg_modifier(cfg) - - self.assertIsInstance(cfg.learner.optimizer, OverrideInplaceUpdateTransformation.Config) - self.assertEqual( - cfg.learner.optimizer.rules, - [f".*/{x}" for x in get_all_fp8_param_names()], - ) - self.assertEqual(cfg.model.linear.quantized_dot_general.fp8_amax_history_length, 1) - - -class GrainConfigModifierTest(test_utils.TestCase): - def test_convert_tf_data_to_grain_source(self): - """Test conversion of tf.data config to grain source using mixture_train_input_source.""" - # Mock data mixture components - class MockComponent: - def __init__(self, name, weight=1.0): - self.name = name - self.weight = weight - - class MockConfig: - def __init__(self, components, vocab_cfg=None, preprocessor=None, max_sequence_length=1024, seed=123): - self.data_mixture_components = components - self.vocab_cfg = vocab_cfg - self.preprocessor = preprocessor - self.max_sequence_length = max_sequence_length - self.replace_newlines_with = "" - self.seed = seed - - # Test that the conversion leverages mixture_train_input_source - components = [ - MockComponent("gs://permanent-us-central1-0rxn/tensorflow_datasets/c4/en/3.0.1"), - MockComponent("gs://permanent-us-central1-0rxn/tensorflow_datasets/pile/train"), - ] - - # Create a mock vocab and preprocessor config - from axlearn.common.config import config_for_function - mock_vocab_cfg = config_for_function(lambda: "mock_vocab") - mock_preprocessor_cfg = config_for_function(lambda x: x) - - tf_data_config = MockConfig( - components=components, - vocab_cfg=mock_vocab_cfg, - preprocessor=mock_preprocessor_cfg, - max_sequence_length=1024, - seed=123 - ) - - modifier = GrainConfigModifier.default_config().instantiate() - - # Mock the mixture_train_input_source function - import unittest.mock - with unittest.mock.patch('axlearn.common.input_grain.mixture_train_input_source') as mock_mixture: - mock_mixture.return_value = lambda dispatch_config: None # Mock return value - - grain_source = modifier._convert_tf_data_to_grain_source(tf_data_config) - - # Should call mixture_train_input_source with the components - mock_mixture.assert_called_once() - call_args = mock_mixture.call_args - - # Verify the call arguments use values from tf_data_config - self.assertTrue(call_args[1]['is_training']) - self.assertEqual(call_args[1]['data_mixture_components'], components) - self.assertEqual(call_args[1]['vocab_cfg'], mock_vocab_cfg) - self.assertEqual(call_args[1]['preprocessor'], mock_preprocessor_cfg) - self.assertEqual(call_args[1]['max_sequence_length'], 1024) - self.assertEqual(call_args[1]['replace_newlines_with'], "") - self.assertEqual(call_args[1]['seed'], 123) - - def test_convert_tf_data_to_grain_source_with_defaults(self): - """Test conversion with missing config values uses defaults.""" - # Mock data mixture components - class MockComponent: - def __init__(self, name, weight=1.0): - self.name = name - self.weight = weight - - class MockConfig: - def __init__(self, components): - self.data_mixture_components = components - # Missing other config attributes to test defaults - - components = [MockComponent("test_dataset")] - tf_data_config = MockConfig(components) - - modifier = GrainConfigModifier.default_config().instantiate() - - # Mock the mixture_train_input_source function - import unittest.mock - with unittest.mock.patch('axlearn.common.input_grain.mixture_train_input_source') as mock_mixture: - mock_mixture.return_value = lambda dispatch_config: None # Mock return value - - grain_source = modifier._convert_tf_data_to_grain_source(tf_data_config) - - # Should call mixture_train_input_source with default values - mock_mixture.assert_called_once() - call_args = mock_mixture.call_args - - # Verify the call arguments use default values - self.assertTrue(call_args[1]['is_training']) - self.assertEqual(call_args[1]['data_mixture_components'], components) - self.assertEqual(call_args[1]['max_sequence_length'], 512) # Default - self.assertEqual(call_args[1]['replace_newlines_with'], "") # Default - self.assertEqual(call_args[1]['seed'], 42) # Default - - -if __name__ == "__main__": - absltest.main() diff --git a/axlearn/experiments/text/gpt/benchmark.py b/axlearn/experiments/text/gpt/benchmark.py new file mode 100644 index 000000000..6d4bbd468 --- /dev/null +++ b/axlearn/experiments/text/gpt/benchmark.py @@ -0,0 +1,94 @@ +import os + +os.environ["JAX_PLATFORMS"] = "cpu" +import time + +import coloredlogs +import jax +import numpy as np +from absl import flags, logging + +from axlearn.common.utils import set_data_dir + +# coloredlogs.install("INFO", fmt="%(asctime)s %(name)s:%(lineno)s[%(process)d] %(levelname)s %(message)s") +formatter = coloredlogs.ColoredFormatter( + fmt="%(asctime)s %(filename)s:%(lineno)s[%(process)d] %(levelname)s %(message)s" +) + +# logging.get_absl_handler().setFormatter(None) + +FLAGS = flags.FLAGS + + +def _print_stats(res, idx): + res = np.array(res) + res = np.diff(res) + logging.warning( + f"{idx} batches, per-batch time {np.mean(res) * 1e-6:.2f} ms, std {np.std(res) * 1e-6:.2f} ms" + ) + res = res[len(res) // 2 :] + logging.warning( + f"{idx} batches, per-batch time {np.mean(res) * 1e-6:.2f} ms, std {np.std(res) * 1e-6:.2f} ms" + ) + + +def benchmark(ds=None, ds_iter=None, max_iters=None): + if ds_iter is None: + assert ds is not None + ds_iter = iter(ds) + if max_iters is None: + max_iters = 2**30 + + idx = 1 + res = [time.time_ns()] + while idx <= max_iters: + next(ds_iter) + idx += 1 + res.append(time.time_ns()) + if idx % 20 == 0: + _print_stats(res, idx) + if max_iters is not None and idx == max_iters: + break + + _print_stats(res, idx) + res = res[: len(res) // 2] + return np.mean(res), np.std(res) + + +def timed(fn, msg): + begin = time.time_ns() + ret = fn() + end = time.time_ns() + logging.info(f"{msg}, {(end - begin) * 1e-6:.2f} ms") + return ret + + +from ajax.experiments.speech.pretrain.online_pretrain_utils import audio_to_modality_config + +if __name__ == "__main__": + # def main(argv): + logging.set_verbosity(logging.INFO) + logging.get_absl_handler().setFormatter(formatter) + + with set_data_dir("/tmp/gcsfuse/tensorflow_datasets"): + with jax.make_mesh((1, 1, 1, 1), ("data", "expert", "fsdp", "seq")): + from axlearn.experiments.text.gpt.c4_trainer import named_trainer_configs + + cfg = named_trainer_configs()["fuji-7B-v2-grain"]() + + from ajax.experiments import general_lm + from jax.sharding import PartitionSpec + + cfg.input.partition_spec = PartitionSpec( + general_lm.batch_axis_names_from(general_lm.MESH_AXIS_NAMES) + ) + ds = timed( + lambda: cfg.input.set(name="input").instantiate(parent=None), + "initialize", + ) + # ds_iter = timed(lambda: iter(ds.dataset().parents[0]), "iter") + ds_iter = timed(lambda: iter(ds), "iter") + + x = next(ds_iter) + while True: + benchmark(ds_iter=ds_iter, max_iters=5000) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 6c2d8dcdd..21d0c0c84 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -1112,7 +1112,7 @@ def make_grain_config(base_config_name: str) -> SpmdTrainer.Config: # pylint: disable=cell-var-from-loop # Adjust batch size for grain processing if needed for evaler in cfg.evalers.values(): - if hasattr(evaler.input, 'input_dispatcher'): + if hasattr(evaler.input, "input_dispatcher"): evaler.input.input_dispatcher.global_logical_batch_size //= ( 64 if version in (Version.V3, Version.V3_TIKTOKEN) else 16 ) From 55871da9ce1f771a5ad794e7da52c1443062a826 Mon Sep 17 00:00:00 2001 From: Haoshuo Huang Date: Mon, 28 Jul 2025 13:59:07 -0700 Subject: [PATCH 3/4] Revert unit test --- .../common/trainer_config_modifier_test.py | 310 ++++++++++++++++++ 1 file changed, 310 insertions(+) create mode 100644 axlearn/common/trainer_config_modifier_test.py diff --git a/axlearn/common/trainer_config_modifier_test.py b/axlearn/common/trainer_config_modifier_test.py new file mode 100644 index 000000000..25fa149cb --- /dev/null +++ b/axlearn/common/trainer_config_modifier_test.py @@ -0,0 +1,310 @@ +# Copyright © 2024 Apple Inc. + +"""Test various ConfigModifier classes in trainer_config_modifier.py.""" + +import jax +from absl.testing import absltest, parameterized + +from axlearn.common import causal_lm, test_utils +from axlearn.common.attention import RepeatedTransformerLayer, StackedTransformerLayer +from axlearn.common.base_layer import RematSpec +from axlearn.common.config import config_for_function +from axlearn.common.optimizers import sgd_optimizer +from axlearn.common.quantized_dot_general.layers import get_all_fp8_param_names +from axlearn.common.trainer import SpmdTrainer +from axlearn.common.trainer_config_modifier import ( + ChainConfigModifier, + FP8ConfigModifier, + GradientAccumulationModifier, + GrainConfigModifier, + MeshShapeModifier, + ModuleConfigModifier, + OverrideInplaceUpdateTransformation, + PartitionSpecModifier, + RematSpecModifier, +) +from axlearn.common.trainer_test import DummyModel + + +class GradientAccumulationModifierTest(test_utils.TestCase): + def test_gradient_accumulation_override(self): + cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) + cfg_modifier = ( + GradientAccumulationModifier.default_config().set(grad_acc_steps=4).instantiate() + ) + cfg = cfg_modifier(cfg) + self.assertEqual(cfg.learner.forward_fn_transformation.steps, 4) + + +class RematSpecModifierTest(test_utils.TestCase): + def test_remat_policy_override(self): + cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) + cfg_modifier = ( + RematSpecModifier.default_config() + .set( + remat_policies={ + "model.linear": RematSpec( + prevent_cse=True, + policy=jax.ad_checkpoint.checkpoint_policies.dots_saveable, + ), + } + ) + .instantiate() + ) + cfg = cfg_modifier(cfg) + self.assertRegex(str(cfg.model.linear), "dots_saveable") + cfg_modifier = ( + RematSpecModifier.default_config() + .set( + remat_policies={ + "model.linear": RematSpec( + prevent_cse=True, + policy=jax.ad_checkpoint.checkpoint_policies.dots_saveable, + ), + "model.unknown": RematSpec( + prevent_cse=True, + policy=jax.ad_checkpoint.checkpoint_policies.dots_saveable, + ), + } + ) + .instantiate() + ) + # Ensure that the exception is working. + with self.assertRaisesRegex(AttributeError, r"unknown \(keys are *"): + _ = cfg_modifier(cfg) + + +class ModuleConfigModifierTest(test_utils.TestCase): + def test_model_config_override(self): + cfg = SpmdTrainer.default_config().set(model=causal_lm.Model.default_config()) + self.assertTrue( + str(cfg.model.decoder.transformer) == str(StackedTransformerLayer.default_config()) + ) + + cfg_modifier = ( + ModuleConfigModifier.default_config() + .set( + target_config="model.decoder.transformer", + modification=RepeatedTransformerLayer.default_config(), + ) + .instantiate() + ) + + cfg = cfg_modifier(cfg) + # The default StackedTransformerLayer should have changed to RepeatedTransformerLayer + self.assertTrue( + str(cfg.model.decoder.transformer) == str(RepeatedTransformerLayer.default_config()) + ) + cfg_modifier = ( + ModuleConfigModifier.default_config() + .set( + target_config="model.decoder.unknown", + modification=RepeatedTransformerLayer.default_config(), + ) + .instantiate() + ) + # Ensure that the exception is working. + with self.assertRaisesRegex(AttributeError, r"unknown \(keys are *"): + _ = cfg_modifier(cfg) + + +class PartitionSpecModifierTest(test_utils.TestCase): + def test_partition_spec_override(self): + cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) + cfg_modifier = ( + PartitionSpecModifier.default_config() + .set( + partition_specs={ + "model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))}, + }, + ) + .instantiate() + ) + cfg = cfg_modifier(cfg) + self.assertTrue( + str(cfg.model.linear.param_partition_spec), """("model", ("expert", "fsdp", "seq")""" + ) + cfg_modifier = ( + PartitionSpecModifier.default_config() + .set( + partition_specs={ + "model.linear": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))}, + "model.unknown": {"param_partition_spec": ("model", ("expert", "fsdp", "seq"))}, + }, + ) + .instantiate() + ) + # Ensure that the exception is working. + with self.assertRaisesRegex(AttributeError, r"unknown \(keys are *"): + _ = cfg_modifier(cfg) + + cfg_modifier = ( + PartitionSpecModifier.default_config() + .set( + partition_specs={ + "model.linear": { + "param_partition_spec": ("model", ("expert", "fsdp", "seq")), + "unknown_partition_spec": ("model", ("expert", "fsdp", "seq")), + }, + }, + ) + .instantiate() + ) + with self.assertRaisesRegex(AttributeError, "unknown_partition_spec *"): + _ = cfg_modifier(cfg) + + +class MeshShapeModifierTest(test_utils.TestCase): + def test_mesh_shape_update(self): + cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) + cfg_modifier = MeshShapeModifier.default_config().set(mesh_shape=(4, 1, 8, 1)).instantiate() + cfg = cfg_modifier(cfg) + self.assertEqual(cfg.mesh_shape, (4, 1, 8, 1)) + + +class ChainConfigModifierTest(test_utils.TestCase): + def test_chain_config_modifier(self): + cfg = SpmdTrainer.default_config().set(model=DummyModel.default_config()) + cfg_modifier = ( + ChainConfigModifier.default_config() + .set( + config_modifiers=[ + GradientAccumulationModifier.default_config().set(grad_acc_steps=4), + MeshShapeModifier.default_config().set(mesh_shape=(4, 1, 8, 1)), + ] + ) + .instantiate() + ) + cfg = cfg_modifier(cfg) + self.assertEqual(cfg.mesh_shape, (4, 1, 8, 1)) + self.assertEqual(cfg.learner.forward_fn_transformation.steps, 4) + + +class FP8ConfigModifierTest(test_utils.TestCase): + @parameterized.parameters([True, False]) + def test_fp8_config_modifier(self, use_config_fn): + cfg: SpmdTrainer.Config = SpmdTrainer.default_config().set( + model=DummyModel.default_config() + ) + if use_config_fn: + cfg.learner.optimizer = config_for_function(sgd_optimizer).set( + learning_rate=0.5, + decouple_weight_decay=True, + ) + else: + cfg.learner.optimizer = sgd_optimizer( + learning_rate=0.5, + decouple_weight_decay=True, + ) + + cfg_modifier = ( + FP8ConfigModifier.default_config().set(fp8_amax_history_length=1).instantiate() + ) + cfg = cfg_modifier(cfg) + + self.assertIsInstance(cfg.learner.optimizer, OverrideInplaceUpdateTransformation.Config) + self.assertEqual( + cfg.learner.optimizer.rules, + [f".*/{x}" for x in get_all_fp8_param_names()], + ) + self.assertEqual(cfg.model.linear.quantized_dot_general.fp8_amax_history_length, 1) + + +class GrainConfigModifierTest(test_utils.TestCase): + def test_convert_tf_data_to_grain_source(self): + """Test conversion of tf.data config to grain source using mixture_train_input_source.""" + # Mock data mixture components + class MockComponent: + def __init__(self, name, weight=1.0): + self.name = name + self.weight = weight + + class MockConfig: + def __init__(self, components, vocab_cfg=None, preprocessor=None, max_sequence_length=1024, seed=123): + self.data_mixture_components = components + self.vocab_cfg = vocab_cfg + self.preprocessor = preprocessor + self.max_sequence_length = max_sequence_length + self.replace_newlines_with = "" + self.seed = seed + + # Test that the conversion leverages mixture_train_input_source + components = [ + MockComponent("gs://permanent-us-central1-0rxn/tensorflow_datasets/c4/en/3.0.1"), + MockComponent("gs://permanent-us-central1-0rxn/tensorflow_datasets/pile/train"), + ] + + # Create a mock vocab and preprocessor config + from axlearn.common.config import config_for_function + mock_vocab_cfg = config_for_function(lambda: "mock_vocab") + mock_preprocessor_cfg = config_for_function(lambda x: x) + + tf_data_config = MockConfig( + components=components, + vocab_cfg=mock_vocab_cfg, + preprocessor=mock_preprocessor_cfg, + max_sequence_length=1024, + seed=123 + ) + + modifier = GrainConfigModifier.default_config().instantiate() + + # Mock the mixture_train_input_source function + import unittest.mock + with unittest.mock.patch('axlearn.common.input_grain.mixture_train_input_source') as mock_mixture: + mock_mixture.return_value = lambda dispatch_config: None # Mock return value + + grain_source = modifier._convert_tf_data_to_grain_source(tf_data_config) + + # Should call mixture_train_input_source with the components + mock_mixture.assert_called_once() + call_args = mock_mixture.call_args + + # Verify the call arguments use values from tf_data_config + self.assertTrue(call_args[1]['is_training']) + self.assertEqual(call_args[1]['data_mixture_components'], components) + self.assertEqual(call_args[1]['vocab_cfg'], mock_vocab_cfg) + self.assertEqual(call_args[1]['preprocessor'], mock_preprocessor_cfg) + self.assertEqual(call_args[1]['max_sequence_length'], 1024) + self.assertEqual(call_args[1]['replace_newlines_with'], "") + self.assertEqual(call_args[1]['seed'], 123) + + def test_convert_tf_data_to_grain_source_with_defaults(self): + """Test conversion with missing config values uses defaults.""" + # Mock data mixture components + class MockComponent: + def __init__(self, name, weight=1.0): + self.name = name + self.weight = weight + + class MockConfig: + def __init__(self, components): + self.data_mixture_components = components + # Missing other config attributes to test defaults + + components = [MockComponent("test_dataset")] + tf_data_config = MockConfig(components) + + modifier = GrainConfigModifier.default_config().instantiate() + + # Mock the mixture_train_input_source function + import unittest.mock + with unittest.mock.patch('axlearn.common.input_grain.mixture_train_input_source') as mock_mixture: + mock_mixture.return_value = lambda dispatch_config: None # Mock return value + + grain_source = modifier._convert_tf_data_to_grain_source(tf_data_config) + + # Should call mixture_train_input_source with default values + mock_mixture.assert_called_once() + call_args = mock_mixture.call_args + + # Verify the call arguments use default values + self.assertTrue(call_args[1]['is_training']) + self.assertEqual(call_args[1]['data_mixture_components'], components) + self.assertEqual(call_args[1]['max_sequence_length'], 512) # Default + self.assertEqual(call_args[1]['replace_newlines_with'], "") # Default + self.assertEqual(call_args[1]['seed'], 42) # Default + + +if __name__ == "__main__": + absltest.main() From 5708df34baf81339e43048058647d980d211c618 Mon Sep 17 00:00:00 2001 From: Haoshuo Huang Date: Mon, 28 Jul 2025 22:03:51 -0700 Subject: [PATCH 4/4] A couple of fixes --- axlearn/common/input_grain.py | 17 ++++++++++------- axlearn/common/input_grain_lm.py | 7 ++++--- axlearn/common/trainer_config_modifier.py | 17 +++++++++++++++-- axlearn/experiments/text/gpt/benchmark.py | 3 +-- axlearn/experiments/text/gpt/fuji.py | 12 +----------- 5 files changed, 31 insertions(+), 25 deletions(-) diff --git a/axlearn/common/input_grain.py b/axlearn/common/input_grain.py index 975700bfe..080c26315 100644 --- a/axlearn/common/input_grain.py +++ b/axlearn/common/input_grain.py @@ -773,6 +773,7 @@ def mixture_train_input_source( *, preprocessor: Union[ConfigOr, list[ConfigOr]], data_mixture_components: Union[ConfigOr, list], + global_logical_batch_size: int, seed: Optional[int] = 42, ) -> BuildDatasetFn: """Build mixture training input source for decoder-only LM model using grain. @@ -781,17 +782,11 @@ def mixture_train_input_source( example will only contain tokens from a single source. Args: - is_training: A boolean indicating that inputs will be used for training. - vocab_cfg: Config to instantiate the seqio vocab. preprocessor: A single or a list of lm text preprocessor config(s). When used as a list, each preprocessor must correspond to one data source; when used as a single config, it will be broadcast for all data sources. data_mixture_components: List of DataMixtureComponent(s). - max_sequence_length: Maximum sequence length of an example. - replace_newlines_with: Value to replace newlines with in the text. - fake_input_source_cfg: A config that instantiates to a BuildDatasetFn for the input source - used during unittest. - seed: Seed for any downstream transformations (e.g. `shuffle` or `random_map`). + global_logical_batch_size: The global logical batch size. Returns: A BuildDatasetFn that mixes the given list of DataMixtureComponent(s). @@ -846,6 +841,14 @@ def build_dataset_fn( # Mix the datasets mixed_ds = sample_from_datasets(sources=sources, weights=weights) + global_batch_size = global_logical_batch_size + logging.info("Global batch size for grain is set to %s", global_batch_size) + + mixed_ds = per_feed_batch( + mixed_ds, + global_batch_size=global_batch_size, + dispatch_config=dispatch_config, + ) # Shard the mixed dataset return mixed_ds diff --git a/axlearn/common/input_grain_lm.py b/axlearn/common/input_grain_lm.py index 24f6ec40d..9a2b3f5a3 100644 --- a/axlearn/common/input_grain_lm.py +++ b/axlearn/common/input_grain_lm.py @@ -545,9 +545,10 @@ def strided_slice(example: dict[str, Tensor]) -> dict[str, Tensor]: ds = ds.map(strided_slice) # Produce batches. - ds = input_grain_text.count_num_bytes( - ds, input_key="target_labels", vocab=vocab, output_key="target_num_bytes" - ) + # Skips target_num_bytes for now. + # ds = input_grain_text.count_num_bytes( + # ds, input_key="target_labels", vocab=vocab, output_key="target_num_bytes" + # ) ds = ds.map(_drop_empty_targets) ds = input_grain.maybe_to_iter_dataset(ds) ds = input_grain.unbatch(ds) diff --git a/axlearn/common/trainer_config_modifier.py b/axlearn/common/trainer_config_modifier.py index 807266d3a..d1dc01dd0 100644 --- a/axlearn/common/trainer_config_modifier.py +++ b/axlearn/common/trainer_config_modifier.py @@ -351,12 +351,15 @@ def __init__(self, cfg: Config): self._grain_source_builder = cfg.grain_source_builder def _convert_tf_data_to_grain_source( - self, tf_data_config: ConfigOr[Configurable] + self, + tf_data_config: ConfigOr[Configurable], + global_logical_batch_size: int, ) -> ConfigOr[Configurable]: """Converts a tf.data source config to a grain source config. Args: tf_data_config: The tf.data source configuration. + global_logical_batch_size: the global logical batch size used. Returns: A grain source configuration. @@ -389,6 +392,7 @@ def processing_fn(ds): return config_for_function(input_grain.mixture_train_input_source).set( preprocessor=preprocessor, data_mixture_components=components, + global_logical_batch_size=global_logical_batch_size, seed=42, ) @@ -413,7 +417,16 @@ def _convert_input_to_grain(self, input_config: Configurable.Config) -> Configur else: assert hasattr(input_config, "source") # Attempt automatic conversion - grain_input_config.source = self._convert_tf_data_to_grain_source(input_config.source) + grain_input_config.source = self._convert_tf_data_to_grain_source( + input_config.source, + global_logical_batch_size=input_config.input_dispatcher.global_logical_batch_size, + ) + + # Copies input_dispatcher and input_partitioner. + if hasattr(input_config, "input_dispatcher"): + grain_input_config.input_dispatcher = input_config.input_dispatcher + if hasattr(input_config, "input_partitioner"): + grain_input_config.input_partitioner = input_config.input_partitioner return grain_input_config diff --git a/axlearn/experiments/text/gpt/benchmark.py b/axlearn/experiments/text/gpt/benchmark.py index 6d4bbd468..c0711cc36 100644 --- a/axlearn/experiments/text/gpt/benchmark.py +++ b/axlearn/experiments/text/gpt/benchmark.py @@ -63,8 +63,6 @@ def timed(fn, msg): return ret -from ajax.experiments.speech.pretrain.online_pretrain_utils import audio_to_modality_config - if __name__ == "__main__": # def main(argv): logging.set_verbosity(logging.INFO) @@ -82,6 +80,7 @@ def timed(fn, msg): cfg.input.partition_spec = PartitionSpec( general_lm.batch_axis_names_from(general_lm.MESH_AXIS_NAMES) ) + cfg.input.source.global_logical_batch_size = 8 ds = timed( lambda: cfg.input.set(name="input").instantiate(parent=None), "initialize", diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 21d0c0c84..cd6123fe4 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -1083,7 +1083,7 @@ def make_single_host_config(base_config_name: str) -> SpmdTrainer.Config: ) config_map[f"{config_name}-fp8-single-host"] = make_single_host_fp8_config_func - if model_size in ("1B", "3B", "7B", "8B"): + if model_size in ("70B"): def make_grain_config(base_config_name: str) -> SpmdTrainer.Config: """Make a grain input processor variant of the base config. @@ -1107,16 +1107,6 @@ def make_grain_config(base_config_name: str) -> SpmdTrainer.Config: convert_training_input=True, ) cfg = grain_modifier.instantiate()(cfg) - - # Configure grain-specific input processing settings - # pylint: disable=cell-var-from-loop - # Adjust batch size for grain processing if needed - for evaler in cfg.evalers.values(): - if hasattr(evaler.input, "input_dispatcher"): - evaler.input.input_dispatcher.global_logical_batch_size //= ( - 64 if version in (Version.V3, Version.V3_TIKTOKEN) else 16 - ) - # pylint: enable=cell-var-from-loop return cfg # Make grain config