Skip to content

Commit cb54bbb

Browse files
fix data changes
2 parents 5efa4e7 + f36deb0 commit cb54bbb

File tree

6 files changed

+200
-151
lines changed

6 files changed

+200
-151
lines changed

biotransformers/lightning_utils/data.py

Lines changed: 161 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import functools
2+
import math
23
import random
34
from collections import OrderedDict
4-
from typing import Callable, List, Sequence, Tuple
5+
from typing import Callable, List, Optional, Sequence, Tuple
56

67
import numpy as np
78
import torch
9+
import torch.distributed as dist
810
from esm.data import BatchConverter
11+
from pytorch_lightning import LightningDataModule
912
from torch.utils.data import DataLoader, Dataset, Sampler
1013

1114

@@ -187,6 +190,7 @@ def get_batch_indices(
187190
sequence_strs: List[str],
188191
toks_per_batch: int,
189192
crop_sizes: Tuple[int, int] = (600, 1200),
193+
seed: int = 0,
190194
) -> List[List[List[Tuple[int, int]]]]:
191195
"""
192196
This sampler aims to create batches that do not contain fixed number of sequences
@@ -208,31 +212,37 @@ def get_batch_indices(
208212
209213
Args:
210214
sequence_strs: list of string
211-
toks_per_batch: maximum number of token per batch
212-
crop_sizes: min and max sequence lengths when cropping
215+
toks_per_batch (int): Maximum number of token per batch
216+
extra_toks_per_seq (int, optional): . Defaults to 0.
217+
crop_sizes (Tuple[int, int]): min and max sequence lengths when cropping
218+
seed (int): seed to be used for random generator
213219
214220
Returns:
215221
List: List of batches indexes and lengths
216222
"""
217223
min_size, max_size = crop_sizes
218224
buffer_type = List[Tuple[int, int]]
219225

220-
def crop_length(length: int) -> int:
221-
crop_size = random.randint(min_size, max_size) - 2
226+
rand_generator = random.Random(seed)
227+
228+
def crop_length(length: int, random_generator: random.Random) -> int:
229+
crop_size = random_generator.randint(min_size, max_size) - 2
222230
if length > crop_size:
223231
return crop_size
224232
else:
225233
return length
226234

227-
sizes = [(crop_length(len(s)), i) for i, s in enumerate(sequence_strs)]
235+
sizes = [
236+
(crop_length(len(s), rand_generator), i) for i, s in enumerate(sequence_strs)
237+
]
228238
min_length, max_length = min([t[0] for t in sizes]), max([t[0] for t in sizes])
229239

230240
# if there is a large gap between min and max size, sort the list
231241
if min_length < 0.8 * max_length:
232242
sizes.sort()
233243
# otherwise shuffle it
234244
else:
235-
random.shuffle(sizes)
245+
rand_generator.shuffle(sizes)
236246

237247
batches: List[List[buffer_type]] = []
238248
buffer: buffer_type = []
@@ -257,7 +267,7 @@ def _flush_current_buf():
257267

258268
_flush_current_buf()
259269

260-
random.shuffle(batches)
270+
rand_generator.shuffle(batches)
261271
return batches
262272

263273

@@ -297,6 +307,98 @@ def __iter__(self):
297307
)
298308

299309

