Skip to content

Commit 5b5094b

Browse files
committed
feat: refactor pilot logging to include a business logic layer
1 parent 6b6e2c7 commit 5b5094b

File tree

5 files changed

+103
-77
lines changed

5 files changed

+103
-77
lines changed

diracx-core/src/diracx/core/models.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,3 +253,14 @@ class VOInfo(TypedDict):
253253

254254
class Metadata(TypedDict):
255255
virtual_organizations: dict[str, VOInfo]
256+
257+
258+
class LogLine(BaseModel):
259+
line_no: int
260+
line: str
261+
262+
263+
class LogMessage(BaseModel):
264+
pilot_stamp: str
265+
lines: list[LogLine]
266+
vo: str

diracx-db/src/diracx/db/os/pilot_logs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,17 @@ def index_name(self, doc_id: int) -> str:
1919
# TODO decide how to define the index name
2020
# use pilot ID
2121
return f"{self.index_prefix}_{doc_id // 1e6:.0f}"
22+
23+
24+
async def search_message(db: PilotLogsDB, search_params: list[dict]):
25+
26+
return await db.search(
27+
["Message"],
28+
search_params,
29+
[{"parameter": "LineNumber", "direction": "asc"}],
30+
)
31+
32+
33+
async def bulk_insert(db: PilotLogsDB, docs: list[dict], pilot_id: int):
34+
35+
await db.bulk_insert(db.index_name(pilot_id), docs)

diracx-logic/src/diracx/logic/pilots/__init__.py

Whitespace-only changes.
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
5+
from diracx.core.models import LogMessage, ScalarSearchOperator, ScalarSearchSpec
6+
from diracx.db.os.pilot_logs import PilotLogsDB, bulk_insert, search_message
7+
from diracx.db.sql.pilot_agents.db import PilotAgentsDB
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
async def send_message(
13+
data: LogMessage,
14+
pilot_logs_db: PilotLogsDB,
15+
pilot_agents_db: PilotAgentsDB,
16+
) -> int:
17+
18+
# get the pilot ID corresponding to a given pilot stamp, expecting exactly one row:
19+
search_params = ScalarSearchSpec(
20+
parameter="PilotStamp",
21+
operator=ScalarSearchOperator.EQUAL,
22+
value=data.pilot_stamp,
23+
)
24+
25+
total, result = await pilot_agents_db.search(
26+
["PilotID", "VO", "SubmissionTime"], [search_params], []
27+
)
28+
if total != 1:
29+
logger.error(
30+
"Cannot determine PilotID for requested PilotStamp: %r, (%d candidates)",
31+
data.pilot_stamp,
32+
total,
33+
)
34+
raise Exception(f"Number of rows !=1 {total}")
35+
36+
pilot_id, vo, submission_time = (
37+
result[0]["PilotID"],
38+
result[0]["VO"],
39+
result[0]["SubmissionTime"],
40+
)
41+
docs = []
42+
for line in data.lines:
43+
docs.append(
44+
{
45+
"PilotStamp": data.pilot_stamp,
46+
"PilotID": pilot_id,
47+
"SubmissionTime": submission_time,
48+
"VO": vo,
49+
"LineNumber": line.line_no,
50+
"Message": line.line,
51+
}
52+
)
53+
# bulk insert pilot logs to OpenSearch DB:
54+
await bulk_insert(pilot_logs_db, docs, pilot_id)
55+
return pilot_id
56+
57+
58+
async def get_logs(
59+
pilot_id: int,
60+
db: PilotLogsDB,
61+
) -> list[dict]:
62+
63+
search_params = [{"parameter": "PilotID", "operator": "eq", "value": pilot_id}]
64+
65+
result = await search_message(db, search_params)
66+
67+
if not result:
68+
return [{"Message": f"No logs for pilot ID = {pilot_id}"}]
69+
return result

diracx-routers/src/diracx/routers/pilots/logging.py

Lines changed: 9 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import logging
44

55
from fastapi import HTTPException, status
6-
from pydantic import BaseModel
76

