diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 8f1e791..c2379fa 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -59,6 +59,7 @@ jobs: --transport streamable-http --proxy-port 8080 --env TOOLHIVE_HOST=172.17.0.1 + --env WORKLOAD_HOST=172.17.0.1 --env TOOLHIVE_PORT=9090 --env WORKLOAD_POLLING_INTERVAL=2 --env ALLOWED_GROUPS=default diff --git a/Dockerfile b/Dockerfile index b984cc1..c6793b0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -106,6 +106,7 @@ USER app # Set default environment variables for container deployment ENV TOOLHIVE_HOST=host.docker.internal +ENV WORKLOAD_HOST=host.docker.internal ENV FASTEMBED_CACHE_PATH=/app/.cache/fastembed ENV TIKTOKEN_CACHE_DIR=/app/.cache/tiktoken ENV COLORED_LOGS=false diff --git a/Taskfile.yml b/Taskfile.yml index c6eec9c..e0507ca 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -102,6 +102,7 @@ tasks: desc: Run the application locally env: TOOLHIVE_HOST: "localhost" + WORKLOAD_HOST: "localhost" TOOLHIVE_PORT: "8080" cmds: - uv run mtm diff --git a/src/mcp_optimizer/cli.py b/src/mcp_optimizer/cli.py index ea76bce..e704d7e 100644 --- a/src/mcp_optimizer/cli.py +++ b/src/mcp_optimizer/cli.py @@ -198,6 +198,7 @@ def main(**kwargs: Any) -> None: toolhive_client = ToolhiveClient( host=config.toolhive_host, + workload_host=config.workload_host, port=config.toolhive_port, scan_port_start=config.toolhive_start_port_scan, scan_port_end=config.toolhive_end_port_scan, diff --git a/src/mcp_optimizer/config.py b/src/mcp_optimizer/config.py index 7a7575f..6d2fee0 100644 --- a/src/mcp_optimizer/config.py +++ b/src/mcp_optimizer/config.py @@ -59,6 +59,9 @@ def normalize_runtime_mode(cls, v) -> str: toolhive_host: str = Field( default="localhost", min_length=1, description="Host for ToolHive API" ) + workload_host: str = Field( + default="localhost", min_length=1, description="Host for MCP workload connections" + ) toolhive_port: int | None = Field( default=None, ge=1024, le=65535, description="Port for ToolHive API (1024-65535)" ) @@ -475,6 +478,7 @@ def _populate_config_from_env() -> dict[str, Any]: "MCP_PORT": "mcp_port", "RELOAD_SERVER": "reload_server", "TOOLHIVE_HOST": "toolhive_host", + "WORKLOAD_HOST": "workload_host", "TOOLHIVE_PORT": "toolhive_port", "TOOLHIVE_START_PORT_SCAN": "toolhive_start_port_scan", "TOOLHIVE_END_PORT_SCAN": "toolhive_end_port_scan", diff --git a/src/mcp_optimizer/server.py b/src/mcp_optimizer/server.py index deef72d..6d75107 100644 --- a/src/mcp_optimizer/server.py +++ b/src/mcp_optimizer/server.py @@ -154,6 +154,7 @@ def initialize_server_components(config: MCPOptimizerConfig) -> None: mcp.settings.port = config.mcp_port toolhive_client = ToolhiveClient( host=config.toolhive_host, + workload_host=config.workload_host, port=config.toolhive_port, scan_port_start=config.toolhive_start_port_scan, scan_port_end=config.toolhive_end_port_scan, diff --git a/src/mcp_optimizer/toolhive/toolhive_client.py b/src/mcp_optimizer/toolhive/toolhive_client.py index 7d603f3..d92197f 100644 --- a/src/mcp_optimizer/toolhive/toolhive_client.py +++ b/src/mcp_optimizer/toolhive/toolhive_client.py @@ -5,7 +5,7 @@ import asyncio from functools import wraps from typing import Any, Awaitable, Callable, Self, TypeVar -from urllib.parse import urlparse +from urllib.parse import urlparse, urlunparse import httpx import structlog @@ -43,6 +43,7 @@ class ToolhiveClient: def __init__( self, host: str, + workload_host: str, port: int | None, scan_port_start: int, scan_port_end: int, @@ -57,6 +58,7 @@ def __init__( Args: host: Toolhive server host + workload_host: Host for MCP workload connections port: Toolhive server port scan_port_start: Start of port range to scan for Toolhive scan_port_end: End of port range to scan for Toolhive @@ -68,6 +70,7 @@ def __init__( (useful when ToolHive is not needed, e.g., K8s mode) """ self.thv_host = host + self.workload_host = workload_host self.timeout = timeout self.max_retries = max_retries self.initial_backoff = initial_backoff @@ -379,6 +382,35 @@ def client(self) -> httpx.AsyncClient: self._client = httpx.AsyncClient(timeout=self.timeout) return self._client + def replace_localhost_in_url(self, url: str | None) -> str | None: + """ + Replace localhost/127.0.0.1 with the workload host in a URL. + + This is required since in docker/podman containers, we cannot use + localhost/127.0.0.1 to reach services on the host machine. + + Args: + url: The URL to process (can be None) + + Returns: + The URL with localhost/127.0.0.1 replaced by workload_host, + or None if input was None + """ + if not url: + return url + + parsed_url = urlparse(url) + hostname = parsed_url.hostname + + if hostname and hostname in ("localhost", "127.0.0.1"): + # Replace hostname only in netloc, not in path/query/fragment + new_netloc = parsed_url.netloc.replace(hostname, self.workload_host, 1) + parsed = parsed_url._replace(netloc=new_netloc) + # urlunparse returns str when input is from urlparse(str) + return str(urlunparse(parsed)) + + return url + async def list_workloads(self, all_workloads: bool = False) -> WorkloadListResponse: """ Get a list of workloads from Toolhive. @@ -410,15 +442,11 @@ async def _list_workloads_impl() -> WorkloadListResponse: logger.warning("No workloads found", all_workloads=all_workloads) return WorkloadListResponse(workloads=[]) - # Replace the localhost/127.0.0.1 host with the toolhive host + # Replace the localhost/127.0.0.1 host with the workload host # This is required since in docker/podman container, we cannot use # localhost/127.0.0.1 for workload in workload_list.workloads: - if workload.url: - parsed_url = urlparse(workload.url) - workload_host = parsed_url.hostname - if workload_host in ("localhost", "127.0.0.1"): - workload.url = workload.url.replace(workload_host, self.thv_host) + workload.url = self.replace_localhost_in_url(workload.url) logger.info( "Successfully fetched workloads", @@ -458,12 +486,8 @@ async def _get_workload_details_impl() -> Workload: data = response.json() workload = Workload.model_validate(data) - # Replace localhost/127.0.0.1 with the toolhive host (same as list_workloads) - if workload.url: - parsed_url = urlparse(workload.url) - workload_host = parsed_url.hostname - if workload_host in ("localhost", "127.0.0.1"): - workload.url = workload.url.replace(workload_host, self.thv_host) + # Replace localhost/127.0.0.1 with the workload host (same as list_workloads) + workload.url = self.replace_localhost_in_url(workload.url) logger.info( "Successfully fetched workload details", diff --git a/tests/test_polling_manager.py b/tests/test_polling_manager.py index b8363d3..5f70083 100644 --- a/tests/test_polling_manager.py +++ b/tests/test_polling_manager.py @@ -56,6 +56,7 @@ async def mock_is_toolhive_available(self, host, port): return ToolhiveClient( host="localhost", + workload_host="localhost", port=8080, scan_port_start=50000, scan_port_end=50100, diff --git a/tests/test_toolhive_client.py b/tests/test_toolhive_client.py index 43cf2c4..b49fd70 100644 --- a/tests/test_toolhive_client.py +++ b/tests/test_toolhive_client.py @@ -58,6 +58,7 @@ async def mock_is_toolhive_available(self, host, port): # Mock the port scanning to avoid network calls during testing client = ToolhiveClient( host="127.0.0.1", + workload_host="127.0.0.1", port=8080, scan_port_start=50000, scan_port_end=50100, @@ -173,6 +174,7 @@ def mock_discover_port(self, port): async with ToolhiveClient( host="127.0.0.1", + workload_host="127.0.0.1", port=8080, scan_port_start=50000, scan_port_end=50100, @@ -204,6 +206,7 @@ def mock_discover_port(self, port): client = ToolhiveClient( host="127.0.0.1", + workload_host="127.0.0.1", port=8080, scan_port_start=50000, scan_port_end=50100, @@ -241,6 +244,7 @@ def mock_discover_port(self, port): client = ToolhiveClient( host="127.0.0.1", + workload_host="127.0.0.1", port=8080, scan_port_start=50000, scan_port_end=50100, @@ -276,6 +280,7 @@ async def mock_is_toolhive_available(self, host, port): ): ToolhiveClient( host="127.0.0.1", + workload_host="127.0.0.1", port=None, scan_port_start=50000, scan_port_end=50100, @@ -303,6 +308,7 @@ async def mock_is_toolhive_available(self, host, port): client = ToolhiveClient( host="127.0.0.1", + workload_host="127.0.0.1", port=None, scan_port_start=50000, scan_port_end=50100, @@ -334,6 +340,7 @@ async def mock_is_toolhive_available(self, host, port): client = ToolhiveClient( host="127.0.0.1", + workload_host="127.0.0.1", port=8080, scan_port_start=50000, scan_port_end=50100, @@ -554,8 +561,8 @@ async def test_get_workload_details_replaces_localhost( # Call the method result = await client.get_workload_details("test-server") - # Verify that localhost was replaced with the toolhive host - assert result.url == f"http://{toolhive_client.thv_host}:8080/mcp" + # Verify that localhost was replaced with the workload host + assert result.url == f"http://{toolhive_client.workload_host}:8080/mcp" @pytest.mark.asyncio @@ -598,3 +605,298 @@ async def test_get_workload_details_timeout(toolhive_client): # Call the method and expect an exception with pytest.raises(httpx.TimeoutException): await client.get_workload_details("test-server") + + +# Tests for replace_localhost_in_url method + + +def test_replace_localhost_in_url_with_none(monkeypatch): + """Test replace_localhost_in_url with None URL.""" + + def mock_discover_port(self, port): + self.thv_port = 8080 + self.base_url = f"http://{self.thv_host}:{self.thv_port}" + + monkeypatch.setattr( + "mcp_optimizer.toolhive.toolhive_client.ToolhiveClient._discover_port", + mock_discover_port, + ) + + client = ToolhiveClient( + host="127.0.0.1", + workload_host="host.docker.internal", + port=8080, + scan_port_start=50000, + scan_port_end=50100, + timeout=5.0, + max_retries=3, + initial_backoff=1.0, + max_backoff=60.0, + ) + + result = client.replace_localhost_in_url(None) + assert result is None + + +def test_replace_localhost_in_url_with_empty_string(monkeypatch): + """Test replace_localhost_in_url with empty string.""" + + def mock_discover_port(self, port): + self.thv_port = 8080 + self.base_url = f"http://{self.thv_host}:{self.thv_port}" + + monkeypatch.setattr( + "mcp_optimizer.toolhive.toolhive_client.ToolhiveClient._discover_port", + mock_discover_port, + ) + + client = ToolhiveClient( + host="127.0.0.1", + workload_host="host.docker.internal", + port=8080, + scan_port_start=50000, + scan_port_end=50100, + timeout=5.0, + max_retries=3, + initial_backoff=1.0, + max_backoff=60.0, + ) + + result = client.replace_localhost_in_url("") + assert result == "" + + +def test_replace_localhost_in_url_with_localhost(monkeypatch): + """Test replace_localhost_in_url replaces localhost with workload_host.""" + + def mock_discover_port(self, port): + self.thv_port = 8080 + self.base_url = f"http://{self.thv_host}:{self.thv_port}" + + monkeypatch.setattr( + "mcp_optimizer.toolhive.toolhive_client.ToolhiveClient._discover_port", + mock_discover_port, + ) + + client = ToolhiveClient( + host="127.0.0.1", + workload_host="host.docker.internal", + port=8080, + scan_port_start=50000, + scan_port_end=50100, + timeout=5.0, + max_retries=3, + initial_backoff=1.0, + max_backoff=60.0, + ) + + result = client.replace_localhost_in_url("http://localhost:8080/sse") + assert result == "http://host.docker.internal:8080/sse" + + +def test_replace_localhost_in_url_with_127_0_0_1(monkeypatch): + """Test replace_localhost_in_url replaces 127.0.0.1 with workload_host.""" + + def mock_discover_port(self, port): + self.thv_port = 8080 + self.base_url = f"http://{self.thv_host}:{self.thv_port}" + + monkeypatch.setattr( + "mcp_optimizer.toolhive.toolhive_client.ToolhiveClient._discover_port", + mock_discover_port, + ) + + client = ToolhiveClient( + host="127.0.0.1", + workload_host="host.docker.internal", + port=8080, + scan_port_start=50000, + scan_port_end=50100, + timeout=5.0, + max_retries=3, + initial_backoff=1.0, + max_backoff=60.0, + ) + + result = client.replace_localhost_in_url("http://127.0.0.1:9000/mcp") + assert result == "http://host.docker.internal:9000/mcp" + + +def test_replace_localhost_in_url_with_regular_hostname(monkeypatch): + """Test replace_localhost_in_url does not replace regular hostnames.""" + + def mock_discover_port(self, port): + self.thv_port = 8080 + self.base_url = f"http://{self.thv_host}:{self.thv_port}" + + monkeypatch.setattr( + "mcp_optimizer.toolhive.toolhive_client.ToolhiveClient._discover_port", + mock_discover_port, + ) + + client = ToolhiveClient( + host="127.0.0.1", + workload_host="host.docker.internal", + port=8080, + scan_port_start=50000, + scan_port_end=50100, + timeout=5.0, + max_retries=3, + initial_backoff=1.0, + max_backoff=60.0, + ) + + result = client.replace_localhost_in_url("https://api.github.com/mcp") + assert result == "https://api.github.com/mcp" + + +def test_replace_localhost_in_url_localhost_in_path_only(monkeypatch): + """Test replace_localhost_in_url only replaces hostname, not path.""" + + def mock_discover_port(self, port): + self.thv_port = 8080 + self.base_url = f"http://{self.thv_host}:{self.thv_port}" + + monkeypatch.setattr( + "mcp_optimizer.toolhive.toolhive_client.ToolhiveClient._discover_port", + mock_discover_port, + ) + + client = ToolhiveClient( + host="127.0.0.1", + workload_host="host.docker.internal", + port=8080, + scan_port_start=50000, + scan_port_end=50100, + timeout=5.0, + max_retries=3, + initial_backoff=1.0, + max_backoff=60.0, + ) + + # Hostname is example.com, but path contains 'localhost' + # Should NOT replace localhost in the path + result = client.replace_localhost_in_url("http://example.com/api/localhost/data") + assert result == "http://example.com/api/localhost/data" + + +def test_replace_localhost_in_url_localhost_in_hostname_and_path(monkeypatch): + """Test that localhost in path is NOT replaced when localhost is also the hostname.""" + + def mock_discover_port(self, port): + self.thv_port = 8080 + self.base_url = f"http://{self.thv_host}:{self.thv_port}" + + monkeypatch.setattr( + "mcp_optimizer.toolhive.toolhive_client.ToolhiveClient._discover_port", + mock_discover_port, + ) + + client = ToolhiveClient( + host="127.0.0.1", + workload_host="host.docker.internal", + port=8080, + scan_port_start=50000, + scan_port_end=50100, + timeout=5.0, + max_retries=3, + initial_backoff=1.0, + max_backoff=60.0, + ) + + # Hostname is localhost AND path contains 'localhost' + # Should ONLY replace localhost in the hostname, NOT in the path + result = client.replace_localhost_in_url("http://localhost:8080/api/localhost/data") + assert result == "http://host.docker.internal:8080/api/localhost/data" + + # Test with 127.0.0.1 as well + result2 = client.replace_localhost_in_url("http://127.0.0.1:8080/path/127.0.0.1/data") + assert result2 == "http://host.docker.internal:8080/path/127.0.0.1/data" + + +def test_replace_localhost_in_url_with_port_number(monkeypatch): + """Test replace_localhost_in_url works correctly with various port numbers.""" + + def mock_discover_port(self, port): + self.thv_port = 8080 + self.base_url = f"http://{self.thv_host}:{self.thv_port}" + + monkeypatch.setattr( + "mcp_optimizer.toolhive.toolhive_client.ToolhiveClient._discover_port", + mock_discover_port, + ) + + client = ToolhiveClient( + host="127.0.0.1", + workload_host="host.docker.internal", + port=8080, + scan_port_start=50000, + scan_port_end=50100, + timeout=5.0, + max_retries=3, + initial_backoff=1.0, + max_backoff=60.0, + ) + + # Test with different port numbers + result1 = client.replace_localhost_in_url("http://localhost:3000/api") + assert result1 == "http://host.docker.internal:3000/api" + + result2 = client.replace_localhost_in_url("http://127.0.0.1:50001/sse#test") + assert result2 == "http://host.docker.internal:50001/sse#test" + + +def test_replace_localhost_in_url_with_https(monkeypatch): + """Test replace_localhost_in_url works with HTTPS URLs.""" + + def mock_discover_port(self, port): + self.thv_port = 8080 + self.base_url = f"http://{self.thv_host}:{self.thv_port}" + + monkeypatch.setattr( + "mcp_optimizer.toolhive.toolhive_client.ToolhiveClient._discover_port", + mock_discover_port, + ) + + client = ToolhiveClient( + host="127.0.0.1", + workload_host="host.docker.internal", + port=8080, + scan_port_start=50000, + scan_port_end=50100, + timeout=5.0, + max_retries=3, + initial_backoff=1.0, + max_backoff=60.0, + ) + + result = client.replace_localhost_in_url("https://localhost:443/secure") + assert result == "https://host.docker.internal:443/secure" + + +def test_replace_localhost_in_url_with_fragment(monkeypatch): + """Test replace_localhost_in_url preserves URL fragments.""" + + def mock_discover_port(self, port): + self.thv_port = 8080 + self.base_url = f"http://{self.thv_host}:{self.thv_port}" + + monkeypatch.setattr( + "mcp_optimizer.toolhive.toolhive_client.ToolhiveClient._discover_port", + mock_discover_port, + ) + + client = ToolhiveClient( + host="127.0.0.1", + workload_host="host.docker.internal", + port=8080, + scan_port_start=50000, + scan_port_end=50100, + timeout=5.0, + max_retries=3, + initial_backoff=1.0, + max_backoff=60.0, + ) + + result = client.replace_localhost_in_url("http://localhost:8080/sse#server-name") + assert result == "http://host.docker.internal:8080/sse#server-name" diff --git a/tests/test_toolhive_retry.py b/tests/test_toolhive_retry.py index ae41d01..4d1eabf 100644 --- a/tests/test_toolhive_retry.py +++ b/tests/test_toolhive_retry.py @@ -40,6 +40,7 @@ async def mock_is_toolhive_available(self, host, port): client = ToolhiveClient( host="localhost", + workload_host="localhost", port=50001, scan_port_start=50000, scan_port_end=50100, @@ -86,6 +87,7 @@ async def mock_discover_port_async(self, port): client = ToolhiveClient( host="localhost", + workload_host="localhost", port=50001, scan_port_start=50000, scan_port_end=50100, @@ -130,6 +132,7 @@ async def mock_scan_for_toolhive(self, host, start_port, end_port): client = ToolhiveClient( host="localhost", + workload_host="localhost", port=50001, scan_port_start=50000, scan_port_end=50100, @@ -176,6 +179,7 @@ async def mock_discover_port_async(self, port): client = ToolhiveClient( host="localhost", + workload_host="localhost", port=50001, scan_port_start=50000, scan_port_end=50100, @@ -221,6 +225,7 @@ async def mock_discover_port_async(self, port): client = ToolhiveClient( host="localhost", + workload_host="localhost", port=50001, scan_port_start=50000, scan_port_end=50100, @@ -395,6 +400,7 @@ def mock_discover_port(self, port): client = ToolhiveClient( host="localhost", + workload_host="localhost", port=50001, scan_port_start=50000, scan_port_end=50100, @@ -664,6 +670,7 @@ async def mock_is_toolhive_available(self, host, port): client = ToolhiveClient( host="localhost", + workload_host="localhost", port=50001, scan_port_start=50000, scan_port_end=50100, @@ -690,6 +697,7 @@ async def mock_is_toolhive_available(self, host, port): client = ToolhiveClient( host="localhost", + workload_host="localhost", port=50001, scan_port_start=50000, scan_port_end=50100, @@ -717,6 +725,7 @@ async def mock_is_toolhive_available(self, host, port): client = ToolhiveClient( host="localhost", + workload_host="localhost", port=50001, scan_port_start=50000, scan_port_end=50100,