|  | 
| 5 | 5 | import warnings | 
| 6 | 6 | from datetime import timedelta | 
| 7 | 7 | from functools import partial | 
| 8 |  | -from typing import Any, Optional | 
|  | 8 | +from typing import Optional | 
| 9 | 9 | 
 | 
| 10 | 10 | import torch | 
| 11 | 11 | from huggingface_hub import hf_hub_download | 
| @@ -93,7 +93,9 @@ def __init__( | 
| 93 | 93 |                 separate parameter. | 
| 94 | 94 |             perceiver_ln_eps (float, optional): Epsilon in the perceiver layer norm. layers. Used | 
| 95 | 95 |                 to stabilise the model. | 
| 96 |  | -            max_history_size (int, optional): Maximum number of history steps. | 
|  | 96 | +            max_history_size (int, optional): Maximum number of history steps. You can load | 
|  | 97 | +                checkpoints with a smaller `max_history_size`, but you cannot load checkpoints | 
|  | 98 | +                with a larger `max_history_size`. | 
| 97 | 99 |             use_lora (bool, optional): Use LoRA adaptation. | 
| 98 | 100 |             lora_steps (int, optional): Use different LoRA adaptation for the first so-many roll-out | 
| 99 | 101 |                 steps. | 
| @@ -316,54 +318,54 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None: | 
| 316 | 318 |                 d[f"decoder.atmos_heads.{name}.weight"] = weight[:, i] | 
| 317 | 319 |                 d[f"decoder.atmos_heads.{name}.bias"] = bias[:, i] | 
| 318 | 320 | 
 | 
| 319 |  | -        # check if history size is compatible and adjust weights if necessary | 
| 320 |  | -        if self.max_history_size > d["encoder.surf_token_embeds.weights.2t"].shape[2]: | 
| 321 |  | -            d = self.adapt_checkpoint_max_history_size(d) | 
| 322 |  | -        elif self.max_history_size < d["encoder.surf_token_embeds.weights.2t"].shape[2]: | 
| 323 |  | -            raise AssertionError(f"Cannot load checkpoint with max_history_size \ | 
| 324 |  | -                {d['encoder.surf_token_embeds.weights.2t'].shape[2]} \ | 
| 325 |  | -                into model with max_history_size {self.max_history_size}") | 
|  | 321 | +        # Check if the history size is compatible and adjust weights if necessary. | 
|  | 322 | +        current_history_size = d["encoder.surf_token_embeds.weights.2t"].shape[2] | 
|  | 323 | +        if self.max_history_size > current_history_size: | 
|  | 324 | +            self.adapt_checkpoint_max_history_size(d) | 
|  | 325 | +        elif self.max_history_size < current_history_size: | 
|  | 326 | +            raise AssertionError( | 
|  | 327 | +                f"Cannot load checkpoint with `max_history_size` {current_history_size} " | 
|  | 328 | +                f"into model with `max_history_size` {self.max_history_size}." | 
|  | 329 | +            ) | 
| 326 | 330 | 
 | 
| 327 | 331 |         self.load_state_dict(d, strict=strict) | 
| 328 | 332 | 
 | 
| 329 |  | -    def adapt_checkpoint_max_history_size(self, checkpoint) -> Any: | 
| 330 |  | -        """Adapt a checkpoint with smaller max_history_size to a model with a larger | 
| 331 |  | -        max_history_size than the current model. | 
|  | 333 | +    def adapt_checkpoint_max_history_size(self, checkpoint: dict[str, torch.Tensor]) -> None: | 
|  | 334 | +        """Adapt a checkpoint with smaller `max_history_size` to a model with a larger | 
|  | 335 | +        `max_history_size` than the current model. | 
| 332 | 336 | 
 | 
| 333 |  | -        If a checkpoint was trained with a larger max_history_size than the current model, | 
|  | 337 | +        If a checkpoint was trained with a larger `max_history_size` than the current model, | 
| 334 | 338 |         this function will assert fail to prevent loading the checkpoint. This is to | 
| 335 | 339 |         prevent loading a checkpoint which will likely cause the checkpoint to degrade is | 
| 336 | 340 |         performance. | 
| 337 | 341 | 
 | 
| 338 |  | -        This implementation copies weights from the checkpoint to the model and fills 0 | 
| 339 |  | -        for the new history width dimension. | 
|  | 342 | +        This implementation copies weights from the checkpoint to the model and fills zeros | 
|  | 343 | +        for the new history width dimension. It mutates `checkpoint`. | 
| 340 | 344 |         """ | 
| 341 |  | -        # Find all weights with prefix "encoder.surf_token_embeds.weights." | 
| 342 | 345 |         for name, weight in list(checkpoint.items()): | 
| 343 |  | -            if name.startswith("encoder.surf_token_embeds.weights.") or name.startswith( | 
| 344 |  | -                "encoder.atmos_token_embeds.weights." | 
| 345 |  | -            ): | 
|  | 346 | +            # We only need to adapt the patch embedding in the encoder. | 
|  | 347 | +            enc_surf_embedding = name.startswith("encoder.surf_token_embeds.weights.") | 
|  | 348 | +            enc_atmos_embedding = name.startswith("encoder.atmos_token_embeds.weights.") | 
|  | 349 | +            if enc_surf_embedding or enc_atmos_embedding: | 
| 346 | 350 |                 # This shouldn't get called with current logic but leaving here for future proofing | 
| 347 |  | -                # and in cases where its called outside current context | 
| 348 |  | -                assert ( | 
| 349 |  | -                    weight.shape[2] <= self.max_history_size | 
| 350 |  | -                ), f"Cannot load checkpoint with max_history_size {weight.shape[2]} \ | 
| 351 |  | -                    into model with max_history_size {self.max_history_size} for weight {name}" | 
| 352 |  | - | 
| 353 |  | -                # Initialize the new weight tensor | 
|  | 351 | +                # and in cases where its called outside current context. | 
|  | 352 | +                if not (weight.shape[2] <= self.max_history_size): | 
|  | 353 | +                    raise AssertionError( | 
|  | 354 | +                        f"Cannot load checkpoint with `max_history_size` {weight.shape[2]} " | 
|  | 355 | +                        f"into model with `max_history_size` {self.max_history_size}." | 
|  | 356 | +                    ) | 
|  | 357 | + | 
|  | 358 | +                # Initialize the new weight tensor. | 
| 354 | 359 |                 new_weight = torch.zeros( | 
| 355 | 360 |                     (weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4]), | 
| 356 | 361 |                     device=weight.device, | 
| 357 | 362 |                     dtype=weight.dtype, | 
| 358 | 363 |                 ) | 
| 359 |  | - | 
| 360 | 364 |                 # Copy the existing weights to the new tensor by duplicating the histories provided | 
| 361 |  | -                # into any new history dimensions | 
| 362 |  | -                for j in range(weight.shape[2]): | 
| 363 |  | -                    # only fill existing weights, others are zeros | 
| 364 |  | -                    new_weight[:, :, j, :, :] = weight[:, :, j, :, :] | 
|  | 365 | +                # into any new history dimensions. The rest remains at zero. | 
|  | 366 | +                new_weight[:, :, : weight.shape[2]] = weight | 
|  | 367 | + | 
| 365 | 368 |                 checkpoint[name] = new_weight | 
| 366 |  | -        return checkpoint | 
| 367 | 369 | 
 | 
| 368 | 370 |     def configure_activation_checkpointing(self): | 
| 369 | 371 |         """Configure activation checkpointing. | 
|  | 
0 commit comments