310+
class DistributedBatchWithConstantNumberTokensSampler(Sampler):
311+
"""
312+
Sampler that returns batches of sequences indices in the dataset so that to ensure
313+
not a fixed number of sequences per batch but rather a fixed number of tokens per
314+
batch. This sampler also takes into account that we may want to crop dynamically
315+
sequences when sampling and thus returns in addition to indices, desired cropping
316+
lengths to inform the dataloader. This version of the sampler is distributed to
317+
be used with DDP accelerator.
318+
"""
319+
320+
def __init__(
321+
self,
322+
sequence_strs: List[str],
323+
toks_per_batch: int,
324+
crop_sizes: Tuple[int, int] = (512, 1024),
325+
num_replicas: Optional[int] = None,
326+
rank: Optional[int] = None,
327+
seed: int = 0,
328+
):
329+
Sampler.__init__(self, data_source=None)
330+
331+
# Replicate Torch Distributed Sampler logic
332+
if num_replicas is None:
333+
if not dist.is_available():
334+
raise RuntimeError("Requires distributed package to be available")
335+
num_replicas = dist.get_world_size()
336+
if rank is None:
337+
if not dist.is_available():
338+
raise RuntimeError("Requires distributed package to be available")
339+
rank = dist.get_rank()
340+
if rank >= num_replicas or rank < 0:
341+
raise ValueError(
342+
"Invalid rank {}, rank should be in the interval"
343+
" [0, {}]".format(rank, num_replicas - 1)
344+
)
345+
346+
self._num_replicas = num_replicas
347+
self._rank = rank
348+
self._epoch = 0
349+
self._seed = seed
350+
351+
self._sequence_strs = sequence_strs
352+
self._toks_per_batch = toks_per_batch
353+
self._crop_sizes = crop_sizes
354+
self._init_batches = get_batch_indices(
355+
sequence_strs=sequence_strs,
356+
toks_per_batch=toks_per_batch,
357+
crop_sizes=crop_sizes,
358+
seed=self._seed + self._epoch,
359+
)
360+
self._num_samples = math.ceil(len(self._init_batches) / self._num_replicas)
361+
self._total_size = self._num_samples * self._num_replicas
362+
363+
def __len__(self) -> int:
364+
return self._num_samples
365+
366+
def set_epoch(self, epoch: int) -> None:
367+
self._epoch = epoch
368+
369+
def __iter__(self):
370+
371+
# generate batches with constant number of tokens
372+
batches = get_batch_indices(
373+
sequence_strs=self._sequence_strs,
374+
toks_per_batch=self._toks_per_batch,
375+
crop_sizes=self._crop_sizes,
376+
seed=self._seed + self._epoch,
377+
)
378+
379+
# shuffle the indices
380+
rng = np.random.default_rng(seed=self._seed + self._epoch)
381+
indices = list(rng.permutation(len(batches)))
382+
383+
# add extra samples to make it evenly divisible
384+
padding_size = self._total_size - len(indices)
385+
if padding_size <= len(indices):
386+
indices += indices[:padding_size]
387+
else:
388+
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
389+
assert len(indices) == self._total_size
390+
391+
# subsample (to get batches for this worker)
392+
indices = indices[self._rank : self._total_size : self._num_replicas]
393+
assert len(indices) == self._num_samples
394+
395+
# get corresponding batches
396+
batches = [batches[i] for i in indices]
397+
398+
# return iterator
399+
yield from batches
400+
401+
300402
class BatchWithConstantNumberTokensDataset(Dataset):
301403
"""
302404
Dataset class to work in pair with the BatchWithConstantNumberTokensSampler.
@@ -319,52 +421,58 @@ def __getitem__(self, sampler_out) -> List[str]:
319421
return sequences
320422

321423

322-
def create_dataloader(
323-
sequences: List[str],
324-
alphabet: AlphabetDataLoader,
325-
masking_ratio: float,
326-
masking_prob: float,
327-
random_token_prob: float,
328-
num_workers: int,
329-
toks_per_batch: int,
330-
crop_sizes: Tuple[int, int] = (512, 1024),
331-
) -> DataLoader:
332-
"""Create the PyTorch Dataloader.
424+
class BatchWithConstantNumberTokensDataModule(LightningDataModule):
425+
def __init__(
426+
self,
427+
train_sequences: List[str],
428+
validation_sequences: List[str],
429+
alphabet: AlphabetDataLoader,
430+
masking_ratio: float,
431+
masking_prob: float,
432+
random_token_prob: float,
433+
num_workers: int,
434+
toks_per_batch: int,
435+
crop_sizes: Tuple[int, int] = (512, 1024),
436+
):
437+
LightningDataModule.__init__(self)
438+
self._train_sequences = train_sequences
439+
self._validation_sequences = validation_sequences
440+
self._alphabet = alphabet
441+
self._masking_ratio = masking_ratio
442+
self._masking_prob = masking_prob
443+
self._random_token_prob = random_token_prob
444+
self._num_workers = num_workers
445+
self._toks_per_batch = toks_per_batch
446+
self._crop_sizes = crop_sizes
333447

334-
Args:
335-
filenames: list of sequences
336-
alphabet: facebook alphabet.
337-
filter_len: whether filter data wrt len.batch_seq
338-
num_workers: num of parallel data samplers
339-
masking_ratio: ratio of tokens to be masked.
340-
masking_prob: probability that the chose token is replaced with a mask token.
341-
random_token_prob: probability that the chose token is replaced with a random token.
342-
toks_per_batch: number of tokens per batch
343-
crop_sizes: range of values to crop dynamically sequences when sampling them
448+
def _get_dataloader(self, sequences: List[str]) -> DataLoader:
449+
dataset = BatchWithConstantNumberTokensDataset(sequences)
450+
batch_sampler = DistributedBatchWithConstantNumberTokensSampler(
451+
sequence_strs=sequences,
452+
toks_per_batch=self._toks_per_batch,
453+
crop_sizes=self._crop_sizes,
454+
)
344455

