Skip to content

Commit d536255

Browse files
Merge pull request #5177 from medha-14/parameter_serialisation
[GSoC 2025] Serialising Parameter Sets
2 parents 4a19a92 + bd0b460 commit d536255

File tree

9 files changed

+697
-29
lines changed

9 files changed

+697
-29
lines changed

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
"https://doi.org/10.1137/20M1336898", # DOI link to ignore
9393
"https://en.wikipedia.org/wiki/", # Wikipedia link to ignore
9494
"https://books.google.co.uk/books",
95+
"https://docs.scipy.org/doc/scipy", # SciPy docs timeout intermittently
9596
]
9697

9798

src/pybamm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
from .expression_tree.operations.jacobian import Jacobian
5959
from .expression_tree.operations.convert_to_casadi import CasadiConverter
6060
from .expression_tree.operations.unpack_symbols import SymbolUnpacker
61-
from .expression_tree.operations.serialise import Serialise
61+
from .expression_tree.operations.serialise import Serialise,ExpressionFunctionParameter
6262

6363
# Model classes
6464
from .models.base_model import BaseModel

src/pybamm/expression_tree/operations/serialise.py

Lines changed: 142 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import importlib
4+
import inspect
45
import json
56
import numbers
67
import re
@@ -16,6 +17,31 @@
1617
SUPPORTED_SCHEMA_VERSION = "1.0"
1718

1819

