Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions examples/quickstart/quickstart_falkordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,6 @@ async def main():
graphiti = Graphiti(graph_driver=falkor_driver)

try:
# Initialize the graph database with graphiti's indices. This only needs to be done once.
await graphiti.build_indices_and_constraints()

#################################################
# ADDING EPISODES
#################################################
Expand Down
3 changes: 0 additions & 3 deletions examples/quickstart/quickstart_neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ async def main():
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)

try:
# Initialize the graph database with graphiti's indices. This only needs to be done once.
await graphiti.build_indices_and_constraints()

#################################################
# ADDING EPISODES
#################################################
Expand Down
110 changes: 110 additions & 0 deletions graphiti_core/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import functools
import inspect
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar

from graphiti_core.driver.driver import GraphProvider
from graphiti_core.helpers import semaphore_gather
from graphiti_core.search.search_config import SearchResults

F = TypeVar('F', bound=Callable[..., Awaitable[Any]])


def handle_multiple_group_ids(func: F) -> F:
"""
Decorator for FalkorDB methods that need to handle multiple group_ids.
Runs the function for each group_id separately and merges results.
"""

@functools.wraps(func)
async def wrapper(self, *args, **kwargs):
group_ids_func_pos = get_parameter_position(func, 'group_ids')
group_ids_pos = (
group_ids_func_pos - 1 if group_ids_func_pos is not None else None
) # Adjust for zero-based index
group_ids = kwargs.get('group_ids')

# If not in kwargs and position exists, get from args
if group_ids is None and group_ids_pos is not None and len(args) > group_ids_pos:
group_ids = args[group_ids_pos]

# Only handle FalkorDB with multiple group_ids
if (
hasattr(self, 'clients')
and hasattr(self.clients, 'driver')
and self.clients.driver.provider == GraphProvider.FALKORDB
and group_ids
and len(group_ids) > 1
):
# Execute for each group_id concurrently
driver = self.clients.driver

async def execute_for_group(gid: str):
# Remove group_ids from args if it was passed positionally
filtered_args = list(args)
if group_ids_pos is not None and len(args) > group_ids_pos:
filtered_args.pop(group_ids_pos)

return await func(
self,
*filtered_args,
**{**kwargs, 'group_ids': [gid], 'driver': driver.clone(database=gid)},
)

results = await semaphore_gather(
*[execute_for_group(gid) for gid in group_ids],
max_coroutines=getattr(self, 'max_coroutines', None),
)

# Merge results based on type
if isinstance(results[0], SearchResults):
return SearchResults.merge(results)
elif isinstance(results[0], list):
return [item for result in results for item in result]
elif isinstance(results[0], tuple):
# Handle tuple outputs (like build_communities returning (nodes, edges))
merged_tuple = []
for i in range(len(results[0])):
component_results = [result[i] for result in results]
if isinstance(component_results[0], list):
merged_tuple.append(
[item for component in component_results for item in component]
)
else:
merged_tuple.append(component_results)
return tuple(merged_tuple)
else:
return results

# Normal execution
return await func(self, *args, **kwargs)

return wrapper # type: ignore


def get_parameter_position(func: Callable, param_name: str) -> int | None:
"""
Returns the positional index of a parameter in the function signature.
If the parameter is not found, returns None.
"""
sig = inspect.signature(func)
for idx, (name, _param) in enumerate(sig.parameters.items()):
if name == param_name:
return idx
return None
9 changes: 9 additions & 0 deletions graphiti_core/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class GraphDriver(ABC):
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
)
_database: str
default_group_id: str = ''
search_interface: SearchInterface | None = None
graph_operations_interface: GraphOperationsInterface | None = None

Expand Down Expand Up @@ -105,6 +106,14 @@ def with_database(self, database: str) -> 'GraphDriver':

return cloned

@abstractmethod
async def build_indices_and_constraints(self, delete_existing: bool = False):
raise NotImplementedError()

def clone(self, database: str) -> 'GraphDriver':
"""Clone the driver with a different database or graph name."""
return self

