Skip to content

Commit c60f8ce

Browse files
committed
chore(types): Type-clean embeddings/ (25 errors) (#1383)
1 parent e5a1d46 commit c60f8ce

File tree

10 files changed

+94
-70
lines changed

10 files changed

+94
-70
lines changed

nemoguardrails/embeddings/basic.py

Lines changed: 45 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
import asyncio
1717
import logging
18-
from typing import Any, Dict, List, Optional, Union
18+
from typing import Any, Dict, List, Optional, Union, cast
1919

20-
from annoy import AnnoyIndex
20+
from annoy import AnnoyIndex # type: ignore
2121

2222
from nemoguardrails.embeddings.cache import cache_embeddings
2323
from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem
@@ -45,62 +45,51 @@ class BasicEmbeddingsIndex(EmbeddingsIndex):
4545
max_batch_hold: The maximum time a batch is held before being processed
4646
"""
4747

48-
embedding_model: str
49-
embedding_engine: str
50-
embedding_params: Dict[str, Any]
51-
index: AnnoyIndex
52-
embedding_size: int
53-
cache_config: EmbeddingsCacheConfig
54-
embeddings: List[List[float]]
55-
search_threshold: float
56-
use_batching: bool
57-
max_batch_size: int
58-
max_batch_hold: float
59-
6048
def __init__(
6149
self,
62-
embedding_model=None,
63-
embedding_engine=None,
64-
embedding_params=None,
65-
index=None,
66-
cache_config: Union[EmbeddingsCacheConfig, Dict[str, Any]] = None,
67-
search_threshold: float = None,
50+
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
51+
embedding_engine: str = "SentenceTransformers",
52+
embedding_params: Optional[Dict[str, Any]] = None,
53+
index: Optional[AnnoyIndex] = None,
54+
cache_config: Optional[Union[EmbeddingsCacheConfig, Dict[str, Any]]] = None,
55+
search_threshold: float = float("inf"),
6856
use_batching: bool = False,
6957
max_batch_size: int = 10,
7058
max_batch_hold: float = 0.01,
7159
):
7260
"""Initialize the BasicEmbeddingsIndex.
7361
7462
Args:
75-
embedding_model (str, optional): The model for computing embeddings. Defaults to None.
76-
embedding_engine (str, optional): The engine for computing embeddings. Defaults to None.
77-
index (AnnoyIndex, optional): The pre-existing index. Defaults to None.
78-
cache_config (EmbeddingsCacheConfig | Dict[str, Any], optional): The cache configuration. Defaults to None.
63+
embedding_model: The model for computing embeddings.
64+
embedding_engine: The engine for computing embeddings.
65+
index: The pre-existing index.
66+
cache_config: The cache configuration.
67+
search_threshold: The threshold for filtering search results.
7968
use_batching: Whether to batch requests when computing the embeddings.
8069
max_batch_size: The maximum size of a batch.
8170
max_batch_hold: The maximum time a batch is held before being processed
8271
"""
8372
self._model: Optional[EmbeddingModel] = None
84-
self._items = []
85-
self._embeddings = []
73+
self._items: List[IndexItem] = []
74+
self._embeddings: List[List[float]] = []
8675
self.embedding_model = embedding_model
8776
self.embedding_engine = embedding_engine
8877
self.embedding_params = embedding_params or {}
8978
self._embedding_size = 0
90-
self.search_threshold = search_threshold or float("inf")
79+
self.search_threshold = search_threshold
9180
if isinstance(cache_config, Dict):
9281
self._cache_config = EmbeddingsCacheConfig(**cache_config)
9382
else:
9483
self._cache_config = cache_config or EmbeddingsCacheConfig()
9584
self._index = index
9685

9786
# Data structures for batching embedding requests
98-
self._req_queue = {}
99-
self._req_results = {}
100-
self._req_idx = 0
101-
self._current_batch_finished_event = None
102-
self._current_batch_full_event = None
103-
self._current_batch_submitted = asyncio.Event()
87+
self._req_queue: Dict[int, str] = {}
88+
self._req_results: Dict[int, List[float]] = {}
89+
self._req_idx: int = 0
90+
self._current_batch_finished_event: Optional[asyncio.Event] = None
91+
self._current_batch_full_event: Optional[asyncio.Event] = None
92+
self._current_batch_submitted: asyncio.Event = asyncio.Event()
10493

10594
# Initialize the batching configuration
10695
self.use_batching = use_batching
@@ -112,6 +101,11 @@ def embeddings_index(self):
112101
"""Get the current embedding index"""
113102
return self._index
114103

104+
@embeddings_index.setter
105+
def embeddings_index(self, index):
106+
"""Setter to allow replacing the index dynamically."""
107+
self._index = index
108+
115109
@property
116110
def cache_config(self):
117111
"""Get the cache configuration."""
@@ -127,16 +121,14 @@ def embeddings(self):
127121
"""Get the computed embeddings."""
128122
return self._embeddings
129123

130-
@embeddings_index.setter
131-
def embeddings_index(self, index):
132-
"""Setter to allow replacing the index dynamically."""
133-
self._index = index
134-
135124
def _init_model(self):
136125
"""Initialize the model used for computing the embeddings."""
126+
model = self.embedding_model
127+
engine = self.embedding_engine
128+
137129
self._model = init_embedding_model(
138-
embedding_model=self.embedding_model,
139-
embedding_engine=self.embedding_engine,
130+
embedding_model=model,
131+
embedding_engine=engine,
140132
embedding_params=self.embedding_params,
141133
)
142134

@@ -153,7 +145,9 @@ async def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
153145
if self._model is None:
154146
self._init_model()
155147

156-
embeddings = await self._model.encode_async(texts)
148+
# self._model can't be None here, or self._init_model() would throw a ValueError
149+
model: EmbeddingModel = cast(EmbeddingModel, self._model)
150+
embeddings = await model.encode_async(texts)
157151
return embeddings
158152

159153
async def add_item(self, item: IndexItem):
@@ -199,6 +193,12 @@ async def _run_batch(self):
199193
"""Runs the current batch of embeddings."""
200194

201195
# Wait up to `max_batch_hold` time or until `max_batch_size` is reached.
196+
if (
197+
self._current_batch_full_event is None
198+
or self._current_batch_finished_event is None
199+
):
200+
raise RuntimeError("Batch events not initialized. This should not happen.")
201+
202202
done, pending = await asyncio.wait(
203203
[
204204
asyncio.create_task(asyncio.sleep(self.max_batch_hold)),
@@ -244,7 +244,10 @@ async def _batch_get_embeddings(self, text: str) -> List[float]:
244244
self._req_idx += 1
245245
self._req_queue[req_id] = text
246246

247-
if self._current_batch_finished_event is None:
247+
if (
248+
self._current_batch_finished_event is None
249+
or self._current_batch_full_event is None
250+
):
248251
self._current_batch_finished_event = asyncio.Event()
249252
self._current_batch_full_event = asyncio.Event()
250253
self._current_batch_submitted.clear()

nemoguardrails/embeddings/cache.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
from abc import ABC, abstractmethod
2121
from functools import singledispatchmethod
2222
from pathlib import Path
23-
from typing import Dict, List
23+
from typing import Dict, List, Optional
24+
25+
try:
26+
import redis # type: ignore
27+
except ImportError:
28+
redis = None # type: ignore
2429

2530
from nemoguardrails.rails.llm.config import EmbeddingsCacheConfig
2631

@@ -30,6 +35,8 @@
3035
class KeyGenerator(ABC):
3136
"""Abstract class for key generators."""
3237

38+
name: str # Class attribute that should be defined in subclasses
39+
3340
@abstractmethod
3441
def generate_key(self, text: str) -> str:
3542
pass
@@ -76,6 +83,8 @@ def generate_key(self, text: str) -> str:
7683
class CacheStore(ABC):
7784
"""Abstract class for cache stores."""
7885

86+
name: str
87+
7988
@abstractmethod
8089
def get(self, key):
8190
"""Get a value from the cache."""
@@ -147,7 +156,7 @@ class FilesystemCacheStore(CacheStore):
147156

148157
name = "filesystem"
149158

150-
def __init__(self, cache_dir: str = None):
159+
def __init__(self, cache_dir: Optional[str] = None):
151160
self._cache_dir = Path(cache_dir or ".cache/embeddings")
152161
self._cache_dir.mkdir(parents=True, exist_ok=True)
153162

@@ -190,8 +199,10 @@ class RedisCacheStore(CacheStore):
190199
name = "redis"
191200

192201
def __init__(self, host: str = "localhost", port: int = 6379, db: int = 0):
193-
import redis
194-
202+
if redis is None:
203+
raise ImportError(
204+
"Could not import redis, please install it with `pip install redis`."
205+
)
195206
self._redis = redis.Redis(host=host, port=port, db=db)
196207

197208
def get(self, key):
@@ -207,9 +218,9 @@ def clear(self):
207218
class EmbeddingsCache:
208219
def __init__(
209220
self,
210-
key_generator: KeyGenerator = None,
211-
cache_store: CacheStore = None,
212-
store_config: dict = None,
221+
key_generator: KeyGenerator,
222+
cache_store: CacheStore,
223+
store_config: Optional[dict] = None,
213224
):
214225
self._key_generator = key_generator
215226
self._cache_store = cache_store
@@ -218,7 +229,10 @@ def __init__(
218229
@classmethod
219230
def from_dict(cls, d: Dict[str, str]):
220231
key_generator = KeyGenerator.from_name(d.get("key_generator"))()
221-
store_config = d.get("store_config")
232+
store_config_raw = d.get("store_config")
233+
store_config: dict = (
234+
store_config_raw if isinstance(store_config_raw, dict) else {}
235+
)
222236
cache_store = CacheStore.from_name(d.get("store"))(**store_config)
223237

224238
return cls(key_generator=key_generator, cache_store=cache_store)
@@ -239,7 +253,7 @@ def get_config(self):
239253
def get(self, texts):
240254
raise NotImplementedError
241255

242-
@get.register
256+
@get.register(str)
243257
def _(self, text: str):
244258
key = self._key_generator.generate_key(text)
245259
log.info(f"Fetching key {key} for text '{text[:20]}...' from cache")
@@ -248,7 +262,7 @@ def _(self, text: str):
248262

249263
return result
250264

251-
@get.register
265+
@get.register(list)
252266
def _(self, texts: list):
253267
cached = {}
254268

@@ -266,13 +280,13 @@ def _(self, texts: list):
266280
def set(self, texts):
267281
raise NotImplementedError
268282

269-
@set.register
283+
@set.register(str)
270284
def _(self, text: str, value: List[float]):
271285
key = self._key_generator.generate_key(text)
272286
log.info(f"Cache miss for text '{text}'. Storing key {key} in cache.")
273287
self._cache_store.set(key, value)
274288

275-
@set.register
289+
@set.register(list)
276290
def _(self, texts: list, values: List[List[float]]):
277291
for text, value in zip(texts, values):
278292
self.set(text, value)

nemoguardrails/embeddings/providers/azureopenai.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,16 @@ class AzureEmbeddingModel(EmbeddingModel):
4646

4747
def __init__(self, embedding_model: str):
4848
try:
49-
from openai import AzureOpenAI
49+
from openai import AzureOpenAI # type: ignore
5050
except ImportError:
5151
raise ImportError(
52-
"Could not import openai, please install it with "
53-
"`pip install openai`."
52+
"Could not import openai, please install it with `pip install openai`."
5453
)
5554
# Set Azure OpenAI API credentials
5655
self.client = AzureOpenAI(
5756
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
5857
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
59-
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
58+
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), # type: ignore
6059
)
6160

6261
self.embedding_model = embedding_model

nemoguardrails/embeddings/providers/cohere.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
import asyncio
1616
from contextvars import ContextVar
17-
from typing import List
17+
from typing import TYPE_CHECKING, List
1818

1919
from .base import EmbeddingModel
2020

@@ -23,6 +23,10 @@
2323
# is changed, it will fail.
2424
async_client_var: ContextVar = ContextVar("async_client", default=None)
2525

26+
if TYPE_CHECKING:
27+
import cohere
28+
from cohere import AsyncClient, Client
29+
2630

2731
class CohereEmbeddingModel(EmbeddingModel):
2832
"""
@@ -64,7 +68,7 @@ def __init__(
6468

6569
self.model = embedding_model
6670
self.input_type = input_type
67-
self.client = cohere.Client(**kwargs)
71+
self.client = cohere.Client(**kwargs) # type: ignore[reportCallIssue]
6872

6973
self.embedding_size_dict = {
7074
"embed-v4.0": 1536,
@@ -120,6 +124,9 @@ def encode(self, documents: List[str]) -> List[List[float]]:
120124
"""
121125

122126
# Make embedding request to Cohere API
123-
return self.client.embed(
127+
# Since we don't pass embedding_types parameter, the response should be
128+
# EmbeddingsFloatsEmbedResponse with embeddings as List[List[float]]
129+
response = self.client.embed(
124130
texts=documents, model=self.model, input_type=self.input_type
125-
).embeddings
131+
)
132+
return response.embeddings # type: ignore[return-value]

nemoguardrails/embeddings/providers/fastembed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class FastEmbedEmbeddingModel(EmbeddingModel):
4242
engine_name = "FastEmbed"
4343

4444
def __init__(self, embedding_model: str, **kwargs):
45-
from fastembed import TextEmbedding as Embedding
45+
from fastembed import TextEmbedding as Embedding # type: ignore
4646

4747
# Enabling a short form model name for all-MiniLM-L6-v2.
4848
if embedding_model == "all-MiniLM-L6-v2":

nemoguardrails/embeddings/providers/google.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class GoogleEmbeddingModel(EmbeddingModel):
4646

4747
def __init__(self, embedding_model: str, **kwargs):
4848
try:
49-
from google import genai
49+
from google import genai # type: ignore[import]
5050

5151
except ImportError:
5252
raise ImportError(

nemoguardrails/embeddings/providers/nim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class NIMEmbeddingModel(EmbeddingModel):
3535

3636
def __init__(self, embedding_model: str, **kwargs):
3737
try:
38-
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
38+
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings # type: ignore
3939

4040
self.model = embedding_model
4141
self.document_embedder = NVIDIAEmbeddings(model=embedding_model, **kwargs)

nemoguardrails/embeddings/providers/openai.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,14 @@ def __init__(
4646
**kwargs,
4747
):
4848
try:
49-
import openai
50-
from openai import AsyncOpenAI, OpenAI
49+
import openai # type: ignore
50+
from openai import AsyncOpenAI, OpenAI # type: ignore
5151
except ImportError:
5252
raise ImportError(
5353
"Could not import openai, please install it with "
5454
"`pip install openai`."
5555
)
56-
if openai.__version__ < "1.0.0":
56+
if openai.__version__ < "1.0.0": # type: ignore
5757
raise RuntimeError(
5858
"`openai<1.0.0` is no longer supported. "
5959
"Please upgrade using `pip install openai>=1.0.0`."

0 commit comments

Comments
 (0)