2424import re
2525from typing import Any , Protocol
2626
27+ from absl import flags
2728from absl import logging
2829import jax
2930from recml .core .utils import types
@@ -162,12 +163,23 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
162163 Defaults to False.
163164 seed: An optional seed to use for deterministic shuffling / preprocessing.
164165 Defaults to None.
165- tf_data_service_address: An optional URI of a tf.data service to offload
166- preprocessing onto during training. The URI should be in the format
167- "protocol://address", e.g. "grpc://tf-data-service:5050". If `None` no
168- data service will be applied.
166+ enable_tf_data_service: Whether to apply tf.data service for this dataset.
167+ If True, flag `tf_data_service_address` must be set.
169168 tf_data_service_policy: Sharding policy to use for tf.data service when it
170169 is enabled.
170+ tf_data_service_job_name: Job name to use for tf.data service. If None, the
171+ default job name will be used.
172+ offload_preprocessing_to_tf_data_service: Whether to offload preprocessing
173+ to tf.data service. If True, enable_tf_data_service must also be True, and
174+ the preprocessing transformation will be offloaded to tf data service
175+ workers. Otherwise, the preprocessing transformation will be applied on
176+ the host CPU. If tf data service is not enabled, this arg must be set
177+ False. Defaults to False.
178+ tf_data_service_replicate_on_split: Whether to replicate the file dataset on
179+ split when distributing data to tf.data service workers. Note: it could be
180+ used in the case where multiple datasets are processed together under
181+ `Dynamic` mode. The dataset with `tf_data_service_replicate_on_split`
182+ enabled is equivalent to having that dataset processed as `Off` mode.
171183 feature_spec: A mapping of feature keys to `FixedLenFeature`,
172184 `VarLenFeature`, `SparseFeature`, or `RaggedFeature` values. This will be
173185 used to parse the TF examples, or as context_features spec to parse TF
@@ -208,7 +220,7 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
208220 tensorflow.
209221 debug: An optional boolean indicating whether to debug input boundedness. If
210222 `True`, the dataset will consist of a single batch that's cached and
211- infinitely repeated
223+ infinitely repeated.
212224 """
213225
214226 cache_reading : bool = False
@@ -231,10 +243,12 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
231243 readahead : str | None = None
232244 group_uris_by_dir : bool = False
233245 seed : int | None = None
234- tf_data_service_address : str | None = None
246+ enable_tf_data_service : bool = False
247+ tf_data_service_job_name : str | None = None
235248 tf_data_service_policy : tf .data .experimental .service .ShardingPolicy = (
236249 tf .data .experimental .service .ShardingPolicy .OFF
237250 )
251+ offload_preprocessing_to_tf_data_service : bool = False
238252 feature_spec : Mapping [str , IO_Feature ] | None = None
239253 sequence_feature_spec : Mapping [str , IO_Feature ] | None = None
240254 tf_transform_output : TFTransformOutput | None = None
@@ -246,14 +260,26 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
246260 sharding_info : DatasetShardingInfo = dataclasses .field (
247261 default_factory = DatasetShardingInfo
248262 )
263+ tf_data_service_replicate_on_split : bool = False
249264 debug : bool = False
250265
251266 def __post_init__ (self ):
252- if self .tf_data_service_address is not None :
267+ if self .enable_tf_data_service :
268+ if flags .FLAGS .tf_data_service_address is None :
269+ raise ValueError (
270+ "Flag `tf_data_service_address` must be set when"
271+ " `enable_tf_data_service` is True."
272+ )
253273 if self .seed is not None :
254274 raise ValueError ("`seed` must be None for data service." )
255275 if self .sharding :
256276 raise ValueError ("`sharding` must be set to False for data service." )
277+ else :
278+ if self .offload_preprocessing_to_tf_data_service :
279+ raise ValueError (
280+ "`offload_preprocessing_to_tf_data_service` must be False when"
281+ " `enable_tf_data_service` is False."
282+ )
257283
258284 @functools .cached_property
259285 def tfds_metadata (self ) -> TFDSMetadata | None :
@@ -464,6 +490,9 @@ def _file_group_reader(file_group: str) -> tf.data.Dataset:
464490 # Create a dataset of file / file group uris.
465491 dataset = tf .data .Dataset .from_tensor_slices (uris )
466492
493+ if self .tf_data_service_replicate_on_split :
494+ dataset = tf .data .apply_rewrite (dataset , rewrite = "replicate_on_split" )
495+
467496 # Repeat the dataset. We might need to repeat the dataset here in case the
468497 # issue is encountered: internal screenshot link:6jAKKoEMT3afXRe
469498 # even we do have enough shards for the input data.
@@ -533,23 +562,26 @@ def _maybe_apply_tf_data_service(
533562 self , dataset : tf .data .Dataset
534563 ) -> tf .data .Dataset :
535564 """Applies the tf.data service to the dataset."""
536- if self .tf_data_service_address is None :
565+ if not self .enable_tf_data_service :
537566 return dataset
538567
568+ tf_data_service_address = flags .FLAGS .tf_data_service_address
569+
539570 per_proc_batch_size = self .sharding_info .per_process_batch_size (
540571 self .global_batch_size
541572 )
542573 logging .info (
543574 "Applying tf.data service with address %s and per replica batch"
544575 " size %s" ,
545- self . tf_data_service_address ,
576+ tf_data_service_address ,
546577 per_proc_batch_size ,
547578 )
548579 return dataset .apply (
549580 tf .data .experimental .service .distribute (
550581 processing_mode = self .tf_data_service_policy ,
551- service = self .tf_data_service_address ,
552- job_name = f"bs_{ per_proc_batch_size } " ,
582+ service = tf_data_service_address ,
583+ job_name = self .tf_data_service_job_name
584+ or "tf_data_service_shared_job_name" ,
553585 )
554586 )
555587
@@ -566,12 +598,18 @@ def make(self) -> tf.data.Dataset:
566598 dataset = self ._parse_dataset (dataset )
567599 # Apply filters to the batched dataset.
568600 dataset = self ._maybe_filter_dataset (dataset )
569- # Apply data service.
570- dataset = self ._maybe_apply_tf_data_service (dataset )
601+ # Apply TF Data service before preprocessing.
602+ if not self .offload_preprocessing_to_tf_data_service :
603+ dataset = self ._maybe_apply_tf_data_service (dataset )
604+
571605 # Apply transformations on the dataset.
572606 for fn in self .map_fns :
573607 dataset = dataset .map (fn , num_parallel_calls = self .num_parallel_threads )
574608
609+ # Apply TF Data Service after preprocessing.
610+ if self .offload_preprocessing_to_tf_data_service :
611+ dataset = self ._maybe_apply_tf_data_service (dataset )
612+
575613 if self .debug :
576614 dataset = dataset .take (1 ).cache ().repeat ()
577615
0 commit comments