@@ -216,7 +216,6 @@ def _compose_tool_text(self, tool: BaseTool) -> str:
216216 parts_to_include = self ._settings ["indexed_tool_def_parts" ]
217217 if not parts_to_include :
218218 raise ValueError ("indexed_tool_def_parts must be a non-empty list" )
219-
220219 segments = []
221220 for p in parts_to_include :
222221 if p .lower () == "name" :
@@ -236,18 +235,15 @@ def _compose_tool_text(self, tool: BaseTool) -> str:
236235 if tags :
237236 segments .append (f"tags: { ' ' .join (tags )} " )
238237 elif p .lower () == "additional_queries" :
239- # Append example queries supplied via settings["additional_queries"][tool.name]
240238 examples_map = self ._settings .get ("additional_queries" ) or {}
241239 examples_list = examples_map .get (tool .name ) or []
242240 if examples_list :
243241 rendered = self ._render_examples (examples_list )
244242 if rendered :
245243 segments .append (f"ex: { rendered } " )
246-
247244 if not segments :
248245 raise ValueError (f"The following tool contains none of the fields listed in indexed_tool_def_parts:\n { tool } " )
249246 text = " | " .join (segments )
250-
251247 # one-pass preprocess + truncation
252248 text = self ._preprocess_text (text )
253249 text = self ._truncate (text )
@@ -260,7 +256,31 @@ def _create_docs_from_tools(self, tools: List[BaseTool]) -> List[Document]:
260256 documents .append (Document (page_content = page_content , metadata = {"name" : tool .name }))
261257 return documents
262258
263- def _index_tools (self , tools : List [BaseTool ], queries : List [QuerySpecification ]) -> None :
259+ def _collect_examples_from_tool_specs (self , tool_specs : Dict [str , Dict [str , Any ]]) -> Dict [str , List [str ]]:
260+ """
261+ Build {tool_name: [example1, example2, ...]} from a tools dict where each
262+ value may contain an 'additional_queries' dict mapping query keys to strings.
263+ """
264+ examples : Dict [str , List [str ]] = {}
265+ for tool_name , spec in (tool_specs or {}).items ():
266+ if not isinstance (spec , dict ):
267+ continue
268+ aq = spec .get ("additional_queries" )
269+ if isinstance (aq , dict ):
270+ for _ , qtext in aq .items ():
271+ if isinstance (qtext , str ) and qtext .strip ():
272+ examples .setdefault (tool_name , []).append (qtext .strip ())
273+ # de-duplicate while preserving order
274+ for k , v in list (examples .items ()):
275+ seen , out = set (), []
276+ for s in v :
277+ if s not in seen :
278+ seen .add (s )
279+ out .append (s )
280+ examples [k ] = out
281+ return examples
282+
283+ def _index_tools (self , tools : List [BaseTool ]) -> None :
264284 self .tool_name_to_base_tool = {tool .name : tool for tool in tools }
265285
266286 self .embeddings = HuggingFaceEmbeddings (model_name = self ._settings ["embedding_model_id" ])
@@ -319,7 +339,7 @@ def _index_tools(self, tools: List[BaseTool], queries: List[QuerySpecification])
319339 search_params = search_params ,
320340 )
321341
322- def set_up (self , model : BaseChatModel , tools : List [BaseTool ], queries : List [ QuerySpecification ] ) -> None :
342+ def set_up (self , model : BaseChatModel , tools : List [BaseTool ], tool_specs : Any ) -> None :
323343 super ().set_up (model , tools )
324344
325345 if self ._settings ["cross_encoder_model_name" ]:
@@ -331,34 +351,15 @@ def set_up(self, model: BaseChatModel, tools: List[BaseTool], queries: List[Quer
331351 if self ._settings ["enable_query_decomposition" ] or self ._settings ["enable_query_rewriting" ]:
332352 self .query_rewriting_model = self ._get_llm (self ._settings ["query_rewriting_model_id" ])
333353
334- # Build additional_queries mapping from provided QuerySpecifications so YAML is not required.
354+ # Build additional_queries mapping from provided specs (accept dict of tool specs or list of QuerySpecifications)
335355 try :
336- tool_examples : Dict [str , List [str ]] = {}
337- for spec in (queries or []):
338- add_q = getattr (spec , "additional_queries" , None ) or {}
339- # Flatten wrapper {"additional_queries": {...}} if present
340- if isinstance (add_q , dict ) and "additional_queries" in add_q and len (add_q ) == 1 :
341- add_q = add_q ["additional_queries" ]
342- for tool_name , qmap in add_q .items ():
343- if isinstance (qmap , dict ):
344- for _ , qtext in qmap .items ():
345- if isinstance (qtext , str ) and qtext .strip ():
346- tool_examples .setdefault (tool_name , []).append (qtext .strip ())
347- # Dedupe while preserving order
348- for k , v in list (tool_examples .items ()):
349- seen = set ()
350- deduped = []
351- for s in v :
352- if s not in seen :
353- seen .add (s )
354- deduped .append (s )
355- tool_examples [k ] = deduped
356- if tool_examples :
357- self ._settings ["additional_queries" ] = tool_examples
356+ examples_map : Dict [str , List [str ]] = {}
357+ if isinstance (tool_specs , dict ):
358+ examples_map = self ._collect_examples_from_tool_specs (tool_specs )
359+ self ._settings ["additional_queries" ] = examples_map
358360 except Exception :
359361 pass
360-
361- self ._index_tools (tools , queries )
362+ self ._index_tools (tools )
362363
363364 def _threshold_results (self , docs_and_scores : List [Tuple [Document , float ]]) -> List [Document ]:
364365 """
@@ -619,4 +620,4 @@ def _dedup_keep_order(xs: List[str]) -> List[str]:
619620
620621 @staticmethod
621622 def _strip_numbering (s : str ) -> str :
622- return re .sub (r"^\s*(?:[-*]|\d+[).:]?)\s*" , "" , s ).strip ().rstrip ("." )
623+ return re .sub (r"^\s*(?:[-*]|\d+[).:]?)\s*" , "" , s ).strip ().rstrip ("." )
0 commit comments