Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,21 +107,23 @@ zarr.config.set(
{"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"}
)

ds = ZarrSparseDataset(
batch_size=4096,
chunk_size=32,
preload_nchunks=256,
).add_anndatas(
[
ad.AnnData(
# note that you can open an AnnData file using any type of zarr store
X=ad.io.sparse_dataset(zarr.open(p)["X"]),
obs=ad.io.read_elem(zarr.open(p)["obs"]),
)
for p in Path("path/to/output/collection").glob("*.zarr")
],
obs_keys=["label_column", "batch_column"],
)
# This settings override ensures that you don't lose/alter your categorical codes when reading the data in!
with ad.settings.override(remove_unused_categories=False):
ds = ZarrSparseDataset(
batch_size=4096,
chunk_size=32,
preload_nchunks=256,
).add_anndatas(
[
ad.AnnData(
# note that you can open an AnnData file using any type of zarr store
X=ad.io.sparse_dataset(zarr.open(p)["X"]),
obs=ad.io.read_elem(zarr.open(p)["obs"]),
)
for p in Path("path/to/output/collection").glob("*.zarr")
],
obs_keys=["label_column", "batch_column"],
)

# Iterate over dataloader (plugin replacement for torch.utils.DataLoader)
for batch in ds:
Expand Down
90 changes: 61 additions & 29 deletions src/annbatch/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import random
import warnings
from collections import defaultdict
from functools import wraps
from pathlib import Path
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -109,7 +110,7 @@ def _check_for_mismatched_keys(paths_or_anndatas: Iterable[PathLike[str] | ad.An
"obsm": defaultdict(lambda: 0),
"obs": defaultdict(lambda: 0),
}
for path_or_anndata in paths_or_anndatas:
for path_or_anndata in tqdm(paths_or_anndatas, desc="checking for mismatched keys"):
if not isinstance(path_or_anndata, ad.AnnData):
adata = ad.experimental.read_lazy(path_or_anndata)
else:
Expand Down Expand Up @@ -141,16 +142,37 @@ def _lazy_load_anndatas(
load_adata: Callable[[PathLike[str] | str], ad.AnnData] = ad.experimental.read_lazy,
):
adatas = []
for i, path in enumerate(paths):
categoricals_in_all_adatas = {}
for i, path in tqdm(enumerate(paths), desc="loading"):
adata = load_adata(path)
# Concatenating Dataset2D drops categoricals
# Track the source file for this given anndata object
adata.obs["src_path"] = pd.Categorical.from_codes(
np.ones((adata.shape[0],), dtype="int") * i, categories=[str(p) for p in paths]
)
# Concatenating Dataset2D drops categoricals so we need to track them
if isinstance(adata.obs, Dataset2D):
adata.obs = adata.obs.to_memory()
adata.obs["src_path"] = pd.Categorical.from_codes([i] * adata.shape[0], categories=[str(p) for p in paths])
categorical_cols_in_this_adata = {
col: set(adata.obs[col].dtype.categories)
for col in adata.obs.columns
if adata.obs[col].dtype == "category"
}
if not categoricals_in_all_adatas:
categoricals_in_all_adatas = {
**categorical_cols_in_this_adata,
"src_path": set(adata.obs["src_path"].dtype.categories),
}
else:
for k in categoricals_in_all_adatas.keys() & categorical_cols_in_this_adata.keys():
categoricals_in_all_adatas[k] = set(categoricals_in_all_adatas[k]).union(
set(categorical_cols_in_this_adata[k])
)
adatas.append(adata)
if len(adatas) == 1:
return adatas[0]
return ad.concat(adatas, join="outer")
adata = ad.concat(adatas, join="outer")
if len(categoricals_in_all_adatas) > 0:
adata.uns["dataset2d_categoricals_to_convert"] = categoricals_in_all_adatas
return adata


def _create_chunks_for_shuffling(adata: ad.AnnData, shuffle_n_obs_per_dataset: int = 1_048_576, shuffle: bool = True):
Expand All @@ -168,27 +190,33 @@ def _create_chunks_for_shuffling(adata: ad.AnnData, shuffle_n_obs_per_dataset: i

def _compute_blockwise(x: DaskArray) -> sp.spmatrix:
""".compute() for large datasets is bad: https://github.com/scverse/annbatch/pull/75"""
return sp.vstack(da.compute(*list(x.blocks)))
if isinstance(x._meta, sp.csr_matrix | sp.csr_array):
return sp.vstack(da.compute(*list(x.blocks)))
return x.compute()


def _to_categorical_obs(adata: ad.AnnData) -> ad.AnnData:
"""Convert columns marked as categorical in `uns` to categories, accounting for `concat` on `Dataset2D` lost dtypes"""
if "dataset2d_categoricals_to_convert" in adata.uns:
for col, categories in adata.uns["dataset2d_categoricals_to_convert"].items():
adata.obs[col] = pd.Categorical(np.array(adata.obs[col]), categories=categories)
del adata.uns["dataset2d_categoricals_to_convert"]
return adata


def _persist_adata_in_memory(adata: ad.AnnData) -> ad.AnnData:
if isinstance(adata.X, DaskArray):
if isinstance(adata.X._meta, sp.csr_matrix | sp.csr_array):
adata.X = _compute_blockwise(adata.X)
else:
adata.X = adata.X.compute()
adata.X = _compute_blockwise(adata.X)
if isinstance(adata.obs, Dataset2D):
adata.obs = adata.obs.to_memory()
adata = _to_categorical_obs(adata)
if isinstance(adata.var, Dataset2D):
adata.var = adata.var.to_memory()

if adata.raw is not None:
adata_raw = adata.raw.to_adata()
if isinstance(adata_raw.X, DaskArray):
if isinstance(adata_raw.X._meta, sp.csr_array | sp.csr_matrix):
adata_raw.X = _compute_blockwise(adata_raw.X)
else:
adata_raw.X = adata_raw.X.compute()
adata_raw.X = _compute_blockwise(adata_raw.X)
if isinstance(adata_raw.var, Dataset2D):
adata_raw.var = adata_raw.var.to_memory()
if isinstance(adata_raw.obs, Dataset2D):
Expand All @@ -199,24 +227,28 @@ def _persist_adata_in_memory(adata: ad.AnnData) -> ad.AnnData:
for k, elem in adata.obsm.items():
# TODO: handle `Dataset2D` in `obsm` and `varm` that are
if isinstance(elem, DaskArray):
if isinstance(elem, sp.csr_matrix | sp.csr_array):
adata.obsm[k] = _compute_blockwise(elem)
else:
adata.obsm[k] = elem.compute()
adata.obsm[k] = _compute_blockwise(elem)

for k, elem in adata.layers.items():
if isinstance(elem, DaskArray):
if isinstance(elem, sp.csr_matrix | sp.csr_array):
adata.layers[k] = _compute_blockwise(elem)
else:
adata.layers[k] = elem.compute()
adata.obsm[k] = _compute_blockwise(elem)

return adata


DATASET_PREFIX = "dataset"


def _with_settings(func):
@wraps(func)
def wrapper(*args, **kwargs):
with ad.settings.override(zarr_write_format=3, remove_unused_categories=False):
return func(*args, **kwargs)

return wrapper


@_with_settings
def create_anndata_collection(
adata_paths: Iterable[PathLike[str]] | Iterable[str],
output_path: PathLike[str] | str,
Expand Down Expand Up @@ -305,7 +337,6 @@ def create_anndata_collection(
...)
"""
Path(output_path).mkdir(parents=True, exist_ok=True)
ad.settings.zarr_write_format = 3
_check_for_mismatched_keys(adata_paths)
adata_concat = _lazy_load_anndatas(adata_paths, load_adata=load_adata)
adata_concat.obs_names_make_unique()
Expand All @@ -314,7 +345,7 @@ def create_anndata_collection(
if var_subset is None:
var_subset = adata_concat.var_names

for i, chunk in enumerate(tqdm(chunks)):
for i, chunk in enumerate(tqdm(chunks, desc="processing chunks")):
var_mask = adata_concat.var_names.isin(var_subset)
# np.sort: It's more efficient to access elements sequentially from dask arrays
# The data will be shuffled later on, we just want the elements at this point
Expand Down Expand Up @@ -356,6 +387,7 @@ def _get_array_encoding_type(path: PathLike[str] | str) -> str:
return encoding["attributes"]["encoding-type"]


@_with_settings
def add_to_collection(
adata_paths: Iterable[PathLike[str]] | Iterable[str],
output_path: PathLike[str] | str,
Expand Down Expand Up @@ -447,16 +479,16 @@ def add_to_collection(
sp.csr_matrix, meta=sp.csr_matrix(np.array([0], dtype=adata_concat.X.dtype))
)

for shard, chunk in tqdm(zip(shards, chunks, strict=False), total=len(shards)):
for shard, chunk in tqdm(zip(shards, chunks, strict=False), total=len(shards), desc="processing chunks"):
if should_sparsify_output_in_memory and encoding == "array":
adata_shard = _lazy_load_anndatas([shard])
adata_shard.X = adata_shard.X.map_blocks(sp.csr_matrix).compute()
else:
adata_shard = ad.read_zarr(shard)

adata = ad.concat(
[adata_shard, adata_concat[chunk, :][:, adata_concat.var.index.isin(adata_shard.var.index)]], join="outer"
subset_adata = _to_categorical_obs(
adata_concat[chunk, :][:, adata_concat.var.index.isin(adata_shard.var.index)]
)
adata = ad.concat([adata_shard, subset_adata], join="outer")
idxs_shuffled = np.random.default_rng().permutation(len(adata))
adata = adata[idxs_shuffled, :].copy() # this significantly speeds up writing to disk
if should_sparsify_output_in_memory and encoding == "array":
Expand Down
11 changes: 5 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@
from collections.abc import Generator


@pytest.fixture(autouse=True)
def anndata_settings():
ad.settings.zarr_write_format = 3 # Needed to support sharding in Zarr


@pytest.fixture(params=[False, True], ids=["zarr-python", "zarrs"])
def use_zarrs(request):
return request.param
Expand Down Expand Up @@ -85,7 +80,11 @@ def adata_with_h5_path_different_var_space(
adata = ad.AnnData(
X=sparse_random(m, n, density=0.1, format="csr", dtype="f4"),
obs=pd.DataFrame(
{"label": np.random.default_rng().integers(0, 5, size=m), "store_id": [i] * m},
{
"label": pd.Categorical([str(m), str(m), *(["a"] * (m - 2))]),
"store_id": [i] * m,
"numeric": np.arange(m),
},
index=np.arange(m).astype(str),
),
var=pd.DataFrame(
Expand Down
38 changes: 23 additions & 15 deletions tests/test_store_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ def test_store_creation_drop_elem(
assert adata_output.raw is None


@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("densify", [True, False])
@pytest.mark.parametrize("shuffle", [pytest.param(True, id="shuffle"), pytest.param(False, id="no_shuffle")])
@pytest.mark.parametrize("densify", [pytest.param(True, id="densify"), pytest.param(False, id="no_densify")])
def test_store_creation(
adata_with_h5_path_different_var_space: tuple[ad.AnnData, Path],
shuffle: bool,
Expand Down Expand Up @@ -225,18 +225,22 @@ def test_store_creation(
sorted(adata_orig.var.index),
)
assert "arr" in adata.obsm
if not shuffle:
np.testing.assert_array_equal(
adata.X if isinstance(adata.X, np.ndarray) else adata.X.toarray(),
adata_orig.X if isinstance(adata_orig.X, np.ndarray) else adata_orig.X.toarray(),
)
np.testing.assert_array_equal(
adata.raw.X if isinstance(adata.raw.X, np.ndarray) else adata.raw.X.toarray(),
adata_orig.raw.X if isinstance(adata_orig.raw.X, np.ndarray) else adata_orig.raw.X.toarray(),
)
np.testing.assert_array_equal(adata.obsm["arr"], adata_orig.obsm["arr"])
adata.obs.index = adata_orig.obs.index # correct for concat
pd.testing.assert_frame_equal(adata.obs, adata_orig.obs)
if shuffle:
adata = adata[adata_orig.obs_names].copy()
np.testing.assert_array_equal(
adata.X if isinstance(adata.X, np.ndarray) else adata.X.toarray(),
adata_orig.X if isinstance(adata_orig.X, np.ndarray) else adata_orig.X.toarray(),
)
np.testing.assert_array_equal(
adata.raw.X if isinstance(adata.raw.X, np.ndarray) else adata.raw.X.toarray(),
adata_orig.raw.X if isinstance(adata_orig.raw.X, np.ndarray) else adata_orig.raw.X.toarray(),
)
np.testing.assert_array_equal(adata.obsm["arr"], adata_orig.obsm["arr"])

# correct for concat misordering the categories
adata.obs["label"] = adata.obs["label"].cat.reorder_categories(adata_orig.obs["label"].dtype.categories)

pd.testing.assert_frame_equal(adata.obs, adata_orig.obs)
z = zarr.open(output_path / "dataset_0.zarr")
assert z["obsm"]["arr"].chunks[0] == 5, z["obsm"]["arr"]
if not densify:
Expand Down Expand Up @@ -326,11 +330,15 @@ def test_store_extension(
zarr_dense_shard_size=10,
)

adata = ad.concat([ad.read_zarr(zarr_path) for zarr_path in sorted(store_path.iterdir())])
adatas_on_disk = [ad.read_zarr(zarr_path) for zarr_path in sorted(store_path.iterdir())]
adata = ad.concat(adatas_on_disk)
adata_orig = adata_with_h5_path_different_var_space[0]
expected_adata = ad.concat([adata_orig, adata_orig[adata_orig.obs["store_id"] >= 4]], join="outer")
assert adata.X.shape[1] == expected_adata.X.shape[1]
assert adata.X.shape[0] == expected_adata.X.shape[0]
# check categoricals to make sure the dtypes match
for a in [*adatas_on_disk, adata]:
assert a.obs["label"].dtype == expected_adata.obs["label"].dtype
assert "arr" in adata.obsm
z = zarr.open(store_path / "dataset_0.zarr")
assert z["obsm"]["arr"].chunks == (5, z["obsm"]["arr"].shape[1])
Expand Down
Loading