Skip to content

Commit 5ba4eef

Browse files
goodfeliTorax team
authored andcommitted
Make PedestalModel a StaticDataclass
PiperOrigin-RevId: 822642397
1 parent 9e7435b commit 5ba4eef

File tree

4 files changed

+8
-70
lines changed

4 files changed

+8
-70
lines changed

torax/_src/pedestal_model/no_pedestal.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""A pedestal model for when there is no pedestal."""
15+
import dataclasses
1516
from jax import numpy as jnp
1617
from torax._src import state
1718
from torax._src.config import runtime_params_slice
1819
from torax._src.geometry import geometry
1920
from torax._src.pedestal_model import pedestal_model
2021

2122

23+
@dataclasses.dataclass(frozen=True, eq=False)
2224
class NoPedestal(pedestal_model.PedestalModel):
2325
"""A pedestal model for when there is no pedestal.
2426
@@ -29,10 +31,6 @@ class NoPedestal(pedestal_model.PedestalModel):
2931
for the jax cond to work.
3032
"""
3133

32-
def __init__(self):
33-
super().__init__()
34-
self._frozen = True
35-
3634
def _call_implementation(
3735
self,
3836
runtime_params: runtime_params_slice.RuntimeParams,
@@ -46,9 +44,3 @@ def _call_implementation(
4644
n_e_ped=0.0,
4745
rho_norm_ped_top_idx=geo.torax_mesh.nx,
4846
)
49-
50-
def __hash__(self):
51-
return hash('NoPedestal')
52-
53-
def __eq__(self, other) -> bool:
54-
return isinstance(other, NoPedestal)

torax/_src/pedestal_model/pedestal_model.py

Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import jax.numpy as jnp
2424
from torax._src import array_typing
2525
from torax._src import state
26+
from torax._src import static_dataclass
2627
from torax._src.config import runtime_params_slice
2728
from torax._src.geometry import geometry
2829

@@ -47,38 +48,16 @@ class PedestalModelOutput:
4748
n_e_ped: array_typing.FloatScalar
4849

4950

50-
class PedestalModel(abc.ABC):
51-
"""Calculates temperature and density of the pedestal.
52-
53-
Subclass responsbilities:
54-
- Must set _frozen = True at the end of the subclass __init__ method to
55-
activate immutability.
56-
"""
57-
58-
def __setattr__(self, attr, value):
59-
# pylint: disable=g-doc-args
60-
# pylint: disable=g-doc-return-or-yield
61-
"""Override __setattr__ to make the class (sort of) immutable.
62-
63-
Note that you can still do obj.field.subfield = x, so it is not true
64-
immutability, but this to helps to avoid some careless errors.
65-
"""
66-
if getattr(self, "_frozen", False):
67-
raise AttributeError("PedestalModels are immutable.")
68-
return super().__setattr__(attr, value)
51+
@dataclasses.dataclass(frozen=True, eq=False)
52+
class PedestalModel(static_dataclass.StaticDataclass, abc.ABC):
53+
"""Calculates temperature and density of the pedestal."""
6954

7055
def __call__(
7156
self,
7257
runtime_params: runtime_params_slice.RuntimeParams,
7358
geo: geometry.Geometry,
7459
core_profiles: state.CoreProfiles,
7560
) -> PedestalModelOutput:
76-
if not getattr(self, "_frozen", False):
77-
raise RuntimeError(
78-
f"Subclass implementation {type(self)} forgot to "
79-
"freeze at the end of __init__."
80-
)
81-
8261
return jax.lax.cond(
8362
runtime_params.pedestal.set_pedestal,
8463
lambda: self._call_implementation(runtime_params, geo, core_profiles),
@@ -103,16 +82,3 @@ def _call_implementation(
10382
core_profiles: state.CoreProfiles,
10483
) -> PedestalModelOutput:
10584
"""Calculate the pedestal values."""
106-
107-
@abc.abstractmethod
108-
def __hash__(self) -> int:
109-
"""Hash function for the pedestal model.
110-
111-
Needed for jax.jit caching to work.
112-
"""
113-
...
114-
115-
@abc.abstractmethod
116-
def __eq__(self, other) -> bool:
117-
"""Equality function for the pedestal model."""
118-
...

torax/_src/pedestal_model/set_pped_tpedratio_nped.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,12 @@ class RuntimeParams(runtime_params_lib.RuntimeParams):
4040
n_e_ped_is_fGW: array_typing.BoolScalar
4141

4242

43+
@dataclasses.dataclass(frozen=True, eq=False)
4344
class SetPressureTemperatureRatioAndDensityPedestalModel(
4445
pedestal_model.PedestalModel
4546
):
4647
"""Pedestal model with specification of pressure, temp ratio, and density."""
4748

48-
def __init__(self):
49-
super().__init__()
50-
self._frozen = True
51-
5249
@override
5350
def _call_implementation(
5451
self,
@@ -113,9 +110,3 @@ def _call_implementation(
113110
rho_norm_ped_top=runtime_params.pedestal.rho_norm_ped_top,
114111
rho_norm_ped_top_idx=ped_idx,
115112
)
116-
117-
def __hash__(self) -> int:
118-
return hash('SetPressureTemperatureRatioAndDensityPedestalModel')
119-
120-
def __eq__(self, other) -> bool:
121-
return isinstance(other, SetPressureTemperatureRatioAndDensityPedestalModel)

torax/_src/pedestal_model/set_tped_nped.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,10 @@ class RuntimeParams(runtime_params_lib.RuntimeParams):
3838
n_e_ped_is_fGW: array_typing.BoolScalar
3939

4040

41+
@dataclasses.dataclass(frozen=True, eq=False)
4142
class SetTemperatureDensityPedestalModel(pedestal_model.PedestalModel):
4243
"""A basic version of the pedestal model that uses direct specification."""
4344

44-
def __init__(
45-
self,
46-
):
47-
super().__init__()
48-
self._frozen = True
49-
5045
@override
5146
def _call_implementation(
5247
self,
@@ -77,9 +72,3 @@ def _call_implementation(
7772
geo.rho_norm - pedestal_params.rho_norm_ped_top
7873
).argmin(),
7974
)
80-
81-
def __hash__(self) -> int:
82-
return hash('SetTemperatureDensityPedestalModel')
83-
84-
def __eq__(self, other) -> bool:
85-
return isinstance(other, SetTemperatureDensityPedestalModel)

0 commit comments

Comments
 (0)