Skip to content

Commit eec97d4

Browse files
committed
Implement cross-encoder reranker input length limit
Signed-off-by: Edwin Yu <[email protected]>
1 parent 5822ba4 commit eec97d4

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

src/memmachine/common/reranker/cross_encoder_reranker.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from pydantic import BaseModel, Field, InstanceOf
66
from sentence_transformers import CrossEncoder
77

8+
from memmachine.common.utils import chunk_text, unflatten_like
9+
810
from .reranker import Reranker
911

1012

@@ -15,6 +17,10 @@ class CrossEncoderRerankerParams(BaseModel):
1517
...,
1618
description="The cross-encoder model to use for reranking",
1719
)
20+
max_input_length: int | None = Field(
21+
default=None,
22+
description="Maximum input length for the model (in Unicode code points)",
23+
)
1824

1925

2026
class CrossEncoderReranker(Reranker):
@@ -25,15 +31,38 @@ def __init__(self, params: CrossEncoderRerankerParams) -> None:
2531
super().__init__()
2632

2733
self._cross_encoder = params.cross_encoder
34+
self._max_input_length = params.max_input_length
2835

2936
async def score(self, query: str, candidates: list[str]) -> list[float]:
3037
"""Score candidates for a query using the cross-encoder."""
31-
scores = [
38+
query = query[: self._max_input_length] if self._max_input_length else query
39+
40+
candidates_chunks = [
41+
chunk_text(candidate, self._max_input_length)
42+
if self._max_input_length
43+
else [candidate]
44+
for candidate in candidates
45+
]
46+
47+
chunks = [
48+
chunk
49+
for candidate_chunks in candidates_chunks
50+
for chunk in candidate_chunks
51+
]
52+
53+
chunk_scores = [
3254
float(score)
3355
for score in await asyncio.to_thread(
3456
self._cross_encoder.predict,
35-
[(query, candidate) for candidate in candidates],
57+
[(query, chunk) for chunk in chunks],
3658
show_progress_bar=False,
3759
)
3860
]
39-
return scores
61+
62+
candidates_chunk_scores = unflatten_like(chunk_scores, candidates_chunks)
63+
64+
# Take the maximum score among chunks for each candidate.
65+
return [
66+
max(candidate_chunk_scores)
67+
for candidate_chunk_scores in candidates_chunk_scores
68+
]

src/memmachine/common/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,21 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
4949
return wrapper
5050

5151

52+
def chunk_text(text: str, max_length: int) -> list[str]:
53+
"""
54+
Chunk text into partitions not exceeding max_length.
55+
56+
Args:
57+
text (str): The input text to chunk.
58+
max_length (int): The maximum length of each chunk.
59+
60+
Returns:
61+
list[str]: A list of text chunks.
62+
63+
"""
64+
return [text[i : i + max_length] for i in range(0, len(text), max_length)]
65+
66+
5267
def chunk_text_balanced(text: str, max_length: int) -> list[str]:
5368
"""
5469
Chunk text into balanced partitions not exceeding max_length.

0 commit comments

Comments
 (0)