Skip to content

Commit 1907e4c

Browse files
committed
Add Shaing corrections to Angioni-Sauter model
- Add smooth transition between models - Fix conversion from psi with 2pi factor - Tune default parameters for an 'ok' match with NCLASS on STEP and ITER cases - Add regression test
1 parent f9f79d0 commit 1907e4c

File tree

4 files changed

+399
-97
lines changed

4 files changed

+399
-97
lines changed

torax/_src/neoclassical/transport/angioni_sauter.py

Lines changed: 174 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,38 +21,61 @@
2121
https://gitlab.epfl.ch/spc/public/neos [O. Sauter et al]
2222
"""
2323

24+
import dataclasses
2425
from typing import Annotated, Literal
2526

27+
import jax
2628
from jax import numpy as jnp
27-
from torax._src import array_typing
28-
from torax._src import constants
29-
from torax._src import state
29+
from typing_extensions import override
30+
31+
from torax._src import array_typing, constants, state
3032
from torax._src.config import runtime_params as runtime_params_lib
3133
from torax._src.geometry import geometry as geometry_lib
3234
from torax._src.neoclassical import formulas
3335
from torax._src.neoclassical.transport import base
3436
from torax._src.neoclassical.transport import runtime_params as transport_runtime_params
3537
from torax._src.physics import collisions
3638
from torax._src.torax_pydantic import torax_pydantic
37-
from typing_extensions import override
3839

3940
# pylint: disable=invalid-name
4041

4142

43+
@jax.tree_util.register_dataclass
44+
@dataclasses.dataclass(frozen=True)
45+
class RuntimeParams(transport_runtime_params.RuntimeParams):
46+
"""RuntimeParams for the Angioni-Sauter neoclassical transport model."""
47+
48+
use_shaing_ion_correction: array_typing.BoolScalar
49+
shaing_ion_multiplier: array_typing.FloatScalar
50+
shaing_blend_start: array_typing.FloatScalar
51+
shaing_blend_rate: array_typing.FloatScalar
52+
53+
4254
class AngioniSauterModelConfig(base.NeoclassicalTransportModelConfig):
4355
"""Pydantic model for the Angioni-Sauter neoclassical transport model."""
4456

4557
model_name: Annotated[
4658
Literal['angioni_sauter'], torax_pydantic.JAX_STATIC
4759
] = 'angioni_sauter'
60+
use_shaing_ion_correction: bool = False
61+
shaing_ion_multiplier: float = 1.8
62+
shaing_blend_start: float = 0.2
63+
shaing_blend_rate: float = 5.0
4864

4965
@override
5066
def build_model(self) -> 'AngioniSauterModel':
5167
return AngioniSauterModel()
5268

5369
@override
54-
def build_runtime_params(self) -> transport_runtime_params.RuntimeParams:
55-
return super().build_runtime_params()
70+
def build_runtime_params(self) -> RuntimeParams:
71+
base_kwargs = dataclasses.asdict(super().build_runtime_params())
72+
return RuntimeParams(
73+
use_shaing_ion_correction=self.use_shaing_ion_correction,
74+
shaing_ion_multiplier=self.shaing_ion_multiplier,
75+
shaing_blend_start=self.shaing_blend_start,
76+
shaing_blend_rate=self.shaing_blend_rate,
77+
**base_kwargs
78+
)
5679

5780

5881
class AngioniSauterModel(base.NeoclassicalTransportModel):
@@ -65,12 +88,47 @@ def _call_implementation(
6588
geometry: geometry_lib.Geometry,
6689
core_profiles: state.CoreProfiles,
6790
) -> base.NeoclassicalTransport:
68-
"""Calculates neoclassical transport coefficients."""
69-
return _calculate_angioni_sauter_transport(
91+
"""Calculates neoclassical transport coefficients with smooth blend.
92+
93+
When use_shaing_ion_correction is enabled, chi_ion is smoothly blended
94+
between Shaing (near axis) and Angioni-Sauter (far from axis) models using
95+
an exponential transition function.
96+
"""
97+
angioni_sauter = _calculate_angioni_sauter_transport(
7098
runtime_params=runtime_params,
7199
geometry=geometry,
72100
core_profiles=core_profiles,
73101
)
102+
shaing = _calculate_shaing_transport(
103+
runtime_params=runtime_params,
104+
geometry=geometry,
105+
core_profiles=core_profiles,
106+
)
107+
108+
109+
# Calculate sigmoid blend weight for Angioni-Sauter (alpha)
110+
# If correction disabled: alpha = 1 (pure Angioni-Sauter)
111+
# If correction enabled: alpha varies smoothly with rho_norm
112+
alpha = jnp.where(
113+
runtime_params.neoclassical.transport.use_shaing_ion_correction,
114+
_calculate_blend_alpha(
115+
rho_norm=geometry.rho_face_norm,
116+
start=runtime_params.neoclassical.transport.shaing_blend_start,
117+
rate=runtime_params.neoclassical.transport.shaing_blend_rate,
118+
),
119+
1.0, # Pure Angioni-Sauter when correction disabled
120+
)
121+
122+
return base.NeoclassicalTransport(
123+
# Ion transport blend: (1-alpha)*Shaing + alpha*Angioni-Sauter
124+
chi_neo_i=(1.0 - alpha) * shaing.chi_neo_i
125+
+ alpha * angioni_sauter.chi_neo_i,
126+
# Electron transport: pure Angioni-Sauter
127+
chi_neo_e=angioni_sauter.chi_neo_e,
128+
D_neo_e=angioni_sauter.D_neo_e,
129+
V_neo_e=angioni_sauter.V_neo_e,
130+
V_neo_ware_e=angioni_sauter.V_neo_ware_e,
131+
)
74132

75133
def __hash__(self) -> int:
76134
return hash(self.__class__.__name__)
@@ -587,3 +645,111 @@ def _calculate_Lmn(
587645
)
588646

589647
return Lmn_e, Lmn_i
648+
649+
650+
def _calculate_shaing_transport(
651+
runtime_params: runtime_params_lib.RuntimeParams,
652+
geometry: geometry_lib.Geometry,
653+
core_profiles: state.CoreProfiles,
654+
) -> base.NeoclassicalTransport:
655+
"""JIT-compatible implementation of the Shaing transport model.
656+
657+
Currently only implements ion thermal transport. Other contributions are
658+
negligible.
659+
660+
Args:
661+
runtime_params: Runtime parameters.
662+
geometry: Geometry object.
663+
core_profiles: Core profiles object.
664+
665+
Returns:
666+
- Neoclassical transport coefficients.
667+
- Radius of validity (in terms of rho_norm) of Shaing model for electrons.
668+
- Radius of validity (in terms of rho_norm) of Shaing model for ions.
669+
"""
670+
# Aliases for readability
671+
m_ion = core_profiles.A_i * constants.CONSTANTS.m_amu
672+
q = core_profiles.q_face
673+
kappa = geometry.elongation_face # Note: denoted delta in Shaing
674+
F = geometry.F_face # Note: denoted I in Shaing
675+
R = geometry.R_major_profile_face
676+
T_i_J = core_profiles.T_i.face_value() * constants.CONSTANTS.keV_to_J
677+
678+
# Collisionality
679+
ln_Lambda_ii = collisions.calculate_log_lambda_ii(
680+
core_profiles.T_i.face_value(),
681+
core_profiles.n_i.face_value(),
682+
core_profiles.Z_i_face,
683+
)
684+
tau_ii = collisions.calculate_tau_ii(
685+
A_i=core_profiles.A_i,
686+
Z_i=core_profiles.Z_i_face,
687+
T_i=core_profiles.T_i.face_value(),
688+
n_i=core_profiles.n_i.face_value(),
689+
ln_Lambda_ii=ln_Lambda_ii,
690+
)
691+
nu_ii = 1 / tau_ii # Ion-ion collision frequency
692+
693+
# Thermal velocity
694+
v_t_ion = jnp.sqrt(2 * T_i_J / m_ion)
695+
696+
# Larmor (gyro)frequency
697+
Omega_0_ion = (
698+
constants.CONSTANTS.q_e * core_profiles.Z_i_face * geometry.B_0 / m_ion
699+
)
700+
701+
# Large aspect ratio approximation (Equation 3, Shaing March 1997)
702+
C_1 = (2 * q / (kappa * F * R)) ** (1 / 2)
703+
704+
# Conversion from flux^2/s -> m^2/s
705+
# TODO: make a more informed choice for dpsi_drhon near the axis (currently
706+
# we simply copy the value at i=1). This is ok as chi[0] is never used.
707+
dpsi_drhon = core_profiles.psi.face_grad()
708+
dpsi_drhon = dpsi_drhon.at[0].set(dpsi_drhon[1])
709+
conversion_factor = 1 / (dpsi_drhon / (2*jnp.pi*geometry.rho_b)) ** 2
710+
711+
# Trapped particle fraction (Equation 46, Shaing March 1997)
712+
f_t_ion = (F * v_t_ion * C_1**2 / Omega_0_ion) ** (1 / 3)
713+
714+
# Orbit width in psi coordinates (Equation 73, Shaing March 1997)
715+
Delta_psi_ion = (F**2 * v_t_ion**2 * C_1 / Omega_0_ion**2) ** (2 / 3)
716+
717+
# Chi i term (Equation 74, Shaing March 1997)
718+
# psi normalization difference accounted for in conversion_factor
719+
chi_i = (nu_ii * Delta_psi_ion**2 / f_t_ion) * conversion_factor
720+
721+
return base.NeoclassicalTransport(
722+
chi_neo_i=runtime_params.neoclassical.transport.shaing_ion_multiplier
723+
* chi_i,
724+
chi_neo_e=jnp.zeros_like(geometry.rho_face),
725+
D_neo_e=jnp.zeros_like(geometry.rho_face),
726+
V_neo_e=jnp.zeros_like(geometry.rho_face),
727+
V_neo_ware_e=jnp.zeros_like(geometry.rho_face),
728+
)
729+
730+
731+
def _calculate_blend_alpha(
732+
rho_norm: array_typing.FloatVectorFace,
733+
start: array_typing.FloatScalar,
734+
rate: array_typing.FloatScalar,
735+
) -> array_typing.FloatVectorFace:
736+
"""Calculate blending weight between Angioni-Sauter and Shaing models.
737+
738+
The blend is:
739+
result = (1-alpha)*Shaing + alpha*Angioni-Sauter
740+
where alpha = 1 / (1 + exp(-2*rate*(rho_norm - start))).
741+
742+
This gives:
743+
- At axis (rho_norm = 0 << start): alpha ~ 0 (pure Shaing)
744+
- At start: alpha = 0.5 (equal blend)
745+
- Far from axis (rho_norm >> start): alpha ~ 1 (pure Angioni-Sauter)
746+
747+
Args:
748+
rho_norm: Normalized toroidal flux coordinate (face grid)
749+
start: Rho norm value where blend transition is centered
750+
rate: Controls transition steepness (higher = sharper transition)
751+
752+
Returns:
753+
Blend factor alpha in range [0, 1]
754+
"""
755+
return 1.0 / (1.0 + jnp.exp(-2.0 * rate * (rho_norm - start)))

0 commit comments

Comments
 (0)