Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion griptape/drivers/embedding/base_embedding_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
if TYPE_CHECKING:
from griptape.tokenizers import BaseTokenizer

VectorOperation = Literal["query", "upsert"]
VectorOperation = Literal["query", "upsert", "insert"]


@define
Expand Down
66 changes: 65 additions & 1 deletion griptape/drivers/vector/base_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def upsert_text_artifact(
**kwargs,
) -> str:
warnings.warn(
"`BaseVectorStoreDriver.upsert_text_artifacts` is deprecated and will be removed in a future release. `BaseVectorStoreDriver.upsert` is a drop-in replacement.",
"`BaseVectorStoreDriver.upsert_text_artifact` is deprecated and will be removed in a future release. `BaseVectorStoreDriver.upsert` is a drop-in replacement.",
DeprecationWarning,
stacklevel=2,
)
Expand All @@ -80,6 +80,60 @@ def upsert_text(
)
return self.upsert(string, namespace=namespace, meta=meta, vector_id=vector_id, **kwargs)

def insert_collection(
self,
artifacts: list[TextArtifact]
| list[ImageArtifact]
| dict[str, list[TextArtifact]]
| dict[str, list[ImageArtifact]],
*,
meta: Optional[dict] = None,
**kwargs,
) -> list[str] | dict[str, list[str]]:
with self.create_futures_executor() as futures_executor:
if isinstance(artifacts, list):
return utils.execute_futures_list(
[
futures_executor.submit(with_contextvars(self.insert), a, namespace=None, meta=meta, **kwargs)
for a in artifacts
],
)
futures_dict = {}

for namespace, artifact_list in artifacts.items():
for a in artifact_list:
if not futures_dict.get(namespace):
futures_dict[namespace] = []

futures_dict[namespace].append(
futures_executor.submit(
with_contextvars(self.insert), a, namespace=namespace, meta=meta, **kwargs
)
)

return utils.execute_futures_list_dict(futures_dict)

def insert(
self,
value: str | TextArtifact | ImageArtifact,
*,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
vector_id: Optional[str] = None,
**kwargs,
) -> str:
artifact = TextArtifact(value) if isinstance(value, str) else value

meta = {} if meta is None else meta

if vector_id is None:
vector_id = str(uuid.uuid4())

meta = {**meta, "artifact": artifact.to_json()}
vector = self.embedding_driver.embed(artifact, vector_operation="insert")

return self.insert_vector(vector, vector_id=vector_id, namespace=namespace, meta=meta, **kwargs)

@overload
def upsert_collection(
self,
Expand Down Expand Up @@ -169,6 +223,16 @@ def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact:
@abstractmethod
def delete_vector(self, vector_id: str) -> None: ...

def insert_vector(
self,
vector: list[float],
*,
vector_id: Optional[str] = None,
namespace: Optional[str] = None,
meta: Optional[dict] = None,
**kwargs,
) -> str: ...

@abstractmethod
def upsert_vector(
self,
Expand Down
77 changes: 77 additions & 0 deletions tests/unit/drivers/vector/test_base_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,83 @@ class TestBaseVectorStoreDriver(ABC):
@abstractmethod
def driver(self, *args, **kwargs) -> BaseVectorStoreDriver: ...

def test_insert(self, driver, mocker):
spy = mocker.patch.object(driver, "insert_vector", return_value="vid123")
returned_id = driver.insert(
"foobar",
namespace="ns1",
meta={"k": "v"},
vector_id="vid123",
)
# Assert
assert returned_id == "vid123"
assert spy.call_count == 1
args, kwargs = spy.call_args
# vector should come from embedding driver mock: [0, 1]
assert args[0] == [0, 1]
assert kwargs["namespace"] == "ns1"
assert kwargs["vector_id"] == "vid123"
# meta should be merged and include serialized artifact
assert kwargs["meta"]["k"] == "v"
assert "artifact" in kwargs["meta"]

def test_insert_generates_vector_id_when_not_provided(self, driver, mocker):
spy = mocker.patch.object(driver, "insert_vector", return_value="auto-id")

result_id = driver.insert(TextArtifact("hello"), namespace="ns2")

assert result_id == "auto-id"
assert spy.call_count == 1
_, kwargs = spy.call_args
# Should auto-generate some vector_id string
assert isinstance(kwargs["vector_id"], str)
assert len(kwargs["vector_id"]) > 0
assert kwargs["namespace"] == "ns2"
assert "artifact" in kwargs["meta"]

def test_insert_collection_list(self, driver, mocker):
# Prepare two deterministic return ids but allow any execution order
ids = ["a1", "a2"]
mock = mocker.patch.object(driver, "insert_vector", side_effect=ids)

result = driver.insert_collection([TextArtifact("one"), TextArtifact("two")])

# Order of execution may vary across Python/platforms, so compare as sets
assert len(result) == 2
assert set(result) == set(ids)
# insert is called under the hood once per artifact
assert mock.call_count == 2
# ensure namespace is None for list inputs
for call in mock.call_args_list:
assert call.kwargs.get("namespace") is None
assert "artifact" in call.kwargs.get("meta", {})

def test_insert_collection_dict(self, driver, mocker):
# Generate IDs deterministically per-namespace regardless of interleaving
prefix = {"nsx": "n1", "nsy": "n2"}
counts = {"nsx": 0, "nsy": 0}

def side_effect(*args, **kwargs):
ns = kwargs.get("namespace")
counts[ns] += 1
return f"{prefix[ns]}-{counts[ns]}"

mock = mocker.patch.object(driver, "insert_vector", side_effect=side_effect)

artifacts = {"nsx": [TextArtifact("a"), TextArtifact("b")], "nsy": [TextArtifact("c")]}
result = driver.insert_collection(artifacts)

assert isinstance(result, dict)
assert set(result.keys()) == {"nsx", "nsy"}
# Per-namespace order is preserved by BaseVectorStoreDriver utilities
assert result["nsx"] == ["n1-1", "n1-2"]
assert result["nsy"] == ["n2-1"]
assert mock.call_count == 3
# Validate counts per namespace without assuming cross-namespace call order
namespaces = [c.kwargs.get("namespace") for c in mock.call_args_list]
assert namespaces.count("nsx") == 2
assert namespaces.count("nsy") == 1

def test_upsert(self, driver):
namespace = driver.upsert(TextArtifact(id="foo1", value="foobar"))

Expand Down
Loading