Skip to content

Commit 269dc5c

Browse files
committed
feat: Develop the internal ToolboxClient wrapper
1 parent 1fe2393 commit 269dc5c

File tree

2 files changed

+315
-0
lines changed

2 files changed

+315
-0
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional, Dict, Callable, Any, Union
16+
import toolbox_core
17+
from google.auth import transport
18+
from google.auth import compute_engine
19+
from google.oauth2 import id_token
20+
import google.auth
21+
22+
from .credentials import CredentialConfig, CredentialType
23+
24+
25+
class ToolboxClient:
26+
"""
27+
Wraps toolbox_core.ToolboxClient to provide ADK-native authentication strategy support.
28+
"""
29+
30+
def __init__(
31+
self,
32+
server_url: str,
33+
credentials: Optional[CredentialConfig] = None,
34+
additional_headers: Optional[Dict[str, str]] = None,
35+
**kwargs
36+
):
37+
"""
38+
Args:
39+
server_url: The URL of the Toolbox server.
40+
credentials: The CredentialConfig object (from CredentialStrategy).
41+
additional_headers: A dictionary of static headers to include in every request.
42+
**kwargs: Additional arguments passed to toolbox_core.ToolboxClient.
43+
"""
44+
self._server_url = server_url
45+
self._credentials = credentials
46+
self._additional_headers = additional_headers or {}
47+
48+
# Prepare auth_token_getters for toolbox-core
49+
# toolbox_core expects: dict[str, Callable[[], str | Awaitable[str]]]
50+
# However, for general headers (like Authorization), we can pass them in client_headers
51+
# if they are static or simpler. Toolbox-core supports `client_headers` which can be dynamic.
52+
53+
self._core_client_headers: Dict[str, Union[str, Callable[[], str]]] = {}
54+
55+
# Add static additional headers
56+
for k, v in self._additional_headers.items():
57+
self._core_client_headers[k] = v
58+
59+
if credentials:
60+
self._configure_auth(credentials)
61+
62+
self._client = toolbox_core.ToolboxClient(
63+
server_url=server_url,
64+
client_headers=self._core_client_headers,
65+
**kwargs
66+
)
67+
68+
def _configure_auth(self, creds: CredentialConfig):
69+
if creds.type == CredentialType.TOOLBOX_IDENTITY:
70+
# No auth headers needed
71+
pass
72+
73+
elif creds.type == CredentialType.APPLICATION_DEFAULT_CREDENTIALS:
74+
if not creds.target_audience:
75+
raise ValueError("target_audience is required for APPLICATION_DEFAULT_CREDENTIALS")
76+
77+
# Create an async capable token getter
78+
# We wrap it to match the signature expected by toolbox-core headers
79+
# (which accepts callables)
80+
self._core_client_headers["Authorization"] = self._create_adc_token_getter(creds.target_audience)
81+
82+
elif creds.type == CredentialType.MANUAL_TOKEN:
83+
if not creds.token:
84+
raise ValueError("token is required for MANUAL_TOKEN")
85+
scheme = creds.scheme or "Bearer"
86+
self._core_client_headers["Authorization"] = f"{scheme} {creds.token}"
87+
88+
elif creds.type == CredentialType.MANUAL_CREDS:
89+
if not creds.credentials:
90+
raise ValueError("credentials object is required for MANUAL_CREDS")
91+
92+
# Adapter for google-auth credentials object to callable
93+
self._core_client_headers["Authorization"] = self._create_creds_token_getter(creds.credentials)
94+
95+
elif creds.type == CredentialType.USER_IDENTITY:
96+
# For USER_IDENTITY (3LO), the *Tool* handles the interactive flow at runtime.
97+
# The client itself doesn't set a global header because the token is per-user
98+
# and passed via the tool's execution context (in the future) or handled by the tool wrapper.
99+
# The ToolboxTool wrapper will need to inject the token per-request or we rely on
100+
# toolbox-core's per-request auth support if widely available, but ADK flow involves
101+
# getting the token in `run_async`.
102+
# For now, we leave client-level headers empty for this strategy.
103+
pass
104+
105+
def _create_adc_token_getter(self, audience: str) -> Callable[[], str]:
106+
"""Returns a callable that fetches a fresh ID token using ADC."""
107+
def get_token() -> str:
108+
# Note: This is a synchronous call. Toolbox-core supports sync callables in headers.
109+
# Ideally we would use async but google-auth is primarily sync for these helpers.
110+
request = transport.requests.Request()
111+
# Try to get ID token directly (e.g. on Cloud Run)
112+
try:
113+
token = id_token.fetch_id_token(request, audience)
114+
return f"Bearer {token}"
115+
except Exception:
116+
# Fallback to default credentials (e.g. local gcloud)
117+
creds, _ = google.auth.default()
118+
if not creds.valid:
119+
creds.refresh(request)
120+
# If specific ID token credentials, use them, otherwise this might be Access Token (scoped)
121+
# For Toolbox we usually need ID Tokens.
122+
# If the user is locally auth'd via `gcloud auth login`, fetch_id_token is preferred.
123+
# If falling back to service account file:
124+
if hasattr(creds, 'id_token') and creds.id_token:
125+
return f"Bearer {creds.id_token}"
126+
127+
# If we are here, we might need to manually sign via IAM or similar if it's a generic SA.
128+
# For simplicity in this v1, we assume fetch_id_token works or standard creds work.
129+
# Re-attempt fetch_id_token on the credentials object if possible
130+
curr_token = creds.token
131+
return f"Bearer {curr_token}"
132+
133+
return get_token
134+
135+
def _create_creds_token_getter(self, credentials: Any) -> Callable[[], str]:
136+
def get_token() -> str:
137+
request = transport.requests.Request()
138+
if not credentials.valid:
139+
credentials.refresh(request)
140+
return f"Bearer {credentials.token}"
141+
return get_token
142+
143+
async def load_toolset(self, toolset_name: str, **kwargs):
144+
return await self._client.load_toolset(toolset_name, **kwargs)
145+
146+
async def load_tool(self, tool_name: str, **kwargs):
147+
return await self._client.load_tool(tool_name, **kwargs)
148+
149+
async def close(self):
150+
await self._client.close()
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
from unittest.mock import MagicMock, patch, ANY, AsyncMock
17+
from toolbox_adk.client import ToolboxClient
18+
from toolbox_adk.credentials import CredentialStrategy
19+
20+
class TestToolboxClient:
21+
22+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
23+
def test_init_no_auth(self, mock_core_client):
24+
creds = CredentialStrategy.TOOLBOX_IDENTITY()
25+
client = ToolboxClient("http://server", credentials=creds)
26+
27+
mock_core_client.assert_called_with(
28+
server_url="http://server",
29+
client_headers={}
30+
)
31+
32+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
33+
def test_init_manual_token(self, mock_core_client):
34+
creds = CredentialStrategy.MANUAL_TOKEN("abc")
35+
client = ToolboxClient("http://server", credentials=creds)
36+
37+
mock_core_client.assert_called_with(
38+
server_url="http://server",
39+
client_headers={"Authorization": "Bearer abc"}
40+
)
41+
42+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
43+
def test_init_additional_headers(self, mock_core_client):
44+
creds = CredentialStrategy.TOOLBOX_IDENTITY()
45+
headers = {"X-Custom": "Val"}
46+
client = ToolboxClient(
47+
"http://server",
48+
credentials=creds,
49+
additional_headers=headers
50+
)
51+
52+
mock_core_client.assert_called_with(
53+
server_url="http://server",
54+
client_headers={"X-Custom": "Val"}
55+
)
56+
57+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
58+
@patch("toolbox_adk.client.id_token.fetch_id_token")
59+
def test_adc_auth_flow_success(self, mock_fetch_token, mock_core_client):
60+
mock_fetch_token.return_value = "id_token_123"
61+
62+
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS("http://aud")
63+
client = ToolboxClient("http://server", credentials=creds)
64+
65+
# Verify a callable was passed
66+
args, kwargs = mock_core_client.call_args
67+
assert "Authorization" in kwargs["client_headers"]
68+
token_getter = kwargs["client_headers"]["Authorization"]
69+
assert callable(token_getter)
70+
71+
# Verify callable behavior
72+
token = token_getter()
73+
assert token == "Bearer id_token_123"
74+
mock_fetch_token.assert_called()
75+
76+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
77+
@patch("toolbox_adk.client.id_token.fetch_id_token")
78+
@patch("toolbox_adk.client.google.auth.default")
79+
def test_adc_auth_flow_fallback(self, mock_default, mock_fetch_token, mock_core_client):
80+
# unexpected error on fetch_id_token
81+
mock_fetch_token.side_effect = Exception("No metadata")
82+
83+
mock_creds = MagicMock()
84+
mock_creds.valid = False
85+
mock_creds.id_token = "fallback_id_token"
86+
mock_default.return_value = (mock_creds, "proj")
87+
88+
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS("http://aud")
89+
client = ToolboxClient("http://server", credentials=creds)
90+
91+
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
92+
token = token_getter()
93+
94+
assert token == "Bearer fallback_id_token"
95+
mock_creds.refresh.assert_called()
96+
97+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
98+
def test_manual_creds(self, mock_core_client):
99+
mock_g_creds = MagicMock()
100+
mock_g_creds.valid = False
101+
mock_g_creds.token = "oauth_token"
102+
103+
creds = CredentialStrategy.MANUAL_CREDS(mock_g_creds)
104+
client = ToolboxClient("http://server", credentials=creds)
105+
106+
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
107+
token = token_getter()
108+
109+
assert token == "Bearer oauth_token"
110+
assert token == "Bearer oauth_token"
111+
mock_g_creds.refresh.assert_called()
112+
113+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
114+
def test_init_validation_errors(self, mock_core_client):
115+
# ADC missing audience
116+
with pytest.raises(ValueError, match="target_audience is required"):
117+
# Fix: only pass target_audience as keyword arg OR positional, not both mixed in a way that causes overlap if defined so
118+
# Actually simpler: just pass raw None
119+
ToolboxClient("url", credentials=CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(None))
120+
121+
# Manual token missing token
122+
with pytest.raises(ValueError, match="token is required"):
123+
ToolboxClient("url", credentials=CredentialStrategy.MANUAL_TOKEN(None))
124+
125+
# Manual creds missing credentials
126+
with pytest.raises(ValueError, match="credentials object is required"):
127+
ToolboxClient("url", credentials=CredentialStrategy.MANUAL_CREDS(None))
128+
129+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
130+
@patch("toolbox_adk.client.id_token.fetch_id_token")
131+
@patch("toolbox_adk.client.google.auth.default")
132+
def test_adc_auth_flow_fallback_access_token(self, mock_default, mock_fetch_token, mock_core_client):
133+
# fetch_id_token fails
134+
mock_fetch_token.side_effect = Exception("No metadata")
135+
136+
mock_creds = MagicMock()
137+
mock_creds.valid = False
138+
mock_creds.id_token = None # No ID token
139+
mock_creds.token = "access_token_123"
140+
mock_default.return_value = (mock_creds, "proj")
141+
142+
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS("http://aud")
143+
client = ToolboxClient("http://server", credentials=creds)
144+
145+
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
146+
token = token_getter()
147+
148+
assert token == "Bearer access_token_123"
149+
mock_creds.refresh.assert_called()
150+
151+
@pytest.mark.asyncio
152+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
153+
async def test_delegation(self, mock_core_client):
154+
mock_instance = mock_core_client.return_value
155+
mock_instance.load_toolset = AsyncMock(return_value=["t1"])
156+
mock_instance.close = AsyncMock()
157+
158+
client = ToolboxClient("http://server")
159+
tools = await client.load_toolset("my-set", extra="arg")
160+
161+
mock_instance.load_toolset.assert_awaited_with("my-set", extra="arg")
162+
assert tools == ["t1"]
163+
164+
await client.close()
165+
mock_instance.close.assert_awaited()

0 commit comments

Comments
 (0)