Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 189 additions & 8 deletions libs/langchain_v1/langchain/agents/middleware/tool_call_limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing import TYPE_CHECKING, Annotated, Any, Literal

from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage
from langgraph.channels.untracked_value import UntrackedValue
from typing_extensions import NotRequired

Expand All @@ -16,7 +16,12 @@
)

if TYPE_CHECKING:
from collections.abc import Callable

from langgraph.runtime import Runtime
from langgraph.types import Command

from langchain.tools.tool_node import ToolCallRequest


class ToolCallLimitState(AgentState):
Expand Down Expand Up @@ -163,12 +168,12 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]):
from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware
from langchain.agents import create_agent

# Limit all tool calls globally
# Limit all tool calls globally - stop entire agent when exceeded
global_limiter = ToolCallLimitMiddleware(thread_limit=20, run_limit=10, exit_behavior="end")

# Limit a specific tool
# Limit a specific tool - block tool execution but let agent continue
search_limiter = ToolCallLimitMiddleware(
tool_name="search", thread_limit=5, run_limit=3, exit_behavior="end"
tool_name="search", thread_limit=5, run_limit=3, exit_behavior="end_tools"
)

# Use both in the same agent
Expand All @@ -186,7 +191,7 @@ def __init__(
tool_name: str | None = None,
thread_limit: int | None = None,
run_limit: int | None = None,
exit_behavior: Literal["end", "error"] = "end",
exit_behavior: Literal["end", "end_tools", "error"] = "end",
) -> None:
"""Initialize the tool call limit middleware.

Expand All @@ -200,6 +205,9 @@ def __init__(
exit_behavior: What to do when limits are exceeded.
- "end": Jump to the end of the agent execution and
inject an artificial AI message indicating that the limit was exceeded.
- "end_tools": Allow the model to request tools, but block tool execution
when limits are exceeded. The agent receives warning messages and can
continue with partial results.
- "error": Raise a ToolCallLimitExceededError
Defaults to "end".

Expand All @@ -212,8 +220,8 @@ def __init__(
msg = "At least one limit must be specified (thread_limit or run_limit)"
raise ValueError(msg)

if exit_behavior not in ("end", "error"):
msg = f"Invalid exit_behavior: {exit_behavior}. Must be 'end' or 'error'"
if exit_behavior not in ("end", "end_tools", "error"):
msg = f"Invalid exit_behavior: {exit_behavior}. Must be 'end', 'end_tools', or 'error'"
raise ValueError(msg)

self.tool_name = tool_name
Expand All @@ -237,18 +245,80 @@ def name(self) -> str:
def before_model(self, state: ToolCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Check tool call limits before making a model call.

For `end` and `error` behaviors, this prevents the model from being
called if limits are already exceeded. For `end_tools` behavior, this first
counts successful tool executions from the previous iteration, then allows
the model to run (blocking happens during tool execution instead).

Args:
state: The current agent state containing tool call counts.
runtime: The langgraph runtime.

Returns:
If limits are exceeded and exit_behavior is "end", returns
a Command to jump to the end with a limit exceeded message. Otherwise returns None.
a Command to jump to the end with a limit exceeded message.
For end_tools, returns state updates with updated counts.
Otherwise returns None.

Raises:
ToolCallLimitExceededError: If limits are exceeded and exit_behavior
is "error".
"""
# For end_tools behavior, count executions from current run
if self.exit_behavior == "end_tools":
messages = state.get("messages", [])
if not messages:
return None

# Only look at messages from the current run (after last HumanMessage)
run_messages = _get_run_messages(messages)
if not run_messages:
return None

# Count successful tool executions in the current run
count_key = self.tool_name if self.tool_name else "__all__"
successful_executions = 0

for msg in run_messages:
if not isinstance(msg, ToolMessage):
continue

# Check if this is a limit warning (not a successful execution)
content = msg.content if isinstance(msg.content, str) else str(msg.content)
is_limit_warning = "tool call limits exceeded" in content.lower()

# Check if this tool matches our filter
if self.tool_name is not None and msg.name != self.tool_name:
continue

if not is_limit_warning:
successful_executions += 1

if successful_executions == 0:
return None

# Check if we've already updated to this count
current_run_count = state.get("run_tool_call_count", {}).get(count_key, 0)

# If we've already counted all executions, don't update again
if current_run_count >= successful_executions:
return None

# Update counts with the delta
thread_counts = state.get("thread_tool_call_count", {}).copy()
run_counts = state.get("run_tool_call_count", {}).copy()

# Calculate how many new executions we haven't counted yet
new_executions = successful_executions - current_run_count

