Skip to content

Commit d633474

Browse files
committed
Add locking.
1 parent da04213 commit d633474

File tree

4 files changed

+103
-58
lines changed

4 files changed

+103
-58
lines changed

sky/metrics/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,10 @@
145145

146146
SKY_APISERVER_WEBSOCKET_SSH_LATENCY_SECONDS = prom.Histogram(
147147
'sky_apiserver_websocket_ssh_latency_seconds',
148-
'Time taken for ssh message to go from client to API server and back.',
148+
('Time taken for ssh message to go from client to API server and back'
149+
'to the client. This does not include: latency to reach the pod, '
150+
'overhead from sending through the k8s port-forward tunnel, or '
151+
'ssh server lag on the destination pod.'),
149152
['pid'],
150153
buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.25,
151154
0.35, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2, 2.5, 2.75, 3, 3.5, 4, 4.5,

sky/server/server.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,13 +1863,20 @@ async def websocket_to_ssh():
18631863
message = message[1:]
18641864
elif message_type == KubernetesSSHMessageType.PINGPONG:
18651865
# PING message - respond with PONG (type 1)
1866-
pong_message = struct.pack(
1867-
'!B', KubernetesSSHMessageType.PINGPONG.value)
1868-
await websocket.send_bytes(pong_message)
1866+
if len(message) != 5:
1867+
raise ValueError('Invalid PING message '
1868+
f'length: {len(message)}')
1869+
# Return the same PING message, so that the client
1870+
# can measure the latency.
1871+
await websocket.send_bytes(message)
18691872
continue
18701873
elif (message_type ==
18711874
KubernetesSSHMessageType.LATENCY_MEASUREMENT):
18721875
# Latency measurement from client
1876+
if len(message) != 9:
1877+
raise ValueError(
1878+
'Invalid latency measurement '
1879+
f'message length: {len(message)}')
18731880
avg_latency_ms = struct.unpack('!Q',
18741881
message[1:9])[0]
18751882
latency_seconds = avg_latency_ms / 1000

sky/skylet/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,3 +548,6 @@
548548

549549
ARM64_ARCH = 'arm64'
550550
X86_64_ARCH = 'x86_64'
551+
552+
SSH_DISABLE_LATENCY_MEASUREMENT_ENV_VAR = (
553+
f'{SKYPILOT_ENV_VAR_PREFIX}SSH_DISABLE_LATENCY_MEASUREMENT')

sky/templates/websocket_proxy.py

Lines changed: 86 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424

2525
from sky.server import constants
2626
from sky.server.server import KubernetesSSHMessageType
27+
from sky.skylet import constants as skylet_constants
2728

2829
BUFFER_SIZE = 2**16 # 64KB
29-
HEARTBEAT_INTERVAL_SECONDS = 60
30+
HEARTBEAT_INTERVAL_SECONDS = 10
3031

3132
# Environment variable for a file path to the API cookie file.
3233
# Keep in sync with server/constants.py
@@ -35,6 +36,8 @@
3536
# Keep in sync with server/constants.py
3637
API_COOKIE_FILE_DEFAULT_LOCATION = '~/.sky/cookies.txt'
3738

39+
MAX_UNANSWERED_PINGS = 100
40+
3841

3942
def _get_cookie_header(url: str) -> Dict[str, str]:
4043
"""Extract Cookie header value from a cookie jar for a specific URL"""
@@ -83,23 +86,27 @@ async def main(url: str, timestamps_supported: bool) -> None:
8386
stdout_writer = asyncio.StreamWriter(transport, protocol, None,
8487
loop)
8588
# Dictionary to store last ping time for latency measurement
86-
last_ping_time_dict: Optional[dict] = None
89+
last_ping_time_dict: Optional[Dict[int, float]] = None
90+
last_ping_time_dict_lock = asyncio.Lock()
8791
if timestamps_supported:
8892
last_ping_time_dict = {}
8993

9094
# Use an Event to signal when websocket is closed
9195
websocket_closed_event = asyncio.Event()
96+
websocket_lock = asyncio.Lock()
9297

93-
await asyncio.gather(stdin_to_websocket(stdin_reader, websocket,
94-
timestamps_supported,
95-
websocket_closed_event),
96-
websocket_to_stdout(websocket, stdout_writer,
97-
timestamps_supported,
98-
last_ping_time_dict,
99-
websocket_closed_event),
100-
latency_monitor(websocket, last_ping_time_dict,
101-
websocket_closed_event),
102-
return_exceptions=True)
98+
await asyncio.gather(
99+
stdin_to_websocket(stdin_reader, websocket,
100+
timestamps_supported, websocket_closed_event,
101+
websocket_lock),
102+
websocket_to_stdout(websocket, stdout_writer,
103+
timestamps_supported, last_ping_time_dict,
104+
last_ping_time_dict_lock,
105+
websocket_closed_event, websocket_lock),
106+
latency_monitor(websocket, last_ping_time_dict,
107+
last_ping_time_dict_lock,
108+
websocket_closed_event, websocket_lock),
109+
return_exceptions=True)
103110
finally:
104111
if old_settings:
105112
termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN,
@@ -108,19 +115,30 @@ async def main(url: str, timestamps_supported: bool) -> None:
108115

109116
async def latency_monitor(websocket: ClientConnection,
110117
last_ping_time_dict: Optional[dict],
111-
websocket_closed_event: asyncio.Event):
118+
last_ping_time_dict_lock: asyncio.Lock,
119+
websocket_closed_event: asyncio.Event,
120+
websocket_lock: asyncio.Lock):
112121
"""Periodically send PING messages (type 1) to measure latency."""
113122
if last_ping_time_dict is None:
114123
return
124+
next_id = 0
115125
while not websocket_closed_event.is_set():
116126
try:
117127
await asyncio.sleep(HEARTBEAT_INTERVAL_SECONDS)
128+
async with last_ping_time_dict_lock:
129+
if len(last_ping_time_dict) >= MAX_UNANSWERED_PINGS:
130+
# We are not getting responses, clear the dictionary so
131+
# as not to grow unbounded.
132+
last_ping_time_dict.clear()
118133
ping_time = time.time()
119-
last_ping_time_dict['time'] = ping_time
120-
message_type_bytes = struct.pack(
121-
'!B', KubernetesSSHMessageType.PINGPONG.value)
134+
next_id += 1
135+
async with last_ping_time_dict_lock:
136+
last_ping_time_dict[next_id] = ping_time
137+
message_header_bytes = struct.pack(
138+
'!BI', KubernetesSSHMessageType.PINGPONG.value, next_id)
122139
try:
123-
await websocket.send(message_type_bytes)
140+
async with websocket_lock:
141+
await websocket.send(message_header_bytes)
124142
except websockets.exceptions.ConnectionClosed as e:
125143
# Websocket is already closed.
126144
print(f'Failed to send PING message: {e}', file=sys.stderr)
@@ -134,7 +152,8 @@ async def latency_monitor(websocket: ClientConnection,
134152
async def stdin_to_websocket(reader: asyncio.StreamReader,
135153
websocket: ClientConnection,
136154
timestamps_supported: bool,
137-
websocket_closed_event: asyncio.Event):
155+
websocket_closed_event: asyncio.Event,
156+
websocket_lock: asyncio.Lock):
138157
try:
139158
while not websocket_closed_event.is_set():
140159
# Read at most BUFFER_SIZE bytes, this not affect
@@ -151,64 +170,74 @@ async def stdin_to_websocket(reader: asyncio.StreamReader,
151170
message_type_bytes = struct.pack(
152171
'!B', KubernetesSSHMessageType.REGULAR_DATA.value)
153172
data = message_type_bytes + data
154-
await websocket.send(data)
173+
async with websocket_lock:
174+
await websocket.send(data)
155175

156176
except Exception as e: # pylint: disable=broad-except
157177
print(f'Error in stdin_to_websocket: {e}', file=sys.stderr)
158178
finally:
159-
await websocket.close()
179+
async with websocket_lock:
180+
await websocket.close()
160181
websocket_closed_event.set()
161182

162183

163-
async def websocket_to_stdout(websocket: ClientConnection,
164-
writer: asyncio.StreamWriter,
165-
timestamps_supported: bool,
166-
last_ping_time_dict: Optional[dict],
167-
websocket_closed_event: asyncio.Event):
184+
async def websocket_to_stdout(
185+
websocket: ClientConnection, writer: asyncio.StreamWriter,
186+
timestamps_supported: bool, last_ping_time_dict: Optional[dict],
187+
last_ping_time_dict_lock: asyncio.Lock,
188+
websocket_closed_event: asyncio.Event, websocket_lock: asyncio.Lock):
168189
try:
169190
while not websocket_closed_event.is_set():
170191
message = await websocket.recv()
171-
if timestamps_supported and len(message) > 0:
192+
if (timestamps_supported and len(message) > 0 and
193+
last_ping_time_dict is not None):
172194
message_type = struct.unpack('!B', message[:1])[0]
173195
if message_type == KubernetesSSHMessageType.REGULAR_DATA.value:
174196
# Regular data - strip type byte and write to stdout
175-
data = message[1:]
176-
writer.write(data)
177-
await writer.drain()
197+
message = message[1:]
178198
elif message_type == KubernetesSSHMessageType.PINGPONG.value:
179199
# PONG response - calculate latency and send measurement
200+
if not len(message) == 5:
201+
raise ValueError(
202+
f'Invalid PONG message length: {len(message)}')
203+
pong_id = struct.unpack('!I', message[1:5])[0]
180204
pong_time = time.time()
181-
if last_ping_time_dict and 'time' in last_ping_time_dict:
182-
ping_time = last_ping_time_dict['time']
183-
latency_seconds = pong_time - ping_time
184-
latency_ms = int(latency_seconds * 1000)
185-
186-
# Send latency measurement (type 2)
187-
message_type_bytes = struct.pack(
188-
'!B',
189-
KubernetesSSHMessageType.LATENCY_MEASUREMENT.value)
190-
latency_bytes = struct.pack('!Q', latency_ms)
191-
data = message_type_bytes + latency_bytes
205+
206+
ping_time = None
207+
async with last_ping_time_dict_lock:
192208
try:
193-
await websocket.send(data)
194-
except Exception as e: # pylint: disable=broad-except
195-
print(f'Failed to send latency measurement: {e}',
196-
file=sys.stderr)
197-
else:
198-
# Unknown message type, write as-is
199-
writer.write(message)
200-
await writer.drain()
201-
else:
202-
# No timestamps support, write directly
203-
writer.write(message)
204-
await writer.drain()
209+
ping_time = last_ping_time_dict.pop(pong_id)
210+
except KeyError:
211+
# We don't have a matching ping, ignore the pong.
212+
pass
213+
214+
if ping_time is None:
215+
continue
216+
217+
latency_seconds = pong_time - ping_time
218+
latency_ms = int(latency_seconds * 1000)
219+
220+
# Send latency measurement (type 2)
221+
message_type_bytes = struct.pack(
222+
'!B',
223+
KubernetesSSHMessageType.LATENCY_MEASUREMENT.value)
224+
latency_bytes = struct.pack('!Q', latency_ms)
225+
message = message_type_bytes + latency_bytes
226+
# Send to server.
227+
async with websocket_lock:
228+
await websocket.send(message)
229+
continue
230+
# No timestamps support, write directly
231+
writer.write(message)
232+
await writer.drain()
205233
except websockets.exceptions.ConnectionClosed:
206234
print('WebSocket connection closed', file=sys.stderr)
207235
except Exception as e: # pylint: disable=broad-except
208236
print(f'Error in websocket_to_stdout: {e}', file=sys.stderr)
209237
raise e
210238
finally:
211-
await websocket.close()
239+
async with websocket_lock:
240+
await websocket.close()
212241
websocket_closed_event.set()
213242

214243

@@ -223,7 +252,10 @@ async def websocket_to_stdout(websocket: ClientConnection,
223252
health_response = requests.get(health_url)
224253
health_data = health_response.json()
225254
timestamps_are_supported = int(health_data['api_version']) > 21
226-
print(f'Timestamps are supported: {timestamps_are_supported}')
255+
disable_latency_measurement = os.environ.get(
256+
skylet_constants.SSH_DISABLE_LATENCY_MEASUREMENT_ENV_VAR, '0') == '1'
257+
timestamps_are_supported = (timestamps_are_supported and
258+
not disable_latency_measurement)
227259

228260
server_proto, server_fqdn = server_url.split('://')
229261
websocket_proto = 'ws'

0 commit comments

Comments
 (0)