Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
import pydantic
import re
import threading
import time

from openai import AsyncStream, Stream
from wrapt import ObjectProxy

# Conditional imports for backward compatibility
try:
Expand Down Expand Up @@ -190,15 +192,20 @@ def set_data_attributes(traced_response: TracedData, span: Span):
span, SpanAttributes.LLM_USAGE_TOTAL_TOKENS, usage.total_tokens
)
if usage.input_tokens_details:
_set_span_attribute(
span,
GenAIAttributes.GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS,
usage.input_tokens_details.cached_tokens,
)
if hasattr(GenAIAttributes, 'GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need it?

_set_span_attribute(
span,
GenAIAttributes.GEN_AI_USAGE_CACHE_READ_INPUT_TOKENS,
usage.input_tokens_details.cached_tokens,
)
elif hasattr(GenAIAttributes, 'GEN_AI_USAGE_INPUT_TOKENS_CACHED'):
_set_span_attribute(
span,
GenAIAttributes.GEN_AI_USAGE_INPUT_TOKENS_CACHED,
usage.input_tokens_details.cached_tokens,
)

# Usage - count of reasoning tokens
reasoning_tokens = None
# Support both dict-style and object-style `usage`
tokens_details = (
usage.get("output_tokens_details") if isinstance(usage, dict)
else getattr(usage, "output_tokens_details", None)
Expand Down Expand Up @@ -433,7 +440,19 @@ def responses_get_or_create_wrapper(tracer: Tracer, wrapped, instance, args, kwa
try:
response = wrapped(*args, **kwargs)
if isinstance(response, Stream):
return response
span = tracer.start_span(
SPAN_NAME,
kind=SpanKind.CLIENT,
start_time=start_time,
)

return ResponseStream(
span=span,
response=response,
start_time=start_time,
request_kwargs=kwargs,
tracer=tracer,
)
except Exception as e:
response_id = kwargs.get("response_id")
existing_data = {}
Expand Down Expand Up @@ -563,7 +582,19 @@ async def async_responses_get_or_create_wrapper(
try:
response = await wrapped(*args, **kwargs)
if isinstance(response, (Stream, AsyncStream)):
return response
span = tracer.start_span(
SPAN_NAME,
kind=SpanKind.CLIENT,
start_time=start_time,
)

return ResponseStream(
span=span,
response=response,
start_time=start_time,
request_kwargs=kwargs,
tracer=tracer,
)
except Exception as e:
response_id = kwargs.get("response_id")
existing_data = {}
Expand Down Expand Up @@ -728,4 +759,188 @@ async def async_responses_cancel_wrapper(
return response


# TODO: build streaming responses
class ResponseStream(ObjectProxy):
"""Proxy class for streaming responses to capture telemetry data"""

_span = None
_start_time = None
_request_kwargs = None
_tracer = None
_traced_data = None

def __init__(
self,
span,
response,
start_time=None,
request_kwargs=None,
tracer=None,
traced_data=None,
):
super().__init__(response)
self._span = span
self._start_time = start_time
self._request_kwargs = request_kwargs or {}
self._tracer = tracer
self._traced_data = traced_data or TracedData(
start_time=start_time,
response_id="",
input=process_input(self._request_kwargs.get("input", [])),
instructions=self._request_kwargs.get("instructions"),
tools=get_tools_from_kwargs(self._request_kwargs),
output_blocks={},
usage=None,
output_text="",
request_model=self._request_kwargs.get("model", ""),
response_model="",
request_reasoning_summary=self._request_kwargs.get("reasoning", {}).get(
"summary"
),
request_reasoning_effort=self._request_kwargs.get("reasoning", {}).get("effort"),
response_reasoning_effort=None,
)

self._complete_response_data = None
self._output_text = ""

self._cleanup_completed = False
self._cleanup_lock = threading.Lock()

def __del__(self):
"""Cleanup when object is garbage collected"""
if hasattr(self, "_cleanup_completed") and not self._cleanup_completed:
self._ensure_cleanup()

def __enter__(self):
"""Context manager entry"""
return self

def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit"""
if exc_type is not None:
self._handle_exception(exc_val)
else:
self._process_complete_response()
return False

def __iter__(self):
"""Synchronous iterator"""
return self

def __next__(self):
"""Synchronous iteration"""
try:
chunk = self.__wrapped__.__next__()
except StopIteration:
self._process_complete_response()
raise
except Exception as e:
self._handle_exception(e)
raise
else:
self._process_chunk(chunk)
return chunk

def __aiter__(self):
"""Async iterator"""
return self

async def __anext__(self):
"""Async iteration"""
try:
chunk = await self.__wrapped__.__anext__()
except StopAsyncIteration:
self._process_complete_response()
raise
except Exception as e:
self._handle_exception(e)
raise
else:
self._process_chunk(chunk)
return chunk

def _process_chunk(self, chunk):
"""Process a streaming chunk"""
if hasattr(chunk, "type"):
if chunk.type == "response.output_text.delta":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider accumulating streaming text into a list and then joining the list once (e.g. using ''.join(list)) rather than performing repeated string concatenation. This can improve performance when processing many small chunks.

if hasattr(chunk, "delta") and chunk.delta:
self._output_text += chunk.delta
elif chunk.type == "response.completed" and hasattr(chunk, "response"):
self._complete_response_data = chunk.response

if hasattr(chunk, "delta"):
if hasattr(chunk.delta, "text") and chunk.delta.text:
self._output_text += chunk.delta.text

if hasattr(chunk, "response") and chunk.response:
self._complete_response_data = chunk.response

@dont_throw
def _process_complete_response(self):
"""Process the complete response and emit span"""
with self._cleanup_lock:
if self._cleanup_completed:
return

try:
if self._complete_response_data:
parsed_response = parse_response(self._complete_response_data)

self._traced_data.response_id = parsed_response.id
self._traced_data.response_model = parsed_response.model
self._traced_data.output_text = self._output_text

if parsed_response.usage:
self._traced_data.usage = parsed_response.usage

if parsed_response.output:
self._traced_data.output_blocks = {
block.id: block for block in parsed_response.output
}

responses[parsed_response.id] = self._traced_data

set_data_attributes(self._traced_data, self._span)
self._span.set_status(StatusCode.OK)
self._span.end()
self._cleanup_completed = True

except Exception as e:
if self._span and self._span.is_recording():
self._span.set_attribute(ERROR_TYPE, e.__class__.__name__)
self._span.set_status(StatusCode.ERROR, str(e))
self._span.end()
self._cleanup_completed = True

@dont_throw
def _handle_exception(self, exception):
"""Handle exceptions during streaming"""
with self._cleanup_lock:
if self._cleanup_completed:
return

if self._span and self._span.is_recording():
self._span.set_attribute(ERROR_TYPE, exception.__class__.__name__)
self._span.record_exception(exception)
self._span.set_status(StatusCode.ERROR, str(exception))
self._span.end()

self._cleanup_completed = True

@dont_throw
def _ensure_cleanup(self):
"""Ensure cleanup happens even if stream is not fully consumed"""
with self._cleanup_lock:
if self._cleanup_completed:
return

try:
if self._span and self._span.is_recording():
set_data_attributes(self._traced_data, self._span)
self._span.set_status(StatusCode.OK)
self._span.end()

self._cleanup_completed = True

except Exception:
self._cleanup_completed = True
Loading
Loading