thread_counts[count_key] = thread_counts.get(count_key, 0) + new_executions
run_counts[count_key] = successful_executions

return {
"thread_tool_call_count": thread_counts,
"run_tool_call_count": run_counts,
}

# Get the count key for this middleware instance
count_key = self.tool_name if self.tool_name else "__all__"

Expand Down Expand Up @@ -285,13 +355,21 @@ def before_model(self, state: ToolCallLimitState, runtime: Runtime) -> dict[str,
def after_model(self, state: ToolCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Increment tool call counts after a model call (when tool calls are made).

For `end_tools` behavior, counting happens in `before_model` on the next
iteration (after tools execute). For other behaviors, this increments the
count based on how many tool calls the model made.

Args:
state: The current agent state.
runtime: The langgraph runtime.

Returns:
State updates with incremented tool call counts if tool calls were made.
"""
# For end_tools, counting happens in before_model (after tools finish)
if self.exit_behavior == "end_tools":
return None

# Get the last AIMessage to check for tool calls
messages = state.get("messages", [])
if not messages:
Expand Down Expand Up @@ -331,3 +409,106 @@ def after_model(self, state: ToolCallLimitState, runtime: Runtime) -> dict[str,
"thread_tool_call_count": thread_counts,
"run_tool_call_count": run_counts,
}

def wrap_tool_call(
self,
request: ToolCallRequest,
execute: Callable[[ToolCallRequest], ToolMessage | Command],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not call this handler?

) -> ToolMessage | Command:
"""Intercept tool execution to enforce limits for end_tools behavior.

For `end_tools` behavior, this method checks if executing this specific
tool would exceed the limits. If so, it returns a warning message instead
of executing the tool. This allows the agent to continue with partial results.

The position of the tool call in the model's response is used to determine
which tools should execute and which should be blocked, even with parallel
tool execution.

Args:
request: The tool call request containing the tool call and state.
execute: Function to execute the tool call.

Returns:
ToolMessage with the tool result, or a warning message if limit exceeded.
"""
# Only intercept for end_tools behavior
if self.exit_behavior != "end_tools":
return execute(request)

# Check if this tool matches our filter
if self.tool_name is not None and request.tool_call["name"] != self.tool_name:
# This tool doesn't match our filter, execute it without counting
return execute(request)

# Get the count key for this middleware instance
count_key = self.tool_name if self.tool_name else "__all__"

# Find the last AI message to get the tool call position
messages = request.state.get("messages", [])
last_ai_message = None
for message in reversed(messages):
if isinstance(message, AIMessage):
last_ai_message = message
break

if not last_ai_message or not last_ai_message.tool_calls:
# No AI message with tool calls found, execute normally
return execute(request)

# Find the position of this tool call in the list
# Only count tool calls that match our filter
tool_call_position = None
for idx, tc in enumerate(last_ai_message.tool_calls):
# Match by tool_call_id
if tc["id"] == request.tool_call["id"]:
# Count how many matching tool calls come before this one
matching_before = sum(
1
for i in range(idx)
if self.tool_name is None
or last_ai_message.tool_calls[i]["name"] == self.tool_name
)
tool_call_position = matching_before
break

# Shouldn't happen, but safety check
if tool_call_position is None:
return execute(request)

# Get current counts from state
thread_counts = request.state.get("thread_tool_call_count", {})
run_counts = request.state.get("run_tool_call_count", {})

current_thread_count = thread_counts.get(count_key, 0)
current_run_count = run_counts.get(count_key, 0)

# Calculate count after this tool executes (based on position)
count_after_this_tool = current_thread_count + tool_call_position + 1
run_count_after_this_tool = current_run_count + tool_call_position + 1

# Check if this tool call would exceed limits
thread_limit_exceeded = (
self.thread_limit is not None and count_after_this_tool > self.thread_limit
)
run_limit_exceeded = (
self.run_limit is not None and run_count_after_this_tool > self.run_limit
)

if thread_limit_exceeded or run_limit_exceeded:
# This tool would exceed the limit - return warning message
limit_message = _build_tool_limit_exceeded_message(
thread_count=current_thread_count + tool_call_position,
run_count=current_run_count + tool_call_position,
thread_limit=self.thread_limit,
run_limit=self.run_limit,
tool_name=self.tool_name,
)
return ToolMessage(
content=f"{limit_message} Do not call any more tools.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be specific to this tool?

tool_call_id=request.tool_call["id"],
name=request.tool_call["name"],
)

# Within limit - execute the tool
return execute(request)
Loading