2424import re
2525from typing import Any , Protocol
2626
27+ from absl import flags
2728from absl import logging
2829import jax
2930from recml .core .utils import types
@@ -162,12 +163,23 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
162163 Defaults to False.
163164 seed: An optional seed to use for deterministic shuffling / preprocessing.
164165 Defaults to None.
165- tf_data_service_address: An optional URI of a tf.data service to offload
166- preprocessing onto during training. The URI should be in the format
167- "protocol://address", e.g. "grpc://tf-data-service:5050". If `None` no
168- data service will be applied.
166+ enable_tf_data_service: Whether to apply tf.data service for this dataset.
167+ If True, flag `tf_data_service_address` must be set.
169168 tf_data_service_policy: Sharding policy to use for tf.data service when it
170169 is enabled.
170+ tf_data_service_job_name: Job name to use for tf.data service. If None, the
171+ default job name will be used.
172+ offload_preprocessing_to_tf_data_service: Whether to offload preprocessing
173+ to tf.data service. If True, enable_tf_data_service must also be True, and
174+ the preprocessing transformation will be offloaded to tf data service
175+ workers. Otherwise, the preprocessing transformation will be applied on
176+ the host CPU. If tf data service is not enabled, this arg must be set
177+ False. Defaults to False.
178+ tf_data_service_replicate_on_split: Whether to replicate the file dataset on
179+ split when distributing data to tf.data service workers. Note: it could be
180+ used in the case where multiple datasets are processed together under
181+ `Dynamic` mode. The dataset with `tf_data_service_replicate_on_split`
182+ enabled is equivalent to having that dataset processed as `Off` mode.
171183 feature_spec: A mapping of feature keys to `FixedLenFeature`,
172184 `VarLenFeature`, `SparseFeature`, or `RaggedFeature` values. This will be
173185 used to parse the TF examples, or as context_features spec to parse TF
@@ -206,12 +218,13 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
206218 dataset. Defaults to `ShardingInfo(num_processes=jax.process_count(),
207219 process_index=jax.process_index())`. This is similar to `InputContext` in
208220 tensorflow.
221+ cache_reading: Whether to cache the reading of the dataset. This is useful
222+ for debugging and testing. Defaults to False.
209223 debug: An optional boolean indicating whether to debug input boundedness. If
210224 `True`, the dataset will consist of a single batch that's cached and
211- infinitely repeated
225+ infinitely repeated.
212226 """
213227
214- cache_reading : bool = False
215228 input_path : str | Sequence [str ] = ""
216229 tfds_source : str | Sequence [str ] = ""
217230 file_format : FileFormat = FileFormat .RECORDIO
@@ -231,10 +244,13 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
231244 readahead : str | None = None
232245 group_uris_by_dir : bool = False
233246 seed : int | None = None
234- tf_data_service_address : str | None = None
247+ enable_tf_data_service : bool = False
248+ tf_data_service_job_name : str | None = None
235249 tf_data_service_policy : tf .data .experimental .service .ShardingPolicy = (
236250 tf .data .experimental .service .ShardingPolicy .OFF
237251 )
252+ offload_preprocessing_to_tf_data_service : bool = False
253+ tf_data_service_replicate_on_split : bool = False
238254 feature_spec : Mapping [str , IO_Feature ] | None = None
239255 sequence_feature_spec : Mapping [str , IO_Feature ] | None = None
240256 tf_transform_output : TFTransformOutput | None = None
@@ -246,14 +262,26 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
246262 sharding_info : DatasetShardingInfo = dataclasses .field (
247263 default_factory = DatasetShardingInfo
248264 )
265+ cache_reading : bool = False
249266 debug : bool = False
250267
251268 def __post_init__ (self ):
252- if self .tf_data_service_address is not None :
269+ if self .enable_tf_data_service :
270+ if flags .FLAGS .tf_data_service_address is None :
271+ raise ValueError (
272+ "Flag `tf_data_service_address` must be set when"
273+ " `enable_tf_data_service` is True."
274+ )
253275 if self .seed is not None :
254276 raise ValueError ("`seed` must be None for data service." )
255277 if self .sharding :
256278 raise ValueError ("`sharding` must be set to False for data service." )
279+ else :
280+ if self .offload_preprocessing_to_tf_data_service :
281+ raise ValueError (
282+ "`offload_preprocessing_to_tf_data_service` must be False when"
283+ " `enable_tf_data_service` is False."
284+ )
257285
258286 @functools .cached_property
259287 def tfds_metadata (self ) -> TFDSMetadata | None :
@@ -464,6 +492,9 @@ def _file_group_reader(file_group: str) -> tf.data.Dataset:
464492 # Create a dataset of file / file group uris.
465493 dataset = tf .data .Dataset .from_tensor_slices (uris )
466494
495+ if self .tf_data_service_replicate_on_split :
496+ dataset = tf .data .apply_rewrite (dataset , rewrite = "replicate_on_split" )
497+
467498 # Repeat the dataset. We might need to repeat the dataset here in case the
468499 # issue is encountered: internal screenshot link:6jAKKoEMT3afXRe
469500 # even we do have enough shards for the input data.
@@ -478,7 +509,7 @@ def _file_group_reader(file_group: str) -> tf.data.Dataset:
478509 )
479510
480511 # Generate a tf.Example dataset by cycling through all uris in parallel.
481- return dataset .interleave (
512+ dataset = dataset .interleave (
482513 map_func = reader ,
483514 cycle_length = self .cycle_length ,
484515 block_length = self .block_length ,
@@ -490,6 +521,12 @@ def _file_group_reader(file_group: str) -> tf.data.Dataset:
490521 deterministic = self .deterministic ,
491522 )
492523
524+ # Cache the reading of examples from files.
525+ if self .cache_reading :
526+ dataset = dataset .cache ()
527+
528+ return dataset
529+
493530 def _parse_dataset (self , dataset : tf .data .Dataset ) -> tf .data .Dataset :
494531 """Batches and parses an examples dataset."""
495532 # Batch the dataset to the global or per replica batch size.
@@ -533,45 +570,51 @@ def _maybe_apply_tf_data_service(
533570 self , dataset : tf .data .Dataset
534571 ) -> tf .data .Dataset :
535572 """Applies the tf.data service to the dataset."""
536- if self .tf_data_service_address is None :
573+ if not self .enable_tf_data_service :
537574 return dataset
538575
576+ tf_data_service_address = flags .FLAGS .tf_data_service_address
577+
539578 per_proc_batch_size = self .sharding_info .per_process_batch_size (
540579 self .global_batch_size
541580 )
542581 logging .info (
543582 "Applying tf.data service with address %s and per replica batch"
544583 " size %s" ,
545- self . tf_data_service_address ,
584+ tf_data_service_address ,
546585 per_proc_batch_size ,
547586 )
548587 return dataset .apply (
549588 tf .data .experimental .service .distribute (
550589 processing_mode = self .tf_data_service_policy ,
551- service = self .tf_data_service_address ,
552- job_name = f"bs_{ per_proc_batch_size } " ,
590+ service = tf_data_service_address ,
591+ job_name = self .tf_data_service_job_name
592+ or "tf_data_service_shared_job_name" ,
553593 )
554594 )
555595
556596 def make (self ) -> tf .data .Dataset :
557597 """Creates a `tf.data.Dataset` instance with all dataset ops applied."""
558598 # Create an examples dataset.
559- if self .cache_reading :
560- dataset = self ._create_dataset ().cache ()
561- else :
562- dataset = self ._create_dataset ()
599+ dataset = self ._create_dataset ()
563600 # Shuffle and repeat the dataset.
564601 dataset = self ._maybe_shuffle_and_repeat (dataset )
565602 # Batch and parse the examples dataset.
566603 dataset = self ._parse_dataset (dataset )
567604 # Apply filters to the batched dataset.
568605 dataset = self ._maybe_filter_dataset (dataset )
569- # Apply data service.
570- dataset = self ._maybe_apply_tf_data_service (dataset )
606+ # Apply TF Data service before preprocessing.
607+ if not self .offload_preprocessing_to_tf_data_service :
608+ dataset = self ._maybe_apply_tf_data_service (dataset )
609+
571610 # Apply transformations on the dataset.
572611 for fn in self .map_fns :
573612 dataset = dataset .map (fn , num_parallel_calls = self .num_parallel_threads )
574613
614+ # Apply TF Data Service after preprocessing.
615+ if self .offload_preprocessing_to_tf_data_service :
616+ dataset = self ._maybe_apply_tf_data_service (dataset )
617+
575618 if self .debug :
576619 dataset = dataset .take (1 ).cache ().repeat ()
577620
@@ -778,8 +821,7 @@ def _vectorized_filter(features: FeaturesDictType) -> FeaturesDictType:
778821 if isinstance (features [name ], tf .SparseTensor ):
779822 outputs [name ] = tf .sparse_boolean_mask (features [name ], mask )
780823 elif isinstance (features [name ], tf .RaggedTensor ):
781- # TODO(b/307323524): Support this when we start using Ragged tensors.
782- raise ValueError ("Filtering ragged tensors is not supported." )
824+ outputs [name ] = tf .ragged .boolean_mask (features [name ], mask )
783825 else :
784826 outputs [name ] = tf .boolean_mask (features [name ], mask )
785827 return outputs
0 commit comments