Skip to content

feat(dlpack): support dlpack for zero copy sharing #3964

@HaoZeke

Description

@HaoZeke

Feature description

Support for consuming and exporting DLPack (DLManagedTensor) tensors. This would allow for zero-copy tensor sharing between Burn and other frameworks.

https://dmlc.github.io/dlpack/latest/

Feature motivation

We're building libraries (https://metatensor.org/) that bridge multiple frameworks (PyTorch, JAX, and hopefully Burn via our Rust core). DLPack is the standard for zero-copy sharing.

Right now, moving a JAX tensor to Burn would require a jax -> numpy (cpu) -> burn (gpu) copy, or a jax -> torch -> dlpack -> metatensor-torch flow which is clumsy.

Direct DLPack support in Burn would let us (and others) build backends that can operate on JAX, PyTorch, or CuPy memory in-place, which is critical for performance.

(Optional) Suggest a Solution

This will be unsafe and highly backend-specific.

  1. A new constructor on burn-tensor (or backend-specific tensor primitives) that can be initialized from a DLManagedTensor pointer.

  2. This implementation would have to:

    • Read the DLManagedTensor struct (device type, device id, pointer, shape, strides).

    • Find the corresponding Burn Device.

    • The hard part: Use unsafe functions to wrap the existing native GPU buffer.

  3. For burn-wgpu, this means using the underlying wgpu instance to import a native Vulkan VkBuffer, Metal MTLBuffer, etc. This seems related to the work discussed in wgpu for underlying API interop: Proposal for Underlying Api Interoperability gfx-rs/wgpu#4067

  4. An equivalent to_dlpack() method would be needed to export a Burn tensor, which would package its own buffer handle into a DLManagedTensor struct with a valid deleter.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions