Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions axlearn/common/input_grain.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
* `repeat` with `num_repeat=None` will produce datasets with size `sys.maxsize`.
"""


import json
import os
import sys
from dataclasses import dataclass
from typing import Any, Callable, Optional, Protocol, Sequence, TypeVar, Union, runtime_checkable
Expand All @@ -44,13 +47,15 @@
from grain._src.python.dataset import dataset as dataset_base
from jax.experimental import multihost_utils

from axlearn.common import file_system as fs
from axlearn.common import input_base, utils
from axlearn.common.config import (
REQUIRED,
ConfigOr,
Required,
config_class,
config_for_class,
config_for_function,
maybe_instantiate,
)
from axlearn.common.module import Module
Expand Down Expand Up @@ -762,3 +767,90 @@ def shape_dtype(x):
return jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)

return jax.tree.map(shape_dtype, example)


def mixture_train_input_source(
*,
preprocessor: Union[ConfigOr, list[ConfigOr]],
data_mixture_components: Union[ConfigOr, list],
global_logical_batch_size: int,
seed: Optional[int] = 42,
) -> BuildDatasetFn:
"""Build mixture training input source for decoder-only LM model using grain.

Mixture sampling happens after input processing but before batching, meaning that each batch
example will only contain tokens from a single source.

Args:
preprocessor: A single or a list of lm text preprocessor config(s). When
used as a list, each preprocessor must correspond to one data source;
when used as a single config, it will be broadcast for all data sources.
data_mixture_components: List of DataMixtureComponent(s).
global_logical_batch_size: The global logical batch size.

Returns:
A BuildDatasetFn that mixes the given list of DataMixtureComponent(s).
"""
from axlearn.common.config import maybe_instantiate

data_mixture_components = maybe_instantiate(data_mixture_components)

def build_dataset_fn(
dispatch_config: DispatchConfig,
) -> Dataset:
sources = []
weights = []

for component in data_mixture_components:
dataset_name = component.name.replace(":", "/")

# Construct ArrayRecord paths
arrayrecord_dataset_dir = os.path.join(
"/tmp/gcsfuse/tensorflow_datasets/array_record", dataset_name
)

# Use fs.listdir to list all files in the directory
all_files = fs.listdir(arrayrecord_dataset_dir)

# Filter for arrayrecord files
arrayrecord_files = [
os.path.join(arrayrecord_dataset_dir, f)
for f in all_files
if f.endswith(".arrayrecord")
]

# Create ArrayRecord dataset
source_ds = array_record_dataset(paths=arrayrecord_files, seed=seed).shuffle().repeat()
source_ds = shard_dataset(source_ds, dispatch_config)
#
features_json = os.path.join(arrayrecord_dataset_dir, "features.json")
# pylint: disable-next=import-outside-toplevel
import tensorflow_datasets as tfds

logging.info(
"Found %s; will assume tfds features and deserialize accordingly.", features_json
)
with fs.open(features_json) as f:
features_dict = tfds.features.FeaturesDict.from_json(json.load(f))
source_ds = source_ds.map(features_dict.deserialize_example_np)
# Apply processor to the source dataset
source_ds = preprocessor(source_ds)

sources.append(source_ds)
weights.append(component.weight)

# Mix the datasets
mixed_ds = sample_from_datasets(sources=sources, weights=weights)
global_batch_size = global_logical_batch_size
logging.info("Global batch size for grain is set to %s", global_batch_size)

mixed_ds = per_feed_batch(
mixed_ds,
global_batch_size=global_batch_size,
dispatch_config=dispatch_config,
)

# Shard the mixed dataset
return mixed_ds

return build_dataset_fn
7 changes: 4 additions & 3 deletions axlearn/common/input_grain_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,9 +545,10 @@ def strided_slice(example: dict[str, Tensor]) -> dict[str, Tensor]:
ds = ds.map(strided_slice)

# Produce batches.
ds = input_grain_text.count_num_bytes(
ds, input_key="target_labels", vocab=vocab, output_key="target_num_bytes"
)
# Skips target_num_bytes for now.
# ds = input_grain_text.count_num_bytes(
# ds, input_key="target_labels", vocab=vocab, output_key="target_num_bytes"
# )
ds = ds.map(_drop_empty_targets)
ds = input_grain.maybe_to_iter_dataset(ds)
ds = input_grain.unbatch(ds)
Expand Down
127 changes: 126 additions & 1 deletion axlearn/common/trainer_config_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

"""Defines trainer config modifiers, which will be used in model definitions."""

from typing import Dict, Sequence, Union
import json
import logging
import os
from typing import Dict, Optional, Sequence, Union

from axlearn.common import config
from axlearn.common import file_system as fs
from axlearn.common.base_layer import RematSpec
from axlearn.common.config import (
REQUIRED,
Expand All @@ -13,6 +17,7 @@
Configurable,
Required,
config_class,
config_for_function,
maybe_instantiate,
)
from axlearn.common.gradient_accumulation import with_minibatch_steps
Expand Down Expand Up @@ -319,3 +324,123 @@ def enter_fn(_, value, default_kv):
update_cfg.transformation = cfg.learner.optimizer
cfg.learner.optimizer = update_cfg
return cfg


class GrainConfigModifier(ConfigModifier):
"""Converts tf.data input pipelines to grain input pipelines."""

@config_class
class Config(ConfigModifier.Config):
"""Configure GrainConfigModifier.

TODO: Supports evaluation pipeline using grain.

Attributes:
convert_training_input: Whether to convert the training input pipeline to grain.
grain_source_builder: Optional grain source builder function to use.
If None, attempts to automatically convert from tf.data sources.
"""

convert_training_input: bool = True
grain_source_builder: Optional[ConfigOr[Configurable]] = None

def __init__(self, cfg: Config):
super().__init__(cfg)
cfg = self.config
self._convert_training_input = cfg.convert_training_input
self._grain_source_builder = cfg.grain_source_builder

def _convert_tf_data_to_grain_source(
self,
tf_data_config: ConfigOr[Configurable],
global_logical_batch_size: int,
) -> ConfigOr[Configurable]:
"""Converts a tf.data source config to a grain source config.

Args:
tf_data_config: The tf.data source configuration.
global_logical_batch_size: the global logical batch size used.

Returns:
A grain source configuration.
"""
# Import grain modules here to avoid circular imports
import grain.python as grain

from axlearn.common import input_grain, input_grain_lm

# Extract data mixture components from tf_data_config
components = tf_data_config.data_mixture_components
# Extract other relevant config parameters from tf_data_config, with fallbacks
vocab_cfg = tf_data_config.vocab_cfg
max_sequence_length = tf_data_config.max_sequence_length

def processing_fn(ds):
return input_grain_lm.text_to_lm_training_input(
ds,
vocab=vocab_cfg,
max_len=max_sequence_length,
max_padding_fraction=tf_data_config.preprocessor.max_padding_fraction,
window_size=tf_data_config.preprocessor.window_size,
read_options=grain.ReadOptions(num_threads=8, prefetch_buffer_size=128),
)

preprocessor = processing_fn

# Use the existing mixture_train_input_source function which already handles
# GCS path conversion and fs.listdir operations
return config_for_function(input_grain.mixture_train_input_source).set(
preprocessor=preprocessor,
data_mixture_components=components,
global_logical_batch_size=global_logical_batch_size,
seed=42,
)

def _convert_input_to_grain(self, input_config: Configurable.Config) -> Configurable.Config:
"""Converts a tf.data Input config to a grain Input config.

Args:
input_config: The tf.data Input configuration.

Returns:
A grain Input configuration.
"""
# Import grain input module
from axlearn.common import input_grain

# Create new grain input config
grain_input_config = input_grain.Input.default_config()

# Convert the source
if self._grain_source_builder is not None:
grain_input_config.source = self._grain_source_builder
else:
assert hasattr(input_config, "source")
# Attempt automatic conversion
grain_input_config.source = self._convert_tf_data_to_grain_source(
input_config.source,
global_logical_batch_size=input_config.input_dispatcher.global_logical_batch_size,
)

# Copies input_dispatcher and input_partitioner.
if hasattr(input_config, "input_dispatcher"):
grain_input_config.input_dispatcher = input_config.input_dispatcher
if hasattr(input_config, "input_partitioner"):
grain_input_config.input_partitioner = input_config.input_partitioner

return grain_input_config

def __call__(self, cfg: SpmdTrainer.Config) -> SpmdTrainer.Config:
"""Converts tf.data input pipelines to grain input pipelines.

Args:
cfg: The trainer config to be modified.

