Skip to content

Commit d471b2a

Browse files
authored
[Model Runner V2] Support num NaNs in logits (#30187)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 9e6562a commit d471b2a

File tree

7 files changed

+89
-26
lines changed

7 files changed

+89
-26
lines changed

vllm/v1/worker/gpu/async_utils.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
from contextlib import contextmanager
44

5+
import numpy as np
56
import torch
67

78
from vllm.v1.outputs import (
89
AsyncModelRunnerOutput,
910
LogprobsTensors,
1011
ModelRunnerOutput,
11-
SamplerOutput,
1212
)
13+
from vllm.v1.worker.gpu.sample.output import SamplerOutput
1314

1415

1516
class AsyncOutput(AsyncModelRunnerOutput):
@@ -34,29 +35,18 @@ def __init__(
3435
with torch.cuda.stream(self.copy_stream):
3536
self.copy_stream.wait_stream(default_stream)
3637

37-
# NOTE(woosuk): We must ensure that CPU tensors are not freed
38-
# before the device-to-host copy is fully completed. For instance,
39-
# operations like
40-
# self.sampled_token_np = ...to("cpu", non_blocking=True).numpy()
41-
# are unsafe because the underlying CPU tensor can be prematurely freed and
42-
# reused by other tensors before the asynchronous copy finishes, potentially
43-
# causing race conditions. To prevent this, we delay freeing by holding
44-
# references until the copy event signals completion.
45-
# Likewise, we also need to keep the reference to the GPU tensors.
46-
# This is done by keeping the reference to sampler_output and
47-
# model_runner_output.
48-
self.sampled_token_ids = sampler_output.sampled_token_ids.to(
49-
"cpu", non_blocking=True
50-
)
38+
self.sampled_token_ids = async_copy_to_np(sampler_output.sampled_token_ids)
5139
if sampler_output.logprobs_tensors is not None:
5240
self.logprobs_tensors: LogprobsTensors | None = (
5341
sampler_output.logprobs_tensors.to_cpu_nonblocking()
5442
)
5543
else:
5644
self.logprobs_tensors = None
57-
self.num_sampled_tokens_cpu = num_sampled_tokens.to(
58-
"cpu", non_blocking=True
59-
)
45+
if sampler_output.num_nans is not None:
46+
self.num_nans = async_copy_to_np(sampler_output.num_nans)
47+
else:
48+
self.num_nans = None
49+
self.num_sampled_tokens_np = async_copy_to_np(num_sampled_tokens)
6050
self.prompt_logprobs_dict: dict[str, LogprobsTensors | None] = {}
6151
if self.model_runner_output.prompt_logprobs_dict:
6252
for k, v in self.model_runner_output.prompt_logprobs_dict.items():
@@ -68,18 +58,25 @@ def __init__(
6858

6959
def get_output(self) -> ModelRunnerOutput:
7060
self.copy_event.synchronize()
71-
num_sampled_tokens_np = self.num_sampled_tokens_cpu.numpy()
7261

7362
# NOTE(woosuk): The following code is to ensure compatibility with
7463
# the existing model runner.
7564
# Going forward, we should keep the data structures as NumPy arrays
7665
# rather than Python lists.
7766
sampled_token_ids: list[list[int]] = self.sampled_token_ids.tolist()
7867
num_reqs = len(sampled_token_ids)
68+
num_sampled_tokens = self.num_sampled_tokens_np.tolist()
7969
for i in range(num_reqs):
80-
del sampled_token_ids[i][num_sampled_tokens_np[i] :]
70+
del sampled_token_ids[i][num_sampled_tokens[i] :]
8171
self.model_runner_output.sampled_token_ids = sampled_token_ids
8272

73+
if self.num_nans is not None:
74+
num_nans = self.num_nans.tolist()
75+
self.model_runner_output.num_nans_in_logits = {
76+
req_id: num_nans[i]
77+
for i, req_id in enumerate(self.model_runner_output.req_ids)
78+
}
79+
8380
if self.logprobs_tensors is not None:
8481
self.model_runner_output.logprobs = self.logprobs_tensors.tolists()
8582
self.model_runner_output.prompt_logprobs_dict = self.prompt_logprobs_dict
@@ -95,3 +92,7 @@ def async_barrier(event: torch.cuda.Event | None):
9592
finally:
9693
if event is not None:
9794
event.record()
95+
96+
97+
def async_copy_to_np(x: torch.Tensor) -> np.ndarray:
98+
return x.to("cpu", non_blocking=True).numpy()

vllm/v1/worker/gpu/metrics/__init__.py

Whitespace-only changes.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import torch
4+
from torch._inductor.runtime.triton_helpers import libdevice
5+
6+
from vllm.triton_utils import tl, triton
7+
8+
9+
@triton.jit
10+
def _num_nans_kernel(
11+
logits_ptr,
12+
logits_stride,
13+
num_nans_ptr,
14+
vocab_size,
15+
BLOCK_SIZE: tl.constexpr,
16+
):
17+
req_idx = tl.program_id(0)
18+
num_nans = 0
19+
for i in range(0, vocab_size, BLOCK_SIZE):
20+
block = i + tl.arange(0, BLOCK_SIZE)
21+
mask = block < vocab_size
22+
logits = tl.load(
23+
logits_ptr + req_idx * logits_stride + block, mask=mask, other=0
24+
)
25+
logits = logits.to(tl.float32)
26+
is_nan = libdevice.isnan(logits).to(tl.int1)
27+
num_nans += tl.sum(is_nan).to(tl.int32)
28+
tl.store(num_nans_ptr + req_idx, num_nans)
29+
30+
31+
def get_num_nans(logits: torch.Tensor) -> torch.Tensor:
32+
num_reqs, vocab_size = logits.shape
33+
BLOCK_SIZE = 8192
34+
num_nans = torch.empty(num_reqs, dtype=torch.int32, device=logits.device)
35+
_num_nans_kernel[(num_reqs,)](
36+
logits,
37+
logits.stride(0),
38+
num_nans,
39+
vocab_size,
40+
BLOCK_SIZE=BLOCK_SIZE,
41+
)
42+
return num_nans

vllm/v1/worker/gpu/model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
LogprobsTensors,
2626
ModelRunnerOutput,
2727
)
28-
from vllm.v1.sample.sampler import SamplerOutput
2928
from vllm.v1.worker.gpu.async_utils import AsyncOutput, async_barrier
3029
from vllm.v1.worker.gpu.attn_utils import (
3130
build_attn_metadata,
@@ -53,6 +52,7 @@
5352
SamplingMetadata,
5453
expand_sampling_metadata,
5554
)
55+
from vllm.v1.worker.gpu.sample.output import SamplerOutput
5656
from vllm.v1.worker.gpu.sample.sampler import Sampler
5757
from vllm.v1.worker.gpu.spec_decode import init_speculator
5858
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample

vllm/v1/worker/gpu/sample/min_p.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,7 @@ def _min_p_kernel(
3939
tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask)
4040

4141

42-
def apply_min_p(logits: torch.Tensor, min_p: torch.Tensor | None) -> None:
43-
if min_p is None:
44-
return
42+
def apply_min_p(logits: torch.Tensor, min_p: torch.Tensor) -> None:
4543
num_reqs, vocab_size = logits.shape
4644
BLOCK_SIZE = 1024
4745
_min_p_kernel[(num_reqs,)](
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from dataclasses import dataclass
4+
5+
import torch
6+
7+
from vllm.v1.outputs import LogprobsTensors
8+
9+
10+
@dataclass
11+
class SamplerOutput:
12+
sampled_token_ids: torch.Tensor
13+
logprobs_tensors: LogprobsTensors | None
14+
num_nans: torch.Tensor | None

vllm/v1/worker/gpu/sample/sampler.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33

44
import torch
55

6+
import vllm.envs as envs
67
from vllm.config.model import LogprobsMode
7-
from vllm.v1.outputs import SamplerOutput
88
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
9+
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
910
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
1011
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
1112
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
1213
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
14+
from vllm.v1.worker.gpu.sample.output import SamplerOutput
1315
from vllm.v1.worker.gpu.sample.penalties import apply_penalties_and_temperature
1416

1517

@@ -21,12 +23,16 @@ def __init__(
2123
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
2224
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
2325
self.logprobs_mode = logprobs_mode
26+
self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default.
2427

2528
def __call__(
2629
self,
2730
logits: torch.Tensor,
2831
sampling_metadata: SamplingMetadata,
2932
) -> SamplerOutput:
33+
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
34+
# that num_nans is computed before applying penalties and temperature.
35+
num_nans = get_num_nans(logits) if self.compute_nans else None
3036
sampled, processed_logits = self.sample(logits, sampling_metadata)
3137
if sampling_metadata.max_num_logprobs is not None:
3238
logits = (
@@ -49,6 +55,7 @@ def __call__(
4955
# token per request.
5056
sampled_token_ids=sampled.view(-1, 1),
5157
logprobs_tensors=logprobs_tensors,
58+
num_nans=num_nans,
5259
)
5360
return sampler_output
5461

@@ -63,7 +70,8 @@ def sample(
6370
# Apply penalties and temperature in place.
6471
apply_penalties_and_temperature(logits, sampling_metadata)
6572
# Apply min_p in place.
66-
apply_min_p(logits, sampling_metadata.min_p)
73+
if sampling_metadata.min_p is not None:
74+
apply_min_p(logits, sampling_metadata.min_p)
6775
# Apply top_k and/or top_p. This might return a new tensor.
6876
logits = apply_top_k_top_p(
6977
logits, sampling_metadata.top_k, sampling_metadata.top_p

0 commit comments

Comments
 (0)