2626from torax ._src .fvm import cell_variable
2727from torax ._src .fvm import implicit_solve_block
2828from torax ._src .fvm import residual_and_loss
29+ from torax ._src .pedestal_policy import pedestal_policy as pedestal_policy_lib
2930from torax ._src .sources import runtime_params as source_runtime_params
3031from torax ._src .sources import source_profile_builders
3132from torax ._src .test_utils import default_sources
@@ -247,13 +248,18 @@ def test_nonlinear_solve_block_loss_minimum(
247248 core_profiles = core_profiles ,
248249 explicit = True ,
249250 )
251+ pedestal_policy = physics_models .pedestal_model .pedestal_policy
252+ pedestal_policy_state = pedestal_policy .initial_state (
253+ t = torax_config .numerics .t_initial
254+ )
250255 coeffs = calc_coeffs .calc_coeffs (
251256 runtime_params = runtime_params ,
252257 geo = geo ,
253258 core_profiles = core_profiles ,
254259 physics_models = physics_models ,
255260 explicit_source_profiles = explicit_source_profiles ,
256261 evolving_names = evolving_names ,
262+ pedestal_policy_state = pedestal_policy_state ,
257263 use_pereverzev = False ,
258264 )
259265 # dt well under the explicit stability limit for dx=1 and chi=1
@@ -287,19 +293,23 @@ def test_nonlinear_solve_block_loss_minimum(
287293 explicit_source_profiles = explicit_source_profiles ,
288294 coeffs_old = coeffs ,
289295 evolving_names = evolving_names ,
296+ pedestal_policy_state = pedestal_policy_state ,
290297 )
291298
292299 residual = residual_and_loss .theta_method_block_residual (
293300 dt = dt ,
294301 runtime_params_t_plus_dt = runtime_params ,
295302 geo_t_plus_dt = geo ,
303+
304+
296305 x_new_guess_vec = jnp .concatenate ([var .value for var in x_new ]),
297306 x_old = x_old ,
298307 core_profiles_t_plus_dt = core_profiles ,
299308 physics_models = physics_models ,
300309 explicit_source_profiles = explicit_source_profiles ,
301310 coeffs_old = coeffs ,
302311 evolving_names = evolving_names ,
312+ pedestal_policy_state = pedestal_policy_state ,
303313 )
304314
305315 np .testing .assert_allclose (loss , 0.0 , atol = 1e-7 )
@@ -353,13 +363,18 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self):
353363 dt = jnp .array (1.0 )
354364 evolving_names = tuple (['T_i' ])
355365
366+ pedestal_policy = physics_models .pedestal_model .pedestal_policy
367+ pedestal_policy_state = pedestal_policy .initial_state (
368+ t = torax_config .numerics .t_initial
369+ )
356370 coeffs = calc_coeffs .calc_coeffs (
357371 runtime_params = runtime_params ,
358372 geo = geo ,
359373 core_profiles = initial_core_profiles ,
360374 physics_models = physics_models ,
361375 explicit_source_profiles = explicit_source_profiles ,
362376 evolving_names = evolving_names ,
377+ pedestal_policy_state = pedestal_policy_state ,
363378 use_pereverzev = False ,
364379 )
365380 initial_right_boundary = jnp .array (0.0 )
@@ -470,13 +485,18 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
470485 dt = jnp .array (1.0 )
471486 evolving_names = tuple (['T_i' ])
472487
488+ pedestal_policy = physics_models .pedestal_model .pedestal_policy
489+ pedestal_policy_state = pedestal_policy .initial_state (
490+ t = torax_config .numerics .t_initial
491+ )
473492 coeffs_old = calc_coeffs .calc_coeffs (
474493 runtime_params = runtime_params_theta05 ,
475494 geo = geo ,
476495 core_profiles = initial_core_profiles ,
477496 physics_models = physics_models ,
478497 explicit_source_profiles = explicit_source_profiles ,
479498 evolving_names = evolving_names ,
499+ pedestal_policy_state = pedestal_policy_state ,
480500 use_pereverzev = False ,
481501 )
482502
@@ -513,6 +533,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
513533 explicit_source_profiles = explicit_source_profiles ,
514534 coeffs_old = coeffs_old ,
515535 evolving_names = evolving_names ,
536+ pedestal_policy_state = pedestal_policy_state ,
516537 )
517538 np .testing .assert_allclose (residual , 0.0 )
518539 with self .subTest ('updated_boundary_conditions' ):
@@ -536,6 +557,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
536557 physics_models = physics_models ,
537558 explicit_source_profiles = explicit_source_profiles ,
538559 coeffs_old = coeffs_old ,
560+ pedestal_policy_state = pedestal_policy_state ,
539561 )
540562 np .testing .assert_allclose (residual , 0.0 )
541563 # But when theta_implicit > 0, the residual should be non-zero.
@@ -555,6 +577,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
555577 explicit_source_profiles = explicit_source_profiles ,
556578 coeffs_old = coeffs_old ,
557579 evolving_names = evolving_names ,
580+ pedestal_policy_state = pedestal_policy_state ,
558581 )
559582 self .assertGreater (jnp .abs (jnp .sum (residual )), 0.0 )
560583
0 commit comments