fix: rng handling with workers#135
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #135 +/- ##
==========================================
- Coverage 93.36% 91.49% -1.88%
==========================================
Files 10 11 +1
Lines 829 811 -18
==========================================
- Hits 774 742 -32
- Misses 55 69 +14
🚀 New features to boost your workflow:
|
np.random.default_rng().shuffle(batch_indices) each iternp.random.default_rng().shuffle(batch_indices) each iter
|
@ilan-gold could you rename to "Rng handling with workers" since we split this PR |
Co-authored-by: Philipp A. <flying-sheep@web.de>
np.random.default_rng().shuffle(batch_indices) each iter|
OK, I think you’re making things too complicated. Why not just delete worker_info = None
if ...:
worker_info = get_worker_info()and later batch_rng = np.random.default_generator(worker_info.seed) if worker_info else self._rngor so? |
I agree actually. I wanted to work with the old choices but I guess |
ilan-gold
left a comment
There was a problem hiding this comment.
Can you add a test for this that would fail on main? If we're fixing something, I would guess something is broken?
| chunks_per_request = split_given_size(chunks, self._preload_nchunks) | ||
| batch_indices = np.arange(in_memory_size) | ||
| split_batch_indices = split_given_size(batch_indices, self._batch_size) | ||
| batch_rng = _spawn_worker_rng(self._rng, worker_info.id) if worker_info else self._rng |
There was a problem hiding this comment.
again, why define a _spawn_worker_rng instead of just using np.random.default_rng(worker_info.seed)?
There was a problem hiding this comment.
ah sorry didn't see your "and later" part in your last comment.
I don't think it would be consistent behavior then because I want the rngs created to be derived from the root ChunkSampler._rng. Each workers rngs would be also reproducible by the root rng this way and wouldn't have to rely on setting torch's seed.
Although ChunkSampler doesn't provide determinism due to workers yields being unordered we can still expect what each worker produces is reproducible within itself by the main rng we gave.
I have this idea because I am trying to mimic jax's way of handling rng's which I think is a good way to handle it.
In jax we would just use
rng, batch_rng = rng.split(2)this way we have a hierarchy of rng's
Co-authored-by: Ilan Gold <ilanbassgold@gmail.com>
|
@ilan-gold @flying-sheep even if we were to use torch.seed we would need to use it in the root ChunkSampler._rng level again since we want everything to be reproducable by torch.seed. If we want explicit torch support, if no rng is provided we can create the base rng with torch.seed if there is torch present. But I think it's better to explicitly document: "This Loader doesn't use the torch random number generator therefore setting torch.seed will not effect the reproducilibity. For reproducibility see ..." Since we take rng instead of seed we make this even clearer I think. |
| class WorkerInfo(NamedTuple): | ||
| """Minimal worker info for RNG handling.""" | ||
|
|
||
| id: int | ||
| num_workers: int | ||
|
|
There was a problem hiding this comment.
Why do we need our own NamedTuple if torch provides identical properties?
There was a problem hiding this comment.
yeah wanted to talk on this. Instead of this would worker_info_fn to ChunkSampler.__init__make sense? where worker_info_fn would be called and return maybe a similar typed dict. This was we stay decoupled from torch
There was a problem hiding this comment.
Instead of this would worker_info_fn to ChunkSampler.__init__make sense? where worker_info_fn would be called and return maybe a similar typed dict. This was we stay decoupled from torch
But isn't the whole point of this PR to be coupled to torch?
There was a problem hiding this comment.
I think currently the codebase is coupled to torch internally because we have self._get_worker_handle calls internally. The fact that we had to create MockWorkerHandles and MockSamplers to have tests I think is convincing enough that we can't just inject what torch needs minimally.
Yes I get that generalizing towards the direction of what torch expect isn't useful since torch will probably be the only library that will leave worker sharding to the sampler level. I also get that while you are unsure about supporting torch compatibility you don't want to change the interface so I think we can drop the worker_info_fn idea if so. I just think even if there is only one case possible when we generalize something it creates clarity and says "we are doing this for torch compatilibity layer, but if there were other cases we could also support it".
But beyond the worker_info_fn idea, whouldn't be a typed dict return type a better idea to explicitly state what we minimally need from torch? Also we wouldn't have to type it whatever torch class there is.
wdyt?
There was a problem hiding this comment.
Please remove and open a separate PR.
hi @ilan-gold @flying-sheep ,
The current code doesn't use spawn and uses kind of unconventional worker rng handling. For example this line ensures each worker returns the same rng for each worker because it was used to sharing chunks across workers
annbatch/src/annbatch/utils.py
Lines 127 to 129 in 9bbaa8b
Another problem is
annbatch/src/annbatch/samplers/_chunk_sampler.py
Lines 164 to 166 in 9bbaa8b
Here is the proposed way to handle rngs now
why we should use spawn: https://numpy.org/doc/stable/reference/random/multithreading.html
Note: we can talk about the details of
WorkerHandler.__init__, it depends on the expected usage of it I guess but I just wanted to show what I was talking about in the standup