3838from torax ._src .fvm import discrete_system
3939from torax ._src .fvm import fvm_conversions
4040from torax ._src .geometry import geometry
41+ from torax ._src .pedestal_policy import pedestal_policy
4142from torax ._src .sources import source_profiles
4243
4344Block1DCoeffs : TypeAlias = block_1d_coeffs .Block1DCoeffs
@@ -201,6 +202,7 @@ def theta_method_block_residual(
201202 x_old : tuple [cell_variable .CellVariable , ...],
202203 core_profiles_t_plus_dt : state .CoreProfiles ,
203204 explicit_source_profiles : source_profiles .SourceProfiles ,
205+ pedestal_policy_state : pedestal_policy .PedestalPolicyState ,
204206 physics_models : physics_models_lib .PhysicsModels ,
205207 coeffs_old : Block1DCoeffs ,
206208 evolving_names : tuple [str , ...],
@@ -220,6 +222,7 @@ def theta_method_block_residual(
220222 being evolved by the PDE system.
221223 explicit_source_profiles: Pre-calculated sources implemented as explicit
222224 sources in the PDE.
225+ pedestal_policy_state: State variables held by the pedestal policy.
223226 physics_models: Physics models used for the calculations.
224227 coeffs_old: The coefficients calculated at x_old.
225228 evolving_names: The names of variables within the core profiles that should
@@ -252,6 +255,7 @@ def theta_method_block_residual(
252255 core_profiles = core_profiles_t_plus_dt ,
253256 explicit_source_profiles = explicit_source_profiles ,
254257 physics_models = physics_models ,
258+ pedestal_policy_state = pedestal_policy_state ,
255259 evolving_names = evolving_names ,
256260 use_pereverzev = False ,
257261 )
@@ -290,6 +294,7 @@ def theta_method_block_loss(
290294 x_old : tuple [cell_variable .CellVariable , ...],
291295 core_profiles_t_plus_dt : state .CoreProfiles ,
292296 explicit_source_profiles : source_profiles .SourceProfiles ,
297+ pedestal_policy_state : pedestal_policy .PedestalPolicyState ,
293298 physics_models : physics_models_lib .PhysicsModels ,
294299 coeffs_old : Block1DCoeffs ,
295300 evolving_names : tuple [str , ...],
@@ -309,6 +314,7 @@ def theta_method_block_loss(
309314 being evolved by the PDE system.
310315 explicit_source_profiles: pre-calculated sources implemented as explicit
311316 sources in the PDE
317+ pedestal_policy_state: State variables held by the pedestal policy.
312318 physics_models: Physics models used for the calculations.
313319 coeffs_old: The coefficients calculated at x_old.
314320 evolving_names: The names of variables within the core profiles that should
@@ -326,6 +332,7 @@ def theta_method_block_loss(
326332 x_new_guess_vec = x_new_guess_vec ,
327333 core_profiles_t_plus_dt = core_profiles_t_plus_dt ,
328334 explicit_source_profiles = explicit_source_profiles ,
335+ pedestal_policy_state = pedestal_policy_state ,
329336 physics_models = physics_models ,
330337 coeffs_old = coeffs_old ,
331338 evolving_names = evolving_names ,
@@ -349,6 +356,7 @@ def jaxopt_solver(
349356 init_x_new_vec : jax .Array ,
350357 core_profiles_t_plus_dt : state .CoreProfiles ,
351358 explicit_source_profiles : source_profiles .SourceProfiles ,
359+ pedestal_policy_state : pedestal_policy .PedestalPolicy ,
352360 physics_models : physics_models_lib .PhysicsModels ,
353361 coeffs_old : Block1DCoeffs ,
354362 evolving_names : tuple [str , ...],
@@ -370,6 +378,7 @@ def jaxopt_solver(
370378 being evolved by the PDE system.
371379 explicit_source_profiles: pre-calculated sources implemented as explicit
372380 sources in the PDE.
381+ pedestal_policy_state: State variables held by the pedestal policy.
373382 physics_models: Physics models used for the calculations.
374383 coeffs_old: The coefficients calculated at x_old.
375384 evolving_names: The names of variables within the core profiles that should
@@ -394,6 +403,7 @@ def jaxopt_solver(
394403 physics_models = physics_models ,
395404 coeffs_old = coeffs_old ,
396405 evolving_names = evolving_names ,
406+ pedestal_policy_state = pedestal_policy_state ,
397407 )
398408 solver = jaxopt .LBFGS (fun = loss , maxiter = maxiter , tol = tol , implicit_diff = True )
399409 solver_output = solver .run (init_x_new_vec )
0 commit comments