4747 ignore_logprob ,
4848 logcdf ,
4949 logp ,
50+ reconsider_logprob ,
5051)
5152from pymc .logprob .abstract import get_measurable_outputs
5253from pymc .model import Model , Potential
@@ -315,7 +316,7 @@ def test_unexpected_rvs():
315316 model .logp ()
316317
317318
318- def test_ignore_logprob_basic ():
319+ def test_ignore_reconsider_logprob_basic ():
319320 x = Normal .dist ()
320321 (measurable_x_out ,) = get_measurable_outputs (x .owner .op , x .owner )
321322 assert measurable_x_out is x .owner .outputs [1 ]
@@ -328,18 +329,34 @@ def test_ignore_logprob_basic():
328329 assert get_measurable_outputs (new_x .owner .op , new_x .owner ) == []
329330
330331 # Test that it will not clone a variable that is already unmeasurable
331- new_new_x = ignore_logprob (new_x )
332- assert new_new_x is new_x
333-
334-
335- def test_ignore_logprob_model ():
336- # logp that does not depend on input
337- def logp (value , x ):
338- return value
332+ assert ignore_logprob (new_x ) is new_x
333+
334+ orig_x = reconsider_logprob (new_x )
335+ assert orig_x is not new_x
336+ assert isinstance (orig_x .owner .op , Normal )
337+ assert type (orig_x .owner .op ).__name__ == "NormalRV"
338+ # Confirm that it has measurable outputs again
339+ assert get_measurable_outputs (orig_x .owner .op , orig_x .owner ) == [orig_x .owner .outputs [1 ]]
340+
341+ # Test that will not clone a variable that is already measurable
342+ assert reconsider_logprob (x ) is x
343+ assert reconsider_logprob (orig_x ) is orig_x
344+
345+
346+ def test_ignore_reconsider_logprob_model ():
347+ def custom_logp (value , x ):
348+ # custom_logp is just the logp of x at value
349+ x = reconsider_logprob (x )
350+ return _joint_logp (
351+ [x ],
352+ rvs_to_values = {x : value },
353+ rvs_to_transforms = {},
354+ rvs_to_total_sizes = {},
355+ )
339356
340357 with Model () as m :
341358 x = Normal .dist ()
342- y = CustomDist ("y" , x , logp = logp )
359+ y = CustomDist ("y" , x , logp = custom_logp )
343360 with pytest .warns (
344361 UserWarning ,
345362 match = "Found a random variable that was neither among the observations "
@@ -355,7 +372,7 @@ def logp(value, x):
355372 # The above warning should go away with ignore_logprob.
356373 with Model () as m :
357374 x = ignore_logprob (Normal .dist ())
358- y = CustomDist ("y" , x , logp = logp )
375+ y = CustomDist ("y" , x , logp = custom_logp )
359376 with warnings .catch_warnings ():
360377 warnings .simplefilter ("error" )
361378 assert _joint_logp (
0 commit comments