1515
1616import asyncio
1717import logging
18- from typing import Any , Dict , List , Optional , Union
18+ from typing import Any , Dict , List , Optional , Union , cast
1919
20- from annoy import AnnoyIndex
20+ from annoy import AnnoyIndex # type: ignore
2121
2222from nemoguardrails .embeddings .cache import cache_embeddings
2323from nemoguardrails .embeddings .index import EmbeddingsIndex , IndexItem
@@ -45,62 +45,51 @@ class BasicEmbeddingsIndex(EmbeddingsIndex):
4545 max_batch_hold: The maximum time a batch is held before being processed
4646 """
4747
48- embedding_model : str
49- embedding_engine : str
50- embedding_params : Dict [str , Any ]
51- index : AnnoyIndex
52- embedding_size : int
53- cache_config : EmbeddingsCacheConfig
54- embeddings : List [List [float ]]
55- search_threshold : float
56- use_batching : bool
57- max_batch_size : int
58- max_batch_hold : float
59-
6048 def __init__ (
6149 self ,
62- embedding_model = None ,
63- embedding_engine = None ,
64- embedding_params = None ,
65- index = None ,
66- cache_config : Union [EmbeddingsCacheConfig , Dict [str , Any ]] = None ,
67- search_threshold : float = None ,
50+ embedding_model : str = "sentence-transformers/all-MiniLM-L6-v2" ,
51+ embedding_engine : str = "SentenceTransformers" ,
52+ embedding_params : Optional [ Dict [ str , Any ]] = None ,
53+ index : Optional [ AnnoyIndex ] = None ,
54+ cache_config : Optional [ Union [EmbeddingsCacheConfig , Dict [str , Any ] ]] = None ,
55+ search_threshold : float = float ( "inf" ) ,
6856 use_batching : bool = False ,
6957 max_batch_size : int = 10 ,
7058 max_batch_hold : float = 0.01 ,
7159 ):
7260 """Initialize the BasicEmbeddingsIndex.
7361
7462 Args:
75- embedding_model (str, optional): The model for computing embeddings. Defaults to None.
76- embedding_engine (str, optional): The engine for computing embeddings. Defaults to None.
77- index (AnnoyIndex, optional): The pre-existing index. Defaults to None.
78- cache_config (EmbeddingsCacheConfig | Dict[str, Any], optional): The cache configuration. Defaults to None.
63+ embedding_model: The model for computing embeddings.
64+ embedding_engine: The engine for computing embeddings.
65+ index: The pre-existing index.
66+ cache_config: The cache configuration.
67+ search_threshold: The threshold for filtering search results.
7968 use_batching: Whether to batch requests when computing the embeddings.
8069 max_batch_size: The maximum size of a batch.
8170 max_batch_hold: The maximum time a batch is held before being processed
8271 """
8372 self ._model : Optional [EmbeddingModel ] = None
84- self ._items = []
85- self ._embeddings = []
73+ self ._items : List [ IndexItem ] = []
74+ self ._embeddings : List [ List [ float ]] = []
8675 self .embedding_model = embedding_model
8776 self .embedding_engine = embedding_engine
8877 self .embedding_params = embedding_params or {}
8978 self ._embedding_size = 0
90- self .search_threshold = search_threshold or float ( "inf" )
79+ self .search_threshold = search_threshold
9180 if isinstance (cache_config , Dict ):
9281 self ._cache_config = EmbeddingsCacheConfig (** cache_config )
9382 else :
9483 self ._cache_config = cache_config or EmbeddingsCacheConfig ()
9584 self ._index = index
9685
9786 # Data structures for batching embedding requests
98- self ._req_queue = {}
99- self ._req_results = {}
100- self ._req_idx = 0
101- self ._current_batch_finished_event = None
102- self ._current_batch_full_event = None
103- self ._current_batch_submitted = asyncio .Event ()
87+ self ._req_queue : Dict [ int , str ] = {}
88+ self ._req_results : Dict [ int , List [ float ]] = {}
89+ self ._req_idx : int = 0
90+ self ._current_batch_finished_event : Optional [ asyncio . Event ] = None
91+ self ._current_batch_full_event : Optional [ asyncio . Event ] = None
92+ self ._current_batch_submitted : asyncio . Event = asyncio .Event ()
10493
10594 # Initialize the batching configuration
10695 self .use_batching = use_batching
@@ -112,6 +101,11 @@ def embeddings_index(self):
112101 """Get the current embedding index"""
113102 return self ._index
114103
104+ @embeddings_index .setter
105+ def embeddings_index (self , index ):
106+ """Setter to allow replacing the index dynamically."""
107+ self ._index = index
108+
115109 @property
116110 def cache_config (self ):
117111 """Get the cache configuration."""
@@ -127,16 +121,14 @@ def embeddings(self):
127121 """Get the computed embeddings."""
128122 return self ._embeddings
129123
130- @embeddings_index .setter
131- def embeddings_index (self , index ):
132- """Setter to allow replacing the index dynamically."""
133- self ._index = index
134-
135124 def _init_model (self ):
136125 """Initialize the model used for computing the embeddings."""
126+ model = self .embedding_model
127+ engine = self .embedding_engine
128+
137129 self ._model = init_embedding_model (
138- embedding_model = self . embedding_model ,
139- embedding_engine = self . embedding_engine ,
130+ embedding_model = model ,
131+ embedding_engine = engine ,
140132 embedding_params = self .embedding_params ,
141133 )
142134
@@ -153,7 +145,9 @@ async def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
153145 if self ._model is None :
154146 self ._init_model ()
155147
156- embeddings = await self ._model .encode_async (texts )
148+ # self._model can't be None here, or self._init_model() would throw a ValueError
149+ model : EmbeddingModel = cast (EmbeddingModel , self ._model )
150+ embeddings = await model .encode_async (texts )
157151 return embeddings
158152
159153 async def add_item (self , item : IndexItem ):
@@ -199,6 +193,12 @@ async def _run_batch(self):
199193 """Runs the current batch of embeddings."""
200194
201195 # Wait up to `max_batch_hold` time or until `max_batch_size` is reached.
196+ if (
197+ self ._current_batch_full_event is None
198+ or self ._current_batch_finished_event is None
199+ ):
200+ raise RuntimeError ("Batch events not initialized. This should not happen." )
201+
202202 done , pending = await asyncio .wait (
203203 [
204204 asyncio .create_task (asyncio .sleep (self .max_batch_hold )),
@@ -244,7 +244,10 @@ async def _batch_get_embeddings(self, text: str) -> List[float]:
244244 self ._req_idx += 1
245245 self ._req_queue [req_id ] = text
246246
247- if self ._current_batch_finished_event is None :
247+ if (
248+ self ._current_batch_finished_event is None
249+ or self ._current_batch_full_event is None
250+ ):
248251 self ._current_batch_finished_event = asyncio .Event ()
249252 self ._current_batch_full_event = asyncio .Event ()
250253 self ._current_batch_submitted .clear ()
0 commit comments