Skip to content

Commit 7bf3bfd

Browse files
MarkDaoustcopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 838331631
1 parent 99058b6 commit 7bf3bfd

File tree

7 files changed

+643
-142
lines changed

7 files changed

+643
-142
lines changed

google/genai/_api_client.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -679,10 +679,29 @@ def __init__(
679679
)
680680
self._http_options.api_version = 'v1beta1'
681681
else: # Implicit initialization or missing arguments.
682-
if not self.api_key:
682+
if env_api_key and api_key:
683+
# Explicit credentials take precedence over implicit api_key.
684+
logger.info(
685+
'The client initialiser api_key argument takes '
686+
'precedence over the API key from the environment variable.'
687+
)
688+
if credentials:
689+
if api_key:
690+
raise ValueError(
691+
'Credentials and API key are mutually exclusive in the client'
692+
' initializer.'
693+
)
694+
elif env_api_key:
695+
logger.info(
696+
'The user `credentials` argument will take precedence over the'
697+
' api key from the environment variables.'
698+
)
699+
self.api_key = None
700+
701+
if not self.api_key and not credentials:
683702
raise ValueError(
684703
'Missing key inputs argument! To use the Google AI API,'
685-
' provide (`api_key`) arguments. To use the Google Cloud API,'
704+
' provide (`api_key` or `credentials`) arguments. To use the Google Cloud API,'
686705
' provide (`vertexai`, `project` & `location`) arguments.'
687706
)
688707
self._http_options.base_url = 'https://generativelanguage.googleapis.com/'
@@ -1163,19 +1182,21 @@ def _request_once(
11631182
) -> HttpResponse:
11641183
data: Optional[Union[str, bytes]] = None
11651184
# If using proj/location, fetch ADC
1166-
if self.vertexai and (self.project or self.location):
1185+
if (
1186+
self.vertexai and (self.project or self.location)
1187+
or not self.vertexai and self._credentials
1188+
):
11671189
http_request.headers['Authorization'] = f'Bearer {self._access_token()}'
11681190
if self._credentials and self._credentials.quota_project_id:
11691191
http_request.headers['x-goog-user-project'] = (
11701192
self._credentials.quota_project_id
11711193
)
1172-
data = json.dumps(http_request.data) if http_request.data else None
1173-
else:
1174-
if http_request.data:
1175-
if not isinstance(http_request.data, bytes):
1176-
data = json.dumps(http_request.data) if http_request.data else None
1177-
else:
1178-
data = http_request.data
1194+
1195+
if http_request.data:
1196+
if not isinstance(http_request.data, bytes):
1197+
data = json.dumps(http_request.data) if http_request.data else None
1198+
else:
1199+
data = http_request.data
11791200

11801201
if stream:
11811202
httpx_request = self._httpx_client.build_request(
@@ -1229,21 +1250,23 @@ async def _async_request_once(
12291250
data: Optional[Union[str, bytes]] = None
12301251

12311252
# If using proj/location, fetch ADC
1232-
if self.vertexai and (self.project or self.location):
1253+
if (
1254+
self.vertexai and (self.project or self.location) or
1255+
not self.vertexai and self._credentials
1256+
):
12331257
http_request.headers['Authorization'] = (
12341258
f'Bearer {await self._async_access_token()}'
12351259
)
12361260
if self._credentials and self._credentials.quota_project_id:
12371261
http_request.headers['x-goog-user-project'] = (
12381262
self._credentials.quota_project_id
12391263
)
1240-
data = json.dumps(http_request.data) if http_request.data else None
1241-
else:
1242-
if http_request.data:
1243-
if not isinstance(http_request.data, bytes):
1244-
data = json.dumps(http_request.data) if http_request.data else None
1245-
else:
1246-
data = http_request.data
1264+
1265+
if http_request.data:
1266+
if not isinstance(http_request.data, bytes):
1267+
data = json.dumps(http_request.data) if http_request.data else None
1268+
else:
1269+
data = http_request.data
12471270

12481271
if stream:
12491272
if self._use_aiohttp():

google/genai/_extra_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,20 @@
1616
"""Extra utils depending on types that are shared between sync and async modules."""
1717

1818
import asyncio
19+
from collections.abc import Callable, MutableMapping
1920
import inspect
2021
import io
2122
import logging
2223
import sys
2324
import typing
24-
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
25+
from typing import Any, Optional, Union, get_args, get_origin
2526
import mimetypes
2627
import os
2728
import pydantic
2829

30+
import google.auth.transport.requests
31+
32+
2933
from . import _common
3034
from . import _mcp_utils
3135
from . import _transformers as t
@@ -674,3 +678,18 @@ def prepare_resumable_upload(
674678
http_options.headers = {}
675679
http_options.headers['X-Goog-Upload-File-Name'] = os.path.basename(file)
676680
return http_options, size_bytes, mime_type
681+
682+
683+
async def _maybe_update_and_insert_auth_token(
684+
headers:MutableMapping[str, str],
685+
creds: google.auth.credentials.Credentials) -> None:
686+
# Refresh credentials to ensure token is valid
687+
if not (creds.token and creds.valid):
688+
try:
689+
auth_req = google.auth.transport.requests.Request() # type: ignore[no-untyped-call]
690+
await asyncio.to_thread(creds.refresh, auth_req)
691+
except Exception as e:
692+
raise ConnectionError(f"Failed to refresh credentials") from e
693+
694+
if not headers.get('Authorization'):
695+
headers['Authorization'] = f'Bearer {creds.token}'

0 commit comments

Comments
 (0)