Skip to content

Commit 6d8096c

Browse files
hamelphiTorax team
authored andcommitted
Consolidate rotation calculations and compute ExB drift
PiperOrigin-RevId: 837611623
1 parent a24241a commit 6d8096c

File tree

5 files changed

+482
-6
lines changed

5 files changed

+482
-6
lines changed

torax/_src/neoclassical/formulas.py

Lines changed: 128 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@
1313
# limitations under the License.
1414
"""Common formulas used in neoclassical models."""
1515

16+
import jax
1617
import jax.numpy as jnp
1718
from torax._src import array_typing
1819
from torax._src import constants
20+
from torax._src.fvm import cell_variable
1921
from torax._src.geometry import geometry as geometry_lib
20-
21-
# pylint: disable=invalid-name
22+
from torax._src.physics import collisions
2223

2324

25+
# pylint: disable=invalid-name
2426
def calculate_f_trap(
2527
geo: geometry_lib.Geometry,
2628
) -> array_typing.FloatVectorFace:
@@ -180,3 +182,127 @@ def calculate_nu_i_star(
180182
* (geo.epsilon_face + constants.CONSTANTS.eps) ** 1.5
181183
)
182184
)
185+
186+
187+
# Functions to calculate the neoclassical poloidal velocity.
188+
def _calculate_neoclassical_k_neo(
189+
nu_star: array_typing.FloatScalar, epsilon: array_typing.FloatScalar
190+
):
191+
"""Calculates the neoclassical coefficient k_neo.
192+
193+
Equation (6.135) from
194+
Hinton, F. L., & Hazeltine, R. D.,
195+
"Theory of plasma transport in toroidal confinement systems"
196+
Rev. Mod. Phys. 48(2), 239–308. (1976)
197+
https://doi.org/10.1103/RevModPhys.48.239
198+
199+
Limits:
200+
- Banana regime (nu_star -> 0): ~1.17
201+
- Pfirsch-Schluter regime (nu_star -> inf): ~ -2.1
202+
203+
Args:
204+
nu_star : The normalized ion collisionality.
205+
epsilon : The inverse aspect ratio.
206+
207+
Returns:
208+
k_neo : The neoclassical coefficient.
209+
"""
210+
# Calculate the first term (Banana-Plateau transition)
211+
# (1.17 - 0.35 * sqrt(nu)) / (1 + 0.7 * sqrt(nu))
212+
sqrt_nu = jnp.sqrt(nu_star)
213+
term1 = (1.17 - 0.35 * sqrt_nu) / (1.0 + 0.7 * sqrt_nu)
214+
215+
# Calculate the second term (Pfirsch-Schluter driver)
216+
# 2.1 * nu^2 * epsilon^3
217+
ps_factor = (nu_star**2) * (epsilon**3)
218+
term2 = 2.1 * ps_factor
219+
220+
# Calculate the final denominator (Switching function)
221+
# 1 + nu^2 * epsilon^3
222+
denominator = 1.0 + ps_factor
223+
224+
return (term1 - term2) / denominator
225+
226+
# TODO(b/381199010): Implement alternative Sauter-based k_neo calculation.
227+
# See Sauter (1999) Eq. 17a-17b
228+
229+
230+
@jax.jit
231+
def calculate_poloidal_velocity(
232+
T_i: cell_variable.CellVariable,
233+
n_i: array_typing.FloatVectorFace,
234+
q: array_typing.FloatVectorFace,
235+
Z_eff: array_typing.FloatVectorFace,
236+
Z_i: array_typing.FloatVectorFace,
237+
B_tor: array_typing.FloatVectorFace,
238+
B_total_squared: array_typing.FloatVectorFace,
239+
geo: geometry_lib.Geometry,
240+
rotation_multiplier: array_typing.FloatScalar = 1.0,
241+
) -> cell_variable.CellVariable:
242+
"""Computes the neoclassical ion poloidal velocity profile.
243+
244+
Implementing eq.33 from
245+
Y. B. Kim , P. H. Diamond , R. J. Groebner.
246+
"Neoclassical poloidal and toroidal rotation in tokamaks"
247+
Phys. Fluids B 3, 2050–2060 (1991)
248+
https://doi.org/10.1063/1.859671
249+
250+
Eq. 33 can be simplified to the following form in SI units:
251+
v_pol = k_neo * (dT/dr) * (B_tor / <B^2>) / (Z * e)
252+
253+
Args:
254+
T_i: Ion temperature as a cell variable [keV].
255+
n_i: Ion density on the face grid [m^-3].
256+
q: Safety factor on the face grid.
257+
Z_eff: Effective charge on the face grid.
258+
Z_i: Main ion charge on the face grid.
259+
B_tor: Toroidal magnetic field on the face grid [T].
260+
B_total_squared: Total magnetic field (toroidal + poloidal) on the face grid
261+
[T].
262+
geo : Geometry
263+
rotation_multiplier: A multiplier to apply to the poloidal velocity.
264+
Returns:
265+
v_pol : Poloidal velocity profile [m/s].
266+
"""
267+
# Note: all computations are performed on the face grid.
268+
269+
T_i_face = T_i.face_value()
270+
epsilon = geo.epsilon_face
271+
272+
# Calculate Neoclassical Coefficient k_i
273+
log_lambda_ii = collisions.calculate_log_lambda_ii(
274+
T_i_face,
275+
n_i,
276+
Z_eff,
277+
)
278+
nu_i_star = calculate_nu_i_star(
279+
q=q,
280+
geo=geo,
281+
n_i=n_i,
282+
T_i=T_i_face,
283+
Z_eff=Z_eff,
284+
log_lambda_ii=log_lambda_ii,
285+
)
286+
k_neo = _calculate_neoclassical_k_neo(nu_i_star, epsilon)
287+
288+
# Calculate Radial Temperature Gradient (dT/dr)
289+
grad_Ti = T_i.face_grad(geo.r_mid) * constants.CONSTANTS.keV_to_J # [J/m]
290+
291+
# Calculate Poloidal Velocity
292+
# v_pol = k_i * (dT/dr) * (B_tor / <B^2>) / (Z * e)
293+
B_total_squared_safe = jnp.maximum(B_total_squared, constants.CONSTANTS.eps)
294+
v_pol = (
295+
k_neo
296+
* grad_Ti
297+
* (B_tor / B_total_squared_safe)
298+
/ (constants.CONSTANTS.q_e * Z_i)
299+
)
300+
301+
v_pol = rotation_multiplier * v_pol
302+
303+
return cell_variable.CellVariable(
304+
value=geometry_lib.face_to_cell(v_pol),
305+
dr=geo.drho_norm,
306+
right_face_constraint=v_pol[-1],
307+
right_face_grad_constraint=None,
308+
)

