mps: add nf4 dequantize/quantize kernel #1790
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
this ports the CUDA NF4 support to Metal.
so far, I've targeted nf4 quant/dequant because it's one of the least-accessible formats for Mac users.
we're using uint8 under the hood. for what it's worth, Metal (and the underlying hardware) lacks fp8/fp4 support.
performance has not been the forefront of this effort, as most of the time was spent determining how to plug metallib into bitsandbytes and correctly build it.
I'd like some feedback on this approach, because due to my inexperience with your build toolchain, it's highly likely I've done things in ways that can be improved.
I'm building on lessons I'd learnt while building a pytorch custom op for universal-metal-flash-attention, namely the way the MTLBuffers are retrieved from torch MPSGraph objects, which required the use of the torch headers.