From 2fa5c1fd3747172cb9d133c19be8cd6809052a34 Mon Sep 17 00:00:00 2001 From: tstadel <60758086+tstadel@users.noreply.github.com> Date: Tue, 7 Oct 2025 15:40:23 +0200 Subject: [PATCH] fix: add ctx to annotations --- src/deepset_mcp/mcp/tool_factory.py | 11 +++++++++++ test/unit/test_tool_factory.py | 10 ++++++++++ 2 files changed, 21 insertions(+) diff --git a/src/deepset_mcp/mcp/tool_factory.py b/src/deepset_mcp/mcp/tool_factory.py index 00635f5..c58d170 100644 --- a/src/deepset_mcp/mcp/tool_factory.py +++ b/src/deepset_mcp/mcp/tool_factory.py @@ -193,8 +193,15 @@ async def client_wrapper_with_context(*args: Any, **kwargs: Any) -> Any: ctx_param = inspect.Parameter(name="ctx", kind=inspect.Parameter.KEYWORD_ONLY, annotation=Context) new_params.append(ctx_param) client_wrapper_with_context.__signature__ = original_sig.replace(parameters=new_params) # type: ignore + + # Remove client from docstring client_wrapper_with_context.__doc__ = remove_params_from_docstring(base_func.__doc__, {"client"}) + # Remove client from annotations and add ctx + new_annotations = {k: v for k, v in base_func.__annotations__.items() if k != "client"} + new_annotations["ctx"] = Context + client_wrapper_with_context.__annotations__ = new_annotations + return client_wrapper_with_context else: @@ -214,6 +221,10 @@ async def client_wrapper_without_context(*args: Any, **kwargs: Any) -> Any: # Remove client from docstring client_wrapper_without_context.__doc__ = remove_params_from_docstring(base_func.__doc__, {"client"}) + # Remove client from annotations + new_annotations = {k: v for k, v in base_func.__annotations__.items() if k != "client"} + client_wrapper_without_context.__annotations__ = new_annotations + return client_wrapper_without_context diff --git a/test/unit/test_tool_factory.py b/test/unit/test_tool_factory.py index 39be272..a108baf 100644 --- a/test/unit/test_tool_factory.py +++ b/test/unit/test_tool_factory.py @@ -9,6 +9,7 @@ from unittest.mock import MagicMock, patch import pytest +from mcp.server.fastmcp import Context from deepset_mcp.api.protocols import AsyncClientProtocol from deepset_mcp.mcp.tool_factory import ( @@ -265,6 +266,11 @@ async def sample_func(client: AsyncClientProtocol, a: int) -> str: assert ":param client:" not in result.__doc__ assert ":param a:" in result.__doc__ + # Check annotations were updated + assert "client" not in result.__annotations__ + assert "ctx" in result.__annotations__ + assert result.__annotations__["ctx"] == Context + def test_client_signature_updated_without_context(self) -> None: """Test that client parameter is removed without ctx.""" @@ -289,6 +295,10 @@ async def sample_func(client: AsyncClientProtocol, a: int) -> str: assert result.__doc__ is not None assert ":param client:" not in result.__doc__ + # Check annotations were updated + assert "client" not in result.__annotations__ + assert "ctx" not in result.__annotations__ + @pytest.mark.asyncio async def test_client_context_missing_raises_error(self) -> None: """Test that missing context raises ValueError."""