Skip to content

Commit c144ff5

Browse files
galshubeliNaseem77
andauthored
[Improvement] Add GraphID isolation support for FalkorDB multi-tenant architecture (#835)
* Update node_db_queries.py * Update node_db_queries.py * graph-per-graphid * fix-groupid-usage * ruff-fix * rev-driver-changes * rm-un-changes * fix lint --------- Co-authored-by: Naseem Ali <[email protected]>
1 parent 8d99984 commit c144ff5

File tree

10 files changed

+267
-72
lines changed

10 files changed

+267
-72
lines changed

examples/quickstart/quickstart_falkordb.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,6 @@ async def main():
7878
graphiti = Graphiti(graph_driver=falkor_driver)
7979

8080
try:
81-
# Initialize the graph database with graphiti's indices. This only needs to be done once.
82-
await graphiti.build_indices_and_constraints()
83-
8481
#################################################
8582
# ADDING EPISODES
8683
#################################################

examples/quickstart/quickstart_neo4j.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,6 @@ async def main():
6767
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
6868

6969
try:
70-
# Initialize the graph database with graphiti's indices. This only needs to be done once.
71-
await graphiti.build_indices_and_constraints()
72-
7370
#################################################
7471
# ADDING EPISODES
7572
#################################################

graphiti_core/decorators.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""
2+
Copyright 2024, Zep Software, Inc.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import functools
18+
import inspect
19+
from collections.abc import Awaitable, Callable
20+
from typing import Any, TypeVar
21+
22+
from graphiti_core.driver.driver import GraphProvider
23+
from graphiti_core.helpers import semaphore_gather
24+
from graphiti_core.search.search_config import SearchResults
25+
26+
F = TypeVar('F', bound=Callable[..., Awaitable[Any]])
27+
28+
29+
def handle_multiple_group_ids(func: F) -> F:
30+
"""
31+
Decorator for FalkorDB methods that need to handle multiple group_ids.
32+
Runs the function for each group_id separately and merges results.
33+
"""
34+
35+
@functools.wraps(func)
36+
async def wrapper(self, *args, **kwargs):
37+
group_ids_func_pos = get_parameter_position(func, 'group_ids')
38+
group_ids_pos = (
39+
group_ids_func_pos - 1 if group_ids_func_pos is not None else None
40+
) # Adjust for zero-based index
41+
group_ids = kwargs.get('group_ids')
42+
43+
# If not in kwargs and position exists, get from args
44+
if group_ids is None and group_ids_pos is not None and len(args) > group_ids_pos:
45+
group_ids = args[group_ids_pos]
46+
47+
# Only handle FalkorDB with multiple group_ids
48+
if (
49+
hasattr(self, 'clients')
50+
and hasattr(self.clients, 'driver')
51+
and self.clients.driver.provider == GraphProvider.FALKORDB
52+
and group_ids
53+
and len(group_ids) > 1
54+
):
55+
# Execute for each group_id concurrently
56+
driver = self.clients.driver
57+
58+
async def execute_for_group(gid: str):
59+
# Remove group_ids from args if it was passed positionally
60+
filtered_args = list(args)
61+
if group_ids_pos is not None and len(args) > group_ids_pos:
62+
filtered_args.pop(group_ids_pos)
63+
64+
return await func(
65+
self,
66+
*filtered_args,
67+
**{**kwargs, 'group_ids': [gid], 'driver': driver.clone(database=gid)},
68+
)
69+
70+
results = await semaphore_gather(
71+
*[execute_for_group(gid) for gid in group_ids],
72+
max_coroutines=getattr(self, 'max_coroutines', None),
73+
)
74+
75+
# Merge results based on type
76+
if isinstance(results[0], SearchResults):
77+
return SearchResults.merge(results)
78+
elif isinstance(results[0], list):
79+
return [item for result in results for item in result]
80+
elif isinstance(results[0], tuple):
81+
# Handle tuple outputs (like build_communities returning (nodes, edges))
82+
merged_tuple = []
83+
for i in range(len(results[0])):
84+
component_results = [result[i] for result in results]
85+
if isinstance(component_results[0], list):
86+
merged_tuple.append(
87+
[item for component in component_results for item in component]
88+
)
89+
else:
90+
merged_tuple.append(component_results)
91+
return tuple(merged_tuple)
92+
else:
93+
return results
94+
95+
# Normal execution
96+
return await func(self, *args, **kwargs)
97+
98+
return wrapper # type: ignore
99+
100+
101+
def get_parameter_position(func: Callable, param_name: str) -> int | None:
102+
"""
103+
Returns the positional index of a parameter in the function signature.
104+
If the parameter is not found, returns None.
105+
"""
106+
sig = inspect.signature(func)
107+
for idx, (name, _param) in enumerate(sig.parameters.items()):
108+
if name == param_name:
109+
return idx
110+
return None

graphiti_core/driver/driver.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class GraphDriver(ABC):
7676
'' # Neo4j (default) syntax does not require a prefix for fulltext queries
7777
)
7878
_database: str
79+
default_group_id: str = ''
7980
search_interface: SearchInterface | None = None
8081
graph_operations_interface: GraphOperationsInterface | None = None
8182

@@ -105,6 +106,14 @@ def with_database(self, database: str) -> 'GraphDriver':
105106

106107
return cloned
107108

109+
@abstractmethod
110+
async def build_indices_and_constraints(self, delete_existing: bool = False):
111+
raise NotImplementedError()
112+
113+
def clone(self, database: str) -> 'GraphDriver':
114+
"""Clone the driver with a different database or graph name."""
115+
return self
116+
108117
def build_fulltext_query(
109118
self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
110119
) -> str:

graphiti_core/driver/falkordb_driver.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
) from None
3535

3636
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
37+
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
3738
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
3839

3940
logger = logging.getLogger(__name__)
@@ -112,6 +113,8 @@ async def run(self, query: str | list, **kwargs: Any) -> Any:
112113

113114
class FalkorDriver(GraphDriver):
114115
provider = GraphProvider.FALKORDB
116+
default_group_id: str = '\\_'
117+
fulltext_syntax: str = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries
115118
aoss_client: None = None
116119

117120
def __init__(
@@ -129,17 +132,32 @@ def __init__(
129132
FalkorDB is a multi-tenant graph database.
130133
To connect, provide the host and port.
131134
The default parameters assume a local (on-premises) FalkorDB instance.
135+
136+
Args:
137+
host (str): The host where FalkorDB is running.
138+
port (int): The port on which FalkorDB is listening.
139+
username (str | None): The username for authentication (if required).
140+
password (str | None): The password for authentication (if required).
141+
falkor_db (FalkorDB | None): An existing FalkorDB instance to use instead of creating a new one.
142+
database (str): The name of the database to connect to. Defaults to 'default_db'.
132143
"""
133144
super().__init__()
134-
135145
self._database = database
136146
if falkor_db is not None:
137147
# If a FalkorDB instance is provided, use it directly
138148
self.client = falkor_db
139149
else:
140150
self.client = FalkorDB(host=host, port=port, username=username, password=password)
141151

142-
self.fulltext_syntax = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries see https://redis.io/docs/latest/develop/ai/search-and-query/query/full-text/
152+
# Schedule the indices and constraints to be built
153+
try:
154+
# Try to get the current event loop
155+
loop = asyncio.get_running_loop()
156+
# Schedule the build_indices_and_constraints to run
157+
loop.create_task(self.build_indices_and_constraints())
158+
except RuntimeError:
159+
# No event loop running, this will be handled later
160+
pass
143161

144162
def _get_graph(self, graph_name: str | None) -> FalkorGraph:
145163
# FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
@@ -224,12 +242,25 @@ async def delete_all_indexes(self) -> None:
224242
if drop_tasks:
225243
await asyncio.gather(*drop_tasks)
226244

245+
async def build_indices_and_constraints(self, delete_existing=False):
246+
if delete_existing:
247+
await self.delete_all_indexes()
248+
index_queries = get_range_indices(self.provider) + get_fulltext_indices(self.provider)
249+
for query in index_queries:
250+
await self.execute_query(query)
251+
227252
def clone(self, database: str) -> 'GraphDriver':
228253
"""
229254
Returns a shallow copy of this driver with a different default database.
230255
Reuses the same connection (e.g. FalkorDB, Neo4j).
231256
"""
232-
cloned = FalkorDriver(falkor_db=self.client, database=database)
257+
if database == self._database:
258+
cloned = self
259+
elif database == self.default_group_id:
260+
cloned = FalkorDriver(falkor_db=self.client)
261+
else:
262+
# Create a new instance of FalkorDriver with the same connection but a different database
263+
cloned = FalkorDriver(falkor_db=self.client, database=database)
233264

234265
return cloned
235266

graphiti_core/driver/neo4j_driver.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,15 @@
2222
from typing_extensions import LiteralString
2323

2424
from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
25+
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
26+
from graphiti_core.helpers import semaphore_gather
2527

2628
logger = logging.getLogger(__name__)
2729

2830

2931
class Neo4jDriver(GraphDriver):
3032
provider = GraphProvider.NEO4J
33+
default_group_id: str = ''
3134

3235
def __init__(
3336
self,
@@ -43,6 +46,18 @@ def __init__(
4346
)
4447
self._database = database
4548

49+
# Schedule the indices and constraints to be built
50+
import asyncio
51+
52+
try:
53+
# Try to get the current event loop
54+
loop = asyncio.get_running_loop()
55+
# Schedule the build_indices_and_constraints to run
56+
loop.create_task(self.build_indices_and_constraints())
57+
except RuntimeError:
58+
# No event loop running, this will be handled later
59+
pass
60+
4661
self.aoss_client = None
4762

4863
async def execute_query(self, cypher_query_: LiteralString, **kwargs: Any) -> EagerResult:
@@ -73,6 +88,25 @@ def delete_all_indexes(self) -> Coroutine:
7388
'CALL db.indexes() YIELD name DROP INDEX name',
7489
)
7590

91+
async def build_indices_and_constraints(self, delete_existing: bool = False):
92+
if delete_existing:
93+
await self.delete_all_indexes()
94+
95+
range_indices: list[LiteralString] = get_range_indices(self.provider)
96+
97+
fulltext_indices: list[LiteralString] = get_fulltext_indices(self.provider)
98+
99+
index_queries: list[LiteralString] = range_indices + fulltext_indices
100+
101+
await semaphore_gather(
102+
*[
103+
self.execute_query(
104+
query,
105+
)
106+
for query in index_queries
107+
]
108+
)
109+
76110
async def health_check(self) -> None:
77111
"""Check Neo4j connectivity by running the driver's verify_connectivity method."""
78112
try:

0 commit comments

Comments
 (0)