Skip to content

Commit 34fb28b

Browse files
committed
refactor
1 parent 25dac89 commit 34fb28b

File tree

1 file changed

+125
-138
lines changed

1 file changed

+125
-138
lines changed

sky/server/requests/requests.py

Lines changed: 125 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import threading
1414
import time
1515
import traceback
16-
from typing import (Any, AsyncGenerator, Callable, Dict, Generator, List,
17-
NamedTuple, Optional, Tuple)
16+
from typing import (Any, Callable, Dict, Generator, List, NamedTuple, Optional,
17+
Tuple)
1818
import uuid
1919

2020
import anyio
@@ -394,124 +394,6 @@ def _update_request_row_fields(
394394
return tuple(content[col] for col in REQUEST_COLUMNS)
395395

396396

397-
def kill_cluster_requests(cluster_name: str, exclude_request_name: str):
398-
"""Kill all pending and running requests for a cluster.
399-
400-
Args:
401-
cluster_name: the name of the cluster.
402-
exclude_request_names: exclude requests with these names. This is to
403-
prevent killing the caller request.
404-
"""
405-
request_ids = [
406-
request_task.request_id
407-
for request_task in get_request_tasks(req_filter=RequestTaskFilter(
408-
status=[RequestStatus.PENDING, RequestStatus.RUNNING],
409-
exclude_request_names=[exclude_request_name],
410-
cluster_names=[cluster_name],
411-
fields=['request_id']))
412-
]
413-
_kill_requests(request_ids)
414-
415-
416-
def kill_requests_with_prefix(request_ids: Optional[List[str]] = None,
417-
user_id: Optional[str] = None) -> List[str]:
418-
"""Kill requests with a given request ID prefix."""
419-
expanded_request_ids: Optional[List[str]] = None
420-
if request_ids is not None:
421-
expanded_request_ids = []
422-
for request_id in request_ids:
423-
request_tasks = get_requests_with_prefix(request_id,
424-
fields=['request_id'])
425-
if request_tasks is None or len(request_tasks) == 0:
426-
continue
427-
if len(request_tasks) > 1:
428-
raise ValueError(f'Multiple requests found for '
429-
f'request ID prefix: {request_id}')
430-
expanded_request_ids.append(request_tasks[0].request_id)
431-
return _kill_requests(request_ids=expanded_request_ids, user_id=user_id)
432-
433-
434-
def _should_kill_request(request_id: str,
435-
request_record: Optional[Request]) -> bool:
436-
if request_record is None:
437-
logger.debug(f'No request ID {request_id}')
438-
return False
439-
# Skip internal requests. The internal requests are scheduled with
440-
# request_id in range(len(INTERNAL_REQUEST_EVENTS)).
441-
if request_record.request_id in set(
442-
event.id for event in daemons.INTERNAL_REQUEST_DAEMONS):
443-
return False
444-
if request_record.status > RequestStatus.RUNNING:
445-
logger.debug(f'Request {request_id} already finished')
446-
return False
447-
return True
448-
449-
450-
def _kill_requests(request_ids: Optional[List[str]] = None,
451-
user_id: Optional[str] = None) -> List[str]:
452-
"""Kill a SkyPilot API request and set its status to cancelled.
453-
454-
Args:
455-
request_ids: The request IDs to kill. If None, all requests for the
456-
user are killed.
457-
user_id: The user ID to kill requests for. If None, all users are
458-
killed.
459-
460-
Returns:
461-
A list of request IDs that were cancelled.
462-
"""
463-
if request_ids is None:
464-
request_ids = [
465-
request_task.request_id
466-
for request_task in get_request_tasks(req_filter=RequestTaskFilter(
467-
status=[RequestStatus.PENDING, RequestStatus.RUNNING],
468-
# Avoid cancelling the cancel request itself.
469-
exclude_request_names=['sky.api_cancel'],
470-
user_id=user_id,
471-
fields=['request_id']))
472-
]
473-
cancelled_request_ids = []
474-
for request_id in request_ids:
475-
with update_request(request_id) as request_record:
476-
if not _should_kill_request(request_id, request_record):
477-
continue
478-
if request_record.pid is not None:
479-
logger.debug(f'Killing request process {request_record.pid}')
480-
# Use SIGTERM instead of SIGKILL:
481-
# - The executor can handle SIGTERM gracefully
482-
# - After SIGTERM, the executor can reuse the request process
483-
# for other requests, avoiding the overhead of forking a new
484-
# process for each request.
485-
os.kill(request_record.pid, signal.SIGTERM)
486-
request_record.status = RequestStatus.CANCELLED
487-
request_record.finished_at = time.time()
488-
cancelled_request_ids.append(request_id)
489-
return cancelled_request_ids
490-
491-
492-
@asyncio_utils.shield
493-
async def kill_request_async(request_id: str) -> bool:
494-
"""Kill a SkyPilot API request and set its status to cancelled.
495-
496-
Returns:
497-
True if the request was killed, False otherwise.
498-
"""
499-
async with _update_request_async(request_id) as request_record:
500-
if not _should_kill_request(request_id, request_record):
501-
return False
502-
if request_record.pid is not None:
503-
logger.debug(f'Killing request process {request_record.pid}')
504-
# Use SIGTERM instead of SIGKILL:
505-
# - The executor can handle SIGTERM gracefully
506-
# - After SIGTERM, the executor can reuse the request process
507-
# for other requests, avoiding the overhead of forking a new
508-
# process for each request.
509-
os.kill(request_record.pid, signal.SIGTERM)
510-
request_record.status = RequestStatus.CANCELLED
511-
request_record.finished_at = time.time()
512-
return True
513-
514-
515397
def create_table(cursor, conn):
516398
# Enable WAL mode to avoid locking issues.
517399
# See: issue #1441 and PR #1509
@@ -655,6 +537,128 @@ def request_lock_path(request_id: str) -> str:
655537
return os.path.join(lock_path, f'.{request_id}.lock')
656538

657539

540+
def kill_cluster_requests(cluster_name: str, exclude_request_name: str):
541+
"""Kill all pending and running requests for a cluster.
542+
543+
Args:
544+
cluster_name: the name of the cluster.
545+
exclude_request_names: exclude requests with these names. This is to
546+
prevent killing the caller request.
547+
"""
548+
request_ids = [
549+
request_task.request_id
550+
for request_task in get_request_tasks(req_filter=RequestTaskFilter(
551+
status=[RequestStatus.PENDING, RequestStatus.RUNNING],
552+
exclude_request_names=[exclude_request_name],
553+
cluster_names=[cluster_name],
554+
fields=['request_id']))
555+
]
556+
_kill_requests(request_ids)
557+
558+
559+
def kill_requests_with_prefix(request_ids: Optional[List[str]] = None,
560+
user_id: Optional[str] = None) -> List[str]:
561+
"""Kill requests with a given request ID prefix."""
562+
expanded_request_ids: Optional[List[str]] = None
563+
if request_ids is not None:
564+
expanded_request_ids = []
565+
for request_id in request_ids:
566+
request_tasks = get_requests_with_prefix(request_id,
567+
fields=['request_id'])
568+
if request_tasks is None or len(request_tasks) == 0:
569+
continue
570+
if len(request_tasks) > 1:
571+
raise ValueError(f'Multiple requests found for '
572+
f'request ID prefix: {request_id}')
573+
expanded_request_ids.append(request_tasks[0].request_id)
574+
return _kill_requests(request_ids=expanded_request_ids, user_id=user_id)
575+
576+
577+
def _should_kill_request(request_id: str,
578+
request_record: Optional[Request]) -> bool:
579+
if request_record is None:
580+
logger.debug(f'No request ID {request_id}')
581+
return False
582+
# Skip internal requests. The internal requests are scheduled with
583+
# request_id in range(len(INTERNAL_REQUEST_EVENTS)).
584+
if request_record.request_id in set(
585+
event.id for event in daemons.INTERNAL_REQUEST_DAEMONS):
586+
return False
587+
if request_record.status > RequestStatus.RUNNING:
588+
logger.debug(f'Request {request_id} already finished')
589+
return False
590+
return True
591+
592+
593+
def _kill_requests(request_ids: Optional[List[str]] = None,
594+
user_id: Optional[str] = None) -> List[str]:
595+
"""Kill a SkyPilot API request and set its status to cancelled.
596+
597+
Args:
598+
request_ids: The request IDs to kill. If None, all requests for the
599+
user are killed.
600+
user_id: The user ID to kill requests for. If None, all users are
601+
killed.
602+
603+
Returns:
604+
A list of request IDs that were cancelled.
605+
"""
606+
if request_ids is None:
607+
request_ids = [
608+
request_task.request_id
609+
for request_task in get_request_tasks(req_filter=RequestTaskFilter(
610+
status=[RequestStatus.PENDING, RequestStatus.RUNNING],
611+
# Avoid cancelling the cancel request itself.
612+
exclude_request_names=['sky.api_cancel'],
613+
user_id=user_id,
614+
fields=['request_id']))
615+
]
616+
cancelled_request_ids = []
617+
for request_id in request_ids:
618+
with update_request(request_id) as request_record:
619+
if not _should_kill_request(request_id, request_record):
620+
continue
621+
if request_record.pid is not None:
622+
logger.debug(f'Killing request process {request_record.pid}')
623+
# Use SIGTERM instead of SIGKILL:
624+
# - The executor can handle SIGTERM gracefully
625+
# - After SIGTERM, the executor can reuse the request process
626+
# for other requests, avoiding the overhead of forking a new
627+
# process for each request.
628+
os.kill(request_record.pid, signal.SIGTERM)
629+
request_record.status = RequestStatus.CANCELLED
630+
request_record.finished_at = time.time()
631+
cancelled_request_ids.append(request_id)
632+
return cancelled_request_ids
633+
634+
635+
@init_db_async
636+
@asyncio_utils.shield
637+
async def kill_request_async(request_id: str) -> bool:
638+
"""Kill a SkyPilot API request and set its status to cancelled.
639+
640+
Returns:
641+
True if the request was killed, False otherwise.
642+
"""
643+
async with filelock.AsyncFileLock(request_lock_path(request_id)):
644+
request = await _get_request_no_lock_async(request_id)
645+
if not _should_kill_request(request_id, request):
646+
return False
647+
assert request is not None
648+
if request.pid is not None:
649+
logger.debug(f'Killing request process {request.pid}')
650+
# Use SIGTERM instead of SIGKILL:
651+
# - The executor can handle SIGTERM gracefully
652+
# - After SIGTERM, the executor can reuse the request process
653+
# for other requests, avoiding the overhead of forking a new
654+
# process for each request.
655+
os.kill(request.pid, signal.SIGTERM)
656+
request.status = RequestStatus.CANCELLED
657+
request.finished_at = time.time()
658+
await _add_or_update_request_no_lock_async(request)
659+
return True
660+
661+
658662
@contextlib.contextmanager
659663
@init_db
660664
@metrics_lib.time_me
@@ -669,24 +673,7 @@ def update_request(request_id: str) -> Generator[Optional[Request], None, None]:
669673
_add_or_update_request_no_lock(request)
670674

671675

672-
@contextlib.asynccontextmanager
673676
@init_db_async
674-
async def _update_request_async(
675-
request_id: str) -> AsyncGenerator[Optional[Request], None]:
676-
"""Get and update a SkyPilot API request.
677-
678-
This function is not shielded from cancellation.
679-
The caller should use @asyncio_utils.shield to protect it."""
680-
# Acquire the lock to avoid race conditions between multiple request
681-
# operations, e.g. execute and cancel.
682-
async with filelock.AsyncFileLock(request_lock_path(request_id)):
683-
request = await _get_request_no_lock_async(request_id)
684-
yield request
685-
if request is not None:
686-
await _add_or_update_request_no_lock_async(request)
687-
688-
689-
@init_db
690677
@metrics_lib.time_me
691678
@asyncio_utils.shield
692679
async def update_status_async(request_id: str, status: RequestStatus) -> None:
@@ -698,7 +685,7 @@ async def update_status_async(request_id: str, status: RequestStatus) -> None:
698685
await _add_or_update_request_no_lock_async(request)
699686

700687

701-
@init_db
688+
@init_db_async
702689
@metrics_lib.time_me
703690
@asyncio_utils.shield
704691
async def update_status_msg_async(request_id: str, status_msg: str) -> None:

0 commit comments

Comments
 (0)