feat: ChunkSampler with replacement when n_iters is set#143
feat: ChunkSampler with replacement when n_iters is set#143selmanozleyen wants to merge 27 commits intoscverse:mainfrom
n_iters is set#143Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #143 +/- ##
==========================================
- Coverage 93.71% 91.88% -1.83%
==========================================
Files 11 11
Lines 811 863 +52
==========================================
+ Hits 760 793 +33
- Misses 51 70 +19
🚀 New features to boost your workflow:
|
no need to dynamically compute
| Must be ``False`` when ``n_iters`` is set. | ||
| n_iters | ||
| If set, enables with-replacement sampling for exactly this many | ||
| batches instead of epoch-based iteration. |
There was a problem hiding this comment.
Probably makes sense to make this explicit i.e., n_iters can override without-replacement sampling behavior, drop_last should be renamed to drop_undersized (we can keep the top-level Loader API of drop_last), and then we have a with_replacement arg
There was a problem hiding this comment.
Then with_replacement would need to be None by default in the signature so that it resolves to True when n_iters is None.
| last = chunk_pool[-1] | ||
| if last.stop - last.start < self._chunk_size: | ||
| new_stop = min(last.start + self._chunk_size, stop) | ||
| new_start = new_stop - self._chunk_size | ||
| chunk_pool[-1] = slice(new_start, new_stop) |
There was a problem hiding this comment.
And then this just becomes drop the last one or not, which is also the smallest
| num_workers = worker_handle.num_workers | ||
| worker_id = worker_handle._worker_info.id | ||
| base, remainder = divmod(n_iters, num_workers) | ||
| n_iters = base + (1 if worker_id < remainder else 0) |
There was a problem hiding this comment.
Very tempted to just error here i.e., we only support torch in the basic chunking use-case. I don't think we should be encouraging people to do this but @felix0097 am curious what you think
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
| If ``None``, it's set to True if ``n_iters`` is provided, otherwise False. | ||
| n_iters | ||
| Number of batches to yield. Required when ``with_replacement`` is True. | ||
| Can't be provided if with_replacement is False. |
There was a problem hiding this comment.
This isn't what I meant, sorry for being unclear - if n_iters is passed in, we just use that instead of deriving from n_obs to define the epoch. This way, there's no "default" behavior switching i.e., If ``None``, it's set to True if ``n_iters`` is provided, otherwise False. is unnecessary. To highlight
n_itersis set, then independent ofwith_replacementthat number of iterations is donewith_replacement=Truemeans thatn_itershas to be setwith_replacement=Falsecan usen_itersif it's set.
Does this make sense? I've proposed three changes that highlight this logic
There was a problem hiding this comment.
I am not sure, as without replacement would mean there is no duplicate of what you see also right? If you think its fine then I will do it. I though of just supporting truncating but then we would need to error when possible_n_iters < n_iters.
There was a problem hiding this comment.
I am not sure, as without replacement would mean there is no duplicate of what you see also right? If you think its fine then I will do it. I though of just supporting truncating but then we would need to error when possible_n_iters < n_iters.
This seems reasonable - I hadn't considered it, but it could just be part of validation to check n_iters against n_obs if sampling without replacement.
| start, stop = self._mask.start or 0, self._mask.stop or n_obs | ||
| total_obs = stop - start | ||
| return total_obs // self._batch_size if self._drop_last else math.ceil(total_obs / self._batch_size) | ||
| return self._possible_n_iters(n_obs) if not self._with_replacement else self._n_iters |
There was a problem hiding this comment.
Then this line is clearer because you wouldn't rely on _with_replacement but instead just _n_iters
| return self._possible_n_iters(n_obs) if not self._with_replacement else self._n_iters | |
| return self._possible_n_iters(n_obs) if not self._n_iters is None else self._n_iters |
| if with_replacement is None: | ||
| with_replacement = n_iters is not None | ||
| if with_replacement and n_iters is None: | ||
| raise ValueError("n_iters is required when with_replacement is True.") | ||
| if not with_replacement and n_iters is not None: | ||
| raise ValueError("n_iters is only supported when with_replacement is True.") | ||
| self._n_iters, self._with_replacement = n_iters, with_replacement |
There was a problem hiding this comment.
| if with_replacement is None: | |
| with_replacement = n_iters is not None | |
| if with_replacement and n_iters is None: | |
| raise ValueError("n_iters is required when with_replacement is True.") | |
| if not with_replacement and n_iters is not None: | |
| raise ValueError("n_iters is only supported when with_replacement is True.") | |
| self._n_iters, self._with_replacement = n_iters, with_replacement | |
| if with_replacement and n_iters is None: | |
| raise ValueError("n_iters is required when with_replacement is True.") | |
| self._n_iters, self._with_replacement = n_iters, with_replacement |
| shuffle: bool = False, | ||
| drop_last: bool = False, | ||
| drop_undersized: bool = False, | ||
| with_replacement: bool | None = None, |
There was a problem hiding this comment.
| with_replacement: bool | None = None, | |
| with_replacement: bool = False, |
we decided it might be smart to add with replacement sampling to continue for the categorical sampler: #119 (comment)