Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ jobs:
--transport streamable-http
--proxy-port 8080
--env TOOLHIVE_HOST=172.17.0.1
--env WORKLOAD_HOST=172.17.0.1
--env TOOLHIVE_PORT=9090
--env WORKLOAD_POLLING_INTERVAL=2
--env ALLOWED_GROUPS=default
Expand Down
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ USER app

# Set default environment variables for container deployment
ENV TOOLHIVE_HOST=host.docker.internal
ENV WORKLOAD_HOST=host.docker.internal
ENV FASTEMBED_CACHE_PATH=/app/.cache/fastembed
ENV TIKTOKEN_CACHE_DIR=/app/.cache/tiktoken
ENV COLORED_LOGS=false
Expand Down
1 change: 1 addition & 0 deletions Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ tasks:
desc: Run the application locally
env:
TOOLHIVE_HOST: "localhost"
WORKLOAD_HOST: "localhost"
TOOLHIVE_PORT: "8080"
cmds:
- uv run mtm
Expand Down
1 change: 1 addition & 0 deletions src/mcp_optimizer/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def main(**kwargs: Any) -> None:

toolhive_client = ToolhiveClient(
host=config.toolhive_host,
workload_host=config.workload_host,
port=config.toolhive_port,
scan_port_start=config.toolhive_start_port_scan,
scan_port_end=config.toolhive_end_port_scan,
Expand Down
4 changes: 4 additions & 0 deletions src/mcp_optimizer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def normalize_runtime_mode(cls, v) -> str:
toolhive_host: str = Field(
default="localhost", min_length=1, description="Host for ToolHive API"
)
workload_host: str = Field(
default="localhost", min_length=1, description="Host for MCP workload connections"
)
toolhive_port: int | None = Field(
default=None, ge=1024, le=65535, description="Port for ToolHive API (1024-65535)"
)
Expand Down Expand Up @@ -475,6 +478,7 @@ def _populate_config_from_env() -> dict[str, Any]:
"MCP_PORT": "mcp_port",
"RELOAD_SERVER": "reload_server",
"TOOLHIVE_HOST": "toolhive_host",
"WORKLOAD_HOST": "workload_host",
"TOOLHIVE_PORT": "toolhive_port",
"TOOLHIVE_START_PORT_SCAN": "toolhive_start_port_scan",
"TOOLHIVE_END_PORT_SCAN": "toolhive_end_port_scan",
Expand Down
1 change: 1 addition & 0 deletions src/mcp_optimizer/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def initialize_server_components(config: MCPOptimizerConfig) -> None:
mcp.settings.port = config.mcp_port
toolhive_client = ToolhiveClient(
host=config.toolhive_host,
workload_host=config.workload_host,
port=config.toolhive_port,
scan_port_start=config.toolhive_start_port_scan,
scan_port_end=config.toolhive_end_port_scan,
Expand Down
50 changes: 37 additions & 13 deletions src/mcp_optimizer/toolhive/toolhive_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import asyncio
from functools import wraps
from typing import Any, Awaitable, Callable, Self, TypeVar
from urllib.parse import urlparse
from urllib.parse import urlparse, urlunparse

import httpx
import structlog
Expand Down Expand Up @@ -43,6 +43,7 @@ class ToolhiveClient:
def __init__(
self,
host: str,
workload_host: str,
port: int | None,
scan_port_start: int,
scan_port_end: int,
Expand All @@ -57,6 +58,7 @@ def __init__(

Args:
host: Toolhive server host
workload_host: Host for MCP workload connections
port: Toolhive server port
scan_port_start: Start of port range to scan for Toolhive
scan_port_end: End of port range to scan for Toolhive
Expand All @@ -68,6 +70,7 @@ def __init__(
(useful when ToolHive is not needed, e.g., K8s mode)
"""
self.thv_host = host
self.workload_host = workload_host
self.timeout = timeout
self.max_retries = max_retries
self.initial_backoff = initial_backoff
Expand Down Expand Up @@ -379,6 +382,35 @@ def client(self) -> httpx.AsyncClient:
self._client = httpx.AsyncClient(timeout=self.timeout)
return self._client

def replace_localhost_in_url(self, url: str | None) -> str | None:
"""
Replace localhost/127.0.0.1 with the workload host in a URL.

This is required since in docker/podman containers, we cannot use
localhost/127.0.0.1 to reach services on the host machine.

Args:
url: The URL to process (can be None)

Returns:
The URL with localhost/127.0.0.1 replaced by workload_host,
or None if input was None
"""
if not url:
return url

parsed_url = urlparse(url)
hostname = parsed_url.hostname

if hostname and hostname in ("localhost", "127.0.0.1"):
# Replace hostname only in netloc, not in path/query/fragment
new_netloc = parsed_url.netloc.replace(hostname, self.workload_host, 1)
parsed = parsed_url._replace(netloc=new_netloc)
# urlunparse returns str when input is from urlparse(str)
return str(urlunparse(parsed))

return url

async def list_workloads(self, all_workloads: bool = False) -> WorkloadListResponse:
"""
Get a list of workloads from Toolhive.
Expand Down Expand Up @@ -410,15 +442,11 @@ async def _list_workloads_impl() -> WorkloadListResponse:
logger.warning("No workloads found", all_workloads=all_workloads)
return WorkloadListResponse(workloads=[])

# Replace the localhost/127.0.0.1 host with the toolhive host
# Replace the localhost/127.0.0.1 host with the workload host
# This is required since in docker/podman container, we cannot use
# localhost/127.0.0.1
for workload in workload_list.workloads:
if workload.url:
parsed_url = urlparse(workload.url)
workload_host = parsed_url.hostname
if workload_host in ("localhost", "127.0.0.1"):
workload.url = workload.url.replace(workload_host, self.thv_host)
workload.url = self.replace_localhost_in_url(workload.url)

logger.info(
"Successfully fetched workloads",
Expand Down Expand Up @@ -458,12 +486,8 @@ async def _get_workload_details_impl() -> Workload:
data = response.json()
workload = Workload.model_validate(data)

# Replace localhost/127.0.0.1 with the toolhive host (same as list_workloads)
if workload.url:
parsed_url = urlparse(workload.url)
workload_host = parsed_url.hostname
if workload_host in ("localhost", "127.0.0.1"):
workload.url = workload.url.replace(workload_host, self.thv_host)
# Replace localhost/127.0.0.1 with the workload host (same as list_workloads)
workload.url = self.replace_localhost_in_url(workload.url)

logger.info(
"Successfully fetched workload details",
Expand Down
1 change: 1 addition & 0 deletions tests/test_polling_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ async def mock_is_toolhive_available(self, host, port):

return ToolhiveClient(
host="localhost",
workload_host="localhost",
port=8080,
scan_port_start=50000,
scan_port_end=50100,
Expand Down
Loading