Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 45 additions & 42 deletions nemoguardrails/embeddings/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

import asyncio
import logging
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, cast

from annoy import AnnoyIndex
from annoy import AnnoyIndex # type: ignore

from nemoguardrails.embeddings.cache import cache_embeddings
from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem
Expand Down Expand Up @@ -45,62 +45,51 @@ class BasicEmbeddingsIndex(EmbeddingsIndex):
max_batch_hold: The maximum time a batch is held before being processed
"""

embedding_model: str
embedding_engine: str
embedding_params: Dict[str, Any]
index: AnnoyIndex
embedding_size: int
cache_config: EmbeddingsCacheConfig
embeddings: List[List[float]]
search_threshold: float
use_batching: bool
max_batch_size: int
max_batch_hold: float

def __init__(
self,
embedding_model=None,
embedding_engine=None,
embedding_params=None,
index=None,
cache_config: Union[EmbeddingsCacheConfig, Dict[str, Any]] = None,
search_threshold: float = None,
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
embedding_engine: str = "SentenceTransformers",
embedding_params: Optional[Dict[str, Any]] = None,
index: Optional[AnnoyIndex] = None,
cache_config: Optional[Union[EmbeddingsCacheConfig, Dict[str, Any]]] = None,
search_threshold: float = float("inf"),
use_batching: bool = False,
max_batch_size: int = 10,
max_batch_hold: float = 0.01,
):
"""Initialize the BasicEmbeddingsIndex.

