Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 56 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,25 @@ The cell lines and perturbations specified in the TOML should match the values a
you can use the `tx predict` command:

```bash
state tx predict --output_dir $HOME/state/test/ --checkpoint final.ckpt
state tx predict \
--output-dir $HOME/state/test/ \
--checkpoint final.ckpt
```

It will look in the `output_dir` above, for a `checkpoints` folder.
It will look in the `output-dir` above, for a `checkpoints` folder.

If you instead want to use a trained checkpoint for inference (e.g. on data not specified)
in the TOML file:


```bash
state tx infer --output $HOME/state/test/ --output_dir /path/to/model/ --checkpoint /path/to/model/final.ckpt --adata /path/to/anndata/processed.h5 --pert_col gene --embed_key X_hvg
state tx infer \
--output $HOME/state/test/ \
--output-dir /path/to/model/ \
--checkpoint /path/to/model/final.ckpt \
--adata /path/to/anndata/processed.h5 \
--pert-col gene \
--embed-key X_hvg
```

Here, `/path/to/model/` is the folder downloaded from [HuggingFace](https://huggingface.co/arcinstitute).
Expand All @@ -108,13 +116,13 @@ State provides two preprocessing commands to prepare data for training and infer

#### Training Data Preprocessing

Use `preprocess_train` to normalize, log-transform, and select highly variable genes from your training data:
Use `preprocess-train` to normalize, log-transform, and select highly variable genes from your training data:

```bash
state tx preprocess_train \
state tx preprocess-train \
--adata /path/to/raw_data.h5ad \
--output /path/to/preprocessed_training_data.h5ad \
--num_hvgs 2000
--num-hvgs 2000
```

This command:
Expand All @@ -125,14 +133,14 @@ This command:

#### Inference Data Preprocessing

Use `preprocess_infer` to create a "control template" for model inference:
Use `preprocess-infer` to create a "control template" for model inference:

```bash
state tx preprocess_infer \
state tx preprocess-infer \
--adata /path/to/real_data.h5ad \
--output /path/to/control_template.h5ad \
--control_condition "DMSO" \
--pert_col "treatment" \
--control-condition "DMSO" \
--pert-col "treatment" \
--seed 42
```

Expand Down Expand Up @@ -301,16 +309,19 @@ state emb transform \
```

Running this command multiple times with the same lancedb appends the new data to the provided database.
Existing cell records will be updated with the new embeddings.

#### Query the database

> For this example, we will use the same dataset (SRX27532045), so the top hit should be the same cell.

Obtain the embeddings:

```bash
state emb transform \
--model-folder /large_storage/ctc/userspace/aadduri/SE-600M \
--input /large_storage/ctc/public/scBasecamp/GeneFull_Ex50pAS/GeneFull_Ex50pAS/Homo_sapiens/SRX27532046.h5ad \
--output tmp/SRX27532046.h5ad \
--input /large_storage/ctc/public/scBasecamp/GeneFull_Ex50pAS/GeneFull_Ex50pAS/Homo_sapiens/SRX27532045.h5ad \
--output tmp/SRX27532045.h5ad \
--gene-column gene_symbols
```

Expand All @@ -319,9 +330,41 @@ Query the database with the embeddings:
```bash
state emb query \
--lancedb tmp/state_embeddings.lancedb \
--input tmp/SRX27532046.h5ad \
--input tmp/SRX27532045.h5ad \
--output tmp/similar_cells.csv \
--k 3
```

Output:
- `query_cell_id` : The cell id of the query cell
- `subject_rank` : The rank of the h (smallest distance to)
- `query_subject_distance` : The distance between the query and subject cell vectors
- `subject_cell_id` : The cell id of the hit cell
- `subject_dataset` : The dataset of the hit cell
- `embedding_key` : The embedding key of the hit cell
- `...` : Other `obs` metadata columns from the query cell

#### Summarize the vector database

Get comprehensive statistics about your vector database:

```bash
state emb vectordb \
--lancedb tmp/state_embeddings.lancedb \
--format table
```

Output formats:
- `table` (default): Human-readable table format with emojis
- `json`: Machine-readable JSON format
- `yaml`: YAML format

The summary includes:
- Total number of cells and datasets
- Number of unique embedding keys
- Embedding vector dimensions
- Cell count breakdown by dataset
- List of all embedding keys

# Singularity

Expand Down
44 changes: 31 additions & 13 deletions src/state/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
from hydra import compose, initialize
from omegaconf import DictConfig

from ._cli._utils import CustomFormatter
from ._cli import (
add_arguments_emb,
add_arguments_tx,
run_emb_fit,
run_emb_transform,
run_emb_query,
run_emb_vectordb,
run_tx_infer,
run_tx_predict,
run_tx_preprocess_infer,
Expand All @@ -19,10 +21,25 @@

def get_args() -> tuple[ap.Namespace, list[str]]:
"""Parse known args and return remaining args for Hydra overrides"""
parser = ap.ArgumentParser()
desc = """description:
Entry point for the STATE command line interface.
Use these commands to train models, compute embeddings, and run inference.
Run `state <command> --help` for details on each command."""
parser = ap.ArgumentParser(description=desc, formatter_class=CustomFormatter)
parser.add_argument("--log-level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Logging level")
subparsers = parser.add_subparsers(required=True, dest="command")
add_arguments_emb(subparsers.add_parser("emb"))
add_arguments_tx(subparsers.add_parser("tx"))

# emb
desc = """description:
Commands for generating and querying STATE embeddings.
See `state emb <command> --help` for subcommand options."""
add_arguments_emb(subparsers.add_parser("emb", description=desc, formatter_class=CustomFormatter))

# tx
desc = """description:
Train and evaluate perturbation models with Hydra configuration.
Overrides can be passed via `state tx <subcommand> param=value`."""
add_arguments_tx(subparsers.add_parser("tx", description=desc, formatter_class=CustomFormatter))

# Use parse_known_args to get both known args and remaining args
return parser.parse_args()
Expand Down Expand Up @@ -62,21 +79,20 @@ def show_hydra_help(method: str):
print()
print("Usage examples:")
print(" Override single parameter:")
print(f" uv run state tx train data.batch_size=64")
print(" uv run state tx train data.batch_size=64")
print()
print(" Override nested parameter:")
print(f" uv run state tx train model.kwargs.hidden_dim=512")
print(" uv run state tx train model.kwargs.hidden_dim=512")
print()
print(" Override multiple parameters:")
print(f" uv run state tx train data.batch_size=64 training.lr=0.001")
print(" uv run state tx train data.batch_size=64 training.lr=0.001")
print()
print(" Change config group:")
print(f" uv run state tx train data=custom_data model=custom_model")
print(" uv run state tx train data=custom_data model=custom_model")
print()
print("Available config groups:")

# Show available config groups
import os
from pathlib import Path

config_dir = Path(__file__).parent / "configs"
Expand All @@ -103,6 +119,8 @@ def main():
run_emb_transform(args)
case "query":
run_emb_query(args)
case "vectordb":
run_emb_vectordb(args)
case "tx":
match args.subcommand:
case "train":
Expand All @@ -112,19 +130,19 @@ def main():
else:
# Load Hydra config with overrides for sets training
cfg = load_hydra_config("tx", args.hydra_overrides)
run_tx_train(cfg)
run_tx_train(cfg, args)
case "predict":
# For now, predict uses argparse and not hydra
run_tx_predict(args)
case "infer":
# Run inference using argparse, similar to predict
run_tx_infer(args)
case "preprocess_train":
case "preprocess-train":
# Run preprocessing using argparse
run_tx_preprocess_train(args.adata, args.output, args.num_hvgs)
case "preprocess_infer":
run_tx_preprocess_train(args.adata, args.output, args.num_hvgs, args.log_level)
case "preprocess-infer":
# Run inference preprocessing using argparse
run_tx_preprocess_infer(args.adata, args.output, args.control_condition, args.pert_col, args.seed)
run_tx_preprocess_infer(args.adata, args.output, args.control_condition, args.pert_col, args.seed, args.log_level)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion src/state/_cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._emb import add_arguments_emb, run_emb_fit, run_emb_transform, run_emb_query
from ._emb import add_arguments_emb, run_emb_fit, run_emb_transform, run_emb_query, run_emb_vectordb
from ._tx import add_arguments_tx, run_tx_infer, run_tx_predict, run_tx_preprocess_infer, run_tx_preprocess_train, run_tx_train

__all__ = [
Expand All @@ -12,4 +12,5 @@
"run_emb_fit",
"run_emb_query",
"run_emb_transform",
"run_emb_vectordb",
]
39 changes: 34 additions & 5 deletions src/state/_cli/_emb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,42 @@
from ._fit import add_arguments_fit, run_emb_fit
from ._transform import add_arguments_transform, run_emb_transform
from ._query import add_arguments_query, run_emb_query
from ._vectordb import add_arguments_vectordb, run_emb_vectordb
from .._utils import CustomFormatter

__all__ = ["run_emb_fit", "run_emb_transform", "run_emb_query", "add_arguments_emb"]
__all__ = ["run_emb_fit", "run_emb_transform", "run_emb_query", "run_emb_vectordb", "add_arguments_emb"]


def add_arguments_emb(parser: ap.ArgumentParser):
""""""
"""Add embedding commands to the parser"""
subparsers = parser.add_subparsers(required=True, dest="subcommand")
add_arguments_fit(subparsers.add_parser("fit"))
add_arguments_transform(subparsers.add_parser("transform"))
add_arguments_query(subparsers.add_parser("query"))

# fit
desc = """description:
Train an embedding model on a reference dataset.
Provide Hydra overrides to adjust training parameters."""
add_arguments_fit(
subparsers.add_parser("fit", description=desc, formatter_class=CustomFormatter)
)

# transform
desc = """description:
Encode an input dataset with a trained embedding model.
Results can be saved locally and inserted into a LanceDB vector store."""
add_arguments_transform(
subparsers.add_parser("transform", description=desc, formatter_class=CustomFormatter)
)

# query
desc = """description:
Search a LanceDB vector store (created with `transform`) for cells with similar embeddings."""
add_arguments_query(
subparsers.add_parser("query", description=desc, formatter_class=CustomFormatter)
)

# vectordb
desc = """description:
Get summary statistics about a LanceDB vector database including datasets, cell counts, and embeddings."""
add_arguments_vectordb(
subparsers.add_parser("vectordb", description=desc, formatter_class=CustomFormatter)
)
1 change: 1 addition & 0 deletions src/state/_cli/_emb/_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def run_emb_fit(cfg, args):

from ...emb.train.trainer import main as trainer_main

logging.basicConfig(level=getattr(logging, args.log_level, logging.INFO))
log = logging.getLogger(__name__)

# Load the base configuration
Expand Down
31 changes: 22 additions & 9 deletions src/state/_cli/_emb/_query.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import argparse as ap
import logging
import pandas as pd
Expand All @@ -14,15 +15,16 @@ def add_arguments_query(parser: ap.ArgumentParser):
parser.add_argument("--embed-key", default="X_state", help="Key containing embeddings in input file")
parser.add_argument("--exclude-distances", action="store_true",
help="Exclude vector distances in results")
parser.add_argument("--filter", type=str, help="Filter expression (e.g., 'cell_type==\"B cell\"')")
parser.add_argument("--batch-size", type=int, default=100,
help="Batch size for query operations")
parser.add_argument("--filter", type=str,
help="Filter expression (e.g., 'cell_type==\"B cell\"', assuming a 'cell_type' column exists in the database)")
parser.add_argument("--batch-size", type=int, default=100, help="Batch size for query operations")
parser.add_argument("--max-workers", type=int, default=os.cpu_count(), help="Maximum number of workers for parallel processing")

def run_emb_query(args: ap.ArgumentParser):
"""
Query a LanceDB database for similar cells.
"""
logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=getattr(logging, args.log_level, logging.INFO))
logger = logging.getLogger(__name__)

from ...emb.vectordb import StateVectorDB
Expand Down Expand Up @@ -59,18 +61,29 @@ def run_emb_query(args: ap.ArgumentParser):
filter=args.filter,
include_distance=not args.exclude_distances,
batch_size=args.batch_size,
max_workers=args.max_workers,
show_progress=True
)

# Add query cell IDs and ranks to results
all_results = []
for query_idx, result_df in enumerate(results_list):
result_df['query_cell_id'] = query_adata.obs.index[query_idx]
result_df['query_rank'] = range(1, len(result_df) + 1)
result_df['subject_rank'] = range(1, len(result_df) + 1)
all_results.append(result_df)

# Combine results
final_results = pd.concat(all_results, ignore_index=True)

# Format the results table
## Move certain columns to the start, if they exist
to_move = ['query_cell_id', 'subject_rank', 'query_subject_distance', 'cell_id', 'dataset', 'embedding_key']
to_move = [col for col in to_move if col in final_results.columns]
final_results = final_results[to_move + [col for col in final_results.columns if col not in to_move]]
## Rename `cell_id` to 'subject_cell_id'
rn_dict = {'cell_id': 'subject_cell_id', 'dataset': 'subject_dataset'}
rn_dict = {k:v for k,v in rn_dict.items() if k in final_results.columns}
final_results = final_results.rename(columns=rn_dict)

# Save results
output_path = Path(args.output)
Expand All @@ -96,11 +109,11 @@ def create_result_anndata(query_adata, results_df, k):
cell_ids_array = np.array(cell_ids_pivot.values, dtype=str)

# Handle distances - convert to float64 and handle missing values
if 'vector_distance' in results_df:
if 'query_subject_distance' in results_df:
distances_pivot = results_df.pivot(
index='query_cell_id',
columns='query_rank',
values='vector_distance'
values='query_subject_distance'
)
distances_array = np.array(distances_pivot.values, dtype=np.float64)
else:
Expand All @@ -118,5 +131,5 @@ def create_result_anndata(query_adata, results_df, k):
# Create result anndata
result_adata = query_adata.copy()
result_adata.uns['lancedb_query_results'] = uns_data
return result_adata

return result_adata
Loading