@@ -117,11 +117,8 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re
117117 break
118118 return can_schedule
119119
120- def _get_num_new_tokens (self , request , token_budget , schedule_waiting = False ):
121- if schedule_waiting :
122- num_new_tokens = request .num_total_tokens - request .num_computed_tokens
123- else :
124- num_new_tokens = request .prompt_token_ids_len - request .num_computed_tokens
120+ def _get_num_new_tokens (self , request , token_budget ):
121+ num_new_tokens = request .need_prefill_tokens - request .num_computed_tokens
125122 num_new_tokens = min (num_new_tokens , token_budget )
126123
127124 if not self .config .enable_mm :
@@ -212,8 +209,8 @@ def schedule(self):
212209 num_decoding_req_nums = 0
213210 while req_index < len (self .running ) and token_budget > 0 :
214211 request = self .running [req_index ]
215- if request .num_computed_tokens >= request .prompt_token_ids_len : # to be decoding
216- if request .num_total_tokens > request .prompt_token_ids_len : # has generated tokens
212+ if request .num_computed_tokens >= request .need_prefill_tokens : # to be decoding
213+ if request .num_total_tokens > request .need_prefill_tokens : # has generated tokens
217214 request .num_computed_tokens = request .num_total_tokens - 1
218215 if (
219216 self .allocated_slots (request ) - request .num_total_tokens
@@ -246,7 +243,7 @@ def schedule(self):
246243 token_budget -= 1
247244 else : # need to prefill
248245 llm_logger .debug (
249- f"scheduler prefill task: { request } request.prompt_token_ids_len { request .prompt_token_ids_len } request.num_computed_tokens { request .num_computed_tokens } "
246+ f"scheduler prefill task: { request } request.need_prefill_tokens { request .need_prefill_tokens } request.num_computed_tokens { request .num_computed_tokens } "
250247 )
251248 num_new_tokens = self ._get_num_new_tokens (request , token_budget )
252249 num_new_block = self .get_new_block_nums (request , num_new_tokens )
@@ -274,7 +271,7 @@ def schedule(self):
274271 break
275272 request = self .waiting [0 ]
276273 if request .status == RequestStatus .WAITING :
277- num_new_tokens = self ._get_num_new_tokens (request , token_budget , True )
274+ num_new_tokens = self ._get_num_new_tokens (request , token_budget )
278275 num_new_block = self .get_new_block_nums (request , num_new_tokens )
279276 # Allocate blocks to prefill
280277 if self .cache_manager .can_allocate_gpu_blocks (num_new_block ):
@@ -295,7 +292,8 @@ def schedule(self):
295292 else :
296293 break
297294 elif request .status == RequestStatus .PREEMPTED :
298- num_new_tokens = self ._get_num_new_tokens (request , token_budget , True )
295+ request .need_prefill_tokens = request .num_total_tokens # Before preempted task rescheduled, preempted task has been sent to engine, no more tokens are output, here num_total_tokens should be static and correct
296+ num_new_tokens = self ._get_num_new_tokens (request , token_budget )
299297 num_new_block = self .get_new_block_nums (request , num_new_tokens )
300298 # Allocate blocks to prefill
301299 if self .cache_manager .can_allocate_gpu_blocks (num_new_block ):
0 commit comments