Skip to content

Commit b4aaf81

Browse files
feat: allow setting a timeout when creating MCPAgentTool
1 parent 8cae18c commit b4aaf81

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

src/strands/tools/mcp/mcp_agent_tool.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
import logging
9+
from datetime import timedelta
910
from typing import TYPE_CHECKING, Any
1011

1112
from mcp.types import Tool as MCPTool
@@ -28,20 +29,28 @@ class MCPAgentTool(AgentTool):
2829
seamlessly within the agent framework.
2930
"""
3031

31-
def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient", name_override: str | None = None) -> None:
32+
def __init__(
33+
self,
34+
mcp_tool: MCPTool,
35+
mcp_client: "MCPClient",
36+
name_override: str | None = None,
37+
timeout: timedelta | None = None,
38+
) -> None:
3239
"""Initialize a new MCPAgentTool instance.
3340
3441
Args:
3542
mcp_tool: The MCP tool to adapt
3643
mcp_client: The MCP server connection to use for tool invocation
3744
name_override: Optional name to use for the agent tool (for disambiguation)
3845
If None, uses the original MCP tool name
46+
timeout: Optional timeout duration for tool execution
3947
"""
4048
super().__init__()
4149
logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name)
4250
self.mcp_tool = mcp_tool
4351
self.mcp_client = mcp_client
4452
self._agent_tool_name = name_override or mcp_tool.name
53+
self._timeout_seconds = timeout
4554

4655
@property
4756
def tool_name(self) -> str:
@@ -105,5 +114,6 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw
105114
tool_use_id=tool_use["toolUseId"],
106115
name=self.mcp_tool.name, # Use original MCP name for server communication
107116
arguments=tool_use["input"],
117+
read_timeout_seconds=self._timeout_seconds,
108118
)
109119
yield ToolResultEvent(result)

tests/strands/tools/mcp/test_mcp_agent_tool.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import timedelta
12
from unittest.mock import MagicMock
23

34
import pytest
@@ -88,5 +89,31 @@ async def test_stream(mcp_agent_tool, mock_mcp_client, alist):
8889

8990
assert tru_events == exp_events
9091
mock_mcp_client.call_tool_async.assert_called_once_with(
91-
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}
92+
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=None
93+
)
94+
95+
96+
def test_timeout_initialization(mock_mcp_tool, mock_mcp_client):
97+
timeout = timedelta(seconds=30)
98+
agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client, timeout=timeout)
99+
assert agent_tool._timeout_seconds == timeout
100+
101+
102+
def test_timeout_default_none(mock_mcp_tool, mock_mcp_client):
103+
agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client)
104+
assert agent_tool._timeout_seconds is None
105+
106+
107+
@pytest.mark.asyncio
108+
async def test_stream_with_timeout(mock_mcp_tool, mock_mcp_client, alist):
109+
timeout = timedelta(seconds=45)
110+
agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client, timeout=timeout)
111+
tool_use = {"toolUseId": "test-456", "name": "test_tool", "input": {"param": "value"}}
112+
113+
tru_events = await alist(agent_tool.stream(tool_use, {}))
114+
exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)]
115+
116+
assert tru_events == exp_events
117+
mock_mcp_client.call_tool_async.assert_called_once_with(
118+
tool_use_id="test-456", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=timeout
92119
)

0 commit comments

Comments
 (0)