Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
guardrail_converse,
guardrail_handling,
)
from opentelemetry.instrumentation.bedrock.prompt_caching import prompt_caching_handling
from opentelemetry.instrumentation.bedrock.prompt_caching import (
prompt_caching_converse_handling,
prompt_caching_handling,
)
from opentelemetry.instrumentation.bedrock.reusable_streaming_body import (
ReusableStreamingBody,
)
Expand Down Expand Up @@ -354,6 +357,7 @@ def _handle_call(span: Span, kwargs, response, metric_params, event_logger):
def _handle_converse(span, kwargs, response, metric_params, event_logger):
(provider, model_vendor, model) = _get_vendor_model(kwargs.get("modelId"))
guardrail_converse(span, response, provider, model, metric_params)
prompt_caching_converse_handling(response, provider, model, metric_params)

set_converse_model_span_attributes(span, provider, model, kwargs)

Expand Down Expand Up @@ -394,7 +398,11 @@ def wrap(*args, **kwargs):
role = event["messageStart"]["role"]
elif "metadata" in event:
# last message sent
metadata = event.get("metadata", {})
guardrail_converse(span, event["metadata"], provider, model, metric_params)
prompt_caching_converse_handling(
metadata, provider, model, metric_params
)
converse_usage_record(span, event["metadata"], metric_params)
span.end()
elif "messageStop" in event:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,45 @@ def prompt_caching_handling(headers, vendor, model, metric_params):
)
if write_cached_tokens > 0:
span.set_attribute(CacheSpanAttrs.CACHED, "write")


def prompt_caching_converse_handling(response, vendor, model, metric_params):
base_attrs = {
"gen_ai.system": vendor,
"gen_ai.response.model": model,
}
span = trace.get_current_span()
if not isinstance(span, trace.Span) or not span.is_recording():
return

usage = response.get("usage", {})
read_cached_tokens = usage.get("cache_read_input_tokens", 0)
write_cached_tokens = usage.get("cache_creation_input_tokens", 0)

if read_cached_tokens > 0:
if metric_params.prompt_caching:
metric_params.prompt_caching.add(
read_cached_tokens,
attributes={
**base_attrs,
CacheSpanAttrs.TYPE: "read",
},
)
span.set_attribute(CacheSpanAttrs.CACHED, "read")
span.set_attribute(
"gen_ai.usage.cache_read_input_tokens", read_cached_tokens
)

if write_cached_tokens > 0:
if metric_params.prompt_caching:
metric_params.prompt_caching.add(
write_cached_tokens,
attributes={
**base_attrs,
CacheSpanAttrs.TYPE: "write",
},
)
span.set_attribute(CacheSpanAttrs.CACHED, "write")
span.set_attribute(
"gen_ai.usage.cache_creation_input_tokens", write_cached_tokens
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import pytest
from opentelemetry.instrumentation.bedrock import PromptCaching
from opentelemetry.instrumentation.bedrock.prompt_caching import CacheSpanAttrs


def call(brt):
return brt.converse(
modelId="anthropic.claude-3-haiku-20240307-v1:0",
messages=[
{
"role": "user",
"content": [
{
"text": "What is the capital of the USA?",
}
],
}
],
inferenceConfig={"maxTokens": 50, "temperature": 0.1},
additionalModelRequestFields={"cacheControl": {"type": "ephemeral"}},
)


def get_metric(resource_metrics, name):
for rm in resource_metrics:
for sm in rm.scope_metrics:
for metric in sm.metrics:
if metric.name == name:
return metric
raise Exception(f"No metric found with name {name}")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use AssertionError for test helper.

The generic Exception should be replaced with AssertionError since this is a test assertion helper and test failures should propagate as assertion failures.

As per static analysis hints.

Apply this diff:

 def get_metric(resource_metrics, name):
     for rm in resource_metrics:
         for sm in rm.scope_metrics:
             for metric in sm.metrics:
                 if metric.name == name:
                     return metric
-    raise Exception(f"No metric found with name {name}")
+    raise AssertionError(f"No metric found with name {name}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_metric(resource_metrics, name):
for rm in resource_metrics:
for sm in rm.scope_metrics:
for metric in sm.metrics:
if metric.name == name:
return metric
raise Exception(f"No metric found with name {name}")
def get_metric(resource_metrics, name):
for rm in resource_metrics:
for sm in rm.scope_metrics:
for metric in sm.metrics:
if metric.name == name:
return metric
raise AssertionError(f"No metric found with name {name}")
🧰 Tools
🪛 Ruff (0.14.2)

30-30: Create your own exception

(TRY002)


30-30: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In
packages/opentelemetry-instrumentation-bedrock/tests/metrics/test_bedrock_converse_prompt_caching_metrics.py
around lines 24 to 30, the helper raises a generic Exception when a metric is
not found; change it to raise an AssertionError (or use assert False) with the
same message so test failures surface as assertion failures. Update the raise
statement accordingly and keep the message f"No metric found with name {name}"
intact.



def assert_metric(reader, usage):
metrics_data = reader.get_metrics_data()
resource_metrics = metrics_data.resource_metrics
assert len(resource_metrics) > 0

m = get_metric(resource_metrics, PromptCaching.LLM_BEDROCK_PROMPT_CACHING)
for data_point in m.data.data_points:
assert data_point.attributes[CacheSpanAttrs.TYPE] in [
"read",
"write",
]
if data_point.attributes[CacheSpanAttrs.TYPE] == "read":
assert data_point.value == usage["cache_read_input_tokens"]
else:
assert data_point.value == usage["cache_creation_input_tokens"]


@pytest.mark.vcr
def test_prompt_cache_converse(test_context, brt):
_, _, reader = test_context

response = call(brt)
# assert first prompt writes a cache
usage = response["usage"]
assert usage["cache_read_input_tokens"] == 0
assert usage["cache_creation_input_tokens"] > 0
cumulative_workaround = usage["cache_creation_input_tokens"]
assert_metric(reader, usage)

response = call(brt)
# assert second prompt reads from the cache
usage = response["usage"]
assert usage["cache_read_input_tokens"] > 0
assert usage["cache_creation_input_tokens"] == 0
# data is stored across reads of metric data due to the cumulative behavior
usage["cache_creation_input_tokens"] = cumulative_workaround
assert_metric(reader, usage)