Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions sky/server/requests/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import dataclasses
import enum
import functools
import json
import os
import pathlib
import shutil
Expand All @@ -21,6 +20,7 @@
import anyio
import colorama
import filelock
import orjson

from sky import exceptions
from sky import global_user_state
Expand Down Expand Up @@ -213,8 +213,8 @@ def readable_encode(self) -> payloads.RequestPayload:
entrypoint=self.entrypoint.__name__,
request_body=self.request_body.model_dump_json(),
status=self.status.value,
return_value=json.dumps(None),
error=json.dumps(None),
return_value=orjson.dumps(None).decode('utf-8'),
error=orjson.dumps(None).decode('utf-8'),
pid=None,
created_at=self.created_at,
schedule_type=self.schedule_type.value,
Expand All @@ -237,8 +237,8 @@ def encode(self) -> payloads.RequestPayload:
entrypoint=encoders.pickle_and_encode(self.entrypoint),
request_body=encoders.pickle_and_encode(self.request_body),
status=self.status.value,
return_value=json.dumps(self.return_value),
error=json.dumps(self.error),
return_value=orjson.dumps(self.return_value).decode('utf-8'),
error=orjson.dumps(self.error).decode('utf-8'),
pid=self.pid,
created_at=self.created_at,
schedule_type=self.schedule_type.value,
Expand Down Expand Up @@ -270,8 +270,8 @@ def decode(cls, payload: payloads.RequestPayload) -> 'Request':
entrypoint=decoders.decode_and_unpickle(payload.entrypoint),
request_body=decoders.decode_and_unpickle(payload.request_body),
status=RequestStatus(payload.status),
return_value=json.loads(payload.return_value),
error=json.loads(payload.error),
return_value=orjson.loads(payload.return_value),
error=orjson.loads(payload.error),
pid=payload.pid,
created_at=payload.created_at,
schedule_type=ScheduleType(payload.schedule_type),
Expand Down Expand Up @@ -328,10 +328,11 @@ def encode_requests(requests: List[Request]) -> List[payloads.RequestPayload]:
entrypoint=request.entrypoint.__name__
if request.entrypoint is not None else '',
request_body=request.request_body.model_dump_json()
if request.request_body is not None else json.dumps(None),
if request.request_body is not None else
orjson.dumps(None).decode('utf-8'),
status=request.status.value,
return_value=json.dumps(None),
error=json.dumps(None),
return_value=orjson.dumps(None).decode('utf-8'),
error=orjson.dumps(None).decode('utf-8'),
pid=None,
created_at=request.created_at,
schedule_type=request.schedule_type.value,
Expand Down Expand Up @@ -372,9 +373,9 @@ def _update_request_row_fields(
if 'user_id' not in fields:
content['user_id'] = ''
if 'return_value' not in fields:
content['return_value'] = json.dumps(None)
content['return_value'] = orjson.dumps(None).decode('utf-8')
if 'error' not in fields:
content['error'] = json.dumps(None)
content['error'] = orjson.dumps(None).decode('utf-8')
if 'schedule_type' not in fields:
content['schedule_type'] = ScheduleType.SHORT.value
# Optional fields in RequestPayload
Expand Down
4 changes: 2 additions & 2 deletions sky/server/requests/serializers/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,11 @@ def encode_volume_list(


@register_encoder('job_status')
def encode_job_status(return_value: Dict[int, Any]) -> Dict[int, str]:
def encode_job_status(return_value: Dict[int, Any]) -> Dict[str, str]:
for job_id in return_value.keys():
if return_value[job_id] is not None:
return_value[job_id] = return_value[job_id].value
return return_value
return {str(k): v for k, v in return_value.items()}


@register_encoder('kubernetes_node_info')
Expand Down
3 changes: 2 additions & 1 deletion sky/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import aiofiles
import anyio
import fastapi
from fastapi import responses as fastapi_responses
from fastapi.middleware import cors
import starlette.middleware.base
import uvloop
Expand Down Expand Up @@ -1498,7 +1499,7 @@ async def local_down(request: fastapi.Request,


# === API server related APIs ===
@app.get('/api/get')
@app.get('/api/get', response_class=fastapi_responses.ORJSONResponse)
async def api_get(request_id: str) -> payloads.RequestPayload:
"""Gets a request with a given request ID prefix."""
while True:
Expand Down
1 change: 1 addition & 0 deletions sky/setup_files/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
# <= 3.13 may encounter https://github.com/ultralytics/yolov5/issues/414
'pyyaml > 3.13, != 5.4.*',
'ijson',
'orjson',
'requests',
# SkyPilot inherits from uvicorn.Server to customize the behavior of
# uvicorn, so we need to pin uvicorn version to avoid potential break
Expand Down