Skip to content

Commit 943559a

Browse files
committed
Switch psi sources to being <j.B>/B_0
1 parent 5b87c03 commit 943559a

File tree

6 files changed

+180
-31
lines changed

6 files changed

+180
-31
lines changed

torax/_src/output_tools/output.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,12 @@
1717
import dataclasses
1818
import inspect
1919
import itertools
20-
2120
import os
2221

2322
from absl import logging
2423
import chex
2524
import jax
2625
import numpy as np
27-
import os
2826
from torax._src import array_typing
2927
from torax._src import state
3028
from torax._src.geometry import geometry as geometry_lib
@@ -58,14 +56,18 @@
5856
Z_EFF = "Z_eff"
5957
SIGMA_PARALLEL = "sigma_parallel"
6058
V_LOOP_LCFS = "v_loop_lcfs"
61-
J_TOTAL = "j_total"
6259
IP_PROFILE = "Ip_profile"
6360
IP = "Ip"
6461

65-
# Calculated or derived currents.
66-
J_OHMIC = "j_ohmic"
67-
J_EXTERNAL = "j_external"
68-
J_BOOTSTRAP = "j_bootstrap"
62+
# Calculated or derived current densities
63+
J_PAR_TOTAL = "j_parallel_total"
64+
J_PAR_OHMIC = "j_parallel_ohmic"
65+
J_PAR_EXTERNAL = "j_parallel_external"
66+
J_PAR_BOOTSTRAP = "j_parallel_bootstrap"
67+
J_TOR_TOTAL = "j_total"
68+
J_TOR_OHMIC = "j_ohmic"
69+
J_TOR_EXTERNAL = "j_external"
70+
J_TOR_BOOTSTRAP = "j_bootstrap"
6971
I_BOOTSTRAP = "I_bootstrap"
7072

7173
# Core transport.
@@ -494,6 +496,7 @@ def _save_core_profiles(
494496
"Ip_profile_face": IP_PROFILE,
495497
"q_face": Q,
496498
"s_face": MAGNETIC_SHEAR,
499+
"j_total": J_TOR_TOTAL, # j_total is the fsa toroidal current, dI/dS
497500
}
498501

