Skip to content

Commit 3c99afa

Browse files
committed
refactor(tool): enforce strict ADK auth (3LO) and full test coverage
1 parent 41d02f7 commit 3c99afa

File tree

2 files changed

+176
-30
lines changed

2 files changed

+176
-30
lines changed

packages/toolbox-adk/src/toolbox_adk/tool.py

Lines changed: 57 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@
2020
from google.adk.tools.base_tool import BaseTool
2121
from google.adk.agents.readonly_context import ReadonlyContext
2222

23-
from google_auth_oauthlib.flow import InstalledAppFlow
24-
from google.auth.transport.requests import Request
23+
from google.adk.auth.auth_tool import AuthConfig
24+
from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes, OAuth2Auth
25+
from fastapi.openapi.models import SecurityScheme, OAuthFlows, OAuthFlowAuthorizationCode, OAuth2
26+
2527
from .credentials import CredentialConfig, CredentialType
2628
from .client import USER_TOKEN_CONTEXT_VAR
2729

2830

31+
2932
class ToolboxContext:
3033
"""Context object passed to pre/post hooks."""
3134
def __init__(self, arguments: Dict[str, Any], tool_context: ReadonlyContext):
@@ -70,7 +73,6 @@ def __init__(
7073
self._pre_hook = pre_hook
7174
self._post_hook = post_hook
7275
self._auth_config = auth_config
73-
self._user_creds: Optional[Any] = None
7476

7577
@override
7678
async def run_async(
@@ -88,36 +90,61 @@ async def run_async(
8890
# 2. ADK Auth Integration (3LO)
8991
# Check if USER_IDENTITY is configured
9092
reset_token = None
93+
9194
if self._auth_config and self._auth_config.type == CredentialType.USER_IDENTITY:
92-
# Handle interactive flow if credentials are missing or expired
93-
if not self._user_creds or not self._user_creds.valid:
94-
if self._user_creds and self._user_creds.expired and self._user_creds.refresh_token:
95-
try:
96-
self._user_creds.refresh(Request())
97-
except Exception:
98-
self._user_creds = None
95+
if not self._auth_config.client_id or not self._auth_config.client_secret:
96+
raise ValueError("USER_IDENTITY requires client_id and client_secret")
9997

100-
if not self._user_creds:
101-
# Trigger flow
102-
if not (self._auth_config.client_id and self._auth_config.client_secret):
103-
raise ValueError("USER_IDENTITY requires client_id and client_secret")
104-
105-
config = {
106-
"installed": {
107-
"client_id": self._auth_config.client_id,
108-
"client_secret": self._auth_config.client_secret,
109-
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
110-
"token_uri": "https://oauth2.googleapis.com/token",
111-
}
112-
}
113-
scopes = self._auth_config.scopes or ["https://www.googleapis.com/auth/cloud-platform"]
114-
115-
flow = InstalledAppFlow.from_client_config(config, scopes=scopes)
116-
self._user_creds = flow.run_local_server(port=0)
98+
# Construct ADK AuthConfig
99+
scopes = self._auth_config.scopes or ["https://www.googleapis.com/auth/cloud-platform"]
100+
scope_dict = {s: "" for s in scopes}
117101

118-
# Inject token into ContextVar
119-
if self._user_creds:
120-
reset_token = USER_TOKEN_CONTEXT_VAR.set(self._user_creds.token)
102+
auth_config_adk = AuthConfig(
103+
auth_scheme=OAuth2(
104+
flows=OAuthFlows(
105+
authorizationCode=OAuthFlowAuthorizationCode(
106+
authorizationUrl="https://accounts.google.com/o/oauth2/auth",
107+
tokenUrl="https://oauth2.googleapis.com/token",
108+
scopes=scope_dict
109+
)
110+
)
111+
),
112+
raw_auth_credential=AuthCredential(
113+
auth_type=AuthCredentialTypes.OAUTH2,
114+
oauth2=OAuth2Auth(
115+
client_id=self._auth_config.client_id,
116+
client_secret=self._auth_config.client_secret,
117+
scopes=scopes
118+
)
119+
)
120+
)
121+
122+
# Check if we already have credentials from a previous exchange
123+
try:
124+
# get_auth_response returns AuthCredential if found
125+
creds = tool_context.get_auth_response(auth_config_adk)
126+
if creds and creds.oauth2 and creds.oauth2.access_token:
127+
reset_token = USER_TOKEN_CONTEXT_VAR.set(creds.oauth2.access_token)
128+
else:
129+
# Request credentials. This will signal the runner to pause.
130+
# We return None (or raise) to stop current execution.
131+
tool_context.request_credential(auth_config_adk)
132+
# Returning None here. The ADK runner will see the requested_auth_configs
133+
# in the tool_context/event and trigger the client event.
134+
return None
135+
except Exception as e:
136+
# If get_auth_response fails drastically or request_credential fails
137+
ctx.error = e
138+
# We might want to request credential if retrieval failed?
139+
# For now let's assume if it fails we can't proceed.
140+
# Actually, strictly we should probably request credential if get_auth_response returns nothing
141+
# but get_auth_response typically handles the lookup.
142+
# If exception is unrelated, raise.
143+
if "credential" in str(e).lower() or isinstance(e, ValueError): # Loose check, but safest is to re-raise
144+
raise e
145+
# Fallback to request logic if it was a lookup error?
146+
tool_context.request_credential(auth_config_adk)
147+
return None
121148

122149
try:
123150
# Execute the core tool

packages/toolbox-adk/tests/unit/test_tool.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,122 @@ async def test_bind_params(self):
127127
# unless we mock properly.
128128
assert new_tool._core_tool == "new_core_tool"
129129
mock_core.bind_params.assert_called_with({"a": 1})
130+
@pytest.mark.asyncio
131+
async def test_3lo_missing_client_secret(self):
132+
# Test ValueError when client_id/secret missing
133+
core_tool = AsyncMock()
134+
auth_config = CredentialConfig(type=CredentialType.USER_IDENTITY)
135+
# Missing client_id/secret
136+
137+
tool = ToolboxTool(core_tool, auth_config=auth_config)
138+
ctx = MagicMock() # Mock the context
139+
140+
with pytest.raises(ValueError, match="USER_IDENTITY requires client_id and client_secret"):
141+
await tool.run_async({"arg": "val"}, ctx)
142+
143+
@pytest.mark.asyncio
144+
async def test_3lo_request_credential_when_missing(self):
145+
# Test that if creds are missing, request_credential is called and returns None
146+
core_tool = AsyncMock()
147+
auth_config = CredentialConfig(
148+
type=CredentialType.USER_IDENTITY,
149+
client_id="cid",
150+
client_secret="csec"
151+
)
152+
153+
tool = ToolboxTool(core_tool, auth_config=auth_config)
154+
155+
ctx = MagicMock()
156+
# Mock get_auth_response returning None (no creds yet)
157+
ctx.get_auth_response.return_value = None
158+
159+
result = await tool.run_async({}, ctx)
160+
161+
# Verify result is None (signal pause)
162+
assert result is None
163+
# Verify request_credential was called
164+
ctx.request_credential.assert_called_once()
165+
# Verify core tool was NOT called
166+
core_tool.assert_not_called()
167+
168+
@pytest.mark.asyncio
169+
async def test_3lo_uses_existing_credential(self):
170+
# Test that if creds exist, they are used and injected
171+
core_tool = AsyncMock(return_value="success")
172+
auth_config = CredentialConfig(
173+
type=CredentialType.USER_IDENTITY,
174+
client_id="cid",
175+
client_secret="csec"
176+
)
177+
178+
tool = ToolboxTool(core_tool, auth_config=auth_config)
179+
180+
ctx = MagicMock()
181+
# Mock get_auth_response returning valid creds
182+
mock_creds = MagicMock()
183+
mock_creds.oauth2.access_token = "valid_token"
184+
ctx.get_auth_response.return_value = mock_creds
185+
186+
result = await tool.run_async({}, ctx)
187+
188+
# Verify result is success
189+
assert result == "success"
190+
# Verify request_credential was NOT called
191+
ctx.request_credential.assert_not_called()
192+
# Verify core tool WAS called
193+
core_tool.assert_called_once()
194+
195+
196+
@pytest.mark.asyncio
197+
async def test_3lo_exception_reraise(self):
198+
# Test that specific credential errors are re-raised
199+
core_tool = AsyncMock()
200+
auth_config = CredentialConfig(
201+
type=CredentialType.USER_IDENTITY,
202+
client_id="cid",
203+
client_secret="csec"
204+
)
205+
tool = ToolboxTool(core_tool, auth_config=auth_config)
206+
ctx = MagicMock()
207+
208+
# Mock get_auth_response raising ValueError
209+
ctx.get_auth_response.side_effect = ValueError("Invalid Credential")
210+
211+
with pytest.raises(ValueError, match="Invalid Credential"):
212+
await tool.run_async({}, ctx)
213+
214+
@pytest.mark.asyncio
215+
async def test_3lo_exception_fallback(self):
216+
# Test that non-credential errors trigger fallback request
217+
core_tool = AsyncMock()
218+
auth_config = CredentialConfig(
219+
type=CredentialType.USER_IDENTITY,
220+
client_id="cid",
221+
client_secret="csec"
222+
)
223+
tool = ToolboxTool(core_tool, auth_config=auth_config)
224+
ctx = MagicMock()
225+
226+
# Mock get_auth_response raising generic error
227+
ctx.get_auth_response.side_effect = RuntimeError("Random failure")
228+
229+
result = await tool.run_async({}, ctx)
230+
231+
# Should catch RuntimeError, call request_credential, and return None
232+
assert result is None
233+
ctx.request_credential.assert_called_once()
234+
235+
def test_init_defaults(self):
236+
# Test initialization with minimal tool metadata
237+
class EmptyTool:
238+
pass
239+
240+
core_tool = EmptyTool()
241+
args = {"core_tool": core_tool}
242+
243+
# Directly instantiate or if strict typing prevents it, force it
244+
# ToolboxTool expects CoreToolboxTool which is a Protocol/Class.
245+
# But at runtime it just checks attributes.
246+
tool = ToolboxTool(core_tool)
247+
assert tool.name == "unknown_tool"
248+
assert tool.description == "No description provided."

0 commit comments

Comments
 (0)