Skip to content

Commit 20701d0

Browse files
jcitrinTorax team
authored andcommitted
Add cumulative cell integration functions to math_utils.
This change introduces `cumulative_cell_integration`, `cumulative_area_integration`, and `cumulative_volume_integration` to compute cumulative integrals of profiles over the radial grid using cell, area, and volume metrics. Tests are added to verify that the final value of the cumulative integral matches the total integral and to check against manual `cumsum` calculations. Needed in #1489 PiperOrigin-RevId: 833240972
1 parent c906e01 commit 20701d0

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

torax/_src/math_utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,45 @@ def volume_average(
251251
) -> array_typing.FloatScalar:
252252
"""Calculates volume-averaged value from input profile."""
253253
return cell_integration(value * geo.vpr, geo) / geo.volume_face[-1]
254+
255+
256+
@array_typing.jaxtyped
257+
def cumulative_cell_integration(
258+
x: array_typing.FloatVectorCell, geo: geometry.Geometry
259+
) -> array_typing.FloatVectorCell:
260+
r"""Cumulative integration of a value `x` over the rhon grid.
261+
262+
Args:
263+
x: The cell averaged value to integrate.
264+
geo: The geometry instance.
265+
266+
Returns:
267+
Cumulative integration array same size as x.
268+
"""
269+
if x.shape != geo.rho_norm.shape:
270+
raise ValueError(
271+
'For cumulative_cell_integration, input "x" must have same shape as '
272+
f'the cell grid. Got x.shape={x.shape}, '
273+
f'expected {geo.rho_norm.shape}.'
274+
)
275+
# Uses cumsum to accumulate x * drho_norm.
276+
# The first element will be x[0] * drho_norm[0].
277+
return jnp.cumsum(x * geo.drho_norm)
278+
279+
280+
@array_typing.jaxtyped
281+
def cumulative_area_integration(
282+
value: array_typing.FloatVectorCell,
283+
geo: geometry.Geometry,
284+
) -> array_typing.FloatVectorCell:
285+
"""Calculates cumulative integral of value using an area metric."""
286+
return cumulative_cell_integration(value * geo.spr, geo)
287+
288+
289+
@array_typing.jaxtyped
290+
def cumulative_volume_integration(
291+
value: array_typing.FloatVectorCell,
292+
geo: geometry.Geometry,
293+
) -> array_typing.FloatVectorCell:
294+
"""Calculates cumulative integral of value using a volume metric."""
295+
return cumulative_cell_integration(value * geo.vpr, geo)

torax/_src/tests/math_utils_test.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,63 @@ def test_cell_to_face(
254254
),
255255
)
256256

257+
@parameterized.parameters(5, 50)
258+
def test_cumulative_cell_integration(self, num_cell_grid_points: int):
259+
"""Tests cumulative_cell_integration against cell_integration."""
260+
geo = geometry_pydantic_model.CircularConfig(
261+
n_rho=num_cell_grid_points
262+
).build_geometry()
263+
x = jax.random.uniform(jax.random.PRNGKey(1), shape=(num_cell_grid_points,))
264+
265+
cumulative_result = math_utils.cumulative_cell_integration(x, geo)
266+
expected = np.zeros(num_cell_grid_points)
267+
268+
for i in range(len(cumulative_result)):
269+
expected[i] = np.sum(x[: i + 1] * geo.drho_norm)
270+
271+
np.testing.assert_allclose(
272+
cumulative_result,
273+
expected,
274+
)
275+
276+
@parameterized.parameters(5, 50)
277+
def test_cumulative_area_integration(self, num_cell_grid_points: int):
278+
"""Tests cumulative_area_integration against area_integration."""
279+
geo = geometry_pydantic_model.CircularConfig(
280+
n_rho=num_cell_grid_points
281+
).build_geometry()
282+
x = jax.random.uniform(jax.random.PRNGKey(2), shape=(num_cell_grid_points,))
283+
284+
cumulative_result = math_utils.cumulative_area_integration(x, geo)
285+
expected = np.zeros(num_cell_grid_points)
286+
287+
for i in range(len(cumulative_result)):
288+
expected[i] = np.sum(x[: i + 1] * geo.spr[: i + 1] * geo.drho_norm)
289+
290+
np.testing.assert_allclose(
291+
cumulative_result,
292+
expected,
293+
)
294+
295+
@parameterized.parameters(5, 50)
296+
def test_cumulative_volume_integration(self, num_cell_grid_points: int):
297+
"""Tests cumulative_volume_integration against volume_integration."""
298+
geo = geometry_pydantic_model.CircularConfig(
299+
n_rho=num_cell_grid_points
300+
).build_geometry()
301+
x = jax.random.uniform(jax.random.PRNGKey(3), shape=(num_cell_grid_points,))
302+
303+
cumulative_result = math_utils.cumulative_volume_integration(x, geo)
304+
expected = np.zeros(num_cell_grid_points)
305+
306+
for i in range(len(cumulative_result)):
307+
expected[i] = np.sum(x[: i + 1] * geo.vpr[: i + 1] * geo.drho_norm)
308+
309+
np.testing.assert_allclose(
310+
cumulative_result,
311+
expected,
312+
)
313+
257314

258315
if __name__ == '__main__':
259316
absltest.main()

0 commit comments

Comments
 (0)