Skip to content

Commit 46935e9

Browse files
authored
Merge pull request #391 from samuelgarcia/shank_id_none
First attemps to make shank_ids optional (None) instead of array of [""]
2 parents 1ac74b1 + c376138 commit 46935e9

File tree

5 files changed

+51
-20
lines changed

5 files changed

+51
-20
lines changed

src/probeinterface/probe.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,11 @@ def get_shank_count(self) -> int:
297297
"""
298298
Return the number of shanks for this probe.
299299
"""
300-
assert self.shank_ids is not None
301-
n = len(np.unique(self.shank_ids))
300+
# assert self.shank_ids is not None
301+
if self.shank_ids is None:
302+
n = 1
303+
else:
304+
n = len(np.unique(self.shank_ids))
302305
return n
303306

304307
def set_contacts(
@@ -380,7 +383,8 @@ def set_contacts(
380383
self.set_contact_ids(contact_ids)
381384

382385
if shank_ids is None:
383-
self._shank_ids = np.zeros(n, dtype=str)
386+
# self._shank_ids = np.zeros(n, dtype=str)
387+
self._shank_ids = None
384388
else:
385389
self._shank_ids = np.asarray(shank_ids).astype(str)
386390
if self.shank_ids.size != n:
@@ -601,11 +605,15 @@ def get_shanks(self):
601605
"""
602606
Return the list of Shank objects for this Probe
603607
"""
604-
assert self.shank_ids is not None, "Can only get shanks if `shank_ids` exist"
605-
shanks = []
606-
for shank_id in np.unique(self.shank_ids):
607-
shank = Shank(probe=self, shank_id=shank_id)
608-
shanks.append(shank)
608+
# assert self.shank_ids is not None, "Can only get shanks if `shank_ids` exist"
609+
if self.shank_ids is None:
610+
# has a unique shank
611+
shanks = [Shank(probe=self, shank_id=None)]
612+
else:
613+
shanks = []
614+
for shank_id in np.unique(self.shank_ids):
615+
shank = Shank(probe=self, shank_id=shank_id)
616+
shanks.append(shank)
609617
return shanks
610618

611619
def __eq__(self, other):
@@ -937,13 +945,18 @@ def from_dict(d: dict) -> "Probe":
937945
"""
938946
probe = Probe(ndim=d["ndim"], si_units=d["si_units"])
939947

948+
shank_ids = d.get("shank_ids", None)
949+
if shank_ids is not None and np.all(shank_ids == ""):
950+
# backward compatible hack with previous version
951+
shank_ids = None
952+
940953
probe.set_contacts(
941954
positions=d["contact_positions"],
942955
plane_axes=d["contact_plane_axes"],
943956
shapes=d["contact_shapes"],
944957
shape_params=d["contact_shape_params"],
945958
contact_ids=d.get("contact_ids", None),
946-
shank_ids=d.get("shank_ids", None),
959+
shank_ids=shank_ids,
947960
contact_sides=d.get("contact_sides", None),
948961
)
949962

@@ -1032,7 +1045,11 @@ def to_numpy(self, complete: bool = False) -> np.array:
10321045
param_shape.append(k)
10331046
for k in param_shape:
10341047
dtype += [(k, "float64")]
1035-
dtype += [("shank_ids", "U64"), ("contact_ids", "U64")]
1048+
1049+
if self._shank_ids is not None:
1050+
dtype += [("shank_ids", "U64")]
1051+
1052+
dtype += [("contact_ids", "U64")]
10361053

10371054
if self._contact_sides is not None:
10381055
dtype += [
@@ -1060,7 +1077,8 @@ def to_numpy(self, complete: bool = False) -> np.array:
10601077
for k, v in p.items():
10611078
arr[k][i] = v
10621079

1063-
arr["shank_ids"] = self.shank_ids
1080+
if self._shank_ids is not None:
1081+
arr["shank_ids"] = self.shank_ids
10641082

10651083
if self._contact_sides is not None:
10661084
arr["contact_sides"] = self.contact_sides

src/probeinterface/shank.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ def __init__(self, probe, shank_id):
1515
self.shank_id = shank_id
1616

1717
def get_indices(self):
18-
(inds,) = np.nonzero(self.probe.shank_ids == self.shank_id)
18+
if self.probe.shank_ids is None:
19+
inds = np.arange(self.probe.get_contact_count(), dtype=int)
20+
else:
21+
inds = np.flatnonzero(self.probe.shank_ids == self.shank_id)
1922
return inds
2023

2124
def get_contact_count(self):

tests/test_io/test_3brain.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,7 @@ def test_3brain():
4343
assert np.all(np.isclose(np.diff(unique_rows), contact_pitch)), file
4444
unique_cols = np.unique(probe.contact_positions[:, 0])
4545
assert np.all(np.isclose(np.diff(unique_cols), contact_pitch))
46+
47+
48+
if __name__ == "__main__":
49+
test_3brain()

tests/test_io/test_io.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ def test_BIDS_format(tmp_path):
8383
probe.set_contact_ids(probe_el_ids)
8484

8585
# switch to more generic dtype for shank_ids
86-
probe.set_shank_ids(probe.shank_ids.astype(str))
86+
if probe.shank_ids is not None:
87+
probe.set_shank_ids(probe.shank_ids.astype(str))
8788

8889
write_BIDS_probe(folder_path, probegroup)
8990

@@ -103,7 +104,8 @@ def test_BIDS_format(tmp_path):
103104
t = np.array([list(probe_read.contact_ids).index(elid) for elid in probe_orig.contact_ids])
104105

105106
assert all(probe_orig.contact_ids == probe_read.contact_ids[t])
106-
assert all(probe_orig.shank_ids == probe_read.shank_ids[t])
107+
if probe_orig.shank_ids is not None:
108+
assert all(probe_orig.shank_ids == probe_read.shank_ids[t])
107109
assert all(probe_orig.contact_shapes == probe_read.contact_shapes[t])
108110
assert probe_orig.ndim == probe_read.ndim
109111
assert probe_orig.si_units == probe_read.si_units
@@ -206,8 +208,12 @@ def test_prb(tmp_path):
206208

207209

208210
if __name__ == "__main__":
209-
# test_probeinterface_format()
210-
# test_BIDS_format()
211+
import tempfile
212+
213+
tmp_path = Path(tempfile.mkdtemp())
214+
215+
# test_probeinterface_format(tmp_path)
216+
test_BIDS_format(tmp_path)
211217
# test_BIDS_format_empty()
212218
# test_BIDS_format_minimal()
213219
pass

tests/test_probe.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,10 +229,10 @@ def test_double_side_probe():
229229

230230

231231
if __name__ == "__main__":
232-
test_probe()
232+
import tempfile
233233

234-
tmp_path = Path("tmp")
235-
tmp_path.mkdir(exist_ok=True)
236-
test_save_to_zarr(tmp_path)
234+
tmp_path = Path(tempfile.mkdtemp())
237235

236+
test_probe()
237+
test_save_to_zarr(tmp_path)
238238
test_double_side_probe()

0 commit comments

Comments
 (0)