Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/deepset_mcp/mcp/tool_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,15 @@ async def client_wrapper_with_context(*args: Any, **kwargs: Any) -> Any:
ctx_param = inspect.Parameter(name="ctx", kind=inspect.Parameter.KEYWORD_ONLY, annotation=Context)
new_params.append(ctx_param)
client_wrapper_with_context.__signature__ = original_sig.replace(parameters=new_params) # type: ignore

# Remove client from docstring
client_wrapper_with_context.__doc__ = remove_params_from_docstring(base_func.__doc__, {"client"})

# Remove client from annotations and add ctx
new_annotations = {k: v for k, v in base_func.__annotations__.items() if k != "client"}
new_annotations["ctx"] = Context
client_wrapper_with_context.__annotations__ = new_annotations

return client_wrapper_with_context
else:

Expand All @@ -214,6 +221,10 @@ async def client_wrapper_without_context(*args: Any, **kwargs: Any) -> Any:
# Remove client from docstring
client_wrapper_without_context.__doc__ = remove_params_from_docstring(base_func.__doc__, {"client"})

# Remove client from annotations
new_annotations = {k: v for k, v in base_func.__annotations__.items() if k != "client"}
client_wrapper_without_context.__annotations__ = new_annotations

return client_wrapper_without_context


Expand Down
10 changes: 10 additions & 0 deletions test/unit/test_tool_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from unittest.mock import MagicMock, patch

import pytest
from mcp.server.fastmcp import Context

from deepset_mcp.api.protocols import AsyncClientProtocol
from deepset_mcp.mcp.tool_factory import (
Expand Down Expand Up @@ -265,6 +266,11 @@ async def sample_func(client: AsyncClientProtocol, a: int) -> str:
assert ":param client:" not in result.__doc__
assert ":param a:" in result.__doc__

# Check annotations were updated
assert "client" not in result.__annotations__
assert "ctx" in result.__annotations__
assert result.__annotations__["ctx"] == Context

def test_client_signature_updated_without_context(self) -> None:
"""Test that client parameter is removed without ctx."""

Expand All @@ -289,6 +295,10 @@ async def sample_func(client: AsyncClientProtocol, a: int) -> str:
assert result.__doc__ is not None
assert ":param client:" not in result.__doc__

# Check annotations were updated
assert "client" not in result.__annotations__
assert "ctx" not in result.__annotations__

@pytest.mark.asyncio
async def test_client_context_missing_raises_error(self) -> None:
"""Test that missing context raises ValueError."""
Expand Down