11import functools
2+ import math
23import random
34from collections import OrderedDict
4- from typing import Callable , List , Sequence , Tuple
5+ from typing import Callable , List , Optional , Sequence , Tuple
56
67import numpy as np
78import torch
9+ import torch .distributed as dist
810from esm .data import BatchConverter
11+ from pytorch_lightning import LightningDataModule
912from torch .utils .data import DataLoader , Dataset , Sampler
1013
1114
@@ -187,6 +190,7 @@ def get_batch_indices(
187190 sequence_strs : List [str ],
188191 toks_per_batch : int ,
189192 crop_sizes : Tuple [int , int ] = (600 , 1200 ),
193+ seed : int = 0 ,
190194) -> List [List [List [Tuple [int , int ]]]]:
191195 """
192196 This sampler aims to create batches that do not contain fixed number of sequences
@@ -208,31 +212,37 @@ def get_batch_indices(
208212
209213 Args:
210214 sequence_strs: list of string
211- toks_per_batch: maximum number of token per batch
212- crop_sizes: min and max sequence lengths when cropping
215+ toks_per_batch (int): Maximum number of token per batch
216+ extra_toks_per_seq (int, optional): . Defaults to 0.
217+ crop_sizes (Tuple[int, int]): min and max sequence lengths when cropping
218+ seed (int): seed to be used for random generator
213219
214220 Returns:
215221 List: List of batches indexes and lengths
216222 """
217223 min_size , max_size = crop_sizes
218224 buffer_type = List [Tuple [int , int ]]
219225
220- def crop_length (length : int ) -> int :
221- crop_size = random .randint (min_size , max_size ) - 2
226+ rand_generator = random .Random (seed )
227+
228+ def crop_length (length : int , random_generator : random .Random ) -> int :
229+ crop_size = random_generator .randint (min_size , max_size ) - 2
222230 if length > crop_size :
223231 return crop_size
224232 else :
225233 return length
226234
227- sizes = [(crop_length (len (s )), i ) for i , s in enumerate (sequence_strs )]
235+ sizes = [
236+ (crop_length (len (s ), rand_generator ), i ) for i , s in enumerate (sequence_strs )
237+ ]
228238 min_length , max_length = min ([t [0 ] for t in sizes ]), max ([t [0 ] for t in sizes ])
229239
230240 # if there is a large gap between min and max size, sort the list
231241 if min_length < 0.8 * max_length :
232242 sizes .sort ()
233243 # otherwise shuffle it
234244 else :
235- random .shuffle (sizes )
245+ rand_generator .shuffle (sizes )
236246
237247 batches : List [List [buffer_type ]] = []
238248 buffer : buffer_type = []
@@ -257,7 +267,7 @@ def _flush_current_buf():
257267
258268 _flush_current_buf ()
259269
260- random .shuffle (batches )
270+ rand_generator .shuffle (batches )
261271 return batches
262272
263273
@@ -297,6 +307,98 @@ def __iter__(self):
297307 )
298308
299309
310+ class DistributedBatchWithConstantNumberTokensSampler (Sampler ):
311+ """
312+ Sampler that returns batches of sequences indices in the dataset so that to ensure
313+ not a fixed number of sequences per batch but rather a fixed number of tokens per
314+ batch. This sampler also takes into account that we may want to crop dynamically
315+ sequences when sampling and thus returns in addition to indices, desired cropping
316+ lengths to inform the dataloader. This version of the sampler is distributed to
317+ be used with DDP accelerator.
318+ """
319+
320+ def __init__ (
321+ self ,
322+ sequence_strs : List [str ],
323+ toks_per_batch : int ,
324+ crop_sizes : Tuple [int , int ] = (512 , 1024 ),
325+ num_replicas : Optional [int ] = None ,
326+ rank : Optional [int ] = None ,
327+ seed : int = 0 ,
328+ ):
329+ Sampler .__init__ (self , data_source = None )
330+
331+ # Replicate Torch Distributed Sampler logic
332+ if num_replicas is None :
333+ if not dist .is_available ():
334+ raise RuntimeError ("Requires distributed package to be available" )
335+ num_replicas = dist .get_world_size ()
336+ if rank is None :
337+ if not dist .is_available ():
338+ raise RuntimeError ("Requires distributed package to be available" )
339+ rank = dist .get_rank ()
340+ if rank >= num_replicas or rank < 0 :
341+ raise ValueError (
342+ "Invalid rank {}, rank should be in the interval"
343+ " [0, {}]" .format (rank , num_replicas - 1 )
344+ )
345+
346+ self ._num_replicas = num_replicas
347+ self ._rank = rank
348+ self ._epoch = 0
349+ self ._seed = seed
350+
351+ self ._sequence_strs = sequence_strs
352+ self ._toks_per_batch = toks_per_batch
353+ self ._crop_sizes = crop_sizes
354+ self ._init_batches = get_batch_indices (
355+ sequence_strs = sequence_strs ,
356+ toks_per_batch = toks_per_batch ,
357+ crop_sizes = crop_sizes ,
358+ seed = self ._seed + self ._epoch ,
359+ )
360+ self ._num_samples = math .ceil (len (self ._init_batches ) / self ._num_replicas )
361+ self ._total_size = self ._num_samples * self ._num_replicas
362+
363+ def __len__ (self ) -> int :
364+ return self ._num_samples
365+
366+ def set_epoch (self , epoch : int ) -> None :
367+ self ._epoch = epoch
368+
369+ def __iter__ (self ):
370+
371+ # generate batches with constant number of tokens
372+ batches = get_batch_indices (
373+ sequence_strs = self ._sequence_strs ,
374+ toks_per_batch = self ._toks_per_batch ,
375+ crop_sizes = self ._crop_sizes ,
376+ seed = self ._seed + self ._epoch ,
377+ )
378+
379+ # shuffle the indices
380+ rng = np .random .default_rng (seed = self ._seed + self ._epoch )
381+ indices = list (rng .permutation (len (batches )))
382+
383+ # add extra samples to make it evenly divisible
384+ padding_size = self ._total_size - len (indices )
385+ if padding_size <= len (indices ):
386+ indices += indices [:padding_size ]
387+ else :
388+ indices += (indices * math .ceil (padding_size / len (indices )))[:padding_size ]
389+ assert len (indices ) == self ._total_size
390+
391+ # subsample (to get batches for this worker)
392+ indices = indices [self ._rank : self ._total_size : self ._num_replicas ]
393+ assert len (indices ) == self ._num_samples
394+
395+ # get corresponding batches
396+ batches = [batches [i ] for i in indices ]
397+
398+ # return iterator
399+ yield from batches
400+
401+
300402class BatchWithConstantNumberTokensDataset (Dataset ):
301403 """
302404 Dataset class to work in pair with the BatchWithConstantNumberTokensSampler.
@@ -319,52 +421,58 @@ def __getitem__(self, sampler_out) -> List[str]:
319421 return sequences
320422
321423
322- def create_dataloader (
323- sequences : List [str ],
324- alphabet : AlphabetDataLoader ,
325- masking_ratio : float ,
326- masking_prob : float ,
327- random_token_prob : float ,
328- num_workers : int ,
329- toks_per_batch : int ,
330- crop_sizes : Tuple [int , int ] = (512 , 1024 ),
331- ) -> DataLoader :
332- """Create the PyTorch Dataloader.
424+ class BatchWithConstantNumberTokensDataModule (LightningDataModule ):
425+ def __init__ (
426+ self ,
427+ train_sequences : List [str ],
428+ validation_sequences : List [str ],
429+ alphabet : AlphabetDataLoader ,
430+ masking_ratio : float ,
431+ masking_prob : float ,
432+ random_token_prob : float ,
433+ num_workers : int ,
434+ toks_per_batch : int ,
435+ crop_sizes : Tuple [int , int ] = (512 , 1024 ),
436+ ):
437+ LightningDataModule .__init__ (self )
438+ self ._train_sequences = train_sequences
439+ self ._validation_sequences = validation_sequences
440+ self ._alphabet = alphabet
441+ self ._masking_ratio = masking_ratio
442+ self ._masking_prob = masking_prob
443+ self ._random_token_prob = random_token_prob
444+ self ._num_workers = num_workers
445+ self ._toks_per_batch = toks_per_batch
446+ self ._crop_sizes = crop_sizes
333447
334- Args:
335- filenames: list of sequences
336- alphabet: facebook alphabet.
337- filter_len: whether filter data wrt len.batch_seq
338- num_workers: num of parallel data samplers
339- masking_ratio: ratio of tokens to be masked.
340- masking_prob: probability that the chose token is replaced with a mask token.
341- random_token_prob: probability that the chose token is replaced with a random token.
342- toks_per_batch: number of tokens per batch
343- crop_sizes: range of values to crop dynamically sequences when sampling them
448+ def _get_dataloader (self , sequences : List [str ]) -> DataLoader :
449+ dataset = BatchWithConstantNumberTokensDataset (sequences )
450+ batch_sampler = DistributedBatchWithConstantNumberTokensSampler (
451+ sequence_strs = sequences ,
452+ toks_per_batch = self ._toks_per_batch ,
453+ crop_sizes = self ._crop_sizes ,
454+ )
344455
345- Returns:
346- torch DataLoader
347- """
456+ loader = DataLoader (
457+ dataset ,
458+ num_workers = self ._num_workers ,
459+ collate_fn = functools .partial (
460+ collate_fn ,
461+ tokenizer = self ._alphabet .tokenizer (),
462+ alphabet = self ._alphabet ,
463+ masking_ratio = self ._masking_ratio ,
464+ masking_prob = self ._masking_prob ,
465+ random_token_prob = self ._random_token_prob ,
466+ ),
467+ pin_memory = True ,
468+ worker_init_fn = worker_init_fn ,
469+ batch_sampler = batch_sampler ,
470+ sampler = None ,
471+ )
472+ return loader
348473
349- dataset = BatchWithConstantNumberTokensDataset (sequences )
350- batch_sampler = BatchWithConstantNumberTokensSampler (
351- sequence_strs = sequences , toks_per_batch = toks_per_batch , crop_sizes = crop_sizes
352- )
474+ def train_dataloader (self ):
475+ return self ._get_dataloader (self ._train_sequences )
353476
354- loader = DataLoader (
355- dataset ,
356- num_workers = num_workers ,
357- collate_fn = functools .partial (
358- collate_fn ,
359- tokenizer = alphabet .tokenizer (),
360- alphabet = alphabet ,
361- masking_ratio = masking_ratio ,
362- masking_prob = masking_prob ,
363- random_token_prob = random_token_prob ,
364- ),
365- pin_memory = True ,
366- worker_init_fn = worker_init_fn ,
367- batch_sampler = batch_sampler ,
368- sampler = None ,
369- )
370- return loader
477+ def val_dataloader (self ):
478+ return self ._get_dataloader (self ._validation_sequences )
0 commit comments