Skip to content

Commit e6aeac6

Browse files
authored
Minor changes for revision of paper (#56)
* Update citation and mention wave model * Expose internal timestep of the model * Fix title in README * Expose level aggregation stabilisation tweak * Test that changing the new flags changes the output * Add more fine-tuning advice * Make naming more consistent * Test that model runs under DDP wrapping * Add `Batch.{to,from}_netcdf`
1 parent 8b11659 commit e6aeac6

File tree

11 files changed

+279
-21
lines changed

11 files changed

+279
-21
lines changed

README.md

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
<img src="docs/aurora.jpg" alt="Aurora logo" width="200"/>
33
</p>
44

5-
# Aurora: A Foundation Model of the Atmosphere
5+
# Aurora: A Foundation Model for the Earth System
66

77
[![CI](https://github.com/microsoft/Aurora/actions/workflows/ci.yaml/badge.svg)](https://github.com/microsoft/Aurora/actions/workflows/ci.yaml)
88
[![Documentation](https://img.shields.io/badge/docs-latest-blue.svg)](https://microsoft.github.io/aurora)
99
[![Paper](https://img.shields.io/badge/arXiv-2405.13063-blue)](https://arxiv.org/abs/2405.13063)
1010

11-
Implementation of the Aurora model for atmospheric forecasting.
11+
Implementation of the Aurora model for Earth system forecasting.
1212

1313
_The package currently includes the pretrained model and the fine-tuned version for high-resolution weather forecasting._
14-
_We are working on the fine-tuned version for air pollution forecasting, which will be included in due time._
14+
_We are working on the fine-tuned versions for air pollution and ocean wave forecasting, which will be included in due time._
1515

1616
[Link to the paper on arXiv.](https://arxiv.org/abs/2405.13063)
1717

@@ -22,8 +22,8 @@ Cite us as follows:
2222

2323
```
2424
@misc{bodnar2024aurora,
25-
title = {Aurora: A Foundation Model of the Atmosphere},
26-
author = {Cristian Bodnar and Wessel P. Bruinsma and Ana Lucic and Megan Stanley and Johannes Brandstetter and Patrick Garvan and Maik Riechert and Jonathan Weyn and Haiyu Dong and Anna Vaughan and Jayesh K. Gupta and Kit Tambiratnam and Alex Archibald and Elizabeth Heider and Max Welling and Richard E. Turner and Paris Perdikaris},
25+
title = {Aurora: A Foundation Model for the Earth System},
26+
author = {Cristian Bodnar and Wessel P. Bruinsma and Ana Lucic and Megan Stanley and Anna Vaughan and Johannes Brandstetter and Patrick Garvan and Maik Riechert and Jonathan A. Weyn and Haiyu Dong and Jayesh K. Gupta and Kit Thambiratnam and Alexander T. Archibald and Chun-Chieh Wu and Elizabeth Heider and Max Welling and Richard E. Turner and Paris Perdikaris},
2727
year = {2024},
2828
url = {https://arxiv.org/abs/2405.13063},
2929
eprint = {2405.13063},
@@ -48,10 +48,11 @@ Contents:
4848
Aurora is a machine learning model that can predict atmospheric variables, such as temperature.
4949
It is a _foundation model_, which means that it was first generally trained on a lot of data,
5050
and then can be adapted to specialised atmospheric forecasting tasks with relatively little data.
51-
We provide three such specialised versions:
51+
We provide four such specialised versions:
5252
one for medium-resolution weather prediction,
5353
one for high-resolution weather prediction,
54-
and one for air pollution prediction.
54+
one for air pollution prediction,
55+
and one for ocean wave prediction.
5556

5657
## Getting Started
5758

@@ -127,7 +128,7 @@ Our goal in publishing this code is
127128
This code has not been developed nor tested for non-academic purposes and hence should not be used as such.
128129

129130
### Limitations
130-
Although Aurora was trained to accurately predict future weather and air pollution,
131+
Although Aurora was trained to accurately predict future weather, air pollution, and ocean waves,
131132
Aurora is based on neural networks, which means that there are no strict guarantees that predictions will always be accurate.
132133
Altering the inputs, providing a sample that was not in the training set,
133134
or even providing a sample that was in the training set but is simply unlucky may result in arbitrarily poor predictions.
@@ -183,7 +184,7 @@ make docs
183184

184185
To locally view the documentation, open `docs/_build/index.html` in your browser.
185186

186-
### Why is the fine-tuned version of Aurora for air quality forecasting missing?
187+
### Why are the fine-tuned versions of Aurora for air quality and ocean wave forecasting missing?
187188

188189
The package currently includes the pretrained model and the fine-tuned version for high-resolution weather forecasting.
189-
We are working on the fine-tuned version for air pollution forecasting, which will be included in due time.
190+
We are working on the fine-tuned versions for air pollution and ocean wave forecasting, which will be included in due time.

aurora/batch.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import dataclasses
44
from datetime import datetime
55
from functools import partial
6-
from typing import Callable
6+
from pathlib import Path
7+
from typing import Callable, List
78

89
import numpy as np
910
import torch
@@ -220,6 +221,80 @@ def regrid(self, res: float) -> "Batch":
220221
),
221222
)
222223

224+
def to_netcdf(self, path: str | Path) -> None:
225+
"""Write the batch to a file.
226+
227+
This requires `xarray` and `netcdf4` to be installed.
228+
"""
229+
try:
230+
import xarray as xr
231+
except ImportError as e:
232+
raise RuntimeError("`xarray` must be installed.") from e
233+
234+
ds = xr.Dataset(
235+
{
236+
**{
237+
f"surf_{k}": (("batch", "history", "latitude", "longitude"), _np(v))
238+
for k, v in self.surf_vars.items()
239+
},
240+
**{
241+
f"static_{k}": (("latitude", "longitude"), _np(v))
242+
for k, v in self.static_vars.items()
243+
},
244+
**{
245+
f"atmos_{k}": (("batch", "history", "level", "latitude", "longitude"), _np(v))
246+
for k, v in self.atmos_vars.items()
247+
},
248+
},
249+
coords={
250+
"latitude": _np(self.metadata.lat),
251+
"longitude": _np(self.metadata.lon),
252+
"time": list(self.metadata.time),
253+
"level": list(self.metadata.atmos_levels),
254+
"rollout_step": self.metadata.rollout_step,
255+
},
256+
)
257+
ds.to_netcdf(path)
258+
259+
@classmethod
260+
def from_netcdf(cls, path: str | Path) -> "Batch":
261+
"""Load a batch from a file."""
262+
try:
263+
import xarray as xr
264+
except ImportError as e:
265+
raise RuntimeError("`xarray` must be installed.") from e
266+
267+
ds = xr.load_dataset(path, engine="netcdf4")
268+
269+
surf_vars: List[str] = []
270+
static_vars: List[str] = []
271+
atmos_vars: List[str] = []
272+
273+
for k in ds:
274+
if k.startswith("surf_"):
275+
surf_vars.append(k.removeprefix("surf_"))
276+
elif k.startswith("static_"):
277+
static_vars.append(k.removeprefix("static_"))
278+
elif k.startswith("atmos_"):
279+
atmos_vars.append(k.removeprefix("atmos_"))
280+
281+
return Batch(
282+
surf_vars={k: torch.from_numpy(ds[f"surf_{k}"].values) for k in surf_vars},
283+
static_vars={k: torch.from_numpy(ds[f"static_{k}"].values) for k in static_vars},
284+
atmos_vars={k: torch.from_numpy(ds[f"atmos_{k}"].values) for k in atmos_vars},
285+
metadata=Metadata(
286+
lat=torch.from_numpy(ds.latitude.values),
287+
lon=torch.from_numpy(ds.longitude.values),
288+
time=tuple(ds.time.values.astype("datetime64[s]").tolist()),
289+
atmos_levels=tuple(ds.level.values),
290+
rollout_step=int(ds.rollout_step.values),
291+
),
292+
)
293+
294+
295+
def _np(x: torch.Tensor) -> np.ndarray:
296+
return x.detach().cpu().numpy()
297+
223298

224299
def interpolate(
225300
v: torch.Tensor,

aurora/model/aurora.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def __init__(
5050
dec_mlp_ratio: float = 2.0,
5151
perceiver_ln_eps: float = 1e-5,
5252
max_history_size: int = 2,
53+
timestep: timedelta = timedelta(hours=6),
54+
stabilise_level_agg: bool = False,
5355
use_lora: bool = True,
5456
lora_steps: int = 40,
5557
lora_mode: LoRAMode = "single",
@@ -96,6 +98,9 @@ def __init__(
9698
max_history_size (int, optional): Maximum number of history steps. You can load
9799
checkpoints with a smaller `max_history_size`, but you cannot load checkpoints
98100
with a larger `max_history_size`.
101+
timestep (timedelta, optional): Timestep of the model. Defaults to 6 hours.
102+
stabilise_level_agg (bool, optional): Stabilise the level aggregation by inserting an
103+
additional layer normalisation. Defaults to `False`.
99104
use_lora (bool, optional): Use LoRA adaptation.
100105
lora_steps (int, optional): Use different LoRA adaptation for the first so-many roll-out
101106
steps.
@@ -115,6 +120,7 @@ def __init__(
115120
self.surf_stats = surf_stats or dict()
116121
self.autocast = autocast
117122
self.max_history_size = max_history_size
123+
self.timestep = timestep
118124

119125
if self.surf_stats:
120126
warnings.warn(
@@ -138,6 +144,7 @@ def __init__(
138144
latent_levels=latent_levels,
139145
max_history_size=max_history_size,
140146
perceiver_ln_eps=perceiver_ln_eps,
147+
stabilise_level_agg=stabilise_level_agg,
141148
)
142149

143150
self.backbone = Swin3DTransformerBackbone(
@@ -202,19 +209,19 @@ def forward(self, batch: Batch) -> Batch:
202209

203210
x = self.encoder(
204211
batch,
205-
lead_time=timedelta(hours=6),
212+
lead_time=self.timestep,
206213
)
207214
with torch.autocast(device_type="cuda") if self.autocast else contextlib.nullcontext():
208215
x = self.backbone(
209216
x,
210-
lead_time=timedelta(hours=6),
217+
lead_time=self.timestep,
211218
patch_res=patch_res,
212219
rollout_step=batch.metadata.rollout_step,
213220
)
214221
pred = self.decoder(
215222
x,
216223
batch,
217-
lead_time=timedelta(hours=6),
224+
lead_time=self.timestep,
218225
patch_res=patch_res,
219226
)
220227

aurora/model/encoder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(
4343
mlp_ratio: float = 4.0,
4444
max_history_size: int = 2,
4545
perceiver_ln_eps: float = 1e-5,
46+
stabilise_level_agg: bool = False,
4647
) -> None:
4748
"""Initialise.
4849
@@ -67,6 +68,8 @@ def __init__(
6768
to `2`.
6869
perceiver_ln_eps (float, optional): Epsilon value for layer normalisation in the
6970
Perceiver. Defaults to 1e-5.
71+
stabilise_level_agg (bool, optional): Stabilise the level aggregation by inserting an
72+
additional layer normalisation. Defaults to `False`.
7073
"""
7174
super().__init__()
7275

@@ -120,6 +123,7 @@ def __init__(
120123
drop=drop_rate,
121124
mlp_ratio=mlp_ratio,
122125
ln_eps=perceiver_ln_eps,
126+
ln_k_q=stabilise_level_agg,
123127
)
124128

125129
# Drop patches after encoding.

aurora/model/perceiver.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(
9797
context_dim: int,
9898
head_dim: int = 64,
9999
num_heads: int = 8,
100+
ln_k_q: bool = False,
100101
) -> None:
101102
"""Initialise.
102103
@@ -105,6 +106,7 @@ def __init__(
105106
context_dim (int): Dimensionality of the context features also given as input.
106107
head_dim (int): Attention head dimensionality.
107108
num_heads (int): Number of heads.
109+
ln_k_q (bool): Apply an extra layer norm. to the keys and queries.
108110
"""
109111
super().__init__()
110112
self.num_heads = num_heads
@@ -115,6 +117,13 @@ def __init__(
115117
self.to_kv = nn.Linear(context_dim, self.inner_dim * 2, bias=False)
116118
self.to_out = nn.Linear(self.inner_dim, latent_dim, bias=False)
117119

120+
if ln_k_q:
121+
self.ln_k = nn.LayerNorm(num_heads * head_dim)
122+
self.ln_q = nn.LayerNorm(num_heads * head_dim)
123+
else:
124+
self.ln_k = lambda x: x
125+
self.ln_q = lambda x: x
126+
118127
def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
119128
"""Run the cross-attention module.
120129
@@ -131,6 +140,11 @@ def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
131140

132141
q = self.to_q(latents) # (B, L1, D2) to (B, L1, D)
133142
k, v = self.to_kv(x).chunk(2, dim=-1) # (B, L2, D1) to twice (B, L2, D)
143+
144+
# Apply LN before (!) splitting the heads.
145+
k = self.ln_k(k)
146+
q = self.ln_q(q)
147+
134148
q, k, v = map(lambda t: rearrange(t, "b l (h d) -> b h l d", h=h), (q, k, v))
135149

136150
out = F.scaled_dot_product_attention(q, k, v)
@@ -152,6 +166,7 @@ def __init__(
152166
drop: float = 0.0,
153167
residual_latent: bool = True,
154168
ln_eps: float = 1e-5,
169+
ln_k_q: bool = False,
155170
) -> None:
156171
"""Initialise.
157172
@@ -168,13 +183,15 @@ def __init__(
168183
Defaults to `True`.
169184
ln_eps (float, optional): Epsilon in the layer normalisation layers. Defaults to
170185
`1e-5`.
186+
ln_k_q (bool, optional): Apply an extra layer norm. to the keys and queries of the first
187+
resampling layer. Defaults to `False`.
171188
"""
172189
super().__init__()
173190

174191
self.residual_latent = residual_latent
175192
self.layers = nn.ModuleList([])
176193
mlp_hidden_dim = int(latent_dim * mlp_ratio)
177-
for _ in range(depth):
194+
for i in range(depth):
178195
self.layers.append(
179196
nn.ModuleList(
180197
[
@@ -183,6 +200,7 @@ def __init__(
183200
context_dim=context_dim,
184201
head_dim=head_dim,
185202
num_heads=num_heads,
203+
ln_k_q=ln_k_q if i == 0 else False,
186204
),
187205
MLP(dim=latent_dim, hidden_features=mlp_hidden_dim, dropout=drop),
188206
nn.LayerNorm(latent_dim, eps=ln_eps),

docs/beware.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ exactly the right variables
1717
at exactly the right pressure levels
1818
from exactly the right source.
1919

20+
This also means that the performance of the model will be sensitive to how the
21+
data is regridded.
22+
For optimal performance, you should ensure that the data is regridded
23+
exactly like the data seen during pretraining and fine-tuning.
24+
2025
(t0-vs-analysis)=
2126
## HRES IFS T0 Versus HRES IFS Analysis
2227

docs/finetuning.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,37 @@ loss = ...
3737
loss.backward()
3838
```
3939

40+
## Exploding Gradients
41+
42+
When fine-tuning, you may run into very large gradient values.
43+
Gradient clipping and internal layer normalisation layers mitigate the impact
44+
of large gradients,
45+
meaning that large gradients will not immediately lead to abnormal model outputs and loss values.
46+
Nevertheless, if gradients do blow up, the model will not learn anymore and eventually the loss value
47+
will also blow up.
48+
You should carefully monitor the value of the gradients to detect exploding gradients.
49+
50+
One cause of exploding gradients is too large values for internal activations.
51+
Typically this can be fixed by judiciously inserting a layer normalisation layer.
52+
53+
We have identified the level aggregation as weak point of the model that can be susceptible
54+
to exploding gradients.
55+
You can stabilise the level aggregation of the model
56+
by setting the following flag in the constructor: `stabilise_level_agg=True`.
57+
Note that `stabilise_level_agg=True` will considerably perturb the model,
58+
so significant additional fine-tuning may be required to get to the desired level of performance.
59+
60+
```python
61+
from aurora import Aurora
62+
from aurora.normalisation import locations, scales
63+
64+
model = Aurora(
65+
use_lora=False,
66+
stabilise_level_agg=True, # Insert extra layer norm. to mitigate exploding gradients.
67+
)
68+
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt", strict=False)
69+
```
70+
4071
## Extending Aurora with New Variables
4172

4273
Aurora can be extended with new variables by adjusting the keyword arguments `surf_vars`,
@@ -66,6 +97,18 @@ scales["new_static_var"] = 1.0
6697
scales["new_atmos_var"] = 1.0
6798
```
6899

100+
To more efficiently learn new variables, it is recommended to use a separate learning rate for
101+
the patch embeddings of the new variables in the encoder and decoder.
102+
For example, if you are using Adam, you can try `1e-3` for the new patch embeddings
103+
and `3e-4` for the other parameters.
104+
105+
By default, patch embeddings in the encoder for new variables are initialised randomly.
106+
This means that adding new variables to the model perturbs the predictions for the existing
107+
variables.
108+
If you do not want this, you can alternatively initialise the new patch embeddings in the encoder
109+
to zero.
110+
The relevant parameter dictionaries are `model.encoder.{surf,atmos}_token_embeds.weights`.
111+
69112
## Other Model Extensions
70113

71114
It is possible to extend to model in any way you like.
@@ -83,3 +126,17 @@ model = Aurora(...)
83126

84127
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt", strict=False)
85128
```
129+
130+
## Triple Check Your Fine-Tuning Data!
131+
132+
When fine-tuning the model, it is absolutely essential to carefully check your fine-tuning data.
133+
134+
* Are the old (and possibly new) normalisation statistics appropriate for the new data?
135+
136+
* Is any data missing?
137+
138+
* Does the data contains zeros or NaNs?
139+
140+
* Does the data contain any outliers that could possibly interfere with fine-tuning?
141+
142+
_Et cetera._

0 commit comments

Comments
 (0)