|
23 | 23 | from aurora.model.decoder import Perceiver3DDecoder |
24 | 24 | from aurora.model.encoder import Perceiver3DEncoder |
25 | 25 | from aurora.model.lora import LoRAMode |
26 | | -from aurora.model.swin3d import BasicLayer3D, Swin3DTransformerBackbone |
| 26 | +from aurora.model.swin3d import Swin3DTransformerBackbone |
27 | 27 |
|
28 | 28 | __all__ = [ |
29 | 29 | "Aurora", |
@@ -79,6 +79,7 @@ def __init__( |
79 | 79 | lora_mode: LoRAMode = "single", |
80 | 80 | surf_stats: Optional[dict[str, tuple[float, float]]] = None, |
81 | 81 | autocast: bool = False, |
| 82 | + bf16_mode: bool = False, |
82 | 83 | level_condition: Optional[tuple[int | float, ...]] = None, |
83 | 84 | dynamic_vars: bool = False, |
84 | 85 | atmos_static_vars: bool = False, |
@@ -141,8 +142,10 @@ def __init__( |
141 | 142 | surf_stats (dict[str, tuple[float, float]], optional): For these surface-level |
142 | 143 | variables, adjust the normalisation to the given tuple consisting of a new location |
143 | 144 | and scale. |
144 | | - autocast (bool, optional): Use `torch.autocast` to reduce memory usage. Defaults to |
145 | | - `False`. |
| 145 | + bf16_mode (bool, optional): To reduce memory usage, convert the tokens to BF16, run |
| 146 | + the backbone in pure BF16, and run the decoder in FP16 AMP. This should enable a |
| 147 | + gradient computation. USE AT YOUR OWN RISK. THIS WAS NOT USED DURING THE DEVELOPMENT |
| 148 | + OF AURORA AND IS PURELY PROVIDED AS A STARTING POINT FOR FINE-TUNING. |
146 | 149 | level_condition (tuple[int | float, ...], optional): Make the patch embeddings dependent |
147 | 150 | on pressure level. If you want to enable this feature, provide a tuple of all |
148 | 151 | possible pressure levels. |
@@ -176,7 +179,6 @@ def __init__( |
176 | 179 | self.atmos_vars = atmos_vars |
177 | 180 | self.patch_size = patch_size |
178 | 181 | self.surf_stats = surf_stats or dict() |
179 | | - self.autocast = autocast |
180 | 182 | self.max_history_size = max_history_size |
181 | 183 | self.timestep = timestep |
182 | 184 | self.use_lora = use_lora |
@@ -246,6 +248,19 @@ def __init__( |
246 | 248 | modulation_head=modulation_head, |
247 | 249 | ) |
248 | 250 |
|
| 251 | + if autocast and not bf16_mode: |
| 252 | + warnings.warn( |
| 253 | + "The argument `autocast` no longer does anything due to limited utility. " |
| 254 | + "Consider instead using `bf16_mode`.", |
| 255 | + stacklevel=2, |
| 256 | + ) |
| 257 | + |
| 258 | + self.bf16_mode = bf16_mode |
| 259 | + |
| 260 | + if self.bf16_mode: |
| 261 | + # We run the backbone in pure BF16. |
| 262 | + self.backbone.to(torch.bfloat16) |
| 263 | + |
249 | 264 | def forward(self, batch: Batch) -> Batch: |
250 | 265 | """Forward pass. |
251 | 266 |
|
@@ -302,24 +317,44 @@ def forward(self, batch: Batch) -> Batch: |
302 | 317 |
|
303 | 318 | transformed_batch = self._pre_encoder_hook(transformed_batch) |
304 | 319 |
|
| 320 | + # The encoder is always just run. |
305 | 321 | x = self.encoder( |
306 | 322 | transformed_batch, |
307 | 323 | lead_time=self.timestep, |
308 | 324 | ) |
309 | | - with torch.autocast(device_type="cuda") if self.autocast else contextlib.nullcontext(): |
310 | | - x = self.backbone( |
311 | | - x, |
312 | | - lead_time=self.timestep, |
313 | | - patch_res=patch_res, |
314 | | - rollout_step=batch.metadata.rollout_step, |
315 | | - ) |
316 | | - pred = self.decoder( |
| 325 | + |
| 326 | + # In BF16 mode, the backbone is run in pure BF16. |
| 327 | + if self.bf16_mode: |
| 328 | + x = x.to(torch.bfloat16) |
| 329 | + x = self.backbone( |
317 | 330 | x, |
318 | | - batch, |
319 | 331 | lead_time=self.timestep, |
320 | 332 | patch_res=patch_res, |
| 333 | + rollout_step=batch.metadata.rollout_step, |
321 | 334 | ) |
322 | 335 |
|
| 336 | + # In BF16 mode, the decoder is run in AMP PF16, and the output is converted back to FP32. |
| 337 | + # We run in PF16 as opposed to BF16 for improved relative precision. |
| 338 | + if self.bf16_mode: |
| 339 | + context = torch.autocast(device_type="cuda", dtype=torch.float16) |
| 340 | + x = x.to(torch.float16) |
| 341 | + else: |
| 342 | + context = contextlib.nullcontext() |
| 343 | + with context: |
| 344 | + pred = self.decoder( |
| 345 | + x, |
| 346 | + batch, |
| 347 | + lead_time=self.timestep, |
| 348 | + patch_res=patch_res, |
| 349 | + ) |
| 350 | + if self.bf16_mode: |
| 351 | + pred = dataclasses.replace( |
| 352 | + pred, |
| 353 | + surf_vars={k: v.float() for k, v in pred.surf_vars.items()}, |
| 354 | + static_vars={k: v.float() for k, v in pred.static_vars.items()}, |
| 355 | + atmos_vars={k: v.float() for k, v in pred.atmos_vars.items()}, |
| 356 | + ) |
| 357 | + |
323 | 358 | # Remove batch and history dimension from static variables. |
324 | 359 | pred = dataclasses.replace( |
325 | 360 | pred, |
@@ -476,7 +511,21 @@ def configure_activation_checkpointing(self): |
476 | 511 |
|
477 | 512 | This is required in order to compute gradients without running out of memory. |
478 | 513 | """ |
479 | | - apply_activation_checkpointing(self, check_fn=lambda x: isinstance(x, BasicLayer3D)) |
| 514 | + # Checkpoint these modules: |
| 515 | + module_names = ( |
| 516 | + "Perceiver3DEncoder", |
| 517 | + "Swin3DTransformerBackbone", |
| 518 | + "Basic3DEncoderLayer", |
| 519 | + "Basic3DDecoderLayer", |
| 520 | + "Perceiver3DDecoder", |
| 521 | + "LinearPatchReconstruction", |
| 522 | + ) |
| 523 | + |
| 524 | + def check(x: torch.nn.Module) -> bool: |
| 525 | + name = x.__class__.__name__ |
| 526 | + return name in module_names |
| 527 | + |
| 528 | + apply_activation_checkpointing(self, check_fn=check) |
480 | 529 |
|
481 | 530 |
|
482 | 531 | class AuroraPretrained(Aurora): |
|
0 commit comments