diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index ae5038b2022d2..b37631935b51f 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -83,6 +83,33 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None, weights_only: Op loaded_checkpoint = self.trainer.strategy.load_checkpoint(checkpoint_path, weights_only=weights_only) self._loaded_checkpoint = _pl_migrate_checkpoint(loaded_checkpoint, checkpoint_path) + try: + batch_progress = getattr(self.trainer, "batch_progress", None) + if batch_progress is not None: + global_step = int(getattr(self.trainer, "global_step", 0)) + + # Align total counters defensively so they are at least the restored global_step + try: + batch_progress.total_ready = max(int(getattr(batch_progress, "total_ready", 0)), global_step) + batch_progress.total_completed = max( + int(getattr(batch_progress, "total_completed", 0)), global_step + ) + except Exception as exc: + log.debug(f"BatchProgress restore fallback triggered: {exc}") + + # Try to compute within-epoch counters when possible + try: + epoch_len = getattr(self.trainer, "limit_train_batches", None) + if isinstance(epoch_len, int) and epoch_len > 0: + epoch_size = int(epoch_len) + batch_progress.current_completed = global_step % max(1, epoch_size) + batch_progress.current_ready = batch_progress.current_completed + except Exception as exc: + log.debug(f"BatchProgress restore fallback triggered: {exc}") + + except Exception as exc: + log.debug(f"BatchProgress restore fallback triggered: {exc}") + def _select_ckpt_path( self, state_fn: TrainerFn, ckpt_path: Optional[_PATH], model_provided: bool, model_connected: bool ) -> Optional[_PATH]: @@ -291,7 +318,8 @@ def restore_model(self) -> None: def restore_training_state(self) -> None: """Restore the trainer state from the pre-loaded checkpoint. - This includes the precision settings, loop progress, optimizer states and learning rate scheduler states. + This includes the precision settinglobal_step, loop progress, optimizer states and learning rate scheduler + states. """ if not self._loaded_checkpoint: diff --git a/tests/test_resume_batch_progress.py b/tests/test_resume_batch_progress.py new file mode 100644 index 0000000000000..22e53c32fea1b --- /dev/null +++ b/tests/test_resume_batch_progress.py @@ -0,0 +1,179 @@ +import logging + +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader, random_split + +import lightning as L +from lightning.pytorch.demos import Transformer, WikiText2 + + +class TinyModel(L.LightningModule): + def __init__(self, vocab_size): + super().__init__() + self.model = Transformer(vocab_size=vocab_size) + + def training_step(self, batch, batch_idx): + x, y = batch + out = self.model(x, y) + return F.nll_loss(out, y.view(-1)) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.1) + + +def _find_batch_progress(trainer): + """Try common places for the BatchProgress object across Lightning versions.""" + log = logging.getLogger(__name__) + candidates = [ + getattr(trainer, "batch_progress", None), + getattr(getattr(trainer, "fit_loop", None), "batch_progress", None), + getattr(getattr(getattr(trainer, "fit_loop", None), "epoch_loop", None), "batch_progress", None), + ] + for candidate in candidates: + if candidate is not None: + return candidate + + # heuristic scan + for name in dir(trainer): + try: + attr = getattr(trainer, name) + if attr is None: + continue + if any(n in dir(attr) for n in ("current_completed", "current", "completed")): + return attr + except Exception as exc: + log.debug(f"BatchProgress restore fallback triggered: {exc}") + return None + + +def _extract_int(candidate): + """ + Try to extract an integer from candidate which may be: + - an int-like + - an object with attributes like 'completed', 'ready', 'count', 'value', 'n' + - a tuple/list like (ready, completed) + - something convertible via int() + Raise ValueError if not possible. + """ + log = logging.getLogger(__name__) + + if isinstance(candidate, int): + return candidate + + for attribute in ("completed", "ready", "count", "value", "n", "total"): + if hasattr(candidate, attribute): + try: + return int(getattr(candidate, attribute)) + except Exception as exc: + log.debug(f"BatchProgress restore fallback triggered: {exc}") + + if isinstance(candidate, (tuple, list)) and len(candidate) > 0: + for element in candidate: + try: + return int(element) + except Exception: + for attribute in ("completed", "ready", "count", "value", "n"): + if hasattr(element, attribute): + try: + return int(getattr(element, attribute)) + except Exception as exc: + log.debug(f"BatchProgress restore fallback triggered: {exc}") + try: + return int(candidate[0]) + except Exception as exc: + log.debug(f"BatchProgress restore fallback triggered: {exc}") + + try: + return int(candidate) + except Exception: + typename = type(candidate).__name__ + sample_attrs = ", ".join(sorted(dir(candidate))[:40]) + raise ValueError(f"Unable to coerce candidate (type {typename}) to int. Sample attrs: {sample_attrs}") + + +def test_resume_mid_epoch_batch_progress(tmp_path): + L.seed_everything(42) + log = logging.getLogger(__name__) + dataset = WikiText2() + dataset_len = len(dataset) + train_ds, val_ds, _ = random_split(dataset, [dataset_len - 200, 100, 100]) + + train_loader = DataLoader(train_ds, batch_size=1, shuffle=False) + val_loader = DataLoader(val_ds, batch_size=1, shuffle=False) + + # 1) Run short training to produce a mid-epoch checkpoint (step 5) + model = TinyModel(vocab_size=dataset.vocab_size) + trainer_short = L.Trainer(max_steps=5, enable_progress_bar=False) + trainer_short.fit(model, train_loader, val_loader) + assert trainer_short.global_step >= 5, f"short trainer didn't reach step 5 (gs={trainer_short.global_step})" + + ckpt_mid = tmp_path / "mid_epoch.ckpt" + trainer_short.save_checkpoint(str(ckpt_mid)) + assert ckpt_mid.exists(), f"failed to create checkpoint at {ckpt_mid}" + + # 2) Resume from that checkpoint with a fresh trainer + trainer_resume = L.Trainer(max_steps=10, enable_progress_bar=False) + model2 = TinyModel(vocab_size=dataset.vocab_size) + trainer_resume.fit(model2, train_loader, val_loader, ckpt_path=str(ckpt_mid)) + + batch_progress = _find_batch_progress(trainer_resume) + assert batch_progress is not None, "BatchProgress object not found on Trainer; see earlier logs." + + possible_total_names = [ + "total_completed", + "total_completed_batches", + "total_steps_completed", + "completed", + "total", + "total_done", + "total_ready", + "ready", + ] + possible_current_names = [ + "current_completed", + "current", + "current_batch", + "current_index", + "completed_in_epoch", + "in_progress", + "current_ready", + ] + + total_candidate = None + for name in possible_total_names: + if hasattr(batch_progress, name): + total_candidate = getattr(batch_progress, name) + break + if total_candidate is None: + total_candidate = batch_progress + + current_candidate = None + for name in possible_current_names: + if hasattr(batch_progress, name): + current_candidate = getattr(batch_progress, name) + break + if current_candidate is None: + for name in dir(batch_progress): + if name.lower().startswith("current"): + try: + current_candidate = getattr(batch_progress, name) + break + except Exception as exc: + log.debug(f"BatchProgress restore fallback triggered: {exc}") + if current_candidate is None: + current_candidate = 0 + + total_completed = _extract_int(total_candidate) + current_completed = _extract_int(current_candidate) + + global_step = trainer_resume.global_step + + assert total_completed >= 0, "negative total_completed found" + assert current_completed >= 0, "negative current_completed found" + + assert total_completed >= global_step or total_completed == 0, ( + f"unexpected total_completed={total_completed} < global_step={global_step}" + ) + + assert current_completed <= total_completed