Skip to content

Commit b210f22

Browse files
authored
[https://nvbugs/5703953][fix] Preserving ip:port for trtllm-serve before initializing llm (#9646)
Signed-off-by: Junyi Xu <[email protected]>
1 parent 6dc8877 commit b210f22

File tree

6 files changed

+108
-67
lines changed

6 files changed

+108
-67
lines changed

tensorrt_llm/commands/serve.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import os
55
import signal # Added import
6+
import socket
67
import subprocess # nosec B404
78
import sys
89
from pathlib import Path
@@ -176,37 +177,43 @@ def launch_server(
176177

177178
backend = llm_args["backend"]
178179
model = llm_args["model"]
179-
if backend == 'pytorch':
180-
llm_args.pop("build_config", None)
181-
llm = PyTorchLLM(**llm_args)
182-
elif backend == '_autodeploy':
183-
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
184-
185-
# AutoDeploy does not support build_config
186-
llm_args.pop("build_config", None)
187-
llm = AutoDeployLLM(**llm_args)
188-
elif backend == 'tensorrt' or backend == 'trt':
189-
llm_args.pop("backend")
190-
llm = LLM(**llm_args)
191-
else:
192-
raise click.BadParameter(
193-
f"{backend} is not a known backend, check help for available options.",
194-
param_hint="backend")
195-
196-
server = OpenAIServer(llm=llm,
197-
model=model,
198-
tool_parser=tool_parser,
199-
server_role=server_role,
200-
metadata_server_cfg=metadata_server_cfg,
201-
disagg_cluster_config=disagg_cluster_config,
202-
multimodal_server_config=multimodal_server_config,
203-
chat_template=chat_template)
204-
205-
# Optionally disable GC (default: not disabled)
206-
if os.getenv("TRTLLM_SERVER_DISABLE_GC", "0") == "1":
207-
gc.disable()
208-
209-
asyncio.run(server(host, port))
180+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
181+
try:
182+
s.bind((host, port))
183+
except OSError as e:
184+
raise RuntimeError(f"Failed to bind socket to {host}:{port}: {e}")
185+
186+
if backend == 'pytorch':
187+
llm_args.pop("build_config", None)
188+
llm = PyTorchLLM(**llm_args)
189+
elif backend == '_autodeploy':
190+
from tensorrt_llm._torch.auto_deploy import LLM as AutoDeployLLM
191+
192+
# AutoDeploy does not support build_config
193+
llm_args.pop("build_config", None)
194+
llm = AutoDeployLLM(**llm_args)
195+
elif backend == 'tensorrt' or backend == 'trt':
196+
llm_args.pop("backend")
197+
llm = LLM(**llm_args)
198+
else:
199+
raise click.BadParameter(
200+
f"{backend} is not a known backend, check help for available options.",
201+
param_hint="backend")
202+
203+
server = OpenAIServer(llm=llm,
204+
model=model,
205+
tool_parser=tool_parser,
206+
server_role=server_role,
207+
metadata_server_cfg=metadata_server_cfg,
208+
disagg_cluster_config=disagg_cluster_config,
209+
multimodal_server_config=multimodal_server_config,
210+
chat_template=chat_template)
211+
212+
# Optionally disable GC (default: not disabled)
213+
if os.getenv("TRTLLM_SERVER_DISABLE_GC", "0") == "1":
214+
gc.disable()
215+
216+
asyncio.run(server(host, port, sockets=[s]))
210217

211218

212219
def launch_mm_encoder_server(

tensorrt_llm/serve/openai_server.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import re
55
import signal
6+
import socket
67
import traceback
78
from collections import deque
89
from contextlib import asynccontextmanager
@@ -990,7 +991,7 @@ async def create_stream_response(generator, request: ResponsesRequest, sampling_
990991
return JSONResponse(content={"detail": "None"})
991992

992993

993-
async def __call__(self, host, port):
994+
async def __call__(self, host, port, sockets: list[socket.socket] | None = None):
994995
# Store the binding address for server registration
995996
self.binding_addr = f"http://{host}:{port}"
996997
self.host = host
@@ -1000,4 +1001,4 @@ async def __call__(self, host, port):
10001001
port=port,
10011002
log_level="info",
10021003
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
1003-
await uvicorn.Server(config).serve()
1004+
await uvicorn.Server(config).serve(sockets=sockets)

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import pytest
1414
import requests
1515
import yaml
16+
from defs.common import revise_disaggregated_server_config_urls_with_free_ports
1617

17-
from tensorrt_llm._utils import get_free_port
1818
from tensorrt_llm.executor.result import GenerationResultBase
1919
from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams
2020
from tensorrt_llm.llmapi.llm_args import LlmArgs
@@ -68,23 +68,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
6868
return False
6969

7070

71-
def revise_disaggregated_server_config_urls_with_free_ports(
72-
disaggregated_server_config: Dict[str, Any]) -> Dict[str, Any]:
73-
num_ctx_ports = len(disaggregated_server_config["context_servers"]["urls"])
74-
num_gen_ports = len(
75-
disaggregated_server_config["generation_servers"]["urls"])
76-
77-
disaggregated_server_config['port'] = get_free_port()
78-
disaggregated_server_config["context_servers"]["urls"] = [
79-
f"localhost:{get_free_port()}" for _ in range(num_ctx_ports)
80-
]
81-
disaggregated_server_config["generation_servers"]["urls"] = [
82-
f"localhost:{get_free_port()}" for _ in range(num_gen_ports)
83-
]
84-
85-
return disaggregated_server_config
86-
87-
8871
@contextlib.contextmanager
8972
def launch_disaggregated_llm(
9073
disaggregated_server_config: Dict[str, Any],

tests/integration/defs/common.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@
1717
import platform
1818
import re
1919
import socket
20+
import tempfile
2021
import time
2122
from difflib import SequenceMatcher
2223
from pathlib import Path
24+
from typing import Any
2325

26+
import yaml
2427
from packaging import version
2528

2629
from tensorrt_llm import LLM as LLM_torch
30+
from tensorrt_llm._utils import get_free_port
2731
from tensorrt_llm.executor.request import LoRARequest
2832
from tensorrt_llm.lora_manager import LoraConfig
2933
from tensorrt_llm.sampling_params import SamplingParams
@@ -1147,3 +1151,41 @@ def wait_for_server(host, port, timeout_seconds=180):
11471151
except (socket.error, ConnectionRefusedError, OSError):
11481152
time.sleep(2)
11491153
return False
1154+
1155+
1156+
def revise_disaggregated_server_config_urls_with_free_ports(
1157+
disaggregated_server_config: dict[str, Any]) -> dict[str, Any]:
1158+
# Revise serve port
1159+
disaggregated_server_config['port'] = get_free_port()
1160+
1161+
# Revise context and generation server urls
1162+
ctx_urls = disaggregated_server_config["context_servers"]["urls"]
1163+
gen_urls = disaggregated_server_config["generation_servers"]["urls"]
1164+
url_map = dict()
1165+
for url in set(ctx_urls + gen_urls):
1166+
url_map[url] = (url.split(':')[0], get_free_port())
1167+
1168+
for i, url in enumerate(ctx_urls):
1169+
disaggregated_server_config["context_servers"]["urls"][
1170+
i] = f"{url_map[url][0]}:{url_map[url][1]}"
1171+
1172+
for i, url in enumerate(gen_urls):
1173+
disaggregated_server_config["generation_servers"]["urls"][
1174+
i] = f"{url_map[url][0]}:{url_map[url][1]}"
1175+
1176+
return disaggregated_server_config
1177+
1178+
1179+
def revise_disagg_config_file_with_free_ports(disagg_config_file: str) -> str:
1180+
# Revise the config file to use free ports
1181+
new_config = None
1182+
with open(disagg_config_file, 'r') as f:
1183+
config = yaml.safe_load(f)
1184+
new_config = revise_disaggregated_server_config_urls_with_free_ports(
1185+
config)
1186+
1187+
temp_fd, new_config_file = tempfile.mkstemp(suffix='.yaml')
1188+
with os.fdopen(temp_fd, 'w') as f:
1189+
yaml.dump(new_config, f)
1190+
1191+
return new_config_file

tests/integration/defs/disaggregated/test_disaggregated.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@
2222

2323
import pytest
2424
import yaml
25-
from defs.common import wait_for_server
25+
from defs.common import (revise_disagg_config_file_with_free_ports,
26+
wait_for_server)
2627
from defs.conftest import (get_sm_version, llm_models_root, skip_arm,
2728
skip_no_hopper)
2829
from defs.trt_test_alternative import check_call, check_output, popen
2930

30-
from tensorrt_llm._utils import mpi_disabled
31+
from tensorrt_llm._utils import get_free_port, mpi_disabled
3132
from tensorrt_llm.logger import logger
3233

3334

@@ -143,12 +144,12 @@ def validate_timing_metrics(perf_metrics_item, request_context=""):
143144
return True
144145

145146

146-
def get_disagg_server_url_from_cfg(config_file: str) -> str:
147+
def get_disagg_server_url_from_cfg(config_file: str) -> tuple[str, int]:
147148
with open(config_file, 'r') as file:
148149
config = yaml.safe_load(file)
149150
server_host = config.get('hostname', 'localhost')
150151
server_port = config.get('port', 8000)
151-
return f"http://{server_host}:{server_port}"
152+
return server_host, server_port
152153

153154

154155
def get_test_config(test_desc, example_dir, test_root):
@@ -277,7 +278,8 @@ def get_test_config(test_desc, example_dir, test_root):
277278
raise ValueError(f"Invalid test description: {test_desc}, "
278279
f"valid descriptions are: {config_map.keys()}")
279280

280-
return config_map[test_desc]
281+
return (config_map[test_desc][0],
282+
revise_disagg_config_file_with_free_ports(config_map[test_desc][1]))
281283

282284

283285
def get_extra_llm_config(config, suffix, cwd):
@@ -481,7 +483,8 @@ def run_disaggregated_test(example_dir,
481483
'trtllm-serve', 'disaggregated', '--server_start_timeout',
482484
str(server_start_timeout), '-c', config_file
483485
]
484-
server_url = get_disagg_server_url_from_cfg(config_file)
486+
server_host, server_port = get_disagg_server_url_from_cfg(config_file)
487+
server_url = f"http://{server_host}:{server_port}"
485488

486489
try:
487490
if not use_ray:
@@ -538,8 +541,8 @@ def run_disaggregated_test(example_dir,
538541
env=run_env,
539542
cwd=cwd))
540543

541-
if not wait_for_server("localhost",
542-
8000,
544+
if not wait_for_server(server_host,
545+
server_port,
543546
timeout_seconds=server_start_timeout):
544547
raise RuntimeError(
545548
f"Disaggregated server failed to start within {server_start_timeout} seconds"
@@ -1569,6 +1572,7 @@ def run_disaggregated_benchmark(example_dir,
15691572
'trtllm-serve', 'disaggregated', '--server_start_timeout',
15701573
str(server_start_timeout), '-c', config_file
15711574
]
1575+
server_host, server_port = get_disagg_server_url_from_cfg(config_file)
15721576
try:
15731577
with ( # Start workers
15741578
open('output_workers.log', 'w') as output_workers,
@@ -1622,9 +1626,9 @@ def run_disaggregated_benchmark(example_dir,
16221626
'--max-concurrency',
16231627
str(max_concurrency),
16241628
'--host',
1625-
'localhost',
1629+
server_host,
16261630
'--port',
1627-
'8000',
1631+
str(server_port),
16281632
'--ignore-eos',
16291633
'--no-test-input',
16301634
'--percentile-metrics',
@@ -1666,7 +1670,7 @@ def get_config_for_benchmark(model_root, backend):
16661670
serve_config = {
16671671
"model": model_root,
16681672
"hostname": "localhost",
1669-
"port": 8000,
1673+
"port": get_free_port(),
16701674
"backend": "pytorch",
16711675
"context_servers": {
16721676
"num_instances": 1,
@@ -1680,7 +1684,7 @@ def get_config_for_benchmark(model_root, backend):
16801684
"backend": backend,
16811685
"max_tokens_in_buffer": 512,
16821686
},
1683-
"urls": ["localhost:8001"]
1687+
"urls": [f"localhost:{get_free_port()}"]
16841688
},
16851689
"generation_servers": {
16861690
"num_instances": 1,
@@ -1693,7 +1697,7 @@ def get_config_for_benchmark(model_root, backend):
16931697
"backend": backend,
16941698
"max_tokens_in_buffer": 512,
16951699
},
1696-
"urls": ["localhost:8002"]
1700+
"urls": [f"localhost:{get_free_port()}"]
16971701
}
16981702
}
16991703
return serve_config
@@ -1724,6 +1728,7 @@ def run_disaggregated_genai_perf(config_file,
17241728
]
17251729

17261730
artifact_dir = os.path.join(cwd or ".", "benchmark-results")
1731+
server_host, server_port = get_disagg_server_url_from_cfg(config_file)
17271732

17281733
try:
17291734
with (open('output_workers.log', 'w') as output_workers,
@@ -1740,8 +1745,9 @@ def run_disaggregated_genai_perf(config_file,
17401745
cwd=cwd) as server_proc):
17411746

17421747
# Wait for server to be ready
1743-
if not wait_for_server(
1744-
"localhost", 8000, timeout_seconds=server_start_timeout):
1748+
if not wait_for_server(server_host,
1749+
server_port,
1750+
timeout_seconds=server_start_timeout):
17451751
raise RuntimeError(
17461752
f"Disaggregated server did not become ready within {server_start_timeout} seconds"
17471753
)
@@ -1751,7 +1757,7 @@ def run_disaggregated_genai_perf(config_file,
17511757
'genai-perf', 'profile', '--model', model_path, '--tokenizer',
17521758
model_path, '--endpoint-type', 'chat', '--endpoint',
17531759
'/v1/chat/completions', '--streaming', '--url',
1754-
'localhost:8000', '--synthetic-input-tokens-mean',
1760+
f'{server_host}:{server_port}', '--synthetic-input-tokens-mean',
17551761
str(input_tokens), '--synthetic-input-tokens-stddev', '0',
17561762
'--output-tokens-mean',
17571763
str(output_tokens), '--output-tokens-stddev', '0',

tests/integration/defs/disaggregated/test_workers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import aiohttp
1010
import pytest
1111
import yaml
12+
from defs.common import revise_disagg_config_file_with_free_ports
1213
from defs.conftest import skip_no_hopper
1314
from defs.trt_test_alternative import popen
1415
from transformers import AutoTokenizer
@@ -42,6 +43,7 @@ def run_disaggregated_workers(
4243
num_ranks: Optional[int] = None
4344
) -> Tuple[Generator[subprocess.Popen, None, None], List[str], List[str]]:
4445

46+
config_file = revise_disagg_config_file_with_free_ports(config_file)
4547
ctx_servers, gen_servers = get_ctx_gen_server_urls_from_cfg(config_file)
4648

4749
# TODO: auto detect num_ranks

0 commit comments

Comments
 (0)