345-
Returns:
346-
torch DataLoader
347-
"""
456+
loader = DataLoader(
457+
dataset,
458+
num_workers=self._num_workers,
459+
collate_fn=functools.partial(
460+
collate_fn,
461+
tokenizer=self._alphabet.tokenizer(),
462+
alphabet=self._alphabet,
463+
masking_ratio=self._masking_ratio,
464+
masking_prob=self._masking_prob,
465+
random_token_prob=self._random_token_prob,
466+
),
467+
pin_memory=True,
468+
worker_init_fn=worker_init_fn,
469+
batch_sampler=batch_sampler,
470+
sampler=None,
471+
)
472+
return loader
348473

349-
dataset = BatchWithConstantNumberTokensDataset(sequences)
350-
batch_sampler = BatchWithConstantNumberTokensSampler(
351-
sequence_strs=sequences, toks_per_batch=toks_per_batch, crop_sizes=crop_sizes
352-
)
474+
def train_dataloader(self):
475+
return self._get_dataloader(self._train_sequences)
353476

354-
loader = DataLoader(
355-
dataset,
356-
num_workers=num_workers,
357-
collate_fn=functools.partial(
358-
collate_fn,
359-
tokenizer=alphabet.tokenizer(),
360-
alphabet=alphabet,
361-
masking_ratio=masking_ratio,
362-
masking_prob=masking_prob,
363-
random_token_prob=random_token_prob,
364-
),
365-
pin_memory=True,
366-
worker_init_fn=worker_init_fn,
367-
batch_sampler=batch_sampler,
368-
sampler=None,
369-
)
370-
return loader
477+
def val_dataloader(self):
478+
return self._get_dataloader(self._validation_sequences)

biotransformers/wrappers/esm_wrappers.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@
77

88
import esm
99
import torch
10-
from biotransformers.lightning_utils.data import (
11-
AlphabetDataLoader,
12-
convert_ckpt_to_statedict,
13-
)
10+
from biotransformers.lightning_utils.data import AlphabetDataLoader
1411
from biotransformers.utils.constant import DEFAULT_ESM_MODEL, ESM_LIST
1512
from biotransformers.utils.logger import logger # noqa
1613
from biotransformers.utils.utils import _generate_chunks, _get_num_batch_iter
@@ -49,6 +46,10 @@ def model(self) -> torch.nn.Module:
4946
"""Return torch model."""
5047
return self._model
5148

49+
def set_model(self, model: torch.nn.Module):
50+
"""Set torch model."""
51+
self._model = model.to(self._model.device)
52+
5253
@property
5354
def clean_model_id(self) -> str:
5455
"""Clean model ID (in case the model directory is not)"""
@@ -118,20 +119,6 @@ def process_sequences_and_tokens(
118119
}
119120
return encoded_inputs
120121

121-
def _load_model(self, path_model: str, map_location=None):
122-
"""Load model."""
123-
if path_model.endswith(".pt"):
124-
loaded_model = torch.load(path_model)
125-
elif path_model.endswith(".ckpt"):
126-
loaded_model = convert_ckpt_to_statedict(
127-
torch.load(path_model)["state_dict"]
128-
)
129-
else:
130-
raise ValueError("Expecting a .pt or .ckpt file")
131-
self._model.load_state_dict(loaded_model, map_location)
132-
self._model.eval()
133-
log.info("Load model %s" % path_model)
134-
135122
def model_pass(
136123
self,
137124
model_inputs: Dict[str, torch.Tensor],

biotransformers/wrappers/language_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def model(self) -> torch.nn.Module:
9999
pass
100100

101101
@abstractmethod
102-
def _load_model(self, path: str):
103-
"""Load model."""
102+
def set_model(self, model: torch.nn.Module):
103+
"""Set torch model."""
104104
pass
105105

106106
@abstractmethod

biotransformers/wrappers/rostlab_wrapper.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@
99

1010
import torch
1111
import copy
12-
from biotransformers.lightning_utils.data import (
13-
AlphabetDataLoader,
14-
convert_ckpt_to_statedict,
15-
)
12+
from biotransformers.lightning_utils.data import AlphabetDataLoader
1613
from biotransformers.utils.constant import DEFAULT_ROSTLAB_MODEL, ROSTLAB_LIST
1714
from biotransformers.utils.logger import logger # noqa
1815
from biotransformers.utils.utils import _generate_chunks, _get_num_batch_iter
@@ -52,6 +49,10 @@ def model(self) -> torch.nn.Module:
5249
"""Return torch model."""
5350
return self._model
5451

52+
def set_model(self, model: torch.nn.Module):
53+
"""Set torch model."""
54+
self._model = model.to(self._model.device)
55+
5556
@property
5657
def clean_model_id(self) -> str:
5758
"""Clean model ID (in case the model directory is not)"""
@@ -102,20 +103,6 @@ def embeddings_size(self) -> int:
102103
"""Returns size of the embeddings"""
103104
return self.hidden_size
104105

105-
def _load_model(self, path_model: str, map_location=None):
106-
"""Load model."""
107-
if path_model.endswith(".pt"):
108-
loaded_model = torch.load(path_model)
109-
elif path_model.endswith(".ckpt"):
110-
loaded_model = convert_ckpt_to_statedict(
111-
torch.load(path_model)["state_dict"]
112-
)
113-
else:
114-
raise ValueError("Expecting a .pt or .ckpt file")
115-
self._model.load_state_dict(loaded_model, map_location)
116-
self._model.eval()
117-
log.info("Load model %s" % path_model)
118-
119106
def process_sequences_and_tokens(
120107
self,
121108
sequences_list: List[str],

0 commit comments

Comments
 (0)