-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
Description
In the following graph:
import pymc as pm
with pm.Model(check_bounds=True):
sigma = pm.HalfNormal("sigma")
y = pm.Normal("y", sigma=sigma)
m.compile_logp().fn.dprint()There ends up being a switch(exp(i0) > 0, something, -inf) that disappears with check_bounds=False (one other stays regardless).
├─ Switch [id E]
│ ├─ GE [id F]
│ │ ├─ exp [id G] 't1'
│ │ │ └─ i0 [id H]
│ │ └─ 0.0 [id I]
│ ├─ add [id J]
│ │ ├─ -0.22579135264472738 [id K]
│ │ └─ mul [id L]
│ │ ├─ -0.5 [id M]
│ │ └─ sqr [id N]
│ │ └─ exp [id G] 't1'
│ │ └─ ···
│ └─ -inf [id O]
We could rewrite exp(x) >= 0) -> True in PyTensor. The one case it's not correct is if it's a nan originally. We could tag such rewrite as nan_unsafe, or we could promise pytensor that we are not going to propose nan for x. I prefer the first approach.
But some of the switches are more complex, such as IntervalTransform(0, 1), for unit variables. We could however check that rvs_to_transforms has the right transform that ensures the constraint, and remove the check like check_bounds does. It may be easier to do that, than to verify the computational graph implies the constraints.
These are a bit myopic points. The general idea is to make check_bounds redundant when a graph is well defined, and remove that flag from the model. Streamlining the UX