20+
class ExpressionFunctionParameter(pybamm.UnaryOperator):
21+
def __init__(self, name, child, func_name, func_args):
22+
super().__init__(name, child)
23+
self.func_name = func_name
24+
self.func_args = func_args
25+
26+
def _unary_evaluate(self, child):
27+
"""Evaluate the symbolic expression (the child)"""
28+
return child
29+
30+
def to_source(self):
31+
"""
32+
Creates python source code for the function.
33+
"""
34+
src = f"def {self.func_name}({', '.join(self.func_args)}):\n"
35+
36+
expression = self.child.create_copy()
37+
for child in expression.pre_order():
38+
if isinstance(child, pybamm.Parameter) and child.name not in self.func_args:
39+
child.name = f'Parameter("{child.name}")'
40+
41+
src += f" return {expression.to_equation()}\n"
42+
return src
43+
44+
1945
class Serialise:
2046
"""
2147
Converts a discretised model to and from a JSON file.
@@ -1285,6 +1311,72 @@ def load_custom_model(filename: str | dict) -> pybamm.BaseModel:
12851311

12861312
return model
12871313

1314+
@staticmethod
1315+
def save_parameters(parameters: dict, filename=None):
1316+
"""
1317+
Serializes a dictionary of parameters to a JSON file.
1318+
The values can be numbers, PyBaMM symbols, or callables.
1319+
1320+
Parameters
1321+
----------
1322+
parameters : dict
1323+
A dictionary of parameter names and values.
1324+
Values can be numeric, PyBaMM symbols, or callables.
1325+
1326+
filename : str, optional
1327+
If given, saves the serialized parameters to this file.
1328+
"""
1329+
parameter_values_dict = {}
1330+
1331+
for k, v in parameters.items():
1332+
if callable(v):
1333+
parameter_values_dict[k] = Serialise.convert_symbol_to_json(
1334+
Serialise.convert_function_to_symbolic_expression(v, k)
1335+
)
1336+
else:
1337+
parameter_values_dict[k] = Serialise.convert_symbol_to_json(v)
1338+
1339+
if filename is not None:
1340+
with open(filename, "w") as f:
1341+
json.dump(parameter_values_dict, f, indent=4)
1342+
1343+
@staticmethod
1344+
def load_parameters(filename):
1345+
"""
1346+
Load a JSON file of parameters (either from Serialise.save_parameters
1347+
or from a standard pybamm.ParameterValues.save), and return a
1348+
pybamm.ParameterValues object.
1349+
1350+
- If a value is a dict with a "type" key, deserialize it as a PyBaMM symbol.
1351+
- Otherwise (float, int, bool, str, list, dict-without-type), leave it as-is.
1352+
"""
1353+
with open(filename) as f:
1354+
raw_dict = json.load(f)
1355+
1356+
deserialized = {}
1357+
for key, val in raw_dict.items():
1358+
if isinstance(val, dict) and "type" in val:
1359+
deserialized[key] = Serialise.convert_symbol_from_json(val)
1360+
1361+
elif isinstance(val, list):
1362+
deserialized[key] = val
1363+
1364+
elif isinstance(val, (numbers.Number | bool)):
1365+
deserialized[key] = val
1366+
1367+
elif isinstance(val, str):
1368+
deserialized[key] = val
1369+
1370+
elif isinstance(val, dict):
1371+
deserialized[key] = val
1372+
1373+
else:
1374+
raise ValueError(
1375+
f"Unsupported parameter format for key '{key}': {val!r}"
1376+
)
1377+
1378+
return pybamm.ParameterValues(deserialized)
1379+
12881380
# Helper functions
12891381

12901382
def _get_pybamm_class(self, snippet: dict):
@@ -1448,6 +1540,40 @@ def _convert_options(self, d):
14481540
else:
14491541
return d
14501542

1543+
def convert_function_to_symbolic_expression(func, name=None):
1544+
"""
1545+
Converts a Python function to a PyBaMM ExpressionFunctionParameter object.
1546+
1547+
1548+
Parameters
1549+
----------
1550+
func : callable
1551+
The Python function to convert. Its body should operate on symbolic inputs
1552+
(e.g., x+1, x*y) so it can be represented as a PyBaMM expression.
1553+
1554+
1555+
name : str, optional
1556+
The name of the function to use in the symbolic expression. If not provided,
1557+
the function's __name__ is used.
1558+
1559+
1560+
Returns
1561+
-------
1562+
ExpressionFunctionParameter
1563+
A symbolic wrapper for the function that preserves its name, arguments,
1564+
and expression body.
1565+
"""
1566+
func_name = name or func.__name__
1567+
1568+
sig = inspect.signature(func)
1569+
arg_names = list(sig.parameters.keys())
1570+
1571+
sym_inputs = [pybamm.Parameter(arg) for arg in arg_names]
1572+
1573+
sym_output = func(*sym_inputs)
1574+
1575+
return ExpressionFunctionParameter(func_name, sym_output, func_name, arg_names)
1576+
14511577
@staticmethod
14521578
def convert_symbol_to_json(
14531579
symbol: pybamm.Symbol | numbers.Number | list,
@@ -1481,8 +1607,15 @@ def convert_symbol_to_json(
14811607
>>> Serialise.convert_symbol_to_json(v)
14821608
{'type': 'Variable', 'name': 'c', 'domains': {'primary': [], 'secondary': [], 'tertiary': [], 'quaternary': []}, 'bounds': [{'type': 'Scalar', 'value': np.float64(-inf)}, {'type': 'Scalar', 'value': np.float64(inf)}]}
14831609
"""
1484-
1485-
if isinstance(symbol, numbers.Number | list):
1610+
if isinstance(symbol, ExpressionFunctionParameter):
1611+
return {
1612+
"type": "ExpressionFunctionParameter",
1613+
"name": symbol.name,
1614+
"children": [Serialise.convert_symbol_to_json(symbol.child)],
1615+
"func_name": symbol.func_name,
1616+
"func_args": symbol.func_args,
1617+
}
1618+
elif isinstance(symbol, numbers.Number | list):
14861619
return symbol
14871620
elif isinstance(symbol, pybamm.Time):
14881621
return {"type": "Time"}
@@ -1692,6 +1825,13 @@ def convert_symbol_from_json(
16921825
},
16931826
diff_variable=diff_variable,
16941827
)
1828+
elif symbol_type == "ExpressionFunctionParameter":
1829+
return ExpressionFunctionParameter(
1830+
json_data["name"],
1831+
Serialise.convert_symbol_from_json(json_data["children"][0]),
1832+
json_data["func_name"],
1833+
json_data["func_args"],
1834+
)
16951835
elif symbol_type == "PrimaryBroadcast":
16961836
child = Serialise.convert_symbol_from_json(json_data["children"][0])
16971837
domain = json_data["broadcast_domain"]

