1+ import asyncio
2+ from dataclasses import dataclass
3+ from tenacity import retry , stop_after_attempt , wait_exponential , retry_if_exception_type
4+ from elastic_transport import TransportError
5+ from typing import List
16from elasticsearch import AsyncElasticsearch
27from ..config .settings import es_settings
38from ..exceptions import ToolException
813
914
1015class AsyncElasticClient :
16+ @dataclass
17+ class SearchResponse :
18+ data : List [dict ]
19+ total : int = 0
20+
1121 def __init__ (self ):
1222 # 初始化异步 ElasticSearch 客户端
1323 self ._client = AsyncElasticsearch (es_settings .URL ,
1424 api_key = es_settings .api_key , verify_certs = False )
1525 self .index = es_settings .ES_INDEX
1626
27+ @retry (
28+ reraise = True ,
29+ stop = stop_after_attempt (3 ),
30+ wait = wait_exponential (multiplier = 1 , min = 1 , max = 10 ),
31+ retry = (
32+ retry_if_exception_type (TransportError ) |
33+ retry_if_exception_type (asyncio .TimeoutError )
34+ ),
35+ )
1736 async def search_news (self , query : str , source : str = None , date_from : str = None , date_to : str = None , max_results : int = 10 ) -> list :
1837 """
1938 ElasticSearch 异步搜索新闻
@@ -112,5 +131,82 @@ async def get_by_id(self, news_id: str) -> dict:
112131 except Exception :
113132 raise ToolException (f'Tool call exception with news_id { news_id } ' )
114133
134+ def _append_common_filters (self , must : list , search_word : str , date_from : str , date_to : str ):
135+ """提炼公共过滤器: 添加 search_word 和时间范围到 must 列表"""
136+ if search_word :
137+ must .append ({
138+ 'multi_match' : {
139+ 'query' : search_word ,
140+ 'fields' : ['title^5' , 'content' ],
141+ 'operator' : 'and'
142+ }
143+ })
144+ if date_from or date_to :
145+ range_filter = {}
146+ if date_from :
147+ range_filter ['gte' ] = date_from
148+ if date_to :
149+ range_filter ['lte' ] = date_to
150+ must .append ({'range' : {'release_time' : range_filter }})
151+
152+ def _add_clauses (self , should_clauses : list , base_filters : list , secondary_queries : list [str ], search_word : str , date_from : str , date_to : str ):
153+ """根据 base_filters 和 secondary_queries 构建子句并添加到 should_clauses"""
154+ if secondary_queries :
155+ for sec in secondary_queries :
156+ must = base_filters + [{'match_phrase' : {'title' : sec }}]
157+ self ._append_common_filters (must , search_word , date_from , date_to )
158+ should_clauses .append ({'bool' : {'must' : must }})
159+ else :
160+ must = base_filters .copy ()
161+ self ._append_common_filters (must , search_word , date_from , date_to )
162+ should_clauses .append ({'bool' : {'must' : must }})
163+
164+ @retry (
165+ reraise = True ,
166+ stop = stop_after_attempt (3 ),
167+ wait = wait_exponential (multiplier = 1 , min = 1 , max = 10 ),
168+ retry = (
169+ retry_if_exception_type (TransportError ) |
170+ retry_if_exception_type (asyncio .TimeoutError )
171+ ),
172+ )
173+ async def search_topic_news (
174+ self ,
175+ primary_queries : List [str ],
176+ secondary_query : List [str ]= None ,
177+ max_results : int = 10 ,
178+ sources : List [str ] = None ,
179+ search_word = None ,
180+ date_from : str = None ,
181+ date_to : str = None
182+ ) -> SearchResponse :
183+ """
184+ "根据多个标签列表、筛选词列表(组)、数据源列表以 OR 关系批量查询新闻,支持时间范围筛选. "
185+ "基本查询逻辑:<label1>&<filtered_words>|<label2>&<filtered_words>|<source1>&<filtered_words>|...|"
186+ "允许在基本查询逻辑之上再搜索"
187+ """
188+ limit = min (max_results , es_settings .MAX_RESULTS_LIMIT )
189+ secondary_queries = secondary_query or []
190+ should_clauses = []
191+ for primary in primary_queries or []:
192+ self ._add_clauses (should_clauses , [{'match_phrase' : {'title' : primary }}], secondary_queries , search_word , date_from , date_to )
193+ for source in sources or []:
194+ self ._add_clauses (should_clauses , [{'term' : {'source.keyword' : source }}], secondary_queries , search_word , date_from , date_to )
195+ body = {'query' : {'bool' : {'should' : should_clauses }}}
196+ # 按发布日期降序排序
197+ body ['sort' ] = [{'release_time' : {'order' : 'desc' }}]
198+
199+ response = await self ._client .search (
200+ index = self .index ,
201+ body = body ,
202+ size = limit ,
203+ source_includes = OUTPUT_SOURCE_FIELDS
204+ )
205+ raw_hits = response .get ('hits' , {})
206+ hits = raw_hits .get ('hits' , [])
207+ total = raw_hits .get ("total" , {}).get ("value" , 0 )
208+ return self .SearchResponse (data = [hit .get ('_source' , {}) for hit in hits ], total = total )
209+
210+
115211 async def close (self ):
116212 await self ._client .close ()
0 commit comments