diff --git a/fastembed/text/onnx_embedding.py b/fastembed/text/onnx_embedding.py index 4cc892f5..74659ba4 100644 --- a/fastembed/text/onnx_embedding.py +++ b/fastembed/text/onnx_embedding.py @@ -168,6 +168,26 @@ sources=ModelSource(hf="snowflake/snowflake-arctic-embed-l"), model_file="onnx/model.onnx", ), + DenseModelDescription( + model="snowflake/snowflake-arctic-embed-l-v2.0", + dim=1024, + tasks={ + "embedding": { + "query_prefix": "query: ", + "passage_prefix": "", + } + }, + description=( + "Text embeddings, Unimodal (text), Multilingual (74 languages), 8192 input tokens truncation, " + "Based on XLM-RoBERTa, supports Matryoshka learning for dimension truncation, " + "Prefixes for queries: recommended (query: ), 2024 year." + ), + license="apache-2.0", + size_in_GB=2.27, + sources=ModelSource(hf="Snowflake/snowflake-arctic-embed-l-v2.0"), + model_file="onnx/model.onnx", + additional_files=["onnx/model.onnx_data"], + ), DenseModelDescription( model="jinaai/jina-clip-v1", dim=768, @@ -294,6 +314,60 @@ def embed( **kwargs, ) + def query_embed( + self, query: Union[str, Iterable[str]], **kwargs: Any + ) -> Iterable[NumpyArray]: + """ + Embeds queries using the model-specific query prefix if configured. + + Args: + query: The query or queries to embed + **kwargs: Additional arguments to pass to embed + + Yields: + Iterable[NumpyArray]: Query embeddings + """ + # Check if model has task-specific prefixes configured + if self.model_description.tasks: + embedding_task = self.model_description.tasks.get("embedding", {}) + query_prefix = embedding_task.get("query_prefix", "") + + if query_prefix: + # Add prefix to queries + if isinstance(query, str): + query = f"{query_prefix}{query}" + else: + query = [f"{query_prefix}{q}" for q in query] + + # Use parent implementation + if isinstance(query, str): + yield from self.embed([query], **kwargs) + else: + yield from self.embed(query, **kwargs) + + def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]: + """ + Embeds passages using the model-specific passage prefix if configured. + + Args: + texts: The passages to embed + **kwargs: Additional arguments to pass to embed + + Yields: + Iterable[NumpyArray]: Passage embeddings + """ + # Check if model has task-specific prefixes configured + if self.model_description.tasks: + embedding_task = self.model_description.tasks.get("embedding", {}) + passage_prefix = embedding_task.get("passage_prefix", "") + + if passage_prefix: + # Add prefix to passages + texts = [f"{passage_prefix}{t}" for t in texts] + + # Use parent implementation + yield from self.embed(texts, **kwargs) + @classmethod def _get_worker_class(cls) -> Type["TextEmbeddingWorker[NumpyArray]"]: return OnnxTextEmbeddingWorker diff --git a/tests/test_text_onnx_embeddings.py b/tests/test_text_onnx_embeddings.py index 6b25d900..ad91c87b 100644 --- a/tests/test_text_onnx_embeddings.py +++ b/tests/test_text_onnx_embeddings.py @@ -64,6 +64,9 @@ [0.0080, -0.0266, -0.0335, 0.0282, 0.0143] ), "snowflake/snowflake-arctic-embed-l": np.array([0.0189, -0.0673, 0.0183, 0.0124, 0.0146]), + "snowflake/snowflake-arctic-embed-l-v2.0": np.array( + [-0.0266, 0.0167, -0.0478, -0.0039, -0.0128] + ), "Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]), "thenlper/gte-base": np.array([0.0038, 0.0355, 0.0181, 0.0092, 0.0654]), "jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]), @@ -175,3 +178,75 @@ def test_embedding_size() -> None: if is_ci: delete_model_cache(model.model._model_dir) + + +def test_query_passage_prefix() -> None: + """Test that query/passage prefixes are applied correctly for models with prefix configuration.""" + is_ci = os.getenv("CI") + + # Test with Snowflake Arctic Embed L v2.0 which has query_prefix configured + model_name = "snowflake/snowflake-arctic-embed-l-v2.0" + model = TextEmbedding(model_name=model_name) + + test_text = "what is fastembed?" + + # Test query_embed (should apply "query: " prefix) + query_embedding = list(model.query_embed(test_text)) + query_embedding_array = np.array(query_embedding) + + # Test regular embed (should not apply prefix) + regular_embedding = list(model.embed([test_text])) + regular_embedding_array = np.array(regular_embedding) + + # Query embeddings with prefix should differ from regular embeddings without prefix + assert not np.allclose(query_embedding_array, regular_embedding_array), ( + "Query embeddings with prefix should differ from regular embeddings" + ) + + # Test passage_embed (should not apply prefix for this model) + passage_embedding = list(model.passage_embed([test_text])) + passage_embedding_array = np.array(passage_embedding) + + # Passage embeddings should match regular embeddings (both without prefix) + assert np.allclose(passage_embedding_array, regular_embedding_array, atol=1e-5), ( + "Passage embeddings should match regular embeddings when no passage prefix configured" + ) + + # Test with multiple queries + queries = ["query one", "query two"] + query_embeddings = list(model.query_embed(queries)) + assert len(query_embeddings) == 2 + assert query_embeddings[0].shape == (1024,) + + if is_ci: + delete_model_cache(model.model._model_dir) + + +def test_prefix_backward_compatibility() -> None: + """Test that models without prefix configuration still work correctly.""" + is_ci = os.getenv("CI") + + # Test with a model that doesn't have prefix configuration + model_name = "BAAI/bge-small-en-v1.5" + model = TextEmbedding(model_name=model_name) + + test_text = "hello world" + + # All three methods should produce the same embeddings for models without prefix config + query_embedding = list(model.query_embed(test_text)) + passage_embedding = list(model.passage_embed([test_text])) + regular_embedding = list(model.embed([test_text])) + + query_array = np.array(query_embedding) + passage_array = np.array(passage_embedding) + regular_array = np.array(regular_embedding) + + assert np.allclose(query_array, regular_array, atol=1e-5), ( + "Query embed should match regular embed for models without prefix config" + ) + assert np.allclose(passage_array, regular_array, atol=1e-5), ( + "Passage embed should match regular embed for models without prefix config" + ) + + if is_ci: + delete_model_cache(model.model._model_dir)