2121)
2222from pymc .exceptions import ImputationWarning
2323
24+ # Turn all warnings into errors for this module
25+ pytestmark = pytest .mark .filterwarnings ("error" )
26+
2427
2528@pytest .fixture (scope = "module" )
2629def eight_schools_params ():
@@ -635,7 +638,9 @@ def test_include_transformed(self):
635638 pm .Uniform ("p" , 0 , 1 )
636639
637640 # First check that the default is to exclude the transformed variables
638- sample_kwargs = dict (tune = 5 , draws = 7 , chains = 2 , cores = 1 )
641+ sample_kwargs = dict (
642+ tune = 5 , draws = 7 , chains = 2 , cores = 1 , compute_convergence_checks = False
643+ )
639644 inference_data = pm .sample (** sample_kwargs , step = pm .Metropolis ())
640645 assert "p_interval__" not in inference_data .posterior
641646
@@ -647,6 +652,17 @@ def test_include_transformed(self):
647652 )
648653 assert "p_interval__" in inference_data .posterior
649654
655+ @pytest .mark .parametrize ("chains" , (1 , 2 ))
656+ def test_single_chain (self , chains ):
657+ # Test that no UserWarning is raised when sampling with NUTS defaults
658+
659+ # When this test was added, a `UserWarning: More chains (500) than draws (1)` used to be issued
660+ # when sampling with a single chain
661+ warnings .simplefilter ("error" )
662+ with pm .Model ():
663+ pm .Normal ("x" )
664+ pm .sample (chains = chains , return_inferencedata = True )
665+
650666
651667class TestPyMCWarmupHandling :
652668 @pytest .mark .parametrize ("save_warmup" , [False , True ])
0 commit comments