Skip to content

Commit da909ae

Browse files
goodfeliTorax team
authored andcommitted
Introduce PedestalPolicy.
This is a step toward having full Martin scaling, where there will be a Martin scaling PedestalPolicy that ramps up and ramps down the pedestal height. This PR is just a step that defines the new stateful interface required for doing so and transitions the existing `set_pedestal` functionality to being implemented with the new PedestalPolicy class. PiperOrigin-RevId: 810477225
1 parent e2d8dab commit da909ae

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+757
-195
lines changed

torax/_src/config/build_runtime_params.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __call__(
103103
numerics=self.numerics.build_runtime_params(t),
104104
neoclassical=self.neoclassical.build_runtime_params(),
105105
pedestal=self.pedestal.build_runtime_params(t),
106+
pedestal_policy=self.pedestal.build_pedestal_policy_runtime_params(),
106107
mhd=self.mhd.build_runtime_params(t),
107108
time_step_calculator=self.time_step_calculator.build_runtime_params(),
108109
)

torax/_src/config/runtime_params_slice.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from torax._src.mhd import runtime_params as mhd_runtime_params
4949
from torax._src.neoclassical import runtime_params as neoclassical_params
5050
from torax._src.pedestal_model import runtime_params as pedestal_model_params
51+
from torax._src.pedestal_policy import runtime_params as pedestal_policy_runtime_params
5152
from torax._src.solver import runtime_params as solver_params
5253
from torax._src.sources import runtime_params as sources_params
5354
from torax._src.time_step_calculator import runtime_params as time_step_calculator_runtime_params
@@ -78,6 +79,7 @@ class RuntimeParams:
7879
neoclassical: neoclassical_params.RuntimeParams
7980
numerics: numerics.RuntimeParams
8081
pedestal: pedestal_model_params.RuntimeParams
82+
pedestal_policy: pedestal_policy_runtime_params.PedestalPolicyRuntimeParams
8183
plasma_composition: plasma_composition.RuntimeParams
8284
profile_conditions: profile_conditions.RuntimeParams
8385
solver: solver_params.RuntimeParams

torax/_src/config/tests/build_runtime_params_test.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,24 @@ def test_pedestal_is_time_dependent(self):
9595
set_pedestal={0.0: True, 1.0: False},
9696
)
9797
)
98+
pedestal_policy = pedestal.build_pedestal_model().pedestal_policy
9899
# Check at time 0.
99100

100101
pedestal_params = pedestal.build_runtime_params(t=0.0)
101102
assert isinstance(pedestal_params, set_tped_nped.RuntimeParams)
102-
np.testing.assert_allclose(pedestal_params.set_pedestal, True)
103+
np.testing.assert_allclose(
104+
pedestal_policy.initial_state(t=0.0).use_pedestal, True
105+
)
103106
np.testing.assert_allclose(pedestal_params.T_i_ped, 0.0)
104107
np.testing.assert_allclose(pedestal_params.T_e_ped, 1.0)
105108
np.testing.assert_allclose(pedestal_params.n_e_ped, 2.0e20)
106109
np.testing.assert_allclose(pedestal_params.rho_norm_ped_top, 3.0)
107110
# And check after the time limit.
108111
pedestal_params = pedestal.build_runtime_params(t=1.0)
109112
assert isinstance(pedestal_params, set_tped_nped.RuntimeParams)
110-
np.testing.assert_allclose(pedestal_params.set_pedestal, False)
113+
np.testing.assert_allclose(
114+
pedestal_policy.initial_state(t=1.0).use_pedestal, False
115+
)
111116
np.testing.assert_allclose(pedestal_params.T_i_ped, 1.0)
112117
np.testing.assert_allclose(pedestal_params.T_e_ped, 2.0)
113118
np.testing.assert_allclose(pedestal_params.n_e_ped, 3.0e20)

