Skip to content

Commit 263a816

Browse files
committed
Query Embedding for Tool Rag
1 parent df5d40a commit 263a816

File tree

4 files changed

+132
-14
lines changed

4 files changed

+132
-14
lines changed

evaluator/algorithms/tool_rag_algorithm.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,33 @@ async def process_query(self, query_spec: QuerySpecification) -> AlgoResponse:
549549
agent = create_react_agent(self._model, relevant_tools)
550550
return await self._invoke_agent_on_query(agent, query_spec.query), relevant_tool_names
551551

552+
def embed_additional_queries(self, query_specs: List[QuerySpecification]):
553+
"""
554+
Embed all additional_queries found in the provided QuerySpecification instances.
555+
Returns a dict: {query_id: {tool_name: [(queryN, embedding), ...]}}
556+
"""
557+
if not hasattr(self, "embeddings") or self.embeddings is None:
558+
raise RuntimeError("Embeddings must be initialized before calling this method.")
559+
results = {}
560+
for spec in query_specs:
561+
if not getattr(spec, "additional_queries", None):
562+
continue
563+
spec_results = {}
564+
add_queries = spec.additional_queries
565+
# flatten if wrapped in {"additional_queries": ...}
566+
if isinstance(add_queries, dict) and "additional_queries" in add_queries and len(add_queries) == 1:
567+
add_queries = add_queries["additional_queries"]
568+
for tool_name, queries in add_queries.items():
569+
emb_list = []
570+
for k, query in queries.items():
571+
if not isinstance(query, str):
572+
continue
573+
emb = self.embeddings.embed_query(query)
574+
emb_list.append((k, emb))
575+
spec_results[tool_name] = emb_list
576+
results[spec.id] = spec_results
577+
return results
578+
552579
def tear_down(self) -> None:
553580
connections.disconnect(alias=MILVUS_CONNECTION_ALIAS)
554581

