Skip to content

Commit 7e000e2

Browse files
committed
test(client): add comprehensive unit tests for 100% coverage
1 parent 3a4877e commit 7e000e2

File tree

2 files changed

+129
-129
lines changed

2 files changed

+129
-129
lines changed

packages/toolbox-adk/src/toolbox_adk/client.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
from .credentials import CredentialConfig, CredentialType
2424

25-
# Global ContextVar for User Identity (3LO) tokens to be injected per-request
2625
USER_TOKEN_CONTEXT_VAR: ContextVar[Optional[str]] = ContextVar(
2726
"toolbox_user_token", default=None
2827
)
@@ -53,11 +52,6 @@ def __init__(
5352
self._credentials = credentials
5453
self._additional_headers = additional_headers or {}
5554

56-
# Prepare auth_token_getters for toolbox-core
57-
# toolbox_core expects: dict[str, Callable[[], str | Awaitable[str]]]
58-
# However, for general headers (like Authorization), we can pass them in client_headers
59-
# if they are static or simpler. Toolbox-core supports `client_headers` which can be dynamic.
60-
6155
self._core_client_headers: Dict[
6256
str, Union[str, Callable[[], str], Callable[[], Awaitable[str]]]
6357
] = {}
@@ -85,8 +79,6 @@ def _configure_auth(self, creds: CredentialConfig) -> None:
8579
)
8680

8781
# Create an async capable token getter
88-
# We wrap it to match the signature expected by toolbox-core headers
89-
# (which accepts callables)
9082
self._core_client_headers["Authorization"] = self._create_adc_token_getter(
9183
creds.target_audience
9284
)
@@ -108,14 +100,10 @@ def _configure_auth(self, creds: CredentialConfig) -> None:
108100

109101
elif creds.type == CredentialType.USER_IDENTITY:
110102
# For USER_IDENTITY (3LO), the *Tool* handles the interactive flow at runtime.
111-
# We use a ContextVar to inject the token per-request.
112103

113104
def get_user_token() -> str:
114105
token = USER_TOKEN_CONTEXT_VAR.get()
115106
if not token:
116-
# If this is called but no token is set in context, it means
117-
# the tool wrapper failed to set it or we are in a context where
118-
# we expected it. We return empty string which might cause 401.
119107
return ""
120108
return f"Bearer {token}"
121109

@@ -125,28 +113,19 @@ def _create_adc_token_getter(self, audience: str) -> Callable[[], str]:
125113
"""Returns a callable that fetches a fresh ID token using ADC."""
126114

127115
def get_token() -> str:
128-
# Note: This is a synchronous call. Toolbox-core supports sync callables in headers.
129-
# Ideally we would use async but google-auth is primarily sync for these helpers.
130116
request = transport.requests.Request()
131-
# Try to get ID token directly (e.g. on Cloud Run)
132117
try:
133118
token = id_token.fetch_id_token(request, audience)
134119
return f"Bearer {token}"
135120
except Exception:
136-
# Fallback to default credentials (e.g. local gcloud)
121+
# Fallback to default credentials
137122
creds, _ = google.auth.default()
138123
if not creds.valid:
139124
creds.refresh(request)
140-
# If specific ID token credentials, use them, otherwise this might be Access Token (scoped)
141-
# For Toolbox we usually need ID Tokens.
142-
# If the user is locally auth'd via `gcloud auth login`, fetch_id_token is preferred.
143-
# If falling back to service account file:
125+
144126
if hasattr(creds, "id_token") and creds.id_token:
145127
return f"Bearer {creds.id_token}"
146128

147-
# If we are here, we might need to manually sign via IAM or similar if it's a generic SA.
148-
# For simplicity in this v1, we assume fetch_id_token works or standard creds work.
149-
# Re-attempt fetch_id_token on the credentials object if possible
150129
curr_token = getattr(creds, "token", None)
151130
return f"Bearer {curr_token}" if curr_token else ""
152131

packages/toolbox-adk/tests/unit/test_client.py

