2020from pytensor .link .jax .dispatch .random import numpyro_available # noqa: E402
2121
2222
23- def random_function (* args , ** kwargs ):
23+ def compile_random_function (* args , ** kwargs ):
2424 with pytest .warns (
2525 UserWarning , match = r"The RandomType SharedVariables \[.+\] will not be used"
2626 ):
@@ -35,7 +35,7 @@ def test_random_RandomStream():
3535 srng = RandomStream (seed = 123 )
3636 out = srng .normal () - srng .normal ()
3737
38- fn = random_function ([], out , mode = jax_mode )
38+ fn = compile_random_function ([], out , mode = jax_mode )
3939 jax_res_1 = fn ()
4040 jax_res_2 = fn ()
4141
@@ -48,7 +48,7 @@ def test_random_updates(rng_ctor):
4848 rng = shared (original_value , name = "original_rng" , borrow = False )
4949 next_rng , x = pt .random .normal (name = "x" , rng = rng ).owner .outputs
5050
51- f = random_function ([], [x ], updates = {rng : next_rng }, mode = jax_mode )
51+ f = compile_random_function ([], [x ], updates = {rng : next_rng }, mode = jax_mode )
5252 assert f () != f ()
5353
5454 # Check that original rng variable content was not overwritten when calling jax_typify
@@ -79,7 +79,7 @@ def test_random_updates_input_storage_order():
7979 # This function replaces inp by input_shared in the update expression
8080 # This is what caused the RNG to appear later than inp_shared in the input_storage
8181
82- fn = random_function (
82+ fn = compile_random_function (
8383 inputs = [],
8484 outputs = [],
8585 updates = {inp_shared : inp_update },
@@ -453,7 +453,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
453453 else :
454454 rng = shared (np .random .RandomState (29402 ))
455455 g = rv_op (* dist_params , size = (10_000 ,) + base_size , rng = rng )
456- g_fn = random_function (dist_params , g , mode = jax_mode )
456+ g_fn = compile_random_function (dist_params , g , mode = jax_mode )
457457 samples = g_fn (
458458 * [
459459 i .tag .test_value
@@ -477,7 +477,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
477477def test_random_bernoulli (size ):
478478 rng = shared (np .random .RandomState (123 ))
479479 g = pt .random .bernoulli (0.5 , size = (1000 ,) + size , rng = rng )
480- g_fn = random_function ([], g , mode = jax_mode )
480+ g_fn = compile_random_function ([], g , mode = jax_mode )
481481 samples = g_fn ()
482482 np .testing .assert_allclose (samples .mean (axis = 0 ), 0.5 , 1 )
483483
@@ -488,7 +488,7 @@ def test_random_mvnormal():
488488 mu = np .ones (4 )
489489 cov = np .eye (4 )
490490 g = pt .random .multivariate_normal (mu , cov , size = (10000 ,), rng = rng )
491- g_fn = random_function ([], g , mode = jax_mode )
491+ g_fn = compile_random_function ([], g , mode = jax_mode )
492492 samples = g_fn ()
493493 np .testing .assert_allclose (samples .mean (axis = 0 ), mu , atol = 0.1 )
494494
@@ -503,7 +503,7 @@ def test_random_mvnormal():
503503def test_random_dirichlet (parameter , size ):
504504 rng = shared (np .random .RandomState (123 ))
505505 g = pt .random .dirichlet (parameter , size = (1000 ,) + size , rng = rng )
506- g_fn = random_function ([], g , mode = jax_mode )
506+ g_fn = compile_random_function ([], g , mode = jax_mode )
507507 samples = g_fn ()
508508 np .testing .assert_allclose (samples .mean (axis = 0 ), 0.5 , 1 )
509509
@@ -513,29 +513,29 @@ def test_random_choice():
513513 num_samples = 10000
514514 rng = shared (np .random .RandomState (123 ))
515515 g = pt .random .choice (np .arange (4 ), size = num_samples , rng = rng )
516- g_fn = random_function ([], g , mode = jax_mode )
516+ g_fn = compile_random_function ([], g , mode = jax_mode )
517517 samples = g_fn ()
518518 np .testing .assert_allclose (np .sum (samples == 3 ) / num_samples , 0.25 , 2 )
519519
520520 # `replace=False` produces unique results
521521 rng = shared (np .random .RandomState (123 ))
522522 g = pt .random .choice (np .arange (100 ), replace = False , size = 99 , rng = rng )
523- g_fn = random_function ([], g , mode = jax_mode )
523+ g_fn = compile_random_function ([], g , mode = jax_mode )
524524 samples = g_fn ()
525525 assert len (np .unique (samples )) == 99
526526
527527 # We can pass an array with probabilities
528528 rng = shared (np .random .RandomState (123 ))
529529 g = pt .random .choice (np .arange (3 ), p = np .array ([1.0 , 0.0 , 0.0 ]), size = 10 , rng = rng )
530- g_fn = random_function ([], g , mode = jax_mode )
530+ g_fn = compile_random_function ([], g , mode = jax_mode )
531531 samples = g_fn ()
532532 np .testing .assert_allclose (samples , np .zeros (10 ))
533533
534534
535535def test_random_categorical ():
536536 rng = shared (np .random .RandomState (123 ))
537537 g = pt .random .categorical (0.25 * np .ones (4 ), size = (10000 , 4 ), rng = rng )
538- g_fn = random_function ([], g , mode = jax_mode )
538+ g_fn = compile_random_function ([], g , mode = jax_mode )
539539 samples = g_fn ()
540540 np .testing .assert_allclose (samples .mean (axis = 0 ), 6 / 4 , 1 )
541541
@@ -544,7 +544,7 @@ def test_random_permutation():
544544 array = np .arange (4 )
545545 rng = shared (np .random .RandomState (123 ))
546546 g = pt .random .permutation (array , rng = rng )
547- g_fn = random_function ([], g , mode = jax_mode )
547+ g_fn = compile_random_function ([], g , mode = jax_mode )
548548 permuted = g_fn ()
549549 with pytest .raises (AssertionError ):
550550 np .testing .assert_allclose (array , permuted )
@@ -554,7 +554,7 @@ def test_random_geometric():
554554 rng = shared (np .random .RandomState (123 ))
555555 p = np .array ([0.3 , 0.7 ])
556556 g = pt .random .geometric (p , size = (10_000 , 2 ), rng = rng )
557- g_fn = random_function ([], g , mode = jax_mode )
557+ g_fn = compile_random_function ([], g , mode = jax_mode )
558558 samples = g_fn ()
559559 np .testing .assert_allclose (samples .mean (axis = 0 ), 1 / p , rtol = 0.1 )
560560 np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt ((1 - p ) / p ** 2 ), rtol = 0.1 )
@@ -565,7 +565,7 @@ def test_negative_binomial():
565565 n = np .array ([10 , 40 ])
566566 p = np .array ([0.3 , 0.7 ])
567567 g = pt .random .negative_binomial (n , p , size = (10_000 , 2 ), rng = rng )
568- g_fn = random_function ([], g , mode = jax_mode )
568+ g_fn = compile_random_function ([], g , mode = jax_mode )
569569 samples = g_fn ()
570570 np .testing .assert_allclose (samples .mean (axis = 0 ), n * (1 - p ) / p , rtol = 0.1 )
571571 np .testing .assert_allclose (
@@ -579,7 +579,7 @@ def test_binomial():
579579 n = np .array ([10 , 40 ])
580580 p = np .array ([0.3 , 0.7 ])
581581 g = pt .random .binomial (n , p , size = (10_000 , 2 ), rng = rng )
582- g_fn = random_function ([], g , mode = jax_mode )
582+ g_fn = compile_random_function ([], g , mode = jax_mode )
583583 samples = g_fn ()
584584 np .testing .assert_allclose (samples .mean (axis = 0 ), n * p , rtol = 0.1 )
585585 np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt (n * p * (1 - p )), rtol = 0.1 )
@@ -594,7 +594,7 @@ def test_beta_binomial():
594594 a = np .array ([1.5 , 13 ])
595595 b = np .array ([0.5 , 9 ])
596596 g = pt .random .betabinom (n , a , b , size = (10_000 , 2 ), rng = rng )
597- g_fn = random_function ([], g , mode = jax_mode )
597+ g_fn = compile_random_function ([], g , mode = jax_mode )
598598 samples = g_fn ()
599599 np .testing .assert_allclose (samples .mean (axis = 0 ), n * a / (a + b ), rtol = 0.1 )
600600 np .testing .assert_allclose (
@@ -612,7 +612,7 @@ def test_multinomial():
612612 n = np .array ([10 , 40 ])
613613 p = np .array ([[0.3 , 0.7 , 0.0 ], [0.1 , 0.4 , 0.5 ]])
614614 g = pt .random .multinomial (n , p , size = (10_000 , 2 ), rng = rng )
615- g_fn = random_function ([], g , mode = jax_mode )
615+ g_fn = compile_random_function ([], g , mode = jax_mode )
616616 samples = g_fn ()
617617 np .testing .assert_allclose (samples .mean (axis = 0 ), n [..., None ] * p , rtol = 0.1 )
618618 np .testing .assert_allclose (
@@ -628,7 +628,7 @@ def test_vonmises_mu_outside_circle():
628628 mu = np .array ([- 30 , 40 ])
629629 kappa = np .array ([100 , 10 ])
630630 g = pt .random .vonmises (mu , kappa , size = (10_000 , 2 ), rng = rng )
631- g_fn = random_function ([], g , mode = jax_mode )
631+ g_fn = compile_random_function ([], g , mode = jax_mode )
632632 samples = g_fn ()
633633 np .testing .assert_allclose (
634634 samples .mean (axis = 0 ), (mu + np .pi ) % (2.0 * np .pi ) - np .pi , rtol = 0.1
@@ -728,15 +728,15 @@ def test_random_concrete_shape():
728728 rng = shared (np .random .RandomState (123 ))
729729 x_pt = pt .dmatrix ()
730730 out = pt .random .normal (0 , 1 , size = x_pt .shape , rng = rng )
731- jax_fn = random_function ([x_pt ], out , mode = jax_mode )
731+ jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
732732 assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
733733
734734
735735def test_random_concrete_shape_from_param ():
736736 rng = shared (np .random .RandomState (123 ))
737737 x_pt = pt .dmatrix ()
738738 out = pt .random .normal (x_pt , 1 , rng = rng )
739- jax_fn = random_function ([x_pt ], out , mode = jax_mode )
739+ jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
740740 assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
741741
742742
@@ -755,7 +755,7 @@ def test_random_concrete_shape_subtensor():
755755 rng = shared (np .random .RandomState (123 ))
756756 x_pt = pt .dmatrix ()
757757 out = pt .random .normal (0 , 1 , size = x_pt .shape [1 ], rng = rng )
758- jax_fn = random_function ([x_pt ], out , mode = jax_mode )
758+ jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
759759 assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
760760
761761
@@ -771,7 +771,7 @@ def test_random_concrete_shape_subtensor_tuple():
771771 rng = shared (np .random .RandomState (123 ))
772772 x_pt = pt .dmatrix ()
773773 out = pt .random .normal (0 , 1 , size = (x_pt .shape [0 ],), rng = rng )
774- jax_fn = random_function ([x_pt ], out , mode = jax_mode )
774+ jax_fn = compile_random_function ([x_pt ], out , mode = jax_mode )
775775 assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
776776
777777
@@ -782,5 +782,5 @@ def test_random_concrete_shape_graph_input():
782782 rng = shared (np .random .RandomState (123 ))
783783 size_pt = pt .scalar ()
784784 out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
785- jax_fn = random_function ([size_pt ], out , mode = jax_mode )
785+ jax_fn = compile_random_function ([size_pt ], out , mode = jax_mode )
786786 assert jax_fn (10 ).shape == (10 ,)
0 commit comments