499502
core_profile_field_names = {
@@ -630,7 +633,7 @@ def _save_core_sources(
630633
else:
631634
xr_dict[f"p_{profile}_e"] = self._stacked_core_sources.T_e[profile]
632635
for profile in self._stacked_core_sources.psi:
633-
xr_dict[f"j_{profile}"] = self._stacked_core_sources.psi[profile]
636+
xr_dict[f"j_parallel_{profile}"] = self._stacked_core_sources.psi[profile]
634637
for profile in self._stacked_core_sources.n_e:
635638
xr_dict[f"s_{profile}"] = self._stacked_core_sources.n_e[profile]
636639

torax/_src/output_tools/post_processing.py

Lines changed: 80 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,18 @@ class PostProcessedOutputs:
146146
rho_q_3_1_second: Second outermost rho_norm value that intercepts the q=3/1
147147
plane. If no intercept is found, set to -inf.
148148
I_bootstrap: Total bootstrap current [A].
149-
j_external: Total current density from psi sources which are external to the
150-
plasma (aka not bootstrap) [A m^-2]
151-
j_ohmic: Ohmic current density [A/m^2]
149+
j_parallel_external: Parallel current density from external psi sources
150+
(i.e., excluding bootstrap) [A m^-2]
151+
j_parallel_ohmic: Parallel ohmic current density [Am^-2]
152+
j_total: Total toroidal current density [Am^-2]
153+
j_bootstrap: Toroidal bootstrap current density [Am^-2]
154+
j_ohmic: Toroidal ohmic current density [Am^-2]
155+
j_external: Toroidal current density from external psi sources
156+
(i.e., excluding bootstrap) [A m^-2]
157+
j_generic_current: Toroidal current density from generic current source
158+
[Am^-2]
159+
j_ecrh: Toroidal current density from electron cyclotron heating
160+
and current source [Am^-2]
152161
S_gas_puff: Integrated gas puff source [s^-1]
153162
S_pellet: Integrated pellet source [s^-1]
154163
S_generic_particle: Integrated generic particle source [s^-1]
@@ -233,8 +242,14 @@ class PostProcessedOutputs:
233242
rho_q_3_1_first: array_typing.FloatScalar
234243
rho_q_3_1_second: array_typing.FloatScalar
235244
I_bootstrap: array_typing.FloatScalar
236-
j_external: array_typing.FloatVector
245+
j_parallel_external: array_typing.FloatVector
246+
j_parallel_ohmic: array_typing.FloatVector
247+
j_total: array_typing.FloatVector
248+
j_bootstrap: array_typing.FloatVector
237249
j_ohmic: array_typing.FloatVector
250+
j_external: array_typing.FloatVector
251+
j_generic_current: array_typing.FloatVector
252+
j_ecrh: array_typing.FloatVector
238253
S_gas_puff: array_typing.FloatScalar
239254
S_pellet: array_typing.FloatScalar
240255
S_generic_particle: array_typing.FloatScalar
@@ -337,8 +352,15 @@ def zeros(cls, geo: geometry.Geometry) -> typing_extensions.Self:
337352
rho_q_2_1_second=jnp.array(0.0, dtype=jax_utils.get_dtype()),
338353
rho_q_3_1_second=jnp.array(0.0, dtype=jax_utils.get_dtype()),
339354
I_bootstrap=jnp.array(0.0, dtype=jax_utils.get_dtype()),
340-
j_external=jnp.zeros(geo.rho_face.shape),
355+
# TODO (v2): rename j_* to j_toroidal_* for clarity
356+
j_parallel_external=jnp.zeros(geo.rho_face.shape),
357+
j_parallel_ohmic=jnp.zeros(geo.rho_face.shape),
358+
j_total=jnp.zeros(geo.rho_face.shape),
359+
j_bootstrap=jnp.zeros(geo.rho_face.shape),
341360
j_ohmic=jnp.zeros(geo.rho_face.shape),
361+
j_external=jnp.zeros(geo.rho_face.shape),
362+
j_generic_current=jnp.zeros(geo.rho_face.shape),
363+
j_ecrh=jnp.zeros(geo.rho_face.shape),
342364
S_gas_puff=jnp.array(0.0, dtype=jax_utils.get_dtype()),
343365
S_pellet=jnp.array(0.0, dtype=jax_utils.get_dtype()),
344366
S_generic_particle=jnp.array(0.0, dtype=jax_utils.get_dtype()),
@@ -514,7 +536,7 @@ def _calculate_integrated_sources(
514536
integrated['P_external_injected'] += integrated[f'{value}']
515537

516538
for key, value in CURRENT_SOURCE_TRANSFORMATIONS.items():
517-
integrated[f'{value}'] = _get_integrated_source_value(
539+
integrated[value] = _get_integrated_source_value(
518540
core_sources.psi, key, geo, math_utils.area_integration
519541
)
520542

@@ -590,10 +612,9 @@ def make_post_processed_outputs(
590612
runtime_params,
591613
)
592614
# Calculate fusion gain with a zero division guard.
593-
Q_fusion = (
594-
integrated_sources['P_fusion']
595-
/ (integrated_sources['P_external_total']
596-
+ constants.CONSTANTS.eps))
615+
Q_fusion = integrated_sources['P_fusion'] / (
616+
integrated_sources['P_external_total'] + constants.CONSTANTS.eps
617+
)
597618

598619
P_LH_hi_dens, P_LH_min, P_LH, n_e_min_P_LH = (
599620
scaling_laws.calculate_plh_scaling_factor(
@@ -756,11 +777,48 @@ def cumulative_values():
756777
sim_state.core_sources.bootstrap_current.j_bootstrap, sim_state.geometry
757778
)
758779

759-
j_external = sum(sim_state.core_sources.psi.values())
760-
psi_current = (
761-
j_external + sim_state.core_sources.bootstrap_current.j_bootstrap
780+
# Parallel current densities
781+
# j_total is toroidal by default (see psi_calculations.calc_j_total)
782+
j_parallel_total = psi_calculations.j_tor_to_j_parallel(
783+
sim_state.core_profiles.j_total, sim_state.geometry
784+
)
785+
# Core sources psi are all <j.B>/B0
786+
j_parallel_external = sum(sim_state.core_sources.psi.values())
787+
j_parallel_ohmic = (
788+
j_parallel_total
789+
- j_parallel_external
790+
- sim_state.core_sources.bootstrap_current.j_bootstrap # parallel by default
791+
)
792+
793+
# Toroidal current densities
794+
# j_total is toroidal by default (see psi_calculations.calc_j_total)
795+
j_total = sim_state.core_profiles.j_total
796+
# Other sources are parallel, so convert to toroidal
797+
j_bootstrap = psi_calculations.j_parallel_to_j_tor(
798+
sim_state.core_sources.bootstrap_current.j_bootstrap, sim_state.geometry
799+
)
800+
j_ohmic = psi_calculations.j_parallel_to_j_tor(
801+
j_parallel_ohmic, sim_state.geometry
802+
)
803+
j_external = psi_calculations.j_parallel_to_j_tor(
804+
j_parallel_external, sim_state.geometry
762805
)
763-
j_ohmic = sim_state.core_profiles.j_total - psi_current
806+
j_sources = {}
807+
for source_name in ['ecrh', 'generic_current']:
808+
if source_name in sim_state.core_sources.psi.keys():
809+
j_sources[f'j_{source_name}'] = psi_calculations.j_parallel_to_j_tor(
810+
sim_state.core_sources.psi[source_name], sim_state.geometry
811+
)
812+
j_sources[f'j_parallel_{source_name}'] = sim_state.core_sources.psi[
813+
source_name
814+
]
815+
else:
816+
j_sources[f'j_{source_name}'] = jnp.array(
817+
0.0, dtype=jax_utils.get_dtype()
818+
)
819+
j_sources[f'j_parallel_{source_name}'] = jnp.array(
820+
0.0, dtype=jax_utils.get_dtype()
821+
)
764822

765823
beta_tor, beta_pol, beta_N = formulas.calculate_betas(
766824
sim_state.core_profiles, sim_state.geometry
@@ -813,8 +871,15 @@ def cumulative_values():
813871
rho_q_2_1_second=safety_factor_fit_outputs.rho_q_2_1_second,
814872
rho_q_3_1_second=safety_factor_fit_outputs.rho_q_3_1_second,
815873
I_bootstrap=I_bootstrap,
816-
j_external=j_external,
874+
j_parallel_total=j_parallel_total,
875+
j_parallel_ohmic=j_parallel_ohmic,
876+
j_parallel_bootstrap=j_parallel_bootstrap,
877+
j_parallel_external=j_parallel_external,
878+
j_total=j_total,
817879
j_ohmic=j_ohmic,
880+
j_bootstrap=j_bootstrap,
881+
j_external=j_external,
882+
**j_sources,
818883
beta_tor=beta_tor,
819884
beta_pol=beta_pol,
820885
beta_N=beta_N,

torax/_src/physics/psi_calculations.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
- calculate_psi_grad_constraint_from_Ip: Calculates the gradient
2727
constraint on the poloidal flux (psi) from Ip.
2828
- _calc_bpol2: Calculates square of poloidal field (Bp).
29+
constraint on the poloidal flux (psi) from Ip.
30+
- j_tor_to_j_parallel: Calculates <j.B>/B0 from j_tor = dI/dS.
31+
- j_parallel_to_j_tor: Calculates j_tor = dI/dS from <j.B>/B0.
2932
"""
3033
import jax
3134
from jax import numpy as jnp
@@ -338,3 +341,47 @@ def calculate_psidot_from_psi_sources(
338341
psidot = (jnp.dot(c_mat, psi.value) + c) / toc_psi
339342

340343
return psidot
344+
345+
346+
def j_tor_to_j_parallel(
347+
j_tor: array_typing.FloatVector, geo: geometry.Geometry
348+
) -> array_typing.FloatVector:
349+
r"""Calculates <j.B>/B0 from j_tor = dI/dS.
350+
351+
The relationship is
352+
353+
.. math::
354+
355+
\frac{\langle j.B \rangle}{B_0} = \frac{F^2 \langle 1/R^2 \rangle}{B_0^2}
356+
\left(\frac{j_{tor}}{2 \pi \rho} \frac{dS}{d\rho} - \frac{1}{2 \pi \rho F}
357+
\frac{dF}{d\rho} \int_0^\rho j_{tor} \frac{dS}{d\rho'} d\rho' \right)
358+
"""
359+
dF_drho = jnp.gradient(geo.F, geo.rho)
360+
term1 = geo.spr * j_tor / (2 * jnp.pi * geo.rho)
361+
term2 = (
362+
dF_drho
363+
/ (2 * jnp.pi * geo.rho * geo.F)
364+
* jnp.cumsum(j_tor * geo.spr * geo.drho)
365+
)
366+
return geo.F**2 * geo.g3 / geo.B_0**2 * (term1 - term2)
367+
368+
369+
def j_parallel_to_j_tor(
370+
j_parallel: array_typing.FloatVector, geo: geometry.Geometry
371+
) -> array_typing.FloatVector:
372+
r"""Calculates j_tor = dI/dS from <j.B>/B0.
373+
374+
The relationship is
375+
376+
.. math::
377+
j_\mathrm{tor} = \frac{\partial}{\partial S} (2 \pi B_0^2 F \int_{0}^{\rho'}
378+
\frac{\langle j.B \rangle}{B_0 F^3 \langle 1/R^2 \rangle} \rho' d\rho')
379+
"""
380+
I = (
381+
2
382+
* jnp.pi
383+
* geo.B_0**2
384+
* geo.F
385+
* jnp.cumsum(j_parallel * geo.rho / (geo.F**3 * geo.g3))
386+
)
387+
return jnp.gradient(I, geo.area)

torax/_src/physics/tests/psi_calculations_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,28 @@ def test_calc_Wpol(self):
190190
# our circular geometry, but approximates it at low inverse aspect ratio.
191191
np.testing.assert_allclose(calculated_Wpol, expected_Wpol, rtol=1e-3)
192192

193+
# pylint: enable=invalid-name
194+
195+
def test_calc_j_tor_from_j_parallel(self):
196+
geo = None
197+
j_tor_truth = None
198+
j_parallel_truth = None
199+
j_tor_from_j_parallel = psi_calculations.j_parallel_to_j_tor(
200+
j_parallel_truth, geo
201+
)
202+
np.testing.assert_allclose(j_tor_from_j_parallel, j_tor_truth, rtol=1e-6)
203+
204+
def test_calc_j_parallel_from_j_tor(self):
205+
geo = None
206+
j_tor_truth = None
207+
j_parallel_truth = None
208+
j_parallel_from_j_tor = psi_calculations.j_tor_to_j_parallel(
209+
j_tor_truth, geo
210+
)
211+
np.testing.assert_allclose(
212+
j_parallel_from_j_tor, j_parallel_truth, rtol=1e-6
213+
)
214+
193215

194216
if __name__ == '__main__':
195217
absltest.main()

torax/_src/sources/generic_current_source.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torax._src.config import runtime_params_slice
2626
from torax._src.geometry import geometry
2727
from torax._src.neoclassical.conductivity import base as conductivity_base
28+
from torax._src.physics import psi_calculations
2829
from torax._src.sources import base as source_base
2930
from torax._src.sources import runtime_params as runtime_params_lib
3031
from torax._src.sources import source
@@ -58,7 +59,8 @@ def calculate_generic_current(
5859
unused_calculated_source_profiles: source_profiles.SourceProfiles | None,
5960
unused_conductivity: conductivity_base.Conductivity | None,
6061
) -> tuple[array_typing.FloatVectorCell, ...]:
61-
"""Calculates the external current density profiles on the cell grid."""
62+
"""Calculates the external parallel current density profile on the cell
63+
grid."""
6264
source_params = runtime_params.sources[source_name]
6365
# pytype: enable=name-error
6466
assert isinstance(source_params, RuntimeParams)
@@ -73,8 +75,9 @@ def calculate_generic_current(
7375
)
7476

7577
Cext = I_generic / math_utils.area_integration(generic_current_form, geo)
76-
generic_current_profile = Cext * generic_current_form
77-
return (generic_current_profile,)
78+
j_tor = Cext * generic_current_form
79+
80+
return (psi_calculations.j_tor_to_j_parallel(j_tor, geo),)
7881

7982

8083
def _calculate_I_generic(

torax/_src/sources/source_profiles.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,20 @@ def merge(
121121
)
122122

123123
def total_psi_sources(self, geo: geometry.Geometry) -> jax.Array:
124-
total = self.bootstrap_current.j_bootstrap
125-
total += sum(self.psi.values())
126-
mu0 = constants.CONSTANTS.mu_0
127-
prefactor = 8 * geo.vpr * jnp.pi**2 * geo.B_0 * mu0 * geo.Phi_b / geo.F**2
128-
return -total * prefactor
124+
# All psi sources are assumed to be parallel to the magnetic field, ie
125+
# self.psi.values() is <j.B> / B0
126+
total_j_dot_B_over_B0 = self.bootstrap_current.j_bootstrap
127+
total_j_dot_B_over_B0 += sum(self.psi.values())
128+
total_j_dot_B = total_j_dot_B_over_B0 * geo.B_0
129+
prefactor = (
130+
8
131+
* geo.vpr
132+
* jnp.pi**2
133+
* constants.CONSTANTS.mu_0
134+
* geo.Phi_b
135+
/ geo.F**2
136+
)
137+
return -total_j_dot_B * prefactor
129138

130139
def total_sources(
131140
self,

0 commit comments

Comments
 (0)