Lines changed: 127 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -12,159 +12,180 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from unittest.mock import ANY, AsyncMock, MagicMock, patch
15+
import unittest
16+
from unittest.mock import AsyncMock, MagicMock, patch
1617

1718
import pytest
1819

19-
from toolbox_adk.client import ToolboxClient
20-
from toolbox_adk.credentials import CredentialStrategy
20+
from toolbox_adk import CredentialStrategy, ToolboxClient
21+
from toolbox_adk.client import CredentialType
2122

2223

23-
class TestToolboxClient:
24+
@pytest.mark.asyncio
25+
class TestToolboxClientAuth:
26+
"""Unit tests for Client Auth logic."""
2427

2528
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
26-
def test_init_no_auth(self, mock_core_client):
29+
async def test_init_toolbox_identity(self, mock_core_client):
30+
"""Test init with TOOLBOX_IDENTITY (no auth headers)."""
2731
creds = CredentialStrategy.TOOLBOX_IDENTITY()
28-
client = ToolboxClient("http://server", credentials=creds)
32+
client = ToolboxClient(server_url="http://test", credentials=creds)
2933

30-
mock_core_client.assert_called_with(
31-
server_url="http://server", client_headers={}
32-
)
33-
34-
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
35-
def test_init_manual_token(self, mock_core_client):
36-
creds = CredentialStrategy.MANUAL_TOKEN("abc")
37-
client = ToolboxClient("http://server", credentials=creds)
38-
39-
mock_core_client.assert_called_with(
40-
server_url="http://server", client_headers={"Authorization": "Bearer abc"}
41-
)
42-
43-
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
44-
def test_init_additional_headers(self, mock_core_client):
45-
creds = CredentialStrategy.TOOLBOX_IDENTITY()
46-
headers = {"X-Custom": "Val"}
47-
client = ToolboxClient(
48-
"http://server", credentials=creds, additional_headers=headers
49-
)
50-
51-
mock_core_client.assert_called_with(
52-
server_url="http://server", client_headers={"X-Custom": "Val"}
53-
)
34+
# Verify core client created with empty headers for auth
35+
_, kwargs = mock_core_client.call_args
36+
assert "client_headers" in kwargs
37+
headers = kwargs["client_headers"]
38+
assert "Authorization" not in headers
5439

5540
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
5641
@patch("toolbox_adk.client.id_token.fetch_id_token")
57-
def test_adc_auth_flow_success(self, mock_fetch_token, mock_core_client):
58-
mock_fetch_token.return_value = "id_token_123"
42+
@patch("toolbox_adk.client.google.auth.default")
43+
@patch("toolbox_adk.client.transport.requests.Request")
44+
async def test_init_adc_success_fetch_id_token(
45+
self, mock_req, mock_default, mock_fetch_id, mock_core_client
46+
):
47+
"""Test ADC strategy where fetch_id_token succeeds."""
48+
mock_fetch_id.return_value = "id-token-123"
5949

60-
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS("http://aud")
61-
client = ToolboxClient("http://server", credentials=creds)
50+
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(
51+
target_audience="aud"
52+
)
53+
client = ToolboxClient(server_url="http://test", credentials=creds)
6254

63-
# Verify a callable was passed
64-
args, kwargs = mock_core_client.call_args
65-
assert "Authorization" in kwargs["client_headers"]
66-
token_getter = kwargs["client_headers"]["Authorization"]
55+
_, kwargs = mock_core_client.call_args
56+
headers = kwargs["client_headers"]
57+
assert "Authorization" in headers
58+
token_getter = headers["Authorization"]
6759
assert callable(token_getter)
6860

69-
# Verify callable behavior
70-
token = token_getter()
71-
assert token == "Bearer id_token_123"
72-
mock_fetch_token.assert_called()
61+
# Call the getter
62+
token_val = token_getter()
63+
assert token_val == "Bearer id-token-123"
64+
mock_fetch_id.assert_called()
7365

