Skip to content

Commit 7765cd4

Browse files
authored
Merge pull request #116 from microsoft/wesselb/bf16-mode
Add `bf16_mode` to enable gradient computation
2 parents 03152ad + 007048a commit 7765cd4

File tree

3 files changed

+78
-20
lines changed

3 files changed

+78
-20
lines changed

aurora/model/aurora.py

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from aurora.model.decoder import Perceiver3DDecoder
2424
from aurora.model.encoder import Perceiver3DEncoder
2525
from aurora.model.lora import LoRAMode
26-
from aurora.model.swin3d import BasicLayer3D, Swin3DTransformerBackbone
26+
from aurora.model.swin3d import Swin3DTransformerBackbone
2727

2828
__all__ = [
2929
"Aurora",
@@ -79,6 +79,7 @@ def __init__(
7979
lora_mode: LoRAMode = "single",
8080
surf_stats: Optional[dict[str, tuple[float, float]]] = None,
8181
autocast: bool = False,
82+
bf16_mode: bool = False,
8283
level_condition: Optional[tuple[int | float, ...]] = None,
8384
dynamic_vars: bool = False,
8485
atmos_static_vars: bool = False,
@@ -141,8 +142,10 @@ def __init__(
141142
surf_stats (dict[str, tuple[float, float]], optional): For these surface-level
142143
variables, adjust the normalisation to the given tuple consisting of a new location
143144
and scale.
144-
autocast (bool, optional): Use `torch.autocast` to reduce memory usage. Defaults to
145-
`False`.
145+
bf16_mode (bool, optional): To reduce memory usage, convert the tokens to BF16, run
146+
the backbone in pure BF16, and run the decoder in FP16 AMP. This should enable a
147+
gradient computation. USE AT YOUR OWN RISK. THIS WAS NOT USED DURING THE DEVELOPMENT
148+
OF AURORA AND IS PURELY PROVIDED AS A STARTING POINT FOR FINE-TUNING.
146149
level_condition (tuple[int | float, ...], optional): Make the patch embeddings dependent
147150
on pressure level. If you want to enable this feature, provide a tuple of all
148151
possible pressure levels.
@@ -176,7 +179,6 @@ def __init__(
176179
self.atmos_vars = atmos_vars
177180
self.patch_size = patch_size
178181
self.surf_stats = surf_stats or dict()
179-
self.autocast = autocast
180182
self.max_history_size = max_history_size
181183
self.timestep = timestep
182184
self.use_lora = use_lora
@@ -246,6 +248,19 @@ def __init__(
246248
modulation_head=modulation_head,
247249
)
248250

251+
if autocast and not bf16_mode:
252+
warnings.warn(
253+
"The argument `autocast` no longer does anything due to limited utility. "
254+
"Consider instead using `bf16_mode`.",
255+
stacklevel=2,
256+
)
257+
258+
self.bf16_mode = bf16_mode
259+
260+
if self.bf16_mode:
261+
# We run the backbone in pure BF16.
262+
self.backbone.to(torch.bfloat16)
263+
249264
def forward(self, batch: Batch) -> Batch:
250265
"""Forward pass.
251266
@@ -302,24 +317,44 @@ def forward(self, batch: Batch) -> Batch:
302317

303318
transformed_batch = self._pre_encoder_hook(transformed_batch)
304319

320+
# The encoder is always just run.
305321
x = self.encoder(
306322
transformed_batch,
307323
lead_time=self.timestep,
308324
)
309-
with torch.autocast(device_type="cuda") if self.autocast else contextlib.nullcontext():
310-
x = self.backbone(
311-
x,
312-
lead_time=self.timestep,
313-
patch_res=patch_res,
314-
rollout_step=batch.metadata.rollout_step,
315-
)
316-
pred = self.decoder(
325+
326+
# In BF16 mode, the backbone is run in pure BF16.
327+
if self.bf16_mode:
328+
x = x.to(torch.bfloat16)
329+
x = self.backbone(
317330
x,
318-
batch,
319331
lead_time=self.timestep,
320332
patch_res=patch_res,
333+
rollout_step=batch.metadata.rollout_step,
321334
)
322335

336+
# In BF16 mode, the decoder is run in AMP PF16, and the output is converted back to FP32.
337+
# We run in PF16 as opposed to BF16 for improved relative precision.
338+
if self.bf16_mode:
339+
context = torch.autocast(device_type="cuda", dtype=torch.float16)
340+
x = x.to(torch.float16)
341+
else:
342+
context = contextlib.nullcontext()
343+
with context:
344+
pred = self.decoder(
345+
x,
346+
batch,
347+
lead_time=self.timestep,
348+
patch_res=patch_res,
349+
)
350+
if self.bf16_mode:
351+
pred = dataclasses.replace(
352+
pred,
353+
surf_vars={k: v.float() for k, v in pred.surf_vars.items()},
354+
static_vars={k: v.float() for k, v in pred.static_vars.items()},
355+
atmos_vars={k: v.float() for k, v in pred.atmos_vars.items()},
356+
)
357+
323358
# Remove batch and history dimension from static variables.
324359
pred = dataclasses.replace(
325360
pred,
@@ -476,7 +511,21 @@ def configure_activation_checkpointing(self):
476511
477512
This is required in order to compute gradients without running out of memory.
478513
"""
479-
apply_activation_checkpointing(self, check_fn=lambda x: isinstance(x, BasicLayer3D))
514+
# Checkpoint these modules:
515+
module_names = (
516+
"Perceiver3DEncoder",
517+
"Swin3DTransformerBackbone",
518+
"Basic3DEncoderLayer",
519+
"Basic3DDecoderLayer",
520+
"Perceiver3DDecoder",
521+
"LinearPatchReconstruction",
522+
)
523+
524+
def check(x: torch.nn.Module) -> bool:
525+
name = x.__class__.__name__
526+
return name in module_names
527+
528+
apply_activation_checkpointing(self, check_fn=check)
480529

481530

482531
class AuroraPretrained(Aurora):

aurora/model/decoder.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
__all__ = ["Perceiver3DDecoder"]
2121

2222

23+
class LinearPatchReconstruction(nn.Linear):
24+
"""Linear layer for patch reconstruction."""
25+
26+
2327
class Perceiver3DDecoder(nn.Module):
2428
"""Multi-scale multi-source multi-variable decoder based on the Perceiver architecture."""
2529

@@ -110,17 +114,17 @@ def __init__(
110114
)
111115

112116
self.surf_heads = nn.ParameterDict(
113-
{name: nn.Linear(embed_dim, patch_size**2) for name in surf_vars}
117+
{name: LinearPatchReconstruction(embed_dim, patch_size**2) for name in surf_vars}
114118
)
115119
if not self.level_condition:
116120
self.atmos_heads = nn.ParameterDict(
117-
{name: nn.Linear(embed_dim, patch_size**2) for name in atmos_vars}
121+
{name: LinearPatchReconstruction(embed_dim, patch_size**2) for name in atmos_vars}
118122
)
119123
else:
120124
self.atmos_heads = nn.ParameterDict(
121125
{
122126
name: LevelConditioned(
123-
lambda: nn.Linear(embed_dim, patch_size**2),
127+
lambda: LinearPatchReconstruction(embed_dim, patch_size**2),
124128
levels=self.level_condition,
125129
levels_dim=-2,
126130
)

docs/finetuning.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,19 @@ model.load_checkpoint()
1313
## Computing Gradients
1414

1515
To compute gradients, you will need an A100 with 80 GB of memory.
16-
In addition, you will need to use [PyTorch AMP](https://pytorch.org/docs/stable/amp.html)
17-
and gradient checkpointing.
16+
In addition, you will need to use reduced precision and gradient checkpointing.
1817
You can do this as follows:
1918

2019
```python
2120
from aurora import AuroraPretrained
2221

23-
model = AuroraPretrained(autocast=True) # Use AMP.
22+
model = AuroraPretrained(
23+
# BF16 mode is an EXPERIMENTAL mode that saves memory by running the backbone in pure BF16
24+
# and the decoder in FP16 AMP. This should enable gradient computation. USE AT YOUR OWN RISK.
25+
# THIS WAS NOT USED IN THE DEVELOPMENT OF AURORA AND IS PURELY PROVIDED AS A STARTING POINT
26+
# FOR FINE-TUNING.
27+
bf16_mode=True,
28+
)
2429
model.load_checkpoint()
2530

2631
batch = ... # Load some data.

0 commit comments

Comments
 (0)