-
Notifications
You must be signed in to change notification settings - Fork 19.5k
feat(langchain_v1): add end_tools exit behavior to ToolCallLimitMiddleware #33641
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
liam-langchain
wants to merge
4
commits into
master
Choose a base branch
from
feat/tool-limit-end-tools
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+602
−9
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
55bc0eb
feat(agents): add end_tools exit behavior to ToolCallLimitMiddleware
liam-langchain 749fb4f
refactor: clean up code style and add sequential execution test
liam-langchain 9d046a1
fix: resolve linting and formatting issues
liam-langchain 256d1e6
fix: resolve mypy type error for ToolMessage content
liam-langchain File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,7 +4,7 @@ | |
|
|
||
| from typing import TYPE_CHECKING, Annotated, Any, Literal | ||
|
|
||
| from langchain_core.messages import AIMessage, AnyMessage, HumanMessage | ||
| from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage | ||
| from langgraph.channels.untracked_value import UntrackedValue | ||
| from typing_extensions import NotRequired | ||
|
|
||
|
|
@@ -16,7 +16,12 @@ | |
| ) | ||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import Callable | ||
|
|
||
| from langgraph.runtime import Runtime | ||
| from langgraph.types import Command | ||
|
|
||
| from langchain.tools.tool_node import ToolCallRequest | ||
|
|
||
|
|
||
| class ToolCallLimitState(AgentState): | ||
|
|
@@ -163,12 +168,12 @@ class ToolCallLimitMiddleware(AgentMiddleware[ToolCallLimitState, Any]): | |
| from langchain.agents.middleware.tool_call_limit import ToolCallLimitMiddleware | ||
| from langchain.agents import create_agent | ||
|
|
||
| # Limit all tool calls globally | ||
| # Limit all tool calls globally - stop entire agent when exceeded | ||
| global_limiter = ToolCallLimitMiddleware(thread_limit=20, run_limit=10, exit_behavior="end") | ||
|
|
||
| # Limit a specific tool | ||
| # Limit a specific tool - block tool execution but let agent continue | ||
| search_limiter = ToolCallLimitMiddleware( | ||
| tool_name="search", thread_limit=5, run_limit=3, exit_behavior="end" | ||
| tool_name="search", thread_limit=5, run_limit=3, exit_behavior="end_tools" | ||
| ) | ||
|
|
||
| # Use both in the same agent | ||
|
|
@@ -186,7 +191,7 @@ def __init__( | |
| tool_name: str | None = None, | ||
| thread_limit: int | None = None, | ||
| run_limit: int | None = None, | ||
| exit_behavior: Literal["end", "error"] = "end", | ||
| exit_behavior: Literal["end", "end_tools", "error"] = "end", | ||
| ) -> None: | ||
| """Initialize the tool call limit middleware. | ||
|
|
||
|
|
@@ -200,6 +205,9 @@ def __init__( | |
| exit_behavior: What to do when limits are exceeded. | ||
| - "end": Jump to the end of the agent execution and | ||
| inject an artificial AI message indicating that the limit was exceeded. | ||
| - "end_tools": Allow the model to request tools, but block tool execution | ||
| when limits are exceeded. The agent receives warning messages and can | ||
| continue with partial results. | ||
| - "error": Raise a ToolCallLimitExceededError | ||
| Defaults to "end". | ||
|
|
||
|
|
@@ -212,8 +220,8 @@ def __init__( | |
| msg = "At least one limit must be specified (thread_limit or run_limit)" | ||
| raise ValueError(msg) | ||
|
|
||
| if exit_behavior not in ("end", "error"): | ||
| msg = f"Invalid exit_behavior: {exit_behavior}. Must be 'end' or 'error'" | ||
| if exit_behavior not in ("end", "end_tools", "error"): | ||
| msg = f"Invalid exit_behavior: {exit_behavior}. Must be 'end', 'end_tools', or 'error'" | ||
| raise ValueError(msg) | ||
|
|
||
| self.tool_name = tool_name | ||
|
|
@@ -237,18 +245,80 @@ def name(self) -> str: | |
| def before_model(self, state: ToolCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002 | ||
| """Check tool call limits before making a model call. | ||
|
|
||
| For `end` and `error` behaviors, this prevents the model from being | ||
| called if limits are already exceeded. For `end_tools` behavior, this first | ||
| counts successful tool executions from the previous iteration, then allows | ||
| the model to run (blocking happens during tool execution instead). | ||
|
|
||
| Args: | ||
| state: The current agent state containing tool call counts. | ||
| runtime: The langgraph runtime. | ||
|
|
||
| Returns: | ||
| If limits are exceeded and exit_behavior is "end", returns | ||
| a Command to jump to the end with a limit exceeded message. Otherwise returns None. | ||
| a Command to jump to the end with a limit exceeded message. | ||
| For end_tools, returns state updates with updated counts. | ||
| Otherwise returns None. | ||
|
|
||
| Raises: | ||
| ToolCallLimitExceededError: If limits are exceeded and exit_behavior | ||
| is "error". | ||
| """ | ||
| # For end_tools behavior, count executions from current run | ||
| if self.exit_behavior == "end_tools": | ||
| messages = state.get("messages", []) | ||
| if not messages: | ||
| return None | ||
|
|
||
| # Only look at messages from the current run (after last HumanMessage) | ||
| run_messages = _get_run_messages(messages) | ||
| if not run_messages: | ||
| return None | ||
|
|
||
| # Count successful tool executions in the current run | ||
| count_key = self.tool_name if self.tool_name else "__all__" | ||
| successful_executions = 0 | ||
|
|
||
| for msg in run_messages: | ||
| if not isinstance(msg, ToolMessage): | ||
| continue | ||
|
|
||
| # Check if this is a limit warning (not a successful execution) | ||
| content = msg.content if isinstance(msg.content, str) else str(msg.content) | ||
| is_limit_warning = "tool call limits exceeded" in content.lower() | ||
|
|
||
| # Check if this tool matches our filter | ||
| if self.tool_name is not None and msg.name != self.tool_name: | ||
| continue | ||
|
|
||
| if not is_limit_warning: | ||
| successful_executions += 1 | ||
|
|
||
| if successful_executions == 0: | ||
| return None | ||
|
|
||
| # Check if we've already updated to this count | ||
| current_run_count = state.get("run_tool_call_count", {}).get(count_key, 0) | ||
|
|
||
| # If we've already counted all executions, don't update again | ||
| if current_run_count >= successful_executions: | ||
| return None | ||
|
|
||
| # Update counts with the delta | ||
| thread_counts = state.get("thread_tool_call_count", {}).copy() | ||
| run_counts = state.get("run_tool_call_count", {}).copy() | ||
|
|
||
| # Calculate how many new executions we haven't counted yet | ||
| new_executions = successful_executions - current_run_count | ||
|
|
||
| thread_counts[count_key] = thread_counts.get(count_key, 0) + new_executions | ||
| run_counts[count_key] = successful_executions | ||
|
|
||
| return { | ||
| "thread_tool_call_count": thread_counts, | ||
| "run_tool_call_count": run_counts, | ||
| } | ||
|
|
||
| # Get the count key for this middleware instance | ||
| count_key = self.tool_name if self.tool_name else "__all__" | ||
|
|
||
|
|
@@ -285,13 +355,21 @@ def before_model(self, state: ToolCallLimitState, runtime: Runtime) -> dict[str, | |
| def after_model(self, state: ToolCallLimitState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002 | ||
| """Increment tool call counts after a model call (when tool calls are made). | ||
|
|
||
| For `end_tools` behavior, counting happens in `before_model` on the next | ||
| iteration (after tools execute). For other behaviors, this increments the | ||
| count based on how many tool calls the model made. | ||
|
|
||
| Args: | ||
| state: The current agent state. | ||
| runtime: The langgraph runtime. | ||
|
|
||
| Returns: | ||
| State updates with incremented tool call counts if tool calls were made. | ||
| """ | ||
| # For end_tools, counting happens in before_model (after tools finish) | ||
| if self.exit_behavior == "end_tools": | ||
| return None | ||
|
|
||
| # Get the last AIMessage to check for tool calls | ||
| messages = state.get("messages", []) | ||
| if not messages: | ||
|
|
@@ -331,3 +409,106 @@ def after_model(self, state: ToolCallLimitState, runtime: Runtime) -> dict[str, | |
| "thread_tool_call_count": thread_counts, | ||
| "run_tool_call_count": run_counts, | ||
| } | ||
|
|
||
| def wrap_tool_call( | ||
| self, | ||
| request: ToolCallRequest, | ||
| execute: Callable[[ToolCallRequest], ToolMessage | Command], | ||
| ) -> ToolMessage | Command: | ||
| """Intercept tool execution to enforce limits for end_tools behavior. | ||
|
|
||
| For `end_tools` behavior, this method checks if executing this specific | ||
| tool would exceed the limits. If so, it returns a warning message instead | ||
| of executing the tool. This allows the agent to continue with partial results. | ||
|
|
||
| The position of the tool call in the model's response is used to determine | ||
| which tools should execute and which should be blocked, even with parallel | ||
| tool execution. | ||
|
|
||
| Args: | ||
| request: The tool call request containing the tool call and state. | ||
| execute: Function to execute the tool call. | ||
|
|
||
| Returns: | ||
| ToolMessage with the tool result, or a warning message if limit exceeded. | ||
| """ | ||
| # Only intercept for end_tools behavior | ||
| if self.exit_behavior != "end_tools": | ||
| return execute(request) | ||
|
|
||
| # Check if this tool matches our filter | ||
| if self.tool_name is not None and request.tool_call["name"] != self.tool_name: | ||
| # This tool doesn't match our filter, execute it without counting | ||
| return execute(request) | ||
|
|
||
| # Get the count key for this middleware instance | ||
| count_key = self.tool_name if self.tool_name else "__all__" | ||
|
|
||
| # Find the last AI message to get the tool call position | ||
| messages = request.state.get("messages", []) | ||
| last_ai_message = None | ||
| for message in reversed(messages): | ||
| if isinstance(message, AIMessage): | ||
| last_ai_message = message | ||
| break | ||
|
|
||
| if not last_ai_message or not last_ai_message.tool_calls: | ||
| # No AI message with tool calls found, execute normally | ||
| return execute(request) | ||
|
|
||
| # Find the position of this tool call in the list | ||
| # Only count tool calls that match our filter | ||
| tool_call_position = None | ||
| for idx, tc in enumerate(last_ai_message.tool_calls): | ||
| # Match by tool_call_id | ||
| if tc["id"] == request.tool_call["id"]: | ||
| # Count how many matching tool calls come before this one | ||
| matching_before = sum( | ||
| 1 | ||
| for i in range(idx) | ||
| if self.tool_name is None | ||
| or last_ai_message.tool_calls[i]["name"] == self.tool_name | ||
| ) | ||
| tool_call_position = matching_before | ||
| break | ||
|
|
||
| # Shouldn't happen, but safety check | ||
| if tool_call_position is None: | ||
| return execute(request) | ||
|
|
||
| # Get current counts from state | ||
| thread_counts = request.state.get("thread_tool_call_count", {}) | ||
| run_counts = request.state.get("run_tool_call_count", {}) | ||
|
|
||
| current_thread_count = thread_counts.get(count_key, 0) | ||
| current_run_count = run_counts.get(count_key, 0) | ||
|
|
||
| # Calculate count after this tool executes (based on position) | ||
| count_after_this_tool = current_thread_count + tool_call_position + 1 | ||
| run_count_after_this_tool = current_run_count + tool_call_position + 1 | ||
|
|
||
| # Check if this tool call would exceed limits | ||
| thread_limit_exceeded = ( | ||
| self.thread_limit is not None and count_after_this_tool > self.thread_limit | ||
| ) | ||
| run_limit_exceeded = ( | ||
| self.run_limit is not None and run_count_after_this_tool > self.run_limit | ||
| ) | ||
|
|
||
| if thread_limit_exceeded or run_limit_exceeded: | ||
| # This tool would exceed the limit - return warning message | ||
| limit_message = _build_tool_limit_exceeded_message( | ||
| thread_count=current_thread_count + tool_call_position, | ||
| run_count=current_run_count + tool_call_position, | ||
| thread_limit=self.thread_limit, | ||
| run_limit=self.run_limit, | ||
| tool_name=self.tool_name, | ||
| ) | ||
| return ToolMessage( | ||
| content=f"{limit_message} Do not call any more tools.", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be specific to this tool? |
||
| tool_call_id=request.tool_call["id"], | ||
| name=request.tool_call["name"], | ||
| ) | ||
|
|
||
| # Within limit - execute the tool | ||
| return execute(request) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not call this handler?