Skip to content

Commit fe0e3f5

Browse files
authored
[BUG FIX] Fix bug when preempted request rescheduled (PaddlePaddle#3080)
* Fix bug when preempted request rescheduled * Fix bug when preempted request rescheduled * Fix bug when preempted request rescheduled
1 parent 0616c20 commit fe0e3f5

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

fastdeploy/engine/request.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def __init__(
118118
self.status = RequestStatus.WAITING
119119
self.task_type = RequestType.PREFILL
120120
self.idx = None
121+
self.need_prefill_tokens = self.prompt_token_ids_len
121122

122123
@classmethod
123124
def from_dict(cls, d: dict):

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,11 +117,8 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re
117117
break
118118
return can_schedule
119119

120-
def _get_num_new_tokens(self, request, token_budget, schedule_waiting=False):
121-
if schedule_waiting:
122-
num_new_tokens = request.num_total_tokens - request.num_computed_tokens
123-
else:
124-
num_new_tokens = request.prompt_token_ids_len - request.num_computed_tokens
120+
def _get_num_new_tokens(self, request, token_budget):
121+
num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens
125122
num_new_tokens = min(num_new_tokens, token_budget)
126123

127124
if not self.config.enable_mm:
@@ -212,8 +209,8 @@ def schedule(self):
212209
num_decoding_req_nums = 0
213210
while req_index < len(self.running) and token_budget > 0:
214211
request = self.running[req_index]
215-
if request.num_computed_tokens >= request.prompt_token_ids_len: # to be decoding
216-
if request.num_total_tokens > request.prompt_token_ids_len: # has generated tokens
212+
if request.num_computed_tokens >= request.need_prefill_tokens: # to be decoding
213+
if request.num_total_tokens > request.need_prefill_tokens: # has generated tokens
217214
request.num_computed_tokens = request.num_total_tokens - 1
218215
if (
219216
self.allocated_slots(request) - request.num_total_tokens
@@ -246,7 +243,7 @@ def schedule(self):
246243
token_budget -= 1
247244
else: # need to prefill
248245
llm_logger.debug(
249-
f"scheduler prefill task: {request} request.prompt_token_ids_len {request.prompt_token_ids_len} request.num_computed_tokens {request.num_computed_tokens}"
246+
f"scheduler prefill task: {request} request.need_prefill_tokens {request.need_prefill_tokens} request.num_computed_tokens {request.num_computed_tokens}"
250247
)
251248
num_new_tokens = self._get_num_new_tokens(request, token_budget)
252249
num_new_block = self.get_new_block_nums(request, num_new_tokens)
@@ -274,7 +271,7 @@ def schedule(self):
274271
break
275272
request = self.waiting[0]
276273
if request.status == RequestStatus.WAITING:
277-
num_new_tokens = self._get_num_new_tokens(request, token_budget, True)
274+
num_new_tokens = self._get_num_new_tokens(request, token_budget)
278275
num_new_block = self.get_new_block_nums(request, num_new_tokens)
279276
# Allocate blocks to prefill
280277
if self.cache_manager.can_allocate_gpu_blocks(num_new_block):
@@ -295,7 +292,8 @@ def schedule(self):
295292
else:
296293
break
297294
elif request.status == RequestStatus.PREEMPTED:
298-
num_new_tokens = self._get_num_new_tokens(request, token_budget, True)
295+
request.need_prefill_tokens = request.num_total_tokens # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct
296+
num_new_tokens = self._get_num_new_tokens(request, token_budget)
299297
num_new_block = self.get_new_block_nums(request, num_new_tokens)
300298
# Allocate blocks to prefill
301299
if self.cache_manager.can_allocate_gpu_blocks(num_new_block):

0 commit comments

Comments
 (0)