Skip to content

Commit 6e0de47

Browse files
authored
more sophisticated model colour algorithm (#285)
1 parent 072f43b commit 6e0de47

File tree

2 files changed

+112
-7
lines changed

2 files changed

+112
-7
lines changed

src/easyreflectometry/model/model_collection.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
interface=None,
2424
unique_name: Optional[str] = None,
2525
populate_if_none: bool = True,
26+
next_color_index: Optional[int] = None,
2627
**kwargs,
2728
):
2829
if not models:
@@ -33,17 +34,25 @@ def __init__(
3334
# Needed to ensure an empty list is created when saving and instatiating the object as_dict -> from_dict
3435
# Else collisions might occur in global_object.map
3536
self.populate_if_none = False
37+
self._next_color_index = next_color_index
3638

37-
super().__init__(name, interface, unique_name=unique_name, *models, **kwargs)
39+
super().__init__(name, interface, *models, unique_name=unique_name, **kwargs)
40+
41+
color_count = len(COLORS)
42+
if color_count == 0:
43+
self._next_color_index = 0
44+
elif self._next_color_index is None:
45+
self._next_color_index = len(self) % color_count
46+
else:
47+
self._next_color_index %= color_count
3848

3949
def add_model(self, model: Optional[Model] = None):
4050
"""Add a model to the collection.
4151
4252
:param model: Model to add.
4353
"""
4454
if model is None:
45-
color = COLORS[len(self) % len(COLORS)]
46-
model = Model(name='Model', interface=self.interface, color=color)
55+
model = Model(name='Model', interface=self.interface, color=self._current_color())
4756
self.append(model)
4857

4958
def duplicate_model(self, index: int):
@@ -59,6 +68,7 @@ def duplicate_model(self, index: int):
5968
def as_dict(self, skip: List[str] | None = None) -> dict:
6069
this_dict = super().as_dict(skip=skip)
6170
this_dict['populate_if_none'] = self.populate_if_none
71+
this_dict['next_color_index'] = self._next_color_index
6272
return this_dict
6373

6474
@classmethod
@@ -69,16 +79,48 @@ def from_dict(cls, this_dict: dict) -> ModelCollection:
6979
:param data: The dictionary for the collection
7080
"""
7181
collection_dict = this_dict.copy()
72-
# We neeed to call from_dict on the base class to get the models
73-
dict_data = collection_dict['data']
74-
del collection_dict['data']
82+
# We need to call from_dict on the base class to get the models
83+
dict_data = collection_dict.pop('data')
84+
next_color_index = collection_dict.pop('next_color_index', None)
7585

7686
collection = super().from_dict(collection_dict) # type: ModelCollection
7787

7888
for model_data in dict_data:
79-
collection.add_model(Model.from_dict(model_data))
89+
collection._append_internal(Model.from_dict(model_data), advance=False)
8090

8191
if len(collection) != len(this_dict['data']):
8292
raise ValueError(f'Expected {len(collection)} models, got {len(this_dict["data"])}')
8393

94+
color_count = len(COLORS)
95+
if color_count == 0:
96+
collection._next_color_index = 0
97+
elif next_color_index is None:
98+
collection._next_color_index = len(collection) % color_count
99+
else:
100+
collection._next_color_index = next_color_index % color_count
101+
84102
return collection
103+
104+
def append(self, model: Model) -> None: # type: ignore[override]
105+
self._append_internal(model, advance=True)
106+
107+
def _append_internal(self, model: Model, advance: bool) -> None:
108+
super().append(model)
109+
if advance:
110+
self._advance_color_index()
111+
112+
def _advance_color_index(self) -> None:
113+
if not COLORS:
114+
self._next_color_index = 0
115+
return
116+
if self._next_color_index is None:
117+
self._next_color_index = len(self) % len(COLORS)
118+
return
119+
self._next_color_index = (self._next_color_index + 1) % len(COLORS)
120+
121+
def _current_color(self) -> str:
122+
if not COLORS:
123+
raise ValueError('No colors defined for models.')
124+
if self._next_color_index is None:
125+
self._next_color_index = 0
126+
return COLORS[self._next_color_index]

tests/model/test_model_collection.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from easyscience import global_object
22

3+
from easyreflectometry.model.model import COLORS
34
from easyreflectometry.model.model import Model
45
from easyreflectometry.model.model_collection import ModelCollection
56

@@ -52,6 +53,44 @@ def test_add_model(self):
5253
assert collection[0].name == 'Model1'
5354
assert collection[1].name == 'Model2'
5455

56+
def test_add_model_color_cycle(self):
57+
collection = ModelCollection(populate_if_none=False)
58+
59+
collection.add_model()
60+
assert collection[0].color == COLORS[0]
61+
62+
collection.add_model()
63+
assert collection[1].color == COLORS[1]
64+
65+
collection.remove(0)
66+
collection.add_model()
67+
68+
assert collection[0].color == COLORS[1]
69+
assert collection[1].color == COLORS[2]
70+
71+
def test_add_model_color_wrap(self):
72+
collection = ModelCollection(populate_if_none=False)
73+
74+
for _ in range(len(COLORS)):
75+
collection.add_model()
76+
77+
collection.add_model()
78+
79+
assert collection[-1].color == COLORS[0]
80+
81+
def test_add_model_preserves_explicit_color(self):
82+
collection = ModelCollection(populate_if_none=False)
83+
collection.add_model()
84+
expected_index = collection._next_color_index
85+
86+
custom_color = '#ABCDEF'
87+
custom_model = Model(name='Custom', color=custom_color)
88+
89+
collection.add_model(custom_model)
90+
91+
assert collection[-1].color == custom_color
92+
assert collection._next_color_index == (expected_index + 1) % len(COLORS)
93+
5594
def test_delete_model(self):
5695
# When
5796
model_1 = Model(name='Model1')
@@ -94,3 +133,27 @@ def test_dict_round_trip(self):
94133
q.as_dict(skip=['resolution_function', 'interface'])
95134
)
96135
assert p[0]._resolution_function.smearing(5.5) == q[0]._resolution_function.smearing(5.5)
136+
137+
def test_next_color_index_round_trip(self):
138+
collection = ModelCollection(populate_if_none=False)
139+
for _ in range(3):
140+
collection.add_model()
141+
142+
expected_index = collection._next_color_index
143+
dict_repr = collection.as_dict()
144+
global_object.map._clear()
145+
146+
restored = ModelCollection.from_dict(dict_repr)
147+
148+
assert restored._next_color_index == expected_index
149+
150+
def test_legacy_from_dict_sets_color_index(self):
151+
collection = ModelCollection()
152+
legacy_dict = collection.as_dict()
153+
legacy_dict.pop('next_color_index', None)
154+
global_object.map._clear()
155+
156+
restored = ModelCollection.from_dict(legacy_dict)
157+
restored.add_model()
158+
159+
assert [model.color for model in restored] == [COLORS[0], COLORS[1]]

0 commit comments

Comments
 (0)