-
Notifications
You must be signed in to change notification settings - Fork 188
Description
Hi, and thank you for the excellent work on this package!
I've been using StreamingDataset for large-scale training with a dataset stored on S3, and I've encountered two major limitations that I couldn't find addressed in the documentation. I’d appreciate clarification or guidance on best practices, and I'd like to propose feature support if these aren't currently handled.
📦 Use Case
We’re working with a very large dataset (~100 TB) stored in an S3-compatible object storage system. The dataset is preprocessed into MDS format and hosted in a shared S3 bucket. Each sample includes metadata (e.g., age, sex, etc.), and users often want to:
-
Filter the dataset based on metadata (e.g., only use female patients over age 60)
-
Split the dataset into train/val/test subsets using arbitrary logic
This is a common scenario in medical and scientific domains where researchers need flexible subsets from a shared, centralized dataset.
❗ Current Limitation 1: Subset behavior unclear and incompatible with DDP
For single-GPU training, we’ve been able to use torch.utils.data.Subset(streaming_dataset, indices) to split or filter the dataset, and it seems to work as expected.
However, this usage is not documented anywhere, and I want to confirm:
-
Is using
Subseton aStreamingDatasetofficially supported? -
Are there caveats around
__len__, shuffling, or streaming behavior when usingSubsetthis way? -
Will it break prefetching or caching mechanisms internally?
For multi-GPU training (DDP), the problem becomes more serious:
As soon as the StreamingDataset is initialized, it automatically shards the dataset across processes based on RANK, WORLD_SIZE, etc. This means:
-
Any filtering or splitting logic (e.g.,
Subset) only applies to that local shard -
The user has no way to enforce a consistent global split across ranks
-
It becomes impossible to implement deterministic train/val/test splits or metadata-based filtering unless the dataset is physically reprocessed into new shard directories (which is infeasible at 100 TB scale)
We would benefit greatly from:
-
A mechanism to restrict the global sample list before DDP sharding occurs, such as a
split()method orpartition_indexconstructor argument that defines a subset of the dataset globally -
Or a way to defer sharding so that the user can apply their own filtering or sampling logic prior to partitioning for distributed training
❗ Current Limitation 2: Shared local cache not supported
In our cluster setup, multiple users and jobs need to train on the same dataset. To save bandwidth and storage, we want to share the same local cache directory (e.g., /scratch/mds_cache) between users or jobs.
Currently:
-
Each process or user creates its own local cache path, leading to duplicated downloads
-
Reusing the same path leads to potential race conditions, errors, or inconsistent states
We would love to see support for:
-
A safe shared cache mode — possibly with options for:
-
Read-only access
-
Locking mechanisms
-
Checksum validation for partially downloaded files
-
This would make streaming much more scalable for large shared environments.