Skip to content

Commit 3de360a

Browse files
goodfeliTorax team
authored andcommitted
Introduce PedestalPolicy.
This is a step toward having full Martin scaling, where there will be a Martin scaling PedestalPolicy that ramps up and ramps down the pedestal height. This PR is just a step that defines the new stateful interface required for doing so and transitions the existing `set_pedestal` functionality to being implemented with the new PedestalPolicy class. PiperOrigin-RevId: 810477225
1 parent 8ba6fa9 commit 3de360a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+552
-141
lines changed

torax/_src/config/tests/build_runtime_params_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,20 @@ def test_pedestal_is_time_dependent(self):
9595
set_pedestal={0.0: True, 1.0: False},
9696
)
9797
)
98+
pedestal_policy = pedestal.build_pedestal_model().pedestal_policy
9899
# Check at time 0.
99100

100101
pedestal_params = pedestal.build_runtime_params(t=0.0)
101102
assert isinstance(pedestal_params, set_tped_nped.RuntimeParams)
102-
np.testing.assert_allclose(pedestal_params.set_pedestal, True)
103+
np.testing.assert_allclose(pedestal_policy.initial_state(t=0.0).use_pedestal, True)
103104
np.testing.assert_allclose(pedestal_params.T_i_ped, 0.0)
104105
np.testing.assert_allclose(pedestal_params.T_e_ped, 1.0)
105106
np.testing.assert_allclose(pedestal_params.n_e_ped, 2.0e20)
106107
np.testing.assert_allclose(pedestal_params.rho_norm_ped_top, 3.0)
107108
# And check after the time limit.
108109
pedestal_params = pedestal.build_runtime_params(t=1.0)
109110
assert isinstance(pedestal_params, set_tped_nped.RuntimeParams)
110-
np.testing.assert_allclose(pedestal_params.set_pedestal, False)
111+
np.testing.assert_allclose(pedestal_policy.initial_state(t=1.0).use_pedestal, False)
111112
np.testing.assert_allclose(pedestal_params.T_i_ped, 1.0)
112113
np.testing.assert_allclose(pedestal_params.T_e_ped, 2.0)
113114
np.testing.assert_allclose(pedestal_params.n_e_ped, 3.0e20)

