From a784fee162ec4ed8c1f41881fad4ade6c6d8633f Mon Sep 17 00:00:00 2001 From: Theo Lee Date: Tue, 9 Dec 2025 00:46:34 +0900 Subject: [PATCH 1/8] create quantized qwen3 moe module --- candle-transformers/src/models/mod.rs | 1 + candle-transformers/src/models/quantized_qwen3_moe.rs | 0 2 files changed, 1 insertion(+) create mode 100644 candle-transformers/src/models/quantized_qwen3_moe.rs diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index e77ba4a36f..2d93833581 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -94,6 +94,7 @@ pub mod quantized_phi; pub mod quantized_phi3; pub mod quantized_qwen2; pub mod quantized_qwen3; +pub mod quantized_qwen3_moe; pub mod quantized_recurrent_gemma; pub mod quantized_rwkv_v5; pub mod quantized_rwkv_v6; diff --git a/candle-transformers/src/models/quantized_qwen3_moe.rs b/candle-transformers/src/models/quantized_qwen3_moe.rs new file mode 100644 index 0000000000..e69de29bb2 From 945b688990829d757a97d886e1db197658692f35 Mon Sep 17 00:00:00 2001 From: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Sat, 6 Dec 2025 20:14:45 +0100 Subject: [PATCH 2/8] [Metal] unary and affine improvements (#3230) * Update unary.metal * update metal unary tests * Remove metal tiled unary kernels (now automated) * Optimize metal affine * Optimize metal powf * Optimize metal elu --- candle-core/src/metal_backend/mod.rs | 526 +++++++----------- candle-metal-kernels/src/kernels/affine.rs | 17 +- candle-metal-kernels/src/kernels/macros.rs | 24 - candle-metal-kernels/src/kernels/unary.rs | 62 +-- candle-metal-kernels/src/lib.rs | 2 +- .../src/metal_src/affine.metal | 253 +++++---- .../src/metal_src/unary.metal | 405 +++++++------- candle-metal-kernels/src/tests.rs | 6 +- candle-metal-kernels/src/utils.rs | 7 + candle-nn/src/ops.rs | 111 ++-- 10 files changed, 661 insertions(+), 752 deletions(-) diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index d3ab0da902..e2f8224d60 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -141,6 +141,7 @@ impl BackendStorage for MetalStorage { &encoder, &device.kernels, name, + self.dtype.size_in_bytes(), el, src, &buffer, @@ -198,6 +199,7 @@ impl BackendStorage for MetalStorage { &encoder, &device.kernels, name, + self.dtype.size_in_bytes(), el, src, &buffer, @@ -250,6 +252,7 @@ impl BackendStorage for MetalStorage { &encoder, &device.kernels, name, + self.dtype.size_in_bytes(), el, src, &buffer, @@ -446,88 +449,68 @@ impl BackendStorage for MetalStorage { encoder.set_label("const-set"); let dst = buffer_o(&self_.buffer, l, self_.dtype); - match (el_count % 2, dtype, l.is_contiguous()) { - (0, DType::BF16 | DType::F16, true) => { - use candle_metal_kernels::unary::contiguous_tiled; - let kernel_name = match dtype { - DType::F16 => contiguous_tiled::const_set::HALF, - DType::BF16 => contiguous_tiled::const_set::BFLOAT, - _ => unreachable!(), - }; - candle_metal_kernels::call_const_set_contiguous_tiled( - &device.device, - &encoder, - &device.kernels, - kernel_name, - el_count, - s, - dst, - ) - .map_err(MetalError::from)?; - } - (_, _, true) => { - use candle_metal_kernels::unary::contiguous; - let kernel_name = match dtype { - DType::F16 => contiguous::const_set::HALF, - DType::BF16 => contiguous::const_set::BFLOAT, - DType::F32 => contiguous::const_set::FLOAT, - DType::I64 => contiguous::const_set::I64, - DType::U32 => contiguous::const_set::U32, - DType::U8 => contiguous::const_set::U8, - DType::F8E4M3 => crate::bail!("unsupported const-set f8e4m3"), - DType::F64 => crate::bail!("unsupported const-set f64"), - DType::F4 - | DType::F6E2M3 - | DType::F6E3M2 - | DType::F8E8M0 - | DType::I16 - | DType::I32 => { - return Err(Error::UnsupportedDTypeForOp(dtype, "const-set").bt()) - } - }; - candle_metal_kernels::call_const_set_contiguous( - &device.device, - &encoder, - &device.kernels, - kernel_name, - el_count, - s, - dst, - ) - .map_err(MetalError::from)?; - } - (_, _, false) => { - use candle_metal_kernels::unary::strided; - let kernel_name = match dtype { - DType::F16 => strided::const_set::HALF, - DType::BF16 => strided::const_set::BFLOAT, - DType::F32 => strided::const_set::FLOAT, - DType::I64 => strided::const_set::I64, - DType::U32 => strided::const_set::U32, - DType::U8 => strided::const_set::U8, - DType::F8E4M3 => crate::bail!("unsupported const-set f8e4m3"), - DType::F64 => crate::bail!("unsupported const-set f64"), - DType::F4 - | DType::F6E2M3 - | DType::F6E3M2 - | DType::F8E8M0 - | DType::I16 - | DType::I32 => { - return Err(Error::UnsupportedDTypeForOp(dtype, "const-set").bt()) - } - }; - candle_metal_kernels::call_const_set_strided( - &device.device, - &encoder, - &device.kernels, - kernel_name, - l.dims(), - s, - l.stride(), - dst, - ) - .map_err(MetalError::from)?; - } + if l.is_contiguous() { + use candle_metal_kernels::unary::contiguous; + let kernel_name = match dtype { + DType::F16 => contiguous::const_set::HALF, + DType::BF16 => contiguous::const_set::BFLOAT, + DType::F32 => contiguous::const_set::FLOAT, + DType::I64 => contiguous::const_set::I64, + DType::U32 => contiguous::const_set::U32, + DType::U8 => contiguous::const_set::U8, + DType::F8E4M3 => crate::bail!("unsupported const-set f8e4m3"), + DType::F64 => crate::bail!("unsupported const-set f64"), + DType::F4 + | DType::F6E2M3 + | DType::F6E3M2 + | DType::F8E8M0 + | DType::I16 + | DType::I32 => { + return Err(Error::UnsupportedDTypeForOp(dtype, "const-set").bt()) + } + }; + candle_metal_kernels::call_const_set_contiguous( + &device.device, + &encoder, + &device.kernels, + kernel_name, + dtype.size_in_bytes(), + el_count, + s, + dst, + ) + .map_err(MetalError::from)?; + } else { + use candle_metal_kernels::unary::strided; + let kernel_name = match dtype { + DType::F16 => strided::const_set::HALF, + DType::BF16 => strided::const_set::BFLOAT, + DType::F32 => strided::const_set::FLOAT, + DType::I64 => strided::const_set::I64, + DType::U32 => strided::const_set::U32, + DType::U8 => strided::const_set::U8, + DType::F8E4M3 => crate::bail!("unsupported const-set f8e4m3"), + DType::F64 => crate::bail!("unsupported const-set f64"), + DType::F4 + | DType::F6E2M3 + | DType::F6E3M2 + | DType::F8E8M0 + | DType::I16 + | DType::I32 => { + return Err(Error::UnsupportedDTypeForOp(dtype, "const-set").bt()) + } + }; + candle_metal_kernels::call_const_set_strided( + &device.device, + &encoder, + &device.kernels, + kernel_name, + l.dims(), + s, + l.stride(), + dst, + ) + .map_err(MetalError::from)?; } Ok(()) } @@ -670,235 +653,156 @@ impl BackendStorage for MetalStorage { encoder.set_label(B::KERNEL); let src = buffer_o(&self.buffer, layout, self.dtype); - match (el_count % 2, dtype, layout.is_contiguous()) { - (0, DType::BF16 | DType::F16, true) => { - use candle_metal_kernels::unary::contiguous_tiled; - let kernel_name = match (B::KERNEL, dtype) { - ("uabs", DType::F16) => contiguous_tiled::abs::HALF, - ("uabs", DType::F32) => contiguous_tiled::abs::FLOAT, - ("uabs", DType::BF16) => contiguous_tiled::abs::BFLOAT, - ("uceil", DType::F16) => contiguous_tiled::ceil::HALF, - ("uceil", DType::F32) => contiguous_tiled::ceil::FLOAT, - ("uceil", DType::BF16) => contiguous_tiled::ceil::BFLOAT, - ("ucos", DType::F16) => contiguous_tiled::cos::HALF, - ("ucos", DType::F32) => contiguous_tiled::cos::FLOAT, - ("ucos", DType::BF16) => contiguous_tiled::cos::BFLOAT, - ("uerf", DType::F16) => contiguous_tiled::erf::HALF, - ("uerf", DType::F32) => contiguous_tiled::erf::FLOAT, - ("uerf", DType::BF16) => contiguous_tiled::erf::BFLOAT, - ("uexp", DType::F16) => contiguous_tiled::exp::HALF, - ("uexp", DType::F32) => contiguous_tiled::exp::FLOAT, - ("uexp", DType::BF16) => contiguous_tiled::exp::BFLOAT, - ("ufloor", DType::F16) => contiguous_tiled::floor::HALF, - ("ufloor", DType::F32) => contiguous_tiled::floor::FLOAT, - ("ufloor", DType::BF16) => contiguous_tiled::floor::BFLOAT, - ("ugelu_erf", DType::F16) => contiguous_tiled::gelu_erf::HALF, - ("ugelu_erf", DType::F32) => contiguous_tiled::gelu_erf::FLOAT, - ("ugelu_erf", DType::BF16) => contiguous_tiled::gelu_erf::BFLOAT, - ("ugelu", DType::F16) => contiguous_tiled::gelu::HALF, - ("ugelu", DType::F32) => contiguous_tiled::gelu::FLOAT, - ("ugelu", DType::BF16) => contiguous_tiled::gelu::BFLOAT, - ("ulog", DType::F16) => contiguous_tiled::log::HALF, - ("ulog", DType::F32) => contiguous_tiled::log::FLOAT, - ("ulog", DType::BF16) => contiguous_tiled::log::BFLOAT, - ("uneg", DType::F16) => contiguous_tiled::neg::HALF, - ("uneg", DType::F32) => contiguous_tiled::neg::FLOAT, - ("uneg", DType::BF16) => contiguous_tiled::neg::BFLOAT, - ("urecip", DType::F16) => contiguous_tiled::recip::HALF, - ("urecip", DType::F32) => contiguous_tiled::recip::FLOAT, - ("urecip", DType::BF16) => contiguous_tiled::recip::BFLOAT, - ("urelu", DType::F16) => contiguous_tiled::relu::HALF, - ("urelu", DType::F32) => contiguous_tiled::relu::FLOAT, - ("urelu", DType::BF16) => contiguous_tiled::relu::BFLOAT, - ("uround", DType::F16) => contiguous_tiled::round::HALF, - ("uround", DType::F32) => contiguous_tiled::round::FLOAT, - ("uround", DType::BF16) => contiguous_tiled::round::BFLOAT, - ("usilu", DType::F16) => contiguous_tiled::silu::HALF, - ("usilu", DType::F32) => contiguous_tiled::silu::FLOAT, - ("usilu", DType::BF16) => contiguous_tiled::silu::BFLOAT, - ("usin", DType::F16) => contiguous_tiled::sin::HALF, - ("usin", DType::F32) => contiguous_tiled::sin::FLOAT, - ("usin", DType::BF16) => contiguous_tiled::sin::BFLOAT, - ("usqr", DType::F16) => contiguous_tiled::sqr::HALF, - ("usqr", DType::F32) => contiguous_tiled::sqr::FLOAT, - ("usqr", DType::BF16) => contiguous_tiled::sqr::BFLOAT, - ("usqrt", DType::F16) => contiguous_tiled::sqrt::HALF, - ("usqrt", DType::F32) => contiguous_tiled::sqrt::FLOAT, - ("usqrt", DType::BF16) => contiguous_tiled::sqrt::BFLOAT, - ("utanh", DType::F16) => contiguous_tiled::tanh::HALF, - ("utanh", DType::F32) => contiguous_tiled::tanh::FLOAT, - ("utanh", DType::BF16) => contiguous_tiled::tanh::BFLOAT, - ("usign", DType::F16) => contiguous_tiled::sign::HALF, - ("usign", DType::F32) => contiguous_tiled::sign::FLOAT, - ("usign", DType::BF16) => contiguous_tiled::sign::BFLOAT, - ("usign", DType::I64) => contiguous_tiled::sign::I64, - (name, dtype) => { - crate::bail!( - "Metal contiguous_tiled unary {name} {dtype:?} not implemented" - ) - } - }; - candle_metal_kernels::call_unary_contiguous_tiled( - &device.device, - &encoder, - &device.kernels, - kernel_name, - el_count, - src, - &buffer, - ) - .map_err(MetalError::from)?; - } - (_, _, true) => { - use candle_metal_kernels::unary::contiguous; - let kernel_name = match (B::KERNEL, dtype) { - ("uabs", DType::F16) => contiguous::abs::HALF, - ("uabs", DType::F32) => contiguous::abs::FLOAT, - ("uabs", DType::BF16) => contiguous::abs::BFLOAT, - ("uceil", DType::F16) => contiguous::ceil::HALF, - ("uceil", DType::F32) => contiguous::ceil::FLOAT, - ("uceil", DType::BF16) => contiguous::ceil::BFLOAT, - ("ucos", DType::F16) => contiguous::cos::HALF, - ("ucos", DType::F32) => contiguous::cos::FLOAT, - ("ucos", DType::BF16) => contiguous::cos::BFLOAT, - ("uerf", DType::F16) => contiguous::erf::HALF, - ("uerf", DType::F32) => contiguous::erf::FLOAT, - ("uerf", DType::BF16) => contiguous::erf::BFLOAT, - ("uexp", DType::F16) => contiguous::exp::HALF, - ("uexp", DType::F32) => contiguous::exp::FLOAT, - ("uexp", DType::BF16) => contiguous::exp::BFLOAT, - ("ufloor", DType::F16) => contiguous::floor::HALF, - ("ufloor", DType::F32) => contiguous::floor::FLOAT, - ("ufloor", DType::BF16) => contiguous::floor::BFLOAT, - ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, - ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, - ("ugelu_erf", DType::BF16) => contiguous::gelu_erf::BFLOAT, - ("ugelu", DType::F16) => contiguous::gelu::HALF, - ("ugelu", DType::F32) => contiguous::gelu::FLOAT, - ("ugelu", DType::BF16) => contiguous::gelu::BFLOAT, - ("ulog", DType::F16) => contiguous::log::HALF, - ("ulog", DType::F32) => contiguous::log::FLOAT, - ("ulog", DType::BF16) => contiguous::log::BFLOAT, - ("uneg", DType::F16) => contiguous::neg::HALF, - ("uneg", DType::F32) => contiguous::neg::FLOAT, - ("uneg", DType::BF16) => contiguous::neg::BFLOAT, - ("urecip", DType::F16) => contiguous::recip::HALF, - ("urecip", DType::F32) => contiguous::recip::FLOAT, - ("urecip", DType::BF16) => contiguous::recip::BFLOAT, - ("urelu", DType::F16) => contiguous::relu::HALF, - ("urelu", DType::F32) => contiguous::relu::FLOAT, - ("urelu", DType::BF16) => contiguous::relu::BFLOAT, - ("uround", DType::F16) => contiguous::round::HALF, - ("uround", DType::F32) => contiguous::round::FLOAT, - ("uround", DType::BF16) => contiguous::round::BFLOAT, - ("usilu", DType::F16) => contiguous::silu::HALF, - ("usilu", DType::F32) => contiguous::silu::FLOAT, - ("usilu", DType::BF16) => contiguous::silu::BFLOAT, - ("usin", DType::F16) => contiguous::sin::HALF, - ("usin", DType::F32) => contiguous::sin::FLOAT, - ("usin", DType::BF16) => contiguous::sin::BFLOAT, - ("usqr", DType::F16) => contiguous::sqr::HALF, - ("usqr", DType::F32) => contiguous::sqr::FLOAT, - ("usqr", DType::BF16) => contiguous::sqr::BFLOAT, - ("usqrt", DType::F16) => contiguous::sqrt::HALF, - ("usqrt", DType::F32) => contiguous::sqrt::FLOAT, - ("usqrt", DType::BF16) => contiguous::sqrt::BFLOAT, - ("utanh", DType::F16) => contiguous::tanh::HALF, - ("utanh", DType::F32) => contiguous::tanh::FLOAT, - ("utanh", DType::BF16) => contiguous::tanh::BFLOAT, - ("usign", DType::F16) => contiguous::sign::HALF, - ("usign", DType::F32) => contiguous::sign::FLOAT, - ("usign", DType::BF16) => contiguous::sign::BFLOAT, - ("usign", DType::I64) => contiguous::sign::I64, - (name, dtype) => { - crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented") - } - }; - candle_metal_kernels::call_unary_contiguous( - &device.device, - &encoder, - &device.kernels, - kernel_name, - el_count, - src, - &buffer, - ) - .map_err(MetalError::from)?; - } - (_, _, false) => { - use candle_metal_kernels::unary::strided; - let kernel_name = match (B::KERNEL, dtype) { - ("ucos", DType::F32) => strided::cos::FLOAT, - ("usin", DType::F32) => strided::sin::FLOAT, - ("usqr", DType::F32) => strided::sqr::FLOAT, - ("usqrt", DType::F32) => strided::sqrt::FLOAT, - ("uneg", DType::F32) => strided::neg::FLOAT, - ("uexp", DType::F32) => strided::exp::FLOAT, - ("ulog", DType::F32) => strided::log::FLOAT, - ("ugelu", DType::F32) => strided::gelu::FLOAT, - ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT, - ("uerf", DType::F32) => strided::erf::FLOAT, - ("usilu", DType::F32) => strided::silu::FLOAT, - ("uabs", DType::F32) => strided::abs::FLOAT, - ("uceil", DType::F32) => strided::ceil::FLOAT, - ("ufloor", DType::F32) => strided::floor::FLOAT, - ("urelu", DType::F32) => strided::relu::FLOAT, - ("uround", DType::F32) => strided::round::FLOAT, - ("utanh", DType::F32) => strided::tanh::FLOAT, - - ("ucos", DType::F16) => strided::cos::HALF, - ("usin", DType::F16) => strided::sin::HALF, - ("usqr", DType::F16) => strided::sqr::HALF, - ("usqrt", DType::F16) => strided::sqrt::HALF, - ("uneg", DType::F16) => strided::neg::HALF, - ("uexp", DType::F16) => strided::exp::HALF, - ("ulog", DType::F16) => strided::log::HALF, - ("ugelu", DType::F16) => strided::gelu::HALF, - ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF, - ("uerf", DType::F16) => strided::erf::HALF, - ("usilu", DType::F16) => strided::silu::HALF, - ("uabs", DType::F16) => strided::abs::HALF, - ("uceil", DType::F16) => strided::ceil::HALF, - ("ufloor", DType::F16) => strided::floor::HALF, - ("urelu", DType::F16) => strided::relu::HALF, - ("uround", DType::F16) => strided::round::HALF, - ("utanh", DType::F16) => strided::tanh::HALF, - - ("ucos", DType::BF16) => strided::cos::BFLOAT, - ("usin", DType::BF16) => strided::sin::BFLOAT, - ("usqr", DType::BF16) => strided::sqr::BFLOAT, - ("usqrt", DType::BF16) => strided::sqrt::BFLOAT, - ("uneg", DType::BF16) => strided::neg::BFLOAT, - ("uexp", DType::BF16) => strided::exp::BFLOAT, - ("ulog", DType::BF16) => strided::log::BFLOAT, - ("ugelu", DType::BF16) => strided::gelu::BFLOAT, - ("ugelu_erf", DType::BF16) => strided::gelu_erf::BFLOAT, - ("uerf", DType::BF16) => strided::erf::BFLOAT, - ("usilu", DType::BF16) => strided::silu::BFLOAT, - ("uabs", DType::BF16) => strided::abs::BFLOAT, - ("uceil", DType::BF16) => strided::ceil::BFLOAT, - ("ufloor", DType::BF16) => strided::floor::BFLOAT, - ("urelu", DType::BF16) => strided::relu::BFLOAT, - ("uround", DType::BF16) => strided::round::BFLOAT, - ("utanh", DType::BF16) => strided::tanh::BFLOAT, - - (name, dtype) => { - crate::bail!("Metal strided unary {name} {dtype:?} not implemented") - } - }; - let dst = BufferOffset::zero_offset(&buffer); - candle_metal_kernels::call_unary_strided( - &device.device, - &encoder, - &device.kernels, - kernel_name, - layout.dims(), - src, - layout.stride(), - dst, - ) - .map_err(MetalError::from)?; - } + if layout.is_contiguous() { + use candle_metal_kernels::unary::contiguous; + let kernel_name = match (B::KERNEL, dtype) { + ("uabs", DType::F16) => contiguous::abs::HALF, + ("uabs", DType::F32) => contiguous::abs::FLOAT, + ("uabs", DType::BF16) => contiguous::abs::BFLOAT, + ("uceil", DType::F16) => contiguous::ceil::HALF, + ("uceil", DType::F32) => contiguous::ceil::FLOAT, + ("uceil", DType::BF16) => contiguous::ceil::BFLOAT, + ("ucos", DType::F16) => contiguous::cos::HALF, + ("ucos", DType::F32) => contiguous::cos::FLOAT, + ("ucos", DType::BF16) => contiguous::cos::BFLOAT, + ("uerf", DType::F16) => contiguous::erf::HALF, + ("uerf", DType::F32) => contiguous::erf::FLOAT, + ("uerf", DType::BF16) => contiguous::erf::BFLOAT, + ("uexp", DType::F16) => contiguous::exp::HALF, + ("uexp", DType::F32) => contiguous::exp::FLOAT, + ("uexp", DType::BF16) => contiguous::exp::BFLOAT, + ("ufloor", DType::F16) => contiguous::floor::HALF, + ("ufloor", DType::F32) => contiguous::floor::FLOAT, + ("ufloor", DType::BF16) => contiguous::floor::BFLOAT, + ("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF, + ("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT, + ("ugelu_erf", DType::BF16) => contiguous::gelu_erf::BFLOAT, + ("ugelu", DType::F16) => contiguous::gelu::HALF, + ("ugelu", DType::F32) => contiguous::gelu::FLOAT, + ("ugelu", DType::BF16) => contiguous::gelu::BFLOAT, + ("ulog", DType::F16) => contiguous::log::HALF, + ("ulog", DType::F32) => contiguous::log::FLOAT, + ("ulog", DType::BF16) => contiguous::log::BFLOAT, + ("uneg", DType::F16) => contiguous::neg::HALF, + ("uneg", DType::F32) => contiguous::neg::FLOAT, + ("uneg", DType::BF16) => contiguous::neg::BFLOAT, + ("urecip", DType::F16) => contiguous::recip::HALF, + ("urecip", DType::F32) => contiguous::recip::FLOAT, + ("urecip", DType::BF16) => contiguous::recip::BFLOAT, + ("urelu", DType::F16) => contiguous::relu::HALF, + ("urelu", DType::F32) => contiguous::relu::FLOAT, + ("urelu", DType::BF16) => contiguous::relu::BFLOAT, + ("uround", DType::F16) => contiguous::round::HALF, + ("uround", DType::F32) => contiguous::round::FLOAT, + ("uround", DType::BF16) => contiguous::round::BFLOAT, + ("usilu", DType::F16) => contiguous::silu::HALF, + ("usilu", DType::F32) => contiguous::silu::FLOAT, + ("usilu", DType::BF16) => contiguous::silu::BFLOAT, + ("usin", DType::F16) => contiguous::sin::HALF, + ("usin", DType::F32) => contiguous::sin::FLOAT, + ("usin", DType::BF16) => contiguous::sin::BFLOAT, + ("usqr", DType::F16) => contiguous::sqr::HALF, + ("usqr", DType::F32) => contiguous::sqr::FLOAT, + ("usqr", DType::BF16) => contiguous::sqr::BFLOAT, + ("usqrt", DType::F16) => contiguous::sqrt::HALF, + ("usqrt", DType::F32) => contiguous::sqrt::FLOAT, + ("usqrt", DType::BF16) => contiguous::sqrt::BFLOAT, + ("utanh", DType::F16) => contiguous::tanh::HALF, + ("utanh", DType::F32) => contiguous::tanh::FLOAT, + ("utanh", DType::BF16) => contiguous::tanh::BFLOAT, + ("usign", DType::F16) => contiguous::sign::HALF, + ("usign", DType::F32) => contiguous::sign::FLOAT, + ("usign", DType::BF16) => contiguous::sign::BFLOAT, + ("usign", DType::I64) => contiguous::sign::I64, + (name, dtype) => { + crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented") + } + }; + + candle_metal_kernels::call_unary_contiguous( + &device.device, + &encoder, + &device.kernels, + kernel_name, + dtype.size_in_bytes(), + el_count, + src, + &buffer, + ) + .map_err(MetalError::from)?; + } else { + use candle_metal_kernels::unary::strided; + let kernel_name = match (B::KERNEL, dtype) { + ("ucos", DType::F32) => strided::cos::FLOAT, + ("usin", DType::F32) => strided::sin::FLOAT, + ("usqr", DType::F32) => strided::sqr::FLOAT, + ("usqrt", DType::F32) => strided::sqrt::FLOAT, + ("uneg", DType::F32) => strided::neg::FLOAT, + ("uexp", DType::F32) => strided::exp::FLOAT, + ("ulog", DType::F32) => strided::log::FLOAT, + ("ugelu", DType::F32) => strided::gelu::FLOAT, + ("ugelu_erf", DType::F32) => strided::gelu_erf::FLOAT, + ("uerf", DType::F32) => strided::erf::FLOAT, + ("usilu", DType::F32) => strided::silu::FLOAT, + ("uabs", DType::F32) => strided::abs::FLOAT, + ("uceil", DType::F32) => strided::ceil::FLOAT, + ("ufloor", DType::F32) => strided::floor::FLOAT, + ("urelu", DType::F32) => strided::relu::FLOAT, + ("uround", DType::F32) => strided::round::FLOAT, + ("utanh", DType::F32) => strided::tanh::FLOAT, + + ("ucos", DType::F16) => strided::cos::HALF, + ("usin", DType::F16) => strided::sin::HALF, + ("usqr", DType::F16) => strided::sqr::HALF, + ("usqrt", DType::F16) => strided::sqrt::HALF, + ("uneg", DType::F16) => strided::neg::HALF, + ("uexp", DType::F16) => strided::exp::HALF, + ("ulog", DType::F16) => strided::log::HALF, + ("ugelu", DType::F16) => strided::gelu::HALF, + ("ugelu_erf", DType::F16) => strided::gelu_erf::HALF, + ("uerf", DType::F16) => strided::erf::HALF, + ("usilu", DType::F16) => strided::silu::HALF, + ("uabs", DType::F16) => strided::abs::HALF, + ("uceil", DType::F16) => strided::ceil::HALF, + ("ufloor", DType::F16) => strided::floor::HALF, + ("urelu", DType::F16) => strided::relu::HALF, + ("uround", DType::F16) => strided::round::HALF, + ("utanh", DType::F16) => strided::tanh::HALF, + + ("ucos", DType::BF16) => strided::cos::BFLOAT, + ("usin", DType::BF16) => strided::sin::BFLOAT, + ("usqr", DType::BF16) => strided::sqr::BFLOAT, + ("usqrt", DType::BF16) => strided::sqrt::BFLOAT, + ("uneg", DType::BF16) => strided::neg::BFLOAT, + ("uexp", DType::BF16) => strided::exp::BFLOAT, + ("ulog", DType::BF16) => strided::log::BFLOAT, + ("ugelu", DType::BF16) => strided::gelu::BFLOAT, + ("ugelu_erf", DType::BF16) => strided::gelu_erf::BFLOAT, + ("uerf", DType::BF16) => strided::erf::BFLOAT, + ("usilu", DType::BF16) => strided::silu::BFLOAT, + ("uabs", DType::BF16) => strided::abs::BFLOAT, + ("uceil", DType::BF16) => strided::ceil::BFLOAT, + ("ufloor", DType::BF16) => strided::floor::BFLOAT, + ("urelu", DType::BF16) => strided::relu::BFLOAT, + ("uround", DType::BF16) => strided::round::BFLOAT, + ("utanh", DType::BF16) => strided::tanh::BFLOAT, + + (name, dtype) => { + crate::bail!("Metal strided unary {name} {dtype:?} not implemented") + } + }; + let dst = BufferOffset::zero_offset(&buffer); + candle_metal_kernels::call_unary_strided( + &device.device, + &encoder, + &device.kernels, + kernel_name, + layout.dims(), + src, + layout.stride(), + dst, + ) + .map_err(MetalError::from)?; } Ok(Self::new(buffer, device.clone(), el_count, dtype)) diff --git a/candle-metal-kernels/src/kernels/affine.rs b/candle-metal-kernels/src/kernels/affine.rs index 21a179e433..818282fe47 100644 --- a/candle-metal-kernels/src/kernels/affine.rs +++ b/candle-metal-kernels/src/kernels/affine.rs @@ -1,5 +1,5 @@ -use crate::linear_split; use crate::utils::{BufferOffset, EncoderProvider}; +use crate::{get_tile_size, linear_split}; use crate::{set_params, Buffer, ComputeCommandEncoder, Device, Kernels, MetalKernelError, Source}; use objc2_metal::MTLResourceUsage; @@ -9,6 +9,7 @@ pub fn call_affine( ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, + dtype_size: usize, size: usize, input: BufferOffset, output: &Buffer, @@ -23,7 +24,9 @@ pub fn call_affine( set_params!(encoder, (size, mul, add, &input, output)); - let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + let tile_size = get_tile_size(dtype_size); + let tiles = size.div_ceil(tile_size); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles); encoder.use_resource(input.buffer, MTLResourceUsage::Read); encoder.use_resource(output, MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); @@ -77,6 +80,7 @@ pub fn call_powf( ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, + dtype_size: usize, size: usize, input: BufferOffset, output: &Buffer, @@ -90,7 +94,9 @@ pub fn call_powf( set_params!(encoder, (size, mul, &input, output)); - let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + let tile_size = get_tile_size(dtype_size); + let tiles = size.div_ceil(tile_size); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles); encoder.use_resource(input.buffer, MTLResourceUsage::Read); encoder.use_resource(output, MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); @@ -134,6 +140,7 @@ pub fn call_elu( ep: impl EncoderProvider, kernels: &Kernels, name: &'static str, + dtype_size: usize, size: usize, input: BufferOffset, output: &Buffer, @@ -147,7 +154,9 @@ pub fn call_elu( set_params!(encoder, (size, mul, &input, output)); - let (thread_group_count, thread_group_size) = linear_split(&pipeline, size); + let tile_size = get_tile_size(dtype_size); + let tiles = size.div_ceil(tile_size); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles); encoder.use_resource(input.buffer, MTLResourceUsage::Read); encoder.use_resource(output, MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); diff --git a/candle-metal-kernels/src/kernels/macros.rs b/candle-metal-kernels/src/kernels/macros.rs index 5088e7dec6..9cff9671ed 100644 --- a/candle-metal-kernels/src/kernels/macros.rs +++ b/candle-metal-kernels/src/kernels/macros.rs @@ -25,30 +25,6 @@ macro_rules! ops{ } } - pub mod contiguous_tiled { - pub struct Kernel(pub &'static str); - $( - pub mod $name { - use super::Kernel; - pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_f32_tiled")); - pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_f16_tiled")); - pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bf16_tiled")); - pub const I64: Kernel = Kernel(concat!(stringify!($name), "_i64_tiled")); - pub const U32: Kernel = Kernel(concat!(stringify!($name), "_u32_tiled")); - pub const U8: Kernel = Kernel(concat!(stringify!($name), "_u8_tiled")); - } - )+ - pub mod copy { - use super::Kernel; - pub const FLOAT: Kernel = Kernel("copy_f32_tiled"); - pub const HALF: Kernel = Kernel("copy_f16_tiled"); - pub const BFLOAT: Kernel = Kernel("copy_bf16_tiled"); - pub const I64: Kernel = Kernel("copy_i64_tiled"); - pub const U32: Kernel = Kernel("copy_u32_tiled"); - pub const U8: Kernel = Kernel("copy_u8_tiled"); - } - } - pub mod strided { pub struct Kernel(pub &'static str); $( diff --git a/candle-metal-kernels/src/kernels/unary.rs b/candle-metal-kernels/src/kernels/unary.rs index 89a945e5ce..40fae63547 100644 --- a/candle-metal-kernels/src/kernels/unary.rs +++ b/candle-metal-kernels/src/kernels/unary.rs @@ -1,6 +1,6 @@ use crate::kernels::macros::ops; use crate::utils::{BufferOffset, EncoderProvider}; -use crate::{get_block_dims, linear_split}; +use crate::{get_block_dims, get_tile_size, linear_split}; use crate::{ set_params, Buffer, ComputeCommandEncoder, Device, EncoderParam, Kernels, MetalKernelError, Source, @@ -18,6 +18,7 @@ pub fn call_unary_contiguous( ep: impl EncoderProvider, kernels: &Kernels, kernel_name: contiguous::Kernel, + dtype_size: usize, length: usize, input: BufferOffset, output: &Buffer, @@ -30,33 +31,8 @@ pub fn call_unary_contiguous( set_params!(encoder, (length, &input, output)); - let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); - encoder.use_resource(input.buffer, MTLResourceUsage::Read); - encoder.use_resource(output, MTLResourceUsage::Write); - encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - Ok(()) -} - -#[allow(clippy::too_many_arguments)] -pub fn call_unary_contiguous_tiled( - device: &Device, - ep: impl EncoderProvider, - kernels: &Kernels, - kernel_name: contiguous_tiled::Kernel, - length: usize, - input: BufferOffset, - output: &Buffer, -) -> Result<(), MetalKernelError> { - let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; - let encoder = ep.encoder(); - let encoder: &ComputeCommandEncoder = encoder.as_ref(); - let tile_size = 2; + let tile_size = get_tile_size(dtype_size); let tiles = length.div_ceil(tile_size); - - encoder.set_compute_pipeline_state(&pipeline); - - set_params!(encoder, (length, &input, output)); - let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles); encoder.use_resource(input.buffer, MTLResourceUsage::Read); encoder.use_resource(output, MTLResourceUsage::Write); @@ -91,38 +67,13 @@ pub fn call_unary_strided( Ok(()) } -#[allow(clippy::too_many_arguments)] -pub fn call_const_set_contiguous_tiled( - device: &Device, - ep: impl EncoderProvider, - kernels: &Kernels, - kernel_name: contiguous_tiled::Kernel, - length: usize, - input: impl EncoderParam, - output: BufferOffset, -) -> Result<(), MetalKernelError> { - let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?; - let encoder = ep.encoder(); - let encoder: &ComputeCommandEncoder = encoder.as_ref(); - let tile_size = 2; - let tiles = length.div_ceil(tile_size); - - encoder.set_compute_pipeline_state(&pipeline); - - set_params!(encoder, (length, input, &output)); - - let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles); - encoder.use_resource(output.buffer, MTLResourceUsage::Write); - encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - Ok(()) -} - #[allow(clippy::too_many_arguments)] pub fn call_const_set_contiguous( device: &Device, ep: impl EncoderProvider, kernels: &Kernels, kernel_name: contiguous::Kernel, + dtype_size: usize, length: usize, input: impl EncoderParam, output: BufferOffset, @@ -132,10 +83,11 @@ pub fn call_const_set_contiguous( let encoder: &ComputeCommandEncoder = encoder.as_ref(); encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (length, input, &output)); - let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + let tile_size = get_tile_size(dtype_size); + let tiles = length.div_ceil(tile_size); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles); encoder.use_resource(output.buffer, MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); Ok(()) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 827d2837b0..4d947ceff5 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -19,7 +19,7 @@ use metal::{ use objc2_metal::{MTLCompileOptions, MTLMathFloatingPointFunctions, MTLMathMode, MTLSize}; use source::Source; pub use utils::BufferOffset; -use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider}; +use utils::{get_block_dims, get_tile_size, linear_split, EncoderParam, EncoderProvider}; pub const RESOURCE_OPTIONS: MTLResourceOptions = objc2_metal::MTLResourceOptions(MTLResourceOptions::StorageModeShared.bits()); diff --git a/candle-metal-kernels/src/metal_src/affine.metal b/candle-metal-kernels/src/metal_src/affine.metal index 7f4c6ccfbb..b03364dfdb 100644 --- a/candle-metal-kernels/src/metal_src/affine.metal +++ b/candle-metal-kernels/src/metal_src/affine.metal @@ -1,5 +1,7 @@ #include +using namespace metal; +// Utils METAL_FUNC uint get_strided_index( uint idx, constant size_t &num_dims, @@ -15,113 +17,162 @@ METAL_FUNC uint get_strided_index( return strided_i; } -using namespace metal; +#define MAX(x, y) ((x) > (y) ? (x) : (y)) -#define AFFINE(FN_NAME, T) \ -kernel void FN_NAME( \ - constant size_t &dim, \ - constant float &mul, \ - constant float &add, \ - device const T *input, \ - device T *output, \ - uint id [[ thread_position_in_grid ]] \ -) { \ - if (id >= dim) { \ - return; \ - } \ - output[id] = T(fma(float(input[id]), mul, add)); \ -} \ -kernel void FN_NAME##_strided( \ - constant size_t &dim, \ - constant size_t &num_dims, \ - constant size_t *dims, \ - constant size_t *strides, \ - constant float &mul, \ - constant float &add, \ - device const T *input, \ - device T *output, \ - uint id [[ thread_position_in_grid ]] \ -) { \ - if (id >= dim) { \ - return; \ - } \ - output[id] = T(fma(float(input[get_strided_index(id, num_dims, dims, strides)]), mul, add)); \ +template +constexpr int work_per_thread() { + constexpr int wpt = 8 / sizeof(T); + return MAX(1, wpt); } -#define POWF(FN_NAME, TYPENAME) \ -kernel void FN_NAME( \ - constant size_t &dim, \ - constant float &mul, \ - device const TYPENAME *input, \ - device TYPENAME *output, \ - uint id [[ thread_position_in_grid ]] \ -) { \ - if (id >= dim) { \ - return; \ - } \ - output[id] = TYPENAME(pow(input[id], TYPENAME(mul))); \ -} \ -kernel void FN_NAME##_strided( \ - constant size_t &dim, \ - constant size_t &num_dims, \ - constant size_t *dims, \ - constant size_t *strides, \ - constant float &mul, \ - device const TYPENAME *input, \ - device TYPENAME *output, \ - uint id [[ thread_position_in_grid ]] \ -) { \ - if (id >= dim) { \ - return; \ - } \ - output[id] = TYPENAME(pow(input[get_strided_index(id, num_dims, dims, strides)], TYPENAME(mul))); \ +// Kernels +template ()> +[[kernel]] void affine_kernel( + constant size_t &dim, + constant float &mul, + constant float &add, + device const T *input, + device T *output, + uint tid [[thread_position_in_grid]] +) { + tid *= W; + if (W > 1 && tid + W > dim) { + for (int i = 0; tid + i < dim; ++i) { + float result = fma(float(input[tid + i]), mul, add); + output[tid + i] = static_cast(result); + } + } else { + for (int i = 0; i < W; ++i) { + float result = fma(float(input[tid + i]), mul, add); + output[tid + i] = static_cast(result); + } + } +} + +template +[[kernel]] void affine_kernel_strided( + constant size_t &dim, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides, + constant float &mul, + constant float &add, + constant const T *input, + device T *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dim) return; + uint idx = get_strided_index(tid, num_dims, dims, strides); + float result = fma(float(input[idx]), mul, add); + output[tid] = static_cast(result); } -#define ELU(FN_NAME, TYPENAME) \ -kernel void FN_NAME( \ - constant size_t &dim, \ - constant float &mul, \ - device const TYPENAME *input, \ - device TYPENAME *output, \ - uint id [[ thread_position_in_grid ]] \ -) { \ - if (id >= dim) { \ - return; \ - } \ - const TYPENAME x = input[id]; \ - output[id] = TYPENAME((x > 0)?x: mul * (exp(x) - 1)); \ -} \ -kernel void FN_NAME##_strided( \ - constant size_t &dim, \ - constant size_t &num_dims, \ - constant size_t *dims, \ - constant size_t *strides, \ - constant float &mul, \ - device const TYPENAME *input, \ - device TYPENAME *output, \ - uint id [[ thread_position_in_grid ]] \ -) { \ - if (id >= dim) { \ - return; \ - } \ - const TYPENAME x = input[get_strided_index(id, num_dims, dims, strides)]; \ - output[id] = TYPENAME((x > 0)?x: mul * (exp(x) - 1)); \ -} \ - - -AFFINE(affine_u8, uint8_t) -AFFINE(affine_u32, uint32_t) -AFFINE(affine_i64, int64_t) -AFFINE(affine_f32, float) -AFFINE(affine_f16, half) -POWF(powf_f32, float) -POWF(powf_f16, half) -ELU(elu_f32, float) -ELU(elu_f16, half) +template ()> +[[kernel]] void powf_kernel( + constant size_t &dim, + constant float &mul, + device const T *input, + device T *output, + uint tid [[thread_position_in_grid]] +) { + tid *= W; + if (W > 1 && tid + W > dim) { + for (int i = 0; tid + i < dim; ++i) { + output[tid + i] = static_cast(pow(static_cast(input[tid + i]), mul)); + } + } else { + for (int i = 0; i < W; ++i) { + output[tid + i] = static_cast(pow(static_cast(input[tid + i]), mul)); + } + } +} + +template +[[kernel]] void powf_kernel_strided( + constant size_t &dim, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides, + constant float &mul, + constant const T *input, + device T *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dim) return; + uint idx = get_strided_index(tid, num_dims, dims, strides); + output[tid] = static_cast(pow(static_cast(input[idx]), mul)); +} + +template ()> +[[kernel]] void elu_kernel( + constant size_t &dim, + constant float &mul, + device const T *input, + device T *output, + uint tid [[thread_position_in_grid]] +) { + tid *= W; + if (W > 1 && tid + W > dim) { + for (int i = 0; tid + i < dim; ++i) { + const T x = input[tid + i]; + output[tid + i] = static_cast((x > 0) ? x : mul * (exp(x) - 1)); + } + } else { + for (int i = 0; i < W; ++i) { + const T x = input[tid + i]; + output[tid + i] = static_cast((x > 0) ? x : mul * (exp(x) - 1)); + } + } +} + +template +[[kernel]] void elu_kernel_strided( + constant size_t &dim, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides, + constant float &mul, + constant const T *input, + device T *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dim) return; + uint idx = get_strided_index(tid, num_dims, dims, strides); + const T x = input[idx]; + output[tid] = static_cast((x > 0) ? x : mul * (exp(x) - 1)); +} + +// Macros to help initialize kernels +#define init_kernel(name, func, ...) \ + template [[host_name(name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>; + +#define init_affine(tname, t) \ + init_kernel("affine_" #tname, affine_kernel, t) \ + init_kernel("affine_" #tname "_strided", affine_kernel_strided, t) + +#define init_powf(tname, t) \ + init_kernel("powf_" #tname, powf_kernel, t) \ + init_kernel("powf_" #tname "_strided", powf_kernel_strided, t) + +#define init_elu(tname, t) \ + init_kernel("elu_" #tname, elu_kernel, t) \ + init_kernel("elu_" #tname "_strided", elu_kernel_strided, t) + + +init_affine(u8, uint8_t); +init_affine(u32, uint32_t); +init_affine(i64, int64_t); +init_affine(f32, float); +init_affine(f16, half); + +init_powf(f32, float); +init_powf(f16, half); +init_elu(f32, float); +init_elu(f16, half); #if defined(__HAVE_BFLOAT__) -AFFINE(affine_bf16, bfloat); -POWF(powf_bf16, bfloat); -ELU(elu_bf16, bfloat); +init_affine(bf16, bfloat); +init_powf(bf16, bfloat); +init_elu(bf16, bfloat); #endif diff --git a/candle-metal-kernels/src/metal_src/unary.metal b/candle-metal-kernels/src/metal_src/unary.metal index 368b9f2077..a3dbd01ef9 100644 --- a/candle-metal-kernels/src/metal_src/unary.metal +++ b/candle-metal-kernels/src/metal_src/unary.metal @@ -1,8 +1,8 @@ #include #include -# using namespace metal; +// Utils METAL_FUNC uint get_strided_index( uint idx, constant size_t &num_dims, @@ -18,19 +18,112 @@ METAL_FUNC uint get_strided_index( return strided_i; } -template METAL_FUNC T sqr(T in){ return in * in; } -template METAL_FUNC T recip(T in){ return T(1.0 / in); } -template METAL_FUNC T neg(T in){ return -in; } +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +template +constexpr int work_per_thread() { + constexpr int wpt = 8 / sizeof(T); + return MAX(1, wpt); +} + +// Kernels +template ()> +[[kernel]] void unary_kernel( + constant size_t &dim, + device const T* input, + device U* output, + uint tid [[thread_position_in_grid]] +) { + tid *= W; + if (W > 1 && tid + W > dim) { + for (int i = 0; tid + i < dim; ++i) { + output[tid + i] = static_cast(unary()(input[tid + i])); + } + } else { + for (int i = 0; i < W; ++i) { + output[tid + i] = static_cast(unary()(input[tid + i])); + } + } +} + +template +[[kernel]] void unary_kernel_strided( + constant size_t &dim, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides, + constant const T *input, + device U *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dim) return; + uint idx = get_strided_index(tid, num_dims, dims, strides); + output[tid] = static_cast(unary()(input[idx])); +} + +template ()> +[[kernel]] void const_set( + constant size_t &dim, + device const T &input, + device T *output, + uint tid [[thread_position_in_grid]] +) { + tid *= W; + if (W > 1 && tid + W > dim) { + for (int i = 0; tid + i < dim; ++i) { + output[tid + i] = input; + } + } else { + for (int i = 0; i < W; ++i) { + output[tid + i] = input; + } + } +} + +template +[[kernel]] void const_set_strided( + constant size_t &dim, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides, + device const T &input, + device T *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dim) { + return; + } + uint idx = get_strided_index(tid, num_dims, dims, strides); + output[idx] = input; +} + +template +[[kernel]] void copy2d( + constant int64_t &d1, + constant int64_t &d2, + constant int64_t &src_s, + constant int64_t &dst_s, + device const T *input, + device T *output, + uint2 idx [[thread_position_in_grid]] +) { + if (idx.x >= d1 || idx.y >= d2) return; + int64_t src_idx = idx.x * src_s + idx.y; + int64_t dst_idx = idx.x * dst_s + idx.y; + output[dst_idx] = input[src_idx]; +} + +// Unary functions template METAL_FUNC T erf(T in){ - float x = (float) in; // constants - float a1 = 0.254829592; - float a2 = -0.284496736; - float a3 = 1.421413741; - float a4 = -1.453152027; - float a5 = 1.061405429; - float p = 0.3275911; + constexpr const float a1 = 0.254829592; + constexpr const float a2 = -0.284496736; + constexpr const float a3 = 1.421413741; + constexpr const float a4 = -1.453152027; + constexpr const float a5 = 1.061405429; + constexpr const float p = 0.3275911; + + float x = static_cast(in); // Save the sign of x int sign = 1; @@ -46,7 +139,7 @@ template METAL_FUNC T erf(T in){ } template METAL_FUNC T id(T in) { return in; } template METAL_FUNC T gelu_erf(T x) { - return T(x * (1 + erf(x * M_SQRT1_2_F)) / 2); + return static_cast(x * (1 + erf(x * M_SQRT1_2_F)) / 2); } template METAL_FUNC T gelu(T x) { if (x > 5) { @@ -58,190 +151,130 @@ template METAL_FUNC T gelu(T x) { T beta = (static_cast(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha); return static_cast(0.5) * x * (static_cast(1.0) + T(precise::tanh(beta))); } -template METAL_FUNC T relu(T in){ - if (in < 0) { - return 0; +template METAL_FUNC T relu(T x) { + if (x > 5) { + return x; } - return in; -} -template METAL_FUNC T silu(T in){ - return in / (static_cast(1) + exp(-in)); -} -template METAL_FUNC T sigmoid(T in) { - return recip(static_cast(1) + exp(-in)); -} - -#define TILE_SIZE 2 - -#define CONST_SET(TYPENAME, FN_NAME) \ -kernel void FN_NAME( \ - constant size_t &dim, \ - constant TYPENAME &input, \ - device TYPENAME *output, \ - uint tid [[ thread_position_in_grid ]] \ -) { \ - if (tid >= dim) { \ - return; \ - } \ - output[tid] = input; \ -} \ -kernel void FN_NAME##_##strided( \ - constant size_t &dim, \ - constant size_t &num_dims, \ - constant size_t *dims, \ - constant size_t *strides, \ - constant TYPENAME &input, \ - device TYPENAME *output, \ - uint tid [[ thread_position_in_grid ]] \ -) { \ - if (tid >= dim) { \ - return; \ - } \ - output[get_strided_index(tid, num_dims, dims, strides)] = input; \ -} \ -kernel void FN_NAME##_##tiled( \ - constant size_t &dim, \ - constant TYPENAME &input, \ - device TYPENAME *output, \ - uint tid [[ thread_position_in_grid ]] \ -) { \ - for (uint i = 0; i < TILE_SIZE; i++) { \ - const uint idx = tid * TILE_SIZE + i; \ - output[idx] = input; \ - } \ + T x_sq = x * x; + T x_cube = x_sq * x; + T alpha = x + static_cast(0.044715) * x_cube; + T beta = (static_cast(M_2_SQRTPI_F * M_SQRT1_2_F) * alpha); + return static_cast(0.5) * x * (static_cast(1.0) + T(precise::tanh(beta))); } - -#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \ -kernel void FN_NAME( \ - constant size_t &dim, \ - device const TYPENAME *input, \ - device TYPENAME *output, \ - uint tid [[ thread_position_in_grid ]] \ -) { \ - if (tid >= dim) { \ - return; \ - } \ - output[tid] = TYPENAME(FN(float(input[tid]))); \ -} \ -kernel void FN_NAME##_##strided( \ - constant size_t &dim, \ - constant size_t &num_dims, \ - constant size_t *dims, \ - constant size_t *strides, \ - device const TYPENAME *input, \ - device TYPENAME *output, \ - uint tid [[ thread_position_in_grid ]] \ -) { \ - if (tid >= dim) { \ - return; \ - } \ - output[tid] = TYPENAME(FN(float(input[get_strided_index(tid, num_dims, dims, strides)]))); \ -} \ -kernel void FN_NAME##_##tiled( \ - constant size_t &dim, \ - device const TYPENAME *input, \ - device TYPENAME *output, \ - uint tid [[ thread_position_in_grid ]] \ -) { \ - for (uint i = 0; i < TILE_SIZE; i++) { \ - const uint idx = tid * TILE_SIZE + i; \ - output[idx] = TYPENAME(FN(float(input[idx]))); \ - } \ +template METAL_FUNC T recip(T x) { + return static_cast(1.0 / x); } - -#define UNARY_OP(NAME) \ -UNARY(NAME, float, NAME##_f32, NAME##_f32_strided); \ -UNARY(NAME, half, NAME##_f16, NAME##_f16_strided); - -#define BFLOAT_UNARY_OP(NAME) \ -UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided); - -#define COPY2D(FN_NAME, TYPENAME) \ -kernel void FN_NAME( \ - constant int64_t &d1, \ - constant int64_t &d2, \ - constant int64_t &src_s, \ - constant int64_t &dst_s, \ - device const TYPENAME *input, \ - device TYPENAME *output, \ - uint2 idx [[thread_position_in_grid]] \ -) { \ - if (idx.x >= d1 || idx.y >= d2) return; \ - int64_t src_idx = idx.x * src_s + idx.y; \ - int64_t dst_idx = idx.x * dst_s + idx.y; \ - output[dst_idx] = input[src_idx]; \ +template METAL_FUNC T sigmoid(T x) { + return static_cast(recip(1 + exp(-x))); } -COPY2D(copy2d_f32, float) -COPY2D(copy2d_f16, half) -COPY2D(copy2d_u8, uint8_t) -COPY2D(copy2d_u32, uint32_t) - -CONST_SET(float, const_set_f32) -CONST_SET(half, const_set_f16) -CONST_SET(uint8_t, const_set_u8) -CONST_SET(uint32_t, const_set_u32) - -UNARY_OP(cos) -UNARY_OP(sin) -UNARY_OP(sqr) -UNARY_OP(sqrt) -UNARY_OP(neg) -UNARY_OP(exp) -UNARY_OP(log) -UNARY_OP(gelu) -UNARY_OP(silu) -UNARY_OP(abs) -UNARY_OP(ceil) -UNARY_OP(floor) -UNARY_OP(round) -UNARY_OP(gelu_erf) -UNARY_OP(erf) -UNARY_OP(recip) -UNARY_OP(relu) -UNARY_OP(sign) -UNARY_OP(sigmoid) -UNARY(id, float, copy_f32, copy_f32_strided) -UNARY(id, half, copy_f16, copy_f16_strided) -UNARY(id, uint8_t, copy_u8, copy_u8_strided) -UNARY(id, uint32_t, copy_u32, copy_u32_strided) +// Define unary ops +#define define_unary_op(name, op) \ +struct name { \ + template \ + METAL_FUNC T operator()(T x) { \ + return static_cast(op); \ + } \ +}; +define_unary_op(usqr, x * x); +define_unary_op(urecip, recip(x)); +define_unary_op(uneg, -x); +define_unary_op(uid, x); +define_unary_op(ugelu, gelu(x)); +define_unary_op(urelu, x < 0 ? 0 : x); +define_unary_op(usilu, x / (1 + exp(-x))); +define_unary_op(ugelu_erf, gelu_erf(x)); +define_unary_op(usqrt, sqrt(x)); +define_unary_op(ucos, cos(x)); +define_unary_op(usin, sin(x)); +define_unary_op(uexp, exp(x)); +define_unary_op(ulog, log(x)); +define_unary_op(uabs, abs(static_cast(x))); +define_unary_op(uceil, ceil(x)); +define_unary_op(ufloor, floor(x)); +define_unary_op(uround, round(x)); +define_unary_op(uerf, erf(x)); +define_unary_op(usign, sign(x)); +define_unary_op(usigmoid, sigmoid(x)); // tanh may create NaN on large values, e.g. 45 rather than outputting 1. // This has been an issue for the encodec example. -UNARY(precise::tanh, float, tanh_f32, tanh_f32_strided); -UNARY(precise::tanh, half, tanh_f16, tanh_f16_strided); +define_unary_op(utanh, precise::tanh(x)); -#if __METAL_VERSION__ >= 220 -UNARY(id, int64_t, copy_i64, copy_i64_strided) -COPY2D(copy2d_i64, int64_t) -CONST_SET(int64_t, const_set_i64) +// Macros to help initialize kernels +#define init_kernel(name, func, ...) \ + template [[host_name(name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>; + +#define init_unary(op_name, unary_op, tname, t) \ + init_kernel(#op_name "_" #tname, unary_kernel, t, t, unary_op) \ + init_kernel(#op_name "_" #tname "_strided", unary_kernel_strided, t, t, unary_op) + +#if defined(__HAVE_BFLOAT__) +#define init_unary_float(op_name, unary_op) \ + init_unary(op_name, unary_op, f32, float) \ + init_unary(op_name, unary_op, f16, half) \ + init_unary(op_name, unary_op, bf16, bfloat) +#else +#define init_unary_float(op_name, unary_op) \ + init_unary(op_name, unary_op, f32, float) \ + init_unary(op_name, unary_op, f16, half) #endif +#define init_copy2d(tname, t) \ + init_kernel("copy2d_" #tname, copy2d, t) + +#define init_const_set(tname, t) \ + init_kernel("const_set_" #tname, const_set, t) \ + init_kernel("const_set_" #tname "_strided", const_set_strided, t) + +// Initialize all unary kernels for floating point types +init_unary_float(gelu_erf, ugelu_erf); +init_unary_float(sqrt, usqrt); +init_unary_float(sqr, usqr); +init_unary_float(neg, uneg); +init_unary_float(recip, urecip); +init_unary_float(copy, uid); +init_unary_float(silu, usilu); +init_unary_float(gelu, ugelu); +init_unary_float(relu, urelu); +init_unary_float(cos, ucos); +init_unary_float(sin, usin); +init_unary_float(exp, uexp); +init_unary_float(log, ulog); +init_unary_float(abs, uabs); +init_unary_float(ceil, uceil); +init_unary_float(floor, ufloor); +init_unary_float(round, uround); +init_unary_float(erf, uerf); +init_unary_float(sign, usign); +init_unary_float(sigmoid, usigmoid); +init_unary_float(tanh, utanh); + +// Initialize copy2d kernels +init_copy2d(f32, float); +init_copy2d(f16, half); + +// Initialize const_set kernels +init_const_set(f32, float); +init_const_set(f16, half); + #if defined(__HAVE_BFLOAT__) -BFLOAT_UNARY_OP(cos) -BFLOAT_UNARY_OP(sin) -BFLOAT_UNARY_OP(sqr) -BFLOAT_UNARY_OP(sqrt) -BFLOAT_UNARY_OP(neg) -BFLOAT_UNARY_OP(exp) -BFLOAT_UNARY_OP(log) -BFLOAT_UNARY_OP(gelu) -BFLOAT_UNARY_OP(silu) -BFLOAT_UNARY_OP(abs) -BFLOAT_UNARY_OP(ceil) -BFLOAT_UNARY_OP(floor) -BFLOAT_UNARY_OP(round) -BFLOAT_UNARY_OP(gelu_erf) -BFLOAT_UNARY_OP(erf) -BFLOAT_UNARY_OP(recip) -BFLOAT_UNARY_OP(relu) -BFLOAT_UNARY_OP(sign) -BFLOAT_UNARY_OP(sigmoid) - -UNARY(id, bfloat, copy_bf16, copy_bf16_strided) - -UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided); - -COPY2D(copy2d_bf16, bfloat) -CONST_SET(bfloat, const_set_bf16) +init_copy2d(bf16, bfloat); +init_const_set(bf16, bfloat); +#endif + +// Initialize unary kernels for integer dtypes +init_unary(copy, uid, u8, uint8_t); +init_unary(copy, uid, u32, uint32_t); + +init_copy2d(u8, uint8_t); +init_copy2d(u32, uint32_t); + +init_const_set(u8, uint8_t); +init_const_set(u32, uint32_t); + +#if __METAL_VERSION__ >= 220 +init_unary(copy, uid, i64, int64_t); +init_copy2d(i64, int64_t); +init_const_set(i64, int64_t); #endif diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 557a5a4859..e0455df715 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -57,6 +57,7 @@ fn run(v: &[T], name: unary::contiguous::Kernel) -> Vec { &command_buffer, &kernels, name, + size_of::(), v.len(), input, &output, @@ -238,9 +239,9 @@ fn gelu_f16() { .iter() .map(|v| f16::from_f32(*v)) .collect(); - let expected: Vec = vec![-0.0, -0.16, 0.0, 0.84, 1.96, 3.0, 10.0, 20.0]; + let expected: Vec = vec![-0.0, -0.159, 0.0, 0.841, 1.954, 2.996, 10.0, 20.0]; let results = run(&v, unary::contiguous::gelu::HALF); - assert_eq!(approx_f16(results, 2), expected); + assert_eq!(approx_f16(results, 3), expected); } #[test] @@ -541,6 +542,7 @@ fn run_affine(v: &[T], mul: f64, add: f64) -> Vec { &command_buffer, &kernels, "affine_f32", + size_of::(), size, BufferOffset::zero_offset(&input), &output, diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs index 1ad647d79d..034d508068 100644 --- a/candle-metal-kernels/src/utils.rs +++ b/candle-metal-kernels/src/utils.rs @@ -64,6 +64,13 @@ pub fn get_block_dims(dim0: usize, dim1: usize, dim2: usize) -> MTLSize { } } +/// Calculate preferred tile size given the size of a data type in bytes. +/// f32 -> 2, f16 -> 4, u8 -> 8. +#[inline(always)] +pub fn get_tile_size(dtype_size: usize) -> usize { + 1.max(8 / dtype_size) +} + pub fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: P) {

::set_param(encoder, position, data) } diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 7f21aa9b21..b5189ff426 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -154,74 +154,49 @@ impl candle::CustomOp1 for Sigmoid { offset_in_bytes: layout.start_offset() * storage.dtype().size_in_bytes(), }; - match (el_count % 2, dtype, layout.is_contiguous()) { - (0, DType::BF16 | DType::F16, true) => { - use candle_metal_kernels::unary::contiguous_tiled; - let kernel_name = match dtype { - DType::F16 => contiguous_tiled::sigmoid::HALF, - DType::F32 => contiguous_tiled::sigmoid::FLOAT, - DType::BF16 => contiguous_tiled::sigmoid::BFLOAT, - dtype => { - candle::bail!( - "Metal contiguous_tiled unary sigmoid {dtype:?} not implemented" - ) - } - }; - candle_metal_kernels::call_unary_contiguous_tiled( - device.metal_device(), - &encoder, - device.kernels(), - kernel_name, - el_count, - src, - &buffer, - ) - .map_err(MetalError::from)?; - } - (_, _, true) => { - use candle_metal_kernels::unary::contiguous; - let kernel_name = match dtype { - DType::F16 => contiguous::sigmoid::HALF, - DType::F32 => contiguous::sigmoid::FLOAT, - DType::BF16 => contiguous::sigmoid::BFLOAT, - dtype => { - candle::bail!("Metal contiguous unary sigmoid {dtype:?} not implemented") - } - }; - candle_metal_kernels::call_unary_contiguous( - device.metal_device(), - &encoder, - device.kernels(), - kernel_name, - el_count, - src, - &buffer, - ) - .map_err(MetalError::from)?; - } - (_, _, false) => { - use candle_metal_kernels::unary::strided; - let kernel_name = match dtype { - DType::F16 => strided::sigmoid::HALF, - DType::F32 => strided::sigmoid::FLOAT, - DType::BF16 => strided::sigmoid::BFLOAT, - dtype => { - candle::bail!("Metal strided unary sigmoid {dtype:?} not implemented") - } - }; - let dst = candle_metal_kernels::BufferOffset::zero_offset(&buffer); - candle_metal_kernels::call_unary_strided( - device.metal_device(), - &encoder, - device.kernels(), - kernel_name, - layout.dims(), - src, - layout.stride(), - dst, - ) - .map_err(MetalError::from)?; - } + if layout.is_contiguous() { + use candle_metal_kernels::unary::contiguous; + let kernel_name = match dtype { + DType::F16 => contiguous::sigmoid::HALF, + DType::F32 => contiguous::sigmoid::FLOAT, + DType::BF16 => contiguous::sigmoid::BFLOAT, + dtype => { + candle::bail!("Metal contiguous unary sigmoid {dtype:?} not implemented") + } + }; + candle_metal_kernels::call_unary_contiguous( + device.metal_device(), + &encoder, + device.kernels(), + kernel_name, + dtype.size_in_bytes(), + el_count, + src, + &buffer, + ) + .map_err(MetalError::from)?; + } else { + use candle_metal_kernels::unary::strided; + let kernel_name = match dtype { + DType::F16 => strided::sigmoid::HALF, + DType::F32 => strided::sigmoid::FLOAT, + DType::BF16 => strided::sigmoid::BFLOAT, + dtype => { + candle::bail!("Metal strided unary sigmoid {dtype:?} not implemented") + } + }; + let dst = candle_metal_kernels::BufferOffset::zero_offset(&buffer); + candle_metal_kernels::call_unary_strided( + device.metal_device(), + &encoder, + device.kernels(), + kernel_name, + layout.dims(), + src, + layout.stride(), + dst, + ) + .map_err(MetalError::from)?; } let new_storage = candle::MetalStorage::new(buffer, device.clone(), el_count, dtype); From 659016c382b5d8901b365c48e460dda97f253846 Mon Sep 17 00:00:00 2001 From: ivarflakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 8 Dec 2025 07:44:52 +0100 Subject: [PATCH 3/8] [Metal] binary improvements (#3231) --- candle-core/benches/bench_main.rs | 1 + candle-core/benches/benchmarks/binary.rs | 57 +++++ candle-core/benches/benchmarks/mod.rs | 1 + candle-core/src/metal_backend/mod.rs | 63 +++-- candle-metal-kernels/src/kernels/binary.rs | 7 +- .../src/metal_src/binary.metal | 224 ++++++++++-------- 6 files changed, 231 insertions(+), 122 deletions(-) create mode 100644 candle-core/benches/benchmarks/binary.rs diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index e6b7cac227..ec02e4bddb 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -4,6 +4,7 @@ use criterion::criterion_main; criterion_main!( benchmarks::affine::benches, + benchmarks::binary::benches, benchmarks::broadcast::benches, benchmarks::copy::benches, benchmarks::conv_transpose2d::benches, diff --git a/candle-core/benches/benchmarks/binary.rs b/candle-core/benches/benchmarks/binary.rs new file mode 100644 index 0000000000..46e2cf7f7f --- /dev/null +++ b/candle-core/benches/benchmarks/binary.rs @@ -0,0 +1,57 @@ +use crate::benchmarks::{BenchDevice, BenchDeviceHandler}; +use candle_core::{DType, Device, Tensor}; +use criterion::{criterion_group, Criterion, Throughput}; +use std::hint::black_box; +use std::time::Instant; + +fn run(lhs: &Tensor, rhs: &Tensor) -> Tensor { + lhs.mul(rhs).unwrap() +} + +fn run_unary_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { + let b = 1; + let m = 1024; + let k = 1024; + + let lhs = Tensor::arange(0.0f32, (b * m * k) as f32, device) + .unwrap() + .to_dtype(dtype) + .unwrap() + .reshape((b, m, k)) + .unwrap(); + + let rhs = Tensor::arange(0.0f32, (b * m * k) as f32, device) + .unwrap() + .to_dtype(dtype) + .unwrap() + .reshape((b, m, k)) + .unwrap(); + + let flops = 2 * b * m * k * dtype.size_in_bytes(); + + let mut group = c.benchmark_group(device.bench_name(name)); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&lhs), black_box(&rhs)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let handler = BenchDeviceHandler::new().unwrap(); + for device in handler.devices { + for dtype in [DType::F32, DType::BF16, DType::F16] { + let name = format!("binary_mul_{:?}", dtype); + run_unary_benchmark(c, &device, dtype, &name); + } + } +} + +criterion_group!(benches, criterion_benchmark); diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index bc98eb2ff8..3b45a83e5f 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -1,4 +1,5 @@ pub(crate) mod affine; +pub(crate) mod binary; pub(crate) mod broadcast; pub(crate) mod conv_transpose2d; pub(crate) mod copy; diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index e2f8224d60..e58b03bcbf 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -1793,14 +1793,16 @@ impl MetalStorage { let encoder = device.command_encoder()?; let lhs = buffer_o(&self.buffer, lhs_l, self.dtype); let rhs = buffer_o(&rhs.buffer, rhs_l, rhs.dtype); - let (buffer, dtype) = if lhs_l.is_contiguous() && rhs_l.is_contiguous() && &op[..1] != "b" { + let (buffer, dtype) = if lhs_l.is_contiguous() && rhs_l.is_contiguous() { use candle_metal_kernels::kernels::binary::contiguous; let (kernel_name, dtype) = match (op, self.dtype) { - ("add", DType::F32) => (contiguous::add::FLOAT, self.dtype), - ("sub", DType::F32) => (contiguous::sub::FLOAT, self.dtype), - ("mul", DType::F32) => (contiguous::mul::FLOAT, self.dtype), - ("div", DType::F32) => (contiguous::div::FLOAT, self.dtype), + ("badd", DType::F32) => (contiguous::add::FLOAT, self.dtype), + ("bsub", DType::F32) => (contiguous::sub::FLOAT, self.dtype), + ("bmul", DType::F32) => (contiguous::mul::FLOAT, self.dtype), + ("bdiv", DType::F32) => (contiguous::div::FLOAT, self.dtype), + ("bminimum", DType::F32) => (contiguous::min::FLOAT, self.dtype), + ("bmaximum", DType::F32) => (contiguous::max::FLOAT, self.dtype), ("eq", DType::F32) => (contiguous::eq::FLOAT, DType::U8), ("ne", DType::F32) => (contiguous::ne::FLOAT, DType::U8), ("le", DType::F32) => (contiguous::le::FLOAT, DType::U8), @@ -1808,10 +1810,12 @@ impl MetalStorage { ("ge", DType::F32) => (contiguous::ge::FLOAT, DType::U8), ("gt", DType::F32) => (contiguous::gt::FLOAT, DType::U8), - ("add", DType::F16) => (contiguous::add::HALF, self.dtype), - ("sub", DType::F16) => (contiguous::sub::HALF, self.dtype), - ("mul", DType::F16) => (contiguous::mul::HALF, self.dtype), - ("div", DType::F16) => (contiguous::div::HALF, self.dtype), + ("badd", DType::F16) => (contiguous::add::HALF, self.dtype), + ("bsub", DType::F16) => (contiguous::sub::HALF, self.dtype), + ("bmul", DType::F16) => (contiguous::mul::HALF, self.dtype), + ("bdiv", DType::F16) => (contiguous::div::HALF, self.dtype), + ("bminimum", DType::F16) => (contiguous::min::HALF, self.dtype), + ("bmaximum", DType::F16) => (contiguous::max::HALF, self.dtype), ("eq", DType::F16) => (contiguous::eq::HALF, DType::U8), ("ne", DType::F16) => (contiguous::ne::HALF, DType::U8), ("le", DType::F16) => (contiguous::le::HALF, DType::U8), @@ -1819,10 +1823,12 @@ impl MetalStorage { ("ge", DType::F16) => (contiguous::ge::HALF, DType::U8), ("gt", DType::F16) => (contiguous::gt::HALF, DType::U8), - ("add", DType::BF16) => (contiguous::add::BFLOAT, self.dtype), - ("sub", DType::BF16) => (contiguous::sub::BFLOAT, self.dtype), - ("mul", DType::BF16) => (contiguous::mul::BFLOAT, self.dtype), - ("div", DType::BF16) => (contiguous::div::BFLOAT, self.dtype), + ("badd", DType::BF16) => (contiguous::add::BFLOAT, self.dtype), + ("bsub", DType::BF16) => (contiguous::sub::BFLOAT, self.dtype), + ("bmul", DType::BF16) => (contiguous::mul::BFLOAT, self.dtype), + ("bdiv", DType::BF16) => (contiguous::div::BFLOAT, self.dtype), + ("bminimum", DType::BF16) => (contiguous::min::BFLOAT, self.dtype), + ("bmaximum", DType::BF16) => (contiguous::max::BFLOAT, self.dtype), ("eq", DType::BF16) => (contiguous::eq::BFLOAT, DType::U8), ("ne", DType::BF16) => (contiguous::ne::BFLOAT, DType::U8), ("le", DType::BF16) => (contiguous::le::BFLOAT, DType::U8), @@ -1830,10 +1836,12 @@ impl MetalStorage { ("ge", DType::BF16) => (contiguous::ge::BFLOAT, DType::U8), ("gt", DType::BF16) => (contiguous::gt::BFLOAT, DType::U8), - ("add", DType::I64) => (contiguous::add::I64, self.dtype), - ("sub", DType::I64) => (contiguous::sub::I64, self.dtype), - ("mul", DType::I64) => (contiguous::mul::I64, self.dtype), - ("div", DType::I64) => (contiguous::div::I64, self.dtype), + ("badd", DType::I64) => (contiguous::add::I64, self.dtype), + ("bsub", DType::I64) => (contiguous::sub::I64, self.dtype), + ("bmul", DType::I64) => (contiguous::mul::I64, self.dtype), + ("bdiv", DType::I64) => (contiguous::div::I64, self.dtype), + ("bminimum", DType::I64) => (contiguous::min::I64, self.dtype), + ("bmaximum", DType::I64) => (contiguous::max::I64, self.dtype), ("eq", DType::I64) => (contiguous::eq::I64, DType::U8), ("ne", DType::I64) => (contiguous::ne::I64, DType::U8), ("le", DType::I64) => (contiguous::le::I64, DType::U8), @@ -1841,10 +1849,12 @@ impl MetalStorage { ("ge", DType::I64) => (contiguous::ge::I64, DType::U8), ("gt", DType::I64) => (contiguous::gt::I64, DType::U8), - ("add", DType::U32) => (contiguous::add::U32, self.dtype), - ("sub", DType::U32) => (contiguous::sub::U32, self.dtype), - ("mul", DType::U32) => (contiguous::mul::U32, self.dtype), - ("div", DType::U32) => (contiguous::div::U32, self.dtype), + ("badd", DType::U32) => (contiguous::add::U32, self.dtype), + ("bsub", DType::U32) => (contiguous::sub::U32, self.dtype), + ("bmul", DType::U32) => (contiguous::mul::U32, self.dtype), + ("bdiv", DType::U32) => (contiguous::div::U32, self.dtype), + ("bminimum", DType::U32) => (contiguous::min::U32, self.dtype), + ("bmaximum", DType::U32) => (contiguous::max::U32, self.dtype), ("eq", DType::U32) => (contiguous::eq::U32, DType::U8), ("ne", DType::U32) => (contiguous::ne::U32, DType::U8), ("le", DType::U32) => (contiguous::le::U32, DType::U8), @@ -1852,10 +1862,12 @@ impl MetalStorage { ("ge", DType::U32) => (contiguous::ge::U32, DType::U8), ("gt", DType::U32) => (contiguous::gt::U32, DType::U8), - ("add", DType::U8) => (contiguous::add::U8, self.dtype), - ("sub", DType::U8) => (contiguous::sub::U8, self.dtype), - ("mul", DType::U8) => (contiguous::mul::U8, self.dtype), - ("div", DType::U8) => (contiguous::div::U8, self.dtype), + ("badd", DType::U8) => (contiguous::add::U8, self.dtype), + ("bsub", DType::U8) => (contiguous::sub::U8, self.dtype), + ("bmul", DType::U8) => (contiguous::mul::U8, self.dtype), + ("bdiv", DType::U8) => (contiguous::div::U8, self.dtype), + ("bminimum", DType::U8) => (contiguous::min::U8, self.dtype), + ("bmaximum", DType::U8) => (contiguous::max::U8, self.dtype), ("eq", DType::U8) => (contiguous::eq::U8, DType::U8), ("ne", DType::U8) => (contiguous::ne::U8, DType::U8), ("le", DType::U8) => (contiguous::le::U8, DType::U8), @@ -1873,6 +1885,7 @@ impl MetalStorage { &encoder, &device.kernels, kernel_name, + self.dtype.size_in_bytes(), el_count, lhs, rhs, diff --git a/candle-metal-kernels/src/kernels/binary.rs b/candle-metal-kernels/src/kernels/binary.rs index d91ec0e109..249e1592d4 100644 --- a/candle-metal-kernels/src/kernels/binary.rs +++ b/candle-metal-kernels/src/kernels/binary.rs @@ -1,6 +1,6 @@ use crate::kernels::macros::ops; -use crate::linear_split; use crate::utils::{BufferOffset, EncoderProvider}; +use crate::{get_tile_size, linear_split}; use crate::{set_params, Buffer, ComputeCommandEncoder, Device, Kernels, MetalKernelError, Source}; use objc2_metal::MTLResourceUsage; @@ -12,6 +12,7 @@ pub fn call_binary_contiguous( ep: impl EncoderProvider, kernels: &Kernels, kernel_name: contiguous::Kernel, + dtype_size: usize, length: usize, left: BufferOffset, right: BufferOffset, @@ -25,7 +26,9 @@ pub fn call_binary_contiguous( set_params!(encoder, (length, &left, &right, output)); - let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); + let tile_size = get_tile_size(dtype_size); + let tiles = length.div_ceil(tile_size); + let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles); encoder.use_resource(left.buffer, MTLResourceUsage::Read); encoder.use_resource(right.buffer, MTLResourceUsage::Read); diff --git a/candle-metal-kernels/src/metal_src/binary.metal b/candle-metal-kernels/src/metal_src/binary.metal index e83498e40d..2c2d88724b 100644 --- a/candle-metal-kernels/src/metal_src/binary.metal +++ b/candle-metal-kernels/src/metal_src/binary.metal @@ -1,5 +1,7 @@ #include +using namespace metal; +// Utils #define MAX(x, y) ((x) > (y) ? (x) : (y)) #define MIN(x, y) ((x) < (y) ? (x) : (y)) @@ -18,108 +20,140 @@ METAL_FUNC uint get_strided_index( return strided_i; } -using namespace metal; - -#define BINARY(FN, TYPENAME, OUT_TYPENAME, FN_NAME, FN_NAME_STRIDED) \ -kernel void FN_NAME( \ - constant size_t &dim, \ - device const TYPENAME *left, \ - device const TYPENAME *right, \ - device OUT_TYPENAME *output, \ - uint tid [[ thread_position_in_grid ]] \ -) { \ - if (tid >= dim) { \ - return; \ - } \ - TYPENAME x = left[tid]; \ - TYPENAME y = right[tid]; \ - output[tid] = OUT_TYPENAME(FN); \ -}\ -kernel void FN_NAME_STRIDED( \ - constant size_t &dim, \ - constant size_t &num_dims, \ - constant size_t *dims, \ - constant size_t *left_strides, \ - constant size_t *right_strides, \ - device const TYPENAME *left, \ - device const TYPENAME *right, \ - device OUT_TYPENAME *output, \ - uint tid [[ thread_position_in_grid ]] \ -) { \ - if (tid >= dim) { \ - return; \ - } \ - TYPENAME x = left[get_strided_index(tid, num_dims, dims, left_strides)]; \ - TYPENAME y = right[get_strided_index(tid, num_dims, dims, right_strides)]; \ - output[tid] = OUT_TYPENAME(FN); \ +template +constexpr int work_per_thread() { + constexpr int wpt = 8 / sizeof(T); + return MAX(1, wpt); } -#define BINARY_OP(FN, NAME) \ -BINARY(FN, float, float, NAME##_f32, NAME##_f32_strided); \ -BINARY(FN, half, half, NAME##_f16, NAME##_f16_strided); \ -BINARY(FN, uint32_t, uint32_t, NAME##_u32, NAME##_u32_strided); \ -BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); - -#define BINARY_OP_OUT(NAME, FN) \ -BINARY(FN, float, uint8_t, NAME##_f32, NAME##_f32_strided); \ -BINARY(FN, half, uint8_t, NAME##_f16, NAME##_f16_strided); \ -BINARY(FN, uint32_t, uint8_t, NAME##_u32, NAME##_u32_strided); \ -BINARY(FN, uint8_t, uint8_t, NAME##_u8, NAME##_u8_strided); - -#define INT64_BINARY_OP(NAME, FN) \ -BINARY(FN, int64_t, int64_t, NAME##_i64, NAME##_i64_strided); - -#define INT64_BINARY_OP_OUT(NAME, FN) \ -BINARY(FN, int64_t, uint8_t, NAME##_i64, NAME##_i64_strided); - -#define BFLOAT_BINARY_OP(FN, NAME) \ -BINARY(FN, bfloat, bfloat, NAME##_bf16, NAME##_bf16_strided); - -#define BFLOAT_BINARY_OP_OUT(NAME, FN) \ -BINARY(FN, bfloat, uint8_t, NAME##_bf16, NAME##_bf16_strided); +// Kernels +template ()> +[[kernel]] void binary_kernel( + constant size_t &dim, + device const T *left, + device const T *right, + device U *output, + uint tid [[thread_position_in_grid]] +) { + binary op; + + tid *= W; + if (W > 1 && tid + W > dim) { + for (int i = 0; tid + i < dim; ++i) { + output[tid + i] = static_cast(op(left[tid + i], right[tid + i])); + } + } else { + for (int i = 0; i < W; ++i) { + output[tid + i] = static_cast(op(left[tid + i], right[tid + i])); + } + } +} -BINARY_OP(x + y, add) -BINARY_OP(x - y, sub) -BINARY_OP(x * y, mul) -BINARY_OP(x / y, div) -BINARY_OP(MIN(x, y), min) -BINARY_OP(MAX(x, y), max) +template +[[kernel]] void binary_kernel_strided( + constant size_t &dim, + constant size_t &num_dims, + constant size_t *dims, + constant size_t *left_strides, + constant size_t *right_strides, + device const T *left, + device const T *right, + device U *output, + uint tid [[ thread_position_in_grid ]] +) { + if (tid >= dim) return; + binary op; + uint l_idx = get_strided_index(tid, num_dims, dims, left_strides); + uint r_idx = get_strided_index(tid, num_dims, dims, right_strides); + output[tid] = static_cast(op(left[l_idx], right[r_idx])); +} -BINARY_OP_OUT(eq, x == y) -BINARY_OP_OUT(ne, x != y) -BINARY_OP_OUT(le, x <= y) -BINARY_OP_OUT(lt, x < y) -BINARY_OP_OUT(ge, x >= y) -BINARY_OP_OUT(gt, x > y) +// Macros to help initialize kernels +#define init_kernel(name, func, ...) \ + template [[host_name(name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>; -#if __METAL_VERSION__ >= 220 -INT64_BINARY_OP(add, x + y) -INT64_BINARY_OP(sub, x - y) -INT64_BINARY_OP(mul, x * y) -INT64_BINARY_OP(div, x / y) -INT64_BINARY_OP(min, MIN(x, y)) -INT64_BINARY_OP(max, MAX(x, y)) +#define init_binary_k(op_name, binary_op, tname, t, u) \ + init_kernel(#op_name "_" #tname, binary_kernel, t, u, binary_op) \ + init_kernel(#op_name "_" #tname "_strided", binary_kernel_strided, t, u, binary_op) -INT64_BINARY_OP_OUT(eq, x == y) -INT64_BINARY_OP_OUT(ne, x != y) -INT64_BINARY_OP_OUT(le, x <= y) -INT64_BINARY_OP_OUT(lt, x < y) -INT64_BINARY_OP_OUT(ge, x >= y) -INT64_BINARY_OP_OUT(gt, x > y) +#if defined(__HAVE_BFLOAT__) +#define init_binary(op_name, binary_op) \ + init_binary_k(op_name, binary_op, f32, float, float) \ + init_binary_k(op_name, binary_op, f16, half, half) \ + init_binary_k(op_name, binary_op, bf16, bfloat, bfloat) \ + init_binary_k(op_name, binary_op, u8, uint8_t, uint8_t) \ + init_binary_k(op_name, binary_op, u32, uint32_t, uint32_t) \ + init_binary_k(op_name, binary_op, i64, int64_t, int64_t) +#else +#define init_binary(op_name, binary_op) \ + init_binary_k(op_name, binary_op, f32, float, float) \ + init_binary_k(op_name, binary_op, f16, half, half) \ + init_binary_k(op_name, binary_op, bf16, bfloat, bfloat) \ + init_binary_k(op_name, binary_op, u8, uint8_t, uint8_t) \ + init_binary_k(op_name, binary_op, u32, uint32_t, uint32_t) \ + init_binary_k(op_name, binary_op, i64, int64_t, int64_t) #endif #if defined(__HAVE_BFLOAT__) -BFLOAT_BINARY_OP(x + y, add) -BFLOAT_BINARY_OP(x - y, sub) -BFLOAT_BINARY_OP(x * y, mul) -BFLOAT_BINARY_OP(x / y, div) -BFLOAT_BINARY_OP(MIN(x, y), min) -BFLOAT_BINARY_OP(MAX(x, y), max) - -BFLOAT_BINARY_OP_OUT(eq, x == y) -BFLOAT_BINARY_OP_OUT(ne, x != y) -BFLOAT_BINARY_OP_OUT(le, x <= y) -BFLOAT_BINARY_OP_OUT(lt, x < y) -BFLOAT_BINARY_OP_OUT(ge, x >= y) -BFLOAT_BINARY_OP_OUT(gt, x > y) +#define init_boolean_binary(op_name, binary_op) \ + init_binary_k(op_name, binary_op, f32, float, bool) \ + init_binary_k(op_name, binary_op, f16, half, bool) \ + init_binary_k(op_name, binary_op, bf16, bfloat, bool) \ + init_binary_k(op_name, binary_op, u8, uint8_t, bool) \ + init_binary_k(op_name, binary_op, u32, uint32_t, bool) \ + init_binary_k(op_name, binary_op, i64, int64_t, bool) +#else +#define init_boolean_binary(op_name, binary_op) \ + init_binary_k(op_name, binary_op, f32, float, bool) \ + init_binary_k(op_name, binary_op, f16, half, bool) \ + init_binary_k(op_name, binary_op, u8, uint8_t, bool) \ + init_binary_k(op_name, binary_op, u32, uint32_t, bool) \ + init_binary_k(op_name, binary_op, i64, int64_t, bool) #endif + +// Define binary ops +#define define_binary_op(name, op) \ +struct name { \ + template \ + METAL_FUNC T operator()(T x, T y) { \ + return static_cast(op); \ + } \ +}; +#define define_binary_bool_op(name, op) \ +struct name { \ + template \ + METAL_FUNC bool operator()(T x, T y) { \ + return op; \ + } \ +}; + +// Define binary ops +define_binary_op(badd, x + y); +define_binary_op(bsub, x - y); +define_binary_op(bmul, x * y); +define_binary_op(bdiv, x / y); +define_binary_op(bmin, MIN(x, y)); +define_binary_op(bmax, MAX(x, y)); + +// Define binary ops that return a bool +define_binary_bool_op(beq, x == y); +define_binary_bool_op(bne, x != y); +define_binary_bool_op(ble, x <= y); +define_binary_bool_op(blt, x < y); +define_binary_bool_op(bge, x >= y); +define_binary_bool_op(bgt, x > y) + +// Initialize kernels +init_binary(add, badd); +init_binary(sub, bsub); +init_binary(mul, bmul); +init_binary(div, bdiv); +init_binary(min, bmin); +init_binary(max, bmax); + +init_boolean_binary(eq, beq); +init_boolean_binary(ne, bne); +init_boolean_binary(le, ble); +init_boolean_binary(lt, blt); +init_boolean_binary(ge, bge); +init_boolean_binary(gt, bgt); From eb1978c8c8e0b513ff577c64add5c4e7f5f3003d Mon Sep 17 00:00:00 2001 From: Theo Lee Date: Tue, 9 Dec 2025 20:40:52 +0900 Subject: [PATCH 4/8] add indexed_moe_forward for QMatMul wrapper --- candle-transformers/src/models/with_tracing.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/candle-transformers/src/models/with_tracing.rs b/candle-transformers/src/models/with_tracing.rs index f4706c7e95..77887ff297 100644 --- a/candle-transformers/src/models/with_tracing.rs +++ b/candle-transformers/src/models/with_tracing.rs @@ -122,6 +122,11 @@ impl QMatMul { let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); Ok(Self { inner, span }) } + + pub fn indexed_moe_forward(&self, xs: &Tensor, ids: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.indexed_moe_forward(xs, ids) + } } impl Module for QMatMul { From d13b74bbbf7a66daf9343f72e761d887f7fad112 Mon Sep 17 00:00:00 2001 From: Theo Lee Date: Tue, 9 Dec 2025 21:25:54 +0900 Subject: [PATCH 5/8] make MlpWeights public --- candle-transformers/src/models/quantized_qwen3.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/candle-transformers/src/models/quantized_qwen3.rs b/candle-transformers/src/models/quantized_qwen3.rs index 5d9f414658..e0bca8e578 100644 --- a/candle-transformers/src/models/quantized_qwen3.rs +++ b/candle-transformers/src/models/quantized_qwen3.rs @@ -45,8 +45,8 @@ impl Gguf { } #[derive(Debug, Clone)] -struct MlpWeights { - gate_proj: QMatMul, +pub(crate) struct MlpWeights { + gate_proj: QMatMul, //양자화 전용 행렬 곱셈기 up_proj: QMatMul, down_proj: QMatMul, act_fn: Activation, @@ -54,7 +54,7 @@ struct MlpWeights { } impl MlpWeights { - fn new(gg: &mut Gguf, prefix: &str) -> Result { + pub(crate) fn new(gg: &mut Gguf, prefix: &str) -> Result { let gate_proj = gg.qmatmul(&format!("{prefix}.ffn_gate.weight"))?; let up_proj = gg.qmatmul(&format!("{prefix}.ffn_up.weight"))?; let down_proj = gg.qmatmul(&format!("{prefix}.ffn_down.weight"))?; @@ -336,7 +336,7 @@ impl ModelWeights { }; let embed_tensor = gg.tensor("token_embd.weight")?; - let embed_tokens = Embedding::new(embed_tensor.dequantize(device)?, hidden_size); + let embed_tokens = Embedding::new(embed_tensor.dequantize(device)?, hidden_size); //압축을 풀어버림. embedding 층은 단어를 찾아오는 거라 압축되어 있으면 indexing이 느림. let rotary = Arc::new(RotaryEmbedding::new( dtype, From 1765487613f2256bf57a5003e007ba36355d9510 Mon Sep 17 00:00:00 2001 From: Theo Lee Date: Wed, 10 Dec 2025 00:57:07 +0900 Subject: [PATCH 6/8] remove unnecessary comments --- .../src/models/quantized_qwen3.rs | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/candle-transformers/src/models/quantized_qwen3.rs b/candle-transformers/src/models/quantized_qwen3.rs index e0bca8e578..e616ef03a5 100644 --- a/candle-transformers/src/models/quantized_qwen3.rs +++ b/candle-transformers/src/models/quantized_qwen3.rs @@ -14,39 +14,39 @@ use candle_nn::{kv_cache::ConcatKvCache, Activation, Embedding, Module}; use std::io::{Read, Seek}; use std::sync::Arc; -struct Gguf { +pub(crate) struct Gguf { ct: gguf_file::Content, reader: R, device: Device, } impl Gguf { - fn new(ct: gguf_file::Content, reader: R, device: Device) -> Self { + pub(crate) fn new(ct: gguf_file::Content, reader: R, device: Device) -> Self { Self { ct, reader, device } } - fn qmatmul(&mut self, name: &str) -> Result { + pub(crate) fn qmatmul(&mut self, name: &str) -> Result { let ws = self.ct.tensor(&mut self.reader, name, &self.device)?; QMatMul::from_weights(ws.into()) } - fn rms_norm(&mut self, name: &str, eps: f64) -> Result { + pub(crate) fn rms_norm(&mut self, name: &str, eps: f64) -> Result { let ws = self.ct.tensor(&mut self.reader, name, &self.device)?; RmsNorm::from_qtensor(ws, eps) } - fn metadata(&self) -> &std::collections::HashMap { + pub(crate) fn metadata(&self) -> &std::collections::HashMap { &self.ct.metadata } - fn tensor(&mut self, name: &str) -> Result { + pub(crate) fn tensor(&mut self, name: &str) -> Result { self.ct.tensor(&mut self.reader, name, &self.device) } } #[derive(Debug, Clone)] pub(crate) struct MlpWeights { - gate_proj: QMatMul, //양자화 전용 행렬 곱셈기 + gate_proj: QMatMul, up_proj: QMatMul, down_proj: QMatMul, act_fn: Activation, @@ -81,13 +81,13 @@ impl Module for MlpWeights { } #[derive(Debug, Clone)] -struct RotaryEmbedding { +pub(crate) struct RotaryEmbedding { sin: Tensor, cos: Tensor, } impl RotaryEmbedding { - fn new( + pub(crate) fn new( dtype: DType, head_dim: usize, max_position_embeddings: usize, @@ -113,7 +113,7 @@ impl RotaryEmbedding { } /// Apply RoPE (q, k shape: B x H x L x D) - fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> { + pub(crate) fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> { let (_, _, seq_len, _) = q.dims4()?; let cos = self.cos.narrow(0, offset, seq_len)?.to_dtype(q.dtype())?; let sin = self.sin.narrow(0, offset, seq_len)?.to_dtype(q.dtype())?; @@ -124,7 +124,7 @@ impl RotaryEmbedding { } #[derive(Debug, Clone)] -struct AttentionWeights { +pub(crate) struct AttentionWeights { q_proj: QMatMul, k_proj: QMatMul, v_proj: QMatMul, @@ -141,7 +141,7 @@ struct AttentionWeights { } impl AttentionWeights { - fn new( + pub(crate) fn new( gg: &mut Gguf, num_heads: usize, num_kv_heads: usize, @@ -181,7 +181,12 @@ impl AttentionWeights { }) } - fn forward(&mut self, x: &Tensor, attn_mask: Option<&Tensor>, offset: usize) -> Result { + pub(crate) fn forward( + &mut self, + x: &Tensor, + attn_mask: Option<&Tensor>, + offset: usize, + ) -> Result { let _enter = self.span_attn.enter(); let (b, l, _) = x.dims3()?; @@ -234,7 +239,7 @@ impl AttentionWeights { self.o_proj.forward(&reshaped_ctx) } - fn clear_kv_cache(&mut self) { + pub(crate) fn clear_kv_cache(&mut self) { self.kv_cache.reset(); } } @@ -336,7 +341,7 @@ impl ModelWeights { }; let embed_tensor = gg.tensor("token_embd.weight")?; - let embed_tokens = Embedding::new(embed_tensor.dequantize(device)?, hidden_size); //압축을 풀어버림. embedding 층은 단어를 찾아오는 거라 압축되어 있으면 indexing이 느림. + let embed_tokens = Embedding::new(embed_tensor.dequantize(device)?, hidden_size); let rotary = Arc::new(RotaryEmbedding::new( dtype, From 3095e2ab1cf65279b50077bdb0799ef7545a879c Mon Sep 17 00:00:00 2001 From: Theo Lee Date: Wed, 10 Dec 2025 00:57:52 +0900 Subject: [PATCH 7/8] add quantized_qwen3_moe model --- .../src/models/quantized_qwen3_moe.rs | 332 ++++++++++++++++++ 1 file changed, 332 insertions(+) diff --git a/candle-transformers/src/models/quantized_qwen3_moe.rs b/candle-transformers/src/models/quantized_qwen3_moe.rs index e69de29bb2..47a83cc7e5 100644 --- a/candle-transformers/src/models/quantized_qwen3_moe.rs +++ b/candle-transformers/src/models/quantized_qwen3_moe.rs @@ -0,0 +1,332 @@ +//! Quantized Qwen3 MoE implementation. +//! +//! +//! References: +//! - [Qwen3 MoE Models](https://huggingface.co/docs/transformers/model_doc/qwen3_moe) (architecture based on official implementations) +//! +use super::with_tracing::QMatMul; +use crate::models::quantized_qwen3::{AttentionWeights, Gguf, MlpWeights, RotaryEmbedding}; +use crate::quantized_nn::RmsNorm; +use candle::quantized::gguf_file; +use candle::{DType, Device, Result, Tensor, D}; +use candle_nn::{Activation, Embedding, Module}; +use std::io::{Read, Seek}; +use std::sync::Arc; + +// Sparse MoE block stores the concatenated weights of all experts, no split! +#[derive(Debug, Clone)] +struct SparseMoeBlockWeights { + gate: QMatMul, + experts_gate: QMatMul, + experts_up: QMatMul, + experts_down: QMatMul, + act: candle_nn::Activation, + norm_topk_prob: bool, + num_experts_per_tok: usize, + span: tracing::Span, +} + +impl SparseMoeBlockWeights { + fn new( + gg: &mut Gguf, + prefix: &str, + act: Activation, + norm_topk_prob: bool, + num_experts_per_tok: usize, + ) -> Result { + let gate = gg.qmatmul(&format!("{prefix}.ffn_gate_inp.weight"))?; + let experts_gate = gg.qmatmul(&format!("{prefix}.ffn_gate_exps.weight"))?; + let experts_up = gg.qmatmul(&format!("{prefix}.ffn_up_exps.weight"))?; + let experts_down = gg.qmatmul(&format!("{prefix}.ffn_down_exps.weight"))?; + let span = tracing::span!(tracing::Level::TRACE, "MoEBlock"); + Ok(Self { + gate, + experts_gate, + experts_up, + experts_down, + act, + norm_topk_prob, + num_experts_per_tok, + span, + }) + } +} + +impl Module for SparseMoeBlockWeights { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let (b_size, seq_len, hidden_dim) = xs.dims3()?; + //[b_size * seq_len, hidden_dim] + let xs = xs.reshape(((), hidden_dim))?; + let original_dtype = xs.dtype(); + let (num_tokens, hidden_dim) = xs.dims2()?; + + let router_logits = self.gate.forward(&xs.to_dtype(DType::F32)?)?; + let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?; + + // Extract topk experts per token + let experts_per_tok = routing_weights + .arg_sort_last_dim(false)? + .narrow(D::Minus1, 0, self.num_experts_per_tok)? + .contiguous()?; + let mut routing_weights = routing_weights.gather(&experts_per_tok, D::Minus1)?; + + if self.norm_topk_prob { + routing_weights = + routing_weights.broadcast_div(&routing_weights.sum_keepdim(D::Minus1)?)?; + } + + let ys = { + let xs = xs.reshape((num_tokens, 1, hidden_dim))?; + let gate = self + .experts_gate + .indexed_moe_forward(&xs, &experts_per_tok)?; + let up = self.experts_up.indexed_moe_forward(&xs, &experts_per_tok)?; + self.experts_down + .indexed_moe_forward(&(up * gate.apply(&self.act)?)?, &experts_per_tok)? + }; + + ys.broadcast_mul(&routing_weights.unsqueeze(D::Minus1)?)? + .sum(D::Minus2)? + .reshape((b_size, seq_len, hidden_dim))? + .to_dtype(original_dtype) + } +} + +#[derive(Debug, Clone)] +enum MoeOrMlpWeights { + Moe(SparseMoeBlockWeights), + Mlp(MlpWeights), +} + +impl Module for MoeOrMlpWeights { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::Moe(m) => m.forward(xs), + Self::Mlp(m) => m.forward(xs), + } + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: AttentionWeights, + feed_forward: MoeOrMlpWeights, + ln1: RmsNorm, + ln2: RmsNorm, +} + +impl DecoderLayer { + fn new( + gg: &mut Gguf, + num_attention_heads: usize, + num_key_value_heads: usize, + head_dim: usize, + rms_norm_eps: f64, + rotary: Arc, + layer_idx: usize, + num_experts: usize, + decoder_sparse_step: usize, + norm_topk_prob: bool, + num_experts_per_tok: usize, + ) -> Result { + let prefix = format!("blk.{layer_idx}"); + + let ln1 = gg.rms_norm(&format!("{prefix}.attn_norm.weight"), rms_norm_eps)?; + let ln2 = gg.rms_norm(&format!("{prefix}.ffn_norm.weight"), rms_norm_eps)?; + let self_attn = AttentionWeights::new( + gg, + num_attention_heads, + num_key_value_heads, + head_dim, + rms_norm_eps, + rotary, + &prefix, + )?; + let feed_forward = if num_experts > 0 && (layer_idx + 1).is_multiple_of(decoder_sparse_step) + { + MoeOrMlpWeights::Moe(SparseMoeBlockWeights::new( + gg, + &prefix, + candle_nn::Activation::Silu, + norm_topk_prob, + num_experts_per_tok, + )?) + } else { + MoeOrMlpWeights::Mlp(MlpWeights::new(gg, &prefix)?) + }; + Ok(Self { + self_attn, + feed_forward, + ln1, + ln2, + }) + } + + fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result { + let h = self.ln1.forward(x)?; + let h = self.self_attn.forward(&h, mask, offset)?; + let x = (x + h)?; + let h2 = self.ln2.forward(&x)?; + let h2 = h2.apply(&self.feed_forward)?; + x + h2 + } + + fn clear_kv_cache(&mut self) { + self.self_attn.clear_kv_cache(); + } +} + +#[derive(Debug, Clone)] +pub struct ModelWeights { + embed_tokens: Embedding, + layers: Vec, + norm: RmsNorm, + lm_head: QMatMul, + device: Device, + dtype: DType, + span: tracing::Span, + span_output: tracing::Span, +} + +impl ModelWeights { + pub fn from_gguf( + ct: gguf_file::Content, + reader: &mut R, + device: &Device, + num_experts_per_tok: Option, + ) -> Result { + let mut gg = Gguf::new(ct, reader, device.clone()); + let md_get = |s: &str| match gg.metadata().get(s) { + None => candle::bail!("cannot find {s} in metadata"), + Some(v) => Ok(v), + }; + + let num_attention_heads = md_get("qwen3moe.attention.head_count")?.to_u32()? as usize; + let num_kv_heads = md_get("qwen3moe.attention.head_count_kv")?.to_u32()? as usize; + let head_dim = md_get("qwen3moe.attention.key_length")?.to_u32()? as usize; + let num_layers = md_get("qwen3moe.block_count")?.to_u32()? as usize; + let hidden_size = md_get("qwen3moe.embedding_length")?.to_u32()? as usize; + let max_position_embeddings = md_get("qwen3moe.context_length")?.to_u32()? as usize; + let rms_norm_eps = md_get("qwen3moe.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; + let rope_freq_base = md_get("qwen3moe.rope.freq_base")?.to_f32()? as f64; + let decoder_sparse_step = 1; + let moe_intermediate_size = + md_get("qwen3moe.expert_feed_forward_length")?.to_u32()? as usize; + let num_experts_per_tok = if let Some(n) = num_experts_per_tok { + n + } else { + md_get("qwen3moe.expert_used_count")?.to_u32()? as usize + }; + let num_experts = md_get("qwen3moe.expert_count")?.to_u32()? as usize; + let norm_topk_prob = false; + + let dtype = match gg.metadata().get("general.dtype") { + Some(v) => match v.to_u32() { + Ok(0) => DType::F32, + Ok(1) => DType::F16, + _ => DType::F16, + }, + None => DType::F16, + }; + + let embed_tensor = gg.tensor("token_embd.weight")?; + let embed_tokens = Embedding::new(embed_tensor.dequantize(device)?, hidden_size); + + let rotary = Arc::new(RotaryEmbedding::new( + dtype, + head_dim, + max_position_embeddings, + rope_freq_base, + device, + )?); + + let mut layers = Vec::with_capacity(num_layers); + for i in 0..num_layers { + layers.push(DecoderLayer::new( + &mut gg, + num_attention_heads, + num_kv_heads, + head_dim, + rms_norm_eps, + rotary.clone(), + i, + num_experts, + decoder_sparse_step, + norm_topk_prob, + num_experts_per_tok, + )?); + } + + let norm = gg.rms_norm("output_norm.weight", rms_norm_eps)?; + // Load output projection tensor, falling back to tied embeddings like gemma3 + let lm_head_tensor = match gg.tensor("output.weight") { + Ok(tensor) => tensor, + Err(_) => gg.tensor("token_embd.weight")?, + }; + let lm_head = QMatMul::from_weights(lm_head_tensor.into())?; + let span = tracing::span!(tracing::Level::TRACE, "model"); + let span_output = tracing::span!(tracing::Level::TRACE, "output"); + Ok(Self { + embed_tokens, + layers, + norm, + lm_head, + device: device.clone(), + dtype, + span, + span_output, + }) + } + + fn causal_mask( + &self, + b: usize, + tgt: usize, + offset: usize, + sw: Option, + ) -> Result { + let minf = f32::NEG_INFINITY; + let mask: Vec<_> = (0..tgt) + .flat_map(|i| { + (0..(tgt + offset)).map(move |j| { + let past_ok = j <= i + offset; + let sw_ok = match sw { + Some(w) => (i + offset) as i64 - j as i64 <= w as i64, + None => true, + }; + if past_ok && sw_ok { + 0. + } else { + minf + } + }) + }) + .collect(); + Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype) + } + + pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result { + let _enter = self.span.enter(); + let (b, l) = input.dims2()?; + let mut h = self.embed_tokens.forward(input)?; + let causal_mask = if l == 1 { + None + } else { + Some(self.causal_mask(b, l, offset, None)?) + }; + for layer in &mut self.layers { + h = layer.forward(&h, causal_mask.as_ref(), offset)?; + } + let h = self.norm.forward(&h)?; + let _enter = self.span_output.enter(); + let last_hidden = h.narrow(1, l - 1, 1)?; + self.lm_head.forward(&last_hidden)?.squeeze(1) + } + + pub fn clear_kv_cache(&mut self) { + for layer in &mut self.layers { + layer.clear_kv_cache(); + } + } +} From a531453a4aa01d76036cd67c6a8a8863f6139f3c Mon Sep 17 00:00:00 2001 From: Theo Lee Date: Wed, 10 Dec 2025 00:58:28 +0900 Subject: [PATCH 8/8] add quantized-qwen3-moe example --- .../examples/quantized-qwen3-moe/README.md | 22 ++ .../examples/quantized-qwen3-moe/main.rs | 306 ++++++++++++++++++ 2 files changed, 328 insertions(+) create mode 100644 candle-examples/examples/quantized-qwen3-moe/README.md create mode 100644 candle-examples/examples/quantized-qwen3-moe/main.rs diff --git a/candle-examples/examples/quantized-qwen3-moe/README.md b/candle-examples/examples/quantized-qwen3-moe/README.md new file mode 100644 index 0000000000..b0f25727c0 --- /dev/null +++ b/candle-examples/examples/quantized-qwen3-moe/README.md @@ -0,0 +1,22 @@ +# candle-quantized-qwen3 + +[Qwen3](<(https://qwenlm.github.io/blog/qwen3/)>) is an upgraded version of Qwen2.5, released by Alibaba Cloud. +Here is the MoE version of Qwen3, but quantized. + +## Running the example + +```bash +cargo run --example quantized-qwen3-moe --release -- --prompt "Write a function to count prime numbers up to N." +``` + +30b is used by default, 235b model is available via `--which` argument. + +```bash +cargo run --example quantized-qwen3-moe --release -- --which 235b --prompt "A train is travelling at 120mph, how far does it travel in 3 minutes 30 seconds?" +``` + +To run on cuda(gpu). + +```bash +cargo run --example quantized-qwen3-moe --release --features cuda -- --prompt "Write a function to count prime numbers up to N." +``` diff --git a/candle-examples/examples/quantized-qwen3-moe/main.rs b/candle-examples/examples/quantized-qwen3-moe/main.rs new file mode 100644 index 0000000000..dd4bcdb0d5 --- /dev/null +++ b/candle-examples/examples/quantized-qwen3-moe/main.rs @@ -0,0 +1,306 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::{Parser, ValueEnum}; +use std::io::Write; +use tokenizers::Tokenizer; + +use candle::quantized::gguf_file; +use candle::Tensor; +use candle_transformers::generation::{LogitsProcessor, Sampling}; + +use candle_examples::token_output_stream::TokenOutputStream; +use candle_transformers::models::quantized_qwen3_moe::ModelWeights as Qwen3_MoE; + +const DEFAULT_PROMPT: &str = "Write a Rust function to calculate the factorial of a given number."; + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum Which { + #[value(name = "30b")] + W3_MoE_30b, + #[value(name = "235b")] + W3_MoE_235b, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp + #[arg(long)] + model: Option, + + /// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way + /// and 'chat' for an interactive model where history of previous prompts and generated tokens + /// is preserved. + #[arg(long)] + prompt: Option, + + /// The length of the sample to generate (in tokens). + #[arg(short = 'n', long, default_value_t = 1000)] + sample_len: usize, + + /// The tokenizer config in json format. + #[arg(long)] + tokenizer: Option, + + /// The temperature used to generate samples, use 0 for greedy sampling. + #[arg(long, default_value_t = 0.8)] + temperature: f64, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// Only sample among the top K samples. + #[arg(long)] + top_k: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// Process prompt elements separately. + #[arg(long)] + split_prompt: bool, + + /// Run on CPU rather than GPU even if a GPU is available. + #[arg(long)] + cpu: bool, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + + /// The model size to use. + #[arg(long, default_value = "30b")] + which: Which, +} + +impl Args { + fn tokenizer(&self) -> anyhow::Result { + let tokenizer_path = match &self.tokenizer { + Some(config) => std::path::PathBuf::from(config), + None => { + let api = hf_hub::api::sync::Api::new()?; + let repo = match self.which { + Which::W3_MoE_30b => "Qwen/Qwen3-30B-A3B", + Which::W3_MoE_235b => "Qwen/Qwen3-235B-A22B", + }; + let api = api.model(repo.to_string()); + api.get("tokenizer.json")? + } + }; + Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg) + } + + fn model(&self) -> anyhow::Result { + let model_path = match &self.model { + Some(config) => std::path::PathBuf::from(config), + None => { + let (repo, filename, revision) = match self.which { + Which::W3_MoE_30b => ( + "unsloth/Qwen3-30B-A3B-GGUF", + "Qwen3-30B-A3B-Q4_K_M.gguf", + "main", + ), + Which::W3_MoE_235b => ( + "unsloth/Qwen3-235B-A22B-GGUF", + "Qwen3-235B-A22B-Q4_K_M-00001-of-00003.gguf", + "main", + ), + }; + let api = hf_hub::api::sync::Api::new()?; + api.repo(hf_hub::Repo::with_revision( + repo.to_string(), + hf_hub::RepoType::Model, + revision.to_string(), + )) + .get(filename)? + } + }; + Ok(model_path) + } +} + +fn format_size(size_in_bytes: usize) -> String { + if size_in_bytes < 1_000 { + format!("{size_in_bytes}B") + } else if size_in_bytes < 1_000_000 { + format!("{:.2}KB", size_in_bytes as f64 / 1e3) + } else if size_in_bytes < 1_000_000_000 { + format!("{:.2}MB", size_in_bytes as f64 / 1e6) + } else { + format!("{:.2}GB", size_in_bytes as f64 / 1e9) + } +} + +fn main() -> anyhow::Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature, args.repeat_penalty, args.repeat_last_n + ); + + let model_path = args.model()?; + let mut file = std::fs::File::open(&model_path)?; + let start = std::time::Instant::now(); + let device = candle_examples::device(args.cpu)?; + + let mut model = { + let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?; + let mut total_size_in_bytes = 0; + for (_, tensor) in model.tensor_infos.iter() { + let elem_count = tensor.shape.elem_count(); + total_size_in_bytes += + elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size(); + } + println!( + "loaded {:?} tensors ({}) in {:.2}s", + model.tensor_infos.len(), + &format_size(total_size_in_bytes), + start.elapsed().as_secs_f32(), + ); + Qwen3_MoE::from_gguf(model, &mut file, &device, None)? + }; + println!("model built"); + + let tokenizer = args.tokenizer()?; + let mut tos = TokenOutputStream::new(tokenizer); + let prompt_str = args + .prompt + .clone() + .unwrap_or_else(|| DEFAULT_PROMPT.to_string()); + + let prompt_str = format!("<|im_start|>user\n{prompt_str}<|im_end|>\n<|im_start|>assistant\n"); + print!("formatted prompt: {}", &prompt_str); + + let tokens = tos + .tokenizer() + .encode(prompt_str, true) + .map_err(anyhow::Error::msg)?; + + let tokens = tokens.get_ids(); + + let to_sample = args.sample_len.saturating_sub(1); + + let mut all_tokens = vec![]; + + let mut logits_processor = { + let temperature = args.temperature; + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (args.top_k, args.top_p) { + (None, None) => Sampling::All { temperature }, + (Some(k), None) => Sampling::TopK { k, temperature }, + (None, Some(p)) => Sampling::TopP { p, temperature }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(args.seed, sampling) + }; + + let start_prompt_processing = std::time::Instant::now(); + + let mut next_token = if !args.split_prompt { + let input = Tensor::new(tokens, &device)?.unsqueeze(0)?; + let logits = model.forward(&input, 0)?; + let logits = logits.squeeze(0)?; + logits_processor.sample(&logits)? + } else { + let mut next_token = 0; + for (pos, token) in tokens.iter().enumerate() { + let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?; + let logits = model.forward(&input, pos)?; + let logits = logits.squeeze(0)?; + next_token = logits_processor.sample(&logits)? + } + next_token + }; + + let prompt_dt = start_prompt_processing.elapsed(); + + all_tokens.push(next_token); + + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + + let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap(); + + let start_post_prompt = std::time::Instant::now(); + + let mut sampled = 0; + for index in 0..to_sample { + let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; + let logits = model.forward(&input, tokens.len() + index)?; + let logits = logits.squeeze(0)?; + let logits = if args.repeat_penalty == 1. { + logits + } else { + let start_at = all_tokens.len().saturating_sub(args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &all_tokens[start_at..], + )? + }; + next_token = logits_processor.sample(&logits)?; + all_tokens.push(next_token); + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + sampled += 1; + if next_token == eos_token { + break; + }; + } + + if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? { + print!("{rest}"); + } + + std::io::stdout().flush()?; + let dt = start_post_prompt.elapsed(); + println!( + "\n\n{:4} prompt tokens processed: {:.2} token/s", + tokens.len(), + tokens.len() as f64 / prompt_dt.as_secs_f64(), + ); + println!( + "{sampled:4} tokens generated: {:.2} token/s", + sampled as f64 / dt.as_secs_f64(), + ); + Ok(()) +}