Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions spec/ndx-probeinterface.extensions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ groups:
default_value: micrometer
doc: SI unit used to define the probe; e.g. 'meter'.
required: false
- name: annotations
dtype: text
doc: annotations attached to the probe
required: false
datasets:
- name: planar_contour
dtype: float
Expand Down
37 changes: 26 additions & 11 deletions src/pynwb/ndx_probeinterface/io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Union, List, Optional
import numpy as np
import json
from probeinterface import Probe, ProbeGroup
from pynwb.file import Device

Expand All @@ -11,8 +12,7 @@
inverted_unit_map = {v: k for k, v in unit_map.items()}


def from_probeinterface(probe_or_probegroup: Union[Probe, ProbeGroup],
name: Optional[str] = None) -> List[Device]:
def from_probeinterface(probe_or_probegroup: Union[Probe, ProbeGroup]) -> List[Device]:
"""
Construct ndx-probeinterface Probe devices from a probeinterface.Probe

Expand All @@ -33,7 +33,7 @@ def from_probeinterface(probe_or_probegroup: Union[Probe, ProbeGroup],
probes = probe_or_probegroup.probes
devices = []
for probe in probes:
devices.append(_single_probe_to_nwb_device(probe, name=name))
devices.append(_single_probe_to_nwb_device(probe))
return devices


Expand All @@ -53,6 +53,11 @@ def to_probeinterface(ndx_probe) -> Probe:
"""
ndim = ndx_probe.ndim
unit = inverted_unit_map[ndx_probe.unit]
name = ndx_probe.name
serial_number = ndx_probe.serial_number
model_name = ndx_probe.model_name
manufacturer = ndx_probe.manufacturer

polygon = ndx_probe.planar_contour

positions = []
Expand Down Expand Up @@ -105,19 +110,27 @@ def to_probeinterface(ndx_probe) -> Probe:
if device_channel_indices is not None:
device_channel_indices = [item for sublist in device_channel_indices for item in sublist]

probeinterface_probe = Probe(ndim=ndim, si_units=unit)
probeinterface_probe = Probe(
ndim=ndim,
si_units=unit,
name=name,
serial_number=serial_number,
manufacturer=manufacturer,
model_name=model_name
)
probeinterface_probe.set_contacts(
positions=positions, shapes=shapes, shape_params=shape_params, plane_axes=plane_axes, shank_ids=shank_ids
)
probeinterface_probe.set_contact_ids(contact_ids=contact_ids)
if device_channel_indices is not None:
probeinterface_probe.set_device_channel_indices(channel_indices=device_channel_indices)
probeinterface_probe.set_planar_contour(polygon)
probeinterface_probe.annotate(**json.loads(ndx_probe.annotations))

return probeinterface_probe


def _single_probe_to_nwb_device(probe: Probe, name: Optional[str]=None):
def _single_probe_to_nwb_device(probe: Probe):
from pynwb import get_class

Probe = get_class("Probe", "ndx-probeinterface")
Expand Down Expand Up @@ -156,10 +169,11 @@ def _single_probe_to_nwb_device(probe: Probe, name: Optional[str]=None):
kwargs["shank_id"] = probe.shank_ids[index]
contact_table.add_row(kwargs)

serial_number = probe.serial_number
model_name = probe.model_name
manufacturer = probe.manufacturer
name = name if name is not None else probe.name
annotations = probe.annotations.copy()
name = annotations.pop("name") if "name" in annotations else None
serial_number = annotations.pop("serial_number") if "serial_number" in annotations else None
model_name = annotations.pop("model_name") if "model_name" in annotations else None
manufacturer = annotations.pop("manufacturer") if "manufacturer" in annotations else None

probe_device = Probe(
name=name,
Expand All @@ -169,7 +183,8 @@ def _single_probe_to_nwb_device(probe: Probe, name: Optional[str]=None):
ndim=probe.ndim,
unit=unit_map[probe.si_units],
planar_contour=planar_contour,
contact_table=contact_table
contact_table=contact_table,
annotations=json.dumps(annotations)
)

return probe_device
return probe_device
41 changes: 34 additions & 7 deletions src/pynwb/tests/test_probe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import datetime
import numpy as np
import json

import probeinterface as pi

Expand All @@ -25,14 +26,14 @@ def set_up_nwbfile():

def create_single_shank_probe():
probe = pi.generate_linear_probe()
probe.annotate(name="Single-shank")
probe.annotate(name="Single-shank", custom_key="custom annotation")
probe.set_contact_ids([f"c{i}" for i in range(probe.get_contact_count())])
return probe


def create_multi_shank_probe():
probe = pi.generate_multi_shank()
probe.annotate(name="Multi-shank")
probe.annotate(name="Multi-shank", custom_key="custom annotation")
probe.set_contact_ids([f"cm{i}" for i in range(probe.get_contact_count())])
return probe

Expand Down Expand Up @@ -69,6 +70,9 @@ def test_constructor_from_probe_single_shank(self):
probe_array = probe.to_numpy()
np.testing.assert_array_equal(contact_table["contact_position"][:], probe.contact_positions)
np.testing.assert_array_equal(contact_table["contact_shape"][:], probe_array["contact_shapes"])
keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"]
filtered_annotations = {key: value for key, value in probe.annotations.items() if key not in keys_to_filter}
self.assertDictEqual(json.loads(device.annotations), filtered_annotations)

