55from pydantic import BaseModel , Field , InstanceOf
66from sentence_transformers import CrossEncoder
77
8+ from memmachine .common .utils import chunk_text , unflatten_like
9+
810from .reranker import Reranker
911
1012
@@ -15,6 +17,10 @@ class CrossEncoderRerankerParams(BaseModel):
1517 ...,
1618 description = "The cross-encoder model to use for reranking" ,
1719 )
20+ max_input_length : int | None = Field (
21+ default = None ,
22+ description = "Maximum input length for the model (in Unicode code points)" ,
23+ )
1824
1925
2026class CrossEncoderReranker (Reranker ):
@@ -25,15 +31,38 @@ def __init__(self, params: CrossEncoderRerankerParams) -> None:
2531 super ().__init__ ()
2632
2733 self ._cross_encoder = params .cross_encoder
34+ self ._max_input_length = params .max_input_length
2835
2936 async def score (self , query : str , candidates : list [str ]) -> list [float ]:
3037 """Score candidates for a query using the cross-encoder."""
31- scores = [
38+ query = query [: self ._max_input_length ] if self ._max_input_length else query
39+
40+ candidates_chunks = [
41+ chunk_text (candidate , self ._max_input_length )
42+ if self ._max_input_length
43+ else [candidate ]
44+ for candidate in candidates
45+ ]
46+
47+ chunks = [
48+ chunk
49+ for candidate_chunks in candidates_chunks
50+ for chunk in candidate_chunks
51+ ]
52+
53+ chunk_scores = [
3254 float (score )
3355 for score in await asyncio .to_thread (
3456 self ._cross_encoder .predict ,
35- [(query , candidate ) for candidate in candidates ],
57+ [(query , chunk ) for chunk in chunks ],
3658 show_progress_bar = False ,
3759 )
3860 ]
39- return scores
61+
62+ candidates_chunk_scores = unflatten_like (chunk_scores , candidates_chunks )
63+
64+ # Take the maximum score among chunks for each candidate.
65+ return [
66+ max (candidate_chunk_scores )
67+ for candidate_chunk_scores in candidates_chunk_scores
68+ ]
0 commit comments