Skip to content

Commit 747e2ad

Browse files
authored
Pointwise resolution (#258)
* initial commit * fixed unit tests * updated smearing * updated/fixed pointwise impl * fix method name * datasets now have proper names, instead of `R_0` and `Qz_0`. * added color changer, renamed default model, minor fixes * fix ruff, fix test name * added experiment index * added tests * fix issue with multiple experiments and a single model * added a helper method
1 parent f7032e1 commit 747e2ad

File tree

19 files changed

+331
-54
lines changed

19 files changed

+331
-54
lines changed

docs/src/tutorials/advancedfitting/multi_contrast.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,8 +358,8 @@
358358
"d83acmw.head_layer.area_per_molecule_parameter.enabled = True\n",
359359
"d83acmw.tail_layer.area_per_molecule_parameter.enabled = True\n",
360360
"\n",
361-
"d70d2o.constain_multiple_contrast(d13d2o)\n",
362-
"d83acmw.constain_multiple_contrast(d70d2o)"
361+
"d70d2o.constrain_multiple_contrast(d13d2o)\n",
362+
"d83acmw.constrain_multiple_contrast(d70d2o)"
363363
]
364364
},
365365
{

docs/src/tutorials/simulation/resolution_functions.ipynb

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
"from easyreflectometry.model import Model\n",
4747
"from easyreflectometry.model import LinearSpline\n",
4848
"from easyreflectometry.model import PercentageFwhm\n",
49+
"from easyreflectometry.model import Pointwise\n",
4950
"from easyreflectometry.sample import Layer\n",
5051
"from easyreflectometry.sample import Material\n",
5152
"from easyreflectometry.sample import Multilayer\n",
@@ -115,6 +116,16 @@
115116
"dict_reference['10'] = load(file_path_10)"
116117
]
117118
},
119+
{
120+
"cell_type": "code",
121+
"execution_count": null,
122+
"id": "e5f65ed7",
123+
"metadata": {},
124+
"outputs": [],
125+
"source": [
126+
"dict_reference['0']"
127+
]
128+
},
118129
{
119130
"cell_type": "markdown",
120131
"id": "1ab3a164-62c8-4bd3-b0d8-e6f22c83dc74",
@@ -251,9 +262,15 @@
251262
"id": "defd6dd5-c618-4af6-a5c7-17532207f0a0",
252263
"metadata": {},
253264
"source": [
254-
"## Resolution functions\n",
255-
"\n",
256-
"We now define the different resoultion functions. "
265+
"## Resolution functions "
266+
]
267+
},
268+
{
269+
"cell_type": "markdown",
270+
"id": "c9d903db",
271+
"metadata": {},
272+
"source": [
273+
"We can now define the different resoultion functions. "
257274
]
258275
},
259276
{
@@ -376,11 +393,53 @@
376393
"plt.yscale('log')\n",
377394
"plt.show()"
378395
]
396+
},
397+
{
398+
"cell_type": "code",
399+
"execution_count": null,
400+
"id": "43881642",
401+
"metadata": {},
402+
"outputs": [],
403+
"source": [
404+
"key = '1'\n",
405+
"reference_coords = dict_reference[key]['coords']['Qz_0'].values\n",
406+
"reference_variances = dict_reference[key]['coords']['Qz_0'].variances\n",
407+
"reference_data = dict_reference[key]['data']['R_0'].values\n",
408+
"model_coords = np.linspace(\n",
409+
" start=min(reference_coords),\n",
410+
" stop=max(reference_coords),\n",
411+
" num=1000,\n",
412+
")\n",
413+
"\n",
414+
"model.resolution_function = resolution_function_dict[key]\n",
415+
"model_data = model.interface().reflectity_profile(\n",
416+
" model_coords,\n",
417+
" model.unique_name,\n",
418+
")\n",
419+
"plt.plot(model_coords, model_data, 'k-', label=f'Variable', linewidth=5)\n",
420+
"data_points = []\n",
421+
"data_points.append(reference_coords) # Qz\n",
422+
"data_points.append(reference_data) # R\n",
423+
"data_points.append(reference_variances) # sQz\n",
424+
"model.resolution_function = Pointwise(q_data_points=data_points)\n",
425+
"model_data = model.interface().reflectity_profile(\n",
426+
" model_coords,\n",
427+
" model.unique_name,\n",
428+
")\n",
429+
"plt.plot(model_coords, model_data, 'r-', label=f'Pointwise')\n",
430+
"\n",
431+
"ax = plt.gca()\n",
432+
"ax.set_xlim([-0.01, 0.45])\n",
433+
"ax.set_ylim([1e-10, 2.5])\n",
434+
"plt.legend()\n",
435+
"plt.yscale('log')\n",
436+
"plt.show()"
437+
]
379438
}
380439
],
381440
"metadata": {
382441
"kernelspec": {
383-
"display_name": "easyref",
442+
"display_name": "erl",
384443
"language": "python",
385444
"name": "python3"
386445
},
@@ -394,7 +453,7 @@
394453
"name": "python",
395454
"nbconvert_exporter": "python",
396455
"pygments_lexer": "ipython3",
397-
"version": "3.12.9"
456+
"version": "3.12.10"
398457
}
399458
},
400459
"nbformat": 4,

