Skip to content

Commit 7fd8a8f

Browse files
committed
Move query and streaming_query to vector_stores
1 parent 8e414a2 commit 7fd8a8f

File tree

4 files changed

+60
-36
lines changed

4 files changed

+60
-36
lines changed

src/app/endpoints/query.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -683,10 +683,10 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
683683
),
684684
}
685685

686-
vector_db_ids = [
687-
vector_db.identifier for vector_db in await client.vector_dbs.list()
686+
vector_store_ids = [
687+
vector_store.id for vector_store in (await client.vector_stores.list()).data
688688
]
689-
toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [
689+
toolgroups = (get_rag_toolgroups(vector_store_ids) or []) + [
690690
mcp_server.name for mcp_server in configuration.mcp_servers
691691
]
692692
# Convert empty list to None for consistency with existing behavior
@@ -782,30 +782,30 @@ def validate_attachments_metadata(attachments: list[Attachment]) -> None:
782782

783783

784784
def get_rag_toolgroups(
785-
vector_db_ids: list[str],
785+
vector_store_ids: list[str],
786786
) -> list[Toolgroup] | None:
787787
"""
788-
Return a list of RAG Tool groups if the given vector DB list is not empty.
788+
Return a list of RAG Tool groups if the given vector store list is not empty.
789789
790790
Generate a list containing a RAG knowledge search toolgroup if
791-
vector database IDs are provided.
791+
vector store IDs are provided.
792792
793793
Parameters:
794-
vector_db_ids (list[str]): List of vector database identifiers to include in the toolgroup.
794+
vector_store_ids (list[str]): List of vector store identifiers to include in the toolgroup.
795795
796796
Returns:
797797
list[Toolgroup] | None: A list with a single RAG toolgroup if
798-
vector_db_ids is non-empty; otherwise, None.
798+
vector_store_ids is non-empty; otherwise, None.
799799
"""
800800
return (
801801
[
802802
ToolgroupAgentToolGroupWithArgs(
803803
name="builtin::rag/knowledge_search",
804804
args={
805-
"vector_db_ids": vector_db_ids,
805+
"vector_store_ids": vector_store_ids,
806806
},
807807
)
808808
]
809-
if vector_db_ids
809+
if vector_store_ids
810810
else None
811811
)

src/app/endpoints/streaming_query.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,10 +1031,10 @@ async def retrieve_response(
10311031
),
10321032
}
10331033

1034-
vector_db_ids = [
1035-
vector_db.identifier for vector_db in await client.vector_dbs.list()
1034+
vector_store_ids = [
1035+
vector_store.id for vector_store in (await client.vector_stores.list()).data
10361036
]
1037-
toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [
1037+
toolgroups = (get_rag_toolgroups(vector_store_ids) or []) + [
10381038
mcp_server.name for mcp_server in configuration.mcp_servers
10391039
]
10401040
# Convert empty list to None for consistency with existing behavior

tests/unit/app/endpoints/test_query.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,16 +1386,16 @@ async def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker) -
13861386

13871387
def test_get_rag_toolgroups() -> None:
13881388
"""Test get_rag_toolgroups function."""
1389-
vector_db_ids: list[str] = []
1390-
result = get_rag_toolgroups(vector_db_ids)
1389+
vector_store_ids: list[str] = []
1390+
result = get_rag_toolgroups(vector_store_ids)
13911391
assert result is None
13921392

1393-
vector_db_ids = ["Vector-DB-1", "Vector-DB-2"]
1394-
result = get_rag_toolgroups(vector_db_ids)
1393+
vector_store_ids = ["Vector-DB-1", "Vector-DB-2"]
1394+
result = get_rag_toolgroups(vector_store_ids)
13951395
assert result is not None
13961396
assert len(result) == 1
13971397
assert result[0]["name"] == "builtin::rag/knowledge_search"
1398-
assert result[0]["args"]["vector_db_ids"] == vector_db_ids
1398+
assert result[0]["args"]["vector_store_ids"] == vector_store_ids
13991399

14001400

14011401
@pytest.mark.asyncio

tests/unit/app/endpoints/test_streaming_query.py

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -440,9 +440,11 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker
440440
mock_client, mock_agent = prepare_agent_mocks
441441
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
442442
mock_client.shields.list.return_value = []
443-
mock_vector_db = mocker.Mock()
444-
mock_vector_db.identifier = "VectorDB-1"
445-
mock_client.vector_dbs.list.return_value = [mock_vector_db]
443+
mock_vector_store = mocker.Mock()
444+
mock_vector_store.id = "VectorDB-1"
445+
mock_list_response = mocker.Mock()
446+
mock_list_response.data = [mock_vector_store]
447+
mock_client.vector_stores.list.return_value = mock_list_response
446448

447449
# Mock configuration with empty MCP servers
448450
mock_config = mocker.Mock()
@@ -483,7 +485,9 @@ async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocke
483485
mock_client, mock_agent = prepare_agent_mocks
484486
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
485487
mock_client.shields.list.return_value = []
486-
mock_client.vector_dbs.list.return_value = []
488+
mock_list_response = mocker.Mock()
489+
mock_list_response.data = []
490+
mock_client.vector_stores.list.return_value = mock_list_response
487491

488492
# Mock configuration with empty MCP servers
489493
mock_config = mocker.Mock()
@@ -537,7 +541,9 @@ def __repr__(self):
537541
mock_client, mock_agent = prepare_agent_mocks
538542
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
539543
mock_client.shields.list.return_value = [MockShield("shield1")]
540-
mock_client.vector_dbs.list.return_value = []
544+
mock_list_response = mocker.Mock()
545+
mock_list_response.data = []
546+
mock_client.vector_stores.list.return_value = mock_list_response
541547

542548
# Mock configuration with empty MCP servers
543549
mock_config = mocker.Mock()
@@ -592,7 +598,9 @@ def __repr__(self):
592598
MockShield("shield1"),
593599
MockShield("shield2"),
594600
]
595-
mock_client.vector_dbs.list.return_value = []
601+
mock_list_response = mocker.Mock()
602+
mock_list_response.data = []
603+
mock_client.vector_stores.list.return_value = mock_list_response
596604

597605
# Mock configuration with empty MCP servers
598606
mock_config = mocker.Mock()
@@ -649,7 +657,9 @@ def __repr__(self):
649657
MockShield("output_shield3"),
650658
MockShield("inout_shield4"),
651659
]
652-
mock_client.vector_dbs.list.return_value = []
660+
mock_list_response = mocker.Mock()
661+
mock_list_response.data = []
662+
mock_client.vector_stores.list.return_value = mock_list_response
653663

