Skip to content

Commit 99f9863

Browse files
committed
Keep the training data continuous and the total batch size constant regardless of changes in the replica world size.
1 parent 024f850 commit 99f9863

File tree

6 files changed

+619
-10
lines changed

6 files changed

+619
-10
lines changed

torchft/data.py

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,114 @@
1414
dataloader frequently to avoid duplicate batches.
1515
"""
1616

17-
from typing import Optional
18-
17+
import torch
1918
import torch.distributed as dist
19+
from torch.utils.data.dataset import Dataset
20+
from torch.utils.data.sampler import Sampler
2021
from torch.utils import data
2122

23+
import math
24+
from collections.abc import Iterator
25+
from typing import Optional, TypeVar
26+
27+
_T_co = TypeVar("_T_co", covariant=True)
28+
29+
class SkipDistributedSampler(Sampler[_T_co]):
30+
def __init__(
31+
self,
32+
dataset: Dataset,
33+
num_replicas: Optional[int] = None,
34+
rank: Optional[int] = None,
35+
shuffle: bool = True,
36+
seed: int = 0,
37+
drop_last: bool = False,
38+
skip_samples: int = 0,
39+
) -> None:
40+
if num_replicas is None:
41+
if not dist.is_available():
42+
raise RuntimeError("Requires distributed package to be available")
43+
num_replicas = dist.get_world_size()
44+
if rank is None:
45+
if not dist.is_available():
46+
raise RuntimeError("Requires distributed package to be available")
47+
rank = dist.get_rank()
48+
if rank >= num_replicas or rank < 0:
49+
raise ValueError(
50+
f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]"
51+
)
52+
self.dataset = dataset
53+
self.num_replicas = num_replicas
54+
self.rank = rank
55+
self.epoch = 0
56+
self.drop_last = drop_last
57+
self.skip_samples = skip_samples
58+
# If the dataset length is evenly divisible by # of replicas, then there
59+
# is no need to drop any data, since the dataset will be split equally.
60+
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
61+
# Split to nearest available length that is evenly divisible.
62+
# This is to ensure each rank receives the same amount of data when
63+
# using this Sampler.
64+
self.num_samples = math.ceil(
65+
(len(self.dataset) - self.skip_samples - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
66+
)
67+
else:
68+
self.num_samples = math.ceil((len(self.dataset) - self.skip_samples) / self.num_replicas) # type: ignore[arg-type]
69+
self.total_size = self.num_samples * self.num_replicas
70+
self.shuffle = shuffle
71+
self.seed = seed
72+
73+
def __iter__(self) -> Iterator[_T_co]:
74+
if self.shuffle:
75+
# deterministically shuffle based on epoch and seed
76+
g = torch.Generator()
77+
g.manual_seed(self.seed + self.epoch)
78+
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
79+
else:
80+
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
81+
82+
if not self.drop_last:
83+
indices = indices[self.skip_samples: len(indices)]
84+
# add extra samples to make it evenly divisible
85+
padding_size = self.total_size - len(indices)
86+
if padding_size <= len(indices):
87+
indices += indices[:padding_size]
88+
else:
89+
indices += (indices * math.ceil(padding_size / len(indices)))[
90+
:padding_size
91+
]
92+
else:
93+
# remove tail of data to make it evenly divisible.
94+
indices = indices[self.skip_samples : self.skip_samples + self.total_size]
95+
if len(indices) != self.total_size:
96+
raise AssertionError(
97+
f"Number of indices ({len(indices)}) does not match total_size ({self.total_size})"
98+
)
99+
100+
# subsample
101+
indices = indices[self.rank : self.total_size : self.num_replicas]
102+
if len(indices) != self.num_samples:
103+
raise AssertionError(
104+
f"Number of subsampled indices ({len(indices)}) does not match num_samples ({self.num_samples})"
105+
)
106+
107+
# pyrefly: ignore # bad-return
108+
return iter(indices)
109+
110+
def __len__(self) -> int:
111+
return self.num_samples
112+
113+
def set_epoch(self, epoch: int) -> None:
114+
r"""
115+
Set the epoch for this sampler.
116+
117+
When :attr:`shuffle=True`, this ensures all replicas
118+
use a different random ordering for each epoch. Otherwise, the next iteration of this
119+
sampler will yield the same ordering.
120+
121+
Args:
122+
epoch (int): Epoch number.
123+
"""
124+
self.epoch = epoch
22125

23126
# pyre-fixme[24]: expected generic parameter
24127
class DistributedSampler(data.distributed.DistributedSampler):

torchft/data_test.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from torch.utils.data import Dataset
1010

11-
from torchft.data import DistributedSampler
11+
from torchft.data import DistributedSampler, SkipDistributedSampler
1212

1313

1414
class DummyDataset(Dataset):
@@ -37,3 +37,62 @@ def test_distributed_sampler(self) -> None:
3737

3838
sampler_iter = iter(sampler)
3939
self.assertEqual(next(sampler_iter), 500)
40+
41+
def test_skip_distributed_sampler(self):
42+
dataset_length = 100
43+
dataset = DummyDataset(dataset_length)
44+
45+
# Case 1: sample is not skipped
46+
for drop_last in [True, False]:
47+
num_replicas = 7
48+
for rank in range(num_replicas):
49+
sampler = SkipDistributedSampler(dataset=dataset, num_replicas=num_replicas,
50+
rank=rank, shuffle=False, drop_last=drop_last)
51+
cur = rank
52+
for idx in sampler:
53+
self.assertEqual(idx, (cur % dataset_length), f"idx={idx}, cur={cur}")
54+
cur += num_replicas
55+
# If drop_last is True, read ceil((100-7)/7)*7=98 samples totally.
56+
# If drop_last is False, read ceil(100/7)*7=105 samples totally.
57+
if drop_last:
58+
self.assertEqual(cur, 98 + rank, f"rank={rank}, cur={cur}")
59+
else:
60+
self.assertEqual(cur, 105 + rank, f"rank={rank}, cur={cur}")
61+
62+
# Case 2: sample is skipped
63+
for drop_last in [True, False]:
64+
num_replicas = 7
65+
skip_samples = 10
66+
for rank in range(num_replicas):
67+
sampler = SkipDistributedSampler(dataset=dataset, num_replicas=num_replicas,
68+
rank=rank, shuffle=False, drop_last=drop_last,
69+
skip_samples=skip_samples)
70+
cur = rank
71+
for idx in sampler:
72+
expected = ((cur + skip_samples) % dataset_length + skip_samples) \
73+
if (cur + skip_samples) >= dataset_length else (cur + skip_samples)
74+
self.assertEqual(idx, expected, f"idx={idx}, expected={expected}")
75+
cur += num_replicas
76+
# If drop_last is True, read ceil((100-10-7)/7)*7=84 samples totally.
77+
# If drop_last is False, read ceil((100-10)/7)*7=91 samples totally.
78+
if drop_last:
79+
self.assertEqual(cur, 84 + rank, f"rank={rank}, cur={cur}")
80+
else:
81+
self.assertEqual(cur, 91 + rank, f"rank={rank}, cur={cur}")
82+
83+
# Case 3: drop last is False and padding size is larger than number of indices
84+
# If skip_samples is 90, and num_replicas is 31, then the indices is [90, 92, ..., 99].
85+
# It means only 10 samples are left, so padding size is 21 which is larger than 10.
86+
num_replicas = 31
87+
skip_samples = 90
88+
expected = list(range(90, 100))
89+
expected = (expected * 4)[:31]
90+
for rank in range(num_replicas):
91+
sampler = SkipDistributedSampler(dataset=dataset, num_replicas=num_replicas,
92+
rank=rank, shuffle=False, drop_last=False,
93+
skip_samples=skip_samples)
94+
cnt = 0
95+
for idx in sampler:
96+
self.assertEqual(idx, expected[rank], f"idx={idx}, rank={rank}, expected={expected}")
97+
cnt += 1
98+
self.assertTrue(cnt, 1)

torchft/manager.py

Lines changed: 87 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"""
2727

