2121https://gitlab.epfl.ch/spc/public/neos [O. Sauter et al]
2222"""
2323
24+ import dataclasses
2425from typing import Annotated , Literal
2526
27+ import jax
2628from jax import numpy as jnp
2729from torax ._src import array_typing
2830from torax ._src import constants
3941# pylint: disable=invalid-name
4042
4143
44+ @jax .tree_util .register_dataclass
45+ @dataclasses .dataclass (frozen = True )
46+ class RuntimeParams (transport_runtime_params .RuntimeParams ):
47+ """RuntimeParams for the Angioni-Sauter neoclassical transport model."""
48+
49+ use_shaing_ion_correction : array_typing .BoolScalar
50+ shaing_ion_multiplier : array_typing .FloatScalar
51+ shaing_blend_start : array_typing .FloatScalar
52+ shaing_blend_rate : array_typing .FloatScalar
53+
54+
4255class AngioniSauterModelConfig (base .NeoclassicalTransportModelConfig ):
4356 """Pydantic model for the Angioni-Sauter neoclassical transport model."""
4457
4558 model_name : Annotated [
4659 Literal ['angioni_sauter' ], torax_pydantic .JAX_STATIC
4760 ] = 'angioni_sauter'
61+ use_shaing_ion_correction : bool = False
62+ shaing_ion_multiplier : float = 1.8
63+ shaing_blend_start : float = 0.2
64+ shaing_blend_rate : float = 5.0
4865
4966 @override
5067 def build_model (self ) -> 'AngioniSauterModel' :
5168 return AngioniSauterModel ()
5269
5370 @override
54- def build_runtime_params (self ) -> transport_runtime_params .RuntimeParams :
55- return super ().build_runtime_params ()
71+ def build_runtime_params (self ) -> RuntimeParams :
72+ base_kwargs = dataclasses .asdict (super ().build_runtime_params ())
73+ return RuntimeParams (
74+ use_shaing_ion_correction = self .use_shaing_ion_correction ,
75+ shaing_ion_multiplier = self .shaing_ion_multiplier ,
76+ shaing_blend_start = self .shaing_blend_start ,
77+ shaing_blend_rate = self .shaing_blend_rate ,
78+ ** base_kwargs
79+ )
5680
5781
5882class AngioniSauterModel (base .NeoclassicalTransportModel ):
@@ -65,13 +89,48 @@ def _call_implementation(
6589 geometry : geometry_lib .Geometry ,
6690 core_profiles : state .CoreProfiles ,
6791 ) -> base .NeoclassicalTransport :
68- """Calculates neoclassical transport coefficients."""
69- return _calculate_angioni_sauter_transport (
92+ """Calculates neoclassical transport coefficients with smooth blend.
93+
94+ When use_shaing_correction is enabled, chi_ion is smoothly blended between
95+ Shaing (near axis) and Angioni-Sauter (far from axis) models using an
96+ exponential transition function.
97+ """
98+ angioni_sauter = _calculate_angioni_sauter_transport (
99+ runtime_params = runtime_params ,
100+ geometry = geometry ,
101+ core_profiles = core_profiles ,
102+ )
103+ shaing = _calculate_shaing_transport (
70104 runtime_params = runtime_params ,
71105 geometry = geometry ,
72106 core_profiles = core_profiles ,
73107 )
74108
109+
110+ # Calculate sigmoid blend weight for Angioni-Sauter (alpha)
111+ # If correction disabled: alpha = 1 (pure Angioni-Sauter)
112+ # If correction enabled: alpha varies smoothly with rho_norm
113+ alpha = jnp .where (
114+ runtime_params .neoclassical .transport .use_shaing_correction ,
115+ _calculate_blend_alpha (
116+ rho_norm = geometry .rho_face_norm ,
117+ start = runtime_params .neoclassical .transport .shaing_blend_start ,
118+ rate = runtime_params .neoclassical .transport .shaing_blend_rate ,
119+ ),
120+ 1.0 , # Pure Angioni-Sauter when correction disabled
121+ )
122+
123+ return base .NeoclassicalTransport (
124+ # Ion transport blend: (1-alpha)*Shaing + alpha*Angioni-Sauter
125+ chi_neo_i = (1.0 - alpha ) * shaing .chi_neo_i
126+ + alpha * angioni_sauter .chi_neo_i ,
127+ # Electron transport: pure Angioni-Sauter
128+ chi_neo_e = angioni_sauter .chi_neo_e ,
129+ D_neo_e = angioni_sauter .D_neo_e ,
130+ V_neo_e = angioni_sauter .V_neo_e ,
131+ V_neo_ware_e = angioni_sauter .V_neo_ware_e ,
132+ )
133+
75134 def __hash__ (self ) -> int :
76135 return hash (self .__class__ .__name__ )
77136
@@ -587,3 +646,111 @@ def _calculate_Lmn(
587646 )
588647
589648 return Lmn_e , Lmn_i
649+
650+
651+ def _calculate_shaing_transport (
652+ runtime_params : runtime_params_lib .RuntimeParams ,
653+ geometry : geometry_lib .Geometry ,
654+ core_profiles : state .CoreProfiles ,
655+ ) -> base .NeoclassicalTransport :
656+ """JIT-compatible implementation of the Shaing transport model.
657+
658+ Currently only implements ion thermal transport. Other contributions are
659+ negligible.
660+
661+ Args:
662+ runtime_params: Runtime parameters.
663+ geometry: Geometry object.
664+ core_profiles: Core profiles object.
665+
666+ Returns:
667+ - Neoclassical transport coefficients.
668+ - Radius of validity (in terms of rho_norm) of Shaing model for electrons.
669+ - Radius of validity (in terms of rho_norm) of Shaing model for ions.
670+ """
671+ # Aliases for readability
672+ m_ion = core_profiles .A_i * constants .CONSTANTS .m_amu
673+ q = core_profiles .q_face
674+ kappa = geometry .elongation_face # Note: denoted delta in Shaing
675+ F = geometry .F_face # Note: denoted I in Shaing
676+ R = geometry .R_major_profile_face
677+ T_i_J = core_profiles .T_i .face_value () * constants .CONSTANTS .keV_to_J
678+
679+ # Collisionality
680+ ln_Lambda_ii = collisions .calculate_log_lambda_ii (
681+ core_profiles .T_i .face_value (),
682+ core_profiles .n_i .face_value (),
683+ core_profiles .Z_i_face ,
684+ )
685+ tau_ii = collisions .calculate_tau_ii (
686+ A_i = core_profiles .A_i ,
687+ Z_i = core_profiles .Z_i_face ,
688+ T_i = core_profiles .T_i .face_value (),
689+ n_i = core_profiles .n_i .face_value (),
690+ ln_Lambda_ii = ln_Lambda_ii ,
691+ )
692+ nu_ii = 1 / tau_ii # Ion-ion collision frequency
693+
694+ # Thermal velocity
695+ v_t_ion = jnp .sqrt (2 * T_i_J / m_ion )
696+
697+ # Larmor (gyro)frequency
698+ Omega_0_ion = (
699+ constants .CONSTANTS .q_e * core_profiles .Z_i_face * geometry .B_0 / m_ion
700+ )
701+
702+ # Large aspect ratio approximation (Equation 3, Shaing March 1997)
703+ C_1 = (2 * q / (kappa * F * R )) ** (1 / 2 )
704+
705+ # Conversion from flux^2/s -> m^2/s
706+ # TODO: make a more informed choice for dpsi_drhon near the axis (currently
707+ # we simply copy the value at i=1). This is ok as chi[0] is never used.
708+ dpsi_drhon = core_profiles .psi .face_grad ()
709+ dpsi_drhon = dpsi_drhon .at [0 ].set (dpsi_drhon [1 ])
710+ conversion_factor = 1 / (dpsi_drhon / (2 * jnp .pi * geometry .rho_b )) ** 2
711+
712+ # Trapped particle fraction (Equation 46, Shaing March 1997)
713+ f_t_ion = (F * v_t_ion * C_1 ** 2 / Omega_0_ion ) ** (1 / 3 )
714+
715+ # Orbit width in psi coordinates (Equation 73, Shaing March 1997)
716+ Delta_psi_ion = (F ** 2 * v_t_ion ** 2 * C_1 / Omega_0_ion ** 2 ) ** (2 / 3 )
717+
718+ # Chi i term (Equation 74, Shaing March 1997)
719+ # psi normalization difference accounted for in conversion_factor
720+ chi_i = (nu_ii * Delta_psi_ion ** 2 / f_t_ion ) * conversion_factor
721+
722+ return base .NeoclassicalTransport (
723+ chi_neo_i = runtime_params .neoclassical .transport .shaing_ion_multiplier
724+ * chi_i ,
725+ chi_neo_e = jnp .zeros_like (geometry .rho_face ),
726+ D_neo_e = jnp .zeros_like (geometry .rho_face ),
727+ V_neo_e = jnp .zeros_like (geometry .rho_face ),
728+ V_neo_ware_e = jnp .zeros_like (geometry .rho_face ),
729+ )
730+
731+
732+ def _calculate_blend_alpha (
733+ rho_norm : array_typing .FloatVectorFace ,
734+ start : array_typing .FloatScalar ,
735+ rate : array_typing .FloatScalar ,
736+ ) -> array_typing .FloatVectorFace :
737+ """Calculate blending weight between Angioni-Sauter and Shaing models.
738+
739+ The blend is:
740+ result = (1-alpha)*Shaing + alpha*Angioni-Sauter
741+ where alpha = 1 / (1 + exp(-2*rate*(rho_norm - start))).
742+
743+ This gives:
744+ - At axis (rho_norm = 0 << start): alpha ~ 0 (pure Shaing)
745+ - At start: alpha = 0.5 (equal blend)
746+ - Far from axis (rho_norm >> start): alpha ~ 1 (pure Angioni-Sauter)
747+
748+ Args:
749+ rho_norm: Normalized toroidal flux coordinate (face grid)
750+ start: Rho norm value where blend transition is centered
751+ rate: Controls transition steepness (higher = sharper transition)
752+
753+ Returns:
754+ Blend factor alpha in range [0, 1]
755+ """
756+ return 1.0 / (1.0 + jnp .exp (- 2.0 * rate * (rho_norm - start )))
0 commit comments