src/easyreflectometry/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
from .data_store import ProjectData
33
from .measurement import load
44
from .measurement import load_as_dataset
5+
from .measurement import merge_datagroups
56

67
__all__ = [
78
"load",
89
"load_as_dataset",
10+
"merge_datagroups",
911
"ProjectData",
1012
"DataSet1D",
1113
]

src/easyreflectometry/data/measurement.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
__author__ = 'github.com/arm61'
22

3+
import os
34
from typing import TextIO
45
from typing import Union
56

@@ -25,11 +26,16 @@ def load(fname: Union[TextIO, str]) -> sc.DataGroup:
2526
def load_as_dataset(fname: Union[TextIO, str]) -> DataSet1D:
2627
"""Load data from an ORSO .ort file as a DataSet1D."""
2728
data_group = load(fname)
29+
basename = os.path.splitext(os.path.basename(fname))[0]
30+
data_name = 'R_' + basename
31+
coords_name = 'Qz_' + basename
32+
coords_name = list(data_group['coords'].keys())[0] if coords_name not in data_group['coords'] else coords_name
33+
data_name = list(data_group['data'].keys())[0] if data_name not in data_group['data'] else data_name
2834
return DataSet1D(
29-
x=data_group['coords']['Qz_0'].values,
30-
y=data_group['data']['R_0'].values,
31-
ye=data_group['data']['R_0'].variances,
32-
xe=data_group['coords']['Qz_0'].variances,
35+
x=data_group['coords'][coords_name].values,
36+
y=data_group['data'][data_name].values,
37+
ye=data_group['data'][data_name].variances,
38+
xe=data_group['coords'][coords_name].variances,
3339
)
3440

3541

@@ -86,6 +92,8 @@ def _load_txt(fname: Union[TextIO, str]) -> sc.DataGroup:
8692
if ',' in first_line:
8793
delimiter = ','
8894

95+
basename = os.path.splitext(os.path.basename(fname))[0]
96+
8997
try:
9098
# First load only the data to check column count
9199
data = np.loadtxt(fname, delimiter=delimiter, comments='#')
@@ -110,13 +118,44 @@ def _load_txt(fname: Union[TextIO, str]) -> sc.DataGroup:
110118
# Re-raise with more descriptive message
111119
raise ValueError(f"Failed to load data from {fname}: {str(error)}") from error
112120

