2626from jax import numpy as jnp
2727import jaxtyping as jt
2828from torax ._src import array_typing
29+ from torax ._src import jax_utils
2930import typing_extensions
3031
3132
3233def _zero () -> array_typing .FloatScalar :
3334 """Returns a scalar zero as a jax Array."""
34- return jnp .zeros (())
35+ return jnp .zeros ((), dtype = jax_utils . get_dtype () )
3536
3637
37- @chex .dataclass (frozen = True )
38+ @array_typing .jaxtyped
39+ @jax .tree_util .register_dataclass
40+ @dataclasses .dataclass (frozen = True )
3841class CellVariable :
3942 """A variable representing values of the cells along the radius.
4043
@@ -53,15 +56,14 @@ class CellVariable:
5356 of the gradient on the rightmost face variable.
5457 """
5558
56- # t* means match 0 or more leading time dimensions.
57- value : jt .Float [chex .Array , 't* cell' ]
58- dr : jt .Float [chex .Array , 't*' ]
59- left_face_constraint : jt .Float [chex .Array , 't*' ] | None = None
60- right_face_constraint : jt .Float [chex .Array , 't*' ] | None = None
61- left_face_grad_constraint : jt .Float [chex .Array , 't*' ] | None = (
59+ value : jt .Float [jax .Array , 'cell' ]
60+ dr : array_typing .FloatScalar
61+ left_face_constraint : array_typing .FloatScalar | None = None
62+ right_face_constraint : array_typing .FloatScalar | None = None
63+ left_face_grad_constraint : array_typing .FloatScalar | None = (
6264 dataclasses .field (default_factory = _zero )
6365 )
64- right_face_grad_constraint : jt . Float [ chex . Array , 't*' ] | None = (
66+ right_face_grad_constraint : array_typing . FloatScalar | None = (
6567 dataclasses .field (default_factory = _zero )
6668 )
6769 # Can't make the above default values be jax zeros because that would be a
@@ -119,21 +121,9 @@ def __post_init__(self):
119121 'right_face_grad_constraint must be set.'
120122 )
121123
122- def _assert_unbatched (self ):
123- if len (self .value .shape ) != 1 :
124- raise AssertionError (
125- 'CellVariable must be unbatched, but has `value` shape '
126- f'{ self .value .shape } . Consider using vmap to batch the function call.'
127- )
128- if self .dr .shape :
129- raise AssertionError (
130- 'CellVariable must be unbatched, but has `dr` shape '
131- f'{ self .dr .shape } . Consider using vmap to batch the function call.'
132- )
133-
134124 def face_grad (
135- self , x : jt .Float [chex .Array , 'cell' ] | None = None
136- ) -> jt .Float [chex .Array , 'face' ]:
125+ self , x : jt .Float [array_typing .Array , 'cell' ] | None = None
126+ ) -> jt .Float [jax .Array , 'face' ]:
137127 """Returns the gradient of this value with respect to the faces.
138128
139129 Implemented using forward differencing of cells. Leftmost and rightmost
@@ -146,7 +136,6 @@ def face_grad(
146136 Returns:
147137 A jax.Array of shape (num_faces,) containing the gradient.
148138 """
149- self ._assert_unbatched ()
150139 if x is None :
151140 forward_difference = jnp .diff (self .value ) / self .dr
152141 else :
@@ -194,7 +183,7 @@ def constrained_grad(
194183 right = jnp .expand_dims (right_grad , axis = 0 )
195184 return jnp .concatenate ([left , forward_difference , right ])
196185
197- def _left_face_value (self ) -> jt .Float [chex .Array , '#t' ]:
186+ def _left_face_value (self ) -> jt .Float [jax .Array , '#t' ]:
198187 """Calculates the value of the leftmost face."""
199188 if self .left_face_constraint is not None :
200189 value = self .left_face_constraint
@@ -206,7 +195,7 @@ def _left_face_value(self) -> jt.Float[chex.Array, '#t']:
206195 value = self .value [..., 0 :1 ]
207196 return value
208197
209- def _right_face_value (self ) -> jt .Float [chex .Array , '#t' ]:
198+ def _right_face_value (self ) -> jt .Float [jax .Array , '#t' ]:
210199 """Calculates the value of the rightmost face."""
211200 if self .right_face_constraint is not None :
212201 value = self .right_face_constraint
@@ -222,14 +211,14 @@ def _right_face_value(self) -> jt.Float[chex.Array, '#t']:
222211 )
223212 return value
224213
225- def face_value (self ) -> jt .Float [jax .Array , 't* face' ]:
214+ def face_value (self ) -> jt .Float [jax .Array , 'face' ]:
226215 """Calculates values of this variable on the face grid."""
227216 inner = (self .value [..., :- 1 ] + self .value [..., 1 :]) / 2.0
228217 return jnp .concatenate (
229218 [self ._left_face_value (), inner , self ._right_face_value ()], axis = - 1
230219 )
231220
232- def grad (self ) -> jt .Float [jax .Array , 't* face' ]:
221+ def grad (self ) -> jt .Float [jax .Array , 'face' ]:
233222 """Returns the gradient of this variable wrt cell centers."""
234223 face = self .face_value ()
235224 return jnp .diff (face ) / jnp .expand_dims (self .dr , axis = - 1 )
@@ -251,7 +240,7 @@ def __str__(self) -> str:
251240 output_string += ')'
252241 return output_string
253242
254- def cell_plus_boundaries (self ) -> jt .Float [jax .Array , 't* cell+2' ]:
243+ def cell_plus_boundaries (self ) -> jt .Float [jax .Array , 'cell+2' ]:
255244 """Returns the value of this variable plus left and right boundaries."""
256245 right_value = self ._right_face_value ()
257246 left_value = self ._left_face_value ()
0 commit comments