2424
2525from  sky .server  import  constants 
2626from  sky .server .server  import  KubernetesSSHMessageType 
27+ from  sky .skylet  import  constants  as  skylet_constants 
2728
2829BUFFER_SIZE  =  2 ** 16   # 64KB 
29- HEARTBEAT_INTERVAL_SECONDS  =  60 
30+ HEARTBEAT_INTERVAL_SECONDS  =  10 
3031
3132# Environment variable for a file path to the API cookie file. 
3233# Keep in sync with server/constants.py 
3536# Keep in sync with server/constants.py 
3637API_COOKIE_FILE_DEFAULT_LOCATION  =  '~/.sky/cookies.txt' 
3738
39+ MAX_UNANSWERED_PINGS  =  100 
40+ 
3841
3942def  _get_cookie_header (url : str ) ->  Dict [str , str ]:
4043    """Extract Cookie header value from a cookie jar for a specific URL""" 
@@ -83,23 +86,27 @@ async def main(url: str, timestamps_supported: bool) -> None:
8386            stdout_writer  =  asyncio .StreamWriter (transport , protocol , None ,
8487                                                 loop )
8588            # Dictionary to store last ping time for latency measurement 
86-             last_ping_time_dict : Optional [dict ] =  None 
89+             last_ping_time_dict : Optional [Dict [int , float ]] =  None 
90+             last_ping_time_dict_lock  =  asyncio .Lock ()
8791            if  timestamps_supported :
8892                last_ping_time_dict  =  {}
8993
9094            # Use an Event to signal when websocket is closed 
9195            websocket_closed_event  =  asyncio .Event ()
96+             websocket_lock  =  asyncio .Lock ()
9297
93-             await  asyncio .gather (stdin_to_websocket (stdin_reader , websocket ,
94-                                                     timestamps_supported ,
95-                                                     websocket_closed_event ),
96-                                  websocket_to_stdout (websocket , stdout_writer ,
97-                                                      timestamps_supported ,
98-                                                      last_ping_time_dict ,
99-                                                      websocket_closed_event ),
100-                                  latency_monitor (websocket , last_ping_time_dict ,
101-                                                  websocket_closed_event ),
102-                                  return_exceptions = True )
98+             await  asyncio .gather (
99+                 stdin_to_websocket (stdin_reader , websocket ,
100+                                    timestamps_supported , websocket_closed_event ,
101+                                    websocket_lock ),
102+                 websocket_to_stdout (websocket , stdout_writer ,
103+                                     timestamps_supported , last_ping_time_dict ,
104+                                     last_ping_time_dict_lock ,
105+                                     websocket_closed_event , websocket_lock ),
106+                 latency_monitor (websocket , last_ping_time_dict ,
107+                                 last_ping_time_dict_lock ,
108+                                 websocket_closed_event , websocket_lock ),
109+                 return_exceptions = True )
103110        finally :
104111            if  old_settings :
105112                termios .tcsetattr (sys .stdin .fileno (), termios .TCSADRAIN ,
@@ -108,19 +115,30 @@ async def main(url: str, timestamps_supported: bool) -> None:
108115
109116async  def  latency_monitor (websocket : ClientConnection ,
110117                          last_ping_time_dict : Optional [dict ],
111-                           websocket_closed_event : asyncio .Event ):
118+                           last_ping_time_dict_lock : asyncio .Lock ,
119+                           websocket_closed_event : asyncio .Event ,
120+                           websocket_lock : asyncio .Lock ):
112121    """Periodically send PING messages (type 1) to measure latency.""" 
113122    if  last_ping_time_dict  is  None :
114123        return 
124+     next_id  =  0 
115125    while  not  websocket_closed_event .is_set ():
116126        try :
117127            await  asyncio .sleep (HEARTBEAT_INTERVAL_SECONDS )
128+             async  with  last_ping_time_dict_lock :
129+                 if  len (last_ping_time_dict ) >=  MAX_UNANSWERED_PINGS :
130+                     # We are not getting responses, clear the dictionary so 
131+                     # as not to grow unbounded. 
132+                     last_ping_time_dict .clear ()
118133            ping_time  =  time .time ()
119-             last_ping_time_dict ['time' ] =  ping_time 
120-             message_type_bytes  =  struct .pack (
121-                 '!B' , KubernetesSSHMessageType .PINGPONG .value )
134+             next_id  +=  1 
135+             async  with  last_ping_time_dict_lock :
136+                 last_ping_time_dict [next_id ] =  ping_time 
137+             message_header_bytes  =  struct .pack (
138+                 '!BI' , KubernetesSSHMessageType .PINGPONG .value , next_id )
122139            try :
123-                 await  websocket .send (message_type_bytes )
140+                 async  with  websocket_lock :
141+                     await  websocket .send (message_header_bytes )
124142            except  websockets .exceptions .ConnectionClosed  as  e :
125143                # Websocket is already closed. 
126144                print (f'Failed to send PING message: { e }  ' , file = sys .stderr )
@@ -134,7 +152,8 @@ async def latency_monitor(websocket: ClientConnection,
134152async  def  stdin_to_websocket (reader : asyncio .StreamReader ,
135153                             websocket : ClientConnection ,
136154                             timestamps_supported : bool ,
137-                              websocket_closed_event : asyncio .Event ):
155+                              websocket_closed_event : asyncio .Event ,
156+                              websocket_lock : asyncio .Lock ):
138157    try :
139158        while  not  websocket_closed_event .is_set ():
140159            # Read at most BUFFER_SIZE bytes, this not affect 
@@ -151,64 +170,74 @@ async def stdin_to_websocket(reader: asyncio.StreamReader,
151170                message_type_bytes  =  struct .pack (
152171                    '!B' , KubernetesSSHMessageType .REGULAR_DATA .value )
153172                data  =  message_type_bytes  +  data 
154-             await  websocket .send (data )
173+             async  with  websocket_lock :
174+                 await  websocket .send (data )
155175
156176    except  Exception  as  e :  # pylint: disable=broad-except 
157177        print (f'Error in stdin_to_websocket: { e }  ' , file = sys .stderr )
158178    finally :
159-         await  websocket .close ()
179+         async  with  websocket_lock :
180+             await  websocket .close ()
160181        websocket_closed_event .set ()
161182
162183
163- async  def  websocket_to_stdout (websocket :  ClientConnection , 
164-                                writer : asyncio .StreamWriter ,
165-                                timestamps_supported : bool ,
166-                                last_ping_time_dict :  Optional [ dict ] ,
167-                                websocket_closed_event : asyncio .Event ):
184+ async  def  websocket_to_stdout (
185+         websocket :  ClientConnection ,  writer : asyncio .StreamWriter ,
186+         timestamps_supported : bool ,  last_ping_time_dict :  Optional [ dict ] ,
187+         last_ping_time_dict_lock :  asyncio . Lock ,
188+         websocket_closed_event : asyncio .Event ,  websocket_lock :  asyncio . Lock ):
168189    try :
169190        while  not  websocket_closed_event .is_set ():
170191            message  =  await  websocket .recv ()
171-             if  timestamps_supported  and  len (message ) >  0 :
192+             if  (timestamps_supported  and  len (message ) >  0  and 
193+                     last_ping_time_dict  is  not   None ):
172194                message_type  =  struct .unpack ('!B' , message [:1 ])[0 ]
173195                if  message_type  ==  KubernetesSSHMessageType .REGULAR_DATA .value :
174196                    # Regular data - strip type byte and write to stdout 
175-                     data  =  message [1 :]
176-                     writer .write (data )
177-                     await  writer .drain ()
197+                     message  =  message [1 :]
178198                elif  message_type  ==  KubernetesSSHMessageType .PINGPONG .value :
179199                    # PONG response - calculate latency and send measurement 
200+                     if  not  len (message ) ==  5 :
201+                         raise  ValueError (
202+                             f'Invalid PONG message length: { len (message )}  ' )
203+                     pong_id  =  struct .unpack ('!I' , message [1 :5 ])[0 ]
180204                    pong_time  =  time .time ()
181-                     if  last_ping_time_dict  and  'time'  in  last_ping_time_dict :
182-                         ping_time  =  last_ping_time_dict ['time' ]
183-                         latency_seconds  =  pong_time  -  ping_time 
184-                         latency_ms  =  int (latency_seconds  *  1000 )
185- 
186-                         # Send latency measurement (type 2) 
187-                         message_type_bytes  =  struct .pack (
188-                             '!B' ,
189-                             KubernetesSSHMessageType .LATENCY_MEASUREMENT .value )
190-                         latency_bytes  =  struct .pack ('!Q' , latency_ms )
191-                         data  =  message_type_bytes  +  latency_bytes 
205+ 
206+                     ping_time  =  None 
207+                     async  with  last_ping_time_dict_lock :
192208                        try :
193-                             await  websocket .send (data )
194-                         except  Exception  as  e :  # pylint: disable=broad-except 
195-                             print (f'Failed to send latency measurement: { e }  ' ,
196-                                   file = sys .stderr )
197-                 else :
198-                     # Unknown message type, write as-is 
199-                     writer .write (message )
200-                     await  writer .drain ()
201-             else :
202-                 # No timestamps support, write directly 
203-                 writer .write (message )
204-                 await  writer .drain ()
209+                             ping_time  =  last_ping_time_dict .pop (pong_id )
210+                         except  KeyError :
211+                             # We don't have a matching ping, ignore the pong. 
212+                             pass 
213+ 
214+                     if  ping_time  is  None :
215+                         continue 
216+ 
217+                     latency_seconds  =  pong_time  -  ping_time 
218+                     latency_ms  =  int (latency_seconds  *  1000 )
219+ 
220+                     # Send latency measurement (type 2) 
221+                     message_type_bytes  =  struct .pack (
222+                         '!B' ,
223+                         KubernetesSSHMessageType .LATENCY_MEASUREMENT .value )
224+                     latency_bytes  =  struct .pack ('!Q' , latency_ms )
225+                     message  =  message_type_bytes  +  latency_bytes 
226+                     # Send to server. 
227+                     async  with  websocket_lock :
228+                         await  websocket .send (message )
229+                     continue 
230+             # No timestamps support, write directly 
231+             writer .write (message )
232+             await  writer .drain ()
205233    except  websockets .exceptions .ConnectionClosed :
206234        print ('WebSocket connection closed' , file = sys .stderr )
207235    except  Exception  as  e :  # pylint: disable=broad-except 
208236        print (f'Error in websocket_to_stdout: { e }  ' , file = sys .stderr )
209237        raise  e 
210238    finally :
211-         await  websocket .close ()
239+         async  with  websocket_lock :
240+             await  websocket .close ()
212241        websocket_closed_event .set ()
213242
214243
@@ -223,7 +252,10 @@ async def websocket_to_stdout(websocket: ClientConnection,
223252    health_response  =  requests .get (health_url )
224253    health_data  =  health_response .json ()
225254    timestamps_are_supported  =  int (health_data ['api_version' ]) >  21 
226-     print (f'Timestamps are supported: { timestamps_are_supported }  ' )
255+     disable_latency_measurement  =  os .environ .get (
256+         skylet_constants .SSH_DISABLE_LATENCY_MEASUREMENT_ENV_VAR , '0' ) ==  '1' 
257+     timestamps_are_supported  =  (timestamps_are_supported  and 
258+                                 not  disable_latency_measurement )
227259
228260    server_proto , server_fqdn  =  server_url .split ('://' )
229261    websocket_proto  =  'ws' 
0 commit comments