@@ -103,6 +103,9 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
103103 assert_size_argument_jax_compatible (node )
104104
105105 def sample_fn (rng , size , dtype , * parameters ):
106+ # PyTensor uses empty size to represent size = None
107+ if jax .numpy .asarray (size ).shape == (0 ,):
108+ size = None
106109 return jax_sample_fn (op )(rng , size , out_dtype , * parameters )
107110
108111 else :
@@ -161,6 +164,8 @@ def sample_fn(rng, size, dtype, *parameters):
161164 rng_key = rng ["jax_state" ]
162165 rng_key , sampling_key = jax .random .split (rng_key , 2 )
163166 loc , scale = parameters
167+ if size is None :
168+ size = jax .numpy .broadcast_arrays (loc , scale )[0 ].shape
164169 sample = loc + jax_op (sampling_key , size , dtype ) * scale
165170 rng ["jax_state" ] = rng_key
166171 return (rng , sample )
@@ -184,15 +189,16 @@ def sample_fn(rng, size, dtype, p):
184189
185190
186191@jax_sample_fn .register (ptr .CategoricalRV )
187- def jax_sample_fn_no_dtype (op ):
188- """Generic JAX implementation of random variables."""
189- name = op .name
190- jax_op = getattr (jax .random , name )
192+ def jax_sample_fn_categorical (op ):
193+ """JAX implementation of `CategoricalRV`."""
191194
192- def sample_fn (rng , size , dtype , * parameters ):
195+ # We need a separate dispatch because Categorical expects logits in JAX
196+ def sample_fn (rng , size , dtype , p ):
193197 rng_key = rng ["jax_state" ]
194198 rng_key , sampling_key = jax .random .split (rng_key , 2 )
195- sample = jax_op (sampling_key , * parameters , shape = size )
199+
200+ logits = jax .scipy .special .logit (p )
201+ sample = jax .random .categorical (sampling_key , logits = logits , shape = size )
196202 rng ["jax_state" ] = rng_key
197203 return (rng , sample )
198204
@@ -243,6 +249,8 @@ def jax_sample_fn_shape_scale(op):
243249 def sample_fn (rng , size , dtype , shape , scale ):
244250 rng_key = rng ["jax_state" ]
245251 rng_key , sampling_key = jax .random .split (rng_key , 2 )
252+ if size is None :
253+ size = jax .numpy .broadcast_arrays (shape , scale )[0 ].shape
246254 sample = jax_op (sampling_key , shape , size , dtype ) * scale
247255 rng ["jax_state" ] = rng_key
248256 return (rng , sample )
@@ -254,10 +262,11 @@ def sample_fn(rng, size, dtype, shape, scale):
254262def jax_sample_fn_exponential (op ):
255263 """JAX implementation of `ExponentialRV`."""
256264
257- def sample_fn (rng , size , dtype , * parameters ):
265+ def sample_fn (rng , size , dtype , scale ):
258266 rng_key = rng ["jax_state" ]
259267 rng_key , sampling_key = jax .random .split (rng_key , 2 )
260- (scale ,) = parameters
268+ if size is None :
269+ size = jax .numpy .asarray (scale ).shape
261270 sample = jax .random .exponential (sampling_key , size , dtype ) * scale
262271 rng ["jax_state" ] = rng_key
263272 return (rng , sample )
@@ -269,14 +278,11 @@ def sample_fn(rng, size, dtype, *parameters):
269278def jax_sample_fn_t (op ):
270279 """JAX implementation of `StudentTRV`."""
271280
272- def sample_fn (rng , size , dtype , * parameters ):
281+ def sample_fn (rng , size , dtype , df , loc , scale ):
273282 rng_key = rng ["jax_state" ]
274283 rng_key , sampling_key = jax .random .split (rng_key , 2 )
275- (
276- df ,
277- loc ,
278- scale ,
279- ) = parameters
284+ if size is None :
285+ size = jax .numpy .broadcast_arrays (df , loc , scale )[0 ].shape
280286 sample = loc + jax .random .t (sampling_key , df , size , dtype ) * scale
281287 rng ["jax_state" ] = rng_key
282288 return (rng , sample )
0 commit comments