def build_fulltext_query(
self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
) -> str:
Expand Down
37 changes: 34 additions & 3 deletions graphiti_core/driver/falkordb_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
) from None

from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -112,6 +113,8 @@ async def run(self, query: str | list, **kwargs: Any) -> Any:

class FalkorDriver(GraphDriver):
provider = GraphProvider.FALKORDB
default_group_id: str = '\\_'
fulltext_syntax: str = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries
aoss_client: None = None

def __init__(
Expand All @@ -129,17 +132,32 @@ def __init__(
FalkorDB is a multi-tenant graph database.
To connect, provide the host and port.
The default parameters assume a local (on-premises) FalkorDB instance.

Args:
host (str): The host where FalkorDB is running.
port (int): The port on which FalkorDB is listening.
username (str | None): The username for authentication (if required).
password (str | None): The password for authentication (if required).
falkor_db (FalkorDB | None): An existing FalkorDB instance to use instead of creating a new one.
database (str): The name of the database to connect to. Defaults to 'default_db'.
"""
super().__init__()

self._database = database
if falkor_db is not None:
# If a FalkorDB instance is provided, use it directly
self.client = falkor_db
else:
self.client = FalkorDB(host=host, port=port, username=username, password=password)

self.fulltext_syntax = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries see https://redis.io/docs/latest/develop/ai/search-and-query/query/full-text/
# Schedule the indices and constraints to be built
try:
# Try to get the current event loop
loop = asyncio.get_running_loop()
# Schedule the build_indices_and_constraints to run
loop.create_task(self.build_indices_and_constraints())
except RuntimeError:
# No event loop running, this will be handled later
pass

def _get_graph(self, graph_name: str | None) -> FalkorGraph:
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
Expand Down Expand Up @@ -224,12 +242,25 @@ async def delete_all_indexes(self) -> None:
if drop_tasks:
await asyncio.gather(*drop_tasks)

async def build_indices_and_constraints(self, delete_existing=False):
if delete_existing:
await self.delete_all_indexes()
index_queries = get_range_indices(self.provider) + get_fulltext_indices(self.provider)
for query in index_queries:
await self.execute_query(query)

def clone(self, database: str) -> 'GraphDriver':
"""
Returns a shallow copy of this driver with a different default database.
Reuses the same connection (e.g. FalkorDB, Neo4j).
"""
cloned = FalkorDriver(falkor_db=self.client, database=database)
if database == self._database:
cloned = self
elif database == self.default_group_id:
cloned = FalkorDriver(falkor_db=self.client)
else:
# Create a new instance of FalkorDriver with the same connection but a different database
cloned = FalkorDriver(falkor_db=self.client, database=database)

return cloned

Expand Down
34 changes: 34 additions & 0 deletions graphiti_core/driver/neo4j_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
from typing_extensions import LiteralString

from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.helpers import semaphore_gather

logger = logging.getLogger(__name__)


class Neo4jDriver(GraphDriver):
provider = GraphProvider.NEO4J
default_group_id: str = ''

def __init__(
self,
Expand All @@ -43,6 +46,18 @@ def __init__(
)
self._database = database

# Schedule the indices and constraints to be built
import asyncio

try:
# Try to get the current event loop
loop = asyncio.get_running_loop()
# Schedule the build_indices_and_constraints to run
loop.create_task(self.build_indices_and_constraints())
except RuntimeError:
# No event loop running, this will be handled later
pass

self.aoss_client = None

async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
Expand Down Expand Up @@ -73,6 +88,25 @@ def delete_all_indexes(self) -> Coroutine:
'CALL db.indexes() YIELD name DROP INDEX name',
)

async def build_indices_and_constraints(self, delete_existing: bool = False):
if delete_existing:
await self.delete_all_indexes()

range_indices: list[LiteralString] = get_range_indices(self.provider)

fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider)

index_queries: list[LiteralString] = range_indices + fulltext_indices

await semaphore_gather(
*[
self.execute_query(
query,
)
for query in index_queries
]
)

async def health_check(self) -> None:
"""Check Neo4j connectivity by running the driver's verify_connectivity method."""
try:
Expand Down
Loading
Loading