2424import re
2525from typing import Any , Protocol
2626
27+ from absl import flags
2728from absl import logging
2829import jax
2930from recml .core .utils import types
@@ -162,12 +163,17 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
162163 Defaults to False.
163164 seed: An optional seed to use for deterministic shuffling / preprocessing.
164165 Defaults to None.
165- tf_data_service_address: An optional URI of a tf.data service to offload
166- preprocessing onto during training. The URI should be in the format
167- "protocol://address", e.g. "grpc://tf-data-service:5050". If `None` no
168- data service will be applied.
166+ enable_tf_data_service: Whether to apply tf.data service for this dataset.
167+ If True, flag `tf_data_service_address` must be set.
169168 tf_data_service_policy: Sharding policy to use for tf.data service when it
170169 is enabled.
170+ tf_data_service_job_name: Job name to use for tf.data service. If None, the
171+ default job name will be used.
172+ tf_data_service_replicate_on_split: Whether to replicate the file dataset on
173+ split when distributing data to tf.data service workers. Note: it could be
174+ used in the case where multiple datasets are processed together under
175+ `Dynamic` mode. The dataset with `tf_data_service_replicate_on_split`
176+ enabled is equivalent to having that dataset processed as `Off` mode.
171177 feature_spec: A mapping of feature keys to `FixedLenFeature`,
172178 `VarLenFeature`, `SparseFeature`, or `RaggedFeature` values. This will be
173179 used to parse the TF examples, or as context_features spec to parse TF
@@ -208,7 +214,7 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
208214 tensorflow.
209215 debug: An optional boolean indicating whether to debug input boundedness. If
210216 `True`, the dataset will consist of a single batch that's cached and
211- infinitely repeated
217+ infinitely repeated.
212218 """
213219
214220 cache_reading : bool = False
@@ -231,7 +237,8 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
231237 readahead : str | None = None
232238 group_uris_by_dir : bool = False
233239 seed : int | None = None
234- tf_data_service_address : str | None = None
240+ enable_tf_data_service : bool = False
241+ tf_data_service_job_name : str | None = None
235242 tf_data_service_policy : tf .data .experimental .service .ShardingPolicy = (
236243 tf .data .experimental .service .ShardingPolicy .OFF
237244 )
@@ -246,10 +253,16 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
246253 sharding_info : DatasetShardingInfo = dataclasses .field (
247254 default_factory = DatasetShardingInfo
248255 )
256+ tf_data_service_replicate_on_split : bool = False
249257 debug : bool = False
250258
251259 def __post_init__ (self ):
252- if self .tf_data_service_address is not None :
260+ if self .enable_tf_data_service :
261+ if flags .FLAGS .tf_data_service_address is None :
262+ raise ValueError (
263+ "Flag `tf_data_service_address` must be set when"
264+ " `enable_tf_data_service` is True."
265+ )
253266 if self .seed is not None :
254267 raise ValueError ("`seed` must be None for data service." )
255268 if self .sharding :
@@ -464,6 +477,9 @@ def _file_group_reader(file_group: str) -> tf.data.Dataset:
464477 # Create a dataset of file / file group uris.
465478 dataset = tf .data .Dataset .from_tensor_slices (uris )
466479
480+ if self .tf_data_service_replicate_on_split :
481+ dataset = tf .data .apply_rewrite (dataset , rewrite = "replicate_on_split" )
482+
467483 # Repeat the dataset. We might need to repeat the dataset here in case the
468484 # issue is encountered: internal screenshot link:6jAKKoEMT3afXRe
469485 # even we do have enough shards for the input data.
@@ -533,23 +549,26 @@ def _maybe_apply_tf_data_service(
533549 self , dataset : tf .data .Dataset
534550 ) -> tf .data .Dataset :
535551 """Applies the tf.data service to the dataset."""
536- if self .tf_data_service_address is None :
552+ if not self .enable_tf_data_service :
537553 return dataset
538554
555+ tf_data_service_address = flags .FLAGS .tf_data_service_address
556+
539557 per_proc_batch_size = self .sharding_info .per_process_batch_size (
540558 self .global_batch_size
541559 )
542560 logging .info (
543561 "Applying tf.data service with address %s and per replica batch"
544562 " size %s" ,
545- self . tf_data_service_address ,
563+ tf_data_service_address ,
546564 per_proc_batch_size ,
547565 )
548566 return dataset .apply (
549567 tf .data .experimental .service .distribute (
550568 processing_mode = self .tf_data_service_policy ,
551- service = self .tf_data_service_address ,
552- job_name = f"bs_{ per_proc_batch_size } " ,
569+ service = tf_data_service_address ,
570+ job_name = self .tf_data_service_job_name
571+ or "tf_data_service_shared_job_name" ,
553572 )
554573 )
555574
0 commit comments