Skip to content

[Bug] Google GenAI Integration: Token overcounting, missing system instructions, and streaming issues with thinking models #5880

@oekekezie

Description

@oekekezie

Describe the bug

Description

I've identified four issues with the Google GenAI (Gemini) integration in Weave, particularly affecting streaming responses and thinking models (e.g., gemini-2.0-flash-thinking-exp, gemini-2.5-flash-preview, gemini-2.5-pro-preview).

Issues

1. Token Overcounting in Streaming Responses

Location: weave/integrations/google_genai/gemini_utils.py - google_genai_gemini_accumulator()

Problem: The accumulator sums token counts across streaming chunks, but Gemini returns cumulative counts (not incremental). This results in massively inflated token counts in traces.

Reference: Google AI Token Documentation

"When streaming output, the usageMetadata attribute only appears on the last chunk of the stream."

Expected: Token counts should be replaced with the latest non-None values, not summed.

2. System Instruction Not Captured in Traces

Location: weave/integrations/google_genai/gemini_utils.py - google_genai_gemini_postprocess_inputs()

Problem: The system_instruction from GenerateContentConfig is not extracted and surfaced in trace inputs, making it difficult to see what system prompt was used when reviewing traces.

Expected: System instructions should be visible at the top level of trace inputs.

3. thoughts_token_count Not Tracked

Location: weave/integrations/google_genai/gemini_utils.py - google_genai_gemini_on_finish()

Problem: Thinking models return thoughts_token_count in usage_metadata, but this is not captured in the usage summary. This makes it impossible to track thinking token usage for cost analysis.

Expected: thoughts_tokens should be included in the usage summary alongside prompt_tokens, completion_tokens, and total_tokens.

4. Response Content Dropped When Streaming with Thinking Models

Location: weave/integrations/google_genai/gemini_utils.py - google_genai_gemini_accumulator()

Problem: When streaming responses from thinking models, chunks may arrive at the same part index but with different thought values (thinking content vs response content). The current index-based accumulation overwrites response content with thinking content (or vice versa), causing data loss.

Reference: Google AI Thinking Documentation

Expected: Parts should be accumulated by type (thought=True vs thought=False/None), not by index, ensuring all content is preserved.

Environment

  • Weave version: 0.51.41 (also tested on earlier versions)
  • google-genai version: 1.x
  • Python version: 3.11+

Reproduction

import weave
from google import genai

weave.init("my-project")

client = genai.Client(api_key="...")

# Use a thinking model with streaming
response = client.models.generate_content_stream(
    model="gemini-2.5-flash-preview",
    contents="Think step by step to write a short poem about robots",
    config={
        "system_instruction": "You are a helpful AI assistant.",
        "thinking_config": {"include_thoughts": True},
    }
)

for chunk in response:
    pass  # Consume stream

# Check Weave trace - you'll see:
# 1. Inflated token counts
# 2. Missing system_instruction in inputs
# 3. Missing thoughts_tokens in usage
# 4. Potentially missing response content if thinking and response arrived at same index

Workaround

I've developed a monkey-patch that fixes all four issues. It must be applied before weave.init() is called:

# Apply BEFORE weave.init()
import weave_gemini_patches
weave_gemini_patches.apply_patches()

import weave
weave.init("my-project")
Click to expand workaround patch code
"""
Monkey patches for Weave SDK's Google GenAI integration.

These patches fix four issues:
1. Token overcounting in streaming responses (tokens were being summed instead of replaced)
2. System prompt not being captured in traces
3. Thinking content not being properly distinguished from response content
4. Response content being dropped when streaming chunks send response at same index as thinking

USAGE:
    # Apply these patches BEFORE calling weave.init()
    import weave_gemini_patches
    weave_gemini_patches.apply_patches()

    import weave
    weave.init("my-project")

    # Now use Google GenAI as normal
    from google import genai
    client = genai.Client(api_key="...")
    response = client.models.generate_content(...)

IMPORTANT: These patches MUST be applied before weave.init() is called,
because weave.init() triggers the automatic patching of the Google GenAI SDK.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

import weave.integrations.google_genai.gemini_utils as gemini_utils
from weave.trace.serialization.serialize import dictify

if TYPE_CHECKING:
    from google.genai.types import GenerateContentResponse

_patches_applied = False


def _fixed_google_genai_gemini_accumulator(
    acc: GenerateContentResponse | None, value: GenerateContentResponse
) -> GenerateContentResponse:
    """
    Fixed accumulator for Google GenAI streaming responses.

    Fixes:
    1. Token counts: REPLACES instead of SUMS (Gemini returns cumulative counts)
    2. Thinking: Accumulates parts by type (thought=True vs thought=False), not by index,
       ensuring all thinking content goes into one part and all response content into another
    3. thoughts_token_count: Now tracked in usage metadata

    See: https://ai.google.dev/gemini-api/docs/tokens
    "When streaming output, the usageMetadata attribute only appears on the
    last chunk of the stream."

    See: https://ai.google.dev/gemini-api/docs/thinking
    Parts with thought=True contain internal model reasoning.
    """
    if acc is None:
        return value

    # Accumulate content from candidates
    value_candidates = getattr(value, "candidates", None) or []
    acc_candidates = getattr(acc, "candidates", None) or []

    for i, value_candidate in enumerate(value_candidates):
        if i >= len(acc_candidates):
            break
        acc_candidate = acc_candidates[i]

        # Handle cases where content or parts might be None
        value_content = getattr(value_candidate, "content", None)
        acc_content = getattr(acc_candidate, "content", None)
        if value_content is None or acc_content is None:
            continue

        value_parts = getattr(value_content, "parts", None) or []
        acc_parts = getattr(acc_content, "parts", None)
        if acc_parts is None:
            acc_parts = []
            acc_content.parts = acc_parts  # Ensure appends modify the actual attribute

        # Accumulate parts by type (thought vs response), not by index.
        # This handles streaming where chunks may arrive at the same index but with
        # different thought values, and ensures all content of the same type is
        # accumulated into a single part.
        for value_part in value_parts:
            is_thought = getattr(value_part, 'thought', False)

            # Find existing part of the same type to accumulate into
            target_part = None
            for existing_part in acc_parts:
                if getattr(existing_part, 'thought', False) == is_thought:
                    target_part = existing_part
                    break

            if target_part is not None:
                # Found existing part of same type - accumulate text
                if hasattr(value_part, 'text') and value_part.text is not None:
                    if hasattr(target_part, 'text') and target_part.text is not None:
                        target_part.text += value_part.text
                    elif hasattr(target_part, 'text'):
                        target_part.text = value_part.text
            else:
                # No existing part of this type - append new part
                acc_parts.append(value_part.model_copy(deep=True))

    # FIX: REPLACE token counts with latest non-None values instead of adding
    # The Gemini API returns cumulative counts, not incremental ones
    if hasattr(value, 'usage_metadata') and value.usage_metadata is not None:
        # If the accumulator lacks usage_metadata, adopt it from the latest chunk
        if not hasattr(acc, 'usage_metadata') or acc.usage_metadata is None:
            acc.usage_metadata = value.usage_metadata

        if hasattr(acc, 'usage_metadata') and acc.usage_metadata is not None:
            if value.usage_metadata.prompt_token_count is not None:
                acc.usage_metadata.prompt_token_count = value.usage_metadata.prompt_token_count

            if value.usage_metadata.candidates_token_count is not None:
                acc.usage_metadata.candidates_token_count = value.usage_metadata.candidates_token_count

            if value.usage_metadata.total_token_count is not None:
                acc.usage_metadata.total_token_count = value.usage_metadata.total_token_count

            if hasattr(value.usage_metadata, 'cached_content_token_count') and \
               value.usage_metadata.cached_content_token_count is not None:
                acc.usage_metadata.cached_content_token_count = value.usage_metadata.cached_content_token_count

            # FIX: Also track thoughts_token_count for thinking models
            if hasattr(value.usage_metadata, 'thoughts_token_count') and \
               value.usage_metadata.thoughts_token_count is not None:
                acc.usage_metadata.thoughts_token_count = value.usage_metadata.thoughts_token_count

    return acc


def _fixed_google_genai_gemini_postprocess_inputs(inputs: dict[str, Any]) -> dict[str, Any]:
    """
    Fixed postprocess_inputs that also extracts system instructions.

    The original implementation only extracted the model name and serialized 'self'.
    This fix also explicitly extracts system_instruction from the config parameter
    and from Chat objects to ensure it appears in traces.
    """
    # Extract the model name from the inputs and ensure it is present in the inputs
    if "self" in inputs:
        model_name = getattr(inputs["self"], "_model", None)
        if model_name is not None:
            inputs["model"] = model_name

        # For Chat objects, extract system instruction from the config
        self_obj = inputs["self"]

        # Try to extract system instruction from Chat._config
        if hasattr(self_obj, '_config') and self_obj._config is not None:
            config = self_obj._config
            system_instruction = None

            # Try different ways the system instruction might be stored
            if hasattr(config, 'system_instruction'):
                system_instruction = config.system_instruction
            elif isinstance(config, dict) and 'system_instruction' in config:
                system_instruction = config['system_instruction']

            if system_instruction is not None:
                inputs["system_instruction"] = _serialize_content(system_instruction)

        # Convert the `self` parameter to a dictionary
        inputs["self"] = dictify(inputs["self"])

    # Also check for system_instruction in the config parameter (for generate_content calls)
    if "config" in inputs and inputs["config"] is not None:
        config = inputs["config"]
        system_instruction = None

        if hasattr(config, 'system_instruction') and config.system_instruction is not None:
            system_instruction = config.system_instruction
        elif isinstance(config, dict) and config.get('system_instruction') is not None:
            system_instruction = config['system_instruction']

        if system_instruction is not None:
            inputs["system_instruction"] = _serialize_content(system_instruction)

    return inputs


def _serialize_content(content: Any) -> Any:
    """Helper to serialize Content/Part objects to a readable format."""
    if content is None:
        return None

    # If it's a string, return as-is
    if isinstance(content, str):
        return content

    # If it has a to_dict method (Pydantic model), use it
    if hasattr(content, 'to_dict'):
        try:
            return content.to_dict()
        except Exception:
            pass

    # If it has a model_dump method (Pydantic v2), use it
    if hasattr(content, 'model_dump'):
        try:
            return content.model_dump()
        except Exception:
            pass

    # If it has parts attribute (Content object), extract text from parts
    if hasattr(content, 'parts'):
        parts = content.parts
        if parts:
            texts = []
            for part in parts:
                if hasattr(part, 'text') and part.text:
                    texts.append(part.text)
            if texts:
                return '\n'.join(texts) if len(texts) > 1 else texts[0]

    # Fallback to string representation
    return str(content)


def _fixed_google_genai_gemini_on_finish(
    call: Any, output: Any, exception: Any = None
) -> None:
    """
    Fixed on_finish handler that also tracks thoughts_token_count.

    The original implementation only tracked prompt_tokens, completion_tokens,
    and total_tokens. This fix also tracks thoughts_tokens for thinking models.
    """
    if not (model_name := call.inputs.get("model")):
        raise ValueError("Unknown model type")
    usage = {model_name: {"requests": 1}}
    summary_update = {"usage": usage}
    if output:
        call.output = dictify(output)
        if hasattr(output, "usage_metadata"):
            usage[model_name].update(
                {
                    "prompt_tokens": output.usage_metadata.prompt_token_count,
                    "completion_tokens": output.usage_metadata.candidates_token_count,
                    "total_tokens": output.usage_metadata.total_token_count,
                }
            )
            # FIX: Also track thoughts_tokens for thinking models
            if hasattr(output.usage_metadata, 'thoughts_token_count') and \
               output.usage_metadata.thoughts_token_count is not None:
                usage[model_name]["thoughts_tokens"] = output.usage_metadata.thoughts_token_count
    if call.summary is not None:
        call.summary.update(summary_update)


def apply_patches() -> None:
    """
    Apply the monkey patches to fix Google GenAI integration issues.

    MUST be called BEFORE weave.init() for the patches to take effect.
    """
    global _patches_applied
    if _patches_applied:
        return

    # Store original functions for potential restoration
    if not hasattr(gemini_utils, "_original_accumulator"):
        gemini_utils._original_accumulator = gemini_utils.google_genai_gemini_accumulator  # type: ignore[attr-defined]
    if not hasattr(gemini_utils, "_original_postprocess_inputs"):
        gemini_utils._original_postprocess_inputs = gemini_utils.google_genai_gemini_postprocess_inputs  # type: ignore[attr-defined]
    if not hasattr(gemini_utils, "_original_on_finish"):
        gemini_utils._original_on_finish = gemini_utils.google_genai_gemini_on_finish  # type: ignore[attr-defined]

    # Apply the fixed functions
    gemini_utils.google_genai_gemini_accumulator = _fixed_google_genai_gemini_accumulator
    gemini_utils.google_genai_gemini_postprocess_inputs = _fixed_google_genai_gemini_postprocess_inputs
    gemini_utils.google_genai_gemini_on_finish = _fixed_google_genai_gemini_on_finish

    _patches_applied = True
    print("✓ Weave Google GenAI patches applied successfully")
    print("  - Fixed: Token overcounting in streaming responses")
    print("  - Fixed: System instruction capture in traces")
    print("  - Fixed: Thinking token tracking (thoughts_tokens)")
    print("  - Fixed: Response content preservation in streaming with thinking")


def restore_original() -> None:
    """
    Restore the original (buggy) functions if needed.

    Note: This only works if the patches were applied before weave.init().
    If weave.init() has already been called, the patchers have already
    captured references to the functions.
    """
    global _patches_applied

    if hasattr(gemini_utils, '_original_accumulator'):
        gemini_utils.google_genai_gemini_accumulator = gemini_utils._original_accumulator  # type: ignore[attr-defined]
    if hasattr(gemini_utils, '_original_postprocess_inputs'):
        gemini_utils.google_genai_gemini_postprocess_inputs = gemini_utils._original_postprocess_inputs  # type: ignore[attr-defined]
    if hasattr(gemini_utils, '_original_on_finish'):
        gemini_utils.google_genai_gemini_on_finish = gemini_utils._original_on_finish  # type: ignore[attr-defined]

    _patches_applied = False
    print("✓ Original functions restored")

Suggested Fix

The fixes are relatively straightforward:

  1. Token counting: Replace token counts with latest values instead of summing
  2. System instruction: Extract from config.system_instruction in postprocess_inputs
  3. Thoughts tokens: Add thoughts_tokens to usage summary in on_finish
  4. Part accumulation: Accumulate parts by type (thought attribute), not by index

I'm happy to submit a PR if that would be helpful. (I haven't submitted a PR before but would be happy to give it a go if that would be helpful!) (EDIT: I submitted a draft PR -- hope it helps!)

Additional Context

Verified the patch is working correctly by examining Weave traces - after applying the patch:

  • Token counts are accurate (e.g., 23 + 53 + 308 = 384 total)
  • System instruction appears in trace inputs
  • Thinking and response content are properly separated into distinct parts
  • No content is dropped during streaming accumulation

Weave Project Link

No response

Screenshots

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions