Skip to content

Feature request: support for split/Subset-like behavior and shared local cache in StreamingDataset #916

@Nima-Hs

Description

@Nima-Hs

Hi, and thank you for the excellent work on this package!

I've been using StreamingDataset for large-scale training with a dataset stored on S3, and I've encountered two major limitations that I couldn't find addressed in the documentation. I’d appreciate clarification or guidance on best practices, and I'd like to propose feature support if these aren't currently handled.


📦 Use Case

We’re working with a very large dataset (~100 TB) stored in an S3-compatible object storage system. The dataset is preprocessed into MDS format and hosted in a shared S3 bucket. Each sample includes metadata (e.g., age, sex, etc.), and users often want to:

  1. Filter the dataset based on metadata (e.g., only use female patients over age 60)

  2. Split the dataset into train/val/test subsets using arbitrary logic

This is a common scenario in medical and scientific domains where researchers need flexible subsets from a shared, centralized dataset.


❗ Current Limitation 1: Subset behavior unclear and incompatible with DDP

For single-GPU training, we’ve been able to use torch.utils.data.Subset(streaming_dataset, indices) to split or filter the dataset, and it seems to work as expected.

However, this usage is not documented anywhere, and I want to confirm:

  • Is using Subset on a StreamingDataset officially supported?

  • Are there caveats around __len__, shuffling, or streaming behavior when using Subset this way?

  • Will it break prefetching or caching mechanisms internally?

For multi-GPU training (DDP), the problem becomes more serious:
As soon as the StreamingDataset is initialized, it automatically shards the dataset across processes based on RANK, WORLD_SIZE, etc. This means:

  • Any filtering or splitting logic (e.g., Subset) only applies to that local shard

  • The user has no way to enforce a consistent global split across ranks

  • It becomes impossible to implement deterministic train/val/test splits or metadata-based filtering unless the dataset is physically reprocessed into new shard directories (which is infeasible at 100 TB scale)

We would benefit greatly from:

  • A mechanism to restrict the global sample list before DDP sharding occurs, such as a split() method or partition_index constructor argument that defines a subset of the dataset globally

  • Or a way to defer sharding so that the user can apply their own filtering or sampling logic prior to partitioning for distributed training


❗ Current Limitation 2: Shared local cache not supported

In our cluster setup, multiple users and jobs need to train on the same dataset. To save bandwidth and storage, we want to share the same local cache directory (e.g., /scratch/mds_cache) between users or jobs.

Currently:

  • Each process or user creates its own local cache path, leading to duplicated downloads

  • Reusing the same path leads to potential race conditions, errors, or inconsistent states

We would love to see support for:

  • A safe shared cache mode — possibly with options for:

    • Read-only access

    • Locking mechanisms

    • Checksum validation for partially downloaded files

This would make streaming much more scalable for large shared environments.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions