diff --git a/axlearn/common/input_grain.py b/axlearn/common/input_grain.py index 10f58a853..080c26315 100644 --- a/axlearn/common/input_grain.py +++ b/axlearn/common/input_grain.py @@ -30,6 +30,9 @@ * `repeat` with `num_repeat=None` will produce datasets with size `sys.maxsize`. """ + +import json +import os import sys from dataclasses import dataclass from typing import Any, Callable, Optional, Protocol, Sequence, TypeVar, Union, runtime_checkable @@ -44,6 +47,7 @@ from grain._src.python.dataset import dataset as dataset_base from jax.experimental import multihost_utils +from axlearn.common import file_system as fs from axlearn.common import input_base, utils from axlearn.common.config import ( REQUIRED, @@ -51,6 +55,7 @@ Required, config_class, config_for_class, + config_for_function, maybe_instantiate, ) from axlearn.common.module import Module @@ -762,3 +767,90 @@ def shape_dtype(x): return jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype) return jax.tree.map(shape_dtype, example) + + +def mixture_train_input_source( + *, + preprocessor: Union[ConfigOr, list[ConfigOr]], + data_mixture_components: Union[ConfigOr, list], + global_logical_batch_size: int, + seed: Optional[int] = 42, +) -> BuildDatasetFn: + """Build mixture training input source for decoder-only LM model using grain. + + Mixture sampling happens after input processing but before batching, meaning that each batch + example will only contain tokens from a single source. + + Args: + preprocessor: A single or a list of lm text preprocessor config(s). When + used as a list, each preprocessor must correspond to one data source; + when used as a single config, it will be broadcast for all data sources. + data_mixture_components: List of DataMixtureComponent(s). + global_logical_batch_size: The global logical batch size. + + Returns: + A BuildDatasetFn that mixes the given list of DataMixtureComponent(s). + """ + from axlearn.common.config import maybe_instantiate + + data_mixture_components = maybe_instantiate(data_mixture_components) + + def build_dataset_fn( + dispatch_config: DispatchConfig, + ) -> Dataset: + sources = [] + weights = [] + + for component in data_mixture_components: + dataset_name = component.name.replace(":", "/") + + # Construct ArrayRecord paths + arrayrecord_dataset_dir = os.path.join( + "/tmp/gcsfuse/tensorflow_datasets/array_record", dataset_name + ) + + # Use fs.listdir to list all files in the directory + all_files = fs.listdir(arrayrecord_dataset_dir) + + # Filter for arrayrecord files + arrayrecord_files = [ + os.path.join(arrayrecord_dataset_dir, f) + for f in all_files + if f.endswith(".arrayrecord") + ] + + # Create ArrayRecord dataset + source_ds = array_record_dataset(paths=arrayrecord_files, seed=seed).shuffle().repeat() + source_ds = shard_dataset(source_ds, dispatch_config) + # + features_json = os.path.join(arrayrecord_dataset_dir, "features.json") + # pylint: disable-next=import-outside-toplevel + import tensorflow_datasets as tfds + + logging.info( + "Found %s; will assume tfds features and deserialize accordingly.", features_json + ) + with fs.open(features_json) as f: + features_dict = tfds.features.FeaturesDict.from_json(json.load(f)) + source_ds = source_ds.map(features_dict.deserialize_example_np) + # Apply processor to the source dataset + source_ds = preprocessor(source_ds) + + sources.append(source_ds) + weights.append(component.weight) + + # Mix the datasets + mixed_ds = sample_from_datasets(sources=sources, weights=weights) + global_batch_size = global_logical_batch_size + logging.info("Global batch size for grain is set to %s", global_batch_size) + + mixed_ds = per_feed_batch( + mixed_ds, + global_batch_size=global_batch_size, + dispatch_config=dispatch_config, + ) + + # Shard the mixed dataset + return mixed_ds + + return build_dataset_fn diff --git a/axlearn/common/input_grain_lm.py b/axlearn/common/input_grain_lm.py index 24f6ec40d..9a2b3f5a3 100644 --- a/axlearn/common/input_grain_lm.py +++ b/axlearn/common/input_grain_lm.py @@ -545,9 +545,10 @@ def strided_slice(example: dict[str, Tensor]) -> dict[str, Tensor]: ds = ds.map(strided_slice) # Produce batches. - ds = input_grain_text.count_num_bytes( - ds, input_key="target_labels", vocab=vocab, output_key="target_num_bytes" - ) + # Skips target_num_bytes for now. + # ds = input_grain_text.count_num_bytes( + # ds, input_key="target_labels", vocab=vocab, output_key="target_num_bytes" + # ) ds = ds.map(_drop_empty_targets) ds = input_grain.maybe_to_iter_dataset(ds) ds = input_grain.unbatch(ds) diff --git a/axlearn/common/trainer_config_modifier.py b/axlearn/common/trainer_config_modifier.py index b24f859f6..d1dc01dd0 100644 --- a/axlearn/common/trainer_config_modifier.py +++ b/axlearn/common/trainer_config_modifier.py @@ -2,9 +2,13 @@ """Defines trainer config modifiers, which will be used in model definitions.""" -from typing import Dict, Sequence, Union +import json +import logging +import os +from typing import Dict, Optional, Sequence, Union from axlearn.common import config +from axlearn.common import file_system as fs from axlearn.common.base_layer import RematSpec from axlearn.common.config import ( REQUIRED, @@ -13,6 +17,7 @@ Configurable, Required, config_class, + config_for_function, maybe_instantiate, ) from axlearn.common.gradient_accumulation import with_minibatch_steps @@ -319,3 +324,123 @@ def enter_fn(_, value, default_kv): update_cfg.transformation = cfg.learner.optimizer cfg.learner.optimizer = update_cfg return cfg + + +class GrainConfigModifier(ConfigModifier): + """Converts tf.data input pipelines to grain input pipelines.""" + + @config_class + class Config(ConfigModifier.Config): + """Configure GrainConfigModifier. + + TODO: Supports evaluation pipeline using grain. + + Attributes: + convert_training_input: Whether to convert the training input pipeline to grain. + grain_source_builder: Optional grain source builder function to use. + If None, attempts to automatically convert from tf.data sources. + """ + + convert_training_input: bool = True + grain_source_builder: Optional[ConfigOr[Configurable]] = None + + def __init__(self, cfg: Config): + super().__init__(cfg) + cfg = self.config + self._convert_training_input = cfg.convert_training_input + self._grain_source_builder = cfg.grain_source_builder + + def _convert_tf_data_to_grain_source( + self, + tf_data_config: ConfigOr[Configurable], + global_logical_batch_size: int, + ) -> ConfigOr[Configurable]: + """Converts a tf.data source config to a grain source config. + + Args: + tf_data_config: The tf.data source configuration. + global_logical_batch_size: the global logical batch size used. + + Returns: + A grain source configuration. + """ + # Import grain modules here to avoid circular imports + import grain.python as grain + + from axlearn.common import input_grain, input_grain_lm + + # Extract data mixture components from tf_data_config + components = tf_data_config.data_mixture_components + # Extract other relevant config parameters from tf_data_config, with fallbacks + vocab_cfg = tf_data_config.vocab_cfg + max_sequence_length = tf_data_config.max_sequence_length + + def processing_fn(ds): + return input_grain_lm.text_to_lm_training_input( + ds, + vocab=vocab_cfg, + max_len=max_sequence_length, + max_padding_fraction=tf_data_config.preprocessor.max_padding_fraction, + window_size=tf_data_config.preprocessor.window_size, + read_options=grain.ReadOptions(num_threads=8, prefetch_buffer_size=128), + ) + + preprocessor = processing_fn + + # Use the existing mixture_train_input_source function which already handles + # GCS path conversion and fs.listdir operations + return config_for_function(input_grain.mixture_train_input_source).set( + preprocessor=preprocessor, + data_mixture_components=components, + global_logical_batch_size=global_logical_batch_size, + seed=42, + ) + + def _convert_input_to_grain(self, input_config: Configurable.Config) -> Configurable.Config: + """Converts a tf.data Input config to a grain Input config. + + Args: + input_config: The tf.data Input configuration. + + Returns: + A grain Input configuration. + """ + # Import grain input module + from axlearn.common import input_grain + + # Create new grain input config + grain_input_config = input_grain.Input.default_config() + + # Convert the source + if self._grain_source_builder is not None: + grain_input_config.source = self._grain_source_builder + else: + assert hasattr(input_config, "source") + # Attempt automatic conversion + grain_input_config.source = self._convert_tf_data_to_grain_source( + input_config.source, + global_logical_batch_size=input_config.input_dispatcher.global_logical_batch_size, + ) + + # Copies input_dispatcher and input_partitioner. + if hasattr(input_config, "input_dispatcher"): + grain_input_config.input_dispatcher = input_config.input_dispatcher + if hasattr(input_config, "input_partitioner"): + grain_input_config.input_partitioner = input_config.input_partitioner + + return grain_input_config + + def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config: + """Converts tf.data input pipelines to grain input pipelines. + + Args: + cfg: The trainer config to be modified. + + Returns: + The modified trainer config with grain input pipelines. + """ + # Convert training input if requested + if self._convert_training_input and hasattr(cfg, "input"): + cfg.input = self._convert_input_to_grain(cfg.input) + + return cfg diff --git a/axlearn/common/trainer_config_modifier_test.py b/axlearn/common/trainer_config_modifier_test.py index 15305f114..25fa149cb 100644 --- a/axlearn/common/trainer_config_modifier_test.py +++ b/axlearn/common/trainer_config_modifier_test.py @@ -16,6 +16,7 @@ ChainConfigModifier, FP8ConfigModifier, GradientAccumulationModifier, + GrainConfigModifier, MeshShapeModifier, ModuleConfigModifier, OverrideInplaceUpdateTransformation, @@ -209,5 +210,101 @@ def test_fp8_config_modifier(self, use_config_fn): self.assertEqual(cfg.model.linear.quantized_dot_general.fp8_amax_history_length, 1) +class GrainConfigModifierTest(test_utils.TestCase): + def test_convert_tf_data_to_grain_source(self): + """Test conversion of tf.data config to grain source using mixture_train_input_source.""" + # Mock data mixture components + class MockComponent: + def __init__(self, name, weight=1.0): + self.name = name + self.weight = weight + + class MockConfig: + def __init__(self, components, vocab_cfg=None, preprocessor=None, max_sequence_length=1024, seed=123): + self.data_mixture_components = components + self.vocab_cfg = vocab_cfg + self.preprocessor = preprocessor + self.max_sequence_length = max_sequence_length + self.replace_newlines_with = "" + self.seed = seed + + # Test that the conversion leverages mixture_train_input_source + components = [ + MockComponent("gs://permanent-us-central1-0rxn/tensorflow_datasets/c4/en/3.0.1"), + MockComponent("gs://permanent-us-central1-0rxn/tensorflow_datasets/pile/train"), + ] + + # Create a mock vocab and preprocessor config + from axlearn.common.config import config_for_function + mock_vocab_cfg = config_for_function(lambda: "mock_vocab") + mock_preprocessor_cfg = config_for_function(lambda x: x) + + tf_data_config = MockConfig( + components=components, + vocab_cfg=mock_vocab_cfg, + preprocessor=mock_preprocessor_cfg, + max_sequence_length=1024, + seed=123 + ) + + modifier = GrainConfigModifier.default_config().instantiate() + + # Mock the mixture_train_input_source function + import unittest.mock + with unittest.mock.patch('axlearn.common.input_grain.mixture_train_input_source') as mock_mixture: + mock_mixture.return_value = lambda dispatch_config: None # Mock return value + + grain_source = modifier._convert_tf_data_to_grain_source(tf_data_config) + + # Should call mixture_train_input_source with the components + mock_mixture.assert_called_once() + call_args = mock_mixture.call_args + + # Verify the call arguments use values from tf_data_config + self.assertTrue(call_args[1]['is_training']) + self.assertEqual(call_args[1]['data_mixture_components'], components) + self.assertEqual(call_args[1]['vocab_cfg'], mock_vocab_cfg) + self.assertEqual(call_args[1]['preprocessor'], mock_preprocessor_cfg) + self.assertEqual(call_args[1]['max_sequence_length'], 1024) + self.assertEqual(call_args[1]['replace_newlines_with'], "") + self.assertEqual(call_args[1]['seed'], 123) + + def test_convert_tf_data_to_grain_source_with_defaults(self): + """Test conversion with missing config values uses defaults.""" + # Mock data mixture components + class MockComponent: + def __init__(self, name, weight=1.0): + self.name = name + self.weight = weight + + class MockConfig: + def __init__(self, components): + self.data_mixture_components = components + # Missing other config attributes to test defaults + + components = [MockComponent("test_dataset")] + tf_data_config = MockConfig(components) + + modifier = GrainConfigModifier.default_config().instantiate() + + # Mock the mixture_train_input_source function + import unittest.mock + with unittest.mock.patch('axlearn.common.input_grain.mixture_train_input_source') as mock_mixture: + mock_mixture.return_value = lambda dispatch_config: None # Mock return value + + grain_source = modifier._convert_tf_data_to_grain_source(tf_data_config) + + # Should call mixture_train_input_source with default values + mock_mixture.assert_called_once() + call_args = mock_mixture.call_args + + # Verify the call arguments use default values + self.assertTrue(call_args[1]['is_training']) + self.assertEqual(call_args[1]['data_mixture_components'], components) + self.assertEqual(call_args[1]['max_sequence_length'], 512) # Default + self.assertEqual(call_args[1]['replace_newlines_with'], "") # Default + self.assertEqual(call_args[1]['seed'], 42) # Default + + if __name__ == "__main__": absltest.main() diff --git a/axlearn/experiments/text/gpt/benchmark.py b/axlearn/experiments/text/gpt/benchmark.py new file mode 100644 index 000000000..c0711cc36 --- /dev/null +++ b/axlearn/experiments/text/gpt/benchmark.py @@ -0,0 +1,93 @@ +import os + +os.environ["JAX_PLATFORMS"] = "cpu" +import time + +import coloredlogs +import jax +import numpy as np +from absl import flags, logging + +from axlearn.common.utils import set_data_dir + +# coloredlogs.install("INFO", fmt="%(asctime)s %(name)s:%(lineno)s[%(process)d] %(levelname)s %(message)s") +formatter = coloredlogs.ColoredFormatter( + fmt="%(asctime)s %(filename)s:%(lineno)s[%(process)d] %(levelname)s %(message)s" +) + +# logging.get_absl_handler().setFormatter(None) + +FLAGS = flags.FLAGS + + +def _print_stats(res, idx): + res = np.array(res) + res = np.diff(res) + logging.warning( + f"{idx} batches, per-batch time {np.mean(res) * 1e-6:.2f} ms, std {np.std(res) * 1e-6:.2f} ms" + ) + res = res[len(res) // 2 :] + logging.warning( + f"{idx} batches, per-batch time {np.mean(res) * 1e-6:.2f} ms, std {np.std(res) * 1e-6:.2f} ms" + ) + + +def benchmark(ds=None, ds_iter=None, max_iters=None): + if ds_iter is None: + assert ds is not None + ds_iter = iter(ds) + if max_iters is None: + max_iters = 2**30 + + idx = 1 + res = [time.time_ns()] + while idx <= max_iters: + next(ds_iter) + idx += 1 + res.append(time.time_ns()) + if idx % 20 == 0: + _print_stats(res, idx) + if max_iters is not None and idx == max_iters: + break + + _print_stats(res, idx) + res = res[: len(res) // 2] + return np.mean(res), np.std(res) + + +def timed(fn, msg): + begin = time.time_ns() + ret = fn() + end = time.time_ns() + logging.info(f"{msg}, {(end - begin) * 1e-6:.2f} ms") + return ret + + +if __name__ == "__main__": + # def main(argv): + logging.set_verbosity(logging.INFO) + logging.get_absl_handler().setFormatter(formatter) + + with set_data_dir("/tmp/gcsfuse/tensorflow_datasets"): + with jax.make_mesh((1, 1, 1, 1), ("data", "expert", "fsdp", "seq")): + from axlearn.experiments.text.gpt.c4_trainer import named_trainer_configs + + cfg = named_trainer_configs()["fuji-7B-v2-grain"]() + + from ajax.experiments import general_lm + from jax.sharding import PartitionSpec + + cfg.input.partition_spec = PartitionSpec( + general_lm.batch_axis_names_from(general_lm.MESH_AXIS_NAMES) + ) + cfg.input.source.global_logical_batch_size = 8 + ds = timed( + lambda: cfg.input.set(name="input").instantiate(parent=None), + "initialize", + ) + # ds_iter = timed(lambda: iter(ds.dataset().parents[0]), "iter") + ds_iter = timed(lambda: iter(ds), "iter") + + x = next(ds_iter) + while True: + benchmark(ds_iter=ds_iter, max_iters=5000) diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 1e60849a8..cd6123fe4 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -43,6 +43,7 @@ ChainConfigModifier, FP8ConfigModifier, GradientAccumulationModifier, + GrainConfigModifier, MeshShapeModifier, ModuleConfigModifier, PartitionSpecModifier, @@ -1082,4 +1083,41 @@ def make_single_host_config(base_config_name: str) -> SpmdTrainer.Config: ) config_map[f"{config_name}-fp8-single-host"] = make_single_host_fp8_config_func + if model_size in ("70B"): + + def make_grain_config(base_config_name: str) -> SpmdTrainer.Config: + """Make a grain input processor variant of the base config. + + This configuration uses the grain input processing framework for + improved data loading and preprocessing performance. + + Args: + base_config_name: The base config name. + + Returns: + A trainer config that uses grain input processing. + """ + + # pytype: disable=annotation-type-mismatch + cfg: SpmdTrainer.Config = config_map[base_config_name]().clone() + # pytype: enable=annotation-type-mismatch + + # Apply grain config modifier to convert tf.data to Grain + grain_modifier = GrainConfigModifier.default_config().set( + convert_training_input=True, + ) + cfg = grain_modifier.instantiate()(cfg) + return cfg + + # Make grain config + make_grain_config_func = functools.partial(make_grain_config, config_name) + config_map[f"{config_name}-grain"] = make_grain_config_func + + # Make grain configs for FP8 + if f"{config_name}-fp8" in config_map: + make_grain_fp8_config_func = functools.partial( + make_grain_config, f"{config_name}-fp8" + ) + config_map[f"{config_name}-fp8-grain"] = make_grain_fp8_config_func + return config_map