@@ -146,6 +146,12 @@ class BaseRunner(metaclass=abc.ABCMeta):
146146 the point is present in ``runner.failed``.
147147 raise_if_retries_exceeded : bool, default: True
148148 Raise the error after a point ``x`` failed `retries`.
149+ dynamic_args_provider : callable, optional
150+ A callable that takes the learner as its sole argument and returns additional
151+ arguments to pass to the function being learned. This allows you to dynamically
152+ adjust parameters of the function based on the current state of the learner.
153+ If provided, the function will be called as `function(x, dynamic_args_provider(learner))`
154+ instead of just `function(x)`.
149155 allow_running_forever : bool, default: False
150156 Allow the runner to run forever when the goal is None.
151157
@@ -188,6 +194,7 @@ def __init__(
188194 shutdown_executor : bool = False ,
189195 retries : int = 0 ,
190196 raise_if_retries_exceeded : bool = True ,
197+ dynamic_args_provider : Callable [[LearnerType ], Any ] | None = None ,
191198 allow_running_forever : bool = False ,
192199 ):
193200 self .executor = _ensure_executor (executor )
@@ -228,6 +235,8 @@ def __init__(
228235 next , itertools .count ()
229236 ) # some unique id to be associated with each point
230237
238+ self .dynamic_args_provider = dynamic_args_provider
239+
231240 def _get_max_tasks (self ) -> int :
232241 return self ._max_tasks or _get_ncores (self .executor )
233242
@@ -432,6 +441,12 @@ class BlockingRunner(BaseRunner):
432441 the point is present in ``runner.failed``.
433442 raise_if_retries_exceeded : bool, default: True
434443 Raise the error after a point ``x`` failed `retries`.
444+ dynamic_args_provider : callable, optional
445+ A callable that takes the learner as its sole argument and returns additional
446+ arguments to pass to the function being learned. This allows you to dynamically
447+ adjust parameters of the function based on the current state of the learner.
448+ If provided, the function will be called as `function(x, dynamic_args_provider(learner))`
449+ instead of just `function(x)`.
435450
436451 Attributes
437452 ----------
@@ -476,6 +491,7 @@ def __init__(
476491 shutdown_executor : bool = False ,
477492 retries : int = 0 ,
478493 raise_if_retries_exceeded : bool = True ,
494+ dynamic_args_provider : Callable [[LearnerType ], Any ] | None = None ,
479495 ) -> None :
480496 if inspect .iscoroutinefunction (learner .function ):
481497 raise ValueError ("Coroutine functions can only be used with 'AsyncRunner'." )
@@ -497,7 +513,14 @@ def __init__(
497513 self ._run ()
498514
499515 def _submit (self , x : tuple [float , ...] | float | int ) -> FutureTypes :
500- return self .executor .submit (self .learner .function , x )
516+ args = (
517+ (x ,)
518+ if not self .dynamic_args_provider
519+ else (x , self .dynamic_args_provider (self .learner ))
520+ )
521+ if self .dynamic_args_provider :
522+ return self .executor .submit (self .learner .function , * args )
523+ return self .executor .submit (self .learner .function , * args )
501524
502525 def _run (self ) -> None :
503526 first_completed = concurrent .FIRST_COMPLETED
@@ -582,8 +605,12 @@ class AsyncRunner(BaseRunner):
582605 the point is present in ``runner.failed``.
583606 raise_if_retries_exceeded : bool, default: True
584607 Raise the error after a point ``x`` failed `retries`.
585- allow_running_forever : bool, default: True
586- If True, the runner will run forever if the goal is not provided.
608+ dynamic_args_provider : callable, optional
609+ A callable that takes the learner as its sole argument and returns additional
610+ arguments to pass to the function being learned. This allows you to dynamically
611+ adjust parameters of the function based on the current state of the learner.
612+ If provided, the function will be called as `function(x, dynamic_args_provider(learner))`
613+ instead of just `function(x)`.
587614
588615 Attributes
589616 ----------
@@ -636,6 +663,7 @@ def __init__(
636663 ioloop = None ,
637664 retries : int = 0 ,
638665 raise_if_retries_exceeded : bool = True ,
666+ dynamic_args_provider : Callable [[LearnerType ], Any ] | None = None ,
639667 ) -> None :
640668 if (
641669 executor is None
@@ -666,6 +694,7 @@ def __init__(
666694 shutdown_executor = shutdown_executor ,
667695 retries = retries ,
668696 raise_if_retries_exceeded = raise_if_retries_exceeded ,
697+ dynamic_args_provider = dynamic_args_provider ,
669698 allow_running_forever = True ,
670699 )
671700 self .ioloop = ioloop or asyncio .get_event_loop ()
@@ -694,10 +723,15 @@ def __init__(
694723
695724 def _submit (self , x : Any ) -> asyncio .Task | asyncio .Future :
696725 ioloop = self .ioloop
726+ args = (
727+ (x ,)
728+ if not self .dynamic_args_provider
729+ else (x , self .dynamic_args_provider (self .learner ))
730+ )
697731 if inspect .iscoroutinefunction (self .learner .function ):
698- return ioloop .create_task (self .learner .function (x ))
732+ return ioloop .create_task (self .learner .function (* args ))
699733 else :
700- return ioloop .run_in_executor (self .executor , self .learner .function , x )
734+ return ioloop .run_in_executor (self .executor , self .learner .function , * args )
701735
702736 def status (self ) -> str :
703737 """Return the runner status as a string.
0 commit comments