Skip to content

Commit 9497125

Browse files
committed
feat(adk): Implement tool wrapper
1 parent 9f43f3a commit 9497125

File tree

4 files changed

+599
-0
lines changed

4 files changed

+599
-0
lines changed
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Awaitable, Callable, Dict, Optional, cast
16+
17+
import toolbox_core
18+
from fastapi.openapi.models import (
19+
OAuth2,
20+
OAuthFlowAuthorizationCode,
21+
OAuthFlows,
22+
SecurityScheme,
23+
)
24+
from google.adk.agents.readonly_context import ReadonlyContext
25+
from google.adk.auth.auth_credential import (
26+
AuthCredential,
27+
AuthCredentialTypes,
28+
OAuth2Auth,
29+
)
30+
from google.adk.auth.auth_tool import AuthConfig
31+
from google.adk.tools.base_tool import BaseTool
32+
from toolbox_core.tool import ToolboxTool as CoreToolboxTool
33+
from typing_extensions import override
34+
35+
from .client import USER_TOKEN_CONTEXT_VAR
36+
from .credentials import CredentialConfig, CredentialType
37+
38+
39+
class ToolboxContext:
40+
"""Context object passed to pre/post hooks."""
41+
42+
def __init__(self, arguments: Dict[str, Any], tool_context: ReadonlyContext):
43+
self.arguments = arguments
44+
self.tool_context = tool_context
45+
self.result: Optional[Any] = None
46+
self.error: Optional[Exception] = None
47+
48+
49+
class ToolboxTool(BaseTool):
50+
"""
51+
A tool that delegates to a remote Toolbox tool, integrated with ADK.
52+
"""
53+
54+
def __init__(
55+
self,
56+
core_tool: CoreToolboxTool,
57+
pre_hook: Optional[Callable[[ToolboxContext], Awaitable[None]]] = None,
58+
post_hook: Optional[Callable[[ToolboxContext], Awaitable[None]]] = None,
59+
auth_config: Optional[CredentialConfig] = None,
60+
):
61+
"""
62+
Args:
63+
core_tool: The underlying toolbox_core.py tool instance.
64+
pre_hook: Async function called before execution. Can modify ctx.arguments.
65+
post_hook: Async function called after execution (finally block). Can inspect ctx.result/error.
66+
auth_config: Credential configuration to handle interactive flows.
67+
"""
68+
# We act as a proxy.
69+
# We need to extract metadata from the core tool to satisfy BaseTool's contract.
70+
71+
name = getattr(core_tool, "__name__", "unknown_tool")
72+
description = (
73+
getattr(core_tool, "__doc__", "No description provided.")
74+
or "No description provided."
75+
)
76+
77+
super().__init__(
78+
name=name,
79+
description=description,
80+
# We pass empty custom_metadata or whatever is needed
81+
custom_metadata={},
82+
)
83+
self._core_tool = core_tool
84+
self._pre_hook = pre_hook
85+
self._post_hook = post_hook
86+
self._auth_config = auth_config
87+
88+
@override
89+
async def run_async(
90+
self,
91+
args: Dict[str, Any],
92+
tool_context: ReadonlyContext,
93+
) -> Any:
94+
# Create context
95+
ctx = ToolboxContext(arguments=args, tool_context=tool_context)
96+
97+
# 1. Pre-hook
98+
if self._pre_hook:
99+
await self._pre_hook(ctx)
100+
101+
# 2. ADK Auth Integration (3LO)
102+
# Check if USER_IDENTITY is configured
103+
reset_token = None
104+
105+
if self._auth_config and self._auth_config.type == CredentialType.USER_IDENTITY:
106+
if not self._auth_config.client_id or not self._auth_config.client_secret:
107+
raise ValueError("USER_IDENTITY requires client_id and client_secret")
108+
109+
# Construct ADK AuthConfig
110+
scopes = self._auth_config.scopes or [
111+
"https://www.googleapis.com/auth/cloud-platform"
112+
]
113+
scope_dict = {s: "" for s in scopes}
114+
115+
auth_config_adk = AuthConfig(
116+
auth_scheme=OAuth2(
117+
flows=OAuthFlows(
118+
authorizationCode=OAuthFlowAuthorizationCode(
119+
authorizationUrl="https://accounts.google.com/o/oauth2/auth",
120+
tokenUrl="https://oauth2.googleapis.com/token",
121+
scopes=scope_dict,
122+
)
123+
)
124+
),
125+
raw_auth_credential=AuthCredential(
126+
auth_type=AuthCredentialTypes.OAUTH2,
127+
oauth2=OAuth2Auth(
128+
client_id=self._auth_config.client_id,
129+
client_secret=self._auth_config.client_secret,
130+
),
131+
),
132+
)
133+
134+
# Check if we already have credentials from a previous exchange
135+
try:
136+
# get_auth_response returns AuthCredential if found
137+
ctx_any = cast(Any, tool_context)
138+
creds = ctx_any.get_auth_response(auth_config_adk)
139+
if creds and creds.oauth2 and creds.oauth2.access_token:
140+
reset_token = USER_TOKEN_CONTEXT_VAR.set(creds.oauth2.access_token)
141+
else:
142+
# Request credentials and pause execution
143+
ctx_any.request_credential(auth_config_adk)
144+
return None
145+
except Exception as e:
146+
ctx.error = e
147+
if "credential" in str(e).lower() or isinstance(e, ValueError):
148+
raise e
149+
# Fallback to request logic
150+
ctx_any = cast(Any, tool_context)
151+
ctx_any.request_credential(auth_config_adk)
152+
return None
153+
154+
try:
155+
# Execute the core tool
156+
result = await self._core_tool(**ctx.arguments)
157+
158+
ctx.result = result
159+
return result
160+
161+
except Exception as e:
162+
ctx.error = e
163+
raise e
164+
finally:
165+
if reset_token:
166+
USER_TOKEN_CONTEXT_VAR.reset(reset_token)
167+
if self._post_hook:
168+
await self._post_hook(ctx)
169+
170+
def bind_params(self, bounded_params: Dict[str, Any]) -> "ToolboxTool":
171+
"""Allows runtime binding of parameters, delegating to core tool."""
172+
new_core_tool = self._core_tool.bind_params(bounded_params)
173+
# Return a new wrapper
174+
return ToolboxTool(
175+
core_tool=new_core_tool,
176+
pre_hook=self._pre_hook,
177+
post_hook=self._post_hook,
178+
auth_config=self._auth_config,
179+
)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Union
16+
17+
from google.adk.agents.readonly_context import ReadonlyContext
18+
from google.adk.tools.base_tool import BaseTool
19+
from google.adk.tools.base_toolset import BaseToolset
20+
from typing_extensions import override
21+
22+
from .client import ToolboxClient
23+
from .credentials import CredentialConfig
24+
from .tool import ToolboxContext, ToolboxTool
25+
26+
27+
class ToolboxToolset(BaseToolset):
28+
"""
29+
A Toolset that provides tools from a remote Toolbox server.
30+
"""
31+
32+
def __init__(
33+
self,
34+
server_url: str,
35+
toolset_name: Optional[str] = None,
36+
tool_names: Optional[List[str]] = None,
37+
credentials: Optional[CredentialConfig] = None,
38+
additional_headers: Optional[
39+
Dict[str, Union[str, Callable[[], str], Callable[[], Awaitable[str]]]]
40+
] = None,
41+
bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None,
42+
pre_hook: Optional[Callable[[ToolboxContext], Awaitable[None]]] = None,
43+
post_hook: Optional[Callable[[ToolboxContext], Awaitable[None]]] = None,
44+
**kwargs: Any,
45+
):
46+
"""
47+
Args:
48+
server_url: The URL of the Toolbox server.
49+
toolset_name: The name of the remote toolset to load.
50+
tool_names: Specific tool names to load (alternative to toolset_name).
51+
credentials: Authentication configuration.
52+
additional_headers: Extra headers (static or dynamic).
53+
bound_params: Parameters to bind globally to loaded tools.
54+
pre_hook: Hook to run before every tool execution.
55+
post_hook: Hook to run after every tool execution.
56+
"""
57+
super().__init__()
58+
self._client = ToolboxClient(
59+
server_url=server_url,
60+
credentials=credentials,
61+
additional_headers=additional_headers,
62+
**kwargs,
63+
)
64+
self._toolset_name = toolset_name
65+
self._tool_names = tool_names
66+
self._bound_params = bound_params
67+
self._pre_hook = pre_hook
68+
self._post_hook = post_hook
69+
70+
@override
71+
async def get_tools(
72+
self, readonly_context: Optional[ReadonlyContext] = None
73+
) -> List[BaseTool]:
74+
"""Loads tools from the toolbox server and wraps them."""
75+
# Note: We don't close the client after get_tools because tools might need it.
76+
77+
tools = []
78+
if self._toolset_name:
79+
core_tools = await self._client.load_toolset(
80+
self._toolset_name, bound_params=self._bound_params
81+
)
82+
tools.extend(core_tools)
83+
84+
if self._tool_names:
85+
for name in self._tool_names:
86+
core_tool = await self._client.load_tool(
87+
name, bound_params=self._bound_params
88+
)
89+
tools.append(core_tool)
90+
91+
# Wrap all core tools in ToolboxTool
92+
return [
93+
ToolboxTool(
94+
core_tool=t,
95+
pre_hook=self._pre_hook,
96+
post_hook=self._post_hook,
97+
auth_config=self._client.credential_config,
98+
)
99+
for t in tools
100+
]
101+
102+
async def close(self):
103+
await self._client.close()

0 commit comments

Comments
 (0)