Skip to content

Commit 72125e8

Browse files
cg505aylei
andauthored
[core] reduce database query frequency for /api/stream (#7569)
* [core] reduce database query frequency for /api/stream * Refine Signed-off-by: Aylei <[email protected]> * Fix bug Signed-off-by: Aylei <[email protected]> * Fix get_request_async called in coroutine polling Signed-off-by: Aylei <[email protected]> * Fix UT Signed-off-by: Aylei <[email protected]> --------- Signed-off-by: Aylei <[email protected]> Co-authored-by: Aylei <[email protected]>
1 parent 2865b60 commit 72125e8

File tree

7 files changed

+83
-41
lines changed

7 files changed

+83
-41
lines changed

sky/jobs/server/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ async def logs(
116116
# Cancel the coroutine after the request is done or client disconnects
117117
background_tasks.add_task(task.cancel)
118118

119-
return stream_utils.stream_response(
119+
return stream_utils.stream_response_for_long_request(
120120
request_id=request_task.request_id,
121121
logs_path=request_task.log_path,
122122
background_tasks=background_tasks,
@@ -201,7 +201,7 @@ async def pool_tail_logs(
201201

202202
request_task = api_requests.get_request(request.state.request_id)
203203

204-
return stream_utils.stream_response(
204+
return stream_utils.stream_response_for_long_request(
205205
request_id=request_task.request_id,
206206
logs_path=request_task.log_path,
207207
background_tasks=background_tasks,

sky/serve/server/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ async def tail_logs(
109109
task = executor.execute_request_in_coroutine(request_task)
110110
# Cancel the coroutine after the request is done or client disconnects
111111
background_tasks.add_task(task.cancel)
112-
return stream_utils.stream_response(
112+
return stream_utils.stream_response_for_long_request(
113113
request_id=request_task.request_id,
114114
logs_path=request_task.log_path,
115115
background_tasks=background_tasks,

sky/server/requests/executor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -580,11 +580,11 @@ async def _execute_request_coroutine(request: api_requests.Request):
580580
**request_body.to_kwargs())
581581

582582
async def poll_task(request_id: str) -> bool:
583-
request = await api_requests.get_request_async(request_id)
584-
if request is None:
583+
req_status = await api_requests.get_request_status_async(request_id)
584+
if req_status is None:
585585
raise RuntimeError('Request not found')
586586

587-
if request.status == api_requests.RequestStatus.CANCELLED:
587+
if req_status.status == api_requests.RequestStatus.CANCELLED:
588588
ctx.cancel()
589589
return True
590590

sky/server/server.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,7 +1243,7 @@ async def logs(
12431243
background_tasks.add_task(task.cancel)
12441244
# TODO(zhwu): This makes viewing logs in browser impossible. We should adopt
12451245
# the same approach as /stream.
1246-
return stream_utils.stream_response(
1246+
return stream_utils.stream_response_for_long_request(
12471247
request_id=request.state.request_id,
12481248
logs_path=request_task.log_path,
12491249
background_tasks=background_tasks,
@@ -1539,6 +1539,7 @@ async def stream(
15391539
'X-Accel-Buffering': 'no'
15401540
})
15411541

1542+
polling_interval = stream_utils.DEFAULT_POLL_INTERVAL
15421543
# Original plain text streaming logic
15431544
if request_id is not None:
15441545
request_task = await requests_lib.get_request_async(request_id)
@@ -1553,6 +1554,8 @@ async def stream(
15531554
raise fastapi.HTTPException(
15541555
status_code=404,
15551556
detail=f'Log of request {request_id!r} has been deleted')
1557+
if request_task.schedule_type == requests_lib.ScheduleType.LONG:
1558+
polling_interval = stream_utils.LONG_REQUEST_POLL_INTERVAL
15561559
else:
15571560
assert log_path is not None, (request_id, log_path)
15581561
if log_path == constants.API_SERVER_LOGS:
@@ -1600,7 +1603,8 @@ async def stream(
16001603
log_path_to_stream,
16011604
plain_logs=format == 'plain',
16021605
tail=tail,
1603-
follow=follow),
1606+
follow=follow,
1607+
polling_interval=polling_interval),
16041608
media_type='text/plain',
16051609
headers=headers,
16061610
)

sky/server/stream_utils.py

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sky import global_user_state
1212
from sky import sky_logging
1313
from sky.server.requests import requests as requests_lib
14+
from sky.utils import common_utils
1415
from sky.utils import message_utils
1516
from sky.utils import rich_utils
1617
from sky.utils import status_lib
@@ -24,7 +25,9 @@
2425
_BUFFER_SIZE = 8 * 1024 # 8KB
2526
_BUFFER_TIMEOUT = 0.02 # 20ms
2627
_HEARTBEAT_INTERVAL = 30
27-
_CLUSTER_STATUS_INTERVAL = 1
28+
29+
LONG_REQUEST_POLL_INTERVAL = 1
30+
DEFAULT_POLL_INTERVAL = 0.1
2831

2932

3033
async def _yield_log_file_with_payloads_skipped(
@@ -41,12 +44,14 @@ async def _yield_log_file_with_payloads_skipped(
4144

4245

4346
async def log_streamer(
44-
request_id: Optional[str],
45-
log_path: pathlib.Path,
46-
plain_logs: bool = False,
47-
tail: Optional[int] = None,
48-
follow: bool = True,
49-
cluster_name: Optional[str] = None) -> AsyncGenerator[str, None]:
47+
request_id: Optional[str],
48+
log_path: pathlib.Path,
49+
plain_logs: bool = False,
50+
tail: Optional[int] = None,
51+
follow: bool = True,
52+
cluster_name: Optional[str] = None,
53+
polling_interval: float = DEFAULT_POLL_INTERVAL
54+
) -> AsyncGenerator[str, None]:
5055
"""Streams the logs of a request.
5156
5257
Args:
@@ -84,6 +89,11 @@ async def log_streamer(
8489
f'scheduled: {request_id}')
8590
req_status = request_task.status
8691
req_msg = request_task.status_msg
92+
# Slowly back off the database polling up to every 1 second, to avoid
93+
# overloading the CPU and DB.
94+
backoff = common_utils.Backoff(initial_backoff=polling_interval,
95+
max_backoff_factor=10,
96+
multiplier=1.2)
8797
while req_status < requests_lib.RequestStatus.RUNNING:
8898
if req_msg is not None:
8999
waiting_msg = request_task.status_msg
@@ -99,7 +109,7 @@ async def log_streamer(
99109
# TODO(aylei): we should use a better mechanism to avoid busy
100110
# polling the DB, which can be a bottleneck for high-concurrency
101111
# requests.
102-
await asyncio.sleep(0.1)
112+
await asyncio.sleep(backoff.current_backoff())
103113
status_with_msg = await requests_lib.get_request_status_async(
104114
request_id, include_msg=True)
105115
req_status = status_with_msg.status
@@ -111,17 +121,20 @@ async def log_streamer(
111121

112122
async with aiofiles.open(log_path, 'rb') as f:
113123
async for chunk in _tail_log_file(f, request_id, plain_logs, tail,
114-
follow, cluster_name):
124+
follow, cluster_name,
125+
polling_interval):
115126
yield chunk
116127

117128

118129
async def _tail_log_file(
119-
f: aiofiles.threadpool.binary.AsyncBufferedReader,
120-
request_id: Optional[str] = None,
121-
plain_logs: bool = False,
122-
tail: Optional[int] = None,
123-
follow: bool = True,
124-
cluster_name: Optional[str] = None) -> AsyncGenerator[str, None]:
130+
f: aiofiles.threadpool.binary.AsyncBufferedReader,
131+
request_id: Optional[str] = None,
132+
plain_logs: bool = False,
133+
tail: Optional[int] = None,
134+
follow: bool = True,
135+
cluster_name: Optional[str] = None,
136+
polling_interval: float = DEFAULT_POLL_INTERVAL
137+
) -> AsyncGenerator[str, None]:
125138
"""Tail the opened log file, buffer the lines and flush in chunks."""
126139

127140
if tail is not None:
@@ -137,7 +150,7 @@ async def _tail_log_file(
137150
yield line_str
138151

139152
last_heartbeat_time = asyncio.get_event_loop().time()
140-
last_cluster_status_check_time = asyncio.get_event_loop().time()
153+
last_status_check_time = asyncio.get_event_loop().time()
141154

142155
# Buffer the lines in memory and flush them in chunks to improve log
143156
# tailing throughput.
@@ -167,7 +180,17 @@ async def flush_buffer() -> AsyncGenerator[str, None]:
167180

168181
line: Optional[bytes] = await f.readline()
169182
if not line:
170-
if request_id is not None:
183+
# Avoid checking the status too frequently to avoid overloading the
184+
# DB.
185+
should_check_status = (current_time -
186+
last_status_check_time) >= polling_interval
187+
if not follow:
188+
# We will only hit this path once, but we should make sure to
189+
# check the status so that we display the final request status
190+
# if the request is complete.
191+
should_check_status = True
192+
if request_id is not None and should_check_status:
193+
last_status_check_time = current_time
171194
req_status = await requests_lib.get_request_status_async(
172195
request_id)
173196
if req_status.status > requests_lib.RequestStatus.RUNNING:
@@ -185,20 +208,19 @@ async def flush_buffer() -> AsyncGenerator[str, None]:
185208
' cancelled\n')
186209
break
187210
if not follow:
211+
# The below checks (cluster status, heartbeat) are not needed
212+
# for non-follow logs.
188213
break
189214
# Provision logs pass in cluster_name, check cluster status
190-
# periodically to see if provisioning is done. We only
191-
# check once a second to avoid overloading the DB.
192-
check_status = (current_time - last_cluster_status_check_time
193-
) >= _CLUSTER_STATUS_INTERVAL
194-
if cluster_name is not None and check_status:
215+
# periodically to see if provisioning is done.
216+
if cluster_name is not None and should_check_status:
217+
last_status_check_time = current_time
195218
cluster_record = await (
196219
global_user_state.get_status_from_cluster_name_async(
197220
cluster_name))
198221
if (cluster_record is None or
199222
cluster_record != status_lib.ClusterStatus.INIT):
200223
break
201-
last_cluster_status_check_time = current_time
202224
if current_time - last_heartbeat_time >= _HEARTBEAT_INTERVAL:
203225
# Currently just used to keep the connection busy, refer to
204226
# https://github.com/skypilot-org/skypilot/issues/5750 for
@@ -234,9 +256,22 @@ async def flush_buffer() -> AsyncGenerator[str, None]:
234256
yield chunk
235257

236258

259+
def stream_response_for_long_request(
260+
request_id: str,
261+
logs_path: pathlib.Path,
262+
background_tasks: fastapi.BackgroundTasks,
263+
) -> fastapi.responses.StreamingResponse:
264+
return stream_response(request_id,
265+
logs_path,
266+
background_tasks,
267+
polling_interval=LONG_REQUEST_POLL_INTERVAL)
268+
269+
237270
def stream_response(
238-
request_id: str, logs_path: pathlib.Path,
239-
background_tasks: fastapi.BackgroundTasks
271+
request_id: str,
272+
logs_path: pathlib.Path,
273+
background_tasks: fastapi.BackgroundTasks,
274+
polling_interval: float = DEFAULT_POLL_INTERVAL
240275
) -> fastapi.responses.StreamingResponse:
241276

242277
async def on_disconnect():
@@ -249,7 +284,7 @@ async def on_disconnect():
249284
background_tasks.add_task(on_disconnect)
250285

251286
return fastapi.responses.StreamingResponse(
252-
log_streamer(request_id, logs_path),
287+
log_streamer(request_id, logs_path, polling_interval=polling_interval),
253288
media_type='text/plain',
254289
headers={
255290
'Cache-Control': 'no-cache, no-transform',

sky/utils/common_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,13 +265,16 @@ def get_global_job_id(job_timestamp: str,
265265

266266
class Backoff:
267267
"""Exponential backoff with jittering."""
268-
MULTIPLIER = 1.6
269268
JITTER = 0.4
270269

271-
def __init__(self, initial_backoff: float = 5, max_backoff_factor: int = 5):
270+
def __init__(self,
271+
initial_backoff: float = 5,
272+
max_backoff_factor: int = 5,
273+
multiplier: float = 1.6):
272274
self._initial = True
273275
self._backoff = 0.0
274276
self._initial_backoff = initial_backoff
277+
self._multiplier = multiplier
275278
self._max_backoff = max_backoff_factor * self._initial_backoff
276279

277280
# https://github.com/grpc/grpc/blob/2d4f3c56001cd1e1f85734b2f7c5ce5f2797c38a/doc/connection-backoff.md
@@ -283,7 +286,7 @@ def current_backoff(self) -> float:
283286
self._initial = False
284287
self._backoff = min(self._initial_backoff, self._max_backoff)
285288
else:
286-
self._backoff = min(self._backoff * self.MULTIPLIER,
289+
self._backoff = min(self._backoff * self._multiplier,
287290
self._max_backoff)
288291
self._backoff += random.uniform(-self.JITTER * self._backoff,
289292
self.JITTER * self._backoff)

tests/unit_tests/test_sky/server/test_server.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,10 @@ def slow_execute(*args, **kwargs):
210210
# Verify the executor calls
211211
mock_prepare.assert_called_once()
212212
mock_execute.assert_called_once_with(mock_request_task)
213-
mock_stream.assert_called_once_with(
214-
request_id=mock.ANY,
215-
logs_path=mock_request_task.log_path,
216-
background_tasks=mock.ANY)
213+
mock_stream.assert_called_once_with(mock.ANY,
214+
mock_request_task.log_path,
215+
mock.ANY,
216+
polling_interval=1)
217217

218218

219219
@mock.patch('sky.utils.context_utils.hijack_sys_attrs')

0 commit comments

Comments
 (0)