-
Notifications
You must be signed in to change notification settings - Fork 730
Description
Feature description
Support for consuming and exporting DLPack (DLManagedTensor) tensors. This would allow for zero-copy tensor sharing between Burn and other frameworks.
https://dmlc.github.io/dlpack/latest/
Feature motivation
We're building libraries (https://metatensor.org/) that bridge multiple frameworks (PyTorch, JAX, and hopefully Burn via our Rust core). DLPack is the standard for zero-copy sharing.
Right now, moving a JAX tensor to Burn would require a jax -> numpy (cpu) -> burn (gpu) copy, or a jax -> torch -> dlpack -> metatensor-torch flow which is clumsy.
Direct DLPack support in Burn would let us (and others) build backends that can operate on JAX, PyTorch, or CuPy memory in-place, which is critical for performance.
(Optional) Suggest a Solution
This will be unsafe and highly backend-specific.
-
A new constructor on
burn-tensor(or backend-specific tensor primitives) that can be initialized from aDLManagedTensorpointer. -
This implementation would have to:
-
Read the
DLManagedTensorstruct (device type, device id, pointer, shape, strides). -
Find the corresponding Burn
Device. -
The hard part: Use
unsafefunctions to wrap the existing native GPU buffer.
-
-
For
burn-wgpu, this means using the underlyingwgpuinstance to import a native VulkanVkBuffer, MetalMTLBuffer, etc. This seems related to the work discussed inwgpufor underlying API interop: Proposal for Underlying Api Interoperability gfx-rs/wgpu#4067 -
An equivalent
to_dlpack()method would be needed to export a Burn tensor, which would package its own buffer handle into aDLManagedTensorstruct with a valid deleter.