Skip to content

Commit 24d5eb9

Browse files
Hilly12recml authors
authored andcommitted
Refactor a few APIs.
Notably this removes the `rng` argument from `JaxTrainer` to avoid implicitly passing it. Reverts changelist 793734230 PiperOrigin-RevId: 789073073
1 parent 847628b commit 24d5eb9

36 files changed

+4993
-630
lines changed

recml/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,4 @@
3838
from recml.core.utils.types import Factory
3939
from recml.core.utils.types import FactoryProtocol
4040
from recml.core.utils.types import ObjectFactory
41+
from recml.layers.common import EmbeddingSpec

recml/core/data/tf_dataset_factory.py

Lines changed: 63 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import re
2525
from typing import Any, Protocol
2626

27+
from absl import flags
2728
from absl import logging
2829
import jax
2930
from recml.core.utils import types
@@ -162,12 +163,23 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
162163
Defaults to False.
163164
seed: An optional seed to use for deterministic shuffling / preprocessing.
164165
Defaults to None.
165-
tf_data_service_address: An optional URI of a tf.data service to offload
166-
preprocessing onto during training. The URI should be in the format
167-
"protocol://address", e.g. "grpc://tf-data-service:5050". If `None` no
168-
data service will be applied.
166+
enable_tf_data_service: Whether to apply tf.data service for this dataset.
167+
If True, flag `tf_data_service_address` must be set.
169168
tf_data_service_policy: Sharding policy to use for tf.data service when it
170169
is enabled.
170+
tf_data_service_job_name: Job name to use for tf.data service. If None, the
171+
default job name will be used.
172+
offload_preprocessing_to_tf_data_service: Whether to offload preprocessing
173+
to tf.data service. If True, enable_tf_data_service must also be True, and
174+
the preprocessing transformation will be offloaded to tf data service
175+
workers. Otherwise, the preprocessing transformation will be applied on
176+
the host CPU. If tf data service is not enabled, this arg must be set
177+
False. Defaults to False.
178+
tf_data_service_replicate_on_split: Whether to replicate the file dataset on
179+
split when distributing data to tf.data service workers. Note: it could be
180+
used in the case where multiple datasets are processed together under
181+
`Dynamic` mode. The dataset with `tf_data_service_replicate_on_split`
182+
enabled is equivalent to having that dataset processed as `Off` mode.
171183
feature_spec: A mapping of feature keys to `FixedLenFeature`,
172184
`VarLenFeature`, `SparseFeature`, or `RaggedFeature` values. This will be
173185
used to parse the TF examples, or as context_features spec to parse TF
@@ -206,12 +218,13 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
206218
dataset. Defaults to `ShardingInfo(num_processes=jax.process_count(),
207219
process_index=jax.process_index())`. This is similar to `InputContext` in
208220
tensorflow.
221+
cache_reading: Whether to cache the reading of the dataset. This is useful
222+
for debugging and testing. Defaults to False.
209223
debug: An optional boolean indicating whether to debug input boundedness. If
210224
`True`, the dataset will consist of a single batch that's cached and
211-
infinitely repeated
225+
infinitely repeated.
212226
"""
213227

214-
cache_reading: bool = False
215228
input_path: str | Sequence[str] = ""
216229
tfds_source: str | Sequence[str] = ""
217230
file_format: FileFormat = FileFormat.RECORDIO
@@ -231,10 +244,13 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
231244
readahead: str | None = None
232245
group_uris_by_dir: bool = False
233246
seed: int | None = None
234-
tf_data_service_address: str | None = None
247+
enable_tf_data_service: bool = False
248+
tf_data_service_job_name: str | None = None
235249
tf_data_service_policy: tf.data.experimental.service.ShardingPolicy = (
236250
tf.data.experimental.service.ShardingPolicy.OFF
237251
)
252+
offload_preprocessing_to_tf_data_service: bool = False
253+
tf_data_service_replicate_on_split: bool = False
238254
feature_spec: Mapping[str, IO_Feature] | None = None
239255
sequence_feature_spec: Mapping[str, IO_Feature] | None = None
240256
tf_transform_output: TFTransformOutput | None = None
@@ -246,14 +262,26 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
246262
sharding_info: DatasetShardingInfo = dataclasses.field(
247263
default_factory=DatasetShardingInfo
248264
)
265+
cache_reading: bool = False
249266
debug: bool = False
250267

251268
def __post_init__(self):
252-
if self.tf_data_service_address is not None:
269+
if self.enable_tf_data_service:
270+
if flags.FLAGS.tf_data_service_address is None:
271+
raise ValueError(
272+
"Flag `tf_data_service_address` must be set when"
273+
" `enable_tf_data_service` is True."
274+
)
253275
if self.seed is not None:
254276
raise ValueError("`seed` must be None for data service.")
255277
if self.sharding:
256278
raise ValueError("`sharding` must be set to False for data service.")
279+
else:
280+
if self.offload_preprocessing_to_tf_data_service:
281+
raise ValueError(
282+
"`offload_preprocessing_to_tf_data_service` must be False when"
283+
" `enable_tf_data_service` is False."
284+
)
257285

258286
@functools.cached_property
259287
def tfds_metadata(self) -> TFDSMetadata | None:
@@ -464,6 +492,9 @@ def _file_group_reader(file_group: str) -> tf.data.Dataset:
464492
# Create a dataset of file / file group uris.
465493
dataset = tf.data.Dataset.from_tensor_slices(uris)
466494

495+
if self.tf_data_service_replicate_on_split:
496+
dataset = tf.data.apply_rewrite(dataset, rewrite="replicate_on_split")
497+
467498
# Repeat the dataset. We might need to repeat the dataset here in case the
468499
# issue is encountered: internal screenshot link:6jAKKoEMT3afXRe
469500
# even we do have enough shards for the input data.
@@ -478,7 +509,7 @@ def _file_group_reader(file_group: str) -> tf.data.Dataset:
478509
)
479510

480511
# Generate a tf.Example dataset by cycling through all uris in parallel.
481-
return dataset.interleave(
512+
dataset = dataset.interleave(
482513
map_func=reader,
483514
cycle_length=self.cycle_length,
484515
block_length=self.block_length,
@@ -490,6 +521,12 @@ def _file_group_reader(file_group: str) -> tf.data.Dataset:
490521
deterministic=self.deterministic,
491522
)
492523

524+
# Cache the reading of examples from files.
525+
if self.cache_reading:
526+
dataset = dataset.cache()
527+
528+
return dataset
529+
493530
def _parse_dataset(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
494531
"""Batches and parses an examples dataset."""
495532
# Batch the dataset to the global or per replica batch size.
@@ -533,45 +570,51 @@ def _maybe_apply_tf_data_service(
533570
self, dataset: tf.data.Dataset
534571
) -> tf.data.Dataset:
535572
"""Applies the tf.data service to the dataset."""
536-
if self.tf_data_service_address is None:
573+
if not self.enable_tf_data_service:
537574
return dataset
538575

576+
tf_data_service_address = flags.FLAGS.tf_data_service_address
577+
539578
per_proc_batch_size = self.sharding_info.per_process_batch_size(
540579
self.global_batch_size
541580
)
542581
logging.info(
543582
"Applying tf.data service with address %s and per replica batch"
544583
" size %s",
545-
self.tf_data_service_address,
584+
tf_data_service_address,
546585
per_proc_batch_size,
547586
)
548587
return dataset.apply(
549588
tf.data.experimental.service.distribute(
550589
processing_mode=self.tf_data_service_policy,
551-
service=self.tf_data_service_address,
552-
job_name=f"bs_{per_proc_batch_size}",
590+
service=tf_data_service_address,
591+
job_name=self.tf_data_service_job_name
592+
or "tf_data_service_shared_job_name",
553593
)
554594
)
555595

556596
def make(self) -> tf.data.Dataset:
557597
"""Creates a `tf.data.Dataset` instance with all dataset ops applied."""
558598
# Create an examples dataset.
559-
if self.cache_reading:
560-
dataset = self._create_dataset().cache()
561-
else:
562-
dataset = self._create_dataset()
599+
dataset = self._create_dataset()
563600
# Shuffle and repeat the dataset.
564601
dataset = self._maybe_shuffle_and_repeat(dataset)
565602
# Batch and parse the examples dataset.
566603
dataset = self._parse_dataset(dataset)
567604
# Apply filters to the batched dataset.
568605
dataset = self._maybe_filter_dataset(dataset)
569-
# Apply data service.
570-
dataset = self._maybe_apply_tf_data_service(dataset)
606+
# Apply TF Data service before preprocessing.
607+
if not self.offload_preprocessing_to_tf_data_service:
608+
dataset = self._maybe_apply_tf_data_service(dataset)
609+
571610
# Apply transformations on the dataset.
572611
for fn in self.map_fns:
573612
dataset = dataset.map(fn, num_parallel_calls=self.num_parallel_threads)
574613

614+
# Apply TF Data Service after preprocessing.
615+
if self.offload_preprocessing_to_tf_data_service:
616+
dataset = self._maybe_apply_tf_data_service(dataset)
617+
575618
if self.debug:
576619
dataset = dataset.take(1).cache().repeat()
577620

@@ -778,8 +821,7 @@ def _vectorized_filter(features: FeaturesDictType) -> FeaturesDictType:
778821
if isinstance(features[name], tf.SparseTensor):
779822
outputs[name] = tf.sparse_boolean_mask(features[name], mask)
780823
elif isinstance(features[name], tf.RaggedTensor):
781-
# TODO(b/307323524): Support this when we start using Ragged tensors.
782-
raise ValueError("Filtering ragged tensors is not supported.")
824+
outputs[name] = tf.ragged.boolean_mask(features[name], mask)
783825
else:
784826
outputs[name] = tf.boolean_mask(features[name], mask)
785827
return outputs

recml/core/ops/hstu_ops.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ def _apply_mask(
125125
masks = []
126126
if mask_ref is not None:
127127
if k_in_lanes:
128-
mask = pl.load(mask_ref, (slice(None), k_slice))
128+
mask = mask_ref[:, k_slice]
129129
else:
130-
mask = pl.load(mask_ref, (k_slice, slice(None)))
130+
mask = mask_ref[k_slice, :]
131131

132132
snm = jnp.where(should_not_mask, 1, 0)
133133
masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0)
@@ -156,7 +156,7 @@ def _apply_mask(
156156
k_sequence = k_offset + jax.lax.broadcasted_iota(
157157
jnp.int32, (k_slice.size, bq), 0
158158
)
159-
q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq]
159+
q_sequence = q_sequence_ref[:1, :] # [1, bq]
160160
q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq))
161161

162162
assert q_sequence.shape == k_sequence.shape
@@ -170,7 +170,7 @@ def _apply_mask(
170170

171171
if q_segment_ids_ref is not None:
172172
if k_in_lanes:
173-
kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice]
173+
kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice]
174174
repeats, rem = divmod(kv_ids.shape[1], NUM_LANES)
175175
if rem:
176176
raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}")
@@ -181,9 +181,9 @@ def _apply_mask(
181181
if rem:
182182
raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}")
183183
kv_ids = pltpu.repeat(
184-
pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1
184+
kv_segment_ids_ref[k_slice, :], repeats, axis=1
185185
) # [k_slice, bq]
186-
q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq]
186+
q_ids = q_segment_ids_ref[:1, :] # [1, bq]
187187
masks.append(q_ids == kv_ids)
188188

189189
if masks:
@@ -228,7 +228,7 @@ def body(kv_compute_index, _):
228228
slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute)
229229

230230
q = q_ref[...]
231-
k = pl.load(k_ref, (slice_k, slice(None)))
231+
k = k_ref[slice_k, :]
232232
qk = jax.lax.dot_general(
233233
q, k, NT_DIM_NUMBERS, preferred_element_type=jnp.float32
234234
)
@@ -256,7 +256,7 @@ def body(kv_compute_index, _):
256256
)
257257

258258
sv_dims = NN_DIM_NUMBERS
259-
v = pl.load(v_ref, (slice_k, slice(None)))
259+
v = v_ref[slice_k, :]
260260

261261
to_float32 = lambda x: x.astype(jnp.float32)
262262
v = to_float32(v)

recml/core/training/core.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
# limitations under the License.
1414
"""Core training library for Jax."""
1515

16+
from __future__ import annotations
17+
1618
import abc
1719
from collections.abc import Mapping, Sequence
1820
import dataclasses
1921
import enum
2022
from typing import Any, Generic, TypeVar
2123

24+
import fiddle as fdl
2225
import jax
2326
import jax.numpy as jnp
2427
from recml.core.data import iterator
@@ -37,7 +40,6 @@
3740
ORBAX_CHECKPOINT_DEFAULT_KEY = "default"
3841

3942
DEFAULT_RNG_SEED = 0
40-
IN_TRAINER_CONTEXT = False # Set to true when run from the main trainer.
4143
STATE_CHECKPOINT_KEY = "state"
4244

4345
TaskT = TypeVar("TaskT")
@@ -57,6 +59,14 @@
5759
class Trainer(abc.ABC, Generic[TaskT]):
5860
"""A base trainer interface for training and evaluation."""
5961

62+
class Mode(enum.StrEnum):
63+
"""Mode to run an experiment."""
64+
65+
TRAIN = "train"
66+
EVAL = "eval"
67+
TRAIN_AND_EVAL = "train_and_eval"
68+
CONTINUOUS_EVAL = "continuous_eval"
69+
6070
@abc.abstractmethod
6171
def __init__(self, model_dir: str, *args, **kwargs):
6272
"""Initializes the instance."""
@@ -77,6 +87,23 @@ def train_and_evaluate(self, task: TaskT, *args, **kwargs) -> Logs | None:
7787
def evaluate_continuously(self, task: TaskT, *args, **kwargs) -> Logs | None:
7888
"""Performs continuous evaluation until a condition is met."""
7989

90+
def run(self, task: TaskT, mode: Any) -> Logs | None:
91+
"""Runs the experiment in the given mode."""
92+
if mode == Trainer.Mode.TRAIN_AND_EVAL:
93+
return self.train_and_evaluate(task)
94+
elif mode == Trainer.Mode.TRAIN:
95+
return self.train(task)
96+
elif mode == Trainer.Mode.EVAL:
97+
return self.evaluate(task)
98+
elif mode == Trainer.Mode.CONTINUOUS_EVAL:
99+
return self.evaluate_continuously(task)
100+
else:
101+
raise ValueError(f"The job mode provided is not supported: {mode}.")
102+
103+
@classmethod
104+
def setup_experiment(cls, experiment_cfg: fdl.Config[Experiment]):
105+
"""Sets up the experiment before it is instantiated."""
106+
80107

81108
@dataclasses.dataclass(frozen=True)
82109
class Experiment(Generic[TaskT]):
@@ -90,32 +117,13 @@ class Experiment(Generic[TaskT]):
90117
trainer: The trainer to use for the experiment.
91118
"""
92119

93-
class Mode(enum.StrEnum):
94-
"""Mode to run an experiment."""
95-
96-
TRAIN = "train"
97-
EVAL = "eval"
98-
TRAIN_AND_EVAL = "train_and_eval"
99-
CONTINUOUS_EVAL = "continuous_eval"
100-
101120
task: TaskT
102121
trainer: Trainer[TaskT]
103122

104123

105-
def run_experiment(
106-
experiment: Experiment, mode: Experiment.Mode
107-
) -> Logs | None:
124+
def run_experiment(experiment: Experiment, mode: Any) -> Logs | None:
108125
"""Runs an experiment."""
109-
if mode == Experiment.Mode.TRAIN_AND_EVAL:
110-
return experiment.trainer.train_and_evaluate(experiment.task)
111-
elif mode == Experiment.Mode.TRAIN:
112-
return experiment.trainer.train(experiment.task)
113-
elif mode == Experiment.Mode.EVAL:
114-
return experiment.trainer.evaluate(experiment.task)
115-
elif mode == Experiment.Mode.CONTINUOUS_EVAL:
116-
return experiment.trainer.evaluate_continuously(experiment.task)
117-
else:
118-
raise ValueError(f"The job mode provided is not supported: {mode}.")
126+
return experiment.trainer.run(experiment.task, mode)
119127

120128

121129
def get_iterators(
@@ -161,9 +169,7 @@ def get_iterators(
161169
k: iterator.TFDatasetIterator(v) for k, v in eval_datasets.items()
162170
}
163171

164-
if not all(
165-
isinstance(v, iterator.Iterator) for v in eval_datasets.values()
166-
):
172+
if not all(isinstance(v, iterator.Iterator) for v in eval_datasets.values()):
167173
raise ValueError(
168174
"Expected all values in the evaluation datasets mapping to be either"
169175
" `tf.data.Dataset` instances or CLU `DatasetIterator` instances,"
@@ -179,7 +185,7 @@ def get_shape(
179185
"""Gets the shape of a dense / sparse / ragged tensor or tensor spec."""
180186
if isinstance(x, tf.SparseTensor):
181187
return [x.shape[0]] + [None for _ in x.shape[1:]]
182-
return x.shape.as_list()
188+
return x.shape.as_list() # pylint: disable=attribute-error
183189

184190

185191
def in_tracing_context() -> bool:

0 commit comments

Comments
 (0)