diff --git a/griptape/drivers/embedding/base_embedding_driver.py b/griptape/drivers/embedding/base_embedding_driver.py index eaf16169a..ebe146e13 100644 --- a/griptape/drivers/embedding/base_embedding_driver.py +++ b/griptape/drivers/embedding/base_embedding_driver.py @@ -15,7 +15,7 @@ if TYPE_CHECKING: from griptape.tokenizers import BaseTokenizer -VectorOperation = Literal["query", "upsert"] +VectorOperation = Literal["query", "upsert", "insert"] @define diff --git a/griptape/drivers/vector/base_vector_store_driver.py b/griptape/drivers/vector/base_vector_store_driver.py index eb16f2735..6bf3ba0d8 100644 --- a/griptape/drivers/vector/base_vector_store_driver.py +++ b/griptape/drivers/vector/base_vector_store_driver.py @@ -58,7 +58,7 @@ def upsert_text_artifact( **kwargs, ) -> str: warnings.warn( - "`BaseVectorStoreDriver.upsert_text_artifacts` is deprecated and will be removed in a future release. `BaseVectorStoreDriver.upsert` is a drop-in replacement.", + "`BaseVectorStoreDriver.upsert_text_artifact` is deprecated and will be removed in a future release. `BaseVectorStoreDriver.upsert` is a drop-in replacement.", DeprecationWarning, stacklevel=2, ) @@ -80,6 +80,60 @@ def upsert_text( ) return self.upsert(string, namespace=namespace, meta=meta, vector_id=vector_id, **kwargs) + def insert_collection( + self, + artifacts: list[TextArtifact] + | list[ImageArtifact] + | dict[str, list[TextArtifact]] + | dict[str, list[ImageArtifact]], + *, + meta: Optional[dict] = None, + **kwargs, + ) -> list[str] | dict[str, list[str]]: + with self.create_futures_executor() as futures_executor: + if isinstance(artifacts, list): + return utils.execute_futures_list( + [ + futures_executor.submit(with_contextvars(self.insert), a, namespace=None, meta=meta, **kwargs) + for a in artifacts + ], + ) + futures_dict = {} + + for namespace, artifact_list in artifacts.items(): + for a in artifact_list: + if not futures_dict.get(namespace): + futures_dict[namespace] = [] + + futures_dict[namespace].append( + futures_executor.submit( + with_contextvars(self.insert), a, namespace=namespace, meta=meta, **kwargs + ) + ) + + return utils.execute_futures_list_dict(futures_dict) + + def insert( + self, + value: str | TextArtifact | ImageArtifact, + *, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + vector_id: Optional[str] = None, + **kwargs, + ) -> str: + artifact = TextArtifact(value) if isinstance(value, str) else value + + meta = {} if meta is None else meta + + if vector_id is None: + vector_id = str(uuid.uuid4()) + + meta = {**meta, "artifact": artifact.to_json()} + vector = self.embedding_driver.embed(artifact, vector_operation="insert") + + return self.insert_vector(vector, vector_id=vector_id, namespace=namespace, meta=meta, **kwargs) + @overload def upsert_collection( self, @@ -169,6 +223,16 @@ def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact: @abstractmethod def delete_vector(self, vector_id: str) -> None: ... + def insert_vector( + self, + vector: list[float], + *, + vector_id: Optional[str] = None, + namespace: Optional[str] = None, + meta: Optional[dict] = None, + **kwargs, + ) -> str: ... + @abstractmethod def upsert_vector( self, diff --git a/tests/unit/drivers/vector/test_base_vector_store_driver.py b/tests/unit/drivers/vector/test_base_vector_store_driver.py index 0b4aae7a6..a258d4e54 100644 --- a/tests/unit/drivers/vector/test_base_vector_store_driver.py +++ b/tests/unit/drivers/vector/test_base_vector_store_driver.py @@ -12,6 +12,83 @@ class TestBaseVectorStoreDriver(ABC): @abstractmethod def driver(self, *args, **kwargs) -> BaseVectorStoreDriver: ... + def test_insert(self, driver, mocker): + spy = mocker.patch.object(driver, "insert_vector", return_value="vid123") + returned_id = driver.insert( + "foobar", + namespace="ns1", + meta={"k": "v"}, + vector_id="vid123", + ) + # Assert + assert returned_id == "vid123" + assert spy.call_count == 1 + args, kwargs = spy.call_args + # vector should come from embedding driver mock: [0, 1] + assert args[0] == [0, 1] + assert kwargs["namespace"] == "ns1" + assert kwargs["vector_id"] == "vid123" + # meta should be merged and include serialized artifact + assert kwargs["meta"]["k"] == "v" + assert "artifact" in kwargs["meta"] + + def test_insert_generates_vector_id_when_not_provided(self, driver, mocker): + spy = mocker.patch.object(driver, "insert_vector", return_value="auto-id") + + result_id = driver.insert(TextArtifact("hello"), namespace="ns2") + + assert result_id == "auto-id" + assert spy.call_count == 1 + _, kwargs = spy.call_args + # Should auto-generate some vector_id string + assert isinstance(kwargs["vector_id"], str) + assert len(kwargs["vector_id"]) > 0 + assert kwargs["namespace"] == "ns2" + assert "artifact" in kwargs["meta"] + + def test_insert_collection_list(self, driver, mocker): + # Prepare two deterministic return ids but allow any execution order + ids = ["a1", "a2"] + mock = mocker.patch.object(driver, "insert_vector", side_effect=ids) + + result = driver.insert_collection([TextArtifact("one"), TextArtifact("two")]) + + # Order of execution may vary across Python/platforms, so compare as sets + assert len(result) == 2 + assert set(result) == set(ids) + # insert is called under the hood once per artifact + assert mock.call_count == 2 + # ensure namespace is None for list inputs + for call in mock.call_args_list: + assert call.kwargs.get("namespace") is None + assert "artifact" in call.kwargs.get("meta", {}) + + def test_insert_collection_dict(self, driver, mocker): + # Generate IDs deterministically per-namespace regardless of interleaving + prefix = {"nsx": "n1", "nsy": "n2"} + counts = {"nsx": 0, "nsy": 0} + + def side_effect(*args, **kwargs): + ns = kwargs.get("namespace") + counts[ns] += 1 + return f"{prefix[ns]}-{counts[ns]}" + + mock = mocker.patch.object(driver, "insert_vector", side_effect=side_effect) + + artifacts = {"nsx": [TextArtifact("a"), TextArtifact("b")], "nsy": [TextArtifact("c")]} + result = driver.insert_collection(artifacts) + + assert isinstance(result, dict) + assert set(result.keys()) == {"nsx", "nsy"} + # Per-namespace order is preserved by BaseVectorStoreDriver utilities + assert result["nsx"] == ["n1-1", "n1-2"] + assert result["nsy"] == ["n2-1"] + assert mock.call_count == 3 + # Validate counts per namespace without assuming cross-namespace call order + namespaces = [c.kwargs.get("namespace") for c in mock.call_args_list] + assert namespaces.count("nsx") == 2 + assert namespaces.count("nsy") == 1 + def test_upsert(self, driver): namespace = driver.upsert(TextArtifact(id="foo1", value="foobar"))