113-
data = {'R_0': sc.array(dims=['Qz_0'], values=y, variances=np.square(e))}
121+
data_name = 'R_' + basename
122+
coords_name = 'Qz_' + basename
123+
data = {data_name: sc.array(dims=[coords_name], values=y, variances=np.square(e))}
114124
coords = {
115-
data['R_0'].dims[0]: sc.array(
116-
dims=['Qz_0'],
125+
data[data_name].dims[0]: sc.array(
126+
dims=[coords_name],
117127
values=x,
118128
variances=np.square(xe),
119129
unit=sc.Unit('1/angstrom'),
120130
)
121131
}
122132
return sc.DataGroup(data=data, coords=coords)
133+
134+
def merge_datagroups(*data_groups: sc.DataGroup) -> sc.DataGroup:
135+
"""Merge multiple DataGroups into a single DataGroup."""
136+
merged_data = {}
137+
merged_coords = {}
138+
merged_attrs = {}
139+
140+
for group in data_groups:
141+
for key, value in group['data'].items():
142+
if key not in merged_data:
143+
merged_data[key] = value
144+
else:
145+
merged_data[key] = sc.concatenate([merged_data[key], value])
146+
147+
for key, value in group['coords'].items():
148+
if key not in merged_coords:
149+
merged_coords[key] = value
150+
else:
151+
merged_coords[key] = sc.concatenate([merged_coords[key], value])
152+
153+
if 'attrs' not in group:
154+
continue
155+
for key, value in group['attrs'].items():
156+
if key not in merged_attrs:
157+
merged_attrs[key] = value
158+
else:
159+
merged_attrs[key] = {**merged_attrs[key], **value}
160+
161+
return sc.DataGroup(data=merged_data, coords=merged_coords, attrs=merged_attrs)

src/easyreflectometry/fitting.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def fit(self, data: sc.DataGroup, id: int = 0) -> sc.DataGroup:
5050
)
5151
sld_profile = self.easy_science_multi_fitter._fit_objects[i].interface.sld_profile(self._models[i].unique_name)
5252
new_data[f'SLD_{id}'] = sc.array(dims=[f'z_{id}'], values=sld_profile[1] * 1e-6, unit=sc.Unit('1/angstrom') ** 2)
53-
new_data['attrs'][f'R_{id}_model'] = {'model': sc.scalar(self._models[i].as_dict())}
53+
if 'attrs' in new_data:
54+
new_data['attrs'][f'R_{id}_model'] = {'model': sc.scalar(self._models[i].as_dict())}
5455
new_data['coords'][f'z_{id}'] = sc.array(
5556
dims=[f'z_{id}'], values=sld_profile[0], unit=(1 / new_data['coords'][f'Qz_{id}'].unit).unit
5657
)

src/easyreflectometry/model/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
from .model_collection import ModelCollection
33
from .resolution_functions import LinearSpline
44
from .resolution_functions import PercentageFwhm
5+
from .resolution_functions import Pointwise
56
from .resolution_functions import ResolutionFunction
67

78
__all__ = (
89
"LinearSpline",
910
"PercentageFwhm",
11+
"Pointwise",
1012
"ResolutionFunction",
1113
"Model",
1214
"ModelCollection",

src/easyreflectometry/model/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
},
4343
}
4444

45+
COLORS =["#0173B2", "#DE8F05", "#029E73", "#D55E00", "#CC78BC", "#CA9161", "#FBAFE4", "#949494", "#ECE133", "#56B4E9"]
4546

4647
class Model(BaseObj):
4748
"""Model is the class that represents the experiment.
@@ -60,8 +61,8 @@ def __init__(
6061
scale: Union[Parameter, Number, None] = None,
6162
background: Union[Parameter, Number, None] = None,
6263
resolution_function: Union[ResolutionFunction, None] = None,
63-
name: str = 'EasyModel',
64-
color: str = 'black',
64+
name: str = 'Model',
65+
color: str = COLORS[0],
6566
unique_name: Optional[str] = None,
6667
interface=None,
6768
):
@@ -70,7 +71,7 @@ def __init__(
7071
:param sample: The sample being modelled.
7172
:param scale: Scaling factor of profile.
7273
:param background: Linear background magnitude.
73-
:param name: Name of the model, defaults to 'EasyModel'.
74+
:param name: Name of the model, defaults to 'Model'.
7475
:param resolution_function: Resolution function, defaults to PercentageFwhm.
7576
:param interface: Calculator interface, defaults to `None`.
7677

src/easyreflectometry/model/model_collection.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Optional
55
from typing import Tuple
66

7+
from easyreflectometry.model.model import COLORS
78
from easyreflectometry.sample.collections.base_collection import BaseCollection
89

910
from .model import Model
@@ -18,7 +19,7 @@ class ModelCollection(BaseCollection):
1819
def __init__(
1920
self,
2021
*models: Tuple[Model],
21-
name: str = 'EasyModels',
22+
name: str = 'Models',
2223
interface=None,
2324
unique_name: Optional[str] = None,
2425
populate_if_none: bool = True,
@@ -41,7 +42,8 @@ def add_model(self, model: Optional[Model] = None):
4142
:param model: Model to add.
4243
"""
4344
if model is None:
44-
model = Model(name='EasyModel added', interface=self.interface)
45+
color = COLORS[len(self) % len(COLORS)]
46+
model = Model(name='Model', interface=self.interface, color=color)
4547
self.append(model)
4648

