diff --git a/src/setfit/data.py b/src/setfit/data.py index abb6d8a2..275591c0 100644 --- a/src/setfit/data.py +++ b/src/setfit/data.py @@ -1,3 +1,5 @@ +import math +import random from typing import TYPE_CHECKING, Dict, List, Tuple import pandas as pd @@ -280,3 +282,42 @@ def collate_fn(batch): labels = torch.Tensor(labels).long() return features, labels + + +class NoDuplicateClassesDataLoader: + def __init__(self, train_examples, batch_size): + self.batch_size = batch_size + self.collate_fn = None + self.train_examples = train_examples + + # TODO: add assert batch_size <= num_classes + + def __iter__(self): + label_class_dict = {} + random.shuffle(self.train_examples) + for example in self.train_examples: + example_label_list = label_class_dict.get(example.label, []) + example_label_list.append(example) + label_class_dict[example.label] = example_label_list + + for _ in range(self.__len__()): + batch = [] + classes_in_batch = set() + + while len(batch) < self.batch_size: + class_to_add = random.choice(label_class_dict.keys()) + if class_to_add not in classes_in_batch: + example = label_class_dict[class_to_add].pop(0) + batch.append(example) + + # list of examples for this class is empty and needs to be refilled + if len(label_class_dict[class_to_add]) == 0: + random.shuffle(self.train_examples) + for example in self.train_examples: + if example.label == class_to_add: + label_class_dict[class_to_add].append(example) + + yield self.collate_fn(batch) if self.collate_fn is not None else batch + + def __len__(self): + return math.floor(len(self.train_examples) / self.batch_size)