-
Notifications
You must be signed in to change notification settings - Fork 2.1k
ENH: Improve torch.compile support in MetaMath #2900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
ENH: Improve torch.compile support in MetaMath #2900
Conversation
The MetaMathQA benchmark already had support to enable torch.compile but it was not very well implemented. The new changes are: - call compile after applying PEFT, not before - compile with dynamic=True - avoid model.eval() + model.train() calls These changes prevent graph breaks and recompiles. A context manager is now used to ensure that those don't happen. Some unrelated changes: - improve some type annotations - use dtype argument instead of deprecated torch_dtype
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Note to self: How to deal with dropout? It's not deactivated by |
|
After some testing: When running evaluation, we would generally want to put the model into Here are the numbers for no compile, compile with train/eval switch, and compile without
Validation accuracy varies a bit, but that's to be expected with a rather small validation set size and measuring on generations. I tried a couple of mitigations, like running eval with @githubnemo What would you prefer: Live with the recompilation or avoid the |
|
Thanks for investigating! If I understand correctly there's almost no time penalty for using the more correct (recompilation) variant, so I'd opt for that since dropout is only one potential candidate for train/eval mismatches. |
Yes, we could do that. It means, however, that we have to remove the |
The MetaMathQA benchmark already had support to enable
torch.compilebut it was not very well implemented. The new changes are:dynamic=Truemodel.eval()+model.train()callsThese changes prevent graph breaks and recompiles. A context manager is now used to ensure that those don't happen.
Some unrelated changes:
dtypeargument instead of deprecatedtorch_dtype