Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 174 additions & 8 deletions torax/_src/neoclassical/transport/angioni_sauter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,38 +21,61 @@
https://gitlab.epfl.ch/spc/public/neos [O. Sauter et al]
"""

import dataclasses
from typing import Annotated, Literal

import jax
from jax import numpy as jnp
from torax._src import array_typing
from torax._src import constants
from torax._src import state
from typing_extensions import override

from torax._src import array_typing, constants, state
from torax._src.config import runtime_params as runtime_params_lib
from torax._src.geometry import geometry as geometry_lib
from torax._src.neoclassical import formulas
from torax._src.neoclassical.transport import base
from torax._src.neoclassical.transport import runtime_params as transport_runtime_params
from torax._src.physics import collisions
from torax._src.torax_pydantic import torax_pydantic
from typing_extensions import override

# pylint: disable=invalid-name


@jax.tree_util.register_dataclass
@dataclasses.dataclass(frozen=True)
class RuntimeParams(transport_runtime_params.RuntimeParams):
"""RuntimeParams for the Angioni-Sauter neoclassical transport model."""

use_shaing_ion_correction: array_typing.BoolScalar
shaing_ion_multiplier: array_typing.FloatScalar
shaing_blend_start: array_typing.FloatScalar
shaing_blend_rate: array_typing.FloatScalar


class AngioniSauterModelConfig(base.NeoclassicalTransportModelConfig):
"""Pydantic model for the Angioni-Sauter neoclassical transport model."""

model_name: Annotated[
Literal['angioni_sauter'], torax_pydantic.JAX_STATIC
] = 'angioni_sauter'
use_shaing_ion_correction: bool = False
shaing_ion_multiplier: float = 1.8
shaing_blend_start: float = 0.2
shaing_blend_rate: float = 5.0

@override
def build_model(self) -> 'AngioniSauterModel':
return AngioniSauterModel()

@override
def build_runtime_params(self) -> transport_runtime_params.RuntimeParams:
return super().build_runtime_params()
def build_runtime_params(self) -> RuntimeParams:
base_kwargs = dataclasses.asdict(super().build_runtime_params())
return RuntimeParams(
use_shaing_ion_correction=self.use_shaing_ion_correction,
shaing_ion_multiplier=self.shaing_ion_multiplier,
shaing_blend_start=self.shaing_blend_start,
shaing_blend_rate=self.shaing_blend_rate,
**base_kwargs
)


class AngioniSauterModel(base.NeoclassicalTransportModel):
Expand All @@ -65,12 +88,47 @@ def _call_implementation(
geometry: geometry_lib.Geometry,
core_profiles: state.CoreProfiles,
) -> base.NeoclassicalTransport:
"""Calculates neoclassical transport coefficients."""
return _calculate_angioni_sauter_transport(
"""Calculates neoclassical transport coefficients with smooth blend.

When use_shaing_ion_correction is enabled, chi_ion is smoothly blended
between Shaing (near axis) and Angioni-Sauter (far from axis) models using
an exponential transition function.
"""
angioni_sauter = _calculate_angioni_sauter_transport(
runtime_params=runtime_params,
geometry=geometry,
core_profiles=core_profiles,
)
shaing = _calculate_shaing_transport(
runtime_params=runtime_params,
geometry=geometry,
core_profiles=core_profiles,
)


# Calculate sigmoid blend weight for Angioni-Sauter (alpha)
# If correction disabled: alpha = 1 (pure Angioni-Sauter)
# If correction enabled: alpha varies smoothly with rho_norm
alpha = jnp.where(
runtime_params.neoclassical.transport.use_shaing_ion_correction,
_calculate_blend_alpha(
rho_norm=geometry.rho_face_norm,
start=runtime_params.neoclassical.transport.shaing_blend_start,
rate=runtime_params.neoclassical.transport.shaing_blend_rate,
),
1.0, # Pure Angioni-Sauter when correction disabled
)

return base.NeoclassicalTransport(
# Ion transport blend: (1-alpha)*Shaing + alpha*Angioni-Sauter
chi_neo_i=(1.0 - alpha) * shaing.chi_neo_i
+ alpha * angioni_sauter.chi_neo_i,
# Electron transport: pure Angioni-Sauter
chi_neo_e=angioni_sauter.chi_neo_e,
D_neo_e=angioni_sauter.D_neo_e,
V_neo_e=angioni_sauter.V_neo_e,
V_neo_ware_e=angioni_sauter.V_neo_ware_e,
)

def __hash__(self) -> int:
return hash(self.__class__.__name__)
Expand Down Expand Up @@ -587,3 +645,111 @@ def _calculate_Lmn(
)

return Lmn_e, Lmn_i


def _calculate_shaing_transport(
runtime_params: runtime_params_lib.RuntimeParams,
geometry: geometry_lib.Geometry,
core_profiles: state.CoreProfiles,
) -> base.NeoclassicalTransport:
"""JIT-compatible implementation of the Shaing transport model.

Currently only implements ion thermal transport. Other contributions are
negligible.

Args:
runtime_params: Runtime parameters.
geometry: Geometry object.
core_profiles: Core profiles object.

Returns:
- Neoclassical transport coefficients.
- Radius of validity (in terms of rho_norm) of Shaing model for electrons.
- Radius of validity (in terms of rho_norm) of Shaing model for ions.
"""
# Aliases for readability
m_ion = core_profiles.A_i * constants.CONSTANTS.m_amu
q = core_profiles.q_face
kappa = geometry.elongation_face # Note: denoted delta in Shaing
F = geometry.F_face # Note: denoted I in Shaing
R = geometry.R_major_profile_face
T_i_J = core_profiles.T_i.face_value() * constants.CONSTANTS.keV_to_J

# Collisionality
ln_Lambda_ii = collisions.calculate_log_lambda_ii(
core_profiles.T_i.face_value(),
core_profiles.n_i.face_value(),
core_profiles.Z_i_face,
)
tau_ii = collisions.calculate_tau_ii(
A_i=core_profiles.A_i,
Z_i=core_profiles.Z_i_face,
T_i=core_profiles.T_i.face_value(),
n_i=core_profiles.n_i.face_value(),
ln_Lambda_ii=ln_Lambda_ii,
)
nu_ii = 1 / tau_ii # Ion-ion collision frequency

# Thermal velocity
v_t_ion = jnp.sqrt(2 * T_i_J / m_ion)

# Larmor (gyro)frequency
Omega_0_ion = (
constants.CONSTANTS.q_e * core_profiles.Z_i_face * geometry.B_0 / m_ion
)

# Large aspect ratio approximation (Equation 3, Shaing March 1997)
C_1 = (2 * q / (kappa * F * R)) ** (1 / 2)

# Conversion from flux^2/s -> m^2/s
# TODO: make a more informed choice for dpsi_drhon near the axis (currently
# we simply copy the value at i=1). This is ok as chi[0] is never used.
dpsi_drhon = core_profiles.psi.face_grad()
dpsi_drhon = dpsi_drhon.at[0].set(dpsi_drhon[1])
conversion_factor = 1 / (dpsi_drhon / (2*jnp.pi*geometry.rho_b)) ** 2

# Trapped particle fraction (Equation 46, Shaing March 1997)
f_t_ion = (F * v_t_ion * C_1**2 / Omega_0_ion) ** (1 / 3)

# Orbit width in psi coordinates (Equation 73, Shaing March 1997)
Delta_psi_ion = (F**2 * v_t_ion**2 * C_1 / Omega_0_ion**2) ** (2 / 3)

# Chi i term (Equation 74, Shaing March 1997)
# psi normalization difference accounted for in conversion_factor
chi_i = (nu_ii * Delta_psi_ion**2 / f_t_ion) * conversion_factor

return base.NeoclassicalTransport(
chi_neo_i=runtime_params.neoclassical.transport.shaing_ion_multiplier
* chi_i,
chi_neo_e=jnp.zeros_like(geometry.rho_face),
D_neo_e=jnp.zeros_like(geometry.rho_face),
V_neo_e=jnp.zeros_like(geometry.rho_face),
V_neo_ware_e=jnp.zeros_like(geometry.rho_face),
)


def _calculate_blend_alpha(
rho_norm: array_typing.FloatVectorFace,
start: array_typing.FloatScalar,
rate: array_typing.FloatScalar,
) -> array_typing.FloatVectorFace:
"""Calculate blending weight between Angioni-Sauter and Shaing models.

The blend is:
result = (1-alpha)*Shaing + alpha*Angioni-Sauter
where alpha = 1 / (1 + exp(-2*rate*(rho_norm - start))).

This gives:
- At axis (rho_norm = 0 << start): alpha ~ 0 (pure Shaing)
- At start: alpha = 0.5 (equal blend)
- Far from axis (rho_norm >> start): alpha ~ 1 (pure Angioni-Sauter)

Args:
rho_norm: Normalized toroidal flux coordinate (face grid)
start: Rho norm value where blend transition is centered
rate: Controls transition steepness (higher = sharper transition)

Returns:
Blend factor alpha in range [0, 1]
"""
return 1.0 / (1.0 + jnp.exp(-2.0 * rate * (rho_norm - start)))
Loading