torax/_src/neoclassical/tests/formulas_test.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def setUp(self):
6464
torax_config
6565
)
6666
)
67-
runtime_params, geo = (
67+
runtime_params, self.geo = (
6868
build_runtime_params.get_consistent_runtime_params_and_geometry(
6969
t=torax_config.numerics.t_initial,
7070
runtime_params_provider=params_provider,
@@ -75,7 +75,7 @@ def setUp(self):
7575
neoclassical_models = torax_config.neoclassical.build_models()
7676
self.core_profiles = initialization.initial_core_profiles(
7777
runtime_params,
78-
geo,
78+
self.geo,
7979
source_models=source_models,
8080
neoclassical_models=neoclassical_models,
8181
)
@@ -85,14 +85,14 @@ def setUp(self):
8585
)
8686
self.nu_e_star = formulas.calculate_nu_e_star(
8787
q=self.core_profiles.q_face,
88-
geo=geo,
88+
geo=self.geo,
8989
n_e=self.core_profiles.n_e.face_value(),
9090
T_e=self.core_profiles.T_e.face_value(),
9191
Z_eff=self.core_profiles.Z_eff_face,
9292
log_lambda_ei=log_lambda_ei,
9393
)
9494

95-
self.f_trap = formulas.calculate_f_trap(geo)
95+
self.f_trap = formulas.calculate_f_trap(self.geo)
9696

9797
def test_calculate_f_trap_positive_triangularity(self):
9898
geo = mock.create_autospec(
@@ -128,10 +128,30 @@ def test_L32_values_are_correct(self):
128128
)
129129
np.testing.assert_allclose(L32, _L32_EXPECTED, atol=_A_TOL, rtol=_R_TOL)
130130

131+
def test_calculate_poloidal_velocity_values_are_correct(self):
132+
poloidal_velocity = formulas.calculate_poloidal_velocity(
133+
T_i=self.core_profiles.T_i,
134+
n_i=self.core_profiles.n_i.face_value(),
135+
q=self.core_profiles.q_face,
136+
Z_eff=self.core_profiles.Z_eff_face,
137+
Z_i=self.core_profiles.Z_i_face,
138+
B_tor=np.ones_like(self.geo.rho_face_norm),
139+
B_total_squared=np.ones_like(self.geo.rho_face_norm),
140+
geo=self.geo,
141+
)
142+
np.testing.assert_allclose(
143+
_POLOIDAL_VELOCITY_EXPECTED,
144+
poloidal_velocity.face_value(),
145+
atol=_A_TOL,
146+
rtol=_R_TOL,
147+
)
148+
131149

132150
# Reference values from running test code in a notebook.
133151
# The test thus does not directly test the implementation, but rather
134152
# guards against unexpected modifications.
153+
# If a change is expected to theese reference values, the new values can b
154+
# copied/pasted from the logs of a failing test.
135155
_L31_EXPECTED = np.array([
136156
0.0,
137157
0.25942749,
@@ -158,6 +178,19 @@ def test_L32_values_are_correct(self):
158178
0.08557197,
159179
0.16296924,
160180
])
181+
_POLOIDAL_VELOCITY_EXPECTED = np.array([
182+
-1485.871716,
183+
-2507.496827,
184+
-3933.755809,
185+
-4537.621566,
186+
-4854.858931,
187+
-5031.592012,
188+
-5073.608117,
189+
-4858.248803,
190+
-3559.941551,
191+
3011.279478,
192+
17562.499243,
193+
])
161194

162195
if __name__ == '__main__':
163196
absltest.main()

torax/_src/physics/formulas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from torax._src.geometry import geometry
3939
from torax._src.physics import psi_calculations
4040

41+
4142
# pylint: disable=invalid-name
4243

4344

torax/_src/physics/rotation.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# Copyright 2025 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Calculations related to the rotation of the plasma."""
15+
16+
from jax import numpy as jnp
17+
from torax._src import array_typing
18+
from torax._src import constants
19+
from torax._src.fvm import cell_variable
20+
from torax._src.geometry import geometry
21+
from torax._src.neoclassical import formulas as neoclassical_formulas
22+
from torax._src.physics import psi_calculations
23+
24+
25+
# pylint: disable=invalid-name
26+
def _calculate_radial_electric_field(
27+
pressure_thermal_i: cell_variable.CellVariable,
28+
toroidal_velocity: cell_variable.CellVariable,
29+
poloidal_velocity: cell_variable.CellVariable,
30+
n_i: cell_variable.CellVariable,
31+
Z_i_face: array_typing.FloatVector,
32+
B_pol_face: array_typing.FloatVector,
33+
B_tor_face: array_typing.FloatVector,
34+
geo: geometry.Geometry,
35+
) -> cell_variable.CellVariable:
36+
"""Calculates the radial electric field Er.
37+
38+
Er = (1 / (Zi * e * ni)) * dpi/dr - v_phi * B_theta + v_theta * B_phi
39+
40+
Args:
41+
pressure_thermal_i: Pressure profile as a cell variable.
42+
toroidal_velocity: Toroidal velocity profile as a cell variable.
43+
poloidal_velocity: Poloidal velocity profile as a cell variable.
44+
n_i: Main ion density profile as a cell variable.
45+
Z_i_face: Main ion charge on the face grid.
46+
B_pol_face: Flux-surface-averaged poloidal magnetic field on the face grid.
47+
B_tor_face: Flux-surface-averaged toroidal magnetic field on the face grid.
48+
geo: Geometry object.
49+
50+
Returns:
51+
Er: Radial electric field [V/m] on the cell grid.
52+
"""
53+
# Calculate dpi/dr with respect to a midplane-averaged radial coordinate.
54+
dpi_dr = pressure_thermal_i.face_grad(geo.r_mid)
55+
56+
# Calculate Er
57+
denominator = Z_i_face * constants.CONSTANTS.q_e * n_i.face_value()
58+
Er = (
59+
(1.0 / denominator) * dpi_dr
60+
- toroidal_velocity.face_value() * B_pol_face
61+
+ poloidal_velocity.face_value() * B_tor_face
62+
)
63+
return cell_variable.CellVariable(
64+
value=geometry.face_to_cell(Er),
65+
dr=geo.drho_norm,
66+
right_face_constraint=Er[-1],
67+
right_face_grad_constraint=None,
68+
)
69+
70+
71+
def _calculate_v_ExB(
72+
Er_face: array_typing.FloatVectorFace,
73+
B_total_face: array_typing.FloatVectorFace,
74+
) -> array_typing.FloatVectorFace:
75+
"""Calculates the ExB velocity, on the face grid."""
76+
B_total_face = jnp.maximum(B_total_face, constants.CONSTANTS.eps)
77+
return jnp.where(B_total_face > 0, Er_face / B_total_face, 0.0)
78+
79+
80+
def calculate_rotation(
81+
T_i: cell_variable.CellVariable,
82+
psi: cell_variable.CellVariable,
83+
n_i: cell_variable.CellVariable,
84+
q_face: array_typing.FloatVectorFace,
85+
Z_eff_face: array_typing.FloatVectorFace,
86+
Z_i_face: array_typing.FloatVector,
87+
toroidal_velocity: cell_variable.CellVariable,
88+
pressure_thermal_i: cell_variable.CellVariable,
89+
geo: geometry.Geometry,
90+
rotation_multiplier: float = 1.0,
91+
):
92+
"""Calculates quantities related to the rotation of the plasma.
93+
94+
Args:
95+
T_i: Ion temperature profile as a cell variable.
96+
psi: Poloidal flux profile as a cell variable.
97+
n_i: Main ion density profile as a cell variable.
98+
q_face: Safety factor on the face grid.
99+
Z_eff_face: Effective charge on the face grid.
100+
Z_i_face: Main ion charge on the face grid.
101+
toroidal_velocity: Toroidal velocity profile as a cell variable.
102+
pressure_thermal_i: Pressure profile as a cell variable.
103+
geo: Geometry object.
104+
rotation_multiplier: A multiplier to apply to the poloidal velocity.
105+
106+
Returns:
107+
v_ExB: ExB velocity profile on the face grid [m/s].
108+
Er: Radial electric field as a cell variable [V/m] .
109+
poloidal_velocity: Poloidal velocity as a cell variable [m/s].
110+
"""
111+
112+
# Flux surface average of `B_phi = F/R`.
113+
B_tor_face = geo.F_face / geo.R_major_profile_face # Tesla
114+
115+
# flux-surface-averaged B_theta.
116+
B_pol_squared_face = psi_calculations.calc_bpol_squared(
117+
geo, psi
118+
) # On the face grid.
119+
B_pol_face = jnp.sqrt(B_pol_squared_face) # Tesla
120+
B_total_squared_face = B_pol_squared_face + B_tor_face**2
121+
B_total_face = jnp.sqrt(B_total_squared_face)
122+
123+
poloidal_velocity = neoclassical_formulas.calculate_poloidal_velocity(
124+
T_i=T_i,
125+
n_i=n_i.face_value(),
126+
q=q_face,
127+
Z_eff=Z_eff_face,
128+
Z_i=Z_i_face,
129+
B_tor=B_tor_face,
130+
B_total_squared=B_total_squared_face,
131+
geo=geo,
132+
rotation_multiplier=rotation_multiplier,
133+
)
134+
135+
Er = _calculate_radial_electric_field(
136+
pressure_thermal_i=pressure_thermal_i,
137+
toroidal_velocity=toroidal_velocity,
138+
poloidal_velocity=poloidal_velocity,
139+
n_i=n_i,
140+
Z_i_face=Z_i_face,
141+
B_pol_face=B_pol_face,
142+
B_tor_face=B_tor_face,
143+
geo=geo,
144+
)
145+
146+
v_ExB = _calculate_v_ExB(Er.face_value(), B_total_face)
147+
148+
return v_ExB, Er, poloidal_velocity

0 commit comments

Comments
 (0)