1313import threading
1414import time
1515import traceback
16- from typing import (Any , AsyncGenerator , Callable , Dict , Generator , List ,
17- NamedTuple , Optional , Tuple )
16+ from typing import (Any , Callable , Dict , Generator , List , NamedTuple , Optional ,
17+ Tuple )
1818import uuid
1919
2020import anyio
@@ -394,124 +394,6 @@ def _update_request_row_fields(
394394 return tuple (content [col ] for col in REQUEST_COLUMNS )
395395
396396
397- def kill_cluster_requests (cluster_name : str , exclude_request_name : str ):
398- """Kill all pending and running requests for a cluster.
399-
400- Args:
401- cluster_name: the name of the cluster.
402- exclude_request_names: exclude requests with these names. This is to
403- prevent killing the caller request.
404- """
405- request_ids = [
406- request_task .request_id
407- for request_task in get_request_tasks (req_filter = RequestTaskFilter (
408- status = [RequestStatus .PENDING , RequestStatus .RUNNING ],
409- exclude_request_names = [exclude_request_name ],
410- cluster_names = [cluster_name ],
411- fields = ['request_id' ]))
412- ]
413- _kill_requests (request_ids )
414-
415-
416- def kill_requests_with_prefix (request_ids : Optional [List [str ]] = None ,
417- user_id : Optional [str ] = None ) -> List [str ]:
418- """Kill requests with a given request ID prefix."""
419- expanded_request_ids : Optional [List [str ]] = None
420- if request_ids is not None :
421- expanded_request_ids = []
422- for request_id in request_ids :
423- request_tasks = get_requests_with_prefix (request_id ,
424- fields = ['request_id' ])
425- if request_tasks is None or len (request_tasks ) == 0 :
426- continue
427- if len (request_tasks ) > 1 :
428- raise ValueError (f'Multiple requests found for '
429- f'request ID prefix: { request_id } ' )
430- expanded_request_ids .append (request_tasks [0 ].request_id )
431- return _kill_requests (request_ids = expanded_request_ids , user_id = user_id )
432-
433-
434- def _should_kill_request (request_id : str ,
435- request_record : Optional [Request ]) -> bool :
436- if request_record is None :
437- logger .debug (f'No request ID { request_id } ' )
438- return False
439- # Skip internal requests. The internal requests are scheduled with
440- # request_id in range(len(INTERNAL_REQUEST_EVENTS)).
441- if request_record .request_id in set (
442- event .id for event in daemons .INTERNAL_REQUEST_DAEMONS ):
443- return False
444- if request_record .status > RequestStatus .RUNNING :
445- logger .debug (f'Request { request_id } already finished' )
446- return False
447- return True
448-
449-
450- def _kill_requests (request_ids : Optional [List [str ]] = None ,
451- user_id : Optional [str ] = None ) -> List [str ]:
452- """Kill a SkyPilot API request and set its status to cancelled.
453-
454- Args:
455- request_ids: The request IDs to kill. If None, all requests for the
456- user are killed.
457- user_id: The user ID to kill requests for. If None, all users are
458- killed.
459-
460- Returns:
461- A list of request IDs that were cancelled.
462- """
463- if request_ids is None :
464- request_ids = [
465- request_task .request_id
466- for request_task in get_request_tasks (req_filter = RequestTaskFilter (
467- status = [RequestStatus .PENDING , RequestStatus .RUNNING ],
468- # Avoid cancelling the cancel request itself.
469- exclude_request_names = ['sky.api_cancel' ],
470- user_id = user_id ,
471- fields = ['request_id' ]))
472- ]
473- cancelled_request_ids = []
474- for request_id in request_ids :
475- with update_request (request_id ) as request_record :
476- if not _should_kill_request (request_id , request_record ):
477- continue
478- if request_record .pid is not None :
479- logger .debug (f'Killing request process { request_record .pid } ' )
480- # Use SIGTERM instead of SIGKILL:
481- # - The executor can handle SIGTERM gracefully
482- # - After SIGTERM, the executor can reuse the request process
483- # for other requests, avoiding the overhead of forking a new
484- # process for each request.
485- os .kill (request_record .pid , signal .SIGTERM )
486- request_record .status = RequestStatus .CANCELLED
487- request_record .finished_at = time .time ()
488- cancelled_request_ids .append (request_id )
489- return cancelled_request_ids
490-
491-
492- @asyncio_utils .shield
493- async def kill_request_async (request_id : str ) -> bool :
494- """Kill a SkyPilot API request and set its status to cancelled.
495-
496- Returns:
497- True if the request was killed, False otherwise.
498- """
499- async with _update_request_async (request_id ) as request_record :
500- if not _should_kill_request (request_id , request_record ):
501- return False
502- if request_record .pid is not None :
503- logger .debug (f'Killing request process { request_record .pid } ' )
504- # Use SIGTERM instead of SIGKILL:
505- # - The executor can handle SIGTERM gracefully
506- # - After SIGTERM, the executor can reuse the request process
507- # for other requests, avoiding the overhead of forking a new
508- # process for each request.
509- os .kill (request_record .pid , signal .SIGTERM )
510- request_record .status = RequestStatus .CANCELLED
511- request_record .finished_at = time .time ()
512- return True
513-
514-
515397def create_table (cursor , conn ):
516398 # Enable WAL mode to avoid locking issues.
517399 # See: issue #1441 and PR #1509
@@ -655,6 +537,128 @@ def request_lock_path(request_id: str) -> str:
655537 return os .path .join (lock_path , f'.{ request_id } .lock' )
656538
657539
540+ def kill_cluster_requests (cluster_name : str , exclude_request_name : str ):
541+ """Kill all pending and running requests for a cluster.
542+
543+ Args:
544+ cluster_name: the name of the cluster.
545+ exclude_request_names: exclude requests with these names. This is to
546+ prevent killing the caller request.
547+ """
548+ request_ids = [
549+ request_task .request_id
550+ for request_task in get_request_tasks (req_filter = RequestTaskFilter (
551+ status = [RequestStatus .PENDING , RequestStatus .RUNNING ],
552+ exclude_request_names = [exclude_request_name ],
553+ cluster_names = [cluster_name ],
554+ fields = ['request_id' ]))
555+ ]
556+ _kill_requests (request_ids )
557+
558+
559+ def kill_requests_with_prefix (request_ids : Optional [List [str ]] = None ,
560+ user_id : Optional [str ] = None ) -> List [str ]:
561+ """Kill requests with a given request ID prefix."""
562+ expanded_request_ids : Optional [List [str ]] = None
563+ if request_ids is not None :
564+ expanded_request_ids = []
565+ for request_id in request_ids :
566+ request_tasks = get_requests_with_prefix (request_id ,
567+ fields = ['request_id' ])
568+ if request_tasks is None or len (request_tasks ) == 0 :
569+ continue
570+ if len (request_tasks ) > 1 :
571+ raise ValueError (f'Multiple requests found for '
572+ f'request ID prefix: { request_id } ' )
573+ expanded_request_ids .append (request_tasks [0 ].request_id )
574+ return _kill_requests (request_ids = expanded_request_ids , user_id = user_id )
575+
576+
577+ def _should_kill_request (request_id : str ,
578+ request_record : Optional [Request ]) -> bool :
579+ if request_record is None :
580+ logger .debug (f'No request ID { request_id } ' )
581+ return False
582+ # Skip internal requests. The internal requests are scheduled with
583+ # request_id in range(len(INTERNAL_REQUEST_EVENTS)).
584+ if request_record .request_id in set (
585+ event .id for event in daemons .INTERNAL_REQUEST_DAEMONS ):
586+ return False
587+ if request_record .status > RequestStatus .RUNNING :
588+ logger .debug (f'Request { request_id } already finished' )
589+ return False
590+ return True
591+
592+
593+ def _kill_requests (request_ids : Optional [List [str ]] = None ,
594+ user_id : Optional [str ] = None ) -> List [str ]:
595+ """Kill a SkyPilot API request and set its status to cancelled.
596+
597+ Args:
598+ request_ids: The request IDs to kill. If None, all requests for the
599+ user are killed.
600+ user_id: The user ID to kill requests for. If None, all users are
601+ killed.
602+
603+ Returns:
604+ A list of request IDs that were cancelled.
605+ """
606+ if request_ids is None :
607+ request_ids = [
608+ request_task .request_id
609+ for request_task in get_request_tasks (req_filter = RequestTaskFilter (
610+ status = [RequestStatus .PENDING , RequestStatus .RUNNING ],
611+ # Avoid cancelling the cancel request itself.
612+ exclude_request_names = ['sky.api_cancel' ],
613+ user_id = user_id ,
614+ fields = ['request_id' ]))
615+ ]
616+ cancelled_request_ids = []
617+ for request_id in request_ids :
618+ with update_request (request_id ) as request_record :
619+ if not _should_kill_request (request_id , request_record ):
620+ continue
621+ if request_record .pid is not None :
622+ logger .debug (f'Killing request process { request_record .pid } ' )
623+ # Use SIGTERM instead of SIGKILL:
624+ # - The executor can handle SIGTERM gracefully
625+ # - After SIGTERM, the executor can reuse the request process
626+ # for other requests, avoiding the overhead of forking a new
627+ # process for each request.
628+ os .kill (request_record .pid , signal .SIGTERM )
629+ request_record .status = RequestStatus .CANCELLED
630+ request_record .finished_at = time .time ()
631+ cancelled_request_ids .append (request_id )
632+ return cancelled_request_ids
633+
634+
635+ @init_db_async
636+ @asyncio_utils .shield
637+ async def kill_request_async (request_id : str ) -> bool :
638+ """Kill a SkyPilot API request and set its status to cancelled.
639+
640+ Returns:
641+ True if the request was killed, False otherwise.
642+ """
643+ async with filelock .AsyncFileLock (request_lock_path (request_id )):
644+ request = await _get_request_no_lock_async (request_id )
645+ if not _should_kill_request (request_id , request ):
646+ return False
647+ assert request is not None
648+ if request .pid is not None :
649+ logger .debug (f'Killing request process { request .pid } ' )
650+ # Use SIGTERM instead of SIGKILL:
651+ # - The executor can handle SIGTERM gracefully
652+ # - After SIGTERM, the executor can reuse the request process
653+ # for other requests, avoiding the overhead of forking a new
654+ # process for each request.
655+ os .kill (request .pid , signal .SIGTERM )
656+ request .status = RequestStatus .CANCELLED
657+ request .finished_at = time .time ()
658+ await _add_or_update_request_no_lock_async (request )
659+ return True
660+
661+
658662@contextlib .contextmanager
659663@init_db
660664@metrics_lib .time_me
@@ -669,24 +673,7 @@ def update_request(request_id: str) -> Generator[Optional[Request], None, None]:
669673 _add_or_update_request_no_lock (request )
670674
671675
672- @contextlib .asynccontextmanager
673676@init_db_async
674- async def _update_request_async (
675- request_id : str ) -> AsyncGenerator [Optional [Request ], None ]:
676- """Get and update a SkyPilot API request.
677-
678- This function is not shielded from cancellation.
679- The caller should use @asyncio_utils.shield to protect it."""
680- # Acquire the lock to avoid race conditions between multiple request
681- # operations, e.g. execute and cancel.
682- async with filelock .AsyncFileLock (request_lock_path (request_id )):
683- request = await _get_request_no_lock_async (request_id )
684- yield request
685- if request is not None :
686- await _add_or_update_request_no_lock_async (request )
687-
688-
689- @init_db
690677@metrics_lib .time_me
691678@asyncio_utils .shield
692679async def update_status_async (request_id : str , status : RequestStatus ) -> None :
@@ -698,7 +685,7 @@ async def update_status_async(request_id: str, status: RequestStatus) -> None:
698685 await _add_or_update_request_no_lock_async (request )
699686
700687
701- @init_db
688+ @init_db_async
702689@metrics_lib .time_me
703690@asyncio_utils .shield
704691async def update_status_msg_async (request_id : str , status_msg : str ) -> None :
0 commit comments