4749
def duplicate_model(self, index: int):

src/easyreflectometry/model/resolution_functions.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def from_dict(cls, data: dict) -> ResolutionFunction:
3030
return PercentageFwhm(data['constant'])
3131
if data['smearing'] == 'LinearSpline':
3232
return LinearSpline(data['q_data_points'], data['fwhm_values'])
33+
if data['smearing'] == 'Pointwise':
34+
return Pointwise([data['q_data_points'], data['R_data_points'], data['sQz_data_points']])
3335
raise ValueError('Unknown resolution function type')
3436

3537

@@ -60,3 +62,60 @@ def as_dict(
6062
self, skip: Optional[List[str]] = None
6163
) -> dict[str, str]: # skip is kept for consistency of the as_dict signature
6264
return {'smearing': 'LinearSpline', 'q_data_points': list(self.q_data_points), 'fwhm_values': list(self.fwhm_values)}
65+
66+
# add pointwise smearing funtion
67+
class Pointwise(ResolutionFunction):
68+
def __init__(self, q_data_points: list[np.ndarray]):
69+
self.q_data_points = q_data_points
70+
self.q = None
71+
72+
def smearing(self, q: Union[np.ndarray, float] = None) -> np.ndarray:
73+
74+
Qz = self.q_data_points[0]
75+
R = self.q_data_points[1]
76+
sQz = self.q_data_points[2]
77+
if q is None:
78+
q = self.q_data_points[0]
79+
self.q = q
80+
sQzs = np.sqrt(sQz)
81+
if isinstance(Qz, float):
82+
Qz = np.array(Qz)
83+
84+
smeared = self.apply_smooth_smearing(Qz, R, sQzs)
85+
return smeared
86+
87+
def as_dict(
88+
self, skip: Optional[List[str]] = None
89+
) -> dict[str, str]: # skip is kept for consistency of the as_dict signature
90+
return {'smearing': 'Pointwise',
91+
'q_data_points': list(self.q_data_points[0]),
92+
'R_data_points': list(self.q_data_points[1]),
93+
'sQz_data_points': list(self.q_data_points[2])}
94+
95+
def gaussian_smearing(self, qt, Qz, R, sQz):
96+
weights = np.exp(-0.5 * ((qt - Qz) / sQz) ** 2)
97+
if np.sum(weights) == 0 or not np.isfinite(np.sum(weights)):
98+
return np.sum(R)
99+
weights /= (sQz * np.sqrt(2 * np.pi))
100+
return np.sum(R * weights) / np.sum(weights)
101+
102+
103+
def apply_smooth_smearing(self, Qz, R, sQzs):
104+
"""
105+
Apply smooth resolution smearing using convolution with Gaussian kernel.
106+
"""
107+
if self.q is None:
108+
R_smeared = np.zeros_like(Qz)
109+
else:
110+
R_smeared = np.zeros_like(self.q)
111+
112+
if not isinstance(Qz, np.ndarray):
113+
Qz = np.array(Qz)
114+
if not isinstance(R, np.ndarray):
115+
R = np.array(R)
116+
R_smeared = np.zeros_like(self.q)
117+
118+
for i, qt in enumerate(self.q):
119+
R_smeared[i] = self.gaussian_smearing(qt, Qz, R, sQzs)
120+
121+
return R_smeared

0 commit comments

Comments
 (0)