diff --git a/.gitignore b/.gitignore index 00fae272..2e781df0 100644 --- a/.gitignore +++ b/.gitignore @@ -29,9 +29,11 @@ scripts/combined_db* *_play.py src/lobster/hydra_config/experiment/* src/lobster/mcp/claude_desktop_config.json +*.ipynb_checkpoints notebooks/nathan/* notebooks/karina/* +notebooks/amyxlu/* models/* diff --git a/slurm/scripts/save_token_losses.sh b/slurm/scripts/save_token_losses.sh new file mode 100644 index 00000000..09b71485 --- /dev/null +++ b/slurm/scripts/save_token_losses.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash + +#SBATCH --job-name token_loss +#SBATCH --nodes 1 +#SBATCH --gpus-per-node 1 +#SBATCH --partition gpu2 +#SBATCH --cpus-per-gpu 4 +#SBATCH --mem 150G +#SBATCH --time=1-00:00:00 + +source !/.bashrc +eval "$(mamba shell hook --shell bash)" + +echo "SLURM_JOB_NODELIST = ${SLURM_JOB_NODELIST}" +echo "SLURM_JOB_ID = ${SLURM_JOB_ID}" +echo "SLURMD_NODENAME = ${SLURMD_NODENAME}" +echo "SLURM_JOB_NUM_NODES = ${SLURM_JOB_NUM_NODES}" + +# make sure that this is already set! +cd $LOBSTER_PROJECT_DIR + +# use uv, which should already be set up +source .venv/bin/activate + +echo "SLURM_JOB_NODELIST = ${SLURM_JOB_NODELIST}" +echo "SLURM_JOB_ID = ${SLURM_JOB_ID}" +echo "SLURMD_NODENAME = ${SLURMD_NODENAME}" +echo "SLURM_JOB_NUM_NODES = ${SLURM_JOB_NUM_NODES}" + +nvidia-smi +mamba activate plaid +mamba env list +echo $CONDA_PREFIX +which python + +# see save_token_losses.py for the default parser arguments +srun torchrun token_selection/scripts/save_token_losses.py \ + --fasta_file /data/bucket/freyn6/data/uniref50.fasta \ + --output_dir /data2/lux70/data/uniref50/per_token_losses \ + --max_num_per_shard 10000 \ No newline at end of file diff --git a/src/lobster/data/_fasta_datamodule.py b/src/lobster/data/_fasta_datamodule.py index 53fbea19..081fb3e6 100644 --- a/src/lobster/data/_fasta_datamodule.py +++ b/src/lobster/data/_fasta_datamodule.py @@ -4,6 +4,7 @@ from typing import Any, TypeVar import pandas as pd +import numpy as np import torch.utils.data # from beignet.datasets import FASTADataset @@ -43,6 +44,7 @@ def __init__( is_relative_model: bool = False, tokenizer_dir: str | None = "pmlm_tokenizer", mlm: bool = True, + offsets_arr: np.ndarray | None = None, ) -> None: """ :param path_to_fasta: path to fasta file @@ -139,6 +141,7 @@ def __init__( self._is_relative_model = is_relative_model self._tokenizer_dir = tokenizer_dir self._mlm = mlm + self._offsets_arr = offsets_arr path = importlib.resources.files("lobster") / "assets" / self._tokenizer_dir self._transform_fn = transform_fn or PmlmTokenizerTransform( @@ -159,16 +162,31 @@ def setup(self, stage: str = "fit") -> None: # noqa: ARG002 if stage == "fit": if any(["train" in self._path_to_fasta]): # pre computed splits self._train_dataset = torch.utils.data.ConcatDataset( - [FASTADataset(root=p, transform=self._transform_fn) for p in self._path_to_fasta if "train" in p] + [ + FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) + for p in self._path_to_fasta + if "train" in p + ] ) self._val_dataset = torch.utils.data.ConcatDataset( - [FASTADataset(root=p, transform=self._transform_fn) for p in self._path_to_fasta if "val" in p] + [ + FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) + for p in self._path_to_fasta + if "val" in p + ] ) self._test_dataset = torch.utils.data.ConcatDataset( - [FASTADataset(root=p, transform=self._transform_fn) for p in self._path_to_fasta if "test" in p] + [ + FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) + for p in self._path_to_fasta + if "test" in p + ] ) else: # iid split - datasets = [FASTADataset(root=p, transform=self._transform_fn) for p in self._path_to_fasta] + datasets = [ + FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) + for p in self._path_to_fasta + ] dataset = torch.utils.data.ConcatDataset(datasets) ( self._train_dataset, @@ -181,7 +199,10 @@ def setup(self, stage: str = "fit") -> None: # noqa: ARG002 ) if stage == "predict": - datasets = [FASTADataset(root=p, transform=self._transform_fn) for p in self._path_to_fasta] + datasets = [ + FASTADataset(root=p, transform=self._transform_fn, offsets_arr=self._offsets_arr) + for p in self._path_to_fasta + ] dataset = torch.utils.data.ConcatDataset(datasets) self._predict_dataset = dataset diff --git a/src/lobster/datasets/__init__.py b/src/lobster/datasets/__init__.py index 8edb834a..e0e5115f 100644 --- a/src/lobster/datasets/__init__.py +++ b/src/lobster/datasets/__init__.py @@ -14,8 +14,10 @@ from ._shuffled_iterable_dataset import ShuffledIterableDataset from ._ume_streaming_dataset import UMEStreamingDataset from ._zinc_dataset import ZINCIterableDataset +from ._sharded_parquet_dataset import ShardedParquetDataset from ._ptm_dataset import PTMDataset + __all__ = [ "CalmDataset", "AtomicaDataset", @@ -37,5 +39,6 @@ "ZINCIterableDataset", "OpenGenome2IterableDataset", "UMEStreamingDataset", + "ShardedParquetDataset", "PTMDataset", ] diff --git a/src/lobster/datasets/_fasta_dataset.py b/src/lobster/datasets/_fasta_dataset.py index e78ad354..0e36ef53 100644 --- a/src/lobster/datasets/_fasta_dataset.py +++ b/src/lobster/datasets/_fasta_dataset.py @@ -17,6 +17,7 @@ def __init__( *, transform: Callable | None = None, use_text_descriptions: bool = True, + offsets_arr: numpy.ndarray | None = None, ) -> None: if isinstance(root, str): root = Path(root) @@ -32,14 +33,17 @@ def __init__( self.data = ThreadSafeFile(self.root, open) - offsets = Path(f"{self.root}.offsets.npy") + if offsets_arr is None: + offsets_path = Path(f"{self.root}.offsets.npy") + if offsets_path.exists(): + self.offsets, sizes = numpy.load(f"{offsets_path}") + else: + self.offsets, sizes = self._build_index() + numpy.save(f"{offsets_path}", numpy.stack([self.offsets, sizes])) - if offsets.exists(): - self.offsets, sizes = numpy.load(f"{offsets}") else: - self.offsets, sizes = self._build_index() - - numpy.save(f"{offsets}", numpy.stack([self.offsets, sizes])) + self.offsets = offsets_arr[0, :] + sizes = offsets_arr[1, :] self.transform = transform diff --git a/src/lobster/datasets/_sharded_parquet_dataset.py b/src/lobster/datasets/_sharded_parquet_dataset.py new file mode 100644 index 00000000..4556a72e --- /dev/null +++ b/src/lobster/datasets/_sharded_parquet_dataset.py @@ -0,0 +1,245 @@ +import os +import glob +import numpy as np +import pandas as pd +import torch +import torch.distributed as dist +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.distributed import DistributedSampler + + +class ShardedParquetDataset(Dataset): + def __init__( + self, parquet_dir, percentile_threshold=90, loss_threshold=None, stats_file=None, rank=None, world_size=None + ): + """ + Distributed dataset for sharded parquet files. + + Args: + parquet_dir: Directory containing parquet shards + percentile_threshold: Only include tokens with loss below this percentile + loss_threshold: Optional explicit loss threshold (if pre-computed) + stats_file: Path to pre-computed statistics file + rank: Process rank in distributed training + world_size: Total number of processes + """ + self.parquet_dir = parquet_dir + self.percentile_threshold = percentile_threshold + + # Get list of all shard files + self.shard_files = sorted(glob.glob(f"{parquet_dir}/partition_id=*/part-*.parquet")) + + # If running distributed, only use shards for this rank + if rank is not None and world_size is not None: + # Distribute shards across workers + self.shard_files = [f for i, f in enumerate(self.shard_files) if i % world_size == rank] + + # Set loss threshold either from argument or by loading stats + if loss_threshold is not None: + self.loss_threshold = loss_threshold + elif stats_file and os.path.exists(stats_file): + # Load pre-computed statistics + import json + + with open(stats_file) as f: + stats = json.load(f) + self.loss_threshold = stats["percentiles"][str(percentile_threshold)] + else: + # Calculate threshold (ideally, this is pre-computed) + self.loss_threshold = self._calculate_percentile() + + # Load sequence metadata from all assigned shards + self.sequence_data = self._load_sequence_metadata() + + def _calculate_percentile(self): + """Calculate percentile threshold from samples.""" + # Only calculate on rank 0 and broadcast if distributed + if dist.is_initialized() and dist.get_rank() != 0: + # Non-root processes wait for result + threshold = torch.zeros(1, dtype=torch.float32).cuda() + dist.broadcast(threshold, 0) + return threshold.item() + + # Root process (or non-distributed) calculates + print(f"Calculating {self.percentile_threshold}th percentile threshold...") + samples = [] + + # Sample from each shard + for shard in self.shard_files[:10]: # Limit to 10 shards for efficiency + df = pd.read_parquet(shard, columns=["loss"]) + # Take a sample proportional to size + sample_size = min(10000, len(df)) + if sample_size > 0: + samples.append(df.sample(sample_size)["loss"].values) + + # Calculate threshold from samples + if samples: + all_samples = np.concatenate(samples) + threshold = float(np.percentile(all_samples, self.percentile_threshold)) + else: + threshold = float("inf") # No samples available + + # Broadcast result if distributed + if dist.is_initialized(): + threshold_tensor = torch.tensor([threshold], dtype=torch.float32).cuda() + dist.broadcast(threshold_tensor, 0) + threshold = threshold_tensor.item() + + print(f"Using loss threshold: {threshold}") + return threshold + + def _load_sequence_metadata(self): + """Load sequence metadata from assigned shards.""" + sequences = [] + + for shard_file in self.shard_files: + # Read just sequence metadata for efficiency + try: + # Group by sequence_id and get sizes + df = pd.read_parquet(shard_file, columns=["sequence_id", "position"]) + seq_info = df.groupby("sequence_id").agg({"position": "max"}) + + for seq_id, max_pos in seq_info.itertuples(): + sequences.append( + { + "sequence_id": seq_id, + "length": max_pos + 1, # Convert to length + "shard_file": shard_file, + } + ) + except Exception as e: + print(f"Error loading metadata from {shard_file}: {e}") + + return sequences + + def __len__(self): + return len(self.sequence_data) + + def __getitem__(self, idx): + """Get a filtered sequence by index.""" + seq_info = self.sequence_data[idx] + seq_id = seq_info["sequence_id"] + shard_file = seq_info["shard_file"] + + # Read this sequence with filtering + try: + # Use PyArrow filter pushdown for efficiency + df = pd.read_parquet( + shard_file, filters=[("sequence_id", "=", seq_id), ("loss", "<=", self.loss_threshold)] + ) + + # Sort by position to maintain sequence order + if not df.empty: + df = df.sort_values("position") + + return { + "sequence_id": seq_id, + "tokens": df["token"].values, + "positions": df["position"].values, + "losses": df["loss"].values, + } + else: + # No tokens passed the filter + return { + "sequence_id": seq_id, + "tokens": np.array([], dtype=np.int64), + "positions": np.array([], dtype=np.int64), + "losses": np.array([], dtype=np.float32), + } + + except Exception as e: + print(f"Error loading sequence {seq_id}: {e}") + # Return empty sequence on error + return { + "sequence_id": seq_id, + "tokens": np.array([], dtype=np.int64), + "positions": np.array([], dtype=np.int64), + "losses": np.array([], dtype=np.float32), + } + + +def collate_variable_length_sequences(batch): + """Custom collate function for variable-length sequences.""" + # Filter out empty sequences + non_empty = [b for b in batch if len(b["tokens"]) > 0] + + if not non_empty: + # All sequences were empty after filtering + return { + "sequence_ids": [], + "tokens": torch.zeros(0, dtype=torch.int64), + "positions": torch.zeros(0, dtype=torch.int64), + "losses": torch.zeros(0, dtype=torch.float32), + "batch_indices": torch.zeros(0, dtype=torch.int64), + } + + # Gather data + sequence_ids = [b["sequence_id"] for b in non_empty] + tokens_list = [torch.tensor(b["tokens"], dtype=torch.int64) for b in non_empty] + positions_list = [torch.tensor(b["positions"], dtype=torch.int64) for b in non_empty] + losses_list = [torch.tensor(b["losses"], dtype=torch.float32) for b in non_empty] + + # Create batch indices for reconstructing sequences later + batch_sizes = [len(t) for t in tokens_list] + batch_indices = torch.cat([torch.full((size,), i, dtype=torch.int64) for i, size in enumerate(batch_sizes)]) + + # Concatenate all tokens + tokens = torch.cat(tokens_list) + positions = torch.cat(positions_list) + losses = torch.cat(losses_list) + + return { + "sequence_ids": sequence_ids, + "tokens": tokens, + "positions": positions, + "losses": losses, + "batch_indices": batch_indices, + } + + +def setup_distributed(): + """Initialize distributed training environment.""" + # Initialize process group + dist.init_process_group( + backend="nccl", # Use 'gloo' for CPU-only + init_method="env://", + ) + + # Get global rank and world size + rank = dist.get_rank() + world_size = dist.get_world_size() + + # Set device for this process + torch.cuda.set_device(rank % torch.cuda.device_count()) + + return rank, world_size + + +def create_distributed_dataloader(parquet_dir, percentile_threshold=90, batch_size=32, num_workers=4): + """Create a distributed dataloader for sharded parquet files.""" + # Setup distributed environment + rank, world_size = setup_distributed() + + # Create dataset with this rank's shards + dataset = ShardedParquetDataset( + parquet_dir=parquet_dir, + percentile_threshold=percentile_threshold, + stats_file=f"{parquet_dir}/stats.json", + rank=rank, + world_size=world_size, + ) + + # Create distributed sampler to handle partitioning + sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True, drop_last=False) + + # Create dataloader with custom collate function + dataloader = DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + num_workers=num_workers, + collate_fn=collate_variable_length_sequences, + pin_memory=True, + ) + + return dataloader, rank, world_size diff --git a/src/lobster/model/_clm.py b/src/lobster/model/_clm.py index 267b3613..fd6a342f 100644 --- a/src/lobster/model/_clm.py +++ b/src/lobster/model/_clm.py @@ -4,7 +4,7 @@ import lightning.pytorch as pl import torch from torch.nn import CrossEntropyLoss -from transformers import LlamaConfig, LlamaForCausalLM, get_scheduler, pipeline +from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM, get_scheduler, pipeline from lobster.constants import SchedulerType from lobster.tokenization import PmlmTokenizer, PmlmTokenizerTransform @@ -13,6 +13,9 @@ from ._clm_configuration import PCLM_CONFIG_ARGS +ALLOWABLE_MODEL_NAMES = list(PCLM_CONFIG_ARGS.keys()) + ["ProtGPT2"] + + class LobsterPCLM(pl.LightningModule): def __init__( self, @@ -68,36 +71,45 @@ def __init__( self.scheduler_kwargs = scheduler_kwargs or {} model_kwargs = model_kwargs or {} - if self._tokenizer_dir is not None: - path = importlib.resources.files("lobster") / "assets" / self._tokenizer_dir - self.tokenizer = PmlmTokenizer.from_pretrained(path, do_lower_case=False) - self._transform_fn = transform_fn or PmlmTokenizerTransform( - path, - padding="max_length", - truncation=True, - max_length=self._max_length, - mlm=False, + assert model_name in ALLOWABLE_MODEL_NAMES, f"model_name must be one of {ALLOWABLE_MODEL_NAMES}" + + if model_name == "ProtGPT2": + self.tokenizer = AutoTokenizer.from_pretrained("nferruz/ProtGPT2") + self.model = AutoModelForCausalLM.from_pretrained("nferruz/ProtGPT2") + self.config = self.model.config + + else: + # Create PCLM model + if self._tokenizer_dir is not None: + path = importlib.resources.files("lobster") / "assets" / self._tokenizer_dir + self.tokenizer = PmlmTokenizer.from_pretrained(path, do_lower_case=False) + self._transform_fn = transform_fn or PmlmTokenizerTransform( + path, + padding="max_length", + truncation=True, + max_length=self._max_length, + mlm=False, + ) + + config_args = PCLM_CONFIG_ARGS[model_name] + if num_key_value_heads is None: + num_key_value_heads = config_args["num_attention_heads"] + self._num_key_value_heads = num_key_value_heads + + config = LlamaConfig( + **config_args, + mask_token_id=self.tokenizer.mask_token_id, + pad_token_id=self.tokenizer.pad_token_id, + cls_token_id=self.tokenizer.cls_token_id, + eos_token_id=self.tokenizer.eos_token_id, + vocab_size=len(self.tokenizer.get_vocab()), + max_position_embeddings=self._max_length, + num_key_value_heads=self._num_key_value_heads, + attention_bias=self._attention_bias, + **model_kwargs, ) - - config_args = PCLM_CONFIG_ARGS[model_name] - if num_key_value_heads is None: - num_key_value_heads = config_args["num_attention_heads"] - self._num_key_value_heads = num_key_value_heads - - config = LlamaConfig( - **config_args, - mask_token_id=self.tokenizer.mask_token_id, - pad_token_id=self.tokenizer.pad_token_id, - cls_token_id=self.tokenizer.cls_token_id, - eos_token_id=self.tokenizer.eos_token_id, - vocab_size=len(self.tokenizer.get_vocab()), - max_position_embeddings=self._max_length, - num_key_value_heads=self._num_key_value_heads, - attention_bias=self._attention_bias, - **model_kwargs, - ) - self.model = LlamaForCausalLM(config) - self.config = self.model.config + self.model = LlamaForCausalLM(config) + self.config = self.model.config self.save_hyperparameters(logger=False) diff --git a/token_selection/README.md b/token_selection/README.md new file mode 100644 index 00000000..b3e0b682 --- /dev/null +++ b/token_selection/README.md @@ -0,0 +1,20 @@ +# Selective Token Modeling + +This directory contains experiments related to calculating per-token losses on an existing pretrained model for the purpose of Selective Token Modeling (SLM) (see the Rho-1 [paper](https://arxiv.org/abs/2404.07965) by Lin et al.). +The core idea is that not all tokens are similarly difficult for the model to learn; in the English language, this might be tokens such as `the`. Faster convergence, better performance, and/or reduced model parameter size can be achieved by selectively trains on useful tokens that aligned with the desired distribution. We can make use of previously trained models to determine this notion of "in-distribution". + +From the project root directory, running +``` +LOBSTER_PROJECT_DIR=$(pwd) +sbatch slurm/scripts/save_token_losses.sh +``` + +will launch a multi-GPU inference job that saves per-token losses for a FASTA sequence on a specified model (the autoregressive [RITA-Large](https://arxiv.org/abs/2205.05789) model is used by default) into Parquet format. + +Model training with selective token percentages can be done using the dataloader in `datasets/_sharded_parquet_dataset.py`. + +## Extensions: +- [ ] Perform the same experiment for other modalities for data mixture determination +- [ ] Perform the same experiment for downstream tasks to determine which tasks are more difficult for the model +- [ ] Perform ablation experiments by incorporating data at different loss percentages. +- [ ] Perform on masked language models to see if the pattern is different. Note: this will require O(L) forward passes. diff --git a/token_selection/scripts/save_token_losses.py b/token_selection/scripts/save_token_losses.py new file mode 100644 index 00000000..8d1ad32f --- /dev/null +++ b/token_selection/scripts/save_token_losses.py @@ -0,0 +1,219 @@ +from typing import Dict, List, Any +import os +import argparse +from pathlib import Path + + +import numpy as np +import pandas as pd +from transformers import AutoModelForCausalLM, AutoTokenizer + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.nn.parallel import DistributedDataParallel +from torch.nn.functional import cross_entropy + +from lobster.datasets import FASTADataset + + +torch.set_float32_matmul_precision("high") + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--fasta_file", + type=str, + help="Path to the FASTA file containing sequences.", + ) + parser.add_argument( + "--offset_array_path", + type=str, + help="Path to the numpy array containing offsets for the FASTA file.", + ) + parser.add_argument( + "--output_dir", + type=str, + help="Directory to save the output files.", + ) + parser.add_argument( + "--model_name", + type=str, + default="lightonai/RITA_l", + help="Name of the autoregressive model to use for token loss computation.", + ) + parser.add_argument( + "--batch_size", + type=int, + default=512, + help="Batch size for processing sequences. Adjust based on GPU memory.", + ) + parser.add_argument( + "--max_length", + type=int, + default=512, + help="Maximum sequence length for the model. Sequences longer than this will be truncated.", + ) + parser.add_argument( + "--max_num_per_shard", + type=int, + default=100_000, + help="Maximum number of sequences to process in each shard. Adjust based on GPU memory.", + ) + parser.add_argument( + "--cur_num_in_shard", + type=int, + default=0, + help="Current number of sequences processed in the current shard. Used for resuming processing.", + ) + parser.add_argument( + "--cur_shard_num", + type=int, + default=0, + help="Current shard number. Used for resuming processing.", + ) + return parser.parse_args() + + +def setup(rank, world_size): + # Initialize process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +def load_model(model_name: str = "lightonai/RITA_xl", max_length: int = 512) -> torch.nn.Module: + """Load the model and tokenizer.""" + model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_name) + tokenizer.pad_token = "" + tokenizer.pad_token_id = tokenizer.vocab[""] + tokenizer.max_length = max_length + model.eval() + return model, tokenizer + + +def get_model_device(model: torch.nn.Module) -> torch.device: + return next(model.parameters()).device + + +def compute_loss(batch, model, tokenizer, max_length, device=None) -> List[Dict[str, Any]]: + sequences, headers = batch + if device is not None: + device = get_model_device(model) + + sequences = [s[:max_length] for s in sequences] + inputs = tokenizer(sequences, return_tensors="pt", padding=True, truncation=False) + input_ids = inputs["input_ids"].to(device) + attn_mask = inputs["attention_mask"].to(device) + N, L = input_ids.shape[0], input_ids.shape[1] - 1 # remove EOS token + + with torch.no_grad(): + output = model(input_ids=input_ids, attention_mask=attn_mask) + + targets = input_ids[:, 1:].reshape(-1) + logits = output["logits"] + logits = logits[:, :-1, :].reshape(-1, logits.shape[-1]) + per_token_loss = cross_entropy(logits, targets, reduction="none") + per_token_loss = per_token_loss.reshape(-1, L).half() # store as float16. + + processed = [ + { + "sequence": sequences[i], + "header": headers[i], + "per_token_loss": per_token_loss[i, : min(len(sequences[i]), max_length)].cpu().tolist(), + } + for i in range(len(sequences)) + ] + return processed + + +def main(rank, args, world_size): + if world_size > 1: + setup(rank, world_size) + + output_dir = Path(args.output_dir) / args.model_name.replace("/", "_") + if not output_dir.exists(): + output_dir.mkdir(parents=True) + + # the fasta loader relies on offsets to do file.seek operations + # we can paralellize this by splitting up the offset array into subsections for each GPU + offset_array = np.load(args.offset_array_path) + print("Original offset array shape:", offset_array.shape) + assert len(offset_array.shape) == 2 + assert offset_array.shape[0] == 2 + + # Partition data for this GPU + per_gpu_size = offset_array.shape[1] // world_size + start_idx = rank * per_gpu_size + end_idx = start_idx + per_gpu_size if rank < world_size - 1 else offset_array.shape[1] + + local_offsets = offset_array[:, start_idx:end_idx] + print(f"Rank {rank} processing offsets from {start_idx} to {end_idx}") + print(f"Rank {rank} processing {local_offsets.shape[1]} sequences from {args.fasta_file}") + + # Create dataset and dataloader for this GPU + dataset = FASTADataset(root=args.fasta_file, offsets_arr=local_offsets, use_text_descriptions=True) + + sampler = torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=world_size, rank=rank) + + dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=sampler, shuffle=False) + + # Create model + model, tokenizer = load_model(args.model_name, args.max_length) + device = torch.device("cuda", rank) + model.to(device) + + # wrap in DDP and compile + if world_size > 1: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + ddp_model = DistributedDataParallel(model, device_ids=[rank], output_device=rank) + else: + ddp_model = model + + ddp_model = torch.compile(ddp_model) + + # Inference loop + results_tmp_list = [] + cur_shard_num = 0 + cur_num_in_shard = 0 + + for batch in dataloader: + with torch.no_grad(): + outputs = compute_loss(batch, ddp_model, tokenizer, args.max_length, device) + results_tmp_list.extend(outputs) + cur_num_in_shard += len(outputs) + + if cur_num_in_shard >= args.max_num_per_shard: + print(f"Saving shard {cur_shard_num} to {output_file}...") + output_file = output_dir / f"rank_{rank:02}_shard_{cur_shard_num:06}.parquet" + pd.DataFrame(results_tmp_list).to_parquet(output_file, engine="pyarrow", index=False) + + cur_shard_num += 1 + cur_num_in_shard = 0 + results_tmp_list = [] + + else: + print(f"Rank {rank} processed {cur_num_in_shard} sequences in shard {cur_shard_num}") + + +if __name__ == "__main__": + args = get_args() + world_size = torch.cuda.device_count() + + if world_size == 1: + print("Only one GPU available. Running without DDP.") + main(0, args, world_size) + exit() + + else: + print(f"Using {world_size} GPUs for DDP.") + rank = int(os.environ["LOCAL_RANK"]) + mp.spawn(main, (args, world_size), world_size, join=True) + + if world_size > 1: + cleanup()