@@ -292,11 +292,11 @@ def sample(
292292 chains : int | None = None ,
293293 cores : int | None = None ,
294294 random_seed : RandomState = None ,
295+ step = None ,
296+ external_sampler : ExternalSampler | None = None ,
295297 progressbar : bool | ProgressBarType = True ,
296298 progressbar_theme : Theme | None = default_progress_theme ,
297- step = None ,
298299 var_names : Sequence [str ] | None = None ,
299- nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
300300 initvals : StartDict | Sequence [StartDict | None ] | None = None ,
301301 init : str = "auto" ,
302302 jitter_max_retries : int = 10 ,
@@ -324,11 +324,11 @@ def sample(
324324 chains : int | None = None ,
325325 cores : int | None = None ,
326326 random_seed : RandomState = None ,
327+ step = None ,
328+ external_sampler : ExternalSampler | None = None ,
327329 progressbar : bool | ProgressBarType = True ,
328330 progressbar_theme : Theme | None = default_progress_theme ,
329- step = None ,
330331 var_names : Sequence [str ] | None = None ,
331- nuts_sampler : Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = "pymc" ,
332332 initvals : StartDict | Sequence [StartDict | None ] | None = None ,
333333 init : str = "auto" ,
334334 jitter_max_retries : int = 10 ,
@@ -356,11 +356,11 @@ def sample(
356356 chains : int | None = None ,
357357 cores : int | None = None ,
358358 random_seed : RandomState = None ,
359+ step = None ,
360+ external_sampler : ExternalSampler | None = None ,
359361 progressbar : bool | ProgressBarType = True ,
360362 progressbar_theme : Theme | None = None ,
361- step = None ,
362363 var_names : Sequence [str ] | None = None ,
363- nuts_sampler : None | Literal ["pymc" , "nutpie" , "numpyro" , "blackjax" ] = None ,
364364 initvals : StartDict | Sequence [StartDict | None ] | None = None ,
365365 init : str = "auto" ,
366366 jitter_max_retries : int = 10 ,
@@ -407,6 +407,12 @@ def sample(
407407 A ``TypeError`` will be raised if a legacy :py:class:`~numpy.random.RandomState` object is passed.
408408 We no longer support ``RandomState`` objects because their seeding mechanism does not allow
409409 easy spawning of new independent random streams that are needed by the step methods.
410+ step : function or iterable of functions, optional
411+ A step function or collection of functions. If there are variables without step methods,
412+ step methods for those variables will be assigned automatically. By default the NUTS step
413+ method will be used, if appropriate to the model. Not compatible with external_sampler
414+ external_sampler: ExternalSampler, optional
415+ An external sampler to sample the whole model. Not compatible with step.
410416 progressbar: bool or ProgressType, optional
411417 How and whether to display the progress bar. If False, no progress bar is displayed. Otherwise, you can ask
412418 for one of the following:
@@ -419,10 +425,6 @@ def sample(
419425 are also displayed.
420426
421427 If True, the default is "split+stats" is used.
422- step : function or iterable of functions
423- A step function or collection of functions. If there are variables without step methods,
424- step methods for those variables will be assigned automatically. By default the NUTS step
425- method will be used, if appropriate to the model.
426428 var_names : list of str, optional
427429 Names of variables to be stored in the trace. Defaults to all free variables and deterministics.
428430 nuts_sampler : str
@@ -608,35 +610,38 @@ def joined_blas_limiter():
608610 rngs = get_random_generator (random_seed ).spawn (chains )
609611 random_seed_list = [rng .integers (2 ** 30 ) for rng in rngs ]
610612
611- if step is None and nuts_sampler not in (None , "pymc" ):
612- # Temporarily instantiate external samplers for user, for backwards-compat
613+ if "nuts_sampler" in kwargs :
614+ # Transition backwards-compatibility
615+ nuts_sampler = kwargs .pop ("nuts_sampler" )
613616 warnings .warn (
614617 f"Setting `pm.sample(nuts_sampler='{ nuts_sampler } , nuts_sampler_kwargs=...)'` is deprecated.\n "
615- f"Use `pm.sample(step =pm.external.{ nuts_sampler .capitalize ()} (**nuts_sampler_kwargs))` instead" ,
618+ f"Use `pm.sample(external_sampler =pm.external.{ nuts_sampler .capitalize ()} (**nuts_sampler_kwargs))` instead" ,
616619 FutureWarning ,
617620 )
618621 from pymc .sampling import external
619622
620- step = getattr (external , nuts_sampler .capitalize ())(
623+ external_sampler = getattr (external , nuts_sampler .capitalize ())(
621624 model = model ,
622625 ** (nuts_sampler_kwargs or {}),
623626 )
624627 nuts_sampler_kwargs = None
625628
626- if isinstance (step , list | tuple ) and len (step ) == 1 :
627- [step ] = step
629+ if external_sampler is not None :
630+ if step is not None :
631+ raise ValueError ("`step` and `external_sampler` cannot be used together" )
628632
629- if isinstance (step , ExternalSampler ):
630- if step .model is not model :
631- raise ValueError ("External step model does not match model detected by sample" )
633+ if external_sampler .model is not model :
634+ raise ValueError (
635+ "External sampler model does not match model detected by sample function"
636+ )
632637 if nuts_sampler_kwargs :
633638 raise ValueError (
634639 f"{ nuts_sampler_kwargs = } should be passed when constructing external sampler"
635640 )
636641 if "nuts" in kwargs :
637- kwargs .update (kwargs [ "nuts" ] .pop ())
642+ kwargs .update (kwargs .pop ("nuts" ))
638643 with joined_blas_limiter ():
639- return step .sample (
644+ return external_sampler .sample (
640645 tune = tune ,
641646 draws = draws ,
642647 chains = chains ,
0 commit comments