diff --git a/python/packages/kagent-adk/src/kagent/adk/_agent_executor.py b/python/packages/kagent-adk/src/kagent/adk/_agent_executor.py index 9c265c72e..7feed7552 100644 --- a/python/packages/kagent-adk/src/kagent/adk/_agent_executor.py +++ b/python/packages/kagent-adk/src/kagent/adk/_agent_executor.py @@ -33,6 +33,7 @@ set_kagent_span_attributes, ) +from ._context import clear_user_id, set_user_id from .converters.event_converter import convert_event_to_a2a_events from .converters.request_converter import convert_a2a_request_to_adk_run_args @@ -235,14 +236,23 @@ async def _handle_request( ) ) - task_result_aggregator = TaskResultAggregator() - async with Aclosing(runner.run_async(**run_args)) as agen: - async for adk_event in agen: - for a2a_event in convert_event_to_a2a_events( - adk_event, invocation_context, context.task_id, context.context_id - ): - task_result_aggregator.process_event(a2a_event) - await event_queue.enqueue_event(a2a_event) + # Set user_id in context for passthrough to LLM providers + user_id = run_args.get("user_id") + if user_id: + set_user_id(user_id) + + try: + task_result_aggregator = TaskResultAggregator() + async with Aclosing(runner.run_async(**run_args)) as agen: + async for adk_event in agen: + for a2a_event in convert_event_to_a2a_events( + adk_event, invocation_context, context.task_id, context.context_id + ): + task_result_aggregator.process_event(a2a_event) + await event_queue.enqueue_event(a2a_event) + finally: + # Clear user_id from context to avoid leaking between requests + clear_user_id() # publish the task result event - this is final if ( diff --git a/python/packages/kagent-adk/src/kagent/adk/_context.py b/python/packages/kagent-adk/src/kagent/adk/_context.py new file mode 100644 index 000000000..850a63771 --- /dev/null +++ b/python/packages/kagent-adk/src/kagent/adk/_context.py @@ -0,0 +1,29 @@ +"""Context variable module for passing user_id through async call chains.""" + +from contextvars import ContextVar +from typing import Optional + +_user_id_context: ContextVar[Optional[str]] = ContextVar("user_id", default=None) + + +def set_user_id(user_id: str) -> None: + """Set the user_id in the current async context. + + Args: + user_id: The user identifier to store in context. + """ + _user_id_context.set(user_id) + + +def get_user_id() -> Optional[str]: + """Get the user_id from the current async context. + + Returns: + The user_id if set, None otherwise. + """ + return _user_id_context.get() + + +def clear_user_id() -> None: + """Clear the user_id from the current async context.""" + _user_id_context.set(None) diff --git a/python/packages/kagent-adk/src/kagent/adk/models/_openai.py b/python/packages/kagent-adk/src/kagent/adk/models/_openai.py index 30fda6e9b..84e84a1cb 100644 --- a/python/packages/kagent-adk/src/kagent/adk/models/_openai.py +++ b/python/packages/kagent-adk/src/kagent/adk/models/_openai.py @@ -32,6 +32,7 @@ from openai.types.shared_params import FunctionDefinition, FunctionParameters from pydantic import Field +from .._context import get_user_id from ._ssl import create_ssl_context if TYPE_CHECKING: @@ -394,6 +395,11 @@ async def generate_content_async( "messages": messages, } + # Add user parameter for usage tracking if user_id is available in context + user_id = get_user_id() + if user_id and user_id.strip(): # Only add if non-empty and non-whitespace + kwargs["user"] = user_id + if self.frequency_penalty is not None: kwargs["frequency_penalty"] = self.frequency_penalty if self.max_tokens: diff --git a/python/packages/kagent-adk/tests/unittests/models/test_openai.py b/python/packages/kagent-adk/tests/unittests/models/test_openai.py index 67683b80d..9b700a137 100644 --- a/python/packages/kagent-adk/tests/unittests/models/test_openai.py +++ b/python/packages/kagent-adk/tests/unittests/models/test_openai.py @@ -21,6 +21,7 @@ from google.genai.types import Content, Part from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam +from kagent.adk._context import clear_user_id, set_user_id from kagent.adk.models import OpenAI from kagent.adk.models._openai import _convert_tools_to_openai @@ -338,6 +339,178 @@ async def mock_coro(*args, **kwargs): assert kwargs["max_tokens"] == 4096 +@pytest.mark.asyncio +async def test_generate_content_async_with_user_passthrough(llm_request, generate_content_response): + """Test that user_id from context is passed as 'user' parameter to OpenAI API.""" + openai_llm = OpenAI(model="gpt-3.5-turbo", type="openai", api_key="fake") + + # Set user_id in context + test_user_id = "user@example.com" + set_user_id(test_user_id) + + try: + with mock.patch.object(openai_llm, "_client") as mock_client: + # Create a mock coroutine that returns the generate_content_response. + async def mock_coro(*args, **kwargs): + return generate_content_response + + # Assign the coroutine to the mocked method + mock_client.chat.completions.create.return_value = mock_coro() + + _ = [resp async for resp in openai_llm.generate_content_async(llm_request, stream=False)] + mock_client.chat.completions.create.assert_called_once() + _, kwargs = mock_client.chat.completions.create.call_args + assert kwargs["user"] == test_user_id + finally: + clear_user_id() + + +@pytest.mark.asyncio +async def test_generate_content_async_without_user_passthrough(llm_request, generate_content_response): + """Test that 'user' parameter is not included when user_id is not in context.""" + openai_llm = OpenAI(model="gpt-3.5-turbo", type="openai", api_key="fake") + + # Ensure user_id is not in context + clear_user_id() + + with mock.patch.object(openai_llm, "_client") as mock_client: + # Create a mock coroutine that returns the generate_content_response. + async def mock_coro(*args, **kwargs): + return generate_content_response + + # Assign the coroutine to the mocked method + mock_client.chat.completions.create.return_value = mock_coro() + + _ = [resp async for resp in openai_llm.generate_content_async(llm_request, stream=False)] + mock_client.chat.completions.create.assert_called_once() + _, kwargs = mock_client.chat.completions.create.call_args + assert "user" not in kwargs + + +@pytest.mark.asyncio +async def test_generate_content_async_streaming_with_user_passthrough(llm_request): + """Test that user_id from context is passed as 'user' parameter in streaming mode.""" + openai_llm = OpenAI(model="gpt-3.5-turbo", type="openai", api_key="fake") + + # Set user_id in context + test_user_id = "user@example.com" + set_user_id(test_user_id) + + try: + with mock.patch.object(openai_llm, "_client") as mock_client: + # Create a mock async generator for streaming + async def mock_stream(*args, **kwargs): + class MockChunk: + def __init__(self): + class MockDelta: + def __init__(self): + self.content = "Hello" + + class MockChoice: + def __init__(self): + self.delta = MockDelta() + self.finish_reason = "stop" + + self.choices = [MockChoice()] + + yield MockChunk() + + # Assign the async generator to the mocked method + mock_client.chat.completions.create.return_value = mock_stream() + + _ = [resp async for resp in openai_llm.generate_content_async(llm_request, stream=True)] + mock_client.chat.completions.create.assert_called_once() + _, kwargs = mock_client.chat.completions.create.call_args + assert kwargs["user"] == test_user_id + assert kwargs["stream"] is True + finally: + clear_user_id() + + +@pytest.mark.asyncio +async def test_azure_openai_with_user_passthrough(llm_request, generate_content_response): + """Test that user_id from context is passed as 'user' parameter to Azure OpenAI API.""" + from kagent.adk.models import AzureOpenAI + + azure_llm = AzureOpenAI( + model="gpt-35-turbo", + type="azure_openai", + api_key="fake", + azure_endpoint="https://test.openai.azure.com", + api_version="2024-02-15-preview", + ) + + # Set user_id in context + test_user_id = "user@example.com" + set_user_id(test_user_id) + + try: + with mock.patch.object(azure_llm, "_client") as mock_client: + # Create a mock coroutine that returns the generate_content_response. + async def mock_coro(*args, **kwargs): + return generate_content_response + + # Assign the coroutine to the mocked method + mock_client.chat.completions.create.return_value = mock_coro() + + _ = [resp async for resp in azure_llm.generate_content_async(llm_request, stream=False)] + mock_client.chat.completions.create.assert_called_once() + _, kwargs = mock_client.chat.completions.create.call_args + assert kwargs["user"] == test_user_id + finally: + clear_user_id() + + +@pytest.mark.asyncio +async def test_generate_content_async_with_empty_user_id(llm_request, generate_content_response): + """Test that empty string user_id is not passed to OpenAI API.""" + openai_llm = OpenAI(model="gpt-3.5-turbo", type="openai", api_key="fake") + + # Set empty user_id in context + set_user_id("") + + try: + with mock.patch.object(openai_llm, "_client") as mock_client: + + async def mock_coro(*args, **kwargs): + return generate_content_response + + mock_client.chat.completions.create.return_value = mock_coro() + + _ = [resp async for resp in openai_llm.generate_content_async(llm_request, stream=False)] + mock_client.chat.completions.create.assert_called_once() + _, kwargs = mock_client.chat.completions.create.call_args + # Empty string should not be included (checked with .strip()) + assert "user" not in kwargs + finally: + clear_user_id() + + +@pytest.mark.asyncio +async def test_generate_content_async_with_whitespace_only_user_id(llm_request, generate_content_response): + """Test that whitespace-only user_id is not passed to OpenAI API.""" + openai_llm = OpenAI(model="gpt-3.5-turbo", type="openai", api_key="fake") + + # Set whitespace-only user_id in context + set_user_id(" ") + + try: + with mock.patch.object(openai_llm, "_client") as mock_client: + + async def mock_coro(*args, **kwargs): + return generate_content_response + + mock_client.chat.completions.create.return_value = mock_coro() + + _ = [resp async for resp in openai_llm.generate_content_async(llm_request, stream=False)] + mock_client.chat.completions.create.assert_called_once() + _, kwargs = mock_client.chat.completions.create.call_args + # Whitespace-only string should not be included + assert "user" not in kwargs + finally: + clear_user_id() + + # ============================================================================ # SSL/TLS Configuration Tests # ============================================================================ diff --git a/python/packages/kagent-adk/tests/unittests/test_context.py b/python/packages/kagent-adk/tests/unittests/test_context.py new file mode 100644 index 000000000..765b975bb --- /dev/null +++ b/python/packages/kagent-adk/tests/unittests/test_context.py @@ -0,0 +1,71 @@ +"""Unit tests for context variable module.""" + +import pytest + +from kagent.adk._context import clear_user_id, get_user_id, set_user_id + + +def test_get_user_id_default(): + """Test that get_user_id returns None by default.""" + # Ensure context is cleared + clear_user_id() + assert get_user_id() is None + + +def test_set_and_get_user_id(): + """Test setting and getting user_id.""" + test_user_id = "user@example.com" + set_user_id(test_user_id) + try: + assert get_user_id() == test_user_id + finally: + clear_user_id() + + +def test_clear_user_id(): + """Test clearing user_id from context.""" + test_user_id = "user@example.com" + set_user_id(test_user_id) + assert get_user_id() == test_user_id + + clear_user_id() + assert get_user_id() is None + + +def test_set_user_id_overwrites_previous(): + """Test that setting user_id overwrites the previous value.""" + set_user_id("user1@example.com") + try: + assert get_user_id() == "user1@example.com" + + set_user_id("user2@example.com") + assert get_user_id() == "user2@example.com" + finally: + clear_user_id() + + +@pytest.mark.asyncio +async def test_context_isolation(): + """Test that context variables are isolated per async task.""" + import asyncio + + async def task1(): + set_user_id("user1@example.com") + await asyncio.sleep(0.01) + return get_user_id() + + async def task2(): + set_user_id("user2@example.com") + await asyncio.sleep(0.01) + return get_user_id() + + # Run tasks concurrently + results = await asyncio.gather(task1(), task2()) + + # Each task should see its own user_id + assert results[0] == "user1@example.com" + assert results[1] == "user2@example.com" + + # Main context should not have user_id set + clear_user_id() + assert get_user_id() is None