-
Notifications
You must be signed in to change notification settings - Fork 6
Implement SimKO to add entropy in TopK token sampling during RL #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev-updated-again
Are you sure you want to change the base?
Conversation
…into add-grpo # Conflicts: # .github/CODEOWNERS # .github/workflows/integration_test_8gpu_h100.yaml # .github/workflows/integration_test_8gpu_models.yaml # .github/workflows/integration_test_8gpu_torchft.yaml # torchtitan/components/checkpoint.py # torchtitan/experiments/__init__.py # torchtitan/models/attention.py # torchtitan/models/deepseek_v3/infra/parallelize.py # torchtitan/models/llama3/__init__.py # torchtitan/models/llama3/model/args.py # torchtitan/models/llama3/model/model.py # torchtitan/tools/logging.py # torchtitan/train.py
add grpo :)
add inference logp IS
feat: qwen3-next
feat: seed-oss support
…ing; add compute_token_entropy and apply_simko_adjustment functions in GRPO step.
|
|
||
| return loss | ||
|
|
||
| def compute_token_entropy(pred, mask): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we have a vocab parallel entropy fn in utils, that should be used instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure will do in the following commit.
| def apply_simko_adjustment( | ||
| ratio, pred, labels, reward, mask, | ||
| alpha=0.01, K=3, lambda_top1=1.1, tau_percentile=80 | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this fn does not work with tensor parallel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok fixing it
| Returns: | ||
| adjusted_ratio: Ratio with SimKO adjustments | ||
| """ | ||
| # 1. Identify forking tokens (high-entropy) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this need backpropagated through, or can you wrap it with no_grad?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No this won't need a backprop and not seen anywhere in paper to mention so, yeah no grad is way to go
|
|
||
| # 2. Get top-K tokens and compute their ratios | ||
| _, topk_indices = torch.topk(pred, k=K, dim=-1) | ||
| new_log_probs_full = torch.nn.functional.log_softmax(pred, dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this uses up a lot of memory
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, do you think like doing log softmax one time would do? If we are doing a no grad on this, it could be reused for old poilicy top k probs as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well this needs grads, since the grads are flowing through this softmax, we need to combine the way we're currently getting logprobs with the top-k logprob method, as the logprobs we need for the GRPO/GSPO loss are not guaranteed to be within the top-k. the quick and dirty way may be to wrap it in a checkpoint and chunk it, that way we "save" the memory by having it re-execute on backward in manageable chunks, even if it's not optimal, it should still just be a tiny portion of the compute so we can add to the backlog as i'm mostly worried about memory 😎
921ffb2 to
94d6abd
Compare
This Impl is directly inspired from this paper - https://arxiv.org/pdf/2510.14807v1
I'm still actively testing out and will share results on it, the topK sampling metrics during training are yet to get logged in W&B