evaluator/components/data_provider.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class QuerySpecification(BaseModel):
2727
"""
2828
id: int
2929
query: str
30+
additional_queries: Optional[Dict[str, Any]] = None
31+
path: Optional[str] = None
3032
reference_answer: Optional[str] = None
3133
golden_tools: ToolSet = Field(default_factory=dict)
3234
additional_tools: Optional[ToolSet] = None
@@ -309,7 +311,7 @@ def _load_queries_from_single_file(
309311
root_dataset_path: str or Path,
310312
experiment_environment: EnvironmentConfig,
311313
dataset_config: DatasetConfig,
312-
) -> List[QuerySpecification]:
314+
) -> Tuple[List[QuerySpecification], List[Dict[str, Any]]]:
313315
with open(query_file_path, 'r') as f:
314316
data = json.load(f)
315317

@@ -328,6 +330,13 @@ def _load_queries_from_single_file(
328330
log(f"Invalid query spec, skipping this query.")
329331
else:
330332
query = raw_query_spec.get("query")
333+
if raw_query_spec.get("additional_queries"):
334+
additional_queries = raw_query_spec.get("additional_queries")
335+
print(f"Additional queries provided: {additional_queries}")
336+
337+
else:
338+
print(f"No additional queries provided")
339+
additional_queries = None
331340
query_id = int(raw_query_spec.get("query_id"))
332341
golden_tools, additional_tools = (
333342
_parse_raw_query_tool_definitions(raw_query_spec, experiment_environment, dataset_config))
@@ -341,6 +350,8 @@ def _load_queries_from_single_file(
341350
QuerySpecification(
342351
id=query_id,
343352
query=query,
353+
path=str(query_file_path),
354+
additional_queries=additional_queries,
344355
reference_answer=reference_answer,
345356
golden_tools=golden_tools,
346357
additional_tools=additional_tools or None
@@ -358,7 +369,7 @@ def get_queries(
358369
experiment_environment: EnvironmentConfig,
359370
dataset_config: DatasetConfig,
360371
fine_tuning_mode=False
361-
) -> List[QuerySpecification]:
372+
) -> Tuple[List[QuerySpecification], List[Dict[str, Any]]]:
362373
"""Load queries from the dataset."""
363374
root_dataset_path = Path(os.getenv("ROOT_DATASET_PATH"))
364375
if not root_dataset_path:
@@ -375,14 +386,14 @@ def get_queries(
375386
queries_num = None if fine_tuning_mode else dataset_config.queries_num
376387
queries = []
377388
for path in local_paths:
389+
print(f"\n\n")
390+
print(f"--------------------------------")
391+
print(f"Loading queries from file: {path}")
392+
print(f"\n\n")
378393
remaining_queries_num = None if queries_num is None else queries_num - len(queries)
379394
if remaining_queries_num == 0:
380395
break
381-
new_queries = _load_queries_from_single_file(path,
382-
remaining_queries_num,
383-
root_dataset_path,
384-
experiment_environment,
385-
dataset_config)
396+
new_queries= _load_queries_from_single_file(path, remaining_queries_num, root_dataset_path, experiment_environment, dataset_config)
386397
queries.extend(new_queries)
387398

388399
return queries

evaluator/evaluator.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44
import traceback
55
from typing import List, Tuple
6-
6+
from pathlib import Path
77
import openai
88
from langgraph.errors import GraphRecursionError
99
from pydantic import ValidationError
@@ -17,6 +17,8 @@
1717
from evaluator.interfaces.algorithm import Algorithm
1818
from evaluator.utils.csv_logger import CSVLogger
1919
from evaluator.components.llm_provider import get_llm
20+
from evaluator.utils.parsing_tools import generate_and_save_additional_queries
21+
import json as _json
2022
from dotenv import load_dotenv
2123

2224
from evaluator.utils.tool_logger import ToolLogger
@@ -35,13 +37,13 @@ class Evaluator(object):
3537

3638
config: EvaluationConfig
3739

38-
def __init__(self, config_path: str | None, use_defaults: bool):
40+
def __init__(self, config_path: str | None, use_defaults: bool, test_with_additional_queries: bool = False):
3941
try:
4042
self.config = load_config(config_path, use_defaults=use_defaults)
4143
except ConfigError as ce:
4244
log(f"Configuration error: {ce}")
4345
raise SystemExit(2)
44-
46+
self.test_with_additional_queries = test_with_additional_queries
4547
async def run(self) -> None:
4648

4749
# Set up the necessary components for the experiments:
@@ -112,15 +114,13 @@ async def _run_experiment(self,
112114
Runs the specified experiment and returns the number of evaluated queries.
113115
"""
114116
processed_queries_num = 0
115-
116117
try:
117118
queries = await self._set_up_experiment(spec, metric_collectors, mcp_proxy_manager)
118119
algorithm, environment = spec
119120

120121
try:
121122
for i, query_spec in enumerate(queries):
122123
log(f"Processing query #{query_spec.id} (Experiment {exp_index} of {total_exp_num}, query {i+1} of {len(queries)})...")
123-
124124
for mc in metric_collectors:
125125
mc.prepare_for_measurement(query_spec)
126126

