@@ -297,8 +297,11 @@ def get_shank_count(self) -> int:
297297 """
298298 Return the number of shanks for this probe.
299299 """
300- assert self .shank_ids is not None
301- n = len (np .unique (self .shank_ids ))
300+ # assert self.shank_ids is not None
301+ if self .shank_ids is None :
302+ n = 1
303+ else :
304+ n = len (np .unique (self .shank_ids ))
302305 return n
303306
304307 def set_contacts (
@@ -380,7 +383,8 @@ def set_contacts(
380383 self .set_contact_ids (contact_ids )
381384
382385 if shank_ids is None :
383- self ._shank_ids = np .zeros (n , dtype = str )
386+ # self._shank_ids = np.zeros(n, dtype=str)
387+ self ._shank_ids = None
384388 else :
385389 self ._shank_ids = np .asarray (shank_ids ).astype (str )
386390 if self .shank_ids .size != n :
@@ -601,11 +605,15 @@ def get_shanks(self):
601605 """
602606 Return the list of Shank objects for this Probe
603607 """
604- assert self .shank_ids is not None , "Can only get shanks if `shank_ids` exist"
605- shanks = []
606- for shank_id in np .unique (self .shank_ids ):
607- shank = Shank (probe = self , shank_id = shank_id )
608- shanks .append (shank )
608+ # assert self.shank_ids is not None, "Can only get shanks if `shank_ids` exist"
609+ if self .shank_ids is None :
610+ # has a unique shank
611+ shanks = [Shank (probe = self , shank_id = None )]
612+ else :
613+ shanks = []
614+ for shank_id in np .unique (self .shank_ids ):
615+ shank = Shank (probe = self , shank_id = shank_id )
616+ shanks .append (shank )
609617 return shanks
610618
611619 def __eq__ (self , other ):
@@ -937,13 +945,18 @@ def from_dict(d: dict) -> "Probe":
937945 """
938946 probe = Probe (ndim = d ["ndim" ], si_units = d ["si_units" ])
939947
948+ shank_ids = d .get ("shank_ids" , None )
949+ if shank_ids is not None and np .all (shank_ids == "" ):
950+ # backward compatible hack with previous version
951+ shank_ids = None
952+
940953 probe .set_contacts (
941954 positions = d ["contact_positions" ],
942955 plane_axes = d ["contact_plane_axes" ],
943956 shapes = d ["contact_shapes" ],
944957 shape_params = d ["contact_shape_params" ],
945958 contact_ids = d .get ("contact_ids" , None ),
946- shank_ids = d . get ( " shank_ids" , None ) ,
959+ shank_ids = shank_ids ,
947960 contact_sides = d .get ("contact_sides" , None ),
948961 )
949962
@@ -1032,7 +1045,11 @@ def to_numpy(self, complete: bool = False) -> np.array:
10321045 param_shape .append (k )
10331046 for k in param_shape :
10341047 dtype += [(k , "float64" )]
1035- dtype += [("shank_ids" , "U64" ), ("contact_ids" , "U64" )]
1048+
1049+ if self ._shank_ids is not None :
1050+ dtype += [("shank_ids" , "U64" )]
1051+
1052+ dtype += [("contact_ids" , "U64" )]
10361053
10371054 if self ._contact_sides is not None :
10381055 dtype += [
@@ -1060,7 +1077,8 @@ def to_numpy(self, complete: bool = False) -> np.array:
10601077 for k , v in p .items ():
10611078 arr [k ][i ] = v
10621079
1063- arr ["shank_ids" ] = self .shank_ids
1080+ if self ._shank_ids is not None :
1081+ arr ["shank_ids" ] = self .shank_ids
10641082
10651083 if self ._contact_sides is not None :
10661084 arr ["contact_sides" ] = self .contact_sides
0 commit comments