Skip to content

fix: rng handling with workers#135

Merged
ilan-gold merged 34 commits intoscverse:mainfrom
selmanozleyen:fix/rng-handling
Feb 12, 2026
Merged

fix: rng handling with workers#135
ilan-gold merged 34 commits intoscverse:mainfrom
selmanozleyen:fix/rng-handling

Conversation

@selmanozleyen
Copy link
Member

@selmanozleyen selmanozleyen commented Feb 5, 2026

hi @ilan-gold @flying-sheep ,

The current code doesn't use spawn and uses kind of unconventional worker rng handling. For example this line ensures each worker returns the same rng for each worker because it was used to sharing chunks across workers

# Use the same seed for all workers that the resulting splits are the same across workers
# torch default seed is `base_seed + worker_id`. Hence, subtract worker_id to get the base seed
return np.random.default_rng(self._worker_info.seed - self._worker_info.id)

Another problem is

# Avoid copies using in-place shuffling since `self._shuffle` should not change mid-training
np.random.default_rng().shuffle(batch_indices)
split_batch_indices = split_given_size(batch_indices, self._batch_size)
where you can't control reproducilibity for batch shuffling. I added a unit test here which you can cherry-pick https://github.com/scverse/annbatch/tree/1bcbb8887e5042c52e0c21dc3ad327c1a1a15bc1 and see it fails on main but passes on this fix (the other changes I made to the file are to update the other tests with MockWorkerHandle's)

Here is the proposed way to handle rngs now

  • For batch shuffling I use the worker_rng because usually each worker gets a different chunks and should shuffle their batches independently
  • For chunk sharing/shuffling I use the ChunkSampler._rng because ChunkSampler is created before the torch workers are spawned and the copy of ChunkSampler's will use the same ChunkSampler._rng and guarantee that the chunks are shared the same way without overlapping.

why we should use spawn: https://numpy.org/doc/stable/reference/random/multithreading.html

Note: we can talk about the details of WorkerHandler.__init__, it depends on the expected usage of it I guess but I just wanted to show what I was talking about in the standup

@codecov
Copy link

codecov bot commented Feb 5, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 91.49%. Comparing base (dd6e16c) to head (872e328).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #135      +/-   ##
==========================================
- Coverage   93.36%   91.49%   -1.88%     
==========================================
  Files          10       11       +1     
  Lines         829      811      -18     
==========================================
- Hits          774      742      -32     
- Misses         55       69      +14     
Files with missing lines Coverage Δ
src/annbatch/samplers/_chunk_sampler.py 93.18% <100.00%> (-0.57%) ⬇️
src/annbatch/samplers/_utils.py 100.00% <100.00%> (ø)
src/annbatch/utils.py 86.36% <100.00%> (-4.63%) ⬇️

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ilan-gold ilan-gold changed the title Fix: rng handling with workers and no more np.random.default_rng().shuffle(batch_indices) each iter fix: rng handling with workers and no more np.random.default_rng().shuffle(batch_indices) each iter Feb 5, 2026
@selmanozleyen
Copy link
Member Author

@ilan-gold could you rename to "Rng handling with workers" since we split this PR

@selmanozleyen selmanozleyen changed the title fix: rng handling with workers and no more np.random.default_rng().shuffle(batch_indices) each iter fix: rng handling with workers Feb 6, 2026
@flying-sheep
Copy link
Member

flying-sheep commented Feb 6, 2026

OK, I think you’re making things too complicated. Why not just delete WorkerHandle and do

worker_info = None
if ...:
    worker_info = get_worker_info()

and later

batch_rng = np.random.default_generator(worker_info.seed) if worker_info else self._rng

or so?

@selmanozleyen
Copy link
Member Author

selmanozleyen commented Feb 6, 2026

OK, I think you’re making things too complicated. Why not just delete WorkerHandle and do

I agree actually. I wanted to work with the old choices but I guess WorkerHandle doesn't make sense with Sampler API anyway.