@@ -199,12 +199,12 @@ async def _set_up_experiment(self,
199199
log(f"Initializing LLM connection: {environment.model_id}")
200200
llm = get_llm(model_id=environment.model_id, model_config=self.config.models)
201201
log("Connection established successfully.\n")
202-
203202
log("Fetching queries for the current experiment...")
204203
queries = get_queries(environment, self.config.data)
205204
log(f"Successfully loaded {len(queries)} queries.\n")
206205
print_iterable_verbose("The following queries will be executed:\n", queries)
207-
206+
log(f"Generating additional queries queries.\n")
207+
generate_and_save_additional_queries(llm, queries)
208208
log("Retrieving tool definitions for the current experiment...")
209209
tool_specs = get_tools_from_queries(queries)
210210
tools = await mcp_proxy_manager.run_mcp_proxy(tool_specs, init_client=True).get_tools()
@@ -213,8 +213,40 @@ async def _set_up_experiment(self,
213213

214214
log("Setting up the algorithm and the metric collectors...")
215215
algorithm.set_up(llm, tools)
216+
217+
if algorithm.__module__ == "evaluator.algorithms.tool_rag_algorithm":
218+
log("Embedding additional queries...")
219+
additional_query_embeddings = algorithm.embed_additional_queries(queries)
220+
print("Additional query embedding counts per query/tool:",
221+
{k: {tk: len(tv) for tk, tv in v.items()} for k, v in additional_query_embeddings.items()})
222+
else:
223+
print("No additional queries to embed.")
216224
for mc in metric_collectors:
217225
mc.set_up()
218226
log("All set!\n")
219227

220228
return queries
229+
230+
if __name__ == "__main__":
231+
import argparse
232+
parser = argparse.ArgumentParser(description="Run the Evaluator experiments.")
233+
parser.add_argument("--config", type=str, default=None, help="Path to evaluation config YAML file")
234+
parser.add_argument("--defaults", action="store_true", help="Use default config options if set")
235+
parser.add_argument("--test-with-additional-queries", action="store_true", help="Test with additional queries")
236+
args = parser.parse_args()
237+
238+
from evaluator.utils.utils import log
239+
240+
log("Starting Evaluator main...")
241+
evaluator = Evaluator(
242+
config_path=args.config,
243+
use_defaults=args.defaults,
244+
test_with_additional_queries=args.test_with_additional_queries
245+
)
246+
try:
247+
import asyncio
248+
asyncio.run(evaluator.run())
249+
log("Evaluator finished successfully!")
250+
except Exception as e:
251+
log(f"Evaluator failed: {e}")
252+
raise

evaluator/utils/parsing_tools.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from evaluator.components.llm_provider import query_llm
2+
from pathlib import Path
3+
import re
4+
import json
5+
from evaluator.utils.utils import print_iterable_verbose, log
6+
7+
def generate_and_save_additional_queries(llm, queries):
8+
"""
9+
For each query in queries, use the provided LLM to generate additional_queries if not present,
10+
and save to the appropriate JSON file for that query (matching by query_id).
11+
"""
12+
13+
system_prompt = '''You create 5 additional queries for each tool and only return the additional queries information, given the query implemented, return in the following format as a JSON string:
14+
{tool_name: {"query1": "", "query2": "", "query3": "", "query4": "", "query5": ""}} '''
15+
for i, query_spec in enumerate(queries):
16+
# If additional_queries already present, skip generating and saving
17+
if getattr(query_spec, 'additional_queries', None):
18+
log(f"Skipping query_id {getattr(query_spec, 'id', '<N/A>')} because additional_queries is present.")
19+
continue
20+
path = Path(query_spec.path)
21+
user_prompt = f"tool_name = {getattr(query_spec, 'golden_tools', {}).keys()}, Query= {getattr(query_spec, 'query', None)}"
22+
result = query_llm(llm, system_prompt, user_prompt)
23+
# Remove markdown/code block wrappers if present
24+
match = re.search(r"</think>\s*(.*)", result, re.DOTALL)
25+
result_text = match.group(1).strip() if match else result
26+
# Try to extract the 'additional_queries' dict block
27+
additional = None
28+
result_text = result_text.strip()
29+
try:
30+
additional = json.loads(result_text)
31+
except Exception as e:
32+
additional = None
33+
query_spec.additional_queries = additional
34+
# Saving additional queries to the original query JSON file
35+
if path and additional is not None:
36+
if path.exists():
37+
import json as _json
38+
with path.open('r', encoding='utf-8') as f:
39+
orig_queries = _json.load(f)
40+
for item in orig_queries:
41+
if (
42+
(item.get("query_id") == query_spec.id)
43+
or (str(item.get("query_id")) == str(query_spec.id))
44+
):
45+
item["additional_queries"] = additional
46+
with path.open('w', encoding='utf-8') as f:
47+
_json.dump(orig_queries, f, indent=2, ensure_ascii=False)
48+
log(f"Successfully added additional queries to original file {path}")

0 commit comments

Comments
 (0)