8-
from diracx.core.models import ScalarSearchOperator, ScalarSearchSpec
7+
from diracx.core.models import LogMessage
8+
from diracx.logic.pilots.logging import get_logs as get_logs_bl
9+
from diracx.logic.pilots.logging import send_message as send_message_bl
910

1011
from ..access_policies import open_access
1112
from ..dependencies import PilotAgentsDB, PilotLogsDB
@@ -16,22 +17,6 @@
1617
router = DiracxRouter()
1718

1819

19-
class LogLine(BaseModel):
20-
line_no: int
21-
line: str
22-
23-
24-
class LogMessage(BaseModel):
25-
pilot_stamp: str
26-
lines: list[LogLine]
27-
vo: str
28-
29-
30-
class DateRange(BaseModel):
31-
min: str | None = None # expects a string in ISO 8601 ("%Y-%m-%dT%H:%M:%S.%f%z")
32-
max: str | None = None # expects a string in ISO 8601 ("%Y-%m-%dT%H:%M:%S.%f%z")
33-
34-
3520
@open_access
3621
@router.post("/")
3722
async def send_message(
@@ -41,55 +26,11 @@ async def send_message(
4126
# user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)],
4227
) -> int:
4328

44-
# expecting exactly one row:
45-
search_params = ScalarSearchSpec(
46-
parameter="PilotStamp",
47-
operator=ScalarSearchOperator.EQUAL,
48-
value=data.pilot_stamp,
49-
)
50-
51-
total, result = await pilot_agents_db.search(
52-
["PilotID", "VO", "SubmissionTime"], [search_params], []
53-
)
54-
if total != 1:
55-
logger.error(
56-
"Cannot determine PilotID for requested PilotStamp: %r, (%d candidates)",
57-
data.pilot_stamp,
58-
total,
59-
)
60-
raise HTTPException(
61-
status.HTTP_400_BAD_REQUEST, detail=f"Number of rows !=1: {total}"
62-
)
63-
pilot_id, vo, submission_time = (
64-
result[0]["PilotID"],
65-
result[0]["VO"],
66-
result[0]["SubmissionTime"],
67-
)
68-
69-
# await check_permissions(action=ActionType.CREATE, pilot_agent_db, pilot_id),
70-
71-
docs = []
72-
for line in data.lines:
73-
docs.append(
74-
{
75-
"PilotStamp": data.pilot_stamp,
76-
"PilotID": pilot_id,
77-
"SubmissionTime": submission_time,
78-
"VO": vo,
79-
"LineNumber": line.line_no,
80-
"Message": line.line,
81-
}
82-
)
83-
await pilot_logs_db.bulk_insert(pilot_logs_db.index_name(pilot_id), docs)
84-
"""
85-
search_params = [{"parameter": "PilotID", "operator": "eq", "value": pilot_id}]
86-
87-
result = await pilot_logs_db.search(
88-
["Message"],
89-
search_params,
90-
[{"parameter": "LineNumber", "direction": "asc"}],
91-
)
92-
"""
29+
# await check_permissions(action=ActionType.CREATE, pilot_agent_db, pilot_id)
30+
try:
31+
pilot_id = await send_message_bl(data, pilot_logs_db, pilot_agents_db)
32+
except Exception as exc:
33+
raise HTTPException(status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
9334
return pilot_id
9435

9536

@@ -107,13 +48,4 @@ async def get_logs(
10748
action=ActionType.QUERY, pilot_agents_db=pilot_agents_db, pilot_id=pilot_id
10849
)
10950

110-
search_params = [{"parameter": "PilotID", "operator": "eq", "value": pilot_id}]
111-
112-
result = await db.search(
113-
["Message"],
114-
search_params,
115-
[{"parameter": "LineNumber", "direction": "asc"}],
116-
)
117-
if not result:
118-
return [{"Message": f"No logs for pilot ID = {pilot_id}"}]
119-
return result
51+
return await get_logs_bl(pilot_id, db)

0 commit comments

Comments
 (0)