@@ -1958,3 +1958,33 @@ def test_vectorize_over_posterior_matches_sample():
19581958            atol = 0.6  /  np .sqrt (10000 ),
19591959        )
19601960        assert  np .all (np .abs (vect_obs  -  x_posterior [..., None ]) <  1 )
1961+ 
1962+ 
1963+ def  test_vectorize_over_posterior_with_intermediate_rvs ():
1964+     with  pm .Model () as  model :
1965+         a  =  pm .Normal ("a" )
1966+         b  =  pm .Normal .dist (a )
1967+         c  =  b  +  1 
1968+         d  =  pm .Normal .dist (c )
1969+         idata  =  pm .sample_prior_predictive (100 , var_names = ["a" ])
1970+         idata .add_groups ({"posterior" : idata .prior })
1971+     _ , _ , vectorized_no_intermediate  =  vectorize_over_posterior (
1972+         outputs = [b , c , d ],
1973+         posterior = idata .posterior ,
1974+         input_rvs = [a ],
1975+         allow_rvs_in_graph = True ,
1976+     )
1977+     [vectorized_intermediate_rvs ] =  vectorize_over_posterior (
1978+         outputs = [d ],
1979+         posterior = idata .posterior ,
1980+         input_rvs = [a ],
1981+         allow_rvs_in_graph = True ,
1982+     )
1983+     assert  vectorized_no_intermediate .type .shape  ==  (1 , 100 )
1984+     assert  vectorized_no_intermediate .type .shape  ==  vectorized_intermediate_rvs .type .shape 
1985+     [a_ancestor1 ] =  get_var_by_name ([vectorized_no_intermediate ], "a" )
1986+     [a_ancestor2 ] =  get_var_by_name ([vectorized_intermediate_rvs ], "a" )
1987+     assert  isinstance (a_ancestor1 , TensorConstant )
1988+     assert  np .array_equiv (a_ancestor1 .eval (), idata .posterior .a .data )
1989+     assert  isinstance (a_ancestor2 , TensorConstant )
1990+     assert  np .array_equiv (a_ancestor2 .eval (), idata .posterior .a .data )
0 commit comments