diff --git a/.gitmodules b/.gitmodules index ceffdb5eb..620fcce1d 100755 --- a/.gitmodules +++ b/.gitmodules @@ -27,4 +27,7 @@ url = https://gitcode.com/xLLM-AI/spdlog.git [submodule "third_party/Mooncake"] path = third_party/Mooncake - url = https://gitcode.com/xLLM-AI/Mooncake.git \ No newline at end of file + url = https://gitcode.com/xLLM-AI/Mooncake.git +[submodule "third_party/torch_npu_ops"] + path = third_party/torch_npu_ops + url = https://gitcode.com/xLLM-AI/torch_npu_ops.git diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 7fb8fa937..de45daa90 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -12,6 +12,7 @@ add_subdirectory(etcd_cpp_apiv3) if(USE_NPU) add_subdirectory(spdlog) add_subdirectory(hccl_transfer/hccl_transfer) + add_subdirectory(torch_npu_ops) endif() add_subdirectory(Mooncake) diff --git a/third_party/torch_npu_ops b/third_party/torch_npu_ops new file mode 160000 index 000000000..2bc8f5784 --- /dev/null +++ b/third_party/torch_npu_ops @@ -0,0 +1 @@ +Subproject commit 2bc8f578424948145f65b3119c7af0345d3b18fd diff --git a/xllm/core/kernels/CMakeLists.txt b/xllm/core/kernels/CMakeLists.txt index 3bba0e16b..fbeb3c943 100644 --- a/xllm/core/kernels/CMakeLists.txt +++ b/xllm/core/kernels/CMakeLists.txt @@ -22,7 +22,7 @@ cc_library( ops_api.cpp DEPS torch - $<$:npu_kernels> + $<$:torch_npu_kernels> $<$:mlu_kernels> $<$:cuda_kernels> ) \ No newline at end of file diff --git a/xllm/core/kernels/npu/CMakeLists.txt b/xllm/core/kernels/npu/CMakeLists.txt index 412f7f188..3585da704 100644 --- a/xllm/core/kernels/npu/CMakeLists.txt +++ b/xllm/core/kernels/npu/CMakeLists.txt @@ -1,14 +1,3 @@ include(cc_library) -add_subdirectory(xllm_ops) - -cc_library( - NAME - npu_kernels - HDRS - linear.h - split.h - rope.h - DEPS - # spdlog::spdlog -) +add_subdirectory(xllm_ops) \ No newline at end of file diff --git a/xllm/core/kernels/npu/rms_norm.h b/xllm/core/kernels/npu/active.cpp similarity index 59% rename from xllm/core/kernels/npu/rms_norm.h rename to xllm/core/kernels/npu/active.cpp index ed7f8d047..49a9e94f4 100644 --- a/xllm/core/kernels/npu/rms_norm.h +++ b/xllm/core/kernels/npu/active.cpp @@ -13,20 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#pragma once -#include "impl/npu_rms_norm_impl.h" +#include +#include -namespace xllm { -namespace kernel { +#include "npu_ops_api.h" +#include "ops_npu/npu_ops.h" -class RmsNorm : public torch::nn::ModuleHolder { - public: - using torch::nn::ModuleHolder::ModuleHolder; - using Impl __attribute__((__unused__)) = NpuRmsNormImpl; +namespace xllm::kernel::npu { - RmsNorm(const ModelContext& context) - : ModuleHolder(std::make_shared(context)) {} -}; - -} // namespace kernel -} // namespace xllm +torch::Tensor active(const torch::Tensor& input, const std::string& act_mode) { + if (act_mode != "silu" && act_mode != "swiglu") { + LOG(FATAL) << "Only swiglu activation is supported in NPU active"; + } + return at_npu::native::custom_ops::npu_swiglu(input); +} +} // namespace xllm::kernel::npu \ No newline at end of file diff --git a/xllm/core/kernels/npu/attention.cpp b/xllm/core/kernels/npu/attention.cpp new file mode 100644 index 000000000..d5f4b80ba --- /dev/null +++ b/xllm/core/kernels/npu/attention.cpp @@ -0,0 +1,65 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "npu_ops_api.h" +#include "ops_npu/npu_ops.h" +namespace xllm::kernel::npu { + +void reshape_paged_cache(torch::Tensor& key, + std::optional& value, + torch::Tensor& k_cache, + std::optional& v_cache, + const torch::Tensor& slot_mapping) { + atb::npu_reshape_and_cache( + key, value.value(), k_cache, v_cache.value(), slot_mapping); +} + +void batch_prefill(const torch::Tensor& query, + const torch::Tensor& key, + const torch::Tensor& value, + const torch::Tensor& mask, + const torch::Tensor& seq_len, + float scale, + torch::Tensor& output) { + int64_t num_heads = query.size(-2); + int64_t num_kv_heads = key.size(-2); + atb::npu_flash_attention( + query, key, value, mask, seq_len, scale, num_heads, num_kv_heads, output); +} + +void batch_decode(const torch::Tensor& query, + const torch::Tensor& k_cache, + const torch::Tensor& v_cache, + float scale, + const torch::Tensor& block_table, + const torch::Tensor& seq_lens, + torch::Tensor& output) { + int64_t head_size = query.size(-1); + int64_t num_heads = query.size(-2); + int64_t num_kv_heads = k_cache.size(-2); + auto q = query.view({-1, num_heads, head_size}); + auto o = output.view({-1, num_heads, head_size}); + atb::npu_paged_attention(q, + k_cache, + v_cache, + num_kv_heads, + num_heads, + scale, + block_table, + seq_lens, + o); +} + +} // namespace xllm::kernel::npu \ No newline at end of file diff --git a/xllm/core/kernels/npu/fused_layernorm.cpp b/xllm/core/kernels/npu/fused_layernorm.cpp new file mode 100644 index 000000000..6c222fbf1 --- /dev/null +++ b/xllm/core/kernels/npu/fused_layernorm.cpp @@ -0,0 +1,44 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "npu_ops_api.h" +#include "ops_npu/npu_ops.h" + +namespace xllm::kernel::npu { + +torch::Tensor rms_norm(const torch::Tensor& input, + const torch::Tensor& weight, + double eps, + const std::string& mode) { + if (mode != "rmsnorm") { + LOG(FATAL) << "Only rmsnorm mode is supported in NPU rms_norm"; + } + std::tuple result = + at_npu::native::custom_ops::npu_rms_norm(input, weight, eps); + auto normalized_input = std::get<0>(result); + return normalized_input; +} + +std::tuple add_rms_norm( + const torch::Tensor& x1, + const torch::Tensor& x2, + const torch::Tensor& gamma, + double epsilon) { + return at_npu::native::custom_ops::npu_add_rms_norm(x1, x2, gamma, epsilon); +} + +} // namespace xllm::kernel::npu \ No newline at end of file diff --git a/xllm/core/kernels/npu/linear.h b/xllm/core/kernels/npu/matmul.cpp similarity index 61% rename from xllm/core/kernels/npu/linear.h rename to xllm/core/kernels/npu/matmul.cpp index 0834c014b..0b80c9dd4 100644 --- a/xllm/core/kernels/npu/linear.h +++ b/xllm/core/kernels/npu/matmul.cpp @@ -13,18 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#pragma once -#include "impl/npu_linear_impl.h" - -namespace xllm::kernel { - -class Linear : public torch::nn::ModuleHolder { - public: - using torch::nn::ModuleHolder::ModuleHolder; - using Impl __attribute__((__unused__)) = NpuLinearImpl; - - Linear(const ModelContext& context) - : ModuleHolder(std::make_shared(context)) {} -}; - -} // namespace xllm::kernel +#include "npu_ops_api.h" +#include "ops_npu/npu_ops.h" + +namespace xllm::kernel::npu { + +torch::Tensor matmul(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& bias) { + if (!bias.has_value()) { + return torch::nn::functional::linear(a, b); + } else { + return torch::nn::functional::linear(a, b, bias.value()); + } +} + +} // namespace xllm::kernel::npu diff --git a/xllm/core/kernels/npu/npu_ops_api.h b/xllm/core/kernels/npu/npu_ops_api.h new file mode 100644 index 000000000..f59f39a07 --- /dev/null +++ b/xllm/core/kernels/npu/npu_ops_api.h @@ -0,0 +1,69 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once +#include + +#include +#include + +#include "custom_functions_npu/atb_common.h" + +namespace xllm::kernel::npu { + +void reshape_paged_cache(torch::Tensor& key, + std::optional& value, + torch::Tensor& k_cache, + std::optional& v_cache, + const torch::Tensor& slot_mapping); + +void batch_prefill(const torch::Tensor& query, + const torch::Tensor& key, + const torch::Tensor& value, + const torch::Tensor& mask, + const torch::Tensor& seq_len, + float scale, + torch::Tensor& output); + +void batch_decode(const torch::Tensor& query, + const torch::Tensor& k_cache, + const torch::Tensor& v_cache, + float scale, + const torch::Tensor& block_table, + const torch::Tensor& seq_lens, + torch::Tensor& output); + +torch::Tensor matmul(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& bias); + +torch::Tensor active(const torch::Tensor& input, const std::string& act_mode); + +torch::Tensor rms_norm(const torch::Tensor& input, + const torch::Tensor& weight, + double eps, + const std::string& mode); + +std::tuple add_rms_norm( + const torch::Tensor& x1, + const torch::Tensor& x2, + const torch::Tensor& gamma, + double epsilon); + +void apply_rotary(torch::Tensor& q, + torch::Tensor& k, + const torch::Tensor& cos_sin_cache, + const torch::Tensor& positions); +} // namespace xllm::kernel::npu diff --git a/xllm/core/kernels/npu/rope.cpp b/xllm/core/kernels/npu/rope.cpp new file mode 100644 index 000000000..7bcbbc7c4 --- /dev/null +++ b/xllm/core/kernels/npu/rope.cpp @@ -0,0 +1,42 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "npu_ops_api.h" +#include "ops_npu/npu_ops.h" + +namespace xllm::kernel::npu { + +void apply_rotary(torch::Tensor& q, + torch::Tensor& k, + const torch::Tensor& cos_sin_cache, + const torch::Tensor& positions) { + auto cos_sin = cos_sin_cache.index_select(0, positions); + int64_t last_dim = cos_sin.size(-1); + auto cos_sin_vec = cos_sin.view({-1, 2, last_dim / 2}) + .repeat({1, 1, 2}) + .chunk(2, /*dim=*/-2); + auto cos = cos_sin_vec[0].view({1, -1, 1, last_dim}); + auto sin = cos_sin_vec[1].view({1, -1, 1, last_dim}); + + const int64_t rotary_dim = sin.size(-1); + q = q.view({1, q.size(0), -1, rotary_dim}); + k = k.view({1, k.size(0), -1, rotary_dim}); + + at_npu::native::custom_ops::npu_apply_rotary_pos_emb(q, k, cos, sin); +} + +} // namespace xllm::kernel::npu \ No newline at end of file diff --git a/xllm/core/kernels/npu/rope.h b/xllm/core/kernels/npu/rope.h deleted file mode 100644 index 7a075b0d3..000000000 --- a/xllm/core/kernels/npu/rope.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#pragma once -#include "impl/npu_rope_impl.h" - -namespace xllm::kernel { - -class Rope : public torch::nn::ModuleHolder { - public: - using torch::nn::ModuleHolder::ModuleHolder; - using Impl __attribute__((__unused__)) = NpuRopeImpl; - - Rope(const ModelContext& context) - : ModuleHolder(std::make_shared(context)) {} -}; - -} // namespace xllm::kernel diff --git a/xllm/core/kernels/npu/split.h b/xllm/core/kernels/npu/split.h deleted file mode 100644 index cda39703e..000000000 --- a/xllm/core/kernels/npu/split.h +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright 2025 The xLLM Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - https://github.com/jd-opensource/xllm/blob/main/LICENSE - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#pragma once -#include "impl/npu_split_impl.h" - -namespace xllm::kernel { -class Split : public torch::nn::ModuleHolder { - public: - using torch::nn::ModuleHolder::ModuleHolder; - using Impl __attribute__((__unused__)) = NpuSplitImpl; - - Split(const ModelContext& context, - int32_t splitDim = 2, - int32_t splitNum = 3, - atb::SVector splitSizes = {}) - : ModuleHolder(std::make_shared(context, - splitDim, - splitNum, - splitSizes)) {} -}; - -} // namespace xllm::kernel diff --git a/xllm/core/kernels/ops_api.cpp b/xllm/core/kernels/ops_api.cpp index 07ab88f62..248871a33 100644 --- a/xllm/core/kernels/ops_api.cpp +++ b/xllm/core/kernels/ops_api.cpp @@ -17,6 +17,8 @@ limitations under the License. #if defined(USE_MLU) #include "mlu/mlu_ops_api.h" +#elif defined(USE_NPU) +#include "npu/npu_ops_api.h" #elif defined(USE_CUDA) #include "cuda/cuda_ops_api.h" #endif @@ -38,6 +40,9 @@ void apply_rotary(RotaryParams& params) { params.discrete, params.dynamic_ntk, params.max_query_len); +#elif defined(USE_NPU) + npu::apply_rotary( + params.q, params.k, params.cos_sin, params.position_ids.value()); #elif defined(USE_CUDA) bool is_neox = !params.interleaved; @@ -65,6 +70,8 @@ void active(ActivationParams& params) { params.expert_size); #elif defined(USE_CUDA) cuda::act_and_mul(params.output, params.input, params.act_mode); +#elif defined(USE_NPU) + params.output = npu::active(params.input, params.act_mode); #else LOG(FATAL) << "active not implemented"; #endif @@ -78,6 +85,12 @@ void reshape_paged_cache(ReshapePagedCacheParams& params) { params.v_cache, params.slot_mapping, params.direction); +#elif defined(USE_NPU) + npu::reshape_paged_cache(params.key, + params.value, + params.k_cache, + params.v_cache, + params.slot_mapping); #elif defined(USE_CUDA) cuda::reshape_paged_cache(params.slot_mapping, params.key, @@ -113,6 +126,14 @@ void batch_prefill(AttentionParams& params) { params.window_size_right, params.compute_dtype, params.return_lse); +#elif defined(USE_NPU) + npu::batch_prefill(params.query, + params.key, + params.value, + params.attn_mask, + params.seq_lens, + params.scale, + params.output); #elif defined(USE_CUDA) cuda::batch_prefill(params.float_workspace_buffer, params.int_workspace_buffer, @@ -154,6 +175,14 @@ void batch_decode(AttentionParams& params) { params.scale, params.return_lse, params.kv_cache_quant_bit_size); +#elif defined(USE_NPU) + npu::batch_decode(params.query, + params.k_cache, + params.v_cache.value_or(torch::Tensor()), + params.scale, + params.block_table.value(), + params.seq_lens, + params.output); #elif defined(USE_CUDA) params.query = params.query.squeeze(1); params.output = params.output.squeeze(1); @@ -202,6 +231,15 @@ void fused_layernorm(FusedLayerNormParams& params) { } else { cuda::rms_norm(params.output, params.input, params.weight, params.eps); } +#elif defined(USE_NPU) + if (params.residual.has_value()) { + std::tie(params.output, std::ignore, params.residual_out) = + npu::add_rms_norm( + params.input, params.residual.value(), params.weight, params.eps); + } else { + params.output = + npu::rms_norm(params.input, params.weight, params.eps, params.mode); + } #else LOG(FATAL) << "fused_layernorm not implemented"; #endif @@ -211,6 +249,8 @@ torch::Tensor matmul(MatmulParams& params) { #if defined(USE_MLU) return mlu::matmul( params.a, params.b, params.bias, params.c, params.alpha, params.beta); +#elif defined(USE_NPU) + return npu::matmul(params.a, params.b, params.bias); #elif defined(USE_CUDA) return cuda::matmul(params.a, params.b, params.bias); #else diff --git a/xllm/core/kernels/param.h b/xllm/core/kernels/param.h index 2be0e88aa..5cf7d7e89 100644 --- a/xllm/core/kernels/param.h +++ b/xllm/core/kernels/param.h @@ -195,6 +195,9 @@ struct AttentionParams { float scale; // Whether to return log-sum-exp values in output_lse. bool return_lse = false; + // ========== Torch NPU related parameters ========== + torch::Tensor seq_lens; + torch::Tensor attn_mask; // ========== FlashInfer related parameters ========== torch::Tensor paged_kv_indptr; diff --git a/xllm/core/layers/common/dense_mlp.cpp b/xllm/core/layers/common/dense_mlp.cpp index 57b52429b..1bdfea078 100644 --- a/xllm/core/layers/common/dense_mlp.cpp +++ b/xllm/core/layers/common/dense_mlp.cpp @@ -18,6 +18,7 @@ limitations under the License. #include #include "kernels/ops_api.h" +#include "platform/device.h" namespace xllm { namespace layer { @@ -90,11 +91,14 @@ torch::Tensor DenseMLPImpl::forward(const torch::Tensor& hidden_states) { // For w8a8 quantization, the active operation is fused with the down_proj return down_proj_->forward(gate_up); } else { - int64_t batch_size = gate_up.sizes()[0]; - auto output = torch::empty( - {batch_size, - intermediate_size_ / parallel_args_.tp_group_->world_size()}, - gate_up.options()); + torch::Tensor output; + if (Device::type_str() != "npu") { + int64_t batch_size = gate_up.sizes()[0]; + output = torch::empty( + {batch_size, + intermediate_size_ / parallel_args_.tp_group_->world_size()}, + gate_up.options()); + } xllm::kernel::ActivationParams activation_params; activation_params.input = gate_up; diff --git a/xllm/core/layers/common/rms_norm.cpp b/xllm/core/layers/common/rms_norm.cpp index 21addfb74..de2f5f44f 100644 --- a/xllm/core/layers/common/rms_norm.cpp +++ b/xllm/core/layers/common/rms_norm.cpp @@ -18,6 +18,7 @@ limitations under the License. #include #include "kernels/ops_api.h" +#include "platform/device.h" namespace xllm { namespace layer { @@ -40,7 +41,10 @@ RMSNormImpl::RMSNormImpl(const ModelContext& context) context.get_tensor_options()) {} torch::Tensor RMSNormImpl::forward(torch::Tensor& input) { - auto output = torch::empty_like(input); + torch::Tensor output; + if (Device::type_str() != "npu") { + output = torch::empty_like(input); + } return forward_output(input, output); } diff --git a/xllm/core/layers/common/rotary_embedding.cpp b/xllm/core/layers/common/rotary_embedding.cpp index 00ea4bd04..50fcd2583 100644 --- a/xllm/core/layers/common/rotary_embedding.cpp +++ b/xllm/core/layers/common/rotary_embedding.cpp @@ -51,7 +51,7 @@ void RotaryEmbeddingImpl::forward(torch::Tensor& q, std::optional position_ids; if (is_prompt) { discrete = false; - if (Device::type_str() == "cuda") { + if (Device::type_str() == "cuda" || Device::type_str() == "npu") { position_ids = positions; } } else {