Skip to content

Commit c0311d2

Browse files
authored
Merge pull request #39 from microsoft/wesselb/adjust-docs-and-minor-changes
Minor style changes and docstring adaptations
2 parents c694fca + d1afde8 commit c0311d2

File tree

2 files changed

+50
-51
lines changed

2 files changed

+50
-51
lines changed

aurora/model/aurora.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import warnings
66
from datetime import timedelta
77
from functools import partial
8-
from typing import Any, Optional
8+
from typing import Optional
99

1010
import torch
1111
from huggingface_hub import hf_hub_download
@@ -93,7 +93,9 @@ def __init__(
9393
separate parameter.
9494
perceiver_ln_eps (float, optional): Epsilon in the perceiver layer norm. layers. Used
9595
to stabilise the model.
96-
max_history_size (int, optional): Maximum number of history steps.
96+
max_history_size (int, optional): Maximum number of history steps. You can load
97+
checkpoints with a smaller `max_history_size`, but you cannot load checkpoints
98+
with a larger `max_history_size`.
9799
use_lora (bool, optional): Use LoRA adaptation.
98100
lora_steps (int, optional): Use different LoRA adaptation for the first so-many roll-out
99101
steps.
@@ -316,54 +318,54 @@ def load_checkpoint_local(self, path: str, strict: bool = True) -> None:
316318
d[f"decoder.atmos_heads.{name}.weight"] = weight[:, i]
317319
d[f"decoder.atmos_heads.{name}.bias"] = bias[:, i]
318320

319-
# check if history size is compatible and adjust weights if necessary
320-
if self.max_history_size > d["encoder.surf_token_embeds.weights.2t"].shape[2]:
321-
d = self.adapt_checkpoint_max_history_size(d)
322-
elif self.max_history_size < d["encoder.surf_token_embeds.weights.2t"].shape[2]:
323-
raise AssertionError(f"Cannot load checkpoint with max_history_size \
324-
{d['encoder.surf_token_embeds.weights.2t'].shape[2]} \
325-
into model with max_history_size {self.max_history_size}")
321+
# Check if the history size is compatible and adjust weights if necessary.
322+
current_history_size = d["encoder.surf_token_embeds.weights.2t"].shape[2]
323+
if self.max_history_size > current_history_size:
324+
self.adapt_checkpoint_max_history_size(d)
325+
elif self.max_history_size < current_history_size:
326+
raise AssertionError(
327+
f"Cannot load checkpoint with `max_history_size` {current_history_size} "
328+
f"into model with `max_history_size` {self.max_history_size}."
329+
)
326330

327331
self.load_state_dict(d, strict=strict)
328332

329-
def adapt_checkpoint_max_history_size(self, checkpoint) -> Any:
330-
"""Adapt a checkpoint with smaller max_history_size to a model with a larger
331-
max_history_size than the current model.
333+
def adapt_checkpoint_max_history_size(self, checkpoint: dict[str, torch.Tensor]) -> None:
334+
"""Adapt a checkpoint with smaller `max_history_size` to a model with a larger
335+
`max_history_size` than the current model.
332336
333-
If a checkpoint was trained with a larger max_history_size than the current model,
337+
If a checkpoint was trained with a larger `max_history_size` than the current model,
334338
this function will assert fail to prevent loading the checkpoint. This is to
335339
prevent loading a checkpoint which will likely cause the checkpoint to degrade is
336340
performance.
337341
338-
This implementation copies weights from the checkpoint to the model and fills 0
339-
for the new history width dimension.
342+
This implementation copies weights from the checkpoint to the model and fills zeros
343+
for the new history width dimension. It mutates `checkpoint`.
340344
"""
341-
# Find all weights with prefix "encoder.surf_token_embeds.weights."
342345
for name, weight in list(checkpoint.items()):
343-
if name.startswith("encoder.surf_token_embeds.weights.") or name.startswith(
344-
"encoder.atmos_token_embeds.weights."
345-
):
346+
# We only need to adapt the patch embedding in the encoder.
347+
enc_surf_embedding = name.startswith("encoder.surf_token_embeds.weights.")
348+
enc_atmos_embedding = name.startswith("encoder.atmos_token_embeds.weights.")
349+
if enc_surf_embedding or enc_atmos_embedding:
346350
# This shouldn't get called with current logic but leaving here for future proofing
347-
# and in cases where its called outside current context
348-
assert (
349-
weight.shape[2] <= self.max_history_size
350-
), f"Cannot load checkpoint with max_history_size {weight.shape[2]} \
351-
into model with max_history_size {self.max_history_size} for weight {name}"
352-
353-
# Initialize the new weight tensor
351+
# and in cases where its called outside current context.
352+
if not (weight.shape[2] <= self.max_history_size):
353+
raise AssertionError(
354+
f"Cannot load checkpoint with `max_history_size` {weight.shape[2]} "
355+
f"into model with `max_history_size` {self.max_history_size}."
356+
)
357+
358+
# Initialize the new weight tensor.
354359
new_weight = torch.zeros(
355360
(weight.shape[0], 1, self.max_history_size, weight.shape[3], weight.shape[4]),
356361
device=weight.device,
357362
dtype=weight.dtype,
358363
)
359-
360364
# Copy the existing weights to the new tensor by duplicating the histories provided
361-
# into any new history dimensions
362-
for j in range(weight.shape[2]):
363-
# only fill existing weights, others are zeros
364-
new_weight[:, :, j, :, :] = weight[:, :, j, :, :]
365+
# into any new history dimensions. The rest remains at zero.
366+
new_weight[:, :, : weight.shape[2]] = weight
367+
365368
checkpoint[name] = new_weight
366-
return checkpoint
367369

368370
def configure_activation_checkpointing(self):
369371
"""Configure activation checkpointing.
Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Copyright (c) Microsoft Corporation. Licensed under the MIT license."""
22

3+
import numpy as np
34
import pytest
45
import torch
56

@@ -19,45 +20,41 @@ def checkpoint():
1920
}
2021

2122

22-
# check both history sizes which are divisible by 2 (original shape) and not
23+
# Check both history sizes which are divisible by 2 (original shape) and not.
2324
@pytest.mark.parametrize("model", [4, 5], indirect=True)
2425
def test_adapt_checkpoint_max_history(model, checkpoint):
25-
# checkpoint starts with history dim, shape[2], as size 2
26+
# Checkpoint starts with history dim., `shape[2]`, equal to 2.
2627
assert checkpoint["encoder.surf_token_embeds.weights.0"].shape[2] == 2
27-
adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint)
28+
model.adapt_checkpoint_max_history_size(checkpoint)
2829

29-
for name, weight in adapted_checkpoint.items():
30+
for name, weight in checkpoint.items():
3031
assert weight.shape[2] == model.max_history_size
3132
for j in range(weight.shape[2]):
3233
if j >= checkpoint[name].shape[2]:
33-
assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :]))
34+
np.testing.assert_allclose(weight[:, :, j, :, :], 0 * weight[:, :, j, :, :])
3435
else:
35-
assert torch.equal(
36-
weight[:, :, j, :, :],
37-
checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :],
38-
)
36+
np.testing.assert_allclose(weight[:, :, j, :, :], checkpoint[name][:, :, j, :, :])
3937

4038

41-
# check that assert is thrown when trying to load a larger checkpoint to a smaller history size
4239
@pytest.mark.parametrize("model", [1], indirect=True)
4340
def test_adapt_checkpoint_max_history_fail(model, checkpoint):
41+
"""Check that an assertion error is thrown when trying to load a larger checkpoint to a
42+
smaller history size."""
4443
with pytest.raises(AssertionError):
4544
model.adapt_checkpoint_max_history_size(checkpoint)
4645

4746

48-
# test adapting the checkpoint twice to ensure that the second time should not change the weights
4947
@pytest.mark.parametrize("model", [4], indirect=True)
5048
def test_adapt_checkpoint_max_history_twice(model, checkpoint):
51-
adapted_checkpoint = model.adapt_checkpoint_max_history_size(checkpoint)
52-
adapted_checkpoint = model.adapt_checkpoint_max_history_size(adapted_checkpoint)
49+
"""Test adapting the checkpoint twice to ensure that the second time should not change the
50+
weights."""
51+
model.adapt_checkpoint_max_history_size(checkpoint)
52+
model.adapt_checkpoint_max_history_size(checkpoint)
5353

54-
for name, weight in adapted_checkpoint.items():
54+
for name, weight in checkpoint.items():
5555
assert weight.shape[2] == model.max_history_size
5656
for j in range(weight.shape[2]):
5757
if j >= checkpoint[name].shape[2]:
58-
assert torch.equal(weight[:, :, j, :, :], torch.zeros_like(weight[:, :, j, :, :]))
58+
np.testing.assert_allclose(weight[:, :, j, :, :], 0 * weight[:, :, j, :, :])
5959
else:
60-
assert torch.equal(
61-
weight[:, :, j, :, :],
62-
checkpoint[name][:, :, j % checkpoint[name].shape[2], :, :],
63-
)
60+
np.testing.assert_allclose(weight[:, :, j, :, :], checkpoint[name][:, :, j, :, :])

0 commit comments

Comments
 (0)