1111from sky import global_user_state
1212from sky import sky_logging
1313from sky .server .requests import requests as requests_lib
14+ from sky .utils import common_utils
1415from sky .utils import message_utils
1516from sky .utils import rich_utils
1617from sky .utils import status_lib
2425_BUFFER_SIZE = 8 * 1024 # 8KB
2526_BUFFER_TIMEOUT = 0.02 # 20ms
2627_HEARTBEAT_INTERVAL = 30
27- _CLUSTER_STATUS_INTERVAL = 1
28+
29+ LONG_REQUEST_POLL_INTERVAL = 1
30+ DEFAULT_POLL_INTERVAL = 0.1
2831
2932
3033async def _yield_log_file_with_payloads_skipped (
@@ -41,12 +44,14 @@ async def _yield_log_file_with_payloads_skipped(
4144
4245
4346async def log_streamer (
44- request_id : Optional [str ],
45- log_path : pathlib .Path ,
46- plain_logs : bool = False ,
47- tail : Optional [int ] = None ,
48- follow : bool = True ,
49- cluster_name : Optional [str ] = None ) -> AsyncGenerator [str , None ]:
47+ request_id : Optional [str ],
48+ log_path : pathlib .Path ,
49+ plain_logs : bool = False ,
50+ tail : Optional [int ] = None ,
51+ follow : bool = True ,
52+ cluster_name : Optional [str ] = None ,
53+ polling_interval : float = DEFAULT_POLL_INTERVAL
54+ ) -> AsyncGenerator [str , None ]:
5055 """Streams the logs of a request.
5156
5257 Args:
@@ -84,6 +89,11 @@ async def log_streamer(
8489 f'scheduled: { request_id } ' )
8590 req_status = request_task .status
8691 req_msg = request_task .status_msg
92+ # Slowly back off the database polling up to every 1 second, to avoid
93+ # overloading the CPU and DB.
94+ backoff = common_utils .Backoff (initial_backoff = polling_interval ,
95+ max_backoff_factor = 10 ,
96+ multiplier = 1.2 )
8797 while req_status < requests_lib .RequestStatus .RUNNING :
8898 if req_msg is not None :
8999 waiting_msg = request_task .status_msg
@@ -99,7 +109,7 @@ async def log_streamer(
99109 # TODO(aylei): we should use a better mechanism to avoid busy
100110 # polling the DB, which can be a bottleneck for high-concurrency
101111 # requests.
102- await asyncio .sleep (0.1 )
112+ await asyncio .sleep (backoff . current_backoff () )
103113 status_with_msg = await requests_lib .get_request_status_async (
104114 request_id , include_msg = True )
105115 req_status = status_with_msg .status
@@ -111,17 +121,20 @@ async def log_streamer(
111121
112122 async with aiofiles .open (log_path , 'rb' ) as f :
113123 async for chunk in _tail_log_file (f , request_id , plain_logs , tail ,
114- follow , cluster_name ):
124+ follow , cluster_name ,
125+ polling_interval ):
115126 yield chunk
116127
117128
118129async def _tail_log_file (
119- f : aiofiles .threadpool .binary .AsyncBufferedReader ,
120- request_id : Optional [str ] = None ,
121- plain_logs : bool = False ,
122- tail : Optional [int ] = None ,
123- follow : bool = True ,
124- cluster_name : Optional [str ] = None ) -> AsyncGenerator [str , None ]:
130+ f : aiofiles .threadpool .binary .AsyncBufferedReader ,
131+ request_id : Optional [str ] = None ,
132+ plain_logs : bool = False ,
133+ tail : Optional [int ] = None ,
134+ follow : bool = True ,
135+ cluster_name : Optional [str ] = None ,
136+ polling_interval : float = DEFAULT_POLL_INTERVAL
137+ ) -> AsyncGenerator [str , None ]:
125138 """Tail the opened log file, buffer the lines and flush in chunks."""
126139
127140 if tail is not None :
@@ -137,7 +150,7 @@ async def _tail_log_file(
137150 yield line_str
138151
139152 last_heartbeat_time = asyncio .get_event_loop ().time ()
140- last_cluster_status_check_time = asyncio .get_event_loop ().time ()
153+ last_status_check_time = asyncio .get_event_loop ().time ()
141154
142155 # Buffer the lines in memory and flush them in chunks to improve log
143156 # tailing throughput.
@@ -167,7 +180,17 @@ async def flush_buffer() -> AsyncGenerator[str, None]:
167180
168181 line : Optional [bytes ] = await f .readline ()
169182 if not line :
170- if request_id is not None :
183+ # Avoid checking the status too frequently to avoid overloading the
184+ # DB.
185+ should_check_status = (current_time -
186+ last_status_check_time ) >= polling_interval
187+ if not follow :
188+ # We will only hit this path once, but we should make sure to
189+ # check the status so that we display the final request status
190+ # if the request is complete.
191+ should_check_status = True
192+ if request_id is not None and should_check_status :
193+ last_status_check_time = current_time
171194 req_status = await requests_lib .get_request_status_async (
172195 request_id )
173196 if req_status .status > requests_lib .RequestStatus .RUNNING :
@@ -185,20 +208,19 @@ async def flush_buffer() -> AsyncGenerator[str, None]:
185208 ' cancelled\n ' )
186209 break
187210 if not follow :
211+ # The below checks (cluster status, heartbeat) are not needed
212+ # for non-follow logs.
188213 break
189214 # Provision logs pass in cluster_name, check cluster status
190- # periodically to see if provisioning is done. We only
191- # check once a second to avoid overloading the DB.
192- check_status = (current_time - last_cluster_status_check_time
193- ) >= _CLUSTER_STATUS_INTERVAL
194- if cluster_name is not None and check_status :
215+ # periodically to see if provisioning is done.
216+ if cluster_name is not None and should_check_status :
217+ last_status_check_time = current_time
195218 cluster_record = await (
196219 global_user_state .get_status_from_cluster_name_async (
197220 cluster_name ))
198221 if (cluster_record is None or
199222 cluster_record != status_lib .ClusterStatus .INIT ):
200223 break
201- last_cluster_status_check_time = current_time
202224 if current_time - last_heartbeat_time >= _HEARTBEAT_INTERVAL :
203225 # Currently just used to keep the connection busy, refer to
204226 # https://github.com/skypilot-org/skypilot/issues/5750 for
@@ -234,9 +256,22 @@ async def flush_buffer() -> AsyncGenerator[str, None]:
234256 yield chunk
235257
236258
259+ def stream_response_for_long_request (
260+ request_id : str ,
261+ logs_path : pathlib .Path ,
262+ background_tasks : fastapi .BackgroundTasks ,
263+ ) -> fastapi .responses .StreamingResponse :
264+ return stream_response (request_id ,
265+ logs_path ,
266+ background_tasks ,
267+ polling_interval = LONG_REQUEST_POLL_INTERVAL )
268+
269+
237270def stream_response (
238- request_id : str , logs_path : pathlib .Path ,
239- background_tasks : fastapi .BackgroundTasks
271+ request_id : str ,
272+ logs_path : pathlib .Path ,
273+ background_tasks : fastapi .BackgroundTasks ,
274+ polling_interval : float = DEFAULT_POLL_INTERVAL
240275) -> fastapi .responses .StreamingResponse :
241276
242277 async def on_disconnect ():
@@ -249,7 +284,7 @@ async def on_disconnect():
249284 background_tasks .add_task (on_disconnect )
250285
251286 return fastapi .responses .StreamingResponse (
252- log_streamer (request_id , logs_path ),
287+ log_streamer (request_id , logs_path , polling_interval = polling_interval ),
253288 media_type = 'text/plain' ,
254289 headers = {
255290 'Cache-Control' : 'no-cache, no-transform' ,
0 commit comments