Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,26 @@
sources=ModelSource(hf="snowflake/snowflake-arctic-embed-l"),
model_file="onnx/model.onnx",
),
DenseModelDescription(
model="snowflake/snowflake-arctic-embed-l-v2.0",
dim=1024,
tasks={
"embedding": {
"query_prefix": "query: ",
"passage_prefix": "",
}
},
description=(
"Text embeddings, Unimodal (text), Multilingual (74 languages), 8192 input tokens truncation, "
"Based on XLM-RoBERTa, supports Matryoshka learning for dimension truncation, "
"Prefixes for queries: recommended (query: ), 2024 year."
),
license="apache-2.0",
size_in_GB=2.27,
sources=ModelSource(hf="Snowflake/snowflake-arctic-embed-l-v2.0"),
model_file="onnx/model.onnx",
additional_files=["onnx/model.onnx_data"],
),
DenseModelDescription(
model="jinaai/jina-clip-v1",
dim=768,
Expand Down Expand Up @@ -294,6 +314,60 @@ def embed(
**kwargs,
)

def query_embed(
self, query: Union[str, Iterable[str]], **kwargs: Any
) -> Iterable[NumpyArray]:
"""
Embeds queries using the model-specific query prefix if configured.

Args:
query: The query or queries to embed
**kwargs: Additional arguments to pass to embed

Yields:
Iterable[NumpyArray]: Query embeddings
"""
# Check if model has task-specific prefixes configured
if self.model_description.tasks:
embedding_task = self.model_description.tasks.get("embedding", {})
query_prefix = embedding_task.get("query_prefix", "")

if query_prefix:
# Add prefix to queries
if isinstance(query, str):
query = f"{query_prefix}{query}"
else:
query = [f"{query_prefix}{q}" for q in query]

# Use parent implementation
if isinstance(query, str):
yield from self.embed([query], **kwargs)
else:
yield from self.embed(query, **kwargs)

def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
"""
Embeds passages using the model-specific passage prefix if configured.

Args:
texts: The passages to embed
**kwargs: Additional arguments to pass to embed

Yields:
Iterable[NumpyArray]: Passage embeddings
"""
# Check if model has task-specific prefixes configured
if self.model_description.tasks:
embedding_task = self.model_description.tasks.get("embedding", {})
passage_prefix = embedding_task.get("passage_prefix", "")

if passage_prefix:
# Add prefix to passages
texts = [f"{passage_prefix}{t}" for t in texts]

# Use parent implementation
yield from self.embed(texts, **kwargs)

@classmethod
def _get_worker_class(cls) -> Type["TextEmbeddingWorker[NumpyArray]"]:
return OnnxTextEmbeddingWorker
Expand Down
75 changes: 75 additions & 0 deletions tests/test_text_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@
[0.0080, -0.0266, -0.0335, 0.0282, 0.0143]
),
"snowflake/snowflake-arctic-embed-l": np.array([0.0189, -0.0673, 0.0183, 0.0124, 0.0146]),
"snowflake/snowflake-arctic-embed-l-v2.0": np.array(
[-0.0266, 0.0167, -0.0478, -0.0039, -0.0128]
),
"Qdrant/clip-ViT-B-32-text": np.array([0.0083, 0.0103, -0.0138, 0.0199, -0.0069]),
"thenlper/gte-base": np.array([0.0038, 0.0355, 0.0181, 0.0092, 0.0654]),
"jinaai/jina-clip-v1": np.array([-0.0862, -0.0101, -0.0056, 0.0375, -0.0472]),
Expand Down Expand Up @@ -175,3 +178,75 @@ def test_embedding_size() -> None:

if is_ci:
delete_model_cache(model.model._model_dir)


def test_query_passage_prefix() -> None:
"""Test that query/passage prefixes are applied correctly for models with prefix configuration."""
is_ci = os.getenv("CI")

# Test with Snowflake Arctic Embed L v2.0 which has query_prefix configured
model_name = "snowflake/snowflake-arctic-embed-l-v2.0"
model = TextEmbedding(model_name=model_name)

test_text = "what is fastembed?"

# Test query_embed (should apply "query: " prefix)
query_embedding = list(model.query_embed(test_text))
query_embedding_array = np.array(query_embedding)

# Test regular embed (should not apply prefix)
regular_embedding = list(model.embed([test_text]))
regular_embedding_array = np.array(regular_embedding)

# Query embeddings with prefix should differ from regular embeddings without prefix
assert not np.allclose(query_embedding_array, regular_embedding_array), (
"Query embeddings with prefix should differ from regular embeddings"
)

# Test passage_embed (should not apply prefix for this model)
passage_embedding = list(model.passage_embed([test_text]))
passage_embedding_array = np.array(passage_embedding)

# Passage embeddings should match regular embeddings (both without prefix)
assert np.allclose(passage_embedding_array, regular_embedding_array, atol=1e-5), (
"Passage embeddings should match regular embeddings when no passage prefix configured"
)

# Test with multiple queries
queries = ["query one", "query two"]
query_embeddings = list(model.query_embed(queries))
assert len(query_embeddings) == 2
assert query_embeddings[0].shape == (1024,)

if is_ci:
delete_model_cache(model.model._model_dir)


def test_prefix_backward_compatibility() -> None:
"""Test that models without prefix configuration still work correctly."""
is_ci = os.getenv("CI")

# Test with a model that doesn't have prefix configuration
model_name = "BAAI/bge-small-en-v1.5"
model = TextEmbedding(model_name=model_name)

test_text = "hello world"

# All three methods should produce the same embeddings for models without prefix config
query_embedding = list(model.query_embed(test_text))
passage_embedding = list(model.passage_embed([test_text]))
regular_embedding = list(model.embed([test_text]))

query_array = np.array(query_embedding)
passage_array = np.array(passage_embedding)
regular_array = np.array(regular_embedding)

assert np.allclose(query_array, regular_array, atol=1e-5), (
"Query embed should match regular embed for models without prefix config"
)
assert np.allclose(passage_array, regular_array, atol=1e-5), (
"Passage embed should match regular embed for models without prefix config"
)

if is_ci:
delete_model_cache(model.model._model_dir)