# set channel indices
device_channel_indices = np.arange(probe.get_contact_count())
Expand All @@ -78,9 +82,6 @@ def test_constructor_from_probe_single_shank(self):
contact_table = device_w_indices.contact_table
np.testing.assert_array_equal(contact_table["device_channel_index_pi"][:], device_channel_indices)

devices_w_names = Probe.from_probeinterface(probe, name="Test Probe")
assert devices_w_names[0].name == "Test Probe"

def test_constructor_from_probe_multi_shank(self):
"""Test that the constructor from Probe sets values as expected for multi-shank."""

Expand Down Expand Up @@ -108,6 +109,9 @@ def test_constructor_from_probe_multi_shank(self):
np.testing.assert_array_equal(
contact_table["shank_id"][:], probe.shank_ids
)
keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"]
filtered_annotations = {key: value for key, value in probe.annotations.items() if key not in keys_to_filter}
self.assertDictEqual(json.loads(device.annotations), filtered_annotations)

def test_constructor_from_probegroup(self):
"""Test that the constructor from probegroup sets values as expected."""
Expand Down Expand Up @@ -142,6 +146,10 @@ def test_constructor_from_probegroup(self):
contact_table["device_channel_index_pi"][:], device_channel_indices
)

keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"]
filtered_annotations = {key: value for key, value in probe.annotations.items() if key not in keys_to_filter}
self.assertDictEqual(json.loads(device.annotations), filtered_annotations)


class TestProbeRoundtrip(TestCase):
"""Simple roundtrip test for Probe device."""
Expand Down Expand Up @@ -174,7 +182,13 @@ def test_roundtrip_nwb_from_probe_single_shank(self):
with NWBHDF5IO(self.path0, mode="r", load_namespaces=True) as io:
read_nwbfile = io.read()
devices = read_nwbfile.devices
self.assertContainerEqual(device, read_nwbfile.devices[device.name])
self.assertContainerEqual(device, read_nwbfile.devices[device.name])
keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"]
filtered_annotations = {
key: value for key, value in read_nwbfile.devices[device.name].to_probeinterface().annotations.items()
if key not in keys_to_filter
}
self.assertDictEqual(json.loads(device.annotations), filtered_annotations)

def test_roundtrip_nwb_from_probe_multi_shank(self):
devices = Probe.from_probeinterface(self.probe1)
Expand All @@ -188,6 +202,12 @@ def test_roundtrip_nwb_from_probe_multi_shank(self):
read_nwbfile = io.read()
devices = read_nwbfile.devices
self.assertContainerEqual(device, read_nwbfile.devices[device.name])
keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"]
filtered_annotations = {
key: value for key, value in read_nwbfile.devices[device.name].to_probeinterface().annotations.items()
if key not in keys_to_filter
}
self.assertDictEqual(json.loads(device.annotations), filtered_annotations)

def test_roundtrip_nwb_from_probegroup(self):
devices = Probe.from_probeinterface(self.probegroup)
Expand All @@ -201,18 +221,25 @@ def test_roundtrip_nwb_from_probegroup(self):
read_nwbfile = io.read()
for device in devices:
self.assertContainerEqual(device, read_nwbfile.devices[device.name])

keys_to_filter = ["name", "manufacturer", "model_name", "serial_number"]
filtered_annotations = {
key: value for key, value in read_nwbfile.devices[device.name].to_probeinterface().annotations.items()
if key not in keys_to_filter
}
self.assertDictEqual(json.loads(device.annotations), filtered_annotations)
def test_roundtrip_pi_from_probe_single_shank(self):
probe_arr = self.probe0.to_numpy()
devices = Probe.from_probeinterface(self.probe0)
device = devices[0]
np.testing.assert_array_equal(probe_arr, device.to_probeinterface().to_numpy())
self.assertDictEqual(self.probe0.annotations, device.to_probeinterface().annotations)

def test_roundtrip_pi_from_probe_multi_shank(self):
probe_arr = self.probe1.to_numpy()
devices = Probe.from_probeinterface(self.probe1)
device = devices[0]
np.testing.assert_array_equal(probe_arr, device.to_probeinterface().to_numpy())
self.assertDictEqual(self.probe1.annotations, device.to_probeinterface().annotations)


if __name__ == "__main__":
Expand Down
16 changes: 14 additions & 2 deletions src/spec/create_extension_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,13 @@ def main():
probe = NWBGroupSpec(
doc="Neural probe object according to probeinterface specification",
attributes=[
NWBAttributeSpec(name="ndim", doc="dimension of the probe", dtype="int", required=True, default_value=2),
NWBAttributeSpec(
name="ndim",
doc="dimension of the probe",
dtype="int",
required=True,
default_value=2
),
NWBAttributeSpec(
name="model_name",
doc="model of the probe; e.g. 'Neuropixels 1.0'",
Expand All @@ -120,6 +126,12 @@ def main():
required=True,
default_value="micrometer",
),
NWBAttributeSpec(
name="annotations",
doc="annotations attached to the probe",
dtype="text",
required=False
),
],
neurodata_type_inc="Device",
neurodata_type_def="Probe",
Expand Down Expand Up @@ -151,4 +163,4 @@ def main():

if __name__ == "__main__":
# usage: python create_extension_spec.py
main()
main()