Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ scripts/combined_db*
*_play.py
src/lobster/hydra_config/experiment/*
src/lobster/mcp/claude_desktop_config.json
*.ipynb_checkpoints

notebooks/nathan/*
notebooks/karina/*
notebooks/amyxlu/*

models/*

Expand Down
40 changes: 40 additions & 0 deletions slurm/scripts/save_token_losses.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/usr/bin/env bash

#SBATCH --job-name token_loss
#SBATCH --nodes 1
#SBATCH --gpus-per-node 1
#SBATCH --partition gpu2
#SBATCH --cpus-per-gpu 4
#SBATCH --mem 150G
#SBATCH --time=1-00:00:00

source !/.bashrc
eval "$(mamba shell hook --shell bash)"

echo "SLURM_JOB_NODELIST = ${SLURM_JOB_NODELIST}"
echo "SLURM_JOB_ID = ${SLURM_JOB_ID}"
echo "SLURMD_NODENAME = ${SLURMD_NODENAME}"
echo "SLURM_JOB_NUM_NODES = ${SLURM_JOB_NUM_NODES}"

# make sure that this is already set!
cd $LOBSTER_PROJECT_DIR

# use uv, which should already be set up
source .venv/bin/activate

echo "SLURM_JOB_NODELIST = ${SLURM_JOB_NODELIST}"
echo "SLURM_JOB_ID = ${SLURM_JOB_ID}"
echo "SLURMD_NODENAME = ${SLURMD_NODENAME}"
echo "SLURM_JOB_NUM_NODES = ${SLURM_JOB_NUM_NODES}"

nvidia-smi
mamba activate plaid
mamba env list
echo $CONDA_PREFIX
which python

# see save_token_losses.py for the default parser arguments
srun torchrun token_selection/scripts/save_token_losses.py \
--fasta_file /data/bucket/freyn6/data/uniref50.fasta \
--output_dir /data2/lux70/data/uniref50/per_token_losses \
--max_num_per_shard 10000
31 changes: 26 additions & 5 deletions src/lobster/data/_fasta_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, TypeVar

import pandas as pd
import numpy as np
import torch.utils.data

# from beignet.datasets import FASTADataset
Expand Down Expand Up @@ -43,6 +44,7 @@ def __init__(
is_relative_model: bool = False,
tokenizer_dir: str | None = "pmlm_tokenizer",
mlm: bool = True,
offsets_arr: np.ndarray | None = None,
) -> None:
"""
:param path_to_fasta: path to fasta file
Expand Down Expand Up @@ -139,6 +141,7 @@ def __init__(
self._is_relative_model = is_relative_model
self._tokenizer_dir = tokenizer_dir
self._mlm = mlm
self._offsets_arr = offsets_arr

path = importlib.resources.files("lobster") / "assets" / self._tokenizer_dir
self._transform_fn = transform_fn or PmlmTokenizerTransform(
Expand All @@ -159,16 +162,31 @@ def setup(self, stage: str = "fit") -> None: # noqa: ARG002
if stage == "fit":
if any(["train" in self._path_to_fasta]): # pre computed splits
self._train_dataset = torch.utils.data.ConcatDataset(
[FASTADataset(root=p, transform=self._transform_fn) for p in self._path_to_fasta if "train" in p]
[
FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr)
for p in self._path_to_fasta
if "train" in p
]
)
self._val_dataset = torch.utils.data.ConcatDataset(
[FASTADataset(root=p, transform=self._transform_fn) for p in self._path_to_fasta if "val" in p]
[
FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr)
for p in self._path_to_fasta
if "val" in p
]
)
self._test_dataset = torch.utils.data.ConcatDataset(
[FASTADataset(root=p, transform=self._transform_fn) for p in self._path_to_fasta if "test" in p]
[
FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr)
for p in self._path_to_fasta
if "test" in p
]
)
else: # iid split
datasets = [FASTADataset(root=p, transform=self._transform_fn) for p in self._path_to_fasta]
datasets = [
FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr)
for p in self._path_to_fasta
]
dataset = torch.utils.data.ConcatDataset(datasets)
(
self._train_dataset,
Expand All @@ -181,7 +199,10 @@ def setup(self, stage: str = "fit") -> None: # noqa: ARG002
)

if stage == "predict":
datasets = [FASTADataset(root=p, transform=self._transform_fn) for p in self._path_to_fasta]
datasets = [
FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr)
for p in self._path_to_fasta
]
dataset = torch.utils.data.ConcatDataset(datasets)
self._predict_dataset = dataset

Expand Down
3 changes: 3 additions & 0 deletions src/lobster/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from ._shuffled_iterable_dataset import ShuffledIterableDataset
from ._ume_streaming_dataset import UMEStreamingDataset
from ._zinc_dataset import ZINCIterableDataset
from ._sharded_parquet_dataset import ShardedParquetDataset
from ._ptm_dataset import PTMDataset


__all__ = [
"CalmDataset",
"AtomicaDataset",
Expand All @@ -37,5 +39,6 @@
"ZINCIterableDataset",
"OpenGenome2IterableDataset",
"UMEStreamingDataset",
"ShardedParquetDataset",
"PTMDataset",
]
16 changes: 10 additions & 6 deletions src/lobster/datasets/_fasta_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(
*,
transform: Callable | None = None,
use_text_descriptions: bool = True,
offsets_arr: numpy.ndarray | None = None,
) -> None:
if isinstance(root, str):
root = Path(root)
Expand All @@ -32,14 +33,17 @@ def __init__(

self.data = ThreadSafeFile(self.root, open)

offsets = Path(f"{self.root}.offsets.npy")
if offsets_arr is None:
offsets_path = Path(f"{self.root}.offsets.npy")
if offsets_path.exists():
self.offsets, sizes = numpy.load(f"{offsets_path}")
else:
self.offsets, sizes = self._build_index()
numpy.save(f"{offsets_path}", numpy.stack([self.offsets, sizes]))

if offsets.exists():
self.offsets, sizes = numpy.load(f"{offsets}")
else:
self.offsets, sizes = self._build_index()

numpy.save(f"{offsets}", numpy.stack([self.offsets, sizes]))
self.offsets = offsets_arr[0, :]
sizes = offsets_arr[1, :]

self.transform = transform

Expand Down
Loading
Loading