Skip to content

Commit fa4de20

Browse files
goodfeliTorax team
authored andcommitted
Martin scaling
PiperOrigin-RevId: 810477225
1 parent 9ed25a1 commit fa4de20

File tree

9 files changed

+130
-39
lines changed

9 files changed

+130
-39
lines changed

torax/_src/fvm/calc_coeffs.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from torax._src.fvm import cell_variable
3030
from torax._src.geometry import geometry
3131
from torax._src.pedestal_model import pedestal_model as pedestal_model_lib
32+
from torax._src.pedestal_policy import pedestal_policy
3233
from torax._src.sources import source_profile_builders
3334
from torax._src.sources import source_profiles as source_profiles_lib
3435
import typing_extensions
@@ -286,14 +287,18 @@ def _calc_coeffs_full(
286287
explicit_source_profiles: source_profiles_lib.SourceProfiles,
287288
physics_models: physics_models_lib.PhysicsModels,
288289
evolving_names: tuple[str, ...],
290+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
289291
use_pereverzev: bool = False,
290292
) -> block_1d_coeffs.Block1DCoeffs:
291293
"""See `calc_coeffs` for details."""
292294

293295
consts = constants.CONSTANTS
294296

295297
pedestal_model_output = physics_models.pedestal_model(
296-
runtime_params, geo, core_profiles
298+
runtime_params,
299+
geo,
300+
core_profiles,
301+
pedestal_policy_state,
297302
)
298303

