2727from torax ._src .fvm import cell_variable
2828from torax ._src .geometry import geometry
2929from torax ._src .pedestal_model import pedestal_model as pedestal_model_lib
30+ from torax ._src .pedestal_policy import pedestal_policy
3031from torax ._src .sources import source_profile_builders
3132from torax ._src .sources import source_profiles as source_profiles_lib
3233import typing_extensions
@@ -63,6 +64,7 @@ def __call__(
6364 core_profiles : state .CoreProfiles ,
6465 x : tuple [cell_variable .CellVariable , ...],
6566 explicit_source_profiles : source_profiles_lib .SourceProfiles ,
67+ pedestal_policy_state : pedestal_policy .PedestalPolicyState ,
6668 allow_pereverzev : bool = False ,
6769 # Checks if reduced calc_coeffs for explicit terms when theta_implicit=1
6870 # should be called
@@ -86,6 +88,7 @@ def __call__(
8688 not recalculated at time t+plus_dt with updated state during the solver
8789 iterations. For sources that are implicit, their explicit profiles are
8890 set to all zeros.
91+ pedestal_policy_state: State held by the pedestal policy.
8992 allow_pereverzev: If True, then the coeffs are being called within a
9093 linear solver. Thus could be either the use_predictor_corrector solver
9194 or as part of calculating the initial guess for the nonlinear solver. In
@@ -101,6 +104,16 @@ def __call__(
101104 coeffs: The diffusion, convection, etc. coefficients for this state.
102105 """
103106
107+ # There are cases where pytype fails to enforce this
108+ if not isinstance (
109+ pedestal_policy_state , pedestal_policy .PedestalPolicyState
110+ ):
111+ raise TypeError (
112+ 'Expected `pedestal_policy_state` to be of type '
113+ '`pedestal_policy.PedestalPolicyState`' ,
114+ f'got `{ type (pedestal_policy_state )} `.' ,
115+ )
116+
104117 # Update core_profiles with the subset of new values of evolving variables
105118 core_profiles = updaters .update_core_profiles_during_step (
106119 x ,
@@ -121,6 +134,7 @@ def __call__(
121134 explicit_source_profiles = explicit_source_profiles ,
122135 physics_models = self .physics_models ,
123136 evolving_names = self .evolving_names ,
137+ pedestal_policy_state = pedestal_policy_state ,
124138 use_pereverzev = use_pereverzev ,
125139 explicit_call = explicit_call ,
126140 )
@@ -219,6 +233,7 @@ def calc_coeffs(
219233 explicit_source_profiles : source_profiles_lib .SourceProfiles ,
220234 physics_models : physics_models_lib .PhysicsModels ,
221235 evolving_names : tuple [str , ...],
236+ pedestal_policy_state : pedestal_policy .PedestalPolicyState ,
222237 use_pereverzev : bool = False ,
223238 explicit_call : bool = False ,
224239) -> block_1d_coeffs .Block1DCoeffs :
@@ -241,6 +256,7 @@ def calc_coeffs(
241256 physics_models: The physics models to use for the simulation.
242257 evolving_names: The names of the evolving variables in the order that their
243258 coefficients should be written to `coeffs`.
259+ pedestal_policy_state: State held by the pedestal policy.
244260 use_pereverzev: Toggle whether to calculate Pereverzev terms
245261 explicit_call: If True, indicates that calc_coeffs is being called for the
246262 explicit component of the PDE. Then calculates a reduced Block1DCoeffs if
@@ -251,6 +267,14 @@ def calc_coeffs(
251267 coeffs: Block1DCoeffs containing the coefficients at this time step.
252268 """
253269
270+ # There are cases where pytype fails to enforce this
271+ if not isinstance (pedestal_policy_state , pedestal_policy .PedestalPolicyState ):
272+ raise TypeError (
273+ 'Expected `pedestal_policy_state` to be of type '
274+ '`pedestal_policy.PedestalPolicyState`' ,
275+ f'got `{ type (pedestal_policy_state )} `.' ,
276+ )
277+
254278 # If we are fully implicit and we are making a call for calc_coeffs for the
255279 # explicit components of the PDE, only return a cheaper reduced Block1DCoeffs
256280 if explicit_call and runtime_params .solver .theta_implicit == 1.0 :
@@ -267,6 +291,7 @@ def calc_coeffs(
267291 explicit_source_profiles = explicit_source_profiles ,
268292 physics_models = physics_models ,
269293 evolving_names = evolving_names ,
294+ pedestal_policy_state = pedestal_policy_state ,
270295 use_pereverzev = use_pereverzev ,
271296 )
272297
@@ -285,14 +310,26 @@ def _calc_coeffs_full(
285310 explicit_source_profiles : source_profiles_lib .SourceProfiles ,
286311 physics_models : physics_models_lib .PhysicsModels ,
287312 evolving_names : tuple [str , ...],
313+ pedestal_policy_state : pedestal_policy .PedestalPolicyState ,
288314 use_pereverzev : bool = False ,
289315) -> block_1d_coeffs .Block1DCoeffs :
290316 """See `calc_coeffs` for details."""
291317
292318 consts = constants .CONSTANTS
293319
320+ # There are cases where pytype fails to enforce this
321+ if not isinstance (pedestal_policy_state , pedestal_policy .PedestalPolicyState ):
322+ raise TypeError (
323+ 'Expected `pedestal_policy_state` to be of type '
324+ '`pedestal_policy.PedestalPolicyState`' ,
325+ f'got `{ type (pedestal_policy_state )} `.' ,
326+ )
327+
294328 pedestal_model_output = physics_models .pedestal_model (
295- runtime_params , geo , core_profiles
329+ runtime_params ,
330+ geo ,
331+ core_profiles ,
332+ pedestal_policy_state = pedestal_policy_state ,
296333 )
297334
298335 # Boolean mask for enforcing internal temperature boundary conditions to
@@ -352,7 +389,11 @@ def _calc_coeffs_full(
352389
353390 # Diffusion term coefficients
354391 turbulent_transport = physics_models .transport_model (
355- runtime_params , geo , core_profiles , pedestal_model_output
392+ runtime_params ,
393+ geo ,
394+ core_profiles ,
395+ pedestal_policy_state ,
396+ pedestal_model_output ,
356397 )
357398 neoclassical_transport = physics_models .neoclassical_models .transport (
358399 runtime_params , geo , core_profiles
0 commit comments