-
Notifications
You must be signed in to change notification settings - Fork 140
Description
Describe the bug
Description
I've identified four issues with the Google GenAI (Gemini) integration in Weave, particularly affecting streaming responses and thinking models (e.g., gemini-2.0-flash-thinking-exp, gemini-2.5-flash-preview, gemini-2.5-pro-preview).
Issues
1. Token Overcounting in Streaming Responses
Location: weave/integrations/google_genai/gemini_utils.py - google_genai_gemini_accumulator()
Problem: The accumulator sums token counts across streaming chunks, but Gemini returns cumulative counts (not incremental). This results in massively inflated token counts in traces.
Reference: Google AI Token Documentation
"When streaming output, the usageMetadata attribute only appears on the last chunk of the stream."
Expected: Token counts should be replaced with the latest non-None values, not summed.
2. System Instruction Not Captured in Traces
Location: weave/integrations/google_genai/gemini_utils.py - google_genai_gemini_postprocess_inputs()
Problem: The system_instruction from GenerateContentConfig is not extracted and surfaced in trace inputs, making it difficult to see what system prompt was used when reviewing traces.
Expected: System instructions should be visible at the top level of trace inputs.
3. thoughts_token_count Not Tracked
Location: weave/integrations/google_genai/gemini_utils.py - google_genai_gemini_on_finish()
Problem: Thinking models return thoughts_token_count in usage_metadata, but this is not captured in the usage summary. This makes it impossible to track thinking token usage for cost analysis.
Expected: thoughts_tokens should be included in the usage summary alongside prompt_tokens, completion_tokens, and total_tokens.
4. Response Content Dropped When Streaming with Thinking Models
Location: weave/integrations/google_genai/gemini_utils.py - google_genai_gemini_accumulator()
Problem: When streaming responses from thinking models, chunks may arrive at the same part index but with different thought values (thinking content vs response content). The current index-based accumulation overwrites response content with thinking content (or vice versa), causing data loss.
Reference: Google AI Thinking Documentation
Expected: Parts should be accumulated by type (thought=True vs thought=False/None), not by index, ensuring all content is preserved.
Environment
- Weave version: 0.51.41 (also tested on earlier versions)
- google-genai version: 1.x
- Python version: 3.11+
Reproduction
import weave
from google import genai
weave.init("my-project")
client = genai.Client(api_key="...")
# Use a thinking model with streaming
response = client.models.generate_content_stream(
model="gemini-2.5-flash-preview",
contents="Think step by step to write a short poem about robots",
config={
"system_instruction": "You are a helpful AI assistant.",
"thinking_config": {"include_thoughts": True},
}
)
for chunk in response:
pass # Consume stream
# Check Weave trace - you'll see:
# 1. Inflated token counts
# 2. Missing system_instruction in inputs
# 3. Missing thoughts_tokens in usage
# 4. Potentially missing response content if thinking and response arrived at same indexWorkaround
I've developed a monkey-patch that fixes all four issues. It must be applied before weave.init() is called:
# Apply BEFORE weave.init()
import weave_gemini_patches
weave_gemini_patches.apply_patches()
import weave
weave.init("my-project")Click to expand workaround patch code
"""
Monkey patches for Weave SDK's Google GenAI integration.
These patches fix four issues:
1. Token overcounting in streaming responses (tokens were being summed instead of replaced)
2. System prompt not being captured in traces
3. Thinking content not being properly distinguished from response content
4. Response content being dropped when streaming chunks send response at same index as thinking
USAGE:
# Apply these patches BEFORE calling weave.init()
import weave_gemini_patches
weave_gemini_patches.apply_patches()
import weave
weave.init("my-project")
# Now use Google GenAI as normal
from google import genai
client = genai.Client(api_key="...")
response = client.models.generate_content(...)
IMPORTANT: These patches MUST be applied before weave.init() is called,
because weave.init() triggers the automatic patching of the Google GenAI SDK.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import weave.integrations.google_genai.gemini_utils as gemini_utils
from weave.trace.serialization.serialize import dictify
if TYPE_CHECKING:
from google.genai.types import GenerateContentResponse
_patches_applied = False
def _fixed_google_genai_gemini_accumulator(
acc: GenerateContentResponse | None, value: GenerateContentResponse
) -> GenerateContentResponse:
"""
Fixed accumulator for Google GenAI streaming responses.
Fixes:
1. Token counts: REPLACES instead of SUMS (Gemini returns cumulative counts)
2. Thinking: Accumulates parts by type (thought=True vs thought=False), not by index,
ensuring all thinking content goes into one part and all response content into another
3. thoughts_token_count: Now tracked in usage metadata
See: https://ai.google.dev/gemini-api/docs/tokens
"When streaming output, the usageMetadata attribute only appears on the
last chunk of the stream."
See: https://ai.google.dev/gemini-api/docs/thinking
Parts with thought=True contain internal model reasoning.
"""
if acc is None:
return value
# Accumulate content from candidates
value_candidates = getattr(value, "candidates", None) or []
acc_candidates = getattr(acc, "candidates", None) or []
for i, value_candidate in enumerate(value_candidates):
if i >= len(acc_candidates):
break
acc_candidate = acc_candidates[i]
# Handle cases where content or parts might be None
value_content = getattr(value_candidate, "content", None)
acc_content = getattr(acc_candidate, "content", None)
if value_content is None or acc_content is None:
continue
value_parts = getattr(value_content, "parts", None) or []
acc_parts = getattr(acc_content, "parts", None)
if acc_parts is None:
acc_parts = []
acc_content.parts = acc_parts # Ensure appends modify the actual attribute
# Accumulate parts by type (thought vs response), not by index.
# This handles streaming where chunks may arrive at the same index but with
# different thought values, and ensures all content of the same type is
# accumulated into a single part.
for value_part in value_parts:
is_thought = getattr(value_part, 'thought', False)
# Find existing part of the same type to accumulate into
target_part = None
for existing_part in acc_parts:
if getattr(existing_part, 'thought', False) == is_thought:
target_part = existing_part
break
if target_part is not None:
# Found existing part of same type - accumulate text
if hasattr(value_part, 'text') and value_part.text is not None:
if hasattr(target_part, 'text') and target_part.text is not None:
target_part.text += value_part.text
elif hasattr(target_part, 'text'):
target_part.text = value_part.text
else:
# No existing part of this type - append new part
acc_parts.append(value_part.model_copy(deep=True))
# FIX: REPLACE token counts with latest non-None values instead of adding
# The Gemini API returns cumulative counts, not incremental ones
if hasattr(value, 'usage_metadata') and value.usage_metadata is not None:
# If the accumulator lacks usage_metadata, adopt it from the latest chunk
if not hasattr(acc, 'usage_metadata') or acc.usage_metadata is None:
acc.usage_metadata = value.usage_metadata
if hasattr(acc, 'usage_metadata') and acc.usage_metadata is not None:
if value.usage_metadata.prompt_token_count is not None:
acc.usage_metadata.prompt_token_count = value.usage_metadata.prompt_token_count
if value.usage_metadata.candidates_token_count is not None:
acc.usage_metadata.candidates_token_count = value.usage_metadata.candidates_token_count
if value.usage_metadata.total_token_count is not None:
acc.usage_metadata.total_token_count = value.usage_metadata.total_token_count
if hasattr(value.usage_metadata, 'cached_content_token_count') and \
value.usage_metadata.cached_content_token_count is not None:
acc.usage_metadata.cached_content_token_count = value.usage_metadata.cached_content_token_count
# FIX: Also track thoughts_token_count for thinking models
if hasattr(value.usage_metadata, 'thoughts_token_count') and \
value.usage_metadata.thoughts_token_count is not None:
acc.usage_metadata.thoughts_token_count = value.usage_metadata.thoughts_token_count
return acc
def _fixed_google_genai_gemini_postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]:
"""
Fixed postprocess_inputs that also extracts system instructions.
The original implementation only extracted the model name and serialized 'self'.
This fix also explicitly extracts system_instruction from the config parameter
and from Chat objects to ensure it appears in traces.
"""
# Extract the model name from the inputs and ensure it is present in the inputs
if "self" in inputs:
model_name = getattr(inputs["self"], "_model", None)
if model_name is not None:
inputs["model"] = model_name
# For Chat objects, extract system instruction from the config
self_obj = inputs["self"]
# Try to extract system instruction from Chat._config
if hasattr(self_obj, '_config') and self_obj._config is not None:
config = self_obj._config
system_instruction = None
# Try different ways the system instruction might be stored
if hasattr(config, 'system_instruction'):
system_instruction = config.system_instruction
elif isinstance(config, dict) and 'system_instruction' in config:
system_instruction = config['system_instruction']
if system_instruction is not None:
inputs["system_instruction"] = _serialize_content(system_instruction)
# Convert the `self` parameter to a dictionary
inputs["self"] = dictify(inputs["self"])
# Also check for system_instruction in the config parameter (for generate_content calls)
if "config" in inputs and inputs["config"] is not None:
config = inputs["config"]
system_instruction = None
if hasattr(config, 'system_instruction') and config.system_instruction is not None:
system_instruction = config.system_instruction
elif isinstance(config, dict) and config.get('system_instruction') is not None:
system_instruction = config['system_instruction']
if system_instruction is not None:
inputs["system_instruction"] = _serialize_content(system_instruction)
return inputs
def _serialize_content(content: Any) -> Any:
"""Helper to serialize Content/Part objects to a readable format."""
if content is None:
return None
# If it's a string, return as-is
if isinstance(content, str):
return content
# If it has a to_dict method (Pydantic model), use it
if hasattr(content, 'to_dict'):
try:
return content.to_dict()
except Exception:
pass
# If it has a model_dump method (Pydantic v2), use it
if hasattr(content, 'model_dump'):
try:
return content.model_dump()
except Exception:
pass
# If it has parts attribute (Content object), extract text from parts
if hasattr(content, 'parts'):
parts = content.parts
if parts:
texts = []
for part in parts:
if hasattr(part, 'text') and part.text:
texts.append(part.text)
if texts:
return '\n'.join(texts) if len(texts) > 1 else texts[0]
# Fallback to string representation
return str(content)
def _fixed_google_genai_gemini_on_finish(
call: Any, output: Any, exception: Any = None
) -> None:
"""
Fixed on_finish handler that also tracks thoughts_token_count.
The original implementation only tracked prompt_tokens, completion_tokens,
and total_tokens. This fix also tracks thoughts_tokens for thinking models.
"""
if not (model_name := call.inputs.get("model")):
raise ValueError("Unknown model type")
usage = {model_name: {"requests": 1}}
summary_update = {"usage": usage}
if output:
call.output = dictify(output)
if hasattr(output, "usage_metadata"):
usage[model_name].update(
{
"prompt_tokens": output.usage_metadata.prompt_token_count,
"completion_tokens": output.usage_metadata.candidates_token_count,
"total_tokens": output.usage_metadata.total_token_count,
}
)
# FIX: Also track thoughts_tokens for thinking models
if hasattr(output.usage_metadata, 'thoughts_token_count') and \
output.usage_metadata.thoughts_token_count is not None:
usage[model_name]["thoughts_tokens"] = output.usage_metadata.thoughts_token_count
if call.summary is not None:
call.summary.update(summary_update)
def apply_patches() -> None:
"""
Apply the monkey patches to fix Google GenAI integration issues.
MUST be called BEFORE weave.init() for the patches to take effect.
"""
global _patches_applied
if _patches_applied:
return
# Store original functions for potential restoration
if not hasattr(gemini_utils, "_original_accumulator"):
gemini_utils._original_accumulator = gemini_utils.google_genai_gemini_accumulator # type: ignore[attr-defined]
if not hasattr(gemini_utils, "_original_postprocess_inputs"):
gemini_utils._original_postprocess_inputs = gemini_utils.google_genai_gemini_postprocess_inputs # type: ignore[attr-defined]
if not hasattr(gemini_utils, "_original_on_finish"):
gemini_utils._original_on_finish = gemini_utils.google_genai_gemini_on_finish # type: ignore[attr-defined]
# Apply the fixed functions
gemini_utils.google_genai_gemini_accumulator = _fixed_google_genai_gemini_accumulator
gemini_utils.google_genai_gemini_postprocess_inputs = _fixed_google_genai_gemini_postprocess_inputs
gemini_utils.google_genai_gemini_on_finish = _fixed_google_genai_gemini_on_finish
_patches_applied = True
print("✓ Weave Google GenAI patches applied successfully")
print(" - Fixed: Token overcounting in streaming responses")
print(" - Fixed: System instruction capture in traces")
print(" - Fixed: Thinking token tracking (thoughts_tokens)")
print(" - Fixed: Response content preservation in streaming with thinking")
def restore_original() -> None:
"""
Restore the original (buggy) functions if needed.
Note: This only works if the patches were applied before weave.init().
If weave.init() has already been called, the patchers have already
captured references to the functions.
"""
global _patches_applied
if hasattr(gemini_utils, '_original_accumulator'):
gemini_utils.google_genai_gemini_accumulator = gemini_utils._original_accumulator # type: ignore[attr-defined]
if hasattr(gemini_utils, '_original_postprocess_inputs'):
gemini_utils.google_genai_gemini_postprocess_inputs = gemini_utils._original_postprocess_inputs # type: ignore[attr-defined]
if hasattr(gemini_utils, '_original_on_finish'):
gemini_utils.google_genai_gemini_on_finish = gemini_utils._original_on_finish # type: ignore[attr-defined]
_patches_applied = False
print("✓ Original functions restored")Suggested Fix
The fixes are relatively straightforward:
- Token counting: Replace token counts with latest values instead of summing
- System instruction: Extract from
config.system_instructioninpostprocess_inputs - Thoughts tokens: Add
thoughts_tokensto usage summary inon_finish - Part accumulation: Accumulate parts by type (
thoughtattribute), not by index
I'm happy to submit a PR if that would be helpful. (I haven't submitted a PR before but would be happy to give it a go if that would be helpful!) (EDIT: I submitted a draft PR -- hope it helps!)
Additional Context
Verified the patch is working correctly by examining Weave traces - after applying the patch:
- Token counts are accurate (e.g., 23 + 53 + 308 = 384 total)
- System instruction appears in trace inputs
- Thinking and response content are properly separated into distinct parts
- No content is dropped during streaming accumulation
Weave Project Link
No response
Screenshots
No response