From bbd4f34963b866e5f29d28e266dbe6244d48e5b1 Mon Sep 17 00:00:00 2001 From: dengyingxu Date: Tue, 2 Dec 2025 23:01:39 +0800 Subject: [PATCH 1/7] feat: add wrappers for ATB and ACLNN fused operators. --- xllm/core/kernels/npu/CMakeLists.txt | 23 +- .../kernels/npu/{rms_norm.h => active.cpp} | 26 +- xllm/core/kernels/npu/attention.cpp | 65 +++ .../npu/custom_functions_npu/atb_common.cpp | 175 +++++++ .../npu/custom_functions_npu/atb_common.h | 494 ++++++++++++++++++ .../operation_cache_compute.cpp | 188 +++++++ .../operation_cache_compute.h | 148 ++++++ .../custom_functions_npu/operation_create.h | 127 +++++ .../npu/custom_functions_npu/utils.cpp | 81 +++ .../kernels/npu/custom_functions_npu/utils.h | 103 ++++ xllm/core/kernels/npu/fused_layernorm.cpp | 36 ++ .../core/kernels/npu/{linear.h => matmul.cpp} | 31 +- xllm/core/kernels/npu/npu_ops_api.h | 62 +++ xllm/core/kernels/npu/ops_npu/npu_ops.h | 56 ++ .../npu/ops_npu/paged_attention_atb.cpp | 62 +++ .../npu/ops_npu/reshape_and_cach_atb.cpp | 60 +++ .../npu/ops_npu/self_attention_atb.cpp | 73 +++ xllm/core/kernels/npu/rope.cpp | 42 ++ xllm/core/kernels/npu/rope.h | 30 -- xllm/core/kernels/npu/split.h | 35 -- xllm/core/kernels/ops_api.cpp | 46 ++ xllm/core/kernels/ops_api.h | 4 + xllm/core/kernels/param.h | 3 + xllm/core/layers/common/rotary_embedding.cpp | 2 +- 24 files changed, 1873 insertions(+), 99 deletions(-) rename xllm/core/kernels/npu/{rms_norm.h => active.cpp} (60%) create mode 100644 xllm/core/kernels/npu/attention.cpp create mode 100644 xllm/core/kernels/npu/custom_functions_npu/atb_common.cpp create mode 100644 xllm/core/kernels/npu/custom_functions_npu/atb_common.h create mode 100644 xllm/core/kernels/npu/custom_functions_npu/operation_cache_compute.cpp create mode 100644 xllm/core/kernels/npu/custom_functions_npu/operation_cache_compute.h create mode 100644 xllm/core/kernels/npu/custom_functions_npu/operation_create.h create mode 100644 xllm/core/kernels/npu/custom_functions_npu/utils.cpp create mode 100644 xllm/core/kernels/npu/custom_functions_npu/utils.h create mode 100644 xllm/core/kernels/npu/fused_layernorm.cpp rename xllm/core/kernels/npu/{linear.h => matmul.cpp} (61%) create mode 100644 xllm/core/kernels/npu/npu_ops_api.h create mode 100644 xllm/core/kernels/npu/ops_npu/npu_ops.h create mode 100644 xllm/core/kernels/npu/ops_npu/paged_attention_atb.cpp create mode 100644 xllm/core/kernels/npu/ops_npu/reshape_and_cach_atb.cpp create mode 100644 xllm/core/kernels/npu/ops_npu/self_attention_atb.cpp create mode 100644 xllm/core/kernels/npu/rope.cpp delete mode 100644 xllm/core/kernels/npu/rope.h delete mode 100644 xllm/core/kernels/npu/split.h diff --git a/xllm/core/kernels/npu/CMakeLists.txt b/xllm/core/kernels/npu/CMakeLists.txt index 412f7f188..380df2e87 100644 --- a/xllm/core/kernels/npu/CMakeLists.txt +++ b/xllm/core/kernels/npu/CMakeLists.txt @@ -2,13 +2,28 @@ include(cc_library) add_subdirectory(xllm_ops) +file(GLOB_RECURSE XLLM_CORE_KERNELS_NPU_HEADER + "${CMAKE_CURRENT_LIST_DIR}/custom_functions_npu/*.h" + "${CMAKE_CURRENT_LIST_DIR}/ops_npu/*.h" + "${CMAKE_CURRENT_LIST_DIR}/*.h" +) + +file(GLOB_RECURSE XLLM_CORE_KERNELS_NPU_SRCS + "${CMAKE_CURRENT_LIST_DIR}/custom_functions_npu/*.cpp" + "${CMAKE_CURRENT_LIST_DIR}/ops_npu/*.cpp" + "${CMAKE_CURRENT_LIST_DIR}/*.cpp" +) + cc_library( NAME npu_kernels HDRS - linear.h - split.h - rope.h + ${XLLM_CORE_KERNELS_NPU_HEADER} + SRCS + ${XLLM_CORE_KERNELS_NPU_SRCS} DEPS - # spdlog::spdlog + :model_context + glog::glog + torch + torch_npu ) diff --git a/xllm/core/kernels/npu/rms_norm.h b/xllm/core/kernels/npu/active.cpp similarity index 60% rename from xllm/core/kernels/npu/rms_norm.h rename to xllm/core/kernels/npu/active.cpp index ed7f8d047..e3e66be54 100644 --- a/xllm/core/kernels/npu/rms_norm.h +++ b/xllm/core/kernels/npu/active.cpp @@ -13,20 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#pragma once -#include "impl/npu_rms_norm_impl.h" +#include -namespace xllm { -namespace kernel { +#include "npu_ops_api.h" +#include "ops_npu/npu_ops.h" -class RmsNorm : public torch::nn::ModuleHolder { - public: - using torch::nn::ModuleHolder::ModuleHolder; - using Impl __attribute__((__unused__)) = NpuRmsNormImpl; +namespace xllm::kernel::npu { - RmsNorm(const ModelContext& context) - : ModuleHolder(std::make_shared(context)) {} -}; - -} // namespace kernel -} // namespace xllm +torch::Tensor active(const torch::Tensor& input, const std::string& act_mode) { + if (act_mode != "silu" && act_mode != "swiglu") { + throw std::runtime_error( + "Only swiglu activation is supported in NPU active"); + } + return at_npu::native::custom_ops::npu_swiglu(input); +} +} // namespace xllm::kernel::npu \ No newline at end of file diff --git a/xllm/core/kernels/npu/attention.cpp b/xllm/core/kernels/npu/attention.cpp new file mode 100644 index 000000000..0b4d80347 --- /dev/null +++ b/xllm/core/kernels/npu/attention.cpp @@ -0,0 +1,65 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "npu_ops_api.h" +#include "ops_npu/npu_ops.h" +namespace xllm::kernel::npu { + +void reshape_paged_cache(torch::Tensor& key, + std::optional& value, + torch::Tensor& k_cache, + std::optional& v_cache, + const torch::Tensor& slot_mapping) { + atb::_npu_reshape_and_cache( + key, value.value(), k_cache, v_cache.value(), slot_mapping); +} + +void batch_prefill(const torch::Tensor& query, + const torch::Tensor& key, + const torch::Tensor& value, + const torch::Tensor& mask, + const torch::Tensor& seq_len, + float scale, + torch::Tensor& output) { + auto num_heads = query.size(-2); + auto num_kv_heads = key.size(-2); + atb::_npu_flash_attention( + query, key, value, mask, seq_len, scale, num_heads, num_kv_heads, output); +} + +void batch_decode(const torch::Tensor& query, + const torch::Tensor& k_cache, + const torch::Tensor& v_cache, + float scale, + const torch::Tensor& block_table, + const torch::Tensor& seq_lens, + torch::Tensor& output) { + auto head_size = query.size(-1); + auto num_heads = query.size(-2); + auto num_kv_heads = k_cache.size(-2); + auto q = query.view({-1, num_heads, head_size}); + auto o = output.view({-1, num_heads, head_size}); + atb::_npu_paged_attention(q, + k_cache, + v_cache, + num_kv_heads, + num_heads, + scale, + block_table, + seq_lens, + o); +} + +} // namespace xllm::kernel::npu \ No newline at end of file diff --git a/xllm/core/kernels/npu/custom_functions_npu/atb_common.cpp b/xllm/core/kernels/npu/custom_functions_npu/atb_common.cpp new file mode 100644 index 000000000..70cf6b6c9 --- /dev/null +++ b/xllm/core/kernels/npu/custom_functions_npu/atb_common.cpp @@ -0,0 +1,175 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "atb_common.h" + +namespace atb { +atb::Tensor at_tensor_to_atb_tensor(const at::Tensor at_tensor) { + static std::map dtype_map = { + {at::ScalarType::Bool, ACL_BOOL}, + {at::ScalarType::Byte, ACL_UINT8}, + {at::ScalarType::Char, ACL_INT8}, + {at::ScalarType::Half, ACL_FLOAT16}, + {at::ScalarType::Float, ACL_FLOAT}, + {at::ScalarType::Int, ACL_INT32}, + {at::ScalarType::Long, ACL_INT64}, + {at::ScalarType::BFloat16, ACL_BF16}, + {at::ScalarType::Double, ACL_DOUBLE}, + {at::ScalarType::Short, ACL_INT16}, + {at::ScalarType::ComplexHalf, ACL_COMPLEX32}, + {at::ScalarType::ComplexFloat, ACL_COMPLEX64}, + {at::ScalarType::ComplexDouble, ACL_COMPLEX128}, + }; + + TORCH_CHECK(at_tensor.is_contiguous(), "at_tensor is not contiguous"); + atb::Tensor tensor; + tensor.desc.format = atb::utils::get_format_for_atb(at_tensor); + if (at_tensor.device().type() == at::kCPU) { + tensor.hostData = at_tensor.data_ptr(); + } else { + tensor.deviceData = at_tensor.data_ptr(); + } + + tensor.desc.shape.dimNum = at_tensor.sizes().size(); + for (uint64_t i = 0; i < at_tensor.sizes().size(); i++) { + tensor.desc.shape.dims[i] = at_tensor.sizes()[i]; + } + + auto dtype_iterator = dtype_map.find(at_tensor.scalar_type()); + TORCH_CHECK(dtype_iterator != dtype_map.end(), + "not support dtype: ", + at_tensor.scalar_type()); + tensor.desc.dtype = dtype_iterator->second; + + tensor.dataSize = atb::Utils::GetTensorSize(tensor); + + return tensor; +} + +void run_atb_cmd_v1(atb::Operation* op, + const ParamSetter& paramsetter, + const std::string& name) { + aclrtStream stream = c10_npu::getCurrentNPUStream().stream(false); + auto context_ptr = atb::utils::get_context(stream); + atb::VariantPack variant_pack = paramsetter.variant_pack_; + uint64_t workspace_size = operation_setup(variant_pack, op, context_ptr); + at::Tensor workspace_tensor; + void* workspace_ptr = nullptr; + if (workspace_size != 0) { + at::TensorOptions options = at::TensorOptions(c10::DeviceType::PrivateUse1); + workspace_tensor = at::empty({workspace_size}, options.dtype(at::kByte)); + workspace_ptr = const_cast(workspace_tensor.storage().data()); + } + const c10::SmallVector& cpu_tensors = + paramsetter.tensor_maintainer_.cpu_tensors; + auto acl_call = [variant_pack, + workspace_ptr, + workspace_size, + context_ptr, + op, + cpu_tensors]() -> int { + auto st = op->Execute( + variant_pack, (uint8_t*)workspace_ptr, workspace_size, context_ptr); + DestroyOperation(op); + return st; + }; + at_npu::native::OpCommand::RunOpApiV2(name, acl_call); +} + +void run_atb_cmd_v2(atb::Operation* op, + const ParamSetter& paramsetter, + const std::string& name) { + aclrtStream stream = c10_npu::getCurrentNPUStream().stream(false); + atb::VariantPack variant_pack = paramsetter.variant_pack_; + const c10::SmallVector& cpu_tensors = + paramsetter.tensor_maintainer_.cpu_tensors; + auto acl_call = [op, variant_pack, stream, cpu_tensors]() -> int { + auto context_ptr = atb::utils::get_context(stream); + uint64_t workspace_size = operation_setup(variant_pack, op, context_ptr); + at::Tensor workspace_tensor; + void* workspace_ptr = nullptr; + if (workspace_size != 0) { + workspace_tensor = + at_npu::native::allocate_workspace(workspace_size, stream); + workspace_ptr = const_cast(workspace_tensor.storage().data()); + } + auto st = op->Execute( + variant_pack, (uint8_t*)workspace_ptr, workspace_size, context_ptr); + return 0; + }; + at_npu::native::OpCommand::RunOpApiV2(name, acl_call); +} + +void run_atb_cmd(atb::Operation* op, + const ParamSetter& paramsetter, + const std::string& name) { + const auto is_capturing = + static_cast(c10_npu::currentStreamCaptureStatusMayInitCtx()); + if (is_capturing) { + run_atb_cmd_v1(op, paramsetter, name); + } else { + run_atb_cmd_v2(op, paramsetter, name); + } +} + +ParamSetter& ParamSetter::Input(const at::Tensor& tensor, + const bool& format_trans) { + if (!tensor.defined()) { + variant_pack_.inTensors.push_back(atb::Tensor()); + return *this; + } + at::Tensor new_tensor = tensor.contiguous(); + if (format_trans) { + new_tensor = atb::utils::format_trans(new_tensor); + } + atb::Tensor atb_tensor; + if (new_tensor.device().type() == at::kCPU) { + auto tensor_clone = new_tensor.clone(); + atb_tensor = at_tensor_to_atb_tensor(tensor_clone); + tensor_maintainer_.cpu_tensors.emplace_back(std::move(tensor_clone)); + } else { + atb_tensor = at_tensor_to_atb_tensor(new_tensor); + tensor_maintainer_.contiguous_tensors.emplace_back(std::move(new_tensor)); + } + variant_pack_.inTensors.push_back(atb_tensor); + return *this; +} + +ParamSetter& ParamSetter::Input(const c10::optional& tensor, + const bool& format_trans) { + if (!tensor.has_value()) { + variant_pack_.inTensors.push_back(atb::Tensor()); + return *this; + } + return Input(tensor.value(), format_trans); +} + +ParamSetter& ParamSetter::Output(at::Tensor& output) { + auto atb_tensor = at_tensor_to_atb_tensor(output); + variant_pack_.outTensors.push_back(atb_tensor); + return *this; +} + +uint64_t operation_setup(atb::VariantPack variant_pack, + atb::Operation* operation, + atb::Context* context_ptr) { + uint64_t workspace_size = 0; + atb::Status status = + operation->Setup(variant_pack, workspace_size, context_ptr); + TORCH_CHECK(status == 0, operation->GetName(), " setup failed!"); + return workspace_size; +} + +} // namespace atb \ No newline at end of file diff --git a/xllm/core/kernels/npu/custom_functions_npu/atb_common.h b/xllm/core/kernels/npu/custom_functions_npu/atb_common.h new file mode 100644 index 000000000..198c425a2 --- /dev/null +++ b/xllm/core/kernels/npu/custom_functions_npu/atb_common.h @@ -0,0 +1,494 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLLM_CORE_KERNELS_NPU_ATB_COMMON_H +#define XLLM_CORE_KERNELS_NPU_ATB_COMMON_H +#include +#include +#include +#include +#include +#include + +#include "./operation_create.h" +#include "atb/atb_infer.h" +#include "utils.h" + +namespace atb { + +using aclTensor = struct aclTensor; +constexpr int64_t MAX_DIM_NUM = 5; +const int N = 32; + +using _aclCreateTensor = aclTensor* (*)(const int64_t* view_dims, + uint64_t view_dims_num, + aclDataType data_type, + const int64_t* stride, + int64_t offset, + aclFormat format, + const int64_t* storage_dims, + uint64_t storage_dims_num, + void* tensor_data); +using _aclDestroyTensor = int (*)(const aclTensor*); + +using AtbApiFunc = int (*)(void*, uint64_t, atb::Operation*, atb::Context*); + +#define GET_OP_API_FUNC(api_name) \ + reinterpret_cast<_##api_name>(get_api_func_addr(#api_name)) + +inline const char* get_atb_api_lib_name(void) { return "libatb.so"; } + +inline const char* get_op_api_lib_name(void) { return "libopapi.so"; } + +inline void* get_api_lib_handler(const char* lib_name) { + auto handler = dlopen(lib_name, RTLD_LAZY); + if (handler == nullptr) { + ASCEND_LOGW("dlopen %s failed, error:%s.", lib_name, dlerror()); + } + return handler; +} + +inline void* get_api_func_addr_in_lib(void* handler, + const char* lib_name, + const char* api_name) { + auto func_addr = dlsym(handler, api_name); + if (func_addr == nullptr) { + ASCEND_LOGW( + "dlsym %s from %s failed, error:%s.", api_name, lib_name, dlerror()); + } + return func_addr; +} + +inline void* get_api_func_addr(const char* api_name) { + static auto atb_api_handler = get_api_lib_handler(get_atb_api_lib_name()); + if (atb_api_handler != nullptr) { + auto func_addr = get_api_func_addr_in_lib( + atb_api_handler, get_atb_api_lib_name(), api_name); + if (func_addr != nullptr) { + return func_addr; + } + } + static auto op_api_handler = get_api_lib_handler(get_op_api_lib_name()); + if (op_api_handler != nullptr) { + auto func_addr = get_api_func_addr_in_lib( + op_api_handler, get_op_api_lib_name(), api_name); + if (func_addr != nullptr) { + return func_addr; + } + TORCH_CHECK(false, "get_api_func_addr not found ", api_name); + } +} + +struct TensorMaintainer { + c10::SmallVector + contiguous_tensors; // npu tensor's life should maintain when + // uncontiguous to contiguous. + c10::SmallVector + cpu_tensors; // cpu tensor's life should maintain in taskqueue. +}; + +inline aclTensor* convert_type(TensorMaintainer& maintainer, + const at::Tensor& tensor) { + static const auto aclCreateTensor = + reinterpret_cast<_aclCreateTensor>(get_api_func_addr("aclCreateTensor")); + if (aclCreateTensor == nullptr) { + return nullptr; + } + + if (!tensor.defined()) { + return nullptr; + } + at::Tensor at_tensor = tensor.contiguous(); + aclFormat format = atb::utils::get_format_for_atb(at_tensor); + + at::ScalarType scalar_data_type = at_tensor.scalar_type(); + aclDataType acl_data_type = + atb::utils::convert_to_acl_data_type(scalar_data_type); + c10::SmallVector storageDims; + // if acl_data_type is ACL_STRING, storageDims is empty. + if (acl_data_type != ACL_STRING) { + TORCH_CHECK(at_tensor.itemsize() > 0, + "the itemsize of tensor must be greater than 0."); + storageDims.push_back(at_tensor.storage().nbytes() / at_tensor.itemsize()); + } + + const auto dimNum = at_tensor.sizes().size(); + auto acl_tensor = + aclCreateTensor(at_tensor.sizes().data(), + at_tensor.sizes().size(), + acl_data_type, + at_tensor.strides().data(), + at_tensor.storage_offset(), + format, + storageDims.data(), + storageDims.size(), + const_cast(at_tensor.storage().data())); + if (at_tensor.device().type() == at::kCPU) { + maintainer.cpu_tensors.emplace_back(std::move(at_tensor)); + } else { + maintainer.contiguous_tensors.emplace_back(std::move(at_tensor)); + } + return acl_tensor; +} + +inline aclTensor* convert_type(TensorMaintainer& maintainer, + const c10::optional& opt_tensor) { + if (opt_tensor.has_value() && opt_tensor.value().defined()) { + return convert_type(maintainer, opt_tensor.value()); + } + + return nullptr; +} + +template +T convert_type(TensorMaintainer& maintainer, T value) { + return value; +} + +template +constexpr auto convert_types(TensorMaintainer& maintainer, Ts&... args) { + return std::make_tuple(convert_type(maintainer, args)...); +} + +struct TensorStruct { + void* data_ptr = nullptr; // at_tensor.storage().data() + at::ScalarType scalar_type; // at_tensor.scalar_type() + size_t nbytes; // at_tensor.storage().nbytes() + size_t itemsize; // at_tensor.itemsize() + int64_t storage_offset; // at_tensor.storage_offset() + std::vector sizes; // at_tensor.sizes() + std::vector strides; // at_tensor.strides() + aclFormat format; // at_tensor format + + TensorStruct(void* data_ptr_, + at::ScalarType scalar_type_, + size_t nbytes_, + size_t itemsize_, + int64_t storage_offset_, + at::IntArrayRef sizes_, + at::IntArrayRef strides_, + aclFormat format_) + : data_ptr(data_ptr_), + scalar_type(scalar_type_), + nbytes(nbytes_), + itemsize(itemsize_), + storage_offset(storage_offset_), + sizes(sizes_.vec()), + strides(strides_.vec()), + format(format_) {} +}; +using TensorStructPtr = std::shared_ptr; + +inline TensorStructPtr copy_type_v2(TensorMaintainer& maintainer, + const at::Tensor& tensor) { + if (!tensor.defined()) { + return nullptr; + } + at::Tensor at_tensor = tensor.contiguous(); + aclFormat format = atb::utils::get_format_for_atb(at_tensor); + std::shared_ptr tensor_structptr = + std::make_shared( + const_cast(at_tensor.storage().data()), + at_tensor.scalar_type(), + at_tensor.storage().nbytes(), + at_tensor.itemsize(), + at_tensor.storage_offset(), + at_tensor.sizes(), + at_tensor.strides(), + format); + if (at_tensor.device().type() == at::kCPU) { + maintainer.cpu_tensors.emplace_back(std::move(at_tensor)); + } else { + maintainer.contiguous_tensors.emplace_back(std::move(at_tensor)); + } + return tensor_structptr; +} + +inline TensorStructPtr copy_type_v2( + TensorMaintainer& maintainer, + const c10::optional& opt_tensor) { + if (opt_tensor.has_value() && opt_tensor.value().defined()) { + return copy_type_v2(maintainer, opt_tensor.value()); + } + + return nullptr; +} + +template +T copy_type_v2(TensorMaintainer& maintainer, T value) { + return value; +} + +inline aclTensor* convert_type_v2(TensorStructPtr at_tensor) { + static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor); + if (aclCreateTensor == nullptr) { + return nullptr; + } + + if (at_tensor == nullptr) { + return nullptr; + } + at::ScalarType scalar_data_type = (*at_tensor).scalar_type; + aclDataType acl_data_type = + atb::utils::convert_to_acl_data_type(scalar_data_type); + c10::SmallVector storageDims; + if (acl_data_type != ACL_STRING) { + TORCH_CHECK((*at_tensor).itemsize > 0, + "the itemsize of tensor must be greater than 0."); + storageDims.push_back((*at_tensor).nbytes / (*at_tensor).itemsize); + } + + const auto dimNum = (*at_tensor).sizes.size(); + + auto acl_tensor = aclCreateTensor((*at_tensor).sizes.data(), + (*at_tensor).sizes.size(), + acl_data_type, + (*at_tensor).strides.data(), + (*at_tensor).storage_offset, + (*at_tensor).format, + storageDims.data(), + storageDims.size(), + (*at_tensor).data_ptr); + return acl_tensor; +} + +template +T convert_type_v2(T value) { + return value; +} + +template +auto convert_types_impl_v2(const Tuple& t, std::index_sequence) { + return std::make_tuple(convert_type_v2(std::get(t))...); +} + +template +constexpr auto convert_types_v2(const std::tuple& args, + uint64_t* workspace_size_addr, + atb::Operation** op_addr, + atb::Context* context_ptr) { + auto convert_args = + convert_types_impl_v2(args, std::make_index_sequence{}); + auto appends = std::make_tuple(workspace_size_addr, op_addr, context_ptr); + return std::tuple_cat(convert_args, appends); +} + +template +constexpr auto copy_types_v2(TensorMaintainer& maintainer, Ts&... args) { + return std::make_tuple(copy_type_v2(maintainer, args)...); +} + +template +auto call(Function f, Tuple t, std::index_sequence) { + return f(std::get(t)...); +} + +template +auto call(Function f, Tuple t) { + static constexpr auto size = std::tuple_size::value; + return call(f, t, std::make_index_sequence{}); +} + +template +auto convert_to_op_api_func(const Tuple& params, + void* opApiAddr, + std::index_sequence) { + using OpApiFunc = + int (*)(typename std::decay(params))>::type...); + auto func = reinterpret_cast(opApiAddr); + return func; +} + +template +auto convert_to_op_api_func(const Tuple& params, void* opApiAddr) { + static constexpr auto size = std::tuple_size::value; + return convert_to_op_api_func( + params, opApiAddr, std::make_index_sequence{}); +} + +inline void release(atb::Context* context) {} + +inline void release(aclTensor* p) { + static const auto aclDestroyTensor = GET_OP_API_FUNC(aclDestroyTensor); + if (aclDestroyTensor == nullptr) { + return; + } + aclDestroyTensor(p); +} + +template +void release(T value) { + (void)value; +} + +template +void call_release(Tuple t, std::index_sequence) { + (void)std::initializer_list{(release(std::get(t)), 0)...}; +} + +template +void release_convert_types(Tuple& t) { + static constexpr auto size = std::tuple_size::value; + call_release(t, std::make_index_sequence{}); +} + +#define EXEC_ATB_CMD_V1(atb_api, ...) \ + do { \ + static const auto getWorkspaceSizeFuncAddr = \ + get_api_func_addr(#atb_api "GetWorkspaceSize"); \ + static const auto atbApiFuncAddr = get_api_func_addr(#atb_api); \ + TORCH_CHECK( \ + getWorkspaceSizeFuncAddr != nullptr && atbApiFuncAddr != nullptr, \ + #atb_api, \ + " or ", \ + #atb_api "GetWorkspaceSize", \ + " not in ", \ + get_atb_api_lib_name(), \ + ", or ", \ + get_atb_api_lib_name(), \ + "not found."); \ + auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \ + auto context_ptr = atb::utils::get_context(acl_stream); \ + uint64_t workspace_size = 0; \ + uint64_t* workspace_size_addr = &workspace_size; \ + atb::Operation* op = nullptr; \ + atb::Operation** op_addr = &op; \ + TensorMaintainer tensor_maintainer; \ + auto converted_params = convert_types(tensor_maintainer, \ + __VA_ARGS__, \ + workspace_size_addr, \ + op_addr, \ + context_ptr); \ + static auto getWorkspaceSizeFunc = \ + convert_to_op_api_func(converted_params, getWorkspaceSizeFuncAddr); \ + auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \ + TORCH_CHECK(workspace_status == 0, "call " #atb_api " failed, detail:"); \ + void* workspace_addr = nullptr; \ + at::Tensor workspace_tensor; \ + if (workspace_size != 0) { \ + at::TensorOptions options = \ + at::TensorOptions(c10::DeviceType::PrivateUse1); \ + workspace_tensor = \ + at::empty({workspace_size}, options.dtype(at::kByte)); \ + workspace_addr = const_cast(workspace_tensor.storage().data()); \ + } \ + const c10::SmallVector& cpu_tensors = \ + tensor_maintainer.cpu_tensors; \ + auto atb_call = [converted_params, \ + workspace_addr, \ + workspace_size, \ + context_ptr, \ + op, \ + cpu_tensors]() -> int { \ + AtbApiFunc atbApiFunc = reinterpret_cast(atbApiFuncAddr); \ + auto api_ret = \ + atbApiFunc(workspace_addr, workspace_size, op, context_ptr); \ + TORCH_CHECK(api_ret == 0, "call " #atb_api " failed, detail:"); \ + DestroyOperation(op); \ + release_convert_types(converted_params); \ + return api_ret; \ + }; \ + at_npu::native::OpCommand::RunOpApiV2(#atb_api, atb_call); \ + } while (false) + +#define EXEC_ATB_CMD_V2(atb_api, ...) \ + do { \ + static const auto getWorkspaceSizeFuncAddr = \ + get_api_func_addr(#atb_api "GetWorkspaceSize"); \ + static const auto AtbApiFuncAddr = get_api_func_addr(#atb_api); \ + TORCH_CHECK( \ + getWorkspaceSizeFuncAddr != nullptr && AtbApiFuncAddr != nullptr, \ + #atb_api, \ + " or ", \ + #atb_api "GetWorkspaceSize", \ + " not in ", \ + get_atb_api_lib_name(), \ + ", or ", \ + get_atb_api_lib_name(), \ + "not found."); \ + auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \ + TensorMaintainer tensor_maintainer; \ + auto copied_params = copy_types_v2(tensor_maintainer, __VA_ARGS__); \ + auto hash_id = compute_hash(std::string(#atb_api), __VA_ARGS__); \ + const c10::SmallVector& cpu_tensors = \ + tensor_maintainer.cpu_tensors; \ + auto atb_call = \ + [copied_params, acl_stream, hash_id, cpu_tensors]() -> int { \ + auto context_ptr = atb::utils::get_context(acl_stream); \ + uint64_t workspace_size = 0; \ + uint64_t* workspace_size_addr = &workspace_size; \ + OpParamCache& opParamCache = \ + OpParamCache::getInstance(); \ + atb::Operation* op = opParamCache.get_operation(hash_id); \ + atb::Operation** op_addr = &op; \ + int api_ret = 0; \ + auto converted_params = convert_types_v2( \ + copied_params, workspace_size_addr, op_addr, context_ptr); \ + auto getWorkspaceSizeFunc = \ + convert_to_op_api_func(converted_params, getWorkspaceSizeFuncAddr); \ + auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \ + opParamCache.save_operation(hash_id, op); \ + TORCH_CHECK(workspace_status == 0, \ + "call " #atb_api "GetWorkspaceSize failed"); \ + void* workspace_addr = nullptr; \ + at::Tensor workspace_tensor; \ + if (workspace_size != 0) { \ + workspace_tensor = \ + at_npu::native::allocate_workspace(workspace_size, acl_stream); \ + workspace_addr = const_cast(workspace_tensor.storage().data()); \ + } \ + AtbApiFunc atbApiFunc = reinterpret_cast(AtbApiFuncAddr); \ + api_ret = atbApiFunc(workspace_addr, workspace_size, op, context_ptr); \ + TORCH_CHECK(api_ret == 0, "call " #atb_api " failed"); \ + release_convert_types(converted_params); \ + return api_ret; \ + }; \ + at_npu::native::OpCommand::RunOpApiV2(#atb_api, atb_call); \ + } while (false) + +#define EXEC_ATB_CMD(atb_api, ...) \ + do { \ + const auto is_capturing = \ + static_cast(c10_npu::currentStreamCaptureStatusMayInitCtx()); \ + if (is_capturing) { \ + EXEC_ATB_CMD_V1(atb_api, __VA_ARGS__); \ + } else { \ + EXEC_ATB_CMD_V2(atb_api, __VA_ARGS__); \ + } \ + } while (false) + +atb::Tensor at_tensor_to_atb_tensor(const at::Tensor atTensor); +atb::Context* get_context(aclrtStream stream); +uint64_t operation_setup(atb::VariantPack variant_pack, + atb::Operation* operation, + atb::Context* context_ptr); +class ParamSetter { + public: + ParamSetter& Input(const at::Tensor& tensor, + const bool& format_trans = false); + ParamSetter& Input(const c10::optional& tensor, + const bool& format_trans = false); + ParamSetter& Output(at::Tensor& tensor); + atb::VariantPack variant_pack_; + TensorMaintainer tensor_maintainer_; +}; + +void run_atb_cmd(atb::Operation* op, + const ParamSetter& paramsetter, + const std::string& name); + +} // namespace atb + +#endif diff --git a/xllm/core/kernels/npu/custom_functions_npu/operation_cache_compute.cpp b/xllm/core/kernels/npu/custom_functions_npu/operation_cache_compute.cpp new file mode 100644 index 000000000..9d1f368b7 --- /dev/null +++ b/xllm/core/kernels/npu/custom_functions_npu/operation_cache_compute.cpp @@ -0,0 +1,188 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "operation_cache_compute.h" + +namespace atb { + +thread_local char g_hash_buf[g_hash_buf_size]; +thread_local int g_hash_offset = 0; +constexpr int g_rShift33Bits = 33; +constexpr uint64_t MIX_STEP1 = 18397679294719823053LLU; +constexpr uint64_t MIX_STEP2 = 14181476777654086739LLU; + +void add_param_to_buf(const string& s) { + MEMCPY_TO_BUF(s.c_str(), static_cast(s.size())); +} + +void add_param_to_buf(const c10::optional& t) {} +void add_param_to_buf(const at::Tensor& t) {} + +void add_param_to_buf() {} + +inline uint64_t rotating_left(uint64_t x, uint8_t n) { + return (x << n) | (x >> (64 - n)); +} + +inline uint64_t mixture(uint64_t x) { + x ^= x >> g_rShift33Bits; + x *= MIX_STEP1; + x ^= x >> g_rShift33Bits; + x *= MIX_STEP2; + x ^= x >> g_rShift33Bits; + + return x; +} + +uint64_t gen_hash(const void* key, + const int len, + const uint32_t seed = 0xdeadb0d7) { + const uint8_t* data = static_cast(key); + const int block_num = len / 16; + uint64_t has = seed; + uint64_t hax = seed; + + const uint64_t c1 = 9782798678568883157LLU; + const uint64_t c2 = 5545529020109919103LLU; + + const uint64_t* blocks = + static_cast(static_cast(data)); + + for (int i = 0; i < block_num; i++) { + int even_num = 2; + uint64_t tmp1 = blocks[i * even_num]; + uint64_t tmp2 = blocks[i * even_num + 1]; + + int8_t bits_31 = 31; + tmp1 *= c1; + tmp1 = rotating_left(tmp1, bits_31); + tmp1 *= c2; + has ^= tmp1; + + int8_t bits_27 = 27; + has = rotating_left(has, bits_27); + has += hax; + has = has * 5 + 1390208809; + + int8_t bits_33 = 33; + tmp2 *= c2; + tmp2 = rotating_left(tmp2, bits_33); + tmp2 *= c1; + hax ^= tmp2; + + hax = rotating_left(hax, bits_31); + hax += has; + hax = hax * 5 + 944331445; + } + + const uint8_t* tail = data + block_num * 16; + uint64_t t1 = 0; + uint64_t t2 = 0; + switch (static_cast(len) & 15) { + case 15: + t2 ^= (static_cast(tail[14])) << 48; + [[fallthrough]]; + ; + case 14: + t2 ^= (static_cast(tail[13])) << 40; + [[fallthrough]]; + ; + case 13: + t2 ^= (static_cast(tail[12])) << 32; + [[fallthrough]]; + ; + case 12: + t2 ^= (static_cast(tail[11])) << 24; + [[fallthrough]]; + ; + case 11: + t2 ^= (static_cast(tail[10])) << 16; + [[fallthrough]]; + ; + case 10: + t2 ^= (static_cast(tail[9])) << 8; + [[fallthrough]]; + ; + case 9: + t2 ^= (static_cast(tail[8])) << 0; + t2 *= c2; + t2 = rotating_left(t2, 33); + t2 *= c1; + hax ^= t2; + [[fallthrough]]; + ; + case 8: + t1 ^= (static_cast(tail[7])) << 56; + [[fallthrough]]; + ; + case 7: + t1 ^= (static_cast(tail[6])) << 48; + [[fallthrough]]; + ; + case 6: + t1 ^= (static_cast(tail[5])) << 40; + [[fallthrough]]; + ; + case 5: + t1 ^= (static_cast(tail[4])) << 32; + [[fallthrough]]; + ; + case 4: + t1 ^= (static_cast(tail[3])) << 24; + [[fallthrough]]; + ; + case 3: + t1 ^= (static_cast(tail[2])) << 16; + [[fallthrough]]; + ; + case 2: + t1 ^= (static_cast(tail[1])) << 8; + [[fallthrough]]; + ; + case 1: + t1 ^= (static_cast(tail[0])) << 0; + t1 *= c1; + t1 = rotating_left(t1, 31); + t1 *= c2; + has ^= t1; + [[fallthrough]]; + ; + default: + break; + }; + + has ^= static_cast(len); + hax ^= static_cast(len); + + has += hax; + hax += has; + + has = mixture(has); + hax = mixture(hax); + + has += hax; + hax += has; + return hax; +} + +uint64_t calc_hash_id() { + if (g_hash_offset == g_hash_buf_max_size) { + return 0; + } + uint64_t hash_id = gen_hash(g_hash_buf, g_hash_offset); + return hash_id; +} + +} // namespace atb diff --git a/xllm/core/kernels/npu/custom_functions_npu/operation_cache_compute.h b/xllm/core/kernels/npu/custom_functions_npu/operation_cache_compute.h new file mode 100644 index 000000000..f7077aaa4 --- /dev/null +++ b/xllm/core/kernels/npu/custom_functions_npu/operation_cache_compute.h @@ -0,0 +1,148 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLLM_CORE_KERNELS_NPU_ATB_PARAM_OPERATION_CACHE_COMPUTE_H +#define XLLM_CORE_KERNELS_NPU_ATB_PARAM_OPERATION_CACHE_COMPUTE_H + +#include + +#include +#include +#include + +#include "atb/atb_infer.h" + +namespace atb { + +constexpr int g_hash_buf_size = 8192; +constexpr int g_hash_buf_max_size = g_hash_buf_size + 1024; +extern thread_local char g_hash_buf[g_hash_buf_size]; +extern thread_local int g_hash_offset; + +#define MEMCPY_TO_BUF(data_expression, size_expression) \ + if (g_hash_offset + (size_expression) > g_hash_buf_size) { \ + g_hash_offset = g_hash_buf_max_size; \ + return; \ + } \ + memcpy(g_hash_buf + g_hash_offset, data_expression, size_expression); \ + g_hash_offset += size_expression; + +uint64_t calc_hash_id(); + +template +void add_param_to_buf(const T& value) { + MEMCPY_TO_BUF(&value, sizeof(T)); +} + +void add_param_to_buf(const string& s); +void add_param_to_buf(const c10::optional& t); +void add_param_to_buf(const at::Tensor& t); +void add_param_to_buf(); + +template +void add_param_to_buf(const std::string& name, const T& value) { + add_param_to_buf(name); + add_param_to_buf(value); +} + +template +void add_param_to_buf(const T& arg, Args&... args) { + add_param_to_buf(arg); + add_param_to_buf(args...); +} + +template +struct HashOpParam { + void operator()(const T& param) const {}; +}; +template <> +struct HashOpParam { + void operator()(const atb::infer::RmsNormParam& param) const { + add_param_to_buf("epsilon", param.normParam.epsilon); + add_param_to_buf("layerType", param.layerType); + add_param_to_buf("quantType", param.normParam.quantType); + } +}; + +template <> +struct HashOpParam { + void operator()(const atb::infer::GroupTopkParam& param) const { + add_param_to_buf("groupNum", param.groupNum); + add_param_to_buf("k", param.k); + add_param_to_buf("groupMultiFlag", param.groupMultiFlag); + add_param_to_buf("n", param.n); + } +}; + +template <> +struct HashOpParam { + void operator()(const atb::infer::PagedAttentionParam& param) const { + add_param_to_buf("num_kv_heads", param.kvHeadNum); + add_param_to_buf("num_heads", param.headNum); + add_param_to_buf("scale_value", param.qkScale); + add_param_to_buf("quant_type", param.quantType); + add_param_to_buf("outdata_type", param.outDataType); + add_param_to_buf("mla_vheadsize", param.mlaVHeadSize); + add_param_to_buf("maskType", param.maskType); + add_param_to_buf("calcType", param.calcType); + } +}; + +template <> +struct HashOpParam { + void operator()(const atb::infer::SelfAttentionParam& param) const { + add_param_to_buf("num_kv_heads", param.kvHeadNum); + add_param_to_buf("num_heads", param.headNum); + add_param_to_buf("scale_value", param.qkScale); + add_param_to_buf("calcType", param.calcType); + add_param_to_buf("kernelType", param.kernelType); + add_param_to_buf("maskType", param.maskType); + add_param_to_buf("quantType", param.quantType); + add_param_to_buf("isTriuMask", param.isTriuMask); + } +}; + +template <> +struct HashOpParam { + void operator()(const atb::infer::RopeParam& param) const { + add_param_to_buf("rotaryCoeff", param.rotaryCoeff); + } +}; + +template <> +struct HashOpParam { + void operator()(const atb::infer::ReshapeAndCacheParam& param) const { + add_param_to_buf("compressType", param.compressType); + add_param_to_buf("kvCacheCfg", param.kvCacheCfg); + } +}; + +template +uint64_t compute_hash(const T& obj) { + g_hash_offset = 0; + HashOpParam{}(obj); + return calc_hash_id(); +} + +template +uint64_t compute_hash(const std::string& name, Ts&... args) { + g_hash_offset = 0; + add_param_to_buf(name, args...); + return calc_hash_id(); +} + +} // namespace atb + +#endif // XLLM_CORE_KERNELS_NPU_ATB_PARAM_OPERATION_CACHE_COMPUTE_H diff --git a/xllm/core/kernels/npu/custom_functions_npu/operation_create.h b/xllm/core/kernels/npu/custom_functions_npu/operation_create.h new file mode 100644 index 000000000..50d93fc5c --- /dev/null +++ b/xllm/core/kernels/npu/custom_functions_npu/operation_create.h @@ -0,0 +1,127 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLLM_CORE_KERNELS_NPU_ATB_OPERATION_CREATE_H +#define XLLM_CORE_KERNELS_NPU_ATB_OPERATION_CREATE_H + +#include +#include + +#include +#include +#include + +#include "atb/atb_infer.h" +#include "operation_cache_compute.h" +#include "utils.h" + +namespace atb { + +template +class OpParamCache { + public: + static OpParamCache& getInstance(); + + atb::Operation* get_operation(const ParamType& param, + const std::string& name); + atb::Operation* get_operation(uint64_t hash_id); + void save_operation(uint64_t hash_id, atb::Operation* op); + + private: + OpParamCache(); + + OpParamCache(const OpParamCache&) = delete; + OpParamCache& operator=(const OpParamCache&) = delete; + + ~OpParamCache(); + + std::unordered_map op_map_; + mutable std::mutex mutex_; +}; + +template +atb::Operation* create_atb_operation(const ParamType& param, + const std::string& name) { + atb::Operation* op = nullptr; + atb::CreateOperation(param, &op); + TORCH_CHECK(op != nullptr, name, " CreateOperation failed!"); + return op; +} + +template +OpParamCache& OpParamCache::getInstance() { + static OpParamCache instance; + return instance; +} + +template +atb::Operation* OpParamCache::get_operation( + const ParamType& param, + const std::string& name) { + const auto is_capturing = + static_cast(c10_npu::currentStreamCaptureStatusMayInitCtx()); + if (is_capturing) { + return create_atb_operation(param, name); + } else { + uint64_t hashValue = compute_hash(param); + { + std::lock_guard lock(mutex_); + auto op_cache = op_map_.find(hashValue); + if (op_cache != op_map_.end()) { + return op_cache->second; + } + atb::Operation* op = create_atb_operation(param, name); + op_map_[hashValue] = op; + return op; + } + } +} + +template +atb::Operation* OpParamCache::get_operation(uint64_t hash_id) { + std::lock_guard lock(mutex_); + auto op_cache = op_map_.find(hash_id); + if (op_cache != op_map_.end()) { + return op_cache->second; + } + + atb::Operation* op = nullptr; + return op; +} + +template +void OpParamCache::save_operation(uint64_t hash_id, + atb::Operation* op) { + std::lock_guard lock(mutex_); + op_map_[hash_id] = op; + return; +} + +template +OpParamCache::OpParamCache() { + atb::utils::ContextManager::get_instance(); +} + +template +OpParamCache::~OpParamCache() { + std::lock_guard lock(mutex_); + for (auto& op_item : op_map_) { + DestroyOperation(op_item.second); + } +} + +} // namespace atb + +#endif // XLLM_CORE_KERNELS_NPU_ATB_OPERATION_CREATE_H diff --git a/xllm/core/kernels/npu/custom_functions_npu/utils.cpp b/xllm/core/kernels/npu/custom_functions_npu/utils.cpp new file mode 100644 index 000000000..290923506 --- /dev/null +++ b/xllm/core/kernels/npu/custom_functions_npu/utils.cpp @@ -0,0 +1,81 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "utils.h" + +#include + +namespace atb { +namespace utils { + +ContextManager& ContextManager::get_instance() { + static ContextManager instance; + return instance; +} + +ContextManager::ContextManager() : atb_context_(nullptr) {} + +ContextManager::~ContextManager() { + if (atb_context_) { + auto status = atb::DestroyContext(atb_context_); + TORCH_CHECK(status == 0, "Destroy context failed!"); + atb_context_ = nullptr; + } +} + +atb::Context* ContextManager::get_context(aclrtStream stream) { + std::call_once(create_flag_, [this]() { + auto status = atb::CreateContext(&atb_context_); + TORCH_CHECK(status == 0, "Create context failed!"); + }); + + atb_context_->SetExecuteStream(stream); + return atb_context_; +} + +atb::Context* get_context(aclrtStream stream) { + return ContextManager::get_instance().get_context(stream); +} + +aclDataType convert_to_acl_data_type(const at::ScalarType& data_type) { + auto acl_dtype = + kATenScalarTypeToAclDataTypeTable[static_cast(data_type)]; + TORCH_CHECK(acl_dtype != ACL_DT_UNDEFINED, + std::string(c10::toString(data_type)) + " has not been supported") + return acl_dtype; +} + +at::Tensor format_trans(const at::Tensor& at_tensor) { + if (torch_npu::utils::is_npu(at_tensor)) { + return at_npu::native::npu_format_cast(at_tensor, ACL_FORMAT_ND); + } + return at_tensor; +} + +bool is_base_format(aclFormat& format) { + return (format == ACL_FORMAT_NCHW) || (format == ACL_FORMAT_ND) || + (format == ACL_FORMAT_NHWC) || (format == ACL_FORMAT_NCDHW); +} + +aclFormat get_format_for_atb(const at::Tensor& at_tensor) { + if (torch_npu::utils::is_npu(at_tensor)) { + aclFormat format = + static_cast(at_npu::native::get_npu_format(at_tensor)); + return is_base_format(format) ? ACL_FORMAT_ND : format; + } + return ACL_FORMAT_ND; +} +} // namespace utils +} // namespace atb diff --git a/xllm/core/kernels/npu/custom_functions_npu/utils.h b/xllm/core/kernels/npu/custom_functions_npu/utils.h new file mode 100644 index 000000000..f9bb2c525 --- /dev/null +++ b/xllm/core/kernels/npu/custom_functions_npu/utils.h @@ -0,0 +1,103 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLLM_CORE_KERNELS_NPU_ATB_UTILS_H +#define XLLM_CORE_KERNELS_NPU_ATB_UTILS_H + +#include +#include +#include + +#include "atb/atb_infer.h" + +namespace atb { +namespace utils { + +class ContextManager { + public: + static ContextManager& get_instance(); + atb::Context* get_context(aclrtStream stream); + ~ContextManager(); + + ContextManager(const ContextManager&) = delete; + ContextManager& operator=(const ContextManager&) = delete; + + private: + ContextManager(); + std::once_flag create_flag_; + atb::Context* atb_context_; +}; + +atb::Context* get_context(aclrtStream stream); + +#define AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(_) \ + _(at::ScalarType::Byte, ACL_UINT8) \ + _(at::ScalarType::Char, ACL_INT8) \ + _(at::ScalarType::Short, ACL_INT16) \ + _(at::ScalarType::Int, ACL_INT32) \ + _(at::ScalarType::Long, ACL_INT64) \ + _(at::ScalarType::Half, ACL_FLOAT16) \ + _(at::ScalarType::Float, ACL_FLOAT) \ + _(at::ScalarType::Double, ACL_DOUBLE) \ + _(at::ScalarType::ComplexHalf, ACL_COMPLEX32) \ + _(at::ScalarType::ComplexFloat, ACL_COMPLEX64) \ + _(at::ScalarType::ComplexDouble, ACL_COMPLEX128) \ + _(at::ScalarType::Bool, ACL_BOOL) \ + _(at::ScalarType::QInt8, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QUInt8, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QInt32, ACL_DT_UNDEFINED) \ + _(at::ScalarType::BFloat16, ACL_BF16) \ + _(at::ScalarType::QUInt4x2, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QUInt2x4, ACL_DT_UNDEFINED) \ + _(at::ScalarType::Bits1x8, ACL_DT_UNDEFINED) \ + _(at::ScalarType::Bits2x4, ACL_DT_UNDEFINED) \ + _(at::ScalarType::Bits4x2, ACL_DT_UNDEFINED) \ + _(at::ScalarType::Bits8, ACL_DT_UNDEFINED) \ + _(at::ScalarType::Bits16, ACL_DT_UNDEFINED) \ + _(at::ScalarType::Float8_e5m2, ACL_DT_UNDEFINED) \ + _(at::ScalarType::Float8_e4m3fn, ACL_DT_UNDEFINED) \ + _(at::ScalarType::Undefined, ACL_DT_UNDEFINED) \ + _(at::ScalarType::NumOptions, ACL_DT_UNDEFINED) + +constexpr aclDataType kATenScalarTypeToAclDataTypeTable + [static_cast(at::ScalarType::NumOptions) + 1] = { +#define DEFINE_ENUM(_1, n) n, + AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(DEFINE_ENUM) +#undef DEFINE_ENUM +}; + +aclDataType convert_to_acl_data_type(const at::ScalarType& data_type); +at::Tensor format_trans(const at::Tensor& at_tensor); +aclFormat get_format_for_atb(const at::Tensor& at_tensor); + +template +inline int get_op_mode(const MapType& mode_map, + c10::optional mode_opt, + c10::string_view default_mode, + const char* mode_name) { + c10::string_view mode_str = mode_opt.value_or(default_mode); + auto it = mode_map.find(mode_str); + TORCH_CHECK(it != mode_map.end(), + "Unsupported ", + mode_name, + " value: '", + mode_str, + "'"); + return it->second; +} +} // namespace utils +} // namespace atb + +#endif diff --git a/xllm/core/kernels/npu/fused_layernorm.cpp b/xllm/core/kernels/npu/fused_layernorm.cpp new file mode 100644 index 000000000..3c8e51708 --- /dev/null +++ b/xllm/core/kernels/npu/fused_layernorm.cpp @@ -0,0 +1,36 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "npu_ops_api.h" +#include "ops_npu/npu_ops.h" + +namespace xllm::kernel::npu { + +torch::Tensor fused_layernorm(const torch::Tensor& input, + const torch::Tensor& weight, + double eps, + const std::string& mode) { + if (mode != "rmsnorm") { + throw std::runtime_error( + "Only rmsnorm mode is supported in NPU fused_layernorm"); + } + std::tuple result = + at_npu::native::custom_ops::npu_rms_norm(input, weight, eps); + auto normalized_input = std::get<0>(result); + return normalized_input; +} + +} // namespace xllm::kernel::npu \ No newline at end of file diff --git a/xllm/core/kernels/npu/linear.h b/xllm/core/kernels/npu/matmul.cpp similarity index 61% rename from xllm/core/kernels/npu/linear.h rename to xllm/core/kernels/npu/matmul.cpp index 0834c014b..0b80c9dd4 100644 --- a/xllm/core/kernels/npu/linear.h +++ b/xllm/core/kernels/npu/matmul.cpp @@ -13,18 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#pragma once -#include "impl/npu_linear_impl.h" - -namespace xllm::kernel { - -class Linear : public torch::nn::ModuleHolder { - public: - using torch::nn::ModuleHolder::ModuleHolder; - using Impl __attribute__((__unused__)) = NpuLinearImpl; - - Linear(const ModelContext& context) - : ModuleHolder(std::make_shared(context)) {} -}; - -} // namespace xllm::kernel +#include "npu_ops_api.h" +#include "ops_npu/npu_ops.h" + +namespace xllm::kernel::npu { + +torch::Tensor matmul(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& bias) { + if (!bias.has_value()) { + return torch::nn::functional::linear(a, b); + } else { + return torch::nn::functional::linear(a, b, bias.value()); + } +} + +} // namespace xllm::kernel::npu diff --git a/xllm/core/kernels/npu/npu_ops_api.h b/xllm/core/kernels/npu/npu_ops_api.h new file mode 100644 index 000000000..6c10a272f --- /dev/null +++ b/xllm/core/kernels/npu/npu_ops_api.h @@ -0,0 +1,62 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once +#include + +#include + +#include "./custom_functions_npu/atb_common.h" + +namespace xllm::kernel::npu { + +void reshape_paged_cache(torch::Tensor& key, + std::optional& value, + torch::Tensor& k_cache, + std::optional& v_cache, + const torch::Tensor& slot_mapping); + +void batch_prefill(const torch::Tensor& query, + const torch::Tensor& key, + const torch::Tensor& value, + const torch::Tensor& mask, + const torch::Tensor& seq_len, + float scale, + torch::Tensor& output); + +void batch_decode(const torch::Tensor& query, + const torch::Tensor& k_cache, + const torch::Tensor& v_cache, + float scale, + const torch::Tensor& block_table, + const torch::Tensor& seq_lens, + torch::Tensor& output); + +torch::Tensor matmul(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& bias); + +torch::Tensor active(const torch::Tensor& input, const std::string& act_mode); + +torch::Tensor fused_layernorm(const torch::Tensor& input, + const torch::Tensor& weight, + double eps, + const std::string& mode); + +void apply_rotary(torch::Tensor& q, + torch::Tensor& k, + const torch::Tensor& cos_sin_cache, + const torch::Tensor& positions); +} // namespace xllm::kernel::npu diff --git a/xllm/core/kernels/npu/ops_npu/npu_ops.h b/xllm/core/kernels/npu/ops_npu/npu_ops.h new file mode 100644 index 000000000..01301a174 --- /dev/null +++ b/xllm/core/kernels/npu/ops_npu/npu_ops.h @@ -0,0 +1,56 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef XLLM_NPU_OPS_H +#define XLLM_NPU_OPS_H + +#include "../custom_functions_npu/atb_common.h" + +using namespace std; + +namespace atb { + +using PagedAttentionParam = atb::infer::PagedAttentionParam; +using ReshapeAndCacheParam = atb::infer::ReshapeAndCacheParam; +using SelfAttentionParam = atb::infer::SelfAttentionParam; + +void _npu_paged_attention(const at::Tensor& query, + const at::Tensor& key_cache, + const at::Tensor& value_cache, + int64_t num_kv_heads, + int64_t num_heads, + double scale_value, + const at::Tensor& block_table, + const at::Tensor& context_lens, + at::Tensor& out); + +void _npu_reshape_and_cache(const at::Tensor& key, + const at::Tensor& value, + at::Tensor& key_cache, + at::Tensor& value_cache, + const at::Tensor& slot_indices); + +void _npu_flash_attention(const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& mask, + const at::Tensor& seq_len, + const double scale_value, + const int64_t num_heads, + const int64_t num_kv_heads, + at::Tensor& out); + +} // namespace atb + +#endif // XLLM_NPU_OPS_H \ No newline at end of file diff --git a/xllm/core/kernels/npu/ops_npu/paged_attention_atb.cpp b/xllm/core/kernels/npu/ops_npu/paged_attention_atb.cpp new file mode 100644 index 000000000..f43f4ecc9 --- /dev/null +++ b/xllm/core/kernels/npu/ops_npu/paged_attention_atb.cpp @@ -0,0 +1,62 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "../custom_functions_npu/atb_common.h" + +namespace atb { +using PagedAttentionParam = atb::infer::PagedAttentionParam; +void _npu_paged_attention(const at::Tensor& query, + const at::Tensor& key_cache, + const at::Tensor& value_cache, + int64_t num_kv_heads, + int64_t num_heads, + double scale_value, + const at::Tensor& block_table, + const at::Tensor& context_lens, + at::Tensor& out) { + const c10::OptionalDeviceGuard device_guard(device_of(query)); + OpParamCache& pagedAttentionParamCache = + OpParamCache::getInstance(); + PagedAttentionParam pagedparam; + pagedparam.headNum = num_heads; + pagedparam.qkScale = scale_value; + pagedparam.kvHeadNum = num_kv_heads; + pagedparam.maskType = PagedAttentionParam::UNDEFINED; + pagedparam.batchRunStatusEnable = false; + pagedparam.quantType = PagedAttentionParam::TYPE_QUANT_UNDEFINED; + pagedparam.outDataType = ACL_DT_UNDEFINED; + pagedparam.hasQuantOffset = false; + pagedparam.compressType = PagedAttentionParam::COMPRESS_TYPE_UNDEFINED; + pagedparam.calcType = PagedAttentionParam::CALC_TYPE_UNDEFINED; + pagedparam.scaleType = PagedAttentionParam::SCALE_TYPE_TOR; + pagedparam.inputLayout = atb::infer::TYPE_BSND; + pagedparam.mlaVHeadSize = 0; + + ParamSetter paramsetter; + paramsetter.Input(query, true) + .Input(key_cache) + .Input(value_cache) + .Input(block_table, true) + .Input(context_lens, true) + .Output(out); + auto opPaged = pagedAttentionParamCache.get_operation( + pagedparam, "PagedAttentionOperation"); + run_atb_cmd(opPaged, paramsetter, "PagedAttentionOperation"); + + return; +} + +} // namespace atb \ No newline at end of file diff --git a/xllm/core/kernels/npu/ops_npu/reshape_and_cach_atb.cpp b/xllm/core/kernels/npu/ops_npu/reshape_and_cach_atb.cpp new file mode 100644 index 000000000..c9781cba7 --- /dev/null +++ b/xllm/core/kernels/npu/ops_npu/reshape_and_cach_atb.cpp @@ -0,0 +1,60 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "../custom_functions_npu/atb_common.h" + +using namespace std; +namespace atb { +using ReshapeAndCacheParam = atb::infer::ReshapeAndCacheParam; +void _npu_reshape_and_cache(const at::Tensor& key, + const at::Tensor& value, + at::Tensor& key_cache, + at::Tensor& value_cache, + const at::Tensor& slot_indices) { + const c10::OptionalDeviceGuard device_guard(device_of(key)); + OpParamCache& reshapeAndCacheParamCache = + OpParamCache::getInstance(); + ReshapeAndCacheParam reshapeparam; + reshapeparam.compressType = ReshapeAndCacheParam::COMPRESS_TYPE_UNDEFINED; + + auto key_cache_format = at_npu::native::get_npu_format(key_cache); + auto value_cache_format = at_npu::native::get_npu_format(value_cache); + bool is_key_cache_nz = (key_cache_format == ACL_FORMAT_FRACTAL_NZ); + bool is_value_cache_nz = (value_cache_format == ACL_FORMAT_FRACTAL_NZ); + + if (is_key_cache_nz && is_value_cache_nz) { + reshapeparam.kvCacheCfg = ReshapeAndCacheParam::K_CACHE_V_CACHE_NZ; + } else { + reshapeparam.kvCacheCfg = ReshapeAndCacheParam::K_CACHE_V_CACHE; + } + + ParamSetter parametter; + parametter.Input(key, true) + .Input(value, true) + .Input(key_cache) + .Input(value_cache) + .Input(slot_indices, true) + .Output(key_cache) + .Output(value_cache); + auto opReshape = reshapeAndCacheParamCache.get_operation( + reshapeparam, "ReshapeCacheOperation"); + run_atb_cmd(opReshape, parametter, "ReshapeCacheOperation"); + + return; +} + +} // namespace atb diff --git a/xllm/core/kernels/npu/ops_npu/self_attention_atb.cpp b/xllm/core/kernels/npu/ops_npu/self_attention_atb.cpp new file mode 100644 index 000000000..281664161 --- /dev/null +++ b/xllm/core/kernels/npu/ops_npu/self_attention_atb.cpp @@ -0,0 +1,73 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "../custom_functions_npu/atb_common.h" + +using namespace std; +namespace atb { +using SelfAttentionParam = atb::infer::SelfAttentionParam; +void _npu_flash_attention(const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& mask, + const at::Tensor& seq_len, + const double scale_value, + const int64_t num_heads, + const int64_t num_kv_heads, + at::Tensor& out) { + const c10::OptionalDeviceGuard device_guard(device_of(query)); + OpParamCache& selfAttentionParamCache = + OpParamCache::getInstance(); + SelfAttentionParam selfattentionparam; + + selfattentionparam.calcType = SelfAttentionParam::PA_ENCODER; + selfattentionparam.kernelType = SelfAttentionParam::KERNELTYPE_DEFAULT; + selfattentionparam.clampType = SelfAttentionParam::CLAMP_TYPE_UNDEFINED; + selfattentionparam.maskType = SelfAttentionParam::MASK_TYPE_NORM; + selfattentionparam.kvcacheCfg = SelfAttentionParam::K_CACHE_V_CACHE; + selfattentionparam.scaleType = SelfAttentionParam::SCALE_TYPE_TOR; + selfattentionparam.quantType = SelfAttentionParam::TYPE_QUANT_UNDEFINED; + selfattentionparam.cacheType = SelfAttentionParam::CACHE_TYPE_NORM; + selfattentionparam.outDataType = ACL_DT_UNDEFINED; + selfattentionparam.headNum = num_heads; + selfattentionparam.kvHeadNum = num_kv_heads; + selfattentionparam.qScale = 1; + selfattentionparam.qkScale = scale_value; + selfattentionparam.batchRunStatusEnable = false; + selfattentionparam.isTriuMask = 0; + selfattentionparam.clampMin = 0; + selfattentionparam.clampMax = 0; + selfattentionparam.inputLayout = atb::infer::TYPE_BSND; + selfattentionparam.mlaVHeadSize = 0; + selfattentionparam.windowSize = 0; + + ParamSetter parametter; + parametter.Input(query, true) + .Input(key, true) + .Input(value, true) + .Input(mask) + .Input(seq_len, true) + .Output(out); + + auto opSelfattention = selfAttentionParamCache.get_operation( + selfattentionparam, "SelfAttentionOperation"); + run_atb_cmd(opSelfattention, parametter, "SelfAttentionOperation"); + + return; +} + +} // namespace atb \ No newline at end of file diff --git a/xllm/core/kernels/npu/rope.cpp b/xllm/core/kernels/npu/rope.cpp new file mode 100644 index 000000000..9e312f961 --- /dev/null +++ b/xllm/core/kernels/npu/rope.cpp @@ -0,0 +1,42 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "npu_ops_api.h" +#include "ops_npu/npu_ops.h" + +namespace xllm::kernel::npu { + +void apply_rotary(torch::Tensor& q, + torch::Tensor& k, + const torch::Tensor& cos_sin_cache, + const torch::Tensor& positions) { + auto cos_sin = cos_sin_cache.index_select(0, positions); + auto last_dim = cos_sin.size(-1); + auto cos_sin_vec = cos_sin.view({-1, 2, last_dim / 2}) + .repeat({1, 1, 2}) + .chunk(2, /*dim=*/-2); + auto cos = cos_sin_vec[0].view({1, -1, 1, last_dim}); + auto sin = cos_sin_vec[1].view({1, -1, 1, last_dim}); + + const int64_t rotary_dim = sin.size(-1); + q = q.view({1, q.size(0), -1, rotary_dim}); + k = k.view({1, k.size(0), -1, rotary_dim}); + + at_npu::native::custom_ops::npu_apply_rotary_pos_emb(q, k, cos, sin); +} + +} // namespace xllm::kernel::npu \ No newline at end of file diff --git a/xllm/core/kernels/npu/rope.h b/xllm/core/kernels/npu/rope.h deleted file mode 100644 index 7a075b0d3..000000000 --- a/xllm/core/kernels/npu/rope.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#pragma once -#include "impl/npu_rope_impl.h" - -namespace xllm::kernel { - -class Rope : public torch::nn::ModuleHolder { - public: - using torch::nn::ModuleHolder::ModuleHolder; - using Impl __attribute__((__unused__)) = NpuRopeImpl; - - Rope(const ModelContext& context) - : ModuleHolder(std::make_shared(context)) {} -}; - -} // namespace xllm::kernel diff --git a/xllm/core/kernels/npu/split.h b/xllm/core/kernels/npu/split.h deleted file mode 100644 index cda39703e..000000000 --- a/xllm/core/kernels/npu/split.h +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#pragma once -#include "impl/npu_split_impl.h" - -namespace xllm::kernel { -class Split : public torch::nn::ModuleHolder { - public: - using torch::nn::ModuleHolder::ModuleHolder; - using Impl __attribute__((__unused__)) = NpuSplitImpl; - - Split(const ModelContext& context, - int32_t splitDim = 2, - int32_t splitNum = 3, - atb::SVector splitSizes = {}) - : ModuleHolder(std::make_shared(context, - splitDim, - splitNum, - splitSizes)) {} -}; - -} // namespace xllm::kernel diff --git a/xllm/core/kernels/ops_api.cpp b/xllm/core/kernels/ops_api.cpp index 07ab88f62..cd39c0738 100644 --- a/xllm/core/kernels/ops_api.cpp +++ b/xllm/core/kernels/ops_api.cpp @@ -17,6 +17,8 @@ limitations under the License. #if defined(USE_MLU) #include "mlu/mlu_ops_api.h" +#elif defined(USE_NPU) +#include "npu/npu_ops_api.h" #elif defined(USE_CUDA) #include "cuda/cuda_ops_api.h" #endif @@ -38,6 +40,9 @@ void apply_rotary(RotaryParams& params) { params.discrete, params.dynamic_ntk, params.max_query_len); +#elif defined(USE_NPU) + npu::apply_rotary( + params.q, params.k, params.cos_sin, params.position_ids.value()); #elif defined(USE_CUDA) bool is_neox = !params.interleaved; @@ -70,6 +75,14 @@ void active(ActivationParams& params) { #endif } +torch::Tensor active_tensor(ActivationParams& params) { +#if defined(USE_NPU) + return npu::active(params.input, params.act_mode); +#else + LOG(FATAL) << "active_tensor not implemented"; +#endif +} + void reshape_paged_cache(ReshapePagedCacheParams& params) { #if defined(USE_MLU) mlu::reshape_paged_cache(params.key, @@ -78,6 +91,12 @@ void reshape_paged_cache(ReshapePagedCacheParams& params) { params.v_cache, params.slot_mapping, params.direction); +#elif defined(USE_NPU) + npu::reshape_paged_cache(params.key, + params.value, + params.k_cache, + params.v_cache, + params.slot_mapping); #elif defined(USE_CUDA) cuda::reshape_paged_cache(params.slot_mapping, params.key, @@ -113,6 +132,14 @@ void batch_prefill(AttentionParams& params) { params.window_size_right, params.compute_dtype, params.return_lse); +#elif defined(USE_NPU) + npu::batch_prefill(params.query, + params.key, + params.value, + params.attn_mask, + params.seq_lens, + params.scale, + params.output); #elif defined(USE_CUDA) cuda::batch_prefill(params.float_workspace_buffer, params.int_workspace_buffer, @@ -154,6 +181,14 @@ void batch_decode(AttentionParams& params) { params.scale, params.return_lse, params.kv_cache_quant_bit_size); +#elif defined(USE_NPU) + npu::batch_decode(params.query, + params.k_cache, + params.v_cache.value_or(torch::Tensor()), + params.scale, + params.block_table.value(), + params.seq_lens, + params.output); #elif defined(USE_CUDA) params.query = params.query.squeeze(1); params.output = params.output.squeeze(1); @@ -207,10 +242,21 @@ void fused_layernorm(FusedLayerNormParams& params) { #endif } +torch::Tensor fused_layernorm_tensor(FusedLayerNormParams& params) { +#if defined(USE_NPU) + return npu::fused_layernorm( + params.input, params.weight, params.eps, params.mode); +#else + LOG(FATAL) << "fused_layernorm not implemented"; +#endif +} + torch::Tensor matmul(MatmulParams& params) { #if defined(USE_MLU) return mlu::matmul( params.a, params.b, params.bias, params.c, params.alpha, params.beta); +#elif defined(USE_NPU) + return npu::matmul(params.a, params.b, params.bias); #elif defined(USE_CUDA) return cuda::matmul(params.a, params.b, params.bias); #else diff --git a/xllm/core/kernels/ops_api.h b/xllm/core/kernels/ops_api.h index 4f17659f6..ef3484978 100644 --- a/xllm/core/kernels/ops_api.h +++ b/xllm/core/kernels/ops_api.h @@ -28,6 +28,8 @@ void apply_rotary(RotaryParams& params); void active(ActivationParams& params); +torch::Tensor active_tensor(ActivationParams& params); + void reshape_paged_cache(ReshapePagedCacheParams& params); void batch_prefill(AttentionParams& params); @@ -36,6 +38,8 @@ void batch_decode(AttentionParams& params); void fused_layernorm(FusedLayerNormParams& params); +torch::Tensor fused_layernorm_tensor(FusedLayerNormParams& params); + torch::Tensor matmul(MatmulParams& params); torch::Tensor group_gemm(GroupGemmParams& params); diff --git a/xllm/core/kernels/param.h b/xllm/core/kernels/param.h index 2be0e88aa..5cf7d7e89 100644 --- a/xllm/core/kernels/param.h +++ b/xllm/core/kernels/param.h @@ -195,6 +195,9 @@ struct AttentionParams { float scale; // Whether to return log-sum-exp values in output_lse. bool return_lse = false; + // ========== Torch NPU related parameters ========== + torch::Tensor seq_lens; + torch::Tensor attn_mask; // ========== FlashInfer related parameters ========== torch::Tensor paged_kv_indptr; diff --git a/xllm/core/layers/common/rotary_embedding.cpp b/xllm/core/layers/common/rotary_embedding.cpp index 00ea4bd04..50fcd2583 100644 --- a/xllm/core/layers/common/rotary_embedding.cpp +++ b/xllm/core/layers/common/rotary_embedding.cpp @@ -51,7 +51,7 @@ void RotaryEmbeddingImpl::forward(torch::Tensor& q, std::optional position_ids; if (is_prompt) { discrete = false; - if (Device::type_str() == "cuda") { + if (Device::type_str() == "cuda" || Device::type_str() == "npu") { position_ids = positions; } } else { From 8f9c3fa2a9003bb42d62b74b719bd35a086ab88c Mon Sep 17 00:00:00 2001 From: dengyingxu Date: Fri, 5 Dec 2025 22:46:13 +0800 Subject: [PATCH 2/7] refactor: standardize interface for active kernel execution. --- xllm/core/kernels/ops_api.cpp | 10 ++-------- xllm/core/kernels/ops_api.h | 2 -- xllm/core/layers/common/dense_mlp.cpp | 13 ++++++++----- 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/xllm/core/kernels/ops_api.cpp b/xllm/core/kernels/ops_api.cpp index cd39c0738..08b0e5ec1 100644 --- a/xllm/core/kernels/ops_api.cpp +++ b/xllm/core/kernels/ops_api.cpp @@ -70,19 +70,13 @@ void active(ActivationParams& params) { params.expert_size); #elif defined(USE_CUDA) cuda::act_and_mul(params.output, params.input, params.act_mode); +#elif defined(USE_NPU) + params.output = npu::active(params.input, params.act_mode); #else LOG(FATAL) << "active not implemented"; #endif } -torch::Tensor active_tensor(ActivationParams& params) { -#if defined(USE_NPU) - return npu::active(params.input, params.act_mode); -#else - LOG(FATAL) << "active_tensor not implemented"; -#endif -} - void reshape_paged_cache(ReshapePagedCacheParams& params) { #if defined(USE_MLU) mlu::reshape_paged_cache(params.key, diff --git a/xllm/core/kernels/ops_api.h b/xllm/core/kernels/ops_api.h index ef3484978..2649ed6fd 100644 --- a/xllm/core/kernels/ops_api.h +++ b/xllm/core/kernels/ops_api.h @@ -28,8 +28,6 @@ void apply_rotary(RotaryParams& params); void active(ActivationParams& params); -torch::Tensor active_tensor(ActivationParams& params); - void reshape_paged_cache(ReshapePagedCacheParams& params); void batch_prefill(AttentionParams& params); diff --git a/xllm/core/layers/common/dense_mlp.cpp b/xllm/core/layers/common/dense_mlp.cpp index 57b52429b..4a5818473 100644 --- a/xllm/core/layers/common/dense_mlp.cpp +++ b/xllm/core/layers/common/dense_mlp.cpp @@ -90,11 +90,14 @@ torch::Tensor DenseMLPImpl::forward(const torch::Tensor& hidden_states) { // For w8a8 quantization, the active operation is fused with the down_proj return down_proj_->forward(gate_up); } else { - int64_t batch_size = gate_up.sizes()[0]; - auto output = torch::empty( - {batch_size, - intermediate_size_ / parallel_args_.tp_group_->world_size()}, - gate_up.options()); + torch::Tensor output; + if (Device::type_str() != "npu") { + int64_t batch_size = gate_up.sizes()[0]; + output = torch::empty( + {batch_size, + intermediate_size_ / parallel_args_.tp_group_->world_size()}, + gate_up.options()); + } xllm::kernel::ActivationParams activation_params; activation_params.input = gate_up; From 3d67d7b43da390f91d6ffb3058e18e42a53266e4 Mon Sep 17 00:00:00 2001 From: dengyingxu Date: Sat, 6 Dec 2025 00:15:13 +0800 Subject: [PATCH 3/7] refactor: redesign wrapper for NPU fused_layernorm operator. --- xllm/core/kernels/ops_api.cpp | 10 ++-------- xllm/core/kernels/ops_api.h | 2 -- xllm/core/layers/common/dense_mlp.cpp | 1 + xllm/core/layers/common/rms_norm.cpp | 6 +++++- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/xllm/core/kernels/ops_api.cpp b/xllm/core/kernels/ops_api.cpp index 08b0e5ec1..f51b39f5a 100644 --- a/xllm/core/kernels/ops_api.cpp +++ b/xllm/core/kernels/ops_api.cpp @@ -231,14 +231,8 @@ void fused_layernorm(FusedLayerNormParams& params) { } else { cuda::rms_norm(params.output, params.input, params.weight, params.eps); } -#else - LOG(FATAL) << "fused_layernorm not implemented"; -#endif -} - -torch::Tensor fused_layernorm_tensor(FusedLayerNormParams& params) { -#if defined(USE_NPU) - return npu::fused_layernorm( +#elif defined(USE_NPU) + params.output = npu::fused_layernorm( params.input, params.weight, params.eps, params.mode); #else LOG(FATAL) << "fused_layernorm not implemented"; diff --git a/xllm/core/kernels/ops_api.h b/xllm/core/kernels/ops_api.h index 2649ed6fd..4f17659f6 100644 --- a/xllm/core/kernels/ops_api.h +++ b/xllm/core/kernels/ops_api.h @@ -36,8 +36,6 @@ void batch_decode(AttentionParams& params); void fused_layernorm(FusedLayerNormParams& params); -torch::Tensor fused_layernorm_tensor(FusedLayerNormParams& params); - torch::Tensor matmul(MatmulParams& params); torch::Tensor group_gemm(GroupGemmParams& params); diff --git a/xllm/core/layers/common/dense_mlp.cpp b/xllm/core/layers/common/dense_mlp.cpp index 4a5818473..1bdfea078 100644 --- a/xllm/core/layers/common/dense_mlp.cpp +++ b/xllm/core/layers/common/dense_mlp.cpp @@ -18,6 +18,7 @@ limitations under the License. #include #include "kernels/ops_api.h" +#include "platform/device.h" namespace xllm { namespace layer { diff --git a/xllm/core/layers/common/rms_norm.cpp b/xllm/core/layers/common/rms_norm.cpp index 21addfb74..de2f5f44f 100644 --- a/xllm/core/layers/common/rms_norm.cpp +++ b/xllm/core/layers/common/rms_norm.cpp @@ -18,6 +18,7 @@ limitations under the License. #include #include "kernels/ops_api.h" +#include "platform/device.h" namespace xllm { namespace layer { @@ -40,7 +41,10 @@ RMSNormImpl::RMSNormImpl(const ModelContext& context) context.get_tensor_options()) {} torch::Tensor RMSNormImpl::forward(torch::Tensor& input) { - auto output = torch::empty_like(input); + torch::Tensor output; + if (Device::type_str() != "npu") { + output = torch::empty_like(input); + } return forward_output(input, output); } From 74cde33921cd29b20a87224d645f2bd0763d7d43 Mon Sep 17 00:00:00 2001 From: dengyingxu Date: Mon, 8 Dec 2025 15:21:34 +0800 Subject: [PATCH 4/7] refactor: replace header guards with #pragma once. --- xllm/core/kernels/npu/custom_functions_npu/atb_common.h | 6 ++---- .../npu/custom_functions_npu/operation_cache_compute.h | 5 +---- .../kernels/npu/custom_functions_npu/operation_create.h | 5 +---- xllm/core/kernels/npu/custom_functions_npu/utils.h | 5 +---- xllm/core/kernels/npu/ops_npu/npu_ops.h | 5 +---- 5 files changed, 6 insertions(+), 20 deletions(-) diff --git a/xllm/core/kernels/npu/custom_functions_npu/atb_common.h b/xllm/core/kernels/npu/custom_functions_npu/atb_common.h index 198c425a2..665de6552 100644 --- a/xllm/core/kernels/npu/custom_functions_npu/atb_common.h +++ b/xllm/core/kernels/npu/custom_functions_npu/atb_common.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLLM_CORE_KERNELS_NPU_ATB_COMMON_H -#define XLLM_CORE_KERNELS_NPU_ATB_COMMON_H +#pragma once + #include #include #include @@ -490,5 +490,3 @@ void run_atb_cmd(atb::Operation* op, const std::string& name); } // namespace atb - -#endif diff --git a/xllm/core/kernels/npu/custom_functions_npu/operation_cache_compute.h b/xllm/core/kernels/npu/custom_functions_npu/operation_cache_compute.h index f7077aaa4..3149f125d 100644 --- a/xllm/core/kernels/npu/custom_functions_npu/operation_cache_compute.h +++ b/xllm/core/kernels/npu/custom_functions_npu/operation_cache_compute.h @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLLM_CORE_KERNELS_NPU_ATB_PARAM_OPERATION_CACHE_COMPUTE_H -#define XLLM_CORE_KERNELS_NPU_ATB_PARAM_OPERATION_CACHE_COMPUTE_H +#pragma once #include @@ -144,5 +143,3 @@ uint64_t compute_hash(const std::string& name, Ts&... args) { } } // namespace atb - -#endif // XLLM_CORE_KERNELS_NPU_ATB_PARAM_OPERATION_CACHE_COMPUTE_H diff --git a/xllm/core/kernels/npu/custom_functions_npu/operation_create.h b/xllm/core/kernels/npu/custom_functions_npu/operation_create.h index 50d93fc5c..5ee0917a9 100644 --- a/xllm/core/kernels/npu/custom_functions_npu/operation_create.h +++ b/xllm/core/kernels/npu/custom_functions_npu/operation_create.h @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLLM_CORE_KERNELS_NPU_ATB_OPERATION_CREATE_H -#define XLLM_CORE_KERNELS_NPU_ATB_OPERATION_CREATE_H +#pragma once #include #include @@ -123,5 +122,3 @@ OpParamCache::~OpParamCache() { } } // namespace atb - -#endif // XLLM_CORE_KERNELS_NPU_ATB_OPERATION_CREATE_H diff --git a/xllm/core/kernels/npu/custom_functions_npu/utils.h b/xllm/core/kernels/npu/custom_functions_npu/utils.h index f9bb2c525..f605ec55e 100644 --- a/xllm/core/kernels/npu/custom_functions_npu/utils.h +++ b/xllm/core/kernels/npu/custom_functions_npu/utils.h @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLLM_CORE_KERNELS_NPU_ATB_UTILS_H -#define XLLM_CORE_KERNELS_NPU_ATB_UTILS_H +#pragma once #include #include @@ -99,5 +98,3 @@ inline int get_op_mode(const MapType& mode_map, } } // namespace utils } // namespace atb - -#endif diff --git a/xllm/core/kernels/npu/ops_npu/npu_ops.h b/xllm/core/kernels/npu/ops_npu/npu_ops.h index 01301a174..5ed5de487 100644 --- a/xllm/core/kernels/npu/ops_npu/npu_ops.h +++ b/xllm/core/kernels/npu/ops_npu/npu_ops.h @@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLLM_NPU_OPS_H -#define XLLM_NPU_OPS_H +#pragma once #include "../custom_functions_npu/atb_common.h" @@ -52,5 +51,3 @@ void _npu_flash_attention(const at::Tensor& query, at::Tensor& out); } // namespace atb - -#endif // XLLM_NPU_OPS_H \ No newline at end of file From 3a71812f71fc109f3e1d0c11ce15c31113564329 Mon Sep 17 00:00:00 2001 From: dengyingxu Date: Mon, 8 Dec 2025 17:25:32 +0800 Subject: [PATCH 5/7] refactor: replace TORCH_CHECK with CHECK macros and optimize code layout. --- xllm/core/kernels/npu/active.cpp | 4 +- xllm/core/kernels/npu/attention.cpp | 6 +-- .../npu/custom_functions_npu/atb_common.cpp | 9 ++-- .../npu/custom_functions_npu/atb_common.h | 51 ++++++++----------- .../custom_functions_npu/operation_create.h | 3 +- .../npu/custom_functions_npu/utils.cpp | 8 +-- .../kernels/npu/custom_functions_npu/utils.h | 9 ++-- xllm/core/kernels/npu/fused_layernorm.cpp | 12 ++--- xllm/core/kernels/npu/npu_ops_api.h | 10 ++-- xllm/core/kernels/npu/ops_npu/npu_ops.h | 6 +-- .../npu/ops_npu/paged_attention_atb.cpp | 20 ++++---- .../npu/ops_npu/reshape_and_cach_atb.cpp | 17 ++++--- .../npu/ops_npu/self_attention_atb.cpp | 30 ++++++----- xllm/core/kernels/ops_api.cpp | 4 +- 14 files changed, 88 insertions(+), 101 deletions(-) diff --git a/xllm/core/kernels/npu/active.cpp b/xllm/core/kernels/npu/active.cpp index e3e66be54..49a9e94f4 100644 --- a/xllm/core/kernels/npu/active.cpp +++ b/xllm/core/kernels/npu/active.cpp @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "npu_ops_api.h" @@ -22,8 +23,7 @@ namespace xllm::kernel::npu { torch::Tensor active(const torch::Tensor& input, const std::string& act_mode) { if (act_mode != "silu" && act_mode != "swiglu") { - throw std::runtime_error( - "Only swiglu activation is supported in NPU active"); + LOG(FATAL) << "Only swiglu activation is supported in NPU active"; } return at_npu::native::custom_ops::npu_swiglu(input); } diff --git a/xllm/core/kernels/npu/attention.cpp b/xllm/core/kernels/npu/attention.cpp index 0b4d80347..381e5ea79 100644 --- a/xllm/core/kernels/npu/attention.cpp +++ b/xllm/core/kernels/npu/attention.cpp @@ -46,9 +46,9 @@ void batch_decode(const torch::Tensor& query, const torch::Tensor& block_table, const torch::Tensor& seq_lens, torch::Tensor& output) { - auto head_size = query.size(-1); - auto num_heads = query.size(-2); - auto num_kv_heads = k_cache.size(-2); + int64_t head_size = query.size(-1); + int64_t num_heads = query.size(-2); + int64_t num_kv_heads = k_cache.size(-2); auto q = query.view({-1, num_heads, head_size}); auto o = output.view({-1, num_heads, head_size}); atb::_npu_paged_attention(q, diff --git a/xllm/core/kernels/npu/custom_functions_npu/atb_common.cpp b/xllm/core/kernels/npu/custom_functions_npu/atb_common.cpp index 70cf6b6c9..97672427e 100644 --- a/xllm/core/kernels/npu/custom_functions_npu/atb_common.cpp +++ b/xllm/core/kernels/npu/custom_functions_npu/atb_common.cpp @@ -33,7 +33,7 @@ atb::Tensor at_tensor_to_atb_tensor(const at::Tensor at_tensor) { {at::ScalarType::ComplexDouble, ACL_COMPLEX128}, }; - TORCH_CHECK(at_tensor.is_contiguous(), "at_tensor is not contiguous"); + CHECK(at_tensor.is_contiguous()) << "at_tensor is not contiguous"; atb::Tensor tensor; tensor.desc.format = atb::utils::get_format_for_atb(at_tensor); if (at_tensor.device().type() == at::kCPU) { @@ -48,9 +48,8 @@ atb::Tensor at_tensor_to_atb_tensor(const at::Tensor at_tensor) { } auto dtype_iterator = dtype_map.find(at_tensor.scalar_type()); - TORCH_CHECK(dtype_iterator != dtype_map.end(), - "not support dtype: ", - at_tensor.scalar_type()); + CHECK(dtype_iterator != dtype_map.end()) + << "not support dtype: " << at_tensor.scalar_type(); tensor.desc.dtype = dtype_iterator->second; tensor.dataSize = atb::Utils::GetTensorSize(tensor); @@ -168,7 +167,7 @@ uint64_t operation_setup(atb::VariantPack variant_pack, uint64_t workspace_size = 0; atb::Status status = operation->Setup(variant_pack, workspace_size, context_ptr); - TORCH_CHECK(status == 0, operation->GetName(), " setup failed!"); + CHECK_EQ(status, 0) << operation->GetName() << " setup failed!"; return workspace_size; } diff --git a/xllm/core/kernels/npu/custom_functions_npu/atb_common.h b/xllm/core/kernels/npu/custom_functions_npu/atb_common.h index 665de6552..456d7f56d 100644 --- a/xllm/core/kernels/npu/custom_functions_npu/atb_common.h +++ b/xllm/core/kernels/npu/custom_functions_npu/atb_common.h @@ -16,6 +16,7 @@ limitations under the License. #pragma once #include +#include #include #include #include @@ -30,7 +31,7 @@ namespace atb { using aclTensor = struct aclTensor; constexpr int64_t MAX_DIM_NUM = 5; -const int N = 32; +const int64_t N = 32; using _aclCreateTensor = aclTensor* (*)(const int64_t* view_dims, uint64_t view_dims_num, @@ -87,7 +88,7 @@ inline void* get_api_func_addr(const char* api_name) { if (func_addr != nullptr) { return func_addr; } - TORCH_CHECK(false, "get_api_func_addr not found ", api_name); + LOG(FATAL) << "get_api_func_addr not found " << api_name; } } @@ -119,8 +120,8 @@ inline aclTensor* convert_type(TensorMaintainer& maintainer, c10::SmallVector storageDims; // if acl_data_type is ACL_STRING, storageDims is empty. if (acl_data_type != ACL_STRING) { - TORCH_CHECK(at_tensor.itemsize() > 0, - "the itemsize of tensor must be greater than 0."); + CHECK_GT(at_tensor.itemsize(), 0) + << "the itemsize of tensor must be greater than 0."; storageDims.push_back(at_tensor.storage().nbytes() / at_tensor.itemsize()); } @@ -245,8 +246,8 @@ inline aclTensor* convert_type_v2(TensorStructPtr at_tensor) { atb::utils::convert_to_acl_data_type(scalar_data_type); c10::SmallVector storageDims; if (acl_data_type != ACL_STRING) { - TORCH_CHECK((*at_tensor).itemsize > 0, - "the itemsize of tensor must be greater than 0."); + CHECK_GT((*at_tensor).itemsize, 0) + << "the itemsize of tensor must be greater than 0."; storageDims.push_back((*at_tensor).nbytes / (*at_tensor).itemsize); } @@ -349,16 +350,10 @@ void release_convert_types(Tuple& t) { static const auto getWorkspaceSizeFuncAddr = \ get_api_func_addr(#atb_api "GetWorkspaceSize"); \ static const auto atbApiFuncAddr = get_api_func_addr(#atb_api); \ - TORCH_CHECK( \ - getWorkspaceSizeFuncAddr != nullptr && atbApiFuncAddr != nullptr, \ - #atb_api, \ - " or ", \ - #atb_api "GetWorkspaceSize", \ - " not in ", \ - get_atb_api_lib_name(), \ - ", or ", \ - get_atb_api_lib_name(), \ - "not found."); \ + CHECK(getWorkspaceSizeFuncAddr != nullptr && atbApiFuncAddr != nullptr) \ + << #atb_api << " or " << #atb_api "GetWorkspaceSize" << " not in " \ + << get_atb_api_lib_name() << ", or " << get_atb_api_lib_name() \ + << "not found."; \ auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \ auto context_ptr = atb::utils::get_context(acl_stream); \ uint64_t workspace_size = 0; \ @@ -374,7 +369,7 @@ void release_convert_types(Tuple& t) { static auto getWorkspaceSizeFunc = \ convert_to_op_api_func(converted_params, getWorkspaceSizeFuncAddr); \ auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \ - TORCH_CHECK(workspace_status == 0, "call " #atb_api " failed, detail:"); \ + CHECK_EQ(workspace_status, 0) << "call " #atb_api " failed, detail:"; \ void* workspace_addr = nullptr; \ at::Tensor workspace_tensor; \ if (workspace_size != 0) { \ @@ -395,7 +390,7 @@ void release_convert_types(Tuple& t) { AtbApiFunc atbApiFunc = reinterpret_cast(atbApiFuncAddr); \ auto api_ret = \ atbApiFunc(workspace_addr, workspace_size, op, context_ptr); \ - TORCH_CHECK(api_ret == 0, "call " #atb_api " failed, detail:"); \ + CHECK_EQ(api_ret, 0) << "call " #atb_api " failed, detail:"; \ DestroyOperation(op); \ release_convert_types(converted_params); \ return api_ret; \ @@ -408,16 +403,10 @@ void release_convert_types(Tuple& t) { static const auto getWorkspaceSizeFuncAddr = \ get_api_func_addr(#atb_api "GetWorkspaceSize"); \ static const auto AtbApiFuncAddr = get_api_func_addr(#atb_api); \ - TORCH_CHECK( \ - getWorkspaceSizeFuncAddr != nullptr && AtbApiFuncAddr != nullptr, \ - #atb_api, \ - " or ", \ - #atb_api "GetWorkspaceSize", \ - " not in ", \ - get_atb_api_lib_name(), \ - ", or ", \ - get_atb_api_lib_name(), \ - "not found."); \ + CHECK(getWorkspaceSizeFuncAddr != nullptr && AtbApiFuncAddr != nullptr) \ + << #atb_api << " or " << #atb_api "GetWorkspaceSize" << " not in " \ + << get_atb_api_lib_name() << ", or " << get_atb_api_lib_name() \ + << "not found."; \ auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \ TensorMaintainer tensor_maintainer; \ auto copied_params = copy_types_v2(tensor_maintainer, __VA_ARGS__); \ @@ -440,8 +429,8 @@ void release_convert_types(Tuple& t) { convert_to_op_api_func(converted_params, getWorkspaceSizeFuncAddr); \ auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \ opParamCache.save_operation(hash_id, op); \ - TORCH_CHECK(workspace_status == 0, \ - "call " #atb_api "GetWorkspaceSize failed"); \ + CHECK_EQ(workspace_status, 0) \ + << "call " #atb_api "GetWorkspaceSize failed"; \ void* workspace_addr = nullptr; \ at::Tensor workspace_tensor; \ if (workspace_size != 0) { \ @@ -451,7 +440,7 @@ void release_convert_types(Tuple& t) { } \ AtbApiFunc atbApiFunc = reinterpret_cast(AtbApiFuncAddr); \ api_ret = atbApiFunc(workspace_addr, workspace_size, op, context_ptr); \ - TORCH_CHECK(api_ret == 0, "call " #atb_api " failed"); \ + CHECK_EQ(api_ret, 0) << "call " #atb_api " failed"; \ release_convert_types(converted_params); \ return api_ret; \ }; \ diff --git a/xllm/core/kernels/npu/custom_functions_npu/operation_create.h b/xllm/core/kernels/npu/custom_functions_npu/operation_create.h index 5ee0917a9..051208b07 100644 --- a/xllm/core/kernels/npu/custom_functions_npu/operation_create.h +++ b/xllm/core/kernels/npu/custom_functions_npu/operation_create.h @@ -15,6 +15,7 @@ limitations under the License. #pragma once +#include #include #include @@ -55,7 +56,7 @@ atb::Operation* create_atb_operation(const ParamType& param, const std::string& name) { atb::Operation* op = nullptr; atb::CreateOperation(param, &op); - TORCH_CHECK(op != nullptr, name, " CreateOperation failed!"); + CHECK(op != nullptr) << name << " CreateOperation failed!"; return op; } diff --git a/xllm/core/kernels/npu/custom_functions_npu/utils.cpp b/xllm/core/kernels/npu/custom_functions_npu/utils.cpp index 290923506..2034ac88b 100644 --- a/xllm/core/kernels/npu/custom_functions_npu/utils.cpp +++ b/xllm/core/kernels/npu/custom_functions_npu/utils.cpp @@ -30,7 +30,7 @@ ContextManager::ContextManager() : atb_context_(nullptr) {} ContextManager::~ContextManager() { if (atb_context_) { auto status = atb::DestroyContext(atb_context_); - TORCH_CHECK(status == 0, "Destroy context failed!"); + CHECK_EQ(status, 0) << "Destroy context failed!"; atb_context_ = nullptr; } } @@ -38,7 +38,7 @@ ContextManager::~ContextManager() { atb::Context* ContextManager::get_context(aclrtStream stream) { std::call_once(create_flag_, [this]() { auto status = atb::CreateContext(&atb_context_); - TORCH_CHECK(status == 0, "Create context failed!"); + CHECK_EQ(status, 0) << "Create context failed!"; }); atb_context_->SetExecuteStream(stream); @@ -52,8 +52,8 @@ atb::Context* get_context(aclrtStream stream) { aclDataType convert_to_acl_data_type(const at::ScalarType& data_type) { auto acl_dtype = kATenScalarTypeToAclDataTypeTable[static_cast(data_type)]; - TORCH_CHECK(acl_dtype != ACL_DT_UNDEFINED, - std::string(c10::toString(data_type)) + " has not been supported") + CHECK_NE(acl_dtype, ACL_DT_UNDEFINED) + << std::string(c10::toString(data_type)) << " has not been supported"; return acl_dtype; } diff --git a/xllm/core/kernels/npu/custom_functions_npu/utils.h b/xllm/core/kernels/npu/custom_functions_npu/utils.h index f605ec55e..99cce00fe 100644 --- a/xllm/core/kernels/npu/custom_functions_npu/utils.h +++ b/xllm/core/kernels/npu/custom_functions_npu/utils.h @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include #include "atb/atb_infer.h" @@ -88,12 +89,8 @@ inline int get_op_mode(const MapType& mode_map, const char* mode_name) { c10::string_view mode_str = mode_opt.value_or(default_mode); auto it = mode_map.find(mode_str); - TORCH_CHECK(it != mode_map.end(), - "Unsupported ", - mode_name, - " value: '", - mode_str, - "'"); + CHECK(it != mode_map.end()) + << "Unsupported " << mode_name << " value: '" << mode_str << "'"; return it->second; } } // namespace utils diff --git a/xllm/core/kernels/npu/fused_layernorm.cpp b/xllm/core/kernels/npu/fused_layernorm.cpp index 3c8e51708..c6e898429 100644 --- a/xllm/core/kernels/npu/fused_layernorm.cpp +++ b/xllm/core/kernels/npu/fused_layernorm.cpp @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "npu_ops_api.h" @@ -19,13 +20,12 @@ limitations under the License. namespace xllm::kernel::npu { -torch::Tensor fused_layernorm(const torch::Tensor& input, - const torch::Tensor& weight, - double eps, - const std::string& mode) { +torch::Tensor rms_norm(const torch::Tensor& input, + const torch::Tensor& weight, + double eps, + const std::string& mode) { if (mode != "rmsnorm") { - throw std::runtime_error( - "Only rmsnorm mode is supported in NPU fused_layernorm"); + LOG(FATAL) << "Only rmsnorm mode is supported in NPU rms_norm"; } std::tuple result = at_npu::native::custom_ops::npu_rms_norm(input, weight, eps); diff --git a/xllm/core/kernels/npu/npu_ops_api.h b/xllm/core/kernels/npu/npu_ops_api.h index 6c10a272f..6c1d79671 100644 --- a/xllm/core/kernels/npu/npu_ops_api.h +++ b/xllm/core/kernels/npu/npu_ops_api.h @@ -18,7 +18,7 @@ limitations under the License. #include -#include "./custom_functions_npu/atb_common.h" +#include "custom_functions_npu/atb_common.h" namespace xllm::kernel::npu { @@ -50,10 +50,10 @@ torch::Tensor matmul(const torch::Tensor& a, torch::Tensor active(const torch::Tensor& input, const std::string& act_mode); -torch::Tensor fused_layernorm(const torch::Tensor& input, - const torch::Tensor& weight, - double eps, - const std::string& mode); +torch::Tensor rms_norm(const torch::Tensor& input, + const torch::Tensor& weight, + double eps, + const std::string& mode); void apply_rotary(torch::Tensor& q, torch::Tensor& k, diff --git a/xllm/core/kernels/npu/ops_npu/npu_ops.h b/xllm/core/kernels/npu/ops_npu/npu_ops.h index 5ed5de487..4850b16d8 100644 --- a/xllm/core/kernels/npu/ops_npu/npu_ops.h +++ b/xllm/core/kernels/npu/ops_npu/npu_ops.h @@ -14,16 +14,12 @@ limitations under the License. ==============================================================================*/ #pragma once -#include "../custom_functions_npu/atb_common.h" +#include "kernels/npu/custom_functions_npu/atb_common.h" using namespace std; namespace atb { -using PagedAttentionParam = atb::infer::PagedAttentionParam; -using ReshapeAndCacheParam = atb::infer::ReshapeAndCacheParam; -using SelfAttentionParam = atb::infer::SelfAttentionParam; - void _npu_paged_attention(const at::Tensor& query, const at::Tensor& key_cache, const at::Tensor& value_cache, diff --git a/xllm/core/kernels/npu/ops_npu/paged_attention_atb.cpp b/xllm/core/kernels/npu/ops_npu/paged_attention_atb.cpp index f43f4ecc9..164896fe6 100644 --- a/xllm/core/kernels/npu/ops_npu/paged_attention_atb.cpp +++ b/xllm/core/kernels/npu/ops_npu/paged_attention_atb.cpp @@ -14,10 +14,9 @@ limitations under the License. ==============================================================================*/ #include -#include "../custom_functions_npu/atb_common.h" +#include "kernels/npu/custom_functions_npu/atb_common.h" namespace atb { -using PagedAttentionParam = atb::infer::PagedAttentionParam; void _npu_paged_attention(const at::Tensor& query, const at::Tensor& key_cache, const at::Tensor& value_cache, @@ -28,20 +27,21 @@ void _npu_paged_attention(const at::Tensor& query, const at::Tensor& context_lens, at::Tensor& out) { const c10::OptionalDeviceGuard device_guard(device_of(query)); - OpParamCache& pagedAttentionParamCache = - OpParamCache::getInstance(); - PagedAttentionParam pagedparam; + OpParamCache& pagedAttentionParamCache = + OpParamCache::getInstance(); + atb::infer::PagedAttentionParam pagedparam; pagedparam.headNum = num_heads; pagedparam.qkScale = scale_value; pagedparam.kvHeadNum = num_kv_heads; - pagedparam.maskType = PagedAttentionParam::UNDEFINED; + pagedparam.maskType = atb::infer::PagedAttentionParam::UNDEFINED; pagedparam.batchRunStatusEnable = false; - pagedparam.quantType = PagedAttentionParam::TYPE_QUANT_UNDEFINED; + pagedparam.quantType = atb::infer::PagedAttentionParam::TYPE_QUANT_UNDEFINED; pagedparam.outDataType = ACL_DT_UNDEFINED; pagedparam.hasQuantOffset = false; - pagedparam.compressType = PagedAttentionParam::COMPRESS_TYPE_UNDEFINED; - pagedparam.calcType = PagedAttentionParam::CALC_TYPE_UNDEFINED; - pagedparam.scaleType = PagedAttentionParam::SCALE_TYPE_TOR; + pagedparam.compressType = + atb::infer::PagedAttentionParam::COMPRESS_TYPE_UNDEFINED; + pagedparam.calcType = atb::infer::PagedAttentionParam::CALC_TYPE_UNDEFINED; + pagedparam.scaleType = atb::infer::PagedAttentionParam::SCALE_TYPE_TOR; pagedparam.inputLayout = atb::infer::TYPE_BSND; pagedparam.mlaVHeadSize = 0; diff --git a/xllm/core/kernels/npu/ops_npu/reshape_and_cach_atb.cpp b/xllm/core/kernels/npu/ops_npu/reshape_and_cach_atb.cpp index c9781cba7..b08e21f8e 100644 --- a/xllm/core/kernels/npu/ops_npu/reshape_and_cach_atb.cpp +++ b/xllm/core/kernels/npu/ops_npu/reshape_and_cach_atb.cpp @@ -15,21 +15,21 @@ limitations under the License. #include -#include "../custom_functions_npu/atb_common.h" +#include "kernels/npu/custom_functions_npu/atb_common.h" using namespace std; namespace atb { -using ReshapeAndCacheParam = atb::infer::ReshapeAndCacheParam; void _npu_reshape_and_cache(const at::Tensor& key, const at::Tensor& value, at::Tensor& key_cache, at::Tensor& value_cache, const at::Tensor& slot_indices) { const c10::OptionalDeviceGuard device_guard(device_of(key)); - OpParamCache& reshapeAndCacheParamCache = - OpParamCache::getInstance(); - ReshapeAndCacheParam reshapeparam; - reshapeparam.compressType = ReshapeAndCacheParam::COMPRESS_TYPE_UNDEFINED; + OpParamCache& reshapeAndCacheParamCache = + OpParamCache::getInstance(); + atb::infer::ReshapeAndCacheParam reshapeparam; + reshapeparam.compressType = + atb::infer::ReshapeAndCacheParam::COMPRESS_TYPE_UNDEFINED; auto key_cache_format = at_npu::native::get_npu_format(key_cache); auto value_cache_format = at_npu::native::get_npu_format(value_cache); @@ -37,9 +37,10 @@ void _npu_reshape_and_cache(const at::Tensor& key, bool is_value_cache_nz = (value_cache_format == ACL_FORMAT_FRACTAL_NZ); if (is_key_cache_nz && is_value_cache_nz) { - reshapeparam.kvCacheCfg = ReshapeAndCacheParam::K_CACHE_V_CACHE_NZ; + reshapeparam.kvCacheCfg = + atb::infer::ReshapeAndCacheParam::K_CACHE_V_CACHE_NZ; } else { - reshapeparam.kvCacheCfg = ReshapeAndCacheParam::K_CACHE_V_CACHE; + reshapeparam.kvCacheCfg = atb::infer::ReshapeAndCacheParam::K_CACHE_V_CACHE; } ParamSetter parametter; diff --git a/xllm/core/kernels/npu/ops_npu/self_attention_atb.cpp b/xllm/core/kernels/npu/ops_npu/self_attention_atb.cpp index 281664161..b43e4ff39 100644 --- a/xllm/core/kernels/npu/ops_npu/self_attention_atb.cpp +++ b/xllm/core/kernels/npu/ops_npu/self_attention_atb.cpp @@ -15,11 +15,10 @@ limitations under the License. #include -#include "../custom_functions_npu/atb_common.h" +#include "kernels/npu/custom_functions_npu/atb_common.h" using namespace std; namespace atb { -using SelfAttentionParam = atb::infer::SelfAttentionParam; void _npu_flash_attention(const at::Tensor& query, const at::Tensor& key, const at::Tensor& value, @@ -30,18 +29,23 @@ void _npu_flash_attention(const at::Tensor& query, const int64_t num_kv_heads, at::Tensor& out) { const c10::OptionalDeviceGuard device_guard(device_of(query)); - OpParamCache& selfAttentionParamCache = - OpParamCache::getInstance(); - SelfAttentionParam selfattentionparam; + OpParamCache& selfAttentionParamCache = + OpParamCache::getInstance(); + atb::infer::SelfAttentionParam selfattentionparam; - selfattentionparam.calcType = SelfAttentionParam::PA_ENCODER; - selfattentionparam.kernelType = SelfAttentionParam::KERNELTYPE_DEFAULT; - selfattentionparam.clampType = SelfAttentionParam::CLAMP_TYPE_UNDEFINED; - selfattentionparam.maskType = SelfAttentionParam::MASK_TYPE_NORM; - selfattentionparam.kvcacheCfg = SelfAttentionParam::K_CACHE_V_CACHE; - selfattentionparam.scaleType = SelfAttentionParam::SCALE_TYPE_TOR; - selfattentionparam.quantType = SelfAttentionParam::TYPE_QUANT_UNDEFINED; - selfattentionparam.cacheType = SelfAttentionParam::CACHE_TYPE_NORM; + selfattentionparam.calcType = atb::infer::SelfAttentionParam::PA_ENCODER; + selfattentionparam.kernelType = + atb::infer::SelfAttentionParam::KERNELTYPE_DEFAULT; + selfattentionparam.clampType = + atb::infer::SelfAttentionParam::CLAMP_TYPE_UNDEFINED; + selfattentionparam.maskType = atb::infer::SelfAttentionParam::MASK_TYPE_NORM; + selfattentionparam.kvcacheCfg = + atb::infer::SelfAttentionParam::K_CACHE_V_CACHE; + selfattentionparam.scaleType = atb::infer::SelfAttentionParam::SCALE_TYPE_TOR; + selfattentionparam.quantType = + atb::infer::SelfAttentionParam::TYPE_QUANT_UNDEFINED; + selfattentionparam.cacheType = + atb::infer::SelfAttentionParam::CACHE_TYPE_NORM; selfattentionparam.outDataType = ACL_DT_UNDEFINED; selfattentionparam.headNum = num_heads; selfattentionparam.kvHeadNum = num_kv_heads; diff --git a/xllm/core/kernels/ops_api.cpp b/xllm/core/kernels/ops_api.cpp index f51b39f5a..c3f7fabe8 100644 --- a/xllm/core/kernels/ops_api.cpp +++ b/xllm/core/kernels/ops_api.cpp @@ -232,8 +232,8 @@ void fused_layernorm(FusedLayerNormParams& params) { cuda::rms_norm(params.output, params.input, params.weight, params.eps); } #elif defined(USE_NPU) - params.output = npu::fused_layernorm( - params.input, params.weight, params.eps, params.mode); + params.output = + npu::rms_norm(params.input, params.weight, params.eps, params.mode); #else LOG(FATAL) << "fused_layernorm not implemented"; #endif From f3d1262fb8f0f6805bc8979befb3c0addd085fe4 Mon Sep 17 00:00:00 2001 From: dengyingxu Date: Mon, 8 Dec 2025 21:03:24 +0800 Subject: [PATCH 6/7] feat: integrate add_rms_norm interface for NPU backend. --- xllm/core/kernels/npu/attention.cpp | 4 ++-- xllm/core/kernels/npu/fused_layernorm.cpp | 8 ++++++++ xllm/core/kernels/npu/npu_ops_api.h | 7 +++++++ xllm/core/kernels/npu/rope.cpp | 2 +- xllm/core/kernels/ops_api.cpp | 10 ++++++++-- 5 files changed, 26 insertions(+), 5 deletions(-) diff --git a/xllm/core/kernels/npu/attention.cpp b/xllm/core/kernels/npu/attention.cpp index 381e5ea79..9dcad44c1 100644 --- a/xllm/core/kernels/npu/attention.cpp +++ b/xllm/core/kernels/npu/attention.cpp @@ -33,8 +33,8 @@ void batch_prefill(const torch::Tensor& query, const torch::Tensor& seq_len, float scale, torch::Tensor& output) { - auto num_heads = query.size(-2); - auto num_kv_heads = key.size(-2); + int64_t num_heads = query.size(-2); + int64_t num_kv_heads = key.size(-2); atb::_npu_flash_attention( query, key, value, mask, seq_len, scale, num_heads, num_kv_heads, output); } diff --git a/xllm/core/kernels/npu/fused_layernorm.cpp b/xllm/core/kernels/npu/fused_layernorm.cpp index c6e898429..6c222fbf1 100644 --- a/xllm/core/kernels/npu/fused_layernorm.cpp +++ b/xllm/core/kernels/npu/fused_layernorm.cpp @@ -33,4 +33,12 @@ torch::Tensor rms_norm(const torch::Tensor& input, return normalized_input; } +std::tuple add_rms_norm( + const torch::Tensor& x1, + const torch::Tensor& x2, + const torch::Tensor& gamma, + double epsilon) { + return at_npu::native::custom_ops::npu_add_rms_norm(x1, x2, gamma, epsilon); +} + } // namespace xllm::kernel::npu \ No newline at end of file diff --git a/xllm/core/kernels/npu/npu_ops_api.h b/xllm/core/kernels/npu/npu_ops_api.h index 6c1d79671..f59f39a07 100644 --- a/xllm/core/kernels/npu/npu_ops_api.h +++ b/xllm/core/kernels/npu/npu_ops_api.h @@ -17,6 +17,7 @@ limitations under the License. #include #include +#include #include "custom_functions_npu/atb_common.h" @@ -55,6 +56,12 @@ torch::Tensor rms_norm(const torch::Tensor& input, double eps, const std::string& mode); +std::tuple add_rms_norm( + const torch::Tensor& x1, + const torch::Tensor& x2, + const torch::Tensor& gamma, + double epsilon); + void apply_rotary(torch::Tensor& q, torch::Tensor& k, const torch::Tensor& cos_sin_cache, diff --git a/xllm/core/kernels/npu/rope.cpp b/xllm/core/kernels/npu/rope.cpp index 9e312f961..7bcbbc7c4 100644 --- a/xllm/core/kernels/npu/rope.cpp +++ b/xllm/core/kernels/npu/rope.cpp @@ -25,7 +25,7 @@ void apply_rotary(torch::Tensor& q, const torch::Tensor& cos_sin_cache, const torch::Tensor& positions) { auto cos_sin = cos_sin_cache.index_select(0, positions); - auto last_dim = cos_sin.size(-1); + int64_t last_dim = cos_sin.size(-1); auto cos_sin_vec = cos_sin.view({-1, 2, last_dim / 2}) .repeat({1, 1, 2}) .chunk(2, /*dim=*/-2); diff --git a/xllm/core/kernels/ops_api.cpp b/xllm/core/kernels/ops_api.cpp index c3f7fabe8..248871a33 100644 --- a/xllm/core/kernels/ops_api.cpp +++ b/xllm/core/kernels/ops_api.cpp @@ -232,8 +232,14 @@ void fused_layernorm(FusedLayerNormParams& params) { cuda::rms_norm(params.output, params.input, params.weight, params.eps); } #elif defined(USE_NPU) - params.output = - npu::rms_norm(params.input, params.weight, params.eps, params.mode); + if (params.residual.has_value()) { + std::tie(params.output, std::ignore, params.residual_out) = + npu::add_rms_norm( + params.input, params.residual.value(), params.weight, params.eps); + } else { + params.output = + npu::rms_norm(params.input, params.weight, params.eps, params.mode); + } #else LOG(FATAL) << "fused_layernorm not implemented"; #endif From 56790e3b5591a475b9c4e5a4023dfcdd9765fc23 Mon Sep 17 00:00:00 2001 From: dengyingxu Date: Tue, 9 Dec 2025 16:39:19 +0800 Subject: [PATCH 7/7] feat: add torch_npu_ops library for NPU backend support. --- .gitmodules | 5 +- third_party/CMakeLists.txt | 1 + third_party/torch_npu_ops | 1 + xllm/core/kernels/CMakeLists.txt | 2 +- xllm/core/kernels/npu/CMakeLists.txt | 28 +- xllm/core/kernels/npu/attention.cpp | 22 +- .../npu/custom_functions_npu/atb_common.cpp | 174 ------- .../npu/custom_functions_npu/atb_common.h | 481 ------------------ .../operation_cache_compute.cpp | 188 ------- .../operation_cache_compute.h | 145 ------ .../custom_functions_npu/operation_create.h | 125 ----- .../npu/custom_functions_npu/utils.cpp | 81 --- .../kernels/npu/custom_functions_npu/utils.h | 97 ---- xllm/core/kernels/npu/ops_npu/npu_ops.h | 49 -- .../npu/ops_npu/paged_attention_atb.cpp | 62 --- .../npu/ops_npu/reshape_and_cach_atb.cpp | 61 --- .../npu/ops_npu/self_attention_atb.cpp | 77 --- 17 files changed, 19 insertions(+), 1580 deletions(-) create mode 160000 third_party/torch_npu_ops delete mode 100644 xllm/core/kernels/npu/custom_functions_npu/atb_common.cpp delete mode 100644 xllm/core/kernels/npu/custom_functions_npu/atb_common.h delete mode 100644 xllm/core/kernels/npu/custom_functions_npu/operation_cache_compute.cpp delete mode 100644 xllm/core/kernels/npu/custom_functions_npu/operation_cache_compute.h delete mode 100644 xllm/core/kernels/npu/custom_functions_npu/operation_create.h delete mode 100644 xllm/core/kernels/npu/custom_functions_npu/utils.cpp delete mode 100644 xllm/core/kernels/npu/custom_functions_npu/utils.h delete mode 100644 xllm/core/kernels/npu/ops_npu/npu_ops.h delete mode 100644 xllm/core/kernels/npu/ops_npu/paged_attention_atb.cpp delete mode 100644 xllm/core/kernels/npu/ops_npu/reshape_and_cach_atb.cpp delete mode 100644 xllm/core/kernels/npu/ops_npu/self_attention_atb.cpp diff --git a/.gitmodules b/.gitmodules index ceffdb5eb..620fcce1d 100755 --- a/.gitmodules +++ b/.gitmodules @@ -27,4 +27,7 @@ url = https://gitcode.com/xLLM-AI/spdlog.git [submodule "third_party/Mooncake"] path = third_party/Mooncake - url = https://gitcode.com/xLLM-AI/Mooncake.git \ No newline at end of file + url = https://gitcode.com/xLLM-AI/Mooncake.git +[submodule "third_party/torch_npu_ops"] + path = third_party/torch_npu_ops + url = https://gitcode.com/xLLM-AI/torch_npu_ops.git diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 7fb8fa937..de45daa90 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -12,6 +12,7 @@ add_subdirectory(etcd_cpp_apiv3) if(USE_NPU) add_subdirectory(spdlog) add_subdirectory(hccl_transfer/hccl_transfer) + add_subdirectory(torch_npu_ops) endif() add_subdirectory(Mooncake) diff --git a/third_party/torch_npu_ops b/third_party/torch_npu_ops new file mode 160000 index 000000000..2bc8f5784 --- /dev/null +++ b/third_party/torch_npu_ops @@ -0,0 +1 @@ +Subproject commit 2bc8f578424948145f65b3119c7af0345d3b18fd diff --git a/xllm/core/kernels/CMakeLists.txt b/xllm/core/kernels/CMakeLists.txt index 3bba0e16b..fbeb3c943 100644 --- a/xllm/core/kernels/CMakeLists.txt +++ b/xllm/core/kernels/CMakeLists.txt @@ -22,7 +22,7 @@ cc_library( ops_api.cpp DEPS torch - $<$:npu_kernels> + $<$:torch_npu_kernels> $<$:mlu_kernels> $<$:cuda_kernels> ) \ No newline at end of file diff --git a/xllm/core/kernels/npu/CMakeLists.txt b/xllm/core/kernels/npu/CMakeLists.txt index 380df2e87..3585da704 100644 --- a/xllm/core/kernels/npu/CMakeLists.txt +++ b/xllm/core/kernels/npu/CMakeLists.txt @@ -1,29 +1,3 @@ include(cc_library) -add_subdirectory(xllm_ops) - -file(GLOB_RECURSE XLLM_CORE_KERNELS_NPU_HEADER - "${CMAKE_CURRENT_LIST_DIR}/custom_functions_npu/*.h" - "${CMAKE_CURRENT_LIST_DIR}/ops_npu/*.h" - "${CMAKE_CURRENT_LIST_DIR}/*.h" -) - -file(GLOB_RECURSE XLLM_CORE_KERNELS_NPU_SRCS - "${CMAKE_CURRENT_LIST_DIR}/custom_functions_npu/*.cpp" - "${CMAKE_CURRENT_LIST_DIR}/ops_npu/*.cpp" - "${CMAKE_CURRENT_LIST_DIR}/*.cpp" -) - -cc_library( - NAME - npu_kernels - HDRS - ${XLLM_CORE_KERNELS_NPU_HEADER} - SRCS - ${XLLM_CORE_KERNELS_NPU_SRCS} - DEPS - :model_context - glog::glog - torch - torch_npu -) +add_subdirectory(xllm_ops) \ No newline at end of file diff --git a/xllm/core/kernels/npu/attention.cpp b/xllm/core/kernels/npu/attention.cpp index 9dcad44c1..d5f4b80ba 100644 --- a/xllm/core/kernels/npu/attention.cpp +++ b/xllm/core/kernels/npu/attention.cpp @@ -22,7 +22,7 @@ void reshape_paged_cache(torch::Tensor& key, torch::Tensor& k_cache, std::optional& v_cache, const torch::Tensor& slot_mapping) { - atb::_npu_reshape_and_cache( + atb::npu_reshape_and_cache( key, value.value(), k_cache, v_cache.value(), slot_mapping); } @@ -35,7 +35,7 @@ void batch_prefill(const torch::Tensor& query, torch::Tensor& output) { int64_t num_heads = query.size(-2); int64_t num_kv_heads = key.size(-2); - atb::_npu_flash_attention( + atb::npu_flash_attention( query, key, value, mask, seq_len, scale, num_heads, num_kv_heads, output); } @@ -51,15 +51,15 @@ void batch_decode(const torch::Tensor& query, int64_t num_kv_heads = k_cache.size(-2); auto q = query.view({-1, num_heads, head_size}); auto o = output.view({-1, num_heads, head_size}); - atb::_npu_paged_attention(q, - k_cache, - v_cache, - num_kv_heads, - num_heads, - scale, - block_table, - seq_lens, - o); + atb::npu_paged_attention(q, + k_cache, + v_cache, + num_kv_heads, + num_heads, + scale, + block_table, + seq_lens, + o); } } // namespace xllm::kernel::npu \ No newline at end of file diff --git a/xllm/core/kernels/npu/custom_functions_npu/atb_common.cpp b/xllm/core/kernels/npu/custom_functions_npu/atb_common.cpp deleted file mode 100644 index 97672427e..000000000 --- a/xllm/core/kernels/npu/custom_functions_npu/atb_common.cpp +++ /dev/null @@ -1,174 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "atb_common.h" - -namespace atb { -atb::Tensor at_tensor_to_atb_tensor(const at::Tensor at_tensor) { - static std::map dtype_map = { - {at::ScalarType::Bool, ACL_BOOL}, - {at::ScalarType::Byte, ACL_UINT8}, - {at::ScalarType::Char, ACL_INT8}, - {at::ScalarType::Half, ACL_FLOAT16}, - {at::ScalarType::Float, ACL_FLOAT}, - {at::ScalarType::Int, ACL_INT32}, - {at::ScalarType::Long, ACL_INT64}, - {at::ScalarType::BFloat16, ACL_BF16}, - {at::ScalarType::Double, ACL_DOUBLE}, - {at::ScalarType::Short, ACL_INT16}, - {at::ScalarType::ComplexHalf, ACL_COMPLEX32}, - {at::ScalarType::ComplexFloat, ACL_COMPLEX64}, - {at::ScalarType::ComplexDouble, ACL_COMPLEX128}, - }; - - CHECK(at_tensor.is_contiguous()) << "at_tensor is not contiguous"; - atb::Tensor tensor; - tensor.desc.format = atb::utils::get_format_for_atb(at_tensor); - if (at_tensor.device().type() == at::kCPU) { - tensor.hostData = at_tensor.data_ptr(); - } else { - tensor.deviceData = at_tensor.data_ptr(); - } - - tensor.desc.shape.dimNum = at_tensor.sizes().size(); - for (uint64_t i = 0; i < at_tensor.sizes().size(); i++) { - tensor.desc.shape.dims[i] = at_tensor.sizes()[i]; - } - - auto dtype_iterator = dtype_map.find(at_tensor.scalar_type()); - CHECK(dtype_iterator != dtype_map.end()) - << "not support dtype: " << at_tensor.scalar_type(); - tensor.desc.dtype = dtype_iterator->second; - - tensor.dataSize = atb::Utils::GetTensorSize(tensor); - - return tensor; -} - -void run_atb_cmd_v1(atb::Operation* op, - const ParamSetter& paramsetter, - const std::string& name) { - aclrtStream stream = c10_npu::getCurrentNPUStream().stream(false); - auto context_ptr = atb::utils::get_context(stream); - atb::VariantPack variant_pack = paramsetter.variant_pack_; - uint64_t workspace_size = operation_setup(variant_pack, op, context_ptr); - at::Tensor workspace_tensor; - void* workspace_ptr = nullptr; - if (workspace_size != 0) { - at::TensorOptions options = at::TensorOptions(c10::DeviceType::PrivateUse1); - workspace_tensor = at::empty({workspace_size}, options.dtype(at::kByte)); - workspace_ptr = const_cast(workspace_tensor.storage().data()); - } - const c10::SmallVector& cpu_tensors = - paramsetter.tensor_maintainer_.cpu_tensors; - auto acl_call = [variant_pack, - workspace_ptr, - workspace_size, - context_ptr, - op, - cpu_tensors]() -> int { - auto st = op->Execute( - variant_pack, (uint8_t*)workspace_ptr, workspace_size, context_ptr); - DestroyOperation(op); - return st; - }; - at_npu::native::OpCommand::RunOpApiV2(name, acl_call); -} - -void run_atb_cmd_v2(atb::Operation* op, - const ParamSetter& paramsetter, - const std::string& name) { - aclrtStream stream = c10_npu::getCurrentNPUStream().stream(false); - atb::VariantPack variant_pack = paramsetter.variant_pack_; - const c10::SmallVector& cpu_tensors = - paramsetter.tensor_maintainer_.cpu_tensors; - auto acl_call = [op, variant_pack, stream, cpu_tensors]() -> int { - auto context_ptr = atb::utils::get_context(stream); - uint64_t workspace_size = operation_setup(variant_pack, op, context_ptr); - at::Tensor workspace_tensor; - void* workspace_ptr = nullptr; - if (workspace_size != 0) { - workspace_tensor = - at_npu::native::allocate_workspace(workspace_size, stream); - workspace_ptr = const_cast(workspace_tensor.storage().data()); - } - auto st = op->Execute( - variant_pack, (uint8_t*)workspace_ptr, workspace_size, context_ptr); - return 0; - }; - at_npu::native::OpCommand::RunOpApiV2(name, acl_call); -} - -void run_atb_cmd(atb::Operation* op, - const ParamSetter& paramsetter, - const std::string& name) { - const auto is_capturing = - static_cast(c10_npu::currentStreamCaptureStatusMayInitCtx()); - if (is_capturing) { - run_atb_cmd_v1(op, paramsetter, name); - } else { - run_atb_cmd_v2(op, paramsetter, name); - } -} - -ParamSetter& ParamSetter::Input(const at::Tensor& tensor, - const bool& format_trans) { - if (!tensor.defined()) { - variant_pack_.inTensors.push_back(atb::Tensor()); - return *this; - } - at::Tensor new_tensor = tensor.contiguous(); - if (format_trans) { - new_tensor = atb::utils::format_trans(new_tensor); - } - atb::Tensor atb_tensor; - if (new_tensor.device().type() == at::kCPU) { - auto tensor_clone = new_tensor.clone(); - atb_tensor = at_tensor_to_atb_tensor(tensor_clone); - tensor_maintainer_.cpu_tensors.emplace_back(std::move(tensor_clone)); - } else { - atb_tensor = at_tensor_to_atb_tensor(new_tensor); - tensor_maintainer_.contiguous_tensors.emplace_back(std::move(new_tensor)); - } - variant_pack_.inTensors.push_back(atb_tensor); - return *this; -} - -ParamSetter& ParamSetter::Input(const c10::optional& tensor, - const bool& format_trans) { - if (!tensor.has_value()) { - variant_pack_.inTensors.push_back(atb::Tensor()); - return *this; - } - return Input(tensor.value(), format_trans); -} - -ParamSetter& ParamSetter::Output(at::Tensor& output) { - auto atb_tensor = at_tensor_to_atb_tensor(output); - variant_pack_.outTensors.push_back(atb_tensor); - return *this; -} - -uint64_t operation_setup(atb::VariantPack variant_pack, - atb::Operation* operation, - atb::Context* context_ptr) { - uint64_t workspace_size = 0; - atb::Status status = - operation->Setup(variant_pack, workspace_size, context_ptr); - CHECK_EQ(status, 0) << operation->GetName() << " setup failed!"; - return workspace_size; -} - -} // namespace atb \ No newline at end of file diff --git a/xllm/core/kernels/npu/custom_functions_npu/atb_common.h b/xllm/core/kernels/npu/custom_functions_npu/atb_common.h deleted file mode 100644 index 456d7f56d..000000000 --- a/xllm/core/kernels/npu/custom_functions_npu/atb_common.h +++ /dev/null @@ -1,481 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "./operation_create.h" -#include "atb/atb_infer.h" -#include "utils.h" - -namespace atb { - -using aclTensor = struct aclTensor; -constexpr int64_t MAX_DIM_NUM = 5; -const int64_t N = 32; - -using _aclCreateTensor = aclTensor* (*)(const int64_t* view_dims, - uint64_t view_dims_num, - aclDataType data_type, - const int64_t* stride, - int64_t offset, - aclFormat format, - const int64_t* storage_dims, - uint64_t storage_dims_num, - void* tensor_data); -using _aclDestroyTensor = int (*)(const aclTensor*); - -using AtbApiFunc = int (*)(void*, uint64_t, atb::Operation*, atb::Context*); - -#define GET_OP_API_FUNC(api_name) \ - reinterpret_cast<_##api_name>(get_api_func_addr(#api_name)) - -inline const char* get_atb_api_lib_name(void) { return "libatb.so"; } - -inline const char* get_op_api_lib_name(void) { return "libopapi.so"; } - -inline void* get_api_lib_handler(const char* lib_name) { - auto handler = dlopen(lib_name, RTLD_LAZY); - if (handler == nullptr) { - ASCEND_LOGW("dlopen %s failed, error:%s.", lib_name, dlerror()); - } - return handler; -} - -inline void* get_api_func_addr_in_lib(void* handler, - const char* lib_name, - const char* api_name) { - auto func_addr = dlsym(handler, api_name); - if (func_addr == nullptr) { - ASCEND_LOGW( - "dlsym %s from %s failed, error:%s.", api_name, lib_name, dlerror()); - } - return func_addr; -} - -inline void* get_api_func_addr(const char* api_name) { - static auto atb_api_handler = get_api_lib_handler(get_atb_api_lib_name()); - if (atb_api_handler != nullptr) { - auto func_addr = get_api_func_addr_in_lib( - atb_api_handler, get_atb_api_lib_name(), api_name); - if (func_addr != nullptr) { - return func_addr; - } - } - static auto op_api_handler = get_api_lib_handler(get_op_api_lib_name()); - if (op_api_handler != nullptr) { - auto func_addr = get_api_func_addr_in_lib( - op_api_handler, get_op_api_lib_name(), api_name); - if (func_addr != nullptr) { - return func_addr; - } - LOG(FATAL) << "get_api_func_addr not found " << api_name; - } -} - -struct TensorMaintainer { - c10::SmallVector - contiguous_tensors; // npu tensor's life should maintain when - // uncontiguous to contiguous. - c10::SmallVector - cpu_tensors; // cpu tensor's life should maintain in taskqueue. -}; - -inline aclTensor* convert_type(TensorMaintainer& maintainer, - const at::Tensor& tensor) { - static const auto aclCreateTensor = - reinterpret_cast<_aclCreateTensor>(get_api_func_addr("aclCreateTensor")); - if (aclCreateTensor == nullptr) { - return nullptr; - } - - if (!tensor.defined()) { - return nullptr; - } - at::Tensor at_tensor = tensor.contiguous(); - aclFormat format = atb::utils::get_format_for_atb(at_tensor); - - at::ScalarType scalar_data_type = at_tensor.scalar_type(); - aclDataType acl_data_type = - atb::utils::convert_to_acl_data_type(scalar_data_type); - c10::SmallVector storageDims; - // if acl_data_type is ACL_STRING, storageDims is empty. - if (acl_data_type != ACL_STRING) { - CHECK_GT(at_tensor.itemsize(), 0) - << "the itemsize of tensor must be greater than 0."; - storageDims.push_back(at_tensor.storage().nbytes() / at_tensor.itemsize()); - } - - const auto dimNum = at_tensor.sizes().size(); - auto acl_tensor = - aclCreateTensor(at_tensor.sizes().data(), - at_tensor.sizes().size(), - acl_data_type, - at_tensor.strides().data(), - at_tensor.storage_offset(), - format, - storageDims.data(), - storageDims.size(), - const_cast(at_tensor.storage().data())); - if (at_tensor.device().type() == at::kCPU) { - maintainer.cpu_tensors.emplace_back(std::move(at_tensor)); - } else { - maintainer.contiguous_tensors.emplace_back(std::move(at_tensor)); - } - return acl_tensor; -} - -inline aclTensor* convert_type(TensorMaintainer& maintainer, - const c10::optional& opt_tensor) { - if (opt_tensor.has_value() && opt_tensor.value().defined()) { - return convert_type(maintainer, opt_tensor.value()); - } - - return nullptr; -} - -template -T convert_type(TensorMaintainer& maintainer, T value) { - return value; -} - -template -constexpr auto convert_types(TensorMaintainer& maintainer, Ts&... args) { - return std::make_tuple(convert_type(maintainer, args)...); -} - -struct TensorStruct { - void* data_ptr = nullptr; // at_tensor.storage().data() - at::ScalarType scalar_type; // at_tensor.scalar_type() - size_t nbytes; // at_tensor.storage().nbytes() - size_t itemsize; // at_tensor.itemsize() - int64_t storage_offset; // at_tensor.storage_offset() - std::vector sizes; // at_tensor.sizes() - std::vector strides; // at_tensor.strides() - aclFormat format; // at_tensor format - - TensorStruct(void* data_ptr_, - at::ScalarType scalar_type_, - size_t nbytes_, - size_t itemsize_, - int64_t storage_offset_, - at::IntArrayRef sizes_, - at::IntArrayRef strides_, - aclFormat format_) - : data_ptr(data_ptr_), - scalar_type(scalar_type_), - nbytes(nbytes_), - itemsize(itemsize_), - storage_offset(storage_offset_), - sizes(sizes_.vec()), - strides(strides_.vec()), - format(format_) {} -}; -using TensorStructPtr = std::shared_ptr; - -inline TensorStructPtr copy_type_v2(TensorMaintainer& maintainer, - const at::Tensor& tensor) { - if (!tensor.defined()) { - return nullptr; - } - at::Tensor at_tensor = tensor.contiguous(); - aclFormat format = atb::utils::get_format_for_atb(at_tensor); - std::shared_ptr tensor_structptr = - std::make_shared( - const_cast(at_tensor.storage().data()), - at_tensor.scalar_type(), - at_tensor.storage().nbytes(), - at_tensor.itemsize(), - at_tensor.storage_offset(), - at_tensor.sizes(), - at_tensor.strides(), - format); - if (at_tensor.device().type() == at::kCPU) { - maintainer.cpu_tensors.emplace_back(std::move(at_tensor)); - } else { - maintainer.contiguous_tensors.emplace_back(std::move(at_tensor)); - } - return tensor_structptr; -} - -inline TensorStructPtr copy_type_v2( - TensorMaintainer& maintainer, - const c10::optional& opt_tensor) { - if (opt_tensor.has_value() && opt_tensor.value().defined()) { - return copy_type_v2(maintainer, opt_tensor.value()); - } - - return nullptr; -} - -template -T copy_type_v2(TensorMaintainer& maintainer, T value) { - return value; -} - -inline aclTensor* convert_type_v2(TensorStructPtr at_tensor) { - static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor); - if (aclCreateTensor == nullptr) { - return nullptr; - } - - if (at_tensor == nullptr) { - return nullptr; - } - at::ScalarType scalar_data_type = (*at_tensor).scalar_type; - aclDataType acl_data_type = - atb::utils::convert_to_acl_data_type(scalar_data_type); - c10::SmallVector storageDims; - if (acl_data_type != ACL_STRING) { - CHECK_GT((*at_tensor).itemsize, 0) - << "the itemsize of tensor must be greater than 0."; - storageDims.push_back((*at_tensor).nbytes / (*at_tensor).itemsize); - } - - const auto dimNum = (*at_tensor).sizes.size(); - - auto acl_tensor = aclCreateTensor((*at_tensor).sizes.data(), - (*at_tensor).sizes.size(), - acl_data_type, - (*at_tensor).strides.data(), - (*at_tensor).storage_offset, - (*at_tensor).format, - storageDims.data(), - storageDims.size(), - (*at_tensor).data_ptr); - return acl_tensor; -} - -template -T convert_type_v2(T value) { - return value; -} - -template -auto convert_types_impl_v2(const Tuple& t, std::index_sequence) { - return std::make_tuple(convert_type_v2(std::get(t))...); -} - -template -constexpr auto convert_types_v2(const std::tuple& args, - uint64_t* workspace_size_addr, - atb::Operation** op_addr, - atb::Context* context_ptr) { - auto convert_args = - convert_types_impl_v2(args, std::make_index_sequence{}); - auto appends = std::make_tuple(workspace_size_addr, op_addr, context_ptr); - return std::tuple_cat(convert_args, appends); -} - -template -constexpr auto copy_types_v2(TensorMaintainer& maintainer, Ts&... args) { - return std::make_tuple(copy_type_v2(maintainer, args)...); -} - -template -auto call(Function f, Tuple t, std::index_sequence) { - return f(std::get(t)...); -} - -template -auto call(Function f, Tuple t) { - static constexpr auto size = std::tuple_size::value; - return call(f, t, std::make_index_sequence{}); -} - -template -auto convert_to_op_api_func(const Tuple& params, - void* opApiAddr, - std::index_sequence) { - using OpApiFunc = - int (*)(typename std::decay(params))>::type...); - auto func = reinterpret_cast(opApiAddr); - return func; -} - -template -auto convert_to_op_api_func(const Tuple& params, void* opApiAddr) { - static constexpr auto size = std::tuple_size::value; - return convert_to_op_api_func( - params, opApiAddr, std::make_index_sequence{}); -} - -inline void release(atb::Context* context) {} - -inline void release(aclTensor* p) { - static const auto aclDestroyTensor = GET_OP_API_FUNC(aclDestroyTensor); - if (aclDestroyTensor == nullptr) { - return; - } - aclDestroyTensor(p); -} - -template -void release(T value) { - (void)value; -} - -template -void call_release(Tuple t, std::index_sequence) { - (void)std::initializer_list{(release(std::get(t)), 0)...}; -} - -template -void release_convert_types(Tuple& t) { - static constexpr auto size = std::tuple_size::value; - call_release(t, std::make_index_sequence{}); -} - -#define EXEC_ATB_CMD_V1(atb_api, ...) \ - do { \ - static const auto getWorkspaceSizeFuncAddr = \ - get_api_func_addr(#atb_api "GetWorkspaceSize"); \ - static const auto atbApiFuncAddr = get_api_func_addr(#atb_api); \ - CHECK(getWorkspaceSizeFuncAddr != nullptr && atbApiFuncAddr != nullptr) \ - << #atb_api << " or " << #atb_api "GetWorkspaceSize" << " not in " \ - << get_atb_api_lib_name() << ", or " << get_atb_api_lib_name() \ - << "not found."; \ - auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \ - auto context_ptr = atb::utils::get_context(acl_stream); \ - uint64_t workspace_size = 0; \ - uint64_t* workspace_size_addr = &workspace_size; \ - atb::Operation* op = nullptr; \ - atb::Operation** op_addr = &op; \ - TensorMaintainer tensor_maintainer; \ - auto converted_params = convert_types(tensor_maintainer, \ - __VA_ARGS__, \ - workspace_size_addr, \ - op_addr, \ - context_ptr); \ - static auto getWorkspaceSizeFunc = \ - convert_to_op_api_func(converted_params, getWorkspaceSizeFuncAddr); \ - auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \ - CHECK_EQ(workspace_status, 0) << "call " #atb_api " failed, detail:"; \ - void* workspace_addr = nullptr; \ - at::Tensor workspace_tensor; \ - if (workspace_size != 0) { \ - at::TensorOptions options = \ - at::TensorOptions(c10::DeviceType::PrivateUse1); \ - workspace_tensor = \ - at::empty({workspace_size}, options.dtype(at::kByte)); \ - workspace_addr = const_cast(workspace_tensor.storage().data()); \ - } \ - const c10::SmallVector& cpu_tensors = \ - tensor_maintainer.cpu_tensors; \ - auto atb_call = [converted_params, \ - workspace_addr, \ - workspace_size, \ - context_ptr, \ - op, \ - cpu_tensors]() -> int { \ - AtbApiFunc atbApiFunc = reinterpret_cast(atbApiFuncAddr); \ - auto api_ret = \ - atbApiFunc(workspace_addr, workspace_size, op, context_ptr); \ - CHECK_EQ(api_ret, 0) << "call " #atb_api " failed, detail:"; \ - DestroyOperation(op); \ - release_convert_types(converted_params); \ - return api_ret; \ - }; \ - at_npu::native::OpCommand::RunOpApiV2(#atb_api, atb_call); \ - } while (false) - -#define EXEC_ATB_CMD_V2(atb_api, ...) \ - do { \ - static const auto getWorkspaceSizeFuncAddr = \ - get_api_func_addr(#atb_api "GetWorkspaceSize"); \ - static const auto AtbApiFuncAddr = get_api_func_addr(#atb_api); \ - CHECK(getWorkspaceSizeFuncAddr != nullptr && AtbApiFuncAddr != nullptr) \ - << #atb_api << " or " << #atb_api "GetWorkspaceSize" << " not in " \ - << get_atb_api_lib_name() << ", or " << get_atb_api_lib_name() \ - << "not found."; \ - auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \ - TensorMaintainer tensor_maintainer; \ - auto copied_params = copy_types_v2(tensor_maintainer, __VA_ARGS__); \ - auto hash_id = compute_hash(std::string(#atb_api), __VA_ARGS__); \ - const c10::SmallVector& cpu_tensors = \ - tensor_maintainer.cpu_tensors; \ - auto atb_call = \ - [copied_params, acl_stream, hash_id, cpu_tensors]() -> int { \ - auto context_ptr = atb::utils::get_context(acl_stream); \ - uint64_t workspace_size = 0; \ - uint64_t* workspace_size_addr = &workspace_size; \ - OpParamCache& opParamCache = \ - OpParamCache::getInstance(); \ - atb::Operation* op = opParamCache.get_operation(hash_id); \ - atb::Operation** op_addr = &op; \ - int api_ret = 0; \ - auto converted_params = convert_types_v2( \ - copied_params, workspace_size_addr, op_addr, context_ptr); \ - auto getWorkspaceSizeFunc = \ - convert_to_op_api_func(converted_params, getWorkspaceSizeFuncAddr); \ - auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \ - opParamCache.save_operation(hash_id, op); \ - CHECK_EQ(workspace_status, 0) \ - << "call " #atb_api "GetWorkspaceSize failed"; \ - void* workspace_addr = nullptr; \ - at::Tensor workspace_tensor; \ - if (workspace_size != 0) { \ - workspace_tensor = \ - at_npu::native::allocate_workspace(workspace_size, acl_stream); \ - workspace_addr = const_cast(workspace_tensor.storage().data()); \ - } \ - AtbApiFunc atbApiFunc = reinterpret_cast(AtbApiFuncAddr); \ - api_ret = atbApiFunc(workspace_addr, workspace_size, op, context_ptr); \ - CHECK_EQ(api_ret, 0) << "call " #atb_api " failed"; \ - release_convert_types(converted_params); \ - return api_ret; \ - }; \ - at_npu::native::OpCommand::RunOpApiV2(#atb_api, atb_call); \ - } while (false) - -#define EXEC_ATB_CMD(atb_api, ...) \ - do { \ - const auto is_capturing = \ - static_cast(c10_npu::currentStreamCaptureStatusMayInitCtx()); \ - if (is_capturing) { \ - EXEC_ATB_CMD_V1(atb_api, __VA_ARGS__); \ - } else { \ - EXEC_ATB_CMD_V2(atb_api, __VA_ARGS__); \ - } \ - } while (false) - -atb::Tensor at_tensor_to_atb_tensor(const at::Tensor atTensor); -atb::Context* get_context(aclrtStream stream); -uint64_t operation_setup(atb::VariantPack variant_pack, - atb::Operation* operation, - atb::Context* context_ptr); -class ParamSetter { - public: - ParamSetter& Input(const at::Tensor& tensor, - const bool& format_trans = false); - ParamSetter& Input(const c10::optional& tensor, - const bool& format_trans = false); - ParamSetter& Output(at::Tensor& tensor); - atb::VariantPack variant_pack_; - TensorMaintainer tensor_maintainer_; -}; - -void run_atb_cmd(atb::Operation* op, - const ParamSetter& paramsetter, - const std::string& name); - -} // namespace atb diff --git a/xllm/core/kernels/npu/custom_functions_npu/operation_cache_compute.cpp b/xllm/core/kernels/npu/custom_functions_npu/operation_cache_compute.cpp deleted file mode 100644 index 9d1f368b7..000000000 --- a/xllm/core/kernels/npu/custom_functions_npu/operation_cache_compute.cpp +++ /dev/null @@ -1,188 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "operation_cache_compute.h" - -namespace atb { - -thread_local char g_hash_buf[g_hash_buf_size]; -thread_local int g_hash_offset = 0; -constexpr int g_rShift33Bits = 33; -constexpr uint64_t MIX_STEP1 = 18397679294719823053LLU; -constexpr uint64_t MIX_STEP2 = 14181476777654086739LLU; - -void add_param_to_buf(const string& s) { - MEMCPY_TO_BUF(s.c_str(), static_cast(s.size())); -} - -void add_param_to_buf(const c10::optional& t) {} -void add_param_to_buf(const at::Tensor& t) {} - -void add_param_to_buf() {} - -inline uint64_t rotating_left(uint64_t x, uint8_t n) { - return (x << n) | (x >> (64 - n)); -} - -inline uint64_t mixture(uint64_t x) { - x ^= x >> g_rShift33Bits; - x *= MIX_STEP1; - x ^= x >> g_rShift33Bits; - x *= MIX_STEP2; - x ^= x >> g_rShift33Bits; - - return x; -} - -uint64_t gen_hash(const void* key, - const int len, - const uint32_t seed = 0xdeadb0d7) { - const uint8_t* data = static_cast(key); - const int block_num = len / 16; - uint64_t has = seed; - uint64_t hax = seed; - - const uint64_t c1 = 9782798678568883157LLU; - const uint64_t c2 = 5545529020109919103LLU; - - const uint64_t* blocks = - static_cast(static_cast(data)); - - for (int i = 0; i < block_num; i++) { - int even_num = 2; - uint64_t tmp1 = blocks[i * even_num]; - uint64_t tmp2 = blocks[i * even_num + 1]; - - int8_t bits_31 = 31; - tmp1 *= c1; - tmp1 = rotating_left(tmp1, bits_31); - tmp1 *= c2; - has ^= tmp1; - - int8_t bits_27 = 27; - has = rotating_left(has, bits_27); - has += hax; - has = has * 5 + 1390208809; - - int8_t bits_33 = 33; - tmp2 *= c2; - tmp2 = rotating_left(tmp2, bits_33); - tmp2 *= c1; - hax ^= tmp2; - - hax = rotating_left(hax, bits_31); - hax += has; - hax = hax * 5 + 944331445; - } - - const uint8_t* tail = data + block_num * 16; - uint64_t t1 = 0; - uint64_t t2 = 0; - switch (static_cast(len) & 15) { - case 15: - t2 ^= (static_cast(tail[14])) << 48; - [[fallthrough]]; - ; - case 14: - t2 ^= (static_cast(tail[13])) << 40; - [[fallthrough]]; - ; - case 13: - t2 ^= (static_cast(tail[12])) << 32; - [[fallthrough]]; - ; - case 12: - t2 ^= (static_cast(tail[11])) << 24; - [[fallthrough]]; - ; - case 11: - t2 ^= (static_cast(tail[10])) << 16; - [[fallthrough]]; - ; - case 10: - t2 ^= (static_cast(tail[9])) << 8; - [[fallthrough]]; - ; - case 9: - t2 ^= (static_cast(tail[8])) << 0; - t2 *= c2; - t2 = rotating_left(t2, 33); - t2 *= c1; - hax ^= t2; - [[fallthrough]]; - ; - case 8: - t1 ^= (static_cast(tail[7])) << 56; - [[fallthrough]]; - ; - case 7: - t1 ^= (static_cast(tail[6])) << 48; - [[fallthrough]]; - ; - case 6: - t1 ^= (static_cast(tail[5])) << 40; - [[fallthrough]]; - ; - case 5: - t1 ^= (static_cast(tail[4])) << 32; - [[fallthrough]]; - ; - case 4: - t1 ^= (static_cast(tail[3])) << 24; - [[fallthrough]]; - ; - case 3: - t1 ^= (static_cast(tail[2])) << 16; - [[fallthrough]]; - ; - case 2: - t1 ^= (static_cast(tail[1])) << 8; - [[fallthrough]]; - ; - case 1: - t1 ^= (static_cast(tail[0])) << 0; - t1 *= c1; - t1 = rotating_left(t1, 31); - t1 *= c2; - has ^= t1; - [[fallthrough]]; - ; - default: - break; - }; - - has ^= static_cast(len); - hax ^= static_cast(len); - - has += hax; - hax += has; - - has = mixture(has); - hax = mixture(hax); - - has += hax; - hax += has; - return hax; -} - -uint64_t calc_hash_id() { - if (g_hash_offset == g_hash_buf_max_size) { - return 0; - } - uint64_t hash_id = gen_hash(g_hash_buf, g_hash_offset); - return hash_id; -} - -} // namespace atb diff --git a/xllm/core/kernels/npu/custom_functions_npu/operation_cache_compute.h b/xllm/core/kernels/npu/custom_functions_npu/operation_cache_compute.h deleted file mode 100644 index 3149f125d..000000000 --- a/xllm/core/kernels/npu/custom_functions_npu/operation_cache_compute.h +++ /dev/null @@ -1,145 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#pragma once - -#include - -#include -#include -#include - -#include "atb/atb_infer.h" - -namespace atb { - -constexpr int g_hash_buf_size = 8192; -constexpr int g_hash_buf_max_size = g_hash_buf_size + 1024; -extern thread_local char g_hash_buf[g_hash_buf_size]; -extern thread_local int g_hash_offset; - -#define MEMCPY_TO_BUF(data_expression, size_expression) \ - if (g_hash_offset + (size_expression) > g_hash_buf_size) { \ - g_hash_offset = g_hash_buf_max_size; \ - return; \ - } \ - memcpy(g_hash_buf + g_hash_offset, data_expression, size_expression); \ - g_hash_offset += size_expression; - -uint64_t calc_hash_id(); - -template -void add_param_to_buf(const T& value) { - MEMCPY_TO_BUF(&value, sizeof(T)); -} - -void add_param_to_buf(const string& s); -void add_param_to_buf(const c10::optional& t); -void add_param_to_buf(const at::Tensor& t); -void add_param_to_buf(); - -template -void add_param_to_buf(const std::string& name, const T& value) { - add_param_to_buf(name); - add_param_to_buf(value); -} - -template -void add_param_to_buf(const T& arg, Args&... args) { - add_param_to_buf(arg); - add_param_to_buf(args...); -} - -template -struct HashOpParam { - void operator()(const T& param) const {}; -}; -template <> -struct HashOpParam { - void operator()(const atb::infer::RmsNormParam& param) const { - add_param_to_buf("epsilon", param.normParam.epsilon); - add_param_to_buf("layerType", param.layerType); - add_param_to_buf("quantType", param.normParam.quantType); - } -}; - -template <> -struct HashOpParam { - void operator()(const atb::infer::GroupTopkParam& param) const { - add_param_to_buf("groupNum", param.groupNum); - add_param_to_buf("k", param.k); - add_param_to_buf("groupMultiFlag", param.groupMultiFlag); - add_param_to_buf("n", param.n); - } -}; - -template <> -struct HashOpParam { - void operator()(const atb::infer::PagedAttentionParam& param) const { - add_param_to_buf("num_kv_heads", param.kvHeadNum); - add_param_to_buf("num_heads", param.headNum); - add_param_to_buf("scale_value", param.qkScale); - add_param_to_buf("quant_type", param.quantType); - add_param_to_buf("outdata_type", param.outDataType); - add_param_to_buf("mla_vheadsize", param.mlaVHeadSize); - add_param_to_buf("maskType", param.maskType); - add_param_to_buf("calcType", param.calcType); - } -}; - -template <> -struct HashOpParam { - void operator()(const atb::infer::SelfAttentionParam& param) const { - add_param_to_buf("num_kv_heads", param.kvHeadNum); - add_param_to_buf("num_heads", param.headNum); - add_param_to_buf("scale_value", param.qkScale); - add_param_to_buf("calcType", param.calcType); - add_param_to_buf("kernelType", param.kernelType); - add_param_to_buf("maskType", param.maskType); - add_param_to_buf("quantType", param.quantType); - add_param_to_buf("isTriuMask", param.isTriuMask); - } -}; - -template <> -struct HashOpParam { - void operator()(const atb::infer::RopeParam& param) const { - add_param_to_buf("rotaryCoeff", param.rotaryCoeff); - } -}; - -template <> -struct HashOpParam { - void operator()(const atb::infer::ReshapeAndCacheParam& param) const { - add_param_to_buf("compressType", param.compressType); - add_param_to_buf("kvCacheCfg", param.kvCacheCfg); - } -}; - -template -uint64_t compute_hash(const T& obj) { - g_hash_offset = 0; - HashOpParam{}(obj); - return calc_hash_id(); -} - -template -uint64_t compute_hash(const std::string& name, Ts&... args) { - g_hash_offset = 0; - add_param_to_buf(name, args...); - return calc_hash_id(); -} - -} // namespace atb diff --git a/xllm/core/kernels/npu/custom_functions_npu/operation_create.h b/xllm/core/kernels/npu/custom_functions_npu/operation_create.h deleted file mode 100644 index 051208b07..000000000 --- a/xllm/core/kernels/npu/custom_functions_npu/operation_create.h +++ /dev/null @@ -1,125 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#pragma once - -#include -#include -#include - -#include -#include -#include - -#include "atb/atb_infer.h" -#include "operation_cache_compute.h" -#include "utils.h" - -namespace atb { - -template -class OpParamCache { - public: - static OpParamCache& getInstance(); - - atb::Operation* get_operation(const ParamType& param, - const std::string& name); - atb::Operation* get_operation(uint64_t hash_id); - void save_operation(uint64_t hash_id, atb::Operation* op); - - private: - OpParamCache(); - - OpParamCache(const OpParamCache&) = delete; - OpParamCache& operator=(const OpParamCache&) = delete; - - ~OpParamCache(); - - std::unordered_map op_map_; - mutable std::mutex mutex_; -}; - -template -atb::Operation* create_atb_operation(const ParamType& param, - const std::string& name) { - atb::Operation* op = nullptr; - atb::CreateOperation(param, &op); - CHECK(op != nullptr) << name << " CreateOperation failed!"; - return op; -} - -template -OpParamCache& OpParamCache::getInstance() { - static OpParamCache instance; - return instance; -} - -template -atb::Operation* OpParamCache::get_operation( - const ParamType& param, - const std::string& name) { - const auto is_capturing = - static_cast(c10_npu::currentStreamCaptureStatusMayInitCtx()); - if (is_capturing) { - return create_atb_operation(param, name); - } else { - uint64_t hashValue = compute_hash(param); - { - std::lock_guard lock(mutex_); - auto op_cache = op_map_.find(hashValue); - if (op_cache != op_map_.end()) { - return op_cache->second; - } - atb::Operation* op = create_atb_operation(param, name); - op_map_[hashValue] = op; - return op; - } - } -} - -template -atb::Operation* OpParamCache::get_operation(uint64_t hash_id) { - std::lock_guard lock(mutex_); - auto op_cache = op_map_.find(hash_id); - if (op_cache != op_map_.end()) { - return op_cache->second; - } - - atb::Operation* op = nullptr; - return op; -} - -template -void OpParamCache::save_operation(uint64_t hash_id, - atb::Operation* op) { - std::lock_guard lock(mutex_); - op_map_[hash_id] = op; - return; -} - -template -OpParamCache::OpParamCache() { - atb::utils::ContextManager::get_instance(); -} - -template -OpParamCache::~OpParamCache() { - std::lock_guard lock(mutex_); - for (auto& op_item : op_map_) { - DestroyOperation(op_item.second); - } -} - -} // namespace atb diff --git a/xllm/core/kernels/npu/custom_functions_npu/utils.cpp b/xllm/core/kernels/npu/custom_functions_npu/utils.cpp deleted file mode 100644 index 2034ac88b..000000000 --- a/xllm/core/kernels/npu/custom_functions_npu/utils.cpp +++ /dev/null @@ -1,81 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "utils.h" - -#include - -namespace atb { -namespace utils { - -ContextManager& ContextManager::get_instance() { - static ContextManager instance; - return instance; -} - -ContextManager::ContextManager() : atb_context_(nullptr) {} - -ContextManager::~ContextManager() { - if (atb_context_) { - auto status = atb::DestroyContext(atb_context_); - CHECK_EQ(status, 0) << "Destroy context failed!"; - atb_context_ = nullptr; - } -} - -atb::Context* ContextManager::get_context(aclrtStream stream) { - std::call_once(create_flag_, [this]() { - auto status = atb::CreateContext(&atb_context_); - CHECK_EQ(status, 0) << "Create context failed!"; - }); - - atb_context_->SetExecuteStream(stream); - return atb_context_; -} - -atb::Context* get_context(aclrtStream stream) { - return ContextManager::get_instance().get_context(stream); -} - -aclDataType convert_to_acl_data_type(const at::ScalarType& data_type) { - auto acl_dtype = - kATenScalarTypeToAclDataTypeTable[static_cast(data_type)]; - CHECK_NE(acl_dtype, ACL_DT_UNDEFINED) - << std::string(c10::toString(data_type)) << " has not been supported"; - return acl_dtype; -} - -at::Tensor format_trans(const at::Tensor& at_tensor) { - if (torch_npu::utils::is_npu(at_tensor)) { - return at_npu::native::npu_format_cast(at_tensor, ACL_FORMAT_ND); - } - return at_tensor; -} - -bool is_base_format(aclFormat& format) { - return (format == ACL_FORMAT_NCHW) || (format == ACL_FORMAT_ND) || - (format == ACL_FORMAT_NHWC) || (format == ACL_FORMAT_NCDHW); -} - -aclFormat get_format_for_atb(const at::Tensor& at_tensor) { - if (torch_npu::utils::is_npu(at_tensor)) { - aclFormat format = - static_cast(at_npu::native::get_npu_format(at_tensor)); - return is_base_format(format) ? ACL_FORMAT_ND : format; - } - return ACL_FORMAT_ND; -} -} // namespace utils -} // namespace atb diff --git a/xllm/core/kernels/npu/custom_functions_npu/utils.h b/xllm/core/kernels/npu/custom_functions_npu/utils.h deleted file mode 100644 index 99cce00fe..000000000 --- a/xllm/core/kernels/npu/custom_functions_npu/utils.h +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#pragma once - -#include -#include -#include -#include - -#include "atb/atb_infer.h" - -namespace atb { -namespace utils { - -class ContextManager { - public: - static ContextManager& get_instance(); - atb::Context* get_context(aclrtStream stream); - ~ContextManager(); - - ContextManager(const ContextManager&) = delete; - ContextManager& operator=(const ContextManager&) = delete; - - private: - ContextManager(); - std::once_flag create_flag_; - atb::Context* atb_context_; -}; - -atb::Context* get_context(aclrtStream stream); - -#define AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(_) \ - _(at::ScalarType::Byte, ACL_UINT8) \ - _(at::ScalarType::Char, ACL_INT8) \ - _(at::ScalarType::Short, ACL_INT16) \ - _(at::ScalarType::Int, ACL_INT32) \ - _(at::ScalarType::Long, ACL_INT64) \ - _(at::ScalarType::Half, ACL_FLOAT16) \ - _(at::ScalarType::Float, ACL_FLOAT) \ - _(at::ScalarType::Double, ACL_DOUBLE) \ - _(at::ScalarType::ComplexHalf, ACL_COMPLEX32) \ - _(at::ScalarType::ComplexFloat, ACL_COMPLEX64) \ - _(at::ScalarType::ComplexDouble, ACL_COMPLEX128) \ - _(at::ScalarType::Bool, ACL_BOOL) \ - _(at::ScalarType::QInt8, ACL_DT_UNDEFINED) \ - _(at::ScalarType::QUInt8, ACL_DT_UNDEFINED) \ - _(at::ScalarType::QInt32, ACL_DT_UNDEFINED) \ - _(at::ScalarType::BFloat16, ACL_BF16) \ - _(at::ScalarType::QUInt4x2, ACL_DT_UNDEFINED) \ - _(at::ScalarType::QUInt2x4, ACL_DT_UNDEFINED) \ - _(at::ScalarType::Bits1x8, ACL_DT_UNDEFINED) \ - _(at::ScalarType::Bits2x4, ACL_DT_UNDEFINED) \ - _(at::ScalarType::Bits4x2, ACL_DT_UNDEFINED) \ - _(at::ScalarType::Bits8, ACL_DT_UNDEFINED) \ - _(at::ScalarType::Bits16, ACL_DT_UNDEFINED) \ - _(at::ScalarType::Float8_e5m2, ACL_DT_UNDEFINED) \ - _(at::ScalarType::Float8_e4m3fn, ACL_DT_UNDEFINED) \ - _(at::ScalarType::Undefined, ACL_DT_UNDEFINED) \ - _(at::ScalarType::NumOptions, ACL_DT_UNDEFINED) - -constexpr aclDataType kATenScalarTypeToAclDataTypeTable - [static_cast(at::ScalarType::NumOptions) + 1] = { -#define DEFINE_ENUM(_1, n) n, - AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(DEFINE_ENUM) -#undef DEFINE_ENUM -}; - -aclDataType convert_to_acl_data_type(const at::ScalarType& data_type); -at::Tensor format_trans(const at::Tensor& at_tensor); -aclFormat get_format_for_atb(const at::Tensor& at_tensor); - -template -inline int get_op_mode(const MapType& mode_map, - c10::optional mode_opt, - c10::string_view default_mode, - const char* mode_name) { - c10::string_view mode_str = mode_opt.value_or(default_mode); - auto it = mode_map.find(mode_str); - CHECK(it != mode_map.end()) - << "Unsupported " << mode_name << " value: '" << mode_str << "'"; - return it->second; -} -} // namespace utils -} // namespace atb diff --git a/xllm/core/kernels/npu/ops_npu/npu_ops.h b/xllm/core/kernels/npu/ops_npu/npu_ops.h deleted file mode 100644 index 4850b16d8..000000000 --- a/xllm/core/kernels/npu/ops_npu/npu_ops.h +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#pragma once - -#include "kernels/npu/custom_functions_npu/atb_common.h" - -using namespace std; - -namespace atb { - -void _npu_paged_attention(const at::Tensor& query, - const at::Tensor& key_cache, - const at::Tensor& value_cache, - int64_t num_kv_heads, - int64_t num_heads, - double scale_value, - const at::Tensor& block_table, - const at::Tensor& context_lens, - at::Tensor& out); - -void _npu_reshape_and_cache(const at::Tensor& key, - const at::Tensor& value, - at::Tensor& key_cache, - at::Tensor& value_cache, - const at::Tensor& slot_indices); - -void _npu_flash_attention(const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const at::Tensor& mask, - const at::Tensor& seq_len, - const double scale_value, - const int64_t num_heads, - const int64_t num_kv_heads, - at::Tensor& out); - -} // namespace atb diff --git a/xllm/core/kernels/npu/ops_npu/paged_attention_atb.cpp b/xllm/core/kernels/npu/ops_npu/paged_attention_atb.cpp deleted file mode 100644 index 164896fe6..000000000 --- a/xllm/core/kernels/npu/ops_npu/paged_attention_atb.cpp +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include - -#include "kernels/npu/custom_functions_npu/atb_common.h" - -namespace atb { -void _npu_paged_attention(const at::Tensor& query, - const at::Tensor& key_cache, - const at::Tensor& value_cache, - int64_t num_kv_heads, - int64_t num_heads, - double scale_value, - const at::Tensor& block_table, - const at::Tensor& context_lens, - at::Tensor& out) { - const c10::OptionalDeviceGuard device_guard(device_of(query)); - OpParamCache& pagedAttentionParamCache = - OpParamCache::getInstance(); - atb::infer::PagedAttentionParam pagedparam; - pagedparam.headNum = num_heads; - pagedparam.qkScale = scale_value; - pagedparam.kvHeadNum = num_kv_heads; - pagedparam.maskType = atb::infer::PagedAttentionParam::UNDEFINED; - pagedparam.batchRunStatusEnable = false; - pagedparam.quantType = atb::infer::PagedAttentionParam::TYPE_QUANT_UNDEFINED; - pagedparam.outDataType = ACL_DT_UNDEFINED; - pagedparam.hasQuantOffset = false; - pagedparam.compressType = - atb::infer::PagedAttentionParam::COMPRESS_TYPE_UNDEFINED; - pagedparam.calcType = atb::infer::PagedAttentionParam::CALC_TYPE_UNDEFINED; - pagedparam.scaleType = atb::infer::PagedAttentionParam::SCALE_TYPE_TOR; - pagedparam.inputLayout = atb::infer::TYPE_BSND; - pagedparam.mlaVHeadSize = 0; - - ParamSetter paramsetter; - paramsetter.Input(query, true) - .Input(key_cache) - .Input(value_cache) - .Input(block_table, true) - .Input(context_lens, true) - .Output(out); - auto opPaged = pagedAttentionParamCache.get_operation( - pagedparam, "PagedAttentionOperation"); - run_atb_cmd(opPaged, paramsetter, "PagedAttentionOperation"); - - return; -} - -} // namespace atb \ No newline at end of file diff --git a/xllm/core/kernels/npu/ops_npu/reshape_and_cach_atb.cpp b/xllm/core/kernels/npu/ops_npu/reshape_and_cach_atb.cpp deleted file mode 100644 index b08e21f8e..000000000 --- a/xllm/core/kernels/npu/ops_npu/reshape_and_cach_atb.cpp +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "kernels/npu/custom_functions_npu/atb_common.h" - -using namespace std; -namespace atb { -void _npu_reshape_and_cache(const at::Tensor& key, - const at::Tensor& value, - at::Tensor& key_cache, - at::Tensor& value_cache, - const at::Tensor& slot_indices) { - const c10::OptionalDeviceGuard device_guard(device_of(key)); - OpParamCache& reshapeAndCacheParamCache = - OpParamCache::getInstance(); - atb::infer::ReshapeAndCacheParam reshapeparam; - reshapeparam.compressType = - atb::infer::ReshapeAndCacheParam::COMPRESS_TYPE_UNDEFINED; - - auto key_cache_format = at_npu::native::get_npu_format(key_cache); - auto value_cache_format = at_npu::native::get_npu_format(value_cache); - bool is_key_cache_nz = (key_cache_format == ACL_FORMAT_FRACTAL_NZ); - bool is_value_cache_nz = (value_cache_format == ACL_FORMAT_FRACTAL_NZ); - - if (is_key_cache_nz && is_value_cache_nz) { - reshapeparam.kvCacheCfg = - atb::infer::ReshapeAndCacheParam::K_CACHE_V_CACHE_NZ; - } else { - reshapeparam.kvCacheCfg = atb::infer::ReshapeAndCacheParam::K_CACHE_V_CACHE; - } - - ParamSetter parametter; - parametter.Input(key, true) - .Input(value, true) - .Input(key_cache) - .Input(value_cache) - .Input(slot_indices, true) - .Output(key_cache) - .Output(value_cache); - auto opReshape = reshapeAndCacheParamCache.get_operation( - reshapeparam, "ReshapeCacheOperation"); - run_atb_cmd(opReshape, parametter, "ReshapeCacheOperation"); - - return; -} - -} // namespace atb diff --git a/xllm/core/kernels/npu/ops_npu/self_attention_atb.cpp b/xllm/core/kernels/npu/ops_npu/self_attention_atb.cpp deleted file mode 100644 index b43e4ff39..000000000 --- a/xllm/core/kernels/npu/ops_npu/self_attention_atb.cpp +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "kernels/npu/custom_functions_npu/atb_common.h" - -using namespace std; -namespace atb { -void _npu_flash_attention(const at::Tensor& query, - const at::Tensor& key, - const at::Tensor& value, - const at::Tensor& mask, - const at::Tensor& seq_len, - const double scale_value, - const int64_t num_heads, - const int64_t num_kv_heads, - at::Tensor& out) { - const c10::OptionalDeviceGuard device_guard(device_of(query)); - OpParamCache& selfAttentionParamCache = - OpParamCache::getInstance(); - atb::infer::SelfAttentionParam selfattentionparam; - - selfattentionparam.calcType = atb::infer::SelfAttentionParam::PA_ENCODER; - selfattentionparam.kernelType = - atb::infer::SelfAttentionParam::KERNELTYPE_DEFAULT; - selfattentionparam.clampType = - atb::infer::SelfAttentionParam::CLAMP_TYPE_UNDEFINED; - selfattentionparam.maskType = atb::infer::SelfAttentionParam::MASK_TYPE_NORM; - selfattentionparam.kvcacheCfg = - atb::infer::SelfAttentionParam::K_CACHE_V_CACHE; - selfattentionparam.scaleType = atb::infer::SelfAttentionParam::SCALE_TYPE_TOR; - selfattentionparam.quantType = - atb::infer::SelfAttentionParam::TYPE_QUANT_UNDEFINED; - selfattentionparam.cacheType = - atb::infer::SelfAttentionParam::CACHE_TYPE_NORM; - selfattentionparam.outDataType = ACL_DT_UNDEFINED; - selfattentionparam.headNum = num_heads; - selfattentionparam.kvHeadNum = num_kv_heads; - selfattentionparam.qScale = 1; - selfattentionparam.qkScale = scale_value; - selfattentionparam.batchRunStatusEnable = false; - selfattentionparam.isTriuMask = 0; - selfattentionparam.clampMin = 0; - selfattentionparam.clampMax = 0; - selfattentionparam.inputLayout = atb::infer::TYPE_BSND; - selfattentionparam.mlaVHeadSize = 0; - selfattentionparam.windowSize = 0; - - ParamSetter parametter; - parametter.Input(query, true) - .Input(key, true) - .Input(value, true) - .Input(mask) - .Input(seq_len, true) - .Output(out); - - auto opSelfattention = selfAttentionParamCache.get_operation( - selfattentionparam, "SelfAttentionOperation"); - run_atb_cmd(opSelfattention, parametter, "SelfAttentionOperation"); - - return; -} - -} // namespace atb \ No newline at end of file