Skip to content

Commit 93cb798

Browse files
[Metrics] Add SSH Latency (#7538)
* Add new metric. * Add version support. * Change metric. * Working latency metric. * Tweak wording for metric. * Up interval. * Bump version again. * Add locking. * Remove lock --------- Co-authored-by: lloydbrownjr <[email protected]>
1 parent 0a2cecd commit 93cb798

File tree

5 files changed

+219
-15
lines changed

5 files changed

+219
-15
lines changed

sky/metrics/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,24 @@
143143
'RSS increment after requests', ['name'],
144144
buckets=_MEM_BUCKETS)
145145

146+
SKY_APISERVER_WEBSOCKET_SSH_LATENCY_SECONDS = prom.Histogram(
147+
'sky_apiserver_websocket_ssh_latency_seconds',
148+
('Time taken for ssh message to go from client to API server and back'
149+
'to the client. This does not include: latency to reach the pod, '
150+
'overhead from sending through the k8s port-forward tunnel, or '
151+
'ssh server lag on the destination pod.'),
152+
['pid'],
153+
buckets=(0.001, 0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.25,
154+
0.35, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2, 2.5, 2.75, 3, 3.5, 4, 4.5,
155+
5, 7.5, 10.0, 12.5, 15.0, 17.5, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0,
156+
50.0, 55.0, 60.0, 80.0, 120.0, 140.0, 160.0, 180.0, 200.0, 220.0,
157+
240.0, 260.0, 280.0, 300.0, 320.0, 340.0, 360.0, 380.0, 400.0,
158+
420.0, 440.0, 460.0, 480.0, 500.0, 520.0, 540.0, 560.0, 580.0,
159+
600.0, 620.0, 640.0, 660.0, 680.0, 700.0, 720.0, 740.0, 760.0,
160+
780.0, 800.0, 820.0, 840.0, 860.0, 880.0, 900.0, 920.0, 940.0,
161+
960.0, 980.0, 1000.0, float('inf')),
162+
)
163+
146164

147165
@contextlib.contextmanager
148166
def time_it(name: str, group: str = 'default'):

sky/server/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# based on version info is needed.
1111
# For more details and code guidelines, refer to:
1212
# https://docs.skypilot.co/en/latest/developers/CONTRIBUTING.html#backward-compatibility-guidelines
13-
API_VERSION = 21
13+
API_VERSION = 22
1414

1515
# The minimum peer API version that the code should still work with.
1616
# Notes (dev):

sky/server/server.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from concurrent.futures import ThreadPoolExecutor
77
import contextlib
88
import datetime
9+
from enum import IntEnum
910
import hashlib
1011
import json
1112
import multiprocessing
@@ -15,6 +16,7 @@
1516
import re
1617
import resource
1718
import shutil
19+
import struct
1820
import sys
1921
import threading
2022
import traceback
@@ -1809,13 +1811,25 @@ async def health(request: fastapi.Request) -> responses.APIHealthResponse:
18091811
)
18101812

18111813

1814+
class KubernetesSSHMessageType(IntEnum):
1815+
REGULAR_DATA = 0
1816+
PINGPONG = 1
1817+
LATENCY_MEASUREMENT = 2
1818+
1819+
18121820
@app.websocket('/kubernetes-pod-ssh-proxy')
1813-
async def kubernetes_pod_ssh_proxy(websocket: fastapi.WebSocket,
1814-
cluster_name: str) -> None:
1821+
async def kubernetes_pod_ssh_proxy(
1822+
websocket: fastapi.WebSocket,
1823+
cluster_name: str,
1824+
client_version: Optional[int] = None) -> None:
18151825
"""Proxies SSH to the Kubernetes pod with websocket."""
18161826
await websocket.accept()
18171827
logger.info(f'WebSocket connection accepted for cluster: {cluster_name}')
18181828