2828
import concurrent.futures
29+
import gc
2930
import logging
3031
import os
3132
import socket
@@ -185,6 +186,7 @@ def __init__(
185186
init_sync: bool = True,
186187
max_retries: Optional[int] = None,
187188
quorum_retries: int = 0,
189+
dataloader_fn: Optional[Callable[[int, int, int], None]] = None,
188190
) -> None:
189191
"""
190192
Args:
@@ -365,6 +367,17 @@ def __init__(
365367

366368
self._update_fr_path()
367369

370+
# The number of batches committed in the current epoch.Compare to _batches_committed,
371+
# _current_batches_committed will reset to 0 when next epoch starts.
372+
self._current_batches_committed = 0
373+
self._epoch = 0
374+
self._loaded_epoch = 0
375+
self._loaded_current_batches_committed = 0
376+
self._dataloader_fn = dataloader_fn
377+
self._dataloader_dirty = False
378+
self._dataloader_iter = None
379+
self._accumulation_steps = 1
380+
368381
def allow_state_dict_read(self) -> None:
369382
if self._is_state_dict_read_allowed:
370383
return
@@ -438,6 +451,12 @@ def allreduce(
438451
return _DummyWork(tensor)
439452

440453
self.wait_quorum()
454+
455+
# If dirty, the result will not be committed, so return empty tensor.
456+
if self._dataloader_dirty:
457+
work = _DummyWork(torch.zeros_like(tensor))
458+
return _ManagedWork(self, work, tensor)
459+
441460
num_participants: int = self.num_participants()
442461

443462
if not self.is_participating():
@@ -678,6 +697,8 @@ def _async_quorum(
678697
if self._use_async_quorum or not allow_heal
679698
else (replica_rank, replica_world_size)
680699
)
700+
self._replica_rank = replica_rank
701+
self._replica_world_size = replica_world_size
681702

682703
# For fixed with spares we need to ensure that we don't have more
683704
# participating replicas than the min replica size.
@@ -691,6 +712,7 @@ def _async_quorum(
691712
):
692713
self._participating_replica_rank = None
693714

715+
quorum_changed = False
694716
if quorum_id != self._quorum_id:
695717
self.quorum_logger.info(
696718
"",
@@ -737,6 +759,7 @@ def _async_quorum(
737759
self._logger.exception(f"got exception in pg configure: {e}")
738760
self.report_error(e)
739761
return
762+
quorum_changed = True
740763

741764
if allow_heal:
742765
# run recovery on the recovery stream if available
@@ -807,6 +830,38 @@ def _async_quorum(
807830
else None
808831
)
809832

833+
# reconfigure dataloader after healing so that we can get offset from other replica group
834+
if quorum_changed and self._dataloader_fn:
835+
self.reconfigure_dataloader()
836+
self._dataloader_dirty = True
837+
838+
def get_batch_samples(self, epoch=0, num_batches=None, batch_size=None, total_batch_size=None):
839+
# In general, `start_quorum` might not have been called during the first loop,
840+
# and the dataloader might not have been initialized yet. In this case, we should
841+
# return immediately and set the dirty flag to avoid computation and commit.
842+
if not self._dataloader_iter:
843+
self._dataloader_dirty = True
844+
return []
845+
# If the recovery worker is behind the current epoch, we should skip computation and commit.
846+
if epoch < self._loaded_epoch:
847+
return None
848+
849+
if total_batch_size != None and batch_size != None:
850+
num_batches = total_batch_size // (batch_size * self._replica_world_size)
851+
852+
assert num_batches is not None, ("num_batches must be specified or "
853+
"total_batch_size and batch_size must be specified")
854+
855+
batch_samples = []
856+
for _ in range(num_batches):
857+
try:
858+
batch_samples.append(next(self._dataloader_iter))
859+
except StopIteration:
860+
break
861+
self._dataloader_dirty = False
862+
self._accumulation_steps = len(batch_samples)
863+
return batch_samples if batch_samples else None
864+
810865
def _update_fr_path(self) -> None:
811866
"""
812867
Update the path that flight recorder will dump the traces to.
@@ -921,9 +976,14 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
921976

