Skip to content

Commit 119bc9f

Browse files
sbodensteinTorax team
authored andcommitted
Remove batch dimension from CellVariable.
PiperOrigin-RevId: 795051091
1 parent 1c3c89e commit 119bc9f

File tree

1 file changed

+18
-29
lines changed

1 file changed

+18
-29
lines changed

torax/_src/fvm/cell_variable.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,18 @@
2626
from jax import numpy as jnp
2727
import jaxtyping as jt
2828
from torax._src import array_typing
29+
from torax._src import jax_utils
2930
import typing_extensions
3031

3132

3233
def _zero() -> array_typing.FloatScalar:
3334
"""Returns a scalar zero as a jax Array."""
34-
return jnp.zeros(())
35+
return jnp.zeros((), dtype=jax_utils.get_dtype())
3536

3637

37-
@chex.dataclass(frozen=True)
38+
@array_typing.jaxtyped
39+
@jax.tree_util.register_dataclass
40+
@dataclasses.dataclass(frozen=True)
3841
class CellVariable:
3942
"""A variable representing values of the cells along the radius.
4043
@@ -53,15 +56,14 @@ class CellVariable:
5356
of the gradient on the rightmost face variable.
5457
"""
5558

56-
# t* means match 0 or more leading time dimensions.
57-
value: jt.Float[chex.Array, 't* cell']
58-
dr: jt.Float[chex.Array, 't*']
59-
left_face_constraint: jt.Float[chex.Array, 't*'] | None = None
60-
right_face_constraint: jt.Float[chex.Array, 't*'] | None = None
61-
left_face_grad_constraint: jt.Float[chex.Array, 't*'] | None = (
59+
value: jt.Float[jax.Array, 'cell']
60+
dr: array_typing.FloatScalar
61+
left_face_constraint: array_typing.FloatScalar | None = None
62+
right_face_constraint: array_typing.FloatScalar | None = None
63+
left_face_grad_constraint: array_typing.FloatScalar | None = (
6264
dataclasses.field(default_factory=_zero)
6365
)
64-
right_face_grad_constraint: jt.Float[chex.Array, 't*'] | None = (
66+
right_face_grad_constraint: array_typing.FloatScalar | None = (
6567
dataclasses.field(default_factory=_zero)
6668
)
6769
# Can't make the above default values be jax zeros because that would be a
@@ -119,21 +121,9 @@ def __post_init__(self):
119121
'right_face_grad_constraint must be set.'
120122
)
121123

122-
def _assert_unbatched(self):
123-
if len(self.value.shape) != 1:
124-
raise AssertionError(
125-
'CellVariable must be unbatched, but has `value` shape '
126-
f'{self.value.shape}. Consider using vmap to batch the function call.'
127-
)
128-
if self.dr.shape:
129-
raise AssertionError(
130-
'CellVariable must be unbatched, but has `dr` shape '
131-
f'{self.dr.shape}. Consider using vmap to batch the function call.'
132-
)
133-
134124
def face_grad(
135-
self, x: jt.Float[chex.Array, 'cell'] | None = None
136-
) -> jt.Float[chex.Array, 'face']:
125+
self, x: jt.Float[array_typing.Array, 'cell'] | None = None
126+
) -> jt.Float[jax.Array, 'face']:
137127
"""Returns the gradient of this value with respect to the faces.
138128
139129
Implemented using forward differencing of cells. Leftmost and rightmost
@@ -146,7 +136,6 @@ def face_grad(
146136
Returns:
147137
A jax.Array of shape (num_faces,) containing the gradient.
148138
"""
149-
self._assert_unbatched()
150139
if x is None:
151140
forward_difference = jnp.diff(self.value) / self.dr
152141
else:
@@ -194,7 +183,7 @@ def constrained_grad(
194183
right = jnp.expand_dims(right_grad, axis=0)
195184
return jnp.concatenate([left, forward_difference, right])
196185

197-
def _left_face_value(self) -> jt.Float[chex.Array, '#t']:
186+
def _left_face_value(self) -> jt.Float[jax.Array, '#t']:
198187
"""Calculates the value of the leftmost face."""
199188
if self.left_face_constraint is not None:
200189
value = self.left_face_constraint
@@ -206,7 +195,7 @@ def _left_face_value(self) -> jt.Float[chex.Array, '#t']:
206195
value = self.value[..., 0:1]
207196
return value
208197

209-
def _right_face_value(self) -> jt.Float[chex.Array, '#t']:
198+
def _right_face_value(self) -> jt.Float[jax.Array, '#t']:
210199
"""Calculates the value of the rightmost face."""
211200
if self.right_face_constraint is not None:
212201
value = self.right_face_constraint
@@ -222,14 +211,14 @@ def _right_face_value(self) -> jt.Float[chex.Array, '#t']:
222211
)
223212
return value
224213

225-
def face_value(self) -> jt.Float[jax.Array, 't* face']:
214+
def face_value(self) -> jt.Float[jax.Array, 'face']:
226215
"""Calculates values of this variable on the face grid."""
227216
inner = (self.value[..., :-1] + self.value[..., 1:]) / 2.0
228217
return jnp.concatenate(
229218
[self._left_face_value(), inner, self._right_face_value()], axis=-1
230219
)
231220

232-
def grad(self) -> jt.Float[jax.Array, 't* face']:
221+
def grad(self) -> jt.Float[jax.Array, 'face']:
233222
"""Returns the gradient of this variable wrt cell centers."""
234223
face = self.face_value()
235224
return jnp.diff(face) / jnp.expand_dims(self.dr, axis=-1)
@@ -251,7 +240,7 @@ def __str__(self) -> str:
251240
output_string += ')'
252241
return output_string
253242

254-
def cell_plus_boundaries(self) -> jt.Float[jax.Array, 't* cell+2']:
243+
def cell_plus_boundaries(self) -> jt.Float[jax.Array, 'cell+2']:
255244
"""Returns the value of this variable plus left and right boundaries."""
256245
right_value = self._right_face_value()
257246
left_value = self._left_face_value()

0 commit comments

Comments
 (0)