Skip to content

Conversation

@ighoshsubho
Copy link

This Impl is directly inspired from this paper - https://arxiv.org/pdf/2510.14807v1

I'm still actively testing out and will share results on it, the topK sampling metrics during training are yet to get logged in W&B

jquesnelle and others added 25 commits August 28, 2025 04:52
…into add-grpo

# Conflicts:
#	.github/CODEOWNERS
#	.github/workflows/integration_test_8gpu_h100.yaml
#	.github/workflows/integration_test_8gpu_models.yaml
#	.github/workflows/integration_test_8gpu_torchft.yaml
#	torchtitan/components/checkpoint.py
#	torchtitan/experiments/__init__.py
#	torchtitan/models/attention.py
#	torchtitan/models/deepseek_v3/infra/parallelize.py
#	torchtitan/models/llama3/__init__.py
#	torchtitan/models/llama3/model/args.py
#	torchtitan/models/llama3/model/model.py
#	torchtitan/tools/logging.py
#	torchtitan/train.py
…ing; add compute_token_entropy and apply_simko_adjustment functions in GRPO step.
@ighoshsubho ighoshsubho requested a review from dmahan93 October 29, 2025 11:28

return loss

def compute_token_entropy(pred, mask):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have a vocab parallel entropy fn in utils, that should be used instead

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure will do in the following commit.

def apply_simko_adjustment(
ratio, pred, labels, reward, mask,
alpha=0.01, K=3, lambda_top1=1.1, tau_percentile=80
):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this fn does not work with tensor parallel

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok fixing it

Returns:
adjusted_ratio: Ratio with SimKO adjustments
"""
# 1. Identify forking tokens (high-entropy)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need backpropagated through, or can you wrap it with no_grad?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this won't need a backprop and not seen anywhere in paper to mention so, yeah no grad is way to go


# 2. Get top-K tokens and compute their ratios
_, topk_indices = torch.topk(pred, k=K, dim=-1)
new_log_probs_full = torch.nn.functional.log_softmax(pred, dim=-1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this uses up a lot of memory

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, do you think like doing log softmax one time would do? If we are doing a no grad on this, it could be reused for old poilicy top k probs as well

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well this needs grads, since the grads are flowing through this softmax, we need to combine the way we're currently getting logprobs with the top-k logprob method, as the logprobs we need for the GRPO/GSPO loss are not guaranteed to be within the top-k. the quick and dirty way may be to wrap it in a checkpoint and chunk it, that way we "save" the memory by having it re-execute on backward in manageable chunks, even if it's not optimal, it should still just be a tiny portion of the compute so we can add to the backlog as i'm mostly worried about memory 😎

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants