Skip to content

Commit facd202

Browse files
goodfeliTorax team
authored andcommitted
Introduce PedestalPolicy.
This is a step toward having full Martin scaling, where there will be a Martin scaling PedestalPolicy that ramps up and ramps down the pedestal height. This PR is just a step that defines the new stateful interface required for doing so and transitions the existing `set_pedestal` functionality to being implemented with the new PedestalPolicy class. PiperOrigin-RevId: 810477225
1 parent 5ba4eef commit facd202

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+736
-135
lines changed

torax/_src/config/build_runtime_params.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def __call__(
103103
numerics=self.numerics.build_runtime_params(t),
104104
neoclassical=self.neoclassical.build_runtime_params(),
105105
pedestal=self.pedestal.build_runtime_params(t),
106+
pedestal_policy=self.pedestal.build_pedestal_policy_runtime_params(),
106107
mhd=self.mhd.build_runtime_params(t),
107108
time_step_calculator=self.time_step_calculator.build_runtime_params(),
108109
)

torax/_src/config/runtime_params_slice.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from torax._src.mhd import runtime_params as mhd_runtime_params
4949
from torax._src.neoclassical import runtime_params as neoclassical_params
5050
from torax._src.pedestal_model import runtime_params as pedestal_model_params
51+
from torax._src.pedestal_policy import runtime_params as pedestal_policy_runtime_params
5152
from torax._src.solver import runtime_params as solver_params
5253
from torax._src.sources import runtime_params as sources_params
5354
from torax._src.time_step_calculator import runtime_params as time_step_calculator_runtime_params
@@ -78,6 +79,7 @@ class RuntimeParams:
7879
neoclassical: neoclassical_params.RuntimeParams
7980
numerics: numerics.RuntimeParams
8081
pedestal: pedestal_model_params.RuntimeParams
82+
pedestal_policy: pedestal_policy_runtime_params.PedestalPolicyRuntimeParams
8183
plasma_composition: plasma_composition.RuntimeParams
8284
profile_conditions: profile_conditions.RuntimeParams
8385
solver: solver_params.RuntimeParams

torax/_src/config/tests/build_runtime_params_test.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,32 @@ def test_pedestal_is_time_dependent(self):
9595
set_pedestal={0.0: True, 1.0: False},
9696
)
9797
)
98+
pedestal_policy = pedestal.build_pedestal_model().pedestal_policy
9899
# Check at time 0.
99100

100101
pedestal_params = pedestal.build_runtime_params(t=0.0)
102+
pedestal_policy_rp = pedestal.build_pedestal_policy_runtime_params()
101103
assert isinstance(pedestal_params, set_tped_nped.RuntimeParams)
102-
np.testing.assert_allclose(pedestal_params.set_pedestal, True)
104+
np.testing.assert_allclose(
105+
pedestal_policy.initial_state(
106+
t=0.0, runtime_params=pedestal_policy_rp
107+
).use_pedestal,
108+
True,
109+
)
103110
np.testing.assert_allclose(pedestal_params.T_i_ped, 0.0)
104111
np.testing.assert_allclose(pedestal_params.T_e_ped, 1.0)
105112
np.testing.assert_allclose(pedestal_params.n_e_ped, 2.0e20)
106113
np.testing.assert_allclose(pedestal_params.rho_norm_ped_top, 3.0)
107114
# And check after the time limit.
108115
pedestal_params = pedestal.build_runtime_params(t=1.0)
116+
# Note: pedestal_policy_rp does not depend on time for its structure
109117
assert isinstance(pedestal_params, set_tped_nped.RuntimeParams)
110-
np.testing.assert_allclose(pedestal_params.set_pedestal, False)
118+
np.testing.assert_allclose(
119+
pedestal_policy.initial_state(
120+
t=1.0, runtime_params=pedestal_policy_rp
121+
).use_pedestal,
122+
False,
123+
)
111124
np.testing.assert_allclose(pedestal_params.T_i_ped, 1.0)
112125
np.testing.assert_allclose(pedestal_params.T_e_ped, 2.0)
113126
np.testing.assert_allclose(pedestal_params.n_e_ped, 3.0e20)

