| 
3 | 3 | import itertools  | 
4 | 4 | import sys  | 
5 | 5 | from collections import defaultdict  | 
6 |  | -from collections.abc import Iterable, Sequence  | 
 | 6 | +from collections.abc import Generator, Iterable, Sequence  | 
7 | 7 | from contextlib import suppress  | 
8 | 8 | from functools import partial  | 
9 | 9 | from operator import itemgetter  | 
@@ -126,11 +126,10 @@ def __init__(  | 
126 | 126 |         self._cdims_default = cdims  | 
127 | 127 | 
 
  | 
128 | 128 |         if len({learner.__class__ for learner in self.learners}) > 1:  | 
129 |  | -            raise TypeError(  | 
130 |  | -                "A BalacingLearner can handle only one type" " of learners."  | 
131 |  | -            )  | 
 | 129 | +            raise TypeError("A BalacingLearner can handle only one type of learners.")  | 
132 | 130 | 
 
  | 
133 | 131 |         self.strategy: STRATEGY_TYPE = strategy  | 
 | 132 | +        self._gen: Generator | None = None  | 
134 | 133 | 
 
  | 
135 | 134 |     def new(self) -> BalancingLearner:  | 
136 | 135 |         """Create a new `BalancingLearner` with the same parameters."""  | 
@@ -288,27 +287,16 @@ def _ask_and_tell_based_on_cycle(  | 
288 | 287 |     def _ask_and_tell_based_on_sequential(  | 
289 | 288 |         self, n: int  | 
290 | 289 |     ) -> tuple[list[tuple[Int, Any]], list[float]]:  | 
 | 290 | +        if self._gen is None:  | 
 | 291 | +            self._gen = _sequential_generator(self.learners)  | 
291 | 292 |         points: list[tuple[Int, Any]] = []  | 
292 | 293 |         loss_improvements: list[float] = []  | 
293 |  | -        learner_index = 0  | 
294 |  | - | 
295 |  | -        while len(points) < n:  | 
296 |  | -            learner = self.learners[learner_index]  | 
297 |  | -            if learner.done():  # type: ignore[attr-defined]  | 
298 |  | -                if learner_index == len(self.learners) - 1:  | 
299 |  | -                    break  | 
300 |  | -                learner_index += 1  | 
301 |  | -                continue  | 
302 |  | - | 
303 |  | -            point, loss_improvement = learner.ask(n=1)  | 
304 |  | -            if not point:  # if learner is exhausted, we don't get points  | 
305 |  | -                if learner_index == len(self.learners) - 1:  | 
306 |  | -                    break  | 
307 |  | -                learner_index += 1  | 
308 |  | -                continue  | 
309 |  | -            points.append((learner_index, point[0]))  | 
310 |  | -            loss_improvements.append(loss_improvement[0])  | 
311 |  | -            self.tell_pending((learner_index, point[0]))  | 
 | 294 | +        for learner_index, point, loss_improvement in self._gen:  | 
 | 295 | +            points.append((learner_index, point))  | 
 | 296 | +            loss_improvements.append(loss_improvement)  | 
 | 297 | +            self.tell_pending((learner_index, point))  | 
 | 298 | +            if len(points) >= n:  | 
 | 299 | +                break  | 
312 | 300 | 
 
  | 
313 | 301 |         return points, loss_improvements  | 
314 | 302 | 
 
  | 
@@ -629,3 +617,27 @@ def __getstate__(self) -> tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]:  | 
629 | 617 |     def __setstate__(self, state: tuple[list[BaseLearner], CDIMS_TYPE, STRATEGY_TYPE]):  | 
630 | 618 |         learners, cdims, strategy = state  | 
631 | 619 |         self.__init__(learners, cdims=cdims, strategy=strategy)  # type: ignore[misc]  | 
 | 620 | + | 
 | 621 | + | 
 | 622 | +def _sequential_generator(  | 
 | 623 | +    learners: list[BaseLearner],  | 
 | 624 | +) -> Generator[tuple[int, Any, float], None, None]:  | 
 | 625 | +    learner_index = 0  | 
 | 626 | +    if not hasattr(learners[0], "done"):  | 
 | 627 | +        msg = "All learners must have a `done` method to use the 'sequential' strategy."  | 
 | 628 | +        raise ValueError(msg)  | 
 | 629 | +    while True:  | 
 | 630 | +        learner = learners[learner_index]  | 
 | 631 | +        if learner.done():  # type: ignore[attr-defined]  | 
 | 632 | +            if learner_index == len(learners) - 1:  | 
 | 633 | +                return  | 
 | 634 | +            learner_index += 1  | 
 | 635 | +            continue  | 
 | 636 | + | 
 | 637 | +        point, loss_improvement = learner.ask(n=1)  | 
 | 638 | +        if not point:  # if learner is exhausted, we don't get points  | 
 | 639 | +            if learner_index == len(learners) - 1:  | 
 | 640 | +                return  | 
 | 641 | +            learner_index += 1  | 
 | 642 | +            continue  | 
 | 643 | +        yield learner_index, point[0], loss_improvement[0]  | 
0 commit comments