Skip to content

Commit ce5ea9b

Browse files
feat: add custom embedding types and migrate providers
- introduce baseembeddingsprovider and helper for embedding functions - add core embedding types and migrate providers, factory, and storage modules - remove unused type aliases and fix pydantic schema error - update providers with env var support and related fixes
1 parent e070c14 commit ce5ea9b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+2736
-1277
lines changed

pyproject.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,15 @@ aisuite = [
7272
qdrant = [
7373
"qdrant-client[fastembed]>=1.14.3",
7474
]
75+
aws = [
76+
"boto3>=1.40.38",
77+
]
78+
watson = [
79+
"ibm-watsonx-ai>=1.3.39",
80+
]
81+
voyageai = [
82+
"voyageai>=0.3.5",
83+
]
7584

7685
[dependency-groups]
7786
dev = [

src/crewai/knowledge/storage/knowledge_storage.py

Lines changed: 12 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
99
from crewai.rag.config.utils import get_rag_client
1010
from crewai.rag.core.base_client import BaseClient
11-
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
11+
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
12+
from crewai.rag.embeddings.factory import build_embedder
13+
from crewai.rag.embeddings.types import ProviderSpec
1214
from crewai.rag.factory import create_client
1315
from crewai.rag.types import BaseRecord, SearchResult
1416
from crewai.utilities.logger import Logger
@@ -22,12 +24,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
2224

2325
def __init__(
2426
self,
25-
embedder: dict[str, Any] | None = None,
27+
embedder: ProviderSpec | BaseEmbeddingsProvider | None = None,
2628
collection_name: str | None = None,
2729
) -> None:
2830
self.collection_name = collection_name
2931
self._client: BaseClient | None = None
30-
self._embedder_config = embedder # Store embedder config
3132

3233
warnings.filterwarnings(
3334
"ignore",
@@ -36,29 +37,12 @@ def __init__(
3637
)
3738

3839
if embedder:
39-
# Cast to EmbedderConfig for type checking
40-
embedder_typed = cast(EmbedderConfig, embedder)
41-
embedding_function = get_embedding_function(embedder_typed)
42-
batch_size = None
43-
if isinstance(embedder, dict) and "config" in embedder:
44-
nested_config = embedder["config"]
45-
if isinstance(nested_config, dict):
46-
batch_size = nested_config.get("batch_size")
47-
48-
# Create config with batch_size if provided
49-
if batch_size is not None:
50-
config = ChromaDBConfig(
51-
embedding_function=cast(
52-
ChromaEmbeddingFunctionWrapper, embedding_function
53-
),
54-
batch_size=batch_size,
55-
)
56-
else:
57-
config = ChromaDBConfig(
58-
embedding_function=cast(
59-
ChromaEmbeddingFunctionWrapper, embedding_function
60-
)
40+
embedding_function = build_embedder(embedder)
41+
config = ChromaDBConfig(
42+
embedding_function=cast(
43+
ChromaEmbeddingFunctionWrapper, embedding_function
6144
)
45+
)
6246
self._client = create_client(config)
6347

6448
def _get_client(self) -> BaseClient:
@@ -123,23 +107,9 @@ def save(self, documents: list[str]) -> None:
123107

124108
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
125109

126-
batch_size = None
127-
if self._embedder_config and isinstance(self._embedder_config, dict):
128-
if "config" in self._embedder_config:
129-
nested_config = self._embedder_config["config"]
130-
if isinstance(nested_config, dict):
131-
batch_size = nested_config.get("batch_size")
132-
133-
if batch_size is not None:
134-
client.add_documents(
135-
collection_name=collection_name,
136-
documents=rag_documents,
137-
batch_size=batch_size,
138-
)
139-
else:
140-
client.add_documents(
141-
collection_name=collection_name, documents=rag_documents
142-
)
110+
client.add_documents(
111+
collection_name=collection_name, documents=rag_documents
112+
)
143113
except Exception as e:
144114
if "dimension mismatch" in str(e).lower():
145115
Logger(verbose=True).log(

src/crewai/memory/storage/rag_storage.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
88
from crewai.rag.config.utils import get_rag_client
99
from crewai.rag.core.base_client import BaseClient
10-
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
11-
from crewai.rag.embeddings.types import EmbeddingOptions
10+
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
11+
from crewai.rag.embeddings.factory import build_embedder
12+
from crewai.rag.embeddings.types import ProviderSpec
1213
from crewai.rag.factory import create_client
1314
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
1415
from crewai.rag.types import BaseRecord
@@ -26,7 +27,7 @@ def __init__(
2627
self,
2728
type: str,
2829
allow_reset: bool = True,
29-
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
30+
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
3031
crew: Any = None,
3132
path: str | None = None,
3233
) -> None:
@@ -50,15 +51,17 @@ def __init__(
5051
)
5152

5253
if self.embedder_config:
53-
embedding_function = get_embedding_function(self.embedder_config)
54+
embedding_function = build_embedder(self.embedder_config)
5455

5556
try:
5657
_ = embedding_function(["test"])
5758
except Exception as e:
5859
provider = (
59-
self.embedder_config.provider
60-
if isinstance(self.embedder_config, EmbeddingOptions)
61-
else self.embedder_config.get("provider", "unknown")
60+
self.embedder_config["provider"]
61+
if isinstance(self.embedder_config, dict)
62+
else self.embedder_config.__class__.__name__.replace(
63+
"Provider", ""
64+
).lower()
6265
)
6366
raise ValueError(
6467
f"Failed to initialize embedder. Please check your configuration or connection.\n"
@@ -80,7 +83,7 @@ def __init__(
8083
embedding_function=cast(
8184
ChromaEmbeddingFunctionWrapper, embedding_function
8285
),
83-
batch_size=batch_size,
86+
batch_size=cast(int, batch_size),
8487
)
8588
else:
8689
config = ChromaDBConfig(
@@ -142,7 +145,7 @@ def save(self, value: Any, metadata: dict[str, Any]) -> None:
142145
client.add_documents(
143146
collection_name=collection_name,
144147
documents=[document],
145-
batch_size=batch_size,
148+
batch_size=cast(int, batch_size),
146149
)
147150
else:
148151
client.add_documents(
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""Base embeddings callable utilities for RAG systems."""
2+
3+
from typing import Protocol, TypeVar, runtime_checkable
4+
5+
import numpy as np
6+
7+
from crewai.rag.core.types import (
8+
Embeddable,
9+
Embedding,
10+
Embeddings,
11+
PyEmbedding,
12+
)
13+
14+
T = TypeVar("T")
15+
D = TypeVar("D", bound=Embeddable, contravariant=True)
16+
17+
18+
def normalize_embeddings(
19+
target: Embedding | list[Embedding] | PyEmbedding | list[PyEmbedding],
20+
) -> Embeddings | None:
21+
"""Normalize various embedding formats to a standard list of numpy arrays.
22+
23+
Args:
24+
target: Input embeddings in various formats (list of floats, list of lists,
25+
numpy array, or list of numpy arrays).
26+
27+
Returns:
28+
Normalized embeddings as a list of numpy arrays, or None if input is None.
29+
30+
Raises:
31+
ValueError: If embeddings are empty or in an unsupported format.
32+
"""
33+
if isinstance(target, np.ndarray):
34+
if target.ndim == 1:
35+
return [target.astype(np.float32)]
36+
if target.ndim == 2:
37+
return [row.astype(np.float32) for row in target]
38+
raise ValueError(f"Unsupported numpy array shape: {target.shape}")
39+
40+
first = target[0]
41+
if isinstance(first, (int, float)) and not isinstance(first, bool):
42+
return [np.array(target, dtype=np.float32)]
43+
if isinstance(first, list):
44+
return [np.array(emb, dtype=np.float32) for emb in target]
45+
if isinstance(first, np.ndarray):
46+
return [emb.astype(np.float32) for emb in target] # type: ignore[union-attr]
47+
48+
raise ValueError(f"Unsupported embeddings format: {type(first)}")
49+
50+
51+
def maybe_cast_one_to_many(target: T | list[T] | None) -> list[T] | None:
52+
"""Cast a single item to a list if needed.
53+
54+
Args:
55+
target: A single item or list of items.
56+
57+
Returns:
58+
A list of items or None if input is None.
59+
"""
60+
if target is None:
61+
return None
62+
return target if isinstance(target, list) else [target]
63+
64+
65+
def validate_embeddings(embeddings: Embeddings) -> Embeddings:
66+
"""Validate embeddings format and content.
67+
68+
Args:
69+
embeddings: List of numpy arrays to validate.
70+
71+
Returns:
72+
Validated embeddings.
73+
74+
Raises:
75+
ValueError: If embeddings format or content is invalid.
76+
"""
77+
if not isinstance(embeddings, list):
78+
raise ValueError(
79+
f"Expected embeddings to be a list, got {type(embeddings).__name__}"
80+
)
81+
if len(embeddings) == 0:
82+
raise ValueError(
83+
f"Expected embeddings to be a list with at least one item, got {len(embeddings)} embeddings"
84+
)
85+
if not all(isinstance(e, np.ndarray) for e in embeddings):
86+
raise ValueError(
87+
"Expected each embedding in the embeddings to be a numpy array"
88+
)
89+
for i, embedding in enumerate(embeddings):
90+
if embedding.ndim == 0:
91+
raise ValueError(
92+
f"Expected a 1-dimensional array, got a 0-dimensional array {embedding}"
93+
)
94+
if embedding.size == 0:
95+
raise ValueError(
96+
f"Expected each embedding to be a 1-dimensional numpy array with at least 1 value. "
97+
f"Got an array with no values at position {i}"
98+
)
99+
if not all(
100+
isinstance(value, (np.integer, float, np.floating))
101+
and not isinstance(value, bool)
102+
for value in embedding
103+
):
104+
raise ValueError(
105+
f"Expected embedding to contain numeric values, got non-numeric values at position {i}"
106+
)
107+
return embeddings
108+
109+
110+
@runtime_checkable
111+
class EmbeddingFunction(Protocol[D]):
112+
"""Protocol for embedding functions.
113+
114+
Embedding functions convert input data (documents or images) into vector embeddings.
115+
"""
116+
117+
def __call__(self, input: D) -> Embeddings:
118+
"""Convert input data to embeddings.
119+
120+
Args:
121+
input: Input data to embed (documents or images).
122+
123+
Returns:
124+
List of numpy arrays representing the embeddings.
125+
"""
126+
...
127+
128+
def __init_subclass__(cls) -> None:
129+
"""Wrap __call__ method to normalize and validate embeddings."""
130+
super().__init_subclass__()
131+
original_call = cls.__call__
132+
133+
def wrapped_call(self: EmbeddingFunction[D], input: D) -> Embeddings:
134+
result = original_call(self, input)
135+
if result is None:
136+
raise ValueError("Embedding function returned None")
137+
normalized = normalize_embeddings(result)
138+
if normalized is None:
139+
raise ValueError("Normalization returned None for non-None input")
140+
return validate_embeddings(normalized)
141+
142+
cls.__call__ = wrapped_call # type: ignore[method-assign]
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Base class for embedding providers."""
2+
3+
from typing import Generic, TypeVar
4+
5+
from pydantic import Field
6+
from pydantic_settings import BaseSettings, SettingsConfigDict
7+
8+
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
9+
10+
T = TypeVar("T", bound=EmbeddingFunction)
11+
12+
13+
class BaseEmbeddingsProvider(BaseSettings, Generic[T]):
14+
"""Abstract base class for embedding providers.
15+
16+
This class provides a common interface for dynamically loading and building
17+
embedding functions from various providers.
18+
"""
19+
20+
model_config = SettingsConfigDict(extra="allow", populate_by_name=True)
21+
embedding_callable: type[T] = Field(
22+
..., description="The embedding function class to use"
23+
)

src/crewai/rag/core/types.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""Core type definitions for RAG systems."""
2+
3+
from collections.abc import Sequence
4+
from typing import TypeVar
5+
6+
import numpy as np
7+
from numpy import floating, integer, number
8+
from numpy.typing import NDArray
9+
10+
T = TypeVar("T")
11+
12+
PyEmbedding = Sequence[float] | Sequence[int]
13+
PyEmbeddings = list[PyEmbedding]
14+
Embedding = NDArray[np.int32 | np.float32]
15+
Embeddings = list[Embedding]
16+
17+
Documents = list[str]
18+
Images = list[np.ndarray]
19+
Embeddable = Documents | Images
20+
21+
ScalarType = TypeVar("ScalarType", bound=np.generic)
22+
IntegerType = TypeVar("IntegerType", bound=integer)
23+
FloatingType = TypeVar("FloatingType", bound=floating)
24+
NumberType = TypeVar("NumberType", bound=number)
25+
26+
DType32 = TypeVar("DType32", np.int32, np.float32)
27+
DType64 = TypeVar("DType64", np.int64, np.float64)
28+
DTypeCommon = TypeVar("DTypeCommon", np.int32, np.int64, np.float32, np.float64)

0 commit comments

Comments
 (0)