diff --git a/heat/utils/data/datatools.py b/heat/utils/data/datatools.py index 044195ccbe..6462931a4f 100644 --- a/heat/utils/data/datatools.py +++ b/heat/utils/data/datatools.py @@ -2,15 +2,34 @@ Function and classes useful for loading data into neural networks """ +import itertools +import random +import warnings +import mpi4py import torch +import torch.distributed from torch.utils import data as torch_data -from typing import Callable, List, Iterator, Union, Optional, Sized +from typing import Callable, List, Iterator, Literal, Union, Optional, Sized +from mpi4py import MPI +from functools import reduce + +import torch.utils +import torchvision from ...core.dndarray import DNDarray -from ...core.communication import MPI_WORLD +from ...core.communication import GPU_AWARE_MPI, MPI_WORLD, MPICommunication +from ...core.random import permutation from . import partial_dataset -__all__ = ["DataLoader", "Dataset", "dataset_shuffle", "dataset_ishuffle"] +__all__ = [ + "DataLoader", + "Dataset", + "dataset_shuffle", + "dataset_ishuffle", + "DistributedDataset", + "DistributedSampler", + "create_train_val_split", +] class DataLoader: @@ -244,6 +263,409 @@ def Ishuffle(self): dataset_ishuffle(dataset=self, attrs=[["data", "htdata"]]) +class DistributedDataset(torch_data.Dataset): + """ + A DistributedDataset for usage in PyTorch. Saves the dndarray and the larray tensor. Uses the larray tensor + for the distribution and getting the items. Intented to be used with DistributedSampler. + """ + + def __init__(self, dndarray: DNDarray, transforms: torchvision.transforms.Compose = None): + if not isinstance(dndarray, DNDarray): + raise TypeError(f"Expected DNDarray but got {type(dndarray)}") + if dndarray.split != 0: + raise ValueError("DistributedDataset only works with a DNDarray split of 0") + + self.dndarray = dndarray + self.transforms = transforms + + def __len__(self) -> int: + return len(self.dndarray.larray) + + def __getitem__(self, index): + item = self.dndarray.larray[index] + if self.transforms is not None: + return self.transforms(item) + return item + + def __getitems__(self, indices): + if self.transforms is not None: + return tuple(self.transforms(self.dndarray.larray[index]) for index in indices) + return tuple(self.dndarray.larray[index] for index in indices) + + +class DistributedSampler(torch_data.Sampler): + """ + A DistributedSampler for usage in PyTorch with Heat Arrays. Uses the nature of the Heat DNDArray + to give the locally stored data on the larray. Shuffling is done by shuffling the indices. + The given Indices corrospond to the index of the larray tensor. + Works only with DNDarray that are split on axis 0 + """ + + def __init__( + self, + dataset: DistributedDataset, + shuffle: bool = False, + seed: Optional[int] = None, + shuffle_type: Literal["global"] | Literal["local"] = "global", + correction: bool = False, + ) -> None: + """ + Parameters + ---------- + dataset : DistributedDataset + Dataset to be shuffled + shuffle : bool, optional + If the underlying DNDarray should be shuffled, by default False + seed : int, optional + seed for shuffling, by default None + shuffle_type : Literal["global"] | Literal["local"], optional + Wether to shuffle process local or get new data using by shuffling globally across all processes, by default "global" + correction : bool, optional + If index correction is wanted after an global shuffle, by default False + """ + if not isinstance(dataset, DistributedDataset): + raise TypeError(f"Expected DistributedDataset for dataset not {type(dataset)}") + if not isinstance(shuffle, bool): + raise TypeError(f"Expected bool for shuffle not {type(shuffle)}") + if not isinstance(seed, int) and seed is not None: + raise TypeError(f"Expected int or None for seed not {type(shuffle)}") + if not isinstance(shuffle_type, str): + raise TypeError("Shuffle Type needs to be an string") + if not isinstance(correction, bool): + raise TypeError("Correction Parameter needs to be an bool") + + self.dataset = dataset + self.dndarray = dataset.dndarray + self.shuffle = shuffle + self.linked_sampler = None + self.correction = correction + self.set_shuffle_type(shuffle_type) + self.set_seed(seed) + + if self.dndarray.split != 0: + raise ValueError("DistributedSampler only works with a DNDarray split of 0") + + @staticmethod + def _in_slice(idx: int, a_slice: slice) -> bool: + """Check if the given index is inside the given slice + + Parameters + ---------- + idx : int + Index to check + a_slice : slice + Slice to check + + Returns + ------- + bool + Wether index is in slice + """ + if idx < a_slice.start or idx >= a_slice.stop: + return False + step = a_slice.step if a_slice.step else 1 + if (idx - a_slice.start) % step == 0: + return True + else: + return False + + def _shuffle(self) -> None: + """Shuffles the given dndarray at creation across processes.""" + if self.shuffle_type == "local": + rand_perm = torch.randperm(self.dndarray.larray.shape[0]) + self.dndarray.larray = self.dndarray.larray[rand_perm] + return + + if self.shuffle_type != "global": + raise ValueError("Shuffle type is not 'local' nor 'global'") + + # TODO: Find out which implementation is better + # self.dndarray = permutation(self.dndarray) + # self.dataset.dndarray = self.dndarray + self._alltoall_shuffle() + + def _alltoall_shuffle(self) -> None: + # Exchanges the data using Indexed data types and i iaj + dtype = self.dndarray.dtype.torch_type() + comm: MPICommunication = self.dndarray.comm + rank: int = comm.rank + world_size: int = comm.size + N: int = self.dndarray.gshape[0] + mpi_type: mpi4py.MPI.Datatype = comm._MPICommunication__mpi_type_mappings[dtype] + + if rank == 0: + indices = torch.randperm(N, dtype=torch.int64) + else: + indices = torch.empty(N, dtype=torch.int64) + mpi4py.MPI.COMM_WORLD.Bcast(indices, root=0) + + indice_buffers: List[List[int]] = [list() for _ in range(world_size)] + rank_slices: List[slice] = [ + comm.chunk((N,), split=0, rank=i)[-1][0] for i in range(world_size) + ] + + block_length: int = reduce(lambda a, b: a * b, self.dndarray.gshape[1:], 1) + local_slice: slice = rank_slices[rank] + local_displacement: int = self.dndarray.counts_displs()[1][rank] * block_length + + # Now figure out which rank needs to send what to each rank and what this rank will receive + for i, idx in enumerate(indices): + idx = idx.item() + for data_send_rank, tslice in enumerate(rank_slices): + if not self._in_slice(idx, tslice): + continue + break + for data_recv_rank, tslice in enumerate(rank_slices): + if not self._in_slice(i, tslice): + continue + break + if data_recv_rank == rank: + indice_buffers[rank].append(idx) + elif data_send_rank == rank: + indice_buffers[data_recv_rank].append(idx) + + # print("RECV BUFFER creating...", flush=True) + send_elems_dtype: List[mpi4py.MPI.Datatype] = list() + local_recv_buffer: torch.Tensor = torch.empty(self.dndarray.larray.shape, dtype=dtype) + + for current_rank in range(world_size): + if current_rank == rank: + send_indice = [ + idx for idx in indice_buffers[current_rank] if self._in_slice(idx, local_slice) + ] + else: + send_indice = indice_buffers[current_rank] + displacements = [ + mpi_type.Get_size() * (disp * block_length - local_displacement) + for disp in send_indice + ] + block_lengths = [block_length] * len(displacements) + send_type = mpi_type.Create_struct( + blocklengths=block_lengths, + displacements=displacements, + datatypes=[mpi_type] * len(displacements), + ) + send_type.Commit() + send_elems_dtype.append(send_type) + + recv_counts = torch.zeros(world_size, dtype=torch.int64) + for idx in indice_buffers[rank]: + for i, tslice in enumerate(rank_slices): + if not self._in_slice(idx, tslice): + continue + recv_counts[i] += 1 + break + + send_elems = self.dndarray.larray + send_elems = send_elems if GPU_AWARE_MPI else send_elems.cpu() + + recv_types: List[mpi4py.MPI.Datatype] = [] + + total_displ = 0 + + for i in range(world_size): + if recv_counts[i] == 0: + recv_type = mpi_type.Create_contiguous(0) + else: + types = [mpi_type.Create_contiguous(block_length) for _ in range(recv_counts[i])] + + displ = torch.zeros(len(types), dtype=torch.int64) + displ[1:] = torch.cumsum(torch.tensor([t.Get_size() for t in types])[:-1], 0) + displ += total_displ + + recv_type = mpi_type.Create_struct( + blocklengths=[1] * len(types), displacements=displ.tolist(), datatypes=types + ) + total_displ += sum([t.Get_size() for t in types]) + + recv_type.Commit() + recv_types.append(recv_type) + + mpi4py.MPI.COMM_WORLD.Alltoallw( + (send_elems, send_elems_dtype), + (local_recv_buffer, recv_types), + ) + + for elem in itertools.chain(recv_types, send_elems_dtype): + elem.Free() + + # As MPI indirectly sorts the data according to the rank we need + # to change that to represent the permutation + if self.correction: + + def get_from_rank(idx): + for i, rslice in enumerate(rank_slices): + if self._in_slice(idx, rslice): + return i + raise RuntimeError("IDX not found in slices") + + idx_to_rank_map = [get_from_rank(idx) for idx in indices[local_slice]] + + sort_idx = torch.argsort(torch.tensor(idx_to_rank_map), stable=True) + local_slices_sorted = indices[local_slice][sort_idx] + + reverse_index = {idx.item(): i for i, idx in enumerate(indices[local_slice])} + idxmap = {i: reverse_index[idx.item()] for i, idx in enumerate(local_slices_sorted)} + + for i, dest in idxmap.items(): + self.dndarray.larray[dest] = local_recv_buffer[i].to(self.dndarray.larray.device) + else: + self.dndarray.larray = local_recv_buffer.to(self.dndarray.larray.device) + + def set_shuffle_type(self, shuffle_type: Literal["global"] | Literal["local"]) -> None: + """Sets the Shuffle type for the Sampler. + + Parameters + ---------- + shuffle_type : Literal["global"] | Literal["local"] + - Local Shuffle means the shuffle of the larray only. + - Global Shuffle means the shuffle across all processes + + Raises + ------ + TypeError + Shuffle type needs to be a string + ValueError + Only Global/Local shuffle types exist + """ + if not isinstance(shuffle_type, str): + raise TypeError("Shuffle type needs to be an string") + if not (shuffle_type == "global" or shuffle_type == "local"): + raise ValueError("only 'global' or 'local' allowed as shuffle type") + + self.shuffle_type: Literal["global"] | Literal["local"] = shuffle_type + + if self.linked_sampler is not None: + self.linked_sampler.set_shuffle_type(shuffle_type) + + def set_seed(self, value: int | None) -> None: + """Sets the seed for the torch.randperm + + Parameters + ---------- + value : int + seed to set + """ + self._seed = value + if value is not None: + torch.manual_seed(value) + if self.shuffle: + self._shuffle() + + if self.linked_sampler is not None: + self.linked_sampler.set_seed(value) + + def link(self, sampler: "DistributedSampler") -> None: + """ + Links another DistributedSampler to this one, to automatically sets the seed/shuffle_type of this and the linked one, + rather than manually setting both seperately. Usefull when one Sampler contains training data and the + linked one the label data. + """ + if not isinstance(sampler, DistributedSampler): + raise TypeError(f"Sampler of type {type(sampler)} needs to be an DistributedSampler") + self.linked_sampler = sampler + + def unlink(self) -> None: + """ + Removes an established link. For more info view :link: function + """ + self.linked_sampler = None + + def __iter__(self) -> Iterator[int]: + if self.shuffle_type == "local": + self.indices = torch.randperm(len(self.dndarray.larray)).tolist() + else: + self.indices = list(range(len(self.dndarray.larray))) + return iter(self.indices) + + def __len__(self) -> int: + return len(self.dndarray.larray) + + +def create_train_val_split( + X: DNDarray, y: DNDarray, p: float = 0.95, seed: int | None = None +) -> tuple[DNDarray, DNDarray, DNDarray, DNDarray]: + """Shuffles the data and then creates the train val split. + + Parameters + ---------- + X : DNDarray + Training Data + y : DNDarray + Training Labels + p : float, optional + How much the training should contain, by default 0.95 + seed : int | None, optional + Random Seed to be used, by default None + + Returns + ------- + tuple[DNDarray, DNDarray, DNDarray, DNDarray] + returns tuple of (train_arr, train_labels_arr, val_arr, val_labels_arr) + """ + if seed is None: + seed = random.randint(-0x8000_0000_0000_0000, 0xFFFF_FFFF_FFFF_FFFF) + + for arr in [X, y]: + dset = DistributedDataset(arr) + _ = DistributedSampler(dset, shuffle=True, seed=seed) + + train_rows = int(X.lshape[0] * p) + val_rows = X.lshape[0] - train_rows + + perm = torch.randperm(X.lshape[0]) + + train_idx = perm[:train_rows] + val_idx = perm[-val_rows:] + + assert len(train_idx) + len(val_idx) == X.lshape[0] + + comm = MPI.COMM_WORLD + + total_train_rows = comm.allreduce(train_rows, MPI.SUM) + total_val_rows = comm.allreduce(val_rows, MPI.SUM) + + train_gshape = tuple([total_train_rows, *X.gshape[1:]]) + val_gshape = tuple([total_val_rows, *X.gshape[1:]]) + + train_arr = DNDarray( + X.larray[train_idx], + train_gshape, + X.dtype, + split=0, + device=X.device, + comm=X.comm, + balanced=True, + ) + val_arr = DNDarray( + X.larray[val_idx], val_gshape, X.dtype, split=0, device=X.device, comm=X.comm, balanced=True + ) + + train_labels_gshape = tuple([total_train_rows, *y.gshape[1:]]) + val_labels_gshape = tuple([total_val_rows, *y.gshape[1:]]) + + train_labels_arr = DNDarray( + y.larray[train_idx], + train_labels_gshape, + y.dtype, + split=0, + device=y.device, + comm=y.comm, + balanced=True, + ) + val_labels_arr = DNDarray( + y.larray[val_idx], + val_labels_gshape, + y.dtype, + split=0, + device=y.device, + comm=y.comm, + balanced=True, + ) + + return train_arr, train_labels_arr, val_arr, val_labels_arr + + def dataset_shuffle(dataset: Union[Dataset, torch_data.Dataset], attrs: List[list]): """ Shuffle the given attributes of a dataset across multiple processes. This will send half of the data to rank + 1. diff --git a/heat/utils/data/tests/test_distributed_data.py b/heat/utils/data/tests/test_distributed_data.py new file mode 100644 index 0000000000..2b59d35c36 --- /dev/null +++ b/heat/utils/data/tests/test_distributed_data.py @@ -0,0 +1,92 @@ +from typing import Optional +import heat as ht +from heat.utils.data.datatools import DistributedDataset, DistributedSampler +import torch +import unittest + + +class SeedEnviroment: + """ + Class to be used in a `with` Enviroment. + Changes the torch seed to the given and then resets it to the previous one when exiting. + """ + + def __init__(self, seed: Optional[int] = None): + self.seed = seed + + def __enter__(self): + self.state = torch.random.get_rng_state() + + if self.seed is not None: + torch.random.manual_seed(self.seed) + + def __exit__(self, *args, **kwargs): + torch.random.set_rng_state(self.state) + + +class TestDistbributedData(unittest.TestCase): + def test_dataset_and_sampler(self) -> bool: + + reference = ht.arange(100, dtype=torch.int32).reshape(20, 5) + + heat_array = ht.copy(reference).resplit_(0) + dset = DistributedDataset(heat_array) + dsampler = DistributedSampler(dset, shuffle=True) + dsampler._shuffle() + + # To test this, the resulting array should be balanced, have the same number of elements as the original one, and the sum of all the columns should be the same + # And the elements should not be equal to each other. + self.assertTrue(dset.dndarray.size == reference.size) + self.assertTrue(dset.dndarray.shape == reference.shape) + self.assertTrue(dset.dndarray.balanced) + + ref_col_sum = reference.sum(0) + col_sum = dset.dndarray.sum(0) + + self.assertTrue(ht.equal(col_sum, ref_col_sum)) + self.assertFalse(ht.equal(reference, dset.dndarray)) + + def test_batches(self) -> bool: + reference = ht.array( + [ + [10, 11, 12, 13, 14], + [20, 21, 22, 23, 24], + [15, 16, 17, 18, 19], + [0, 1, 2, 3, 4], + [5, 6, 7, 8, 9], + ], + split=0, + dtype=ht.int32, + ) + + with SeedEnviroment(): + arr = ht.arange(25, dtype=ht.int32, split=0).reshape(5, 5) + dset = DistributedDataset(arr) + dsampler = DistributedSampler(dset, shuffle=True, seed=42) + + dataloader = torch.utils.data.DataLoader( + dset, batch_size=1, shuffle=False, sampler=dsampler + ) + + for batch in dataloader: + found = False + for larray in reference.larray: + if not torch.isclose(batch, larray).all(): + continue + found = True + break + self.assertTrue(found) + + def test_dataset_exceptions(self) -> bool: + with self.assertRaises(TypeError): + DistributedDataset("") + with self.assertRaises(ValueError): + DistributedDataset(ht.zeros(2, split=1)) + + def test_data_sampler_exceptions(self) -> bool: + with self.assertRaises(TypeError): + DistributedSampler(ht.zeros(10)) + with self.assertRaises(TypeError): + DistributedSampler(DistributedDataset(ht.zeros(2, split=0)), shuffle="") + with self.assertRaises(TypeError): + DistributedSampler(DistributedDataset(ht.zeros(2, split=0)), shuffle=True, seed="")