1829+
timestamps_supported = client_version is not None and client_version > 21
1830+
logger.info(f'Websocket timestamps supported: {timestamps_supported}, \
1831+
client_version = {client_version}')
1832+
18191833
# Run core.status in another thread to avoid blocking the event loop.
18201834
with ThreadPoolExecutor(max_workers=1) as thread_pool_executor:
18211835
cluster_records = await context_utils.to_thread_with_executor(
@@ -1870,6 +1884,42 @@ async def kubernetes_pod_ssh_proxy(websocket: fastapi.WebSocket,
18701884
async def websocket_to_ssh():
18711885
try:
18721886
async for message in websocket.iter_bytes():
1887+
if timestamps_supported:
1888+
type_size = struct.calcsize('!B')
1889+
message_type = struct.unpack('!B',
1890+
message[:type_size])[0]
1891+
if (message_type ==
1892+
KubernetesSSHMessageType.REGULAR_DATA):
1893+
# Regular data - strip type byte and forward to SSH
1894+
message = message[type_size:]
1895+
elif message_type == KubernetesSSHMessageType.PINGPONG:
1896+
# PING message - respond with PONG (type 1)
1897+
ping_id_size = struct.calcsize('!I')
1898+
if len(message) != type_size + ping_id_size:
1899+
raise ValueError('Invalid PING message '
1900+
f'length: {len(message)}')
1901+
# Return the same PING message, so that the client
1902+
# can measure the latency.
1903+
await websocket.send_bytes(message)
1904+
continue
1905+
elif (message_type ==
1906+
KubernetesSSHMessageType.LATENCY_MEASUREMENT):
1907+
# Latency measurement from client
1908+
latency_size = struct.calcsize('!Q')
1909+
if len(message) != type_size + latency_size:
1910+
raise ValueError(
1911+
'Invalid latency measurement '
1912+
f'message length: {len(message)}')
1913+
avg_latency_ms = struct.unpack(
1914+
'!Q',
1915+
message[type_size:type_size + latency_size])[0]
1916+
latency_seconds = avg_latency_ms / 1000
1917+
metrics_utils.SKY_APISERVER_WEBSOCKET_SSH_LATENCY_SECONDS.labels(pid=os.getpid()).observe(latency_seconds) # pylint: disable=line-too-long
1918+
continue
1919+
else:
1920+
# Unknown message type.
1921+
raise ValueError(
1922+
f'Unknown message type: {message_type}')
18731923
writer.write(message)
18741924
try:
18751925
await writer.drain()
@@ -1900,6 +1950,11 @@ async def ssh_to_websocket():
19001950
nonlocal ssh_failed
19011951
ssh_failed = True
19021952
break
1953+
if timestamps_supported:
1954+
# Prepend message type byte (0 = regular data)
1955+
message_type_bytes = struct.pack(
1956+
'!B', KubernetesSSHMessageType.REGULAR_DATA.value)
1957+
data = message_type_bytes + data
19031958
await websocket.send_bytes(data)
19041959
except Exception: # pylint: disable=broad-except
19051960
pass

sky/skylet/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,3 +548,6 @@
548548

549549
ARM64_ARCH = 'arm64'
550550
X86_64_ARCH = 'x86_64'
551+
552+
SSH_DISABLE_LATENCY_MEASUREMENT_ENV_VAR = (
553+
f'{SKYPILOT_ENV_VAR_PREFIX}SSH_DISABLE_LATENCY_MEASUREMENT')

sky/templates/websocket_proxy.py

Lines changed: 140 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,23 @@
1111
import asyncio
1212
from http.cookiejar import MozillaCookieJar
1313
import os
14+
import struct
1415
import sys
15-
from typing import Dict
16+
import time
17+
from typing import Dict, Optional
1618
from urllib.request import Request
1719

20+
import requests
1821
import websockets
1922
from websockets.asyncio.client import ClientConnection
2023
from websockets.asyncio.client import connect
2124

25+
from sky.server import constants
26+
from sky.server.server import KubernetesSSHMessageType
27+
from sky.skylet import constants as skylet_constants
28+
2229
BUFFER_SIZE = 2**16 # 64KB
30+
HEARTBEAT_INTERVAL_SECONDS = 10
2331

2432
# Environment variable for a file path to the API cookie file.
2533
# Keep in sync with server/constants.py
@@ -28,6 +36,8 @@
2836
# Keep in sync with server/constants.py
2937
API_COOKIE_FILE_DEFAULT_LOCATION = '~/.sky/cookies.txt'
3038

39+
MAX_UNANSWERED_PINGS = 100
40+
3141

3242
def _get_cookie_header(url: str) -> Dict[str, str]:
3343
"""Extract Cookie header value from a cookie jar for a specific URL"""
@@ -49,7 +59,7 @@ def _get_cookie_header(url: str) -> Dict[str, str]:
4959
return {'Cookie': cookie_header}
5060

5161

52-
async def main(url: str) -> None:
62+
async def main(url: str, timestamps_supported: bool) -> None:
5363
cookie_header = _get_cookie_header(url)
5464
async with connect(url,
5565
ping_interval=None,
@@ -75,45 +85,149 @@ async def main(url: str) -> None:
7585
asyncio.streams.FlowControlMixin, sys.stdout) # type: ignore
7686
stdout_writer = asyncio.StreamWriter(transport, protocol, None,
7787
loop)
88+
# Dictionary to store last ping time for latency measurement
89+
last_ping_time_dict: Optional[Dict[int, float]] = None
90+
if timestamps_supported:
91+
last_ping_time_dict = {}
92+
93+
# Use an Event to signal when websocket is closed
94+
websocket_closed_event = asyncio.Event()
95+
websocket_lock = asyncio.Lock()
7896

79-
await asyncio.gather(stdin_to_websocket(stdin_reader, websocket),
80-
websocket_to_stdout(websocket, stdout_writer))
97+
await asyncio.gather(
98+
stdin_to_websocket(stdin_reader, websocket,
99+
timestamps_supported, websocket_closed_event,
100+
websocket_lock),
101+
websocket_to_stdout(websocket, stdout_writer,
102+
timestamps_supported, last_ping_time_dict,
103+
websocket_closed_event, websocket_lock),
104+
latency_monitor(websocket, last_ping_time_dict,
105+
websocket_closed_event, websocket_lock),
106+
return_exceptions=True)
81107
finally:
82108
if old_settings:
83109
termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN,
84110
old_settings)
85111

86112

113+
async def latency_monitor(websocket: ClientConnection,
114+
last_ping_time_dict: Optional[dict],
115+
websocket_closed_event: asyncio.Event,
116+
websocket_lock: asyncio.Lock):
117+
"""Periodically send PING messages (type 1) to measure latency."""
118+
if last_ping_time_dict is None:
119+
return
120+
next_id = 0
121+
while not websocket_closed_event.is_set():
122+
try:
123+
await asyncio.sleep(HEARTBEAT_INTERVAL_SECONDS)
124+
if len(last_ping_time_dict) >= MAX_UNANSWERED_PINGS:
125+
# We are not getting responses, clear the dictionary so
126+
# as not to grow unbounded.
127+
last_ping_time_dict.clear()
128+
ping_time = time.time()
129+
next_id += 1
130+
last_ping_time_dict[next_id] = ping_time
131+
message_header_bytes = struct.pack(
132+
'!BI', KubernetesSSHMessageType.PINGPONG.value, next_id)
133+
try:
134+
async with websocket_lock:
135+
await websocket.send(message_header_bytes)
136+
except websockets.exceptions.ConnectionClosed as e:
137+
# Websocket is already closed.
138+
print(f'Failed to send PING message: {e}', file=sys.stderr)
139+
break
140+
except Exception as e:
141+
print(f'Error in latency_monitor: {e}', file=sys.stderr)
142+
websocket_closed_event.set()
143+
raise e
144+
145+
87146
async def stdin_to_websocket(reader: asyncio.StreamReader,
88-
websocket: ClientConnection):
147+
websocket: ClientConnection,
148+
timestamps_supported: bool,
149+
websocket_closed_event: asyncio.Event,
150+
websocket_lock: asyncio.Lock):
89151
try:
90-
while True:
152+
while not websocket_closed_event.is_set():
91153
# Read at most BUFFER_SIZE bytes, this not affect
92154
# responsiveness since it will return as soon as
93155
# there is at least one byte.
94156
# The BUFFER_SIZE is chosen to be large enough to improve
95157
# throughput.
96158
data = await reader.read(BUFFER_SIZE)
159+
97160
if not data:
98161
break
99-
await websocket.send(data)
162+
if timestamps_supported:
163+
# Send message with type 0 to indicate data.
164+
message_type_bytes = struct.pack(
165+
'!B', KubernetesSSHMessageType.REGULAR_DATA.value)
166+
data = message_type_bytes + data
167+
async with websocket_lock:
168+
await websocket.send(data)
169+
100170
except Exception as e: # pylint: disable=broad-except
101171
print(f'Error in stdin_to_websocket: {e}', file=sys.stderr)
102172
finally:
103-
await websocket.close()
173+
async with websocket_lock:
174+
await websocket.close()
175+
websocket_closed_event.set()
104176

105177

106178
async def websocket_to_stdout(websocket: ClientConnection,
107-
writer: asyncio.StreamWriter):
179+
writer: asyncio.StreamWriter,
180+
timestamps_supported: bool,
181+
last_ping_time_dict: Optional[dict],
182+
websocket_closed_event: asyncio.Event,
183+
websocket_lock: asyncio.Lock):
108184
try:
109-
while True:
185+
while not websocket_closed_event.is_set():
110186
message = await websocket.recv()
187+
if (timestamps_supported and len(message) > 0 and
188+
last_ping_time_dict is not None):
189+
message_type = struct.unpack('!B', message[:1])[0]
190+
if message_type == KubernetesSSHMessageType.REGULAR_DATA.value:
191+
# Regular data - strip type byte and write to stdout
192+
message = message[1:]
193+
elif message_type == KubernetesSSHMessageType.PINGPONG.value:
194+
# PONG response - calculate latency and send measurement
195+
if not len(message) == struct.calcsize('!BI'):
196+
raise ValueError(
197+
f'Invalid PONG message length: {len(message)}')
198+
pong_id = struct.unpack('!I', message[1:5])[0]
199+
pong_time = time.time()
200+
201+
ping_time = last_ping_time_dict.pop(pong_id, None)
202+
203+
if ping_time is None:
204+
continue
205+
206+
latency_seconds = pong_time - ping_time
207+
latency_ms = int(latency_seconds * 1000)
208+
209+
# Send latency measurement (type 2)
210+
message_type_bytes = struct.pack(
211+
'!B',
212+
KubernetesSSHMessageType.LATENCY_MEASUREMENT.value)
213+
latency_bytes = struct.pack('!Q', latency_ms)
214+
message = message_type_bytes + latency_bytes
215+
# Send to server.
216+
async with websocket_lock:
217+
await websocket.send(message)
218+
continue
219+
# No timestamps support, write directly
111220
writer.write(message)
112221
await writer.drain()
113222
except websockets.exceptions.ConnectionClosed:
114223
print('WebSocket connection closed', file=sys.stderr)
115224
except Exception as e: # pylint: disable=broad-except
116225
print(f'Error in websocket_to_stdout: {e}', file=sys.stderr)
226+
raise e
227+
finally:
228+
async with websocket_lock:
229+
await websocket.close()
230+
websocket_closed_event.set()
117231

