Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions python/packages/kagent-adk/src/kagent/adk/_agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
set_kagent_span_attributes,
)

from ._context import clear_user_id, set_user_id
from .converters.event_converter import convert_event_to_a2a_events
from .converters.request_converter import convert_a2a_request_to_adk_run_args

Expand Down Expand Up @@ -235,14 +236,23 @@ async def _handle_request(
)
)

task_result_aggregator = TaskResultAggregator()
async with Aclosing(runner.run_async(**run_args)) as agen:
async for adk_event in agen:
for a2a_event in convert_event_to_a2a_events(
adk_event, invocation_context, context.task_id, context.context_id
):
task_result_aggregator.process_event(a2a_event)
await event_queue.enqueue_event(a2a_event)
# Set user_id in context for passthrough to LLM providers
user_id = run_args.get("user_id")
if user_id:
set_user_id(user_id)

try:
task_result_aggregator = TaskResultAggregator()
async with Aclosing(runner.run_async(**run_args)) as agen:
async for adk_event in agen:
for a2a_event in convert_event_to_a2a_events(
adk_event, invocation_context, context.task_id, context.context_id
):
task_result_aggregator.process_event(a2a_event)
await event_queue.enqueue_event(a2a_event)
finally:
# Clear user_id from context to avoid leaking between requests
clear_user_id()

