Skip to content

Commit 9eb2d5f

Browse files
committed
Add Shaing corrections to Angioni-Sauter model
1 parent dd93cc5 commit 9eb2d5f

File tree

1 file changed

+191
-4
lines changed

1 file changed

+191
-4
lines changed

torax/_src/neoclassical/transport/angioni_sauter.py

Lines changed: 191 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
https://gitlab.epfl.ch/spc/public/neos [O. Sauter et al]
2222
"""
2323

24+
import dataclasses
2425
from typing import Annotated, Literal
2526

27+
import jax
2628
from jax import numpy as jnp
2729
from torax._src import array_typing
2830
from torax._src import constants
@@ -39,20 +41,74 @@
3941
# pylint: disable=invalid-name
4042

4143

44+
@jax.tree_util.register_dataclass
45+
@dataclasses.dataclass(frozen=True)
46+
class RuntimeParams(transport_runtime_params.RuntimeParams):
47+
"""RuntimeParams for the Angioni-Sauter neoclassical transport model."""
48+
49+
use_shaing_correction: array_typing.BoolScalar
50+
shaing_chi_i_multiplier: array_typing.FloatScalar
51+
shaing_chi_e_multiplier: array_typing.FloatScalar
52+
shaing_D_e_multiplier: array_typing.FloatScalar
53+
shaing_blend_rho_norm_start: array_typing.FloatScalar
54+
shaing_blend_rate: array_typing.FloatScalar
55+
56+
4257
class AngioniSauterModelConfig(base.NeoclassicalTransportModelConfig):
4358
"""Pydantic model for the Angioni-Sauter neoclassical transport model."""
4459

4560
model_name: Annotated[
4661
Literal['angioni_sauter'], torax_pydantic.JAX_STATIC
4762
] = 'angioni_sauter'
63+
use_shaing_correction: bool = False
64+
shaing_chi_i_multiplier: float = 0.8
65+
shaing_chi_e_multiplier: float = 1.0
66+
shaing_D_e_multiplier: float = 1.0
67+
shaing_blend_rho_norm_start: float = 0.2
68+
shaing_blend_rate: float = 0.005
4869

4970
@override
5071
def build_model(self) -> 'AngioniSauterModel':
5172
return AngioniSauterModel()
5273

5374
@override
54-
def build_runtime_params(self) -> transport_runtime_params.RuntimeParams:
55-
return super().build_runtime_params()
75+
# TODO: Why does this not accept t?
76+
def build_runtime_params(self) -> RuntimeParams:
77+
78+
base_kwargs = dataclasses.asdict(super().build_runtime_params())
79+
return RuntimeParams(
80+
use_shaing_correction=self.use_shaing_correction,
81+
shaing_chi_i_multiplier=self.shaing_chi_i_multiplier,
82+
shaing_chi_e_multiplier=self.shaing_chi_e_multiplier,
83+
shaing_D_e_multiplier=self.shaing_D_e_multiplier,
84+
shaing_blend_rho_norm_start=self.shaing_blend_rho_norm_start,
85+
shaing_blend_rate=self.shaing_blend_rate,
86+
**base_kwargs
87+
)
88+
89+
90+
def _calculate_blend_alpha(
91+
rho_norm: array_typing.FloatVectorFace,
92+
start: array_typing.FloatScalar,
93+
rate: array_typing.FloatScalar,
94+
) -> array_typing.FloatVectorFace:
95+
"""Calculate exponential blend factor for smooth transition between models.
96+
97+
The blend function is: alpha = 1 - exp(-rate * max(0, rho_norm - start))
98+
99+
This gives:
100+
- Near axis (rho_norm < start): alpha ~ 0 (full Shaing correction)
101+
- Far from axis (rho_norm >> start): alpha ~ 1 (pure Angioni-Sauter)
102+
103+
Args:
104+
rho_norm: Normalized toroidal flux coordinate (face grid)
105+
start: Rho norm value where blend transition begins
106+
rate: Controls transition steepness (higher = sharper transition)
107+
108+
Returns:
109+
Blend factor alpha in range [0, 1]
110+
"""
111+
return 1.0 - jnp.exp(-rate * jnp.maximum(0.0, rho_norm - start))
56112

57113

58114
class AngioniSauterModel(base.NeoclassicalTransportModel):
@@ -65,12 +121,46 @@ def _call_implementation(
65121
geometry: geometry_lib.Geometry,
66122
core_profiles: state.CoreProfiles,
67123
) -> base.NeoclassicalTransport:
68-
"""Calculates neoclassical transport coefficients."""
69-
return _calculate_angioni_sauter_transport(
124+
"""Calculates neoclassical transport coefficients with smooth blend.
125+
126+
When use_shaing_correction is enabled, smoothly blends between Shaing
127+
(near axis) and Angioni-Sauter (far from axis) models using an exponential
128+
transition function.
129+
"""
130+
angioni_sauter = _calculate_angioni_sauter_transport(
70131
runtime_params=runtime_params,
71132
geometry=geometry,
72133
core_profiles=core_profiles,
73134
)
135+
shaing = _calculate_shaing_transport(
136+
runtime_params=runtime_params,
137+
geometry=geometry,
138+
core_profiles=core_profiles,
139+
)
140+
141+
# Calculate blend factor: alpha = 0 near axis (Shaing), alpha = 1 far (A-S)
142+
# If correction disabled, alpha = 1 everywhere (pure Angioni-Sauter)
143+
alpha = jnp.where(
144+
runtime_params.neoclassical.transport.use_shaing_correction,
145+
_calculate_blend_alpha(
146+
rho_norm=geometry.rho_face_norm,
147+
start=runtime_params.neoclassical.transport.shaing_blend_rho_norm_start,
148+
rate=runtime_params.neoclassical.transport.shaing_blend_rate,
149+
),
150+
1.0, # Pure Angioni-Sauter when correction disabled
151+
)
152+
153+
# Blend: (1-alpha)*Shaing + alpha*Angioni-Sauter
154+
return base.NeoclassicalTransport(
155+
chi_neo_i=(1.0 - alpha) * shaing.chi_neo_i
156+
+ alpha * angioni_sauter.chi_neo_i,
157+
chi_neo_e=(1.0 - alpha) * shaing.chi_neo_e
158+
+ alpha * angioni_sauter.chi_neo_e,
159+
D_neo_e=(1.0 - alpha) * shaing.D_neo_e + alpha * angioni_sauter.D_neo_e,
160+
V_neo_e=(1.0 - alpha) * shaing.V_neo_e + alpha * angioni_sauter.V_neo_e,
161+
V_neo_ware_e=(1.0 - alpha) * shaing.V_neo_ware_e
162+
+ alpha * angioni_sauter.V_neo_ware_e,
163+
)
74164

75165
def __hash__(self) -> int:
76166
return hash(self.__class__.__name__)
@@ -595,3 +685,100 @@ def _calculate_Lmn(
595685
)
596686

597687
return Lmn_e, Lmn_i
688+
689+
690+
def _calculate_shaing_transport(
691+
runtime_params: runtime_params_slice.RuntimeParams,
692+
geometry: geometry_lib.Geometry,
693+
core_profiles: state.CoreProfiles,
694+
) -> base.NeoclassicalTransport:
695+
"""JIT-compatible implementation of the Shaing transport model.
696+
697+
Args:
698+
runtime_params: Runtime parameters.
699+
geometry: Geometry object.
700+
core_profiles: Core profiles object.
701+
702+
Returns:
703+
Neoclassical transport coefficients.
704+
"""
705+
# Aliases for readability
706+
m_ion = core_profiles.A_i * constants.CONSTANTS.m_amu
707+
Z_ion = core_profiles.Z_i_face
708+
q = core_profiles.q_face
709+
delta = geometry.elongation_face
710+
F = geometry.F_face # Note: denoted I in Shaing
711+
R = geometry.R_major
712+
T_i_J = core_profiles.T_i.face_value() * constants.CONSTANTS.keV_to_J
713+
T_e_J = core_profiles.T_e.face_value() * constants.CONSTANTS.keV_to_J
714+
ln_Lambda_ii = collisions.calculate_log_lambda_ii(
715+
core_profiles.T_i.face_value(),
716+
core_profiles.n_i.face_value(),
717+
core_profiles.Z_i_face,
718+
)
719+
ln_Lambda_ei = collisions.calculate_log_lambda_ei(
720+
core_profiles.T_e.face_value(), core_profiles.n_e.face_value()
721+
)
722+
tau_ii = (
723+
12
724+
* jnp.pi ** (3 / 2)
725+
* constants.CONSTANTS.epsilon_0**2
726+
* m_ion ** (1 / 2)
727+
* T_i_J ** (3 / 2)
728+
/ (
729+
core_profiles.n_i.face_value()
730+
* Z_ion**4
731+
* constants.CONSTANTS.q_e**4
732+
* ln_Lambda_ii
733+
)
734+
)
735+
tau_ei = (
736+
3
737+
* (2 * jnp.pi) ** (3 / 2)
738+
* constants.CONSTANTS.epsilon_0**2
739+
* constants.CONSTANTS.m_e ** (1 / 2)
740+
* T_e_J ** (3 / 2)
741+
/ (
742+
core_profiles.n_i.face_value()
743+
* Z_ion**4
744+
* constants.CONSTANTS.q_e**4
745+
* ln_Lambda_ei
746+
)
747+
)
748+
nu_ii = 1 / tau_ii
749+
nu_ei = 1 / tau_ei
750+
v_t_ion = jnp.sqrt(2 * T_i_J / m_ion)
751+
v_t_electron = jnp.sqrt(2 * T_e_J / constants.CONSTANTS.m_e)
752+
Omega_0_ion = (
753+
constants.CONSTANTS.q_e * core_profiles.Z_i_face * geometry.B_0 / m_ion
754+
)
755+
Omega_0_electron = (
756+
constants.CONSTANTS.q_e * geometry.B_0 / constants.CONSTANTS.m_e
757+
)
758+
759+
# Common terms
760+
# Large aspect ratio approximation
761+
# (See Equation 3, Shaing March 1997)
762+
C_1 = (2 * q / (delta * F * R)) ** (1 / 2)
763+
764+
# Chi i term
765+
# Equation 74, Shaing March 1997
766+
chi_i = nu_ii * (F * v_t_ion / Omega_0_ion) ** (7 / 3) * C_1 ** (2 / 3)
767+
768+
# Chi e term
769+
# Equation 31, Shaing May 1997
770+
chi_e = (
771+
nu_ei * (F * v_t_electron / Omega_0_electron) ** (7 / 3) * C_1 ** (2 / 3)
772+
)
773+
774+
return base.NeoclassicalTransport(
775+
chi_neo_i=runtime_params.neoclassical.transport.shaing_chi_i_multiplier
776+
* chi_i,
777+
chi_neo_e=runtime_params.neoclassical.transport.shaing_chi_e_multiplier
778+
* chi_e,
779+
D_neo_e=runtime_params.neoclassical.transport.shaing_D_e_multiplier
780+
* chi_e,
781+
# TODO
782+
V_neo_e=jnp.zeros_like(geometry.rho_face),
783+
V_neo_ware_e=jnp.zeros_like(geometry.rho_face),
784+
)

0 commit comments

Comments
 (0)