Skip to content

Conversation

@BenjaminBossan
Copy link
Member

The MetaMathQA benchmark already had support to enable torch.compile but it was not very well implemented. The new changes are:

  • call compile after applying PEFT, not before
  • compile with dynamic=True
  • avoid model.eval() + model.train() calls

These changes prevent graph breaks and recompiles. A context manager is now used to ensure that those don't happen.

Some unrelated changes:

  • improve some type annotations
  • use dtype argument instead of deprecated torch_dtype

The MetaMathQA benchmark already had support to enable torch.compile but
it was not very well implemented. The new changes are:

- call compile after applying PEFT, not before
- compile with dynamic=True
- avoid model.eval() + model.train() calls

These changes prevent graph breaks and recompiles. A context manager is
now used to ensure that those don't happen.

Some unrelated changes:

- improve some type annotations
- use dtype argument instead of deprecated torch_dtype
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Member Author

Note to self: How to deal with dropout? It's not deactivated by torch.inference_mode().

@BenjaminBossan BenjaminBossan marked this pull request as draft November 20, 2025 14:32
@BenjaminBossan
Copy link
Member Author

After some testing: When running evaluation, we would generally want to put the model into eval mode (dropout). However, this triggers a re-compile when the model is put back into train mode (i.e. a total of two compiles happen). We could skip this train/eval toggle during training to avoid the re-compile. This would mean that the model is in train mode when evaluating, but arguably that is not a very big deal. Obviously, when it comes to the test set, we do put the model in eval mode first.

Here are the numbers for no compile, compile with train/eval switch, and compile without train/eval switch:

metric no compile compile w/ train/eval switch compile w/o switch
reserved mem max 22.3 GB 16.4 GB 16.4 GB
reserved mem avg 14.4 GB 11.2 GB 11.2 GB
reserved mem 99th 20.1 GB 14.6 GB 14.6 GB
number of compiles 0 1 2
train time / step ~29 sec ~24 sec ~24 sec
final train loss (sanity check) 0.60717 0.60710 0.60695

Validation accuracy varies a bit, but that's to be expected with a rather small validation set size and measuring on generations.

I tried a couple of mitigations, like running eval with torch.compiler.disable or compiling the eval function separately (which wouldn't really save time, as we replace a recompile with another compile), but nothing I tried helps.

@githubnemo What would you prefer: Live with the recompilation or avoid the train/eval switch?

@BenjaminBossan BenjaminBossan marked this pull request as ready for review November 25, 2025 11:05
@githubnemo
Copy link
Collaborator

Thanks for investigating!

If I understand correctly there's almost no time penalty for using the more correct (recompilation) variant, so I'd opt for that since dropout is only one potential candidate for train/eval mismatches.

@BenjaminBossan
Copy link
Member Author

If I understand correctly there's almost no time penalty for using the more correct (recompilation) variant, so I'd opt for that since dropout is only one potential candidate for train/eval mismatches.

Yes, we could do that. It means, however, that we have to remove the error_on_recompile context, which could prevent us from detecting other recompilation issues. LMK if that sounds acceptable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants