Skip to content

Commit 7425b98

Browse files
goodfeliTorax team
authored andcommitted
Introduce PedestalPolicy.
This is a step toward having full Martin scaling, where there will be a Martin scaling PedestalPolicy that ramps up and ramps down the pedestal height. This PR is just a step that defines the new stateful interface required for doing so and transitions the existing `set_pedestal` functionality to being implemented with the new PedestalPolicy class. PiperOrigin-RevId: 810477225
1 parent 2e96ab0 commit 7425b98

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+836
-135
lines changed

torax/_src/config/build_runtime_params.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def __call__(
131131
numerics=self.numerics.build_runtime_params(t),
132132
neoclassical=self.neoclassical.build_runtime_params(),
133133
pedestal=self.pedestal.build_runtime_params(t),
134+
pedestal_policy=self.pedestal.build_pedestal_policy_runtime_params(),
134135
mhd=self.mhd.build_runtime_params(t),
135136
time_step_calculator=self.time_step_calculator.build_runtime_params(),
136137
edge=None if self.edge is None else self.edge.build_runtime_params(t),

torax/_src/config/runtime_params.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from torax._src.mhd import runtime_params as mhd_runtime_params
5050
from torax._src.neoclassical import runtime_params as neoclassical_params
5151
from torax._src.pedestal_model import runtime_params as pedestal_model_params
52+
from torax._src.pedestal_policy import runtime_params as pedestal_policy_runtime_params
5253
from torax._src.solver import runtime_params as solver_params
5354
from torax._src.sources import runtime_params as sources_params
5455
from torax._src.time_step_calculator import runtime_params as time_step_calculator_runtime_params
@@ -80,6 +81,7 @@ class RuntimeParams:
8081
neoclassical: neoclassical_params.RuntimeParams
8182
numerics: numerics.RuntimeParams
8283
pedestal: pedestal_model_params.RuntimeParams
84+
pedestal_policy: pedestal_policy_runtime_params.PedestalPolicyRuntimeParams
8385
plasma_composition: plasma_composition.RuntimeParams
8486
profile_conditions: profile_conditions.RuntimeParams
8587
solver: solver_params.RuntimeParams

torax/_src/config/tests/build_runtime_params_test.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,19 +103,32 @@ def test_pedestal_is_time_dependent(self):
103103
set_pedestal={0.0: True, 1.0: False},
104104
)
105105
)
106+
pedestal_policy = pedestal.build_pedestal_model().pedestal_policy
106107
# Check at time 0.
107108

108109
pedestal_params = pedestal.build_runtime_params(t=0.0)
110+
pedestal_policy_rp = pedestal.build_pedestal_policy_runtime_params()
109111
assert isinstance(pedestal_params, set_tped_nped.RuntimeParams)
110-
np.testing.assert_allclose(pedestal_params.set_pedestal, True)
112+
np.testing.assert_allclose(
113+
pedestal_policy.initial_state(
114+
t=0.0, runtime_params=pedestal_policy_rp
115+
).use_pedestal,
116+
True,
117+
)
111118
np.testing.assert_allclose(pedestal_params.T_i_ped, 0.0)
112119
np.testing.assert_allclose(pedestal_params.T_e_ped, 1.0)
113120
np.testing.assert_allclose(pedestal_params.n_e_ped, 2.0e20)
114121
np.testing.assert_allclose(pedestal_params.rho_norm_ped_top, 3.0)
115122
# And check after the time limit.
116123
pedestal_params = pedestal.build_runtime_params(t=1.0)
124+
# Note: pedestal_policy_rp does not depend on time for its structure
117125
assert isinstance(pedestal_params, set_tped_nped.RuntimeParams)
118-
np.testing.assert_allclose(pedestal_params.set_pedestal, False)
126+
np.testing.assert_allclose(
127+
pedestal_policy.initial_state(
128+
t=1.0, runtime_params=pedestal_policy_rp
129+
).use_pedestal,
130+
False,
131+
)
119132
np.testing.assert_allclose(pedestal_params.T_i_ped, 1.0)
120133
np.testing.assert_allclose(pedestal_params.T_e_ped, 2.0)
121134
np.testing.assert_allclose(pedestal_params.n_e_ped, 3.0e20)

