diff --git a/python/sglang/srt/lora/backend/ascend_backend.py b/python/sglang/srt/lora/backend/ascend_backend.py new file mode 100644 index 00000000000..4278b340e48 --- /dev/null +++ b/python/sglang/srt/lora/backend/ascend_backend.py @@ -0,0 +1,287 @@ +from typing import Optional + +import torch + +from sglang.srt.lora.backend.base_backend import BaseLoRABackend +from sglang.srt.lora.utils import LoRABatchInfo +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.utils import is_npu + +if is_npu(): + import sgl_kernel_npu # noqa: F401 + import torch_npu # noqa: F401 + + +class AscendLoRABackend(BaseLoRABackend): + name = "ascend" + + def __init__( + self, + max_loras_per_batch: int, + device: torch.device, + **kwargs, + ): + super().__init__(max_loras_per_batch, device) + + def run_lora_a_sgemm( + self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + + total_seq_len, _ = x.shape + _, weight_out_dim, _ = weights.shape + + output_tensor = torch.zeros( + (total_seq_len, weight_out_dim), dtype=x.dtype, device=x.device + ) + torch.ops.npu.sgmv_shrink( + x, + weights, + self.batch_info.weight_indices, + self.batch_info.seg_lens, + output_tensor, + 1.0, + ) + scaling = ( + self.batch_info.scalings.gather(0, self.batch_info.weight_indices) + .repeat_interleave(self.batch_info.seg_lens, output_size=total_seq_len) + .unsqueeze(-1) + ) + output_tensor *= scaling + + return output_tensor + + def run_lora_b_sgemm( + self, + x: torch.Tensor, + weights: torch.Tensor, + base_output: torch.Tensor = None, + *args, + **kwargs, + ) -> torch.Tensor: + total_seq_len, _ = x.shape + _, weight_out_dim, _ = weights.shape + + if base_output is None: + output_tensor = torch.zeros( + (total_seq_len, weight_out_dim), device=x.device, dtype=x.dtype + ) + else: + output_tensor = base_output + + torch.ops.npu.sgmv_expand( + x, + weights, + self.batch_info.weight_indices, + self.batch_info.seg_lens, + output_tensor, + 0, + weight_out_dim, + ) + + return output_tensor + + def run_qkv_lora( + self, + x: torch.Tensor, + qkv_lora_a: torch.Tensor, + qkv_lora_b: torch.Tensor, + output_offset: torch.Tensor, + output_offset_cpu: torch.Tensor, + max_qkv_out_dim: int, + base_output: torch.Tensor = None, + *args, + **kwargs, + ) -> torch.Tensor: + num_slices = 3 + assert isinstance(qkv_lora_b, torch.Tensor) + + total_seq_len, _ = x.shape + _, weight_intermediate_dim, _ = qkv_lora_a.shape + _, weight_out_dim, _ = qkv_lora_b.shape + max_rank = weight_intermediate_dim // num_slices + + if base_output is None: + output_tensor = torch.zeros( + (total_seq_len, weight_out_dim), device=x.device, dtype=x.dtype + ) + else: + output_tensor = base_output + + lora_a_output = torch.zeros( + total_seq_len, weight_intermediate_dim, dtype=x.dtype, device=x.device + ) + torch.ops.npu.sgmv_shrink( + x, + qkv_lora_a, + self.batch_info.weight_indices, + self.batch_info.seg_lens, + lora_a_output, + 1.0, + ) + + scaling = ( + self.batch_info.scalings.gather(0, self.batch_info.weight_indices) + .repeat_interleave(self.batch_info.seg_lens, output_size=total_seq_len) + .unsqueeze(-1) + ) + lora_a_output *= scaling + + for slice_id in range(num_slices): + slice_offset = output_offset_cpu[slice_id] + slice_offset_next = output_offset_cpu[slice_id + 1] + slice_size = slice_offset_next - slice_offset + torch.ops.npu.sgmv_expand( + lora_a_output[:, (max_rank * slice_id) : (max_rank * (slice_id + 1))], + qkv_lora_b[:, slice_offset:slice_offset_next], + self.batch_info.weight_indices, + self.batch_info.seg_lens, + output_tensor, + slice_offset, + slice_size, + ) + + return output_tensor + + def run_gate_up_lora( + self, + x: torch.Tensor, + gate_up_lora_a: torch.Tensor, + gate_up_lora_b: torch.Tensor, + base_output: torch.Tensor = None, + *args, + **kwargs, + ) -> torch.Tensor: + + num_slices = 2 + assert isinstance(gate_up_lora_b, torch.Tensor) + + total_seq_len, _ = x.shape + _, weight_intermediate_dim, _ = gate_up_lora_a.shape + _, weight_out_dim, _ = gate_up_lora_b.shape + slice_size = weight_out_dim // num_slices + max_rank = weight_intermediate_dim // num_slices + + if base_output is None: + output_tensor = torch.zeros( + (total_seq_len, weight_out_dim), device=x.device, dtype=x.dtype + ) + else: + output_tensor = base_output + + lora_a_output = torch.zeros( + total_seq_len, weight_intermediate_dim, dtype=x.dtype, device=x.device + ) + + torch.ops.npu.sgmv_shrink( + x, + gate_up_lora_a, + self.batch_info.weight_indices, + self.batch_info.seg_lens, + lora_a_output, + 1.0, + ) + + scaling = ( + self.batch_info.scalings.gather(0, self.batch_info.weight_indices) + .repeat_interleave(self.batch_info.seg_lens, output_size=total_seq_len) + .unsqueeze(-1) + ) + lora_a_output *= scaling + + slice_offset = 0 + for slice_id in range(num_slices): + torch.ops.npu.sgmv_expand( + lora_a_output[:, (max_rank * slice_id) : (max_rank * (slice_id + 1))], + gate_up_lora_b[:, slice_offset : slice_offset + slice_size], + self.batch_info.weight_indices, + self.batch_info.seg_lens, + output_tensor, + slice_offset, + slice_size, + ) + slice_offset += slice_size + + return output_tensor + + def init_cuda_graph_batch_info( + self, cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int + ): + # Initialize seg_lens and seg_indptr for CUDA graph as they remain constant + # across batches. + cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph].fill_(1) + torch.cumsum( + cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph], + dim=0, + out=cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1], + ) + + def prepare_lora_batch( + self, + forward_batch: ForwardBatch, + weight_indices: list[int], + lora_ranks: list[int], + scalings: list[float], + batch_info: Optional[LoRABatchInfo] = None, + ): + # Use pinned memory to avoid synchronizations during host-to-device transfer + weight_indices_tensor = torch.tensor( + weight_indices, dtype=torch.int32, pin_memory=True, device="cpu" + ) + lora_ranks_tensor = torch.tensor( + lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu" + ) + scalings_tensor = torch.tensor( + scalings, dtype=torch.float, pin_memory=True, device="cpu" + ) + + bs = forward_batch.batch_size + + if batch_info is not None: + assert ( + batch_info.use_cuda_graph + ), "batch_info.use_cuda_graph must be True when batch_info is provided" + batch_info.bs = forward_batch.batch_size + batch_info.num_segments = forward_batch.batch_size + else: + max_len = ( + # Calculate max_len from the CPU copy to avoid D2H transfer. + max(forward_batch.extend_seq_lens_cpu) + if forward_batch.forward_mode.is_extend() + else 1 + ) + seg_lens = ( + forward_batch.extend_seq_lens + if forward_batch.forward_mode.is_extend() + else torch.ones(bs, dtype=torch.int32, device=self.device) + ) + seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device) + seg_indptr[1:] = torch.cumsum(seg_lens, dim=0) + + batch_info = LoRABatchInfo( + bs=forward_batch.batch_size, + num_segments=forward_batch.batch_size, + max_len=max_len, + use_cuda_graph=False, + seg_lens=seg_lens, + seg_indptr=seg_indptr, + weight_indices=torch.empty( + (bs,), dtype=torch.int32, device=self.device + ), + lora_ranks=torch.empty( + (self.max_loras_per_batch,), dtype=torch.int32, device=self.device + ), + scalings=torch.empty( + (self.max_loras_per_batch,), dtype=torch.float, device=self.device + ), + permutation=None, + ) + + # Copy to device asynchronously + batch_info.lora_ranks[: self.max_loras_per_batch].copy_( + lora_ranks_tensor, non_blocking=True + ) + batch_info.scalings[: self.max_loras_per_batch].copy_( + scalings_tensor, non_blocking=True + ) + batch_info.weight_indices[:bs].copy_(weight_indices_tensor, non_blocking=True) + self.batch_info = batch_info diff --git a/python/sglang/srt/lora/backend/base_backend.py b/python/sglang/srt/lora/backend/base_backend.py index 4d241f93168..ad7f886d761 100644 --- a/python/sglang/srt/lora/backend/base_backend.py +++ b/python/sglang/srt/lora/backend/base_backend.py @@ -133,23 +133,3 @@ def prepare_lora_batch( internal batch info (e.g., self.cuda_graph_batch_info for CUDA Graph mode) """ pass - - -def get_backend_from_name(name: str) -> BaseLoRABackend: - """ - Get corresponding backend class from backend's name - """ - if name == "triton": - from sglang.srt.lora.backend.triton_backend import TritonLoRABackend - - return TritonLoRABackend - elif name == "csgmv": - from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend - - return ChunkedSgmvLoRABackend - elif name == "flashinfer": - raise ValueError( - "FlashInfer LoRA backend has been deprecated, please use `triton` instead." - ) - else: - raise ValueError(f"Invalid backend: {name}") diff --git a/python/sglang/srt/lora/backend/lora_registry.py b/python/sglang/srt/lora/backend/lora_registry.py new file mode 100644 index 00000000000..c3dd7788861 --- /dev/null +++ b/python/sglang/srt/lora/backend/lora_registry.py @@ -0,0 +1,53 @@ +import logging + +from sglang.srt.lora.backend.base_backend import BaseLoRABackend + +logger = logging.getLogger(__name__) + +LORA_SUPPORTED_BACKENDS = {} + + +def register_lora_backend(name): + def decorator(fn): + LORA_SUPPORTED_BACKENDS[name] = fn + return fn + + return decorator + + +@register_lora_backend("triton") +def create_triton_backend(): + from sglang.srt.lora.backend.triton_backend import TritonLoRABackend + + return TritonLoRABackend + + +@register_lora_backend("csgmv") +def create_triton_csgmv_backend(): + from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend + + return ChunkedSgmvLoRABackend + + +@register_lora_backend("ascend") +def create_ascend_backend(): + from sglang.srt.lora.backend.ascend_backend import AscendLoRABackend + + return AscendLoRABackend + + +@register_lora_backend("flashinfer") +def create_flashinfer_backend(): + raise ValueError( + "FlashInfer LoRA backend has been deprecated, please use `triton` instead." + ) + + +def get_backend_from_name(name: str) -> BaseLoRABackend: + """ + Get corresponding backend class from backend's name + """ + if name not in LORA_SUPPORTED_BACKENDS: + raise ValueError(f"Invalid backend: {name}") + lora_backend = LORA_SUPPORTED_BACKENDS[name]() + return lora_backend diff --git a/python/sglang/srt/lora/layers.py b/python/sglang/srt/lora/layers.py index 4426faccba7..139d97cbca3 100644 --- a/python/sglang/srt/lora/layers.py +++ b/python/sglang/srt/lora/layers.py @@ -27,6 +27,8 @@ def __init__( self.base_layer: nn.Module = base_layer self.set_lora: bool = False self.lora_backend: BaseLoRABackend = lora_backend + if hasattr(self.base_layer, "weight"): + self.weight = self.base_layer.weight def forward(self, x: torch.Tensor): return self.base_layer.forward(x) @@ -198,6 +200,7 @@ def __init__( dtype=torch.int32, device=next(self.base_layer.parameters()).device, ) + self.output_offset_cpu = self.output_offset.cpu() # For computing number of launched blocks self.max_qkv_out_dim = max(q_proj_shard_size, kv_proj_shard_size) @@ -218,6 +221,7 @@ def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor qkv_lora_b=self.B_buffer_qkv, base_output=base_output, output_offset=self.output_offset, + output_offset_cpu=self.output_offset_cpu, max_qkv_out_dim=self.max_qkv_out_dim, ) return lora_output diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index b1277caca84..f1199304a26 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -27,16 +27,13 @@ from sglang.srt.configs.load_config import LoadConfig from sglang.srt.lora.backend.base_backend import BaseLoRABackend -from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend -from sglang.srt.lora.backend.triton_backend import TritonLoRABackend +from sglang.srt.lora.backend.lora_registry import LORA_SUPPORTED_BACKENDS from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.utils.hf_transformers_utils import AutoConfig logger = logging.getLogger(__name__) -SUPPORTED_BACKENDS = (TritonLoRABackend, ChunkedSgmvLoRABackend) - class LoRALayer(nn.Module): def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig): @@ -161,8 +158,8 @@ def normalize_gate_up_proj( gate_up_name = weight_name.replace("gate_proj", "gate_up_proj") if up_name not in weights: weights[up_name] = torch.zeros_like(weights[weight_name]) - assert isinstance(self.lora_backend, SUPPORTED_BACKENDS), ( - f"LoRA weight initialization currently only supported for LoRA backends: {', '.join(b.name for b in SUPPORTED_BACKENDS)}" + assert self.lora_backend.name in LORA_SUPPORTED_BACKENDS, ( + f"LoRA weight initialization currently only supported for LoRA backends: {', '.join(b for b in LORA_SUPPORTED_BACKENDS)}" f"Received backend: {self.lora_backend.name}. Please verify your backend configuration " f"or consider implementing custom initialization logic for other backends." ) diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 19ff874dc1d..21558bebb2c 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -21,7 +21,8 @@ import torch from sglang.srt.configs.load_config import LoadConfig -from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name +from sglang.srt.lora.backend.base_backend import BaseLoRABackend +from sglang.srt.lora.backend.lora_registry import get_backend_from_name from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer from sglang.srt.lora.lora import LoRAAdapter from sglang.srt.lora.lora_config import LoRAConfig @@ -37,9 +38,16 @@ from sglang.srt.managers.io_struct import LoRAUpdateOutput from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import replace_submodule +from sglang.srt.utils import is_npu, replace_submodule from sglang.srt.utils.hf_transformers_utils import AutoConfig +if is_npu(): + from torch_npu.contrib import transfer_to_npu # noqa: F401 + + # Re-mock torch.cuda.is_available cuz transfer_to_npu mocks it to True + torch.cuda.is_available = lambda: False + + logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 5b9a520b988..f08b6d8b85c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -129,7 +129,7 @@ "intel_xpu", ] -LORA_BACKEND_CHOICES = ["triton", "csgmv"] +LORA_BACKEND_CHOICES = ["triton", "csgmv", "ascend"] DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]