src/pybamm/models/full_battery_models/lithium_ion/basic_spm_with_3d_thermal.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ def __init__(self, options=None, name="SPM with Separate Cell Domain"):
4848
T = pybamm.Variable("Cell temperature [K]", domain="cell")
4949

5050
if self.options.get("cell geometry") == "pouch":
51-
x = pybamm.SpatialVariable("x", domain="cell")
52-
y = pybamm.SpatialVariable("y", domain="cell")
53-
z = pybamm.SpatialVariable("z", domain="cell")
54-
integration_vars = [x, y, z]
51+
self.x_cell = pybamm.SpatialVariable("x", domain="cell")
52+
self.y_cell = pybamm.SpatialVariable("y", domain="cell")
53+
self.z_cell = pybamm.SpatialVariable("z", domain="cell")
54+
integration_vars = [self.x_cell, self.y_cell, self.z_cell]
5555
elif self.options.get("cell geometry") == "cylindrical":
56-
r = pybamm.SpatialVariable("r_macro", domain="cell")
57-
z = pybamm.SpatialVariable("z", domain="cell")
58-
integration_vars = [r, z]
56+
self.r_cell = pybamm.SpatialVariable("r_macro", domain="cell")
57+
self.z_cell = pybamm.SpatialVariable("z", domain="cell")
58+
integration_vars = [self.r_cell, self.z_cell]
5959
else:
6060
raise ValueError(
6161
f"Geometry type '{self.options.get('cell geometry')}' is not supported. "
@@ -261,9 +261,8 @@ def __init__(self, options=None, name="SPM with Separate Cell Domain"):
261261
def set_thermal_bcs(self, T):
262262
geometry_type = self.options.get("cell geometry", "pouch")
263263
if geometry_type == "pouch":
264-
y = pybamm.SpatialVariable("y", "cell")
265-
z = pybamm.SpatialVariable("z", "cell")
266-
T_amb = self.param.T_amb(y, z, pybamm.t)
264+
# Reuse the spatial variables created in __init__
265+
T_amb = self.param.T_amb(self.y_cell, self.z_cell, pybamm.t)
267266
face_params = {
268267
"x_min": self.param.h_edge_x_min,
269268
"x_max": self.param.h_edge_x_max,
@@ -273,9 +272,8 @@ def set_thermal_bcs(self, T):
273272
"z_max": self.param.h_edge_z_max,
274273
}
275274
elif geometry_type == "cylindrical":
276-
r = pybamm.SpatialVariable("r_macro", "cell")
277-
z = pybamm.SpatialVariable("z", "cell")
278-
T_amb = self.param.T_amb(r, z, pybamm.t)
275+
# Reuse the spatial variables created in __init__
276+
T_amb = self.param.T_amb(self.r_cell, self.z_cell, pybamm.t)
279277
face_params = {
280278
"r_min": self.param.h_edge_radial_min,
281279
"r_max": self.param.h_edge_radial_max,

src/pybamm/parameters/parameter_values.py

Lines changed: 118 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
import json
12
import numbers
23
from collections import defaultdict
4+
from pathlib import Path
35
from pprint import pformat
46
from warnings import warn
57

68
import numpy as np
79

810
import pybamm
11+
from pybamm.expression_tree.operations.serialise import Serialise
912
from pybamm.models.full_battery_models.lithium_ion.msmr import (
1013
is_deprecated_msmr_name,
1114
replace_deprecated_msmr_name,
@@ -668,7 +671,10 @@ def process_symbol(self, symbol):
668671
try:
669672
return self._processed_symbols[symbol]
670673
except KeyError:
671-
processed_symbol = self._process_symbol(symbol)
674+
if not isinstance(symbol, pybamm.FunctionParameter):
675+
processed_symbol = self._process_symbol(symbol)
676+
else:
677+
processed_symbol = self._process_function_parameter(symbol)
672678
self._processed_symbols[symbol] = processed_symbol
673679

674680
return processed_symbol
@@ -835,6 +841,88 @@ def _process_symbol(self, symbol):
835841
# Backup option: return the object
836842
return symbol
837843

844+
def _process_function_parameter(self, symbol):
845+
function_parameter = self[symbol.name]
846+
# Handle symbolic function parameter case
847+
if isinstance(function_parameter, pybamm.ExpressionFunctionParameter):
848+
# Process children
849+
new_children = []
850+
for child in symbol.children:
851+
if symbol.diff_variable is not None and any(
852+
x == symbol.diff_variable for x in child.pre_order()
853+
):
854+
# Wrap with NotConstant to avoid simplification,
855+
# which would stop symbolic diff from working properly
856+
new_child = pybamm.NotConstant(child)
857+
new_children.append(self.process_symbol(new_child))
858+
else:
859+
new_children.append(self.process_symbol(child))
860+
861+
# Get the expression and inputs for the function
862+
expression = function_parameter.child
863+
inputs = {
864+
arg: child
865+
for arg, child in zip(
866+
function_parameter.func_args, symbol.children, strict=False
867+
)
868+
}
869+
870+
# Set domains for function inputs in post-order traversal
871+
for node in expression.post_order():
872+
if node.name in inputs:
873+
node.domains = inputs[node.name].domains
874+
else:
875+
node.domains = node.get_children_domains(node.children)
876+
877+
# Combine parameter values with inputs
878+
combined_params = ParameterValues({**self, **inputs})
879+
880+
# Process any FunctionParameter children first to avoid recursion
881+
for child in expression.pre_order():
882+
if isinstance(child, pybamm.FunctionParameter):
883+
# Build new child with parent inputs
884+
new_child_children = [
885+
inputs[child_child.name]
886+
if isinstance(child_child, pybamm.Parameter)
887+
and child_child.name in inputs
888+
else child_child
889+
for child_child in child.children
890+
]
891+
new_child = pybamm.FunctionParameter(
892+
child.name,
893+
dict(zip(child.input_names, new_child_children, strict=False)),
894+
diff_variable=child.diff_variable,
895+
print_name=child.print_name,
896+
)
897+
898+
# For this local combined parameter values, process the new child
899+
# and store the result as the processed symbol for this child
900+
# This means the child is evaluated with the parent inputs only when
901+
# it is called from within the parent function (not elsewhere in
902+
# the expression tree)
903+
combined_params._processed_symbols[child] = (
904+
combined_params.process_symbol(new_child)
905+
)
906+
907+
# Process function with combined parameter values to get a symbolic
908+
# expression
909+
function = combined_params.process_symbol(expression)
910+
911+
# Differentiate if necessary
912+
if symbol.diff_variable is None:
913+
# Use ones_like so that we get the right shapes
914+
function_out = function * pybamm.ones_like(*new_children)
915+
else:
916+
# return differentiated function
917+
new_diff_variable = self.process_symbol(symbol.diff_variable)
918+
function_out = function.diff(new_diff_variable)
919+
920+
return function_out
921+
922+
# Handle non-symbolic function_name case
923+
else:
924+
return self._process_symbol(symbol)
925+
838926
def evaluate(self, symbol, inputs=None):
839927
"""
840928
Process and evaluate a symbol.
@@ -992,3 +1080,32 @@ def __contains__(self, key):
9921080

9931081
def __iter__(self):
9941082
return iter(self._dict_items)
1083+
1084+
@staticmethod
1085+
def from_json(filename_or_dict):
1086+
"""
1087+
Loads a ParameterValues object from a JSON file or a dictionary.
1088+
1089+
Parameters
1090+
----------
1091+
filename_or_dict : string-like or dict
1092+
The filename to load the JSON file from, or a dictionary.
1093+
1094+
Returns
1095+
-------
1096+
ParameterValues
1097+
The ParameterValues object
1098+
"""
1099+
if isinstance(filename_or_dict, str | Path):
1100+
with open(filename_or_dict) as f:
1101+
parameter_values_dict = json.load(f)
1102+
elif isinstance(filename_or_dict, dict):
1103+
parameter_values_dict = filename_or_dict.copy()
1104+
else:
1105+
raise TypeError("Input must be a filename (str or pathlib.Path) or a dict")
1106+
1107+
for key, value in parameter_values_dict.items():
1108+
if isinstance(value, dict):
1109+
parameter_values_dict[key] = Serialise.convert_symbol_from_json(value)
1110+
1111+
return ParameterValues(parameter_values_dict)

0 commit comments

Comments
 (0)