torax/_src/fvm/calc_coeffs.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torax._src.fvm import cell_variable
2828
from torax._src.geometry import geometry
2929
from torax._src.pedestal_model import pedestal_model as pedestal_model_lib
30+
from torax._src.pedestal_policy import pedestal_policy
3031
from torax._src.sources import source_profile_builders
3132
from torax._src.sources import source_profiles as source_profiles_lib
3233
import typing_extensions
@@ -63,6 +64,7 @@ def __call__(
6364
core_profiles: state.CoreProfiles,
6465
x: tuple[cell_variable.CellVariable, ...],
6566
explicit_source_profiles: source_profiles_lib.SourceProfiles,
67+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
6668
allow_pereverzev: bool = False,
6769
# Checks if reduced calc_coeffs for explicit terms when theta_implicit=1
6870
# should be called
@@ -86,6 +88,7 @@ def __call__(
8688
not recalculated at time t+plus_dt with updated state during the solver
8789
iterations. For sources that are implicit, their explicit profiles are
8890
set to all zeros.
91+
pedestal_policy_state: State held by the pedestal policy.
8992
allow_pereverzev: If True, then the coeffs are being called within a
9093
linear solver. Thus could be either the use_predictor_corrector solver
9194
or as part of calculating the initial guess for the nonlinear solver. In
@@ -101,6 +104,16 @@ def __call__(
101104
coeffs: The diffusion, convection, etc. coefficients for this state.
102105
"""
103106

107+
# There are cases where pytype fails to enforce this
108+
if not isinstance(
109+
pedestal_policy_state, pedestal_policy.PedestalPolicyState
110+
):
111+
raise TypeError(
112+
'Expected `pedestal_policy_state` to be of type '
113+
'`pedestal_policy.PedestalPolicyState`',
114+
f'got `{type(pedestal_policy_state)}`.',
115+
)
116+
104117
# Update core_profiles with the subset of new values of evolving variables
105118
core_profiles = updaters.update_core_profiles_during_step(
106119
x,
@@ -121,6 +134,7 @@ def __call__(
121134
explicit_source_profiles=explicit_source_profiles,
122135
physics_models=self.physics_models,
123136
evolving_names=self.evolving_names,
137+
pedestal_policy_state=pedestal_policy_state,
124138
use_pereverzev=use_pereverzev,
125139
explicit_call=explicit_call,
126140
)
@@ -219,6 +233,7 @@ def calc_coeffs(
219233
explicit_source_profiles: source_profiles_lib.SourceProfiles,
220234
physics_models: physics_models_lib.PhysicsModels,
221235
evolving_names: tuple[str, ...],
236+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
222237
use_pereverzev: bool = False,
223238
explicit_call: bool = False,
224239
) -> block_1d_coeffs.Block1DCoeffs:
@@ -241,6 +256,7 @@ def calc_coeffs(
241256
physics_models: The physics models to use for the simulation.
242257
evolving_names: The names of the evolving variables in the order that their
243258
coefficients should be written to `coeffs`.
259+
pedestal_policy_state: State held by the pedestal policy.
244260
use_pereverzev: Toggle whether to calculate Pereverzev terms
245261
explicit_call: If True, indicates that calc_coeffs is being called for the
246262
explicit component of the PDE. Then calculates a reduced Block1DCoeffs if
@@ -251,6 +267,14 @@ def calc_coeffs(
251267
coeffs: Block1DCoeffs containing the coefficients at this time step.
252268
"""
253269

270+
# There are cases where pytype fails to enforce this
271+
if not isinstance(pedestal_policy_state, pedestal_policy.PedestalPolicyState):
272+
raise TypeError(
273+
'Expected `pedestal_policy_state` to be of type '
274+
'`pedestal_policy.PedestalPolicyState`',
275+
f'got `{type(pedestal_policy_state)}`.',
276+
)
277+
254278
# If we are fully implicit and we are making a call for calc_coeffs for the
255279
# explicit components of the PDE, only return a cheaper reduced Block1DCoeffs
256280
if explicit_call and runtime_params.solver.theta_implicit == 1.0:
@@ -267,6 +291,7 @@ def calc_coeffs(
267291
explicit_source_profiles=explicit_source_profiles,
268292
physics_models=physics_models,
269293
evolving_names=evolving_names,
294+
pedestal_policy_state=pedestal_policy_state,
270295
use_pereverzev=use_pereverzev,
271296
)
272297

@@ -285,14 +310,26 @@ def _calc_coeffs_full(
285310
explicit_source_profiles: source_profiles_lib.SourceProfiles,
286311
physics_models: physics_models_lib.PhysicsModels,
287312
evolving_names: tuple[str, ...],
313+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
288314
use_pereverzev: bool = False,
289315
) -> block_1d_coeffs.Block1DCoeffs:
290316
"""See `calc_coeffs` for details."""
291317

292318
consts = constants.CONSTANTS
293319

320+
# There are cases where pytype fails to enforce this
321+
if not isinstance(pedestal_policy_state, pedestal_policy.PedestalPolicyState):
322+
raise TypeError(
323+
'Expected `pedestal_policy_state` to be of type '
324+
'`pedestal_policy.PedestalPolicyState`',
325+
f'got `{type(pedestal_policy_state)}`.',
326+
)
327+
294328
pedestal_model_output = physics_models.pedestal_model(
295-
runtime_params, geo, core_profiles
329+
runtime_params,
330+
geo,
331+
core_profiles,
332+
pedestal_policy_state=pedestal_policy_state,
296333
)
297334

298335
# Boolean mask for enforcing internal temperature boundary conditions to
@@ -352,7 +389,11 @@ def _calc_coeffs_full(
352389

353390
# Diffusion term coefficients
354391
turbulent_transport = physics_models.transport_model(
355-
runtime_params, geo, core_profiles, pedestal_model_output
392+
runtime_params,
393+
geo,
394+
core_profiles,
395+
pedestal_policy_state,
396+
pedestal_model_output,
356397
)
357398
neoclassical_transport = physics_models.neoclassical_models.transport(
358399
runtime_params, geo, core_profiles

torax/_src/fvm/newton_raphson_solve_block.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from torax._src.fvm import fvm_conversions
3535
from torax._src.fvm import residual_and_loss
3636
from torax._src.geometry import geometry
37+
from torax._src.pedestal_policy import pedestal_policy
3738
from torax._src.solver import jax_root_finding
3839
from torax._src.solver import predictor_corrector_method
3940
from torax._src.sources import source_profiles
@@ -68,6 +69,8 @@ def newton_raphson_solve_block(
6869
physics_models: physics_models_lib.PhysicsModels,
6970
coeffs_callback: calc_coeffs.CoeffsCallback,
7071
evolving_names: tuple[str, ...],
72+
pedestal_policy_state_t: pedestal_policy.PedestalPolicyState,
73+
pedestal_policy_state_t_plus_dt: pedestal_policy.PedestalPolicyState,
7174
initial_guess_mode: enums.InitialGuessMode,
7275
maxiter: int,
7376
tol: float,
@@ -129,6 +132,8 @@ def newton_raphson_solve_block(
129132
core_profiles. Repeatedly called by the iterative optimizer.
130133
evolving_names: The names of variables within the core profiles that should
131134
evolve.
135+
pedestal_policy_state_t: Pedestal policy state at time t
136+
pedestal_policy_state_t_plus_dt: Pedestal policy state at time t + dt
132137
initial_guess_mode: chooses the initial_guess for the iterative method,
133138
either x_old or linear step. When taking the linear step, it is also
134139
recommended to use Pereverzev-Corrigan terms if the transport coefficients
@@ -160,6 +165,7 @@ def newton_raphson_solve_block(
160165
core_profiles_t,
161166
x_old,
162167
explicit_source_profiles=explicit_source_profiles,
168+
pedestal_policy_state=pedestal_policy_state_t,
163169
explicit_call=True,
164170
)
165171

@@ -176,6 +182,7 @@ def newton_raphson_solve_block(
176182
core_profiles_t,
177183
x_old,
178184
explicit_source_profiles=explicit_source_profiles,
185+
pedestal_policy_state=pedestal_policy_state_t,
179186
allow_pereverzev=True,
180187
explicit_call=True,
181188
)
@@ -194,6 +201,7 @@ def newton_raphson_solve_block(
194201
coeffs_exp=coeffs_exp_linear,
195202
coeffs_callback=coeffs_callback,
196203
explicit_source_profiles=explicit_source_profiles,
204+
pedestal_policy_state_t_plus_dt=pedestal_policy_state_t_plus_dt,
197205
)
198206
init_x_new_vec = fvm_conversions.cell_variable_tuple_to_vec(init_x_new)
199207
case enums.InitialGuessMode.X_OLD:
@@ -215,6 +223,7 @@ def newton_raphson_solve_block(
215223
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
216224
physics_models=physics_models,
217225
explicit_source_profiles=explicit_source_profiles,
226+
pedestal_policy_state_t_plus_dt=pedestal_policy_state_t_plus_dt,
218227
coeffs_old=coeffs_old,
219228
evolving_names=evolving_names,
220229
)

torax/_src/fvm/optimizer_solve_block.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from torax._src.fvm import fvm_conversions
3333
from torax._src.fvm import residual_and_loss
3434
from torax._src.geometry import geometry
35+
from torax._src.pedestal_policy import pedestal_policy
3536
from torax._src.solver import predictor_corrector_method
3637
from torax._src.sources import source_profiles
3738

@@ -57,6 +58,8 @@ def optimizer_solve_block(
5758
core_profiles_t: state.CoreProfiles,
5859
core_profiles_t_plus_dt: state.CoreProfiles,
5960
explicit_source_profiles: source_profiles.SourceProfiles,
61+
pedestal_policy_state_t: pedestal_policy.PedestalPolicyState,
62+
pedestal_policy_state_t_plus_dt: pedestal_policy.PedestalPolicyState,
6063
physics_models: physics_models_lib.PhysicsModels,
6164
coeffs_callback: calc_coeffs.CoeffsCallback,
6265
evolving_names: tuple[str, ...],
@@ -98,6 +101,9 @@ def optimizer_solve_block(
98101
being evolved by the PDE system.
99102
explicit_source_profiles: Pre-calculated sources implemented as explicit
100103
sources in the PDE.
104+
pedestal_policy_state_t: State variables held by the pedestal policy.
105+
pedestal_policy_state_t_plus_dt: State variables held by the pedestal
106+
policy.
101107
physics_models: Physics models used for the calculations.
102108
coeffs_callback: Calculates diffusion, convection etc. coefficients given a
103109
core_profiles. Repeatedly called by the iterative optimizer.
@@ -124,6 +130,7 @@ def optimizer_solve_block(
124130
core_profiles_t,
125131
x_old,
126132
explicit_source_profiles=explicit_source_profiles,
133+
pedestal_policy_state=pedestal_policy_state_t,
127134
explicit_call=True,
128135
)
129136

@@ -141,6 +148,7 @@ def optimizer_solve_block(
141148
core_profiles_t,
142149
x_old,
143150
explicit_source_profiles=explicit_source_profiles,
151+
pedestal_policy_state=pedestal_policy_state_t,
144152
allow_pereverzev=True,
145153
explicit_call=True,
146154
)
@@ -158,6 +166,7 @@ def optimizer_solve_block(
158166
coeffs_exp=coeffs_exp_linear,
159167
coeffs_callback=coeffs_callback,
160168
explicit_source_profiles=explicit_source_profiles,
169+
pedestal_policy_state_t_plus_dt=pedestal_policy_state_t_plus_dt,
161170
)
162171
init_x_new_vec = fvm_conversions.cell_variable_tuple_to_vec(init_x_new)
163172
case enums.InitialGuessMode.X_OLD:
@@ -180,6 +189,7 @@ def optimizer_solve_block(
180189
init_x_new_vec=init_x_new_vec,
181190
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
182191
explicit_source_profiles=explicit_source_profiles,
192+
pedestal_policy_state_t_plus_dt=pedestal_policy_state_t_plus_dt,
183193
physics_models=physics_models,
184194
coeffs_old=coeffs_old,
185195
evolving_names=evolving_names,

torax/_src/fvm/residual_and_loss.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from torax._src.fvm import discrete_system
3939
from torax._src.fvm import fvm_conversions
4040
from torax._src.geometry import geometry
41+
from torax._src.pedestal_policy import pedestal_policy
4142
from torax._src.sources import source_profiles
4243

4344
Block1DCoeffs: TypeAlias = block_1d_coeffs.Block1DCoeffs
@@ -201,6 +202,7 @@ def theta_method_block_residual(
201202
x_old: tuple[cell_variable.CellVariable, ...],
202203
core_profiles_t_plus_dt: state.CoreProfiles,
203204
explicit_source_profiles: source_profiles.SourceProfiles,
205+
pedestal_policy_state_t_plus_dt: pedestal_policy.PedestalPolicyState,
204206
physics_models: physics_models_lib.PhysicsModels,
205207
coeffs_old: Block1DCoeffs,
206208
evolving_names: tuple[str, ...],
@@ -220,6 +222,8 @@ def theta_method_block_residual(
220222
being evolved by the PDE system.
221223
explicit_source_profiles: Pre-calculated sources implemented as explicit
222224
sources in the PDE.
225+
pedestal_policy_state_t_plus_dt: State variables held by the pedestal
226+
policy.
223227
physics_models: Physics models used for the calculations.
224228
coeffs_old: The coefficients calculated at x_old.
225229
evolving_names: The names of variables within the core profiles that should
@@ -252,6 +256,7 @@ def theta_method_block_residual(
252256
core_profiles=core_profiles_t_plus_dt,
253257
explicit_source_profiles=explicit_source_profiles,
254258
physics_models=physics_models,
259+
pedestal_policy_state=pedestal_policy_state_t_plus_dt,
255260
evolving_names=evolving_names,
256261
use_pereverzev=False,
257262
)
@@ -290,6 +295,7 @@ def theta_method_block_loss(
290295
x_old: tuple[cell_variable.CellVariable, ...],
291296
core_profiles_t_plus_dt: state.CoreProfiles,
292297
explicit_source_profiles: source_profiles.SourceProfiles,
298+
pedestal_policy_state_t_plus_dt: pedestal_policy.PedestalPolicyState,
293299
physics_models: physics_models_lib.PhysicsModels,
294300
coeffs_old: Block1DCoeffs,
295301
evolving_names: tuple[str, ...],
@@ -309,6 +315,8 @@ def theta_method_block_loss(
309315
being evolved by the PDE system.
310316
explicit_source_profiles: pre-calculated sources implemented as explicit
311317
sources in the PDE
318+
pedestal_policy_state_t_plus_dt: State variables held by the pedestal
319+
policy.
312320
physics_models: Physics models used for the calculations.
313321
coeffs_old: The coefficients calculated at x_old.
314322
evolving_names: The names of variables within the core profiles that should
@@ -326,6 +334,7 @@ def theta_method_block_loss(
326334
x_new_guess_vec=x_new_guess_vec,
327335
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
328336
explicit_source_profiles=explicit_source_profiles,
337+
pedestal_policy_state_t_plus_dt=pedestal_policy_state_t_plus_dt,
329338
physics_models=physics_models,
330339
coeffs_old=coeffs_old,
331340
evolving_names=evolving_names,
@@ -349,6 +358,7 @@ def jaxopt_solver(
349358
init_x_new_vec: jax.Array,
350359
core_profiles_t_plus_dt: state.CoreProfiles,
351360
explicit_source_profiles: source_profiles.SourceProfiles,
361+
pedestal_policy_state_t_plus_dt: pedestal_policy.PedestalPolicy,
352362
physics_models: physics_models_lib.PhysicsModels,
353363
coeffs_old: Block1DCoeffs,
354364
evolving_names: tuple[str, ...],
@@ -370,6 +380,8 @@ def jaxopt_solver(
370380
being evolved by the PDE system.
371381
explicit_source_profiles: pre-calculated sources implemented as explicit
372382
sources in the PDE.
383+
pedestal_policy_state_t_plus_dt: State variables held by the pedestal
384+
policy.
373385
physics_models: Physics models used for the calculations.
374386
coeffs_old: The coefficients calculated at x_old.
375387
evolving_names: The names of variables within the core profiles that should
@@ -394,6 +406,7 @@ def jaxopt_solver(
394406
physics_models=physics_models,
395407
coeffs_old=coeffs_old,
396408
evolving_names=evolving_names,
409+
pedestal_policy_state_t_plus_dt=pedestal_policy_state_t_plus_dt,
397410
)
398411
solver = jaxopt.LBFGS(fun=loss, maxiter=maxiter, tol=tol, implicit_diff=True)
399412
solver_output = solver.run(init_x_new_vec)

0 commit comments

Comments
 (0)