118232

119233
if __name__ == '__main__':
@@ -123,11 +237,25 @@ async def websocket_to_stdout(websocket: ClientConnection,
123237
# TODO(aylei): Remove this after 0.10.0
124238
server_url = f'http://{server_url}'
125239

240+
health_url = f'{server_url}/api/health'
241+
health_response = requests.get(health_url)
242+
health_data = health_response.json()
243+
timestamps_are_supported = int(health_data['api_version']) > 21
244+
disable_latency_measurement = os.environ.get(
245+
skylet_constants.SSH_DISABLE_LATENCY_MEASUREMENT_ENV_VAR, '0') == '1'
246+
timestamps_are_supported = (timestamps_are_supported and
247+
not disable_latency_measurement)
248+
126249
server_proto, server_fqdn = server_url.split('://')
127250
websocket_proto = 'ws'
128251
if server_proto == 'https':
129252
websocket_proto = 'wss'
130253
server_url = f'{websocket_proto}://{server_fqdn}'
254+
255+
client_version_str = (f'&client_version={constants.API_VERSION}'
256+
if timestamps_are_supported else '')
257+
131258
websocket_url = (f'{server_url}/kubernetes-pod-ssh-proxy'
132-
f'?cluster_name={sys.argv[2]}')
133-
asyncio.run(main(websocket_url))
259+
f'?cluster_name={sys.argv[2]}'
260+
f'{client_version_str}')
261+
asyncio.run(main(websocket_url, timestamps_are_supported))

0 commit comments

Comments
 (0)