654664
# Mock configuration with empty MCP servers
655665
mock_config = mocker.Mock()
@@ -700,7 +710,9 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker
700710
mock_client, mock_agent = prepare_agent_mocks
701711
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
702712
mock_client.shields.list.return_value = []
703-
mock_client.vector_dbs.list.return_value = []
713+
mock_list_response = mocker.Mock()
714+
mock_list_response.data = []
715+
mock_client.vector_stores.list.return_value = mock_list_response
704716

705717
# Mock configuration with empty MCP servers
706718
mock_config = mocker.Mock()
@@ -752,7 +764,9 @@ async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocke
752764
mock_client, mock_agent = prepare_agent_mocks
753765
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
754766
mock_client.shields.list.return_value = []
755-
mock_client.vector_dbs.list.return_value = []
767+
mock_list_response = mocker.Mock()
768+
mock_list_response.data = []
769+
mock_client.vector_stores.list.return_value = mock_list_response
756770

757771
# Mock configuration with empty MCP servers
758772
mock_config = mocker.Mock()
@@ -1157,7 +1171,9 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker):
11571171
mock_client, mock_agent = prepare_agent_mocks
11581172
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
11591173
mock_client.shields.list.return_value = []
1160-
mock_client.vector_dbs.list.return_value = []
1174+
mock_list_response = mocker.Mock()
1175+
mock_list_response.data = []
1176+
mock_client.vector_stores.list.return_value = mock_list_response
11611177

11621178
# Mock configuration with MCP servers
11631179
mcp_servers = [
@@ -1236,7 +1252,9 @@ async def test_retrieve_response_with_mcp_servers_empty_token(
12361252
mock_client, mock_agent = prepare_agent_mocks
12371253
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
12381254
mock_client.shields.list.return_value = []
1239-
mock_client.vector_dbs.list.return_value = []
1255+
mock_list_response = mocker.Mock()
1256+
mock_list_response.data = []
1257+
mock_client.vector_stores.list.return_value = mock_list_response
12401258

12411259
# Mock configuration with MCP servers
12421260
mcp_servers = [
@@ -1298,7 +1316,9 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker):
12981316
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
12991317
mock_client = mocker.AsyncMock()
13001318
mock_client.shields.list.return_value = []
1301-
mock_client.vector_dbs.list.return_value = []
1319+
mock_list_response = mocker.Mock()
1320+
mock_list_response.data = []
1321+
mock_client.vector_stores.list.return_value = mock_list_response
13021322

13031323
# Mock configuration with MCP servers
13041324
mcp_servers = [
@@ -1548,9 +1568,11 @@ async def test_retrieve_response_no_tools_bypasses_mcp_and_rag(
15481568
mock_client, mock_agent = prepare_agent_mocks
15491569
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
15501570
mock_client.shields.list.return_value = []
1551-
mock_vector_db = mocker.Mock()
1552-
mock_vector_db.identifier = "VectorDB-1"
1553-
mock_client.vector_dbs.list.return_value = [mock_vector_db]
1571+
mock_vector_store = mocker.Mock()
1572+
mock_vector_store.id = "VectorDB-1"
1573+
mock_list_response = mocker.Mock()
1574+
mock_list_response.data = [mock_vector_store]
1575+
mock_client.vector_stores.list.return_value = mock_list_response
15541576

15551577
# Mock configuration with MCP servers
15561578
mcp_servers = [
@@ -1598,9 +1620,11 @@ async def test_retrieve_response_no_tools_false_preserves_functionality(
15981620
mock_client, mock_agent = prepare_agent_mocks
15991621
mock_agent.create_turn.return_value.output_message.content = "LLM answer"
16001622
mock_client.shields.list.return_value = []
1601-
mock_vector_db = mocker.Mock()
1602-
mock_vector_db.identifier = "VectorDB-1"
1603-
mock_client.vector_dbs.list.return_value = [mock_vector_db]
1623+
mock_vector_store = mocker.Mock()
1624+
mock_vector_store.id = "VectorDB-1"
1625+
mock_list_response = mocker.Mock()
1626+
mock_list_response.data = [mock_vector_store]
1627+
mock_client.vector_stores.list.return_value = mock_list_response
16041628

16051629
# Mock configuration with MCP servers
16061630
mcp_servers = [

0 commit comments

Comments
 (0)