Returns:
The modified trainer config with grain input pipelines.
"""
# Convert training input if requested
if self._convert_training_input and hasattr(cfg, "input"):
cfg.input = self._convert_input_to_grain(cfg.input)

return cfg
97 changes: 97 additions & 0 deletions axlearn/common/trainer_config_modifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ChainConfigModifier,
FP8ConfigModifier,
GradientAccumulationModifier,
GrainConfigModifier,
MeshShapeModifier,
ModuleConfigModifier,
OverrideInplaceUpdateTransformation,
Expand Down Expand Up @@ -209,5 +210,101 @@ def test_fp8_config_modifier(self, use_config_fn):
self.assertEqual(cfg.model.linear.quantized_dot_general.fp8_amax_history_length, 1)


class GrainConfigModifierTest(test_utils.TestCase):
def test_convert_tf_data_to_grain_source(self):
"""Test conversion of tf.data config to grain source using mixture_train_input_source."""
# Mock data mixture components
class MockComponent:
def __init__(self, name, weight=1.0):
self.name = name
self.weight = weight

class MockConfig:
def __init__(self, components, vocab_cfg=None, preprocessor=None, max_sequence_length=1024, seed=123):
self.data_mixture_components = components
self.vocab_cfg = vocab_cfg
self.preprocessor = preprocessor
self.max_sequence_length = max_sequence_length
self.replace_newlines_with = "<newline>"
self.seed = seed

# Test that the conversion leverages mixture_train_input_source
components = [
MockComponent("gs://permanent-us-central1-0rxn/tensorflow_datasets/c4/en/3.0.1"),
MockComponent("gs://permanent-us-central1-0rxn/tensorflow_datasets/pile/train"),
]

# Create a mock vocab and preprocessor config
from axlearn.common.config import config_for_function
mock_vocab_cfg = config_for_function(lambda: "mock_vocab")
mock_preprocessor_cfg = config_for_function(lambda x: x)

tf_data_config = MockConfig(
components=components,
vocab_cfg=mock_vocab_cfg,
preprocessor=mock_preprocessor_cfg,
max_sequence_length=1024,
seed=123
)

modifier = GrainConfigModifier.default_config().instantiate()

# Mock the mixture_train_input_source function
import unittest.mock
with unittest.mock.patch('axlearn.common.input_grain.mixture_train_input_source') as mock_mixture:
mock_mixture.return_value = lambda dispatch_config: None # Mock return value

grain_source = modifier._convert_tf_data_to_grain_source(tf_data_config)

# Should call mixture_train_input_source with the components
mock_mixture.assert_called_once()
call_args = mock_mixture.call_args

# Verify the call arguments use values from tf_data_config
self.assertTrue(call_args[1]['is_training'])
self.assertEqual(call_args[1]['data_mixture_components'], components)
self.assertEqual(call_args[1]['vocab_cfg'], mock_vocab_cfg)
self.assertEqual(call_args[1]['preprocessor'], mock_preprocessor_cfg)
self.assertEqual(call_args[1]['max_sequence_length'], 1024)
self.assertEqual(call_args[1]['replace_newlines_with'], "<newline>")
self.assertEqual(call_args[1]['seed'], 123)

def test_convert_tf_data_to_grain_source_with_defaults(self):
"""Test conversion with missing config values uses defaults."""
# Mock data mixture components
class MockComponent:
def __init__(self, name, weight=1.0):
self.name = name
self.weight = weight

class MockConfig:
def __init__(self, components):
self.data_mixture_components = components
# Missing other config attributes to test defaults

components = [MockComponent("test_dataset")]
tf_data_config = MockConfig(components)

modifier = GrainConfigModifier.default_config().instantiate()

# Mock the mixture_train_input_source function
import unittest.mock
with unittest.mock.patch('axlearn.common.input_grain.mixture_train_input_source') as mock_mixture:
mock_mixture.return_value = lambda dispatch_config: None # Mock return value

grain_source = modifier._convert_tf_data_to_grain_source(tf_data_config)

# Should call mixture_train_input_source with default values
mock_mixture.assert_called_once()
call_args = mock_mixture.call_args

# Verify the call arguments use default values
self.assertTrue(call_args[1]['is_training'])
self.assertEqual(call_args[1]['data_mixture_components'], components)
self.assertEqual(call_args[1]['max_sequence_length'], 512) # Default
self.assertEqual(call_args[1]['replace_newlines_with'], "<n>") # Default
self.assertEqual(call_args[1]['seed'], 42) # Default


if __name__ == "__main__":
absltest.main()
Loading
Loading