diff --git a/spec/ndx-probeinterface.extensions.yaml b/spec/ndx-probeinterface.extensions.yaml index 6aa5604..1a1c857 100644 --- a/spec/ndx-probeinterface.extensions.yaml +++ b/spec/ndx-probeinterface.extensions.yaml @@ -21,6 +21,10 @@ groups: default_value: micrometer doc: SI unit used to define the probe; e.g. 'meter'. required: false + - name: annotations + dtype: text + doc: annotations attached to the probe + required: false datasets: - name: planar_contour dtype: float diff --git a/src/pynwb/ndx_probeinterface/io.py b/src/pynwb/ndx_probeinterface/io.py index 2a192f9..07993ed 100644 --- a/src/pynwb/ndx_probeinterface/io.py +++ b/src/pynwb/ndx_probeinterface/io.py @@ -1,5 +1,6 @@ from typing import Union, List, Optional import numpy as np +import json from probeinterface import Probe, ProbeGroup from pynwb.file import Device @@ -11,8 +12,7 @@ inverted_unit_map = {v: k for k, v in unit_map.items()} -def from_probeinterface(probe_or_probegroup: Union[Probe, ProbeGroup], - name: Optional[str] = None) -> List[Device]: +def from_probeinterface(probe_or_probegroup: Union[Probe, ProbeGroup]) -> List[Device]: """ Construct ndx-probeinterface Probe devices from a probeinterface.Probe @@ -33,7 +33,7 @@ def from_probeinterface(probe_or_probegroup: Union[Probe, ProbeGroup], probes = probe_or_probegroup.probes devices = [] for probe in probes: - devices.append(_single_probe_to_nwb_device(probe, name=name)) + devices.append(_single_probe_to_nwb_device(probe)) return devices @@ -53,6 +53,11 @@ def to_probeinterface(ndx_probe) -> Probe: """ ndim = ndx_probe.ndim unit = inverted_unit_map[ndx_probe.unit] + name = ndx_probe.name + serial_number = ndx_probe.serial_number + model_name = ndx_probe.model_name + manufacturer = ndx_probe.manufacturer + polygon = ndx_probe.planar_contour positions = [] @@ -105,7 +110,14 @@ def to_probeinterface(ndx_probe) -> Probe: if device_channel_indices is not None: device_channel_indices = [item for sublist in device_channel_indices for item in sublist] - probeinterface_probe = Probe(ndim=ndim, si_units=unit) + probeinterface_probe = Probe( + ndim=ndim, + si_units=unit, + name=name, + serial_number=serial_number, + manufacturer=manufacturer, + model_name=model_name + ) probeinterface_probe.set_contacts( positions=positions, shapes=shapes, shape_params=shape_params, plane_axes=plane_axes, shank_ids=shank_ids ) @@ -113,11 +125,12 @@ def to_probeinterface(ndx_probe) -> Probe: if device_channel_indices is not None: probeinterface_probe.set_device_channel_indices(channel_indices=device_channel_indices) probeinterface_probe.set_planar_contour(polygon) + probeinterface_probe.annotate(**json.loads(ndx_probe.annotations)) return probeinterface_probe -def _single_probe_to_nwb_device(probe: Probe, name: Optional[str]=None): +def _single_probe_to_nwb_device(probe: Probe): from pynwb import get_class Probe = get_class("Probe", "ndx-probeinterface") @@ -156,10 +169,11 @@ def _single_probe_to_nwb_device(probe: Probe, name: Optional[str]=None): kwargs["shank_id"] = probe.shank_ids[index] contact_table.add_row(kwargs) - serial_number = probe.serial_number - model_name = probe.model_name - manufacturer = probe.manufacturer - name = name if name is not None else probe.name + annotations = probe.annotations.copy() + name = annotations.pop("name") if "name" in annotations else None + serial_number = annotations.pop("serial_number") if "serial_number" in annotations else None + model_name = annotations.pop("model_name") if "model_name" in annotations else None + manufacturer = annotations.pop("manufacturer") if "manufacturer" in annotations else None probe_device = Probe( name=name, @@ -169,7 +183,8 @@ def _single_probe_to_nwb_device(probe: Probe, name: Optional[str]=None): ndim=probe.ndim, unit=unit_map[probe.si_units], planar_contour=planar_contour, - contact_table=contact_table + contact_table=contact_table, + annotations=json.dumps(annotations) ) - return probe_device + return probe_device \ No newline at end of file diff --git a/src/pynwb/tests/test_probe.py b/src/pynwb/tests/test_probe.py index 0697c00..6e4b597 100644 --- a/src/pynwb/tests/test_probe.py +++ b/src/pynwb/tests/test_probe.py @@ -1,6 +1,7 @@ import pytest import datetime import numpy as np +import json import probeinterface as pi @@ -25,14 +26,14 @@ def set_up_nwbfile(): def create_single_shank_probe(): probe = pi.generate_linear_probe() - probe.annotate(name="Single-shank") + probe.annotate(name="Single-shank", custom_key="custom annotation") probe.set_contact_ids([f"c{i}" for i in range(probe.get_contact_count())]) return probe def create_multi_shank_probe(): probe = pi.generate_multi_shank() - probe.annotate(name="Multi-shank") + probe.annotate(name="Multi-shank", custom_key="custom annotation") probe.set_contact_ids([f"cm{i}" for i in range(probe.get_contact_count())]) return probe @@ -69,6 +70,9 @@ def test_constructor_from_probe_single_shank(self): probe_array = probe.to_numpy() np.testing.assert_array_equal(contact_table["contact_position"][:], probe.contact_positions) np.testing.assert_array_equal(contact_table["contact_shape"][:], probe_array["contact_shapes"]) + keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"] + filtered_annotations = {key: value for key, value in probe.annotations.items() if key not in keys_to_filter} + self.assertDictEqual(json.loads(device.annotations), filtered_annotations) # set channel indices device_channel_indices = np.arange(probe.get_contact_count()) @@ -78,9 +82,6 @@ def test_constructor_from_probe_single_shank(self): contact_table = device_w_indices.contact_table np.testing.assert_array_equal(contact_table["device_channel_index_pi"][:], device_channel_indices) - devices_w_names = Probe.from_probeinterface(probe, name="Test Probe") - assert devices_w_names[0].name == "Test Probe" - def test_constructor_from_probe_multi_shank(self): """Test that the constructor from Probe sets values as expected for multi-shank.""" @@ -108,6 +109,9 @@ def test_constructor_from_probe_multi_shank(self): np.testing.assert_array_equal( contact_table["shank_id"][:], probe.shank_ids ) + keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"] + filtered_annotations = {key: value for key, value in probe.annotations.items() if key not in keys_to_filter} + self.assertDictEqual(json.loads(device.annotations), filtered_annotations) def test_constructor_from_probegroup(self): """Test that the constructor from probegroup sets values as expected.""" @@ -142,6 +146,10 @@ def test_constructor_from_probegroup(self): contact_table["device_channel_index_pi"][:], device_channel_indices ) + keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"] + filtered_annotations = {key: value for key, value in probe.annotations.items() if key not in keys_to_filter} + self.assertDictEqual(json.loads(device.annotations), filtered_annotations) + class TestProbeRoundtrip(TestCase): """Simple roundtrip test for Probe device.""" @@ -174,7 +182,13 @@ def test_roundtrip_nwb_from_probe_single_shank(self): with NWBHDF5IO(self.path0, mode="r", load_namespaces=True) as io: read_nwbfile = io.read() devices = read_nwbfile.devices - self.assertContainerEqual(device, read_nwbfile.devices[device.name]) + self.assertContainerEqual(device, read_nwbfile.devices[device.name]) + keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"] + filtered_annotations = { + key: value for key, value in read_nwbfile.devices[device.name].to_probeinterface().annotations.items() + if key not in keys_to_filter + } + self.assertDictEqual(json.loads(device.annotations), filtered_annotations) def test_roundtrip_nwb_from_probe_multi_shank(self): devices = Probe.from_probeinterface(self.probe1) @@ -188,6 +202,12 @@ def test_roundtrip_nwb_from_probe_multi_shank(self): read_nwbfile = io.read() devices = read_nwbfile.devices self.assertContainerEqual(device, read_nwbfile.devices[device.name]) + keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"] + filtered_annotations = { + key: value for key, value in read_nwbfile.devices[device.name].to_probeinterface().annotations.items() + if key not in keys_to_filter + } + self.assertDictEqual(json.loads(device.annotations), filtered_annotations) def test_roundtrip_nwb_from_probegroup(self): devices = Probe.from_probeinterface(self.probegroup) @@ -201,18 +221,25 @@ def test_roundtrip_nwb_from_probegroup(self): read_nwbfile = io.read() for device in devices: self.assertContainerEqual(device, read_nwbfile.devices[device.name]) - + keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"] + filtered_annotations = { + key: value for key, value in read_nwbfile.devices[device.name].to_probeinterface().annotations.items() + if key not in keys_to_filter + } + self.assertDictEqual(json.loads(device.annotations), filtered_annotations) def test_roundtrip_pi_from_probe_single_shank(self): probe_arr = self.probe0.to_numpy() devices = Probe.from_probeinterface(self.probe0) device = devices[0] np.testing.assert_array_equal(probe_arr, device.to_probeinterface().to_numpy()) + self.assertDictEqual(self.probe0.annotations, device.to_probeinterface().annotations) def test_roundtrip_pi_from_probe_multi_shank(self): probe_arr = self.probe1.to_numpy() devices = Probe.from_probeinterface(self.probe1) device = devices[0] np.testing.assert_array_equal(probe_arr, device.to_probeinterface().to_numpy()) + self.assertDictEqual(self.probe1.annotations, device.to_probeinterface().annotations) if __name__ == "__main__": diff --git a/src/spec/create_extension_spec.py b/src/spec/create_extension_spec.py index a6a3628..127dc22 100644 --- a/src/spec/create_extension_spec.py +++ b/src/spec/create_extension_spec.py @@ -100,7 +100,13 @@ def main(): probe = NWBGroupSpec( doc="Neural probe object according to probeinterface specification", attributes=[ - NWBAttributeSpec(name="ndim", doc="dimension of the probe", dtype="int", required=True, default_value=2), + NWBAttributeSpec( + name="ndim", + doc="dimension of the probe", + dtype="int", + required=True, + default_value=2 + ), NWBAttributeSpec( name="model_name", doc="model of the probe; e.g. 'Neuropixels 1.0'", @@ -120,6 +126,12 @@ def main(): required=True, default_value="micrometer", ), + NWBAttributeSpec( + name="annotations", + doc="annotations attached to the probe", + dtype="text", + required=False + ), ], neurodata_type_inc="Device", neurodata_type_def="Probe", @@ -151,4 +163,4 @@ def main(): if __name__ == "__main__": # usage: python create_extension_spec.py - main() + main() \ No newline at end of file