Skip to content

Commit 9b33a67

Browse files
committed
Add Shaing corrections to Angioni-Sauter model
- Add smooth transition between models - Add self-consistent potato orbit width
1 parent 20701d0 commit 9b33a67

File tree

2 files changed

+261
-12
lines changed

2 files changed

+261
-12
lines changed

torax/_src/neoclassical/transport/angioni_sauter.py

Lines changed: 255 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
https://gitlab.epfl.ch/spc/public/neos [O. Sauter et al]
2222
"""
2323

24+
import dataclasses
2425
from typing import Annotated, Literal
2526

27+
import jax
2628
from jax import numpy as jnp
2729
from torax._src import array_typing
2830
from torax._src import constants
@@ -39,20 +41,53 @@
3941
# pylint: disable=invalid-name
4042

4143

44+
@jax.tree_util.register_dataclass
45+
@dataclasses.dataclass(frozen=True)
46+
class RuntimeParams(transport_runtime_params.RuntimeParams):
47+
"""RuntimeParams for the Angioni-Sauter neoclassical transport model."""
48+
49+
use_shaing_correction: array_typing.BoolScalar
50+
shaing_chi_i_multiplier: array_typing.FloatScalar
51+
shaing_chi_e_multiplier: array_typing.FloatScalar
52+
shaing_D_e_multiplier: array_typing.FloatScalar
53+
shaing_blend_start: array_typing.FloatScalar
54+
shaing_blend_rate: array_typing.FloatScalar
55+
shaing_blend_use_potato_orbit_width: array_typing.BoolScalar
56+
57+
4258
class AngioniSauterModelConfig(base.NeoclassicalTransportModelConfig):
4359
"""Pydantic model for the Angioni-Sauter neoclassical transport model."""
4460

4561
model_name: Annotated[
4662
Literal['angioni_sauter'], torax_pydantic.JAX_STATIC
4763
] = 'angioni_sauter'
64+
use_shaing_correction: bool = False
65+
shaing_chi_i_multiplier: float = 0.8
66+
shaing_chi_e_multiplier: float = 1.0
67+
shaing_D_e_multiplier: float = 1.0
68+
shaing_blend_use_potato_orbit_width: bool = True
69+
shaing_blend_start: float = 0.2
70+
shaing_blend_rate: float = 10.0
4871

4972
@override
5073
def build_model(self) -> 'AngioniSauterModel':
5174
return AngioniSauterModel()
5275

5376
@override
54-
def build_runtime_params(self) -> transport_runtime_params.RuntimeParams:
55-
return super().build_runtime_params()
77+
# TODO: Why does this not accept t?
78+
def build_runtime_params(self) -> RuntimeParams:
79+
80+
base_kwargs = dataclasses.asdict(super().build_runtime_params())
81+
return RuntimeParams(
82+
use_shaing_correction=self.use_shaing_correction,
83+
shaing_chi_i_multiplier=self.shaing_chi_i_multiplier,
84+
shaing_chi_e_multiplier=self.shaing_chi_e_multiplier,
85+
shaing_D_e_multiplier=self.shaing_D_e_multiplier,
86+
shaing_blend_use_potato_orbit_width=self.shaing_blend_use_potato_orbit_width,
87+
shaing_blend_start=self.shaing_blend_start,
88+
shaing_blend_rate=self.shaing_blend_rate,
89+
**base_kwargs
90+
)
5691

5792

5893
class AngioniSauterModel(base.NeoclassicalTransportModel):
@@ -65,12 +100,69 @@ def _call_implementation(
65100
geometry: geometry_lib.Geometry,
66101
core_profiles: state.CoreProfiles,
67102
) -> base.NeoclassicalTransport:
68-
"""Calculates neoclassical transport coefficients."""
69-
return _calculate_angioni_sauter_transport(
103+
"""Calculates neoclassical transport coefficients with smooth blend.
104+
105+
When use_shaing_correction is enabled, smoothly blends between Shaing
106+
(near axis) and Angioni-Sauter (far from axis) models using an exponential
107+
transition function.
108+
"""
109+
angioni_sauter = _calculate_angioni_sauter_transport(
70110
runtime_params=runtime_params,
71111
geometry=geometry,
72112
core_profiles=core_profiles,
73113
)
114+
shaing, shaing_delta_rho_norm_electron, shaing_delta_rho_norm_ion = (
115+
_calculate_shaing_transport(
116+
runtime_params=runtime_params,
117+
geometry=geometry,
118+
core_profiles=core_profiles,
119+
)
120+
)
121+
122+
# Calculate sigmoid blend weight for Angioni-Sauter (alpha)
123+
# If correction disabled: alpha = 1 (pure Angioni-Sauter)
124+
# If correction enabled: alpha varies smoothly with rho_norm
125+
electron_rho_norm_start = jnp.where(
126+
runtime_params.neoclassical.transport.shaing_blend_use_potato_orbit_width,
127+
shaing_delta_rho_norm_electron,
128+
runtime_params.neoclassical.transport.shaing_blend_start,
129+
)
130+
ion_rho_norm_start = jnp.where(
131+
runtime_params.neoclassical.transport.shaing_blend_use_potato_orbit_width,
132+
shaing_delta_rho_norm_ion,
133+
runtime_params.neoclassical.transport.shaing_blend_start,
134+
)
135+
alpha_electron = jnp.where(
136+
runtime_params.neoclassical.transport.use_shaing_correction,
137+
_calculate_blend_alpha(
138+
rho_norm=geometry.rho_face_norm,
139+
start=electron_rho_norm_start,
140+
rate=runtime_params.neoclassical.transport.shaing_blend_rate,
141+
),
142+
1.0, # Pure Angioni-Sauter when correction disabled
143+
)
144+
alpha_ion = jnp.where(
145+
runtime_params.neoclassical.transport.use_shaing_correction,
146+
_calculate_blend_alpha(
147+
rho_norm=geometry.rho_face_norm,
148+
start=ion_rho_norm_start,
149+
rate=runtime_params.neoclassical.transport.shaing_blend_rate,
150+
),
151+
1.0, # Pure Angioni-Sauter when correction disabled
152+
)
153+
# Blend: (1-alpha)*Shaing + alpha*Angioni-Sauter
154+
return base.NeoclassicalTransport(
155+
chi_neo_i=(1.0 - alpha_ion) * shaing.chi_neo_i
156+
+ alpha_ion * angioni_sauter.chi_neo_i,
157+
chi_neo_e=(1.0 - alpha_electron) * shaing.chi_neo_e
158+
+ alpha_electron * angioni_sauter.chi_neo_e,
159+
D_neo_e=(1.0 - alpha_electron) * shaing.D_neo_e
160+
+ alpha_electron * angioni_sauter.D_neo_e,
161+
V_neo_e=(1.0 - alpha_electron) * shaing.V_neo_e
162+
+ alpha_electron * angioni_sauter.V_neo_e,
163+
V_neo_ware_e=(1.0 - alpha_electron) * shaing.V_neo_ware_e
164+
+ alpha_electron * angioni_sauter.V_neo_ware_e,
165+
)
74166

75167
def __hash__(self) -> int:
76168
return hash(self.__class__.__name__)
@@ -587,3 +679,162 @@ def _calculate_Lmn(
587679
)
588680

589681
return Lmn_e, Lmn_i
682+
683+
684+
def _calculate_shaing_transport(
685+
runtime_params: runtime_params_slice.RuntimeParams,
686+
geometry: geometry_lib.Geometry,
687+
core_profiles: state.CoreProfiles,
688+
) -> (
689+
base.NeoclassicalTransport,
690+
array_typing.FloatScalar,
691+
array_typing.FloatScalar,
692+
):
693+
"""JIT-compatible implementation of the Shaing transport model.
694+
695+
Args:
696+
runtime_params: Runtime parameters.
697+
geometry: Geometry object.
698+
core_profiles: Core profiles object.
699+
700+
Returns:
701+
- Neoclassical transport coefficients.
702+
- Radius of validity (in terms of rho_norm) of Shaing model for electrons.
703+
- Radius of validity (in terms of rho_norm) of Shaing model for ions.
704+
"""
705+
# Aliases for readability
706+
m_ion = core_profiles.A_i * constants.CONSTANTS.m_amu
707+
Z_ion = core_profiles.Z_i_face
708+
q = core_profiles.q_face
709+
kappa = geometry.elongation_face # Note: denoted delta in Shaing
710+
F = geometry.F_face # Note: denoted I in Shaing
711+
R = geometry.R_major_profile_face
712+
T_i_J = core_profiles.T_i.face_value() * constants.CONSTANTS.keV_to_J
713+
T_e_J = core_profiles.T_e.face_value() * constants.CONSTANTS.keV_to_J
714+
ln_Lambda_ii = collisions.calculate_log_lambda_ii(
715+
core_profiles.T_i.face_value(),
716+
core_profiles.n_i.face_value(),
717+
core_profiles.Z_i_face,
718+
)
719+
ln_Lambda_ei = collisions.calculate_log_lambda_ei(
720+
core_profiles.T_e.face_value(), core_profiles.n_e.face_value()
721+
)
722+
tau_ii = (
723+
12
724+
* jnp.pi ** (3 / 2)
725+
* constants.CONSTANTS.epsilon_0**2
726+
* m_ion ** (1 / 2)
727+
* T_i_J ** (3 / 2)
728+
/ (
729+
core_profiles.n_i.face_value()
730+
* Z_ion**4
731+
* constants.CONSTANTS.q_e**4
732+
* ln_Lambda_ii
733+
)
734+
)
735+
tau_ei = (
736+
3
737+
* (2 * jnp.pi) ** (3 / 2)
738+
* constants.CONSTANTS.epsilon_0**2
739+
* constants.CONSTANTS.m_e ** (1 / 2)
740+
* T_e_J ** (3 / 2)
741+
/ (
742+
core_profiles.n_i.face_value()
743+
* Z_ion**4
744+
* constants.CONSTANTS.q_e**4
745+
* ln_Lambda_ei
746+
)
747+
)
748+
nu_ii = 1 / tau_ii
749+
nu_ei = 1 / tau_ei
750+
v_t_ion = jnp.sqrt(2 * T_i_J / m_ion)
751+
v_t_electron = jnp.sqrt(2 * T_e_J / constants.CONSTANTS.m_e)
752+
# Larmor / gyrofrequencies for ions and electrons
753+
Omega_0_ion = (
754+
constants.CONSTANTS.q_e * core_profiles.Z_i_face * geometry.B_0 / m_ion
755+
)
756+
Omega_0_electron = (
757+
constants.CONSTANTS.q_e * geometry.B_0 / constants.CONSTANTS.m_e
758+
)
759+
760+
# Common terms
761+
# Large aspect ratio approximation (Equation 3, Shaing March 1997)
762+
C_1 = (2 * q / (kappa * F * R)) ** (1 / 2)
763+
# Conversion from flux^2/s -> m^2/s
764+
# TODO: make an informed choice for dpsi_drhon near the axis
765+
dpsi_drhon = core_profiles.psi.face_grad()
766+
dpsi_drhon = dpsi_drhon.at[0].add(constants.CONSTANTS.eps)
767+
conversion_factor = 1 / (dpsi_drhon / geometry.rho_b) ** 2
768+
769+
# Trapped particle fraction (Equation 46, Shaing March 1997)
770+
f_t_ion = (F * v_t_ion * C_1**2 / Omega_0_ion) ** (1 / 3)
771+
f_t_electron = (F * v_t_electron * C_1**2 / Omega_0_electron) ** (1 / 3)
772+
773+
# Orbit width in psi coordinates (Equation 73, Shaing March 1997)
774+
Delta_psi_ion = (F**2 * v_t_ion**2 * C_1 / Omega_0_ion**2) ** (2 / 3)
775+
Delta_psi_electron = (
776+
F**2 * v_t_electron**2 * C_1 / Omega_0_electron**2
777+
) ** (2 / 3)
778+
779+
# Orbit width on axis in rho coordinates
780+
# Shaing is probably using psi/2pi normalization, whereas TORAX uses psi
781+
psi_norm = core_profiles.psi.face_value() - core_profiles.psi.face_value()[0]
782+
Delta_rho_norm_electron = jnp.interp(
783+
Delta_psi_electron[0] * 2 * jnp.pi, psi_norm, geometry.rho_face
784+
)
785+
Delta_rho_norm_ion = jnp.interp(
786+
Delta_psi_ion[0] * 2 * jnp.pi, psi_norm, geometry.rho_face
787+
)
788+
789+
# Chi i term
790+
# Equation 74, Shaing March 1997
791+
# psi normalization difference accounted for in conversion_factor
792+
chi_i = (nu_ii * Delta_psi_ion**2 / f_t_ion) * conversion_factor
793+
794+
# Chi e term
795+
# Equation 31, Shaing May 1997
796+
# psi normalization difference accounted for in conversion_factor
797+
chi_e = (nu_ei * Delta_psi_electron**2 / f_t_electron) * conversion_factor
798+
799+
return (
800+
base.NeoclassicalTransport(
801+
chi_neo_i=runtime_params.neoclassical.transport.shaing_chi_i_multiplier
802+
* chi_i,
803+
chi_neo_e=runtime_params.neoclassical.transport.shaing_chi_e_multiplier
804+
* chi_e,
805+
D_neo_e=runtime_params.neoclassical.transport.shaing_D_e_multiplier
806+
* chi_e,
807+
# TODO: implement a convection term
808+
V_neo_e=jnp.zeros_like(geometry.rho_face),
809+
V_neo_ware_e=jnp.zeros_like(geometry.rho_face),
810+
),
811+
Delta_rho_norm_electron,
812+
Delta_rho_norm_ion,
813+
)
814+
815+
816+
def _calculate_blend_alpha(
817+
rho_norm: array_typing.FloatVectorFace,
818+
start: array_typing.FloatScalar,
819+
rate: array_typing.FloatScalar,
820+
) -> array_typing.FloatVectorFace:
821+
"""Calculate blending weight between Angioni-Sauter and Shaing models.
822+
823+
The blend is:
824+
result = (1-alpha)*Shaing + alpha*Angioni-Sauter
825+
where alpha = 1 / (1 + exp(-2*rate*(rho_norm - start))).
826+
827+
This gives:
828+
- At axis (rho_norm = 0 << start): alpha ~ 0 (pure Shaing)
829+
- At start: alpha = 0.5 (equal blend)
830+
- Far from axis (rho_norm >> start): alpha ~ 1 (pure Angioni-Sauter)
831+
832+
Args:
833+
rho_norm: Normalized toroidal flux coordinate (face grid)
834+
start: Rho norm value where blend transition is centered
835+
rate: Controls transition steepness (higher = sharper transition)
836+
837+
Returns:
838+
Blend factor alpha in range [0, 1]
839+
"""
840+
return 1.0 / (1.0 + jnp.exp(-2.0 * rate * (rho_norm - start)))

torax/examples/step_flattop_bgb.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -149,20 +149,18 @@
149149
"D_e_max": 100.0,
150150
"V_e_min": -50.0,
151151
"V_e_max": 50.0,
152-
# Patching
153-
# Replaces neoclassical in the core, pending potato orbit correction
154-
# https://github.com/google-deepmind/torax/issues/1406
155-
"apply_inner_patch": True,
156-
"rho_inner": 0.05,
157-
"chi_e_inner": 1.0,
158-
"chi_i_inner": 15.0,
159152
# Smoothing
160153
"smooth_everywhere": True,
161154
"smoothing_width": 0.05,
162155
},
163156
"neoclassical": {
164157
"bootstrap_current": {"model_name": "sauter"},
165-
"transport": {"model_name": "angioni_sauter"},
158+
"transport": {
159+
"model_name": "angioni_sauter",
160+
"use_shaing_correction": True,
161+
"shaing_blend_use_potato_orbit_width": True,
162+
"chi_max": 20.0,
163+
},
166164
},
167165
"numerics": {
168166
"t_initial": 0.0,

0 commit comments

Comments
 (0)