diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4e1f3085..1caf1af3 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -51,7 +51,7 @@ jobs: python -m pip install -U pip # if running a cron job, we add the --pre flag to test against pre-releases python -m pip install .[test] ${{ github.event_name == 'schedule' && '--pre' || '' }} - + - name: Restore shared data cache id: cache-data uses: actions/cache@v4 @@ -79,6 +79,7 @@ jobs: update_existing: true - name: Coverage + if: success() && matrix.platform == 'ubuntu-latest' uses: codecov/codecov-action@v5 with: token: ${{ secrets.CODECOV_TOKEN }} @@ -118,4 +119,4 @@ jobs: - uses: softprops/action-gh-release@v2 with: generate_release_notes: true - files: './dist/*' + files: "./dist/*" diff --git a/CHANGELOG.md b/CHANGELOG.md index 205a57e8..dd6006db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,27 @@ # Changelog +## [v0.5.0] + +### Features +- Add support for OME-NGFF v0.5 +- Move to zarr-python v3 +- API to delete labels and tables from OME-Zarr containers and HCS plates. + +### API Breaking Changes + +- New `Roi` models, now supporting arbitrary axes. +- The `compressor` argument has been renamed to `compressors` in all relevant functions and methods to reflect the support for multiple compressors in zarr v3. +- The `version` argument has been renamed to `ngff_version` in all relevant functions and methods to specify the OME-NGFF version. +- Remove the `parallel_safe` argument from all zarr related functions and methods. The locking mechanism is now handled internally and only depends on the +`cache`. +- Remove the unused `parent` argument from `ZarrGroupHandler`. +- Internal changes to `ZarrGroupHandler` to support cleanup unused apis. +- Remove `ngio_logger` in favor of standard warnings module. + ## [v0.4.6] ### Bug Fixes -- Fix channel selection from `wavelenght_id` +- Fix channel selection from `wavelength_id` - Fix table opening mode to stop wrtiting groups when opening in append mode. ## [v0.4.5] diff --git a/pyproject.toml b/pyproject.toml index c66e2318..f7ad6b8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,8 +35,8 @@ classifiers = [ dependencies = [ "numpy", "filelock", - "zarr<3", - "anndata>=0.8.0,<0.11.4", # To be removed when we transition to zarr v3 + "zarr>3", + "anndata", "pydantic", "pandas>=1.2.0", "requests", @@ -54,14 +54,23 @@ dependencies = [ # https://peps.python.org/pep-0621/#dependencies-optional-dependencies # "extras" (e.g. for `pip install .[test]`) [project.optional-dependencies] -test = ["pytest", "pytest-cov", "scikit-image"] +test = [ + "pytest", + "pytest-cov", + "scikit-image", + "moto[server]", + "boto", + "pytest_httpserver", + "devtools", + "s3fs", +] dev = [ + "zarrs", "Pympler", "napari", "pyqt5", "matplotlib", - "devtools", "notebook", "mypy", "pdbpp", # https://github.com/pdbpp/pdbpp @@ -155,7 +164,9 @@ minversion = "7.0" testpaths = ["tests"] filterwarnings = [ "error", - "ignore::FutureWarning", # TODO remove after zarr-python v3 + "ignore::zarr.errors.ZarrUserWarning", # required for anndata + "ignore::UserWarning", # from ngio.utils._zarr_utils + "ignore::DeprecationWarning", # temporary ignore deprecation warnings ] addopts = [ "-vv", @@ -224,6 +235,7 @@ test13 = { features = ["py13", "test"], solve-group = "py13" } # dev env dev = { features = ["dev", "test"], solve-group = "py11" } +test = { features = ["test"], solve-group = "py11" } [tool.pixi.tasks] serve_docs = "mkdocs serve" diff --git a/src/ngio/__init__.py b/src/ngio/__init__.py index 235b0d91..7a9e9519 100644 --- a/src/ngio/__init__.py +++ b/src/ngio/__init__.py @@ -9,7 +9,7 @@ __author__ = "Lorenzo Cerrone" __email__ = "lorenzo.cerrone@uzh.ch" -from ngio.common import Dimensions, Roi, RoiPixels +from ngio.common import Dimensions, Roi, RoiSlice from ngio.hcs import ( OmeZarrPlate, OmeZarrWell, @@ -37,6 +37,7 @@ NgffVersions, PixelSize, ) +from ngio.utils import NgioSupportedStore, StoreOrGroup __all__ = [ "AxesSetup", @@ -47,12 +48,14 @@ "ImageInWellPath", "Label", "NgffVersions", + "NgioSupportedStore", "OmeZarrContainer", "OmeZarrPlate", "OmeZarrWell", "PixelSize", "Roi", - "RoiPixels", + "RoiSlice", + "StoreOrGroup", "create_empty_ome_zarr", "create_empty_plate", "create_empty_well", diff --git a/src/ngio/common/__init__.py b/src/ngio/common/__init__.py index 171dbcf9..92249f14 100644 --- a/src/ngio/common/__init__.py +++ b/src/ngio/common/__init__.py @@ -2,22 +2,27 @@ from ngio.common._dimensions import Dimensions from ngio.common._masking_roi import compute_masking_roi -from ngio.common._pyramid import consolidate_pyramid, init_empty_pyramid, on_disk_zoom -from ngio.common._roi import ( - Roi, - RoiPixels, +from ngio.common._pyramid import ( + ChunksLike, + ImagePyramidBuilder, + ShardsLike, + consolidate_pyramid, + on_disk_zoom, ) +from ngio.common._roi import Roi, RoiSlice from ngio.common._zoom import InterpolationOrder, dask_zoom, numpy_zoom __all__ = [ + "ChunksLike", "Dimensions", + "ImagePyramidBuilder", "InterpolationOrder", "Roi", - "RoiPixels", + "RoiSlice", + "ShardsLike", "compute_masking_roi", "consolidate_pyramid", "dask_zoom", - "init_empty_pyramid", "numpy_zoom", "on_disk_zoom", ] diff --git a/src/ngio/common/_masking_roi.py b/src/ngio/common/_masking_roi.py index 9cf592ab..83d1a53c 100644 --- a/src/ngio/common/_masking_roi.py +++ b/src/ngio/common/_masking_roi.py @@ -7,7 +7,7 @@ import scipy.ndimage as ndi from dask.delayed import delayed -from ngio.common._roi import Roi, RoiPixels +from ngio.common._roi import Roi from ngio.ome_zarr_meta import PixelSize from ngio.utils import NgioValueError @@ -135,52 +135,23 @@ def compute_masking_roi( rois = [] for label, slice_ in slices.items(): if len(slice_) == 2: - min_t, max_t = None, None - min_z, max_z = None, None - min_y, min_x = slice_[0].start, slice_[1].start - max_y, max_x = slice_[0].stop, slice_[1].stop + slices = {"y": slice_[0], "x": slice_[1]} elif len(slice_) == 3: - min_t, max_t = None, None - min_z, min_y, min_x = slice_[0].start, slice_[1].start, slice_[2].start - max_z, max_y, max_x = slice_[0].stop, slice_[1].stop, slice_[2].stop + slices = {"z": slice_[0], "y": slice_[1], "x": slice_[2]} elif len(slice_) == 4: - min_t, min_z, min_y, min_x = ( - slice_[0].start, - slice_[1].start, - slice_[2].start, - slice_[3].start, - ) - max_t, max_z, max_y, max_x = ( - slice_[0].stop, - slice_[1].stop, - slice_[2].stop, - slice_[3].stop, - ) + slices = { + "t": slice_[0], + "z": slice_[1], + "y": slice_[2], + "x": slice_[3], + } else: raise ValueError("Invalid slice length.") - if max_t is None: - t_length = None - else: - t_length = max_t - min_t - - if max_z is None: - z_length = None - else: - z_length = max_z - min_z - - roi = RoiPixels( - name=str(label), - x_length=max_x - min_x, - y_length=max_y - min_y, - z_length=z_length, - t_length=t_length, - x=min_x, - y=min_y, - z=min_z, - label=label, + roi = Roi.from_values( + name=str(label), slices=slices, label=label, space="pixel" ) - roi = roi.to_roi(pixel_size) + roi = roi.to_world(pixel_size=pixel_size) rois.append(roi) return rois diff --git a/src/ngio/common/_pyramid.py b/src/ngio/common/_pyramid.py index 5ebba70a..a908ba37 100644 --- a/src/ngio/common/_pyramid.py +++ b/src/ngio/common/_pyramid.py @@ -1,12 +1,10 @@ -import math -from collections.abc import Callable, Sequence -from typing import Literal +from collections.abc import Callable, Mapping, Sequence +from typing import Any, Literal -import dask import dask.array as da import numpy as np import zarr -from zarr.types import DIMENSION_SEPARATOR +from pydantic import BaseModel, ConfigDict, model_validator from ngio.common._zoom import ( InterpolationOrder, @@ -15,10 +13,7 @@ numpy_zoom, ) from ngio.utils import ( - AccessModeLiteral, NgioValueError, - StoreOrGroup, - open_group_wrapper, ) @@ -27,7 +22,10 @@ def _on_disk_numpy_zoom( target: zarr.Array, order: InterpolationOrder, ) -> None: - target[...] = numpy_zoom(source[...], target_shape=target.shape, order=order) + source_array = source[...] + if not isinstance(source_array, np.ndarray): + raise NgioValueError("source zarr array could not be read as a numpy array") + target[...] = numpy_zoom(source_array, target_shape=target.shape, order=order) def _on_disk_dask_zoom( @@ -37,18 +35,20 @@ def _on_disk_dask_zoom( ) -> None: source_array = da.from_zarr(source) target_array = dask_zoom(source_array, target_shape=target.shape, order=order) + # This is a potential fix for Dask 2025.11 + # import dask.config # chunk_size_bytes = np.prod(target.chunks) * target_array.dtype.itemsize - # current_chunk_size = dask.config.get("array.chunk-size", 0) - # - #if current_chunk_size < chunk_size_bytes: - # # Increase the chunk size to avoid dask potentially creating - # # corrupted chunks when writing chunks that are not multiple of the - # # target chunk size - # dask.config.set({"array.chunk-size": f"{chunk_size_bytes}B"}) + # current_chunk_size = dask.config.get("array.chunk-size") + # Increase the chunk size to avoid dask potentially creating + # corrupted chunks when writing chunks that are not multiple of the + # target chunk size + # dask.config.set({"array.chunk-size": f"{chunk_size_bytes}B"}) target_array = target_array.rechunk(target.chunks) target_array = target_array.compute_chunk_sizes() target_array.to_zarr(target) + # Restore previous chunk size + # dask.config.set({"array.chunk-size": current_chunk_size}) def _on_disk_coarsen( @@ -194,67 +194,224 @@ def consolidate_pyramid( processed.append(target_image) -def _maybe_int(value: float | int) -> float | int: - """Convert a float to an int if it is an integer.""" - if isinstance(value, int): - return value - if value.is_integer(): - return int(value) - return value - - -def init_empty_pyramid( - store: StoreOrGroup, - paths: list[str], - ref_shape: Sequence[int], - scaling_factors: Sequence[float], - chunks: Sequence[int] | None = None, - dtype: str = "uint16", - mode: AccessModeLiteral = "a", - dimension_separator: DIMENSION_SEPARATOR = "/", - compressor="default", -) -> None: - # Return the an Image object - if chunks is not None and len(chunks) != len(ref_shape): - raise NgioValueError( - "The shape and chunks must have the same number of dimensions." - ) +################################################ +# +# Builders for image pyramids +# +################################################ + +ChunksLike = tuple[int, ...] | Literal["auto"] +ShardsLike = tuple[int, ...] | Literal["auto"] + - if chunks is not None: - chunks = [min(c, s) for c, s in zip(chunks, ref_shape, strict=True)] +def shapes_from_scaling_factors( + base_shape: tuple[int, ...], + scaling_factors: tuple[float, ...], + num_levels: int, +) -> list[tuple[int, ...]]: + """Compute the shapes of each level in the pyramid from scaling factors. - if len(ref_shape) != len(scaling_factors): - raise NgioValueError( - "The shape and scaling factor must have the same number of dimensions." + Args: + base_shape (tuple[int, ...]): The shape of the base level. + scaling_factors (tuple[float, ...]): The scaling factors between levels. + num_levels (int): The number of levels in the pyramid. + + Returns: + list[tuple[int, ...]]: The shapes of each level in the pyramid. + """ + shapes = [] + current_shape = base_shape + for _ in range(num_levels): + shapes.append(current_shape) + current_shape = tuple( + max(1, int(s / f)) + for s, f in zip(current_shape, scaling_factors, strict=True) ) + return shapes + + +def _check_order(shapes: Sequence[tuple[int, ...]]): + """Check if the shapes are in decreasing order.""" + num_pixels = [np.prod(shape) for shape in shapes] + for i in range(1, len(num_pixels)): + if num_pixels[i] >= num_pixels[i - 1]: + raise NgioValueError("Shapes are not in decreasing order.") - # Ensure scaling factors are int if possible - # To reduce the risk of floating point issues - scaling_factors = [_maybe_int(s) for s in scaling_factors] - root_group = open_group_wrapper(store, mode=mode) +class PyramidLevel(BaseModel): + path: str + shape: tuple[int, ...] + scale: tuple[float, ...] + chunks: ChunksLike = "auto" + shards: ShardsLike | None = None - for path in paths: - if any(s < 1 for s in ref_shape): + @model_validator(mode="after") + def _model_validation(self) -> "PyramidLevel": + # Same length as shape + if len(self.scale) != len(self.shape): raise NgioValueError( - "Level shape must be at least 1 on all dimensions. " - f"Calculated shape: {ref_shape} at level {path}." + "Scale must have the same length as shape " + f"({len(self.shape)}), got {len(self.scale)}" ) - new_arr = root_group.zeros( - name=path, - shape=ref_shape, - dtype=dtype, + if any(isinstance(s, float) and s < 0 for s in self.scale): + raise NgioValueError("Scale values must be positive.") + + if isinstance(self.chunks, tuple): + if len(self.chunks) != len(self.shape): + raise NgioValueError( + "Chunks must have the same length as shape " + f"({len(self.shape)}), got {len(self.chunks)}" + ) + normalized_chunks = [] + for dim_size, chunk_size in zip(self.shape, self.chunks, strict=True): + normalized_chunks.append(min(dim_size, chunk_size)) + self.chunks = tuple(normalized_chunks) + + if isinstance(self.shards, tuple): + if len(self.shards) != len(self.shape): + raise NgioValueError( + "Shards must have the same length as shape " + f"({len(self.shape)}), got {len(self.shards)}" + ) + normalized_shards = [] + for dim_size, shard_size in zip(self.shape, self.shards, strict=True): + normalized_shards.append(min(dim_size, shard_size)) + self.shards = tuple(normalized_shards) + return self + + +class ImagePyramidBuilder(BaseModel): + levels: list[PyramidLevel] + axes: tuple[str, ...] + data_type: str = "uint16" + dimension_separator: Literal[".", "/"] = "/" + compressors: Any = "auto" + zarr_format: Literal[2, 3] = 2 + other_array_kwargs: Mapping[str, Any] = {} + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @classmethod + def from_scaling_factors( + cls, + levels_paths: tuple[str, ...], + scaling_factors: tuple[float, ...], + base_shape: tuple[int, ...], + base_scale: tuple[float, ...], + axes: tuple[str, ...], + chunks: ChunksLike = "auto", + shards: ShardsLike | None = None, + data_type: str = "uint16", + dimension_separator: Literal[".", "/"] = "/", + compressors: Any = "auto", + zarr_format: Literal[2, 3] = 2, + other_array_kwargs: Mapping[str, Any] | None = None, + ) -> "ImagePyramidBuilder": + shapes = shapes_from_scaling_factors( + base_shape=base_shape, + scaling_factors=scaling_factors, + num_levels=len(levels_paths), + ) + return cls.from_shapes( + shapes=shapes, + base_scale=base_scale, + axes=axes, + levels_paths=levels_paths, chunks=chunks, + shards=shards, + data_type=data_type, dimension_separator=dimension_separator, - overwrite=True, - compressor=compressor, + compressors=compressors, + zarr_format=zarr_format, + other_array_kwargs=other_array_kwargs, ) - ref_shape = [ - math.floor(s / sc) for s, sc in zip(ref_shape, scaling_factors, strict=True) - ] - chunks = tuple( - min(c, s) for c, s in zip(new_arr.chunks, ref_shape, strict=True) + @classmethod + def from_shapes( + cls, + shapes: Sequence[tuple[int, ...]], + base_scale: tuple[float, ...], + axes: tuple[str, ...], + levels_paths: Sequence[str] | None = None, + chunks: ChunksLike = "auto", + shards: ShardsLike | None = None, + data_type: str = "uint16", + dimension_separator: Literal[".", "/"] = "/", + compressors: Any = "auto", + zarr_format: Literal[2, 3] = 2, + other_array_kwargs: Mapping[str, Any] | None = None, + ) -> "ImagePyramidBuilder": + levels = [] + if levels_paths is None: + levels_paths = tuple(str(i) for i in range(len(shapes))) + _check_order(shapes) + scale_ = base_scale + for i, (path, shape) in enumerate(zip(levels_paths, shapes, strict=True)): + levels.append( + PyramidLevel( + path=path, + shape=shape, + scale=scale_, + chunks=chunks, + shards=shards, + ) + ) + if i + 1 < len(shapes): + # This only works for downsampling pyramids + # The _check_order function ensures that + # shapes are decreasing + next_shape = shapes[i + 1] + scaling_factor = tuple( + s1 / s2 + for s1, s2 in zip( + shape, + next_shape, + strict=True, + ) + ) + scale_ = tuple( + s * f for s, f in zip(scale_, scaling_factor, strict=True) + ) + other_array_kwargs = other_array_kwargs or {} + return cls( + levels=levels, + axes=axes, + data_type=data_type, + dimension_separator=dimension_separator, + compressors=compressors, + zarr_format=zarr_format, + other_array_kwargs=other_array_kwargs, ) - return None + def to_zarr(self, group: zarr.Group) -> None: + """Save the pyramid specification to a Zarr group. + + Args: + group (zarr.Group): The Zarr group to save the pyramid specification to. + """ + array_static_kwargs = { + "dtype": self.data_type, + "overwrite": True, + "compressors": self.compressors, + **self.other_array_kwargs, + } + + if self.zarr_format == 2: + array_static_kwargs["chunk_key_encoding"] = { + "name": "v2", + "separator": self.dimension_separator, + } + else: + array_static_kwargs["chunk_key_encoding"] = { + "name": "default", + "separator": self.dimension_separator, + } + array_static_kwargs["dimension_names"] = self.axes + for p_level in self.levels: + group.create_array( + name=p_level.path, + shape=tuple(p_level.shape), + chunks=p_level.chunks, + shards=p_level.shards, + **array_static_kwargs, + ) diff --git a/src/ngio/common/_roi.py b/src/ngio/common/_roi.py index bc974907..18ed61c4 100644 --- a/src/ngio/common/_roi.py +++ b/src/ngio/common/_roi.py @@ -4,17 +4,16 @@ the ImageLikeHandler. """ -from typing import TypeVar -from warnings import warn +from collections.abc import Callable +from typing import Literal, Self -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field, field_validator -from ngio.common._dimensions import Dimensions -from ngio.ome_zarr_meta.ngio_specs import DefaultSpaceUnit, PixelSize, SpaceUnits +from ngio.ome_zarr_meta import PixelSize from ngio.utils import NgioValueError -def _world_to_raster(value: float, pixel_size: float, eps: float = 1e-6) -> float: +def world_to_pixel(value: float, pixel_size: float, eps: float = 1e-6) -> float: raster_value = value / pixel_size # If the value is very close to an integer, round it @@ -26,362 +25,291 @@ def _world_to_raster(value: float, pixel_size: float, eps: float = 1e-6) -> floa return raster_value -def _to_raster(value: float, length: float, pixel_size: float) -> tuple[float, float]: - """Convert to raster coordinates.""" - raster_value = _world_to_raster(value, pixel_size) - raster_length = _world_to_raster(length, pixel_size) - return raster_value, raster_length +def pixel_to_world(value: float, pixel_size: float) -> float: + return value * pixel_size -def _to_slice(start: float | None, length: float | None) -> slice: - if length is not None: - assert start is not None - end = start + length - else: - end = None - return slice(start, end) +def _join_roi_names(name1: str | None, name2: str | None) -> str | None: + if name1 is not None and name2 is not None: + if name1 == name2: + return name1 + return f"{name1}:{name2}" + return name1 or name2 + + +def _join_roi_labels(label1: int | None, label2: int | None) -> int | None: + if label1 is not None and label2 is not None: + if label1 == label2: + return label1 + raise NgioValueError("Cannot join ROIs with different labels") + return label1 or label2 + + +class RoiSlice(BaseModel): + axis_name: str + start: float | None = Field(default=None) + length: float | None = Field(default=None, ge=0) + + model_config = ConfigDict(extra="forbid") + + @classmethod + def _from_slice( + cls, + axis_name: str, + selection: slice, + ) -> "RoiSlice": + start = selection.start + length = ( + None + if selection.stop is None or selection.start is None + else selection.stop - selection.start + ) + return cls(axis_name=axis_name, start=start, length=length) + + @classmethod + def from_value( + cls, + axis_name: str, + value: float | tuple[float | None, float | None] | slice, + ) -> "RoiSlice": + if isinstance(value, slice): + return cls._from_slice(axis_name=axis_name, selection=value) + elif isinstance(value, tuple): + return cls(axis_name=axis_name, start=value[0], length=value[1]) + elif isinstance(value, int | float): + return cls(axis_name=axis_name, start=value, length=1) + else: + raise TypeError(f"Unsupported type for slice value: {type(value)}") + def __repr__(self) -> str: + return f"{self.axis_name}: {self.start}->{self.end}" -def _raster_to_world(value: int | float, pixel_size: float) -> float: - """Convert to world coordinates.""" - return value * pixel_size + @property + def end(self) -> float | None: + if self.start is None or self.length is None: + return None + return self.start + self.length + def to_slice(self) -> slice: + return slice(self.start, self.end) -T = TypeVar("T", int, float) + def _is_compatible(self, other: "RoiSlice", msg: str) -> None: + if self.axis_name != other.axis_name: + raise NgioValueError( + f"{msg}: Cannot operate on RoiSlices with different axis names" + ) + def union(self, other: "RoiSlice") -> "RoiSlice": + self._is_compatible(other, "RoiSlice union failed") + start = min(self.start or 0, other.start or 0) + end = max(self.end or float("inf"), other.end or float("inf")) + length = end - start if end > start else 0 + if length == float("inf"): + length = None + return RoiSlice(axis_name=self.axis_name, start=start, length=length) + + def intersection(self, other: "RoiSlice") -> "RoiSlice | None": + self._is_compatible(other, "RoiSlice intersection failed") + start = max(self.start or 0, other.start or 0) + end = min(self.end or float("inf"), other.end or float("inf")) + if end <= start: + # No intersection + return None + length = end - start + if length == float("inf"): + length = None + return RoiSlice(axis_name=self.axis_name, start=start, length=length) + + def to_world(self, pixel_size: float) -> "RoiSlice": + start = ( + pixel_to_world(self.start, pixel_size) if self.start is not None else None + ) + length = ( + pixel_to_world(self.length, pixel_size) if self.length is not None else None + ) + return RoiSlice(axis_name=self.axis_name, start=start, length=length) -class GenericRoi(BaseModel): - """A generic Region of Interest (ROI) model.""" + def to_pixel(self, pixel_size: float) -> "RoiSlice": + start = ( + world_to_pixel(self.start, pixel_size) if self.start is not None else None + ) + length = ( + world_to_pixel(self.length, pixel_size) if self.length is not None else None + ) + return RoiSlice(axis_name=self.axis_name, start=start, length=length) - name: str | None = None - x: float - y: float - z: float | None = None - t: float | None = None - x_length: float - y_length: float - z_length: float | None = None - t_length: float | None = None - label: int | None = None - unit: SpaceUnits | str | None = None + def zoom(self, zoom_factor: float = 1.0) -> "RoiSlice": + if zoom_factor <= 0: + raise NgioValueError("Zoom factor must be greater than 0") + zoom_factor -= 1.0 + if self.length is None: + return self - model_config = ConfigDict(extra="allow") + diff_length = self.length * zoom_factor + length = self.length + diff_length + start = max((self.start or 0) - (diff_length / 2), 0) + return RoiSlice(axis_name=self.axis_name, start=start, length=length) - def intersection(self, other: "GenericRoi") -> "GenericRoi | None": - """Calculate the intersection of this ROI with another ROI.""" - return roi_intersection(self, other) - def _nice_str(self) -> str: - if self.t is not None: - t_start = self.t - else: - t_start = None - if self.t_length is not None and t_start is not None: - t_end = t_start + self.t_length - else: - t_end = None +class Roi(BaseModel): + name: str | None + slices: list[RoiSlice] = Field(min_length=2) + label: int | None = Field(default=None, ge=0) + space: Literal["world", "pixel"] = "world" - t_str = f"t={t_start}->{t_end}" + model_config = ConfigDict(extra="allow") - if self.z is not None: - z_start = self.z - else: - z_start = None - if self.z_length is not None and z_start is not None: - z_end = z_start + self.z_length + @field_validator("slices") + @classmethod + def validate_no_duplicate_axes(cls, v: list[RoiSlice]) -> list[RoiSlice]: + axis_names = [s.axis_name for s in v] + if len(axis_names) != len(set(axis_names)): + raise NgioValueError("Roi slices must have unique axis names") + return v + + def _nice_repr__(self) -> str: + slices_repr = ", ".join(repr(s) for s in self.slices) + if self.label is None: + label_str = "" else: - z_end = None - z_str = f"z={z_start}->{z_end}" - - y_str = f"y={self.y}->{self.y + self.y_length}" - x_str = f"x={self.x}->{self.x + self.x_length}" - - if self.label is not None: label_str = f", label={self.label}" + + if self.name is None: + name_str = "" else: - label_str = "" - cls_name = self.__class__.__name__ - return f"{cls_name}({t_str}, {z_str}, {y_str}, {x_str}{label_str})" + name_str = f"name={self.name}, " + return f"Roi({name_str}{slices_repr}{label_str}, space={self.space})" + + @classmethod + def from_values( + cls, + slices: dict[str, float | tuple[float | None, float | None] | slice], + name: str | None, + label: int | None = None, + space: Literal["world", "pixel"] = "world", + **kwargs, + ) -> Self: + _slices = [] + for axis, _slice in slices.items(): + _slices.append(RoiSlice.from_value(axis_name=axis, value=_slice)) + return cls.model_construct( + name=name, slices=_slices, label=label, space=space, **kwargs + ) + + def get(self, axis_name: str) -> RoiSlice | None: + for roi_slice in self.slices: + if roi_slice.axis_name == axis_name: + return roi_slice + return None def get_name(self) -> str: - """Get the name of the ROI, or a default if not set.""" if self.name is not None: return self.name - return self._nice_str() + if self.label is not None: + return str(self.label) + return self._nice_repr__() + + @staticmethod + def _apply_sym_ops( + self_slices: list[RoiSlice], + other_slices: list[RoiSlice], + op: Callable[[RoiSlice, RoiSlice], RoiSlice | None], + ) -> list[RoiSlice] | None: + self_axis_dict = {s.axis_name: s for s in self_slices} + other_axis_dict = {s.axis_name: s for s in other_slices} + common_axis_names = self_axis_dict.keys() | other_axis_dict.keys() + new_slices = [] + for axis_name in common_axis_names: + slice_a = self_axis_dict.get(axis_name) + slice_b = other_axis_dict.get(axis_name) + if slice_a is not None and slice_b is not None: + result = op(slice_a, slice_b) + if result is None: + return None + new_slices.append(result) + elif slice_a is not None: + new_slices.append(slice_a) + elif slice_b is not None: + new_slices.append(slice_b) + return new_slices + + def intersection(self, other: Self) -> Self | None: + if self.space != other.space: + raise NgioValueError( + "Roi intersection failed: One ROI is in pixel space and the " + "other in world space" + ) - def __repr__(self) -> str: - return self._nice_str() - - def __str__(self) -> str: - return self._nice_str() - - def to_slicing_dict(self, pixel_size: PixelSize) -> dict[str, slice]: - raise NotImplementedError - - -def _1d_intersection( - a: T | None, a_length: T | None, b: T | None, b_length: T | None -) -> tuple[T | None, T | None]: - """Calculate the intersection of two 1D intervals.""" - if a is None: - if b is not None and b_length is not None: - return b, b_length - return None, None - if b is None: - if a is not None and a_length is not None: - return a, a_length - return None, None - - assert ( - a is not None - and a_length is not None - and b is not None - and b_length is not None - ) - start = max(a, b) - end = min(a + a_length, b + b_length) - length = end - start - - if length <= 0: - return None, None - - return start, length - - -def roi_intersection(ref_roi: GenericRoi, other_roi: GenericRoi) -> GenericRoi | None: - """Calculate the intersection of two ROIs.""" - if ( - ref_roi.unit is not None - and other_roi.unit is not None - and ref_roi.unit != other_roi.unit - ): - raise NgioValueError( - "Cannot calculate intersection of ROIs with different units." + out_slices = self._apply_sym_ops( + self.slices, other.slices, op=lambda a, b: a.intersection(b) ) + if out_slices is None: + return None - x, x_length = _1d_intersection( - ref_roi.x, ref_roi.x_length, other_roi.x, other_roi.x_length - ) - if x is None and x_length is None: - # No intersection - return None - assert x is not None and x_length is not None - - y, y_length = _1d_intersection( - ref_roi.y, ref_roi.y_length, other_roi.y, other_roi.y_length - ) - if y is None and y_length is None: - # No intersection - return None - assert y is not None and y_length is not None - - z, z_length = _1d_intersection( - ref_roi.z, ref_roi.z_length, other_roi.z, other_roi.z_length - ) - t, t_length = _1d_intersection( - ref_roi.t, ref_roi.t_length, other_roi.t, other_roi.t_length - ) - - if (z_length is not None and z_length <= 0) or ( - t_length is not None and t_length <= 0 - ): - # No intersection - return None + name = _join_roi_names(self.name, other.name) + label = _join_roi_labels(self.label, other.label) + return self.model_copy( + update={"name": name, "slices": out_slices, "label": label} + ) - # Find label - if ref_roi.label is not None and other_roi.label is not None: - if ref_roi.label != other_roi.label: + def union(self, other: Self) -> Self: + if self.space != other.space: raise NgioValueError( - "Cannot calculate intersection of ROIs with different labels." + "Roi union failed: One ROI is in pixel space and the " + "other in world space" ) - label = ref_roi.label or other_roi.label - - if ref_roi.name is not None and other_roi.name is not None: - name = f"{ref_roi.name}:{other_roi.name}" - else: - name = ref_roi.name or other_roi.name - - cls_ref = ref_roi.__class__ - return cls_ref( - name=name, - x=x, - y=y, - z=z, - t=t, - x_length=x_length, - y_length=y_length, - z_length=z_length, - t_length=t_length, - unit=ref_roi.unit, - label=label, - ) - - -class Roi(GenericRoi): - x: float = 0.0 - y: float = 0.0 - unit: SpaceUnits | str | None = DefaultSpaceUnit - - def to_roi_pixels(self, pixel_size: PixelSize) -> "RoiPixels": - """Convert to raster coordinates.""" - x, x_length = _to_raster(self.x, self.x_length, pixel_size.x) - y, y_length = _to_raster(self.y, self.y_length, pixel_size.y) - - if self.z is None: - z, z_length = None, None - else: - assert self.z_length is not None - z, z_length = _to_raster(self.z, self.z_length, pixel_size.z) - if self.t is None: - t, t_length = None, None - else: - assert self.t_length is not None - t, t_length = _to_raster(self.t, self.t_length, pixel_size.t) - extra_dict = self.model_extra if self.model_extra else {} - - return RoiPixels( - name=self.name, - x=x, - y=y, - z=z, - t=t, - x_length=x_length, - y_length=y_length, - z_length=z_length, - t_length=t_length, - label=self.label, - unit=self.unit, - **extra_dict, + out_slices = self._apply_sym_ops( + self.slices, other.slices, op=lambda a, b: a.union(b) ) + if out_slices is None: + raise NgioValueError("Roi union failed: could not compute union") - def to_pixel_roi( - self, pixel_size: PixelSize, dimensions: Dimensions | None = None - ) -> "RoiPixels": - """Convert to raster coordinates.""" - warn( - "to_pixel_roi is deprecated and will be removed in a future release. " - "Use to_roi_pixels instead.", - DeprecationWarning, - stacklevel=2, + name = _join_roi_names(self.name, other.name) + label = _join_roi_labels(self.label, other.label) + return self.model_copy( + update={"name": name, "slices": out_slices, "label": label} ) - return self.to_roi_pixels(pixel_size=pixel_size) - - def zoom(self, zoom_factor: float = 1) -> "Roi": - """Zoom the ROI by a factor. - - Args: - zoom_factor: The zoom factor. If the zoom factor - is less than 1 the ROI will be zoomed in. - If the zoom factor is greater than 1 the ROI will be zoomed out. - If the zoom factor is 1 the ROI will not be changed. - """ - return zoom_roi(self, zoom_factor) - - def to_slicing_dict(self, pixel_size: PixelSize) -> dict[str, slice]: - """Convert to a slicing dictionary.""" - roi_pixels = self.to_roi_pixels(pixel_size) - return roi_pixels.to_slicing_dict(pixel_size) - - -class RoiPixels(GenericRoi): - """Region of interest (ROI) in pixel coordinates.""" - - x: float = 0 - y: float = 0 - unit: SpaceUnits | str | None = None - - def to_roi(self, pixel_size: PixelSize) -> "Roi": - """Convert to raster coordinates.""" - x = _raster_to_world(self.x, pixel_size.x) - x_length = _raster_to_world(self.x_length, pixel_size.x) - y = _raster_to_world(self.y, pixel_size.y) - y_length = _raster_to_world(self.y_length, pixel_size.y) + def zoom( + self, zoom_factor: float = 1.0, axes: tuple[str, ...] = ("x", "y") + ) -> Self: + new_slices = [] + for roi_slice in self.slices: + if roi_slice.axis_name in axes: + new_slices.append(roi_slice.zoom(zoom_factor=zoom_factor)) + else: + new_slices.append(roi_slice) + return self.model_copy(update={"slices": new_slices}) + + def to_world(self, pixel_size: PixelSize | None = None) -> Self: + if self.space == "world": + return self.model_copy() + if pixel_size is None: + raise NgioValueError( + "Pixel sizes must be provided to convert ROI from pixel to world" + ) + new_slices = [] + for roi_slice in self.slices: + pixel_size_ = pixel_size.get(roi_slice.axis_name, default=1.0) + new_slices.append(roi_slice.to_world(pixel_size=pixel_size_)) + return self.model_copy(update={"slices": new_slices, "space": "world"}) - if self.z is None: - z = None - else: - z = _raster_to_world(self.z, pixel_size.z) + def to_pixel(self, pixel_size: PixelSize | None = None) -> Self: + if self.space == "pixel": + return self.model_copy() - if self.z_length is None: - z_length = None - else: - z_length = _raster_to_world(self.z_length, pixel_size.z) - - if self.t is None: - t = None - else: - t = _raster_to_world(self.t, pixel_size.t) + if pixel_size is None: + raise NgioValueError( + "Pixel sizes must be provided to convert ROI from world to pixel" + ) - if self.t_length is None: - t_length = None - else: - t_length = _raster_to_world(self.t_length, pixel_size.t) - - extra_dict = self.model_extra if self.model_extra else {} - return Roi( - name=self.name, - x=x, - y=y, - z=z, - t=t, - x_length=x_length, - y_length=y_length, - z_length=z_length, - t_length=t_length, - label=self.label, - unit=self.unit, - **extra_dict, - ) + new_slices = [] + for roi_slice in self.slices: + pixel_size_ = pixel_size.get(roi_slice.axis_name, default=1.0) + new_slices.append(roi_slice.to_pixel(pixel_size=pixel_size_)) + return self.model_copy(update={"slices": new_slices, "space": "pixel"}) - def to_slicing_dict(self, pixel_size: PixelSize) -> dict[str, slice]: - """Convert to a slicing dictionary.""" - x_slice = _to_slice(self.x, self.x_length) - y_slice = _to_slice(self.y, self.y_length) - z_slice = _to_slice(self.z, self.z_length) - t_slice = _to_slice(self.t, self.t_length) - return { - "x": x_slice, - "y": y_slice, - "z": z_slice, - "t": t_slice, - } - - -def zoom_roi(roi: Roi, zoom_factor: float = 1) -> Roi: - """Zoom the ROI by a factor. - - Args: - roi: The ROI to zoom. - zoom_factor: The zoom factor. If the zoom factor - is less than 1 the ROI will be zoomed in. - If the zoom factor is greater than 1 the ROI will be zoomed out. - If the zoom factor is 1 the ROI will not be changed. - """ - if zoom_factor <= 0: - raise NgioValueError("Zoom factor must be greater than 0.") - - # the zoom factor needs to be rescaled - # from the range [-1, inf) to [0, inf) - zoom_factor -= 1 - diff_x = roi.x_length * zoom_factor - diff_y = roi.y_length * zoom_factor - - new_x = max(roi.x - diff_x / 2, 0) - new_y = max(roi.y - diff_y / 2, 0) - - new_roi = Roi( - name=roi.name, - x=new_x, - y=new_y, - z=roi.z, - t=roi.t, - x_length=roi.x_length + diff_x, - y_length=roi.y_length + diff_y, - z_length=roi.z_length, - t_length=roi.t_length, - label=roi.label, - unit=roi.unit, - ) - return new_roi + def to_slicing_dict(self, pixel_size: PixelSize | None = None) -> dict[str, slice]: + roi = self.to_pixel(pixel_size=pixel_size) + return {roi_slice.axis_name: roi_slice.to_slice() for roi_slice in roi.slices} diff --git a/src/ngio/experimental/iterators/_feature.py b/src/ngio/experimental/iterators/_feature.py index c9e6a79e..53d891fd 100644 --- a/src/ngio/experimental/iterators/_feature.py +++ b/src/ngio/experimental/iterators/_feature.py @@ -4,7 +4,7 @@ import dask.array as da import numpy as np -from ngio.common import Roi, RoiPixels +from ngio.common import Roi from ngio.experimental.iterators._abstract_iterator import AbstractIteratorBuilder from ngio.images import Image, Label from ngio.images._image import ( @@ -18,8 +18,8 @@ TransformProtocol, ) -NumpyPipeType: TypeAlias = tuple[np.ndarray, np.ndarray, Roi | RoiPixels] -DaskPipeType: TypeAlias = tuple[da.Array, da.Array, Roi | RoiPixels] +NumpyPipeType: TypeAlias = tuple[np.ndarray, np.ndarray, Roi] +DaskPipeType: TypeAlias = tuple[da.Array, da.Array, Roi] class NumpyFeatureGetter(DataGetter[NumpyPipeType]): diff --git a/src/ngio/experimental/iterators/_rois_utils.py b/src/ngio/experimental/iterators/_rois_utils.py index eefdeefb..ef463810 100644 --- a/src/ngio/experimental/iterators/_rois_utils.py +++ b/src/ngio/experimental/iterators/_rois_utils.py @@ -1,4 +1,4 @@ -from ngio import Roi, RoiPixels +from ngio import Roi from ngio.images._abstract_image import AbstractImage @@ -48,18 +48,17 @@ def grid( for z in range(0, z_dim, stride_z): for y in range(0, y_dim, stride_y): for x in range(0, x_dim, stride_x): - roi = RoiPixels( + roi = Roi.from_values( name=base_name, - x=x, - y=y, - z=z, - t=t, - x_length=size_x, - y_length=size_y, - z_length=size_z, - t_length=size_t, + slices={ + "x": (x, size_x), + "y": (y, size_y), + "z": (z, size_z), + "t": (t, size_t), + }, + space="pixel", ) - new_rois.append(roi.to_roi(pixel_size=ref_image.pixel_size)) + new_rois.append(roi.to_world(pixel_size=ref_image.pixel_size)) return rois_product(rois, new_rois) diff --git a/src/ngio/hcs/_plate.py b/src/ngio/hcs/_plate.py index 643cf6a1..a05a88e7 100644 --- a/src/ngio/hcs/_plate.py +++ b/src/ngio/hcs/_plate.py @@ -15,15 +15,16 @@ list_image_tables_async, ) from ngio.ome_zarr_meta import ( + DefaultNgffVersion, ImageInWellPath, NgffVersions, NgioPlateMeta, NgioWellMeta, - find_plate_meta_handler, - find_well_meta_handler, - get_plate_meta_handler, - get_well_meta_handler, + PlateMetaHandler, + WellMetaHandler, path_in_well_validation, + update_ngio_plate_meta, + update_ngio_well_meta, ) from ngio.tables import ( ConditionTable, @@ -40,17 +41,23 @@ ) from ngio.utils import ( AccessModeLiteral, + NgioCache, + NgioError, NgioValueError, StoreOrGroup, ZarrGroupHandler, ) -def _default_table_container(handler: ZarrGroupHandler) -> TablesContainer | None: +def _try_get_table_container( + handler: ZarrGroupHandler, create_mode: bool = True +) -> TablesContainer | None: """Return a default table container.""" - success, table_handler = handler.safe_derive_handler("tables") - if success and isinstance(table_handler, ZarrGroupHandler): + try: + table_handler = handler.get_handler("tables", create_mode=create_mode) return TablesContainer(table_handler) + except NgioError: + return None # Mock lock class that does nothing @@ -76,7 +83,7 @@ def __init__(self, group_handler: ZarrGroupHandler) -> None: group_handler: The Zarr group handler that contains the Well. """ self._group_handler = group_handler - self._meta_handler = find_well_meta_handler(group_handler) + self._meta_handler = WellMetaHandler(group_handler) def __repr__(self) -> str: """Return a string representation of the well.""" @@ -90,7 +97,7 @@ def meta_handler(self): @property def meta(self): """Return the metadata.""" - return self._meta_handler.meta + return self._meta_handler.get_meta() @property def acquisition_ids(self) -> list[int]: @@ -136,7 +143,7 @@ def get_image(self, image_path: str) -> OmeZarrContainer: Returns: OmeZarrContainer: The image. """ - handler = self._group_handler.derive_handler(image_path) + handler = self._group_handler.get_handler(image_path) return OmeZarrContainer(handler) def _add_image( @@ -158,7 +165,7 @@ def _add_image( meta = self.meta.add_image( path=image_path, acquisition=acquisition_id, strict=strict ) - self.meta_handler.write_meta(meta) + self.meta_handler.update_meta(meta) self.meta_handler._group_handler.clean_cache() return self._group_handler.get_group(image_path, create_mode=True) @@ -237,8 +244,14 @@ def __init__( table_container: The tables container that contains plate level tables. """ self._group_handler = group_handler - self._meta_handler = find_plate_meta_handler(group_handler) + self._meta_handler = PlateMetaHandler(group_handler) self._tables_container = table_container + self._wells_cache: NgioCache[OmeZarrWell] = NgioCache( + use_cache=self._group_handler.use_cache + ) + self._images_cache: NgioCache[OmeZarrContainer] = NgioCache( + use_cache=self._group_handler.use_cache + ) def __repr__(self) -> str: """Return a string representation of the plate.""" @@ -252,7 +265,7 @@ def meta_handler(self): @property def meta(self): """Return the metadata.""" - return self._meta_handler.meta + return self._meta_handler.get_meta() @property def columns(self) -> list[str]: @@ -356,6 +369,24 @@ def get_image_acquisition_id( well = self.get_well(row=row, column=column) return well.get_image_acquisition_id(image_path=image_path) + def _get_well(self, well_path: str) -> OmeZarrWell: + """Get a well from the plate by its path. + + Args: + well_path (str): The path of the well. + + Returns: + OmeZarrWell: The well. + + """ + cached_well = self._wells_cache.get(well_path) + if cached_well is not None: + return cached_well + + group_handler = self._group_handler.get_handler(well_path) + self._wells_cache.set(well_path, OmeZarrWell(group_handler)) + return OmeZarrWell(group_handler) + def get_well(self, row: str, column: int | str) -> OmeZarrWell: """Get a well from the plate. @@ -367,8 +398,7 @@ def get_well(self, row: str, column: int | str) -> OmeZarrWell: OmeZarrWell: The well. """ well_path = self._well_path(row=row, column=column) - group_handler = self._group_handler.derive_handler(well_path) - return OmeZarrWell(group_handler) + return self._get_well(well_path=well_path) async def get_wells_async(self) -> dict[str, OmeZarrWell]: """Get all wells in the plate asynchronously. @@ -380,26 +410,17 @@ async def get_wells_async(self) -> dict[str, OmeZarrWell]: dict[str, OmeZarrWell]: A dictionary of wells, where the key is the well path and the value is the well object. """ - wells = self._group_handler.get_from_cache("wells") - if wells is not None: - assert isinstance(wells, dict) - return wells - - def process_well(well_path): - group_handler = self._group_handler.derive_handler(well_path) - well = OmeZarrWell(group_handler) - return well_path, well - wells, tasks = {}, [] for well_path in self.wells_paths(): - task = asyncio.to_thread(process_well, well_path) + task = asyncio.to_thread( + lambda well_path: (well_path, self._get_well(well_path)), well_path + ) tasks.append(task) results = await asyncio.gather(*tasks) for well_path, well in results: wells[well_path] = well - self._group_handler.add_to_cache("wells", wells) return wells def get_wells(self) -> dict[str, OmeZarrWell]: @@ -409,24 +430,25 @@ def get_wells(self) -> dict[str, OmeZarrWell]: dict[str, OmeZarrWell]: A dictionary of wells, where the key is the well path and the value is the well object. """ - wells = self._group_handler.get_from_cache("wells") - if wells is not None: - assert isinstance(wells, dict) - return wells - - def process_well(well_path): - group_handler = self._group_handler.derive_handler(well_path) - well = OmeZarrWell(group_handler) - return well_path, well - wells = {} for well_path in self.wells_paths(): - _, well = process_well(well_path) - wells[well_path] = well - - self._group_handler.add_to_cache("wells", wells) + wells[well_path] = self._get_well(well_path) return wells + def _get_image(self, image_path: str) -> OmeZarrContainer: + """Get an image from the plate by its path. + + Args: + image_path (str): The path of the image. + """ + cached_image = self._images_cache.get(image_path) + if cached_image is not None: + return cached_image + img_group_handler = self._group_handler.get_handler(image_path) + image = OmeZarrContainer(img_group_handler) + self._images_cache.set(image_path, image) + return image + async def get_images_async( self, acquisition: int | None = None ) -> dict[str, OmeZarrContainer]: @@ -442,30 +464,19 @@ async def get_images_async( dict[str, OmeZarrContainer]: A dictionary of images, where the key is the image path and the value is the image object. """ - images = self._group_handler.get_from_cache("images") - if images is not None: - assert isinstance(images, dict) - return images - paths = await self.images_paths_async(acquisition=acquisition) - def process_image(image_path): - """Process a single image and return the image path and image object.""" - img_group_handler = self._group_handler.derive_handler(image_path) - image = OmeZarrContainer(img_group_handler) - return image_path, image - images, tasks = {}, [] for image_path in paths: - task = asyncio.to_thread(process_image, image_path) + task = asyncio.to_thread( + lambda image_path: (image_path, self._get_image(image_path)), image_path + ) tasks.append(task) results = await asyncio.gather(*tasks) for image_path, image in results: images[image_path] = image - - self._group_handler.add_to_cache("images", images) return images def get_images(self, acquisition: int | None = None) -> dict[str, OmeZarrContainer]: @@ -474,24 +485,11 @@ def get_images(self, acquisition: int | None = None) -> dict[str, OmeZarrContain Args: acquisition: The acquisition id to filter the images. """ - images = self._group_handler.get_from_cache("images") - if images is not None: - assert isinstance(images, dict) - return images paths = self.images_paths(acquisition=acquisition) - - def process_image(image_path): - """Process a single image and return the image path and image object.""" - img_group_handler = self._group_handler.derive_handler(image_path) - image = OmeZarrContainer(img_group_handler) - return image_path, image - images = {} for image_path in paths: - _, image = process_image(image_path) - images[image_path] = image + images[image_path] = self._get_image(image_path) - self._group_handler.add_to_cache("images", images) return images def get_image( @@ -508,8 +506,7 @@ def get_image( OmeZarrContainer: The image. """ image_path = self._image_path(row=row, column=column, path=image_path) - group_handler = self._group_handler.derive_handler(image_path) - return OmeZarrContainer(group_handler) + return self._get_image(image_path) def get_image_store( self, row: str, column: int | str, image_path: str @@ -538,7 +535,7 @@ def get_well_images( for image_paths in self.well_images_paths( row=row, column=column, acquisition=acquisition ): - group_handler = self._group_handler.derive_handler(image_paths) + group_handler = self._group_handler.get_handler(image_paths) images[image_paths] = OmeZarrContainer(group_handler) return images @@ -567,11 +564,11 @@ def _add_image( meta = meta.add_acquisition( acquisition_id=acquisition_id, acquisition_name=acquisition_name ) - self.meta_handler.write_meta(meta) + self.meta_handler.update_meta(meta) self.meta_handler._group_handler.clean_cache() well_path = self.meta.get_well_path(row=row, column=column) - group_handler = self._group_handler.derive_handler(well_path) + group_handler = self._group_handler.get_handler(well_path) if atomic: well_lock = group_handler.lock @@ -586,18 +583,19 @@ def _add_image( well_meta = NgioWellMeta.default_init() version = self.meta.plate.version version = version if version is not None else "0.4" - meta_handler = get_well_meta_handler(group_handler, version=version) + update_ngio_well_meta(group_handler, well_meta) + meta_handler = WellMetaHandler(group_handler=group_handler) else: - meta_handler = find_well_meta_handler(group_handler) - well_meta = meta_handler.meta + meta_handler = WellMetaHandler(group_handler=group_handler) + well_meta = meta_handler.get_meta() - group_handler = self._group_handler.derive_handler(well_path) + group_handler = self._group_handler.get_handler(well_path) if image_path is not None: well_meta = well_meta.add_image( path=image_path, acquisition=acquisition_id, strict=False ) - meta_handler.write_meta(well_meta) + meta_handler.update_meta(well_meta) meta_handler._group_handler.clean_cache() if image_path is not None: @@ -674,7 +672,7 @@ def add_column( ) -> "OmeZarrPlate": """Add a column to an ome-zarr plate.""" meta, _ = self.meta.add_column(column) - self.meta_handler.write_meta(meta) + self.meta_handler.update_meta(meta) self.meta_handler._group_handler.clean_cache() return self @@ -684,7 +682,7 @@ def add_row( ) -> "OmeZarrPlate": """Add a row to an ome-zarr plate.""" meta, _ = self.meta.add_row(row) - self.meta_handler.write_meta(meta) + self.meta_handler.update_meta(meta) self.meta_handler._group_handler.clean_cache() return self @@ -704,7 +702,7 @@ def add_acquisition( meta = self.meta.add_acquisition( acquisition_id=acquisition_id, acquisition_name=acquisition_name ) - self.meta_handler.write_meta(meta) + self.meta_handler.update_meta(meta) self.meta_handler._group_handler.clean_cache() return self @@ -723,7 +721,7 @@ def _remove_well( with plate_lock: meta = self.meta meta = meta.remove_well(row, column) - self.meta_handler.write_meta(meta) + self.meta_handler.update_meta(meta) self.meta_handler._group_handler.clean_cache() def _remove_image( @@ -744,7 +742,7 @@ def _remove_image( with well_lock: well_meta = well.meta well_meta = well_meta.remove_image(path=image_path) - well.meta_handler.write_meta(well_meta) + well.meta_handler.update_meta(well_meta) well.meta_handler._group_handler.clean_cache() if len(well_meta.paths()) == 0: self._remove_well(row, column, atomic=atomic) @@ -781,41 +779,42 @@ def derive_plate( self, store: StoreOrGroup, plate_name: str | None = None, - version: NgffVersions = "0.4", + version: NgffVersions | None = None, + ngff_version: NgffVersions = DefaultNgffVersion, keep_acquisitions: bool = False, cache: bool = False, overwrite: bool = False, - parallel_safe: bool = True, ) -> "OmeZarrPlate": """Derive a new OME-Zarr plate from an existing one. Args: store (StoreOrGroup): The Zarr store or group that stores the plate. plate_name (str | None): The name of the new plate. - version (NgffVersion): The version of the new plate. + version (NgffVersion | None): Deprecated. Please use 'ngff_version' instead. + ngff_version (NgffVersion): The NGFF version to use for the new plate. keep_acquisitions (bool): Whether to keep the acquisitions in the new plate. cache (bool): Whether to use a cache for the zarr group metadata. overwrite (bool): Whether to overwrite the existing plate. - parallel_safe (bool): Whether the group handler is parallel safe. """ return derive_ome_zarr_plate( ome_zarr_plate=self, store=store, plate_name=plate_name, + ngff_version=ngff_version, version=version, keep_acquisitions=keep_acquisitions, cache=cache, overwrite=overwrite, - parallel_safe=parallel_safe, ) - def _get_tables_container(self) -> TablesContainer | None: + def _get_tables_container(self, create_mode: bool = True) -> TablesContainer | None: """Return the tables container.""" - if self._tables_container is None: - _tables_container = _default_table_container(self._group_handler) - if _tables_container is None: - return None - self._tables_container = _tables_container + if self._tables_container is not None: + return self._tables_container + _tables_container = _try_get_table_container( + self._group_handler, create_mode=create_mode + ) + self._tables_container = _tables_container return self._tables_container @property @@ -830,17 +829,20 @@ def tables_container(self) -> TablesContainer: def list_tables(self, filter_types: TypedTable | str | None = None) -> list[str]: """List all tables in the image.""" + _tables_container = self._get_tables_container(create_mode=False) + if _tables_container is None: + return [] return self.tables_container.list(filter_types=filter_types) def list_roi_tables(self) -> list[str]: """List all ROI tables in the image.""" - masking_roi = self.tables_container.list( - filter_types="masking_roi_table", - ) roi = self.tables_container.list( filter_types="roi_table", ) - return masking_roi + roi + masking_roi = self.tables_container.list( + filter_types="masking_roi_table", + ) + return roi + masking_roi def get_roi_table(self, name: str) -> RoiTable: """Get a ROI table from the image. @@ -957,6 +959,25 @@ def add_table( name=name, table=table, backend=backend, overwrite=overwrite ) + def delete_table(self, name: str, missing_ok: bool = False) -> None: + """Delete a table from the group. + + Args: + name (str): The name of the table to delete. + missing_ok (bool): If True, do not raise an error if the table does not + exist. + + """ + table_container = self._get_tables_container(create_mode=False) + if table_container is None and missing_ok: + return + if table_container is None: + raise NgioValueError( + f"No tables found in the image, cannot delete {name}. " + "Set missing_ok=True to ignore this error." + ) + table_container.delete(name=name, missing_ok=missing_ok) + def list_image_tables( self, acquisition: int | None = None, @@ -1143,7 +1164,6 @@ def open_ome_zarr_plate( store: StoreOrGroup, cache: bool = False, mode: AccessModeLiteral = "r+", - parallel_safe: bool = True, ) -> OmeZarrPlate: """Open an OME-Zarr plate. @@ -1152,27 +1172,20 @@ def open_ome_zarr_plate( cache (bool): Whether to use a cache for the zarr group metadata. mode (AccessModeLiteral): The access mode for the image. Defaults to "r+". - parallel_safe (bool): Whether the group handler is parallel safe. """ - group_handler = ZarrGroupHandler( - store=store, cache=cache, mode=mode, parallel_safe=parallel_safe - ) + group_handler = ZarrGroupHandler(store=store, cache=cache, mode=mode) return OmeZarrPlate(group_handler) def _create_empty_plate_from_meta( store: StoreOrGroup, meta: NgioPlateMeta, - version: NgffVersions = "0.4", overwrite: bool = False, ) -> ZarrGroupHandler: """Create an empty OME-Zarr plate from metadata.""" mode = "w" if overwrite else "w-" - group_handler = ZarrGroupHandler( - store=store, cache=True, mode=mode, parallel_safe=False - ) - meta_handler = get_plate_meta_handler(group_handler, version=version) - meta_handler.write_meta(meta) + group_handler = ZarrGroupHandler(store=store, cache=True, mode=mode) + update_ngio_plate_meta(group_handler, meta) return group_handler @@ -1180,20 +1193,38 @@ def create_empty_plate( store: StoreOrGroup, name: str, images: list[ImageInWellPath] | None = None, - version: NgffVersions = "0.4", + version: NgffVersions | None = None, + ngff_version: NgffVersions = DefaultNgffVersion, cache: bool = False, overwrite: bool = False, - parallel_safe: bool = True, ) -> OmeZarrPlate: - """Initialize and create an empty OME-Zarr plate.""" + """Initialize and create an empty OME-Zarr plate. + + Args: + store (StoreOrGroup): The Zarr store or group that stores the plate. + name (str): The name of the plate. + images (list[ImageInWellPath] | None): A list of images to add to the plate. + If None, no images are added. Defaults to None. + version (NgffVersion | None): Deprecated. Please use 'ngff_version' instead. + ngff_version (NgffVersion): The NGFF version to use for the new plate. + cache (bool): Whether to use a cache for the zarr group metadata. + overwrite (bool): Whether to overwrite the existing plate. + """ + if version is not None: + warnings.warn( + "The 'version' argument is deprecated, and will be removed in ngio=0.3. " + "Please use 'ngff_version' instead.", + DeprecationWarning, + stacklevel=2, + ) + ngff_version = version plate_meta = NgioPlateMeta.default_init( name=name, - version=version, + ngff_version=ngff_version, ) group_handler = _create_empty_plate_from_meta( store=store, meta=plate_meta, - version=version, overwrite=overwrite, ) @@ -1211,7 +1242,6 @@ def create_empty_plate( store=store, cache=cache, mode="r+", - parallel_safe=parallel_safe, ) @@ -1219,11 +1249,11 @@ def derive_ome_zarr_plate( ome_zarr_plate: OmeZarrPlate, store: StoreOrGroup, plate_name: str | None = None, - version: NgffVersions = "0.4", + version: NgffVersions | None = None, + ngff_version: NgffVersions = DefaultNgffVersion, keep_acquisitions: bool = False, cache: bool = False, overwrite: bool = False, - parallel_safe: bool = True, ) -> OmeZarrPlate: """Derive a new OME-Zarr plate from an existing one. @@ -1231,31 +1261,38 @@ def derive_ome_zarr_plate( ome_zarr_plate (OmeZarrPlate): The existing OME-Zarr plate. store (StoreOrGroup): The Zarr store or group that stores the plate. plate_name (str | None): The name of the new plate. - version (NgffVersion): The version of the new plate. + version (NgffVersion | None): Deprecated. Please use 'ngff_version' instead. + ngff_version (NgffVersion): The NGFF version to use for the new plate. keep_acquisitions (bool): Whether to keep the acquisitions in the new plate. cache (bool): Whether to use a cache for the zarr group metadata. overwrite (bool): Whether to overwrite the existing plate. - parallel_safe (bool): Whether the group handler is parallel safe. """ + if version is not None: + warnings.warn( + "The 'version' argument is deprecated, and will be removed in ngio=0.3. " + "Please use 'ngff_version' instead.", + DeprecationWarning, + stacklevel=2, + ) + ngff_version = version + if plate_name is None: plate_name = ome_zarr_plate.meta.plate.name new_meta = ome_zarr_plate.meta.derive( name=plate_name, - version=version, + ngff_version=ngff_version, keep_acquisitions=keep_acquisitions, ) _ = _create_empty_plate_from_meta( store=store, meta=new_meta, overwrite=overwrite, - version=version, ) return open_ome_zarr_plate( store=store, cache=cache, mode="r+", - parallel_safe=parallel_safe, ) @@ -1263,7 +1300,6 @@ def open_ome_zarr_well( store: StoreOrGroup, cache: bool = False, mode: AccessModeLiteral = "r+", - parallel_safe: bool = True, ) -> OmeZarrWell: """Open an OME-Zarr well. @@ -1271,40 +1307,48 @@ def open_ome_zarr_well( store (StoreOrGroup): The Zarr store or group that stores the plate. cache (bool): Whether to use a cache for the zarr group metadata. mode (AccessModeLiteral): The access mode for the image. Defaults to "r+". - parallel_safe (bool): Whether the group handler is parallel safe. """ group_handler = ZarrGroupHandler( - store=store, cache=cache, mode=mode, parallel_safe=parallel_safe + store=store, + cache=cache, + mode=mode, ) return OmeZarrWell(group_handler) def create_empty_well( store: StoreOrGroup, - version: NgffVersions = "0.4", + version: NgffVersions | None = None, + ngff_version: NgffVersions = DefaultNgffVersion, cache: bool = False, overwrite: bool = False, - parallel_safe: bool = True, ) -> OmeZarrWell: """Create an empty OME-Zarr well. Args: store (StoreOrGroup): The Zarr store or group that stores the well. - version (NgffVersion): The version of the new well. + version (NgffVersion | None): Deprecated. Please use 'ngff_version' instead. + ngff_version (NgffVersion): The version of the new well. cache (bool): Whether to use a cache for the zarr group metadata. overwrite (bool): Whether to overwrite the existing well. - parallel_safe (bool): Whether the group handler is parallel safe. """ + if version is not None: + warnings.warn( + "The 'version' argument is deprecated, and will be removed in ngio=0.3. " + "Please use 'ngff_version' instead.", + DeprecationWarning, + stacklevel=2, + ) + ngff_version = version group_handler = ZarrGroupHandler( - store=store, cache=True, mode="w" if overwrite else "w-", parallel_safe=False + store=store, cache=True, mode="w" if overwrite else "w-" + ) + update_ngio_well_meta( + group_handler, NgioWellMeta.default_init(ngff_version=ngff_version) ) - meta_handler = get_well_meta_handler(group_handler, version=version) - meta = NgioWellMeta.default_init() - meta_handler.write_meta(meta) return open_ome_zarr_well( store=store, cache=cache, mode="r+", - parallel_safe=parallel_safe, ) diff --git a/src/ngio/images/_abstract_image.py b/src/ngio/images/_abstract_image.py index de7d9a36..fe3ef47c 100644 --- a/src/ngio/images/_abstract_image.py +++ b/src/ngio/images/_abstract_image.py @@ -1,19 +1,26 @@ """Generic class to handle Image-like data in a OME-NGFF file.""" -from collections.abc import Sequence -from typing import Generic, Literal, TypeVar +import warnings +from abc import ABC, abstractmethod +from collections.abc import Mapping, Sequence +from typing import Any, Literal import dask.array as da import numpy as np import zarr +from zarr.core.array import CompressorLike from ngio.common import ( Dimensions, InterpolationOrder, Roi, - RoiPixels, consolidate_pyramid, ) +from ngio.common._pyramid import ChunksLike, ShardsLike, shapes_from_scaling_factors +from ngio.images._create_utils import ( + _image_or_label_meta, + init_image_like_from_shapes, +) from ngio.io_pipes import ( DaskGetter, DaskRoiGetter, @@ -31,15 +38,25 @@ Dataset, ImageMetaHandler, LabelMetaHandler, + NgioImageMeta, PixelSize, ) +from ngio.ome_zarr_meta.ngio_specs import ( + Channel, + NgffVersions, + NgioLabelMeta, +) from ngio.tables import RoiTable -from ngio.utils import NgioFileExistsError, ZarrGroupHandler - -_image_handler = TypeVar("_image_handler", ImageMetaHandler, LabelMetaHandler) +from ngio.utils import ( + NgioFileExistsError, + NgioValueError, + StoreOrGroup, + ZarrGroupHandler, +) +from ngio.utils._zarr_utils import find_dimension_separator -class AbstractImage(Generic[_image_handler]): +class AbstractImage(ABC): """A class to handle a single image (or level) in an OME-Zarr image. This class is meant to be subclassed by specific image types. @@ -49,7 +66,7 @@ def __init__( self, group_handler: ZarrGroupHandler, path: str, - meta_handler: _image_handler, + meta_handler: ImageMetaHandler | LabelMetaHandler, ) -> None: """Initialize the Image at a single level. @@ -78,14 +95,21 @@ def path(self) -> str: return self._path @property - def meta_handler(self) -> _image_handler: + @abstractmethod + def meta_handler(self) -> ImageMetaHandler | LabelMetaHandler: """Return the metadata.""" - return self._meta_handler + pass + + @property + @abstractmethod + def meta(self) -> NgioImageMeta | NgioLabelMeta: + """Return the metadata.""" + pass @property def dataset(self) -> Dataset: """Return the dataset of the image.""" - return self.meta_handler.meta.get_dataset(path=self.path) + return self.meta_handler.get_meta().get_dataset(path=self.path) @property def dimensions(self) -> Dimensions: @@ -202,7 +226,7 @@ def _get_as_numpy( def _get_roi_as_numpy( self, - roi: Roi | RoiPixels, + roi: Roi, axes_order: Sequence[str] | None = None, transforms: Sequence[TransformProtocol] | None = None, **slicing_kwargs: SlicingInputType, @@ -252,7 +276,7 @@ def _get_as_dask( def _get_roi_as_dask( self, - roi: Roi | RoiPixels, + roi: Roi, axes_order: Sequence[str] | None = None, transforms: Sequence[TransformProtocol] | None = None, **slicing_kwargs: SlicingInputType, @@ -309,7 +333,7 @@ def _get_array( def _get_roi( self, - roi: Roi | RoiPixels, + roi: Roi, axes_order: Sequence[str] | None = None, transforms: Sequence[TransformProtocol] | None = None, mode: Literal["numpy", "dask"] = "numpy", @@ -385,7 +409,7 @@ def _set_array( def _set_roi( self, - roi: Roi | RoiPixels, + roi: Roi, patch: np.ndarray | da.Array, axes_order: Sequence[str] | None = None, transforms: Sequence[TransformProtocol] | None = None, @@ -444,25 +468,14 @@ def _consolidate( def roi(self, name: str | None = "image") -> Roi: """Return the ROI covering the entire image.""" - dim_x = self.dimensions.get("x") - dim_y = self.dimensions.get("y") - assert dim_x is not None and dim_y is not None - dim_z = self.dimensions.get("z") - z = None if dim_z is None else 0 - dim_t = self.dimensions.get("t") - t = None if dim_t is None else 0 - roi_px = RoiPixels( - name=name, - x=0, - y=0, - z=z, - t=t, - x_length=dim_x, - y_length=dim_y, - z_length=dim_z, - t_length=dim_t, - ) - return roi_px.to_roi(pixel_size=self.pixel_size) + slices = {} + for ax_name in ["t", "z", "y", "x"]: + axis_size = self.dimensions.get(ax_name, default=None) + if axis_size is None: + continue + slices[ax_name] = slice(0, axis_size) + roi_px = Roi.from_values(name=name, slices=slices, space="pixel") + return roi_px.to_world(pixel_size=self.pixel_size) def build_image_roi_table(self, name: str | None = "image") -> RoiTable: """Build the ROI table containing the ROI covering the entire image.""" @@ -576,7 +589,7 @@ def consolidate_image( mode: Literal["dask", "numpy", "coarsen"] = "dask", ) -> None: """Consolidate the image on disk.""" - target_paths = image._meta_handler.meta.paths + target_paths = image.meta_handler.get_meta().paths targets = [ image._group_handler.get_array(path) for path in target_paths @@ -585,3 +598,373 @@ def consolidate_image( consolidate_pyramid( source=image.zarr_array, targets=targets, order=order, mode=mode ) + + +def _shapes_from_ref_image( + ref_image: AbstractImage, +) -> list[tuple[int, ...]]: + """Rebuild base shape based on a new shape.""" + paths = ref_image.meta.paths + index_path = paths.index(ref_image.path) + sub_paths = paths[index_path:] + group_handler = ref_image._group_handler + shapes = [] + for path in sub_paths: + zarr_array = group_handler.get_array(path) + shapes.append(zarr_array.shape) + if len(shapes) == len(paths): + return shapes + missing_levels = len(paths) - len(shapes) + print(ref_image.meta.scaling_factor()) + extended_shapes = shapes_from_scaling_factors( + base_shape=shapes[-1], + scaling_factors=ref_image.meta.scaling_factor(), + num_levels=missing_levels + 1, + ) + shapes.extend(extended_shapes[1:]) + return shapes + + +def _shapes_from_new_shape( + ref_image: AbstractImage, + shape: Sequence[int], +) -> list[tuple[int, ...]]: + """Rebuild pyramid shapes based on a new base shape.""" + if len(shape) != len(ref_image.shape): + raise NgioValueError( + "The shape of the new image does not match the reference image." + f"Got shape {shape} for reference shape {ref_image.shape}." + ) + base_shape = tuple(shape) + scaling_factors = ref_image.meta.scaling_factor() + num_levels = len(ref_image.meta.paths) + return shapes_from_scaling_factors( + base_shape=base_shape, + scaling_factors=scaling_factors, + num_levels=num_levels, + ) + + +def _compute_pyramid_shapes( + ref_image: AbstractImage, + shape: Sequence[int] | None, +) -> list[tuple[int, ...]]: + """Rebuild pyramid shapes based on a new base shape.""" + if shape is None: + return _shapes_from_ref_image(ref_image=ref_image) + return _shapes_from_new_shape(ref_image=ref_image, shape=shape) + + +def _check_chunks_and_shards_compatibility( + ref_shape: tuple[int, ...], + chunks: ChunksLike, + shards: ShardsLike | None, +) -> None: + """Check if the chunks and shards are compatible with the reference shape. + + Args: + ref_shape: The reference shape. + chunks: The chunks to check. + shards: The shards to check. + """ + if chunks != "auto": + if len(chunks) != len(ref_shape): + raise NgioValueError( + "The length of the chunks must be the same as the number of dimensions." + ) + if shards is not None and shards != "auto": + if len(shards) != len(ref_shape): + raise NgioValueError( + "The length of the shards must be the same as the number of dimensions." + ) + + +def _apply_channel_policy( + ref_image: AbstractImage, + channels_policy: Literal["squeeze", "same", "singleton"] | int, + shapes: list[tuple[int, ...]], + axes: tuple[str, ...], + chunks: ChunksLike, + shards: ShardsLike | None, +) -> tuple[list[tuple[int, ...]], tuple[str, ...], ChunksLike, ShardsLike | None]: + """Apply the channel policy to the shapes and axes. + + Args: + ref_image: The reference image. + channels_policy: The channels policy to apply. + shapes: The shapes of the pyramid levels. + axes: The axes of the image. + chunks: The chunks of the image. + shards: The shards of the image. + + Returns: + The new shapes and axes after applying the channel policy. + """ + if channels_policy == "same": + return shapes, axes, chunks, shards + + if channels_policy == "singleton": + # Treat 'singleton' as setting channel size to 1 + channels_policy = 1 + + channel_index = ref_image.axes_handler.get_index("c") + if channel_index is None: + if channels_policy == "squeeze": + return shapes, axes, chunks, shards + raise NgioValueError( + f"Cannot apply channel policy {channels_policy=} to an image " + "without channels axis." + ) + if channels_policy == "squeeze": + new_shapes = [] + for shape in shapes: + new_shape = shape[:channel_index] + shape[channel_index + 1 :] + new_shapes.append(new_shape) + new_axes = axes[:channel_index] + axes[channel_index + 1 :] + if chunks == "auto": + new_chunks: ChunksLike = "auto" + else: + new_chunks = chunks[:channel_index] + chunks[channel_index + 1 :] + if shards == "auto" or shards is None: + new_shards: ShardsLike | None = shards + else: + new_shards = shards[:channel_index] + shards[channel_index + 1 :] + return new_shapes, new_axes, new_chunks, new_shards + elif isinstance(channels_policy, int): + new_shapes = [] + for shape in shapes: + new_shape = ( + *shape[:channel_index], + channels_policy, + *shape[channel_index + 1 :], + ) + new_shapes.append(new_shape) + return new_shapes, axes, chunks, shards + else: + raise NgioValueError( + f"Invalid channels policy: {channels_policy}. " + "Must be 'squeeze', 'same', or an integer." + ) + + +def _check_channels_meta_compatibility( + meta_type: type[_image_or_label_meta], + ref_image: AbstractImage, + channels_meta: Sequence[str | Channel] | None, +) -> Sequence[str | Channel] | None: + """Check if the channels metadata is compatible with the reference image. + + Args: + meta_type: The metadata type. + ref_image: The reference image. + channels_meta: The channels metadata to check. + + Returns: + The channels metadata if compatible, None otherwise. + """ + if issubclass(meta_type, NgioLabelMeta): + if channels_meta is not None: + raise NgioValueError("Cannot set channels_meta for a label image.") + return None + if channels_meta is not None: + return channels_meta + assert isinstance(ref_image.meta, NgioImageMeta) + ref_meta = ref_image.meta + index_c = ref_meta.axes_handler.get_index("c") + if index_c is None: + return None + + # If the channels number does not match, return None + # Else return the channels metadata from the reference image + ref_shape = ref_image.shape + ref_num_channels = ref_shape[index_c] if index_c is not None else 1 + channels_ = ref_meta.channels_meta.channels if ref_meta.channels_meta else [] + # Reset to None if number of channels do not match + channels_meta_ = channels_ if ref_num_channels == len(channels_) else None + return channels_meta_ + + +def abstract_derive( + *, + ref_image: AbstractImage, + meta_type: type[_image_or_label_meta], + store: StoreOrGroup, + overwrite: bool = False, + # Metadata parameters + shape: Sequence[int] | None = None, + pixelsize: float | tuple[float, float] | None = None, + z_spacing: float | None = None, + time_spacing: float | None = None, + name: str | None = None, + channels_policy: Literal["squeeze", "same", "singleton"] | int = "same", + channels_meta: Sequence[str | Channel] | None = None, + ngff_version: NgffVersions | None = None, + # Zarr Array parameters + chunks: ChunksLike | None = None, + shards: ShardsLike | None = None, + dtype: str | None = None, + dimension_separator: Literal[".", "/"] | None = None, + compressors: CompressorLike | None = None, + extra_array_kwargs: Mapping[str, Any] | None = None, + # Deprecated arguments + labels: Sequence[str] | None = None, + pixel_size: PixelSize | None = None, +) -> ZarrGroupHandler: + """Create an empty OME-Zarr image from an existing image. + + If a kwarg is not provided, the value from the reference image will be used. + + Args: + ref_image (AbstractImage): The reference image to derive from. + meta_type (type[_image_or_label_meta]): The metadata type to use. + store (StoreOrGroup): The Zarr store or group to create the image in. + overwrite (bool): Whether to overwrite an existing image. + shape (Sequence[int] | None): The shape of the new image. + pixelsize (float | tuple[float, float] | None): The pixel size of the new image. + z_spacing (float | None): The z spacing of the new image. + time_spacing (float | None): The time spacing of the new image. + axes_names (Sequence[str] | None): The axes names of the new image. + name (str | None): The name of the new image. + channels_policy (Literal["squeeze", "same", "singleton"] | int): + Possible policies: + - If "squeeze", the channels axis will be removed (no matter its size). + - If "same", the channels axis will be kept as is (if it exists). + - If "singleton", the channels axis will be set to size 1. + - If an integer is provided, the channels axis will be changed to have that + size. + channels_meta (Sequence[str | Channel] | None): The channels metadata + of the new image. + ngff_version (NgffVersions | None): The NGFF version to use. + chunks (ChunksLike | None): The chunk shape of the new image. + shards (ShardsLike | None): The shard shape of the new image. + dtype (str | None): The data type of the new image. + dimension_separator (DIMENSION_SEPARATOR | None): The separator to use for + dimensions. + compressors (CompressorLike | None): The compressors to use. + extra_array_kwargs (Mapping[str, Any] | None): Extra arguments to pass to + the zarr array creation. + labels (Sequence[str] | None): The labels of the new image. + This argument is DEPRECATED please use channels_meta instead. + pixel_size (PixelSize | None): The pixel size of the new image. + This argument is DEPRECATED please use pixelsize, z_spacing, + and time_spacing instead. + + Returns: + ImagesContainer: The new derived image. + + """ + # TODO: remove in ngio 0.6 + if labels is not None: + warnings.warn( + "The 'labels' argument is deprecated and will be removed in " + "a future release.", + DeprecationWarning, + stacklevel=2, + ) + channels_meta = list(labels) + if pixel_size is not None: + warnings.warn( + "The 'pixel_size' argument is deprecated and will be removed in " + "a future release.", + DeprecationWarning, + stacklevel=2, + ) + pixelsize_ = (pixel_size.y, pixel_size.x) + z_spacing_ = pixel_size.z + time_spacing_ = pixel_size.t + else: + if pixelsize is None: + pixelsize_ = (ref_image.pixel_size.y, ref_image.pixel_size.x) + else: + pixelsize_ = pixelsize + + if z_spacing is None: + z_spacing_ = ref_image.pixel_size.z + else: + z_spacing_ = z_spacing + + if time_spacing is None: + time_spacing_ = ref_image.pixel_size.t + else: + time_spacing_ = time_spacing + ref_meta = ref_image.meta + + shapes = _compute_pyramid_shapes( + shape=shape, + ref_image=ref_image, + ) + ref_shape = next(iter(shapes)) + + if pixelsize is None: + pixelsize = (ref_image.pixel_size.y, ref_image.pixel_size.x) + + if z_spacing is None: + z_spacing = ref_image.pixel_size.z + + if time_spacing is None: + time_spacing = ref_image.pixel_size.t + + if name is None: + name = ref_meta.name + + if dtype is None: + dtype = ref_image.dtype + + if dimension_separator is None: + dimension_separator = find_dimension_separator(ref_image.zarr_array) + + if compressors is None: + compressors = ref_image.zarr_array.compressors # type: ignore + + if chunks is None: + chunks = ref_image.zarr_array.chunks + print(chunks) + if shards is None: + shards = ref_image.zarr_array.shards + + _check_chunks_and_shards_compatibility( + ref_shape=ref_shape, + chunks=chunks, + shards=shards, + ) + + if ngff_version is None: + ngff_version = ref_meta.version + + shapes, axes, chunks, shards = _apply_channel_policy( + ref_image=ref_image, + channels_policy=channels_policy, + shapes=shapes, + axes=ref_image.axes, + chunks=chunks, + shards=shards, + ) + channels_meta_ = _check_channels_meta_compatibility( + meta_type=meta_type, + ref_image=ref_image, + channels_meta=channels_meta, + ) + + handler = init_image_like_from_shapes( + store=store, + meta_type=meta_type, + shapes=shapes, + pixelsize=pixelsize_, + z_spacing=z_spacing_, + time_spacing=time_spacing_, + levels=ref_meta.paths, + time_unit=ref_image.time_unit, + space_unit=ref_image.space_unit, + axes_names=axes, + name=name, + channels_meta=channels_meta_, + chunks=chunks, + shards=shards, + dtype=dtype, + dimension_separator=dimension_separator, + compressors=compressors, + overwrite=overwrite, + ngff_version=ngff_version, + extra_array_kwargs=extra_array_kwargs, + ) + return handler diff --git a/src/ngio/images/_create.py b/src/ngio/images/_create.py deleted file mode 100644 index 0adabc51..00000000 --- a/src/ngio/images/_create.py +++ /dev/null @@ -1,276 +0,0 @@ -"""Utility functions for working with OME-Zarr images.""" - -from collections.abc import Sequence -from typing import TypeVar - -from zarr.types import DIMENSION_SEPARATOR - -from ngio.common._pyramid import init_empty_pyramid -from ngio.ome_zarr_meta import ( - NgioImageMeta, - NgioLabelMeta, - PixelSize, - get_image_meta_handler, - get_label_meta_handler, -) -from ngio.ome_zarr_meta.ngio_specs import ( - DefaultNgffVersion, - DefaultSpaceUnit, - DefaultTimeUnit, - NgffVersions, - SpaceUnits, - TimeUnits, - canonical_axes_order, - canonical_label_axes_order, -) -from ngio.utils import NgioValueError, StoreOrGroup, ZarrGroupHandler - -_image_or_label_meta = TypeVar("_image_or_label_meta", NgioImageMeta, NgioLabelMeta) - - -def _init_generic_meta( - meta_type: type[_image_or_label_meta], - pixelsize: float, - axes_names: Sequence[str], - z_spacing: float = 1.0, - time_spacing: float = 1.0, - levels: int | list[str] = 5, - yx_scaling_factor: float | tuple[float, float] = 2.0, - z_scaling_factor: float = 1.0, - space_unit: SpaceUnits | str | None = DefaultSpaceUnit, - time_unit: TimeUnits | str | None = DefaultTimeUnit, - name: str | None = None, - version: NgffVersions = DefaultNgffVersion, -) -> tuple[_image_or_label_meta, list[float]]: - """Initialize the metadata for an image or label.""" - scaling_factors = [] - for ax in axes_names: - if ax == "z": - scaling_factors.append(z_scaling_factor) - elif ax in ["x"]: - if isinstance(yx_scaling_factor, tuple): - scaling_factors.append(yx_scaling_factor[1]) - else: - scaling_factors.append(yx_scaling_factor) - elif ax in ["y"]: - if isinstance(yx_scaling_factor, tuple): - scaling_factors.append(yx_scaling_factor[0]) - else: - scaling_factors.append(yx_scaling_factor) - else: - scaling_factors.append(1.0) - - pixel_sizes = PixelSize( - x=pixelsize, - y=pixelsize, - z=z_spacing, - t=time_spacing, - space_unit=space_unit, - time_unit=time_unit, - ) - - meta = meta_type.default_init( - name=name, - levels=levels, - axes_names=axes_names, - pixel_size=pixel_sizes, - scaling_factors=scaling_factors, - version=version, - ) - return meta, scaling_factors - - -def create_empty_label_container( - store: StoreOrGroup, - shape: Sequence[int], - pixelsize: float, - z_spacing: float = 1.0, - time_spacing: float = 1.0, - levels: int | list[str] = 5, - yx_scaling_factor: float | tuple[float, float] = 2.0, - z_scaling_factor: float = 1.0, - space_unit: SpaceUnits | str | None = DefaultSpaceUnit, - time_unit: TimeUnits | str | None = DefaultTimeUnit, - axes_names: Sequence[str] | None = None, - name: str | None = None, - chunks: Sequence[int] | None = None, - dtype: str = "uint32", - dimension_separator: DIMENSION_SEPARATOR = "/", - compressor="default", - overwrite: bool = False, - version: NgffVersions = DefaultNgffVersion, -) -> ZarrGroupHandler: - """Create an empty label with the given shape and metadata. - - Args: - store (StoreOrGroup): The Zarr store or group to create the image in. - shape (Sequence[int]): The shape of the image. - pixelsize (float): The pixel size in x and y dimensions. - z_spacing (float, optional): The spacing between z slices. Defaults to 1.0. - time_spacing (float, optional): The spacing between time points. - Defaults to 1.0. - levels (int | list[str], optional): The number of levels in the pyramid or a - list of level names. Defaults to 5. - yx_scaling_factor (float, optional): The down-scaling factor in x and y - dimensions. Defaults to 2.0. - z_scaling_factor (float, optional): The down-scaling factor in z dimension. - Defaults to 1.0. - space_unit (SpaceUnits, optional): The unit of space. Defaults to - DefaultSpaceUnit. - time_unit (TimeUnits, optional): The unit of time. Defaults to - DefaultTimeUnit. - axes_names (Sequence[str] | None, optional): The names of the axes. - If None the canonical names are used. Defaults to None. - name (str | None, optional): The name of the image. Defaults to None. - chunks (Sequence[int] | None, optional): The chunk shape. If None the shape - is used. Defaults to None. - dimension_separator (DIMENSION_SEPARATOR): The separator to use for - dimensions. Defaults to "/". - compressor: The compressor to use. Defaults to "default". - dtype (str, optional): The data type of the image. Defaults to "uint16". - overwrite (bool, optional): Whether to overwrite an existing image. - Defaults to True. - version (str, optional): The version of the OME-Zarr specification. - Defaults to DefaultVersion. - - """ - if axes_names is None: - axes_names = canonical_label_axes_order()[-len(shape) :] - - if len(axes_names) != len(shape): - raise NgioValueError( - f"Number of axes names {axes_names} does not match the number of " - f"dimensions {shape}." - ) - - meta, scaling_factors = _init_generic_meta( - meta_type=NgioLabelMeta, - pixelsize=pixelsize, - z_spacing=z_spacing, - time_spacing=time_spacing, - levels=levels, - yx_scaling_factor=yx_scaling_factor, - z_scaling_factor=z_scaling_factor, - space_unit=space_unit, - time_unit=time_unit, - axes_names=axes_names, - name=name, - version=version, - ) - - mode = "w" if overwrite else "w-" - group_handler = ZarrGroupHandler(store=store, mode=mode, cache=False) - image_handler = get_label_meta_handler(version=version, group_handler=group_handler) - image_handler.write_meta(meta) - - init_empty_pyramid( - store=store, - paths=meta.paths, - scaling_factors=scaling_factors, - ref_shape=shape, - chunks=chunks, - dtype=dtype, - mode="a", - dimension_separator=dimension_separator, - compressor=compressor, - ) - group_handler._mode = "r+" - return group_handler - - -def create_empty_image_container( - store: StoreOrGroup, - shape: Sequence[int], - pixelsize: float, - z_spacing: float = 1.0, - time_spacing: float = 1.0, - levels: int | list[str] = 5, - yx_scaling_factor: float | tuple[float, float] = 2, - z_scaling_factor: float = 1.0, - space_unit: SpaceUnits | str | None = DefaultSpaceUnit, - time_unit: TimeUnits | str | None = DefaultTimeUnit, - axes_names: Sequence[str] | None = None, - name: str | None = None, - chunks: Sequence[int] | None = None, - dtype: str = "uint16", - dimension_separator: DIMENSION_SEPARATOR = "/", - compressor="default", - overwrite: bool = False, - version: NgffVersions = DefaultNgffVersion, -) -> ZarrGroupHandler: - """Create an empty OME-Zarr image with the given shape and metadata. - - Args: - store (StoreOrGroup): The Zarr store or group to create the image in. - shape (Sequence[int]): The shape of the image. - pixelsize (float): The pixel size in x and y dimensions. - z_spacing (float, optional): The spacing between z slices. Defaults to 1.0. - time_spacing (float, optional): The spacing between time points. - Defaults to 1.0. - levels (int | list[str], optional): The number of levels in the pyramid or a - list of level names. Defaults to 5. - yx_scaling_factor (float, optional): The down-scaling factor in x and y - dimensions. Defaults to 2.0. - z_scaling_factor (float, optional): The down-scaling factor in z dimension. - Defaults to 1.0. - space_unit (SpaceUnits, optional): The unit of space. Defaults to - DefaultSpaceUnit. - time_unit (TimeUnits, optional): The unit of time. Defaults to - DefaultTimeUnit. - axes_names (Sequence[str] | None, optional): The names of the axes. - If None the canonical names are used. Defaults to None. - name (str | None, optional): The name of the image. Defaults to None. - chunks (Sequence[int] | None, optional): The chunk shape. If None the shape - is used. Defaults to None. - dtype (str, optional): The data type of the image. Defaults to "uint16". - dimension_separator (DIMENSION_SEPARATOR): The separator to use for - dimensions. Defaults to "/". - compressor: The compressor to use. Defaults to "default". - overwrite (bool, optional): Whether to overwrite an existing image. - Defaults to True. - version (str, optional): The version of the OME-Zarr specification. - Defaults to DefaultVersion. - - """ - if axes_names is None: - axes_names = canonical_axes_order()[-len(shape) :] - - if len(axes_names) != len(shape): - raise NgioValueError( - f"Number of axes names {axes_names} does not match the number of " - f"dimensions {shape}." - ) - - meta, scaling_factors = _init_generic_meta( - meta_type=NgioImageMeta, - pixelsize=pixelsize, - z_spacing=z_spacing, - time_spacing=time_spacing, - levels=levels, - yx_scaling_factor=yx_scaling_factor, - z_scaling_factor=z_scaling_factor, - space_unit=space_unit, - time_unit=time_unit, - axes_names=axes_names, - name=name, - version=version, - ) - mode = "w" if overwrite else "w-" - group_handler = ZarrGroupHandler(store=store, mode=mode, cache=False) - image_handler = get_image_meta_handler(version=version, group_handler=group_handler) - image_handler.write_meta(meta) - - init_empty_pyramid( - store=store, - paths=meta.paths, - scaling_factors=scaling_factors, - ref_shape=shape, - chunks=chunks, - dtype=dtype, - mode="a", - dimension_separator=dimension_separator, - compressor=compressor, - ) - - group_handler._mode = "r+" - return group_handler diff --git a/src/ngio/images/_create_synt_container.py b/src/ngio/images/_create_synt_container.py index 8f5901a8..e82d447d 100644 --- a/src/ngio/images/_create_synt_container.py +++ b/src/ngio/images/_create_synt_container.py @@ -1,14 +1,17 @@ """Abstract class for handling OME-NGFF images.""" -from collections.abc import Sequence +from collections.abc import Mapping, Sequence +from typing import Any, Literal import numpy as np import PIL.Image -from zarr.types import DIMENSION_SEPARATOR +from zarr.core.array import CompressorLike +from ngio.common._pyramid import ChunksLike, ShardsLike from ngio.common._synt_images_utils import fit_to_shape from ngio.images._ome_zarr_container import OmeZarrContainer, create_ome_zarr_from_array from ngio.ome_zarr_meta.ngio_specs import ( + Channel, DefaultNgffVersion, NgffVersions, ) @@ -27,52 +30,45 @@ def create_synthetic_ome_zarr( shape: Sequence[int], reference_sample: AVAILABLE_SAMPLES | SampleInfo = "Cardiomyocyte", levels: int | list[str] = 5, - xy_scaling_factor: float = 2, - z_scaling_factor: float = 1.0, - axes_names: Sequence[str] | None = None, - chunks: Sequence[int] | None = None, - channel_labels: list[str] | None = None, - channel_wavelengths: list[str] | None = None, - channel_colors: Sequence[str] | None = None, - channel_active: Sequence[bool] | None = None, table_backend: TableBackend = DefaultTableBackend, - dimension_separator: DIMENSION_SEPARATOR = "/", - compressor="default", + scaling_factors: Sequence[float] | Literal["auto"] = "auto", + axes_names: Sequence[str] | None = None, + channels_meta: Sequence[str | Channel] | None = None, + ngff_version: NgffVersions = DefaultNgffVersion, + chunks: ChunksLike = "auto", + shards: ShardsLike | None = None, + dimension_separator: Literal[".", "/"] = "/", + compressors: CompressorLike = "auto", + extra_array_kwargs: Mapping[str, Any] | None = None, overwrite: bool = False, - version: NgffVersions = DefaultNgffVersion, ) -> OmeZarrContainer: - """Create an empty OME-Zarr image with the given shape and metadata. + """Create a synthetic OME-Zarr image with the given shape and metadata. Args: store (StoreOrGroup): The Zarr store or group to create the image in. shape (Sequence[int]): The shape of the image. reference_sample (AVAILABLE_SAMPLES | SampleInfo): The reference sample to use. - levels (int | list[str], optional): The number of levels in the pyramid or a - list of level names. Defaults to 5. - xy_scaling_factor (float, optional): The down-scaling factor in x and y - dimensions. Defaults to 2.0. - z_scaling_factor (float, optional): The down-scaling factor in z dimension. - Defaults to 1.0. - axes_names (Sequence[str] | None, optional): The names of the axes. - If None the canonical names are used. Defaults to None. - chunks (Sequence[int] | None, optional): The chunk shape. If None the shape - is used. Defaults to None. - channel_labels (list[str] | None, optional): The labels of the channels. - Defaults to None. - channel_wavelengths (list[str] | None, optional): The wavelengths of the - channels. Defaults to None. - channel_colors (Sequence[str] | None, optional): The colors of the channels. + Defaults to "Cardiomyocyte". + levels (int | list[str]): The number of levels in the pyramid or a list of + level names. Defaults to 5. + table_backend (TableBackend): Table backend to be used to store tables. + Defaults to DefaultTableBackend. + scaling_factors (Sequence[float] | Literal["auto"]): The down-scaling factors + for the pyramid levels. Defaults to "auto". + axes_names (Sequence[str] | None): The names of the axes. If None the + canonical names are used. Defaults to None. + channels_meta (Sequence[str | Channel] | None): The channels metadata. Defaults to None. - channel_active (Sequence[bool] | None, optional): Whether the channels are - active. Defaults to None. - table_backend (TableBackend): Table backend to be used to store tables - dimension_separator (DIMENSION_SEPARATOR): The separator to use for - dimensions. Defaults to "/". - compressor: The compressor to use. Defaults to "default". - overwrite (bool, optional): Whether to overwrite an existing image. - Defaults to True. - version (NgffVersion, optional): The version of the OME-Zarr specification. + ngff_version (NgffVersions): The version of the OME-Zarr specification. Defaults to DefaultNgffVersion. + chunks (ChunksLike): The chunk shape. Defaults to "auto". + shards (ShardsLike | None): The shard shape. Defaults to None. + dimension_separator (Literal[".", "/"]): The separator to use for + dimensions. Defaults to "/". + compressors (CompressorLike): The compressors to use. Defaults to "auto". + extra_array_kwargs (Mapping[str, Any] | None): Extra arguments to pass to + the zarr array creation. Defaults to None. + overwrite (bool): Whether to overwrite an existing image. Defaults to False. """ if isinstance(reference_sample, str): sample_info = get_sample_info(reference_sample) @@ -86,25 +82,23 @@ def create_synthetic_ome_zarr( ome_zarr = create_ome_zarr_from_array( store=store, array=raw, - xy_pixelsize=sample_info.xy_pixelsize, + pixelsize=sample_info.xy_pixelsize, z_spacing=sample_info.z_spacing, time_spacing=sample_info.time_spacing, levels=levels, - xy_scaling_factor=xy_scaling_factor, - z_scaling_factor=z_scaling_factor, space_unit=sample_info.space_unit, time_unit=sample_info.time_unit, axes_names=axes_names, - channel_labels=channel_labels, - channel_wavelengths=channel_wavelengths, - channel_colors=channel_colors, - channel_active=channel_active, + channels_meta=channels_meta, + scaling_factors=scaling_factors, + extra_array_kwargs=extra_array_kwargs, name=sample_info.name, chunks=chunks, + shards=shards, overwrite=overwrite, dimension_separator=dimension_separator, - compressor=compressor, - version=version, + compressors=compressors, + ngff_version=ngff_version, ) image = ome_zarr.get_image() diff --git a/src/ngio/images/_create_utils.py b/src/ngio/images/_create_utils.py new file mode 100644 index 00000000..cc9fcbcf --- /dev/null +++ b/src/ngio/images/_create_utils.py @@ -0,0 +1,423 @@ +"""Utility functions for working with OME-Zarr images.""" + +import warnings +from collections.abc import Mapping, Sequence +from typing import Any, Literal, TypeVar + +from zarr.core.array import CompressorLike + +from ngio.common._pyramid import ChunksLike, ImagePyramidBuilder, ShardsLike +from ngio.ome_zarr_meta import ( + NgioImageMeta, + NgioLabelMeta, + update_ngio_meta, +) +from ngio.ome_zarr_meta.ngio_specs import ( + AxesHandler, + Channel, + ChannelsMeta, + DefaultNgffVersion, + DefaultSpaceUnit, + DefaultTimeUnit, + NgffVersions, + SpaceUnits, + TimeUnits, + build_canonical_axes_handler, + canonical_axes_order, + canonical_label_axes_order, +) +from ngio.ome_zarr_meta.ngio_specs._axes import AxesSetup +from ngio.utils import NgioValueError, StoreOrGroup, ZarrGroupHandler + +_image_or_label_meta = TypeVar("_image_or_label_meta", NgioImageMeta, NgioLabelMeta) + + +def _build_axes_handler( + *, + shape: tuple[int, ...], + axes_names: Sequence[str] | None, + default_channel_order: tuple[str, ...], + space_units: SpaceUnits | str | None = DefaultSpaceUnit, + time_units: TimeUnits | str | None = DefaultTimeUnit, + axes_setup: AxesSetup | None = None, + allow_non_canonical_axes: bool = False, + strict_canonical_order: bool = False, +) -> AxesHandler: + """Compute axes names for given shape.""" + if axes_names is None: + axes_names = default_channel_order[-len(shape) :] + # Validate length + if len(axes_names) != len(shape): + raise NgioValueError( + f"Number of axes names {axes_names} does not match the number of " + f"dimensions {shape}." + ) + return build_canonical_axes_handler( + axes_names=axes_names, + space_units=space_units, + time_units=time_units, + axes_setup=axes_setup, + allow_non_canonical_axes=allow_non_canonical_axes, + strict_canonical_order=strict_canonical_order, + ) + + +def _align_to_axes( + *, + values: dict[str, float], + axes_handler: AxesHandler, + default_value: float = 1.0, +) -> tuple[float, ...]: + """Align given values to axes names.""" + aligned_values = [default_value] * len(axes_handler.axes_names) + for ax, value in values.items(): + index = axes_handler.get_index(ax) + if index is not None: + aligned_values[index] = value + return tuple(aligned_values) + + +def _check_deprecated_scaling_factors( + *, + yx_scaling_factor: float | tuple[float, float] | None = None, + z_scaling_factor: float | None = None, + scaling_factors: Sequence[float] | Literal["auto"] = "auto", + shape: tuple[int, ...], +) -> Sequence[float] | Literal["auto"]: + if yx_scaling_factor is not None or z_scaling_factor is not None: + warnings.warn( + "The 'yx_scaling_factor' and 'z_scaling_factor' arguments are deprecated " + "and will be removed in future versions. Please use the 'scaling_factors' " + "argument instead.", + DeprecationWarning, + stacklevel=2, + ) + if scaling_factors != "auto": + raise NgioValueError( + "Cannot use both 'scaling_factors' and deprecated " + "'yx_scaling_factor'/'z_scaling_factor' arguments." + ) + if isinstance(yx_scaling_factor, tuple): + if len(yx_scaling_factor) != 2: + raise NgioValueError( + "yx_scaling_factor tuple must have length 2 for y and x scaling." + ) + y_scale = yx_scaling_factor[0] + x_scale = yx_scaling_factor[1] + else: + y_scale = yx_scaling_factor if yx_scaling_factor is not None else 2.0 + x_scale = yx_scaling_factor if yx_scaling_factor is not None else 2.0 + z_scale = z_scaling_factor if z_scaling_factor is not None else 1.0 + scaling_factors = (z_scale, x_scale, y_scale) + if len(scaling_factors) < len(shape): + padding = (1.0,) * (len(shape) - len(scaling_factors)) + scaling_factors = padding + scaling_factors + + return scaling_factors + return scaling_factors + + +def _compute_scaling_factors( + *, + scaling_factors: Sequence[float] | Literal["auto"], + shape: tuple[int, ...], + axes_handler: AxesHandler, + xy_scaling_factor: float | tuple[float, float] | None = None, + z_scaling_factor: float | None = None, +) -> tuple[float, ...]: + """Compute scaling factors for given axes names.""" + # TODO remove with ngio 0.6 + scaling_factors = _check_deprecated_scaling_factors( + yx_scaling_factor=xy_scaling_factor, + z_scaling_factor=z_scaling_factor, + scaling_factors=scaling_factors, + shape=shape, + ) + if scaling_factors == "auto": + return _align_to_axes( + values={ + "x": 2.0, + "y": 2.0, + "z": 1.0, + }, + axes_handler=axes_handler, + ) + if len(scaling_factors) != len(shape): + raise NgioValueError( + "Length of scaling_factors does not match the number of dimensions." + ) + return tuple(scaling_factors) + + +def _compute_base_scale( + *, + pixelsize: float | tuple[float, float], + z_spacing: float, + time_spacing: float, + axes_handler: AxesHandler, +) -> tuple[float, ...]: + """Compute base scale for given axes names.""" + if isinstance(pixelsize, tuple): + if len(pixelsize) != 2: + raise NgioValueError( + "pixelsize tuple must have length 2 for y and x pixel sizes." + ) + x_size = pixelsize[1] + y_size = pixelsize[0] + else: + x_size = pixelsize + y_size = pixelsize + return _align_to_axes( + values={ + "x": x_size, + "y": y_size, + "z": z_spacing, + "t": time_spacing, + }, + axes_handler=axes_handler, + ) + + +def _create_image_like_group( + *, + store: StoreOrGroup, + pyramid_builder: ImagePyramidBuilder, + meta: _image_or_label_meta, + overwrite: bool = False, +) -> ZarrGroupHandler: + """Advanced create empty image container function placeholder.""" + mode = "w" if overwrite else "w-" + group_handler = ZarrGroupHandler( + store=store, mode=mode, cache=False, zarr_format=meta.zarr_format + ) + update_ngio_meta(group_handler, meta) + # Reopen in r+ mode + group_handler = group_handler.reopen_handler() + # Write the pyramid + pyramid_builder.to_zarr(group=group_handler.group) + return group_handler + + +def _add_channels_meta( + *, + meta: _image_or_label_meta, + channels_meta: Sequence[str | Channel] | None = None, +) -> _image_or_label_meta: + """Create ChannelsMeta from given channels_meta input.""" + if isinstance(meta, NgioLabelMeta): + if channels_meta is not None: + raise NgioValueError( + "Cannot add channels_meta to NgioLabelMeta. " + "Labels do not have channels." + ) + else: + return meta + if channels_meta is None: + return meta + list_of_channels = [] + for c in channels_meta: + if isinstance(c, str): + channel = Channel.default_init(label=c) + elif isinstance(c, Channel): + channel = c + else: + raise NgioValueError( + "channels_meta must be a list of strings or Channel objects." + ) + list_of_channels.append(channel) + + channels_meta_ = ChannelsMeta(channels=list_of_channels) + meta.set_channels_meta(channels_meta=channels_meta_) + return meta + + +def init_image_like( + *, + # Where to create the image + store: StoreOrGroup, + # Ngff image parameters + meta_type: type[_image_or_label_meta], + shape: Sequence[int], + pixelsize: float | tuple[float, float], + z_spacing: float = 1.0, + time_spacing: float = 1.0, + scaling_factors: Sequence[float] | Literal["auto"] = "auto", + levels: int | list[str] = 5, + space_unit: SpaceUnits | str | None = DefaultSpaceUnit, + time_unit: TimeUnits | str | None = DefaultTimeUnit, + axes_names: Sequence[str] | None = None, + name: str | None = None, + channels_meta: Sequence[str | Channel] | None = None, + ngff_version: NgffVersions = DefaultNgffVersion, + # Zarr Array parameters + chunks: ChunksLike = "auto", + shards: ShardsLike | None = None, + dtype: str = "uint16", + dimension_separator: Literal[".", "/"] = "/", + compressors: CompressorLike = "auto", + extra_array_kwargs: Mapping[str, Any] | None = None, + # internal axes configuration for advanced use cases + axes_setup: AxesSetup | None = None, + allow_non_canonical_axes: bool = False, + strict_canonical_order: bool = False, + # Whether to overwrite existing image + overwrite: bool = False, + # Deprecated arguments + yx_scaling_factor: float | tuple[float, float] | None = None, + z_scaling_factor: float | None = None, +) -> ZarrGroupHandler: + """Create an empty OME-Zarr image with the given shape and metadata.""" + shape = tuple(shape) + if meta_type is NgioImageMeta: + default_axes_order = canonical_axes_order() + else: + default_axes_order = canonical_label_axes_order() + + axes_handler = _build_axes_handler( + shape=shape, + axes_names=axes_names, + default_channel_order=default_axes_order, + space_units=space_unit, + time_units=time_unit, + axes_setup=axes_setup, + allow_non_canonical_axes=allow_non_canonical_axes, + strict_canonical_order=strict_canonical_order, + ) + base_scale = _compute_base_scale( + pixelsize=pixelsize, + z_spacing=z_spacing, + time_spacing=time_spacing, + axes_handler=axes_handler, + ) + scaling_factors = _compute_scaling_factors( + scaling_factors=scaling_factors, + shape=shape, + axes_handler=axes_handler, + xy_scaling_factor=yx_scaling_factor, + z_scaling_factor=z_scaling_factor, + ) + if isinstance(levels, int): + levels_paths = tuple(str(i) for i in range(levels)) + else: + levels_paths = tuple(levels) + + pyramid_builder = ImagePyramidBuilder.from_scaling_factors( + levels_paths=levels_paths, + scaling_factors=scaling_factors, + base_shape=shape, + base_scale=base_scale, + axes=axes_handler.axes_names, + chunks=chunks, + data_type=dtype, + dimension_separator=dimension_separator, + compressors=compressors, + shards=shards, + zarr_format=2 if ngff_version == "0.4" else 3, + other_array_kwargs=extra_array_kwargs, + ) + meta = meta_type.default_init( + levels=[p.path for p in pyramid_builder.levels], + axes_handler=axes_handler, + scales=[p.scale for p in pyramid_builder.levels], + translations=[None for _ in pyramid_builder.levels], + name=name, + version=ngff_version, + ) + meta = _add_channels_meta(meta=meta, channels_meta=channels_meta) + # Keep this creation at the end to avoid partial creations on errors + return _create_image_like_group( + store=store, + pyramid_builder=pyramid_builder, + meta=meta, + overwrite=overwrite, + ) + + +def init_image_like_from_shapes( + *, + # Where to create the image + store: StoreOrGroup, + # Ngff image parameters + meta_type: type[_image_or_label_meta], + shapes: Sequence[tuple[int, ...]], + pixelsize: float | tuple[float, float], + z_spacing: float = 1.0, + time_spacing: float = 1.0, + levels: list[str] | None = None, + space_unit: SpaceUnits | str | None = DefaultSpaceUnit, + time_unit: TimeUnits | str | None = DefaultTimeUnit, + axes_names: Sequence[str] | None = None, + name: str | None = None, + channels_meta: Sequence[str | Channel] | None = None, + ngff_version: NgffVersions = DefaultNgffVersion, + # Zarr Array parameters + chunks: ChunksLike = "auto", + shards: ShardsLike | None = None, + dtype: str = "uint16", + dimension_separator: Literal[".", "/"] = "/", + compressors: CompressorLike = "auto", + extra_array_kwargs: Mapping[str, Any] | None = None, + # internal axes configuration for advanced use cases + axes_setup: AxesSetup | None = None, + allow_non_canonical_axes: bool = False, + strict_canonical_order: bool = False, + # Whether to overwrite existing image + overwrite: bool = False, +) -> ZarrGroupHandler: + """Create an empty OME-Zarr image with the given shape and metadata.""" + base_shape = shapes[0] + if meta_type is NgioImageMeta: + default_axes_order = canonical_axes_order() + else: + default_axes_order = canonical_label_axes_order() + + axes_handler = _build_axes_handler( + shape=base_shape, + axes_names=axes_names, + default_channel_order=default_axes_order, + space_units=space_unit, + time_units=time_unit, + axes_setup=axes_setup, + allow_non_canonical_axes=allow_non_canonical_axes, + strict_canonical_order=strict_canonical_order, + ) + base_scale = _compute_base_scale( + pixelsize=pixelsize, + z_spacing=z_spacing, + time_spacing=time_spacing, + axes_handler=axes_handler, + ) + if levels is None: + levels_paths = tuple(str(i) for i in range(len(shapes))) + else: + levels_paths = tuple(levels) + + pyramid_builder = ImagePyramidBuilder.from_shapes( + shapes=shapes, + base_scale=base_scale, + levels_paths=levels_paths, + axes=axes_handler.axes_names, + chunks=chunks, + data_type=dtype, + dimension_separator=dimension_separator, + compressors=compressors, + shards=shards, + zarr_format=2 if ngff_version == "0.4" else 3, + other_array_kwargs=extra_array_kwargs, + ) + meta = meta_type.default_init( + levels=[p.path for p in pyramid_builder.levels], + axes_handler=axes_handler, + scales=[p.scale for p in pyramid_builder.levels], + translations=[None for _ in pyramid_builder.levels], + name=name, + version=ngff_version, + ) + meta = _add_channels_meta(meta=meta, channels_meta=channels_meta) + # Keep this creation at the end to avoid partial creations on errors + return _create_image_like_group( + store=store, + pyramid_builder=pyramid_builder, + meta=meta, + overwrite=overwrite, + ) diff --git a/src/ngio/images/_image.py b/src/ngio/images/_image.py index ca3a6a60..f2258cde 100644 --- a/src/ngio/images/_image.py +++ b/src/ngio/images/_image.py @@ -1,21 +1,20 @@ """Generic class to handle Image-like data in a OME-NGFF file.""" -from collections.abc import Sequence -from typing import Literal +from collections.abc import Mapping, Sequence +from typing import Any, Literal import dask.array as da import numpy as np from pydantic import BaseModel, model_validator -from zarr.types import DIMENSION_SEPARATOR +from zarr.core.array import CompressorLike from ngio.common import ( Dimensions, InterpolationOrder, Roi, - RoiPixels, ) -from ngio.images._abstract_image import AbstractImage -from ngio.images._create import create_empty_image_container +from ngio.common._pyramid import ChunksLike, ShardsLike +from ngio.images._abstract_image import AbstractImage, abstract_derive from ngio.io_pipes import ( SlicingInputType, TransformProtocol, @@ -24,7 +23,6 @@ ImageMetaHandler, NgioImageMeta, PixelSize, - find_image_meta_handler, ) from ngio.ome_zarr_meta.ngio_specs import ( Channel, @@ -32,11 +30,12 @@ ChannelVisualisation, DefaultSpaceUnit, DefaultTimeUnit, + NgffVersions, SpaceUnits, TimeUnits, ) from ngio.utils import ( - NgioValidationError, + NgioValueError, StoreOrGroup, ZarrGroupHandler, ) @@ -88,7 +87,7 @@ def _check_channel_meta(meta: NgioImageMeta, dimension: Dimensions) -> ChannelsM return ChannelsMeta.default_init(labels=c_dim) if len(meta.channels_meta.channels) != c_dim: - raise NgioValidationError( + raise NgioValueError( "The number of channels does not match the image. " f"Expected {len(meta.channels_meta.channels)} channels, got {c_dim}." ) @@ -96,7 +95,7 @@ def _check_channel_meta(meta: NgioImageMeta, dimension: Dimensions) -> ChannelsM return meta.channels_meta -class Image(AbstractImage[ImageMetaHandler]): +class Image(AbstractImage): """A class to handle a single image (or level) in an OME-Zarr image. This class is meant to be subclassed by specific image types. @@ -117,15 +116,23 @@ def __init__( """ if meta_handler is None: - meta_handler = find_image_meta_handler(group_handler) + meta_handler = ImageMetaHandler(group_handler) super().__init__( group_handler=group_handler, path=path, meta_handler=meta_handler ) + @property + def meta_handler(self) -> ImageMetaHandler: + """Return the metadata handler.""" + assert isinstance(self._meta_handler, ImageMetaHandler) + return self._meta_handler + @property def meta(self) -> NgioImageMeta: """Return the metadata.""" - return self._meta_handler.meta + meta = self.meta_handler.get_meta() + assert isinstance(meta, NgioImageMeta) + return meta @property def channels_meta(self) -> ChannelsMeta: @@ -185,7 +192,7 @@ def get_as_numpy( def get_roi_as_numpy( self, - roi: Roi | RoiPixels, + roi: Roi, channel_selection: ChannelSlicingInputType = None, axes_order: Sequence[str] | None = None, transforms: Sequence[TransformProtocol] | None = None, @@ -239,7 +246,7 @@ def get_as_dask( def get_roi_as_dask( self, - roi: Roi | RoiPixels, + roi: Roi, channel_selection: ChannelSlicingInputType = None, axes_order: Sequence[str] | None = None, transforms: Sequence[TransformProtocol] | None = None, @@ -296,7 +303,7 @@ def get_array( def get_roi( self, - roi: Roi | RoiPixels, + roi: Roi, channel_selection: ChannelSlicingInputType = None, axes_order: Sequence[str] | None = None, transforms: Sequence[TransformProtocol] | None = None, @@ -356,7 +363,7 @@ def set_array( def set_roi( self, - roi: Roi | RoiPixels, + roi: Roi, patch: np.ndarray | da.Array, channel_selection: ChannelSlicingInputType = None, axes_order: Sequence[str] | None = None, @@ -399,22 +406,22 @@ class ImagesContainer: def __init__(self, group_handler: ZarrGroupHandler) -> None: """Initialize the LabelGroupHandler.""" self._group_handler = group_handler - self._meta_handler = find_image_meta_handler(group_handler) + self._meta_handler = ImageMetaHandler(group_handler) @property def meta(self) -> NgioImageMeta: """Return the metadata.""" - return self._meta_handler.meta + return self._meta_handler.get_meta() @property def levels(self) -> int: """Return the number of levels in the image.""" - return self._meta_handler.meta.levels + return self._meta_handler.get_meta().levels @property def levels_paths(self) -> list[str]: """Return the paths of the levels in the image.""" - return self._meta_handler.meta.paths + return self._meta_handler.get_meta().paths @property def num_channels(self) -> int: @@ -454,6 +461,15 @@ def get_channel_idx( channel_label=channel_label, wavelength_id=wavelength_id ) + def _set_channel_meta( + self, + channels_meta: ChannelsMeta, + ) -> None: + """Set the channels metadata.""" + meta = self.meta + meta.set_channels_meta(channels_meta) + self._meta_handler.update_meta(meta) + def set_channel_meta( self, labels: Sequence[str | None] | int | None = None, @@ -490,16 +506,12 @@ def set_channel_meta( ref_image = self.get(path=low_res_dataset.path) if start is not None and end is None: - raise NgioValidationError( - "If start is provided, end must be provided as well." - ) + raise NgioValueError("If start is provided, end must be provided as well.") if end is not None and start is None: - raise NgioValidationError( - "If end is provided, start must be provided as well." - ) + raise NgioValueError("If end is provided, start must be provided as well.") if start is not None and percentiles is not None: - raise NgioValidationError( + raise NgioValueError( "If start and end are provided, percentiles must be None." ) @@ -511,11 +523,11 @@ def set_channel_meta( ) elif start is not None and end is not None: if len(start) != len(end): - raise NgioValidationError( + raise NgioValueError( "The start and end lists must have the same length." ) if len(start) != self.num_channels: - raise NgioValidationError( + raise NgioValueError( "The start and end lists must have the same length as " "the number of channels." ) @@ -539,10 +551,7 @@ def set_channel_meta( data_type=ref_image.dtype, **omero_kwargs, ) - - meta = self.meta - meta.set_channels_meta(channel_meta) - self._meta_handler.write_meta(meta) + self._set_channel_meta(channel_meta) def set_channel_percentiles( self, @@ -551,7 +560,7 @@ def set_channel_percentiles( ) -> None: """Update the percentiles of the channels.""" if self.meta._channels_meta is None: - raise NgioValidationError("The channels meta is not initialized.") + raise NgioValueError("The channels meta is not initialized.") low_res_dataset = self.meta.get_lowest_resolution_dataset() ref_image = self.get(path=low_res_dataset.path) @@ -576,7 +585,7 @@ def set_channel_percentiles( meta = self.meta meta.set_channels_meta(new_meta) - self._meta_handler.write_meta(meta) + self._meta_handler.update_meta(meta) def set_axes_unit( self, @@ -591,59 +600,99 @@ def set_axes_unit( """ meta = self.meta meta = meta.to_units(space_unit=space_unit, time_unit=time_unit) - self._meta_handler.write_meta(meta) + self._meta_handler.update_meta(meta) def derive( self, store: StoreOrGroup, ref_path: str | None = None, + # Metadata parameters shape: Sequence[int] | None = None, - labels: Sequence[str] | None = None, - pixel_size: PixelSize | None = None, - axes_names: Sequence[str] | None = None, + pixelsize: float | tuple[float, float] | None = None, + z_spacing: float | None = None, + time_spacing: float | None = None, name: str | None = None, - chunks: Sequence[int] | None = None, - dtype: str | None = None, - dimension_separator: DIMENSION_SEPARATOR | None = None, - compressor: str | None = None, + channels_meta: Sequence[str | Channel] | None = None, + channels_policy: Literal["same", "squeeze", "singleton"] | int = "same", + ngff_version: NgffVersions | None = None, + # Zarr Array parameters + chunks: ChunksLike | None = None, + shards: ShardsLike | None = None, + dtype: str = "uint16", + dimension_separator: Literal[".", "/"] = "/", + compressors: CompressorLike = "auto", + extra_array_kwargs: Mapping[str, Any] | None = None, overwrite: bool = False, + # Deprecated arguments + labels: Sequence[str] | None = None, + pixel_size: PixelSize | None = None, ) -> "ImagesContainer": """Create an empty OME-Zarr image from an existing image. + If a kwarg is not provided, the value from the reference image will be used. + Args: + image_container (ImagesContainer): The image container to derive the new + image. store (StoreOrGroup): The Zarr store or group to create the image in. - ref_path (str | None): The path to the reference image in - the image container. + ref_path (str | None): The path to the reference image in the image + container. shape (Sequence[int] | None): The shape of the new image. - labels (Sequence[str] | None): The labels of the new image. - pixel_size (PixelSize | None): The pixel size of the new image. - axes_names (Sequence[str] | None): The axes names of the new image. + pixelsize (float | tuple[float, float] | None): The pixel size of the new + image. + z_spacing (float | None): The z spacing of the new image. + time_spacing (float | None): The time spacing of the new image. name (str | None): The name of the new image. - chunks (Sequence[int] | None): The chunk shape of the new image. - dimension_separator (DIMENSION_SEPARATOR | None): The separator to use for - dimensions. If None it will use the same as the reference image. - compressor (str | None): The compressor to use. If None it will use - the same as the reference image. + channels_meta (Sequence[str | Channel] | None): The channels metadata + of the new image. + channels_policy (Literal["same", "squeeze", "singleton"] | int): + Possible policies: + - If "squeeze", the channels axis will be removed (no matter its size). + - If "same", the channels axis will be kept as is (if it exists). + - If "singleton", the channels axis will be set to size 1. + - If an integer is provided, the channels axis will be changed to have + that size. + ngff_version (NgffVersions | None): The NGFF version to use. + chunks (ChunksLike | None): The chunk shape of the new image. + shards (ShardsLike | None): The shard shape of the new image. dtype (str | None): The data type of the new image. + dimension_separator (DIMENSION_SEPARATOR | None): The separator to use for + dimensions. + compressors (CompressorLike | None): The compressors to use. + extra_array_kwargs (Mapping[str, Any] | None): Extra arguments to pass to + the zarr array creation. overwrite (bool): Whether to overwrite an existing image. + labels (Sequence[str] | None): The labels of the new image. + This argument is deprecated please use channels_meta instead. + pixel_size (PixelSize | None): The pixel size of the new image. + This argument is deprecated please use pixelsize, z_spacing, + and time_spacing instead. Returns: - ImagesContainer: The new image + ImagesContainer: The new derived image. + """ return derive_image_container( image_container=self, store=store, ref_path=ref_path, shape=shape, - labels=labels, - pixel_size=pixel_size, - axes_names=axes_names, + pixelsize=pixelsize, + z_spacing=z_spacing, + time_spacing=time_spacing, name=name, + channels_meta=channels_meta, + channels_policy=channels_policy, + ngff_version=ngff_version, chunks=chunks, + shards=shards, dtype=dtype, dimension_separator=dimension_separator, - compressor=compressor, + compressors=compressors, + extra_array_kwargs=extra_array_kwargs, overwrite=overwrite, + labels=labels, + pixel_size=pixel_size, ) def get( @@ -662,7 +711,7 @@ def get( closest pixel size level will be returned. """ - dataset = self._meta_handler.meta.get_dataset( + dataset = self._meta_handler.get_meta().get_dataset( path=path, pixel_size=pixel_size, strict=strict ) return Image( @@ -715,142 +764,98 @@ def compute_image_percentile( def derive_image_container( + *, image_container: ImagesContainer, store: StoreOrGroup, ref_path: str | None = None, + # Metadata parameters shape: Sequence[int] | None = None, - labels: Sequence[str] | None = None, - pixel_size: PixelSize | None = None, - axes_names: Sequence[str] | None = None, + pixelsize: float | tuple[float, float] | None = None, + z_spacing: float | None = None, + time_spacing: float | None = None, name: str | None = None, - chunks: Sequence[int] | None = None, + channels_policy: Literal["same", "squeeze", "singleton"] | int = "same", + channels_meta: Sequence[str | Channel] | None = None, + ngff_version: NgffVersions | None = None, + # Zarr Array parameters + chunks: ChunksLike | None = None, + shards: ShardsLike | None = None, dtype: str | None = None, - dimension_separator: DIMENSION_SEPARATOR | None = None, - compressor=None, + dimension_separator: Literal[".", "/"] | None = None, + compressors: CompressorLike | None = None, + extra_array_kwargs: Mapping[str, Any] | None = None, overwrite: bool = False, + # Deprecated arguments + labels: Sequence[str] | None = None, + pixel_size: PixelSize | None = None, ) -> ImagesContainer: - """Create an empty OME-Zarr image from an existing image. + """Derive a new OME-Zarr image container from an existing image. + + If a kwarg is not provided, the value from the reference image will be used. Args: - image_container (ImagesContainer): The image container to derive the new image. + image_container (ImagesContainer): The image container to derive the new image + from. store (StoreOrGroup): The Zarr store or group to create the image in. ref_path (str | None): The path to the reference image in the image container. shape (Sequence[int] | None): The shape of the new image. - labels (Sequence[str] | None): The labels of the new image. - pixel_size (PixelSize | None): The pixel size of the new image. - axes_names (Sequence[str] | None): The axes names of the new image. + pixelsize (float | tuple[float, float] | None): The pixel size of the new image. + z_spacing (float | None): The z spacing of the new image. + time_spacing (float | None): The time spacing of the new image. name (str | None): The name of the new image. - chunks (Sequence[int] | None): The chunk shape of the new image. - dimension_separator (DIMENSION_SEPARATOR | None): The separator to use for - dimensions. If None it will use the same as the reference image. - compressor: The compressor to use. If None it will use - the same as the reference image. + channels_policy (Literal["squeeze", "same", "singleton"] | int): Possible + policies: + - If "squeeze", the channels axis will be removed (no matter its size). + - If "same", the channels axis will be kept as is (if it exists). + - If "singleton", the channels axis will be set to size 1. + - If an integer is provided, the channels axis will be changed to have + that size. + channels_meta (Sequence[str | Channel] | None): The channels metadata + of the new image. + ngff_version (NgffVersions | None): The NGFF version to use. + chunks (ChunksLike | None): The chunk shape of the new image. + shards (ShardsLike | None): The shard shape of the new image. dtype (str | None): The data type of the new image. - overwrite (bool): Whether to overwrite an existing image. + dimension_separator (Literal[".", "/"] | None): The separator to use for + dimensions. + compressors (CompressorLike | None): The compressors to use. + extra_array_kwargs (Mapping[str, Any] | None): Extra arguments to pass to + the zarr array creation. + overwrite (bool): Whether to overwrite an existing image. Defaults to False. + labels (Sequence[str] | None): Deprecated. This argument is deprecated, + please use channels_meta instead. + pixel_size (PixelSize | None): Deprecated. The pixel size of the new image. + This argument is deprecated, please use pixelsize, z_spacing, + and time_spacing instead. Returns: - ImagesContainer: The new image + ImagesContainer: The new derived image container. """ - if ref_path is None: - ref_image = image_container.get() - else: - ref_image = image_container.get(path=ref_path) - - ref_meta = ref_image.meta - - if shape is None: - shape = ref_image.shape - - if pixel_size is None: - pixel_size = ref_image.pixel_size - - if axes_names is None: - axes_names = ref_meta.axes_handler.axes_names - - if len(axes_names) != len(shape): - raise NgioValidationError( - "The axes names of the new image does not match the reference image." - f"Got {axes_names} for shape {shape}." - ) - - if chunks is None: - chunks = ref_image.chunks - - if len(chunks) != len(shape): - raise NgioValidationError( - "The chunks of the new image does not match the reference image." - f"Got {chunks} for shape {shape}." - ) - - if name is None: - name = ref_meta.name - - if dtype is None: - dtype = ref_image.dtype - - if dimension_separator is None: - dimension_separator = ref_image.zarr_array._dimension_separator # type: ignore - - if compressor is None: - compressor = ref_image.zarr_array.compressor # type: ignore - - handler = create_empty_image_container( + ref_image = image_container.get(path=ref_path) + group_handler = abstract_derive( + ref_image=ref_image, + meta_type=NgioImageMeta, store=store, shape=shape, - pixelsize=pixel_size.x, - z_spacing=pixel_size.z, - time_spacing=pixel_size.t, - levels=ref_meta.paths, - yx_scaling_factor=ref_meta.yx_scaling(), - z_scaling_factor=ref_meta.z_scaling(), - time_unit=pixel_size.time_unit, - space_unit=pixel_size.space_unit, - axes_names=axes_names, + pixelsize=pixelsize, + z_spacing=z_spacing, + time_spacing=time_spacing, name=name, + channels_meta=channels_meta, + channels_policy=channels_policy, + ngff_version=ngff_version, chunks=chunks, + shards=shards, dtype=dtype, - dimension_separator=dimension_separator, # type: ignore - compressor=compressor, # type: ignore + dimension_separator=dimension_separator, + compressors=compressors, + extra_array_kwargs=extra_array_kwargs, overwrite=overwrite, - version=ref_meta.version, - ) - image_container = ImagesContainer(handler) - - if ref_image.num_channels == image_container.num_channels: - _labels = ref_image.channel_labels - wavelength_id = ref_image.wavelength_ids - - channel_meta = ref_image.channels_meta - colors = [c.channel_visualisation.color for c in channel_meta.channels] - active = [c.channel_visualisation.active for c in channel_meta.channels] - start = [c.channel_visualisation.start for c in channel_meta.channels] - end = [c.channel_visualisation.end for c in channel_meta.channels] - else: - _labels = None - wavelength_id = None - colors = None - active = None - start = None - end = None - - if labels is not None: - if len(labels) != image_container.num_channels: - raise NgioValidationError( - "The number of labels does not match the number of channels." - ) - _labels = labels - - image_container.set_channel_meta( - labels=_labels, - wavelength_id=wavelength_id, - percentiles=None, - colors=colors, - active=active, - start=start, - end=end, + labels=labels, + pixel_size=pixel_size, ) - return image_container + return ImagesContainer(group_handler=group_handler) def _parse_str_or_model( @@ -859,9 +864,9 @@ def _parse_str_or_model( """Parse a string or ChannelSelectionModel to an integer channel index.""" if isinstance(channel_selection, int): if channel_selection < 0: - raise NgioValidationError("Channel index must be a non-negative integer.") + raise NgioValueError("Channel index must be a non-negative integer.") if channel_selection >= image.num_channels: - raise NgioValidationError( + raise NgioValueError( "Channel index must be less than the number " f"of channels ({image.num_channels})." ) @@ -879,7 +884,7 @@ def _parse_str_or_model( ) elif channel_selection.mode == "index": return int(channel_selection.identifier) - raise NgioValidationError( + raise NgioValueError( "Invalid channel selection type. " f"{channel_selection} is of type {type(channel_selection)} ", "supported types are str, ChannelSelectionModel, and int.", @@ -898,7 +903,7 @@ def _parse_channel_selection( elif isinstance(channel_selection, Sequence): _sequence = [_parse_str_or_model(image, cs) for cs in channel_selection] return {"c": _sequence} - raise NgioValidationError( + raise NgioValueError( f"Invalid channel selection type {type(channel_selection)}. " "Supported types are int, str, ChannelSelectionModel, and Sequence." ) @@ -912,7 +917,7 @@ def add_channel_selection_to_slicing_dict( """Add channel selection information to the slicing dictionary.""" channel_info = _parse_channel_selection(image, channel_selection) if "c" in slicing_dict and channel_info: - raise NgioValidationError( + raise NgioValueError( "Both channel_selection and 'c' in slicing_kwargs are provided. " "Which channel selection should be used is ambiguous. " "Please provide only one." diff --git a/src/ngio/images/_label.py b/src/ngio/images/_label.py index 427d356b..fbcc4c5c 100644 --- a/src/ngio/images/_label.py +++ b/src/ngio/images/_label.py @@ -1,23 +1,26 @@ """A module for handling label images in OME-NGFF files.""" -from collections.abc import Sequence -from typing import Literal +from collections.abc import Mapping, Sequence +from typing import Any, Literal -from zarr.types import DIMENSION_SEPARATOR +from zarr.core.array import CompressorLike from ngio.common import compute_masking_roi -from ngio.images._abstract_image import AbstractImage -from ngio.images._create import create_empty_label_container +from ngio.common._pyramid import ChunksLike, ShardsLike +from ngio.images._abstract_image import AbstractImage, abstract_derive from ngio.images._image import Image from ngio.ome_zarr_meta import ( LabelMetaHandler, + LabelsGroupMetaHandler, NgioLabelMeta, + NgioLabelsGroupMeta, PixelSize, - find_label_meta_handler, + update_ngio_labels_group_meta, ) from ngio.ome_zarr_meta.ngio_specs import ( DefaultSpaceUnit, DefaultTimeUnit, + NgffVersions, SpaceUnits, TimeUnits, ) @@ -30,7 +33,7 @@ ) -class Label(AbstractImage[LabelMetaHandler]): +class Label(AbstractImage): """Placeholder class for a label.""" get_as_numpy = AbstractImage._get_as_numpy @@ -57,7 +60,7 @@ def __init__( """ if meta_handler is None: - meta_handler = find_label_meta_handler(group_handler) + meta_handler = LabelMetaHandler(group_handler) super().__init__( group_handler=group_handler, path=path, meta_handler=meta_handler ) @@ -66,10 +69,18 @@ def __repr__(self) -> str: """Return the string representation of the label.""" return f"Label(path={self.path}, {self.dimensions})" + @property + def meta_handler(self) -> LabelMetaHandler: + """Return the metadata handler.""" + assert isinstance(self._meta_handler, LabelMetaHandler) + return self._meta_handler + @property def meta(self) -> NgioLabelMeta: """Return the metadata.""" - return self._meta_handler.meta + meta = self.meta_handler.get_meta() + assert isinstance(meta, NgioLabelMeta) + return meta def set_axes_unit( self, @@ -84,7 +95,7 @@ def set_axes_unit( """ meta = self.meta meta = meta.to_units(space_unit=space_unit, time_unit=time_unit) - self._meta_handler.write_meta(meta) + self.meta_handler.update_meta(meta) def build_masking_roi_table(self) -> MaskingRoiTable: """Compute the masking ROI table.""" @@ -104,30 +115,39 @@ def consolidate( class LabelsContainer: """A class to handle the /labels group in an OME-NGFF file.""" - def __init__(self, group_handler: ZarrGroupHandler) -> None: + def __init__( + self, + group_handler: ZarrGroupHandler, + ngff_version: NgffVersions | None = None, + ) -> None: """Initialize the LabelGroupHandler.""" self._group_handler = group_handler - - # Validate the group - # Either contains a labels attribute or is empty - attrs = self._group_handler.load_attrs() - if len(attrs) == 0: - # It's an empty group - pass - elif "labels" in attrs and isinstance(attrs["labels"], list): - # It's a valid group - pass - else: - raise NgioValidationError( - f"Invalid /labels group. " - f"Expected a single labels attribute with a list of label names. " - f"Found: {attrs}" + # If the group is empty, initialize the metadata + try: + self._meta_handler = LabelsGroupMetaHandler(group_handler) + except NgioValidationError: + if ngff_version is None: + raise NgioValueError( + "The /labels group is missing metadata. " + "Please provide the ngff_version to initialize it." + ) from None + meta = NgioLabelsGroupMeta(labels=[], version=ngff_version) + update_ngio_labels_group_meta( + group_handler=group_handler, + ngio_meta=meta, ) + self._group_handler = self._group_handler.reopen_handler() + self._meta_handler = LabelsGroupMetaHandler(group_handler) + + @property + def meta(self) -> NgioLabelsGroupMeta: + """Return the metadata.""" + meta = self._meta_handler.get_meta() + return meta def list(self) -> list[str]: """Create the /labels group if it doesn't exist.""" - attrs = self._group_handler.load_attrs() - return attrs.get("labels", []) + return self.meta.labels def get( self, @@ -153,49 +173,106 @@ def get( f"Available labels: {self.list()}" ) - group_handler = self._group_handler.derive_handler(name) - label_meta_handler = find_label_meta_handler(group_handler) - path = label_meta_handler.meta.get_dataset( - path=path, pixel_size=pixel_size, strict=strict - ).path + group_handler = self._group_handler.get_handler(name) + label_meta_handler = LabelMetaHandler(group_handler) + path = ( + label_meta_handler.get_meta() + .get_dataset(path=path, pixel_size=pixel_size, strict=strict) + .path + ) return Label(group_handler, path, label_meta_handler) + def delete(self, name: str, missing_ok: bool = False) -> None: + """Delete a label from the group. + + Args: + name (str): The name of the label to delete. + missing_ok (bool): If True, do not raise an error if the label does not + exist. + + """ + existing_labels = self.list() + if name not in existing_labels: + if missing_ok: + return + raise NgioValueError( + f"Label '{name}' not found in the Labels group. " + f"Available labels: {existing_labels}" + ) + + self._group_handler.delete_group(name) + existing_labels.remove(name) + update_meta = NgioLabelsGroupMeta( + labels=existing_labels, version=self.meta.version + ) + self._meta_handler.update_meta(update_meta) + def derive( self, name: str, ref_image: Image | Label, + # Metadata parameters shape: Sequence[int] | None = None, - pixel_size: PixelSize | None = None, - axes_names: Sequence[str] | None = None, - chunks: Sequence[int] | None = None, - dtype: str = "uint32", - dimension_separator: DIMENSION_SEPARATOR | None = None, - compressor=None, + pixelsize: float | tuple[float, float] | None = None, + z_spacing: float | None = None, + time_spacing: float | None = None, + channels_policy: Literal["same", "squeeze", "singleton"] | int = "squeeze", + ngff_version: NgffVersions | None = None, + # Zarr Array parameters + chunks: ChunksLike | None = None, + shards: ShardsLike | None = None, + dtype: str | None = None, + dimension_separator: Literal[".", "/"] | None = None, + compressors: CompressorLike | None = None, + extra_array_kwargs: Mapping[str, Any] | None = None, overwrite: bool = False, + # Deprecated arguments + labels: Sequence[str] | None = None, + pixel_size: PixelSize | None = None, ) -> "Label": - """Create an empty OME-Zarr label from a reference image. + """Create an empty OME-Zarr image from an existing image. - And add the label to the /labels group. + If a kwarg is not provided, the value from the reference image will be used. Args: store (StoreOrGroup): The Zarr store or group to create the image in. - ref_image (Image | Label): A reference image that will be used to create - the new image. - name (str): The name of the new image. + ref_image (Image | Label): The reference image to derive the new image from. shape (Sequence[int] | None): The shape of the new image. - pixel_size (PixelSize | None): The pixel size of the new image. + pixelsize (float | tuple[float, float] | None): The pixel size of the new + image. + z_spacing (float | None): The z spacing of the new image. + time_spacing (float | None): The time spacing of the new image. + scaling_factors (Sequence[float] | Literal["auto"] | None): The scaling + factors of the new image. axes_names (Sequence[str] | None): The axes names of the new image. - For labels, the channel axis is not allowed. - chunks (Sequence[int] | None): The chunk shape of the new image. - dtype (str): The data type of the new label. + name (str | None): The name of the new image. + channels_meta (Sequence[str | Channel] | None): The channels metadata + of the new image. + channels_policy (Literal["squeeze", "same", "singleton"] | int): + Possible policies: + - If "squeeze", the channels axis will be removed (no matter its size). + - If "same", the channels axis will be kept as is (if it exists). + - If "singleton", the channels axis will be set to size 1. + - If an integer is provided, the channels axis will be changed to have + that size. + ngff_version (NgffVersions | None): The NGFF version to use. + chunks (ChunksLike | None): The chunk shape of the new image. + shards (ShardsLike | None): The shard shape of the new image. + dtype (str | None): The data type of the new image. dimension_separator (DIMENSION_SEPARATOR | None): The separator to use for - dimensions. If None it will use the same as the reference image. - compressor: The compressor to use. If None it will use - the same as the reference image. + dimensions. + compressors (CompressorLike | None): The compressors to use. + extra_array_kwargs (Mapping[str, Any] | None): Extra arguments to pass to + the zarr array creation. overwrite (bool): Whether to overwrite an existing image. + labels (Sequence[str] | None): The labels of the new image. + This argument is deprecated please use channels_meta instead. + pixel_size (PixelSize | None): The pixel size of the new image. + This argument is deprecated please use pixelsize, z_spacing, + and time_spacing instead. Returns: - Label: The new label. + Label: The new derived label. """ existing_labels = self.list() @@ -208,131 +285,124 @@ def derive( label_group = self._group_handler.get_group(name, create_mode=True) derive_label( - store=label_group, ref_image=ref_image, - name=name, + store=label_group, shape=shape, - pixel_size=pixel_size, - axes_names=axes_names, + pixelsize=pixelsize, + z_spacing=z_spacing, + time_spacing=time_spacing, + name=name, + channels_policy=channels_policy, + ngff_version=ngff_version, chunks=chunks, + shards=shards, dtype=dtype, dimension_separator=dimension_separator, - compressor=compressor, + compressors=compressors, + extra_array_kwargs=extra_array_kwargs, overwrite=overwrite, + labels=labels, + pixel_size=pixel_size, ) if name not in existing_labels: existing_labels.append(name) - self._group_handler.write_attrs({"labels": existing_labels}) + update_meta = NgioLabelsGroupMeta( + labels=existing_labels, version=self.meta.version + ) + self._meta_handler.update_meta(update_meta) return self.get(name) def derive_label( + *, store: StoreOrGroup, ref_image: Image | Label, - name: str, + # Metadata parameters shape: Sequence[int] | None = None, - pixel_size: PixelSize | None = None, - axes_names: Sequence[str] | None = None, - chunks: Sequence[int] | None = None, - dimension_separator: DIMENSION_SEPARATOR | None = None, - compressor=None, - dtype: str = "uint32", + pixelsize: float | tuple[float, float] | None = None, + z_spacing: float | None = None, + time_spacing: float | None = None, + name: str | None = None, + channels_policy: Literal["same", "squeeze", "singleton"] | int = "squeeze", + ngff_version: NgffVersions | None = None, + # Zarr Array parameters + chunks: ChunksLike | None = None, + shards: ShardsLike | None = None, + dtype: str | None = None, + dimension_separator: Literal[".", "/"] | None = None, + compressors: CompressorLike | None = None, + extra_array_kwargs: Mapping[str, Any] | None = None, overwrite: bool = False, -) -> None: - """Create an empty OME-Zarr label from a reference image. + # Deprecated arguments + labels: Sequence[str] | None = None, + pixel_size: PixelSize | None = None, +) -> ZarrGroupHandler: + """Derive a new OME-Zarr label from an existing image or label. + + If a kwarg is not provided, the value from the reference image will be used. Args: - store (StoreOrGroup): The Zarr store or group to create the image in. - ref_image (Image | Label): A reference image that will be used to - create the new image. - name (str): The name of the new image. - shape (Sequence[int] | None): The shape of the new image. - pixel_size (PixelSize | None): The pixel size of the new image. - axes_names (Sequence[str] | None): The axes names of the new image. - For labels, the channel axis is not allowed. - chunks (Sequence[int] | None): The chunk shape of the new image. - dtype (str): The data type of the new label. - dimension_separator (DIMENSION_SEPARATOR | None): The separator to use for - dimensions. If None it will use the same as the reference image. - compressor: The compressor to use. If None it will use - the same as the reference image. - overwrite (bool): Whether to overwrite an existing image. + store (StoreOrGroup): The Zarr store or group to create the label in. + ref_image (Image | Label): The reference image to derive the new label from. + shape (Sequence[int] | None): The shape of the new label. + pixelsize (float | tuple[float, float] | None): The pixel size of the new label. + z_spacing (float | None): The z spacing of the new label. + time_spacing (float | None): The time spacing of the new label. + name (str | None): The name of the new label. + channels_policy (Literal["squeeze", "same", "singleton"] | int): Possible + policies: + - If "squeeze", the channels axis will be removed (no matter its size). + - If "same", the channels axis will be kept as is (if it exists). + - If "singleton", the channels axis will be set to size 1. + - If an integer is provided, the channels axis will be changed to have that + size. + ngff_version (NgffVersions | None): The NGFF version to use. + chunks (ChunksLike | None): The chunk shape of the new label. + shards (ShardsLike | None): The shard shape of the new label. + dtype (str | None): The data type of the new label. + dimension_separator (Literal[".", "/"] | None): The separator to use for + dimensions. + compressors (CompressorLike | None): The compressors to use. + extra_array_kwargs (Mapping[str, Any] | None): Extra arguments to pass to + the zarr array creation. + overwrite (bool): Whether to overwrite an existing label. Defaults to False. + labels (Sequence[str] | None): Deprecated. This argument is deprecated, + please use channels_meta instead. + pixel_size (PixelSize | None): Deprecated. The pixel size of the new label. + This argument is deprecated, please use pixelsize, z_spacing, + and time_spacing instead. Returns: - None + ZarrGroupHandler: The group handler of the new label. """ - ref_meta = ref_image.meta - - if shape is None: - shape = ref_image.shape - - if pixel_size is None: - pixel_size = ref_image.pixel_size - - if axes_names is None: - axes_names = ref_meta.axes_handler.axes_names - c_axis = ref_meta.axes_handler.get_index("c") - else: - if "c" in axes_names: - raise NgioValidationError( - "Labels cannot have a channel axis. " - "Please remove the channel axis from the axes names." - ) - c_axis = None - - if len(axes_names) != len(shape): - raise NgioValidationError( - "The axes names of the new image does not match the reference image." - f"Got {axes_names} for shape {shape}." - ) - - if chunks is None: - chunks = ref_image.chunks - - if len(chunks) != len(shape): - raise NgioValidationError( - "The chunks of the new image does not match the reference image." - f"Got {chunks} for shape {shape}." - ) - - if c_axis is not None: - # remove channel if present - shape = list(shape) - shape = shape[:c_axis] + shape[c_axis + 1 :] - chunks = list(chunks) - chunks = chunks[:c_axis] + chunks[c_axis + 1 :] - axes_names = list(axes_names) - axes_names = axes_names[:c_axis] + axes_names[c_axis + 1 :] - - if dimension_separator is None: - dimension_separator = ref_image.zarr_array._dimension_separator # type: ignore - if compressor is None: - compressor = ref_image.zarr_array.compressor # type: ignore - - _ = create_empty_label_container( + if dtype is None and isinstance(ref_image, Image): + dtype = "uint32" + group_handler = abstract_derive( + ref_image=ref_image, + meta_type=NgioLabelMeta, store=store, shape=shape, - pixelsize=ref_image.pixel_size.x, - z_spacing=ref_image.pixel_size.z, - time_spacing=ref_image.pixel_size.t, - levels=ref_meta.paths, - yx_scaling_factor=ref_meta.yx_scaling(), - z_scaling_factor=ref_meta.z_scaling(), - time_unit=ref_image.pixel_size.time_unit, - space_unit=ref_image.pixel_size.space_unit, - axes_names=axes_names, + pixelsize=pixelsize, + z_spacing=z_spacing, + time_spacing=time_spacing, + name=name, + channels_meta=None, + channels_policy=channels_policy, + ngff_version=ngff_version, chunks=chunks, + shards=shards, dtype=dtype, - dimension_separator=dimension_separator, # type: ignore - compressor=compressor, # type: ignore + dimension_separator=dimension_separator, + compressors=compressors, + extra_array_kwargs=extra_array_kwargs, overwrite=overwrite, - version=ref_meta.version, - name=name, + labels=labels, + pixel_size=pixel_size, ) - return None + return group_handler def build_masking_roi_table(label: Label) -> MaskingRoiTable: diff --git a/src/ngio/images/_ome_zarr_container.py b/src/ngio/images/_ome_zarr_container.py index c4050e73..3853ddbc 100644 --- a/src/ngio/images/_ome_zarr_container.py +++ b/src/ngio/images/_ome_zarr_container.py @@ -1,17 +1,24 @@ """Abstract class for handling OME-NGFF images.""" import warnings -from collections.abc import Sequence +from collections.abc import Mapping, Sequence +from typing import Any, Literal import numpy as np -from zarr.types import DIMENSION_SEPARATOR +from zarr.core.array import CompressorLike -from ngio.images._create import create_empty_image_container +from ngio.common._pyramid import ChunksLike, ShardsLike +from ngio.images._create_utils import init_image_like from ngio.images._image import Image, ImagesContainer from ngio.images._label import Label, LabelsContainer from ngio.images._masked_image import MaskedImage, MaskedLabel -from ngio.ome_zarr_meta import NgioImageMeta, PixelSize, find_label_meta_handler +from ngio.ome_zarr_meta import ( + LabelMetaHandler, + NgioImageMeta, + PixelSize, +) from ngio.ome_zarr_meta.ngio_specs import ( + Channel, DefaultNgffVersion, DefaultSpaceUnit, DefaultTimeUnit, @@ -34,6 +41,7 @@ ) from ngio.utils import ( AccessModeLiteral, + NgioError, NgioValidationError, NgioValueError, StoreOrGroup, @@ -41,18 +49,26 @@ ) -def _default_table_container(handler: ZarrGroupHandler) -> TablesContainer | None: +def _try_get_table_container( + handler: ZarrGroupHandler, create_mode: bool = True +) -> TablesContainer | None: """Return a default table container.""" - success, table_handler = handler.safe_derive_handler("tables") - if success and isinstance(table_handler, ZarrGroupHandler): + try: + table_handler = handler.get_handler("tables", create_mode=create_mode) return TablesContainer(table_handler) + except NgioError: + return None -def _default_label_container(handler: ZarrGroupHandler) -> LabelsContainer | None: +def _try_get_label_container( + handler: ZarrGroupHandler, ngff_version: NgffVersions, create_mode: bool = True +) -> LabelsContainer | None: """Return a default label container.""" - success, label_handler = handler.safe_derive_handler("labels") - if success and isinstance(label_handler, ZarrGroupHandler): - return LabelsContainer(label_handler) + try: + label_handler = handler.get_handler("labels", create_mode=create_mode) + return LabelsContainer(label_handler, ngff_version=ngff_version) + except FileNotFoundError: + return None class OmeZarrContainer: @@ -127,13 +143,17 @@ def images_container(self) -> ImagesContainer: """ return self._images_container - def _get_labels_container(self) -> LabelsContainer | None: + def _get_labels_container(self, create_mode: bool = True) -> LabelsContainer | None: """Return the labels container.""" - if self._labels_container is None: - _labels_container = _default_label_container(self._group_handler) - if _labels_container is None: - return None - self._labels_container = _labels_container + if self._labels_container is not None: + return self._labels_container + + _labels_container = _try_get_label_container( + self._group_handler, + create_mode=create_mode, + ngff_version=self.image_meta.version, + ) + self._labels_container = _labels_container return self._labels_container @property @@ -144,13 +164,15 @@ def labels_container(self) -> LabelsContainer: raise NgioValidationError("No labels found in the image.") return _labels_container - def _get_tables_container(self) -> TablesContainer | None: + def _get_tables_container(self, create_mode: bool = True) -> TablesContainer | None: """Return the tables container.""" - if self._tables_container is None: - _tables_container = _default_table_container(self._group_handler) - if _tables_container is None: - return None - self._tables_container = _tables_container + if self._tables_container is not None: + return self._tables_container + + _tables_container = _try_get_table_container( + self._group_handler, create_mode=create_mode + ) + self._tables_container = _tables_container return self._tables_container @property @@ -402,83 +424,117 @@ def derive_image( self, store: StoreOrGroup, ref_path: str | None = None, + # Metadata parameters shape: Sequence[int] | None = None, - labels: Sequence[str] | None = None, - pixel_size: PixelSize | None = None, - axes_names: Sequence[str] | None = None, + pixelsize: float | tuple[float, float] | None = None, + z_spacing: float | None = None, + time_spacing: float | None = None, name: str | None = None, - chunks: Sequence[int] | None = None, - dtype: str | None = None, - dimension_separator: DIMENSION_SEPARATOR | None = None, - compressor=None, + channels_policy: Literal["squeeze", "same", "singleton"] | int = "same", + channels_meta: Sequence[str | Channel] | None = None, + ngff_version: NgffVersions | None = None, + # Zarr Array parameters + chunks: ChunksLike | None = None, + shards: ShardsLike | None = None, + dtype: str = "uint16", + dimension_separator: Literal[".", "/"] = "/", + compressors: CompressorLike = "auto", + extra_array_kwargs: Mapping[str, Any] | None = None, + overwrite: bool = False, + # Copy from current image copy_labels: bool = False, copy_tables: bool = False, - overwrite: bool = False, + # Deprecated arguments + labels: Sequence[str] | None = None, + pixel_size: PixelSize | None = None, ) -> "OmeZarrContainer": - """Create an empty OME-Zarr container from an existing image. + """Derive a new OME-Zarr container from the current image. + + If a kwarg is not provided, the value from the reference image will be used. Args: store (StoreOrGroup): The Zarr store or group to create the image in. - ref_path (str | None): The path to the reference image in - the image container. + ref_path (str | None): The path to the reference image in the image + container. shape (Sequence[int] | None): The shape of the new image. - labels (Sequence[str] | None): The labels of the new image. - pixel_size (PixelSize | None): The pixel size of the new image. - axes_names (Sequence[str] | None): The axes names of the new image. - chunks (Sequence[int] | None): The chunk shape of the new image. - dtype (str | None): The data type of the new image. + pixelsize (float | tuple[float, float] | None): The pixel size of the new + image. + z_spacing (float | None): The z spacing of the new image. + time_spacing (float | None): The time spacing of the new image. name (str | None): The name of the new image. - dimension_separator (DIMENSION_SEPARATOR | None): The dimension - separator to use. If None, the dimension separator of the - reference image will be used. - compressor: The compressor to use. If None, the compressor of the - reference image will be used. - copy_labels (bool): Whether to copy the labels from the reference image. - copy_tables (bool): Whether to copy the tables from the reference image. - overwrite (bool): Whether to overwrite an existing image. + channels_policy (Literal["squeeze", "same"] | int): Possible policies: + - If "squeeze", the channels axis will be removed (no matter its size). + - If "same", the channels axis will be kept as is (if it exists). + - If "singleton", the channels axis will be set to size 1. + - If an integer is provided, the channels axis will be changed to have + that size. + channels_meta (Sequence[str | Channel] | None): The channels metadata + of the new image. + ngff_version (NgffVersions | None): The NGFF version to use. + chunks (ChunksLike | None): The chunk shape of the new image. + shards (ShardsLike | None): The shard shape of the new image. + dtype (str): The data type of the new image. Defaults to "uint16". + dimension_separator (Literal[".", "/"]): The separator to use for + dimensions. Defaults to "/". + compressors (CompressorLike): The compressors to use. Defaults to "auto". + extra_array_kwargs (Mapping[str, Any] | None): Extra arguments to pass to + the zarr array creation. + overwrite (bool): Whether to overwrite an existing image. Defaults to False. + copy_labels (bool): Whether to copy the labels from the current image. + Defaults to False. + copy_tables (bool): Whether to copy the tables from the current image. + Defaults to False. + labels (Sequence[str] | None): Deprecated. This argument is deprecated, + please use channels_meta instead. + pixel_size (PixelSize | None): Deprecated. The pixel size of the new image. + This argument is deprecated, please use pixelsize, z_spacing, + and time_spacing instead. Returns: - OmeZarrContainer: The new image container. + OmeZarrContainer: The new derived OME-Zarr container. """ - _ = self._images_container.derive( + new_container = self._images_container.derive( store=store, ref_path=ref_path, shape=shape, - labels=labels, - pixel_size=pixel_size, - axes_names=axes_names, + pixelsize=pixelsize, + z_spacing=z_spacing, + time_spacing=time_spacing, name=name, + channels_meta=channels_meta, + channels_policy=channels_policy, + ngff_version=ngff_version, chunks=chunks, + shards=shards, dtype=dtype, dimension_separator=dimension_separator, - compressor=compressor, + compressors=compressors, + extra_array_kwargs=extra_array_kwargs, overwrite=overwrite, - ) - - handler = ZarrGroupHandler( - store, cache=self._group_handler.use_cache, mode=self._group_handler.mode + labels=labels, + pixel_size=pixel_size, ) new_ome_zarr = OmeZarrContainer( - group_handler=handler, + group_handler=new_container._group_handler, validate_paths=False, ) if copy_labels: - self.labels_container._group_handler.copy_handler( - new_ome_zarr.labels_container._group_handler + self.labels_container._group_handler.copy_group( + new_ome_zarr.labels_container._group_handler.group ) if copy_tables: - self.tables_container._group_handler.copy_handler( - new_ome_zarr.tables_container._group_handler + self.tables_container._group_handler.copy_group( + new_ome_zarr.tables_container._group_handler.group ) return new_ome_zarr def list_tables(self, filter_types: TypedTable | str | None = None) -> list[str]: """List all tables in the image.""" - table_container = self._get_tables_container() + table_container = self._get_tables_container(create_mode=False) if table_container is None: return [] @@ -619,9 +675,28 @@ def add_table( name=name, table=table, backend=backend, overwrite=overwrite ) + def delete_table(self, name: str, missing_ok: bool = False) -> None: + """Delete a table from the group. + + Args: + name (str): The name of the table to delete. + missing_ok (bool): If True, do not raise an error if the table does not + exist. + + """ + table_container = self._get_tables_container(create_mode=False) + if table_container is None and missing_ok: + return + if table_container is None: + raise NgioValueError( + f"No tables found in the image, cannot delete {name}. " + "Set missing_ok=True to ignore this error." + ) + table_container.delete(name=name, missing_ok=missing_ok) + def list_labels(self) -> list[str]: """List all labels in the image.""" - label_container = self._get_labels_container() + label_container = self._get_labels_container(create_mode=False) if label_container is None: return [] return label_container.list() @@ -684,42 +759,86 @@ def get_masked_label( masking_roi_table=masking_table, ) + def delete_label(self, name: str, missing_ok: bool = False) -> None: + """Delete a label from the group. + + Args: + name (str): The name of the label to delete. + missing_ok (bool): If True, do not raise an error if the label does not + exist. + + """ + label_container = self._get_labels_container(create_mode=False) + if label_container is None and missing_ok: + return + if label_container is None: + raise NgioValueError( + f"No labels found in the image, cannot delete {name}. " + "Set missing_ok=True to ignore this error." + ) + label_container.delete(name=name, missing_ok=missing_ok) + def derive_label( self, name: str, ref_image: Image | Label | None = None, + # Metadata parameters shape: Sequence[int] | None = None, - pixel_size: PixelSize | None = None, - axes_names: Sequence[str] | None = None, - chunks: Sequence[int] | None = None, - dtype: str = "uint32", - dimension_separator: DIMENSION_SEPARATOR | None = None, - compressor=None, + pixelsize: float | tuple[float, float] | None = None, + z_spacing: float | None = None, + time_spacing: float | None = None, + channels_policy: Literal["same", "squeeze", "singleton"] | int = "squeeze", + ngff_version: NgffVersions | None = None, + # Zarr Array parameters + chunks: ChunksLike | None = None, + shards: ShardsLike | None = None, + dtype: str | None = None, + dimension_separator: Literal[".", "/"] | None = None, + compressors: CompressorLike | None = None, + extra_array_kwargs: Mapping[str, Any] | None = None, overwrite: bool = False, + # Deprecated arguments + labels: Sequence[str] | None = None, + pixel_size: PixelSize | None = None, ) -> "Label": - """Create an empty OME-Zarr label from a reference image. + """Derive a new label from an existing image or label. - And add the label to the /labels group. + If a kwarg is not provided, the value from the reference image will be used. Args: - name (str): The name of the new image. - ref_image (Image | Label | None): A reference image that will be used - to create the new image. - shape (Sequence[int] | None): The shape of the new image. - pixel_size (PixelSize | None): The pixel size of the new image. - axes_names (Sequence[str] | None): The axes names of the new image. - For labels, the channel axis is not allowed. - chunks (Sequence[int] | None): The chunk shape of the new image. - dtype (str): The data type of the new label. - dimension_separator (DIMENSION_SEPARATOR | None): The dimension - separator to use. If None, the dimension separator of the - reference image will be used. - compressor: The compressor to use. If None, the compressor of the - reference image will be used. - overwrite (bool): Whether to overwrite an existing image. + name (str): The name of the new label. + ref_image (Image | Label | None): The reference image to derive the new + label from. If None, the first level image will be used. + shape (Sequence[int] | None): The shape of the new label. + pixelsize (float | tuple[float, float] | None): The pixel size of the new + label. + z_spacing (float | None): The z spacing of the new label. + time_spacing (float | None): The time spacing of the new label. + channels_policy (Literal["same", "squeeze"] | int): Possible policies: + - If "squeeze", the channels axis will be removed (no matter its size). + - If "same", the channels axis will be kept as is (if it exists). + - If "singleton", the channels axis will be set to size 1. + - If an integer is provided, the channels axis will be changed to have + that size. + Defaults to "squeeze". + ngff_version (NgffVersions | None): The NGFF version to use. + chunks (ChunksLike | None): The chunk shape of the new label. + shards (ShardsLike | None): The shard shape of the new label. + dtype (str | None): The data type of the new label. + dimension_separator (Literal[".", "/"] | None): The separator to use for + dimensions. + compressors (CompressorLike | None): The compressors to use. + extra_array_kwargs (Mapping[str, Any] | None): Extra arguments to pass to + the zarr array creation. + overwrite (bool): Whether to overwrite an existing label. Defaults to False. + labels (Sequence[str] | None): Deprecated. This argument is deprecated, + please use channels_meta instead. + pixel_size (PixelSize | None): Deprecated. The pixel size of the new label. + This argument is deprecated, please use pixelsize, z_spacing, + and time_spacing instead. Returns: - Label: The new label. + Label: The new derived label. """ if ref_image is None: @@ -728,13 +847,20 @@ def derive_label( name=name, ref_image=ref_image, shape=shape, - pixel_size=pixel_size, - axes_names=axes_names, + pixelsize=pixelsize, + z_spacing=z_spacing, + time_spacing=time_spacing, + channels_policy=channels_policy, + ngff_version=ngff_version, chunks=chunks, + shards=shards, dtype=dtype, dimension_separator=dimension_separator, - compressor=compressor, + compressors=compressors, + extra_array_kwargs=extra_array_kwargs, overwrite=overwrite, + labels=labels, + pixel_size=pixel_size, ) @@ -773,7 +899,7 @@ def open_image( mode (AccessModeLiteral): The access mode for the image. Defaults to "r+". """ - group_handler = ZarrGroupHandler(store, cache, mode) + group_handler = ZarrGroupHandler(store=store, cache=cache, mode=mode) images_container = ImagesContainer(group_handler) return images_container.get( path=path, @@ -806,12 +932,14 @@ def open_label( mode (AccessModeLiteral): The access mode for the image. Defaults to "r+". """ - group_handler = ZarrGroupHandler(store, cache, mode) + group_handler = ZarrGroupHandler(store=store, cache=cache, mode=mode) if name is None: - label_meta_handler = find_label_meta_handler(group_handler) - path = label_meta_handler.meta.get_dataset( - path=path, pixel_size=pixel_size, strict=strict - ).path + label_meta_handler = LabelMetaHandler(group_handler) + path = ( + label_meta_handler.get_meta() + .get_dataset(path=path, pixel_size=pixel_size, strict=strict) + .path + ) return Label(group_handler, path, label_meta_handler) labels_container = LabelsContainer(group_handler) @@ -826,196 +954,282 @@ def open_label( def create_empty_ome_zarr( store: StoreOrGroup, shape: Sequence[int], - xy_pixelsize: float, + pixelsize: float | tuple[float, float] | None = None, z_spacing: float = 1.0, time_spacing: float = 1.0, + scaling_factors: Sequence[float] | Literal["auto"] = "auto", levels: int | list[str] = 5, - xy_scaling_factor: float = 2, - z_scaling_factor: float = 1.0, space_unit: SpaceUnits = DefaultSpaceUnit, time_unit: TimeUnits = DefaultTimeUnit, axes_names: Sequence[str] | None = None, + channels_meta: Sequence[str | Channel] | None = None, name: str | None = None, - chunks: Sequence[int] | None = None, + ngff_version: NgffVersions = DefaultNgffVersion, + chunks: ChunksLike = "auto", + shards: ShardsLike | None = None, dtype: str = "uint16", - dimension_separator: DIMENSION_SEPARATOR = "/", - compressor="default", + dimension_separator: Literal[".", "/"] = "/", + compressors: CompressorLike = "auto", + extra_array_kwargs: Mapping[str, Any] | None = None, + overwrite: bool = False, + # Deprecated arguments + xy_pixelsize: float | None = None, + xy_scaling_factor: float | None = None, + z_scaling_factor: float | None = None, channel_labels: list[str] | None = None, channel_wavelengths: list[str] | None = None, channel_colors: Sequence[str] | None = None, channel_active: Sequence[bool] | None = None, - overwrite: bool = False, - version: NgffVersions = DefaultNgffVersion, ) -> OmeZarrContainer: """Create an empty OME-Zarr image with the given shape and metadata. Args: store (StoreOrGroup): The Zarr store or group to create the image in. shape (Sequence[int]): The shape of the image. - xy_pixelsize (float): The pixel size in x and y dimensions. - z_spacing (float, optional): The spacing between z slices. Defaults to 1.0. - time_spacing (float, optional): The spacing between time points. - Defaults to 1.0. - levels (int | list[str], optional): The number of levels in the pyramid or a - list of level names. Defaults to 5. - xy_scaling_factor (float, optional): The down-scaling factor in x and y - dimensions. Defaults to 2.0. - z_scaling_factor (float, optional): The down-scaling factor in z dimension. - Defaults to 1.0. - space_unit (SpaceUnits, optional): The unit of space. Defaults to - DefaultSpaceUnit. - time_unit (TimeUnits, optional): The unit of time. Defaults to - DefaultTimeUnit. - axes_names (Sequence[str] | None, optional): The names of the axes. - If None the canonical names are used. Defaults to None. - name (str | None, optional): The name of the image. Defaults to None. - chunks (Sequence[int] | None, optional): The chunk shape. If None the shape - is used. Defaults to None. - dtype (str, optional): The data type of the image. Defaults to "uint16". - dimension_separator (DIMENSION_SEPARATOR): The dimension - separator to use. Defaults to "/". - compressor: The compressor to use. Defaults to "default". - channel_labels (list[str] | None, optional): The labels of the channels. + pixelsize (float | tuple[float, float] | None): The pixel size in x and y + dimensions. + z_spacing (float): The spacing between z slices. Defaults to 1.0. + time_spacing (float): The spacing between time points. Defaults to 1.0. + scaling_factors (Sequence[float] | Literal["auto"]): The down-scaling factors + for the pyramid levels. Defaults to "auto". + levels (int | list[str]): The number of levels in the pyramid or a list of + level names. Defaults to 5. + space_unit (SpaceUnits): The unit of space. Defaults to DefaultSpaceUnit. + time_unit (TimeUnits): The unit of time. Defaults to DefaultTimeUnit. + axes_names (Sequence[str] | None): The names of the axes. If None the + canonical names are used. Defaults to None. + channels_meta (Sequence[str | Channel] | None): The channels metadata. Defaults to None. - channel_wavelengths (list[str] | None, optional): The wavelengths of the - channels. Defaults to None. - channel_colors (Sequence[str] | None, optional): The colors of the channels. - Defaults to None. - channel_active (Sequence[bool] | None, optional): Whether the channels are - active. Defaults to None. - overwrite (bool, optional): Whether to overwrite an existing image. - Defaults to True. - version (NgffVersion, optional): The version of the OME-Zarr specification. + name (str | None): The name of the image. Defaults to None. + ngff_version (NgffVersions): The version of the OME-Zarr specification. Defaults to DefaultNgffVersion. + chunks (ChunksLike): The chunk shape. Defaults to "auto". + shards (ShardsLike | None): The shard shape. Defaults to None. + dtype (str): The data type of the image. Defaults to "uint16". + dimension_separator (Literal[".", "/"]): The dimension separator to use. + Defaults to "/". + compressors (CompressorLike): The compressor to use. Defaults to "auto". + extra_array_kwargs (Mapping[str, Any] | None): Extra arguments to pass to + the zarr array creation. Defaults to None. + overwrite (bool): Whether to overwrite an existing image. Defaults to False. + xy_pixelsize (float | None): Deprecated. Use pixelsize instead. + xy_scaling_factor (float | None): Deprecated. Use scaling_factors instead. + z_scaling_factor (float | None): Deprecated. Use scaling_factors instead. + channel_labels (list[str] | None): Deprecated. Use channels_meta instead. + channel_wavelengths (list[str] | None): Deprecated. Use channels_meta instead. + channel_colors (Sequence[str] | None): Deprecated. Use channels_meta instead. + channel_active (Sequence[bool] | None): Deprecated. Use channels_meta instead. """ - handler = create_empty_image_container( + if xy_pixelsize is not None: + warnings.warn( + "'xy_pixelsize' is deprecated and will be removed in a future " + "version. Please use 'pixelsize' instead.", + DeprecationWarning, + stacklevel=2, + ) + pixelsize = xy_pixelsize + if xy_scaling_factor is not None or z_scaling_factor is not None: + warnings.warn( + "'xy_scaling_factor' and 'z_scaling_factor' are deprecated and will be " + "removed in a future version. Please use 'scaling_factors' instead.", + DeprecationWarning, + stacklevel=2, + ) + xy_scaling_factor_ = xy_scaling_factor or 2.0 + z_scaling_factor_ = z_scaling_factor or 1.0 + if len(shape) == 2: + scaling_factors = (xy_scaling_factor_, xy_scaling_factor_) + else: + zyx_factors = (z_scaling_factor_, xy_scaling_factor_, xy_scaling_factor_) + scaling_factors = (1.0,) * (len(shape) - 3) + zyx_factors + + if channel_labels is not None: + warnings.warn( + "'channel_labels' is deprecated and will be removed in a future " + "version. Please use 'channels_meta' instead.", + DeprecationWarning, + stacklevel=2, + ) + channels_meta = channel_labels + + if channel_wavelengths is not None: + warnings.warn( + "'channel_wavelengths' is deprecated and will be removed in a future " + "version. Please use 'channels_meta' instead.", + DeprecationWarning, + stacklevel=2, + ) + if channel_colors is not None: + warnings.warn( + "'channel_colors' is deprecated and will be removed in a future " + "version. Please use 'channels_meta' instead.", + DeprecationWarning, + stacklevel=2, + ) + if channel_active is not None: + warnings.warn( + "'channel_active' is deprecated and will be removed in a future " + "version. Please use 'channels_meta' instead.", + DeprecationWarning, + stacklevel=2, + ) + + if pixelsize is None: + raise NgioValueError("pixelsize must be provided.") + + handler = init_image_like( store=store, + meta_type=NgioImageMeta, shape=shape, - pixelsize=xy_pixelsize, + pixelsize=pixelsize, z_spacing=z_spacing, time_spacing=time_spacing, + scaling_factors=scaling_factors, levels=levels, - yx_scaling_factor=xy_scaling_factor, - z_scaling_factor=z_scaling_factor, space_unit=space_unit, time_unit=time_unit, axes_names=axes_names, + channels_meta=channels_meta, name=name, + ngff_version=ngff_version, chunks=chunks, + shards=shards, dtype=dtype, dimension_separator=dimension_separator, - compressor=compressor, + compressors=compressors, + extra_array_kwargs=extra_array_kwargs, overwrite=overwrite, - version=version, ) ome_zarr = OmeZarrContainer(group_handler=handler) - ome_zarr.set_channel_meta( - labels=channel_labels, - wavelength_id=channel_wavelengths, - percentiles=None, - colors=channel_colors, - active=channel_active, - ) + if ( + channel_wavelengths is not None + or channel_colors is not None + or channel_active is not None + ): + channel_names = ome_zarr.channel_labels + ome_zarr.set_channel_meta( + labels=channel_names, + wavelength_id=channel_wavelengths, + percentiles=None, + colors=channel_colors, + active=channel_active, + ) + else: + ome_zarr.set_channel_meta( + labels=ome_zarr.channel_labels, + percentiles=None, + ) return ome_zarr def create_ome_zarr_from_array( store: StoreOrGroup, array: np.ndarray, - xy_pixelsize: float, + pixelsize: float | tuple[float, float] | None = None, z_spacing: float = 1.0, time_spacing: float = 1.0, + scaling_factors: Sequence[float] | Literal["auto"] = "auto", levels: int | list[str] = 5, - xy_scaling_factor: float = 2.0, - z_scaling_factor: float = 1.0, space_unit: SpaceUnits = DefaultSpaceUnit, time_unit: TimeUnits = DefaultTimeUnit, axes_names: Sequence[str] | None = None, + channels_meta: Sequence[str | Channel] | None = None, + percentiles: tuple[float, float] = (0.1, 99.9), + name: str | None = None, + ngff_version: NgffVersions = DefaultNgffVersion, + chunks: ChunksLike = "auto", + shards: ShardsLike | None = None, + dimension_separator: Literal[".", "/"] = "/", + compressors: CompressorLike = "auto", + extra_array_kwargs: Mapping[str, Any] | None = None, + overwrite: bool = False, + # Deprecated arguments + xy_pixelsize: float | None = None, + xy_scaling_factor: float | None = None, + z_scaling_factor: float | None = None, channel_labels: list[str] | None = None, channel_wavelengths: list[str] | None = None, - percentiles: tuple[float, float] | None = (0.1, 99.9), channel_colors: Sequence[str] | None = None, channel_active: Sequence[bool] | None = None, - name: str | None = None, - chunks: Sequence[int] | None = None, - dimension_separator: DIMENSION_SEPARATOR = "/", - compressor: str = "default", - overwrite: bool = False, - version: NgffVersions = DefaultNgffVersion, ) -> OmeZarrContainer: """Create an OME-Zarr image from a numpy array. Args: store (StoreOrGroup): The Zarr store or group to create the image in. array (np.ndarray): The image data. - xy_pixelsize (float): The pixel size in x and y dimensions. - z_spacing (float, optional): The spacing between z slices. Defaults to 1.0. - time_spacing (float, optional): The spacing between time points. - Defaults to 1.0. - levels (int | list[str], optional): The number of levels in the pyramid or a - list of level names. Defaults to 5. - xy_scaling_factor (float, optional): The down-scaling factor in x and y - dimensions. Defaults to 2.0. - z_scaling_factor (float, optional): The down-scaling factor in z dimension. - Defaults to 1.0. - space_unit (SpaceUnits, optional): The unit of space. Defaults to - DefaultSpaceUnit. - time_unit (TimeUnits, optional): The unit of time. Defaults to - DefaultTimeUnit. - axes_names (Sequence[str] | None, optional): The names of the axes. - If None the canonical names are used. Defaults to None. - name (str | None, optional): The name of the image. Defaults to None. - chunks (Sequence[int] | None, optional): The chunk shape. If None the shape - is used. Defaults to None. - channel_labels (list[str] | None, optional): The labels of the channels. - Defaults to None. - channel_wavelengths (list[str] | None, optional): The wavelengths of the - channels. Defaults to None. - percentiles (tuple[float, float] | None, optional): The percentiles of the - channels. Defaults to None. - channel_colors (Sequence[str] | None, optional): The colors of the channels. + pixelsize (float | tuple[float, float] | None): The pixel size in x and y + dimensions. + z_spacing (float): The spacing between z slices. Defaults to 1.0. + time_spacing (float): The spacing between time points. Defaults to 1.0. + scaling_factors (Sequence[float] | Literal["auto"]): The down-scaling factors + for the pyramid levels. Defaults to "auto". + levels (int | list[str]): The number of levels in the pyramid or a list of + level names. Defaults to 5. + space_unit (SpaceUnits): The unit of space. Defaults to DefaultSpaceUnit. + time_unit (TimeUnits): The unit of time. Defaults to DefaultTimeUnit. + axes_names (Sequence[str] | None): The names of the axes. If None the + canonical names are used. Defaults to None. + channels_meta (Sequence[str | Channel] | None): The channels metadata. Defaults to None. - channel_active (Sequence[bool] | None, optional): Whether the channels are - active. Defaults to None. - dimension_separator (DIMENSION_SEPARATOR): The separator to use for - dimensions. Defaults to "/". - compressor: The compressor to use. Defaults to "default". - overwrite (bool, optional): Whether to overwrite an existing image. - Defaults to True. - version (str, optional): The version of the OME-Zarr specification. + percentiles (tuple[float, float]): The percentiles of the channels for + computing display ranges. Defaults to (0.1, 99.9). + name (str | None): The name of the image. Defaults to None. + ngff_version (NgffVersions): The version of the OME-Zarr specification. Defaults to DefaultNgffVersion. + chunks (ChunksLike): The chunk shape. Defaults to "auto". + shards (ShardsLike | None): The shard shape. Defaults to None. + dimension_separator (Literal[".", "/"]): The separator to use for + dimensions. Defaults to "/". + compressors (CompressorLike): The compressors to use. Defaults to "auto". + extra_array_kwargs (Mapping[str, Any] | None): Extra arguments to pass to + the zarr array creation. Defaults to None. + overwrite (bool): Whether to overwrite an existing image. Defaults to False. + xy_pixelsize (float | None): Deprecated. Use pixelsize instead. + xy_scaling_factor (float | None): Deprecated. Use scaling_factors instead. + z_scaling_factor (float | None): Deprecated. Use scaling_factors instead. + channel_labels (list[str] | None): Deprecated. Use channels_meta instead. + channel_wavelengths (list[str] | None): Deprecated. Use channels_meta instead. + channel_colors (Sequence[str] | None): Deprecated. Use channels_meta instead. + channel_active (Sequence[bool] | None): Deprecated. Use channels_meta instead. """ - handler = create_empty_image_container( + ome_zarr = create_empty_ome_zarr( store=store, shape=array.shape, - pixelsize=xy_pixelsize, + pixelsize=pixelsize, z_spacing=z_spacing, time_spacing=time_spacing, + scaling_factors=scaling_factors, levels=levels, - yx_scaling_factor=xy_scaling_factor, - z_scaling_factor=z_scaling_factor, space_unit=space_unit, time_unit=time_unit, axes_names=axes_names, + channels_meta=channels_meta, name=name, + ngff_version=ngff_version, chunks=chunks, - dtype=str(array.dtype), - overwrite=overwrite, + shards=shards, dimension_separator=dimension_separator, - compressor=compressor, - version=version, + compressors=compressors, + extra_array_kwargs=extra_array_kwargs, + overwrite=overwrite, + xy_pixelsize=xy_pixelsize, + xy_scaling_factor=xy_scaling_factor, + z_scaling_factor=z_scaling_factor, + channel_labels=channel_labels, + channel_wavelengths=channel_wavelengths, + channel_colors=channel_colors, + channel_active=channel_active, ) - - ome_zarr = OmeZarrContainer(group_handler=handler) image = ome_zarr.get_image() image.set_array(array) image.consolidate() - ome_zarr.set_channel_meta( - labels=channel_labels, - wavelength_id=channel_wavelengths, - percentiles=percentiles, - colors=channel_colors, - active=channel_active, + if len(percentiles) != 2: + raise NgioValueError( + f"'percentiles' must be a tuple of two values. Got {percentiles}" + ) + ome_zarr.set_channel_percentiles( + start_percentile=percentiles[0], + end_percentile=percentiles[1], ) return ome_zarr diff --git a/src/ngio/io_pipes/_io_pipes.py b/src/ngio/io_pipes/_io_pipes.py index 7e562b32..33153d54 100644 --- a/src/ngio/io_pipes/_io_pipes.py +++ b/src/ngio/io_pipes/_io_pipes.py @@ -7,7 +7,7 @@ from dask.array import Array as DaskArray from ngio.common._dimensions import Dimensions -from ngio.common._roi import Roi, RoiPixels +from ngio.common._roi import Roi from ngio.io_pipes._ops_axes import ( AxesOps, build_axes_ops, @@ -72,7 +72,7 @@ def __init__( slicing_ops: SlicingOps, axes_ops: AxesOps, transforms: Sequence[TransformProtocol] | None = None, - roi: Roi | RoiPixels | None = None, + roi: Roi | None = None, ) -> None: self._zarr_array = zarr_array self._slicing_ops = slicing_ops @@ -106,7 +106,7 @@ def transforms(self) -> Sequence[TransformProtocol] | None: return self._transforms @property - def roi(self) -> Roi | RoiPixels: + def roi(self) -> Roi: if self._roi is None: name = self.__class__.__name__ raise ValueError(f"No ROI defined for {name}.") @@ -127,7 +127,7 @@ def __init__( slicing_ops: SlicingOps, axes_ops: AxesOps, transforms: Sequence[TransformProtocol] | None = None, - roi: Roi | RoiPixels | None = None, + roi: Roi | None = None, ) -> None: self._zarr_array = zarr_array self._slicing_ops = slicing_ops @@ -161,7 +161,7 @@ def transforms(self) -> Sequence[TransformProtocol] | None: return self._transforms @property - def roi(self) -> Roi | RoiPixels: + def roi(self) -> Roi: if self._roi is None: name = self.__class__.__name__ raise ValueError(f"No ROI defined for {name}.") @@ -185,7 +185,7 @@ def __init__( transforms: Sequence[TransformProtocol] | None = None, slicing_dict: dict[str, SlicingInputType] | None = None, remove_channel_selection: bool = False, - roi: Roi | RoiPixels | None = None, + roi: Roi | None = None, ) -> None: """Build a pipe to get a numpy or dask array from a zarr array.""" slicing_ops, axes_ops = setup_io_pipe( @@ -225,7 +225,7 @@ def __init__( transforms: Sequence[TransformProtocol] | None = None, slicing_dict: dict[str, SlicingInputType] | None = None, remove_channel_selection: bool = False, - roi: Roi | RoiPixels | None = None, + roi: Roi | None = None, ) -> None: """Build a pipe to get a numpy or dask array from a zarr array.""" slicing_ops, axes_ops = setup_io_pipe( @@ -279,7 +279,7 @@ def __init__( transforms: Sequence[TransformProtocol] | None = None, slicing_dict: dict[str, SlicingInputType] | None = None, remove_channel_selection: bool = False, - roi: Roi | RoiPixels | None = None, + roi: Roi | None = None, ) -> None: """Build a pipe to get a numpy or dask array from a zarr array.""" slicing_ops, axes_ops = setup_io_pipe( @@ -325,7 +325,7 @@ def __init__( transforms: Sequence[TransformProtocol] | None = None, slicing_dict: dict[str, SlicingInputType] | None = None, remove_channel_selection: bool = False, - roi: Roi | RoiPixels | None = None, + roi: Roi | None = None, ) -> None: """Build a pipe to get a numpy or dask array from a zarr array.""" slicing_ops, axes_ops = setup_io_pipe( diff --git a/src/ngio/io_pipes/_io_pipes_masked.py b/src/ngio/io_pipes/_io_pipes_masked.py index c980a570..388baa53 100644 --- a/src/ngio/io_pipes/_io_pipes_masked.py +++ b/src/ngio/io_pipes/_io_pipes_masked.py @@ -6,7 +6,7 @@ from dask.array import Array as DaskArray from ngio.common._dimensions import Dimensions -from ngio.common._roi import Roi, RoiPixels +from ngio.common._roi import Roi from ngio.io_pipes._io_pipes import ( DaskGetter, DaskSetter, @@ -57,7 +57,7 @@ def _setup_numpy_getters( dimensions: Dimensions, label_zarr_array: zarr.Array, label_dimensions: Dimensions, - roi: Roi | RoiPixels, + roi: Roi, axes_order: Sequence[str] | None = None, transforms: Sequence[TransformProtocol] | None = None, label_transforms: Sequence[TransformProtocol] | None = None, @@ -117,7 +117,7 @@ def __init__( dimensions: Dimensions, label_zarr_array: zarr.Array, label_dimensions: Dimensions, - roi: Roi | RoiPixels, + roi: Roi, axes_order: Sequence[str] | None = None, transforms: Sequence[TransformProtocol] | None = None, label_transforms: Sequence[TransformProtocol] | None = None, @@ -187,7 +187,7 @@ def __init__( dimensions: Dimensions, label_zarr_array: zarr.Array, label_dimensions: Dimensions, - roi: Roi | RoiPixels, + roi: Roi, axes_order: Sequence[str] | None = None, transforms: Sequence[TransformProtocol] | None = None, label_transforms: Sequence[TransformProtocol] | None = None, @@ -290,7 +290,7 @@ def _setup_dask_getters( dimensions: Dimensions, label_zarr_array: zarr.Array, label_dimensions: Dimensions, - roi: Roi | RoiPixels, + roi: Roi, axes_order: Sequence[str] | None = None, transforms: Sequence[TransformProtocol] | None = None, label_transforms: Sequence[TransformProtocol] | None = None, @@ -350,7 +350,7 @@ def __init__( dimensions: Dimensions, label_zarr_array: zarr.Array, label_dimensions: Dimensions, - roi: Roi | RoiPixels, + roi: Roi, axes_order: Sequence[str] | None = None, transforms: Sequence[TransformProtocol] | None = None, label_transforms: Sequence[TransformProtocol] | None = None, @@ -418,7 +418,7 @@ def __init__( dimensions: Dimensions, label_zarr_array: zarr.Array, label_dimensions: Dimensions, - roi: Roi | RoiPixels, + roi: Roi, axes_order: Sequence[str] | None = None, transforms: Sequence[TransformProtocol] | None = None, label_transforms: Sequence[TransformProtocol] | None = None, diff --git a/src/ngio/io_pipes/_io_pipes_roi.py b/src/ngio/io_pipes/_io_pipes_roi.py index 7737e9b2..898c678b 100644 --- a/src/ngio/io_pipes/_io_pipes_roi.py +++ b/src/ngio/io_pipes/_io_pipes_roi.py @@ -3,7 +3,7 @@ import zarr from ngio.common._dimensions import Dimensions -from ngio.common._roi import Roi, RoiPixels +from ngio.common._roi import Roi from ngio.io_pipes._io_pipes import ( DaskGetter, DaskSetter, @@ -17,7 +17,7 @@ def roi_to_slicing_dict( *, - roi: Roi | RoiPixels, + roi: Roi, pixel_size: PixelSize, slicing_dict: dict[str, SlicingInputType] | None = None, ) -> dict[str, SlicingInputType]: @@ -40,7 +40,7 @@ def __init__( *, zarr_array: zarr.Array, dimensions: Dimensions, - roi: Roi | RoiPixels, + roi: Roi, axes_order: Sequence[str] | None = None, transforms: Sequence[TransformProtocol] | None = None, slicing_dict: dict[str, SlicingInputType] | None = None, @@ -68,7 +68,7 @@ def __init__( *, zarr_array: zarr.Array, dimensions: Dimensions, - roi: Roi | RoiPixels, + roi: Roi, axes_order: Sequence[str] | None = None, transforms: Sequence[TransformProtocol] | None = None, slicing_dict: dict[str, SlicingInputType] | None = None, @@ -96,7 +96,7 @@ def __init__( *, zarr_array: zarr.Array, dimensions: Dimensions, - roi: Roi | RoiPixels, + roi: Roi, axes_order: Sequence[str] | None = None, transforms: Sequence[TransformProtocol] | None = None, slicing_dict: dict[str, SlicingInputType] | None = None, @@ -124,7 +124,7 @@ def __init__( *, zarr_array: zarr.Array, dimensions: Dimensions, - roi: Roi | RoiPixels, + roi: Roi, axes_order: Sequence[str] | None = None, transforms: Sequence[TransformProtocol] | None = None, slicing_dict: dict[str, SlicingInputType] | None = None, diff --git a/src/ngio/io_pipes/_io_pipes_types.py b/src/ngio/io_pipes/_io_pipes_types.py index e69ec97f..2551941f 100644 --- a/src/ngio/io_pipes/_io_pipes_types.py +++ b/src/ngio/io_pipes/_io_pipes_types.py @@ -3,7 +3,7 @@ import zarr -from ngio.common._roi import Roi, RoiPixels +from ngio.common._roi import Roi from ngio.io_pipes._ops_axes import AxesOps from ngio.io_pipes._ops_slices import SlicingOps from ngio.io_pipes._ops_transforms import TransformProtocol @@ -26,7 +26,7 @@ def axes_ops(self) -> AxesOps: ... def transforms(self) -> Sequence[TransformProtocol] | None: ... @property - def roi(self) -> Roi | RoiPixels: ... + def roi(self) -> Roi: ... def __call__(self) -> GetterDataType: return self.get() @@ -48,7 +48,7 @@ def axes_ops(self) -> AxesOps: ... def transforms(self) -> Sequence[TransformProtocol] | None: ... @property - def roi(self) -> Roi | RoiPixels: ... + def roi(self) -> Roi: ... def __call__(self, patch: SetterDataType) -> None: return self.set(patch) diff --git a/src/ngio/io_pipes/_match_shape.py b/src/ngio/io_pipes/_match_shape.py index 2cd194f2..e20d48f7 100644 --- a/src/ngio/io_pipes/_match_shape.py +++ b/src/ngio/io_pipes/_match_shape.py @@ -1,10 +1,11 @@ +import warnings from collections.abc import Sequence from enum import Enum import dask.array as da import numpy as np -from ngio.utils import NgioValueError, ngio_logger +from ngio.utils import NgioValueError class Action(str, Enum): @@ -28,7 +29,7 @@ def _compute_pad_widths( pad_def.append((before, after)) else: pad_def.append((0, 0)) - ngio_logger.warning( + warnings.warn( f"Images have a different shape ({array_shape} vs {target_shape}). " f"Resolving by padding: {pad_def}", stacklevel=2, @@ -75,7 +76,7 @@ def _compute_trim_slices( else: slices.append(slice(0, s)) - ngio_logger.warning( + warnings.warn( f"Images have a different shape ({array_shape} vs {target_shape}). " f"Resolving by trimming: {slices}", stacklevel=2, @@ -117,7 +118,7 @@ def _compute_rescaling_shape( rescaling_shape.append(s) factor.append(1.0) - ngio_logger.warning( + warnings.warn( f"Images have a different shape ({array_shape} vs {target_shape}). " f"Resolving by scaling with factors {factor}.", stacklevel=2, diff --git a/src/ngio/io_pipes/_ops_slices_utils.py b/src/ngio/io_pipes/_ops_slices_utils.py index a0ddcfd6..261860e3 100644 --- a/src/ngio/io_pipes/_ops_slices_utils.py +++ b/src/ngio/io_pipes/_ops_slices_utils.py @@ -1,8 +1,9 @@ +import warnings from collections.abc import Iterable, Iterator from itertools import product from typing import TypeAlias, TypeVar -from ngio.utils import NgioValueError, ngio_logger +from ngio.utils import NgioValueError T = TypeVar("T") @@ -85,9 +86,10 @@ def check_if_regions_overlap(slices: Iterable[tuple[SlicingType, ...]]) -> bool: return True if it == 10_000: - ngio_logger.warning( + warnings.warn( "Performance Warning check_for_overlaps is O(n^2) and may be slow for " - "large numbers of regions." + "large numbers of regions.", + stacklevel=2, ) return False @@ -189,8 +191,9 @@ def check_if_chunks_overlap( if si & sj: return True if it == 10_000: - ngio_logger.warning( + warnings.warn( "Performance Warning check_for_chunks_overlaps is O(n^2) and may be " - "slow for large numbers of regions." + "slow for large numbers of regions.", + stacklevel=2, ) return False diff --git a/src/ngio/ome_zarr_meta/__init__.py b/src/ngio/ome_zarr_meta/__init__.py index dc526ecd..d9882990 100644 --- a/src/ngio/ome_zarr_meta/__init__.py +++ b/src/ngio/ome_zarr_meta/__init__.py @@ -3,22 +3,25 @@ from ngio.ome_zarr_meta._meta_handlers import ( ImageMetaHandler, LabelMetaHandler, - find_image_meta_handler, - find_label_meta_handler, - find_plate_meta_handler, - find_well_meta_handler, - get_image_meta_handler, - get_label_meta_handler, - get_plate_meta_handler, - get_well_meta_handler, + LabelsGroupMetaHandler, + PlateMetaHandler, + WellMetaHandler, + update_ngio_image_meta, + update_ngio_label_meta, + update_ngio_labels_group_meta, + update_ngio_meta, + update_ngio_plate_meta, + update_ngio_well_meta, ) from ngio.ome_zarr_meta.ngio_specs import ( AxesHandler, Dataset, + DefaultNgffVersion, ImageInWellPath, NgffVersions, NgioImageMeta, NgioLabelMeta, + NgioLabelsGroupMeta, NgioPlateMeta, NgioWellMeta, PixelSize, @@ -29,26 +32,28 @@ __all__ = [ "AxesHandler", "Dataset", + "DefaultNgffVersion", "ImageInWellPath", "ImageMetaHandler", - "ImageMetaHandler", - "LabelMetaHandler", "LabelMetaHandler", + "LabelsGroupMetaHandler", "NgffVersions", "NgffVersions", "NgioImageMeta", "NgioLabelMeta", + "NgioLabelsGroupMeta", "NgioPlateMeta", "NgioWellMeta", "PixelSize", + "PlateMetaHandler", + "PlateMetaHandler", + "WellMetaHandler", "build_canonical_axes_handler", - "find_image_meta_handler", - "find_label_meta_handler", - "find_plate_meta_handler", - "find_well_meta_handler", - "get_image_meta_handler", - "get_label_meta_handler", - "get_plate_meta_handler", - "get_well_meta_handler", "path_in_well_validation", + "update_ngio_image_meta", + "update_ngio_label_meta", + "update_ngio_labels_group_meta", + "update_ngio_meta", + "update_ngio_plate_meta", + "update_ngio_well_meta", ] diff --git a/src/ngio/ome_zarr_meta/_meta_handlers.py b/src/ngio/ome_zarr_meta/_meta_handlers.py index 9f37d552..59dcb859 100644 --- a/src/ngio/ome_zarr_meta/_meta_handlers.py +++ b/src/ngio/ome_zarr_meta/_meta_handlers.py @@ -1,799 +1,536 @@ """Base class for handling OME-NGFF metadata in Zarr groups.""" -from typing import Generic, Protocol, TypeVar - -from pydantic import ValidationError +from collections.abc import Callable +from typing import TypeVar from ngio.ome_zarr_meta.ngio_specs import ( AxesSetup, NgioImageMeta, NgioLabelMeta, + NgioLabelsGroupMeta, NgioPlateMeta, NgioWellMeta, ) +from ngio.ome_zarr_meta.ngio_specs._ngio_image import NgffVersions from ngio.ome_zarr_meta.v04 import ( ngio_to_v04_image_meta, ngio_to_v04_label_meta, + ngio_to_v04_labels_group_meta, ngio_to_v04_plate_meta, ngio_to_v04_well_meta, v04_to_ngio_image_meta, v04_to_ngio_label_meta, + v04_to_ngio_labels_group_meta, v04_to_ngio_plate_meta, v04_to_ngio_well_meta, ) +from ngio.ome_zarr_meta.v05 import ( + ngio_to_v05_image_meta, + ngio_to_v05_label_meta, + ngio_to_v05_labels_group_meta, + ngio_to_v05_plate_meta, + ngio_to_v05_well_meta, + v05_to_ngio_image_meta, + v05_to_ngio_label_meta, + v05_to_ngio_labels_group_meta, + v05_to_ngio_plate_meta, + v05_to_ngio_well_meta, +) from ngio.utils import ( NgioValidationError, NgioValueError, ZarrGroupHandler, ) -ConverterError = ValidationError | Exception | None - -########################################################################### -# -# The code below implements a generic class for handling OME-Zarr metadata -# in Zarr groups. -# -########################################################################### - - -class ImageMetaImporter(Protocol): - @staticmethod - def __call__( - metadata: dict, - axes_setup: AxesSetup | None = None, - allow_non_canonical_axes: bool = False, - strict_canonical_order: bool = True, - ) -> tuple[bool, NgioImageMeta | ConverterError]: - """Convert the metadata to a NgioImageMeta object. - - Args: - metadata (dict): The metadata (typically from a Zarr group .attrs). - axes_setup (AxesSetup, optional): The axes setup. - This is used to map axes with non-canonical names. - allow_non_canonical_axes (bool, optional): Whether to allow non-canonical - axes. - strict_canonical_order (bool, optional): Whether to enforce a strict - canonical order. - - Returns: - tuple[bool, NgioImageMeta | ConverterError]: A tuple with a boolean - indicating whether the conversion was successful and the - NgioImageMeta object or an error. - - """ - ... - - -class ImageMetaExporter(Protocol): - def __call__(self, metadata: NgioImageMeta) -> dict: ... - - -class LabelMetaImporter(Protocol): - @staticmethod - def __call__( - metadata: dict, - axes_setup: AxesSetup | None = None, - allow_non_canonical_axes: bool = False, - strict_canonical_order: bool = True, - ) -> tuple[bool, NgioLabelMeta | ConverterError]: - """Convert the metadata to a NgioLabelMeta object. - - Args: - metadata (dict): The metadata (typically from a Zarr group .attrs). - axes_setup (AxesSetup, optional): The axes setup. - This is used to map axes with non-canonical names. - allow_non_canonical_axes (bool, optional): Whether to allow non-canonical - axes. - strict_canonical_order (bool, optional): Whether to enforce a strict - canonical order. - - Returns: - tuple[bool, NgioLabelMeta | ConverterError]: A tuple with a boolean - indicating whether the conversion was successful and the - NgioLabelMeta object or an error. - - """ - ... - - -class LabelMetaExporter(Protocol): - def __call__(self, metadata: NgioLabelMeta) -> dict: ... - - -class WellMetaImporter(Protocol): - def __call__( - self, metadata: dict - ) -> tuple[bool, NgioWellMeta | ConverterError]: ... +# This could be replaced with a more dynamic registry if needed in the future +_image_encoder_registry = {"0.4": ngio_to_v04_image_meta, "0.5": ngio_to_v05_image_meta} +_image_decoder_registry = {"0.4": v04_to_ngio_image_meta, "0.5": v05_to_ngio_image_meta} +_label_encoder_registry = {"0.4": ngio_to_v04_label_meta, "0.5": ngio_to_v05_label_meta} +_label_decoder_registry = {"0.4": v04_to_ngio_label_meta, "0.5": v05_to_ngio_label_meta} +_plate_encoder_registry = {"0.4": ngio_to_v04_plate_meta, "0.5": ngio_to_v05_plate_meta} +_plate_decoder_registry = {"0.4": v04_to_ngio_plate_meta, "0.5": v05_to_ngio_plate_meta} +_well_encoder_registry = {"0.4": ngio_to_v04_well_meta, "0.5": ngio_to_v05_well_meta} +_well_decoder_registry = {"0.4": v04_to_ngio_well_meta, "0.5": v05_to_ngio_well_meta} +_labels_group_encoder_registry = { + "0.4": ngio_to_v04_labels_group_meta, + "0.5": ngio_to_v05_labels_group_meta, +} +_labels_group_decoder_registry = { + "0.4": v04_to_ngio_labels_group_meta, + "0.5": v05_to_ngio_labels_group_meta, +} + +_meta_type = TypeVar( + "_meta_type", + NgioImageMeta, + NgioLabelMeta, + NgioLabelsGroupMeta, + NgioPlateMeta, + NgioWellMeta, +) -class WellMetaExporter(Protocol): - def __call__(self, metadata: NgioWellMeta) -> dict: ... +def _find_encoder_registry( + ngio_meta: _meta_type, +) -> dict[str, Callable]: + if isinstance(ngio_meta, NgioImageMeta): + return _image_encoder_registry + elif isinstance(ngio_meta, NgioLabelMeta): + return _label_encoder_registry + elif isinstance(ngio_meta, NgioPlateMeta): + return _plate_encoder_registry + elif isinstance(ngio_meta, NgioWellMeta): + return _well_encoder_registry + elif isinstance(ngio_meta, NgioLabelsGroupMeta): + return _labels_group_encoder_registry + else: + raise NgioValueError(f"Unsupported NGIO metadata type: {type(ngio_meta)}") + + +def update_ngio_meta( + group_handler: ZarrGroupHandler, + ngio_meta: _meta_type, +) -> None: + """Update the metadata in the Zarr group. + + Args: + group_handler (ZarrGroupHandler): The Zarr group handler. + ngio_meta (_meta_type): The new NGIO metadata. + + """ + registry = _find_encoder_registry(ngio_meta) + exporter = registry.get(ngio_meta.version) + if exporter is None: + raise NgioValueError(f"Unsupported NGFF version: {ngio_meta.version}") + + zarr_meta = exporter(ngio_meta) + group_handler.write_attrs(zarr_meta) + + +def _find_decoder_registry( + meta_type: type[_meta_type], +) -> dict[str, Callable]: + if meta_type is NgioImageMeta: + return _image_decoder_registry + elif meta_type is NgioLabelMeta: + return _label_decoder_registry + elif meta_type is NgioPlateMeta: + return _plate_decoder_registry + elif meta_type is NgioWellMeta: + return _well_decoder_registry + elif meta_type is NgioLabelsGroupMeta: + return _labels_group_decoder_registry + else: + raise NgioValueError(f"Unsupported NGIO metadata type: {meta_type}") + + +def get_ngio_meta( + group_handler: ZarrGroupHandler, + meta_type: type[_meta_type], + version: str | None = None, + **kwargs, +) -> _meta_type: + """Retrieve the NGIO metadata from the Zarr group. + + Args: + group_handler (ZarrGroupHandler): The Zarr group handler. + meta_type (type[_meta_type]): The type of NGIO metadata to retrieve. + version (str | None): Optional NGFF version to use for decoding. + **kwargs: Additional arguments to pass to the decoder. + + Returns: + _meta_type: The NGIO metadata. + """ + registry = _find_decoder_registry(meta_type) + if version is not None: + decoder = registry.get(version) + if decoder is None: + raise NgioValueError(f"Unsupported NGFF version: {version}") + versions_to_try = {version: decoder} + else: + versions_to_try = registry + + attrs = group_handler.load_attrs() + all_errors = [] + for version, decoder in versions_to_try.items(): + try: + ngio_meta = decoder(attrs, **kwargs) + return ngio_meta + except Exception as e: + all_errors.append(f"Version {version}: {e}") + error_message = ( + f"Failed to decode NGIO {meta_type.__name__} metadata:\n" + + "\n".join(all_errors) + ) + raise NgioValidationError(error_message) -class PlateMetaImporter(Protocol): - def __call__( - self, metadata: dict - ) -> tuple[bool, NgioPlateMeta | ConverterError]: ... +################################################## +# +# Concrete implementations for NGIO metadata types +# +################################################## -class PlateMetaExporter(Protocol): - def __call__(self, metadata: NgioPlateMeta) -> dict: ... +def get_ngio_image_meta( + group_handler: ZarrGroupHandler, + version: str | None = None, + axes_setup: AxesSetup | None = None, + allow_non_canonical_axes: bool = False, + strict_canonical_order: bool = True, +) -> NgioImageMeta: + """Retrieve the NGIO image metadata from the Zarr group. + + Args: + group_handler (ZarrGroupHandler): The Zarr group handler. + version (str | None): Optional NGFF version to use for decoding. + axes_setup (AxesSetup | None): Optional axes setup for validation. + allow_non_canonical_axes (bool): Whether to allow non-canonical axes. + strict_canonical_order (bool): Whether to enforce strict canonical order. + + Returns: + NgioImageMeta: The NGIO image metadata. + """ + return get_ngio_meta( + group_handler=group_handler, + meta_type=NgioImageMeta, + version=version, + axes_setup=axes_setup, + allow_non_canonical_axes=allow_non_canonical_axes, + strict_canonical_order=strict_canonical_order, + ) -########################################################################### -# -# Image and label metadata handlers -# -########################################################################### +def update_ngio_image_meta( + group_handler: ZarrGroupHandler, + ngio_meta: NgioImageMeta, +) -> None: + """Update the NGIO image metadata in the Zarr group. -_image_meta = TypeVar("_image_meta", NgioImageMeta, NgioLabelMeta) -_image_meta_importer = TypeVar( - "_image_meta_importer", ImageMetaImporter, LabelMetaImporter -) -_image_meta_exporter = TypeVar( - "_image_meta_exporter", ImageMetaExporter, LabelMetaExporter -) + Args: + group_handler (ZarrGroupHandler): The Zarr group handler. + ngio_meta (NgioImageMeta): The new NGIO image metadata. + """ + update_ngio_meta( + group_handler=group_handler, + ngio_meta=ngio_meta, + ) -class GenericMetaHandler( - Generic[_image_meta, _image_meta_importer, _image_meta_exporter] -): - """Generic class for handling OME-Zarr metadata in Zarr groups.""" +class ImageMetaHandler: def __init__( self, - meta_importer: _image_meta_importer, - meta_exporter: _image_meta_exporter, group_handler: ZarrGroupHandler, + version: str | None = None, axes_setup: AxesSetup | None = None, allow_non_canonical_axes: bool = False, strict_canonical_order: bool = True, ): - """Initialize the handler. - - Args: - meta_importer (MetaImporter): The metadata importer. - meta_exporter (MetaExporter): The metadata exporter. - group_handler (ZarrGroupHandler): The Zarr group handler. - axes_setup (AxesSetup, optional): The axes setup. - This is used to map axes with non-canonical names. - allow_non_canonical_axes (bool, optional): Whether to allow non-canonical - axes. - strict_canonical_order (bool, optional): Whether to enforce a strict - canonical order. - """ self._group_handler = group_handler - self._meta_importer = meta_importer - self._meta_exporter = meta_exporter + self._version = version self._axes_setup = axes_setup self._allow_non_canonical_axes = allow_non_canonical_axes self._strict_canonical_order = strict_canonical_order - def _load_meta(self, return_error: bool = False): - """Load the metadata from the store.""" - attrs = self._group_handler.load_attrs() - is_valid, meta_or_error = self._meta_importer( - metadata=attrs, + # Validate metadata + meta = self.get_meta() + # Store the resolved version + self._version = meta.version + + def get_meta(self) -> NgioImageMeta: + """Retrieve the NGIO image metadata.""" + return get_ngio_image_meta( + group_handler=self._group_handler, + version=self._version, axes_setup=self._axes_setup, allow_non_canonical_axes=self._allow_non_canonical_axes, strict_canonical_order=self._strict_canonical_order, ) - if is_valid: - return meta_or_error - - if return_error: - return meta_or_error - - raise NgioValueError(f"Could not load metadata: {meta_or_error}") - - def _write_meta(self, meta) -> None: - """Write the metadata to the store.""" - _meta = self._meta_exporter(metadata=meta) - self._group_handler.write_attrs(_meta) - - def write_meta(self, meta: _image_meta) -> None: - self._write_meta(meta) - - @property - def meta(self) -> _image_meta: - """Return the metadata.""" - raise NotImplementedError("This method should be implemented in a subclass.") - - -class ImageMetaHandler( - GenericMetaHandler[NgioImageMeta, ImageMetaImporter, ImageMetaExporter] -): - """Generic class for handling OME-Zarr metadata in Zarr groups.""" - - @property - def meta(self) -> NgioImageMeta: - meta = self._load_meta() - if isinstance(meta, NgioImageMeta): - return meta - raise NgioValueError(f"Could not load metadata: {meta}") - def safe_load_meta(self) -> NgioImageMeta | ConverterError: - """Load the metadata from the store.""" - return self._load_meta(return_error=True) - - -class LabelMetaHandler( - GenericMetaHandler[NgioLabelMeta, LabelMetaImporter, LabelMetaExporter] -): - """Generic class for handling OME-Zarr metadata in Zarr groups.""" + def update_meta(self, ngio_meta: NgioImageMeta) -> None: + """Update the NGIO image metadata.""" + update_ngio_meta( + group_handler=self._group_handler, + ngio_meta=ngio_meta, + ) - @property - def meta(self) -> NgioLabelMeta: - meta = self._load_meta() - if isinstance(meta, NgioLabelMeta): - return meta - raise NgioValueError(f"Could not load metadata: {meta}") - def safe_load_meta(self) -> NgioLabelMeta | ConverterError: - """Load the metadata from the store.""" - return self._load_meta(return_error=True) +def get_ngio_label_meta( + group_handler: ZarrGroupHandler, + version: str | None = None, + axes_setup: AxesSetup | None = None, + allow_non_canonical_axes: bool = False, + strict_canonical_order: bool = True, +) -> NgioLabelMeta: + """Retrieve the NGIO label metadata from the Zarr group. + + Args: + group_handler (ZarrGroupHandler): The Zarr group handler. + version (str | None): Optional NGFF version to use for decoding. + axes_setup (AxesSetup | None): Optional axes setup for validation. + allow_non_canonical_axes (bool): Whether to allow non-canonical axes. + strict_canonical_order (bool): Whether to enforce strict canonical order. + + Returns: + NgioLabelMeta: The NGIO label metadata. + """ + return get_ngio_meta( + group_handler=group_handler, + meta_type=NgioLabelMeta, + version=version, + axes_setup=axes_setup, + allow_non_canonical_axes=allow_non_canonical_axes, + strict_canonical_order=strict_canonical_order, + ) -########################################################################### -# -# Well and plate metadata handlers -# -########################################################################### +def update_ngio_label_meta( + group_handler: ZarrGroupHandler, + ngio_meta: NgioLabelMeta, +) -> None: + """Update the NGIO label metadata in the Zarr group. -_hcs_meta = TypeVar("_hcs_meta", NgioWellMeta, NgioPlateMeta) -_hcs_meta_importer = TypeVar("_hcs_meta_importer", WellMetaImporter, PlateMetaImporter) -_hcs_meta_exporter = TypeVar("_hcs_meta_exporter", WellMetaExporter, PlateMetaExporter) + Args: + group_handler (ZarrGroupHandler): The Zarr group handler. + ngio_meta (NgioLabelMeta): The new NGIO label metadata. + """ + update_ngio_meta( + group_handler=group_handler, + ngio_meta=ngio_meta, + ) -class GenericHCSMetaHandler(Generic[_hcs_meta, _hcs_meta_importer, _hcs_meta_exporter]): - """Generic class for handling OME-Zarr metadata in Zarr groups.""" +class LabelMetaHandler: def __init__( self, - meta_importer: _hcs_meta_importer, - meta_exporter: _hcs_meta_exporter, group_handler: ZarrGroupHandler, + version: str | None = None, + axes_setup: AxesSetup | None = None, + allow_non_canonical_axes: bool = False, + strict_canonical_order: bool = True, ): self._group_handler = group_handler - self._meta_importer = meta_importer - self._meta_exporter = meta_exporter - - def _load_meta(self, return_error: bool = False): - """Load the metadata from the store.""" - attrs = self._group_handler.load_attrs() - is_valid, meta_or_error = self._meta_importer(metadata=attrs) - if is_valid: - return meta_or_error - - if return_error: - return meta_or_error - - raise NgioValueError(f"Could not load metadata: {meta_or_error}") - - def _write_meta(self, meta) -> None: - _meta = self._meta_exporter(metadata=meta) - self._group_handler.write_attrs(_meta) - - def write_meta(self, meta: _hcs_meta) -> None: - self._write_meta(meta) - - @property - def meta(self) -> _hcs_meta: - raise NotImplementedError("This method should be implemented in a subclass.") - - -class WellMetaHandler( - GenericHCSMetaHandler[NgioWellMeta, WellMetaImporter, WellMetaExporter] -): - """Generic class for handling OME-Zarr metadata in Zarr groups.""" - - @property - def meta(self) -> NgioWellMeta: - meta = self._load_meta() - if isinstance(meta, NgioWellMeta): - return meta - raise NgioValueError(f"Could not load metadata: {meta}") + self._version = version + self._axes_setup = axes_setup + self._allow_non_canonical_axes = allow_non_canonical_axes + self._strict_canonical_order = strict_canonical_order - def safe_load_meta(self) -> NgioWellMeta | ConverterError: - """Load the metadata from the store.""" - return self._load_meta(return_error=True) + # Validate metadata + meta = self.get_meta() + # Store the resolved version + self._version = meta.version + def get_meta(self) -> NgioLabelMeta: + """Retrieve the NGIO label metadata.""" + return get_ngio_label_meta( + group_handler=self._group_handler, + version=self._version, + axes_setup=self._axes_setup, + allow_non_canonical_axes=self._allow_non_canonical_axes, + strict_canonical_order=self._strict_canonical_order, + ) -class PlateMetaHandler( - GenericHCSMetaHandler[NgioPlateMeta, PlateMetaImporter, PlateMetaExporter] -): - """Generic class for handling OME-Zarr metadata in Zarr groups.""" + def update_meta(self, ngio_meta: NgioLabelMeta) -> None: + """Update the NGIO label metadata.""" + update_ngio_meta( + group_handler=self._group_handler, + ngio_meta=ngio_meta, + ) - @property - def meta(self) -> NgioPlateMeta: - meta = self._load_meta() - if isinstance(meta, NgioPlateMeta): - return meta - raise NgioValueError(f"Could not load metadata: {meta}") - def safe_load_meta(self) -> NgioPlateMeta | ConverterError: - """Load the metadata from the store.""" - return self._load_meta(return_error=True) +def get_ngio_plate_meta( + group_handler: ZarrGroupHandler, + version: str | None = None, +) -> NgioPlateMeta: + """Retrieve the NGIO plate metadata from the Zarr group. + + Args: + group_handler (ZarrGroupHandler): The Zarr group handler. + version (str | None): Optional NGFF version to use for decoding. + + Returns: + NgioPlateMeta: The NGIO plate metadata. + """ + return get_ngio_meta( + group_handler=group_handler, + meta_type=NgioPlateMeta, + version=version, + ) -########################################################################### -# -# Metadata importer/exporter registration & builder classes -# -########################################################################### +def update_ngio_plate_meta( + group_handler: ZarrGroupHandler, + ngio_meta: NgioPlateMeta, +) -> None: + """Update the NGIO plate metadata in the Zarr group. + Args: + group_handler (ZarrGroupHandler): The Zarr group handler. + ngio_meta (NgioPlateMeta): The new NGIO plate metadata. -_meta_exporter = TypeVar( - "_meta_exporter", - ImageMetaExporter, - LabelMetaExporter, - WellMetaExporter, - PlateMetaExporter, -) -_meta_importer = TypeVar( - "_meta_importer", - ImageMetaImporter, - LabelMetaImporter, - WellMetaImporter, - PlateMetaImporter, -) + """ + update_ngio_meta( + group_handler=group_handler, + ngio_meta=ngio_meta, + ) -class _ImporterExporter(Generic[_meta_importer, _meta_exporter]): +class PlateMetaHandler: def __init__( - self, - version: str, - importer: _meta_importer, - exporter: _meta_exporter, - ): - self.importer = importer - self.exporter = exporter - self.version = version - - -ImageImporterExporter = _ImporterExporter[ImageMetaImporter, ImageMetaExporter] -LabelImporterExporter = _ImporterExporter[LabelMetaImporter, LabelMetaExporter] -WellImporterExporter = _ImporterExporter[WellMetaImporter, WellMetaExporter] -PlateImporterExporter = _ImporterExporter[PlateMetaImporter, PlateMetaExporter] - -_importer_exporter = TypeVar( - "_importer_exporter", - ImageImporterExporter, - LabelImporterExporter, - WellImporterExporter, - PlateImporterExporter, -) -_image_handler = TypeVar("_image_handler", ImageMetaHandler, LabelMetaHandler) -_hcs_handler = TypeVar("_hcs_handler", WellMetaHandler, PlateMetaHandler) - - -class ImplementedMetaImporterExporter: - _instance = None - _image_ie: dict[str, ImageImporterExporter] - _label_ie: dict[str, LabelImporterExporter] - _well_ie: dict[str, WellImporterExporter] - _plate_ie: dict[str, PlateImporterExporter] - - def __new__(cls): - """Create a new instance of the class if it does not exist.""" - if cls._instance is None: - cls._instance = super().__new__(cls) - cls._image_ie = {} - cls._label_ie = {} - cls._well_ie = {} - cls._plate_ie = {} - return cls._instance - - def _find_image_handler( self, group_handler: ZarrGroupHandler, - axes_setup: AxesSetup | None = None, - allow_non_canonical_axes: bool = False, - strict_canonical_order: bool = True, - _ie_name: str = "_image_ie", - _handler: type[_image_handler] = ImageMetaHandler, - ) -> _image_handler: - """Get an image metadata handler.""" - _errors = {} - - dict_ie = self.__getattribute__(_ie_name) - - for ie in reversed(dict_ie.values()): - handler = _handler( - meta_importer=ie.importer, - meta_exporter=ie.exporter, - group_handler=group_handler, - axes_setup=axes_setup, - allow_non_canonical_axes=allow_non_canonical_axes, - strict_canonical_order=strict_canonical_order, - ) - meta = handler.safe_load_meta() - if isinstance(meta, ValidationError): - _errors[ie.version] = meta - continue - return handler - - raise NgioValidationError( - f"Could not load metadata from any known version. Errors: {_errors}" - ) - - def find_image_handler( - self, - group_handler: ZarrGroupHandler, - axes_setup: AxesSetup | None = None, - allow_non_canonical_axes: bool = False, - strict_canonical_order: bool = True, - ) -> ImageMetaHandler: - """Get an image metadata handler.""" - return self._find_image_handler( - group_handler=group_handler, - axes_setup=axes_setup, - allow_non_canonical_axes=allow_non_canonical_axes, - strict_canonical_order=strict_canonical_order, - _ie_name="_image_ie", - _handler=ImageMetaHandler, - ) - - def get_image_meta_handler( - self, - group_handler: ZarrGroupHandler, - version: str, - axes_setup: AxesSetup | None = None, - allow_non_canonical_axes: bool = False, - strict_canonical_order: bool = True, - ) -> ImageMetaHandler: - """Get an image metadata handler.""" - if version not in self._image_ie: - raise NgioValueError(f"Image handler for version {version} does not exist.") - - image_ie = self._image_ie[version] - return ImageMetaHandler( - meta_importer=image_ie.importer, - meta_exporter=image_ie.exporter, - group_handler=group_handler, - axes_setup=axes_setup, - allow_non_canonical_axes=allow_non_canonical_axes, - strict_canonical_order=strict_canonical_order, - ) - - def _register( - self, - version: str, - importer: _importer_exporter, - overwrite: bool = False, - _ie_name: str = "_image_ie", - ): - """Register an importer/exporter.""" - ie_dict = self.__getattribute__(_ie_name) - if version in ie_dict and not overwrite: - raise NgioValueError( - f"Importer/exporter for version {version} already exists. " - "Use 'overwrite=True' to overwrite." - ) - ie_dict[version] = importer - - def register_image_ie( - self, - version: str, - importer: ImageMetaImporter, - exporter: ImageMetaExporter, - overwrite: bool = False, + version: str | None = None, ): - """Register an importer/exporter.""" - importer_exporter = ImageImporterExporter( - version=version, importer=importer, exporter=exporter - ) - self._register( - version=version, - importer=importer_exporter, - overwrite=overwrite, - _ie_name="_image_ie", + self._group_handler = group_handler + self._version = version + + # Validate metadata + _ = self.get_meta() + # Store the resolved version + # self._version = meta.version + + def get_meta(self) -> NgioPlateMeta: + """Retrieve the NGIO plate metadata.""" + return get_ngio_plate_meta( + group_handler=self._group_handler, + version=self._version, ) - def find_label_handler( - self, - group_handler: ZarrGroupHandler, - axes_setup: AxesSetup | None = None, - allow_non_canonical_axes: bool = False, - strict_canonical_order: bool = True, - ) -> LabelMetaHandler: - """Get a label metadata handler.""" - return self._find_image_handler( - group_handler=group_handler, - axes_setup=axes_setup, - allow_non_canonical_axes=allow_non_canonical_axes, - strict_canonical_order=strict_canonical_order, - _ie_name="_label_ie", - _handler=LabelMetaHandler, + def update_meta(self, ngio_meta: NgioPlateMeta) -> None: + """Update the NGIO plate metadata.""" + update_ngio_meta( + group_handler=self._group_handler, + ngio_meta=ngio_meta, ) - def get_label_meta_handler( - self, - group_handler: ZarrGroupHandler, - version: str, - axes_setup: AxesSetup | None = None, - allow_non_canonical_axes: bool = False, - strict_canonical_order: bool = True, - ) -> LabelMetaHandler: - """Get a label metadata handler.""" - if version not in self._label_ie: - raise NgioValueError(f"Label handler for version {version} does not exist.") - - label_ie = self._label_ie[version] - return LabelMetaHandler( - meta_importer=label_ie.importer, - meta_exporter=label_ie.exporter, - group_handler=group_handler, - axes_setup=axes_setup, - allow_non_canonical_axes=allow_non_canonical_axes, - strict_canonical_order=strict_canonical_order, - ) - def register_label_ie( - self, - version: str, - importer: LabelMetaImporter, - exporter: LabelMetaExporter, - overwrite: bool = False, - ): - """Register an importer/exporter.""" - importer_exporter = LabelImporterExporter( - version=version, importer=importer, exporter=exporter - ) - self._register( - version=version, - importer=importer_exporter, - overwrite=overwrite, - _ie_name="_label_ie", - ) +def get_ngio_well_meta( + group_handler: ZarrGroupHandler, + version: str | None = None, +) -> NgioWellMeta: + """Retrieve the NGIO well metadata from the Zarr group. + + Args: + group_handler (ZarrGroupHandler): The Zarr group handler. + version (str | None): Optional NGFF version to use for decoding. + + Returns: + NgioWellMeta: The NGIO well metadata. + """ + return get_ngio_meta( + group_handler=group_handler, + meta_type=NgioWellMeta, + version=version, + ) - def _find_hcs_handler( - self, - group_handler: ZarrGroupHandler, - _ie_name: str = "_well_ie", - _handler: type[_hcs_handler] = WellMetaHandler, - ) -> _hcs_handler: - """Get a handler for a HCS metadata.""" - _errors = {} - - dict_ie = self.__getattribute__(_ie_name) - - for ie in reversed(dict_ie.values()): - handler = _handler( - meta_importer=ie.importer, - meta_exporter=ie.exporter, - group_handler=group_handler, - ) - meta = handler.safe_load_meta() - if isinstance(meta, ValidationError): - _errors[ie.version] = meta - continue - return handler - - raise NgioValidationError( - f"Could not load metadata from any known version. Errors: {_errors}" - ) - def find_well_handler( - self, - group_handler: ZarrGroupHandler, - ) -> WellMetaHandler: - """Get a well metadata handler.""" - return self._find_hcs_handler( - group_handler=group_handler, - _ie_name="_well_ie", - _handler=WellMetaHandler, - ) +def update_ngio_well_meta( + group_handler: ZarrGroupHandler, + ngio_meta: NgioWellMeta, +) -> None: + """Update the NGIO well metadata in the Zarr group. - def get_well_meta_handler( - self, - group_handler: ZarrGroupHandler, - version: str, - ) -> WellMetaHandler: - """Get a well metadata handler.""" - if version not in self._well_ie: - raise NgioValueError(f"Well handler for version {version} does not exist.") - - well_ie = self._well_ie[version] - return WellMetaHandler( - meta_importer=well_ie.importer, - meta_exporter=well_ie.exporter, - group_handler=group_handler, - ) + Args: + group_handler (ZarrGroupHandler): The Zarr group handler. + ngio_meta (NgioWellMeta): The new NGIO well metadata. - def register_well_ie( - self, - version: str, - importer: WellMetaImporter, - exporter: WellMetaExporter, - overwrite: bool = False, - ): - """Register an importer/exporter.""" - importer_exporter = WellImporterExporter( - version=version, importer=importer, exporter=exporter - ) - self._register( - version=version, - importer=importer_exporter, - overwrite=overwrite, - _ie_name="_well_ie", - ) + """ + update_ngio_meta( + group_handler=group_handler, + ngio_meta=ngio_meta, + ) - def find_plate_handler( - self, - group_handler: ZarrGroupHandler, - ) -> PlateMetaHandler: - """Get a plate metadata handler.""" - return self._find_hcs_handler( - group_handler=group_handler, - _ie_name="_plate_ie", - _handler=PlateMetaHandler, - ) - def get_plate_meta_handler( +class WellMetaHandler: + def __init__( self, group_handler: ZarrGroupHandler, - version: str, - ) -> PlateMetaHandler: - """Get a plate metadata handler.""" - if version not in self._plate_ie: - raise NgioValueError(f"Plate handler for version {version} does not exist.") - - plate_ie = self._plate_ie[version] - return PlateMetaHandler( - meta_importer=plate_ie.importer, - meta_exporter=plate_ie.exporter, - group_handler=group_handler, - ) - - def register_plate_ie( - self, - version: str, - importer: PlateMetaImporter, - exporter: PlateMetaExporter, - overwrite: bool = False, + version: str | None = None, ): - """Register an importer/exporter.""" - importer_exporter = PlateImporterExporter( - version=version, importer=importer, exporter=exporter - ) - self._register( - version=version, - importer=importer_exporter, - overwrite=overwrite, - _ie_name="_plate_ie", + self._group_handler = group_handler + self._version = version + + # Validate metadata + _ = self.get_meta() + # Store the resolved version + # self._version = meta.version + + def get_meta(self) -> NgioWellMeta: + """Retrieve the NGIO well metadata.""" + return get_ngio_well_meta( + group_handler=self._group_handler, + version=self._version, ) - -########################################################################### -# -# Register metadata importers/exporters -# -########################################################################### - - -ImplementedMetaImporterExporter().register_image_ie( - version="0.4", - importer=v04_to_ngio_image_meta, - exporter=ngio_to_v04_image_meta, -) -ImplementedMetaImporterExporter().register_label_ie( - version="0.4", - importer=v04_to_ngio_label_meta, - exporter=ngio_to_v04_label_meta, -) -ImplementedMetaImporterExporter().register_well_ie( - version="0.4", importer=v04_to_ngio_well_meta, exporter=ngio_to_v04_well_meta -) -ImplementedMetaImporterExporter().register_plate_ie( - version="0.4", importer=v04_to_ngio_plate_meta, exporter=ngio_to_v04_plate_meta -) - - -########################################################################### -# -# Public functions to avoid direct access to the importer/exporter -# registration methods -# -########################################################################### - - -def find_image_meta_handler( - group_handler: ZarrGroupHandler, - axes_setup: AxesSetup | None = None, - allow_non_canonical_axes: bool = False, - strict_canonical_order: bool = True, -) -> ImageMetaHandler: - """Open an image metadata handler.""" - return ImplementedMetaImporterExporter().find_image_handler( - group_handler=group_handler, - axes_setup=axes_setup, - allow_non_canonical_axes=allow_non_canonical_axes, - strict_canonical_order=strict_canonical_order, - ) + def update_meta(self, ngio_meta: NgioWellMeta) -> None: + """Update the NGIO well metadata.""" + update_ngio_meta( + group_handler=self._group_handler, + ngio_meta=ngio_meta, + ) -def get_image_meta_handler( +def get_ngio_labels_group_meta( group_handler: ZarrGroupHandler, - version: str, - axes_setup: AxesSetup | None = None, - allow_non_canonical_axes: bool = False, - strict_canonical_order: bool = True, -) -> ImageMetaHandler: - """Open an image metadata handler.""" - return ImplementedMetaImporterExporter().get_image_meta_handler( + version: str | None = None, +) -> NgioLabelsGroupMeta: + """Retrieve the NGIO labels group metadata from the Zarr group. + + Args: + group_handler (ZarrGroupHandler): The Zarr group handler. + version (str | None): Optional NGFF version to use for decoding. + + Returns: + NgioLabelsGroupMeta: The NGIO labels group metadata. + """ + return get_ngio_meta( group_handler=group_handler, + meta_type=NgioLabelsGroupMeta, version=version, - axes_setup=axes_setup, - allow_non_canonical_axes=allow_non_canonical_axes, - strict_canonical_order=strict_canonical_order, - ) - - -def find_label_meta_handler( - group_handler: ZarrGroupHandler, - axes_setup: AxesSetup | None = None, - allow_non_canonical_axes: bool = False, - strict_canonical_order: bool = True, -) -> LabelMetaHandler: - """Open a label metadata handler.""" - return ImplementedMetaImporterExporter().find_label_handler( - group_handler=group_handler, - axes_setup=axes_setup, - allow_non_canonical_axes=allow_non_canonical_axes, - strict_canonical_order=strict_canonical_order, ) -def get_label_meta_handler( +def update_ngio_labels_group_meta( group_handler: ZarrGroupHandler, - version: str, - axes_setup: AxesSetup | None = None, - allow_non_canonical_axes: bool = False, - strict_canonical_order: bool = True, -) -> LabelMetaHandler: - """Open a label metadata handler.""" - return ImplementedMetaImporterExporter().get_label_meta_handler( - group_handler=group_handler, - version=version, - axes_setup=axes_setup, - allow_non_canonical_axes=allow_non_canonical_axes, - strict_canonical_order=strict_canonical_order, - ) + ngio_meta: NgioLabelsGroupMeta, +) -> None: + """Update the NGIO labels group metadata in the Zarr group. + Args: + group_handler (ZarrGroupHandler): The Zarr group handler. + ngio_meta (NgioLabelsGroupMeta): The new NGIO labels group metadata. -def find_well_meta_handler(group_handler: ZarrGroupHandler) -> WellMetaHandler: - """Open a well metadata handler.""" - return ImplementedMetaImporterExporter().find_well_handler( + """ + update_ngio_meta( group_handler=group_handler, + ngio_meta=ngio_meta, ) -def get_well_meta_handler( - group_handler: ZarrGroupHandler, - version: str, -) -> WellMetaHandler: - """Open a well metadata handler.""" - return ImplementedMetaImporterExporter().get_well_meta_handler( - group_handler=group_handler, - version=version, - ) +class LabelsGroupMetaHandler: + def __init__( + self, + group_handler: ZarrGroupHandler, + version: NgffVersions | None = None, + ): + self._group_handler = group_handler + self._version = version + meta = self.get_meta() + self._version = meta.version -def find_plate_meta_handler(group_handler: ZarrGroupHandler) -> PlateMetaHandler: - """Open a plate metadata handler.""" - return ImplementedMetaImporterExporter().find_plate_handler( - group_handler=group_handler - ) - + def get_meta(self) -> NgioLabelsGroupMeta: + """Retrieve the NGIO labels group metadata.""" + return get_ngio_labels_group_meta( + group_handler=self._group_handler, + version=self._version, + ) -def get_plate_meta_handler( - group_handler: ZarrGroupHandler, - version: str, -) -> PlateMetaHandler: - """Open a plate metadata handler.""" - return ImplementedMetaImporterExporter().get_plate_meta_handler( - group_handler=group_handler, - version=version, - ) + def update_meta(self, ngio_meta: NgioLabelsGroupMeta) -> None: + """Update the NGIO labels group metadata.""" + update_ngio_labels_group_meta( + group_handler=self._group_handler, + ngio_meta=ngio_meta, + ) diff --git a/src/ngio/ome_zarr_meta/ngio_specs/__init__.py b/src/ngio/ome_zarr_meta/ngio_specs/__init__.py index e696a33a..16181947 100644 --- a/src/ngio/ome_zarr_meta/ngio_specs/__init__.py +++ b/src/ngio/ome_zarr_meta/ngio_specs/__init__.py @@ -40,6 +40,7 @@ NgioImageLabelMeta, NgioImageMeta, NgioLabelMeta, + NgioLabelsGroupMeta, ) from ngio.ome_zarr_meta.ngio_specs._pixel_size import PixelSize @@ -62,6 +63,7 @@ "NgioImageLabelMeta", "NgioImageMeta", "NgioLabelMeta", + "NgioLabelsGroupMeta", "NgioPlateMeta", "NgioWellMeta", "PixelSize", diff --git a/src/ngio/ome_zarr_meta/ngio_specs/_axes.py b/src/ngio/ome_zarr_meta/ngio_specs/_axes.py index 16b34d0b..68c1d523 100644 --- a/src/ngio/ome_zarr_meta/ngio_specs/_axes.py +++ b/src/ngio/ome_zarr_meta/ngio_specs/_axes.py @@ -352,6 +352,7 @@ def axes(self) -> tuple[Axis, ...]: @property def axes_names(self) -> tuple[str, ...]: + """On disk axes names.""" return tuple(ax.name for ax in self._axes) @property diff --git a/src/ngio/ome_zarr_meta/ngio_specs/_dataset.py b/src/ngio/ome_zarr_meta/ngio_specs/_dataset.py index cb7eed17..37a5ddd1 100644 --- a/src/ngio/ome_zarr_meta/ngio_specs/_dataset.py +++ b/src/ngio/ome_zarr_meta/ngio_specs/_dataset.py @@ -60,11 +60,20 @@ def axes_handler(self) -> AxesHandler: @property def pixel_size(self) -> PixelSize: """Return the pixel size for the dataset.""" + scale = self._scale + pix_size_dict = {} + # Mandatory axes: x, y + for ax in ["x", "y"]: + index = self.axes_handler.get_index(ax) + assert index is not None + pix_size_dict[ax] = scale[index] + + for ax in ["z", "t"]: + index = self.axes_handler.get_index(ax) + pix_size_dict[ax] = scale[index] if index is not None else 1.0 + return PixelSize( - x=self.get_scale("x", default=1.0), - y=self.get_scale("y", default=1.0), - z=self.get_scale("z", default=1.0), - t=self.get_scale("t", default=1.0), + **pix_size_dict, space_unit=self.axes_handler.space_unit, time_unit=self.axes_handler.time_unit, ) @@ -78,21 +87,3 @@ def scale(self) -> tuple[float, ...]: def translation(self) -> tuple[float, ...]: """Return the translation as a tuple.""" return tuple(self._translation) - - def get_scale(self, axis_name: str, default: float | None = None) -> float: - """Return the scale for a given axis.""" - idx = self.axes_handler.get_index(axis_name) - if idx is None: - if default is not None: - return default - raise ValueError(f"Axis {axis_name} not found in axes {self.axes_handler}.") - return self._scale[idx] - - def get_translation(self, axis_name: str, default: float | None = None) -> float: - """Return the translation for a given axis.""" - idx = self.axes_handler.get_index(axis_name) - if idx is None: - if default is not None: - return default - raise ValueError(f"Axis {axis_name} not found in axes {self.axes_handler}.") - return self._translation[idx] diff --git a/src/ngio/ome_zarr_meta/ngio_specs/_ngio_hcs.py b/src/ngio/ome_zarr_meta/ngio_specs/_ngio_hcs.py index 5376321b..eaef6ee2 100644 --- a/src/ngio/ome_zarr_meta/ngio_specs/_ngio_hcs.py +++ b/src/ngio/ome_zarr_meta/ngio_specs/_ngio_hcs.py @@ -1,32 +1,31 @@ """HCS (High Content Screening) specific metadata classes for NGIO.""" +import warnings from typing import Annotated -from ome_zarr_models.v04.hcs import HCSAttrs -from ome_zarr_models.v04.plate import ( +from ome_zarr_models.common.plate import ( Acquisition, Column, - Plate, + PlateBase, Row, WellInPlate, ) -from ome_zarr_models.v04.well import WellAttrs as WellAttrs04 -from ome_zarr_models.v04.well_types import WellImage as WellImage04 -from ome_zarr_models.v04.well_types import WellMeta as WellMeta04 +from ome_zarr_models.common.well_types import WellImage as WellImageCommon from pydantic import BaseModel, SkipValidation, field_serializer from ngio.ome_zarr_meta.ngio_specs._ngio_image import DefaultNgffVersion, NgffVersions -from ngio.utils import NgioValueError, ngio_logger +from ngio.utils import NgioValueError def path_in_well_validation(path: str) -> str: """Validate the path in the well.""" # Check if the value contains only alphanumeric characters if not path.isalnum(): - ngio_logger.warning( + warnings.warn( f"Path '{path}' contains non-alphanumeric characters. " "This may cause issues with some tools. " - "Consider using only alphanumeric characters in the path." + "Consider using only alphanumeric characters in the path.", + stacklevel=2, ) return path @@ -41,7 +40,7 @@ class ImageInWellPath(BaseModel): acquisition_name: str | None = None -class CustomWellImage(WellImage04): +class CustomWellImage(WellImageCommon): path: Annotated[str, SkipValidation] @field_serializer("path") @@ -50,32 +49,29 @@ def serialize_path(self, value: str) -> str: return path_in_well_validation(value) -class CustomWellMeta(WellMeta04): - images: list[CustomWellImage] # type: ignore (override of WellMeta04.images) - - -class CustomWellAttrs(WellAttrs04): - well: CustomWellMeta # type: ignore (override of WellAttrs04.well) +class PlateWithVersion(PlateBase): + version: NgffVersions -class NgioWellMeta(CustomWellAttrs): +class NgioWellMeta(BaseModel): """HCS well metadata.""" + images: list[CustomWellImage] # type: ignore (override of WellMeta04.images) + version: NgffVersions + @classmethod def default_init( cls, - version: NgffVersions | None = None, + ngff_version: NgffVersions = DefaultNgffVersion, ) -> "NgioWellMeta": - if version is None: - version = DefaultNgffVersion - well = cls(well=CustomWellMeta(images=[], version=version)) + well = cls(images=[], version=ngff_version) return well @property def acquisition_ids(self) -> list[int]: """Return the acquisition ids in the well.""" acquisitions = [] - for images in self.well.images: + for images in self.images: if ( images.acquisition is not None and images.acquisition not in acquisitions @@ -85,7 +81,7 @@ def acquisition_ids(self) -> list[int]: def get_image_acquisition_id(self, image_path: str) -> int | None: """Return the acquisition id for the given image path.""" - for images in self.well.images: + for images in self.images: if images.path == image_path: return images.acquisition raise NgioValueError(f"Image at path {image_path} not found in the well.") @@ -100,11 +96,9 @@ def paths(self, acquisition: int | None = None) -> list[str]: acquisition (int | None): The acquisition id to filter the images. """ if acquisition is None: - return [images.path for images in self.well.images] + return [images.path for images in self.images] return [ - images.path - for images in self.well.images - if images.acquisition == acquisition + images.path for images in self.images if images.acquisition == acquisition ] def add_image( @@ -118,7 +112,7 @@ def add_image( strict (bool): If True, check if the image already exists in the well. If False, do not check if the image already exists in the well. """ - list_of_images = self.well.images + list_of_images = self.images for image in list_of_images: if image.path == path: raise NgioValueError( @@ -137,9 +131,7 @@ def add_image( new_image = CustomWellImage(path=path, acquisition=acquisition) list_of_images.append(new_image) - return NgioWellMeta( - well=CustomWellMeta(images=list_of_images, version=self.well.version) - ) + return NgioWellMeta(images=list_of_images, version=self.version) def remove_image(self, path: str) -> "NgioWellMeta": """Remove an image from the well. @@ -147,15 +139,11 @@ def remove_image(self, path: str) -> "NgioWellMeta": Args: path (str): The path of the image. """ - list_of_images = self.well.images + list_of_images = self.images for image in list_of_images: if image.path == path: list_of_images.remove(image) - return NgioWellMeta( - well=CustomWellMeta( - images=list_of_images, version=self.well.version - ) - ) + return NgioWellMeta(images=list_of_images, version=self.version) raise NgioValueError(f"Image at path {path} not found in the well.") @@ -218,26 +206,30 @@ def _relabel_wells( return new_wells -class NgioPlateMeta(HCSAttrs): +class NgioPlateMeta(BaseModel): """HCS plate metadata.""" + plate: PlateWithVersion + version: NgffVersions + @classmethod def default_init( cls, images: list[ImageInWellPath] | None = None, name: str | None = None, - version: NgffVersions | None = None, + ngff_version: NgffVersions = DefaultNgffVersion, ) -> "NgioPlateMeta": plate = cls( - plate=Plate( + plate=PlateWithVersion( rows=[], columns=[], acquisitions=None, wells=[], field_count=None, - version=version, name=name, - ) + version=ngff_version, + ), + version=ngff_version, ) if images is None: @@ -346,16 +338,16 @@ def add_row(self, row: str) -> "tuple[NgioPlateMeta, int]": else: wells = self.plate.wells - new_plate = Plate( + new_plate = PlateWithVersion( rows=rows, columns=self.plate.columns, acquisitions=self.plate.acquisitions, wells=wells, field_count=self.plate.field_count, name=self.plate.name, - version=self.plate.version, + version=self.version, ) - return NgioPlateMeta(plate=new_plate), row_idx + return NgioPlateMeta(plate=new_plate, version=self.version), row_idx def add_column(self, column: str | int) -> "tuple[NgioPlateMeta, int]": """Add a column to the plate. @@ -384,16 +376,16 @@ def add_column(self, column: str | int) -> "tuple[NgioPlateMeta, int]": else: wells = self.plate.wells - new_plate = Plate( + new_plate = PlateWithVersion( rows=self.plate.rows, columns=columns, acquisitions=self.plate.acquisitions, wells=wells, field_count=self.plate.field_count, name=self.plate.name, - version=self.plate.version, + version=self.version, ) - return NgioPlateMeta(plate=new_plate), column_idx + return NgioPlateMeta(plate=new_plate, version=self.version), column_idx def add_well( self, @@ -422,16 +414,16 @@ def add_well( ) ) - new_plate = Plate( + new_plate = PlateWithVersion( rows=plate.plate.rows, columns=plate.plate.columns, acquisitions=plate.plate.acquisitions, wells=wells, field_count=plate.plate.field_count, name=plate.plate.name, - version=plate.plate.version, + version=plate.version, ) - return NgioPlateMeta(plate=new_plate) + return NgioPlateMeta(plate=new_plate, version=plate.version) def add_acquisition( self, @@ -461,16 +453,16 @@ def add_acquisition( Acquisition(id=acquisition_id, name=acquisition_name, **acquisition_kwargs) ) - new_plate = Plate( + new_plate = PlateWithVersion( rows=self.plate.rows, columns=self.plate.columns, acquisitions=acquisitions, wells=self.plate.wells, field_count=self.plate.field_count, name=self.plate.name, - version=self.plate.version, + version=self.version, ) - return NgioPlateMeta(plate=new_plate) + return NgioPlateMeta(plate=new_plate, version=self.version) def remove_well(self, row: str, column: str | int) -> "NgioPlateMeta": """Remove a well from the plate. @@ -497,28 +489,28 @@ def remove_well(self, row: str, column: str | int) -> "NgioPlateMeta": f"Well at row {row} and column {column} not found in the plate." ) - new_plate = Plate( + new_plate = PlateWithVersion( rows=self.plate.rows, columns=self.plate.columns, acquisitions=self.plate.acquisitions, wells=wells, field_count=self.plate.field_count, name=self.plate.name, - version=self.plate.version, + version=self.version, ) - return NgioPlateMeta(plate=new_plate) + return NgioPlateMeta(plate=new_plate, version=self.version) def derive( self, name: str | None = None, - version: NgffVersions | None = None, + ngff_version: NgffVersions | None = None, keep_acquisitions: bool = False, ) -> "NgioPlateMeta": """Derive the plate metadata. Args: name (str): The name of the derived plate. - version (NgffVersion | None): The version of the derived plate. + ngff_version (NgffVersion | None): The version of the derived plate. If None, use the version of the original plate. keep_acquisitions (bool): If True, keep the acquisitions in the plate. """ @@ -530,17 +522,18 @@ def derive( else: acquisitions = None - if version is None: - version = self.plate.version # type: ignore (version is NgffVersions or None) + if ngff_version is None: + ngff_version = self.version return NgioPlateMeta( - plate=Plate( + plate=PlateWithVersion( rows=rows, columns=columns, acquisitions=acquisitions, wells=[], field_count=self.plate.field_count, - version=version, name=name, - ) + version=ngff_version, + ), + version=ngff_version, ) diff --git a/src/ngio/ome_zarr_meta/ngio_specs/_ngio_image.py b/src/ngio/ome_zarr_meta/ngio_specs/_ngio_image.py index d64a7b04..35be8e9e 100644 --- a/src/ngio/ome_zarr_meta/ngio_specs/_ngio_image.py +++ b/src/ngio/ome_zarr_meta/ngio_specs/_ngio_image.py @@ -13,11 +13,11 @@ from pydantic import BaseModel from ngio.ome_zarr_meta.ngio_specs._axes import ( + AxesHandler, DefaultSpaceUnit, DefaultTimeUnit, SpaceUnits, TimeUnits, - build_canonical_axes_handler, ) from ngio.ome_zarr_meta.ngio_specs._channels import ChannelsMeta from ngio.ome_zarr_meta.ngio_specs._dataset import Dataset @@ -25,7 +25,7 @@ from ngio.utils import NgioValidationError, NgioValueError T = TypeVar("T") -NgffVersions = Literal["0.4"] +NgffVersions = Literal["0.4", "0.5"] DefaultNgffVersion: Literal["0.4"] = "0.4" @@ -41,6 +41,13 @@ def default_init(cls, version: NgffVersions) -> "ImageLabelSource": return cls(version=version, source={"image": "../../"}) +class NgioLabelsGroupMeta(BaseModel): + """Metadata model for the /labels group in OME-NGFF.""" + + version: NgffVersions + labels: list[str] + + class AbstractNgioImageMeta: """Base class for ImageMeta and LabelMeta.""" @@ -66,42 +73,23 @@ def __repr__(self): @classmethod def default_init( cls, - levels: int | Sequence[str], - axes_names: Sequence[str], - pixel_size: PixelSize, - scaling_factors: Sequence[float] | None = None, + levels: Sequence[str], + axes_handler: AxesHandler, + scales: Sequence[tuple[float, ...]], + translations: Sequence[tuple[float, ...] | None], name: str | None = None, version: NgffVersions = DefaultNgffVersion, ): """Initialize the ImageMeta object.""" - axes_handler = build_canonical_axes_handler( - axes_names, - space_units=pixel_size.space_unit, - time_units=pixel_size.time_unit, - ) - - px_size_dict = pixel_size.as_dict() - scale = [px_size_dict.get(name, 1.0) for name in axes_handler.axes_names] - - if scaling_factors is None: - _default = {"x": 2.0, "y": 2.0} - scaling_factors = [ - _default.get(name, 1.0) for name in axes_handler.axes_names - ] - - if isinstance(levels, int): - levels = [str(i) for i in range(levels)] - datasets = [] - for level in levels: + for level, scale, translation in zip(levels, scales, translations, strict=True): dataset = Dataset( path=level, axes_handler=axes_handler, scale=scale, - translation=None, + translation=translation, ) datasets.append(dataset) - scale = [s * f for s, f in zip(scale, scaling_factors, strict=True)] return cls( version=version, @@ -146,6 +134,17 @@ def version(self) -> NgffVersions: """Version of the OME-NFF metadata used to build the object.""" return self._version # type: ignore (version is a Literal type) + @property + def zarr_format(self) -> Literal[2, 3]: + """Zarr version used to store the data.""" + match self.version: + case "0.4": + return 2 + case "0.5": + return 3 + case _: + raise NgioValueError(f"Unsupported NGFF version: {self.version}") + @property def name(self) -> str | None: """Name of the image.""" @@ -326,50 +325,15 @@ def _get_closest_datasets(self, path: str | None = None) -> tuple[Dataset, Datas ) return dataset, lr_dataset - def scaling_factor(self, path: str | None = None) -> list[float]: - """Get the scaling factors from a dataset to its lower resolution.""" - if self.levels == 1: - return [1.0] * len(self.axes_handler.axes_names) - dataset, lr_dataset = self._get_closest_datasets(path=path) - - scaling_factors = [] - for ax_name in self.axes_handler.axes_names: - s_d = dataset.get_scale(ax_name) - s_lr_d = lr_dataset.get_scale(ax_name) - scaling_factors.append(s_lr_d / s_d) - return scaling_factors - - def yx_scaling(self, path: str | None = None) -> tuple[float, float]: - """Get the scaling factor from a dataset to its lower resolution.""" - if self.levels == 1: - return 1.0, 1.0 - dataset, lr_dataset = self._get_closest_datasets(path=path) - - if lr_dataset is None: - raise NgioValueError( - "No lower resolution dataset found. " - "This is the lowest resolution dataset." - ) - - s_d = dataset.get_scale("y") - s_lr_d = lr_dataset.get_scale("y") - scale_y = s_lr_d / s_d - - s_d = dataset.get_scale("x") - s_lr_d = lr_dataset.get_scale("x") - scale_x = s_lr_d / s_d - - return scale_y, scale_x - - def z_scaling(self, path: str | None = None) -> float: - """Get the scaling factor from a dataset to its lower resolution.""" + def scaling_factor(self, path: str | None = None) -> tuple[float, ...]: + """Get the scaling factors to downscale to the next lower resolution dataset.""" if self.levels == 1: - return 1.0 + return (1.0,) * len(self.axes_handler.axes_names) dataset, lr_dataset = self._get_closest_datasets(path=path) - - s_d = dataset.get_scale("z", default=1.0) - s_lr_d = lr_dataset.get_scale("z", default=1.0) - return s_lr_d / s_d + scale = dataset.scale + lr_scale = lr_dataset.scale + scaling_factors = [s / s_lr for s_lr, s in zip(scale, lr_scale, strict=True)] + return tuple(scaling_factors) class NgioLabelMeta(AbstractNgioImageMeta): diff --git a/src/ngio/ome_zarr_meta/v04/__init__.py b/src/ngio/ome_zarr_meta/v04/__init__.py index adeb9be9..76acf41c 100644 --- a/src/ngio/ome_zarr_meta/v04/__init__.py +++ b/src/ngio/ome_zarr_meta/v04/__init__.py @@ -1,12 +1,14 @@ """Utility to read/write OME-Zarr metadata v0.4.""" -from ngio.ome_zarr_meta.v04._v04_spec_utils import ( +from ngio.ome_zarr_meta.v04._v04_spec import ( ngio_to_v04_image_meta, ngio_to_v04_label_meta, + ngio_to_v04_labels_group_meta, ngio_to_v04_plate_meta, ngio_to_v04_well_meta, v04_to_ngio_image_meta, v04_to_ngio_label_meta, + v04_to_ngio_labels_group_meta, v04_to_ngio_plate_meta, v04_to_ngio_well_meta, ) @@ -14,10 +16,12 @@ __all__ = [ "ngio_to_v04_image_meta", "ngio_to_v04_label_meta", + "ngio_to_v04_labels_group_meta", "ngio_to_v04_plate_meta", "ngio_to_v04_well_meta", "v04_to_ngio_image_meta", "v04_to_ngio_label_meta", + "v04_to_ngio_labels_group_meta", "v04_to_ngio_plate_meta", "v04_to_ngio_well_meta", ] diff --git a/src/ngio/ome_zarr_meta/v04/_v04_spec_utils.py b/src/ngio/ome_zarr_meta/v04/_v04_spec.py similarity index 85% rename from src/ngio/ome_zarr_meta/v04/_v04_spec_utils.py rename to src/ngio/ome_zarr_meta/v04/_v04_spec.py index 0f03c9dd..33109268 100644 --- a/src/ngio/ome_zarr_meta/v04/_v04_spec_utils.py +++ b/src/ngio/ome_zarr_meta/v04/_v04_spec.py @@ -9,7 +9,6 @@ - A function to convert a ngio image metadata to a v04 image metadata. """ -from ome_zarr_models.common.multiscales import ValidTransform as ValidTransformV04 from ome_zarr_models.v04.axes import Axis as AxisV04 from ome_zarr_models.v04.coordinate_transformations import VectorScale as VectorScaleV04 from ome_zarr_models.v04.coordinate_transformations import ( @@ -17,13 +16,14 @@ ) from ome_zarr_models.v04.hcs import HCSAttrs as HCSAttrsV04 from ome_zarr_models.v04.image import ImageAttrs as ImageAttrsV04 -from ome_zarr_models.v04.image_label import ImageLabelAttrs as LabelAttrsV04 +from ome_zarr_models.v04.image_label import ImageLabelAttrs as ImageLabelAttrsV04 +from ome_zarr_models.v04.labels import LabelsAttrs as LabelsAttrsV04 from ome_zarr_models.v04.multiscales import Dataset as DatasetV04 from ome_zarr_models.v04.multiscales import Multiscale as MultiscaleV04 +from ome_zarr_models.v04.multiscales import ValidTransform as ValidTransformV04 from ome_zarr_models.v04.omero import Channel as ChannelV04 from ome_zarr_models.v04.omero import Omero as OmeroV04 from ome_zarr_models.v04.omero import Window as WindowV04 -from pydantic import ValidationError from ngio.ome_zarr_meta.ngio_specs import ( AxesHandler, @@ -37,6 +37,7 @@ ImageLabelSource, NgioImageMeta, NgioLabelMeta, + NgioLabelsGroupMeta, NgioPlateMeta, NgioWellMeta, default_channel_name, @@ -44,37 +45,6 @@ from ngio.ome_zarr_meta.v04._custom_models import CustomWellAttrs as WellAttrsV04 -def _is_v04_image_meta(metadata: dict) -> ImageAttrsV04 | ValidationError: - """Check if the metadata is a valid OME-Zarr v04 metadata. - - Args: - metadata (dict): The metadata to check. - - Returns: - bool: True if the metadata is a valid OME-Zarr v04 metadata, False otherwise. - """ - try: - return ImageAttrsV04(**metadata) - except ValidationError as e: - return e - - -def _is_v04_label_meta(metadata: dict) -> LabelAttrsV04 | ValidationError: - """Check if the metadata is a valid OME-Zarr v04 metadata. - - Args: - metadata (dict): The metadata to check. - - Returns: - bool: True if the metadata is a valid OME-Zarr v04 metadata, False otherwise. - """ - try: - return LabelAttrsV04(**metadata) - except ValidationError as e: - return e - raise RuntimeError("Unreachable code") - - def _v04_omero_to_channels(v04_omero: OmeroV04 | None) -> ChannelsMeta | None: if v04_omero is None: return None @@ -169,7 +139,7 @@ def _v04_to_ngio_datasets( unit = str(unit) axes.append( Axis( - name=v04_axis.name, + name=str(v04_axis.name), axis_type=AxisType(v04_axis.type), # (for some reason the type is a generic JsonValue, # but it should be a string or None) @@ -203,7 +173,7 @@ def v04_to_ngio_image_meta( axes_setup: AxesSetup | None = None, allow_non_canonical_axes: bool = False, strict_canonical_order: bool = True, -) -> tuple[bool, NgioImageMeta | ValidationError]: +) -> NgioImageMeta: """Convert a v04 image metadata to a ngio image metadata. Args: @@ -216,9 +186,7 @@ def v04_to_ngio_image_meta( Returns: NgioImageMeta: The ngio image metadata. """ - v04_image = _is_v04_image_meta(metadata) - if isinstance(v04_image, ValidationError): - return False, v04_image + v04_image = ImageAttrsV04(**metadata) if len(v04_image.multiscales) > 1: raise NotImplementedError( @@ -239,7 +207,7 @@ def v04_to_ngio_image_meta( name = v04_muliscale.name if name is not None and not isinstance(name, str): name = str(name) - return True, NgioImageMeta( + return NgioImageMeta( version="0.4", name=name, datasets=datasets, @@ -252,7 +220,7 @@ def v04_to_ngio_label_meta( axes_setup: AxesSetup | None = None, allow_non_canonical_axes: bool = False, strict_canonical_order: bool = True, -) -> tuple[bool, NgioLabelMeta | ValidationError]: +) -> NgioLabelMeta: """Convert a v04 image metadata to a ngio image metadata. Args: @@ -265,9 +233,7 @@ def v04_to_ngio_label_meta( Returns: NgioImageMeta: The ngio image metadata. """ - v04_label = _is_v04_label_meta(metadata) - if isinstance(v04_label, ValidationError): - return False, v04_label + v04_label = ImageLabelAttrsV04(**metadata) if len(v04_label.multiscales) > 1: raise NotImplementedError( @@ -301,7 +267,7 @@ def v04_to_ngio_label_meta( if name is not None and not isinstance(name, str): name = str(name) - return True, NgioLabelMeta( + return NgioLabelMeta( version="0.4", name=name, datasets=datasets, @@ -417,48 +383,55 @@ def ngio_to_v04_label_meta(metadata: NgioLabelMeta) -> dict: "multiscales": [v04_muliscale], "image-label": metadata.image_label.model_dump(), } - v04_label = LabelAttrsV04(**labels_meta) + v04_label = ImageLabelAttrsV04(**labels_meta) return v04_label.model_dump(exclude_none=True, by_alias=True) +def v04_to_ngio_labels_group_meta( + metadata: dict, +) -> NgioLabelsGroupMeta: + """Convert a v04 label group metadata to a ngio label group metadata. + + Args: + metadata (dict): The v04 label group metadata. + + Returns: + NgioLabelGroupMeta: The ngio label group metadata. + """ + v04_label_group = LabelsAttrsV04(**metadata).model_dump() + labels = v04_label_group.get("labels", []) + return NgioLabelsGroupMeta(labels=labels, version="0.4") + + def v04_to_ngio_well_meta( metadata: dict, -) -> tuple[bool, NgioWellMeta | ValidationError]: +) -> NgioWellMeta: """Convert a v04 well metadata to a ngio well metadata. Args: metadata (dict): The v04 well metadata. Returns: - result (bool): True if the conversion was successful, False otherwise. - ngio_well_meta (NgioWellMeta): The ngio well metadata. + NgioWellMeta: The ngio well metadata. """ - try: - v04_well = WellAttrsV04(**metadata) - except ValidationError as e: - return False, e - - return True, NgioWellMeta(**v04_well.model_dump()) + v04_well = WellAttrsV04(**metadata).well.model_dump() + images = v04_well.get("images", []) + return NgioWellMeta(images=images, version="0.4") def v04_to_ngio_plate_meta( metadata: dict, -) -> tuple[bool, NgioPlateMeta | ValidationError]: +) -> NgioPlateMeta: """Convert a v04 plate metadata to a ngio plate metadata. Args: metadata (dict): The v04 plate metadata. Returns: - result (bool): True if the conversion was successful, False otherwise. - ngio_plate_meta (NgioPlateMeta): The ngio plate metadata. + NgioPlateMeta: The ngio plate metadata. """ - try: - v04_plate = HCSAttrsV04(**metadata) - except ValidationError as e: - return False, e - - return True, NgioPlateMeta(**v04_plate.model_dump()) + v04_plate = HCSAttrsV04(**metadata).plate.model_dump() + return NgioPlateMeta(plate=v04_plate, version="0.4") # type: ignore def ngio_to_v04_well_meta(metadata: NgioWellMeta) -> dict: @@ -470,7 +443,7 @@ def ngio_to_v04_well_meta(metadata: NgioWellMeta) -> dict: Returns: dict: The v04 well metadata. """ - v04_well = WellAttrsV04(**metadata.model_dump()) + v04_well = WellAttrsV04(well=metadata.model_dump()) # type: ignore return v04_well.model_dump(exclude_none=True, by_alias=True) @@ -485,3 +458,16 @@ def ngio_to_v04_plate_meta(metadata: NgioPlateMeta) -> dict: """ v04_plate = HCSAttrsV04(**metadata.model_dump()) return v04_plate.model_dump(exclude_none=True, by_alias=True) + + +def ngio_to_v04_labels_group_meta(metadata: NgioLabelsGroupMeta) -> dict: + """Convert a ngio label group metadata to a v04 label group metadata. + + Args: + metadata (NgioLabelsGroupMeta): The ngio label group metadata. + + Returns: + dict: The v04 label group metadata. + """ + v04_label_group = LabelsAttrsV04(labels=metadata.labels) + return v04_label_group.model_dump(exclude_none=True, by_alias=True) diff --git a/src/ngio/ome_zarr_meta/v05/__init__.py b/src/ngio/ome_zarr_meta/v05/__init__.py new file mode 100644 index 00000000..a42ace47 --- /dev/null +++ b/src/ngio/ome_zarr_meta/v05/__init__.py @@ -0,0 +1,27 @@ +"""Utility to read/write OME-Zarr metadata v0.4.""" + +from ngio.ome_zarr_meta.v05._v05_spec import ( + ngio_to_v05_image_meta, + ngio_to_v05_label_meta, + ngio_to_v05_labels_group_meta, + ngio_to_v05_plate_meta, + ngio_to_v05_well_meta, + v05_to_ngio_image_meta, + v05_to_ngio_label_meta, + v05_to_ngio_labels_group_meta, + v05_to_ngio_plate_meta, + v05_to_ngio_well_meta, +) + +__all__ = [ + "ngio_to_v05_image_meta", + "ngio_to_v05_label_meta", + "ngio_to_v05_labels_group_meta", + "ngio_to_v05_plate_meta", + "ngio_to_v05_well_meta", + "v05_to_ngio_image_meta", + "v05_to_ngio_label_meta", + "v05_to_ngio_labels_group_meta", + "v05_to_ngio_plate_meta", + "v05_to_ngio_well_meta", +] diff --git a/src/ngio/ome_zarr_meta/v05/_custom_models.py b/src/ngio/ome_zarr_meta/v05/_custom_models.py new file mode 100644 index 00000000..0019fbbd --- /dev/null +++ b/src/ngio/ome_zarr_meta/v05/_custom_models.py @@ -0,0 +1,18 @@ +from typing import Annotated + +from ome_zarr_models.v05.well import WellAttrs as WellAttrs05 +from ome_zarr_models.v05.well_types import WellImage as WellImage05 +from ome_zarr_models.v05.well_types import WellMeta as WellMeta05 +from pydantic import SkipValidation + + +class CustomWellImage(WellImage05): + path: Annotated[str, SkipValidation] + + +class CustomWellMeta(WellMeta05): + images: list[CustomWellImage] # type: ignore[valid-type] + + +class CustomWellAttrs(WellAttrs05): + well: CustomWellMeta # type: ignore[valid-type] diff --git a/src/ngio/ome_zarr_meta/v05/_v05_spec.py b/src/ngio/ome_zarr_meta/v05/_v05_spec.py new file mode 100644 index 00000000..b1d2a530 --- /dev/null +++ b/src/ngio/ome_zarr_meta/v05/_v05_spec.py @@ -0,0 +1,511 @@ +"""Utilities for OME-Zarr v05 specs. + +This module provides a set of classes to internally handle the metadata +of the OME-Zarr v05 specification. + +For Images and Labels implements the following functionalities: +- A function to find if a dict view of the metadata is a valid OME-Zarr v05 metadata. +- A function to convert a v05 image metadata to a ngio image metadata. +- A function to convert a ngio image metadata to a v05 image metadata. +""" + +from ome_zarr_models.common.omero import Channel as ChannelV05 +from ome_zarr_models.common.omero import Omero as OmeroV05 +from ome_zarr_models.common.omero import Window as WindowV05 +from ome_zarr_models.v05.axes import Axis as AxisV05 +from ome_zarr_models.v05.coordinate_transformations import VectorScale as VectorScaleV05 +from ome_zarr_models.v05.coordinate_transformations import ( + VectorTranslation as VectorTranslationV05, +) +from ome_zarr_models.v05.hcs import HCSAttrs as HCSAttrsV05 +from ome_zarr_models.v05.image import ImageAttrs as ImageAttrsV05 +from ome_zarr_models.v05.image_label import ImageLabelAttrs as ImageLabelAttrsV05 +from ome_zarr_models.v05.labels import Labels as Labels +from ome_zarr_models.v05.labels import LabelsAttrs as LabelsAttrsV05 +from ome_zarr_models.v05.multiscales import Dataset as DatasetV05 +from ome_zarr_models.v05.multiscales import Multiscale as MultiscaleV05 +from ome_zarr_models.v05.multiscales import ValidTransform as ValidTransformV05 +from pydantic import BaseModel + +from ngio.ome_zarr_meta.ngio_specs import ( + AxesHandler, + AxesSetup, + Axis, + AxisType, + Channel, + ChannelsMeta, + ChannelVisualisation, + Dataset, + ImageLabelSource, + NgioImageMeta, + NgioLabelMeta, + NgioLabelsGroupMeta, + NgioPlateMeta, + NgioWellMeta, + default_channel_name, +) +from ngio.ome_zarr_meta.v05._custom_models import CustomWellAttrs as WellAttrsV05 + + +class ImageV05AttrsWithOmero(ImageAttrsV05): + omero: OmeroV05 | None = None + + +class ImageV05WithOmero(BaseModel): + ome: ImageV05AttrsWithOmero + + +class ImageLabelV05(BaseModel): + ome: ImageLabelAttrsV05 + + +def _v05_omero_to_channels(v05_omero: OmeroV05 | None) -> ChannelsMeta | None: + if v05_omero is None: + return None + + ngio_channels = [] + for idx, v05_channel in enumerate(v05_omero.channels): + channel_extra = v05_channel.model_extra + + if channel_extra is None: + channel_extra = {} + + if "label" in channel_extra: + label = channel_extra.pop("label") + else: + label = default_channel_name(idx) + + if "wavelength_id" in channel_extra: + wavelength_id = channel_extra.pop("wavelength_id") + else: + wavelength_id = label + + if "active" in channel_extra: + active = channel_extra.pop("active") + else: + active = True + + channel_visualisation = ChannelVisualisation( + color=v05_channel.color, + start=v05_channel.window.start, + end=v05_channel.window.end, + min=v05_channel.window.min, + max=v05_channel.window.max, + active=active, + **channel_extra, + ) + + ngio_channels.append( + Channel( + label=label, + wavelength_id=wavelength_id, + channel_visualisation=channel_visualisation, + ) + ) + + v05_omero_extra = v05_omero.model_extra if v05_omero.model_extra is not None else {} + return ChannelsMeta(channels=ngio_channels, **v05_omero_extra) + + +def _compute_scale_translation( + v05_transforms: ValidTransformV05, + scale: list[float], + translation: list[float], +) -> tuple[list[float], list[float]]: + for v05_transform in v05_transforms: + if isinstance(v05_transform, VectorScaleV05): + scale = [t1 * t2 for t1, t2 in zip(scale, v05_transform.scale, strict=True)] + + elif isinstance(v05_transform, VectorTranslationV05): + translation = [ + t1 + t2 + for t1, t2 in zip(translation, v05_transform.translation, strict=True) + ] + else: + raise NotImplementedError( + f"Coordinate transformation {v05_transform} is not supported." + ) + return scale, translation + + +def _v05_to_ngio_datasets( + v05_multiscale: MultiscaleV05, + axes_setup: AxesSetup, + allow_non_canonical_axes: bool = False, + strict_canonical_order: bool = True, +) -> list[Dataset]: + """Convert a v05 multiscale to a list of ngio datasets.""" + datasets = [] + + global_scale = [1.0] * len(v05_multiscale.axes) + global_translation = [0.0] * len(v05_multiscale.axes) + + if v05_multiscale.coordinateTransformations is not None: + global_scale, global_translation = _compute_scale_translation( + v05_multiscale.coordinateTransformations, global_scale, global_translation + ) + + # Prepare axes handler + axes = [] + for v05_axis in v05_multiscale.axes: + unit = v05_axis.unit + if unit is not None and not isinstance(unit, str): + unit = str(unit) + axes.append( + Axis( + name=str(v05_axis.name), + axis_type=AxisType(v05_axis.type), + # (for some reason the type is a generic JsonValue, + # but it should be a string or None) + unit=v05_axis.unit, # type: ignore + ) + ) + axes_handler = AxesHandler( + axes=axes, + axes_setup=axes_setup, + allow_non_canonical_axes=allow_non_canonical_axes, + strict_canonical_order=strict_canonical_order, + ) + + for v05_dataset in v05_multiscale.datasets: + _scale, _translation = _compute_scale_translation( + v05_dataset.coordinateTransformations, global_scale, global_translation + ) + datasets.append( + Dataset( + path=v05_dataset.path, + axes_handler=axes_handler, + scale=_scale, + translation=_translation, + ) + ) + return datasets + + +def v05_to_ngio_image_meta( + metadata: dict, + axes_setup: AxesSetup | None = None, + allow_non_canonical_axes: bool = False, + strict_canonical_order: bool = True, +) -> NgioImageMeta: + """Convert a v05 image metadata to a ngio image metadata. + + Args: + metadata (dict): The v05 image metadata. + axes_setup (AxesSetup, optional): The axes setup. This is + required to convert image with non-canonical axes names. + allow_non_canonical_axes (bool, optional): Allow non-canonical axes. + strict_canonical_order (bool, optional): Strict canonical order. + + Returns: + NgioImageMeta: The ngio image metadata. + """ + v05_image = ImageV05WithOmero(**metadata) + v05_image = v05_image.ome + if len(v05_image.multiscales) > 1: + raise NotImplementedError( + "Multiple multiscales in a single image are not supported in ngio." + ) + + v05_multiscale = v05_image.multiscales[0] + + channels_meta = _v05_omero_to_channels(v05_image.omero) + axes_setup = axes_setup if axes_setup is not None else AxesSetup() + datasets = _v05_to_ngio_datasets( + v05_multiscale, + axes_setup=axes_setup, + allow_non_canonical_axes=allow_non_canonical_axes, + strict_canonical_order=strict_canonical_order, + ) + + name = v05_multiscale.name + if name is not None and not isinstance(name, str): + name = str(name) + return NgioImageMeta( + version="0.5", + name=name, + datasets=datasets, + channels=channels_meta, + ) + + +def v05_to_ngio_label_meta( + metadata: dict, + axes_setup: AxesSetup | None = None, + allow_non_canonical_axes: bool = False, + strict_canonical_order: bool = True, +) -> NgioLabelMeta: + """Convert a v05 image metadata to a ngio image metadata. + + Args: + metadata (dict): The v05 image metadata. + axes_setup (AxesSetup, optional): The axes setup. This is + required to convert image with non-canonical axes names. + allow_non_canonical_axes (bool, optional): Allow non-canonical axes. + strict_canonical_order (bool, optional): Strict canonical order. + + Returns: + NgioLabelMeta: The ngio label metadata. + """ + v05_label = ImageLabelV05(**metadata) + v05_label = v05_label.ome + + if len(v05_label.multiscales) > 1: + raise NotImplementedError( + "Multiple multiscales in a single image are not supported in ngio." + ) + + v05_multiscale = v05_label.multiscales[0] + + axes_setup = axes_setup if axes_setup is not None else AxesSetup() + datasets = _v05_to_ngio_datasets( + v05_multiscale, + axes_setup=axes_setup, + allow_non_canonical_axes=allow_non_canonical_axes, + strict_canonical_order=strict_canonical_order, + ) + + if v05_label.image_label is not None: + source = v05_label.image_label.source + if source is None: + image_label_source = None + else: + source = v05_label.image_label.source + if source is None: + image_label_source = None + else: + image_label_source = source.image + image_label_source = ImageLabelSource( + version="0.5", + source={"image": image_label_source}, + ) + else: + image_label_source = None + name = v05_multiscale.name + if name is not None and not isinstance(name, str): + name = str(name) + + return NgioLabelMeta( + version="0.5", + name=name, + datasets=datasets, + image_label=image_label_source, + ) + + +def _ngio_to_v05_multiscale(name: str | None, datasets: list[Dataset]) -> MultiscaleV05: + """Convert a ngio multiscale to a v05 multiscale. + + Args: + name (str | None): The name of the multiscale. + datasets (list[Dataset]): The ngio datasets. + + Returns: + MultiscaleV05: The v05 multiscale. + """ + ax_mapper = datasets[0].axes_handler + v05_axes = [] + for axis in ax_mapper.axes: + v05_axes.append( + AxisV05( + name=axis.name, + type=axis.axis_type.value if axis.axis_type is not None else None, + unit=axis.unit if axis.unit is not None else None, + ) + ) + + v05_datasets = [] + for dataset in datasets: + transform = [VectorScaleV05(type="scale", scale=list(dataset._scale))] + if sum(dataset._translation) > 0: + transform = ( + VectorScaleV05(type="scale", scale=list(dataset._scale)), + VectorTranslationV05( + type="translation", translation=list(dataset._translation) + ), + ) + else: + transform = (VectorScaleV05(type="scale", scale=list(dataset._scale)),) + + v05_datasets.append( + DatasetV05(path=dataset.path, coordinateTransformations=transform) + ) + return MultiscaleV05(axes=v05_axes, datasets=tuple(v05_datasets), name=name) + + +def _ngio_to_v05_omero(channels: ChannelsMeta | None) -> OmeroV05 | None: + """Convert a ngio channels to a v05 omero.""" + if channels is None: + return None + + v05_channels = [] + for channel in channels.channels: + _model_extra = { + "label": channel.label, + "wavelength_id": channel.wavelength_id, + "active": channel.channel_visualisation.active, + } + if channel.channel_visualisation.model_extra is not None: + _model_extra.update(channel.channel_visualisation.model_extra) + + v05_channels.append( + ChannelV05( + color=channel.channel_visualisation.valid_color, + window=WindowV05( + start=channel.channel_visualisation.start, + end=channel.channel_visualisation.end, + min=channel.channel_visualisation.min, + max=channel.channel_visualisation.max, + ), + **_model_extra, + ) + ) + + _model_extra = channels.model_extra if channels.model_extra is not None else {} + return OmeroV05(channels=v05_channels, **_model_extra) + + +def ngio_to_v05_image_meta(metadata: NgioImageMeta) -> dict: + """Convert a ngio image metadata to a v05 image metadata. + + Args: + metadata (NgioImageMeta): The ngio image metadata. + + Returns: + dict: The v05 image metadata. + """ + v05_muliscale = _ngio_to_v05_multiscale( + name=metadata.name, datasets=metadata.datasets + ) + v05_omero = _ngio_to_v05_omero(metadata._channels_meta) + + v05_image_attrs = ImageV05AttrsWithOmero( + multiscales=[v05_muliscale], omero=v05_omero, version="0.5" + ) + v05_image = ImageV05WithOmero( + ome=v05_image_attrs, + ) + return v05_image.model_dump(exclude_none=True, by_alias=True) + + +def ngio_to_v05_label_meta(metadata: NgioLabelMeta) -> dict: + """Convert a ngio image metadata to a v05 image metadata. + + Args: + metadata (NgioImageMeta): The ngio image metadata. + + Returns: + dict: The v05 image metadata. + """ + v05_muliscale = _ngio_to_v05_multiscale( + name=metadata.name, datasets=metadata.datasets + ) + labels_meta = { + "multiscales": [v05_muliscale], + "image-label": metadata.image_label.model_dump(), + } + v05_label = ImageLabelAttrsV05(**labels_meta, version="0.5") + v05_label = ImageLabelV05( + ome=v05_label, + ) + return v05_label.model_dump(exclude_none=True, by_alias=True) + + +class WellV05(BaseModel): + ome: WellAttrsV05 + + +class HCSV05(BaseModel): + ome: HCSAttrsV05 + + +def v05_to_ngio_well_meta( + metadata: dict, +) -> NgioWellMeta: + """Convert a v05 well metadata to a ngio well metadata. + + Args: + metadata (dict): The v05 well metadata. + + Returns: + NgioWellMeta: The ngio well metadata. + """ + v05_well = WellV05(**metadata).ome.well.model_dump() + images = v05_well.get("images", []) + return NgioWellMeta(images=images, version="0.5") + + +def v05_to_ngio_plate_meta( + metadata: dict, +) -> NgioPlateMeta: + """Convert a v05 plate metadata to a ngio plate metadata. + + Args: + metadata (dict): The v05 plate metadata. + + Returns: + NgioPlateMeta: The ngio plate metadata. + """ + v05_plate = HCSV05(**metadata).ome.plate.model_dump() + return NgioPlateMeta(plate=v05_plate, version="0.5") # type: ignore + + +def ngio_to_v05_well_meta(metadata: NgioWellMeta) -> dict: + """Convert a ngio well metadata to a v05 well metadata. + + Args: + metadata (NgioWellMeta): The ngio well metadata. + + Returns: + dict: The v05 well metadata. + """ + v05_well = WellAttrsV05(well=metadata.model_dump()) # type: ignore + v05_well = WellV05(ome=v05_well) + return v05_well.model_dump(exclude_none=True, by_alias=True) + + +def ngio_to_v05_plate_meta(metadata: NgioPlateMeta) -> dict: + """Convert a ngio plate metadata to a v05 plate metadata. + + Args: + metadata (NgioPlateMeta): The ngio plate metadata. + + Returns: + dict: The v05 plate metadata. + """ + v05_plate = HCSAttrsV05(**metadata.model_dump()) + v05_plate = HCSV05(ome=v05_plate) + return v05_plate.model_dump(exclude_none=True, by_alias=True) + + +class LabelsV05(BaseModel): + ome: LabelsAttrsV05 + + +def v05_to_ngio_labels_group_meta( + metadata: dict, +) -> NgioLabelsGroupMeta: + """Convert a v04 label group metadata to a ngio label group metadata. + + Args: + metadata (dict): The v04 label group metadata. + + Returns: + NgioLabelGroupMeta: The ngio label group metadata. + """ + v05_label_group = LabelsV05(**metadata) + return NgioLabelsGroupMeta(labels=v05_label_group.ome.labels, version="0.5") + + +def ngio_to_v05_labels_group_meta(metadata: NgioLabelsGroupMeta) -> dict: + """Convert a ngio label group metadata to a v05 label group metadata. + + Args: + metadata (NgioLabelsGroupMeta): The ngio label group metadata. + + Returns: + dict: The v05 label group metadata. + """ + v05_labels_attrs = LabelsAttrsV05(labels=metadata.labels, version="0.5") + v05_labels_group = LabelsV05(ome=v05_labels_attrs) + return v05_labels_group.model_dump(exclude_none=True, by_alias=True) diff --git a/src/ngio/tables/_tables_container.py b/src/ngio/tables/_tables_container.py index e727ee35..03cf1db6 100644 --- a/src/ngio/tables/_tables_container.py +++ b/src/ngio/tables/_tables_container.py @@ -258,7 +258,7 @@ def _get_tables_list(self) -> list[str]: def _get_table_group_handler(self, name: str) -> ZarrGroupHandler: """Get the group handler for a table.""" - handler = self._group_handler.derive_handler(path=name) + handler = self._group_handler.get_handler(path=name) return handler def list(self, filter_types: TypedTable | str | None = None) -> list[str]: @@ -311,6 +311,27 @@ def get_as( backend=backend, ) # type: ignore[return-value] + def delete(self, name: str, missing_ok: bool = False) -> None: + """Delete a table from the group. + + Args: + name (str): The name of the table to delete. + missing_ok (bool): If True, do not raise an error if + the table does not exist. + """ + existing_tables = self._get_tables_list() + if name not in existing_tables: + if missing_ok: + return + raise NgioValueError( + f"Table '{name}' not found in the Tables group. " + f"Available tables: {existing_tables}" + ) + + self._group_handler.delete_group(name) + existing_tables.remove(name) + self._group_handler.write_attrs({"tables": existing_tables}) + def add( self, name: str, @@ -326,9 +347,7 @@ def add( "Use overwrite=True to replace it." ) - table_handler = self._group_handler.derive_handler( - path=name, overwrite=overwrite - ) + table_handler = self._group_handler.get_handler(path=name, overwrite=overwrite) if backend is None: backend = table.backend_name @@ -360,12 +379,9 @@ def open_tables_container( store: StoreOrGroup, cache: bool = False, mode: AccessModeLiteral = "r+", - parallel_safe: bool = False, ) -> TablesContainer: """Open a table handler from a Zarr store.""" - handler = ZarrGroupHandler( - store=store, cache=cache, mode=mode, parallel_safe=parallel_safe - ) + handler = ZarrGroupHandler(store=store, cache=cache, mode=mode) return TablesContainer(handler) @@ -374,11 +390,12 @@ def open_table( backend: TableBackend | None = None, cache: bool = False, mode: AccessModeLiteral = "r+", - parallel_safe: bool = False, ) -> Table: """Open a table from a Zarr store.""" handler = ZarrGroupHandler( - store=store, cache=cache, mode=mode, parallel_safe=parallel_safe + store=store, + cache=cache, + mode=mode, ) meta = _get_meta(handler) return ImplementedTables().get_table( @@ -392,11 +409,12 @@ def open_table_as( backend: TableBackend | None = None, cache: bool = False, mode: AccessModeLiteral = "r+", - parallel_safe: bool = False, ) -> TableType: """Open a table from a Zarr store as a specific type.""" handler = ZarrGroupHandler( - store=store, cache=cache, mode=mode, parallel_safe=parallel_safe + store=store, + cache=cache, + mode=mode, ) return table_cls.from_handler( handler=handler, @@ -410,12 +428,20 @@ def write_table( backend: TableBackend = DefaultTableBackend, cache: bool = False, mode: AccessModeLiteral = "a", - parallel_safe: bool = False, ) -> None: - """Write a table to a Zarr store.""" - handler = ZarrGroupHandler( - store=store, cache=cache, mode=mode, parallel_safe=parallel_safe - ) + """Write a table to a Zarr store. + + A table will be created at the given store location. + + Args: + store (StoreOrGroup): The Zarr store or group to write the table to. + table (Table): The table to write. + backend (TableBackend): The backend to use for writing the table. + cache (bool): Whether to use caching for the Zarr group handler. + mode (AccessModeLiteral): The access mode to use for the Zarr group handler. + + """ + handler = ZarrGroupHandler(store=store, cache=cache, mode=mode) table.set_backend( handler=handler, backend=backend, diff --git a/src/ngio/tables/backends/_abstract_backend.py b/src/ngio/tables/backends/_abstract_backend.py index ad02bcc5..906980ea 100644 --- a/src/ngio/tables/backends/_abstract_backend.py +++ b/src/ngio/tables/backends/_abstract_backend.py @@ -198,6 +198,13 @@ def write_metadata(self, metadata: dict | None = None) -> None: if metadata is None: metadata = {} + attrs = self._group_handler.reopen_group().attrs.asdict() + # This is required by anndata to identify the format + if "encoding-type" in attrs: + metadata["encoding-type"] = attrs["encoding-type"] + if "encoding-version" in attrs: + metadata["encoding-version"] = attrs["encoding-version"] + backend_metadata = BackendMeta( backend=self.backend_name(), index_key=self.index_key, diff --git a/src/ngio/tables/backends/_anndata.py b/src/ngio/tables/backends/_anndata.py index fad51982..40a092f1 100644 --- a/src/ngio/tables/backends/_anndata.py +++ b/src/ngio/tables/backends/_anndata.py @@ -1,7 +1,10 @@ +import zarr from anndata import AnnData +from anndata._settings import settings from pandas import DataFrame from polars import DataFrame as PolarsDataFrame from polars import LazyFrame +from zarr.storage import FsspecStore, LocalStore, MemoryStore from ngio.tables.backends._abstract_backend import AbstractTableBackend from ngio.tables.backends._anndata_utils import ( @@ -12,7 +15,7 @@ convert_polars_to_anndata, normalize_anndata, ) -from ngio.utils import NgioValueError +from ngio.utils import NgioValueError, copy_group class AnnDataBackend(AbstractTableBackend): @@ -40,6 +43,7 @@ def implements_polars() -> bool: def load_as_anndata(self) -> AnnData: """Load the table as an AnnData object.""" + settings.zarr_write_format = self._group_handler.zarr_format anndata = custom_anndata_read_zarr(self._group_handler._group) anndata = normalize_anndata(anndata, index_key=self.index_key) return anndata @@ -48,17 +52,66 @@ def load(self) -> AnnData: """Load the table as an AnnData object.""" return self.load_as_anndata() + def _write_to_local_store( + self, store: LocalStore, path: str, table: AnnData + ) -> None: + """Write the AnnData table to a LocalStore.""" + store_path = f"{store.root}/{path}" + table.write_zarr(store_path) + + def _write_to_fsspec_store( + self, store: FsspecStore, path: str, table: AnnData + ) -> None: + """Write the AnnData table to a FsspecStore.""" + full_url = f"{store.path}/{path}" + fs = store.fs + mapper = fs.get_mapper(full_url) + table.write_zarr(mapper) + + def _write_to_memory_store( + self, store: MemoryStore, path: str, table: AnnData + ) -> None: + """Write the AnnData table to a MemoryStore.""" + store = MemoryStore() + table.write_zarr(store) + anndata_group = zarr.open_group(store, mode="r") + copy_group( + anndata_group, + self._group_handler._group, + suppress_warnings=True, + ) + def write_from_anndata(self, table: AnnData) -> None: """Serialize the table from an AnnData object.""" - full_url = self._group_handler.full_url - if full_url is None: + # Make sure to use the correct zarr format + settings.zarr_write_format = self._group_handler.zarr_format + store = self._group_handler.store + path = self._group_handler.group.path + if isinstance(store, LocalStore): + self._write_to_local_store( + store, + path, + table, + ) + elif isinstance(store, FsspecStore): + self._write_to_fsspec_store( + store, + path, + table, + ) + elif isinstance(store, MemoryStore): + self._write_to_memory_store( + store, + path, + table, + ) + else: raise NgioValueError( - f"Ngio does not support writing file from a " - f"store of type {type(self._group_handler)}. " + f"Ngio does not support writing an AnnData table to a " + f"store of type {type(store)}. " "Please make sure to use a compatible " - "store like a zarr.DirectoryStore." + "store like a LocalStore, or FsspecStore." ) - table.write_zarr(full_url) # type: ignore (AnnData writer requires a str path) def write_from_pandas(self, table: DataFrame) -> None: """Serialize the table from a pandas DataFrame.""" diff --git a/src/ngio/tables/backends/_anndata_utils.py b/src/ngio/tables/backends/_anndata_utils.py index 86c5968d..997dda4e 100644 --- a/src/ngio/tables/backends/_anndata_utils.py +++ b/src/ngio/tables/backends/_anndata_utils.py @@ -34,10 +34,6 @@ def custom_anndata_read_zarr( elem_to_read (Sequence[str] | None): The elements to read from the store. """ group = open_group_wrapper(store=store, mode="r") - - if not isinstance(group.store, zarr.DirectoryStore): - elem_to_read = ["X", "obs", "var"] - if elem_to_read is None: elem_to_read = [ "X", @@ -87,6 +83,8 @@ def callback(func: Callable, elem_name: str, elem: Any, iospec: Any) -> Any: if isinstance(group["obs"], zarr.Array): _clean_uns(adata) + if isinstance(adata, dict): + adata = AnnData(**adata) # type: ignore if not isinstance(adata, AnnData): raise NgioValueError(f"Expected an AnnData object, but got {type(adata)}") return adata diff --git a/src/ngio/tables/backends/_csv.py b/src/ngio/tables/backends/_csv.py index 49e1e757..f89683b6 100644 --- a/src/ngio/tables/backends/_csv.py +++ b/src/ngio/tables/backends/_csv.py @@ -1,20 +1,7 @@ -import pandas as pd -import polars as pl +from ngio.tables.backends._py_arrow_backends import PyArrowBackend -from ngio.tables.backends._non_zarr_backends import NonZarrBaseBackend - -def write_lf_to_csv(path: str, table: pl.DataFrame) -> None: - """Write a polars DataFrame to a CSV file.""" - table.write_csv(path) - - -def write_df_to_csv(path: str, table: pd.DataFrame) -> None: - """Write a pandas DataFrame to a CSV file.""" - table.to_csv(path, index=False) - - -class CsvTableBackend(NonZarrBaseBackend): +class CsvTableBackend(PyArrowBackend): """A class to load and write small tables in CSV format.""" def __init__( @@ -22,11 +9,8 @@ def __init__( ): """Initialize the CsvTableBackend.""" super().__init__( - lf_reader=pl.scan_csv, - df_reader=pd.read_csv, - lf_writer=write_lf_to_csv, - df_writer=write_df_to_csv, table_name="table.csv", + table_format="csv", ) @staticmethod diff --git a/src/ngio/tables/backends/_json.py b/src/ngio/tables/backends/_json.py index 56d9dfa4..e789c100 100644 --- a/src/ngio/tables/backends/_json.py +++ b/src/ngio/tables/backends/_json.py @@ -8,7 +8,7 @@ normalize_pandas_df, normalize_polars_lf, ) -from ngio.utils import NgioFileNotFoundError +from ngio.utils import NgioError class JsonTableBackend(AbstractTableBackend): @@ -37,22 +37,19 @@ def implements_polars() -> bool: def _get_table_group(self): """Get the table group, creating it if it doesn't exist.""" try: - table_group = self._group_handler.get_group(path="table") - except NgioFileNotFoundError: - table_group = self._group_handler.group.create_group("table") + table_group = self._group_handler.get_group(path="table", create_mode=True) + except NgioError as e: + raise NgioError( + "Could not get or create a 'table' group in the store " + f"{self._group_handler.store} path " + f"{self._group_handler.group.path}/table." + ) from e return table_group - def _load_as_pandas_df(self) -> DataFrame: - """Load the table as a pandas DataFrame.""" - table_group = self._get_table_group() - table_dict = dict(table_group.attrs) - - data_frame = pd.DataFrame.from_dict(table_dict) - return data_frame - def load_as_pandas_df(self) -> DataFrame: """Load the table as a pandas DataFrame.""" - data_frame = self._load_as_pandas_df() + table_dict = self._get_table_group().attrs.asdict() + data_frame = pd.DataFrame.from_dict(table_dict) data_frame = normalize_pandas_df( data_frame, index_key=self.index_key, diff --git a/src/ngio/tables/backends/_non_zarr_backends.py b/src/ngio/tables/backends/_non_zarr_backends.py deleted file mode 100644 index 155aa889..00000000 --- a/src/ngio/tables/backends/_non_zarr_backends.py +++ /dev/null @@ -1,196 +0,0 @@ -import io -from collections.abc import Callable -from typing import Any - -from pandas import DataFrame -from polars import DataFrame as PolarsDataFrame -from polars import LazyFrame -from zarr.storage import DirectoryStore, FSStore - -from ngio.tables.backends._abstract_backend import AbstractTableBackend -from ngio.tables.backends._utils import normalize_pandas_df, normalize_polars_lf -from ngio.utils import NgioFileNotFoundError, NgioValueError - - -class NonZarrBaseBackend(AbstractTableBackend): - """A class to load and write small tables in CSV format.""" - - def __init__( - self, - df_reader: Callable[[Any], DataFrame], - lf_reader: Callable[[Any], LazyFrame], - df_writer: Callable[[str, DataFrame], None], - lf_writer: Callable[[str, PolarsDataFrame], None], - table_name: str, - ): - self.df_reader = df_reader - self.lf_reader = lf_reader - self.df_writer = df_writer - self.lf_writer = lf_writer - self.table_name = table_name - - @staticmethod - def implements_anndata() -> bool: - """Whether the handler implements the anndata protocol.""" - return False - - @staticmethod - def implements_pandas() -> bool: - """Whether the handler implements the dataframe protocol.""" - return True - - @staticmethod - def implements_polars() -> bool: - """Whether the handler implements the polars protocol.""" - return True - - @staticmethod - def backend_name() -> str: - """Return the name of the backend.""" - raise NotImplementedError( - "The backend_name method must be implemented in the subclass." - ) - - def _load_from_directory_store(self, reader): - """Load the table from a directory store.""" - url = self._group_handler.full_url - if url is None: - ext = self.table_name.split(".")[-1] - raise NgioValueError( - f"Ngio does not support reading a {ext} table from a " - f"store of type {type(self._group_handler)}. " - "Please make sure to use a compatible " - "store like a zarr.DirectoryStore." - ) - table_path = f"{url}/{self.table_name}" - dataframe = reader(table_path) - return dataframe - - def _load_from_fs_store_df(self, reader): - """Load the table from an FS store.""" - path = self._group_handler.group.path - table_path = f"{path}/{self.table_name}" - bytes_table = self._group_handler.store.get(table_path) - if bytes_table is None: - raise NgioFileNotFoundError(f"No table found at {table_path}. ") - dataframe = reader(io.BytesIO(bytes_table)) - return dataframe - - def _load_from_fs_store_lf(self, reader): - """Load the table from an FS store.""" - full_url = self._group_handler.full_url - parquet_path = f"{full_url}/{self.table_name}" - store_fs = self._group_handler.store.fs # type: ignore (in this context, store_fs is a fs.FSStore) - with store_fs.open(parquet_path, "rb") as f: - dataframe = reader(f) - return dataframe - - def load_as_pandas_df(self) -> DataFrame: - """Load the table as a pandas DataFrame.""" - store = self._group_handler.store - if isinstance(store, DirectoryStore): - dataframe = self._load_from_directory_store(reader=self.df_reader) - elif isinstance(store, FSStore): - dataframe = self._load_from_fs_store_df(reader=self.df_reader) - else: - ext = self.table_name.split(".")[-1] - raise NgioValueError( - f"Ngio does not support reading a {ext} table from a " - f"store of type {type(store)}. " - "Please make sure to use a compatible " - "store like a zarr.DirectoryStore or " - "zarr.FSStore." - ) - - dataframe = normalize_pandas_df( - dataframe, - index_key=self.index_key, - index_type=self.index_type, - reset_index=False, - ) - return dataframe - - def load(self) -> DataFrame: - """Load the table as a pandas DataFrame.""" - return self.load_as_pandas_df() - - def load_as_polars_lf(self) -> LazyFrame: - """Load the table as a polars LazyFrame.""" - store = self._group_handler.store - if isinstance(store, DirectoryStore): - lazy_frame = self._load_from_directory_store(reader=self.lf_reader) - elif isinstance(store, FSStore): - lazy_frame = self._load_from_fs_store_lf(reader=self.lf_reader) - else: - ext = self.table_name.split(".")[-1] - raise NgioValueError( - f"Ngio does not support reading a {ext} from a " - f"store of type {type(store)}. " - "Please make sure to use a compatible " - "store like a zarr.DirectoryStore or " - "zarr.FSStore." - ) - if not isinstance(lazy_frame, LazyFrame): - raise NgioValueError( - "Table is not a lazy frame. Please report this issue as an ngio bug." - f" {type(lazy_frame)}" - ) - - lazy_frame = normalize_polars_lf( - lazy_frame, - index_key=self.index_key, - index_type=self.index_type, - ) - return lazy_frame - - def _get_store_url(self) -> str: - """Get the store URL.""" - store = self._group_handler.store - if isinstance(store, DirectoryStore): - full_url = self._group_handler.full_url - else: - ext = self.table_name.split(".")[-1] - raise NgioValueError( - f"Ngio does not support writing a {ext} file to a " - f"store of type {type(store)}. " - "Please make sure to use a compatible " - "store like a zarr.DirectoryStore or " - "zarr.FSStore." - ) - if full_url is None: - ext = self.table_name.split(".")[-1] - raise NgioValueError( - f"Ngio does not support writing a {ext} file to a " - f"store of type {type(store)}. " - "Please make sure to use a compatible " - "store like a zarr.DirectoryStore or " - "zarr.FSStore." - ) - return full_url - - def write_from_pandas(self, table: DataFrame) -> None: - """Write the table from a pandas DataFrame.""" - table = normalize_pandas_df( - table, - index_key=self.index_key, - index_type=self.index_type, - reset_index=True, - ) - full_url = self._get_store_url() - table_path = f"{full_url}/{self.table_name}" - self.df_writer(table_path, table) - - def write_from_polars(self, table: PolarsDataFrame | LazyFrame) -> None: - """Write the table from a polars DataFrame or LazyFrame.""" - table = normalize_polars_lf( - table, - index_key=self.index_key, - index_type=self.index_type, - ) - - if isinstance(table, LazyFrame): - table = table.collect() - - full_url = self._get_store_url() - table_path = f"{full_url}/{self.table_name}" - self.lf_writer(table_path, table) diff --git a/src/ngio/tables/backends/_parquet.py b/src/ngio/tables/backends/_parquet.py index 058f3f0b..fa098399 100644 --- a/src/ngio/tables/backends/_parquet.py +++ b/src/ngio/tables/backends/_parquet.py @@ -1,32 +1,7 @@ -import pandas as pd -import polars as pl +from ngio.tables.backends._py_arrow_backends import PyArrowBackend -from ngio.tables.backends._non_zarr_backends import NonZarrBaseBackend - -def write_lf_to_parquet(path: str, table: pl.DataFrame) -> None: - """Write a polars DataFrame to a Parquet file.""" - # make categorical into string (for pandas compatibility) - schema = table.collect_schema() - - categorical_columns = [] - for name, dtype in zip(schema.names(), schema.dtypes(), strict=True): - if dtype == pl.Categorical: - categorical_columns.append(name) - - for col in categorical_columns: - table = table.with_columns(pl.col(col).cast(pl.Utf8)) - - # write to parquet - table.write_parquet(path) - - -def write_df_to_parquet(path: str, table: pd.DataFrame) -> None: - """Write a pandas DataFrame to a Parquet file.""" - table.to_parquet(path, index=False) - - -class ParquetTableBackend(NonZarrBaseBackend): +class ParquetTableBackend(PyArrowBackend): """A class to load and write small tables in Parquet format.""" def __init__( @@ -34,11 +9,8 @@ def __init__( ): """Initialize the ParquetTableBackend.""" super().__init__( - lf_reader=pl.scan_parquet, - df_reader=pd.read_parquet, - lf_writer=write_lf_to_parquet, - df_writer=write_df_to_parquet, table_name="table.parquet", + table_format="parquet", ) @staticmethod diff --git a/src/ngio/tables/backends/_py_arrow_backends.py b/src/ngio/tables/backends/_py_arrow_backends.py new file mode 100644 index 00000000..325b854a --- /dev/null +++ b/src/ngio/tables/backends/_py_arrow_backends.py @@ -0,0 +1,222 @@ +from typing import Literal + +import polars as pl +import pyarrow as pa +import pyarrow.csv as pa_csv +import pyarrow.dataset as pa_ds +import pyarrow.fs as pa_fs +import pyarrow.parquet as pa_parquet +from pandas import DataFrame +from polars import DataFrame as PolarsDataFrame +from polars import LazyFrame +from zarr.storage import FsspecStore, LocalStore, MemoryStore, ZipStore + +from ngio.tables.backends._abstract_backend import AbstractTableBackend +from ngio.tables.backends._utils import normalize_pandas_df, normalize_polars_lf +from ngio.utils import NgioValueError +from ngio.utils._zarr_utils import _make_sync_fs + + +class PyArrowBackend(AbstractTableBackend): + """A class to load and write small tables in CSV format.""" + + def __init__( + self, + table_name: str, + table_format: Literal["csv", "parquet"] = "parquet", + ): + self.table_name = table_name + self.table_format = table_format + + @staticmethod + def implements_anndata() -> bool: + """Whether the handler implements the anndata protocol.""" + return False + + @staticmethod + def implements_pandas() -> bool: + """Whether the handler implements the dataframe protocol.""" + return True + + @staticmethod + def implements_polars() -> bool: + """Whether the handler implements the polars protocol.""" + return True + + @staticmethod + def backend_name() -> str: + """Return the name of the backend.""" + raise NotImplementedError( + "The backend_name method must be implemented in the subclass." + ) + + def _raise_store_type_not_supported(self): + """Raise an error for unsupported store types.""" + ext = self.table_name.split(".")[-1] + store = self._group_handler.store + raise NgioValueError( + f"Ngio does not support reading a {ext} table from a " + f"store of type {type(store)}. " + "Please make sure to use a compatible " + "store like a LocalStore, or " + "FsspecStore, or MemoryStore, or ZipStore." + ) + + def _load_from_local_store(self, store: LocalStore, path: str) -> pa_ds.Dataset: + """Load the table from a directory store.""" + root_path = store.root + table_path = f"{root_path}/{path}/{self.table_name}" + dataset = pa_ds.dataset(table_path, format=self.table_format) + return dataset + + def _load_from_fsspec_store(self, store: FsspecStore, path: str) -> pa_ds.Dataset: + """Load the table from an FS store.""" + table_path = f"{store.path}/{path}/{self.table_name}" + fs = _make_sync_fs(store.fs) + dataset = pa_ds.dataset(table_path, format=self.table_format, filesystem=fs) + return dataset + + def _load_from_in_memory_store( + self, store: MemoryStore, path: str + ) -> pa_ds.Dataset: + """Load the table from an in-memory store.""" + table_path = f"{path}/{self.table_name}" + table = store._store_dict.get(table_path, None) + if table is None: + raise NgioValueError( + f"Table {self.table_name} not found in the in-memory store at " + f"path {path}." + ) + assert isinstance(table, pa.Table) + dataset = pa_ds.dataset(table) + return dataset + + def _load_from_zip_store(self, store: ZipStore, path: str) -> pa_ds.Dataset: + """Load the table from a zip store.""" + raise NotImplementedError("Zip store loading is not implemented yet.") + + def _load_pyarrow_dataset(self) -> pa_ds.Dataset: + """Load the table as a pyarrow Dataset.""" + store = self._group_handler.store + path = self._group_handler.group.path + if isinstance(store, LocalStore): + return self._load_from_local_store(store, path) + elif isinstance(store, FsspecStore): + return self._load_from_fsspec_store(store, path) + elif isinstance(store, MemoryStore): + return self._load_from_in_memory_store(store, path) + elif isinstance(store, ZipStore): + return self._load_from_zip_store(store, path) + self._raise_store_type_not_supported() + + def load_as_pandas_df(self) -> DataFrame: + """Load the table as a pandas DataFrame.""" + dataset = self._load_pyarrow_dataset() + dataframe = dataset.to_table().to_pandas() + dataframe = normalize_pandas_df( + dataframe, + index_key=self.index_key, + index_type=self.index_type, + reset_index=False, + ) + return dataframe + + def load(self) -> DataFrame: + """Load the table as a pandas DataFrame.""" + return self.load_as_pandas_df() + + def load_as_polars_lf(self) -> LazyFrame: + """Load the table as a polars LazyFrame.""" + dataset = self._load_pyarrow_dataset() + lazy_frame = pl.scan_pyarrow_dataset(dataset) + if not isinstance(lazy_frame, LazyFrame): + raise NgioValueError( + "Table is not a lazy frame. Please report this issue as an ngio bug." + f" {type(lazy_frame)}" + ) + + lazy_frame = normalize_polars_lf( + lazy_frame, + index_key=self.index_key, + index_type=self.index_type, + ) + return lazy_frame + + def _write_to_stream(self, stream, table: pa.Table) -> None: + """Write the table to a stream.""" + if self.table_format == "parquet": + pa_parquet.write_table(table, stream) + elif self.table_format == "csv": + pa_csv.write_csv(table, stream) + else: + raise NgioValueError( + f"Unsupported table format: {self.table_format}. " + "Supported formats are 'parquet' and 'csv'." + ) + + def _write_to_local_store( + self, store: LocalStore, path: str, table: pa.Table + ) -> None: + """Write the table to a directory store.""" + root_path = store.root + table_path = f"{root_path}/{path}/{self.table_name}" + self._write_to_stream(table_path, table) + + def _write_to_fsspec_store( + self, store: FsspecStore, path: str, table: pa.Table + ) -> None: + """Write the table to an FS store.""" + table_path = f"{store.path}/{path}/{self.table_name}" + fs = _make_sync_fs(store.fs) + fs = pa_fs.PyFileSystem(pa_fs.FSSpecHandler(fs)) + with fs.open_output_stream(table_path) as out_stream: + self._write_to_stream(out_stream, table) + + def _write_to_in_memory_store( + self, store: MemoryStore, path: str, table: pa.Table + ) -> None: + """Write the table to an in-memory store.""" + table_path = f"{path}/{self.table_name}" + store._store_dict[table_path] = table + + def _write_to_zip_store(self, store: ZipStore, path: str, table: pa.Table) -> None: + """Write the table to a zip store.""" + raise NotImplementedError("Writing to zip store is not implemented yet.") + + def _write_pyarrow_dataset(self, dataset: pa.Table) -> None: + """Write the table from a pyarrow Dataset.""" + store = self._group_handler.store + path = self._group_handler.group.path + if isinstance(store, LocalStore): + return self._write_to_local_store(store=store, path=path, table=dataset) + elif isinstance(store, FsspecStore): + return self._write_to_fsspec_store(store=store, path=path, table=dataset) + elif isinstance(store, MemoryStore): + return self._write_to_in_memory_store(store=store, path=path, table=dataset) + elif isinstance(store, ZipStore): + return self._write_to_zip_store(store=store, path=path, table=dataset) + self._raise_store_type_not_supported() + + def write_from_pandas(self, table: DataFrame) -> None: + """Write the table from a pandas DataFrame.""" + table = normalize_pandas_df( + table, + index_key=self.index_key, + index_type=self.index_type, + reset_index=True, + ) + table = pa.Table.from_pandas(table, preserve_index=False) + self._write_pyarrow_dataset(table) + + def write_from_polars(self, table: PolarsDataFrame | LazyFrame) -> None: + """Write the table from a polars DataFrame or LazyFrame.""" + table = normalize_polars_lf( + table, + index_key=self.index_key, + index_type=self.index_type, + ) + + if isinstance(table, LazyFrame): + table = table.collect() + table = table.to_arrow() + self._write_pyarrow_dataset(table) diff --git a/src/ngio/tables/backends/_utils.py b/src/ngio/tables/backends/_utils.py index e3698854..101483cf 100644 --- a/src/ngio/tables/backends/_utils.py +++ b/src/ngio/tables/backends/_utils.py @@ -403,7 +403,7 @@ def convert_anndata_to_pandas( DataFrame: Converted and normalized pandas DataFrame. """ pandas_df = anndata.to_df() - pandas_df[anndata.obs_keys()] = anndata.obs + pandas_df[anndata.obs.columns.to_list()] = anndata.obs pandas_df = normalize_pandas_df( pandas_df, index_key=index_key, diff --git a/src/ngio/tables/v1/_roi_table.py b/src/ngio/tables/v1/_roi_table.py index b645f53f..af4cce63 100644 --- a/src/ngio/tables/v1/_roi_table.py +++ b/src/ngio/tables/v1/_roi_table.py @@ -4,6 +4,7 @@ https://fractal-analytics-platform.github.io/fractal-tasks-core/tables/ """ +import warnings from collections.abc import Iterable from typing import Literal from uuid import uuid4 @@ -26,7 +27,6 @@ NgioTableValidationError, NgioValueError, ZarrGroupHandler, - ngio_warn, ) REQUIRED_COLUMNS = [ @@ -77,7 +77,9 @@ def _check_optional_columns(col_name: str) -> None: """Check if the column name is in the optional columns.""" if col_name not in OPTIONAL_COLUMNS + TIME_COLUMNS: - ngio_warn(f"Column {col_name} is not in the optional columns.") + warnings.warn( + f"Column {col_name} is not in the optional columns.", stacklevel=2 + ) def _dataframe_to_rois( @@ -120,17 +122,17 @@ def _dataframe_to_rois( else: label = getattr(row, "label", None) - roi = Roi( + slices = { + "x": (row.x_micrometer, row.len_x_micrometer), + "y": (row.y_micrometer, row.len_y_micrometer), + "z": (z_micrometer, z_length_micrometer), + } + if t_second is not None or t_length_second is not None: + slices["t"] = (t_second, t_length_second) + roi = Roi.from_values( name=str(row.Index), - x=row.x_micrometer, # type: ignore (type can not be known here) - y=row.y_micrometer, # type: ignore (type can not be known here) - z=z_micrometer, - t=t_second, - x_length=row.len_x_micrometer, # type: ignore (type can not be known here) - y_length=row.len_y_micrometer, # type: ignore (type can not be known here) - z_length=z_length_micrometer, - t_length=t_length_second, - unit="micrometer", + slices=slices, + space="world", label=label, **extras, ) @@ -143,24 +145,39 @@ def _rois_to_dataframe(rois: dict[str, Roi], index_key: str | None) -> pd.DataFr data = [] for roi in rois.values(): # This normalization is necessary for backward compatibility - z_micrometer = roi.z if roi.z is not None else 0.0 - len_z_micrometer = roi.z_length if roi.z_length is not None else 1.0 + if roi.space != "world": + raise NotImplementedError( + "Only ROIs in world coordinates can be serialized." + ) + z_slice = roi.get("z") + if z_slice is None: + z_micrometer = 0.0 + len_z_micrometer = 1.0 + else: + z_micrometer = z_slice.start if z_slice.start is not None else 0.0 + len_z_micrometer = z_slice.length if z_slice.length is not None else 1.0 + + x_slice = roi.get("x") + if x_slice is None: + raise NgioValueError("ROI is missing 'x' slice.") + y_slice = roi.get("y") + if y_slice is None: + raise NgioValueError("ROI is missing 'y' slice.") row = { index_key: roi.get_name(), - "x_micrometer": roi.x, - "y_micrometer": roi.y, + "x_micrometer": x_slice.start if x_slice.start is not None else 0.0, + "y_micrometer": y_slice.start if y_slice.start is not None else 0.0, "z_micrometer": z_micrometer, - "len_x_micrometer": roi.x_length, - "len_y_micrometer": roi.y_length, + "len_x_micrometer": x_slice.length if x_slice.length is not None else 1.0, + "len_y_micrometer": y_slice.length if y_slice.length is not None else 1.0, "len_z_micrometer": len_z_micrometer, } - if roi.t is not None: - row["t_second"] = roi.t - - if roi.t_length is not None: - row["len_t_second"] = roi.t_length + t_slice = roi.get("t") + if t_slice is not None: + row["t_second"] = t_slice.start if t_slice.start is not None else 0.0 + row["len_t_second"] = t_slice.length if t_slice.length is not None else 1.0 if roi.label is not None and index_key != "label": row["label"] = roi.label @@ -183,7 +200,7 @@ def __init__(self, rois: Iterable[Roi]) -> None: self._rois_by_name = {} self._rois_by_label = {} for roi in rois: - name = roi.get_name() + name = roi.name if name in self._rois_by_name: name = f"{name}_{uuid4().hex[:8]}" self._rois_by_name[name] = roi diff --git a/src/ngio/utils/__init__.py b/src/ngio/utils/__init__.py index edadfbff..a6ad64ae 100644 --- a/src/ngio/utils/__init__.py +++ b/src/ngio/utils/__init__.py @@ -1,13 +1,12 @@ """Various utilities for the ngio package.""" -import os - from ngio.utils._datasets import ( download_ome_zarr_dataset, list_ome_zarr_datasets, print_datasets_infos, ) from ngio.utils._errors import ( + NgioError, NgioFileExistsError, NgioFileNotFoundError, NgioTableValidationError, @@ -15,35 +14,32 @@ NgioValueError, ) from ngio.utils._fractal_fsspec_store import fractal_fsspec_store -from ngio.utils._logger import ngio_logger, ngio_warn, set_logger_level from ngio.utils._zarr_utils import ( AccessModeLiteral, + NgioCache, + NgioSupportedStore, StoreOrGroup, ZarrGroupHandler, + copy_group, open_group_wrapper, ) -set_logger_level(os.getenv("NGIO_LOGGER_LEVEL", "WARNING")) - __all__ = [ - # Zarr "AccessModeLiteral", - # Errors + "NgioCache", + "NgioError", "NgioFileExistsError", "NgioFileNotFoundError", + "NgioSupportedStore", "NgioTableValidationError", "NgioValidationError", "NgioValueError", "StoreOrGroup", "ZarrGroupHandler", - # Other + "copy_group", "download_ome_zarr_dataset", "fractal_fsspec_store", "list_ome_zarr_datasets", - # Logger - "ngio_logger", - "ngio_warn", "open_group_wrapper", "print_datasets_infos", - "set_logger_level", ] diff --git a/src/ngio/utils/_cache.py b/src/ngio/utils/_cache.py new file mode 100644 index 00000000..6ee0c07f --- /dev/null +++ b/src/ngio/utils/_cache.py @@ -0,0 +1,48 @@ +from typing import Generic, TypeVar + +T = TypeVar("T") + + +class NgioCache(Generic[T]): + """A simple cache for NGIO objects.""" + + def __init__(self, use_cache: bool = True): + self._cache: dict[str, T] = {} + self._use_cache = use_cache + + def _cache_sanity_check(self) -> None: + if len(self._cache) > 0: + raise RuntimeError( + "Cache is disabled, but cache contains items. " + "This indicates a logic error." + ) + + @property + def use_cache(self) -> bool: + return self._use_cache + + @property + def cache(self) -> dict[str, T]: + return self._cache + + @property + def is_empty(self) -> bool: + return len(self._cache) == 0 + + def get(self, key: str, default: T | None = None) -> T | None: + if not self._use_cache: + self._cache_sanity_check() + return default + return self._cache.get(key, default) + + def set(self, key: str, value: T, overwrite: bool = True) -> None: + if not self._use_cache: + self._cache_sanity_check() + return + self._cache[key] = value + + def clear(self) -> None: + if not self._use_cache: + self._cache_sanity_check() + return + self._cache.clear() diff --git a/src/ngio/utils/_logger.py b/src/ngio/utils/_logger.py deleted file mode 100644 index 9886cc4e..00000000 --- a/src/ngio/utils/_logger.py +++ /dev/null @@ -1,50 +0,0 @@ -import logging -import time -from functools import cache - -from ngio.utils._errors import NgioValueError - -# Configure the logger -ngio_logger = logging.getLogger("NgioLogger") -ngio_logger.setLevel(logging.ERROR) - -# Set up a console handler with a custom format -console_handler = logging.StreamHandler() -formatter = logging.Formatter( - "%(asctime)s - %(levelname)s - %(name)s - " - "[%(module)s.%(funcName)s:%(lineno)d]: %(message)s" -) -console_handler.setFormatter(formatter) - -# Add the handler to the logger -ngio_logger.addHandler(console_handler) - - -def set_logger_level(level: str) -> None: - """Set the logger level. - - Args: - level: The level to set the logger to. - Must be one of "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL". - """ - if level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]: - raise NgioValueError(f"Invalid log level: {level}") - - ngio_logger.setLevel(level) - - -@cache -def _warn(message: str, ttl_hash: int) -> None: - """Log a warning message with a time-to-live (TTL) hash.""" - ngio_logger.warning(message, stacklevel=3) - - -def ngio_warn(message: str, cooldown: int = 2) -> None: - """Log a warning message. - - Args: - message: The warning message to log. - cooldown: The cooldown period in seconds to avoid repeated logging. - """ - ttl_hash = time.time() // cooldown - _warn(message, ttl_hash) diff --git a/src/ngio/utils/_zarr_utils.py b/src/ngio/utils/_zarr_utils.py index a94ca622..b8b10cbf 100644 --- a/src/ngio/utils/_zarr_utils.py +++ b/src/ngio/utils/_zarr_utils.py @@ -1,61 +1,75 @@ """Common utilities for working with Zarr groups in consistent ways.""" +import json +import warnings from pathlib import Path -from typing import Literal +from typing import Literal, TypeAlias +import dask.array as da import fsspec import zarr from filelock import BaseFileLock, FileLock -from zarr.errors import ContainsGroupError, GroupNotFoundError -from zarr.storage import DirectoryStore, FSStore, MemoryStore, Store, StoreLike -from zarr.types import DIMENSION_SEPARATOR - -from ngio.utils import NgioFileExistsError, NgioFileNotFoundError, NgioValueError -from ngio.utils._errors import NgioError +from pydantic_zarr.v2 import ArraySpec as AnyArraySpecV2 +from pydantic_zarr.v3 import ArraySpec as AnyArraySpecV3 +from zarr.abc.store import Store +from zarr.errors import ContainsGroupError +from zarr.storage import FsspecStore, LocalStore, MemoryStore, ZipStore + +from ngio.utils._cache import NgioCache +from ngio.utils._errors import ( + NgioFileExistsError, + NgioFileNotFoundError, + NgioValueError, +) AccessModeLiteral = Literal["r", "r+", "w", "w-", "a"] # StoreLike is more restrictive than it could be # but to make sure we can handle the store correctly # we need to be more restrictive -NgioSupportedStore = ( - str | Path | fsspec.mapping.FSMap | FSStore | DirectoryStore | MemoryStore +NgioSupportedStore: TypeAlias = ( + str | Path | fsspec.mapping.FSMap | FsspecStore | MemoryStore | dict | LocalStore ) -GenericStore = Store | NgioSupportedStore -StoreOrGroup = GenericStore | zarr.Group +GenericStore: TypeAlias = NgioSupportedStore | Store +StoreOrGroup: TypeAlias = NgioSupportedStore | zarr.Group def _check_store(store) -> NgioSupportedStore: """Check the store and return a valid store.""" - if isinstance(store, NgioSupportedStore): - return store - - raise NotImplementedError( - f"Store type {type(store)} is not supported. " - f"Supported types are: {NgioSupportedStore}" - ) + if not isinstance(store, NgioSupportedStore): + warnings.warn( + f"Store type {type(store)} is not explicitly supported. " + f"Supported types are: {NgioSupportedStore}. " + "Proceeding, but this may lead to unexpected behavior.", + UserWarning, + stacklevel=2, + ) + return store -def _check_group(group: zarr.Group, mode: AccessModeLiteral) -> zarr.Group: +def _check_group( + group: zarr.Group, mode: AccessModeLiteral | None = None +) -> zarr.Group: """Check the group and return a valid group.""" - is_read_only = getattr(group, "_read_only", False) - if is_read_only and mode in ["w", "w-"]: - raise NgioValueError( - "The group is read only. Cannot open in write mode ['w', 'w-']" - ) + if group.read_only and mode not in [None, "r"]: + raise NgioValueError(f"The group is read only. Cannot open in mode {mode}.") - if mode == "r" and not is_read_only: + if mode == "r" and not group.read_only: # let's make sure we don't accidentally write to the group group = zarr.open_group(store=group.store, path=group.path, mode="r") - return group -def open_group_wrapper(store: StoreOrGroup, mode: AccessModeLiteral) -> zarr.Group: +def open_group_wrapper( + store: StoreOrGroup, + mode: AccessModeLiteral | None = None, + zarr_format: Literal[2, 3] | None = None, +) -> zarr.Group: """Wrapper around zarr.open_group with some additional checks. Args: store (StoreOrGroup): The store or group to open. - mode (ReadOrEdirLiteral): The mode to open the group in. + mode (AccessModeLiteral): The mode to open the group in. + zarr_format (int): The Zarr format version to use. Returns: zarr.Group: The opened Zarr group. @@ -67,16 +81,22 @@ def open_group_wrapper(store: StoreOrGroup, mode: AccessModeLiteral) -> zarr.Gro try: _check_store(store) - group = zarr.open_group(store=store, mode=mode) + mode = mode if mode is not None else "a" + group = zarr.open_group(store=store, mode=mode, zarr_format=zarr_format) - except ContainsGroupError as e: + except FileExistsError as e: raise NgioFileExistsError( f"A Zarr group already exists at {store}, consider setting overwrite=True." ) from e - except GroupNotFoundError as e: + except FileNotFoundError as e: raise NgioFileNotFoundError(f"No Zarr group found at {store}") from e + except ContainsGroupError as e: + raise NgioFileExistsError( + f"A Zarr group already exists at {store}, consider setting overwrite=True." + ) from e + return group @@ -86,178 +106,185 @@ class ZarrGroupHandler: def __init__( self, store: StoreOrGroup, + zarr_format: Literal[2, 3] | None = None, cache: bool = False, - mode: AccessModeLiteral = "a", - parallel_safe: bool = False, - parent: "ZarrGroupHandler | None" = None, + mode: AccessModeLiteral | None = None, ): """Initialize the handler. Args: store (StoreOrGroup): The Zarr store or group containing the image data. meta_mode (str): The mode of the metadata handler. + zarr_format (int | None): The Zarr format version to use. cache (bool): Whether to cache the metadata. - mode (str): The mode of the store. - parallel_safe (bool): If True, the handler will create a lock file to make - that can be used to make the handler parallel safe. - Be aware that the lock needs to be used manually. - parent (ZarrGroupHandler | None): The parent handler. + mode (str | None): The mode of the store. """ - if mode not in ["r", "r+", "w", "w-", "a"]: + if mode not in ["r", "r+", "w", "w-", "a", None]: raise NgioValueError(f"Mode {mode} is not supported.") - if parallel_safe and cache: - raise NgioValueError( - "The cache and parallel_safe options are mutually exclusive." - "If you want to use the lock mechanism, you should not use the cache." - ) - - group = open_group_wrapper(store, mode) - _store = group.store - - # Make sure the cache is set in the attrs - # in the same way as the cache in the handler - group.attrs.cache = cache - - if parallel_safe: - if not isinstance(_store, DirectoryStore): - raise NgioValueError( - "The store needs to be a DirectoryStore to use the lock mechanism. " - f"Instead, got {_store.__class__.__name__}." - ) - store_path = Path(_store.path) / group.path - self._lock_path = store_path.with_suffix(".lock") - self._lock = FileLock(self._lock_path, timeout=10) - - else: - self._lock_path = None - self._lock = None - + group = open_group_wrapper(store=store, mode=mode, zarr_format=zarr_format) self._group = group - self._mode = mode self.use_cache = cache - self._parallel_safe = parallel_safe - self._cache = {} - self._parent = parent + + self._group_cache: NgioCache[zarr.Group] = NgioCache(use_cache=cache) + self._array_cache: NgioCache[zarr.Array] = NgioCache(use_cache=cache) + self._handlers_cache: NgioCache[ZarrGroupHandler] = NgioCache(use_cache=cache) + self._lock: tuple[Path, BaseFileLock] | None = None def __repr__(self) -> str: """Return a string representation of the handler.""" return ( - f"ZarrGroupHandler(full_url={self.full_url}, mode={self.mode}, " + f"ZarrGroupHandler(full_url={self.full_url}, read_only={self.read_only}, " f"cache={self.use_cache}" ) @property - def store(self) -> StoreLike: + def store(self) -> Store: """Return the store of the group.""" - return self.group.store + return self._group.store @property def full_url(self) -> str | None: """Return the store path.""" - if isinstance(self.store, DirectoryStore | FSStore): - _store_path = str(self.store.path) - _store_path = _store_path.rstrip("/") - return f"{self.store.path}/{self._group.path}" + if isinstance(self.store, LocalStore): + return (self.store.root / self.group.path).as_posix() + elif isinstance(self.store, FsspecStore): + return f"{self.store.path}/{self.group.path}" + elif isinstance(self.store, ZipStore): + return (self.store.path / self.group.path).as_posix() + elif isinstance(self.store, MemoryStore): + return None + warnings.warn( + f"Cannot determine full URL for store type {type(self.store)}. ", + UserWarning, + stacklevel=2, + ) return None @property - def mode(self) -> AccessModeLiteral: - """Return the mode of the group.""" - return self._mode # type: ignore (return type is Literal) + def zarr_format(self) -> Literal[2, 3]: + """Return the Zarr format version.""" + return self._group.metadata.zarr_format + + @property + def read_only(self) -> bool: + """Return whether the group is read only.""" + return self._group.read_only + + def _create_lock(self) -> tuple[Path, BaseFileLock]: + """Create the lock.""" + if self._lock is not None: + return self._lock + + if self.use_cache is True: + raise NgioValueError( + "Lock mechanism is not compatible with caching. " + "Please set cache=False to use the lock mechanism." + ) + + if not isinstance(self.store, LocalStore): + raise NgioValueError( + "The store needs to be a LocalStore to use the lock mechanism. " + f"Instead, got {self.store.__class__.__name__}." + ) + + store_path = Path(self.store.root) / self.group.path + _lock_path = store_path.with_suffix(".lock") + _lock = FileLock(_lock_path, timeout=10) + return _lock_path, _lock @property def lock(self) -> BaseFileLock: """Return the lock.""" if self._lock is None: - raise NgioValueError( - "The handler is not parallel safe. " - "Reopen the handler with parallel_safe=True." - ) - return self._lock + self._lock = self._create_lock() + return self._lock[1] @property - def parent(self) -> "ZarrGroupHandler | None": - """Return the parent handler.""" - return self._parent + def lock_path(self) -> Path: + """Return the lock path.""" + if self._lock is None: + self._lock = self._create_lock() + return self._lock[0] def remove_lock(self) -> None: """Return the lock.""" - if self._lock is None or self._lock_path is None: + if self._lock is None: return None - lock_path = Path(self._lock_path) - if lock_path.exists() and self._lock.lock_counter == 0: + lock_path, lock = self._lock + if lock_path.exists() and lock.lock_counter == 0: lock_path.unlink() self._lock = None - self._lock_path = None return None raise NgioValueError("The lock is still in use. Cannot remove it.") - @property - def group(self) -> zarr.Group: - """Return the group.""" - return self._group + def reopen_group(self) -> zarr.Group: + """Reopen the group. - def add_to_cache(self, key: str, value: object) -> None: - """Add an object to the cache.""" - if not self.use_cache: - return None - self._cache[key] = value + This is useful when the group has been modified + outside of the handler. + """ + mode = "r" if self.read_only else "r+" + return zarr.open_group( + store=self._group.store, + path=self._group.path, + mode=mode, + zarr_format=self._group.metadata.zarr_format, + ) - def get_from_cache(self, key: str) -> object | None: - """Get an object from the cache.""" - if not self.use_cache: - return None - return self._cache.get(key, None) + def reopen_handler(self) -> "ZarrGroupHandler": + """Reopen the handler. + + This is useful when the group has been modified + outside of the handler. + """ + mode = "r" if self.read_only else "r+" + group = self.reopen_group() + return ZarrGroupHandler( + store=group, + zarr_format=group.metadata.zarr_format, + cache=self.use_cache, + mode=mode, + ) def clean_cache(self) -> None: """Clear the cached metadata.""" - self._cache = {} + group = self.reopen_group() + self.__init__( + store=group, + zarr_format=group.metadata.zarr_format, + cache=self.use_cache, + mode="r" if self.read_only else "r+", + ) + + @property + def group(self) -> zarr.Group: + """Return the group.""" + if self.use_cache is False: + # If we are not using cache, we need to reopen the group + # to make sure that the attributes are up to date + return self.reopen_group() + return self._group def load_attrs(self) -> dict: """Load the attributes of the group.""" - attrs = self.get_from_cache("attrs") - if attrs is not None and isinstance(attrs, dict): - return attrs - - attrs = dict(self.group.attrs) - - self.add_to_cache("attrs", attrs) - return attrs - - def _write_attrs(self, attrs: dict, overwrite: bool = False) -> None: - """Write the metadata to the store.""" - is_read_only = getattr(self._group, "_read_only", False) - if is_read_only: - raise NgioValueError("The group is read only. Cannot write metadata.") - - # we need to invalidate the current attrs cache - self.add_to_cache("attrs", None) - if overwrite: - self.group.attrs.clear() - - self.group.attrs.update(attrs) + return self.reopen_group().attrs.asdict() def write_attrs(self, attrs: dict, overwrite: bool = False) -> None: """Write the metadata to the store.""" # Maybe we should use the lock here - self._write_attrs(attrs, overwrite) - - def _obj_get(self, path: str): - """Get a group from the group.""" - group_or_array = self.get_from_cache(path) - if group_or_array is not None: - return group_or_array - - group_or_array = self.group.get(path, None) - self.add_to_cache(path, group_or_array) - return group_or_array + if self.read_only: + raise NgioValueError("The group is read only. Cannot write metadata.") + group = self.reopen_group() + if overwrite: + group.attrs.clear() + group.attrs.update(attrs) def create_group(self, path: str, overwrite: bool = False) -> zarr.Group: """Create a group in the group.""" - if self.mode == "r": + if self.group.read_only: raise NgioValueError("Cannot create a group in read only mode.") try: @@ -267,7 +294,7 @@ def create_group(self, path: str, overwrite: bool = False) -> zarr.Group: f"A Zarr group already exists at {path}, " "consider setting overwrite=True." ) from e - self.add_to_cache(path, group) + self._group_cache.set(path, group, overwrite=overwrite) return group def get_group( @@ -293,123 +320,215 @@ def get_group( if overwrite: return self.create_group(path, overwrite=overwrite) - group = self._obj_get(path) + group = self._group_cache.get(path) if isinstance(group, zarr.Group): return group - if group is not None: - raise NgioValueError( - f"The object at {path} is not a group, but a {type(group)}" - ) + group = self.group.get(path, default=None) + if isinstance(group, zarr.Group): + self._group_cache.set(path, group, overwrite=overwrite) + return group + + if isinstance(group, zarr.Array): + raise NgioValueError(f"The object at {path} is not a group, but an array.") if not create_mode: raise NgioFileNotFoundError(f"No group found at {path}") group = self.create_group(path) + self._group_cache.set(path, group, overwrite=overwrite) return group - def safe_get_group( - self, path: str, create_mode: bool = False - ) -> tuple[bool, zarr.Group | NgioError]: - """Get a group from the group. + def get_array(self, path: str) -> zarr.Array: + """Get an array from the group.""" + array = self._array_cache.get(path) + if isinstance(array, zarr.Array): + return array + array = self.group.get(path, default=None) + if isinstance(array, zarr.Array): + self._array_cache.set(path, array) + return array + + if isinstance(array, zarr.Group): + raise NgioValueError(f"The object at {path} is not an array, but a group.") + raise NgioFileNotFoundError(f"No array found at {path}") + + def get_handler( + self, + path: str, + create_mode: bool = True, + overwrite: bool = False, + ) -> "ZarrGroupHandler": + """Get a new handler for a group in the current handler group. Args: path (str): The path to the group. create_mode (bool): If True, create the group if it does not exist. + overwrite (bool): If True, overwrite the group if it exists. + """ + handler = self._handlers_cache.get(path) + if handler is not None: + return handler + group = self.get_group(path, create_mode=create_mode, overwrite=overwrite) + mode = "r" if group.read_only else "r+" + handler = ZarrGroupHandler( + store=group, zarr_format=self.zarr_format, cache=self.use_cache, mode=mode + ) + self._handlers_cache.set(path, handler) + return handler - Returns: - zarr.Group | None: The Zarr group or None if it does not exist - or an error occurs. + @property + def is_listable(self) -> bool: + return is_group_listable(self.group) + + def delete_group(self, path: str) -> None: + """Delete a group from the current group. + Args: + path (str): The path to the group to delete. """ - try: - return True, self.get_group(path, create_mode) - except NgioError as e: - return False, e + if self.group.read_only: + raise NgioValueError("Cannot delete a group in read only mode.") + self.group.__delitem__(path) + self._group_cache._cache.pop(path, None) + self._handlers_cache._cache.pop(path, None) + + def delete_self(self) -> None: + """Delete the current group.""" + if self.group.read_only: + raise NgioValueError("Cannot delete a group in read only mode.") + self.group.__delitem__("/") + + def copy_group(self, dest_group: zarr.Group): + """Copy the group to a new store.""" + copy_group(self.group, dest_group) - def get_array(self, path: str) -> zarr.Array: - """Get an array from the group.""" - array = self._obj_get(path) - if array is None: - raise NgioFileNotFoundError(f"No array found at {path}") - if not isinstance(array, zarr.Array): + +def find_dimension_separator(array: zarr.Array) -> Literal[".", "/"]: + """Find the dimension separator used in the Zarr store. + + Args: + array (zarr.Array): The Zarr array to check. + + Returns: + Literal[".", "/"]: The dimension separator used in the store. + """ + from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding + + if array.metadata.zarr_format == 2: + separator = array.metadata.dimension_separator + else: + separator = array.metadata.chunk_key_encoding + if not isinstance(separator, DefaultChunkKeyEncoding): raise NgioValueError( - f"The object at {path} is not an array, but a {type(array)}" + "Only DefaultChunkKeyEncoding is supported in this example." ) - return array + separator = separator.separator + return separator - def create_array( - self, - path: str, - shape: tuple[int, ...], - dtype: str, - chunks: tuple[int, ...] | None = None, - dimension_separator: DIMENSION_SEPARATOR = "/", - compressor: str = "default", - overwrite: bool = False, - ) -> zarr.Array: - if self.mode == "r": - raise NgioValueError("Cannot create an array in read only mode.") - try: - return self.group.zeros( - name=path, - shape=shape, - dtype=dtype, - chunks=chunks, - dimension_separator=dimension_separator, - compressor=compressor, - overwrite=overwrite, - ) - except ContainsGroupError as e: - raise NgioFileExistsError( - f"A Zarr array already exists at {path}, " - "consider setting overwrite=True." - ) from e - except Exception as e: - raise NgioValueError(f"Error creating array at {path}") from e +def is_group_listable(group: zarr.Group) -> bool: + """Check if a Zarr group is listable. - def derive_handler( - self, - path: str, - overwrite: bool = False, - ) -> "ZarrGroupHandler": - """Derive a new handler from the current handler. + A group is considered listable if it contains at least one array or subgroup. - Args: - path (str): The path to the group. - overwrite (bool): If True, overwrite the group if it exists. - """ - group = self.get_group(path, create_mode=True, overwrite=overwrite) - return ZarrGroupHandler( - store=group, - cache=self.use_cache, - mode=self.mode, - parallel_safe=self._parallel_safe, - parent=self, + Args: + group (zarr.Group): The Zarr group to check. + + Returns: + bool: True if the group is listable, False otherwise. + """ + if not group.store.supports_listing: + # If the store does not support listing + # then for sure it is not listable + return False + try: + next(group.keys()) + return True + except StopIteration: + # Group is listable but empty + return True + except Exception as _: + # Some stores may raise errors when listing + # consider those not listable + return False + + +def _make_sync_fs(fs: fsspec.AbstractFileSystem) -> fsspec.AbstractFileSystem: + fs_dict = json.loads(fs.to_json()) + fs_dict["asynchronous"] = False + return fsspec.AbstractFileSystem.from_json(json.dumps(fs_dict)) + + +def _get_mapper(store: LocalStore | FsspecStore, path: str): + if isinstance(store, LocalStore): + fs = fsspec.filesystem("file") + full_path = (store.root / path).as_posix() + else: + fs = _make_sync_fs(store.fs) + full_path = f"{store.path}/{path}" + return fs.get_mapper(full_path) + + +def _fsspec_copy( + src_fs: LocalStore | FsspecStore, + src_path: str, + dest_fs: LocalStore | FsspecStore, + dest_path: str, +): + src_mapper = _get_mapper(src_fs, src_path) + dest_mapper = _get_mapper(dest_fs, dest_path) + for key in src_mapper.keys(): + dest_mapper[key] = src_mapper[key] + + +def _zarr_python_copy(src_group: zarr.Group, dest_group: zarr.Group): + # Copy attributes + dest_group.attrs.put(src_group.attrs.asdict()) + # Copy arrays + for name, array in src_group.arrays(): + if array.metadata.zarr_format == 2: + spec = AnyArraySpecV2.from_zarr(array) + else: + spec = AnyArraySpecV3.from_zarr(array) + dst = spec.to_zarr( + store=dest_group.store, + path=f"{dest_group.path}/{name}", + overwrite=True, + ) + if array.ndim > 0: + dask_array = da.from_zarr(array) + da.to_zarr(dask_array, dst, overwrite=False) + # Copy subgroups + for name, subgroup in src_group.groups(): + dest_subgroup = dest_group.create_group(name, overwrite=True) + _zarr_python_copy(subgroup, dest_subgroup) + + +def copy_group( + src_group: zarr.Group, dest_group: zarr.Group, suppress_warnings: bool = False +): + if src_group.metadata.zarr_format != dest_group.metadata.zarr_format: + raise NgioValueError( + "Different Zarr format versions between source and destination, " + "cannot copy." ) - def safe_derive_handler( - self, - path: str, - overwrite: bool = False, - ) -> tuple[bool, "ZarrGroupHandler | NgioError"]: - """Derive a new handler from the current handler.""" - try: - return True, self.derive_handler(path, overwrite=overwrite) - except NgioError as e: - return False, e + if not is_group_listable(src_group): + raise NgioValueError("Source group is not listable, cannot copy.") - def copy_handler(self, handler: "ZarrGroupHandler") -> None: - """Copy the group to a new store.""" - _, n_skipped, _ = zarr.copy_store( - source=self.group.store, - dest=handler.group.store, - source_path=self.group.path, - dest_path=handler.group.path, - if_exists="replace", + if dest_group.read_only: + raise NgioValueError("Destination group is read only, cannot copy.") + if isinstance(src_group.store, LocalStore | FsspecStore) and isinstance( + dest_group.store, LocalStore | FsspecStore + ): + _fsspec_copy(src_group.store, src_group.path, dest_group.store, dest_group.path) + return + if not suppress_warnings: + warnings.warn( + "Fsspec copy not possible, falling back to Zarr Python API for the copy. " + "This will preserve some tabular data non-zarr native (parquet, and csv), " + "and it will be slower for large datasets.", + UserWarning, + stacklevel=2, ) - if n_skipped > 0: - raise NgioValueError( - f"Error copying group to {handler.full_url}, " - f"#{n_skipped} files where skipped." - ) + _zarr_python_copy(src_group, dest_group) diff --git a/tests/conftest.py b/tests/conftest.py index 0faeb47c..9c16cdf0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,9 +32,15 @@ def cardiomyocyte_small_mip_path(tmp_path: Path) -> Path: @pytest.fixture -def images_v04(tmp_path: Path) -> dict[str, Path]: - source = Path("tests/data/v04/images/") - dest = tmp_path / "v04" / "images" - dest.mkdir(parents=True, exist_ok=True) - shutil.copytree(source, dest, dirs_exist_ok=True) - return {file.name: file for file in dest.glob("*.zarr")} +def images_all_versions(tmp_path: Path) -> dict[str, Path]: + dest_base = tmp_path / "all_versions" / "images" + dest_base.mkdir(parents=True, exist_ok=True) + paths = {} + for version in ["v04", "v05"]: + source = Path(f"tests/data/{version}/images/") + dest = dest_base / version + dest.mkdir(parents=True, exist_ok=True) + shutil.copytree(source, dest, dirs_exist_ok=True) + for file in dest.glob("*.zarr"): + paths[f"{version}/{file.name}"] = file + return paths diff --git a/tests/create_test_data.py b/tests/create_test_data.py new file mode 100644 index 00000000..12ebfd5f --- /dev/null +++ b/tests/create_test_data.py @@ -0,0 +1,78 @@ +from pathlib import Path + +from ngio import NgffVersions, create_empty_ome_zarr + +DATA_DIR = Path("tests/data") + +IMAGE_SPECS = [ + { + "name": "test_image_tcyx.zarr", + "shape": (4, 3, 64, 64), + "axes": "tcyx", + }, + { + "name": "test_image_tczyx.zarr", + "shape": (4, 2, 10, 64, 64), + "axes": "tczyx", + }, + { + "name": "test_image_zyx.zarr", + "shape": (10, 64, 64), + "axes": "zyx", + }, + { + "name": "test_image_c1yx.zarr", + "shape": (2, 1, 64, 64), + "axes": "czyx", + }, + { + "name": "test_image_tyx.zarr", + "shape": (4, 64, 64), + "axes": "tyx", + }, + { + "name": "test_image_tzyx.zarr", + "shape": (4, 10, 64, 64), + "axes": "tzyx", + }, + { + "name": "test_image_cyx.zarr", + "shape": (2, 64, 64), + "axes": "cyx", + }, + { + "name": "test_image_yx.zarr", + "shape": (64, 64), + "axes": "yx", + }, + { + "name": "test_image_czyx.zarr", + "shape": (2, 10, 64, 64), + "axes": "czyx", + }, +] + + +def create_test_images_dataset(version: NgffVersions) -> None: + version_str = "".join(version.split(".")) + base_dir = DATA_DIR / f"v{version_str}" / "images" + base_dir.mkdir(parents=True, exist_ok=True) + for spec in IMAGE_SPECS: + image_path = base_dir / spec["name"] + ome_zarr = create_empty_ome_zarr( + store=image_path, + xy_pixelsize=0.5, + shape=spec["shape"], + axes_names=spec["axes"], + ngff_version=version, + levels=2, + overwrite=True, + ) + ome_zarr.derive_label("label") + well_roi_table = ome_zarr.build_image_roi_table() + ome_zarr.add_table(name="well_ROI_table", table=well_roi_table, backend="csv") + + +if __name__ == "__main__": + for version in ["0.4", "0.5"]: + create_test_images_dataset(version) # type: ignore diff --git a/tests/data/v04/images/test_image_c1yx.zarr/.zattrs b/tests/data/v04/images/test_image_c1yx.zarr/.zattrs index b9773360..cb00ca1b 100644 --- a/tests/data/v04/images/test_image_c1yx.zarr/.zattrs +++ b/tests/data/v04/images/test_image_c1yx.zarr/.zattrs @@ -1,86 +1,86 @@ { - "multiscales": [ + "multiscales": [ + { + "axes": [ { - "axes": [ - { - "name": "c", - "type": "channel" - }, - { - "name": "z", - "type": "space", - "unit": "micrometer" - }, - { - "name": "y", - "type": "space", - "unit": "micrometer" - }, - { - "name": "x", - "type": "space", - "unit": "micrometer" - } - ], - "datasets": [ - { - "coordinateTransformations": [ - { - "scale": [ - 1.0, - 1.0, - 0.5, - 0.5 - ], - "type": "scale" - } - ], - "path": "0" - }, - { - "coordinateTransformations": [ - { - "scale": [ - 1.0, - 1.0, - 1.0, - 1.0 - ], - "type": "scale" - } - ], - "path": "1" - } - ], - "version": "0.4" + "name": "c", + "type": "channel" + }, + { + "name": "z", + "type": "space", + "unit": "micrometer" + }, + { + "name": "y", + "type": "space", + "unit": "micrometer" + }, + { + "name": "x", + "type": "space", + "unit": "micrometer" } - ], - "omero": { - "channels": [ + ], + "datasets": [ + { + "path": "0", + "coordinateTransformations": [ { - "active": true, - "color": "00FFFF", - "label": "channel1", - "wavelength_id": "channel1", - "window": { - "end": 255.0, - "max": 255.0, - "min": 0.0, - "start": 0.0 - } - }, + "type": "scale", + "scale": [ + 1.0, + 1.0, + 0.5, + 0.5 + ] + } + ] + }, + { + "path": "1", + "coordinateTransformations": [ { - "active": true, - "color": "00FFFF", - "label": "channel2", - "wavelength_id": "channel2", - "window": { - "end": 255.0, - "max": 255.0, - "min": 0.0, - "start": 0.0 - } + "type": "scale", + "scale": [ + 1.0, + 1.0, + 1.0, + 1.0 + ] } - ] + ] + } + ], + "version": "0.4" } + ], + "omero": { + "channels": [ + { + "color": "00FFFF", + "window": { + "max": 65535.0, + "min": 0.0, + "start": 0.0, + "end": 65535.0 + }, + "label": "channel_0", + "wavelength_id": "channel_0", + "active": true + }, + { + "color": "FF00FF", + "window": { + "max": 65535.0, + "min": 0.0, + "start": 0.0, + "end": 65535.0 + }, + "label": "channel_1", + "wavelength_id": "channel_1", + "active": true + } + ] + } } \ No newline at end of file diff --git a/tests/data/v04/images/test_image_c1yx.zarr/.zgroup b/tests/data/v04/images/test_image_c1yx.zarr/.zgroup index 3b7daf22..cab13da6 100644 --- a/tests/data/v04/images/test_image_c1yx.zarr/.zgroup +++ b/tests/data/v04/images/test_image_c1yx.zarr/.zgroup @@ -1,3 +1,3 @@ { - "zarr_format": 2 + "zarr_format": 2 } \ No newline at end of file diff --git a/tests/data/v04/images/test_image_c1yx.zarr/0/.zarray b/tests/data/v04/images/test_image_c1yx.zarr/0/.zarray index f1c0ee41..83ed1a14 100644 --- a/tests/data/v04/images/test_image_c1yx.zarr/0/.zarray +++ b/tests/data/v04/images/test_image_c1yx.zarr/0/.zarray @@ -1,27 +1,27 @@ { - "chunks": [ - 2, - 1, - 64, - 64 - ], - "compressor": { - "blocksize": 0, - "clevel": 5, - "cname": "lz4", - "id": "blosc", - "shuffle": 1 - }, - "dimension_separator": "/", - "dtype": "|u1", - "fill_value": 0, - "filters": null, - "order": "C", - "shape": [ - 2, - 1, - 64, - 64 - ], - "zarr_format": 2 + "shape": [ + 2, + 1, + 64, + 64 + ], + "chunks": [ + 2, + 1, + 64, + 64 + ], + "dtype": " bool: + return os.getenv("GITHUB_ACTIONS") == "true" or os.getenv("CI") == "true" + + +if sys.platform == "darwin" and _running_on_github_ci(): + # The store tests require local servers which seem to have issues on macOS CI. + # Skip the whole module in this case. + pytestmark = pytest.skip( + reason="Integration tests (local servers) are skipped on macOS CI", + allow_module_level=True, + ) + + +def _wait_for_port(proc, host, port, timeout=20): + """Wait until a TCP port starts accepting connections, or the process dies.""" + start = time.time() + while time.time() - start < timeout: + # bail out early if the process crashed + if proc.poll() is not None: + raise RuntimeError("moto server process exited early") + + try: + with socket.create_connection((host, port), timeout=0.5): + return + except OSError: + time.sleep(0.2) + + raise RuntimeError(f"Port {port} did not open in time") + + +@pytest.fixture(scope="session") +def moto_s3_server(tmp_path_factory): + host = "127.0.0.1" + port = 5005 + + log_dir = tmp_path_factory.mktemp("moto_logs") + log_file = log_dir / "server.log" + + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + # only start S3 service f + env["MOTO_SERVICE"] = "s3" + + # NOTE: no "s3" argument here anymore + cmd = [ + sys.executable, + "-m", + "moto.server", + "-p", + str(port), + ] + + with log_file.open("wb") as lf: + proc = subprocess.Popen( + cmd, + stdout=lf, + stderr=subprocess.STDOUT, + env=env, + ) + + try: + _wait_for_port(proc, host, port, timeout=20) + except Exception as e: + proc.terminate() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + log_text = log_file.read_text(errors="replace") + raise RuntimeError( + f"Failed to start moto server: {e}\n--- moto log ---\n{log_text}" + ) from e + + s3_endpoint_url = f"http://{host}:{port}" + bucket_name = "s3-ci-test-bucket" + s3 = boto3.client( + "s3", + region_name="us-east-1", + aws_access_key_id="test", + aws_secret_access_key="test", + endpoint_url=s3_endpoint_url, + ) + s3.create_bucket(Bucket=bucket_name) + yield {"endpoint_url": s3_endpoint_url, "bucket_name": bucket_name} + + proc.terminate() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + + +def _find_free_port(host="127.0.0.1"): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind((host, 0)) + return s.getsockname()[1] + + +@pytest.fixture(scope="session") +def http_static_server(tmp_path_factory): + """ + Serve a temporary directory via `python -m http.server`. + + From the test code's perspective this is read-only: you write files + directly into `root` on disk, and then access them via HTTP. + """ + host = "127.0.0.1" + port = _find_free_port(host) + + root = tmp_path_factory.mktemp("http_static_root") + + cmd = [ + sys.executable, + "-m", + "http.server", + str(port), + "--bind", + host, + ] + + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + + proc = subprocess.Popen( + cmd, + cwd=str(root), # serve this directory + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + env=env, + ) + + try: + _wait_for_port(proc, host, port, timeout=10) + except Exception as e: + proc.terminate() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + raise RuntimeError(f"Failed to start http server: {e}") from e + + yield {"url": f"http://{host}:{port}", "root": Path(root)} + + proc.terminate() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() diff --git a/tests/stores/test_http_store.py b/tests/stores/test_http_store.py new file mode 100644 index 00000000..79f0a2fe --- /dev/null +++ b/tests/stores/test_http_store.py @@ -0,0 +1,80 @@ +from pathlib import Path + +from utils import ( + check_ome_zarr, + create_sample_ome_zarr, + derive_image, + get_http_mapper, + get_s3_mapper, + random_zarr_path, +) + +from ngio import open_ome_zarr_container + +HTTP_STORE_SUPPORTED_BACKENDS = ["anndata", "json", "csv", "parquet"] + + +def test_http_store(http_static_server: dict) -> None: + # create boto3 client pointing at moto server + url, root = http_static_server["url"], http_static_server["root"] + + zarr_path = random_zarr_path() + local_store = root / zarr_path + _ = create_sample_ome_zarr( + store=local_store, supported_backends=HTTP_STORE_SUPPORTED_BACKENDS + ) + http_mapper = get_http_mapper(url, zarr_path) + ome_zarr = open_ome_zarr_container(store=http_mapper) + check_ome_zarr(ome_zarr, supported_backends=HTTP_STORE_SUPPORTED_BACKENDS) + + +def test_http_store_derive_to_s3_store( + http_static_server: dict, moto_s3_server: dict +) -> None: + url, root = http_static_server["url"], http_static_server["root"] + zarr_path = random_zarr_path() + local_store = root / zarr_path + _ = create_sample_ome_zarr( + store=local_store, supported_backends=HTTP_STORE_SUPPORTED_BACKENDS + ) + http_mapper = get_http_mapper(url, zarr_path) + ome_zarr = open_ome_zarr_container(store=http_mapper) + other_store = get_s3_mapper( + base_url=moto_s3_server["endpoint_url"], + bucket_name=moto_s3_server["bucket_name"], + zarr_path=random_zarr_path(), + ) + derived_ome_zarr = derive_image(ome_zarr, other_store=other_store) + check_ome_zarr(derived_ome_zarr, supported_backends=HTTP_STORE_SUPPORTED_BACKENDS) + + +def test_http_store_derive_to_local_store( + http_static_server: dict, tmp_path: Path +) -> None: + url, root = http_static_server["url"], http_static_server["root"] + zarr_path = random_zarr_path() + local_store = root / zarr_path + _ = create_sample_ome_zarr( + store=local_store, supported_backends=HTTP_STORE_SUPPORTED_BACKENDS + ) + http_mapper = get_http_mapper(url, zarr_path) + ome_zarr = open_ome_zarr_container(store=http_mapper) + + other_store = tmp_path / "http_local_store_test" / random_zarr_path() + derived_ome_zarr = derive_image(ome_zarr, other_store=other_store) + check_ome_zarr(derived_ome_zarr, supported_backends=HTTP_STORE_SUPPORTED_BACKENDS) + + +def test_http_store_derive_to_memory_store(http_static_server: dict) -> None: + url, root = http_static_server["url"], http_static_server["root"] + zarr_path = random_zarr_path() + local_store = root / zarr_path + _ = create_sample_ome_zarr( + store=local_store, supported_backends=HTTP_STORE_SUPPORTED_BACKENDS + ) + http_mapper = get_http_mapper(url, zarr_path) + ome_zarr = open_ome_zarr_container(store=http_mapper) + + other_store = {} + derived_ome_zarr = derive_image(ome_zarr, other_store=other_store) + check_ome_zarr(derived_ome_zarr, supported_backends=["anndata", "json"]) diff --git a/tests/stores/test_local_store.py b/tests/stores/test_local_store.py new file mode 100644 index 00000000..9575ca5d --- /dev/null +++ b/tests/stores/test_local_store.py @@ -0,0 +1,54 @@ +from pathlib import Path + +from utils import ( + check_ome_zarr, + create_sample_ome_zarr, + derive_image, + get_s3_mapper, + random_zarr_path, +) + +LOCAL_STORE_SUPPORTED_BACKENDS = ["anndata", "json", "csv", "parquet"] + + +def test_local_store(tmp_path: Path) -> None: + store_path = tmp_path / "local_store_test" / "test.zarr" + ome_zarr = create_sample_ome_zarr( + store=store_path, supported_backends=LOCAL_STORE_SUPPORTED_BACKENDS + ) + check_ome_zarr(ome_zarr, supported_backends=LOCAL_STORE_SUPPORTED_BACKENDS) + + +def test_local_store_derive_to_local_store(tmp_path: Path) -> None: + store_path = tmp_path / "local_store_test" / "test.zarr" + ome_zarr = create_sample_ome_zarr( + store=store_path, supported_backends=LOCAL_STORE_SUPPORTED_BACKENDS + ) + other_store = tmp_path / "local_store_test" / "derived_test.zarr" + derived_ome_zarr = derive_image(ome_zarr, other_store=other_store) + check_ome_zarr(derived_ome_zarr, supported_backends=LOCAL_STORE_SUPPORTED_BACKENDS) + + +def test_local_store_derive_to_s3_store(tmp_path: Path, moto_s3_server: dict) -> None: + store_path = tmp_path / "local_store_test" / "test.zarr" + ome_zarr = create_sample_ome_zarr( + store=store_path, supported_backends=LOCAL_STORE_SUPPORTED_BACKENDS + ) + # create boto3 client pointing at moto server + other_store = get_s3_mapper( + base_url=moto_s3_server["endpoint_url"], + bucket_name=moto_s3_server["bucket_name"], + zarr_path=random_zarr_path(), + ) + derived_ome_zarr = derive_image(ome_zarr, other_store=other_store) + check_ome_zarr(derived_ome_zarr, supported_backends=LOCAL_STORE_SUPPORTED_BACKENDS) + + +def test_local_store_derive_to_memory_store(tmp_path: Path) -> None: + store_path = tmp_path / "local_store_test" / "test.zarr" + ome_zarr = create_sample_ome_zarr( + store=store_path, supported_backends=LOCAL_STORE_SUPPORTED_BACKENDS + ) + other_store = {} + derived_ome_zarr = derive_image(ome_zarr, other_store=other_store) + check_ome_zarr(derived_ome_zarr, supported_backends=["anndata", "json"]) diff --git a/tests/stores/test_memory_store.py b/tests/stores/test_memory_store.py new file mode 100644 index 00000000..fd349e7b --- /dev/null +++ b/tests/stores/test_memory_store.py @@ -0,0 +1,54 @@ +from pathlib import Path + +from utils import ( + check_ome_zarr, + create_sample_ome_zarr, + derive_image, + get_s3_mapper, + random_zarr_path, +) + +MEMORY_STORE_SUPPORTED_BACKENDS = ["anndata", "json", "csv", "parquet"] + + +def test_memory_store() -> None: + store_path = {} + ome_zarr = create_sample_ome_zarr( + store=store_path, supported_backends=MEMORY_STORE_SUPPORTED_BACKENDS + ) + check_ome_zarr(ome_zarr, supported_backends=MEMORY_STORE_SUPPORTED_BACKENDS) + + +def test_memory_store_derive_to_memory_store() -> None: + store_path = {} + ome_zarr = create_sample_ome_zarr( + store=store_path, supported_backends=MEMORY_STORE_SUPPORTED_BACKENDS + ) + other_store = {} + derived_ome_zarr = derive_image(ome_zarr, other_store=other_store) + check_ome_zarr(derived_ome_zarr, supported_backends=["anndata", "json"]) + + +def test_memory_store_derive_to_local_store(tmp_path: Path) -> None: + store_path = {} + ome_zarr = create_sample_ome_zarr( + store=store_path, supported_backends=MEMORY_STORE_SUPPORTED_BACKENDS + ) + other_store = tmp_path / "local_store_test" / "derived_test.zarr" + derived_ome_zarr = derive_image(ome_zarr, other_store=other_store) + check_ome_zarr(derived_ome_zarr, supported_backends=["anndata", "json"]) + + +def test_memory_store_derive_to_s3_store(moto_s3_server: dict) -> None: + store_path = {} + ome_zarr = create_sample_ome_zarr( + store=store_path, supported_backends=MEMORY_STORE_SUPPORTED_BACKENDS + ) + # create boto3 client pointing at moto server + other_store = get_s3_mapper( + base_url=moto_s3_server["endpoint_url"], + bucket_name=moto_s3_server["bucket_name"], + zarr_path=random_zarr_path(), + ) + derived_ome_zarr = derive_image(ome_zarr, other_store=other_store) + check_ome_zarr(derived_ome_zarr, supported_backends=["anndata", "json"]) diff --git a/tests/stores/test_s3_store.py b/tests/stores/test_s3_store.py new file mode 100644 index 00000000..36c39046 --- /dev/null +++ b/tests/stores/test_s3_store.py @@ -0,0 +1,88 @@ +# tests/stores/test_s3_store.py + +from pathlib import Path + +from utils import ( + check_ome_zarr, + create_sample_ome_zarr, + derive_image, + get_s3_mapper, + random_zarr_path, +) + +S3_STORE_SUPPORTED_BACKENDS = ["anndata", "json"] +# CSV and Parquet work locally on S3 store, but not when using moto server for testing. +# To be investigated. + + +def test_s3_store(moto_s3_server: dict) -> None: + # create boto3 client pointing at moto server + endpoint_url, bucket_name = ( + moto_s3_server["endpoint_url"], + moto_s3_server["bucket_name"], + ) + store = get_s3_mapper( + endpoint_url, bucket_name=bucket_name, zarr_path=random_zarr_path() + ) + ome_zarr = create_sample_ome_zarr( + store=store, supported_backends=S3_STORE_SUPPORTED_BACKENDS + ) + check_ome_zarr(ome_zarr, supported_backends=S3_STORE_SUPPORTED_BACKENDS) + + +def test_s3_store_derive_to_s3_store(moto_s3_server: dict) -> None: + endpoint_url, bucket_name = ( + moto_s3_server["endpoint_url"], + moto_s3_server["bucket_name"], + ) + store = get_s3_mapper( + endpoint_url, + bucket_name=bucket_name, + zarr_path=random_zarr_path(), + ) + ome_zarr = create_sample_ome_zarr( + store=store, supported_backends=S3_STORE_SUPPORTED_BACKENDS + ) + other_store = get_s3_mapper( + endpoint_url, + bucket_name=bucket_name, + zarr_path=random_zarr_path(), + ) + derived_ome_zarr = derive_image(ome_zarr, other_store=other_store) + check_ome_zarr(derived_ome_zarr, supported_backends=S3_STORE_SUPPORTED_BACKENDS) + + +def test_s3_store_derive_to_local_store(moto_s3_server: dict, tmp_path: Path) -> None: + endpoint_url, bucket_name = ( + moto_s3_server["endpoint_url"], + moto_s3_server["bucket_name"], + ) + store = get_s3_mapper( + endpoint_url, + bucket_name=bucket_name, + zarr_path=random_zarr_path(), + ) + ome_zarr = create_sample_ome_zarr( + store=store, supported_backends=S3_STORE_SUPPORTED_BACKENDS + ) + other_store = tmp_path / "s3_local_store_test" / random_zarr_path() + derived_ome_zarr = derive_image(ome_zarr, other_store=other_store) + check_ome_zarr(derived_ome_zarr, supported_backends=S3_STORE_SUPPORTED_BACKENDS) + + +def test_s3_store_derive_to_memory_store(moto_s3_server: dict) -> None: + endpoint_url, bucket_name = ( + moto_s3_server["endpoint_url"], + moto_s3_server["bucket_name"], + ) + store = get_s3_mapper( + endpoint_url, + bucket_name=bucket_name, + zarr_path=random_zarr_path(), + ) + ome_zarr = create_sample_ome_zarr( + store=store, supported_backends=S3_STORE_SUPPORTED_BACKENDS + ) + other_store = {} + derived_ome_zarr = derive_image(ome_zarr, other_store=other_store) + check_ome_zarr(derived_ome_zarr, supported_backends=S3_STORE_SUPPORTED_BACKENDS) diff --git a/tests/stores/utils.py b/tests/stores/utils.py new file mode 100644 index 00000000..53b27320 --- /dev/null +++ b/tests/stores/utils.py @@ -0,0 +1,168 @@ +from uuid import uuid4 + +import fsspec.implementations.http +import numpy as np +import pandas as pd +import s3fs + +from ngio import OmeZarrContainer, create_empty_ome_zarr +from ngio.tables import FeatureTable + +TEST_IMAGE = np.ones((3, 5, 64, 64), dtype=np.uint16) +TEST_LABEL = np.zeros((5, 64, 64), dtype=np.uint16) +TEST_LABEL[0, 10:30, 10:30] = 1 +TEST_LABEL[0, 35:55, 35:55] = 2 +TEST_LABEL[0, 20:40, 40:60] = 3 + +TEST_TABLE = pd.DataFrame( + { + "label": [1, 2, 3], + "area": [150, 200, 250], + "mean_intensity": [120.5, 130.0, 140.5], + } +) +TEST_TABLE.set_index("label", inplace=True) + + +def set_to_image(ome_zarr: OmeZarrContainer) -> OmeZarrContainer: + """Create and return an empty OME-Zarr container in the given store.""" + image = ome_zarr.get_image(path="1") + image.set_array(patch=TEST_IMAGE) + image.consolidate() + return ome_zarr + + +def check_image_data( + ome_zarr: OmeZarrContainer, +) -> None: + """Retrieve the array data from the first image at the specified pyramid level.""" + image = ome_zarr.get_image() + array_data = image.get_as_numpy() + np.testing.assert_array_equal(array_data, TEST_IMAGE) + + +def derive_label( + ome_zarr: OmeZarrContainer, +) -> OmeZarrContainer: + """Derive a label array from the first image in the OME-Zarr container.""" + label_array = ome_zarr.derive_label(name="labels") + label_array.set_array(patch=TEST_LABEL) + label_array.consolidate() + return ome_zarr + + +def check_label_data( + ome_zarr: OmeZarrContainer, +) -> None: + """Retrieve the label data from the specified label array.""" + label_array = ome_zarr.get_label(name="labels") + label_data = label_array.get_as_numpy() + np.testing.assert_array_equal(label_data, TEST_LABEL) + + +def add_table_to_ome_zarr( + ome_zarr: OmeZarrContainer, + backend: str = "anndata", +) -> OmeZarrContainer: + """Add a table to the OME-Zarr container.""" + test_feature_table = FeatureTable(table_data=TEST_TABLE, reference_label="labels") + name = f"test_table_{backend}" + ome_zarr.add_table( + name=name, + table=test_feature_table, + backend=backend, + ) + return ome_zarr + + +def check_table_data( + ome_zarr: OmeZarrContainer, + table_name: str, +) -> None: + """Retrieve the table data from the specified table in the OME-Zarr container.""" + table = ome_zarr.get_table(name=table_name) + dataframe = table.dataframe + pd.testing.assert_frame_equal( + dataframe.sort_index(axis=1), TEST_TABLE.sort_index(axis=1) + ) + + +def derive_image( + ome_zarr: OmeZarrContainer, + other_store, + copy_labels: bool = True, + copy_tables: bool = True, +) -> OmeZarrContainer: + """Derive a new image from the first image in the OME-Zarr container.""" + derived_ome_zarr = ome_zarr.derive_image( + store=other_store, + copy_labels=copy_labels, + copy_tables=copy_tables, + ) + return derived_ome_zarr + + +def create_sample_ome_zarr( + store, supported_backends: list[str] | None = None +) -> OmeZarrContainer: + """Create a sample OME-Zarr structure in the given store for testing.""" + ome_zarr = create_empty_ome_zarr( + store=store, + shape=(3, 5, 64, 64), + pixelsize=(0.65, 0.65), + channels_meta=["Channel 1", "Channel 2", "Channel 3"], + levels=3, + axes_names=["c", "z", "y", "x"], + ) + ome_zarr = set_to_image(ome_zarr) + check_image_data(ome_zarr) + ome_zarr = derive_label(ome_zarr) + if supported_backends is None: + supported_backends = ["anndata", "json", "csv", "parquet"] + + for backend in supported_backends: + ome_zarr = add_table_to_ome_zarr(ome_zarr, backend=backend) + return ome_zarr + + +def check_ome_zarr( + ome_zarr: OmeZarrContainer, + check_tables: bool = True, + check_labels: bool = True, + supported_backends: list[str] | None = None, +) -> None: + """Check that all tables in the OME-Zarr container match the test table.""" + if check_labels: + check_label_data(ome_zarr) + if check_tables: + if supported_backends is None: + supported_backends = ["anndata", "json", "csv", "parquet"] + + table_list = ome_zarr.list_tables() + assert len(table_list) >= len(supported_backends) + for table_name in table_list: + backend = table_name.split("_")[-1] + if backend in supported_backends: + check_table_data(ome_zarr, table_name=table_name) + + +def get_s3_mapper(base_url: str, bucket_name: str, zarr_path: str): + s3_fs = s3fs.S3FileSystem( + key="test", + secret="test", + client_kwargs={ + "endpoint_url": base_url, + "region_name": "us-east-1", + }, + ) + return s3_fs.get_mapper(f"{bucket_name}/{zarr_path}") + + +def get_http_mapper(base_url: str, zarr_path: str): + http_fs = fsspec.implementations.http.HTTPFileSystem() + return http_fs.get_mapper(f"{base_url}/{zarr_path}") + + +def random_zarr_path() -> str: + """Generate a random zarr path for testing.""" + return f"test_zarr_{uuid4()}.zarr" diff --git a/tests/unit/common/test_pyramid.py b/tests/unit/common/test_pyramid.py index 815e55fd..9329be17 100644 --- a/tests/unit/common/test_pyramid.py +++ b/tests/unit/common/test_pyramid.py @@ -22,9 +22,9 @@ def test_on_disk_zooms( tmp_path: Path, order: InterpolationOrder, mode: Literal["dask", "numpy", "coarsen"] ): source = tmp_path / "source.zarr" - source_array = zarr.open_array(source, shape=(16, 128, 128), dtype="uint8") + source_array = zarr.create_array(source, shape=(16, 128, 128), dtype="uint8") target = tmp_path / "target.zarr" - target_array = zarr.open_array(target, shape=(16, 64, 64), dtype="uint8") + target_array = zarr.create_array(target, shape=(16, 64, 64), dtype="uint8") on_disk_zoom(source_array, target_array, order=order, mode=mode) diff --git a/tests/unit/common/test_roi.py b/tests/unit/common/test_roi.py index 4592d21a..92e6fc03 100644 --- a/tests/unit/common/test_roi.py +++ b/tests/unit/common/test_roi.py @@ -6,43 +6,50 @@ def test_basic_rois_ops(): - roi = Roi( + roi = Roi.from_values( name="test", - x=0.0, - y=0.0, - z=0.0, - x_length=1.0, - y_length=1.0, - z_length=1.0, - label=1, - unit="micrometer", # type: ignore + slices={ + "x": (0, 1), + "y": (0, 1), + "z": (0, 1), + }, + space="world", other="other", # type: ignore + label=1, ) - assert roi.x == 0.0 + slixe_x = roi.get("x") + assert slixe_x is not None + assert slixe_x.axis_name == "x" + assert slixe_x.start == 0 + assert slixe_x.length == 1 pixel_size = PixelSize(x=1.0, y=1.0, z=1.0) - raster_roi = roi.to_roi_pixels(pixel_size) + raster_roi = roi.to_pixel(pixel_size) assert roi.__str__() assert roi.__repr__() assert raster_roi.to_slicing_dict(pixel_size=pixel_size) == { - "x": slice(0, 1), - "y": slice(0, 1), - "z": slice(0, 1), - "t": slice(None), + "x": slice(0.0, 1.0), + "y": slice(0.0, 1.0), + "z": slice(0.0, 1.0), } assert roi.model_extra is not None assert roi.model_extra["other"] == "other" - world_roi_2 = raster_roi.to_roi(pixel_size) + world_roi_2 = raster_roi.to_world(pixel_size) + + x_slice_2 = world_roi_2.get("x") + assert x_slice_2 is not None + assert x_slice_2.axis_name == "x" + assert x_slice_2.start == 0 + assert x_slice_2.length == 1 - assert world_roi_2.x == 0.0 - assert world_roi_2.y == 0.0 - assert world_roi_2.z == 0.0 - assert world_roi_2.x_length == 1.0 - assert world_roi_2.y_length == 1.0 - assert world_roi_2.z_length == 1.0 + y_slice_2 = world_roi_2.get("y") + assert y_slice_2 is not None + assert y_slice_2.axis_name == "y" + assert y_slice_2.start == 0 + assert y_slice_2.length == 1 assert world_roi_2.other == "other" # type: ignore roi_zoomed = roi.zoom(2.0) @@ -50,21 +57,20 @@ def test_basic_rois_ops(): roi.zoom(-1.0) assert roi_zoomed.to_slicing_dict(pixel_size) == { - "x": slice(0, 2), - "y": slice(0, 2), - "z": slice(0, 1), - "t": slice(None), + "x": slice(0.0, 2.0), + "y": slice(0.0, 2.0), + "z": slice(0.0, 1.0), } - roi2 = Roi( + roi2 = Roi.from_values( name="test2", - x=0.0, - y=0.0, - z=0.0, - x_length=1.0, - y_length=1.0, - z_length=1.0, - unit="micrometer", # type: ignore + slices={ + "x": (0.0, 1.0), + "y": (0.0, 1.0), + "z": (0.0, 1.0), + }, + space="world", + # type: ignore label=1, ) roi_i = roi.intersection(roi2) @@ -81,59 +87,54 @@ def test_basic_rois_ops(): [ ( # Basic intersection - Roi( + Roi.from_values( name="ref", - x=0.0, - y=0.0, - z=0.0, - x_length=1.0, - y_length=1.0, - z_length=1.0, - unit="micrometer", # type: ignore + slices={ + "x": (0.0, 1.0), + "y": (0.0, 1.0), + "z": (0.0, 1.0), + }, + space="world", ), - Roi( + Roi.from_values( name="other", - x=0.5, - y=0.5, - z=0.5, - x_length=1.0, - y_length=1.0, - z_length=1.0, - unit="micrometer", # type: ignore + slices={ + "x": (0.5, 1.0), + "y": (0.5, 1.0), + "z": (0.5, 1.0), + }, + space="world", ), - Roi( + Roi.from_values( name="ref:other", - x=0.5, - y=0.5, - z=0.5, - x_length=0.5, - y_length=0.5, - z_length=0.5, - unit="micrometer", # type: ignore + slices={ + "x": (0.5, 0.5), + "y": (0.5, 0.5), + "z": (0.5, 0.5), + }, + space="world", ), "ref:other", ), ( # No intersection - Roi( + Roi.from_values( name="ref", - x=0.0, - y=0.0, - z=0.0, - x_length=1.0, - y_length=1.0, - z_length=1.0, - unit="micrometer", # type: ignore + slices={ + "x": (0.0, 1.0), + "y": (0.0, 1.0), + "z": (0.0, 1.0), + }, + space="world", ), - Roi( + Roi.from_values( name="other", - x=2.0, - y=2.0, - z=2.0, - x_length=1.0, - y_length=1.0, - z_length=1.0, - unit="micrometer", # type: ignore + slices={ + "x": (2.0, 1.0), + "y": (2.0, 1.0), + "z": (2.0, 1.0), + }, + space="world", ), None, "", @@ -141,41 +142,34 @@ def test_basic_rois_ops(): ( # Intersection with z=None (expected behaves like infinite z) # t=None (expected behaves like infinite t) - Roi( + Roi.from_values( name="ref", - x=0.0, - y=0.0, - z=None, - t=0, - x_length=2.0, - y_length=2.0, - z_length=None, - t_length=2.0, - unit="micrometer", # type: ignore + slices={ + "x": (0.0, 1.0), + "y": (0.0, 1.0), + }, + space="world", ), - Roi( + Roi.from_values( name=None, - x=-1.0, - y=-1.0, - z=-1.0, - t=None, - x_length=2.0, - y_length=2.0, - z_length=2.0, - t_length=None, - unit="micrometer", # type: ignore + slices={ + "x": (0.5, 1.0), + "y": (0.5, 1.0), + "z": (-1.0, 2.0), + "t": (0.0, 2.0), + }, + space="world", + unit="micrometer", ), - Roi( + Roi.from_values( name="ref", - x=0.0, - y=0.0, - z=-1.0, - t=0, - x_length=1.0, - y_length=1.0, - z_length=2.0, - t_length=2.0, - unit="micrometer", # type: ignore + slices={ + "x": (0.5, 0.5), + "y": (0.5, 0.5), + "z": (-1.0, 2.0), + "t": (0.0, 2.0), + }, + space="world", ), "ref", ), @@ -193,9 +187,93 @@ def test_rois_intersection( else: assert intersection is not None assert intersection.name == expected_name - assert intersection.x == expected_intersection.x - assert intersection.y == expected_intersection.y - assert intersection.z == expected_intersection.z - assert intersection.x_length == expected_intersection.x_length - assert intersection.y_length == expected_intersection.y_length - assert intersection.z_length == expected_intersection.z_length + assert intersection.get("x") == expected_intersection.get("x") + assert intersection.get("y") == expected_intersection.get("y") + assert intersection.get("z") == expected_intersection.get("z") + + assert intersection.get("t") == expected_intersection.get("t") + + +@pytest.mark.parametrize( + "roi_ref,roi_other,expected_union,expected_name", + [ + ( + # Basic intersection + Roi.from_values( + name="ref", + slices={ + "x": (0.0, 1.0), + "y": (0.0, 1.0), + "z": (0.0, 1.0), + }, + space="world", + ), + Roi.from_values( + name="other", + slices={ + "x": (0.5, 1.0), + "y": (0.5, 1.0), + "z": (0.5, 1.0), + }, + space="world", + ), + Roi.from_values( + name="ref:other", + slices={ + "x": (0.0, 1.5), + "y": (0.0, 1.5), + "z": (0.0, 1.5), + }, + space="world", + ), + "ref:other", + ), + ( + # Intersection with z=None (expected behaves like infinite z) + # t=None (expected behaves like infinite t) + Roi.from_values( + name="ref", + slices={ + "x": (0.0, 1.0), + "y": (0.0, 1.0), + }, + space="world", + ), + Roi.from_values( + name=None, + slices={ + "x": (0.5, 1.0), + "y": (0.5, 1.0), + "z": (-1.0, 2.0), + "t": (0.0, 2.0), + }, + space="world", + unit="micrometer", + ), + Roi.from_values( + name="ref", + slices={ + "x": (0.0, 1.5), + "y": (0.0, 1.5), + "z": (-1.0, 2.0), + "t": (0.0, 2.0), + }, + space="world", + ), + "ref", + ), + ], +) +def test_rois_union( + roi_ref: Roi, + roi_other: Roi, + expected_union: Roi, + expected_name: str, +): + union = roi_ref.union(roi_other) + assert union is not None + assert union.name == expected_name + assert union.get("x") == expected_union.get("x") + assert union.get("y") == expected_union.get("y") + assert union.get("z") == expected_union.get("z") + assert union.get("t") == expected_union.get("t") diff --git a/tests/unit/common/test_transforms.py b/tests/unit/common/test_transforms.py index f519215a..4554a4ae 100644 --- a/tests/unit/common/test_transforms.py +++ b/tests/unit/common/test_transforms.py @@ -24,7 +24,7 @@ def test_zoom_from_dimensions(tmp_path: Path): target_image=full_res_img, order="nearest", ) - roi = Roi(name=None, x=0, y=0, x_length=21, y_length=21) + roi = Roi.from_values(name=None, slices={"x": (0, 21), "y": (0, 21)}) full_res_data = full_res_img.get_roi_as_numpy(roi=roi) rescaled_data = img.get_roi_as_numpy(roi=roi, transforms=[zoom]) @@ -34,7 +34,7 @@ def test_zoom_from_dimensions(tmp_path: Path): except Exception as e: raise AssertionError(f"Failed inbound test: {e}") from e - roi = Roi(name=None, x=80, y=80, x_length=21, y_length=21) + roi = Roi.from_values(name=None, slices={"x": (80, 21), "y": (80, 21)}) full_res_data = full_res_img.get_roi_as_numpy(roi=roi) rescaled_data = img.get_roi_as_numpy(roi=roi, transforms=[zoom]) diff --git a/tests/unit/hcs/test_plate.py b/tests/unit/hcs/test_plate.py index 4caaa3e1..ad5324ea 100644 --- a/tests/unit/hcs/test_plate.py +++ b/tests/unit/hcs/test_plate.py @@ -1,5 +1,6 @@ import asyncio from pathlib import Path +from typing import Literal import pandas.testing as pdt import pytest @@ -59,8 +60,11 @@ def test_open_real_ome_zarr_plate(cardiomyocyte_tiny_path: Path): assert well.get_image_acquisition_id("0") is None -def test_create_and_edit_plate(tmp_path: Path): - test_plate = create_empty_plate(tmp_path / "test_plate.zarr", name="test_plate") +@pytest.mark.parametrize("ngff_version", ["0.4", "0.5"]) +def test_create_and_edit_plate(tmp_path: Path, ngff_version: Literal["0.4", "0.5"]): + test_plate = create_empty_plate( + tmp_path / "test_plate.zarr", name="test_plate", ngff_version=ngff_version + ) assert test_plate.columns == [] assert test_plate.rows == [] assert test_plate.acquisition_ids == [] @@ -168,7 +172,11 @@ def test_tables_api(tmp_path: Path): test_plate.add_table("test_table", test_table, backend="csv") test_roi_table = RoiTable( - rois=[Roi(name="roi_1", x_length=10, y_length=10, z_length=10)] # type: ignore + rois=[ + Roi.from_values( + name="roi_1", slices={"x": (0, 10), "y": (0, 10), "z": (0, 10)} + ) + ] # type: ignore ) test_plate.add_table("test_roi_table", test_roi_table) assert test_plate.list_tables() == ["test_table", "test_roi_table"] @@ -179,6 +187,18 @@ def test_tables_api(tmp_path: Path): test_df, check_names=False, ) + test_plate.delete_table("test_table") + assert "test_table" not in test_plate.list_tables() + test_plate.delete_table("test_table", missing_ok=True) + with pytest.raises(NgioValueError): + test_plate.delete_table("test_table", missing_ok=False) + + test_plate = create_empty_plate( + tmp_path / "test_plate.zarr", name="test_plate", overwrite=True + ) + with pytest.raises(NgioValueError): + test_plate.delete_table("non_existing_table") + test_plate.delete_table("non_existing_table", missing_ok=True) @pytest.mark.filterwarnings("ignore::anndata._warnings.ImplicitModificationWarning") diff --git a/tests/unit/images/test_create.py b/tests/unit/images/test_create.py index 26c71977..80be8d88 100644 --- a/tests/unit/images/test_create.py +++ b/tests/unit/images/test_create.py @@ -2,12 +2,17 @@ import numpy as np import pytest +from pydantic import ValidationError +from zarr.storage import MemoryStore from ngio import ( + OmeZarrContainer, create_empty_ome_zarr, create_ome_zarr_from_array, create_synthetic_ome_zarr, ) +from ngio.images._create_utils import init_image_like_from_shapes +from ngio.ome_zarr_meta import NgioImageMeta from ngio.utils import NgioValueError @@ -128,7 +133,7 @@ def test_create_fail(tmp_path: Path): overwrite=True, ) - with pytest.raises(NgioValueError): + with pytest.raises(ValidationError): create_ome_zarr_from_array( array=np.random.randint(0, 255, (2, 64, 64), dtype="uint8"), store=tmp_path / "fail.zarr", @@ -138,3 +143,116 @@ def test_create_fail(tmp_path: Path): chunks=(1, 64, 64, 64), # should fail expected 3 axes overwrite=True, ) + + +def test_derive_label_channels_policy(): + store = MemoryStore() + ome_zarr = create_synthetic_ome_zarr(store, shape=(3, 1, 64, 65)) + + label = ome_zarr.derive_label("test-label-singleton", channels_policy="singleton") + assert label.dimensions.get("c") == 1 + label = ome_zarr.derive_label("test-label-same", channels_policy="same") + assert label.dimensions.get("c") == 3 + label = ome_zarr.derive_label("test-label-squeeze", channels_policy="squeeze") + assert "c" not in label.axes + assert label.dimensions.get("c") is None + label = ome_zarr.derive_label("test-label-int", channels_policy=2) + assert label.dimensions.get("c") == 2 + + +def test_derive_from_non_dishogeneus_shapes(): + # Yes those shapes are intentionally weird + shapes = [ + (4, 3, 64, 65), + (4, 3, 64, 50), + (4, 3, 32, 25), + ] + store = MemoryStore() + image_handler = init_image_like_from_shapes( + store=store, + meta_type=NgioImageMeta, + shapes=shapes, + pixelsize=0.5, + ) + ome_zarr = OmeZarrContainer(group_handler=image_handler) + ome_zarr.derive_label("test-label-same", channels_policy="same") + for path in ome_zarr.levels_paths: + img = ome_zarr.get_image(path=path) + lbl = ome_zarr.get_label(name="test-label-same", path=path) + assert img.shape == lbl.shape + + image = ome_zarr.get_image(path="1") + ome_zarr.derive_label("test-label-level-1", ref_image=image, channels_policy="same") + + for path_img, path_lbl in zip(["1", "2"], ["0", "1"], strict=True): + img = ome_zarr.get_image(path=path_img) + lbl = ome_zarr.get_label(name="test-label-level-1", path=path_lbl) + assert img.shape == lbl.shape + + lbl = ome_zarr.get_label(name="test-label-level-1", path="2") + scaling_factor = tuple(s1 / s2 for s1, s2 in zip(shapes[0], shapes[1], strict=True)) + assert image.meta.scaling_factor() == scaling_factor + assert lbl.shape == (4, 3, 32, 19) + + +def test_create_with_sharding(tmp_path: Path): + store = tmp_path / "test_image_sharded.zarr" + ome_zarr = create_empty_ome_zarr( + store=store, + shape=(4, 3, 64, 64), + pixelsize=0.5, + chunks=(2, 1, 32, 32), + shards=(4, 3, 64, 64), + dtype="uint8", + levels=3, + overwrite=True, + ngff_version="0.5", + ) + ome_zarr.derive_label("label_sharded") + img = ome_zarr.get_image(path="0") + assert img.zarr_array.shards is not None + assert img.zarr_array.chunks == (2, 1, 32, 32) + assert img.zarr_array.shards == (4, 3, 64, 64) + + img = ome_zarr.get_image(path="2") + assert img.zarr_array.shards is not None + # Check clipping of chunks/shards at the smallest level + assert img.zarr_array.chunks == (2, 1, 16, 16) + assert img.zarr_array.shards == (4, 3, 16, 16) + + +def test_fail_derive_singleton(): + store = MemoryStore() + ome_zarr = create_empty_ome_zarr(store=store, shape=(1, 1, 64, 4), pixelsize=0.5) + expected_shapes = [ + (1, 1, 64, 4), + (1, 1, 32, 2), + (1, 1, 16, 1), + (1, 1, 8, 1), + (1, 1, 4, 1), + ] + expected_pixel_size_x = [0.5, 1.0, 2.0, 2.0, 2.0] + for path, shape, px_x in zip( + ome_zarr.levels_paths, expected_shapes, expected_pixel_size_x, strict=True + ): + img = ome_zarr.get_image(path=path) + assert img.shape == shape + assert img.pixel_size.x == px_x + + +def test_fail_create_from_non_decreasing_shapes(): + # Yes those shapes are intentionally weird + shapes = [ + (4, 3, 64, 64), + (4, 3, 128, 128), + (4, 3, 256, 256), + ] + store = MemoryStore() + + with pytest.raises(NgioValueError): + _ = init_image_like_from_shapes( + store=store, + meta_type=NgioImageMeta, + shapes=shapes, + pixelsize=0.5, + ) diff --git a/tests/unit/images/test_images.py b/tests/unit/images/test_images.py index 50988c9c..39f5c999 100644 --- a/tests/unit/images/test_images.py +++ b/tests/unit/images/test_images.py @@ -10,19 +10,28 @@ @pytest.mark.parametrize( "zarr_name", [ - "test_image_yx.zarr", - "test_image_cyx.zarr", - "test_image_zyx.zarr", - "test_image_czyx.zarr", - "test_image_c1yx.zarr", - "test_image_tyx.zarr", - "test_image_tcyx.zarr", - "test_image_tzyx.zarr", - "test_image_tczyx.zarr", + "v04/test_image_yx.zarr", + "v04/test_image_cyx.zarr", + "v04/test_image_zyx.zarr", + "v04/test_image_czyx.zarr", + "v04/test_image_c1yx.zarr", + "v04/test_image_tyx.zarr", + "v04/test_image_tcyx.zarr", + "v04/test_image_tzyx.zarr", + "v04/test_image_tczyx.zarr", + "v05/test_image_yx.zarr", + "v05/test_image_cyx.zarr", + "v05/test_image_zyx.zarr", + "v05/test_image_czyx.zarr", + "v05/test_image_c1yx.zarr", + "v05/test_image_tyx.zarr", + "v05/test_image_tcyx.zarr", + "v05/test_image_tzyx.zarr", + "v05/test_image_tczyx.zarr", ], ) -def test_open_image(images_v04: dict[str, Path], zarr_name: str): - path = images_v04[zarr_name] +def test_open_image(images_all_versions: dict[str, Path], zarr_name: str): + path = images_all_versions[zarr_name] image = open_image(path) assert isinstance(image, Image) @@ -216,6 +225,7 @@ def test_zoom_virtual_axes( assert img2_data.shape[0] == 1 # Virtual channel axis roi = img1.build_image_roi_table().rois()[0] + print(roi) img2_roi_data = img2.get_roi_as_numpy(roi, transforms=[zoom], axes_order="czyx") # Roi data should match exactly except for virtual axis assert img1_data.shape[1:] == img2_roi_data.shape[1:] diff --git a/tests/unit/images/test_omezarr_container.py b/tests/unit/images/test_omezarr_container.py index 3ca9eeaf..bec2e4ec 100644 --- a/tests/unit/images/test_omezarr_container.py +++ b/tests/unit/images/test_omezarr_container.py @@ -4,11 +4,15 @@ import numpy as np import pytest -from ngio import create_empty_ome_zarr, open_ome_zarr_container +from ngio import ( + create_empty_ome_zarr, + create_synthetic_ome_zarr, + open_ome_zarr_container, +) from ngio.images._image import ChannelSelectionModel from ngio.io_pipes._ops_axes import AxesOps from ngio.io_pipes._ops_slices import SlicingOps -from ngio.utils import fractal_fsspec_store +from ngio.utils import NgioValueError, fractal_fsspec_store class IdentityTransform: @@ -40,19 +44,28 @@ def set_as_dask_transform( @pytest.mark.parametrize( "zarr_name", [ - "test_image_yx.zarr", - "test_image_cyx.zarr", - "test_image_zyx.zarr", - "test_image_czyx.zarr", - "test_image_c1yx.zarr", - "test_image_tyx.zarr", - "test_image_tcyx.zarr", - "test_image_tzyx.zarr", - "test_image_tczyx.zarr", + "v04/test_image_yx.zarr", + "v04/test_image_cyx.zarr", + "v04/test_image_zyx.zarr", + "v04/test_image_czyx.zarr", + "v04/test_image_c1yx.zarr", + "v04/test_image_tyx.zarr", + "v04/test_image_tcyx.zarr", + "v04/test_image_tzyx.zarr", + "v04/test_image_tczyx.zarr", + "v05/test_image_yx.zarr", + "v05/test_image_cyx.zarr", + "v05/test_image_zyx.zarr", + "v05/test_image_czyx.zarr", + "v05/test_image_c1yx.zarr", + "v05/test_image_tyx.zarr", + "v05/test_image_tcyx.zarr", + "v05/test_image_tzyx.zarr", + "v05/test_image_tczyx.zarr", ], ) -def test_open_ome_zarr_container(images_v04: dict[str, Path], zarr_name: str): - path = images_v04[zarr_name] +def test_open_ome_zarr_container(images_all_versions: dict[str, Path], zarr_name: str): + path = images_all_versions[zarr_name] ome_zarr = open_ome_zarr_container(path) whole_image_roi = ome_zarr.build_image_roi_table().get("image") @@ -225,7 +238,7 @@ def test_remote_ome_zarr_container(): # ] _ = ome_zarr.get_label("nuclei", path="0") - _ = ome_zarr.get_table("well_ROI_table") + _ = ome_zarr.get_table("well_ROI_table").dataframe def test_get_and_squeeze(tmp_path: Path): @@ -311,3 +324,55 @@ def test_derive_image_and_labels(tmp_path: Path): ) derived_ome_zarr = ome_zarr.derive_image(tmp_path / "derived.zarr") _ = derived_ome_zarr.derive_label("derived_label") + + +def test_derive_copy_labels_and_tables(tmp_path: Path): + # Testing for #116 + store = tmp_path / "ome_zarr.zarr" + ome_zarr = create_synthetic_ome_zarr( + store, + shape=(3, 20, 30), + levels=1, + axes_names=["c", "y", "x"], + ) + derived_ome_zarr = ome_zarr.derive_image( + tmp_path / "derived.zarr", copy_labels=True, copy_tables=True + ) + assert ome_zarr.list_labels() == derived_ome_zarr.list_labels() + assert ome_zarr.list_tables() == derived_ome_zarr.list_tables() + + +def test_delete_label_and_table(tmp_path: Path): + store = tmp_path / "ome_zarr.zarr" + ome_zarr = create_synthetic_ome_zarr( + store, + shape=(3, 20, 30), + levels=1, + axes_names=["c", "y", "x"], + ) + ome_zarr.derive_label("label_to_delete") + assert "label_to_delete" in ome_zarr.list_labels() + ome_zarr.delete_label("label_to_delete") + assert "label_to_delete" not in ome_zarr.list_labels() + ome_zarr.delete_label("label_to_delete", missing_ok=True) + with pytest.raises(NgioValueError): + ome_zarr.delete_label("label_to_delete", missing_ok=False) + + new_table = ome_zarr.build_image_roi_table() + ome_zarr.add_table("table_to_delete", new_table) + assert "table_to_delete" in ome_zarr.list_tables() + ome_zarr.delete_table("table_to_delete") + assert "table_to_delete" not in ome_zarr.list_tables() + ome_zarr.delete_table("table_to_delete", missing_ok=True) + with pytest.raises(NgioValueError): + ome_zarr.delete_table("table_to_delete", missing_ok=False) + + ome_zarr = create_empty_ome_zarr( + store, shape=(3, 20, 30), pixelsize=0.5, overwrite=True + ) + with pytest.raises(NgioValueError): + ome_zarr.delete_label("non_existing_label") + ome_zarr.delete_label("non_existing_label", missing_ok=True) + with pytest.raises(NgioValueError): + ome_zarr.delete_table("non_existing_table") + ome_zarr.delete_table("non_existing_table", missing_ok=True) diff --git a/tests/unit/iterators/test_iterators.py b/tests/unit/iterators/test_iterators.py index 9c624c7a..c21e8bb4 100644 --- a/tests/unit/iterators/test_iterators.py +++ b/tests/unit/iterators/test_iterators.py @@ -18,20 +18,29 @@ @pytest.mark.parametrize( "zarr_name", [ - "test_image_yx.zarr", - "test_image_cyx.zarr", - "test_image_zyx.zarr", - "test_image_czyx.zarr", - "test_image_c1yx.zarr", - "test_image_tyx.zarr", - "test_image_tcyx.zarr", - "test_image_tzyx.zarr", - "test_image_tczyx.zarr", + "v04/test_image_yx.zarr", + "v04/test_image_cyx.zarr", + "v04/test_image_zyx.zarr", + "v04/test_image_czyx.zarr", + "v04/test_image_c1yx.zarr", + "v04/test_image_tyx.zarr", + "v04/test_image_tcyx.zarr", + "v04/test_image_tzyx.zarr", + "v04/test_image_tczyx.zarr", + "v05/test_image_yx.zarr", + "v05/test_image_cyx.zarr", + "v05/test_image_zyx.zarr", + "v05/test_image_czyx.zarr", + "v05/test_image_c1yx.zarr", + "v05/test_image_tyx.zarr", + "v05/test_image_tcyx.zarr", + "v05/test_image_tzyx.zarr", + "v05/test_image_tczyx.zarr", ], ) -def test_segmentation_iterator(images_v04: dict[str, Path], zarr_name: str): +def test_segmentation_iterator(images_all_versions: dict[str, Path], zarr_name: str): # Base test only the API, not the actual segmentation logic - path = images_v04[zarr_name] + path = images_all_versions[zarr_name] ome_zarr = open_ome_zarr_container(path) image = ome_zarr.get_image() label = ome_zarr.get_label("label") @@ -69,20 +78,31 @@ def test_segmentation_iterator(images_v04: dict[str, Path], zarr_name: str): @pytest.mark.parametrize( "zarr_name", [ - "test_image_yx.zarr", - "test_image_cyx.zarr", - "test_image_zyx.zarr", - "test_image_czyx.zarr", - "test_image_c1yx.zarr", - "test_image_tyx.zarr", - "test_image_tcyx.zarr", - "test_image_tzyx.zarr", - "test_image_tczyx.zarr", + "v04/test_image_yx.zarr", + "v04/test_image_cyx.zarr", + "v04/test_image_zyx.zarr", + "v04/test_image_czyx.zarr", + "v04/test_image_c1yx.zarr", + "v04/test_image_tyx.zarr", + "v04/test_image_tcyx.zarr", + "v04/test_image_tzyx.zarr", + "v04/test_image_tczyx.zarr", + "v05/test_image_yx.zarr", + "v05/test_image_cyx.zarr", + "v05/test_image_zyx.zarr", + "v05/test_image_czyx.zarr", + "v05/test_image_c1yx.zarr", + "v05/test_image_tyx.zarr", + "v05/test_image_tcyx.zarr", + "v05/test_image_tzyx.zarr", + "v05/test_image_tczyx.zarr", ], ) -def test_masked_segmentation_iterator(images_v04: dict[str, Path], zarr_name: str): +def test_masked_segmentation_iterator( + images_all_versions: dict[str, Path], zarr_name: str +): # Base test only the API, not the actual segmentation logic - path = images_v04[zarr_name] + path = images_all_versions[zarr_name] ome_zarr = open_ome_zarr_container(path) masked_label = ome_zarr.derive_label("masking_label") @@ -117,20 +137,29 @@ def test_masked_segmentation_iterator(images_v04: dict[str, Path], zarr_name: st @pytest.mark.parametrize( "zarr_name", [ - "test_image_yx.zarr", - "test_image_cyx.zarr", - "test_image_zyx.zarr", - "test_image_czyx.zarr", - "test_image_c1yx.zarr", - "test_image_tyx.zarr", - "test_image_tcyx.zarr", - "test_image_tzyx.zarr", - "test_image_tczyx.zarr", + "v04/test_image_yx.zarr", + "v04/test_image_cyx.zarr", + "v04/test_image_zyx.zarr", + "v04/test_image_czyx.zarr", + "v04/test_image_c1yx.zarr", + "v04/test_image_tyx.zarr", + "v04/test_image_tcyx.zarr", + "v04/test_image_tzyx.zarr", + "v04/test_image_tczyx.zarr", + "v05/test_image_yx.zarr", + "v05/test_image_cyx.zarr", + "v05/test_image_zyx.zarr", + "v05/test_image_czyx.zarr", + "v05/test_image_c1yx.zarr", + "v05/test_image_tyx.zarr", + "v05/test_image_tcyx.zarr", + "v05/test_image_tzyx.zarr", + "v05/test_image_tczyx.zarr", ], ) -def test_img_processing_iterator(images_v04: dict[str, Path], zarr_name: str): +def test_img_processing_iterator(images_all_versions: dict[str, Path], zarr_name: str): # Base test only the API, not the actual segmentation logic - path = images_v04[zarr_name] + path = images_all_versions[zarr_name] ome_zarr = open_ome_zarr_container(path) image = ome_zarr.get_image() t_ome_zarr = ome_zarr.derive_image(store=MemoryStore()) @@ -163,20 +192,29 @@ def test_img_processing_iterator(images_v04: dict[str, Path], zarr_name: str): @pytest.mark.parametrize( "zarr_name", [ - "test_image_yx.zarr", - "test_image_cyx.zarr", - "test_image_zyx.zarr", - "test_image_czyx.zarr", - "test_image_c1yx.zarr", - "test_image_tyx.zarr", - "test_image_tcyx.zarr", - "test_image_tzyx.zarr", - "test_image_tczyx.zarr", + "v04/test_image_yx.zarr", + "v04/test_image_cyx.zarr", + "v04/test_image_zyx.zarr", + "v04/test_image_czyx.zarr", + "v04/test_image_c1yx.zarr", + "v04/test_image_tyx.zarr", + "v04/test_image_tcyx.zarr", + "v04/test_image_tzyx.zarr", + "v04/test_image_tczyx.zarr", + "v05/test_image_yx.zarr", + "v05/test_image_cyx.zarr", + "v05/test_image_zyx.zarr", + "v05/test_image_czyx.zarr", + "v05/test_image_c1yx.zarr", + "v05/test_image_tyx.zarr", + "v05/test_image_tcyx.zarr", + "v05/test_image_tzyx.zarr", + "v05/test_image_tczyx.zarr", ], ) -def test_features_iterator(images_v04: dict[str, Path], zarr_name: str): +def test_features_iterator(images_all_versions: dict[str, Path], zarr_name: str): # Base test only the API, not the actual segmentation logic - path = images_v04[zarr_name] + path = images_all_versions[zarr_name] ome_zarr = open_ome_zarr_container(path) image = ome_zarr.get_image() diff --git a/tests/unit/ome_zarr_meta/test_image_handler.py b/tests/unit/ome_zarr_meta/test_image_handler.py index beaef87e..d1f53c6e 100644 --- a/tests/unit/ome_zarr_meta/test_image_handler.py +++ b/tests/unit/ome_zarr_meta/test_image_handler.py @@ -1,6 +1,6 @@ from pathlib import Path -from ngio.ome_zarr_meta import NgioImageMeta, find_image_meta_handler +from ngio.ome_zarr_meta import ImageMetaHandler, NgioImageMeta from ngio.utils import ZarrGroupHandler @@ -9,7 +9,7 @@ def test_get_image_handler(cardiomyocyte_tiny_path: Path): # The pooch cache is giving us trouble here cardiomyocyte_tiny_path = cardiomyocyte_tiny_path / "B" / "03" / "0" group_handler = ZarrGroupHandler(cardiomyocyte_tiny_path) - handler = find_image_meta_handler(group_handler) - meta = handler.safe_load_meta() + handler = ImageMetaHandler(group_handler) + meta = handler.get_meta() assert isinstance(meta, NgioImageMeta) - handler.write_meta(meta) + handler.update_meta(meta) diff --git a/tests/unit/ome_zarr_meta/test_unit_ngio_specs.py b/tests/unit/ome_zarr_meta/test_unit_ngio_specs.py index fca7a27b..dfeede79 100644 --- a/tests/unit/ome_zarr_meta/test_unit_ngio_specs.py +++ b/tests/unit/ome_zarr_meta/test_unit_ngio_specs.py @@ -192,9 +192,9 @@ def test_dataset(): ) assert ds.path == "0" - assert ds.get_scale("x") == 0.5 assert ds.axes_handler.get_index("x") == 4 - assert ds.get_translation("x") == 0.0 + assert ds.scale == tuple(scale) + assert ds.translation == tuple(translation) ps = ds.pixel_size assert ps.x == 0.5 diff --git a/tests/unit/ome_zarr_meta/test_unit_v04_utils.py b/tests/unit/ome_zarr_meta/test_unit_v04_utils.py index 820f7afd..128bbc29 100644 --- a/tests/unit/ome_zarr_meta/test_unit_v04_utils.py +++ b/tests/unit/ome_zarr_meta/test_unit_v04_utils.py @@ -5,9 +5,7 @@ from ome_zarr_models.v04.well import WellAttrs as WellAttrsV04 from ngio.ome_zarr_meta import NgioImageMeta, NgioLabelMeta, NgioWellMeta -from ngio.ome_zarr_meta.v04._v04_spec_utils import ( - _is_v04_image_meta, - _is_v04_label_meta, +from ngio.ome_zarr_meta.v04._v04_spec import ( ngio_to_v04_image_meta, ngio_to_v04_label_meta, ngio_to_v04_well_meta, @@ -22,9 +20,7 @@ def test_image_round_trip(): with open(path) as f: input_metadata = json.load(f) - assert _is_v04_image_meta(input_metadata) - is_valid, ngio_image = v04_to_ngio_image_meta(input_metadata) - assert is_valid + ngio_image = v04_to_ngio_image_meta(input_metadata) assert isinstance(ngio_image, NgioImageMeta) output_metadata = ngio_to_v04_image_meta(ngio_image) assert ImageAttrsV04(**output_metadata) == ImageAttrsV04(**input_metadata) @@ -35,10 +31,7 @@ def test_label_round_trip(): with open(path) as f: metadata = json.load(f) - assert _is_v04_label_meta(metadata) - - is_valid, ngio_label = v04_to_ngio_label_meta(metadata) - assert is_valid + ngio_label = v04_to_ngio_label_meta(metadata) assert isinstance(ngio_label, NgioLabelMeta) output_metadata = ngio_to_v04_label_meta(ngio_label) assert LabelAttrsV04(**output_metadata) == LabelAttrsV04(**metadata) @@ -49,8 +42,7 @@ def test_well_meta(): with open(path) as f: metadata = json.load(f) - is_valid, ngio_well = v04_to_ngio_well_meta(metadata) - assert is_valid + ngio_well = v04_to_ngio_well_meta(metadata) assert isinstance(ngio_well, NgioWellMeta) output_metadata = ngio_to_v04_well_meta(ngio_well) assert isinstance(output_metadata, dict) @@ -62,8 +54,7 @@ def test_well_meta_path_normalization(): with open(path) as f: metadata = json.load(f) - is_valid, ngio_well = v04_to_ngio_well_meta(metadata) - assert is_valid + ngio_well = v04_to_ngio_well_meta(metadata) assert isinstance(ngio_well, NgioWellMeta) output_metadata = ngio_to_v04_well_meta(ngio_well) assert isinstance(output_metadata, dict) diff --git a/tests/unit/tables/test_backends.py b/tests/unit/tables/test_backends.py index 71dc3690..8cf30168 100644 --- a/tests/unit/tables/test_backends.py +++ b/tests/unit/tables/test_backends.py @@ -106,11 +106,13 @@ def test_csv_backend(tmp_path: Path): assert backend.implements_pandas() test_table = pd.DataFrame( - {"a": [1, 2, 3], "b": [4.0, 5.0, 6.0], "c": ["a", "b", "c"]} + {"a": [1, 2, 3], "b": [4.1, 5.1, 6.1], "c": ["a", "b", "c"]} ) backend.write(test_table, metadata={"test": "test"}) loaded_table = backend.load_as_pandas_df() + print(loaded_table) + print(test_table) assert loaded_table.equals(test_table), loaded_table meta = backend._group_handler.load_attrs() assert meta["test"] == "test" @@ -155,7 +157,7 @@ def test_parquet_backend(tmp_path: Path): def test_anndata_backend(tmp_path: Path): store = tmp_path / "test_anndata_backend.zarr" - handler = ZarrGroupHandler(store=store, cache=True, mode="a") + handler = ZarrGroupHandler(store=store, cache=True, mode="a", zarr_format=2) backend = AnnDataBackend() backend.set_group_handler(handler, index_type="int") diff --git a/tests/unit/tables/test_masking_roi_table_v1.py b/tests/unit/tables/test_masking_roi_table_v1.py index 930cdc94..d5456db4 100644 --- a/tests/unit/tables/test_masking_roi_table_v1.py +++ b/tests/unit/tables/test_masking_roi_table_v1.py @@ -9,15 +9,9 @@ def test_masking_roi_table_v1(tmp_path: Path): rois = { - 1: Roi( + 1: Roi.from_values( name="1", - x=0.0, - y=0.0, - z=0.0, - x_length=1.0, - y_length=1.0, - z_length=1.0, - unit="micrometer", # type: ignore + slices={"x": slice(0, 10), "y": slice(0, 10), "z": slice(0, 5)}, ) } @@ -28,29 +22,17 @@ def test_masking_roi_table_v1(tmp_path: Path): assert table.meta.region.path == "../labels/label" table.add( - roi=Roi( + roi=Roi.from_values( name="2", - x=0.0, - y=0.0, - z=0.0, - x_length=1.0, - y_length=1.0, - z_length=1.0, - unit="micrometer", # type: ignore + slices={"x": slice(0, 10), "y": slice(0, 10), "z": slice(0, 5)}, ) ) with pytest.raises(NgioValueError): table.add( - roi=Roi( + roi=Roi.from_values( name="2", - x=0.0, - y=0.0, - z=0.0, - x_length=1.0, - y_length=1.0, - z_length=1.0, - unit="micrometer", # type: ignore + slices={"x": slice(0, 10), "y": slice(0, 10), "z": slice(0, 5)}, ) ) diff --git a/tests/unit/tables/test_roi_table_v1.py b/tests/unit/tables/test_roi_table_v1.py index 4cee744f..f100dc6b 100644 --- a/tests/unit/tables/test_roi_table_v1.py +++ b/tests/unit/tables/test_roi_table_v1.py @@ -14,15 +14,9 @@ def test_roi_table_v1(tmp_path: Path): rois = [ - Roi( + Roi.from_values( name="roi1", - x=0.0, - y=0.0, - z=0.0, - x_length=1.0, - y_length=1.0, - z_length=1.0, - unit="micrometer", # type: ignore + slices={"x": slice(0, 1), "y": slice(0, 1), "z": slice(0, 1)}, ) ] @@ -30,43 +24,25 @@ def test_roi_table_v1(tmp_path: Path): assert isinstance(table.__repr__(), str) table.add( - roi=Roi( + roi=Roi.from_values( name="roi2", - x=0.0, - y=0.0, - z=0.0, - x_length=1.0, - y_length=1.0, - z_length=1.0, - unit="micrometer", # type: ignore + slices={"x": slice(0, 1), "y": slice(0, 1), "z": slice(0, 1)}, ) ) with pytest.raises(NgioValueError): # ROI name already exists table.add( - roi=Roi( + roi=Roi.from_values( name="roi2", - x=0.0, - y=0.0, - z=0.0, - x_length=1.0, - y_length=1.0, - z_length=1.0, - unit="micrometer", # type: ignore + slices={"x": slice(0, 1), "y": slice(0, 1), "z": slice(0, 1)}, ) ) table.add( - roi=Roi( + roi=Roi.from_values( name="roi2", - x=0.0, - y=0.0, - z=0.0, - x_length=1.0, - y_length=1.0, - z_length=1.0, - unit="micrometer", # type: ignore + slices={"x": slice(0, 1), "y": slice(0, 1), "z": slice(0, 1)}, ), overwrite=True, ) diff --git a/tests/unit/utils/test_zarr_utils.py b/tests/unit/utils/test_zarr_utils.py index 772d5937..34893368 100644 --- a/tests/unit/utils/test_zarr_utils.py +++ b/tests/unit/utils/test_zarr_utils.py @@ -7,6 +7,7 @@ import numpy as np import pytest import zarr +from zarr.storage import LocalStore from ngio.utils import ( NgioFileExistsError, @@ -23,8 +24,8 @@ def test_group_handler_creation(tmp_path: Path, cache: bool): handler = ZarrGroupHandler(store=store, cache=cache, mode="a") _store = handler.group.store - assert isinstance(_store, zarr.DirectoryStore) - assert Path(_store.path) == store + assert isinstance(_store, LocalStore) + assert Path(_store.root.as_posix()) == store assert handler.use_cache == cache attrs = handler.load_attrs() @@ -32,10 +33,7 @@ def test_group_handler_creation(tmp_path: Path, cache: bool): attrs = {"a": 1, "b": 2, "c": 3} handler.write_attrs(attrs) assert handler.load_attrs() == attrs - if cache: - assert handler.get_from_cache("attrs") == attrs handler.clean_cache() - assert handler.get_from_cache("attrs") is None handler.write_attrs({"a": 2}, overwrite=False) assert handler.load_attrs()["a"] == 2 @@ -53,6 +51,11 @@ def test_group_handler_creation(tmp_path: Path, cache: bool): with pytest.raises(NgioFileExistsError): handler.create_group("new_group", overwrite=False) + # Delete the group + handler.delete_group("new_group") + with pytest.raises(NgioFileNotFoundError): + handler.get_group("new_group") + def test_group_handler_from_group(tmp_path: Path): store = tmp_path / "test_group_handler_from_group.zarr" @@ -62,22 +65,45 @@ def test_group_handler_from_group(tmp_path: Path): assert handler.group == group +def test_group_handler_delete(tmp_path: Path): + store = tmp_path / "test_group_handler_from_group.zarr" + group = zarr.group(store=store, overwrite=True) + group.create_group("to_be_deleted") + handler = ZarrGroupHandler(store=group, cache=True, mode="a") + assert isinstance(handler.get_group("to_be_deleted"), zarr.Group) + handler.delete_group("to_be_deleted") + with pytest.raises(NgioFileNotFoundError): + handler.get_group("to_be_deleted") + assert store.exists() + handler.delete_self() + assert not store.exists() + + store = tmp_path / "test_group_handler_from_group.zarr" + group = zarr.group(store=store, overwrite=True) + group.create_group("to_be_deleted") + handler = ZarrGroupHandler(store=group, cache=True, mode="r") + with pytest.raises(NgioValueError): + handler.delete_group("to_be_deleted") + with pytest.raises(NgioValueError): + handler.delete_self() + + def test_group_handler_read(tmp_path: Path): store = tmp_path / "test_group_handler_read.zarr" - group = zarr.group(store=store, overwrite=True) + group = zarr.create_group(store=store, overwrite=True) input_attrs = {"a": 1, "b": 2, "c": 3} group.attrs.update(input_attrs) group.create_group("group1") - group.create_dataset("array1", shape=(10, 10), dtype="int32") + group.create_array("array1", shape=(10, 10), dtype="int32") handler = ZarrGroupHandler(store=store, cache=True, mode="r") assert handler.load_attrs() == input_attrs assert isinstance(handler.get_array("array1"), zarr.Array) assert isinstance(handler.get_group("group1"), zarr.Group) - assert handler.mode == "r" + assert handler.read_only with pytest.raises(NgioFileNotFoundError): handler.get_array("array2") @@ -97,10 +123,10 @@ def test_group_handler_read(tmp_path: Path): def test_open_fail(tmp_path: Path): store = tmp_path / "test_open_fail.zarr" - group = zarr.group(store=store, overwrite=True) + group = zarr.create_group(store=store, overwrite=True) read_only_group = open_group_wrapper(store=group, mode="r") - assert read_only_group._read_only + assert read_only_group.read_only with pytest.raises(NgioFileExistsError): open_group_wrapper(store=store, mode="w-") @@ -117,9 +143,7 @@ def test_multiprocessing_safety(tmp_path: Path): @dask.delayed # type: ignore def add_item(i): - handler = ZarrGroupHandler( - zarr_store, cache=False, mode="a", parallel_safe=True - ) + handler = ZarrGroupHandler(zarr_store, cache=False, mode="a") assert handler.lock is not None with handler.lock: @@ -129,7 +153,7 @@ def add_item(i): return i - handler = ZarrGroupHandler(zarr_store, cache=False, mode="w", parallel_safe=True) + handler = ZarrGroupHandler(zarr_store, cache=False, mode="w") attrs = handler.load_attrs() attrs = {"test_list": []} handler.write_attrs(attrs, overwrite=True) @@ -145,7 +169,7 @@ def add_item(i): assert len(counts) == num_items assert np.all(counts == 1) - assert handler._lock_path is not None + assert handler.lock_path is not None if sys.platform.startswith("win"): # The file lock path is not removed on Windows @@ -153,18 +177,26 @@ def add_item(i): # even though the file should exist (or at least it does on Mac/Linux) return None - assert Path(handler._lock_path).exists() - lock_path = Path(handler._lock_path) + assert handler.lock_path.exists() + lock_path = handler.lock_path handler.remove_lock() assert not lock_path.exists() handler.remove_lock() - handler = ZarrGroupHandler(zarr_store, cache=False, mode="w", parallel_safe=True) + handler = ZarrGroupHandler(zarr_store, cache=False, mode="w") assert handler.lock is not None with pytest.raises(NgioValueError): + # Attempt to remove the lock while it is in use with handler.lock: handler.remove_lock() + # If cache is used, raise error when creating lock + # Since caching creates a temporary local copy + # which cannot be locked properly + handler = ZarrGroupHandler(zarr_store, cache=True, mode="r") + with pytest.raises(NgioValueError): + handler._create_lock() + def test_remote_storage(): url = ( @@ -180,9 +212,35 @@ def test_remote_storage(): assert handler.load_attrs() assert isinstance(handler.get_array("0"), zarr.Array) assert isinstance(handler.get_group("labels"), zarr.Group) + assert not handler.is_listable - # Check if the fsspec store based group is handled correctly - open_group_wrapper(store=handler.group, mode="r") - with pytest.raises(NgioValueError): - ZarrGroupHandler(store=store, parallel_safe=True) +@pytest.mark.parametrize( + "src_store,dest_store", + [ + (Path("src.zarr"), Path("dest.zarr")), + (Path("src.zarr"), {}), + (Path("dest.zarr"), {}), + ], +) +def test_copy_group(tmp_path: Path, src_store, dest_store): + if isinstance(src_store, Path): + src_store = tmp_path / src_store + if isinstance(dest_store, Path): + dest_store = tmp_path / dest_store + + src_group = zarr.create_group(store=src_store, overwrite=True) + src_group.attrs.update({"a": 1, "b": 2, "c": 3}) + src_group.create_array("array1", shape=(10, 10), dtype="int32") + sub_group = src_group.create_group("group1") + sub_group.create_array("sub_array1", shape=(5, 5), dtype="float32") + handler = ZarrGroupHandler(store=src_group, cache=False, mode="r") + + dest_group = zarr.group(store=dest_store, overwrite=True) + handler.copy_group(dest_group=dest_group) + # Reopen dest group to ensure all data is read from store + dest_group = zarr.open_group(dest_store, mode="r") + assert dest_group.attrs.asdict() == src_group.attrs.asdict() + assert "array1" in dest_group + assert "group1" in dest_group + assert "sub_array1" in dest_group["group1"]