Skip to content

Commit 9e6562a

Browse files
authored
[Model Runner V2] Fix Triton warning on tl.where (#30355)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 0b6a8a3 commit 9e6562a

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

vllm/v1/worker/gpu/sample/penalties.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def _penalties_and_temperature_kernel(
6262
mask=packed_block < tl.cdiv(vocab_size, 32),
6363
)
6464
prompt_bin_mask = (packed_mask[:, None] >> (tl.arange(0, 32)[None, :])) & 1
65+
prompt_bin_mask = prompt_bin_mask.to(tl.int1)
6566
prompt_bin_mask = prompt_bin_mask.reshape(BLOCK_SIZE)
6667

6768
# If token appears in prompt or output, apply, otherwise use 1.0 for no-op.

0 commit comments

Comments
 (0)