Skip to content

Commit 9f43f3a

Browse files
committed
feat(adk): Implement client wrapper
1 parent b206d66 commit 9f43f3a

File tree

2 files changed

+345
-0
lines changed

2 files changed

+345
-0
lines changed
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from contextvars import ContextVar
16+
from typing import Any, Awaitable, Callable, Dict, Optional, Union
17+
18+
import google.auth
19+
import toolbox_core
20+
from google.auth import compute_engine, transport
21+
from google.oauth2 import id_token
22+
23+
from .credentials import CredentialConfig, CredentialType
24+
25+
USER_TOKEN_CONTEXT_VAR: ContextVar[Optional[str]] = ContextVar(
26+
"toolbox_user_token", default=None
27+
)
28+
29+
30+
class ToolboxClient:
31+
"""
32+
Wraps toolbox_core.ToolboxClient to provide ADK-native authentication strategy support.
33+
"""
34+
35+
def __init__(
36+
self,
37+
server_url: str,
38+
credentials: Optional[CredentialConfig] = None,
39+
additional_headers: Optional[
40+
Dict[str, Union[str, Callable[[], str], Callable[[], Awaitable[str]]]]
41+
] = None,
42+
**kwargs: Any,
43+
):
44+
"""
45+
Args:
46+
server_url: The URL of the Toolbox server.
47+
credentials: The CredentialConfig object (from CredentialStrategy).
48+
additional_headers: Dictionary of headers (static or dynamic callables).
49+
**kwargs: Additional arguments passed to toolbox_core.ToolboxClient.
50+
"""
51+
self._server_url = server_url
52+
self._credentials = credentials
53+
self._additional_headers = additional_headers or {}
54+
55+
self._core_client_headers: Dict[
56+
str, Union[str, Callable[[], str], Callable[[], Awaitable[str]]]
57+
] = {}
58+
59+
# Add static additional headers
60+
for k, v in self._additional_headers.items():
61+
self._core_client_headers[k] = v
62+
63+
if credentials:
64+
self._configure_auth(credentials)
65+
66+
self._client = toolbox_core.ToolboxClient(
67+
url=server_url, client_headers=self._core_client_headers, **kwargs
68+
)
69+
70+
def _configure_auth(self, creds: CredentialConfig) -> None:
71+
if creds.type == CredentialType.TOOLBOX_IDENTITY:
72+
# No auth headers needed
73+
pass
74+
75+
elif creds.type == CredentialType.APPLICATION_DEFAULT_CREDENTIALS:
76+
if not creds.target_audience:
77+
raise ValueError(
78+
"target_audience is required for APPLICATION_DEFAULT_CREDENTIALS"
79+
)
80+
81+
# Create an async capable token getter
82+
self._core_client_headers["Authorization"] = self._create_adc_token_getter(
83+
creds.target_audience
84+
)
85+
86+
elif creds.type == CredentialType.MANUAL_TOKEN:
87+
if not creds.token:
88+
raise ValueError("token is required for MANUAL_TOKEN")
89+
scheme = creds.scheme or "Bearer"
90+
self._core_client_headers["Authorization"] = f"{scheme} {creds.token}"
91+
92+
elif creds.type == CredentialType.MANUAL_CREDS:
93+
if not creds.credentials:
94+
raise ValueError("credentials object is required for MANUAL_CREDS")
95+
96+
# Adapter for google-auth credentials object to callable
97+
self._core_client_headers["Authorization"] = (
98+
self._create_creds_token_getter(creds.credentials)
99+
)
100+
101+
elif creds.type == CredentialType.USER_IDENTITY:
102+
# For USER_IDENTITY (3LO), the *Tool* handles the interactive flow at runtime.
103+
104+
def get_user_token() -> str:
105+
token = USER_TOKEN_CONTEXT_VAR.get()
106+
if not token:
107+
return ""
108+
return f"Bearer {token}"
109+
110+
self._core_client_headers["Authorization"] = get_user_token
111+
112+
def _create_adc_token_getter(self, audience: str) -> Callable[[], str]:
113+
"""Returns a callable that fetches a fresh ID token using ADC."""
114+
115+
def get_token() -> str:
116+
request = transport.requests.Request()
117+
try:
118+
token = id_token.fetch_id_token(request, audience)
119+
return f"Bearer {token}"
120+
except Exception:
121+
# Fallback to default credentials
122+
creds, _ = google.auth.default()
123+
if not creds.valid:
124+
creds.refresh(request)
125+
126+
if hasattr(creds, "id_token") and creds.id_token:
127+
return f"Bearer {creds.id_token}"
128+
129+
curr_token = getattr(creds, "token", None)
130+
return f"Bearer {curr_token}" if curr_token else ""
131+
132+
return get_token
133+
134+
def _create_creds_token_getter(self, credentials: Any) -> Callable[[], str]:
135+
def get_token() -> str:
136+
request = transport.requests.Request()
137+
if not credentials.valid:
138+
credentials.refresh(request)
139+
return f"Bearer {credentials.token}"
140+
141+
return get_token
142+
143+
@property
144+
def credential_config(self) -> Optional[CredentialConfig]:
145+
return self._credentials
146+
147+
async def load_toolset(self, toolset_name: str, **kwargs: Any) -> Any:
148+
return await self._client.load_toolset(toolset_name, **kwargs)
149+
150+
async def load_tool(self, tool_name: str, **kwargs: Any) -> Any:
151+
return await self._client.load_tool(tool_name, **kwargs)
152+
153+
async def close(self):
154+
await self._client.close()
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
from unittest.mock import AsyncMock, MagicMock, patch
17+
18+
import pytest
19+
20+
from toolbox_adk import CredentialStrategy, ToolboxClient
21+
from toolbox_adk.client import CredentialType
22+
23+
24+
@pytest.mark.asyncio
25+
class TestToolboxClientAuth:
26+
"""Unit tests for Client Auth logic."""
27+
28+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
29+
async def test_init_toolbox_identity(self, mock_core_client):
30+
"""Test init with TOOLBOX_IDENTITY (no auth headers)."""
31+
creds = CredentialStrategy.TOOLBOX_IDENTITY()
32+
client = ToolboxClient(server_url="http://test", credentials=creds)
33+
34+
# Verify core client created with empty headers for auth
35+
_, kwargs = mock_core_client.call_args
36+
assert "client_headers" in kwargs
37+
headers = kwargs["client_headers"]
38+
assert "Authorization" not in headers
39+
40+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
41+
@patch("toolbox_adk.client.id_token.fetch_id_token")
42+
@patch("toolbox_adk.client.google.auth.default")
43+
@patch("toolbox_adk.client.transport.requests.Request")
44+
async def test_init_adc_success_fetch_id_token(
45+
self, mock_req, mock_default, mock_fetch_id, mock_core_client
46+
):
47+
"""Test ADC strategy where fetch_id_token succeeds."""
48+
mock_fetch_id.return_value = "id-token-123"
49+
50+
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(
51+
target_audience="aud"
52+
)
53+
client = ToolboxClient(server_url="http://test", credentials=creds)
54+
55+
_, kwargs = mock_core_client.call_args
56+
headers = kwargs["client_headers"]
57+
assert "Authorization" in headers
58+
token_getter = headers["Authorization"]
59+
assert callable(token_getter)
60+
61+
# Call the getter
62+
token_val = token_getter()
63+
assert token_val == "Bearer id-token-123"
64+
mock_fetch_id.assert_called()
65+
66+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
67+
@patch("toolbox_adk.client.id_token.fetch_id_token")
68+
@patch("toolbox_adk.client.google.auth.default")
69+
@patch("toolbox_adk.client.transport.requests.Request")
70+
async def test_init_adc_fallback_creds(
71+
self, mock_req, mock_default, mock_fetch_id, mock_core_client
72+
):
73+
"""Test ADC strategy fallback to default() when fetch_id_token fails."""
74+
mock_fetch_id.side_effect = Exception("No metadata server")
75+
76+
# Mock default creds
77+
mock_creds = MagicMock()
78+
mock_creds.valid = False
79+
mock_creds.id_token = "fallback-id-token"
80+
mock_default.return_value = (mock_creds, "proj")
81+
82+
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(
83+
target_audience="aud"
84+
)
85+
client = ToolboxClient(server_url="http://test", credentials=creds)
86+
87+
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
88+
token = token_getter()
89+
assert token == "Bearer fallback-id-token"
90+
mock_creds.refresh.assert_called() # Because we set valid=False
91+
92+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
93+
@patch("toolbox_adk.client.id_token.fetch_id_token")
94+
@patch("toolbox_adk.client.google.auth.default")
95+
@patch("toolbox_adk.client.transport.requests.Request")
96+
async def test_init_adc_fallback_creds_token(
97+
self, mock_req, mock_default, mock_fetch_id, mock_core_client
98+
):
99+
"""Test ADC fallback when creds have .token but no .id_token."""
100+
mock_fetch_id.side_effect = Exception("No metadata server")
101+
102+
mock_creds = MagicMock()
103+
mock_creds.valid = True
104+
del mock_creds.id_token # Simulate no id_token attr or None
105+
mock_creds.token = "access-token-123" # e.g. user creds
106+
mock_default.return_value = (mock_creds, "proj")
107+
108+
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(
109+
target_audience="aud"
110+
)
111+
client = ToolboxClient(server_url="http://test", credentials=creds)
112+
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
113+
assert token_getter() == "Bearer access-token-123"
114+
115+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
116+
async def test_init_manual_token(self, mock_core_client):
117+
creds = CredentialStrategy.MANUAL_TOKEN(token="abc")
118+
client = ToolboxClient("http://test", credentials=creds)
119+
headers = mock_core_client.call_args[1]["client_headers"]
120+
assert headers["Authorization"] == "Bearer abc"
121+
122+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
123+
async def test_init_manual_creds(self, mock_core_client):
124+
mock_google_creds = MagicMock()
125+
mock_google_creds.valid = True
126+
mock_google_creds.token = "creds-token"
127+
128+
creds = CredentialStrategy.MANUAL_CREDS(credentials=mock_google_creds)
129+
client = ToolboxClient("http://test", credentials=creds)
130+
131+
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
132+
assert token_getter() == "Bearer creds-token"
133+
134+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
135+
async def test_init_user_identity(self, mock_core_client):
136+
creds = CredentialStrategy.USER_IDENTITY(client_id="c", client_secret="s")
137+
client = ToolboxClient("http://test", credentials=creds)
138+
139+
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
140+
# Should be empty initially
141+
assert token_getter() == ""
142+
143+
# Set context
144+
from toolbox_adk.client import USER_TOKEN_CONTEXT_VAR
145+
146+
token = USER_TOKEN_CONTEXT_VAR.set("user-tok")
147+
try:
148+
assert token_getter() == "Bearer user-tok"
149+
finally:
150+
USER_TOKEN_CONTEXT_VAR.reset(token)
151+
152+
async def test_validation_errors(self):
153+
with pytest.raises(ValueError):
154+
# ADC requires audience
155+
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(
156+
target_audience=""
157+
)
158+
ToolboxClient("http://test", credentials=creds)
159+
160+
with pytest.raises(ValueError):
161+
creds = CredentialStrategy.MANUAL_TOKEN(token="")
162+
ToolboxClient("http://test", credentials=creds)
163+
164+
with pytest.raises(ValueError):
165+
creds = CredentialStrategy.MANUAL_CREDS(credentials=None)
166+
ToolboxClient("http://test", credentials=creds)
167+
168+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
169+
async def test_load_methods(self, mock_core_client_class):
170+
# Setup mock instance
171+
mock_instance = AsyncMock()
172+
mock_core_client_class.return_value = mock_instance
173+
174+
client = ToolboxClient(
175+
"http://test", credentials=CredentialStrategy.TOOLBOX_IDENTITY()
176+
)
177+
178+
# Test load_toolset
179+
await client.load_toolset("ts", foo="bar")
180+
mock_instance.load_toolset.assert_called_with("ts", foo="bar")
181+
182+
# Test load_tool
183+
await client.load_tool("t", baz="qux")
184+
mock_instance.load_tool.assert_called_with("t", baz="qux")
185+
186+
# Test close
187+
await client.close()
188+
mock_instance.close.assert_called_once()
189+
190+
# Test property
191+
assert client.credential_config is not None

0 commit comments

Comments
 (0)