Skip to content

Commit 9da5f7d

Browse files
jcitrinTorax team
authored andcommitted
sawtooth simple redistribution bugfix
Include `impurity_fractions` in `Ions` construction during sawtooth redistribution when both n_e and T_e are non-evolving. This bug was caused by insufficient test coverage of simple_redistribution and was triggered by the new impurity input API features. A test was added for better coverage. Fixes #1586 PiperOrigin-RevId: 810833062
1 parent 327e6d8 commit 9da5f7d

File tree

2 files changed

+140
-0
lines changed

2 files changed

+140
-0
lines changed

torax/_src/mhd/sawtooth/simple_redistribution.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def __call__(
124124
ions_redistributed = getters.Ions(
125125
n_i=core_profiles_t.n_i,
126126
n_impurity=core_profiles_t.n_impurity,
127+
impurity_fractions=core_profiles_t.impurity_fractions,
127128
Z_i=core_profiles_t.Z_i,
128129
Z_i_face=core_profiles_t.Z_i_face,
129130
Z_impurity=core_profiles_t.Z_impurity,
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for the simple_redistribution module."""
16+
17+
from absl.testing import absltest
18+
from absl.testing import parameterized
19+
import jax
20+
from jax import numpy as jnp
21+
import numpy as np
22+
from torax._src.config import build_runtime_params
23+
from torax._src.core_profiles import initialization
24+
from torax._src.mhd.sawtooth import simple_redistribution
25+
from torax._src.physics import psi_calculations
26+
from torax._src.torax_pydantic import model_config
27+
28+
# Set jax_enable_x64 to True to ensure high precision for tests.
29+
jax.config.update('jax_enable_x64', True)
30+
31+
32+
class SimpleRedistributionTest(parameterized.TestCase):
33+
34+
@parameterized.product(
35+
evolve_ion_heat=[True, False],
36+
evolve_electron_heat=[True, False],
37+
evolve_density=[True, False],
38+
)
39+
def test_simple_redistribution_with_evolving_profiles(
40+
self, evolve_ion_heat, evolve_electron_heat, evolve_density
41+
):
42+
"""Tests that SimpleRedistribution works with all evolving profiles."""
43+
config_dict = {
44+
'numerics': {
45+
'evolve_ion_heat': evolve_ion_heat,
46+
'evolve_electron_heat': evolve_electron_heat,
47+
'evolve_density': evolve_density,
48+
'evolve_current': True,
49+
},
50+
'profile_conditions': { # Set up to ensure q[0] < 1
51+
'Ip': 15e6,
52+
'initial_j_is_total_current': True,
53+
'initial_psi_from_j': True,
54+
'current_profile_nu': 3,
55+
},
56+
'plasma_composition': {},
57+
'geometry': {'geometry_type': 'circular', 'n_rho': 10},
58+
'pedestal': {},
59+
'sources': {},
60+
'solver': {},
61+
'transport': {},
62+
'mhd': {
63+
'sawtooth': {
64+
'trigger_model': {'model_name': 'simple'},
65+
'redistribution_model': {
66+
'model_name': 'simple',
67+
'flattening_factor': 1.01,
68+
'mixing_radius_multiplier': 1.5,
69+
},
70+
}
71+
},
72+
}
73+
74+
torax_config = model_config.ToraxConfig.from_dict(config_dict)
75+
76+
assert torax_config.mhd is not None
77+
assert torax_config.mhd.sawtooth is not None
78+
79+
redistribution_model = (
80+
torax_config.mhd.sawtooth.redistribution_model.build_redistribution_model()
81+
)
82+
self.assertIsInstance(
83+
redistribution_model, simple_redistribution.SimpleRedistribution
84+
)
85+
runtime_params_provider = (
86+
build_runtime_params.RuntimeParamsProvider.from_config(torax_config)
87+
)
88+
geo_provider = torax_config.geometry.build_provider
89+
90+
runtime_params_t = runtime_params_provider(t=0.0)
91+
geo_t = geo_provider(t=0.0)
92+
93+
core_profiles_t = initialization.initial_core_profiles(
94+
runtime_params=runtime_params_t,
95+
geo=geo_t,
96+
source_models=torax_config.sources.build_models(),
97+
neoclassical_models=torax_config.neoclassical.build_models(),
98+
)
99+
100+
# Find the q=1 surface radius to pass to the model
101+
q_face_before = core_profiles_t.q_face
102+
self.assertLess(
103+
q_face_before[0],
104+
1.0,
105+
'Initial q-profile must be below 1 for this test.',
106+
)
107+
rho_norm_q1 = np.interp(
108+
1.0,
109+
q_face_before,
110+
geo_t.rho_face_norm,
111+
)
112+
113+
# Call the redistribution model
114+
redistributed_core_profiles = redistribution_model(
115+
jnp.asarray(rho_norm_q1),
116+
runtime_params_t,
117+
geo_t,
118+
core_profiles_t,
119+
)
120+
121+
# Main check: Ensure no errors were raised.
122+
# Also, perform a basic check to ensure redistribution occurred.
123+
q_face_after = psi_calculations.calc_q_face(
124+
geo_t, redistributed_core_profiles.psi
125+
)
126+
self.assertGreater(
127+
q_face_after[0],
128+
q_face_before[0],
129+
'On-axis q should increase after redistribution.',
130+
)
131+
self.assertGreaterEqual(
132+
q_face_after[0],
133+
1.0,
134+
'On-axis q should be at least 1.0 after redistribution.',
135+
)
136+
137+
138+
if __name__ == '__main__':
139+
absltest.main()

0 commit comments

Comments
 (0)