Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 287 additions & 0 deletions python/sglang/srt/lora/backend/ascend_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
from typing import Optional

import torch

from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_npu

if is_npu():
import sgl_kernel_npu # noqa: F401
import torch_npu # noqa: F401


class AscendLoRABackend(BaseLoRABackend):
name = "ascend"

def __init__(
self,
max_loras_per_batch: int,
device: torch.device,
**kwargs,
):
super().__init__(max_loras_per_batch, device)

def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:

total_seq_len, _ = x.shape
_, weight_out_dim, _ = weights.shape

output_tensor = torch.zeros(
(total_seq_len, weight_out_dim), dtype=x.dtype, device=x.device
)
torch.ops.npu.sgmv_shrink(
x,
weights,
self.batch_info.weight_indices,
self.batch_info.seg_lens,
output_tensor,
1.0,
)
scaling = (
self.batch_info.scalings.gather(0, self.batch_info.weight_indices)
.repeat_interleave(self.batch_info.seg_lens, output_size=total_seq_len)
.unsqueeze(-1)
)
output_tensor *= scaling

return output_tensor

def run_lora_b_sgemm(
self,
x: torch.Tensor,
weights: torch.Tensor,
base_output: torch.Tensor = None,
*args,
**kwargs,
) -> torch.Tensor:
total_seq_len, _ = x.shape
_, weight_out_dim, _ = weights.shape

if base_output is None:
output_tensor = torch.zeros(
(total_seq_len, weight_out_dim), device=x.device, dtype=x.dtype
)
else:
output_tensor = base_output

torch.ops.npu.sgmv_expand(
x,
weights,
self.batch_info.weight_indices,
self.batch_info.seg_lens,
output_tensor,
0,
weight_out_dim,
)

return output_tensor

def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
qkv_lora_b: torch.Tensor,
output_offset: torch.Tensor,
output_offset_cpu: torch.Tensor,
max_qkv_out_dim: int,
base_output: torch.Tensor = None,
*args,
**kwargs,
) -> torch.Tensor:
num_slices = 3
assert isinstance(qkv_lora_b, torch.Tensor)

total_seq_len, _ = x.shape
_, weight_intermediate_dim, _ = qkv_lora_a.shape
_, weight_out_dim, _ = qkv_lora_b.shape
max_rank = weight_intermediate_dim // num_slices

if base_output is None:
output_tensor = torch.zeros(
(total_seq_len, weight_out_dim), device=x.device, dtype=x.dtype
)
else:
output_tensor = base_output

lora_a_output = torch.zeros(
total_seq_len, weight_intermediate_dim, dtype=x.dtype, device=x.device
)
torch.ops.npu.sgmv_shrink(
x,
qkv_lora_a,
self.batch_info.weight_indices,
self.batch_info.seg_lens,
lora_a_output,
1.0,
)

scaling = (
self.batch_info.scalings.gather(0, self.batch_info.weight_indices)
.repeat_interleave(self.batch_info.seg_lens, output_size=total_seq_len)
.unsqueeze(-1)
)
lora_a_output *= scaling

for slice_id in range(num_slices):
slice_offset = output_offset_cpu[slice_id]
slice_offset_next = output_offset_cpu[slice_id + 1]
slice_size = slice_offset_next - slice_offset
torch.ops.npu.sgmv_expand(
lora_a_output[:, (max_rank * slice_id) : (max_rank * (slice_id + 1))],
qkv_lora_b[:, slice_offset:slice_offset_next],
self.batch_info.weight_indices,
self.batch_info.seg_lens,
output_tensor,
slice_offset,
slice_size,
)

return output_tensor

def run_gate_up_lora(
self,
x: torch.Tensor,
gate_up_lora_a: torch.Tensor,
gate_up_lora_b: torch.Tensor,
base_output: torch.Tensor = None,
*args,
**kwargs,
) -> torch.Tensor:

num_slices = 2
assert isinstance(gate_up_lora_b, torch.Tensor)

total_seq_len, _ = x.shape
_, weight_intermediate_dim, _ = gate_up_lora_a.shape
_, weight_out_dim, _ = gate_up_lora_b.shape
slice_size = weight_out_dim // num_slices
max_rank = weight_intermediate_dim // num_slices

if base_output is None:
output_tensor = torch.zeros(
(total_seq_len, weight_out_dim), device=x.device, dtype=x.dtype
)
else:
output_tensor = base_output

lora_a_output = torch.zeros(
total_seq_len, weight_intermediate_dim, dtype=x.dtype, device=x.device
)

torch.ops.npu.sgmv_shrink(
x,
gate_up_lora_a,
self.batch_info.weight_indices,
self.batch_info.seg_lens,
lora_a_output,
1.0,
)

scaling = (
self.batch_info.scalings.gather(0, self.batch_info.weight_indices)
.repeat_interleave(self.batch_info.seg_lens, output_size=total_seq_len)
.unsqueeze(-1)
)
lora_a_output *= scaling