Copy link
Collaborator

@ilan-gold ilan-gold left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for this that would fail on main? If we're fixing something, I would guess something is broken?

chunks_per_request = split_given_size(chunks, self._preload_nchunks)
batch_indices = np.arange(in_memory_size)
split_batch_indices = split_given_size(batch_indices, self._batch_size)
batch_rng = _spawn_worker_rng(self._rng, worker_info.id) if worker_info else self._rng
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, why define a _spawn_worker_rng instead of just using np.random.default_rng(worker_info.seed)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah sorry didn't see your "and later" part in your last comment.

I don't think it would be consistent behavior then because I want the rngs created to be derived from the root ChunkSampler._rng. Each workers rngs would be also reproducible by the root rng this way and wouldn't have to rely on setting torch's seed.

Although ChunkSampler doesn't provide determinism due to workers yields being unordered we can still expect what each worker produces is reproducible within itself by the main rng we gave.

I have this idea because I am trying to mimic jax's way of handling rng's which I think is a good way to handle it.
In jax we would just use

rng, batch_rng = rng.split(2)

this way we have a hierarchy of rng's

@selmanozleyen
Copy link
Member Author

selmanozleyen commented Feb 10, 2026

@ilan-gold @flying-sheep even if we were to use torch.seed we would need to use it in the root ChunkSampler._rng level again since we want everything to be reproducable by torch.seed.

If we want explicit torch support, if no rng is provided we can create the base rng with torch.seed if there is torch present. But I think it's better to explicitly document: "This Loader doesn't use the torch random number generator therefore setting torch.seed will not effect the reproducilibity. For reproducibility see ..."

Since we take rng instead of seed we make this even clearer I think.

Copy link
Collaborator

@ilan-gold ilan-gold left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add this behavior explanation to the docs along with a warning that using torch is neither encouraged nor guaranteed to behave as people expect?

Comment on lines +20 to +25
class WorkerInfo(NamedTuple):
"""Minimal worker info for RNG handling."""

id: int
num_workers: int

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need our own NamedTuple if torch provides identical properties?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah wanted to talk on this. Instead of this would worker_info_fn to ChunkSampler.__init__make sense? where worker_info_fn would be called and return maybe a similar typed dict. This was we stay decoupled from torch

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of this would worker_info_fn to ChunkSampler.__init__make sense? where worker_info_fn would be called and return maybe a similar typed dict. This was we stay decoupled from torch

But isn't the whole point of this PR to be coupled to torch?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think currently the codebase is coupled to torch internally because we have self._get_worker_handle calls internally. The fact that we had to create MockWorkerHandles and MockSamplers to have tests I think is convincing enough that we can't just inject what torch needs minimally.

Yes I get that generalizing towards the direction of what torch expect isn't useful since torch will probably be the only library that will leave worker sharding to the sampler level. I also get that while you are unsure about supporting torch compatibility you don't want to change the interface so I think we can drop the worker_info_fn idea if so. I just think even if there is only one case possible when we generalize something it creates clarity and says "we are doing this for torch compatilibity layer, but if there were other cases we could also support it".

But beyond the worker_info_fn idea, whouldn't be a typed dict return type a better idea to explicitly state what we minimally need from torch? Also we wouldn't have to type it whatever torch class there is.

wdyt?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove and open a separate PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done #147

Copy link
Member

@flying-sheep flying-sheep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@ilan-gold ilan-gold added the skip-gpu-ci Whether gpu ci should be skipped label Feb 12, 2026
@ilan-gold ilan-gold enabled auto-merge (squash) February 12, 2026 15:51
@ilan-gold ilan-gold disabled auto-merge February 12, 2026 17:33
@ilan-gold ilan-gold merged commit 747d1c2 into scverse:main Feb 12, 2026
13 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

skip-gpu-ci Whether gpu ci should be skipped

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants