diff --git a/tests/trace/test_weave_client.py b/tests/trace/test_weave_client.py index 0aa71db370d0..9e8cd380cbe3 100644 --- a/tests/trace/test_weave_client.py +++ b/tests/trace/test_weave_client.py @@ -1720,6 +1720,125 @@ def test_summary_tokens_cost_sqlite(client): assert with_cost_call_summary == weave_summary +def _setup_calls_for_storage_size_test(client): + """Helper function to set up calls for storage size tests. + + Returns: + List of created Call objects. + """ + call0 = client.create_call("x", {"a": 5, "b": 10}) + call0_child1 = client.create_call("x", {"a": 5, "b": 11}, call0) + call1 = client.create_call("y", {"a": 6, "b": 11}) + return [call0, call0_child1, call1] + + +def test_get_calls_storage_size_with_filter(client): + """Test that storage size parameters can be combined with other get_calls parameters.""" + all_calls = _setup_calls_for_storage_size_test(client) + assert len(all_calls) > 2 + + call0 = all_calls[0] + + # Test that parameters can be combined with other parameters + calls_filtered = list( + client.get_calls( + filter=tsi.CallsFilter(op_names=[call0.op_name]), + include_storage_size=True, + include_total_storage_size=True, + ) + ) + assert len(calls_filtered) == 2 + + +def test_get_calls_storage_size_with_limit(client): + """Test that storage size parameters can be combined with other get_calls parameters.""" + all_calls = _setup_calls_for_storage_size_test(client) + assert len(all_calls) > 2 + + # Test that parameters can be combined with other parameters + calls_limited = list( + client.get_calls( + include_storage_size=True, + include_total_storage_size=True, + limit=2, + ) + ) + assert len(calls_limited) == 2 + + +@pytest.fixture +def clickhouse_client(client): + if client_is_sqlite(client): + return None + return client.server._next_trace_server.ch_client + + +def test_get_calls_storage_size_values(client, clickhouse_client): + """Test that storage size values are correctly included when parameters are set.""" + if client_is_sqlite(client): + pytest.skip("Skipping test for sqlite clients") + + _setup_calls_for_storage_size_test(client) + + # This is a best effort to achieve consistency in the calls_merged_stats table. + # calls_merged_stats is an AggregatingMergeTree table populated by a materialized view. + # ClickHouse merges data asynchronously, so queries may see unmerged data. + # OPTIMIZE TABLE ... FINAL forces an immediate merge to ensure consistency for tests. + if clickhouse_client: + clickhouse_client.command("OPTIMIZE TABLE calls_merged_stats FINAL") + + # Get calls via get_calls with storage size parameters + client_calls = list( + client.get_calls(include_storage_size=True, include_total_storage_size=True) + ) + + # Get calls directly from server with same parameters + server_calls = list( + client.server.calls_query_stream( + tsi.CallsQueryReq( + project_id=client._project_id(), + include_storage_size=True, + include_total_storage_size=True, + ) + ) + ) + + # Verify same number of calls + assert len(client_calls) == len(server_calls) + assert len(server_calls) > 0 + + # Verify that get_calls returns the same calls (by ID) as direct server calls + client_call_ids = {call.id for call in client_calls if call.id} + server_call_ids = {call.id for call in server_calls if call.id} + assert client_call_ids == server_call_ids + + # Create a mapping of call IDs to client calls for easy lookup + client_calls_by_id = {call.id: call for call in client_calls if call.id} + + # Verify storage size fields and compare values between server and client calls + for server_call in server_calls: + # Verify storage size fields are present on server calls + assert hasattr(server_call, "storage_size_bytes") + assert hasattr(server_call, "total_storage_size_bytes") + + # Verify that storage size values match between server and client calls + if server_call.id and server_call.id in client_calls_by_id: + client_call = client_calls_by_id[server_call.id] + assert server_call.storage_size_bytes == client_call.storage_size_bytes + assert ( + server_call.total_storage_size_bytes + == client_call.total_storage_size_bytes + ) + assert server_call.storage_size_bytes is not None + + # total_storage_size_bytes is only set for root calls (parent_id is None) + # For child calls, it is intentionally None + expect_total_storage_size_bytes = server_call.parent_id is None + assert expect_total_storage_size_bytes == ( + server_call.total_storage_size_bytes is not None + ) + + def test_ref_in_dict(client): ref = client._save_object({"a": 5}, "d1") diff --git a/weave/trace/call.py b/weave/trace/call.py index cf5b91fb09ff..3f26f2d683f9 100644 --- a/weave/trace/call.py +++ b/weave/trace/call.py @@ -79,6 +79,12 @@ class Call: _children: list[Call] = dataclasses.field(default_factory=list) _feedback: RefFeedbackQuery | None = None + # Size of metadata storage for this call + storage_size_bytes: int | None = None + + # Total size of metadata storage for the entire trace + total_storage_size_bytes: int | None = None + @property def display_name(self) -> str | Callable[[Call], str] | None: return self._display_name @@ -332,6 +338,8 @@ def _make_calls_iterator( query: Query | None = None, include_costs: bool = False, include_feedback: bool = False, + include_storage_size: bool = False, + include_total_storage_size: bool = False, columns: list[str] | None = None, expand_columns: list[str] | None = None, return_expanded_column_values: bool = True, @@ -353,6 +361,8 @@ def fetch_func(offset: int, limit: int) -> list[CallSchema]: limit=limit, include_costs=include_costs, include_feedback=include_feedback, + include_storage_size=include_storage_size, + include_total_storage_size=include_total_storage_size, query=query, sort_by=sort_by, columns=columns, @@ -422,6 +432,8 @@ def make_client_call( wb_run_id=server_call.wb_run_id, wb_run_step=server_call.wb_run_step, wb_run_step_end=server_call.wb_run_step_end, + storage_size_bytes=server_call.storage_size_bytes, + total_storage_size_bytes=server_call.total_storage_size_bytes, ) if isinstance(call.attributes, AttributesDict): call.attributes.freeze() diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 434152f82a13..abc8126c2a14 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -529,6 +529,8 @@ def get_calls( query: QueryLike | None = None, include_costs: bool = False, include_feedback: bool = False, + include_storage_size: bool = False, + include_total_storage_size: bool = False, columns: list[str] | None = None, expand_columns: list[str] | None = None, return_expanded_column_values: bool = True, @@ -551,6 +553,8 @@ def get_calls( `query`: A mongo-like expression for advanced filtering. Not all Mongo operators are supported. `include_costs`: If True, includes token/cost info in `summary.weave`. `include_feedback`: If True, includes feedback in `summary.weave.feedback`. + `include_storage_size`: If True, includes the storage size for a call. + `include_total_storage_size`: If True, includes the total storage size for a trace. `columns`: List of fields to return per call. Reducing this can significantly improve performance. (Some fields like `id`, `trace_id`, `op_name`, and `started_at` are always included.) `scored_by`: Filter by one or more scorers (name or ref URI). Multiple scorers are AND-ed. @@ -585,6 +589,8 @@ def get_calls( query=query, include_costs=include_costs, include_feedback=include_feedback, + include_storage_size=include_storage_size, + include_total_storage_size=include_total_storage_size, columns=columns, expand_columns=expand_columns, return_expanded_column_values=return_expanded_column_values,