Skip to content

Commit aaf8403

Browse files
committed
style: apply black formatting to client wrapper
1 parent 92e5dde commit aaf8403

File tree

2 files changed

+84
-70
lines changed

2 files changed

+84
-70
lines changed

packages/toolbox-adk/src/toolbox_adk/client.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional, Dict, Callable, Any, Union, Awaitable
15+
from contextvars import ContextVar
16+
from typing import Any, Awaitable, Callable, Dict, Optional, Union
17+
18+
import google.auth
1619
import toolbox_core
17-
from google.auth import transport
18-
from google.auth import compute_engine
20+
from google.auth import compute_engine, transport
1921
from google.oauth2 import id_token
20-
import google.auth
21-
from contextvars import ContextVar
2222

2323
from .credentials import CredentialConfig, CredentialType
2424

2525
# Global ContextVar for User Identity (3LO) tokens to be injected per-request
26-
USER_TOKEN_CONTEXT_VAR: ContextVar[Optional[str]] = ContextVar("toolbox_user_token", default=None)
26+
USER_TOKEN_CONTEXT_VAR: ContextVar[Optional[str]] = ContextVar(
27+
"toolbox_user_token", default=None
28+
)
2729

2830

2931
class ToolboxClient:
@@ -35,9 +37,10 @@ def __init__(
3537
self,
3638
server_url: str,
3739
credentials: Optional[CredentialConfig] = None,
38-
39-
additional_headers: Optional[Dict[str, Union[str, Callable[[], str], Callable[[], Awaitable[str]]]]] = None,
40-
**kwargs
40+
additional_headers: Optional[
41+
Dict[str, Union[str, Callable[[], str], Callable[[], Awaitable[str]]]]
42+
] = None,
43+
**kwargs,
4144
):
4245
"""
4346
Args:
@@ -49,14 +52,14 @@ def __init__(
4952
self._server_url = server_url
5053
self._credentials = credentials
5154
self._additional_headers = additional_headers or {}
52-
55+
5356
# Prepare auth_token_getters for toolbox-core
5457
# toolbox_core expects: dict[str, Callable[[], str | Awaitable[str]]]
5558
# However, for general headers (like Authorization), we can pass them in client_headers
5659
# if they are static or simpler. Toolbox-core supports `client_headers` which can be dynamic.
57-
60+
5861
self._core_client_headers: Dict[str, Union[str, Callable[[], str]]] = {}
59-
62+
6063
# Add static additional headers
6164
for k, v in self._additional_headers.items():
6265
self._core_client_headers[k] = v
@@ -65,55 +68,60 @@ def __init__(
6568
self._configure_auth(credentials)
6669

6770
self._client = toolbox_core.ToolboxClient(
68-
url=server_url,
69-
client_headers=self._core_client_headers,
70-
**kwargs
71+
url=server_url, client_headers=self._core_client_headers, **kwargs
7172
)
7273

7374
def _configure_auth(self, creds: CredentialConfig):
7475
if creds.type == CredentialType.TOOLBOX_IDENTITY:
7576
# No auth headers needed
7677
pass
77-
78+
7879
elif creds.type == CredentialType.APPLICATION_DEFAULT_CREDENTIALS:
7980
if not creds.target_audience:
80-
raise ValueError("target_audience is required for APPLICATION_DEFAULT_CREDENTIALS")
81-
81+
raise ValueError(
82+
"target_audience is required for APPLICATION_DEFAULT_CREDENTIALS"
83+
)
84+
8285
# Create an async capable token getter
8386
# We wrap it to match the signature expected by toolbox-core headers
8487
# (which accepts callables)
85-
self._core_client_headers["Authorization"] = self._create_adc_token_getter(creds.target_audience)
86-
88+
self._core_client_headers["Authorization"] = self._create_adc_token_getter(
89+
creds.target_audience
90+
)
91+
8792
elif creds.type == CredentialType.MANUAL_TOKEN:
8893
if not creds.token:
8994
raise ValueError("token is required for MANUAL_TOKEN")
9095
scheme = creds.scheme or "Bearer"
9196
self._core_client_headers["Authorization"] = f"{scheme} {creds.token}"
92-
97+
9398
elif creds.type == CredentialType.MANUAL_CREDS:
9499
if not creds.credentials:
95100
raise ValueError("credentials object is required for MANUAL_CREDS")
96-
101+
97102
# Adapter for google-auth credentials object to callable
98-
self._core_client_headers["Authorization"] = self._create_creds_token_getter(creds.credentials)
99-
103+
self._core_client_headers["Authorization"] = (
104+
self._create_creds_token_getter(creds.credentials)
105+
)
106+
100107
elif creds.type == CredentialType.USER_IDENTITY:
101108
# For USER_IDENTITY (3LO), the *Tool* handles the interactive flow at runtime.
102109
# We use a ContextVar to inject the token per-request.
103-
110+
104111
def get_user_token() -> str:
105112
token = USER_TOKEN_CONTEXT_VAR.get()
106113
if not token:
107-
# If this is called but no token is set in context, it means
108-
# the tool wrapper failed to set it or we are in a context where
114+
# If this is called but no token is set in context, it means
115+
# the tool wrapper failed to set it or we are in a context where
109116
# we expected it. We return empty string which might cause 401.
110117
return ""
111118
return f"Bearer {token}"
112-
119+
113120
self._core_client_headers["Authorization"] = get_user_token
114121

115122
def _create_adc_token_getter(self, audience: str) -> Callable[[], str]:
116123
"""Returns a callable that fetches a fresh ID token using ADC."""
124+
117125
def get_token() -> str:
118126
# Note: This is a synchronous call. Toolbox-core supports sync callables in headers.
119127
# Ideally we would use async but google-auth is primarily sync for these helpers.
@@ -131,9 +139,9 @@ def get_token() -> str:
131139
# For Toolbox we usually need ID Tokens.
132140
# If the user is locally auth'd via `gcloud auth login`, fetch_id_token is preferred.
133141
# If falling back to service account file:
134-
if hasattr(creds, 'id_token') and creds.id_token:
135-
return f"Bearer {creds.id_token}"
136-
142+
if hasattr(creds, "id_token") and creds.id_token:
143+
return f"Bearer {creds.id_token}"
144+
137145
# If we are here, we might need to manually sign via IAM or similar if it's a generic SA.
138146
# For simplicity in this v1, we assume fetch_id_token works or standard creds work.
139147
# Re-attempt fetch_id_token on the credentials object if possible
@@ -148,6 +156,7 @@ def get_token() -> str:
148156
if not credentials.valid:
149157
credentials.refresh(request)
150158
return f"Bearer {credentials.token}"
159+
151160
return get_token
152161

153162
@property

packages/toolbox-adk/tests/unit/test_client.py

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,62 +12,60 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from unittest.mock import ANY, AsyncMock, MagicMock, patch
16+
1517
import pytest
16-
from unittest.mock import MagicMock, patch, ANY, AsyncMock
18+
1719
from toolbox_adk.client import ToolboxClient
1820
from toolbox_adk.credentials import CredentialStrategy
1921

22+
2023
class TestToolboxClient:
21-
24+
2225
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
2326
def test_init_no_auth(self, mock_core_client):
2427
creds = CredentialStrategy.TOOLBOX_IDENTITY()
2528
client = ToolboxClient("http://server", credentials=creds)
26-
29+
2730
mock_core_client.assert_called_with(
28-
server_url="http://server",
29-
client_headers={}
31+
server_url="http://server", client_headers={}
3032
)
3133

3234
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
3335
def test_init_manual_token(self, mock_core_client):
3436
creds = CredentialStrategy.MANUAL_TOKEN("abc")
3537
client = ToolboxClient("http://server", credentials=creds)
36-
38+
3739
mock_core_client.assert_called_with(
38-
server_url="http://server",
39-
client_headers={"Authorization": "Bearer abc"}
40+
server_url="http://server", client_headers={"Authorization": "Bearer abc"}
4041
)
4142

4243
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
4344
def test_init_additional_headers(self, mock_core_client):
4445
creds = CredentialStrategy.TOOLBOX_IDENTITY()
4546
headers = {"X-Custom": "Val"}
4647
client = ToolboxClient(
47-
"http://server",
48-
credentials=creds,
49-
additional_headers=headers
48+
"http://server", credentials=creds, additional_headers=headers
5049
)
51-
50+
5251
mock_core_client.assert_called_with(
53-
server_url="http://server",
54-
client_headers={"X-Custom": "Val"}
52+
server_url="http://server", client_headers={"X-Custom": "Val"}
5553
)
5654

5755
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
5856
@patch("toolbox_adk.client.id_token.fetch_id_token")
5957
def test_adc_auth_flow_success(self, mock_fetch_token, mock_core_client):
6058
mock_fetch_token.return_value = "id_token_123"
61-
59+
6260
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS("http://aud")
6361
client = ToolboxClient("http://server", credentials=creds)
64-
62+
6563
# Verify a callable was passed
6664
args, kwargs = mock_core_client.call_args
6765
assert "Authorization" in kwargs["client_headers"]
6866
token_getter = kwargs["client_headers"]["Authorization"]
6967
assert callable(token_getter)
70-
68+
7169
# Verify callable behavior
7270
token = token_getter()
7371
assert token == "Bearer id_token_123"
@@ -76,36 +74,38 @@ def test_adc_auth_flow_success(self, mock_fetch_token, mock_core_client):
7674
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
7775
@patch("toolbox_adk.client.id_token.fetch_id_token")
7876
@patch("toolbox_adk.client.google.auth.default")
79-
def test_adc_auth_flow_fallback(self, mock_default, mock_fetch_token, mock_core_client):
77+
def test_adc_auth_flow_fallback(
78+
self, mock_default, mock_fetch_token, mock_core_client
79+
):
8080
# unexpected error on fetch_id_token
8181
mock_fetch_token.side_effect = Exception("No metadata")
82-
82+
8383
mock_creds = MagicMock()
8484
mock_creds.valid = False
8585
mock_creds.id_token = "fallback_id_token"
8686
mock_default.return_value = (mock_creds, "proj")
87-
87+
8888
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS("http://aud")
8989
client = ToolboxClient("http://server", credentials=creds)
90-
90+
9191
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
9292
token = token_getter()
93-
93+
9494
assert token == "Bearer fallback_id_token"
9595
mock_creds.refresh.assert_called()
96-
96+
9797
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
9898
def test_manual_creds(self, mock_core_client):
9999
mock_g_creds = MagicMock()
100100
mock_g_creds.valid = False
101101
mock_g_creds.token = "oauth_token"
102-
102+
103103
creds = CredentialStrategy.MANUAL_CREDS(mock_g_creds)
104104
client = ToolboxClient("http://server", credentials=creds)
105-
105+
106106
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
107107
token = token_getter()
108-
108+
109109
assert token == "Bearer oauth_token"
110110
assert token == "Bearer oauth_token"
111111
mock_g_creds.refresh.assert_called()
@@ -115,36 +115,41 @@ def test_init_validation_errors(self, mock_core_client):
115115
# ADC missing audience
116116
with pytest.raises(ValueError, match="target_audience is required"):
117117
# Fix: only pass target_audience as keyword arg OR positional, not both mixed in a way that causes overlap if defined so
118-
# Actually simpler: just pass raw None
119-
ToolboxClient("url", credentials=CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(None))
120-
118+
# Actually simpler: just pass raw None
119+
ToolboxClient(
120+
"url",
121+
credentials=CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(None),
122+
)
123+
121124
# Manual token missing token
122125
with pytest.raises(ValueError, match="token is required"):
123126
ToolboxClient("url", credentials=CredentialStrategy.MANUAL_TOKEN(None))
124-
127+
125128
# Manual creds missing credentials
126129
with pytest.raises(ValueError, match="credentials object is required"):
127130
ToolboxClient("url", credentials=CredentialStrategy.MANUAL_CREDS(None))
128131

129132
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
130133
@patch("toolbox_adk.client.id_token.fetch_id_token")
131134
@patch("toolbox_adk.client.google.auth.default")
132-
def test_adc_auth_flow_fallback_access_token(self, mock_default, mock_fetch_token, mock_core_client):
135+
def test_adc_auth_flow_fallback_access_token(
136+
self, mock_default, mock_fetch_token, mock_core_client
137+
):
133138
# fetch_id_token fails
134139
mock_fetch_token.side_effect = Exception("No metadata")
135-
140+
136141
mock_creds = MagicMock()
137142
mock_creds.valid = False
138-
mock_creds.id_token = None # No ID token
143+
mock_creds.id_token = None # No ID token
139144
mock_creds.token = "access_token_123"
140145
mock_default.return_value = (mock_creds, "proj")
141-
146+
142147
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS("http://aud")
143148
client = ToolboxClient("http://server", credentials=creds)
144-
149+
145150
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
146151
token = token_getter()
147-
152+
148153
assert token == "Bearer access_token_123"
149154
mock_creds.refresh.assert_called()
150155

@@ -154,12 +159,12 @@ async def test_delegation(self, mock_core_client):
154159
mock_instance = mock_core_client.return_value
155160
mock_instance.load_toolset = AsyncMock(return_value=["t1"])
156161
mock_instance.close = AsyncMock()
157-
162+
158163
client = ToolboxClient("http://server")
159164
tools = await client.load_toolset("my-set", extra="arg")
160-
165+
161166
mock_instance.load_toolset.assert_awaited_with("my-set", extra="arg")
162167
assert tools == ["t1"]
163-
168+
164169
await client.close()
165170
mock_instance.close.assert_awaited()

0 commit comments

Comments
 (0)