Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 59 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,19 @@ elseif(BUILD_MPS)
string(APPEND BNB_OUTPUT_NAME "_mps")
add_compile_definitions(BUILD_MPS)
file(MAKE_DIRECTORY "build")
add_custom_command(OUTPUT "bitsandbytes/bitsandbytes.metallib"
COMMAND xcrun metal -c -o "build/bitsandbytes.air" ${METAL_FILES}
COMMAND xcrun metallib "build/bitsandbytes.air" -o "bitsandbytes/bitsandbytes.metallib"
set(METAL_AIR "${CMAKE_BINARY_DIR}/bitsandbytes.air")
set(METAL_LIB "${PROJECT_SOURCE_DIR}/bitsandbytes/bitsandbytes.metallib")
set(METAL_SOURCES "")
foreach(METAL_FILE ${METAL_FILES})
list(APPEND METAL_SOURCES "${PROJECT_SOURCE_DIR}/${METAL_FILE}")
endforeach()
add_custom_command(OUTPUT "${METAL_LIB}"
COMMAND xcrun metal -c ${METAL_SOURCES} -o "${METAL_AIR}"
COMMAND xcrun metallib "${METAL_AIR}" -o "${METAL_LIB}"
DEPENDS "${METAL_FILES}"
COMMENT "Compiling Metal kernels"
VERBATIM)
add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib")
add_custom_target(metallib DEPENDS "${METAL_LIB}")
elseif(BUILD_XPU)
list(APPEND SRC_FILES ${XPU_FILES})
string(APPEND BNB_OUTPUT_NAME "_xpu")
Expand All @@ -257,10 +263,57 @@ if(MSVC)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /arch:AVX2 /fp:fast")
endif()

find_package(Python3 COMPONENTS Interpreter Development)
message(STATUS "Python3 found: ${Python3_FOUND}")

if(NOT Torch_DIR)
execute_process(
COMMAND ${Python3_EXECUTABLE} -c "import torch, pathlib; print(pathlib.Path(torch.__file__).resolve().parent / 'share/cmake/Torch')"
OUTPUT_VARIABLE Torch_DIR
OUTPUT_STRIP_TRAILING_WHITESPACE
)
endif()
message(STATUS "Torch_DIR=${Torch_DIR}")
find_package(Torch REQUIRED CONFIG)

set_source_files_properties(${CPP_FILES} PROPERTIES LANGUAGE CXX)
add_library(bitsandbytes SHARED ${SRC_FILES})
target_compile_features(bitsandbytes PUBLIC cxx_std_17)
target_include_directories(bitsandbytes PUBLIC csrc include)
if(Python3_FOUND)
message(STATUS "Python include dirs: ${Python3_INCLUDE_DIRS}")
target_include_directories(bitsandbytes PRIVATE ${Python3_INCLUDE_DIRS})
execute_process(
COMMAND ${Python3_EXECUTABLE} -c "import sysconfig; print(sysconfig.get_paths()['include'])"
OUTPUT_VARIABLE PYTHON_SYSTEM_INCLUDE
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(PYTHON_SYSTEM_INCLUDE)
target_include_directories(bitsandbytes PRIVATE ${PYTHON_SYSTEM_INCLUDE})
endif()
execute_process(
COMMAND ${Python3_EXECUTABLE} -c "import torch\nfrom torch.utils.cpp_extension import include_paths\nprint(';'.join(include_paths()))"
OUTPUT_VARIABLE TORCH_INCLUDE_DIRS
OUTPUT_STRIP_TRAILING_WHITESPACE
ERROR_QUIET
)
if(TORCH_INCLUDE_DIRS)
string(REPLACE "\\n" ";" TORCH_INCLUDE_DIRS "${TORCH_INCLUDE_DIRS}")
target_include_directories(bitsandbytes PRIVATE ${TORCH_INCLUDE_DIRS})
endif()
execute_process(
COMMAND ${Python3_EXECUTABLE} -c "import torch\nfrom torch.utils.cpp_extension import library_paths\nprint(';'.join(library_paths()))"
OUTPUT_VARIABLE TORCH_LIBRARY_DIRS
OUTPUT_STRIP_TRAILING_WHITESPACE
ERROR_QUIET
)
if(TORCH_LIBRARY_DIRS)
string(REPLACE "\\n" ";" TORCH_LIBRARY_DIRS "${TORCH_LIBRARY_DIRS}")
target_link_directories(bitsandbytes PRIVATE ${TORCH_LIBRARY_DIRS})
target_link_libraries(bitsandbytes PRIVATE torch torch_cpu torch_python c10)
endif()
target_link_libraries(bitsandbytes PRIVATE ${Python3_LIBRARIES})
endif()


if(BUILD_CUDA)
Expand Down Expand Up @@ -308,7 +361,8 @@ if(BUILD_HIP)
endif()
if(BUILD_MPS)
add_dependencies(bitsandbytes metallib)
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
target_compile_options(bitsandbytes PRIVATE "-fno-objc-arc")
target_link_libraries(bitsandbytes PRIVATE objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
endif()
if(BUILD_XPU)
set(SYCL_LINK_FLAGS "-fsycl;--offload-compress;-fsycl-targets=spir64_gen,spir64;-Xs;-device pvc,xe-lpg,ats-m150 -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required'")
Expand Down
3 changes: 3 additions & 0 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
if hasattr(torch, "xpu") and torch.xpu.is_available():
from .backends.xpu import ops as xpu_ops

if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
from .backends.mps import ops as mps_ops

if importlib.util.find_spec("habana_frameworks") and importlib.util.find_spec("habana_frameworks.torch"):
# In case not automatically imported
import habana_frameworks.torch
Expand Down
14 changes: 12 additions & 2 deletions bitsandbytes/backends/default/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,19 @@
from math import prod, sqrt
from typing import Optional

import importlib.util

import torch

_HAS_TRITON = importlib.util.find_spec("triton") is not None


def _maybe_compile(fn):
if not _HAS_TRITON:
return fn
return torch.compile(fn)

Check failure on line 16 in bitsandbytes/backends/default/ops.py

View workflow job for this annotation

GitHub Actions / Lint

Ruff (E402)

bitsandbytes/backends/default/ops.py:16:1: E402 Module level import not at top of file
from ..._ops import register_kernel

Check failure on line 17 in bitsandbytes/backends/default/ops.py

View workflow job for this annotation

GitHub Actions / Lint

Ruff (E402)

bitsandbytes/backends/default/ops.py:17:1: E402 Module level import not at top of file
from ..utils import CODE


Expand Down Expand Up @@ -321,7 +331,7 @@
}


@torch.compile
@_maybe_compile
def _optimizer_precondition_32bit(
g: torch.Tensor,
p: torch.Tensor,
Expand Down Expand Up @@ -382,7 +392,7 @@
unorm_vec.add_(total_norm)


@torch.compile
@_maybe_compile
def _optimizer_update_32bit(
g: torch.Tensor,
p: torch.Tensor,
Expand Down
117 changes: 117 additions & 0 deletions bitsandbytes/backends/mps/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from __future__ import annotations

import ctypes as ct
from typing import Sequence, Tuple

import torch

from ..._ops import register_kernel
from ...cextension import lib
_ALLOWED_BLOCKS = (64, 128, 256, 512, 1024, 2048, 4096)
_SUPPORTED_DTYPES = (torch.float16, torch.float32)


lib.cquantize_blockwise_fp16_nf4_tensor.argtypes = [ct.py_object, ct.py_object, ct.py_object, ct.c_int32]
lib.cquantize_blockwise_fp16_nf4_tensor.restype = None
lib.cquantize_blockwise_fp32_nf4_tensor.argtypes = [ct.py_object, ct.py_object, ct.py_object, ct.c_int32]
lib.cquantize_blockwise_fp32_nf4_tensor.restype = None
lib.cdequantize_blockwise_fp16_nf4_tensor.argtypes = [ct.py_object, ct.py_object, ct.py_object, ct.c_int32]
lib.cdequantize_blockwise_fp16_nf4_tensor.restype = None
lib.cdequantize_blockwise_fp32_nf4_tensor.argtypes = [ct.py_object, ct.py_object, ct.py_object, ct.c_int32]
lib.cdequantize_blockwise_fp32_nf4_tensor.restype = None


def _quantize_nf4(
A: torch.Tensor, blocksize: int, quant_storage: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
torch._check(blocksize in _ALLOWED_BLOCKS)
torch._check(quant_storage == torch.uint8, lambda: "Only uint8 storage is supported for NF4 on MPS.")

A = A.contiguous()
n = A.numel()
blocks = -(n // -blocksize)

absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
out = torch.empty(((n + 1) // 2, 1), device=A.device, dtype=quant_storage)

if A.dtype == torch.float16:
lib.cquantize_blockwise_fp16_nf4_tensor(ct.py_object(A), ct.py_object(absmax), ct.py_object(out), ct.c_int32(blocksize))
elif A.dtype == torch.float32:
lib.cquantize_blockwise_fp32_nf4_tensor(ct.py_object(A), ct.py_object(absmax), ct.py_object(out), ct.c_int32(blocksize))
else:
torch._check(False, lambda: f"NF4 quantization on MPS supports {list(_SUPPORTED_DTYPES)}, got {A.dtype}")

return out, absmax


def _dequantize_nf4(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
torch._check(blocksize in _ALLOWED_BLOCKS)

A = A.contiguous()
absmax = absmax.contiguous()
torch._check(out.is_contiguous(), lambda: "Output tensor must be contiguous for NF4 dequantization on MPS.")

if dtype == torch.float16:
lib.cdequantize_blockwise_fp16_nf4_tensor(ct.py_object(A), ct.py_object(absmax), ct.py_object(out), ct.c_int32(blocksize))
elif dtype == torch.float32:
lib.cdequantize_blockwise_fp32_nf4_tensor(ct.py_object(A), ct.py_object(absmax), ct.py_object(out), ct.c_int32(blocksize))
else:
torch._check(False, lambda: f"NF4 dequantization on MPS supports {list(_SUPPORTED_DTYPES)}, got {dtype}")


@register_kernel("bitsandbytes::quantize_4bit", "mps")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
if quant_type != "nf4" or A.dtype not in _SUPPORTED_DTYPES:
return torch.ops.bitsandbytes.quantize_4bit.default(A, blocksize, quant_type, quant_storage)
return _quantize_nf4(A, blocksize, quant_storage)


@register_kernel("bitsandbytes::dequantize_4bit", "mps")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
if quant_type != "nf4" or dtype not in _SUPPORTED_DTYPES:
return torch.ops.bitsandbytes.dequantize_4bit.default(A, absmax, blocksize, quant_type, shape, dtype)
out = torch.empty(shape, dtype=dtype, device=A.device)
_dequantize_nf4(A, absmax, blocksize, dtype, out)
return out


@register_kernel("bitsandbytes::dequantize_4bit.out", "mps")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
out: torch.Tensor,
) -> None:
if quant_type != "nf4" or dtype not in _SUPPORTED_DTYPES:
torch.ops.bitsandbytes.dequantize_4bit.out.default(
A,
absmax,
blocksize,
quant_type,
shape,
dtype,
out,
)
return

torch._check(out.shape == tuple(shape), lambda: f"Expected out.shape == {tuple(shape)}, got {out.shape}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
_dequantize_nf4(A, absmax, blocksize, dtype, out)
6 changes: 5 additions & 1 deletion bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,9 @@ def get_native_library() -> BNBNativeLibrary:

binary_path = cuda_binary_path

if torch._C._has_xpu:
if BNB_BACKEND == "MPS":
binary_path = PACKAGE_DIR / f"libbitsandbytes_mps{DYNAMIC_LIBRARY_SUFFIX}"
elif torch._C._has_xpu:
binary_path = PACKAGE_DIR / f"libbitsandbytes_xpu{DYNAMIC_LIBRARY_SUFFIX}"

logger.debug(f"Loading bitsandbytes native library from: {binary_path}")
Expand All @@ -306,6 +308,8 @@ def get_native_library() -> BNBNativeLibrary:
BNB_BACKEND = "ROCm"
elif torch.cuda.is_available():
BNB_BACKEND = "CUDA"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
BNB_BACKEND = "MPS"
elif torch._C._has_xpu:
BNB_BACKEND = "XPU"

Expand Down
Loading
Loading