Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions tests/trace/test_weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1720,6 +1720,125 @@ def test_summary_tokens_cost_sqlite(client):
assert with_cost_call_summary == weave_summary


def _setup_calls_for_storage_size_test(client):
"""Helper function to set up calls for storage size tests.

Returns:
List of created Call objects.
"""
call0 = client.create_call("x", {"a": 5, "b": 10})
call0_child1 = client.create_call("x", {"a": 5, "b": 11}, call0)
call1 = client.create_call("y", {"a": 6, "b": 11})
return [call0, call0_child1, call1]


def test_get_calls_storage_size_with_filter(client):
"""Test that storage size parameters can be combined with other get_calls parameters."""
all_calls = _setup_calls_for_storage_size_test(client)
assert len(all_calls) > 2

call0 = all_calls[0]

# Test that parameters can be combined with other parameters
calls_filtered = list(
client.get_calls(
filter=tsi.CallsFilter(op_names=[call0.op_name]),
include_storage_size=True,
include_total_storage_size=True,
)
)
assert len(calls_filtered) == 2


def test_get_calls_storage_size_with_limit(client):
"""Test that storage size parameters can be combined with other get_calls parameters."""
all_calls = _setup_calls_for_storage_size_test(client)
assert len(all_calls) > 2

# Test that parameters can be combined with other parameters
calls_limited = list(
client.get_calls(
include_storage_size=True,
include_total_storage_size=True,
limit=2,
)
)
assert len(calls_limited) == 2


@pytest.fixture
def clickhouse_client(client):
if client_is_sqlite(client):
return None
return client.server._next_trace_server.ch_client


def test_get_calls_storage_size_values(client, clickhouse_client):
"""Test that storage size values are correctly included when parameters are set."""
if client_is_sqlite(client):
pytest.skip("Skipping test for sqlite clients")

_setup_calls_for_storage_size_test(client)

# This is a best effort to achieve consistency in the calls_merged_stats table.
# calls_merged_stats is an AggregatingMergeTree table populated by a materialized view.
# ClickHouse merges data asynchronously, so queries may see unmerged data.
# OPTIMIZE TABLE ... FINAL forces an immediate merge to ensure consistency for tests.
if clickhouse_client:
clickhouse_client.command("OPTIMIZE TABLE calls_merged_stats FINAL")

# Get calls via get_calls with storage size parameters
client_calls = list(
client.get_calls(include_storage_size=True, include_total_storage_size=True)
)

# Get calls directly from server with same parameters
server_calls = list(
client.server.calls_query_stream(
tsi.CallsQueryReq(
project_id=client._project_id(),
include_storage_size=True,
include_total_storage_size=True,
)
)
)

# Verify same number of calls
assert len(client_calls) == len(server_calls)
assert len(server_calls) > 0

# Verify that get_calls returns the same calls (by ID) as direct server calls
client_call_ids = {call.id for call in client_calls if call.id}
server_call_ids = {call.id for call in server_calls if call.id}
assert client_call_ids == server_call_ids

# Create a mapping of call IDs to client calls for easy lookup
client_calls_by_id = {call.id: call for call in client_calls if call.id}

# Verify storage size fields and compare values between server and client calls
for server_call in server_calls:
# Verify storage size fields are present on server calls
assert hasattr(server_call, "storage_size_bytes")
assert hasattr(server_call, "total_storage_size_bytes")

# Verify that storage size values match between server and client calls
if server_call.id and server_call.id in client_calls_by_id:
client_call = client_calls_by_id[server_call.id]
assert server_call.storage_size_bytes == client_call.storage_size_bytes
assert (
server_call.total_storage_size_bytes
== client_call.total_storage_size_bytes
)
assert server_call.storage_size_bytes is not None

# total_storage_size_bytes is only set for root calls (parent_id is None)
# For child calls, it is intentionally None
expect_total_storage_size_bytes = server_call.parent_id is None
assert expect_total_storage_size_bytes == (
server_call.total_storage_size_bytes is not None
)


def test_ref_in_dict(client):
ref = client._save_object({"a": 5}, "d1")

Expand Down
12 changes: 12 additions & 0 deletions weave/trace/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ class Call:
_children: list[Call] = dataclasses.field(default_factory=list)
_feedback: RefFeedbackQuery | None = None

# Size of metadata storage for this call
storage_size_bytes: int | None = None

# Total size of metadata storage for the entire trace
total_storage_size_bytes: int | None = None

@property
def display_name(self) -> str | Callable[[Call], str] | None:
return self._display_name
Expand Down Expand Up @@ -332,6 +338,8 @@ def _make_calls_iterator(
query: Query | None = None,
include_costs: bool = False,
include_feedback: bool = False,
include_storage_size: bool = False,
include_total_storage_size: bool = False,
columns: list[str] | None = None,
expand_columns: list[str] | None = None,
return_expanded_column_values: bool = True,
Expand All @@ -353,6 +361,8 @@ def fetch_func(offset: int, limit: int) -> list[CallSchema]:
limit=limit,
include_costs=include_costs,
include_feedback=include_feedback,
include_storage_size=include_storage_size,
include_total_storage_size=include_total_storage_size,
query=query,
sort_by=sort_by,
columns=columns,
Expand Down Expand Up @@ -422,6 +432,8 @@ def make_client_call(
wb_run_id=server_call.wb_run_id,
wb_run_step=server_call.wb_run_step,
wb_run_step_end=server_call.wb_run_step_end,
storage_size_bytes=server_call.storage_size_bytes,
total_storage_size_bytes=server_call.total_storage_size_bytes,
)
if isinstance(call.attributes, AttributesDict):
call.attributes.freeze()
Expand Down
6 changes: 6 additions & 0 deletions weave/trace/weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,8 @@ def get_calls(
query: QueryLike | None = None,
include_costs: bool = False,
include_feedback: bool = False,
include_storage_size: bool = False,
include_total_storage_size: bool = False,
columns: list[str] | None = None,
expand_columns: list[str] | None = None,
return_expanded_column_values: bool = True,
Expand All @@ -551,6 +553,8 @@ def get_calls(
`query`: A mongo-like expression for advanced filtering. Not all Mongo operators are supported.
`include_costs`: If True, includes token/cost info in `summary.weave`.
`include_feedback`: If True, includes feedback in `summary.weave.feedback`.
`include_storage_size`: If True, includes the storage size for a call.
`include_total_storage_size`: If True, includes the total storage size for a trace.
`columns`: List of fields to return per call. Reducing this can significantly improve performance.
(Some fields like `id`, `trace_id`, `op_name`, and `started_at` are always included.)
`scored_by`: Filter by one or more scorers (name or ref URI). Multiple scorers are AND-ed.
Expand Down Expand Up @@ -585,6 +589,8 @@ def get_calls(
query=query,
include_costs=include_costs,
include_feedback=include_feedback,
include_storage_size=include_storage_size,
include_total_storage_size=include_total_storage_size,
columns=columns,
expand_columns=expand_columns,
return_expanded_column_values=return_expanded_column_values,
Expand Down