From 56f333277019d7dcb75b07afff898cb0b4649c56 Mon Sep 17 00:00:00 2001 From: Raphael Sourty Date: Thu, 18 Apr 2024 00:41:51 +0200 Subject: [PATCH 1/3] init explorer --- neural_cherche/explore/__init__.py | 3 + neural_cherche/explore/bm25.py | 334 +++++++++++++++++++++++++++++ 2 files changed, 337 insertions(+) create mode 100644 neural_cherche/explore/__init__.py create mode 100644 neural_cherche/explore/bm25.py diff --git a/neural_cherche/explore/__init__.py b/neural_cherche/explore/__init__.py new file mode 100644 index 0000000..d8db128 --- /dev/null +++ b/neural_cherche/explore/__init__.py @@ -0,0 +1,3 @@ +from .bm25 import BM25 + +__all__ = ["BM25"] diff --git a/neural_cherche/explore/bm25.py b/neural_cherche/explore/bm25.py new file mode 100644 index 0000000..8d50ef8 --- /dev/null +++ b/neural_cherche/explore/bm25.py @@ -0,0 +1,334 @@ +import itertools + +import torch +from scipy.sparse import csr_matrix +from sklearn.feature_extraction.text import CountVectorizer + +from ..rank import ColBERT +from ..retrieve import BM25 as BM25Retriever + +__all__ = ["BM25"] + + +class BM25(BM25Retriever): + """BM25 explorer. + + Parameters + ---------- + key + Field identifier of each document. + on + Fields to use to match the query to the documents. + documents + Documents in TFIdf retriever are static. The retriever must be reseted to index new + documents. + CountVectorizer + CountVectorizer class of Sklearn to create a custom CountVectorizer counter. + b + The impact of document length normalization. Default is `0.75`, Higher will + penalize longer documents more. + k1 + How quickly the impact of term frequency saturates. Default is `1.5`, Higher + will make term frequency more influential. + epsilon + Smoothing term. Default is `0`. + fit + Fit the CountVectorizer on the documents. Default is `True`. + + Examples + -------- + >>> from neural_cherche import explore, models, rank + >>> from pprint import pprint + + >>> documents = [ + ... {"id": 0, "document": "Food"}, + ... {"id": 1, "document": "Sports"}, + ... {"id": 2, "document": "Cinema"}, + ... ] + + >>> model = models.ColBERT( + ... model_name_or_path="raphaelsty/neural-cherche-colbert", + ... ) + + >>> ranker = rank.ColBERT( + ... model=model, + ... key="id", + ... on=["document"], + ... ) + + >>> explorer = explore.BM25( + ... key="id", + ... on=["document"], + ... ranker=ranker, + ... ) + + >>> documents_embeddings = explorer.encode_documents( + ... documents=documents + ... ) + + >>> explorer = explorer.add( + ... documents_embeddings=documents_embeddings, + ... ) + + >>> queries = ["Food", "Sports", "Cinema food sports", "cinema"] + >>> queries_embeddings = explorer.encode_queries( + ... queries=queries + ... ) + + >>> scores = explorer( + ... queries_embeddings=queries_embeddings, + ... documents_embeddings=documents_embeddings, + ... k=10, + ... ranker_batch_size=32, + ... retriever_batch_size=2000, + ... max_step=3, + ... beam_size=3, + ... ) + + >>> pprint(scores) + + """ + + def __init__( + self, + key: str, + on: list[str], + ranker: ColBERT, + count_vectorizer: CountVectorizer = None, + b: float = 0.75, + k1: float = 1.5, + epsilon: float = 0, + fit: bool = True, + ) -> None: + super().__init__( + key=key, + on=on, + count_vectorizer=count_vectorizer, + b=b, + k1=k1, + epsilon=epsilon, + fit=fit, + ) + + self.ranker = ranker + self.mapping_documents = {} + + def encode_documents( + self, + documents: list[dict], + ranker_embeddings: bool = False, + batch_size: int = 32, + query_mode: bool = False, + tqdm_bar: bool = True, + **kwargs, + ) -> dict[str, csr_matrix]: + """Encode documents.""" + embeddings = { + "retriever": super().encode_documents( + documents=documents, + ), + "ranker": {}, + } + + for document in documents: + self.mapping_documents[document[self.key]] = document + + if ranker_embeddings: + embeddings["ranker"] = self.ranker.encode_documents( + documents=documents, + batch_size=batch_size, + query_mode=query_mode, + tqdm_bar=tqdm_bar, + **kwargs, + ) + + return embeddings + + def encode_queries( + self, + queries: list[str], + batch_size: int = 32, + query_mode: bool = True, + tqdm_bar: bool = True, + warn_duplicates: bool = True, + **kwargs, + ) -> dict[str, csr_matrix]: + """Encode queries.""" + return { + "retriever": super().encode_queries( + queries=queries, warn_duplicates=warn_duplicates + ), + "ranker": self.ranker.encode_queries( + queries=queries, + batch_size=batch_size, + query_mode=query_mode, + tqdm_bar=tqdm_bar, + **kwargs, + ), + } + + def add(self, documents_embeddings: dict[dict[str, torch.Tensor]]) -> "BM25": + """Add new documents to the BM25 retriever.""" + super().add(documents_embeddings=documents_embeddings["retriever"]) + + return self + + def __call__( + self, + queries_embeddings: dict[str, dict[str, csr_matrix] | dict[str, torch.Tensor]], + documents_embeddings: dict[ + str, dict[str, csr_matrix] | dict[str, torch.Tensor] + ], + k: int = 30, + beam_size: int = 3, + max_step: int = 3, + retriever_batch_size: int = 2000, + ranker_batch_size: int = 32, + early_stopping: bool = False, + tqdm_bar: bool = False, + queries: list[str] = None, + actual_step: int = 0, + scores: list[dict] = None, + ) -> list[list[dict]]: + """Explore the documents.""" + scores = ( + [{} for _ in queries_embeddings["retriever"]] if scores is None else scores + ) + + queries = ( + queries + if queries is not None + else [[query] for query in list(queries_embeddings["retriever"].keys())] + ) + + retriever_queries_embeddings = super().encode_queries( + queries=list( + set([query for group_queries in queries for query in group_queries]) + ), + warn_duplicates=False, + ) + + # Retrieve the top k documents + candidates = super().__call__( + queries_embeddings=retriever_queries_embeddings, + k=k, + batch_size=retriever_batch_size, + tqdm_bar=tqdm_bar, + ) + + # Start post-process candidates retriever. + mapping_position = { + query: position + for position, query in enumerate( + iterable=list(retriever_queries_embeddings.keys()) + ) + } + + # Map candidates back to queries and avoid duplicates and avoid already scored documents + candidates = [ + [ + [ + document + for document in candidates[mapping_position[query]] + if document[self.key] not in query_scores + ] + if query in mapping_position + else [] + for query in group_queries + ] + for group_queries, query_scores in zip(queries, scores) + ] + + candidates = list(itertools.chain.from_iterable(candidates)) + + # Drop duplicates + distinct_candidates = [] + for query_candidates in candidates: + distinct_candidates_query, duplicates = [], {} + for document in query_candidates: + if document[self.key] not in duplicates: + distinct_candidates_query.append(document) + duplicates[document[self.key]] = True + distinct_candidates.append(distinct_candidates_query) + candidates = distinct_candidates + + print(candidates, queries) + # End post-process candidates retriever. + + # Encoding documents + documents_to_encode, duplicates = [], {} + for query_documents in candidates: + for document in query_documents: + if ( + document[self.key] not in documents_embeddings["ranker"] + and document[self.key] not in duplicates + ): + documents_to_encode.append( + self.mapping_documents[document[self.key]] + ) + + duplicates[document[self.key]] = True + + if documents_to_encode: + documents_embeddings["ranker"].update( + self.ranker.encode_documents( + documents=documents_to_encode, + batch_size=ranker_batch_size, + tqdm_bar=False, + ) + ) + # End encoding documents + + # Rank the candidates and take the top k + candidates = self.ranker( + documents=candidates, + queries_embeddings=queries_embeddings["ranker"], + documents_embeddings=documents_embeddings["ranker"], + batch_size=ranker_batch_size, + tqdm_bar=tqdm_bar, + ) + + scores = [ + { + **query_scores, + **{ + document[self.key]: document["similarity"] + for document in query_documents + }, + } + for query_scores, query_documents in zip(scores, candidates) + ] + + if (actual_step - 1) > max_step: + return scores + + # Add early stopping + # Take beam_size top candidates which are not in query_explored + # Create query explored + top_candidates = candidates + + queries = [ + [ + f"{query} {' '.join([self.mapping_documents[document[self.key]][field] for field in self.on])}" + for document in query_documents + ][:beam_size] + for query, query_documents in zip( + list(queries_embeddings["retriever"].keys()), top_candidates + ) + ] + + print(len(scores), len(queries)) + + return self( + queries_embeddings=queries_embeddings, + documents_embeddings=documents_embeddings, + k=k, + beam_size=beam_size, + max_step=max_step, + retriever_batch_size=retriever_batch_size, + ranker_batch_size=ranker_batch_size, + tqdm_bar=tqdm_bar, + queries=queries, + actual_step=actual_step + 1, + scores=scores, + ) From 48d66c11088b2ab7706787bb635d368b0f6e2ee1 Mon Sep 17 00:00:00 2001 From: Raphael Sourty Date: Sun, 21 Apr 2024 17:00:21 +0200 Subject: [PATCH 2/3] intialize explorer --- neural_cherche/__init__.py | 2 +- neural_cherche/explore/bm25.py | 352 ++++++++++++++++++++++++--------- neural_cherche/rank/colbert.py | 36 ++-- 3 files changed, 284 insertions(+), 106 deletions(-) diff --git a/neural_cherche/__init__.py b/neural_cherche/__init__.py index bf8a7c3..d7c1a24 100644 --- a/neural_cherche/__init__.py +++ b/neural_cherche/__init__.py @@ -1 +1 @@ -__all__ = ["losses", "models", "retrieve", "rank", "train", "utils"] +__all__ = ["explore", "losses", "models", "retrieve", "rank", "train", "utils"] diff --git a/neural_cherche/explore/bm25.py b/neural_cherche/explore/bm25.py index 8d50ef8..3b3ba12 100644 --- a/neural_cherche/explore/bm25.py +++ b/neural_cherche/explore/bm25.py @@ -40,10 +40,10 @@ class BM25(BM25Retriever): >>> from neural_cherche import explore, models, rank >>> from pprint import pprint + >>> queries = ["Food", "Sports", "Cinema food sports", "cinema"] + >>> documents = [ - ... {"id": 0, "document": "Food"}, - ... {"id": 1, "document": "Sports"}, - ... {"id": 2, "document": "Cinema"}, + ... {"id": id, "document": queries[id%4]} for id in range(100) ... ] >>> model = models.ColBERT( @@ -63,14 +63,14 @@ class BM25(BM25Retriever): ... ) >>> documents_embeddings = explorer.encode_documents( - ... documents=documents + ... documents=documents, + ... ranker_embeddings=False, ... ) >>> explorer = explorer.add( ... documents_embeddings=documents_embeddings, ... ) - >>> queries = ["Food", "Sports", "Cinema food sports", "cinema"] >>> queries_embeddings = explorer.encode_queries( ... queries=queries ... ) @@ -79,6 +79,7 @@ class BM25(BM25Retriever): ... queries_embeddings=queries_embeddings, ... documents_embeddings=documents_embeddings, ... k=10, + ... early_stopping=True, ... ranker_batch_size=32, ... retriever_batch_size=2000, ... max_step=3, @@ -86,6 +87,46 @@ class BM25(BM25Retriever): ... ) >>> pprint(scores) + [[{'id': 96, 'similarity': 4.7243194580078125}, + {'id': 24, 'similarity': 4.7243194580078125}, + {'id': 60, 'similarity': 4.7243194580078125}, + {'id': 20, 'similarity': 4.7243194580078125}, + {'id': 56, 'similarity': 4.7243194580078125}, + {'id': 52, 'similarity': 4.7243194580078125}, + {'id': 0, 'similarity': 4.7243194580078125}, + {'id': 48, 'similarity': 4.7243194580078125}, + {'id': 36, 'similarity': 4.7243194580078125}, + {'id': 40, 'similarity': 4.7243194580078125}], + [{'id': 97, 'similarity': 4.792297840118408}, + {'id': 25, 'similarity': 4.792297840118408}, + {'id': 61, 'similarity': 4.792297840118408}, + {'id': 21, 'similarity': 4.792297840118408}, + {'id': 57, 'similarity': 4.792297840118408}, + {'id': 53, 'similarity': 4.792297840118408}, + {'id': 1, 'similarity': 4.792297840118408}, + {'id': 49, 'similarity': 4.792297840118408}, + {'id': 37, 'similarity': 4.792297840118408}, + {'id': 41, 'similarity': 4.792297840118408}], + [{'id': 74, 'similarity': 7.377876281738281}, + {'id': 82, 'similarity': 7.377876281738281}, + {'id': 62, 'similarity': 7.377876281738281}, + {'id': 94, 'similarity': 7.377876281738281}, + {'id': 70, 'similarity': 7.377876281738281}, + {'id': 66, 'similarity': 7.377876281738281}, + {'id': 78, 'similarity': 7.377876281738281}, + {'id': 2, 'similarity': 7.377876281738281}, + {'id': 90, 'similarity': 7.377876281738281}, + {'id': 46, 'similarity': 7.377876281738281}], + [{'id': 31, 'similarity': 5.06969690322876}, + {'id': 23, 'similarity': 5.06969690322876}, + {'id': 55, 'similarity': 5.069695472717285}, + {'id': 47, 'similarity': 5.069695472717285}, + {'id': 43, 'similarity': 5.069695472717285}, + {'id': 39, 'similarity': 5.069695472717285}, + {'id': 35, 'similarity': 5.069695472717285}, + {'id': 63, 'similarity': 5.069695472717285}, + {'id': 27, 'similarity': 5.069695472717285}, + {'id': 11, 'similarity': 5.069695472717285}]] """ @@ -170,9 +211,142 @@ def encode_queries( def add(self, documents_embeddings: dict[dict[str, torch.Tensor]]) -> "BM25": """Add new documents to the BM25 retriever.""" super().add(documents_embeddings=documents_embeddings["retriever"]) - return self + def _encode_queries_retriever( + self, queries: list, queries_embeddings: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + """Encode queries for the retriever.""" + return super().encode_queries( + queries=list( + set( + [ + query + for group_queries in queries + for query in group_queries + if query + ] + ) + ), + warn_duplicates=False, + ) + + def _encode_documents_ranker( + self, + candidates: list[list[dict]], + documents_embeddings: dict[str, csr_matrix], + batch_size: int, + ) -> dict[str, torch.Tensor]: + """Encode documents for the ranker.""" + documents_to_encode, duplicates = [], {} + for query_documents in candidates: + for document in query_documents: + if ( + document[self.key] not in documents_embeddings + and document[self.key] not in duplicates + ): + documents_to_encode.append( + self.mapping_documents[document[self.key]] + ) + + duplicates[document[self.key]] = True + + if documents_to_encode: + documents_embeddings.update( + self.ranker.encode_documents( + documents=documents_to_encode, + batch_size=batch_size, + tqdm_bar=False, + ) + ) + + return documents_embeddings + + def _post_process_candidates_retriever( + self, + queries_embeddings: dict, + queries: list[str], + candidates: list[list[dict]], + documents_explored: list[dict], + k: int, + ) -> list[list[dict]]: + """Post-process candidates from the retriever.""" + mapping_position = { + query: position + for position, query in enumerate(iterable=list(queries_embeddings.keys())) + } + + # Gather all the documents retrieved for the same query + candidates = [ + list( + itertools.chain.from_iterable( + [ + [ + document + for document in candidates[mapping_position[query]] + if document[self.key] not in query_scores + ] + if query in mapping_position + else [] + for query in group_queries + ] + ) + ) + for group_queries, query_scores in zip(queries, documents_explored) + ] + + # Drop duplicates documents retrieved for the same query. + distinct_candidates = [] + for query_candidates in candidates: + distinct_candidates_query, duplicates = [], {} + for document in query_candidates: + if document[self.key] not in duplicates: + distinct_candidates_query.append(document) + duplicates[document[self.key]] = True + distinct_candidates.append(distinct_candidates_query) + + return distinct_candidates + + def _get_next_queries( + self, + candidates: list[list[dict]], + queries_embeddings: dict[str, csr_matrix], + documents_explored: list[dict], + beam_size: int, + max_scores: list[float], + early_stopping: bool, + ) -> tuple[list[str], list[float], list[dict]]: + """Get the next queries to explore.""" + next_queries, next_max_scores = [], [] + + for query, query_documents, query_documents_explored, query_max_score in zip( + list(queries_embeddings.keys()), candidates, documents_explored, max_scores + ): + query_next_queries = [] + early_stopping_condition = False + + for document in query_documents: + if document[self.key] not in query_documents_explored: + if ( + document["similarity"] >= query_max_score and early_stopping + ) or (early_stopping_condition and early_stopping): + if document["similarity"] > query_max_score: + query_max_score = document["similarity"] + + early_stopping_condition = True + query_documents_explored[document[self.key]] = True + query_next_queries.append( + f"{query} {' '.join([self.mapping_documents[document[self.key]][field] for field in self.on])}" + ) + + if len(query_next_queries) >= beam_size: + break + + next_max_scores.append(query_max_score) + next_queries.append(query_next_queries) + + return next_queries, next_max_scores, documents_explored + def __call__( self, queries_embeddings: dict[str, dict[str, csr_matrix] | dict[str, torch.Tensor]], @@ -189,23 +363,56 @@ def __call__( queries: list[str] = None, actual_step: int = 0, scores: list[dict] = None, + documents_explored: list[dict] = None, + max_scores: list[float] = None, ) -> list[list[dict]]: - """Explore the documents.""" - scores = ( - [{} for _ in queries_embeddings["retriever"]] if scores is None else scores - ) - + """Explore the documents. + + Parameters + ---------- + queries_embeddings + Queries embeddings. + documents_embeddings + Documents embeddings. + k + Number of documents to retrieve. + beam_size + Among the top k documents retrieved, how many documents to explore. + max_step + Maximum number of steps to explore. + retriever_batch_size + Batch size for the retriever. + ranker_batch_size + Batch size for the ranker. + early_stopping + Number of step to perform the exploration until the ranker did not spot better + documents. + """ queries = ( queries if queries is not None else [[query] for query in list(queries_embeddings["retriever"].keys())] ) - retriever_queries_embeddings = super().encode_queries( - queries=list( - set([query for group_queries in queries for query in group_queries]) - ), - warn_duplicates=False, + scores = ( + [{} for _ in queries_embeddings["retriever"]] if scores is None else scores + ) + + max_scores = ( + [0 for _ in queries_embeddings["retriever"]] + if max_scores is None + else max_scores + ) + + documents_explored = ( + documents_explored + if documents_explored is not None + else [{} for _ in queries_embeddings["retriever"]] + ) + + retriever_queries_embeddings = self._encode_queries_retriever( + queries=queries, + queries_embeddings=queries_embeddings["retriever"], ) # Retrieve the top k documents @@ -217,67 +424,20 @@ def __call__( ) # Start post-process candidates retriever. - mapping_position = { - query: position - for position, query in enumerate( - iterable=list(retriever_queries_embeddings.keys()) - ) - } - - # Map candidates back to queries and avoid duplicates and avoid already scored documents - candidates = [ - [ - [ - document - for document in candidates[mapping_position[query]] - if document[self.key] not in query_scores - ] - if query in mapping_position - else [] - for query in group_queries - ] - for group_queries, query_scores in zip(queries, scores) - ] - - candidates = list(itertools.chain.from_iterable(candidates)) - - # Drop duplicates - distinct_candidates = [] - for query_candidates in candidates: - distinct_candidates_query, duplicates = [], {} - for document in query_candidates: - if document[self.key] not in duplicates: - distinct_candidates_query.append(document) - duplicates[document[self.key]] = True - distinct_candidates.append(distinct_candidates_query) - candidates = distinct_candidates - - print(candidates, queries) - # End post-process candidates retriever. + candidates = self._post_process_candidates_retriever( + queries_embeddings=retriever_queries_embeddings, + queries=queries, + candidates=candidates, + documents_explored=documents_explored, + k=k, + ) # Encoding documents - documents_to_encode, duplicates = [], {} - for query_documents in candidates: - for document in query_documents: - if ( - document[self.key] not in documents_embeddings["ranker"] - and document[self.key] not in duplicates - ): - documents_to_encode.append( - self.mapping_documents[document[self.key]] - ) - - duplicates[document[self.key]] = True - - if documents_to_encode: - documents_embeddings["ranker"].update( - self.ranker.encode_documents( - documents=documents_to_encode, - batch_size=ranker_batch_size, - tqdm_bar=False, - ) - ) - # End encoding documents + documents_embeddings["ranker"] = self._encode_documents_ranker( + candidates=candidates, + documents_embeddings=documents_embeddings["ranker"], + batch_size=ranker_batch_size, + ) # Rank the candidates and take the top k candidates = self.ranker( @@ -300,29 +460,23 @@ def __call__( ] if (actual_step - 1) > max_step: - return scores + return self._rank(scores=scores, k=k) # Add early stopping - # Take beam_size top candidates which are not in query_explored - # Create query explored - top_candidates = candidates - - queries = [ - [ - f"{query} {' '.join([self.mapping_documents[document[self.key]][field] for field in self.on])}" - for document in query_documents - ][:beam_size] - for query, query_documents in zip( - list(queries_embeddings["retriever"].keys()), top_candidates - ) - ] - - print(len(scores), len(queries)) + queries, max_scores, documents_explored = self._get_next_queries( + queries_embeddings=queries_embeddings["retriever"], + candidates=candidates, + documents_explored=documents_explored, + beam_size=beam_size, + max_scores=max_scores, + early_stopping=early_stopping, + ) return self( queries_embeddings=queries_embeddings, documents_embeddings=documents_embeddings, k=k, + early_stopping=early_stopping, beam_size=beam_size, max_step=max_step, retriever_batch_size=retriever_batch_size, @@ -331,4 +485,20 @@ def __call__( queries=queries, actual_step=actual_step + 1, scores=scores, + documents_explored=documents_explored, + max_scores=max_scores, ) + + def _rank(self, scores: list[dict], k: int) -> list[dict]: + """Rank the scores.""" + return [ + [ + {self.key: key, "similarity": similarity} + for key, similarity in sorted( + query_scores.items(), + key=lambda item: item[1], + reverse=True, + ) + ][:k] + for query_scores in scores + ] diff --git a/neural_cherche/rank/colbert.py b/neural_cherche/rank/colbert.py index 0903a94..5b9483b 100644 --- a/neural_cherche/rank/colbert.py +++ b/neural_cherche/rank/colbert.py @@ -138,23 +138,31 @@ def encode_documents( if not documents: return {} - embeddings = self.encode_queries( - queries=[ - " ".join([document[field] for field in self.on]) - for document in documents - ], + embeddings = {} + + for batch_documents in utils.batchify( + X=documents, batch_size=batch_size, tqdm_bar=tqdm_bar, - query_mode=query_mode, - desc=desc, - warn_duplicates=False, - **kwargs, - ) + desc=f"{self.__class__.__name__} {desc}", + ): + batch_embeddings = self.model.encode( + texts=[ + " ".join([document[field] for field in self.on]) + for document in batch_documents + ], + query_mode=query_mode, + **kwargs, + ) - return { - document[self.key]: embedding - for document, embedding in zip(documents, embeddings.values()) - } + batch_embeddings = ( + batch_embeddings["embeddings"].cpu().detach().numpy().astype("float32") + ) + + for document, embedding in zip(batch_documents, batch_embeddings): + embeddings[document[self.key]] = embedding + + return embeddings def encode_candidates_documents( self, From 198a6b4f62c0742b7b133d42fb06cc63c21b5c15 Mon Sep 17 00:00:00 2001 From: Arthur Date: Sat, 4 May 2024 21:45:53 +0200 Subject: [PATCH 3/3] Create xtr.py I have coded XTR from google, "Rethinking the Role of Token Retrieval in Multi-Vector Retrieval", we still need to optimize the code and to add Missing similarity imputation, please let me know if u have any question. --- neural_cherche/retrieve/xtr.py | 196 +++++++++++++++++++++++++++++++++ 1 file changed, 196 insertions(+) create mode 100644 neural_cherche/retrieve/xtr.py diff --git a/neural_cherche/retrieve/xtr.py b/neural_cherche/retrieve/xtr.py new file mode 100644 index 0000000..5aea9ac --- /dev/null +++ b/neural_cherche/retrieve/xtr.py @@ -0,0 +1,196 @@ +from collections import defaultdict + +import torch +import tqdm + +from .. import models, utils +from ..retrieve import ColBERT + +__all__ = ["XTR"] + + +class XTR(ColBERT): + """XTR retriever. + + Parameters + ---------- + key + Document unique identifier. + on + Document texts. + model + ColBERT model. + + Examples + -------- + >>> from neural_cherche import models, retrieve + >>> from pprint import pprint + >>> import torch + + >>> _ = torch.manual_seed(42) + + >>> encoder = models.ColBERT( + ... model_name_or_path="raphaelsty/neural-cherche-colbert", + ... device="mps", + ... ) + + >>> documents = [ + ... {"id": 0, "document": "Food"}, + ... {"id": 1, "document": "Sports"}, + ... {"id": 2, "document": "Cinema"}, + ... ] + + >>> queries = ["Food", "Sports", "Cinema"] + + >>> retriever = retrieve.XTR( + ... key="id", + ... on=["document"], + ... model=encoder, + ... ) + + >>> documents_embeddings = retriever.encode_documents( + ... documents=documents, + ... batch_size=3, + ... ) + + >>> retriever = retriever.add( + ... documents_embeddings=documents_embeddings, + ... ) + + >>> queries_embeddings = retriever.encode_queries( + ... queries=queries, + ... batch_size=3, + ... ) + + >>> scores = retriever( + ... queries_embeddings=queries_embeddings, + ... batch_size=3, + ... tqdm_bar=True, + ... k=3, + ... ) + + >>> pprint(scores) + [[{'id': 0, 'similarity': 4.7243194580078125}, + {'id': 2, 'similarity': 2.403003692626953}, + {'id': 1, 'similarity': 2.286036252975464}], + [{'id': 1, 'similarity': 4.792296886444092}, + {'id': 2, 'similarity': 2.6001152992248535}, + {'id': 0, 'similarity': 2.312016487121582}], + [{'id': 2, 'similarity': 5.069696426391602}, + {'id': 1, 'similarity': 2.5587477684020996}, + {'id': 0, 'similarity': 2.4474282264709473}]] + + """ + + def __init__( + self, + key: str, + on: list[str], + model: models.ColBERT, + ) -> None: + self.key = key + self.on = on if isinstance(on, list) else [on] + self.model = model + self.device = self.model.device + self.documents = [] + self.documents_embeddings = {} + + def __call__( + self, + queries_embeddings: dict[str, torch.Tensor], + batch_size: int = 32, + tqdm_bar: bool = True, + k: int = None, + ) -> list[list[str]]: + """Rank documents givent queries. + + Parameters + ---------- + queries + Queries. + documents + Documents. + queries_embeddings + Queries embeddings. + documents_embeddings + Documents embeddings. + batch_size + Batch size. + tqdm_bar + Show tqdm bar. + k + Number of documents to retrieve. + """ + scores = [] + + bar = ( + tqdm.tqdm(iterable=queries_embeddings.items(), position=0) + if tqdm_bar + else queries_embeddings.items() + ) + + for query, query_embedding in bar: + query_scores = [] + + embedding_query = torch.tensor( + data=query_embedding, + device=self.device, + dtype=torch.float32, + ) + + for batch_query_documents in utils.batchify( + X=self.documents, + batch_size=batch_size, + tqdm_bar=False, + ): + embeddings_batch_documents = torch.stack( + tensors=[ + torch.tensor( + data=self.documents_embeddings[document[self.key]], + device=self.device, + dtype=torch.float32, + ) + for document in batch_query_documents + ], + dim=0, + ) + + query_documents_scores = torch.einsum( + "sh,bth->bst", + embedding_query, + embeddings_batch_documents, + ) + + query_scores.append(query_documents_scores) + scores.append(self.xtr_score(torch.cat(tensors=query_scores), k)) + + return scores + + def xtr_score(self, all_socres, k:int + )-> list[list[dict]]: + num_tokens = all_socres.shape[1] + sum_tokens_queries = defaultdict(float) + for token_id in range(num_tokens): + # Iterate through tokens + tensor = all_socres[:, token_id, :] + # Flatten the tensor + flattened_tensor = tensor.flatten() + # Use topk to get the indices of the top k` elements across the entire tensor + top_values, top_indices = flattened_tensor.topk(1000) + # Convert the flattened indices to their original shape + index_top_doc = top_indices // tensor.shape[1] # index of the doc + # index_top_token= top_indices % tensor.shape[1]# index of token embding doc + # exact index for doc and token embding doc for a token query + # index_doc_docToken = torch.stack([top_indices // tensor.shape[1], top_indices % tensor.shape[1]],1) + # Iterate through same doc index and update using sum + for idx, i_doc in enumerate(index_top_doc): + sum_tokens_queries[self.documents[i_doc.item()]["index"]] += top_values[ + idx + ] + # make it in the same format of {self.key: key_, 'similarity': value_} and stop at top k + socres = [] + for n, (key_, value_) in enumerate(sum_tokens_queries.items()): + if n > k - 1: + break + socres.append({self.key: key_, "similarity": value_ / num_tokens}) + return socres