7466
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
7567
@patch("toolbox_adk.client.id_token.fetch_id_token")
7668
@patch("toolbox_adk.client.google.auth.default")
77-
def test_adc_auth_flow_fallback(
78-
self, mock_default, mock_fetch_token, mock_core_client
69+
@patch("toolbox_adk.client.transport.requests.Request")
70+
async def test_init_adc_fallback_creds(
71+
self, mock_req, mock_default, mock_fetch_id, mock_core_client
7972
):
80-
# unexpected error on fetch_id_token
81-
mock_fetch_token.side_effect = Exception("No metadata")
73+
"""Test ADC strategy fallback to default() when fetch_id_token fails."""
74+
mock_fetch_id.side_effect = Exception("No metadata server")
8275

76+
# Mock default creds
8377
mock_creds = MagicMock()
8478
mock_creds.valid = False
85-
mock_creds.id_token = "fallback_id_token"
79+
mock_creds.id_token = "fallback-id-token"
8680
mock_default.return_value = (mock_creds, "proj")
8781

88-
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS("http://aud")
89-
client = ToolboxClient("http://server", credentials=creds)
82+
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(
83+
target_audience="aud"
84+
)
85+
client = ToolboxClient(server_url="http://test", credentials=creds)
9086

9187
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
9288
token = token_getter()
93-
94-
assert token == "Bearer fallback_id_token"
95-
mock_creds.refresh.assert_called()
89+
assert token == "Bearer fallback-id-token"
90+
mock_creds.refresh.assert_called() # Because we set valid=False
9691

9792
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
98-
def test_manual_creds(self, mock_core_client):
99-
mock_g_creds = MagicMock()
100-
mock_g_creds.valid = False
101-
mock_g_creds.token = "oauth_token"
93+
@patch("toolbox_adk.client.id_token.fetch_id_token")
94+
@patch("toolbox_adk.client.google.auth.default")
95+
@patch("toolbox_adk.client.transport.requests.Request")
96+
async def test_init_adc_fallback_creds_token(
97+
self, mock_req, mock_default, mock_fetch_id, mock_core_client
98+
):
99+
"""Test ADC fallback when creds have .token but no .id_token."""
100+
mock_fetch_id.side_effect = Exception("No metadata server")
102101

103-
creds = CredentialStrategy.MANUAL_CREDS(mock_g_creds)
104-
client = ToolboxClient("http://server", credentials=creds)
102+
mock_creds = MagicMock()
103+
mock_creds.valid = True
104+
del mock_creds.id_token # Simulate no id_token attr or None
105+
mock_creds.token = "access-token-123" # e.g. user creds
106+
mock_default.return_value = (mock_creds, "proj")
105107

108+
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(
109+
target_audience="aud"
110+
)
111+
client = ToolboxClient(server_url="http://test", credentials=creds)
106112
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
107-
token = token_getter()
113+
assert token_getter() == "Bearer access-token-123"
108114

109-
assert token == "Bearer oauth_token"
110-
assert token == "Bearer oauth_token"
111-
mock_g_creds.refresh.assert_called()
115+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
116+
async def test_init_manual_token(self, mock_core_client):
117+
creds = CredentialStrategy.MANUAL_TOKEN(token="abc")
118+
client = ToolboxClient("http://test", credentials=creds)
119+
headers = mock_core_client.call_args[1]["client_headers"]
120+
assert headers["Authorization"] == "Bearer abc"
112121

113122
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
114-
def test_init_validation_errors(self, mock_core_client):
115-
# ADC missing audience
116-
with pytest.raises(ValueError, match="target_audience is required"):
117-
# Fix: only pass target_audience as keyword arg OR positional, not both mixed in a way that causes overlap if defined so
118-
# Actually simpler: just pass raw None
119-
ToolboxClient(
120-
"url",
121-
credentials=CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(None),
122-
)
123+
async def test_init_manual_creds(self, mock_core_client):
124+
mock_google_creds = MagicMock()
125+
mock_google_creds.valid = True
126+
mock_google_creds.token = "creds-token"
123127

124-
# Manual token missing token
125-
with pytest.raises(ValueError, match="token is required"):
126-
ToolboxClient("url", credentials=CredentialStrategy.MANUAL_TOKEN(None))
128+
creds = CredentialStrategy.MANUAL_CREDS(credentials=mock_google_creds)
129+
client = ToolboxClient("http://test", credentials=creds)
127130

128-
# Manual creds missing credentials
129-
with pytest.raises(ValueError, match="credentials object is required"):
130-
ToolboxClient("url", credentials=CredentialStrategy.MANUAL_CREDS(None))
131+
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
132+
assert token_getter() == "Bearer creds-token"
131133

132134
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
133-
@patch("toolbox_adk.client.id_token.fetch_id_token")
134-
@patch("toolbox_adk.client.google.auth.default")
135-
def test_adc_auth_flow_fallback_access_token(
136-
self, mock_default, mock_fetch_token, mock_core_client
137-
):
138-
# fetch_id_token fails
139-
mock_fetch_token.side_effect = Exception("No metadata")
140-
141-
mock_creds = MagicMock()
142-
mock_creds.valid = False
143-
mock_creds.id_token = None # No ID token
144-
mock_creds.token = "access_token_123"
145-
mock_default.return_value = (mock_creds, "proj")
146-
147-
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS("http://aud")
148-
client = ToolboxClient("http://server", credentials=creds)
135+
async def test_init_user_identity(self, mock_core_client):
136+
creds = CredentialStrategy.USER_IDENTITY(client_id="c", client_secret="s")
137+
client = ToolboxClient("http://test", credentials=creds)
149138

150139
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
151-
token = token_getter()
140+
# Should be empty initially
141+
assert token_getter() == ""
142+
143+
# Set context
144+
from toolbox_adk.client import USER_TOKEN_CONTEXT_VAR
145+
146+
token = USER_TOKEN_CONTEXT_VAR.set("user-tok")
147+
try:
148+
assert token_getter() == "Bearer user-tok"
149+
finally:
150+
USER_TOKEN_CONTEXT_VAR.reset(token)
151+
152+
async def test_validation_errors(self):
153+
with pytest.raises(ValueError):
154+
# ADC requires audience
155+
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(
156+
target_audience=""
157+
)
158+
ToolboxClient("http://test", credentials=creds)
152159

153-
assert token == "Bearer access_token_123"
154-
mock_creds.refresh.assert_called()
160+
with pytest.raises(ValueError):
161+
creds = CredentialStrategy.MANUAL_TOKEN(token="")
162+
ToolboxClient("http://test", credentials=creds)
163+
164+
with pytest.raises(ValueError):
165+
creds = CredentialStrategy.MANUAL_CREDS(credentials=None)
166+
ToolboxClient("http://test", credentials=creds)
155167

156-
@pytest.mark.asyncio
157168
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
158-
async def test_delegation(self, mock_core_client):
159-
mock_instance = mock_core_client.return_value
160-
mock_instance.load_toolset = AsyncMock(return_value=["t1"])
161-
mock_instance.close = AsyncMock()
169+
async def test_load_methods(self, mock_core_client_class):
170+
# Setup mock instance
171+
mock_instance = AsyncMock()
172+
mock_core_client_class.return_value = mock_instance
162173

163-
client = ToolboxClient("http://server")
164-
tools = await client.load_toolset("my-set", extra="arg")
174+
client = ToolboxClient(
175+
"http://test", credentials=CredentialStrategy.TOOLBOX_IDENTITY()
176+
)
165177

166-
mock_instance.load_toolset.assert_awaited_with("my-set", extra="arg")
167-
assert tools == ["t1"]
178+
# Test load_toolset
179+
await client.load_toolset("ts", foo="bar")
180+
mock_instance.load_toolset.assert_called_with("ts", foo="bar")
168181

182+
# Test load_tool
183+
await client.load_tool("t", baz="qux")
184+
mock_instance.load_tool.assert_called_with("t", baz="qux")
185+
186+
# Test close
169187
await client.close()
170-
mock_instance.close.assert_awaited()
188+
mock_instance.close.assert_called_once()
189+
190+
# Test property
191+
assert client.credential_config is not None

0 commit comments

Comments
 (0)