2121https://gitlab.epfl.ch/spc/public/neos [O. Sauter et al]
2222"""
2323
24+ import dataclasses
2425from typing import Annotated , Literal
2526
27+ import jax
2628from jax import numpy as jnp
2729from torax ._src import array_typing
2830from torax ._src import constants
3941# pylint: disable=invalid-name
4042
4143
44+ @jax .tree_util .register_dataclass
45+ @dataclasses .dataclass (frozen = True )
46+ class RuntimeParams (transport_runtime_params .RuntimeParams ):
47+ """RuntimeParams for the Angioni-Sauter neoclassical transport model."""
48+
49+ use_shaing_correction : array_typing .BoolScalar
50+ shaing_chi_i_multiplier : array_typing .FloatScalar
51+ shaing_chi_e_multiplier : array_typing .FloatScalar
52+ shaing_D_e_multiplier : array_typing .FloatScalar
53+ shaing_blend_rho_norm_start : array_typing .FloatScalar
54+ shaing_blend_rate : array_typing .FloatScalar
55+
56+
4257class AngioniSauterModelConfig (base .NeoclassicalTransportModelConfig ):
4358 """Pydantic model for the Angioni-Sauter neoclassical transport model."""
4459
4560 model_name : Annotated [
4661 Literal ['angioni_sauter' ], torax_pydantic .JAX_STATIC
4762 ] = 'angioni_sauter'
63+ use_shaing_correction : bool = False
64+ shaing_chi_i_multiplier : float = 0.8
65+ shaing_chi_e_multiplier : float = 1.0
66+ shaing_D_e_multiplier : float = 1.0
67+ shaing_blend_rho_norm_start : float = 0.2
68+ shaing_blend_rate : float = 0.005
4869
4970 @override
5071 def build_model (self ) -> 'AngioniSauterModel' :
5172 return AngioniSauterModel ()
5273
5374 @override
54- def build_runtime_params (self ) -> transport_runtime_params .RuntimeParams :
55- return super ().build_runtime_params ()
75+ # TODO: Why does this not accept t?
76+ def build_runtime_params (self ) -> RuntimeParams :
77+
78+ base_kwargs = dataclasses .asdict (super ().build_runtime_params ())
79+ return RuntimeParams (
80+ use_shaing_correction = self .use_shaing_correction ,
81+ shaing_chi_i_multiplier = self .shaing_chi_i_multiplier ,
82+ shaing_chi_e_multiplier = self .shaing_chi_e_multiplier ,
83+ shaing_D_e_multiplier = self .shaing_D_e_multiplier ,
84+ shaing_blend_rho_norm_start = self .shaing_blend_rho_norm_start ,
85+ shaing_blend_rate = self .shaing_blend_rate ,
86+ ** base_kwargs
87+ )
88+
89+
90+ def _calculate_blend_alpha (
91+ rho_norm : array_typing .FloatVectorFace ,
92+ start : array_typing .FloatScalar ,
93+ rate : array_typing .FloatScalar ,
94+ ) -> array_typing .FloatVectorFace :
95+ """Calculate exponential blend factor for smooth transition between models.
96+
97+ The blend function is: alpha = 1 - exp(-rate * max(0, rho_norm - start))
98+
99+ This gives:
100+ - Near axis (rho_norm < start): alpha ~ 0 (full Shaing correction)
101+ - Far from axis (rho_norm >> start): alpha ~ 1 (pure Angioni-Sauter)
102+
103+ Args:
104+ rho_norm: Normalized toroidal flux coordinate (face grid)
105+ start: Rho norm value where blend transition begins
106+ rate: Controls transition steepness (higher = sharper transition)
107+
108+ Returns:
109+ Blend factor alpha in range [0, 1]
110+ """
111+ return 1.0 - jnp .exp (- rate * jnp .maximum (0.0 , rho_norm - start ))
56112
57113
58114class AngioniSauterModel (base .NeoclassicalTransportModel ):
@@ -65,12 +121,46 @@ def _call_implementation(
65121 geometry : geometry_lib .Geometry ,
66122 core_profiles : state .CoreProfiles ,
67123 ) -> base .NeoclassicalTransport :
68- """Calculates neoclassical transport coefficients."""
69- return _calculate_angioni_sauter_transport (
124+ """Calculates neoclassical transport coefficients with smooth blend.
125+
126+ When use_shaing_correction is enabled, smoothly blends between Shaing
127+ (near axis) and Angioni-Sauter (far from axis) models using an exponential
128+ transition function.
129+ """
130+ angioni_sauter = _calculate_angioni_sauter_transport (
70131 runtime_params = runtime_params ,
71132 geometry = geometry ,
72133 core_profiles = core_profiles ,
73134 )
135+ shaing = _calculate_shaing_transport (
136+ runtime_params = runtime_params ,
137+ geometry = geometry ,
138+ core_profiles = core_profiles ,
139+ )
140+
141+ # Calculate blend factor: alpha = 0 near axis (Shaing), alpha = 1 far (A-S)
142+ # If correction disabled, alpha = 1 everywhere (pure Angioni-Sauter)
143+ alpha = jnp .where (
144+ runtime_params .neoclassical .transport .use_shaing_correction ,
145+ _calculate_blend_alpha (
146+ rho_norm = geometry .rho_face_norm ,
147+ start = runtime_params .neoclassical .transport .shaing_blend_rho_norm_start ,
148+ rate = runtime_params .neoclassical .transport .shaing_blend_rate ,
149+ ),
150+ 1.0 , # Pure Angioni-Sauter when correction disabled
151+ )
152+
153+ # Blend: (1-alpha)*Shaing + alpha*Angioni-Sauter
154+ return base .NeoclassicalTransport (
155+ chi_neo_i = (1.0 - alpha ) * shaing .chi_neo_i
156+ + alpha * angioni_sauter .chi_neo_i ,
157+ chi_neo_e = (1.0 - alpha ) * shaing .chi_neo_e
158+ + alpha * angioni_sauter .chi_neo_e ,
159+ D_neo_e = (1.0 - alpha ) * shaing .D_neo_e + alpha * angioni_sauter .D_neo_e ,
160+ V_neo_e = (1.0 - alpha ) * shaing .V_neo_e + alpha * angioni_sauter .V_neo_e ,
161+ V_neo_ware_e = (1.0 - alpha ) * shaing .V_neo_ware_e
162+ + alpha * angioni_sauter .V_neo_ware_e ,
163+ )
74164
75165 def __hash__ (self ) -> int :
76166 return hash (self .__class__ .__name__ )
@@ -595,3 +685,100 @@ def _calculate_Lmn(
595685 )
596686
597687 return Lmn_e , Lmn_i
688+
689+
690+ def _calculate_shaing_transport (
691+ runtime_params : runtime_params_slice .RuntimeParams ,
692+ geometry : geometry_lib .Geometry ,
693+ core_profiles : state .CoreProfiles ,
694+ ) -> base .NeoclassicalTransport :
695+ """JIT-compatible implementation of the Shaing transport model.
696+
697+ Args:
698+ runtime_params: Runtime parameters.
699+ geometry: Geometry object.
700+ core_profiles: Core profiles object.
701+
702+ Returns:
703+ Neoclassical transport coefficients.
704+ """
705+ # Aliases for readability
706+ m_ion = core_profiles .A_i * constants .CONSTANTS .m_amu
707+ Z_ion = core_profiles .Z_i_face
708+ q = core_profiles .q_face
709+ delta = geometry .elongation_face
710+ F = geometry .F_face # Note: denoted I in Shaing
711+ R = geometry .R_major
712+ T_i_J = core_profiles .T_i .face_value () * constants .CONSTANTS .keV_to_J
713+ T_e_J = core_profiles .T_e .face_value () * constants .CONSTANTS .keV_to_J
714+ ln_Lambda_ii = collisions .calculate_log_lambda_ii (
715+ core_profiles .T_i .face_value (),
716+ core_profiles .n_i .face_value (),
717+ core_profiles .Z_i_face ,
718+ )
719+ ln_Lambda_ei = collisions .calculate_log_lambda_ei (
720+ core_profiles .T_e .face_value (), core_profiles .n_e .face_value ()
721+ )
722+ tau_ii = (
723+ 12
724+ * jnp .pi ** (3 / 2 )
725+ * constants .CONSTANTS .epsilon_0 ** 2
726+ * m_ion ** (1 / 2 )
727+ * T_i_J ** (3 / 2 )
728+ / (
729+ core_profiles .n_i .face_value ()
730+ * Z_ion ** 4
731+ * constants .CONSTANTS .q_e ** 4
732+ * ln_Lambda_ii
733+ )
734+ )
735+ tau_ei = (
736+ 3
737+ * (2 * jnp .pi ) ** (3 / 2 )
738+ * constants .CONSTANTS .epsilon_0 ** 2
739+ * constants .CONSTANTS .m_e ** (1 / 2 )
740+ * T_e_J ** (3 / 2 )
741+ / (
742+ core_profiles .n_i .face_value ()
743+ * Z_ion ** 4
744+ * constants .CONSTANTS .q_e ** 4
745+ * ln_Lambda_ei
746+ )
747+ )
748+ nu_ii = 1 / tau_ii
749+ nu_ei = 1 / tau_ei
750+ v_t_ion = jnp .sqrt (2 * T_i_J / m_ion )
751+ v_t_electron = jnp .sqrt (2 * T_e_J / constants .CONSTANTS .m_e )
752+ Omega_0_ion = (
753+ constants .CONSTANTS .q_e * core_profiles .Z_i_face * geometry .B_0 / m_ion
754+ )
755+ Omega_0_electron = (
756+ constants .CONSTANTS .q_e * geometry .B_0 / constants .CONSTANTS .m_e
757+ )
758+
759+ # Common terms
760+ # Large aspect ratio approximation
761+ # (See Equation 3, Shaing March 1997)
762+ C_1 = (2 * q / (delta * F * R )) ** (1 / 2 )
763+
764+ # Chi i term
765+ # Equation 74, Shaing March 1997
766+ chi_i = nu_ii * (F * v_t_ion / Omega_0_ion ) ** (7 / 3 ) * C_1 ** (2 / 3 )
767+
768+ # Chi e term
769+ # Equation 31, Shaing May 1997
770+ chi_e = (
771+ nu_ei * (F * v_t_electron / Omega_0_electron ) ** (7 / 3 ) * C_1 ** (2 / 3 )
772+ )
773+
774+ return base .NeoclassicalTransport (
775+ chi_neo_i = runtime_params .neoclassical .transport .shaing_chi_i_multiplier
776+ * chi_i ,
777+ chi_neo_e = runtime_params .neoclassical .transport .shaing_chi_e_multiplier
778+ * chi_e ,
779+ D_neo_e = runtime_params .neoclassical .transport .shaing_D_e_multiplier
780+ * chi_e ,
781+ # TODO
782+ V_neo_e = jnp .zeros_like (geometry .rho_face ),
783+ V_neo_ware_e = jnp .zeros_like (geometry .rho_face ),
784+ )
0 commit comments