Skip to content

Commit 7a0c763

Browse files
sbodensteinTorax team
authored andcommitted
Add a BatchedCellVariable class.
This will allow for vectorizing code that currently loops across sequences of CellVariables. PiperOrigin-RevId: 795051091
1 parent 9da5f7d commit 7a0c763

File tree

4 files changed

+200
-1
lines changed

4 files changed

+200
-1
lines changed

torax/_src/fvm/cell_variable.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,20 @@
2020
[https://www.ctcms.nist.gov/fipy/]
2121
"""
2222
import dataclasses
23+
from typing import Self, Sequence
2324

2425
import chex
2526
import jax
2627
from jax import numpy as jnp
2728
import jaxtyping as jt
2829
from torax._src import array_typing
30+
from torax._src import jax_utils
2931
import typing_extensions
3032

3133

3234
def _zero() -> array_typing.FloatScalar:
3335
"""Returns a scalar zero as a jax Array."""
34-
return jnp.zeros(())
36+
return jnp.zeros((), dtype=jax_utils.get_dtype())
3537

3638

3739
@chex.dataclass(frozen=True)
@@ -266,3 +268,62 @@ def __eq__(self, other: typing_extensions.Self) -> bool:
266268
return True
267269
except AssertionError:
268270
return False
271+
272+
273+
@dataclasses.dataclass(frozen=True)
274+
class BatchedCellVariable:
275+
"""An object representing a batch of `CellVariable` objects."""
276+
277+
value: jt.Float[jax.Array, 'batch cell']
278+
dr: jt.Float[jax.Array, 'batch']
279+
left_face_constraint: jt.Float[jax.Array, 'batch']
280+
right_face_constraint: jt.Float[jax.Array, 'batch']
281+
left_face_grad_constraint: jt.Float[jax.Array, 'batch']
282+
right_face_grad_constraint: jt.Float[jax.Array, 'batch']
283+
284+
@classmethod
285+
def construct(cls, x: CellVariable | Sequence[CellVariable]) -> Self:
286+
"""Constructs a `CellVariableBatched` object.
287+
288+
Note that any `None` values are represented as `jnp.nan` values.
289+
290+
Args:
291+
x: A `CellVariable` or a sequence of `CellVariables`.
292+
293+
Returns:
294+
A `BatchedCellVariable` object.
295+
"""
296+
x = (x,) if isinstance(x, CellVariable) else x
297+
if not x:
298+
raise ValueError('The tuple must have at least one element.')
299+
300+
x = [_conform(v) for v in x]
301+
302+
def concat(*x):
303+
return jnp.concatenate(x, axis=0)
304+
305+
x = jax.tree_util.tree_map(concat, *x)
306+
return cls(**x)
307+
308+
309+
def _conform(x: CellVariable):
310+
def conform_scalar(x):
311+
if x is None:
312+
return jnp.array([jnp.nan], dtype=jax_utils.get_dtype())
313+
314+
if x.ndim == 0:
315+
return jnp.expand_dims(x, axis=0)
316+
return x
317+
318+
value = jnp.expand_dims(x.value, axis=0) if x.value.ndim == 1 else x.value
319+
320+
return {
321+
'value': value,
322+
'dr': conform_scalar(x.dr),
323+
'left_face_constraint': conform_scalar(x.left_face_constraint),
324+
'right_face_constraint': conform_scalar(x.right_face_constraint),
325+
'left_face_grad_constraint': conform_scalar(x.left_face_grad_constraint),
326+
'right_face_grad_constraint': conform_scalar(
327+
x.right_face_grad_constraint
328+
),
329+
}

torax/_src/fvm/tests/cell_variable_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,44 @@ def test_almost_equal(
584584
with self.assertRaises(AssertionError):
585585
chex.assert_trees_all_close(var1, var2, atol=atol)
586586

587+
def test_construct_batched_cell_variable(self):
588+
var_1 = cell_variable.CellVariable(
589+
value=jnp.array([1.0, 2.0]),
590+
dr=jnp.array(0.1),
591+
left_face_constraint=jnp.array(3.0),
592+
left_face_grad_constraint=None,
593+
)
594+
var_2 = cell_variable.CellVariable(
595+
value=jnp.array([5.0, 6.0]),
596+
dr=jnp.array(0.2),
597+
right_face_constraint=jnp.array(5.0),
598+
right_face_grad_constraint=None,
599+
)
600+
with self.subTest('from_tuple'):
601+
var = cell_variable.BatchedCellVariable.construct((var_1, var_2))
602+
chex.assert_trees_all_equal(
603+
var.value, jnp.array([[1.0, 2.0], [5.0, 6.0]])
604+
)
605+
chex.assert_trees_all_equal(var.dr, jnp.array([0.1, 0.2]))
606+
607+
chex.assert_trees_all_equal(
608+
var.left_face_constraint, jnp.array([3.0, jnp.nan])
609+
)
610+
chex.assert_trees_all_equal(
611+
var.right_face_constraint, jnp.array([jnp.nan, 5.0])
612+
)
613+
chex.assert_trees_all_equal(
614+
var.left_face_grad_constraint, jnp.array([jnp.nan, 0.0])
615+
)
616+
chex.assert_trees_all_equal(
617+
var.right_face_grad_constraint, jnp.array([0.0, jnp.nan])
618+
)
619+
620+
with self.subTest('from_single_var'):
621+
var = cell_variable.BatchedCellVariable.construct(var_1)
622+
chex.assert_trees_all_equal(var.value, jnp.array([[1.0, 2.0]]))
623+
chex.assert_trees_all_equal(var.dr, jnp.array([0.1]))
624+
587625

588626
if __name__ == '__main__':
589627
absltest.main()

torax/_src/jax_utils.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
T = TypeVar('T')
3030
BooleanNumeric: TypeAlias = Any # A bool, or a Boolean array.
3131
_State = ParamSpec('_State')
32+
PyTree: TypeAlias = Any
3233

3334

3435
@functools.cache
@@ -299,6 +300,70 @@ def init_array(x):
299300
return jax.tree_util.tree_map(init_array, t)
300301

301302

303+
def batched_cond(
304+
pred: jax.Array,
305+
true_fun: Callable[..., PyTree],
306+
false_fun: Callable[..., PyTree],
307+
operands: tuple[PyTree, ...],
308+
implementation: Literal['vectorize', 'map'] = 'vectorize',
309+
):
310+
"""A batched version of `jax.lax.cond`.
311+
312+
JAX provides two approaches for implementing a batched version of
313+
`jax.lax.cond`, neither of which is always faster:
314+
`implementation='vectorize'` is equivalent to `jnp.select`, which evaluates
315+
both braches for every batch element. This is fully vectorized, allowing for
316+
parallel execution on CPU/GPU, but requiring twice the number of function
317+
evaluations. `implementation='map'` will sequentially evaluate `jax.lax.cond`,
318+
preventing vectorized execution, but only requiring a single function
319+
evaluation per batch element.
320+
321+
This function also handles the special case where `pred` is a concrete list of
322+
length-1, in which case we can avoid tracing both branches like `jax.lax.cond`
323+
does by doing the control-flow in Python.
324+
325+
Args:
326+
pred: Boolean 1D array `[batch_size]`, indicating which branch function to
327+
apply.
328+
true_fun: Function (A -> B), to be applied if `pred` is True.
329+
false_fun: Function (A -> B), to be applied if `pred` is False.
330+
operands: A tuple of arguments to pass to the functions. Each `jax.Array`
331+
(every PyTree leaf) must have a leading batch dimension of size
332+
`batch_size`.
333+
implementation: The implementation to use. 'vectorize' compiles to a
334+
`jax.lax.select`, where both branches are evaluated. 'map' uses
335+
`jax.lax.map`.
336+
337+
Returns:
338+
The result of applying the appropriate function to each element of the
339+
batch.
340+
"""
341+
342+
if not isinstance(operands, tuple):
343+
raise ValueError('The args must be a tuple.')
344+
345+
if pred.ndim != 1 or pred.dtype != jnp.bool:
346+
raise ValueError('pred must be a 1D array of bools.')
347+
348+
# For the special case where `pred` is a concrete list of length 1, we can
349+
# avoid tracing both branches by doing the control flow in Python.
350+
if len(pred) == 1 and not isinstance(pred, jax.core.Tracer):
351+
f = true_fun if bool(pred) else false_fun
352+
operands = jax.tree.map(lambda x: jnp.squeeze(x, axis=0), operands)
353+
out = f(*operands)
354+
return jax.tree.map(lambda x: jnp.expand_dims(x, axis=0), out)
355+
356+
f = lambda args: jax.lax.cond(args[0], true_fun, false_fun, *args[1])
357+
match implementation:
358+
case 'vectorize':
359+
# This is compiled to a jax.lax.select, where both branches are evaluated.
360+
return jax.vmap(f)((pred, operands))
361+
case 'map':
362+
return jax.lax.map(f, (pred, operands))
363+
case _:
364+
raise ValueError(f'Unknown implementation: {implementation}')
365+
366+
302367
@functools.partial(
303368
jit, static_argnames=['cond_fun', 'body_fun', 'max_steps', 'scan_unroll']
304369
)

torax/_src/tests/jax_utils_test.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,41 @@ def f(x, z, y=2.0):
150150
x = {'temp1': jnp.array(1.3), 'temp2': jnp.array(2.6)}
151151
chex.assert_trees_all_close(f_non_inlined(x, z='left'), f(x, z='left'))
152152

153+
@parameterized.parameters(['map', 'vectorize'])
154+
def test_batched_cond(self, implementation):
155+
pred = jnp.array([True, False])
156+
x = jnp.array([[2, 3.0, 4.0], [5.0, 6.0, 7.0]])
157+
out = jax_utils.batched_cond(
158+
pred=pred,
159+
true_fun=lambda x, y: x * y,
160+
false_fun=lambda x, y: x * y**2,
161+
operands=(x, x),
162+
implementation=implementation,
163+
)
164+
out_gt = jnp.array(
165+
[[4.0, 9.0, 16.0], [125.0, 216.0, 343.0]], dtype=jnp.float32
166+
)
167+
chex.assert_trees_all_equal(out, out_gt)
168+
169+
@parameterized.parameters(['map', 'vectorize'])
170+
def test_batched_cond_concrete_special(self, implementation):
171+
pred = jnp.array([True])
172+
x = jnp.array([[2, 3.0, 4.0]])
173+
174+
@jax.jit
175+
def f(x):
176+
return jax_utils.batched_cond(
177+
pred=pred,
178+
true_fun=lambda x, y: x * y,
179+
false_fun=lambda x, y: x * y**2,
180+
operands=(x, x),
181+
implementation=implementation,
182+
)
183+
184+
out = f(x)
185+
out_gt = jnp.array([[4.0, 9.0, 16.0]], dtype=jnp.float32)
186+
chex.assert_trees_all_equal(out, out_gt)
187+
153188
def test_max_steps_while_loop(self):
154189
terminating_step = 4
155190

0 commit comments

Comments
 (0)