Skip to content

Commit d81d97c

Browse files
committed
test(client): add comprehensive unit tests for 100% coverage
1 parent c87a69a commit d81d97c

File tree

2 files changed

+125
-147
lines changed

2 files changed

+125
-147
lines changed

packages/toolbox-adk/src/toolbox_adk/client.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
from .credentials import CredentialConfig, CredentialType
2424

25-
# Global ContextVar for User Identity (3LO) tokens to be injected per-request
2625
USER_TOKEN_CONTEXT_VAR: ContextVar[Optional[str]] = ContextVar(
2726
"toolbox_user_token", default=None
2827
)
@@ -53,11 +52,6 @@ def __init__(
5352
self._credentials = credentials
5453
self._additional_headers = additional_headers or {}
5554

56-
# Prepare auth_token_getters for toolbox-core
57-
# toolbox_core expects: dict[str, Callable[[], str | Awaitable[str]]]
58-
# However, for general headers (like Authorization), we can pass them in client_headers
59-
# if they are static or simpler. Toolbox-core supports `client_headers` which can be dynamic.
60-
6155
self._core_client_headers: Dict[
6256
str, Union[str, Callable[[], str], Callable[[], Awaitable[str]]]
6357
] = {}
@@ -85,8 +79,6 @@ def _configure_auth(self, creds: CredentialConfig) -> None:
8579
)
8680

8781
# Create an async capable token getter
88-
# We wrap it to match the signature expected by toolbox-core headers
89-
# (which accepts callables)
9082
self._core_client_headers["Authorization"] = self._create_adc_token_getter(
9183
creds.target_audience
9284
)
@@ -108,14 +100,10 @@ def _configure_auth(self, creds: CredentialConfig) -> None:
108100

109101
elif creds.type == CredentialType.USER_IDENTITY:
110102
# For USER_IDENTITY (3LO), the *Tool* handles the interactive flow at runtime.
111-
# We use a ContextVar to inject the token per-request.
112103

113104
def get_user_token() -> str:
114105
token = USER_TOKEN_CONTEXT_VAR.get()
115106
if not token:
116-
# If this is called but no token is set in context, it means
117-
# the tool wrapper failed to set it or we are in a context where
118-
# we expected it. We return empty string which might cause 401.
119107
return ""
120108
return f"Bearer {token}"
121109

@@ -125,28 +113,19 @@ def _create_adc_token_getter(self, audience: str) -> Callable[[], str]:
125113
"""Returns a callable that fetches a fresh ID token using ADC."""
126114

127115
def get_token() -> str:
128-
# Note: This is a synchronous call. Toolbox-core supports sync callables in headers.
129-
# Ideally we would use async but google-auth is primarily sync for these helpers.
130116
request = transport.requests.Request()
131-
# Try to get ID token directly (e.g. on Cloud Run)
132117
try:
133118
token = id_token.fetch_id_token(request, audience)
134119
return f"Bearer {token}"
135120
except Exception:
136-
# Fallback to default credentials (e.g. local gcloud)
121+
# Fallback to default credentials
137122
creds, _ = google.auth.default()
138123
if not creds.valid:
139124
creds.refresh(request)
140-
# If specific ID token credentials, use them, otherwise this might be Access Token (scoped)
141-
# For Toolbox we usually need ID Tokens.
142-
# If the user is locally auth'd via `gcloud auth login`, fetch_id_token is preferred.
143-
# If falling back to service account file:
125+
144126
if hasattr(creds, "id_token") and creds.id_token:
145127
return f"Bearer {creds.id_token}"
146128

147-
# If we are here, we might need to manually sign via IAM or similar if it's a generic SA.
148-
# For simplicity in this v1, we assume fetch_id_token works or standard creds work.
149-
# Re-attempt fetch_id_token on the credentials object if possible
150129
curr_token = getattr(creds, "token", None)
151130
return f"Bearer {curr_token}" if curr_token else ""
152131

packages/toolbox-adk/tests/unit/test_client.py

Lines changed: 123 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -12,159 +12,158 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from unittest.mock import ANY, AsyncMock, MagicMock, patch
15+
import unittest
16+
from unittest.mock import MagicMock, patch, AsyncMock
1617

1718
import pytest
19+
from toolbox_adk import CredentialStrategy, ToolboxClient
20+
from toolbox_adk.client import CredentialType
1821

19-
from toolbox_adk.client import ToolboxClient
20-
from toolbox_adk.credentials import CredentialStrategy
21-
22-
23-
class TestToolboxClient:
22+
@pytest.mark.asyncio
23+
class TestToolboxClientAuth:
24+
"""Unit tests for Client Auth logic."""
2425

2526
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
26-
def test_init_no_auth(self, mock_core_client):
27+
async def test_init_toolbox_identity(self, mock_core_client):
28+
"""Test init with TOOLBOX_IDENTITY (no auth headers)."""
2729
creds = CredentialStrategy.TOOLBOX_IDENTITY()
28-
client = ToolboxClient("http://server", credentials=creds)
29-
30-
mock_core_client.assert_called_with(
31-
server_url="http://server", client_headers={}
32-
)
33-
34-
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
35-
def test_init_manual_token(self, mock_core_client):
36-
creds = CredentialStrategy.MANUAL_TOKEN("abc")
37-
client = ToolboxClient("http://server", credentials=creds)
38-
39-
mock_core_client.assert_called_with(
40-
server_url="http://server", client_headers={"Authorization": "Bearer abc"}
41-
)
42-
43-
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
44-
def test_init_additional_headers(self, mock_core_client):
45-
creds = CredentialStrategy.TOOLBOX_IDENTITY()
46-
headers = {"X-Custom": "Val"}
47-
client = ToolboxClient(
48-
"http://server", credentials=creds, additional_headers=headers
49-
)
50-
51-
mock_core_client.assert_called_with(
52-
server_url="http://server", client_headers={"X-Custom": "Val"}
53-
)
30+
client = ToolboxClient(server_url="http://test", credentials=creds)
31+
32+
# Verify core client created with empty headers for auth
33+
_, kwargs = mock_core_client.call_args
34+
assert "client_headers" in kwargs
35+
headers = kwargs["client_headers"]
36+
assert "Authorization" not in headers
5437

5538
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
5639
@patch("toolbox_adk.client.id_token.fetch_id_token")
57-
def test_adc_auth_flow_success(self, mock_fetch_token, mock_core_client):
58-
mock_fetch_token.return_value = "id_token_123"
59-
60-
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS("http://aud")
61-
client = ToolboxClient("http://server", credentials=creds)
62-
63-
# Verify a callable was passed
64-
args, kwargs = mock_core_client.call_args
65-
assert "Authorization" in kwargs["client_headers"]
66-
token_getter = kwargs["client_headers"]["Authorization"]
40+
@patch("toolbox_adk.client.google.auth.default")
41+
@patch("toolbox_adk.client.transport.requests.Request")
42+
async def test_init_adc_success_fetch_id_token(self, mock_req, mock_default, mock_fetch_id, mock_core_client):
43+
"""Test ADC strategy where fetch_id_token succeeds."""
44+
mock_fetch_id.return_value = "id-token-123"
45+
46+
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(target_audience="aud")
47+
client = ToolboxClient(server_url="http://test", credentials=creds)
48+
49+
_, kwargs = mock_core_client.call_args
50+
headers = kwargs["client_headers"]
51+
assert "Authorization" in headers
52+
token_getter = headers["Authorization"]
6753
assert callable(token_getter)
68-
69-
# Verify callable behavior
70-
token = token_getter()
71-
assert token == "Bearer id_token_123"
72-
mock_fetch_token.assert_called()
54+
55+
# Call the getter
56+
token_val = token_getter()
57+
assert token_val == "Bearer id-token-123"
58+
mock_fetch_id.assert_called()
7359

7460
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
7561
@patch("toolbox_adk.client.id_token.fetch_id_token")
7662
@patch("toolbox_adk.client.google.auth.default")
77-
def test_adc_auth_flow_fallback(
78-
self, mock_default, mock_fetch_token, mock_core_client
79-
):
80-
# unexpected error on fetch_id_token
81-
mock_fetch_token.side_effect = Exception("No metadata")
82-
63+
@patch("toolbox_adk.client.transport.requests.Request")
64+
async def test_init_adc_fallback_creds(self, mock_req, mock_default, mock_fetch_id, mock_core_client):
65+
"""Test ADC strategy fallback to default() when fetch_id_token fails."""
66+
mock_fetch_id.side_effect = Exception("No metadata server")
67+
68+
# Mock default creds
8369
mock_creds = MagicMock()
8470
mock_creds.valid = False
85-
mock_creds.id_token = "fallback_id_token"
71+
mock_creds.id_token = "fallback-id-token"
8672
mock_default.return_value = (mock_creds, "proj")
87-
88-
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS("http://aud")
89-
client = ToolboxClient("http://server", credentials=creds)
90-
73+
74+
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(target_audience="aud")
75+
client = ToolboxClient(server_url="http://test", credentials=creds)
76+
9177
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
9278
token = token_getter()
93-
94-
assert token == "Bearer fallback_id_token"
95-
mock_creds.refresh.assert_called()
96-
97-
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
98-
def test_manual_creds(self, mock_core_client):
99-
mock_g_creds = MagicMock()
100-
mock_g_creds.valid = False
101-
mock_g_creds.token = "oauth_token"
102-
103-
creds = CredentialStrategy.MANUAL_CREDS(mock_g_creds)
104-
client = ToolboxClient("http://server", credentials=creds)
105-
106-
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
107-
token = token_getter()
108-
109-
assert token == "Bearer oauth_token"
110-
assert token == "Bearer oauth_token"
111-
mock_g_creds.refresh.assert_called()
112-
113-
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
114-
def test_init_validation_errors(self, mock_core_client):
115-
# ADC missing audience
116-
with pytest.raises(ValueError, match="target_audience is required"):
117-
# Fix: only pass target_audience as keyword arg OR positional, not both mixed in a way that causes overlap if defined so
118-
# Actually simpler: just pass raw None
119-
ToolboxClient(
120-
"url",
121-
credentials=CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(None),
122-
)
123-
124-
# Manual token missing token
125-
with pytest.raises(ValueError, match="token is required"):
126-
ToolboxClient("url", credentials=CredentialStrategy.MANUAL_TOKEN(None))
127-
128-
# Manual creds missing credentials
129-
with pytest.raises(ValueError, match="credentials object is required"):
130-
ToolboxClient("url", credentials=CredentialStrategy.MANUAL_CREDS(None))
79+
assert token == "Bearer fallback-id-token"
80+
mock_creds.refresh.assert_called() # Because we set valid=False
13181

13282
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
13383
@patch("toolbox_adk.client.id_token.fetch_id_token")
13484
@patch("toolbox_adk.client.google.auth.default")
135-
def test_adc_auth_flow_fallback_access_token(
136-
self, mock_default, mock_fetch_token, mock_core_client
137-
):
138-
# fetch_id_token fails
139-
mock_fetch_token.side_effect = Exception("No metadata")
140-
85+
@patch("toolbox_adk.client.transport.requests.Request")
86+
async def test_init_adc_fallback_creds_token(self, mock_req, mock_default, mock_fetch_id, mock_core_client):
87+
"""Test ADC fallback when creds have .token but no .id_token."""
88+
mock_fetch_id.side_effect = Exception("No metadata server")
89+
14190
mock_creds = MagicMock()
142-
mock_creds.valid = False
143-
mock_creds.id_token = None # No ID token
144-
mock_creds.token = "access_token_123"
91+
mock_creds.valid = True
92+
del mock_creds.id_token # Simulate no id_token attr or None
93+
mock_creds.token = "access-token-123" # e.g. user creds
14594
mock_default.return_value = (mock_creds, "proj")
146-
147-
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS("http://aud")
148-
client = ToolboxClient("http://server", credentials=creds)
149-
95+
96+
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(target_audience="aud")
97+
client = ToolboxClient(server_url="http://test", credentials=creds)
15098
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
151-
token = token_getter()
152-
153-
assert token == "Bearer access_token_123"
154-
mock_creds.refresh.assert_called()
99+
assert token_getter() == "Bearer access-token-123"
155100