299304
# Boolean mask for enforcing internal temperature boundary conditions to
@@ -353,7 +358,11 @@ def _calc_coeffs_full(
353358

354359
# Diffusion term coefficients
355360
turbulent_transport = physics_models.transport_model(
356-
runtime_params, geo, core_profiles, pedestal_model_output
361+
runtime_params,
362+
geo,
363+
core_profiles,
364+
pedestal_policy_state,
365+
pedestal_model_output,
357366
)
358367
neoclassical_transport = physics_models.neoclassical_models.transport(
359368
runtime_params, geo, core_profiles

torax/_src/pedestal_model/no_pedestal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ class NoPedestal(pedestal_model.PedestalModel):
2424
2525
This is a placeholder pedestal model that is used when there is no pedestal.
2626
It returns infinite pedestal location and zero temperature and density.
27-
Assuming set_pedestal is set to False properly this will not be used, but
28-
this is a safe fallback in case set_pedestal is not set properly and is needed
27+
Assuming use_pedestal is set to False properly this will not be used, but
28+
this is a safe fallback in case use_pedestal is not set properly and is needed
2929
for the jax cond to work.
3030
"""
3131

torax/_src/pedestal_model/pedestal_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torax._src import state
2626
from torax._src.config import runtime_params_slice
2727
from torax._src.geometry import geometry
28+
from torax._src.pedestal_policy import pedestal_policy
2829

2930
# pylint: disable=invalid-name
3031
# Using physics notation naming convention
@@ -72,6 +73,7 @@ def __call__(
7273
runtime_params: runtime_params_slice.RuntimeParams,
7374
geo: geometry.Geometry,
7475
core_profiles: state.CoreProfiles,
76+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
7577
) -> PedestalModelOutput:
7678
if not getattr(self, "_frozen", False):
7779
raise RuntimeError(
@@ -80,7 +82,7 @@ def __call__(
8082
)
8183

8284
return jax.lax.cond(
83-
runtime_params.pedestal.set_pedestal,
85+
pedestal_policy_state.use_pedestal,
8486
lambda: self._call_implementation(runtime_params, geo, core_profiles),
8587
# Set the pedestal location to infinite to indicate that the pedestal is
8688
# not present.

torax/_src/pedestal_model/pydantic_model.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,8 @@
3030
class BasePedestal(torax_pydantic.BaseModelFrozen, abc.ABC):
3131
"""Base class for pedestal models.
3232
33-
Attributes:
34-
set_pedestal: Whether to use the pedestal model and set the pedestal. Can be
35-
time varying.
3633
"""
3734

38-
set_pedestal: torax_pydantic.TimeVaryingScalar = (
39-
torax_pydantic.ValidatedDefault(False)
40-
)
41-
4235
@abc.abstractmethod
4336
def build_pedestal_model(self) -> pedestal_model.PedestalModel:
4437
"""Builds the pedestal model."""
@@ -91,7 +84,6 @@ def build_runtime_params(
9184
self, t: chex.Numeric
9285
) -> set_pped_tpedratio_nped.RuntimeParams:
9386
return set_pped_tpedratio_nped.RuntimeParams(
94-
set_pedestal=self.set_pedestal.get_value(t),
9587
P_ped=self.P_ped.get_value(t),
9688
n_e_ped=self.n_e_ped.get_value(t),
9789
n_e_ped_is_fGW=self.n_e_ped_is_fGW,
@@ -138,7 +130,6 @@ def build_runtime_params(
138130
self, t: chex.Numeric
139131
) -> set_tped_nped.RuntimeParams:
140132
return set_tped_nped.RuntimeParams(
141-
set_pedestal=self.set_pedestal.get_value(t),
142133
n_e_ped=self.n_e_ped.get_value(t),
143134
n_e_ped_is_fGW=self.n_e_ped_is_fGW,
144135
T_i_ped=self.T_i_ped.get_value(t),
@@ -150,7 +141,7 @@ def build_runtime_params(
150141
class NoPedestal(BasePedestal):
151142
"""A pedestal model for when there is no pedestal.
152143
153-
Note that setting `set_pedestal` to True with a NoPedestal model is the
144+
Note that setting `use_pedestal` to True with a NoPedestal model is the
154145
equivalent of setting it to False.
155146
"""
156147

@@ -167,7 +158,6 @@ def build_runtime_params(
167158
self, t: chex.Numeric
168159
) -> runtime_params.RuntimeParams:
169160
return runtime_params.RuntimeParams(
170-
set_pedestal=self.set_pedestal.get_value(t),
171161
)
172162

173163

torax/_src/pedestal_model/runtime_params.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,11 @@
1616
import dataclasses
1717

1818
import jax
19-
from torax._src import array_typing
2019

2120

2221
@jax.tree_util.register_dataclass
2322
@dataclasses.dataclass(frozen=True)
2423
class RuntimeParams:
2524
"""Input params for the pedestal model."""
2625

27-
set_pedestal: array_typing.BoolScalar
26+
pass
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2024 DeepMind Technologies Limited
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""The PedestalPolicy abstract base class.
16+
17+
Determines potentially changing pedestal settings during simulation.
18+
"""
19+
import abc
20+
import dataclasses
21+
from typing import Optional
22+
23+
import jax
24+
import jax.numpy as jnp
25+
from torax._src import array_typing
26+
from torax._src import state
27+
28+
# pylint: disable=invalid-name
29+
# Using physics notation naming convention
30+
31+
32+
@jax.tree_util.register_dataclass
33+
@dataclasses.dataclass(frozen=True)
34+
class PedestalPolicyState:
35+
"""State of the PedestalPolicy."""
36+
37+
# Whether to use the pedestal on this time step
38+
use_pedestal: array_typing.BoolScalar
39+
# Factor to scale pedestal height by.
40+
# Not all pedestal policies use this.
41+
scale_pedestal: Optional[array_typing.FloatScalar]
42+
43+
44+
class PedestalPolicy(abc.ABC):
45+
"""Determines potentially changing pedestal settings during simulation.
46+
47+
Subclass responsbilities:
48+
- Must set _frozen = True at the end of the subclass __init__ method to
49+
activate immutability.
50+
"""
51+
52+
def __setattr__(self, attr, value):
53+
# pylint: disable=g-doc-args
54+
# pylint: disable=g-doc-return-or-yield
55+
"""Override __setattr__ to make the class (sort of) immutable.
56+
57+
Note that you can still do obj.field.subfield = x, so it is not true
58+
immutability, but this to helps to avoid some careless errors.
59+
"""
60+
if getattr(self, "_frozen", False):
61+
raise AttributeError("PedestalPolicy is immutable.")
62+
return super().__setattr__(attr, value)
63+
64+
# @abc.abstractmethod
65+
# def _call_implementation(
66+
# self,
67+
# runtime_params: runtime_params_slice.RuntimeParams,
68+
# geo: geometry.Geometry,
69+
# core_profiles: state.CoreProfiles,
70+
# ) -> PedestalModelOutput:
71+
# """Calculate the pedestal values."""
72+
73+
@abc.abstractmethod
74+
def __hash__(self) -> int:
75+
"""Hash function for the pedestal model.
76+
77+
Needed for jax.jit caching to work.
78+
"""
79+
...
80+
81+
@abc.abstractmethod
82+
def __eq__(self, other) -> bool:
83+
"""Equality function for the pedestal model."""
84+
...

torax/_src/transport_model/combined.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torax._src.config import runtime_params_slice
2626
from torax._src.geometry import geometry
2727
from torax._src.pedestal_model import pedestal_model as pedestal_model_lib
28+
from torax._src.pedestal_policy import pedestal_policy
2829
from torax._src.transport_model import runtime_params as runtime_params_lib
2930
from torax._src.transport_model import transport_model as transport_model_lib
3031

@@ -56,6 +57,7 @@ def __call__(
5657
runtime_params: runtime_params_slice.RuntimeParams,
5758
geo: geometry.Geometry,
5859
core_profiles: state.CoreProfiles,
60+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
5961
pedestal_model_output: pedestal_model_lib.PedestalModelOutput,
6062
) -> transport_model_lib.TurbulentTransport:
6163
if not getattr(self, "_frozen", False):
@@ -97,6 +99,7 @@ def __call__(
9799
runtime_params,
98100
geo,
99101
transport_coeffs,
102+
pedestal_policy_state,
100103
pedestal_model_output,
101104
)
102105

torax/_src/transport_model/transport_coefficients_builder.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from torax._src.geometry import geometry
2323
from torax._src.neoclassical import neoclassical_models as neoclassical_models_lib
2424
from torax._src.pedestal_model import pedestal_model as pedestal_model_lib
25+
from torax._src.pedestal_policy import pedestal_policy
2526
from torax._src.transport_model import transport_model as transport_model_lib
2627

2728

@@ -33,13 +34,20 @@ def calculate_total_transport_coeffs(
3334
runtime_params: runtime_params_slice.RuntimeParams,
3435
geo: geometry.Geometry,
3536
core_profiles: state.CoreProfiles,
37+
pedestal_policy_state: pedestal_policy.PedestalPolicyState,
3638
) -> state.CoreTransport:
3739
"""Calculates the transport coefficients."""
38-
pedestal_model_output = pedestal_model(runtime_params, geo, core_profiles)
40+
pedestal_model_output = pedestal_model(
41+
runtime_params,
42+
geo,
43+
core_profiles,
44+
pedestal_policy_state,
45+
)
3946
turbulent_transport = transport_model(
4047
runtime_params=runtime_params,
4148
geo=geo,
4249
core_profiles=core_profiles,
50+
pedestal_policy_state=pedestal_policy_state,
4351
pedestal_model_output=pedestal_model_output,
4452
)
4553
neoclassical_transport_coeffs = neoclassical_models.transport(

0 commit comments

Comments
 (0)