2121https://gitlab.epfl.ch/spc/public/neos [O. Sauter et al]
2222"""
2323
24+ import dataclasses
2425from typing import Annotated , Literal
2526
27+ import jax
2628from jax import numpy as jnp
2729from torax ._src import array_typing
2830from torax ._src import constants
3941# pylint: disable=invalid-name
4042
4143
44+ @jax .tree_util .register_dataclass
45+ @dataclasses .dataclass (frozen = True )
46+ class RuntimeParams (transport_runtime_params .RuntimeParams ):
47+ """RuntimeParams for the Angioni-Sauter neoclassical transport model."""
48+
49+ use_shaing_correction : array_typing .BoolScalar
50+ shaing_chi_i_multiplier : array_typing .FloatScalar
51+ shaing_chi_e_multiplier : array_typing .FloatScalar
52+ shaing_D_e_multiplier : array_typing .FloatScalar
53+ shaing_blend_rho_norm_start : array_typing .FloatScalar
54+ shaing_blend_rate : array_typing .FloatScalar
55+
56+
4257class AngioniSauterModelConfig (base .NeoclassicalTransportModelConfig ):
4358 """Pydantic model for the Angioni-Sauter neoclassical transport model."""
4459
4560 model_name : Annotated [
4661 Literal ['angioni_sauter' ], torax_pydantic .JAX_STATIC
4762 ] = 'angioni_sauter'
63+ use_shaing_correction : bool = False
64+ shaing_chi_i_multiplier : float = 0.8
65+ shaing_chi_e_multiplier : float = 1.0
66+ shaing_D_e_multiplier : float = 1.0
67+ shaing_blend_rho_norm_start : float = 0.2
68+ shaing_blend_rate : float = 0.005
4869
4970 @override
5071 def build_model (self ) -> 'AngioniSauterModel' :
5172 return AngioniSauterModel ()
5273
5374 @override
54- def build_runtime_params (self ) -> transport_runtime_params .RuntimeParams :
55- return super ().build_runtime_params ()
75+ # TODO: Why does this not accept t?
76+ def build_runtime_params (self ) -> RuntimeParams :
77+
78+ base_kwargs = dataclasses .asdict (super ().build_runtime_params ())
79+ return RuntimeParams (
80+ use_shaing_correction = self .use_shaing_correction ,
81+ shaing_chi_i_multiplier = self .shaing_chi_i_multiplier ,
82+ shaing_chi_e_multiplier = self .shaing_chi_e_multiplier ,
83+ shaing_D_e_multiplier = self .shaing_D_e_multiplier ,
84+ shaing_blend_rho_norm_start = self .shaing_blend_rho_norm_start ,
85+ shaing_blend_rate = self .shaing_blend_rate ,
86+ ** base_kwargs
87+ )
5688
5789
5890class AngioniSauterModel (base .NeoclassicalTransportModel ):
@@ -65,13 +97,48 @@ def _call_implementation(
6597 geometry : geometry_lib .Geometry ,
6698 core_profiles : state .CoreProfiles ,
6799 ) -> base .NeoclassicalTransport :
68- """Calculates neoclassical transport coefficients."""
69- return _calculate_angioni_sauter_transport (
100+ """Calculates neoclassical transport coefficients with smooth blend.
101+
102+ When use_shaing_correction is enabled, smoothly blends between Shaing
103+ (near axis) and Angioni-Sauter (far from axis) models using an exponential
104+ transition function.
105+ """
106+ angioni_sauter = _calculate_angioni_sauter_transport (
107+ runtime_params = runtime_params ,
108+ geometry = geometry ,
109+ core_profiles = core_profiles ,
110+ )
111+ shaing = _calculate_shaing_transport (
70112 runtime_params = runtime_params ,
71113 geometry = geometry ,
72114 core_profiles = core_profiles ,
73115 )
74116
117+ # Calculate sigmoid blend weight for Angioni-Sauter (alpha)
118+ # If correction disabled: alpha = 1 (pure Angioni-Sauter)
119+ # If correction enabled: alpha varies smoothly with rho_norm
120+ alpha = jnp .where (
121+ runtime_params .neoclassical .transport .use_shaing_correction ,
122+ _calculate_blend_alpha (
123+ rho_norm = geometry .rho_face_norm ,
124+ start = runtime_params .neoclassical .transport .shaing_blend_rho_norm_start ,
125+ rate = runtime_params .neoclassical .transport .shaing_blend_rate ,
126+ ),
127+ 1.0 , # Pure Angioni-Sauter when correction disabled
128+ )
129+
130+ # Blend: (1-alpha)*Shaing + alpha*Angioni-Sauter
131+ return base .NeoclassicalTransport (
132+ chi_neo_i = (1.0 - alpha ) * shaing .chi_neo_i
133+ + alpha * angioni_sauter .chi_neo_i ,
134+ chi_neo_e = (1.0 - alpha ) * shaing .chi_neo_e
135+ + alpha * angioni_sauter .chi_neo_e ,
136+ D_neo_e = (1.0 - alpha ) * shaing .D_neo_e + alpha * angioni_sauter .D_neo_e ,
137+ V_neo_e = (1.0 - alpha ) * shaing .V_neo_e + alpha * angioni_sauter .V_neo_e ,
138+ V_neo_ware_e = (1.0 - alpha ) * shaing .V_neo_ware_e
139+ + alpha * angioni_sauter .V_neo_ware_e ,
140+ )
141+
75142 def __hash__ (self ) -> int :
76143 return hash (self .__class__ .__name__ )
77144
@@ -595,3 +662,127 @@ def _calculate_Lmn(
595662 )
596663
597664 return Lmn_e , Lmn_i
665+
666+
667+ def _calculate_shaing_transport (
668+ runtime_params : runtime_params_slice .RuntimeParams ,
669+ geometry : geometry_lib .Geometry ,
670+ core_profiles : state .CoreProfiles ,
671+ ) -> base .NeoclassicalTransport :
672+ """JIT-compatible implementation of the Shaing transport model.
673+
674+ Args:
675+ runtime_params: Runtime parameters.
676+ geometry: Geometry object.
677+ core_profiles: Core profiles object.
678+
679+ Returns:
680+ Neoclassical transport coefficients.
681+ """
682+ # Aliases for readability
683+ m_ion = core_profiles .A_i * constants .CONSTANTS .m_amu
684+ Z_ion = core_profiles .Z_i_face
685+ q = core_profiles .q_face
686+ delta = geometry .elongation_face
687+ F = geometry .F_face # Note: denoted I in Shaing
688+ R = geometry .R_major
689+ T_i_J = core_profiles .T_i .face_value () * constants .CONSTANTS .keV_to_J
690+ T_e_J = core_profiles .T_e .face_value () * constants .CONSTANTS .keV_to_J
691+ ln_Lambda_ii = collisions .calculate_log_lambda_ii (
692+ core_profiles .T_i .face_value (),
693+ core_profiles .n_i .face_value (),
694+ core_profiles .Z_i_face ,
695+ )
696+ ln_Lambda_ei = collisions .calculate_log_lambda_ei (
697+ core_profiles .T_e .face_value (), core_profiles .n_e .face_value ()
698+ )
699+ tau_ii = (
700+ 12
701+ * jnp .pi ** (3 / 2 )
702+ * constants .CONSTANTS .epsilon_0 ** 2
703+ * m_ion ** (1 / 2 )
704+ * T_i_J ** (3 / 2 )
705+ / (
706+ core_profiles .n_i .face_value ()
707+ * Z_ion ** 4
708+ * constants .CONSTANTS .q_e ** 4
709+ * ln_Lambda_ii
710+ )
711+ )
712+ tau_ei = (
713+ 3
714+ * (2 * jnp .pi ) ** (3 / 2 )
715+ * constants .CONSTANTS .epsilon_0 ** 2
716+ * constants .CONSTANTS .m_e ** (1 / 2 )
717+ * T_e_J ** (3 / 2 )
718+ / (
719+ core_profiles .n_i .face_value ()
720+ * Z_ion ** 4
721+ * constants .CONSTANTS .q_e ** 4
722+ * ln_Lambda_ei
723+ )
724+ )
725+ nu_ii = 1 / tau_ii
726+ nu_ei = 1 / tau_ei
727+ v_t_ion = jnp .sqrt (2 * T_i_J / m_ion )
728+ v_t_electron = jnp .sqrt (2 * T_e_J / constants .CONSTANTS .m_e )
729+ Omega_0_ion = (
730+ constants .CONSTANTS .q_e * core_profiles .Z_i_face * geometry .B_0 / m_ion
731+ )
732+ Omega_0_electron = (
733+ constants .CONSTANTS .q_e * geometry .B_0 / constants .CONSTANTS .m_e
734+ )
735+
736+ # Common terms
737+ # Large aspect ratio approximation
738+ # (See Equation 3, Shaing March 1997)
739+ C_1 = (2 * q / (delta * F * R )) ** (1 / 2 )
740+
741+ # Chi i term
742+ # Equation 74, Shaing March 1997
743+ chi_i = nu_ii * (F * v_t_ion / Omega_0_ion ) ** (7 / 3 ) * C_1 ** (2 / 3 )
744+
745+ # Chi e term
746+ # Equation 31, Shaing May 1997
747+ chi_e = (
748+ nu_ei * (F * v_t_electron / Omega_0_electron ) ** (7 / 3 ) * C_1 ** (2 / 3 )
749+ )
750+
751+ return base .NeoclassicalTransport (
752+ chi_neo_i = runtime_params .neoclassical .transport .shaing_chi_i_multiplier
753+ * chi_i ,
754+ chi_neo_e = runtime_params .neoclassical .transport .shaing_chi_e_multiplier
755+ * chi_e ,
756+ D_neo_e = runtime_params .neoclassical .transport .shaing_D_e_multiplier
757+ * chi_e ,
758+ # TODO
759+ V_neo_e = jnp .zeros_like (geometry .rho_face ),
760+ V_neo_ware_e = jnp .zeros_like (geometry .rho_face ),
761+ )
762+
763+
764+ def _calculate_blend_alpha (
765+ rho_norm : array_typing .FloatVectorFace ,
766+ start : array_typing .FloatScalar ,
767+ rate : array_typing .FloatScalar ,
768+ ) -> array_typing .FloatVectorFace :
769+ """Calculate blending weight between Angioni-Sauter and Shaing models.
770+
771+ The blend is:
772+ result = (1-alpha)*Shaing + alpha*Angioni-Sauter
773+ where alpha = 1 / (1 + exp(-2*rate*(rho_norm - start))).
774+
775+ This gives:
776+ - At axis (rho_norm = 0 << start): alpha ~ 0 (full Shaing)
777+ - At start: alpha = 0.5 (equal blend)
778+ - Far from axis (rho_norm >> start): alpha ~ 1 (pure Angioni-Sauter)
779+
780+ Args:
781+ rho_norm: Normalized toroidal flux coordinate (face grid)
782+ start: Rho norm value where blend transition is centered
783+ rate: Controls transition steepness (higher = sharper transition)
784+
785+ Returns:
786+ Blend factor alpha in range [0, 1]
787+ """
788+ return 1.0 / (1.0 + jnp .exp (- 2.0 * rate * (rho_norm - start )))
0 commit comments