1111import asyncio
1212from http .cookiejar import MozillaCookieJar
1313import os
14+ import struct
1415import sys
15- from typing import Dict
16+ import time
17+ from typing import Dict , Optional
1618from urllib .request import Request
1719
20+ import requests
1821import websockets
1922from websockets .asyncio .client import ClientConnection
2023from websockets .asyncio .client import connect
2124
25+ from sky .server import constants
26+ from sky .server .server import KubernetesSSHMessageType
27+ from sky .skylet import constants as skylet_constants
28+
2229BUFFER_SIZE = 2 ** 16 # 64KB
30+ HEARTBEAT_INTERVAL_SECONDS = 10
2331
2432# Environment variable for a file path to the API cookie file.
2533# Keep in sync with server/constants.py
2836# Keep in sync with server/constants.py
2937API_COOKIE_FILE_DEFAULT_LOCATION = '~/.sky/cookies.txt'
3038
39+ MAX_UNANSWERED_PINGS = 100
40+
3141
3242def _get_cookie_header (url : str ) -> Dict [str , str ]:
3343 """Extract Cookie header value from a cookie jar for a specific URL"""
@@ -49,7 +59,7 @@ def _get_cookie_header(url: str) -> Dict[str, str]:
4959 return {'Cookie' : cookie_header }
5060
5161
52- async def main (url : str ) -> None :
62+ async def main (url : str , timestamps_supported : bool ) -> None :
5363 cookie_header = _get_cookie_header (url )
5464 async with connect (url ,
5565 ping_interval = None ,
@@ -75,45 +85,149 @@ async def main(url: str) -> None:
7585 asyncio .streams .FlowControlMixin , sys .stdout ) # type: ignore
7686 stdout_writer = asyncio .StreamWriter (transport , protocol , None ,
7787 loop )
88+ # Dictionary to store last ping time for latency measurement
89+ last_ping_time_dict : Optional [Dict [int , float ]] = None
90+ if timestamps_supported :
91+ last_ping_time_dict = {}
92+
93+ # Use an Event to signal when websocket is closed
94+ websocket_closed_event = asyncio .Event ()
95+ websocket_lock = asyncio .Lock ()
7896
79- await asyncio .gather (stdin_to_websocket (stdin_reader , websocket ),
80- websocket_to_stdout (websocket , stdout_writer ))
97+ await asyncio .gather (
98+ stdin_to_websocket (stdin_reader , websocket ,
99+ timestamps_supported , websocket_closed_event ,
100+ websocket_lock ),
101+ websocket_to_stdout (websocket , stdout_writer ,
102+ timestamps_supported , last_ping_time_dict ,
103+ websocket_closed_event , websocket_lock ),
104+ latency_monitor (websocket , last_ping_time_dict ,
105+ websocket_closed_event , websocket_lock ),
106+ return_exceptions = True )
81107 finally :
82108 if old_settings :
83109 termios .tcsetattr (sys .stdin .fileno (), termios .TCSADRAIN ,
84110 old_settings )
85111
86112
113+ async def latency_monitor (websocket : ClientConnection ,
114+ last_ping_time_dict : Optional [dict ],
115+ websocket_closed_event : asyncio .Event ,
116+ websocket_lock : asyncio .Lock ):
117+ """Periodically send PING messages (type 1) to measure latency."""
118+ if last_ping_time_dict is None :
119+ return
120+ next_id = 0
121+ while not websocket_closed_event .is_set ():
122+ try :
123+ await asyncio .sleep (HEARTBEAT_INTERVAL_SECONDS )
124+ if len (last_ping_time_dict ) >= MAX_UNANSWERED_PINGS :
125+ # We are not getting responses, clear the dictionary so
126+ # as not to grow unbounded.
127+ last_ping_time_dict .clear ()
128+ ping_time = time .time ()
129+ next_id += 1
130+ last_ping_time_dict [next_id ] = ping_time
131+ message_header_bytes = struct .pack (
132+ '!BI' , KubernetesSSHMessageType .PINGPONG .value , next_id )
133+ try :
134+ async with websocket_lock :
135+ await websocket .send (message_header_bytes )
136+ except websockets .exceptions .ConnectionClosed as e :
137+ # Websocket is already closed.
138+ print (f'Failed to send PING message: { e } ' , file = sys .stderr )
139+ break
140+ except Exception as e :
141+ print (f'Error in latency_monitor: { e } ' , file = sys .stderr )
142+ websocket_closed_event .set ()
143+ raise e
144+
145+
87146async def stdin_to_websocket (reader : asyncio .StreamReader ,
88- websocket : ClientConnection ):
147+ websocket : ClientConnection ,
148+ timestamps_supported : bool ,
149+ websocket_closed_event : asyncio .Event ,
150+ websocket_lock : asyncio .Lock ):
89151 try :
90- while True :
152+ while not websocket_closed_event . is_set () :
91153 # Read at most BUFFER_SIZE bytes, this not affect
92154 # responsiveness since it will return as soon as
93155 # there is at least one byte.
94156 # The BUFFER_SIZE is chosen to be large enough to improve
95157 # throughput.
96158 data = await reader .read (BUFFER_SIZE )
159+
97160 if not data :
98161 break
99- await websocket .send (data )
162+ if timestamps_supported :
163+ # Send message with type 0 to indicate data.
164+ message_type_bytes = struct .pack (
165+ '!B' , KubernetesSSHMessageType .REGULAR_DATA .value )
166+ data = message_type_bytes + data
167+ async with websocket_lock :
168+ await websocket .send (data )
169+
100170 except Exception as e : # pylint: disable=broad-except
101171 print (f'Error in stdin_to_websocket: { e } ' , file = sys .stderr )
102172 finally :
103- await websocket .close ()
173+ async with websocket_lock :
174+ await websocket .close ()
175+ websocket_closed_event .set ()
104176
105177
106178async def websocket_to_stdout (websocket : ClientConnection ,
107- writer : asyncio .StreamWriter ):
179+ writer : asyncio .StreamWriter ,
180+ timestamps_supported : bool ,
181+ last_ping_time_dict : Optional [dict ],
182+ websocket_closed_event : asyncio .Event ,
183+ websocket_lock : asyncio .Lock ):
108184 try :
109- while True :
185+ while not websocket_closed_event . is_set () :
110186 message = await websocket .recv ()
187+ if (timestamps_supported and len (message ) > 0 and
188+ last_ping_time_dict is not None ):
189+ message_type = struct .unpack ('!B' , message [:1 ])[0 ]
190+ if message_type == KubernetesSSHMessageType .REGULAR_DATA .value :
191+ # Regular data - strip type byte and write to stdout
192+ message = message [1 :]
193+ elif message_type == KubernetesSSHMessageType .PINGPONG .value :
194+ # PONG response - calculate latency and send measurement
195+ if not len (message ) == struct .calcsize ('!BI' ):
196+ raise ValueError (
197+ f'Invalid PONG message length: { len (message )} ' )
198+ pong_id = struct .unpack ('!I' , message [1 :5 ])[0 ]
199+ pong_time = time .time ()
200+
201+ ping_time = last_ping_time_dict .pop (pong_id , None )
202+
203+ if ping_time is None :
204+ continue
205+
206+ latency_seconds = pong_time - ping_time
207+ latency_ms = int (latency_seconds * 1000 )
208+
209+ # Send latency measurement (type 2)
210+ message_type_bytes = struct .pack (
211+ '!B' ,
212+ KubernetesSSHMessageType .LATENCY_MEASUREMENT .value )
213+ latency_bytes = struct .pack ('!Q' , latency_ms )
214+ message = message_type_bytes + latency_bytes
215+ # Send to server.
216+ async with websocket_lock :
217+ await websocket .send (message )
218+ continue
219+ # No timestamps support, write directly
111220 writer .write (message )
112221 await writer .drain ()
113222 except websockets .exceptions .ConnectionClosed :
114223 print ('WebSocket connection closed' , file = sys .stderr )
115224 except Exception as e : # pylint: disable=broad-except
116225 print (f'Error in websocket_to_stdout: { e } ' , file = sys .stderr )
226+ raise e
227+ finally :
228+ async with websocket_lock :
229+ await websocket .close ()
230+ websocket_closed_event .set ()
117231
118232
119233if __name__ == '__main__' :
@@ -123,11 +237,25 @@ async def websocket_to_stdout(websocket: ClientConnection,
123237 # TODO(aylei): Remove this after 0.10.0
124238 server_url = f'http://{ server_url } '
125239
240+ health_url = f'{ server_url } /api/health'
241+ health_response = requests .get (health_url )
242+ health_data = health_response .json ()
243+ timestamps_are_supported = int (health_data ['api_version' ]) > 21
244+ disable_latency_measurement = os .environ .get (
245+ skylet_constants .SSH_DISABLE_LATENCY_MEASUREMENT_ENV_VAR , '0' ) == '1'
246+ timestamps_are_supported = (timestamps_are_supported and
247+ not disable_latency_measurement )
248+
126249 server_proto , server_fqdn = server_url .split ('://' )
127250 websocket_proto = 'ws'
128251 if server_proto == 'https' :
129252 websocket_proto = 'wss'
130253 server_url = f'{ websocket_proto } ://{ server_fqdn } '
254+
255+ client_version_str = (f'&client_version={ constants .API_VERSION } '
256+ if timestamps_are_supported else '' )
257+
131258 websocket_url = (f'{ server_url } /kubernetes-pod-ssh-proxy'
132- f'?cluster_name={ sys .argv [2 ]} ' )
133- asyncio .run (main (websocket_url ))
259+ f'?cluster_name={ sys .argv [2 ]} '
260+ f'{ client_version_str } ' )
261+ asyncio .run (main (websocket_url , timestamps_are_supported ))
0 commit comments