Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 33 additions & 21 deletions torax/_src/core_profiles/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,9 @@ def _iterate_psi_and_sources(
runtime_params,
geo,
source_profiles.bootstrap_current,
sum(source_profiles.psi.values()),
j_toroidal_external=psi_calculations.j_parallel_to_j_toroidal(
sum(source_profiles.psi.values()), geo
),
)
psi = update_psi_from_j(
runtime_params.profile_conditions.Ip,
Expand Down Expand Up @@ -578,34 +580,44 @@ def get_j_total_hires_with_external_sources(
runtime_params: runtime_params_lib.RuntimeParams,
geo: geometry.Geometry,
bootstrap_current: bootstrap_current_base.BootstrapCurrent,
external_current: jax.Array,
j_toroidal_external: jax.Array,
) -> jax.Array:
"""Calculates j_total hires when the Ohmic current is given by a formula."""
Ip = runtime_params.profile_conditions.Ip
psi_current = external_current + bootstrap_current.j_bootstrap

j_bootstrap_hires = jnp.interp(
geo.rho_hires, geo.rho_face, bootstrap_current.j_bootstrap_face
# Convert bootstrap current density to toroidal, and calculate high-resolution
# version
j_toroidal_bootstrap = psi_calculations.j_parallel_to_j_toroidal(
bootstrap_current.j_parallel_bootstrap, geo
)
j_toroidal_bootstrap_hires = jnp.interp(
geo.rho_hires, geo.rho_face, bootstrap_current.j_parallel_bootstrap_face
)

# calculate hi-res "External" current profile (e.g. ECCD) on cell grid.
external_current_face = math_utils.cell_to_face(
external_current,
# Calculate high-resolution version of external (eg ECCD) current density
j_toroidal_external_face = math_utils.cell_to_face(
j_toroidal_external,
geo,
preserved_quantity=math_utils.IntegralPreservationQuantity.SURFACE,
)
external_current_hires = jnp.interp(
geo.rho_hires, geo.rho_face, external_current_face
j_toroidal_external_hires = jnp.interp(
geo.rho_hires, geo.rho_face, j_toroidal_external_face
)

# calculate high resolution j_total and Ohmic current profile
jformula_hires = (
# Calculate high resolution j_total and j_ohmic
j_toroidal_ohmic_formula_hires = (
1 - geo.rho_hires_norm**2
) ** runtime_params.profile_conditions.current_profile_nu
denom = _trapz(jformula_hires * geo.spr_hires, geo.rho_hires_norm)
I_non_inductive = math_utils.area_integration(psi_current, geo)
Iohm = Ip - I_non_inductive
Cohm_hires = Iohm / denom
j_ohmic_hires = jformula_hires * Cohm_hires
j_total_hires = j_ohmic_hires + external_current_hires + j_bootstrap_hires
return j_total_hires
denom = _trapz(
j_toroidal_ohmic_formula_hires * geo.spr_hires, geo.rho_hires_norm
)
I_noninductive = math_utils.area_integration(
j_toroidal_external + j_toroidal_bootstrap, geo
)
I_ohmic = runtime_params.profile_conditions.Ip - I_noninductive
C_ohm_hires = I_ohmic / denom
j_toroidal_ohmic_hires = j_toroidal_ohmic_formula_hires * C_ohm_hires
j_toroidal_total_hires = (
j_toroidal_ohmic_hires
+ j_toroidal_external_hires
+ j_toroidal_bootstrap_hires
)
return j_toroidal_total_hires
76 changes: 43 additions & 33 deletions torax/_src/core_profiles/tests/initialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torax._src.geometry import geometry
from torax._src.geometry import standard_geometry
from torax._src.neoclassical.bootstrap_current import base as bootstrap_current_base
from torax._src.physics import psi_calculations
from torax._src.sources import generic_current_source
from torax._src.sources import runtime_params as runtime_params_lib
from torax._src.sources import source_profile_builders
Expand All @@ -42,11 +43,11 @@
class _Currents:
"""Container for the various currents used in tests."""

j_total: jax.Array
j_total_face: jax.Array
j_external: jax.Array
j_bootstrap: jax.Array
j_ohmic: jax.Array
j_toroidal_total: jax.Array
j_toroidal_total_face: jax.Array
j_toroidal_external: jax.Array
j_toroidal_bootstrap: jax.Array
j_toroidal_ohmic: jax.Array


class InitializationTest(parameterized.TestCase):
Expand All @@ -62,19 +63,22 @@ def test_update_psi_from_j(self):
# Turn on the external current source.
runtime_params, geo = references.get_runtime_params_and_geo()
bootstrap = bootstrap_current_base.BootstrapCurrent.zeros(geo)
external_current = generic_current_source.calculate_generic_current(
j_parallel_external = generic_current_source.calculate_generic_current(
runtime_params=runtime_params,
geo=geo,
source_name=generic_current_source.GenericCurrentSource.SOURCE_NAME,
unused_state=mock.ANY,
unused_calculated_source_profiles=mock.ANY,
unused_conductivity=mock.ANY,
)[0]
j_toroidal_external = psi_calculations.j_parallel_to_j_toroidal(
j_parallel_external, geo
)
j_total_hires = initialization.get_j_total_hires_with_external_sources(
bootstrap_current=bootstrap,
external_current=external_current,
runtime_params=runtime_params,
geo=geo,
j_toroidal_external=j_toroidal_external,
)
psi = initialization.update_psi_from_j(
runtime_params.profile_conditions.Ip,
Expand Down Expand Up @@ -197,38 +201,38 @@ def test_compare_initial_currents_with_different_initial_j_ohmic(
np.testing.assert_raises(
AssertionError,
np.testing.assert_allclose,
np.mean(currents1.j_total),
np.mean(currents2.j_total),
np.mean(currents1.j_toroidal_total),
np.mean(currents2.j_toroidal_total),
rtol=_TOL,
)

# Check that the total current agrees with the expected reference formula.
np.testing.assert_allclose(
np.mean(currents1.j_total),
np.mean(currents1.j_toroidal_total),
np.mean(j_total1_expected),
rtol=_TOL,
)

# The only non-inductive current is the external current. Therefore the
# sum of Ohmic + external current should be equal to the total current.
np.testing.assert_allclose(
np.mean(currents1.j_external + currents1.j_ohmic),
np.mean(currents1.j_total),
np.mean(currents1.j_toroidal_external + currents1.j_toroidal_ohmic),
np.mean(currents1.j_toroidal_total),
rtol=_TOL,
)

# j_ohmic2_expected is the expected formula for j_ohmic when setting
# initial_j_is_total_current=False as in Case 2. It is the "nu" formula
# scaled down to compensate for the external current.
np.testing.assert_allclose(
np.mean(currents2.j_ohmic),
np.mean(currents2.j_toroidal_ohmic),
np.mean(j_ohmic2_expected),
rtol=_TOL,
)

# Check that the face conversions agree with the expected reference.
np.testing.assert_allclose(
np.mean(currents1.j_total_face),
np.mean(currents1.j_toroidal_total_face),
np.mean(
math_utils.cell_to_face(
j_total1_expected,
Expand All @@ -242,7 +246,7 @@ def test_compare_initial_currents_with_different_initial_j_ohmic(
np.testing.assert_raises(
AssertionError,
np.testing.assert_allclose,
np.mean(currents2.j_total_face),
np.mean(currents2.j_toroidal_total_face),
np.mean(
math_utils.cell_to_face(
j_total1_expected,
Expand Down Expand Up @@ -315,8 +319,8 @@ def test_initial_psi_from_j_with_bootstrap_is_consistent_with_case_without_boots

# In Case 1, all the current should be Ohmic current.
np.testing.assert_allclose(
np.mean(currents1.j_ohmic),
np.mean(currents1.j_total),
np.mean(currents1.j_toroidal_ohmic),
np.mean(currents1.j_toroidal_total),
rtol=_TOL,
)

Expand All @@ -325,17 +329,17 @@ def test_initial_psi_from_j_with_bootstrap_is_consistent_with_case_without_boots
np.testing.assert_raises(
AssertionError,
np.testing.assert_allclose,
np.mean(currents1.j_ohmic),
np.mean(currents2.j_ohmic),
np.mean(currents1.j_toroidal_ohmic),
np.mean(currents2.j_toroidal_ohmic),
rtol=_TOL,
)

# The only non-inductive current in Case 2 is the bootstrap current.
# Thus, the sum of Ohmic and booststrap currents should be equal to the
# total (ohmic) current in Case 1.
np.testing.assert_allclose(
np.mean(currents1.j_total),
np.mean(currents2.j_ohmic + currents2.j_bootstrap),
np.mean(currents1.j_toroidal_total),
np.mean(currents2.j_toroidal_ohmic + currents2.j_toroidal_bootstrap),
rtol=_TOL,
)

Expand All @@ -347,11 +351,11 @@ def test_initial_psi_from_geo_noop_circular(self):
}
torax_config = model_config.ToraxConfig.from_dict(config)
_, _, currents1 = _get_initial_state(torax_config)
jtotal1 = currents1.j_total
jtotal1 = currents1.j_toroidal_total

torax_config.update_fields({'profile_conditions.initial_psi_from_j': True})
_, _, currents2 = _get_initial_state(torax_config)
jtotal2 = currents2.j_total
jtotal2 = currents2.j_toroidal_total

np.testing.assert_allclose(jtotal1, jtotal2)

Expand Down Expand Up @@ -496,17 +500,23 @@ def _get_initial_state(
neoclassical_models=neoclassical_models,
conductivity=conductivity,
)
j_total = core_profiles.j_total
j_total_face = core_profiles.j_total_face
j_external = sum(core_sources.psi.values())
j_bootstrap = core_sources.bootstrap_current.j_bootstrap
j_ohmic = j_total - j_external - j_bootstrap
j_toroidal_total = core_profiles.j_total
j_toroidal_total_face = core_profiles.j_total_face
j_toroidal_external = psi_calculations.j_parallel_to_j_toroidal(
sum(core_sources.psi.values()), geo
)
j_toroidal_bootstrap = psi_calculations.j_parallel_to_j_toroidal(
core_sources.bootstrap_current.j_parallel_bootstrap, geo
)
j_toroidal_ohmic = (
j_toroidal_total - j_toroidal_external - j_toroidal_bootstrap
)
currents = _Currents(
j_total=j_total,
j_total_face=j_total_face,
j_external=j_external,
j_bootstrap=j_bootstrap,
j_ohmic=j_ohmic,
j_toroidal_total=j_toroidal_total,
j_toroidal_total_face=j_toroidal_total_face,
j_toroidal_external=j_toroidal_external,
j_toroidal_bootstrap=j_toroidal_bootstrap,
j_toroidal_ohmic=j_toroidal_ohmic,
)
return core_profiles, geo, currents

Expand Down
22 changes: 11 additions & 11 deletions torax/_src/mhd/sawtooth/tests/sawtooth_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@

import dataclasses

from absl.testing import absltest
from absl.testing import parameterized
import jax
import numpy as np
from absl.testing import absltest, parameterized

from torax._src import state
from torax._src.config import build_runtime_params
from torax._src.orchestration import initial_state as initial_state_lib
Expand Down Expand Up @@ -274,15 +274,15 @@ def test_no_subsequent_sawtooth_crashes(self):
])

_POST_CRASH_PSI = np.array([
9.778742,
11.342102,
14.360384,
18.737049,
24.378128,
31.058185,
38.126174,
44.844899,
50.742815,
8.245389,
9.864265,
12.989683,
17.52163,
23.362746,
30.278108,
37.587465,
44.522205,
50.597804,
55.729866,
])

Expand Down
8 changes: 4 additions & 4 deletions torax/_src/neoclassical/bootstrap_current/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
class BootstrapCurrent:
"""Values returned by a bootstrap current model."""

j_bootstrap: jax.Array
j_bootstrap_face: jax.Array
j_parallel_bootstrap: jax.Array
j_parallel_bootstrap_face: jax.Array

@classmethod
def zeros(cls, geometry: geometry_lib.Geometry) -> 'BootstrapCurrent':
"""Returns a BootstrapCurrent with all values set to zero."""
return cls(
j_bootstrap=jnp.zeros_like(geometry.rho_norm),
j_bootstrap_face=jnp.zeros_like(geometry.rho_face_norm),
j_parallel_bootstrap=jnp.zeros_like(geometry.rho_norm),
j_parallel_bootstrap_face=jnp.zeros_like(geometry.rho_face_norm),
)


Expand Down
14 changes: 7 additions & 7 deletions torax/_src/neoclassical/bootstrap_current/sauter.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def calculate_bootstrap_current(
geo=geometry,
)
return base.BootstrapCurrent(
j_bootstrap=result.j_bootstrap,
j_bootstrap_face=result.j_bootstrap_face,
j_parallel_bootstrap=result.j_parallel_bootstrap,
j_parallel_bootstrap_face=result.j_parallel_bootstrap_face,
)

def __eq__(self, other) -> bool:
Expand Down Expand Up @@ -110,7 +110,7 @@ def _calculate_bootstrap_current(
q_face: array_typing.FloatVectorFace,
geo: geometry_lib.Geometry,
) -> base.BootstrapCurrent:
"""Calculates j_bootstrap and j_bootstrap_face using the Sauter model."""
"""Calculates j_parallel_bootstrap and j_parallel_bootstrap_face using the Sauter model."""
# pylint: disable=invalid-name

# Formulas from Sauter PoP 1999. Future work can include Redl PoP 2021
Expand Down Expand Up @@ -169,17 +169,17 @@ def _calculate_bootstrap_current(
tecoeff = (L31 + L32) * pe
ticoeff = (L31 + alpha * L34) * pi

j_bootstrap_face = global_coeff * (
j_parallel_bootstrap_face = global_coeff * (
necoeff * dlnne_drnorm
+ nicoeff * dlnni_drnorm
+ tecoeff * dlnte_drnorm
+ ticoeff * dlnti_drnorm
)
j_bootstrap = geometry_lib.face_to_cell(j_bootstrap_face)
j_parallel_bootstrap = geometry_lib.face_to_cell(j_parallel_bootstrap_face)

return base.BootstrapCurrent(
j_bootstrap=j_bootstrap,
j_bootstrap_face=j_bootstrap_face,
j_parallel_bootstrap=j_parallel_bootstrap,
j_parallel_bootstrap_face=j_parallel_bootstrap_face,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def test_sauter_bootstrap_current_is_correct_shape(self):
result = model.calculate_bootstrap_current(
runtime_params, geo, core_profiles
)
self.assertEqual(result.j_bootstrap.shape, (n_rho,))
self.assertEqual(result.j_bootstrap_face.shape, (n_rho + 1,))
self.assertEqual(result.j_parallel_bootstrap.shape, (n_rho,))
self.assertEqual(result.j_parallel_bootstrap_face.shape, (n_rho + 1,))


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions torax/_src/neoclassical/bootstrap_current/zeros.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def calculate_bootstrap_current(
) -> base.BootstrapCurrent:
"""Calculates bootstrap current."""
return base.BootstrapCurrent(
j_bootstrap=jnp.zeros_like(geometry.rho),
j_bootstrap_face=jnp.zeros_like(geometry.rho_face),
j_parallel_bootstrap=jnp.zeros_like(geometry.rho),
j_parallel_bootstrap_face=jnp.zeros_like(geometry.rho_face),
)

def __eq__(self, other) -> bool:
Expand Down
8 changes: 4 additions & 4 deletions torax/_src/orchestration/tests/initial_state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ def test_core_profile_final_step(self, test_config):
output.IP_PROFILE,
output.Q,
output.MAGNETIC_SHEAR,
output.J_BOOTSTRAP,
output.J_OHMIC,
output.J_EXTERNAL,
output.J_TOTAL,
output.J_TOR_BOOTSTRAP,
output.J_TOR_OHMIC,
output.J_TOR_EXTERNAL,
output.J_TOR_TOTAL,
output.SIGMA_PARALLEL,
]
index = -1
Expand Down
Loading