Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 40 additions & 20 deletions google/genai/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,10 +679,29 @@ def __init__(
)
self._http_options.api_version = 'v1beta1'
else: # Implicit initialization or missing arguments.
if not self.api_key:
if env_api_key and api_key:
# Explicit credentials take precedence over implicit api_key.
logger.info(
'The client initialiser api_key argument takes '
'precedence over the API key from the environment variable.'
)
if credentials:
if api_key:
raise ValueError(
'Credentials and API key are mutually exclusive in the client'
' initializer.'
)
elif env_api_key:
logger.info(
'The user `credentials` argument will take precedence over the'
' api key from the environment variables.'
)
self.api_key = None

if not self.api_key and not credentials:
raise ValueError(
'Missing key inputs argument! To use the Google AI API,'
' provide (`api_key`) arguments. To use the Google Cloud API,'
' provide (`api_key` or `credentials`) arguments. To use the Google Cloud API,'
' provide (`vertexai`, `project` & `location`) arguments.'
)
self._http_options.base_url = 'https://generativelanguage.googleapis.com/'
Expand Down Expand Up @@ -1162,20 +1181,21 @@ def _request_once(
stream: bool = False,
) -> HttpResponse:
data: Optional[Union[str, bytes]] = None
# If using proj/location, fetch ADC
if self.vertexai and (self.project or self.location):

uses_vertex_creds = self.vertexai and (self.project or self.location)
uses_mldev_creds = not self.vertexai and self._credentials
if (uses_vertex_creds or uses_mldev_creds):
http_request.headers['Authorization'] = f'Bearer {self._access_token()}'
if self._credentials and self._credentials.quota_project_id:
http_request.headers['x-goog-user-project'] = (
self._credentials.quota_project_id
)
data = json.dumps(http_request.data) if http_request.data else None
else:
if http_request.data:
if not isinstance(http_request.data, bytes):
data = json.dumps(http_request.data) if http_request.data else None
else:
data = http_request.data

if http_request.data:
if not isinstance(http_request.data, bytes):
data = json.dumps(http_request.data) if http_request.data else None
else:
data = http_request.data

if stream:
httpx_request = self._httpx_client.build_request(
Expand Down Expand Up @@ -1228,22 +1248,22 @@ async def _async_request_once(
) -> HttpResponse:
data: Optional[Union[str, bytes]] = None

# If using proj/location, fetch ADC
if self.vertexai and (self.project or self.location):
uses_vertex_creds = self.vertexai and (self.project or self.location)
uses_mldev_creds = not self.vertexai and self._credentials
if (uses_vertex_creds or uses_mldev_creds):
http_request.headers['Authorization'] = (
f'Bearer {await self._async_access_token()}'
)
if self._credentials and self._credentials.quota_project_id:
http_request.headers['x-goog-user-project'] = (
self._credentials.quota_project_id
)
data = json.dumps(http_request.data) if http_request.data else None
else:
if http_request.data:
if not isinstance(http_request.data, bytes):
data = json.dumps(http_request.data) if http_request.data else None
else:
data = http_request.data

if http_request.data:
if not isinstance(http_request.data, bytes):
data = json.dumps(http_request.data) if http_request.data else None
else:
data = http_request.data

if stream:
if self._use_aiohttp():
Expand Down
21 changes: 20 additions & 1 deletion google/genai/_extra_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,20 @@
"""Extra utils depending on types that are shared between sync and async modules."""

import asyncio
from collections.abc import Callable, MutableMapping
import inspect
import io
import logging
import sys
import typing
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
from typing import Any, Optional, Union, get_args, get_origin
import mimetypes
import os
import pydantic

import google.auth.transport.requests


from . import _common
from . import _mcp_utils
from . import _transformers as t
Expand Down Expand Up @@ -674,3 +678,18 @@ def prepare_resumable_upload(
http_options.headers = {}
http_options.headers['X-Goog-Upload-File-Name'] = os.path.basename(file)
return http_options, size_bytes, mime_type


async def _maybe_update_and_insert_auth_token(
headers:MutableMapping[str, str],
creds: google.auth.credentials.Credentials) -> None:
# Refresh credentials to ensure token is valid
if not (creds.token and creds.valid):
try:
auth_req = google.auth.transport.requests.Request() # type: ignore[no-untyped-call]
await asyncio.to_thread(creds.refresh, auth_req)
except Exception as e:
raise ConnectionError(f"Failed to refresh credentials") from e

if not headers.get('Authorization'):
headers['Authorization'] = f'Bearer {creds.token}'
Loading