torax/_src/fvm/calc_coeffs.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torax._src.fvm import cell_variable
2828
from torax._src.geometry import geometry
2929
from torax._src.pedestal_model import pedestal_model as pedestal_model_lib
30+
from torax._src.pedestal_policy import pedestal_policy
3031
from torax._src.sources import source_profile_builders
3132
from torax._src.sources import source_profiles as source_profiles_lib
3233
import typing_extensions
@@ -63,6 +64,7 @@ def __call__(
6364
core_profiles: state.CoreProfiles,
6465
x: tuple[cell_variable.CellVariable, ...],
6566
explicit_source_profiles: source_profiles_lib.SourceProfiles,
67+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
6668
allow_pereverzev: bool = False,
6769
# Checks if reduced calc_coeffs for explicit terms when theta_implicit=1
6870
# should be called
@@ -86,6 +88,7 @@ def __call__(
8688
not recalculated at time t+plus_dt with updated state during the solver
8789
iterations. For sources that are implicit, their explicit profiles are
8890
set to all zeros.
91+
pedestal_policy_state: State held by the pedestal policy.
8992
allow_pereverzev: If True, then the coeffs are being called within a
9093
linear solver. Thus could be either the use_predictor_corrector solver
9194
or as part of calculating the initial guess for the nonlinear solver. In
@@ -121,6 +124,7 @@ def __call__(
121124
explicit_source_profiles=explicit_source_profiles,
122125
physics_models=self.physics_models,
123126
evolving_names=self.evolving_names,
127+
pedestal_policy_state=pedestal_policy_state,
124128
use_pereverzev=use_pereverzev,
125129
explicit_call=explicit_call,
126130
)
@@ -219,6 +223,7 @@ def calc_coeffs(
219223
explicit_source_profiles: source_profiles_lib.SourceProfiles,
220224
physics_models: physics_models_lib.PhysicsModels,
221225
evolving_names: tuple[str, ...],
226+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
222227
use_pereverzev: bool = False,
223228
explicit_call: bool = False,
224229
) -> block_1d_coeffs.Block1DCoeffs:
@@ -241,6 +246,7 @@ def calc_coeffs(
241246
physics_models: The physics models to use for the simulation.
242247
evolving_names: The names of the evolving variables in the order that their
243248
coefficients should be written to `coeffs`.
249+
pedestal_policy_state: State held by the pedestal policy.
244250
use_pereverzev: Toggle whether to calculate Pereverzev terms
245251
explicit_call: If True, indicates that calc_coeffs is being called for the
246252
explicit component of the PDE. Then calculates a reduced Block1DCoeffs if
@@ -267,6 +273,7 @@ def calc_coeffs(
267273
explicit_source_profiles=explicit_source_profiles,
268274
physics_models=physics_models,
269275
evolving_names=evolving_names,
276+
pedestal_policy_state=pedestal_policy_state,
270277
use_pereverzev=use_pereverzev,
271278
)
272279

@@ -284,14 +291,18 @@ def _calc_coeffs_full(
284291
explicit_source_profiles: source_profiles_lib.SourceProfiles,
285292
physics_models: physics_models_lib.PhysicsModels,
286293
evolving_names: tuple[str, ...],
294+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
287295
use_pereverzev: bool = False,
288296
) -> block_1d_coeffs.Block1DCoeffs:
289297
"""See `calc_coeffs` for details."""
290298

291299
consts = constants.CONSTANTS
292300

293301
pedestal_model_output = physics_models.pedestal_model(
294-
runtime_params, geo, core_profiles
302+
runtime_params,
303+
geo,
304+
core_profiles,
305+
pedestal_policy_state,
295306
)
296307

297308
# Boolean mask for enforcing internal temperature boundary conditions to
@@ -351,7 +362,11 @@ def _calc_coeffs_full(
351362

352363
# Diffusion term coefficients
353364
turbulent_transport = physics_models.transport_model(
354-
runtime_params, geo, core_profiles, pedestal_model_output
365+
runtime_params,
366+
geo,
367+
core_profiles,
368+
pedestal_policy_state,
369+
pedestal_model_output,
355370
)
356371
neoclassical_transport = physics_models.neoclassical_models.transport(
357372
runtime_params, geo, core_profiles

torax/_src/fvm/newton_raphson_solve_block.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from torax._src.fvm import jax_root_finding
3636
from torax._src.fvm import residual_and_loss
3737
from torax._src.geometry import geometry
38+
from torax._src.pedestal_policy import pedestal_policy
3839
from torax._src.solver import predictor_corrector_method
3940
from torax._src.sources import source_profiles
4041

@@ -67,6 +68,7 @@ def newton_raphson_solve_block(
6768
physics_models: physics_models_lib.PhysicsModels,
6869
coeffs_callback: calc_coeffs.CoeffsCallback,
6970
evolving_names: tuple[str, ...],
71+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
7072
initial_guess_mode: enums.InitialGuessMode,
7173
maxiter: int,
7274
tol: float,
@@ -128,6 +130,7 @@ def newton_raphson_solve_block(
128130
core_profiles. Repeatedly called by the iterative optimizer.
129131
evolving_names: The names of variables within the core profiles that should
130132
evolve.
133+
pedestal_policy_state: State variables held by the pedestal policy.
131134
initial_guess_mode: chooses the initial_guess for the iterative method,
132135
either x_old or linear step. When taking the linear step, it is also
133136
recommended to use Pereverzev-Corrigan terms if the transport coefficients
@@ -159,6 +162,7 @@ def newton_raphson_solve_block(
159162
core_profiles_t,
160163
x_old,
161164
explicit_source_profiles=explicit_source_profiles,
165+
pedestal_policy_state=pedestal_policy_state,
162166
explicit_call=True,
163167
)
164168

@@ -175,6 +179,7 @@ def newton_raphson_solve_block(
175179
core_profiles_t,
176180
x_old,
177181
explicit_source_profiles=explicit_source_profiles,
182+
pedestal_policy_state=pedestal_policy_state,
178183
allow_pereverzev=True,
179184
explicit_call=True,
180185
)
@@ -193,6 +198,7 @@ def newton_raphson_solve_block(
193198
coeffs_exp=coeffs_exp_linear,
194199
coeffs_callback=coeffs_callback,
195200
explicit_source_profiles=explicit_source_profiles,
201+
pedestal_policy_state=pedestal_policy_state,
196202
)
197203
init_x_new_vec = fvm_conversions.cell_variable_tuple_to_vec(init_x_new)
198204
case enums.InitialGuessMode.X_OLD:
@@ -214,6 +220,7 @@ def newton_raphson_solve_block(
214220
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
215221
physics_models=physics_models,
216222
explicit_source_profiles=explicit_source_profiles,
223+
pedestal_policy_state=pedestal_policy_state,
217224
coeffs_old=coeffs_old,
218225
evolving_names=evolving_names,
219226
)

torax/_src/fvm/optimizer_solve_block.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from torax._src.fvm import fvm_conversions
3333
from torax._src.fvm import residual_and_loss
3434
from torax._src.geometry import geometry
35+
from torax._src.pedestal_policy import pedestal_policy
3536
from torax._src.solver import predictor_corrector_method
3637
from torax._src.sources import source_profiles
3738

@@ -56,6 +57,7 @@ def optimizer_solve_block(
5657
core_profiles_t: state.CoreProfiles,
5758
core_profiles_t_plus_dt: state.CoreProfiles,
5859
explicit_source_profiles: source_profiles.SourceProfiles,
60+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
5961
physics_models: physics_models_lib.PhysicsModels,
6062
coeffs_callback: calc_coeffs.CoeffsCallback,
6163
evolving_names: tuple[str, ...],
@@ -97,6 +99,7 @@ def optimizer_solve_block(
9799
being evolved by the PDE system.
98100
explicit_source_profiles: Pre-calculated sources implemented as explicit
99101
sources in the PDE.
102+
pedestal_policy_state: State variables held by the pedestal policy.
100103
physics_models: Physics models used for the calculations.
101104
coeffs_callback: Calculates diffusion, convection etc. coefficients given a
102105
core_profiles. Repeatedly called by the iterative optimizer.
@@ -123,6 +126,7 @@ def optimizer_solve_block(
123126
core_profiles_t,
124127
x_old,
125128
explicit_source_profiles=explicit_source_profiles,
129+
pedestal_policy_state=pedestal_policy_state,
126130
explicit_call=True,
127131
)
128132

@@ -140,6 +144,7 @@ def optimizer_solve_block(
140144
core_profiles_t,
141145
x_old,
142146
explicit_source_profiles=explicit_source_profiles,
147+
pedestal_policy_state=pedestal_policy_state,
143148
allow_pereverzev=True,
144149
explicit_call=True,
145150
)

torax/_src/fvm/residual_and_loss.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from torax._src.fvm import discrete_system
3939
from torax._src.fvm import fvm_conversions
4040
from torax._src.geometry import geometry
41+
from torax._src.pedestal_policy import pedestal_policy
4142
from torax._src.sources import source_profiles
4243

4344
Block1DCoeffs: TypeAlias = block_1d_coeffs.Block1DCoeffs
@@ -200,6 +201,7 @@ def theta_method_block_residual(
200201
x_old: tuple[cell_variable.CellVariable, ...],
201202
core_profiles_t_plus_dt: state.CoreProfiles,
202203
explicit_source_profiles: source_profiles.SourceProfiles,
204+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
203205
physics_models: physics_models_lib.PhysicsModels,
204206
coeffs_old: Block1DCoeffs,
205207
evolving_names: tuple[str, ...],
@@ -219,6 +221,7 @@ def theta_method_block_residual(
219221
being evolved by the PDE system.
220222
explicit_source_profiles: Pre-calculated sources implemented as explicit
221223
sources in the PDE.
224+
pedestal_policy_state: State variables held by the pedestal policy.
222225
physics_models: Physics models used for the calculations.
223226
coeffs_old: The coefficients calculated at x_old.
224227
evolving_names: The names of variables within the core profiles that should
@@ -251,6 +254,7 @@ def theta_method_block_residual(
251254
core_profiles=core_profiles_t_plus_dt,
252255
explicit_source_profiles=explicit_source_profiles,
253256
physics_models=physics_models,
257+
pedestal_policy_state=pedestal_policy_state,
254258
evolving_names=evolving_names,
255259
use_pereverzev=False,
256260
)
@@ -288,6 +292,7 @@ def theta_method_block_loss(
288292
x_old: tuple[cell_variable.CellVariable, ...],
289293
core_profiles_t_plus_dt: state.CoreProfiles,
290294
explicit_source_profiles: source_profiles.SourceProfiles,
295+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
291296
physics_models: physics_models_lib.PhysicsModels,
292297
coeffs_old: Block1DCoeffs,
293298
evolving_names: tuple[str, ...],
@@ -324,6 +329,7 @@ def theta_method_block_loss(
324329
x_new_guess_vec=x_new_guess_vec,
325330
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
326331
explicit_source_profiles=explicit_source_profiles,
332+
pedestal_policy_state=pedestal_policy_state,
327333
physics_models=physics_models,
328334
coeffs_old=coeffs_old,
329335
evolving_names=evolving_names,

torax/_src/fvm/tests/calc_coeffs_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torax._src.core_profiles import initialization
2121
from torax._src.fvm import calc_coeffs
2222
from torax._src.sources import source_profile_builders
23+
from torax._src.pedestal_policy import pedestal_policy as pedestal_policy_lib
2324
from torax._src.test_utils import default_sources
2425
from torax._src.torax_pydantic import model_config
2526

@@ -79,13 +80,17 @@ def test_calc_coeffs_smoke_test(
7980
neoclassical_models=physics_models.neoclassical_models,
8081
explicit=True,
8182
)
83+
pedestal_policy_state = physics_models.pedestal_model.pedestal_policy.initial_state(
84+
t=torax_config.numerics.t_initial
85+
)
8286
calc_coeffs.calc_coeffs(
8387
runtime_params=runtime_params,
8488
geo=geo,
8589
core_profiles=core_profiles,
8690
physics_models=physics_models,
8791
explicit_source_profiles=explicit_source_profiles,
8892
evolving_names=evolving_names,
93+
pedestal_policy_state=pedestal_policy_state,
8994
use_pereverzev=False,
9095
)
9196

torax/_src/fvm/tests/fvm_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from torax._src.fvm import cell_variable
2727
from torax._src.fvm import implicit_solve_block
2828
from torax._src.fvm import residual_and_loss
29+
from torax._src.pedestal_policy import pedestal_policy as pedestal_policy_lib
2930
from torax._src.sources import runtime_params as source_runtime_params
3031
from torax._src.sources import source_profile_builders
3132
from torax._src.test_utils import default_sources
@@ -247,13 +248,18 @@ def test_nonlinear_solve_block_loss_minimum(
247248
core_profiles=core_profiles,
248249
explicit=True,
249250
)
251+
pedestal_policy = physics_models.pedestal_model.pedestal_policy
252+
pedestal_policy_state = pedestal_policy.initial_state(
253+
t=torax_config.numerics.t_initial
254+
)
250255
coeffs = calc_coeffs.calc_coeffs(
251256
runtime_params=runtime_params,
252257
geo=geo,
253258
core_profiles=core_profiles,
254259
physics_models=physics_models,
255260
explicit_source_profiles=explicit_source_profiles,
256261
evolving_names=evolving_names,
262+
pedestal_policy_state=pedestal_policy_state,
257263
use_pereverzev=False,
258264
)
259265
# dt well under the explicit stability limit for dx=1 and chi=1
@@ -287,19 +293,23 @@ def test_nonlinear_solve_block_loss_minimum(
287293
explicit_source_profiles=explicit_source_profiles,
288294
coeffs_old=coeffs,
289295
evolving_names=evolving_names,
296+
pedestal_policy_state=pedestal_policy_state,
290297
)
291298

292299
residual = residual_and_loss.theta_method_block_residual(
293300
dt=dt,
294301
runtime_params_t_plus_dt=runtime_params,
295302
geo_t_plus_dt=geo,
303+
304+
296305
x_new_guess_vec=jnp.concatenate([var.value for var in x_new]),
297306
x_old=x_old,
298307
core_profiles_t_plus_dt=core_profiles,
299308
physics_models=physics_models,
300309
explicit_source_profiles=explicit_source_profiles,
301310
coeffs_old=coeffs,
302311
evolving_names=evolving_names,
312+
pedestal_policy_state=pedestal_policy_state,
303313
)
304314

305315
np.testing.assert_allclose(loss, 0.0, atol=1e-7)
@@ -353,13 +363,18 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self):
353363
dt = jnp.array(1.0)
354364
evolving_names = tuple(['T_i'])
355365

366+
pedestal_policy = physics_models.pedestal_model.pedestal_policy
367+
pedestal_policy_state = pedestal_policy.initial_state(
368+
t=torax_config.numerics.t_initial
369+
)
356370
coeffs = calc_coeffs.calc_coeffs(
357371
runtime_params=runtime_params,
358372
geo=geo,
359373
core_profiles=initial_core_profiles,
360374
physics_models=physics_models,
361375
explicit_source_profiles=explicit_source_profiles,
362376
evolving_names=evolving_names,
377+
pedestal_policy_state=pedestal_policy_state,
363378
use_pereverzev=False,
364379
)
365380
initial_right_boundary = jnp.array(0.0)
@@ -470,13 +485,18 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
470485
dt = jnp.array(1.0)
471486
evolving_names = tuple(['T_i'])
472487

488+
pedestal_policy = physics_models.pedestal_model.pedestal_policy
489+
pedestal_policy_state = pedestal_policy.initial_state(
490+
t=torax_config.numerics.t_initial
491+
)
473492
coeffs_old = calc_coeffs.calc_coeffs(
474493
runtime_params=runtime_params_theta05,
475494
geo=geo,
476495
core_profiles=initial_core_profiles,
477496
physics_models=physics_models,
478497
explicit_source_profiles=explicit_source_profiles,
479498
evolving_names=evolving_names,
499+
pedestal_policy_state=pedestal_policy_state,
480500
use_pereverzev=False,
481501
)
482502

@@ -513,6 +533,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
513533
explicit_source_profiles=explicit_source_profiles,
514534
coeffs_old=coeffs_old,
515535
evolving_names=evolving_names,
536+
pedestal_policy_state=pedestal_policy_state,
516537
)
517538
np.testing.assert_allclose(residual, 0.0)
518539
with self.subTest('updated_boundary_conditions'):
@@ -536,6 +557,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
536557
physics_models=physics_models,
537558
explicit_source_profiles=explicit_source_profiles,
538559
coeffs_old=coeffs_old,
560+
pedestal_policy_state=pedestal_policy_state,
539561
)
540562
np.testing.assert_allclose(residual, 0.0)
541563
# But when theta_implicit > 0, the residual should be non-zero.
@@ -555,6 +577,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
555577
explicit_source_profiles=explicit_source_profiles,
556578
coeffs_old=coeffs_old,
557579
evolving_names=evolving_names,
580+
pedestal_policy_state=pedestal_policy_state,
558581
)
559582
self.assertGreater(jnp.abs(jnp.sum(residual)), 0.0)
560583

0 commit comments

Comments
 (0)