torax/_src/fvm/calc_coeffs.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torax._src.fvm import cell_variable
2828
from torax._src.geometry import geometry
2929
from torax._src.pedestal_model import pedestal_model as pedestal_model_lib
30+
from torax._src.pedestal_policy import pedestal_policy
3031
from torax._src.sources import source_profile_builders
3132
from torax._src.sources import source_profiles as source_profiles_lib
3233
import typing_extensions
@@ -63,6 +64,7 @@ def __call__(
6364
core_profiles: state.CoreProfiles,
6465
x: tuple[cell_variable.CellVariable, ...],
6566
explicit_source_profiles: source_profiles_lib.SourceProfiles,
67+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
6668
allow_pereverzev: bool = False,
6769
# Checks if reduced calc_coeffs for explicit terms when theta_implicit=1
6870
# should be called
@@ -86,6 +88,7 @@ def __call__(
8688
not recalculated at time t+plus_dt with updated state during the solver
8789
iterations. For sources that are implicit, their explicit profiles are
8890
set to all zeros.
91+
pedestal_policy_state: State held by the pedestal policy.
8992
allow_pereverzev: If True, then the coeffs are being called within a
9093
linear solver. Thus could be either the use_predictor_corrector solver
9194
or as part of calculating the initial guess for the nonlinear solver. In
@@ -121,6 +124,7 @@ def __call__(
121124
explicit_source_profiles=explicit_source_profiles,
122125
physics_models=self.physics_models,
123126
evolving_names=self.evolving_names,
127+
pedestal_policy_state=pedestal_policy_state,
124128
use_pereverzev=use_pereverzev,
125129
explicit_call=explicit_call,
126130
)
@@ -219,6 +223,7 @@ def calc_coeffs(
219223
explicit_source_profiles: source_profiles_lib.SourceProfiles,
220224
physics_models: physics_models_lib.PhysicsModels,
221225
evolving_names: tuple[str, ...],
226+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
222227
use_pereverzev: bool = False,
223228
explicit_call: bool = False,
224229
) -> block_1d_coeffs.Block1DCoeffs:
@@ -241,6 +246,7 @@ def calc_coeffs(
241246
physics_models: The physics models to use for the simulation.
242247
evolving_names: The names of the evolving variables in the order that their
243248
coefficients should be written to `coeffs`.
249+
pedestal_policy_state: State held by the pedestal policy.
244250
use_pereverzev: Toggle whether to calculate Pereverzev terms
245251
explicit_call: If True, indicates that calc_coeffs is being called for the
246252
explicit component of the PDE. Then calculates a reduced Block1DCoeffs if
@@ -267,6 +273,7 @@ def calc_coeffs(
267273
explicit_source_profiles=explicit_source_profiles,
268274
physics_models=physics_models,
269275
evolving_names=evolving_names,
276+
pedestal_policy_state=pedestal_policy_state,
270277
use_pereverzev=use_pereverzev,
271278
)
272279

@@ -284,14 +291,18 @@ def _calc_coeffs_full(
284291
explicit_source_profiles: source_profiles_lib.SourceProfiles,
285292
physics_models: physics_models_lib.PhysicsModels,
286293
evolving_names: tuple[str, ...],
294+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
287295
use_pereverzev: bool = False,
288296
) -> block_1d_coeffs.Block1DCoeffs:
289297
"""See `calc_coeffs` for details."""
290298

291299
consts = constants.CONSTANTS
292300

293301
pedestal_model_output = physics_models.pedestal_model(
294-
runtime_params, geo, core_profiles
302+
runtime_params,
303+
geo,
304+
core_profiles,
305+
pedestal_policy_state,
295306
)
296307

297308
# Boolean mask for enforcing internal temperature boundary conditions to
@@ -351,7 +362,11 @@ def _calc_coeffs_full(
351362

352363
# Diffusion term coefficients
353364
turbulent_transport = physics_models.transport_model(
354-
runtime_params, geo, core_profiles, pedestal_model_output
365+
runtime_params,
366+
geo,
367+
core_profiles,
368+
pedestal_policy_state,
369+
pedestal_model_output,
355370
)
356371
neoclassical_transport = physics_models.neoclassical_models.transport(
357372
runtime_params, geo, core_profiles

torax/_src/fvm/newton_raphson_solve_block.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from torax._src.fvm import jax_root_finding
3636
from torax._src.fvm import residual_and_loss
3737
from torax._src.geometry import geometry
38+
from torax._src.pedestal_policy import pedestal_policy
3839
from torax._src.solver import predictor_corrector_method
3940
from torax._src.sources import source_profiles
4041

@@ -67,6 +68,7 @@ def newton_raphson_solve_block(
6768
physics_models: physics_models_lib.PhysicsModels,
6869
coeffs_callback: calc_coeffs.CoeffsCallback,
6970
evolving_names: tuple[str, ...],
71+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
7072
initial_guess_mode: enums.InitialGuessMode,
7173
maxiter: int,
7274
tol: float,
@@ -128,6 +130,7 @@ def newton_raphson_solve_block(
128130
core_profiles. Repeatedly called by the iterative optimizer.
129131
evolving_names: The names of variables within the core profiles that should
130132
evolve.
133+
pedestal_policy_state: State variables held by the pedestal policy.
131134
initial_guess_mode: chooses the initial_guess for the iterative method,
132135
either x_old or linear step. When taking the linear step, it is also
133136
recommended to use Pereverzev-Corrigan terms if the transport coefficients
@@ -159,6 +162,7 @@ def newton_raphson_solve_block(
159162
core_profiles_t,
160163
x_old,
161164
explicit_source_profiles=explicit_source_profiles,
165+
pedestal_policy_state=pedestal_policy_state,
162166
explicit_call=True,
163167
)
164168

@@ -175,6 +179,7 @@ def newton_raphson_solve_block(
175179
core_profiles_t,
176180
x_old,
177181
explicit_source_profiles=explicit_source_profiles,
182+
pedestal_policy_state=pedestal_policy_state,
178183
allow_pereverzev=True,
179184
explicit_call=True,
180185
)
@@ -193,6 +198,7 @@ def newton_raphson_solve_block(
193198
coeffs_exp=coeffs_exp_linear,
194199
coeffs_callback=coeffs_callback,
195200
explicit_source_profiles=explicit_source_profiles,
201+
pedestal_policy_state=pedestal_policy_state,
196202
)
197203
init_x_new_vec = fvm_conversions.cell_variable_tuple_to_vec(init_x_new)
198204
case enums.InitialGuessMode.X_OLD:
@@ -214,6 +220,7 @@ def newton_raphson_solve_block(
214220
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
215221
physics_models=physics_models,
216222
explicit_source_profiles=explicit_source_profiles,
223+
pedestal_policy_state=pedestal_policy_state,
217224
coeffs_old=coeffs_old,
218225
evolving_names=evolving_names,
219226
)

torax/_src/fvm/optimizer_solve_block.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from torax._src.fvm import fvm_conversions
3333
from torax._src.fvm import residual_and_loss
3434
from torax._src.geometry import geometry
35+
from torax._src.pedestal_policy import pedestal_policy
3536
from torax._src.solver import predictor_corrector_method
3637
from torax._src.sources import source_profiles
3738

@@ -56,6 +57,7 @@ def optimizer_solve_block(
5657
core_profiles_t: state.CoreProfiles,
5758
core_profiles_t_plus_dt: state.CoreProfiles,
5859
explicit_source_profiles: source_profiles.SourceProfiles,
60+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
5961
physics_models: physics_models_lib.PhysicsModels,
6062
coeffs_callback: calc_coeffs.CoeffsCallback,
6163
evolving_names: tuple[str, ...],
@@ -97,6 +99,7 @@ def optimizer_solve_block(
9799
being evolved by the PDE system.
98100
explicit_source_profiles: Pre-calculated sources implemented as explicit
99101
sources in the PDE.
102+
pedestal_policy_state: State variables held by the pedestal policy.
100103
physics_models: Physics models used for the calculations.
101104
coeffs_callback: Calculates diffusion, convection etc. coefficients given a
102105
core_profiles. Repeatedly called by the iterative optimizer.
@@ -123,6 +126,7 @@ def optimizer_solve_block(
123126
core_profiles_t,
124127
x_old,
125128
explicit_source_profiles=explicit_source_profiles,
129+
pedestal_policy_state=pedestal_policy_state,
126130
explicit_call=True,
127131
)
128132

@@ -140,6 +144,7 @@ def optimizer_solve_block(
140144
core_profiles_t,
141145
x_old,
142146
explicit_source_profiles=explicit_source_profiles,
147+
pedestal_policy_state=pedestal_policy_state,
143148
allow_pereverzev=True,
144149
explicit_call=True,
145150
)
@@ -157,6 +162,7 @@ def optimizer_solve_block(
157162
coeffs_exp=coeffs_exp_linear,
158163
coeffs_callback=coeffs_callback,
159164
explicit_source_profiles=explicit_source_profiles,
165+
pedestal_policy_state=pedestal_policy_state,
160166
)
161167
init_x_new_vec = fvm_conversions.cell_variable_tuple_to_vec(init_x_new)
162168
case enums.InitialGuessMode.X_OLD:
@@ -186,6 +192,7 @@ def optimizer_solve_block(
186192
init_x_new_vec=init_x_new_vec,
187193
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
188194
explicit_source_profiles=explicit_source_profiles,
195+
pedestal_policy_state=pedestal_policy_state,
189196
physics_models=physics_models,
190197
coeffs_old=coeffs_old,
191198
evolving_names=evolving_names,

torax/_src/fvm/residual_and_loss.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from torax._src.fvm import discrete_system
3939
from torax._src.fvm import fvm_conversions
4040
from torax._src.geometry import geometry
41+
from torax._src.pedestal_policy import pedestal_policy
4142
from torax._src.sources import source_profiles
4243

4344
Block1DCoeffs: TypeAlias = block_1d_coeffs.Block1DCoeffs
@@ -200,6 +201,7 @@ def theta_method_block_residual(
200201
x_old: tuple[cell_variable.CellVariable, ...],
201202
core_profiles_t_plus_dt: state.CoreProfiles,
202203
explicit_source_profiles: source_profiles.SourceProfiles,
204+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
203205
physics_models: physics_models_lib.PhysicsModels,
204206
coeffs_old: Block1DCoeffs,
205207
evolving_names: tuple[str, ...],
@@ -219,6 +221,7 @@ def theta_method_block_residual(
219221
being evolved by the PDE system.
220222
explicit_source_profiles: Pre-calculated sources implemented as explicit
221223
sources in the PDE.
224+
pedestal_policy_state: State variables held by the pedestal policy.
222225
physics_models: Physics models used for the calculations.
223226
coeffs_old: The coefficients calculated at x_old.
224227
evolving_names: The names of variables within the core profiles that should
@@ -251,6 +254,7 @@ def theta_method_block_residual(
251254
core_profiles=core_profiles_t_plus_dt,
252255
explicit_source_profiles=explicit_source_profiles,
253256
physics_models=physics_models,
257+
pedestal_policy_state=pedestal_policy_state,
254258
evolving_names=evolving_names,
255259
use_pereverzev=False,
256260
)
@@ -288,6 +292,7 @@ def theta_method_block_loss(
288292
x_old: tuple[cell_variable.CellVariable, ...],
289293
core_profiles_t_plus_dt: state.CoreProfiles,
290294
explicit_source_profiles: source_profiles.SourceProfiles,
295+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
291296
physics_models: physics_models_lib.PhysicsModels,
292297
coeffs_old: Block1DCoeffs,
293298
evolving_names: tuple[str, ...],
@@ -307,6 +312,7 @@ def theta_method_block_loss(
307312
being evolved by the PDE system.
308313
explicit_source_profiles: pre-calculated sources implemented as explicit
309314
sources in the PDE
315+
pedestal_policy_state: State variables held by the pedestal policy.
310316
physics_models: Physics models used for the calculations.
311317
coeffs_old: The coefficients calculated at x_old.
312318
evolving_names: The names of variables within the core profiles that should
@@ -324,6 +330,7 @@ def theta_method_block_loss(
324330
x_new_guess_vec=x_new_guess_vec,
325331
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
326332
explicit_source_profiles=explicit_source_profiles,
333+
pedestal_policy_state=pedestal_policy_state,
327334
physics_models=physics_models,
328335
coeffs_old=coeffs_old,
329336
evolving_names=evolving_names,
@@ -346,6 +353,7 @@ def jaxopt_solver(
346353
init_x_new_vec: jax.Array,
347354
core_profiles_t_plus_dt: state.CoreProfiles,
348355
explicit_source_profiles: source_profiles.SourceProfiles,
356+
pedestal_policy_state: pedestal_policy.PedestalPolicy,
349357
physics_models: physics_models_lib.PhysicsModels,
350358
coeffs_old: Block1DCoeffs,
351359
evolving_names: tuple[str, ...],
@@ -367,6 +375,7 @@ def jaxopt_solver(
367375
being evolved by the PDE system.
368376
explicit_source_profiles: pre-calculated sources implemented as explicit
369377
sources in the PDE.
378+
pedestal_policy_state: State variables held by the pedestal policy.
370379
physics_models: Physics models used for the calculations.
371380
coeffs_old: The coefficients calculated at x_old.
372381
evolving_names: The names of variables within the core profiles that should
@@ -391,6 +400,7 @@ def jaxopt_solver(
391400
physics_models=physics_models,
392401
coeffs_old=coeffs_old,
393402
evolving_names=evolving_names,
403+
pedestal_policy_state=pedestal_policy_state,
394404
)
395405
solver = jaxopt.LBFGS(fun=loss, maxiter=maxiter, tol=tol)
396406
solver_output = solver.run(init_x_new_vec)

torax/_src/fvm/tests/calc_coeffs_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,19 @@ def test_calc_coeffs_smoke_test(
7979
neoclassical_models=physics_models.neoclassical_models,
8080
explicit=True,
8181
)
82+
pedestal_policy_state = (
83+
physics_models.pedestal_model.pedestal_policy.initial_state(
84+
t=torax_config.numerics.t_initial
85+
)
86+
)
8287
calc_coeffs.calc_coeffs(
8388
runtime_params=runtime_params,
8489
geo=geo,
8590
core_profiles=core_profiles,
8691
physics_models=physics_models,
8792
explicit_source_profiles=explicit_source_profiles,
8893
evolving_names=evolving_names,
94+
pedestal_policy_state=pedestal_policy_state,
8995
use_pereverzev=False,
9096
)
9197

0 commit comments

Comments
 (0)