# publish the task result event - this is final
if (
Expand Down
29 changes: 29 additions & 0 deletions python/packages/kagent-adk/src/kagent/adk/_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Context variable module for passing user_id through async call chains."""

from contextvars import ContextVar
from typing import Optional

_user_id_context: ContextVar[Optional[str]] = ContextVar("user_id", default=None)


def set_user_id(user_id: str) -> None:
"""Set the user_id in the current async context.

Args:
user_id: The user identifier to store in context.
"""
_user_id_context.set(user_id)


def get_user_id() -> Optional[str]:
"""Get the user_id from the current async context.

Returns:
The user_id if set, None otherwise.
"""
return _user_id_context.get()


def clear_user_id() -> None:
"""Clear the user_id from the current async context."""
_user_id_context.set(None)
6 changes: 6 additions & 0 deletions python/packages/kagent-adk/src/kagent/adk/models/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from openai.types.shared_params import FunctionDefinition, FunctionParameters
from pydantic import Field

from .._context import get_user_id
from ._ssl import create_ssl_context

if TYPE_CHECKING:
Expand Down Expand Up @@ -394,6 +395,11 @@ async def generate_content_async(
"messages": messages,
}

# Add user parameter for usage tracking if user_id is available in context
user_id = get_user_id()
if user_id and user_id.strip(): # Only add if non-empty and non-whitespace
kwargs["user"] = user_id

if self.frequency_penalty is not None:
kwargs["frequency_penalty"] = self.frequency_penalty
if self.max_tokens:
Expand Down
173 changes: 173 additions & 0 deletions python/packages/kagent-adk/tests/unittests/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from google.genai.types import Content, Part
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam

from kagent.adk._context import clear_user_id, set_user_id
from kagent.adk.models import OpenAI
from kagent.adk.models._openai import _convert_tools_to_openai

Expand Down Expand Up @@ -338,6 +339,178 @@ async def mock_coro(*args, **kwargs):
assert kwargs["max_tokens"] == 4096


@pytest.mark.asyncio
async def test_generate_content_async_with_user_passthrough(llm_request, generate_content_response):
"""Test that user_id from context is passed as 'user' parameter to OpenAI API."""
openai_llm = OpenAI(model="gpt-3.5-turbo", type="openai", api_key="fake")

# Set user_id in context
test_user_id = "[email protected]"
set_user_id(test_user_id)

try:
with mock.patch.object(openai_llm, "_client") as mock_client:
# Create a mock coroutine that returns the generate_content_response.
async def mock_coro(*args, **kwargs):
return generate_content_response

# Assign the coroutine to the mocked method
mock_client.chat.completions.create.return_value = mock_coro()

_ = [resp async for resp in openai_llm.generate_content_async(llm_request, stream=False)]
mock_client.chat.completions.create.assert_called_once()
_, kwargs = mock_client.chat.completions.create.call_args
assert kwargs["user"] == test_user_id
finally:
clear_user_id()


@pytest.mark.asyncio
async def test_generate_content_async_without_user_passthrough(llm_request, generate_content_response):
"""Test that 'user' parameter is not included when user_id is not in context."""
openai_llm = OpenAI(model="gpt-3.5-turbo", type="openai", api_key="fake")

# Ensure user_id is not in context
clear_user_id()

with mock.patch.object(openai_llm, "_client") as mock_client:
# Create a mock coroutine that returns the generate_content_response.
async def mock_coro(*args, **kwargs):
return generate_content_response

# Assign the coroutine to the mocked method
mock_client.chat.completions.create.return_value = mock_coro()

_ = [resp async for resp in openai_llm.generate_content_async(llm_request, stream=False)]
mock_client.chat.completions.create.assert_called_once()
_, kwargs = mock_client.chat.completions.create.call_args
assert "user" not in kwargs


@pytest.mark.asyncio
async def test_generate_content_async_streaming_with_user_passthrough(llm_request):
"""Test that user_id from context is passed as 'user' parameter in streaming mode."""
openai_llm = OpenAI(model="gpt-3.5-turbo", type="openai", api_key="fake")

# Set user_id in context
test_user_id = "[email protected]"
set_user_id(test_user_id)

try:
with mock.patch.object(openai_llm, "_client") as mock_client:
# Create a mock async generator for streaming
async def mock_stream(*args, **kwargs):
class MockChunk:
def __init__(self):
class MockDelta:
def __init__(self):
self.content = "Hello"

class MockChoice:
def __init__(self):
self.delta = MockDelta()
self.finish_reason = "stop"

self.choices = [MockChoice()]

yield MockChunk()

# Assign the async generator to the mocked method
mock_client.chat.completions.create.return_value = mock_stream()

_ = [resp async for resp in openai_llm.generate_content_async(llm_request, stream=True)]
mock_client.chat.completions.create.assert_called_once()
_, kwargs = mock_client.chat.completions.create.call_args
assert kwargs["user"] == test_user_id
assert kwargs["stream"] is True
finally:
clear_user_id()


@pytest.mark.asyncio
async def test_azure_openai_with_user_passthrough(llm_request, generate_content_response):
"""Test that user_id from context is passed as 'user' parameter to Azure OpenAI API."""
from kagent.adk.models import AzureOpenAI

azure_llm = AzureOpenAI(
model="gpt-35-turbo",
type="azure_openai",
api_key="fake",
azure_endpoint="https://test.openai.azure.com",
api_version="2024-02-15-preview",
)

# Set user_id in context
test_user_id = "[email protected]"
set_user_id(test_user_id)

try:
with mock.patch.object(azure_llm, "_client") as mock_client:
# Create a mock coroutine that returns the generate_content_response.
async def mock_coro(*args, **kwargs):
return generate_content_response

# Assign the coroutine to the mocked method
mock_client.chat.completions.create.return_value = mock_coro()

_ = [resp async for resp in azure_llm.generate_content_async(llm_request, stream=False)]
mock_client.chat.completions.create.assert_called_once()
_, kwargs = mock_client.chat.completions.create.call_args
assert kwargs["user"] == test_user_id
finally:
clear_user_id()


@pytest.mark.asyncio
async def test_generate_content_async_with_empty_user_id(llm_request, generate_content_response):
"""Test that empty string user_id is not passed to OpenAI API."""
openai_llm = OpenAI(model="gpt-3.5-turbo", type="openai", api_key="fake")

# Set empty user_id in context
set_user_id("")

try:
with mock.patch.object(openai_llm, "_client") as mock_client:

async def mock_coro(*args, **kwargs):
return generate_content_response

mock_client.chat.completions.create.return_value = mock_coro()

_ = [resp async for resp in openai_llm.generate_content_async(llm_request, stream=False)]
mock_client.chat.completions.create.assert_called_once()
_, kwargs = mock_client.chat.completions.create.call_args
# Empty string should not be included (checked with .strip())
assert "user" not in kwargs
finally:
clear_user_id()


@pytest.mark.asyncio
async def test_generate_content_async_with_whitespace_only_user_id(llm_request, generate_content_response):
"""Test that whitespace-only user_id is not passed to OpenAI API."""
openai_llm = OpenAI(model="gpt-3.5-turbo", type="openai", api_key="fake")

# Set whitespace-only user_id in context
set_user_id(" ")

try:
with mock.patch.object(openai_llm, "_client") as mock_client:

async def mock_coro(*args, **kwargs):
return generate_content_response

mock_client.chat.completions.create.return_value = mock_coro()

_ = [resp async for resp in openai_llm.generate_content_async(llm_request, stream=False)]
mock_client.chat.completions.create.assert_called_once()
_, kwargs = mock_client.chat.completions.create.call_args
# Whitespace-only string should not be included
assert "user" not in kwargs
finally:
clear_user_id()


# ============================================================================
# SSL/TLS Configuration Tests
# ============================================================================
Expand Down
71 changes: 71 additions & 0 deletions python/packages/kagent-adk/tests/unittests/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Unit tests for context variable module."""

import pytest

from kagent.adk._context import clear_user_id, get_user_id, set_user_id


def test_get_user_id_default():
"""Test that get_user_id returns None by default."""
# Ensure context is cleared
clear_user_id()
assert get_user_id() is None


def test_set_and_get_user_id():
"""Test setting and getting user_id."""
test_user_id = "[email protected]"
set_user_id(test_user_id)
try:
assert get_user_id() == test_user_id
finally:
clear_user_id()


def test_clear_user_id():
"""Test clearing user_id from context."""
test_user_id = "[email protected]"
set_user_id(test_user_id)
assert get_user_id() == test_user_id

clear_user_id()
assert get_user_id() is None


def test_set_user_id_overwrites_previous():
"""Test that setting user_id overwrites the previous value."""
set_user_id("[email protected]")
try:
assert get_user_id() == "[email protected]"

set_user_id("[email protected]")
assert get_user_id() == "[email protected]"
finally:
clear_user_id()


@pytest.mark.asyncio
async def test_context_isolation():
"""Test that context variables are isolated per async task."""
import asyncio

async def task1():
set_user_id("[email protected]")
await asyncio.sleep(0.01)
return get_user_id()

async def task2():
set_user_id("[email protected]")
await asyncio.sleep(0.01)
return get_user_id()

# Run tasks concurrently
results = await asyncio.gather(task1(), task2())

# Each task should see its own user_id
assert results[0] == "[email protected]"
assert results[1] == "[email protected]"

# Main context should not have user_id set
clear_user_id()
assert get_user_id() is None