Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/lobster/callbacks/_calm_linear_probe_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class CalmLinearProbeCallback(LinearProbeCallback):
Fraction of data to use for testing.
max_samples : int, default=3000
Maximum number of samples to use from each dataset.
random_seed : int, default=42
Random seed for reproducibility of train/test splits.

Attributes
----------
Expand All @@ -78,6 +80,7 @@ def __init__(
run_every_n_epochs: int | None = None,
test_size: float = 0.2,
max_samples: int = 3000,
random_seed: int = 42,
):
tokenizer_transform = UMETokenizerTransform(
modality="nucleotide",
Expand All @@ -96,6 +99,7 @@ def __init__(

self.test_size = test_size
self.max_samples = max_samples
self.random_seed = random_seed

self.dataset_splits = {}
self.aggregate_metrics = defaultdict(list)
Expand All @@ -122,6 +126,9 @@ def _create_split_datasets(
if split_key in self.dataset_splits:
return self.dataset_splits[split_key]

# Set random seed for reproducibility
np.random.seed(self.random_seed)

dataset = CalmPropertyDataset(task=task, species=species, transform_fn=self.transform_fn)

indices = np.arange(len(dataset))
Expand Down
4 changes: 4 additions & 0 deletions src/lobster/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ._atomica_dataset import AtomicaDataset
from ._calm_dataset import CalmDataset, CalmIterableDataset
from ._calm_property_dataset import CalmPropertyDataset
from ._calm_property_unlabeled_dataset import CalmPropertyUnlabeledDataset
from ._fasta_dataset import FASTADataset
from ._huggingface_iterable_dataset import HuggingFaceIterableDataset
from ._latent_generator_3d_coordinates_dataset import LatentGeneratorPinderIterableDataset
Expand All @@ -10,6 +11,7 @@
from ._multiplexed_sampling_dataset import MultiplexedSamplingDataset
from ._open_genome_2 import OpenGenome2IterableDataset
from ._peer_dataset import PEERDataset
from ._peer_unlabeled_dataset import PEERUnlabeledDataset
from ._round_robin_concat_iterable_dataset import RoundRobinConcatIterableDataset
from ._shuffled_iterable_dataset import ShuffledIterableDataset
from ._ume_streaming_dataset import UMEStreamingDataset
Expand All @@ -20,6 +22,7 @@
"AtomicaDataset",
"CalmIterableDataset",
"CalmPropertyDataset",
"CalmPropertyUnlabeledDataset",
"FASTADataset",
"M320MDataset",
"M320MIterableDataset",
Expand All @@ -31,6 +34,7 @@
"HuggingFaceIterableDataset",
"RoundRobinConcatIterableDataset",
"PEERDataset",
"PEERUnlabeledDataset",
"RoundRobinConcatIterableDataset",
"LatentGeneratorPinderIterableDataset",
"ZINCIterableDataset",
Expand Down
Loading
Loading