slice_offset = 0
for slice_id in range(num_slices):
torch.ops.npu.sgmv_expand(
lora_a_output[:, (max_rank * slice_id) : (max_rank * (slice_id + 1))],
gate_up_lora_b[:, slice_offset : slice_offset + slice_size],
self.batch_info.weight_indices,
self.batch_info.seg_lens,
output_tensor,
slice_offset,
slice_size,
)
slice_offset += slice_size

return output_tensor

def init_cuda_graph_batch_info(
self, cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int
):
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
# across batches.
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph].fill_(1)
torch.cumsum(
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph],
dim=0,
out=cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1],
)

def prepare_lora_batch(
self,
forward_batch: ForwardBatch,
weight_indices: list[int],
lora_ranks: list[int],
scalings: list[float],
batch_info: Optional[LoRABatchInfo] = None,
):
# Use pinned memory to avoid synchronizations during host-to-device transfer
weight_indices_tensor = torch.tensor(
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
)
lora_ranks_tensor = torch.tensor(
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
)
scalings_tensor = torch.tensor(
scalings, dtype=torch.float, pin_memory=True, device="cpu"
)

bs = forward_batch.batch_size

if batch_info is not None:
assert (
batch_info.use_cuda_graph
), "batch_info.use_cuda_graph must be True when batch_info is provided"
batch_info.bs = forward_batch.batch_size
batch_info.num_segments = forward_batch.batch_size
else:
max_len = (
# Calculate max_len from the CPU copy to avoid D2H transfer.
max(forward_batch.extend_seq_lens_cpu)
if forward_batch.forward_mode.is_extend()
else 1
)
seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs, dtype=torch.int32, device=self.device)
)
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)

batch_info = LoRABatchInfo(
bs=forward_batch.batch_size,
num_segments=forward_batch.batch_size,
max_len=max_len,
use_cuda_graph=False,
seg_lens=seg_lens,
seg_indptr=seg_indptr,
weight_indices=torch.empty(
(bs,), dtype=torch.int32, device=self.device
),
lora_ranks=torch.empty(
(self.max_loras_per_batch,), dtype=torch.int32, device=self.device
),
scalings=torch.empty(
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
),
permutation=None,
)

# Copy to device asynchronously
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
lora_ranks_tensor, non_blocking=True
)
batch_info.scalings[: self.max_loras_per_batch].copy_(
scalings_tensor, non_blocking=True
)
batch_info.weight_indices[:bs].copy_(weight_indices_tensor, non_blocking=True)
self.batch_info = batch_info
20 changes: 0 additions & 20 deletions python/sglang/srt/lora/backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,23 +133,3 @@ def prepare_lora_batch(
internal batch info (e.g., self.cuda_graph_batch_info for CUDA Graph mode)
"""
pass


def get_backend_from_name(name: str) -> BaseLoRABackend:
"""
Get corresponding backend class from backend's name
"""
if name == "triton":
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend

return TritonLoRABackend
elif name == "csgmv":
from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend

return ChunkedSgmvLoRABackend
elif name == "flashinfer":
raise ValueError(
"FlashInfer LoRA backend has been deprecated, please use `triton` instead."
)
else:
raise ValueError(f"Invalid backend: {name}")
53 changes: 53 additions & 0 deletions python/sglang/srt/lora/backend/lora_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import logging

from sglang.srt.lora.backend.base_backend import BaseLoRABackend

logger = logging.getLogger(__name__)

LORA_SUPPORTED_BACKENDS = {}


def register_lora_backend(name):
def decorator(fn):
LORA_SUPPORTED_BACKENDS[name] = fn
return fn

return decorator


@register_lora_backend("triton")
def create_triton_backend():
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend

return TritonLoRABackend


@register_lora_backend("csgmv")
def create_triton_csgmv_backend():
from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend

return ChunkedSgmvLoRABackend


@register_lora_backend("ascend")
def create_ascend_backend():
from sglang.srt.lora.backend.ascend_backend import AscendLoRABackend

return AscendLoRABackend


@register_lora_backend("flashinfer")
def create_flashinfer_backend():
raise ValueError(
"FlashInfer LoRA backend has been deprecated, please use `triton` instead."
)


def get_backend_from_name(name: str) -> BaseLoRABackend:
"""
Get corresponding backend class from backend's name
"""
if name not in LORA_SUPPORTED_BACKENDS:
raise ValueError(f"Invalid backend: {name}")
lora_backend = LORA_SUPPORTED_BACKENDS[name]()
return lora_backend
4 changes: 3 additions & 1 deletion python/sglang/srt/lora/backend/triton_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def run_gate_up_lora(
return lora_output

def init_cuda_graph_batch_info(
self, cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int
self,
cuda_graph_batch_info: LoRABatchInfo,
max_bs_in_cuda_graph: int,
):
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
# across batches.
Expand Down
Loading
Loading