922977
# decide whether we're in a healthy state to increase the step count
923978
if should_commit:
924-
self._step += 1
925-
self._batches_committed += self.num_participants()
926979
self._commit_failures = 0 # Reset failure counter on success
980+
if not self._dataloader_dirty:
981+
self._step += 1
982+
self._batches_committed += self.num_participants() * self._accumulation_steps
983+
self._current_batches_committed += self.num_participants() * self._accumulation_steps
984+
return True
985+
else:
986+
return False
927987
else:
928988
self._commit_failures += 1
929989
# Check if we've hit max retries
@@ -934,8 +994,7 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
934994
msg = f"should_commit failed {self._commit_failures} times consecutively, exceeding max_retries={self._max_retries}"
935995
self._logger.exception(msg)
936996
raise RuntimeError(msg)
937-
938-
return should_commit
997+
return False
939998

940999
def load_state_dict(self, state_dict: Dict[str, int]) -> None:
9411000
"""
@@ -948,6 +1007,11 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None:
9481007
"""
9491008
self._step = state_dict["step"]
9501009
self._batches_committed = state_dict["batches_committed"]
1010+
self._loaded_epoch = state_dict["epoch"]
1011+
self._loaded_current_batches_committed = state_dict["current_batches_committed"]
1012+
if self._loaded_epoch == 0:
1013+
self._epoch = 0
1014+
self._current_batches_committed = self._loaded_current_batches_committed
9511015

9521016
def _manager_state_dict(self) -> Dict[str, object]:
9531017
with self._state_dict_lock.r_lock():
@@ -969,7 +1033,8 @@ def state_dict(self) -> Dict[str, int]:
9691033
Returns:
9701034
the state dict for this manager
9711035
"""
972-
return {"step": self._step, "batches_committed": self._batches_committed}
1036+
return {"step": self._step, "batches_committed": self._batches_committed,
1037+
"epoch": self._epoch, "current_batches_committed": self._current_batches_committed}
9731038

9741039
def current_step(self) -> int:
9751040
"""
@@ -1047,6 +1112,23 @@ def is_participating(self) -> bool:
10471112
return False
10481113
return True
10491114

1115+
def reconfigure_dataloader(self):
1116+
dataloader = self._dataloader_fn(self._replica_world_size,
1117+
self._replica_rank, self._current_batches_committed)
1118+
dataloader.sampler.set_epoch(self._epoch)
1119+
self._dataloader_iter = iter(dataloader)
1120+
# cleanup for old dataloader
1121+
gc.collect()
1122+
1123+
def next_epoch(self):
1124+
self._epoch += 1
1125+
if self._loaded_epoch == self._epoch:
1126+
self._current_batches_committed = self._loaded_current_batches_committed
1127+
else:
1128+
self._current_batches_committed = 0
1129+
if self._dataloader_fn:
1130+
self.reconfigure_dataloader()
1131+
self._dataloader_dirty = False
10501132

10511133
class _ManagerLogger:
10521134
def __init__(self, manager: Manager, replica_id: str, group_rank: int) -> None:

0 commit comments

Comments
 (0)