Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- Fixed ``ModelParallelStrategy`` single-file checkpointing when ``torch.compile`` wraps the model so optimizer states no longer raise ``KeyError`` during save ([#21357](https://github.com/Lightning-AI/pytorch-lightning/issues/21357))

- Fixed gradient clipping not working with fused optimizers when using ``bf16-mixed`` precision ([#21435](https://github.com/Lightning-AI/pytorch-lightning/issues/21435))
-

### Deprecated
Expand Down
2 changes: 1 addition & 1 deletion src/lightning/pytorch/plugins/precision/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def clip_gradients(
clip_val: Union[int, float] = 0.0,
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
) -> None:
if clip_val > 0 and _optimizer_handles_unscaling(optimizer):
if clip_val > 0 and self.scaler is not None and _optimizer_handles_unscaling(optimizer):
raise RuntimeError(
f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping"
" because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?"
Expand Down
28 changes: 22 additions & 6 deletions tests/tests_pytorch/plugins/precision/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,31 @@ def test_clip_gradients():
precision.clip_grad_by_norm.assert_called_once()


def test_optimizer_amp_scaling_support_in_step_method():
"""Test that the plugin checks if the optimizer takes over unscaling in its step, making it incompatible with
gradient clipping (example: fused Adam)."""
@pytest.mark.parametrize(
("precision", "scaler", "should_error"),
[
("16-mixed", Mock(), True), # fp16 with scaler: fused optimizer + clip = error
("bf16-mixed", None, False), # bf16 no scaler: fused optimizer + clip = ok
],
)
def test_optimizer_amp_scaling_support_in_step_method(precision, scaler, should_error):
"""Test that gradient clipping with fused optimizers is only blocked when a scaler is present.

The `_step_supports_amp_scaling` flag indicates the optimizer handles unscaling internally (e.g., fused Adam).
This is incompatible with gradient clipping only when using a GradScaler (16-mixed), since we can't unscale
before clipping. With bf16-mixed there's no scaler, so gradient clipping works normally.

"""
optimizer = Mock(_step_supports_amp_scaling=True)
precision = MixedPrecision(precision="16-mixed", device="cuda:0", scaler=Mock())
plugin = MixedPrecision(precision=precision, device="cuda:0", scaler=scaler)
plugin.clip_grad_by_norm = Mock()

with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"):
precision.clip_gradients(optimizer, clip_val=1.0)
if should_error:
with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"):
plugin.clip_gradients(optimizer, clip_val=1.0)
else:
plugin.clip_gradients(optimizer, clip_val=1.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM)
plugin.clip_grad_by_norm.assert_called_once()


def test_amp_with_no_grad():
Expand Down