Skip to content

Conversation

@levendlee
Copy link
Member

What does this PR do?

Fixes # (issue).

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

ngoyal2707 and others added 21 commits March 29, 2024 15:12
This commit works with a 4 GPU run on SMALL model with FSDP and PP
enabled.
- Clean up flatten and non_flatten parameter generation logic.
- Avoid checking `main_grad` attribute all equal to zeros.
- Cleans up amax and scale update logic. Amax and scale should be
  done for both weights and parameters. So it should be done at
  forward of each microbatch.

- Consolidate `cast_params` and `all_gather` stream.
This commit works with a 4 GPU run on SMALL model with FSDP and PP
enabled.
- Clean up flatten and non_flatten parameter generation logic.
- Avoid checking `main_grad` attribute all equal to zeros.
- Cleans up amax and scale update logic. Amax and scale should be
  done for both weights and parameters. So it should be done at
  forward of each microbatch.

- Consolidate `cast_params` and `all_gather` stream.
…kresearch/fairscale into shikaili_fp8_allgather_no_pp_fix
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 20, 2024
Copy link

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @levendlee for the great work! I left some comments for my own learning.

and all(_is_te_module_with_weights(info[1]) for info in p._param_infos))
if fused_wgard_accumulation:
if getattr(p, "main_grad", None) is None:
p.main_grad = torch.empty_like(p, dtype=torch.float32)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, why empty_like instead of zeros_like?

"""Update Amax and scales associated with FP8 parameters."""
if params is None:
params = self.params
with torch.cuda.stream(self._streams["fp32_to_fp16"]):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why did you use the "all_gather" stream instead of the "fp32_to_fp16" stream?

self.has_full_params = False

if self.fp8_all_gather:
self._update_amax_and_scale_fwd(is_first_microbatch_fwd=is_first_microbatch_fwd)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, is there a reason that this is not done together with _cast_params_for_all_gather? (For example, could this call be delayed a few lines to below where _cast_params_for_all_gather is called?)


@torch.no_grad()
def _rebuild_full_params(self, force_full_precision: bool = False, wait_for_all_gather = True) -> Optional[List[Tuple[torch.Tensor, bool]]]:
def _rebuild_full_params(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For fp8_all_gather=True, what happens when this method is called without the TE autocast context?

# All-gather full parameters. This will also transfer FP32 parameters to
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
self._rebuild_full_params()
self.module.has_unflatten_views = getattr(self.module, "has_unflatten_views", False)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants