Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions libs/community/langchain_community/retrievers/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,19 @@ def from_texts(
A BM25Retriever instance.
"""
try:
from rank_bm25 import BM25Okapi
from bm25s import BM25
except ImportError:
raise ImportError(
"Could not import rank_bm25, please install with `pip install "
"rank_bm25`."
"Could not import bm25s, please install with `pip install "
"bm25s`."
)

texts_processed = [preprocess_func(t) for t in texts]
bm25_params = bm25_params or {}
vectorizer = BM25Okapi(texts_processed, **bm25_params)
method = bm25_params.pop("method", "atire")
idf_method = bm25_params.pop("idf_method", "lucene")
vectorizer = BM25(method=method, idf_method=idf_method, **bm25_params)
vectorizer.index(texts_processed)
metadatas = metadatas or ({} for _ in texts)
if ids:
docs = [
Expand Down Expand Up @@ -109,8 +112,23 @@ def from_documents(
)

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
self, query: str, *, run_manager: CallbackManagerForRetrieverRun, **kwargs: Any
) -> List[Document]:
# Cache self.vectorizer.retrieve() parameters
if not hasattr(self, "_retrieve_params"):
import inspect
self._retrieve_params = inspect.signature(self.vectorizer.retrieve).parameters
retrieve_params = self._retrieve_params
retrieve_kwargs = {
k: v for k, v in kwargs.items() if k in retrieve_params
}

processed_query = self.preprocess_func(query)
return_docs = self.vectorizer.get_top_n(processed_query, self.docs, n=self.k)
return return_docs
results = self.vectorizer.retrieve(
query_tokens=[processed_query],
corpus=self.docs,
k=self.k,
return_as="documents",
**retrieve_kwargs # for params like 'backend_selection', 'n_threads' ...
) # np.ndarray of shape (num_queries, k), dtype=object; each entry is a Document
return list(results[0])