diff --git a/diracx-core/src/diracx/core/models.py b/diracx-core/src/diracx/core/models.py index eecb738fd..4d0866c33 100644 --- a/diracx-core/src/diracx/core/models.py +++ b/diracx-core/src/diracx/core/models.py @@ -271,3 +271,14 @@ class JobCommand(BaseModel): job_id: int command: Literal["Kill"] arguments: str | None = None + + +class LogLine(BaseModel): + line_no: int + line: str + + +class LogMessage(BaseModel): + pilot_stamp: str + lines: list[LogLine] + vo: str diff --git a/diracx-db/pyproject.toml b/diracx-db/pyproject.toml index a48c45441..1fcf48933 100644 --- a/diracx-db/pyproject.toml +++ b/diracx-db/pyproject.toml @@ -36,6 +36,7 @@ TaskQueueDB = "diracx.db.sql:TaskQueueDB" [project.entry-points."diracx.dbs.os"] JobParametersDB = "diracx.db.os:JobParametersDB" +PilotLogsDB = "diracx.db.os:PilotLogsDB" [tool.setuptools.packages.find] where = ["src"] diff --git a/diracx-db/src/diracx/db/os/__init__.py b/diracx-db/src/diracx/db/os/__init__.py index 535e2a954..c1ce89bcb 100644 --- a/diracx-db/src/diracx/db/os/__init__.py +++ b/diracx-db/src/diracx/db/os/__init__.py @@ -1,5 +1,9 @@ from __future__ import annotations -__all__ = ("JobParametersDB",) +__all__ = ( + "JobParametersDB", + "PilotLogsDB", +) from .job_parameters import JobParametersDB +from .pilot_logs import PilotLogsDB diff --git a/diracx-db/src/diracx/db/os/pilot_logs.py b/diracx-db/src/diracx/db/os/pilot_logs.py new file mode 100644 index 000000000..ab48096cb --- /dev/null +++ b/diracx-db/src/diracx/db/os/pilot_logs.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from diracx.db.os.utils import BaseOSDB + + +class PilotLogsDB(BaseOSDB): + fields = { + "PilotStamp": {"type": "keyword"}, + "PilotID": {"type": "long"}, + "SubmissionTime": {"type": "date"}, + "LineNumber": {"type": "long"}, + "Message": {"type": "text"}, + "VO": {"type": "keyword"}, + "timestamp": {"type": "date"}, + } + index_prefix = "pilot_logs" + + def index_name(self, doc_id: int) -> str: + # TODO decide how to define the index name + # use pilot ID + return f"{self.index_prefix}_{doc_id // 1e6:.0f}" + + +async def search_message(db: PilotLogsDB, search_params: list[dict]): + + return await db.search( + ["Message"], + search_params, + [{"parameter": "LineNumber", "direction": "asc"}], + ) diff --git a/diracx-db/src/diracx/db/os/utils.py b/diracx-db/src/diracx/db/os/utils.py index ea5d292e6..8e52ab878 100644 --- a/diracx-db/src/diracx/db/os/utils.py +++ b/diracx-db/src/diracx/db/os/utils.py @@ -13,6 +13,7 @@ from typing import Any, Self from opensearchpy import AsyncOpenSearch +from opensearchpy.helpers import async_bulk from diracx.core.exceptions import InvalidQueryError from diracx.core.extensions import select_from_extension @@ -197,6 +198,13 @@ async def upsert(self, vo: str, doc_id: int, document: Any) -> None: response, ) + async def bulk_insert(self, index_name: str, docs: list[dict[str, Any]]) -> None: + """Bulk inserting to database.""" + n_inserted = await async_bulk( + self.client, actions=[doc | {"_index": index_name} for doc in docs] + ) + logger.info("Inserted %d documents to %r", n_inserted, index_name) + async def search( self, parameters, search, sorts, *, per_page: int = 100, page: int | None = None ) -> list[dict[str, Any]]: @@ -238,6 +246,17 @@ async def search( return hits + async def delete(self, query: list[dict[str, Any]]) -> dict: + """Delete multiple documents by query.""" + body = {} + res = {} + if query: + body["query"] = apply_search_filters(self.fields, query) + res = await self.client.delete_by_query( + body=body, index=f"{self.index_prefix}*" + ) + return res + def require_type(operator, field_name, field_type, allowed_types): if field_type not in allowed_types: diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index 8b7fb1025..f5ce1df11 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -13,7 +13,7 @@ from diracx.core.exceptions import InvalidQueryError from diracx.core.models import JobCommand, SearchSpec, SortSpec -from ..utils import BaseSQLDB, apply_search_filters, apply_sort_constraints +from ..utils import BaseSQLDB, apply_search_filters, apply_sort_constraints, get_columns from ..utils.functions import utcnow from .schema import ( HeartBeatLoggingInfo, @@ -25,17 +25,6 @@ ) -def _get_columns(table, parameters): - columns = [x for x in table.columns] - if parameters: - if unrecognised_parameters := set(parameters) - set(table.columns.keys()): - raise InvalidQueryError( - f"Unrecognised parameters requested {unrecognised_parameters}" - ) - columns = [c for c in columns if c.name in parameters] - return columns - - class JobDB(BaseSQLDB): metadata = JobDBBase.metadata @@ -56,7 +45,7 @@ class JobDB(BaseSQLDB): async def summary(self, group_by, search) -> list[dict[str, str | int]]: """Get a summary of the jobs.""" - columns = _get_columns(Jobs.__table__, group_by) + columns = get_columns(Jobs.__table__, group_by) stmt = select(*columns, func.count(Jobs.job_id).label("count")) stmt = apply_search_filters(Jobs.__table__.columns.__getitem__, stmt, search) @@ -81,7 +70,7 @@ async def search( ) -> tuple[int, list[dict[Any, Any]]]: """Search for jobs in the database.""" # Find which columns to select - columns = _get_columns(Jobs.__table__, parameters) + columns = get_columns(Jobs.__table__, parameters) stmt = select(*columns) @@ -267,7 +256,7 @@ async def set_properties( required_parameters = list(required_parameters_set)[0] update_parameters = [{"job_id": k, **v} for k, v in properties.items()] - columns = _get_columns(Jobs.__table__, required_parameters) + columns = get_columns(Jobs.__table__, required_parameters) values: dict[str, BindParameter[Any] | datetime] = { c.name: bindparam(c.name) for c in columns } diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/db.py b/diracx-db/src/diracx/db/sql/pilot_agents/db.py index b4f801b78..51c0ce0f0 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/db.py +++ b/diracx-db/src/diracx/db/sql/pilot_agents/db.py @@ -1,10 +1,17 @@ from __future__ import annotations from datetime import datetime, timezone +from typing import Any -from sqlalchemy import insert +from sqlalchemy import func, insert, select -from ..utils import BaseSQLDB +from diracx.core.exceptions import InvalidQueryError +from diracx.core.models import ( + SearchSpec, + SortSpec, +) + +from ..utils import BaseSQLDB, apply_search_filters, apply_sort_constraints, get_columns from .schema import PilotAgents, PilotAgentsDBBase @@ -44,3 +51,46 @@ async def add_pilot_references( stmt = insert(PilotAgents).values(values) await self.conn.execute(stmt) return + + async def search( + self, + parameters: list[str] | None, + search: list[SearchSpec], + sorts: list[SortSpec], + *, + distinct: bool = False, + per_page: int = 100, + page: int | None = None, + ) -> tuple[int, list[dict[Any, Any]]]: + # Find which columns to select + columns = get_columns(PilotAgents.__table__, parameters) + + stmt = select(*columns) + + stmt = apply_search_filters( + PilotAgents.__table__.columns.__getitem__, stmt, search + ) + stmt = apply_sort_constraints( + PilotAgents.__table__.columns.__getitem__, stmt, sorts + ) + + if distinct: + stmt = stmt.distinct() + + # Calculate total count before applying pagination + total_count_subquery = stmt.alias() + total_count_stmt = select(func.count()).select_from(total_count_subquery) + total = (await self.conn.execute(total_count_stmt)).scalar_one() + + # Apply pagination + if page is not None: + if page < 1: + raise InvalidQueryError("Page must be a positive integer") + if per_page < 1: + raise InvalidQueryError("Per page must be a positive integer") + stmt = stmt.offset((page - 1) * per_page).limit(per_page) + + # Execute the query + return total, [ + dict(row._mapping) async for row in (await self.conn.stream(stmt)) + ] diff --git a/diracx-db/src/diracx/db/sql/utils/__init__.py b/diracx-db/src/diracx/db/sql/utils/__init__.py index 69b78b4bf..b8c27b041 100644 --- a/diracx-db/src/diracx/db/sql/utils/__init__.py +++ b/diracx-db/src/diracx/db/sql/utils/__init__.py @@ -5,6 +5,7 @@ SQLDBUnavailableError, apply_search_filters, apply_sort_constraints, + get_columns, ) from .functions import hash, substract_date, utcnow from .types import Column, DateNowColumn, EnumBackedBool, EnumColumn, NullColumn @@ -19,6 +20,7 @@ "EnumColumn", "apply_search_filters", "apply_sort_constraints", + "get_columns", "substract_date", "hash", "SQLDBUnavailableError", diff --git a/diracx-db/src/diracx/db/sql/utils/base.py b/diracx-db/src/diracx/db/sql/utils/base.py index b02b8aded..505798a18 100644 --- a/diracx-db/src/diracx/db/sql/utils/base.py +++ b/diracx-db/src/diracx/db/sql/utils/base.py @@ -258,6 +258,17 @@ def find_time_resolution(value): raise InvalidQueryError(f"Cannot parse {value=}") +def get_columns(table, parameters): + columns = [x for x in table.columns] + if parameters: + if unrecognised_parameters := set(parameters) - set(table.columns.keys()): + raise InvalidQueryError( + f"Unrecognised parameters requested {unrecognised_parameters}" + ) + columns = [c for c in columns if c.name in parameters] + return columns + + def apply_search_filters(column_mapping, stmt, search): for query in search: try: diff --git a/diracx-logic/src/diracx/logic/jobs/status.py b/diracx-logic/src/diracx/logic/jobs/status.py index a5f2b2dc0..e54c9bc74 100644 --- a/diracx-logic/src/diracx/logic/jobs/status.py +++ b/diracx-logic/src/diracx/logic/jobs/status.py @@ -41,7 +41,7 @@ VectorSearchSpec, ) from diracx.db.os.job_parameters import JobParametersDB -from diracx.db.sql.job.db import JobDB, _get_columns +from diracx.db.sql.job.db import JobDB, get_columns from diracx.db.sql.job.schema import Jobs from diracx.db.sql.job_logging.db import JobLoggingDB from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB @@ -508,7 +508,7 @@ async def set_job_parameters_or_attributes( ): """Set job parameters or attributes for a list of jobs.""" attribute_columns: list[str] = [ - col.name for col in _get_columns(Jobs.__table__, None) + col.name for col in get_columns(Jobs.__table__, None) ] attribute_columns_lower: list[str] = [col.lower() for col in attribute_columns] diff --git a/diracx-logic/src/diracx/logic/pilots/__init__.py b/diracx-logic/src/diracx/logic/pilots/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/diracx-logic/src/diracx/logic/pilots/logging.py b/diracx-logic/src/diracx/logic/pilots/logging.py new file mode 100644 index 000000000..c2da5d930 --- /dev/null +++ b/diracx-logic/src/diracx/logic/pilots/logging.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import logging + +from diracx.core.models import LogMessage, ScalarSearchOperator, ScalarSearchSpec +from diracx.db.os.pilot_logs import PilotLogsDB, search_message +from diracx.db.sql.pilot_agents.db import PilotAgentsDB + +logger = logging.getLogger(__name__) + + +async def send_message( + data: LogMessage, + pilot_logs_db: PilotLogsDB, + pilot_agents_db: PilotAgentsDB, +) -> int: + + # get the pilot ID corresponding to a given pilot stamp, expecting exactly one row: + search_params = ScalarSearchSpec( + parameter="PilotStamp", + operator=ScalarSearchOperator.EQUAL, + value=data.pilot_stamp, + ) + + total, result = await pilot_agents_db.search( + ["PilotID", "VO", "SubmissionTime"], [search_params], [] + ) + if total != 1: + logger.error( + "Cannot determine PilotID for requested PilotStamp: %r, (%d candidates)", + data.pilot_stamp, + total, + ) + raise Exception(f"Number of rows !=1 {total}") + + pilot_id, vo, submission_time = ( + result[0]["PilotID"], + result[0]["VO"], + result[0]["SubmissionTime"], + ) + docs = [] + for line in data.lines: + docs.append( + { + "PilotStamp": data.pilot_stamp, + "PilotID": pilot_id, + "SubmissionTime": submission_time, + "VO": vo, + "LineNumber": line.line_no, + "Message": line.line, + } + ) + # bulk insert pilot logs to OpenSearch DB: + await pilot_logs_db.bulk_insert(pilot_logs_db.index_name(pilot_id), docs) + return pilot_id + + +async def get_logs( + pilot_id: int, + db: PilotLogsDB, +) -> list[dict]: + + search_params = [{"parameter": "PilotID", "operator": "eq", "value": pilot_id}] + + result = await search_message(db, search_params) + + if not result: + return [{"Message": f"No logs for pilot ID = {pilot_id}"}] + return result diff --git a/diracx-routers/pyproject.toml b/diracx-routers/pyproject.toml index fdc5bc0e3..b4238418d 100644 --- a/diracx-routers/pyproject.toml +++ b/diracx-routers/pyproject.toml @@ -46,10 +46,12 @@ auth = "diracx.routers.auth:router" config = "diracx.routers.configuration:router" health = "diracx.routers.health:router" jobs = "diracx.routers.jobs:router" +pilots = "diracx.routers.pilots:router" [project.entry-points."diracx.access_policies"] WMSAccessPolicy = "diracx.routers.jobs.access_policies:WMSAccessPolicy" SandboxAccessPolicy = "diracx.routers.jobs.access_policies:SandboxAccessPolicy" +PilotLogsAccessPolicy = "diracx.routers.pilots.access_policies:PilotLogsAccessPolicy" # Minimum version of the client supported [project.entry-points."diracx.min_client_version"] diff --git a/diracx-routers/src/diracx/routers/dependencies.py b/diracx-routers/src/diracx/routers/dependencies.py index 8eb2bd265..8ce498658 100644 --- a/diracx-routers/src/diracx/routers/dependencies.py +++ b/diracx-routers/src/diracx/routers/dependencies.py @@ -8,6 +8,7 @@ "SandboxMetadataDB", "TaskQueueDB", "PilotAgentsDB", + "PilotLogsDB", "add_settings_annotation", "AvailableSecurityProperties", ) @@ -23,6 +24,7 @@ from diracx.core.settings import DevelopmentSettings as _DevelopmentSettings from diracx.core.settings import SandboxStoreSettings as _SandboxStoreSettings from diracx.db.os import JobParametersDB as _JobParametersDB +from diracx.db.os import PilotLogsDB as _PilotLogsDB from diracx.db.sql import AuthDB as _AuthDB from diracx.db.sql import JobDB as _JobDB from diracx.db.sql import JobLoggingDB as _JobLoggingDB @@ -38,7 +40,7 @@ def add_settings_annotation(cls: T) -> T: return Annotated[cls, Depends(cls.create)] # type: ignore -# Databases +# SQL Databases AuthDB = Annotated[_AuthDB, Depends(_AuthDB.transaction)] JobDB = Annotated[_JobDB, Depends(_JobDB.transaction)] JobLoggingDB = Annotated[_JobLoggingDB, Depends(_JobLoggingDB.transaction)] @@ -48,9 +50,9 @@ def add_settings_annotation(cls: T) -> T: ] TaskQueueDB = Annotated[_TaskQueueDB, Depends(_TaskQueueDB.transaction)] -# Opensearch databases +# OpenSearch Databases JobParametersDB = Annotated[_JobParametersDB, Depends(_JobParametersDB.session)] - +PilotLogsDB = Annotated[_PilotLogsDB, Depends(_PilotLogsDB.session)] # Miscellaneous Config = Annotated[_Config, Depends(ConfigSource.create)] diff --git a/diracx-routers/src/diracx/routers/pilots/__init__.py b/diracx-routers/src/diracx/routers/pilots/__init__.py new file mode 100644 index 000000000..3e9084bc7 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/__init__.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from logging import getLogger + +from ..fastapi_classes import DiracxRouter +from .logging import router as logging_router + +logger = getLogger(__name__) + +router = DiracxRouter() +router.include_router(logging_router) diff --git a/diracx-routers/src/diracx/routers/pilots/access_policies.py b/diracx-routers/src/diracx/routers/pilots/access_policies.py new file mode 100644 index 000000000..ea2b053ba --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/access_policies.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import logging +from enum import StrEnum, auto +from typing import Annotated, Callable + +from fastapi import Depends, HTTPException, status + +from diracx.core.models import ScalarSearchOperator, ScalarSearchSpec +from diracx.core.properties import ( + NORMAL_USER, +) +from diracx.routers.access_policies import BaseAccessPolicy + +from ..dependencies import PilotAgentsDB +from ..utils.users import AuthorizedUserInfo + +logger = logging.getLogger(__name__) + + +class ActionType(StrEnum): + #: Create/update pilot log records + CREATE = auto() + #: Search + QUERY = auto() + + +class PilotLogsAccessPolicy(BaseAccessPolicy): + """Rules: + Only NORMAL_USER in a correct VO and a diracAdmin VO member can query log records. + All other actions and users are explicitly denied access. + """ + + @staticmethod + async def policy( + policy_name: str, + user_info: AuthorizedUserInfo, + /, + *, + action: ActionType | None = None, + pilot_agents_db: PilotAgentsDB | None = None, + pilot_id: int | None = None, + ): + assert pilot_agents_db + if action is None: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, detail="Action is a mandatory argument" + ) + elif action == ActionType.QUERY: + if pilot_id is None: + logger.error("Pilot ID value is not provided (None)") + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + detail=f"PilotID not provided: {pilot_id}", + ) + search_params = ScalarSearchSpec( + parameter="PilotID", + operator=ScalarSearchOperator.EQUAL, + value=pilot_id, + ) + + total, result = await pilot_agents_db.search(["VO"], [search_params], []) + # we expect exactly one row. + if total != 1: + logger.error( + "Cannot determine VO for requested PilotID: %d, found %d candidates.", + pilot_id, + total, + ) + raise HTTPException( + status.HTTP_400_BAD_REQUEST, detail=f"PilotID not found: {pilot_id}" + ) + vo = result[0]["VO"] + + if user_info.vo == "diracAdmin": + return + + if NORMAL_USER in user_info.properties and user_info.vo == vo: + return + + raise HTTPException( + status.HTTP_403_FORBIDDEN, + detail="You don't have permission to access this pilot's log.", + ) + else: + raise NotImplementedError(action) + + +CheckPilotLogsPolicyCallable = Annotated[Callable, Depends(PilotLogsAccessPolicy.check)] diff --git a/diracx-routers/src/diracx/routers/pilots/logging.py b/diracx-routers/src/diracx/routers/pilots/logging.py new file mode 100644 index 000000000..ce7dea86a --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/logging.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import logging + +from fastapi import HTTPException, status + +from diracx.core.models import LogMessage +from diracx.logic.pilots.logging import get_logs as get_logs_bl +from diracx.logic.pilots.logging import send_message as send_message_bl + +from ..access_policies import open_access +from ..dependencies import PilotAgentsDB, PilotLogsDB +from ..fastapi_classes import DiracxRouter +from .access_policies import ActionType, CheckPilotLogsPolicyCallable + +logger = logging.getLogger(__name__) +router = DiracxRouter() + + +@open_access +@router.post("/") +async def send_message( + data: LogMessage, + pilot_logs_db: PilotLogsDB, + pilot_agents_db: PilotAgentsDB, + # user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], +) -> int: + + # await check_permissions(action=ActionType.CREATE, pilot_agent_db, pilot_id) + try: + pilot_id = await send_message_bl(data, pilot_logs_db, pilot_agents_db) + except Exception as exc: + raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc + return pilot_id + + +@router.get("/logs") +async def get_logs( + pilot_id: int, + db: PilotLogsDB, + pilot_agents_db: PilotAgentsDB, + check_permissions: CheckPilotLogsPolicyCallable, +) -> list[dict]: + + logger.debug("Retrieving logs for pilot ID %d", pilot_id) + # users will only see logs from their own VO if enforced by a policy: + await check_permissions( + action=ActionType.QUERY, pilot_agents_db=pilot_agents_db, pilot_id=pilot_id + ) + + return await get_logs_bl(pilot_id, db) diff --git a/diracx-routers/tests/pilots/test_pilot_logger.py b/diracx-routers/tests/pilots/test_pilot_logger.py new file mode 100644 index 000000000..40319466e --- /dev/null +++ b/diracx-routers/tests/pilots/test_pilot_logger.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import pytest +from fastapi.testclient import TestClient + +from diracx.routers.utils.users import AuthSettings + +pytestmark = pytest.mark.enabled_dependencies( + [ + "AuthSettings", + "PilotAgentsDB", + "PilotLogsDB", + "PilotLogsAccessPolicy", + "DevelopmentSettings", + ] +) + + +@pytest.fixture +def normal_user_client(client_factory): + with client_factory.normal_user() as client: + yield client + + +async def test_send_and_retrieve_logs( + normal_user_client: TestClient, test_auth_settings: AuthSettings +): + + from diracx.db.sql import PilotAgentsDB + + # Add a pilot reference + upper_limit = 6 + refs = [f"ref_{i}" for i in range(1, upper_limit)] + stamps = [f"stamp_{i}" for i in range(1, upper_limit)] + stamp_dict = dict(zip(refs, stamps)) + + db = normal_user_client.app.dependency_overrides[PilotAgentsDB.transaction].args[0] + + async with db: + await db.add_pilot_references( + refs, "test_vo", grid_type="DIRAC", pilot_stamps=stamp_dict + ) + + msg = ( + "2022-02-26 13:48:35.123456 UTC DEBUG [PilotParams] JSON file loaded: pilot.json\n" + "2022-02-26 13:48:36.123456 UTC DEBUG [PilotParams] JSON file analysed: pilot.json" + ) + # message dict + lines = [] + for i, line in enumerate(msg.split("\n")): + lines.append({"line_no": i, "line": line}) + msg_dict = {"lines": lines, "pilot_stamp": "stamp_1", "vo": "diracAdmin"} + + # send message + r = normal_user_client.post("/api/pilots/", json=msg_dict) + + assert r.status_code == 200, r.text + # it just returns the pilot id corresponding for pilot stamp. + assert r.json() == 1 + # get the message back: + r = normal_user_client.get("/api/pilots/logs?pilot_id=1") + assert r.status_code == 200, r.text + assert [next(iter(d.values())) for d in r.json()] == msg.split("\n") diff --git a/diracx-testing/src/diracx/testing/mock_osdb.py b/diracx-testing/src/diracx/testing/mock_osdb.py index 5f7fe7f93..19c63b950 100644 --- a/diracx-testing/src/diracx/testing/mock_osdb.py +++ b/diracx-testing/src/diracx/testing/mock_osdb.py @@ -10,7 +10,7 @@ from functools import partial from typing import Any, AsyncIterator -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.dialects.sqlite import insert as sqlite_insert from diracx.core.models import SearchSpec, SortSpec @@ -100,6 +100,21 @@ async def upsert(self, vo, doc_id, document) -> None: stmt = stmt.on_conflict_do_update(index_elements=["doc_id"], set_=values) await self._sql_db.conn.execute(stmt) + async def bulk_insert(self, index_name: str, docs: list[dict[str, Any]]) -> None: + async with self._sql_db: + rows = [] + for doc in docs: + # don't use doc_id column explicitly. This ensures that doc_id is unique. + values = {} + for key, value in doc.items(): + if key in self.fields: + values[key] = value + else: + values.setdefault("extra", {})[key] = value + rows.append(values) + stmt = sqlite_insert(self._table).values(rows) + await self._sql_db.conn.execute(stmt) + async def search( self, parameters: list[str] | None, @@ -153,6 +168,14 @@ async def search( results.append(result) return results + async def delete(self, query: list[dict[str, Any]]) -> None: + async with self._sql_db: + stmt = delete(self._table) + stmt = sql_utils.apply_search_filters( + self._table.columns.__getitem__, stmt, query + ) + await self._sql_db.conn.execute(stmt) + async def ping(self): async with self._sql_db: return await self._sql_db.ping()