torax/_src/fvm/calc_coeffs.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from torax._src.fvm import cell_variable
2828
from torax._src.geometry import geometry
2929
from torax._src.pedestal_model import pedestal_model as pedestal_model_lib
30+
from torax._src.pedestal_policy import pedestal_policy
3031
from torax._src.sources import source_profile_builders
3132
from torax._src.sources import source_profiles as source_profiles_lib
3233
import typing_extensions
@@ -63,6 +64,7 @@ def __call__(
6364
core_profiles: state.CoreProfiles,
6465
x: tuple[cell_variable.CellVariable, ...],
6566
explicit_source_profiles: source_profiles_lib.SourceProfiles,
67+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
6668
allow_pereverzev: bool = False,
6769
# Checks if reduced calc_coeffs for explicit terms when theta_implicit=1
6870
# should be called
@@ -86,6 +88,7 @@ def __call__(
8688
not recalculated at time t+plus_dt with updated state during the solver
8789
iterations. For sources that are implicit, their explicit profiles are
8890
set to all zeros.
91+
pedestal_policy_state: State held by the pedestal policy.
8992
allow_pereverzev: If True, then the coeffs are being called within a
9093
linear solver. Thus could be either the use_predictor_corrector solver
9194
or as part of calculating the initial guess for the nonlinear solver. In
@@ -121,6 +124,7 @@ def __call__(
121124
explicit_source_profiles=explicit_source_profiles,
122125
physics_models=self.physics_models,
123126
evolving_names=self.evolving_names,
127+
pedestal_policy_state=pedestal_policy_state,
124128
use_pereverzev=use_pereverzev,
125129
explicit_call=explicit_call,
126130
)
@@ -219,6 +223,7 @@ def calc_coeffs(
219223
explicit_source_profiles: source_profiles_lib.SourceProfiles,
220224
physics_models: physics_models_lib.PhysicsModels,
221225
evolving_names: tuple[str, ...],
226+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
222227
use_pereverzev: bool = False,
223228
explicit_call: bool = False,
224229
) -> block_1d_coeffs.Block1DCoeffs:
@@ -241,6 +246,7 @@ def calc_coeffs(
241246
physics_models: The physics models to use for the simulation.
242247
evolving_names: The names of the evolving variables in the order that their
243248
coefficients should be written to `coeffs`.
249+
pedestal_policy_state: State held by the pedestal policy.
244250
use_pereverzev: Toggle whether to calculate Pereverzev terms
245251
explicit_call: If True, indicates that calc_coeffs is being called for the
246252
explicit component of the PDE. Then calculates a reduced Block1DCoeffs if
@@ -267,6 +273,7 @@ def calc_coeffs(
267273
explicit_source_profiles=explicit_source_profiles,
268274
physics_models=physics_models,
269275
evolving_names=evolving_names,
276+
pedestal_policy_state=pedestal_policy_state,
270277
use_pereverzev=use_pereverzev,
271278
)
272279

@@ -285,14 +292,18 @@ def _calc_coeffs_full(
285292
explicit_source_profiles: source_profiles_lib.SourceProfiles,
286293
physics_models: physics_models_lib.PhysicsModels,
287294
evolving_names: tuple[str, ...],
295+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
288296
use_pereverzev: bool = False,
289297
) -> block_1d_coeffs.Block1DCoeffs:
290298
"""See `calc_coeffs` for details."""
291299

292300
consts = constants.CONSTANTS
293301

294302
pedestal_model_output = physics_models.pedestal_model(
295-
runtime_params, geo, core_profiles
303+
runtime_params,
304+
geo,
305+
core_profiles,
306+
pedestal_policy_state,
296307
)
297308

298309
# Boolean mask for enforcing internal temperature boundary conditions to
@@ -352,7 +363,11 @@ def _calc_coeffs_full(
352363

353364
# Diffusion term coefficients
354365
turbulent_transport = physics_models.transport_model(
355-
runtime_params, geo, core_profiles, pedestal_model_output
366+
runtime_params,
367+
geo,
368+
core_profiles,
369+
pedestal_policy_state,
370+
pedestal_model_output,
356371
)
357372
neoclassical_transport = physics_models.neoclassical_models.transport(
358373
runtime_params, geo, core_profiles

torax/_src/fvm/newton_raphson_solve_block.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from torax._src.fvm import jax_root_finding
3636
from torax._src.fvm import residual_and_loss
3737
from torax._src.geometry import geometry
38+
from torax._src.pedestal_policy import pedestal_policy
3839
from torax._src.solver import predictor_corrector_method
3940
from torax._src.sources import source_profiles
4041

@@ -68,6 +69,7 @@ def newton_raphson_solve_block(
6869
physics_models: physics_models_lib.PhysicsModels,
6970
coeffs_callback: calc_coeffs.CoeffsCallback,
7071
evolving_names: tuple[str, ...],
72+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
7173
initial_guess_mode: enums.InitialGuessMode,
7274
maxiter: int,
7375
tol: float,
@@ -129,6 +131,7 @@ def newton_raphson_solve_block(
129131
core_profiles. Repeatedly called by the iterative optimizer.
130132
evolving_names: The names of variables within the core profiles that should
131133
evolve.
134+
pedestal_policy_state: State variables held by the pedestal policy.
132135
initial_guess_mode: chooses the initial_guess for the iterative method,
133136
either x_old or linear step. When taking the linear step, it is also
134137
recommended to use Pereverzev-Corrigan terms if the transport coefficients
@@ -160,6 +163,7 @@ def newton_raphson_solve_block(
160163
core_profiles_t,
161164
x_old,
162165
explicit_source_profiles=explicit_source_profiles,
166+
pedestal_policy_state=pedestal_policy_state,
163167
explicit_call=True,
164168
)
165169

@@ -176,6 +180,7 @@ def newton_raphson_solve_block(
176180
core_profiles_t,
177181
x_old,
178182
explicit_source_profiles=explicit_source_profiles,
183+
pedestal_policy_state=pedestal_policy_state,
179184
allow_pereverzev=True,
180185
explicit_call=True,
181186
)
@@ -194,6 +199,7 @@ def newton_raphson_solve_block(
194199
coeffs_exp=coeffs_exp_linear,
195200
coeffs_callback=coeffs_callback,
196201
explicit_source_profiles=explicit_source_profiles,
202+
pedestal_policy_state=pedestal_policy_state,
197203
)
198204
init_x_new_vec = fvm_conversions.cell_variable_tuple_to_vec(init_x_new)
199205
case enums.InitialGuessMode.X_OLD:
@@ -215,6 +221,7 @@ def newton_raphson_solve_block(
215221
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
216222
physics_models=physics_models,
217223
explicit_source_profiles=explicit_source_profiles,
224+
pedestal_policy_state=pedestal_policy_state,
218225
coeffs_old=coeffs_old,
219226
evolving_names=evolving_names,
220227
)

torax/_src/fvm/optimizer_solve_block.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from torax._src.fvm import fvm_conversions
3333
from torax._src.fvm import residual_and_loss
3434
from torax._src.geometry import geometry
35+
from torax._src.pedestal_policy import pedestal_policy
3536
from torax._src.solver import predictor_corrector_method
3637
from torax._src.sources import source_profiles
3738

@@ -57,6 +58,7 @@ def optimizer_solve_block(
5758
core_profiles_t: state.CoreProfiles,
5859
core_profiles_t_plus_dt: state.CoreProfiles,
5960
explicit_source_profiles: source_profiles.SourceProfiles,
61+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
6062
physics_models: physics_models_lib.PhysicsModels,
6163
coeffs_callback: calc_coeffs.CoeffsCallback,
6264
evolving_names: tuple[str, ...],
@@ -98,6 +100,7 @@ def optimizer_solve_block(
98100
being evolved by the PDE system.
99101
explicit_source_profiles: Pre-calculated sources implemented as explicit
100102
sources in the PDE.
103+
pedestal_policy_state: State variables held by the pedestal policy.
101104
physics_models: Physics models used for the calculations.
102105
coeffs_callback: Calculates diffusion, convection etc. coefficients given a
103106
core_profiles. Repeatedly called by the iterative optimizer.
@@ -124,6 +127,7 @@ def optimizer_solve_block(
124127
core_profiles_t,
125128
x_old,
126129
explicit_source_profiles=explicit_source_profiles,
130+
pedestal_policy_state=pedestal_policy_state,
127131
explicit_call=True,
128132
)
129133

@@ -141,6 +145,7 @@ def optimizer_solve_block(
141145
core_profiles_t,
142146
x_old,
143147
explicit_source_profiles=explicit_source_profiles,
148+
pedestal_policy_state=pedestal_policy_state,
144149
allow_pereverzev=True,
145150
explicit_call=True,
146151
)
@@ -158,6 +163,7 @@ def optimizer_solve_block(
158163
coeffs_exp=coeffs_exp_linear,
159164
coeffs_callback=coeffs_callback,
160165
explicit_source_profiles=explicit_source_profiles,
166+
pedestal_policy_state=pedestal_policy_state,
161167
)
162168
init_x_new_vec = fvm_conversions.cell_variable_tuple_to_vec(init_x_new)
163169
case enums.InitialGuessMode.X_OLD:
@@ -187,6 +193,7 @@ def optimizer_solve_block(
187193
init_x_new_vec=init_x_new_vec,
188194
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
189195
explicit_source_profiles=explicit_source_profiles,
196+
pedestal_policy_state=pedestal_policy_state,
190197
physics_models=physics_models,
191198
coeffs_old=coeffs_old,
192199
evolving_names=evolving_names,

torax/_src/fvm/residual_and_loss.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from torax._src.fvm import discrete_system
3939
from torax._src.fvm import fvm_conversions
4040
from torax._src.geometry import geometry
41+
from torax._src.pedestal_policy import pedestal_policy
4142
from torax._src.sources import source_profiles
4243

4344
Block1DCoeffs: TypeAlias = block_1d_coeffs.Block1DCoeffs
@@ -201,6 +202,7 @@ def theta_method_block_residual(
201202
x_old: tuple[cell_variable.CellVariable, ...],
202203
core_profiles_t_plus_dt: state.CoreProfiles,
203204
explicit_source_profiles: source_profiles.SourceProfiles,
205+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
204206
physics_models: physics_models_lib.PhysicsModels,
205207
coeffs_old: Block1DCoeffs,
206208
evolving_names: tuple[str, ...],
@@ -220,6 +222,7 @@ def theta_method_block_residual(
220222
being evolved by the PDE system.
221223
explicit_source_profiles: Pre-calculated sources implemented as explicit
222224
sources in the PDE.
225+
pedestal_policy_state: State variables held by the pedestal policy.
223226
physics_models: Physics models used for the calculations.
224227
coeffs_old: The coefficients calculated at x_old.
225228
evolving_names: The names of variables within the core profiles that should
@@ -252,6 +255,7 @@ def theta_method_block_residual(
252255
core_profiles=core_profiles_t_plus_dt,
253256
explicit_source_profiles=explicit_source_profiles,
254257
physics_models=physics_models,
258+
pedestal_policy_state=pedestal_policy_state,
255259
evolving_names=evolving_names,
256260
use_pereverzev=False,
257261
)
@@ -290,6 +294,7 @@ def theta_method_block_loss(
290294
x_old: tuple[cell_variable.CellVariable, ...],
291295
core_profiles_t_plus_dt: state.CoreProfiles,
292296
explicit_source_profiles: source_profiles.SourceProfiles,
297+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
293298
physics_models: physics_models_lib.PhysicsModels,
294299
coeffs_old: Block1DCoeffs,
295300
evolving_names: tuple[str, ...],
@@ -309,6 +314,7 @@ def theta_method_block_loss(
309314
being evolved by the PDE system.
310315
explicit_source_profiles: pre-calculated sources implemented as explicit
311316
sources in the PDE
317+
pedestal_policy_state: State variables held by the pedestal policy.
312318
physics_models: Physics models used for the calculations.
313319
coeffs_old: The coefficients calculated at x_old.
314320
evolving_names: The names of variables within the core profiles that should
@@ -326,6 +332,7 @@ def theta_method_block_loss(
326332
x_new_guess_vec=x_new_guess_vec,
327333
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
328334
explicit_source_profiles=explicit_source_profiles,
335+
pedestal_policy_state=pedestal_policy_state,
329336
physics_models=physics_models,
330337
coeffs_old=coeffs_old,
331338
evolving_names=evolving_names,
@@ -349,6 +356,7 @@ def jaxopt_solver(
349356
init_x_new_vec: jax.Array,
350357
core_profiles_t_plus_dt: state.CoreProfiles,
351358
explicit_source_profiles: source_profiles.SourceProfiles,
359+
pedestal_policy_state: pedestal_policy.PedestalPolicy,
352360
physics_models: physics_models_lib.PhysicsModels,
353361
coeffs_old: Block1DCoeffs,
354362
evolving_names: tuple[str, ...],
@@ -370,6 +378,7 @@ def jaxopt_solver(
370378
being evolved by the PDE system.
371379
explicit_source_profiles: pre-calculated sources implemented as explicit
372380
sources in the PDE.
381+
pedestal_policy_state: State variables held by the pedestal policy.
373382
physics_models: Physics models used for the calculations.
374383
coeffs_old: The coefficients calculated at x_old.
375384
evolving_names: The names of variables within the core profiles that should
@@ -394,6 +403,7 @@ def jaxopt_solver(
394403
physics_models=physics_models,
395404
coeffs_old=coeffs_old,
396405
evolving_names=evolving_names,
406+
pedestal_policy_state=pedestal_policy_state,
397407
)
398408
solver = jaxopt.LBFGS(fun=loss, maxiter=maxiter, tol=tol, implicit_diff=True)
399409
solver_output = solver.run(init_x_new_vec)

torax/_src/fvm/tests/calc_coeffs_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,13 +79,20 @@ def test_calc_coeffs_smoke_test(
7979
neoclassical_models=physics_models.neoclassical_models,
8080
explicit=True,
8181
)
82+
pedestal_policy_state = (
83+
physics_models.pedestal_model.pedestal_policy.initial_state(
84+
t=torax_config.numerics.t_initial,
85+
runtime_params=runtime_params.pedestal_policy,
86+
)
87+
)
8288
calc_coeffs.calc_coeffs(
8389
runtime_params=runtime_params,
8490
geo=geo,
8591
core_profiles=core_profiles,
8692
physics_models=physics_models,
8793
explicit_source_profiles=explicit_source_profiles,
8894
evolving_names=evolving_names,
95+
pedestal_policy_state=pedestal_policy_state,
8996
use_pereverzev=False,
9097
)
9198

0 commit comments

Comments
 (0)