156-
@pytest.mark.asyncio
157101
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
158-
async def test_delegation(self, mock_core_client):
159-
mock_instance = mock_core_client.return_value
160-
mock_instance.load_toolset = AsyncMock(return_value=["t1"])
161-
mock_instance.close = AsyncMock()
102+
async def test_init_manual_token(self, mock_core_client):
103+
creds = CredentialStrategy.MANUAL_TOKEN(token="abc")
104+
client = ToolboxClient("http://test", credentials=creds)
105+
headers = mock_core_client.call_args[1]["client_headers"]
106+
assert headers["Authorization"] == "Bearer abc"
162107

163-
client = ToolboxClient("http://server")
164-
tools = await client.load_toolset("my-set", extra="arg")
108+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
109+
async def test_init_manual_creds(self, mock_core_client):
110+
mock_google_creds = MagicMock()
111+
mock_google_creds.valid = True
112+
mock_google_creds.token = "creds-token"
113+
114+
creds = CredentialStrategy.MANUAL_CREDS(credentials=mock_google_creds)
115+
client = ToolboxClient("http://test", credentials=creds)
116+
117+
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
118+
assert token_getter() == "Bearer creds-token"
165119

166-
mock_instance.load_toolset.assert_awaited_with("my-set", extra="arg")
167-
assert tools == ["t1"]
120+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
121+
async def test_init_user_identity(self, mock_core_client):
122+
creds = CredentialStrategy.USER_IDENTITY(client_id="c", client_secret="s")
123+
client = ToolboxClient("http://test", credentials=creds)
124+
125+
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
126+
# Should be empty initially
127+
assert token_getter() == ""
128+
129+
# Set context
130+
from toolbox_adk.client import USER_TOKEN_CONTEXT_VAR
131+
token = USER_TOKEN_CONTEXT_VAR.set("user-tok")
132+
try:
133+
assert token_getter() == "Bearer user-tok"
134+
finally:
135+
USER_TOKEN_CONTEXT_VAR.reset(token)
136+
137+
async def test_validation_errors(self):
138+
with pytest.raises(ValueError):
139+
# ADC requires audience
140+
CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(target_audience="")
141+
142+
with pytest.raises(ValueError):
143+
CredentialStrategy.MANUAL_TOKEN(token="")
144+
145+
with pytest.raises(ValueError):
146+
CredentialStrategy.MANUAL_CREDS(credentials=None)
168147

148+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
149+
async def test_load_methods(self, mock_core_client_class):
150+
# Setup mock instance
151+
mock_instance = AsyncMock()
152+
mock_core_client_class.return_value = mock_instance
153+
154+
client = ToolboxClient("http://test", credentials=CredentialStrategy.TOOLBOX_IDENTITY())
155+
156+
# Test load_toolset
157+
await client.load_toolset("ts", foo="bar")
158+
mock_instance.load_toolset.assert_called_with("ts", foo="bar")
159+
160+
# Test load_tool
161+
await client.load_tool("t", baz="qux")
162+
mock_instance.load_tool.assert_called_with("t", baz="qux")
163+
164+
# Test close
169165
await client.close()
170-
mock_instance.close.assert_awaited()
166+
mock_instance.close.assert_called_once()
167+
168+
# Test property
169+
assert client.credential_config is not None

0 commit comments

Comments
 (0)