Args:
embedding_model (str, optional): The model for computing embeddings. Defaults to None.
embedding_engine (str, optional): The engine for computing embeddings. Defaults to None.
index (AnnoyIndex, optional): The pre-existing index. Defaults to None.
cache_config (EmbeddingsCacheConfig | Dict[str, Any], optional): The cache configuration. Defaults to None.
embedding_model: The model for computing embeddings.
embedding_engine: The engine for computing embeddings.
index: The pre-existing index.
cache_config: The cache configuration.
search_threshold: The threshold for filtering search results.
use_batching: Whether to batch requests when computing the embeddings.
max_batch_size: The maximum size of a batch.
max_batch_hold: The maximum time a batch is held before being processed
"""
self._model: Optional[EmbeddingModel] = None
self._items = []
self._embeddings = []
self._items: List[IndexItem] = []
self._embeddings: List[List[float]] = []
self.embedding_model = embedding_model
self.embedding_engine = embedding_engine
self.embedding_params = embedding_params or {}
self._embedding_size = 0
self.search_threshold = search_threshold or float("inf")
self.search_threshold = search_threshold
if isinstance(cache_config, Dict):
self._cache_config = EmbeddingsCacheConfig(**cache_config)
else:
self._cache_config = cache_config or EmbeddingsCacheConfig()
self._index = index

# Data structures for batching embedding requests
self._req_queue = {}
self._req_results = {}
self._req_idx = 0
self._current_batch_finished_event = None
self._current_batch_full_event = None
self._current_batch_submitted = asyncio.Event()
self._req_queue: Dict[int, str] = {}
self._req_results: Dict[int, List[float]] = {}
self._req_idx: int = 0
self._current_batch_finished_event: Optional[asyncio.Event] = None
self._current_batch_full_event: Optional[asyncio.Event] = None
self._current_batch_submitted: asyncio.Event = asyncio.Event()

# Initialize the batching configuration
self.use_batching = use_batching
Expand All @@ -112,6 +101,11 @@ def embeddings_index(self):
"""Get the current embedding index"""
return self._index

@embeddings_index.setter
def embeddings_index(self, index):
"""Setter to allow replacing the index dynamically."""
self._index = index

@property
def cache_config(self):
"""Get the cache configuration."""
Expand All @@ -127,16 +121,14 @@ def embeddings(self):
"""Get the computed embeddings."""
return self._embeddings

@embeddings_index.setter
def embeddings_index(self, index):
"""Setter to allow replacing the index dynamically."""
self._index = index

def _init_model(self):
"""Initialize the model used for computing the embeddings."""
model = self.embedding_model
engine = self.embedding_engine

self._model = init_embedding_model(
embedding_model=self.embedding_model,
embedding_engine=self.embedding_engine,
embedding_model=model,
embedding_engine=engine,
embedding_params=self.embedding_params,
)

Expand All @@ -153,7 +145,9 @@ async def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
if self._model is None:
self._init_model()

embeddings = await self._model.encode_async(texts)
# self._model can't be None here, or self._init_model() would throw a ValueError
model: EmbeddingModel = cast(EmbeddingModel, self._model)
embeddings = await model.encode_async(texts)
return embeddings

async def add_item(self, item: IndexItem):
Expand Down Expand Up @@ -199,6 +193,12 @@ async def _run_batch(self):
"""Runs the current batch of embeddings."""

# Wait up to `max_batch_hold` time or until `max_batch_size` is reached.
if (
self._current_batch_full_event is None
or self._current_batch_finished_event is None
):
raise RuntimeError("Batch events not initialized. This should not happen.")

done, pending = await asyncio.wait(
[
asyncio.create_task(asyncio.sleep(self.max_batch_hold)),
Expand Down Expand Up @@ -244,7 +244,10 @@ async def _batch_get_embeddings(self, text: str) -> List[float]:
self._req_idx += 1
self._req_queue[req_id] = text

if self._current_batch_finished_event is None:
if (
self._current_batch_finished_event is None
or self._current_batch_full_event is None
):
self._current_batch_finished_event = asyncio.Event()
self._current_batch_full_event = asyncio.Event()
self._current_batch_submitted.clear()
Expand Down
38 changes: 26 additions & 12 deletions nemoguardrails/embeddings/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from abc import ABC, abstractmethod
from functools import singledispatchmethod
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Optional

try:
import redis # type: ignore
except ImportError:
redis = None # type: ignore

from nemoguardrails.rails.llm.config import EmbeddingsCacheConfig

Expand All @@ -30,6 +35,8 @@
class KeyGenerator(ABC):
"""Abstract class for key generators."""

name: str # Class attribute that should be defined in subclasses

@abstractmethod
def generate_key(self, text: str) -> str:
pass
Expand Down Expand Up @@ -76,6 +83,8 @@ def generate_key(self, text: str) -> str:
class CacheStore(ABC):
"""Abstract class for cache stores."""

name: str

@abstractmethod
def get(self, key):
"""Get a value from the cache."""
Expand Down Expand Up @@ -147,7 +156,7 @@ class FilesystemCacheStore(CacheStore):

name = "filesystem"

def __init__(self, cache_dir: str = None):
def __init__(self, cache_dir: Optional[str] = None):
self._cache_dir = Path(cache_dir or ".cache/embeddings")
self._cache_dir.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -190,8 +199,10 @@ class RedisCacheStore(CacheStore):
name = "redis"

def __init__(self, host: str = "localhost", port: int = 6379, db: int = 0):
import redis

if redis is None:
raise ImportError(
"Could not import redis, please install it with `pip install redis`."
)
self._redis = redis.Redis(host=host, port=port, db=db)

def get(self, key):
Expand All @@ -207,9 +218,9 @@ def clear(self):
class EmbeddingsCache:
def __init__(
self,
key_generator: KeyGenerator = None,
cache_store: CacheStore = None,
store_config: dict = None,
key_generator: KeyGenerator,
cache_store: CacheStore,
store_config: Optional[dict] = None,
):
self._key_generator = key_generator
self._cache_store = cache_store
Expand All @@ -218,7 +229,10 @@ def __init__(
@classmethod
def from_dict(cls, d: Dict[str, str]):
key_generator = KeyGenerator.from_name(d.get("key_generator"))()
store_config = d.get("store_config")
store_config_raw = d.get("store_config")
store_config: dict = (
store_config_raw if isinstance(store_config_raw, dict) else {}
)
cache_store = CacheStore.from_name(d.get("store"))(**store_config)

return cls(key_generator=key_generator, cache_store=cache_store)
Expand All @@ -239,7 +253,7 @@ def get_config(self):
def get(self, texts):
raise NotImplementedError

@get.register
@get.register(str)
def _(self, text: str):
key = self._key_generator.generate_key(text)
log.info(f"Fetching key {key} for text '{text[:20]}...' from cache")
Expand All @@ -248,7 +262,7 @@ def _(self, text: str):

return result

@get.register
@get.register(list)
def _(self, texts: list):
cached = {}

Expand All @@ -266,13 +280,13 @@ def _(self, texts: list):
def set(self, texts):
raise NotImplementedError

@set.register
@set.register(str)
def _(self, text: str, value: List[float]):
key = self._key_generator.generate_key(text)
log.info(f"Cache miss for text '{text}'. Storing key {key} in cache.")
self._cache_store.set(key, value)

@set.register
@set.register(list)
def _(self, texts: list, values: List[List[float]]):
for text, value in zip(texts, values):
self.set(text, value)
Expand Down
7 changes: 3 additions & 4 deletions nemoguardrails/embeddings/providers/azureopenai.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,16 @@ class AzureEmbeddingModel(EmbeddingModel):

def __init__(self, embedding_model: str):
try:
from openai import AzureOpenAI
from openai import AzureOpenAI # type: ignore
except ImportError:
raise ImportError(
"Could not import openai, please install it with "
"`pip install openai`."
"Could not import openai, please install it with `pip install openai`."
)
# Set Azure OpenAI API credentials
self.client = AzureOpenAI(
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), # type: ignore
)

self.embedding_model = embedding_model
Expand Down
15 changes: 11 additions & 4 deletions nemoguardrails/embeddings/providers/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
import asyncio
from contextvars import ContextVar
from typing import List
from typing import TYPE_CHECKING, List

from .base import EmbeddingModel

Expand All @@ -23,6 +23,10 @@
# is changed, it will fail.
async_client_var: ContextVar = ContextVar("async_client", default=None)

if TYPE_CHECKING:
import cohere
from cohere import AsyncClient, Client


class CohereEmbeddingModel(EmbeddingModel):
"""
Expand Down Expand Up @@ -64,7 +68,7 @@ def __init__(

self.model = embedding_model
self.input_type = input_type
self.client = cohere.Client(**kwargs)
self.client = cohere.Client(**kwargs) # type: ignore[reportCallIssue]

self.embedding_size_dict = {
"embed-v4.0": 1536,
Expand Down Expand Up @@ -120,6 +124,9 @@ def encode(self, documents: List[str]) -> List[List[float]]:
"""

# Make embedding request to Cohere API
return self.client.embed(
# Since we don't pass embedding_types parameter, the response should be
# EmbeddingsFloatsEmbedResponse with embeddings as List[List[float]]
response = self.client.embed(
texts=documents, model=self.model, input_type=self.input_type
).embeddings
)
return response.embeddings # type: ignore[return-value]
2 changes: 1 addition & 1 deletion nemoguardrails/embeddings/providers/fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class FastEmbedEmbeddingModel(EmbeddingModel):
engine_name = "FastEmbed"

def __init__(self, embedding_model: str, **kwargs):
from fastembed import TextEmbedding as Embedding
from fastembed import TextEmbedding as Embedding # type: ignore

# Enabling a short form model name for all-MiniLM-L6-v2.
if embedding_model == "all-MiniLM-L6-v2":
Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/embeddings/providers/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class GoogleEmbeddingModel(EmbeddingModel):

def __init__(self, embedding_model: str, **kwargs):
try:
from google import genai
from google import genai # type: ignore[import]

except ImportError:
raise ImportError(
Expand Down
2 changes: 1 addition & 1 deletion nemoguardrails/embeddings/providers/nim.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class NIMEmbeddingModel(EmbeddingModel):

def __init__(self, embedding_model: str, **kwargs):
try:
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings # type: ignore

self.model = embedding_model
self.document_embedder = NVIDIAEmbeddings(model=embedding_model, **kwargs)
Expand Down
6 changes: 3 additions & 3 deletions nemoguardrails/embeddings/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,14 @@ def __init__(
**kwargs,
):
try:
import openai
from openai import AsyncOpenAI, OpenAI
import openai # type: ignore
from openai import AsyncOpenAI, OpenAI # type: ignore
except ImportError:
raise ImportError(
"Could not import openai, please install it with "
"`pip install openai`."
)
if openai.__version__ < "1.0.0":
if openai.__version__ < "1.0.0": # type: ignore
raise RuntimeError(
"`openai<1.0.0` is no longer supported. "
"Please upgrade using `pip install openai>=1.0.0`."
Expand Down
Loading