diff --git a/README.md b/README.md
index 1d63c26..360006b 100644
--- a/README.md
+++ b/README.md
@@ -5,7 +5,7 @@
*
DataOps Observability is part of DataKitchen's Open Source Data Observability. DataOps Observability monitors every data journey from data source to customer value, from any team development environment into production, across every tool, team, environment, and customer so that problems are detected, localized, and understood immediately.
*
-[](https://datakitchen.storylane.io/share/g01ss0plyamz)
+[](https://datakitchen.storylane.io/share/g01ss0plyamz)
[Interactive Product Tour](https://datakitchen.storylane.io/share/g01ss0plyamz)
## Developer Setup
@@ -100,9 +100,7 @@ We enforce the use of certain linting tools. To not get caught by the build-syst
The following hooks are enabled in pre-commit:
-- `black`: The black formatter is enforced on the project. We use a basic configuration. Ideally this should solve any and all
-formatting questions we might encounter.
-- `isort`: the isort import-sorter is enforced on the project. We use it with the `black` profile.
+- `ruff`: Handles code formatting, import sorting, and linting
To enable pre-commit from within your virtual environment, simply run:
diff --git a/agent_api/config/defaults.py b/agent_api/config/defaults.py
index d02c934..cb6cd3e 100644
--- a/agent_api/config/defaults.py
+++ b/agent_api/config/defaults.py
@@ -5,13 +5,12 @@
"""
import os
-from typing import Optional
# Flask specific settings: https://flask.palletsprojects.com/en/latest/config/#builtin-configuration-values
from common.entities import Service
-PROPAGATE_EXCEPTIONS: Optional[bool] = None
-SERVER_NAME: Optional[str] = os.environ.get("AGENT_API_HOSTNAME") # Use flask defaults if none set
+PROPAGATE_EXCEPTIONS: bool | None = None
+SERVER_NAME: str | None = os.environ.get("AGENT_API_HOSTNAME") # Use flask defaults if none set
USE_X_SENDFILE: bool = False # If we serve files enable this in production settings when webserver support configured
# Application settings
diff --git a/agent_api/config/local.py b/agent_api/config/local.py
index aaf5847..f063607 100644
--- a/agent_api/config/local.py
+++ b/agent_api/config/local.py
@@ -1,5 +1,3 @@
-from typing import Optional
-
# Flask specific settings: https://flask.palletsprojects.com/en/latest/config/#builtin-configuration-values
-PROPAGATE_EXCEPTIONS: Optional[bool] = True
+PROPAGATE_EXCEPTIONS: bool | None = True
SECRET_KEY: str = "NOT_VERY_SECRET"
diff --git a/agent_api/config/minikube.py b/agent_api/config/minikube.py
index e5ed0f8..6c72c79 100644
--- a/agent_api/config/minikube.py
+++ b/agent_api/config/minikube.py
@@ -1,5 +1,3 @@
-from typing import Optional
-
# Flask specific settings: https://flask.palletsprojects.com/en/latest/config/#builtin-configuration-values
-TESTING: Optional[bool] = True
+TESTING: bool | None = True
SECRET_KEY: str = "NOT_VERY_SECRET"
diff --git a/agent_api/endpoints/v1/heartbeat.py b/agent_api/endpoints/v1/heartbeat.py
index 451bd56..d01886e 100644
--- a/agent_api/endpoints/v1/heartbeat.py
+++ b/agent_api/endpoints/v1/heartbeat.py
@@ -1,7 +1,7 @@
import logging
-from datetime import datetime, timezone
+from datetime import datetime, UTC
from http import HTTPStatus
-from typing import Optional, Union, cast
+from typing import Union, cast
from uuid import UUID
from flask import Response, g, make_response
@@ -23,7 +23,7 @@ def _update_or_create(
version: str,
project_id: Union[str, UUID],
latest_heartbeat: datetime,
- latest_event_timestamp: Optional[datetime],
+ latest_event_timestamp: datetime | None,
) -> None:
try:
agent = Agent.select().where(Agent.key == key, Agent.tool == tool, Agent.project_id == project_id).get()
@@ -57,7 +57,7 @@ class Heartbeat(BaseView):
def post(self) -> Response:
data = self.parse_body(schema=HeartbeatSchema())
- data["latest_heartbeat"] = datetime.now(tz=timezone.utc)
+ data["latest_heartbeat"] = datetime.now(tz=UTC)
data["project_id"] = g.project.id
_update_or_create(**data)
return make_response("", HTTPStatus.NO_CONTENT)
diff --git a/agent_api/tests/integration/v1_endpoints/test_heartbeat.py b/agent_api/tests/integration/v1_endpoints/test_heartbeat.py
index ca2cfa5..0c6d08d 100644
--- a/agent_api/tests/integration/v1_endpoints/test_heartbeat.py
+++ b/agent_api/tests/integration/v1_endpoints/test_heartbeat.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
from http import HTTPStatus
import pytest
@@ -9,7 +9,7 @@
@pytest.mark.integration
def test_agent_heartbeat(client, database_ctx, headers):
- last_event_timestamp = datetime(2023, 10, 20, 4, 42, 42, tzinfo=timezone.utc)
+ last_event_timestamp = datetime(2023, 10, 20, 4, 42, 42, tzinfo=UTC)
data = {
"key": "test-key",
"tool": "test-tool",
@@ -35,7 +35,7 @@ def test_agent_heartbeat_no_event_timestamp(client, database_ctx, headers):
@pytest.mark.integration
def test_agent_heartbeat_update(client, database_ctx, headers):
- last_event_timestamp = datetime(2023, 10, 20, 4, 42, 42, tzinfo=timezone.utc)
+ last_event_timestamp = datetime(2023, 10, 20, 4, 42, 42, tzinfo=UTC)
data = {
"key": "test-key",
"tool": "test-tool",
@@ -47,7 +47,7 @@ def test_agent_heartbeat_update(client, database_ctx, headers):
assert HTTPStatus.NO_CONTENT == response_1.status_code, response_1.json
# The latest_event_timestamp should be older than "now"
- now = datetime.now(timezone.utc)
+ now = datetime.now(UTC)
agent_1 = Agent.select().get()
assert agent_1.latest_heartbeat < now
assert agent_1.status == AgentStatus.ONLINE
@@ -62,7 +62,7 @@ def test_agent_heartbeat_update(client, database_ctx, headers):
@pytest.mark.integration
def test_agent_heartbeat_existing_update(client, database_ctx, headers):
- last_event_timestamp = datetime(2023, 10, 20, 4, 42, 42, tzinfo=timezone.utc)
+ last_event_timestamp = datetime(2023, 10, 20, 4, 42, 42, tzinfo=UTC)
data_1 = {
"key": "test-key",
"tool": "test-tool",
@@ -79,7 +79,7 @@ def test_agent_heartbeat_existing_update(client, database_ctx, headers):
data_2 = data_1.copy()
data_2["version"] = "12.0.3"
- data_2["latest_event_timestamp"] = datetime(2023, 10, 20, 4, 44, 44, tzinfo=timezone.utc).isoformat()
+ data_2["latest_event_timestamp"] = datetime(2023, 10, 20, 4, 44, 44, tzinfo=UTC).isoformat()
response_2 = client.post("/agent/v1/heartbeat", json=data_2, headers=headers)
assert HTTPStatus.NO_CONTENT == response_2.status_code, response_2.json
diff --git a/cli/base.py b/cli/base.py
index 7d28942..889a187 100644
--- a/cli/base.py
+++ b/cli/base.py
@@ -4,7 +4,7 @@
from argparse import ArgumentParser
from logging.config import dictConfig
from pathlib import Path
-from typing import Any, Optional
+from typing import Any
from collections.abc import Callable
from log_color import ColorFormatter, ColorStripper
@@ -80,7 +80,7 @@ def __init__(self, **kwargs: Any) -> None:
LOG.info("#g<\u2714> Established #c<%s> connection to #c<%s>", DB.obj.__class__.__name__, DB.obj.database)
-def logging_init(*, level: str, logfile: Optional[str] = None) -> None:
+def logging_init(*, level: str, logfile: str | None = None) -> None:
"""Given the log level and an optional logging file location, configure all logging."""
# Don't bother with a file handler if we're not logging to a file
handlers = ["console", "filehandler"] if logfile else ["console"]
diff --git a/cli/entry_points/database_schema.py b/cli/entry_points/database_schema.py
index fa2376f..7a2c02a 100644
--- a/cli/entry_points/database_schema.py
+++ b/cli/entry_points/database_schema.py
@@ -1,6 +1,6 @@
import re
from argparse import ArgumentParser
-from typing import Any, Optional
+from typing import Any
from re import Pattern
from collections.abc import Iterable
@@ -19,7 +19,7 @@ class MysqlPrintDatabase(MySQLDatabase):
def __init__(self) -> None:
super().__init__("")
- def execute_sql(self, sql: str, params: Optional[Iterable[Any]] = None, commit: Optional[bool] = None) -> None:
+ def execute_sql(self, sql: str, params: Iterable[Any] | None = None, commit: bool | None = None) -> None:
if params:
raise Exception(f"Params are not expected to be needed to run DDL SQL, but found {params}")
if match := self._create_table_re.match(sql):
diff --git a/cli/entry_points/gen_events.py b/cli/entry_points/gen_events.py
index 955b54a..85f0d20 100644
--- a/cli/entry_points/gen_events.py
+++ b/cli/entry_points/gen_events.py
@@ -7,7 +7,7 @@
import time
from argparse import Action, ArgumentParser, Namespace
from datetime import datetime
-from typing import Any, Optional, Union
+from typing import Any, Union
from collections.abc import Sequence
from requests_extensions import get_session
@@ -90,7 +90,7 @@ def __call__(
parser: ArgumentParser,
namespace: Namespace,
values: Union[str, Sequence[Any], None],
- option_string: Optional[str] = None,
+ option_string: str | None = None,
) -> None:
event_data = {}
remove_fields = []
diff --git a/cli/entry_points/graph_schema.py b/cli/entry_points/graph_schema.py
index c433fe9..e44d6c6 100644
--- a/cli/entry_points/graph_schema.py
+++ b/cli/entry_points/graph_schema.py
@@ -2,6 +2,7 @@
import sys
from argparse import ArgumentParser
from pathlib import Path
+from typing import Any
from jinja2 import Environment, FileSystemLoader
from peewee import Field, ForeignKeyField, ManyToManyField, Model
@@ -54,7 +55,7 @@ def subcmd_entry_point(self) -> None:
dot_parts = [head.render({})]
# Initial context/config
- model_context = []
+ model_context: list[dict[str, Any]] = []
LOG.info("#m")
for name, model in model_map.items():
diff --git a/cli/lib.py b/cli/lib.py
index 80ed768..fd79dcb 100644
--- a/cli/lib.py
+++ b/cli/lib.py
@@ -2,7 +2,6 @@
import re
import textwrap
from argparse import ArgumentParser, ArgumentTypeError
-from typing import Optional
from uuid import UUID
from log_color.colors import ColorStr
@@ -21,7 +20,7 @@ def uuid_type(arg: str) -> UUID:
def slice_type(arg: str) -> slice:
"""Convert an argument to a slice; for simplicity, disallow negative slice values and steps."""
- def _int_or_none(val: str) -> Optional[int]:
+ def _int_or_none(val: str) -> int | None:
if not val:
return None
else:
diff --git a/common/actions/action.py b/common/actions/action.py
index 7eea0a7..18e062b 100644
--- a/common/actions/action.py
+++ b/common/actions/action.py
@@ -8,7 +8,7 @@
]
import logging
-from typing import Any, NamedTuple, Optional
+from typing import Any, NamedTuple
from uuid import UUID
from common.entities import Action, Rule
@@ -35,15 +35,15 @@ class InvalidActionTemplate(ActionException):
class ActionResult(NamedTuple):
result: bool
- response: Optional[dict]
- exception: Optional[Exception]
+ response: dict | None
+ exception: Exception | None
class BaseAction:
required_arguments: set = set()
requires_action_template: bool = False
- def __init__(self, action_template: Optional[Action], override_arguments: dict) -> None:
+ def __init__(self, action_template: Action | None, override_arguments: dict) -> None:
if self.requires_action_template and not action_template:
raise ActionTemplateRequired(f"'{self.__class__.__name__}' requires an action template to be set")
@@ -70,7 +70,7 @@ def _validate_args(self) -> None:
if missing_args:
raise ValueError(f"Required arguments {missing_args} missing for {self.__class__.__name__}")
- def _run(self, event: EVENT_TYPE, rule: Rule, journey_id: Optional[UUID]) -> ActionResult:
+ def _run(self, event: EVENT_TYPE, rule: Rule, journey_id: UUID | None) -> ActionResult:
raise NotImplementedError("Base Action cannot be executed")
def _store_action_result(self, action_result: ActionResult) -> None:
@@ -88,7 +88,7 @@ def _store_action_result(self, action_result: ActionResult) -> None:
exc_info=action_result.exception,
)
- def execute(self, event: EVENT_TYPE, rule: Rule, journey_id: Optional[UUID]) -> bool:
+ def execute(self, event: EVENT_TYPE, rule: Rule, journey_id: UUID | None) -> bool:
action_result = self._run(event, rule, journey_id)
self._store_action_result(action_result)
return action_result.result
diff --git a/common/actions/action_factory.py b/common/actions/action_factory.py
index 9aabcc1..b37a1c1 100644
--- a/common/actions/action_factory.py
+++ b/common/actions/action_factory.py
@@ -1,6 +1,5 @@
__all__ = ["ACTION_CLASS_MAP", "action_factory"]
-from typing import Optional
from common.entities import Action
@@ -11,7 +10,7 @@
ACTION_CLASS_MAP: dict[str, type[BaseAction]] = {"CALL_WEBHOOK": WebhookAction, "SEND_EMAIL": SendEmailAction}
-def action_factory(implementation: str, action_args: dict, template: Optional[Action]) -> BaseAction:
+def action_factory(implementation: str, action_args: dict, template: Action | None) -> BaseAction:
try:
action_class = ACTION_CLASS_MAP[implementation]
except KeyError as ke:
diff --git a/common/actions/data_points.py b/common/actions/data_points.py
index 6d5e4ac..289b013 100644
--- a/common/actions/data_points.py
+++ b/common/actions/data_points.py
@@ -1,7 +1,7 @@
import logging
from collections.abc import Mapping
from datetime import datetime
-from typing import Any, Optional, cast
+from typing import Any, cast
from collections.abc import Callable, Iterator
from uuid import UUID
@@ -102,7 +102,7 @@ def __init__(self, event: EVENT_TYPE):
"name": self._name,
}
- def _id(self) -> Optional[UUID]:
+ def _id(self) -> UUID | None:
return self.event.project_id
def _name(self) -> str:
@@ -120,16 +120,16 @@ def __init__(self, event: Event):
"type": self._type,
}
- def _id(self) -> Optional[UUID]:
+ def _id(self) -> UUID | None:
return self.event.component_id
- def _key(self) -> Optional[str]:
+ def _key(self) -> str | None:
return self.event.component_key
- def _name(self) -> Optional[str]:
+ def _name(self) -> str | None:
return self.event.component.display_name
- def _type(self) -> Optional[str]:
+ def _type(self) -> str | None:
return self.event.component_type.name
@@ -142,13 +142,13 @@ def __init__(self, event: Event):
"name": self._name,
}
- def _id(self) -> Optional[UUID]:
+ def _id(self) -> UUID | None:
return self.event.pipeline_id
- def _key(self) -> Optional[str]:
+ def _key(self) -> str | None:
return self.event.pipeline_key
- def _name(self) -> Optional[str]:
+ def _name(self) -> str | None:
return cast(str, self.event.pipeline.display_name)
@@ -170,13 +170,13 @@ def __init__(self, event: Event):
"expected_end_time_formatted": self._expected_end_time_formatted,
}
- def _id(self) -> Optional[UUID]:
+ def _id(self) -> UUID | None:
return self.event.run_id
- def _key(self) -> Optional[str]:
+ def _key(self) -> str | None:
return self.event.run_key
- def _name(self) -> Optional[str]:
+ def _name(self) -> str | None:
return self.event.run_name
def _status(self) -> str:
@@ -195,7 +195,7 @@ def _start_time_formatted(self) -> str:
return datetime_formatted(start_time)
return "N/A"
- def _expected_start_dt(self) -> Optional[datetime]:
+ def _expected_start_dt(self) -> datetime | None:
try:
run = getattr(self.event, "run", None)
except DoesNotExist:
@@ -206,11 +206,11 @@ def _expected_start_dt(self) -> Optional[datetime]:
return cast(datetime, expected_start_time)
return None
- def _expected_start_time(self) -> Optional[str]:
+ def _expected_start_time(self) -> str | None:
val = self._expected_start_dt()
return datetime_iso8601(val) if val else None
- def _expected_start_time_formatted(self) -> Optional[str]:
+ def _expected_start_time_formatted(self) -> str | None:
val = self._expected_start_dt()
return datetime_formatted(val) if val else None
@@ -226,7 +226,7 @@ def _end_time_formatted(self) -> str:
return datetime_formatted(end_time)
return "N/A"
- def _expected_end_dt(self) -> Optional[datetime]:
+ def _expected_end_dt(self) -> datetime | None:
try:
run = getattr(self.event, "run", None)
except DoesNotExist:
@@ -237,11 +237,11 @@ def _expected_end_dt(self) -> Optional[datetime]:
return cast(datetime, expected_end_time)
return None
- def _expected_end_time(self) -> Optional[str]:
+ def _expected_end_time(self) -> str | None:
val = self._expected_end_dt()
return datetime_iso8601(val) if val else None
- def _expected_end_time_formatted(self) -> Optional[str]:
+ def _expected_end_time_formatted(self) -> str | None:
val = self._expected_end_dt()
return datetime_formatted(val) if val else None
@@ -255,10 +255,10 @@ def __init__(self, event: Event):
"name": self._name,
}
- def _id(self) -> Optional[UUID]:
+ def _id(self) -> UUID | None:
return self.event.task_id
- def _key(self) -> Optional[str]:
+ def _key(self) -> str | None:
ret: str = getattr(self.event, "task_key")
return ret
@@ -278,7 +278,7 @@ def __init__(self, event: Event) -> None:
"end_time_formatted": self._end_time_formatted,
}
- def _id(self) -> Optional[UUID]:
+ def _id(self) -> UUID | None:
return self.event.run_task_id
def _status(self) -> str:
@@ -408,7 +408,7 @@ def _alert_type(self) -> str:
_type: str = self.event.type.value
return _type
- def _expected_start_dt(self) -> Optional[datetime]:
+ def _expected_start_dt(self) -> datetime | None:
try:
alert = getattr(self.event, "alert", None)
except DoesNotExist:
@@ -419,7 +419,7 @@ def _expected_start_dt(self) -> Optional[datetime]:
return cast(datetime, expected_start_time)
return None
- def _expected_end_dt(self) -> Optional[datetime]:
+ def _expected_end_dt(self) -> datetime | None:
try:
alert = getattr(self.event, "alert", None)
except DoesNotExist:
@@ -430,11 +430,11 @@ def _expected_end_dt(self) -> Optional[datetime]:
return cast(datetime, expected_end_time)
return None
- def _expected_start_time_formatted(self) -> Optional[str]:
+ def _expected_start_time_formatted(self) -> str | None:
val = self._expected_start_dt()
return datetime_formatted(val) if val else None
- def _expected_end_time_formatted(self) -> Optional[str]:
+ def _expected_end_time_formatted(self) -> str | None:
val = self._expected_end_dt()
return datetime_formatted(val) if val else None
@@ -448,16 +448,16 @@ def __init__(self, event: RunAlert) -> None:
"name": self._name,
}
- def _id(self) -> Optional[UUID]:
- id: Optional[UUID] = self.event.batch_pipeline_id
+ def _id(self) -> UUID | None:
+ id: UUID | None = self.event.batch_pipeline_id
return id
- def _key(self) -> Optional[str]:
- key: Optional[str] = self.event.batch_pipeline.key
+ def _key(self) -> str | None:
+ key: str | None = self.event.batch_pipeline.key
return key
- def _name(self) -> Optional[str]:
- name: Optional[str] = self.event.batch_pipeline.display_name
+ def _name(self) -> str | None:
+ name: str | None = self.event.batch_pipeline.display_name
return name
@@ -470,16 +470,16 @@ def __init__(self, event: RunAlert) -> None:
"name": self._name,
}
- def _id(self) -> Optional[UUID]:
- id: Optional[UUID] = self.event.run.id
+ def _id(self) -> UUID | None:
+ id: UUID | None = self.event.run.id
return id
- def _key(self) -> Optional[str]:
- key: Optional[str] = self.event.run.key
+ def _key(self) -> str | None:
+ key: str | None = self.event.run.key
return key
- def _name(self) -> Optional[str]:
- name: Optional[str] = self.event.run.name
+ def _name(self) -> str | None:
+ name: str | None = self.event.run.name
return name
@@ -493,26 +493,26 @@ def __init__(self, rule: Rule) -> None:
"run_state_trigger_successive": self._run_state_trigger_successive,
}
- def _run_state_matches(self) -> Optional[str]:
+ def _run_state_matches(self) -> str | None:
try:
- matches: Optional[str] = self.rule.rule_data["conditions"][0]["run_state"]["matches"]
+ matches: str | None = self.rule.rule_data["conditions"][0]["run_state"]["matches"]
return matches
except Exception:
return None
- def _run_state_count(self) -> Optional[str]:
+ def _run_state_count(self) -> str | None:
try:
return str(self.rule.rule_data["conditions"][0]["run_state"]["count"])
except Exception:
return None
- def _run_state_group_run_name(self) -> Optional[str]:
+ def _run_state_group_run_name(self) -> str | None:
try:
return str(self.rule.rule_data["conditions"][0]["run_state"]["group_run_name"])
except Exception:
return None
- def _run_state_trigger_successive(self) -> Optional[str]:
+ def _run_state_trigger_successive(self) -> str | None:
try:
return str(self.rule.rule_data["conditions"][0]["run_state"]["trigger_successive"])
except Exception:
diff --git a/common/actions/send_email_action.py b/common/actions/send_email_action.py
index 3906829..7845801 100644
--- a/common/actions/send_email_action.py
+++ b/common/actions/send_email_action.py
@@ -1,7 +1,6 @@
__all__ = ["SendEmailAction"]
import logging
from dataclasses import asdict
-from typing import Optional
from uuid import UUID
from peewee import DoesNotExist
@@ -22,7 +21,7 @@ class SendEmailAction(BaseAction):
required_arguments = {"recipients", "template"}
requires_action_template = True
- def _run(self, event: EVENT_TYPE, rule: Rule, journey_id: Optional[UUID]) -> ActionResult:
+ def _run(self, event: EVENT_TYPE, rule: Rule, journey_id: UUID | None) -> ActionResult:
try:
context = self._get_data_points(event, rule, journey_id)
except Exception as e:
@@ -41,7 +40,7 @@ def _run(self, event: EVENT_TYPE, rule: Rule, journey_id: Optional[UUID]) -> Act
return ActionResult(True, response, None)
def _get_data_points(
- self, event: EVENT_TYPE, rule: Rule, journey_id: Optional[UUID]
+ self, event: EVENT_TYPE, rule: Rule, journey_id: UUID | None
) -> dict | AgentStatusChangeDataPoints:
"""
Get the data points to be used in the email template
diff --git a/common/actions/webhook_action.py b/common/actions/webhook_action.py
index 207eee8..59dc93a 100644
--- a/common/actions/webhook_action.py
+++ b/common/actions/webhook_action.py
@@ -1,7 +1,7 @@
__all__ = ["WebhookAction"]
import logging
-from typing import Any, Optional, Union
+from typing import Any, Union
from collections.abc import Mapping
from uuid import UUID
@@ -44,7 +44,7 @@ def format_data(data: Union[None, list, dict, str], data_points: Mapping) -> Any
class WebhookAction(BaseAction):
required_arguments = {"url", "method"}
- def _run(self, event: EVENT_TYPE, rule: Rule, _: Optional[UUID]) -> ActionResult:
+ def _run(self, event: EVENT_TYPE, rule: Rule, _: UUID | None) -> ActionResult:
data_points: Mapping
match event:
case RunAlert() | InstanceAlert():
@@ -67,7 +67,7 @@ def _run(self, event: EVENT_TYPE, rule: Rule, _: Optional[UUID]) -> ActionResult
return ActionResult(False, None, e)
return ActionResult(True, {"status_code": response.status_code}, None)
- def _parse_headers(self, data_points: Mapping) -> Optional[dict[str, str]]:
+ def _parse_headers(self, data_points: Mapping) -> dict[str, str] | None:
if headers := self.arguments.get("headers"):
return {h["key"]: format_data(h["value"], data_points) for h in headers}
else:
diff --git a/common/api/base_view.py b/common/api/base_view.py
index ac6f9d7..33beae2 100644
--- a/common/api/base_view.py
+++ b/common/api/base_view.py
@@ -4,7 +4,7 @@
import logging
from dataclasses import dataclass
from functools import cached_property
-from typing import Any, Optional
+from typing import Any
from flask import g, request
from flask.typing import ResponseReturnValue
@@ -22,7 +22,7 @@
@dataclass
class Permission:
entity_attribute: str
- role: Optional[str] = None
+ role: str | None = None
methods: tuple[str, ...] = ("GET", "PUT", "POST", "PATCH", "DELETE")
def __call__(self, *methods: str) -> Permission:
@@ -48,7 +48,7 @@ class BaseView(MethodView):
"""
@property
- def user(self) -> Optional[User]:
+ def user(self) -> User | None:
"""Return the currently authenticated user."""
return getattr(g, "user", None)
@@ -61,12 +61,12 @@ def user_roles(self) -> list[str]:
return []
@property
- def claims(self) -> Optional[User]:
+ def claims(self) -> User | None:
"""Return the currently authenticated token."""
return getattr(g, "claims", None)
@property
- def project(self) -> Optional[Project]:
+ def project(self) -> Project | None:
return getattr(g, "project", None)
@property
diff --git a/common/api/flask_ext/authentication/common.py b/common/api/flask_ext/authentication/common.py
index 249a0a9..6f7125b 100644
--- a/common/api/flask_ext/authentication/common.py
+++ b/common/api/flask_ext/authentication/common.py
@@ -1,7 +1,6 @@
__all__ = ["get_domain", "BaseAuthPlugin", "validate_authentication"]
import logging
import re
-from typing import Optional
from urllib.parse import urlparse
from flask import current_app, g, request
@@ -16,10 +15,10 @@
class BaseAuthPlugin(BaseExtension):
header_name: str = NotImplemented
- header_prefix: Optional[str] = None
+ header_prefix: str | None = None
@classmethod
- def get_header_data(cls) -> Optional[str]:
+ def get_header_data(cls) -> str | None:
auth_data = request.headers.get(cls.header_name, None)
if auth_data and cls.header_prefix:
if match := re.match(rf"^{cls.header_prefix}\s+(.*)\s*$", auth_data):
diff --git a/common/api/flask_ext/authentication/jwt_plugin.py b/common/api/flask_ext/authentication/jwt_plugin.py
index 965ba9a..0bcd41f 100644
--- a/common/api/flask_ext/authentication/jwt_plugin.py
+++ b/common/api/flask_ext/authentication/jwt_plugin.py
@@ -1,7 +1,7 @@
__all__ = ["JWTAuth"]
import logging
-from datetime import datetime, timedelta, timezone
-from typing import Optional, cast
+from datetime import datetime, timedelta, UTC
+from typing import cast
from collections.abc import Callable
from flask import current_app, g, request
@@ -24,7 +24,7 @@ def get_token_expiration(claims: JWT_CLAIMS) -> datetime:
except KeyError as ke:
raise ValueError("Token claims missing 'exp' key") from ke
try:
- return datetime.fromtimestamp(cast(float | int, exp_timestamp), tz=timezone.utc)
+ return datetime.fromtimestamp(cast(float | int, exp_timestamp), tz=UTC)
except Exception as e:
raise ValueError(f"Unable to parse expiration from '{claims['exp']}'") from e
@@ -63,7 +63,7 @@ def pre_request_auth(cls) -> None:
except Exception as e:
raise Unauthorized("Invalid authentication token") from e
- if get_token_expiration(claims) < datetime.now(timezone.utc):
+ if get_token_expiration(claims) < datetime.now(UTC):
LOG.error("JWT token expired")
raise Unauthorized("Invalid authentication token")
@@ -95,7 +95,7 @@ def decode_token(cls, token: str) -> JWT_CLAIMS:
return decoded_token
@classmethod
- def log_user_in(cls, user: User, logout_callback: Optional[str] = None, claims: Optional[JWT_CLAIMS] = None) -> str:
+ def log_user_in(cls, user: User, logout_callback: str | None = None, claims: JWT_CLAIMS | None = None) -> str:
claims = claims or {}
if logout_callback:
@@ -105,7 +105,7 @@ def log_user_in(cls, user: User, logout_callback: Optional[str] = None, claims:
raise ValueError(f"Logout callback '{logout_callback}' is not registered.")
if "exp" not in claims:
- claims["exp"] = (datetime.now(timezone.utc) + cls.default_jwt_expiration).timestamp()
+ claims["exp"] = (datetime.now(UTC) + cls.default_jwt_expiration).timestamp()
claims["user_id"] = str(user.id)
claims["company_id"] = str(user.primary_company_id)
diff --git a/common/api/flask_ext/base_extension.py b/common/api/flask_ext/base_extension.py
index bfd403f..653cf74 100644
--- a/common/api/flask_ext/base_extension.py
+++ b/common/api/flask_ext/base_extension.py
@@ -1,12 +1,11 @@
__all__ = ["BaseExtension"]
-from typing import Optional
from flask import Flask
from flask.typing import AfterRequestCallable, AppOrBlueprintKey, BeforeRequestCallable
class BaseExtension:
- def __init__(self, app: Optional[Flask] = None) -> None:
+ def __init__(self, app: Flask | None = None) -> None:
if app is not None:
self.app = app
self.init_app()
diff --git a/common/api/flask_ext/config.py b/common/api/flask_ext/config.py
index 3e914d9..3e9070f 100644
--- a/common/api/flask_ext/config.py
+++ b/common/api/flask_ext/config.py
@@ -1,6 +1,5 @@
__all__ = ["Config"]
import os
-from typing import Optional
from flask import Flask
@@ -16,7 +15,7 @@ class Config:
then the configuration will load "foo.bar.production"
"""
- def __init__(self, app: Optional[Flask] = None, config_module: str = ""):
+ def __init__(self, app: Flask | None = None, config_module: str = ""):
if not config_module:
raise ValueError("You must provide a 'config_module' to the Config extension")
self.app = app
diff --git a/common/api/flask_ext/cors.py b/common/api/flask_ext/cors.py
index f6cb681..79e7d95 100644
--- a/common/api/flask_ext/cors.py
+++ b/common/api/flask_ext/cors.py
@@ -1,6 +1,5 @@
__all__ = ["CORS"]
from http import HTTPStatus
-from typing import Optional
from flask import Flask, Response, make_response, request
from werkzeug.exceptions import NotFound
@@ -26,13 +25,13 @@
class CORS(BaseExtension):
- def __init__(self, app: Optional[Flask] = None, allowed_methods: Optional[list[str]] = None):
+ def __init__(self, app: Flask | None = None, allowed_methods: list[str] | None = None):
allowed_methods = allowed_methods or []
self.allowed_methods = ", ".join(allowed_methods + ["OPTIONS"]).upper()
super().__init__(app)
@staticmethod
- def make_preflight_response() -> Optional[Response]:
+ def make_preflight_response() -> Response | None:
if request.method == "OPTIONS":
# When request.endpoint isn't populated it means that the URL didn't match any registered view. For this
# case we abort and issue a 404
diff --git a/common/api/request_parsing.py b/common/api/request_parsing.py
index 73f91e1..c794ad5 100644
--- a/common/api/request_parsing.py
+++ b/common/api/request_parsing.py
@@ -1,7 +1,7 @@
__all__ = ["get_bool_param", "no_body_allowed", "str_to_bool", "get_origin_domain"]
from functools import wraps
-from typing import Any, Optional
+from typing import Any
from collections.abc import Callable, Iterable
from urllib.parse import urlparse
@@ -33,10 +33,10 @@ def str_to_bool(value: str, param_name: str) -> bool:
elif case_insensitive_value == "false":
return False
else:
- raise ValidationError({param_name: ("Expected 'true' or 'false'. Instead received " f"'{value}'.")})
+ raise ValidationError({param_name: (f"Expected 'true' or 'false'. Instead received '{value}'.")})
-def no_body_allowed(func: Optional[Callable] = None, /, methods: Iterable[str] = SAFE_HTTP_METHODS) -> Callable:
+def no_body_allowed(func: Callable | None = None, /, methods: Iterable[str] = SAFE_HTTP_METHODS) -> Callable:
"""
Decorator to be used on MethodView functions if the function does not allow a request body to be passed.
@@ -58,7 +58,7 @@ def _wrapper(*args: list, **kwargs: dict) -> Any:
return decorator
-def get_origin_domain() -> Optional[str]:
+def get_origin_domain() -> str | None:
if (source_url := request.headers.get("Origin")) is not None:
try:
return urlparse(source_url).netloc or None
diff --git a/common/api/search_view.py b/common/api/search_view.py
index 44fad79..ad68568 100644
--- a/common/api/search_view.py
+++ b/common/api/search_view.py
@@ -1,4 +1,4 @@
-from typing import Any, Optional
+from typing import Any
from flask import Blueprint, Response, request
from flask.typing import RouteCallable
@@ -12,7 +12,7 @@
class SearchView(BaseView):
- args_from_post: Optional[MultiDict] = None
+ args_from_post: MultiDict | None = None
request_body_schema: type[Schema]
def post(self, *args: Any, **kwargs: Any) -> Response:
diff --git a/common/apscheduler_extensions.py b/common/apscheduler_extensions.py
index 18fa404..4624eec 100644
--- a/common/apscheduler_extensions.py
+++ b/common/apscheduler_extensions.py
@@ -3,7 +3,6 @@
import logging
import re
from datetime import datetime, timedelta
-from typing import Optional
from collections.abc import Generator
from zoneinfo import ZoneInfo
@@ -36,10 +35,10 @@ def __init__(self, trigger: BaseTrigger, delay: timedelta):
self.trigger = trigger
self.delay = delay
- def get_next_fire_time(self, previous_fire_time: Optional[datetime], now: datetime) -> Optional[datetime]:
+ def get_next_fire_time(self, previous_fire_time: datetime | None, now: datetime) -> datetime | None:
if previous_fire_time:
previous_fire_time -= self.delay
- next_fire_time: Optional[datetime] = self.trigger.get_next_fire_time(previous_fire_time, now)
+ next_fire_time: datetime | None = self.trigger.get_next_fire_time(previous_fire_time, now)
return next_fire_time + self.delay if next_fire_time else None
def __str__(self) -> str:
@@ -109,7 +108,7 @@ def fix_weekdays(expression: str) -> str:
def get_crontab_trigger_times(
- crontab: str, timezone: ZoneInfo, start_range: datetime, end_range: Optional[datetime] = None
+ crontab: str, timezone: ZoneInfo, start_range: datetime, end_range: datetime | None = None
) -> Generator[datetime, None, None]:
"""
Generate the crontab trigger times for the given time range.
diff --git a/common/auth/keys/service_key.py b/common/auth/keys/service_key.py
index 8ed0a53..1fbedb3 100644
--- a/common/auth/keys/service_key.py
+++ b/common/auth/keys/service_key.py
@@ -1,8 +1,8 @@
import logging
from base64 import b64encode
from dataclasses import dataclass
-from datetime import datetime, timedelta, timezone
-from typing import NamedTuple, Optional
+from datetime import datetime, timedelta, UTC
+from typing import NamedTuple
from uuid import uuid4
from peewee import DoesNotExist
@@ -31,15 +31,15 @@ def generate_key(
*,
project: Project,
allowed_services: list[str],
- name: Optional[str] = None,
- description: Optional[str] = None,
+ name: str | None = None,
+ description: str | None = None,
expiration_days: int = DEFAULT_EXPIRY_DAYS,
) -> KeyPair:
"""Generate a new Service Account key for the given service name."""
passphrase = generate_passphrase()
salt = str(uuid4())
passphrase_hash = hash_value(value=passphrase, salt=salt)
- expiry = datetime.now(timezone.utc) + timedelta(days=expiration_days)
+ expiry = datetime.now(UTC) + timedelta(days=expiration_days)
digest = create_digest(iterations=HASH_ITERATIONS, salt=salt, passphrase_hash=passphrase_hash)
# Give the key a pretty unique name if none provided (Name only has to be unique per-project so this should safe)
diff --git a/common/datetime_utils.py b/common/datetime_utils.py
index 8677449..6c473e0 100644
--- a/common/datetime_utils.py
+++ b/common/datetime_utils.py
@@ -1,12 +1,12 @@
__all__ = ["datetime_formatted", "datetime_iso8601", "to_utc_aware"]
-from datetime import datetime, timezone
+from datetime import datetime, UTC
# Although datetimes in Events are tz aware they only contains the raw offset
# after being marshmallow serialized. Since the tz name is desired astimezone
# is used for strftime to return the name.
def to_utc_aware(dt: datetime) -> datetime:
- return dt.replace(tzinfo=timezone.utc) if dt.tzinfo is None else dt.astimezone(timezone.utc)
+ return dt.replace(tzinfo=UTC) if dt.tzinfo is None else dt.astimezone(UTC)
def datetime_formatted(dt: datetime) -> str:
@@ -28,5 +28,5 @@ def datetime_to_timestamp(dt: datetime) -> float:
def timestamp_to_datetime(timestamp: float) -> datetime:
"""Convert a timestamp to a datetime object in UTC time."""
- dt = datetime.fromtimestamp(timestamp, tz=timezone.utc)
+ dt = datetime.fromtimestamp(timestamp, tz=UTC)
return dt
diff --git a/common/decorators.py b/common/decorators.py
index 075ae84..485f4bd 100644
--- a/common/decorators.py
+++ b/common/decorators.py
@@ -1,12 +1,12 @@
from __future__ import annotations
-from typing import Any, Generic, Optional, TypeVar, cast
+from typing import Any, TypeVar, cast
from collections.abc import Callable
PropertyType = TypeVar("PropertyType")
-class cached_property(Generic[PropertyType]):
+class cached_property[PropertyType]:
"""
A `property` decorator that caches the value on the instance.
@@ -47,7 +47,7 @@ class cached_property(Generic[PropertyType]):
"""
- _name: Optional[str] = None
+ _name: str | None = None
def __init__(self, f: Callable[[Any], PropertyType]) -> None:
self.func = f
@@ -59,7 +59,7 @@ def __set_name__(self, owner: type[object], name: str) -> None:
elif name != self._name:
raise TypeError(f"Cannot assign the same instance to two names ({self._name} and {name}).")
- def __get__(self, inst: object, cls: Optional[Any] = None) -> PropertyType:
+ def __get__(self, inst: object, cls: Any | None = None) -> PropertyType:
"""
Retrieve the value from instance, stashing the result in inst.__dict__
diff --git a/common/entities/alert.py b/common/entities/alert.py
index 17b7371..70a6ebf 100644
--- a/common/entities/alert.py
+++ b/common/entities/alert.py
@@ -2,7 +2,6 @@
from datetime import datetime
from enum import Enum
-from typing import Optional
from peewee import CharField, CompositeKey, ForeignKeyField
from playhouse.mysql_ext import JSONField
@@ -63,7 +62,7 @@ class AlertBase(BaseEntity, AuditUpdateTimeEntityMixin):
level = EnumStrField(AlertLevel, null=False, max_length=50)
@property
- def expected_start_time(self) -> Optional[datetime]:
+ def expected_start_time(self) -> datetime | None:
"""If the alert has expected_start_time in it's details dict, return it as a datetime object."""
timestamp = self.details.get("expected_start_time", None)
if timestamp:
@@ -76,7 +75,7 @@ def expected_start_time(self) -> Optional[datetime]:
return None
@expected_start_time.setter
- def expected_start_time(self, dt_obj: Optional[datetime]) -> None:
+ def expected_start_time(self, dt_obj: datetime | None) -> None:
"""Set the expected_start_time value (converts to timestamp in details dict)."""
if dt_obj is None:
self.details.pop("expected_start_time", None)
@@ -85,7 +84,7 @@ def expected_start_time(self, dt_obj: Optional[datetime]) -> None:
self.details["expected_start_time"] = timestamp
@property
- def expected_end_time(self) -> Optional[datetime]:
+ def expected_end_time(self) -> datetime | None:
"""If the alert has expected_end_time in it's details dict, return it as a datetime object."""
timestamp = self.details.get("expected_end_time", None)
if timestamp:
@@ -98,7 +97,7 @@ def expected_end_time(self) -> Optional[datetime]:
return None
@expected_end_time.setter
- def expected_end_time(self, dt_obj: Optional[datetime]) -> None:
+ def expected_end_time(self, dt_obj: datetime | None) -> None:
"""Set the expected_end_time value (converts to timestamp in details dict)."""
if dt_obj is None:
self.details.pop("expected_end_time", None)
diff --git a/common/entities/authentication.py b/common/entities/authentication.py
index 125cb1d..60f0776 100644
--- a/common/entities/authentication.py
+++ b/common/entities/authentication.py
@@ -1,6 +1,6 @@
__all__ = ["ApiKey", "Service", "ServiceAccountKey"]
import logging
-from datetime import datetime, timezone
+from datetime import datetime, UTC
from enum import Enum
from uuid import uuid4
@@ -24,7 +24,7 @@ class ApiKey(Model):
user = ForeignKeyField(User, backref="api_keys", on_delete="CASCADE", null=False, index=True)
def is_expired(self) -> bool:
- if datetime.now(timezone.utc) > self.expiry:
+ if datetime.now(UTC) > self.expiry:
return True
else:
return False
@@ -59,7 +59,7 @@ def is_expired(self) -> bool:
# If no expiration date is set, the key's duration is unlimited
if not self.expiry:
return False
- if datetime.now(timezone.utc) > self.expiry:
+ if datetime.now(UTC) > self.expiry:
return True
else:
return False
diff --git a/common/entities/base_entity.py b/common/entities/base_entity.py
index 7ac8e3f..f0cd4ba 100644
--- a/common/entities/base_entity.py
+++ b/common/entities/base_entity.py
@@ -1,6 +1,6 @@
__all__ = ["ActivableEntityMixin", "AuditEntityMixin", "AuditUpdateTimeEntityMixin", "BaseEntity", "BaseModel", "DB"]
-from datetime import datetime, timezone
+from datetime import datetime, UTC
from typing import Any
from uuid import uuid4
@@ -53,5 +53,5 @@ class AuditUpdateTimeEntityMixin(Model):
@classmethod
def update(cls, *args: Any, **kwargs: Any) -> Any:
- kwargs["updated_on"] = datetime.utcnow().replace(tzinfo=timezone.utc)
+ kwargs["updated_on"] = datetime.utcnow().replace(tzinfo=UTC)
return super().update(*args, **kwargs)
diff --git a/common/entities/company.py b/common/entities/company.py
index 6c84f63..20d4cd1 100644
--- a/common/entities/company.py
+++ b/common/entities/company.py
@@ -1,6 +1,5 @@
__all__ = ["Company"]
-from typing import Optional
from peewee import CharField, ForeignKeyField
@@ -13,5 +12,5 @@ class Company(BaseEntity, AuditEntityMixin):
name = CharField(unique=True, null=False)
@property
- def parent(self) -> Optional[ForeignKeyField]:
+ def parent(self) -> ForeignKeyField | None:
return None
diff --git a/common/entities/component_meta.py b/common/entities/component_meta.py
index 606f5cc..2bd2daf 100644
--- a/common/entities/component_meta.py
+++ b/common/entities/component_meta.py
@@ -1,4 +1,4 @@
-from typing import Any, Optional, Union
+from typing import Any, Union
from peewee import Field, ForeignKeyField, ModelBase, ModelSelect, ModelUpdate
@@ -144,7 +144,7 @@ def create(cls: type[BaseModel], **data: object) -> BaseModel:
component.save(force_insert=True)
return component
- def save(self: BaseModel, force_insert: bool = False, only: Optional[object] = None) -> Union[bool, int]:
+ def save(self: BaseModel, force_insert: bool = False, only: object | None = None) -> Union[bool, int]:
ret = 0
if not force_insert and self.component:
ret += self.component.save(only=only) or 0
diff --git a/common/entities/dataset_operation.py b/common/entities/dataset_operation.py
index c80553e..3d0c201 100644
--- a/common/entities/dataset_operation.py
+++ b/common/entities/dataset_operation.py
@@ -13,8 +13,8 @@
class DatasetOperationType(Enum):
- READ: str = "READ"
- WRITE: str = "WRITE"
+ READ = "READ"
+ WRITE = "WRITE"
class DatasetOperation(BaseEntity):
diff --git a/common/entities/upcoming_instance.py b/common/entities/upcoming_instance.py
index 949c486..362ad86 100644
--- a/common/entities/upcoming_instance.py
+++ b/common/entities/upcoming_instance.py
@@ -1,7 +1,6 @@
__all__ = ["UpcomingInstance"]
from dataclasses import dataclass
from datetime import datetime
-from typing import Optional
from common.entities.journey import Journey
@@ -9,5 +8,5 @@
@dataclass
class UpcomingInstance:
journey: Journey
- expected_start_time: Optional[datetime] = None
- expected_end_time: Optional[datetime] = None
+ expected_start_time: datetime | None = None
+ expected_end_time: datetime | None = None
diff --git a/common/entity_services/component_service.py b/common/entity_services/component_service.py
index a772926..413479b 100644
--- a/common/entity_services/component_service.py
+++ b/common/entity_services/component_service.py
@@ -2,7 +2,6 @@
from datetime import datetime
from itertools import cycle
-from typing import Optional
from collections.abc import Generator
from peewee import Select
@@ -31,7 +30,7 @@ def select_journeys(component: Component) -> Select:
@classmethod
def get_or_create_active_instances(
- cls, component: Component, start_time: Optional[datetime] = None
+ cls, component: Component, start_time: datetime | None = None
) -> Generator[tuple[bool, Instance], None, None]:
"""
Retrieves active Instances for a given component. Create active Instances when a Journey does not have one.
diff --git a/common/entity_services/helpers/filter_rules.py b/common/entity_services/helpers/filter_rules.py
index 290924c..391256d 100644
--- a/common/entity_services/helpers/filter_rules.py
+++ b/common/entity_services/helpers/filter_rules.py
@@ -13,7 +13,7 @@
from dataclasses import dataclass, field
from datetime import datetime
-from typing import Optional, TypeVar
+from typing import TypeVar
from collections.abc import Callable
from uuid import UUID
@@ -72,7 +72,7 @@ class ParamConfig:
func: Callable
-def _date_or_none(params: MultiDict, field_name: str) -> Optional[datetime]:
+def _date_or_none(params: MultiDict, field_name: str) -> datetime | None:
if date := params.get(field_name):
try:
return arrow.get(date).datetime
@@ -81,7 +81,7 @@ def _date_or_none(params: MultiDict, field_name: str) -> Optional[datetime]:
return None
-def _str_to_bool(params: MultiDict, field_name: str) -> Optional[bool]:
+def _str_to_bool(params: MultiDict, field_name: str) -> bool | None:
if (value := params.get(field_name)) is None:
return None
return str_to_bool(value, field_name)
@@ -130,28 +130,28 @@ class Filters:
Extend by specifying the wanted attributes and how to unpack them in from_params.
"""
- active: Optional[bool] = None
+ active: bool | None = None
component_ids: list[str] = field(default_factory=list)
component_types: list[str] = field(default_factory=list)
- date_range_end: Optional[datetime] = None
- date_range_start: Optional[datetime] = None
- end_range: Optional[datetime] = None
- end_range_begin: Optional[datetime] = None
- end_range_end: Optional[datetime] = None
+ date_range_end: datetime | None = None
+ date_range_start: datetime | None = None
+ end_range: datetime | None = None
+ end_range_begin: datetime | None = None
+ end_range_end: datetime | None = None
event_ids: list[str] = field(default_factory=list)
event_types: list[str] = field(default_factory=list)
instance_ids: list[str] = field(default_factory=list)
journey_ids: list[str] = field(default_factory=list)
journey_names: list[str] = field(default_factory=list)
- key: Optional[str] = None
+ key: str | None = None
levels: list[str] = field(default_factory=list)
pipeline_keys: list[str] = field(default_factory=list)
project_ids: list[str] = field(default_factory=list)
run_ids: list[str] = field(default_factory=list)
run_keys: list[str] = field(default_factory=list)
- start_range: Optional[datetime] = None
- start_range_begin: Optional[datetime] = None
- start_range_end: Optional[datetime] = None
+ start_range: datetime | None = None
+ start_range_begin: datetime | None = None
+ start_range_end: datetime | None = None
statuses: list[str] = field(default_factory=list)
task_ids: list[str] = field(default_factory=list)
tools: list[str] = field(default_factory=list)
@@ -166,9 +166,7 @@ def __bool__(self) -> bool:
return False
@staticmethod
- def validate_time_range(
- range_begin: Optional[datetime], range_end: Optional[datetime], range_begin_name: str
- ) -> None:
+ def validate_time_range(range_begin: datetime | None, range_end: datetime | None, range_begin_name: str) -> None:
if range_begin is None or range_end is None:
return None
if range_begin >= range_end:
diff --git a/common/entity_services/helpers/list_rules.py b/common/entity_services/helpers/list_rules.py
index 8352ea8..56e7869 100644
--- a/common/entity_services/helpers/list_rules.py
+++ b/common/entity_services/helpers/list_rules.py
@@ -4,7 +4,7 @@
from dataclasses import dataclass
from enum import Enum as std_Enum
from enum import auto
-from typing import Generic, Optional, TypeVar
+from typing import TypeVar
from collections.abc import Generator
from marshmallow import EXCLUDE, Schema
@@ -36,7 +36,7 @@ class Meta:
@dataclass
-class Page(Generic[T]):
+class Page[T]:
"""
Useful for returning results from the service layer to get paginated results
but also receive the total objects without pagination.
@@ -83,7 +83,7 @@ class ListRules:
page: int = DEFAULT_PAGE
count: int = DEFAULT_COUNT
sort: SortOrder = SortOrder.ASC
- search: Optional[str] = None
+ search: str | None = None
@classmethod
def from_params_without_search(cls, params: MultiDict) -> ListRules:
diff --git a/common/entity_services/instance_service.py b/common/entity_services/instance_service.py
index 62139ec..feb07b0 100644
--- a/common/entity_services/instance_service.py
+++ b/common/entity_services/instance_service.py
@@ -2,7 +2,6 @@
from collections import Counter, defaultdict
from datetime import datetime
-from typing import Optional
from collections.abc import Iterable
from uuid import UUID
@@ -163,11 +162,11 @@ def run_alerts_query(instances: Iterable[Instance]) -> ModelSelect:
def get_instance_run_counts(
instance: UUID | Instance,
*,
- include_run_statuses: Optional[Iterable[str]] = None,
- exclude_run_statuses: Optional[Iterable[str]] = None,
- journey: Optional[UUID] = None,
- pipelines: Optional[Iterable[UUID]] = None,
- end_before: Optional[datetime] = None,
+ include_run_statuses: Iterable[str] | None = None,
+ exclude_run_statuses: Iterable[str] | None = None,
+ journey: UUID | None = None,
+ pipelines: Iterable[UUID] | None = None,
+ end_before: datetime | None = None,
) -> dict[UUID, int]:
"""
Return a dict of pipelines with the corresponding run count per pipeline.
diff --git a/common/entity_services/journey_service.py b/common/entity_services/journey_service.py
index d6881ac..6be3e28 100644
--- a/common/entity_services/journey_service.py
+++ b/common/entity_services/journey_service.py
@@ -1,7 +1,6 @@
__all__ = ["JourneyService"]
import logging
-from typing import Optional
from uuid import UUID
from common.entities import Action, Company, Component, Journey, JourneyDagEdge, Organization, Project, Rule
@@ -19,7 +18,7 @@ def get_rules_with_rules(journey_id: UUID, list_rules: ListRules) -> Page[Rule]:
return Page[Rule].get_paginated_results(query, Rule.created_on, list_rules)
@staticmethod
- def get_action_by_implementation(journey_id: UUID, action_impl: str) -> Optional[Action]:
+ def get_action_by_implementation(journey_id: UUID, action_impl: str) -> Action | None:
"""
Fetches an Action entity given a Journey ID and the action implementation.
diff --git a/common/entity_services/pipeline_service.py b/common/entity_services/pipeline_service.py
index 0081ac4..551c4a4 100644
--- a/common/entity_services/pipeline_service.py
+++ b/common/entity_services/pipeline_service.py
@@ -1,6 +1,5 @@
__all__ = ["PipelineService"]
import logging
-from typing import Optional
from common.entities import Pipeline
@@ -9,6 +8,6 @@
class PipelineService:
@staticmethod
- def get_by_key_and_project(pipeline_key: Optional[str], project_id: str) -> Pipeline:
+ def get_by_key_and_project(pipeline_key: str | None, project_id: str) -> Pipeline:
pipeline: Pipeline = Pipeline.get(Pipeline.key == pipeline_key, Pipeline.project == project_id)
return pipeline
diff --git a/common/entity_services/project_service.py b/common/entity_services/project_service.py
index 6086e5a..0c7b3bb 100644
--- a/common/entity_services/project_service.py
+++ b/common/entity_services/project_service.py
@@ -2,7 +2,7 @@
from functools import reduce
from operator import or_
-from typing import Any, Optional
+from typing import Any
from peewee import PREFETCH_TYPE, Value, fn, prefetch, DoesNotExist
@@ -107,7 +107,7 @@ def get_components_with_rules(project_id: str, rules: ListRules, filters: Compon
@staticmethod
def get_runs_with_rules(
- project_id: Optional[str], pipeline_ids: list[str], rules: ListRules, filters: RunFilters
+ project_id: str | None, pipeline_ids: list[str], rules: ListRules, filters: RunFilters
) -> Page[Run]:
start_dt = fn.COALESCE(Run.start_time, Run.expected_start_time)
query = Run.select(Run, start_dt.alias("start_dt")).distinct()
@@ -154,7 +154,7 @@ def get_runs_with_rules(
@staticmethod
def get_instances_with_rules(
- rules: ListRules, filters: Filters, project_ids: list[str], company_id: Optional[str] = None
+ rules: ListRules, filters: Filters, project_ids: list[str], company_id: str | None = None
) -> Page[Instance]:
memberships = [Journey.project.in_(project_ids)] if project_ids else []
if company_id:
@@ -211,7 +211,7 @@ def get_instances_with_rules(
return Page[Instance](results=results, total=query.count())
@staticmethod
- def get_journeys_with_rules(project_id: str, rules: ListRules, component_id: Optional[str] = None) -> Page[Journey]:
+ def get_journeys_with_rules(project_id: str, rules: ListRules, component_id: str | None = None) -> Page[Journey]:
base_query = Journey.project == project_id
if rules.search is not None:
base_query &= Journey.name ** f"%{rules.search}%"
diff --git a/common/entity_services/test_outcome_service.py b/common/entity_services/test_outcome_service.py
index 9639889..b2835d7 100644
--- a/common/entity_services/test_outcome_service.py
+++ b/common/entity_services/test_outcome_service.py
@@ -1,6 +1,5 @@
__all__ = ["TestOutcomeService"]
-from typing import Optional
from uuid import UUID
from peewee import SqliteDatabase
@@ -17,9 +16,9 @@ def insert_from_event(
*,
event: TestOutcomesEvent,
component_id: UUID,
- instance_set_id: Optional[UUID] = None,
- run_id: Optional[UUID] = None,
- task_id: Optional[UUID] = None,
+ instance_set_id: UUID | None = None,
+ run_id: UUID | None = None,
+ task_id: UUID | None = None,
) -> None:
test_outcomes = []
test_outcome_integrations = []
@@ -65,7 +64,7 @@ def insert_from_event(
)
# Using the recursive lookup avoids having to check for None on optional values up the whole chain
- testgen_dataset: Optional[TestgenDataset] = getattr_recursive(
+ testgen_dataset: TestgenDataset | None = getattr_recursive(
event, "component_integrations__integrations__testgen", None
)
diff --git a/common/entity_services/upcoming_instance_service.py b/common/entity_services/upcoming_instance_service.py
index b02f671..64f44c9 100644
--- a/common/entity_services/upcoming_instance_service.py
+++ b/common/entity_services/upcoming_instance_service.py
@@ -4,7 +4,7 @@
from datetime import datetime
from heapq import merge
from operator import itemgetter
-from typing import Optional, cast
+from typing import cast
from collections.abc import Generator
from uuid import UUID
from zoneinfo import ZoneInfo
@@ -65,7 +65,7 @@ def _collect_journey_schedules(
def _get_instance_times(
schedules: JourneySchedules,
start_time: datetime,
- end_time: Optional[datetime],
+ end_time: datetime | None,
) -> Generator[tuple[datetime, bool], None, None]:
"""
Generate a sequence of expected instance start and end times from the given schedules
@@ -96,8 +96,8 @@ class UpcomingInstanceService:
def get_upcoming_instances_with_rules(
rules: ListRules,
filters: UpcomingInstanceFilters,
- project_id: Optional[UUID] = None,
- company_id: Optional[UUID] = None,
+ project_id: UUID | None = None,
+ company_id: UUID | None = None,
) -> list[UpcomingInstance]:
assert filters.start_range is not None
memberships = []
@@ -159,8 +159,8 @@ def get_upcoming_instances_with_rules(
def get_upcoming_instances(
journey: Journey,
start_time: datetime,
- end_time: Optional[datetime] = None,
- schedules: Optional[JourneySchedules] = None,
+ end_time: datetime | None = None,
+ schedules: JourneySchedules | None = None,
) -> Generator[UpcomingInstance, None, None]:
"""
Get upcoming instances for the given journey
diff --git a/common/entity_services/user_service.py b/common/entity_services/user_service.py
index 1f6a072..4e093be 100644
--- a/common/entity_services/user_service.py
+++ b/common/entity_services/user_service.py
@@ -1,14 +1,10 @@
-from typing import Optional
-
from common.entities import User
from common.entity_services.helpers import ListRules, Page
class UserService:
@staticmethod
- def list_with_rules(
- rules: ListRules, company_id: Optional[str] = None, name_filter: Optional[str] = None
- ) -> Page[User]:
+ def list_with_rules(rules: ListRules, company_id: str | None = None, name_filter: str | None = None) -> Page[User]:
query = User.select()
if company_id:
query = query.where(User.primary_company_id == company_id)
diff --git a/common/events/base.py b/common/events/base.py
index 64a496e..4d1f0d8 100644
--- a/common/events/base.py
+++ b/common/events/base.py
@@ -12,10 +12,9 @@
]
from dataclasses import dataclass, field
-from datetime import datetime, timezone
+from datetime import datetime, UTC
from enum import Enum as std_Enum
from functools import partial
-from typing import Optional
from uuid import UUID
from uuid import UUID as std_UUID
from uuid import uuid4
@@ -44,14 +43,14 @@ def partition_identifier(self) -> str:
@dataclass(kw_only=True)
class ComponentMixin:
- component_id: Optional[UUID] = None
- component_type: Optional[ComponentType] = None
+ component_id: UUID | None = None
+ component_type: ComponentType | None = None
@dataclass(kw_only=True)
class BatchPipelineMixin:
- batch_pipeline_id: Optional[UUID] = None
- run_id: Optional[UUID] = None
+ batch_pipeline_id: UUID | None = None
+ run_id: UUID | None = None
@cached_property
def batch_pipeline(self) -> Pipeline:
@@ -66,13 +65,13 @@ def run(self) -> Run:
@dataclass(kw_only=True)
class RunMixin:
- run_id: Optional[UUID] = None
+ run_id: UUID | None = None
@dataclass(kw_only=True)
class TaskMixin:
- task_id: Optional[UUID] = None
- run_task_id: Optional[UUID] = None
+ task_id: UUID | None = None
+ run_task_id: UUID | None = None
@cached_property
def task(self) -> Task:
@@ -109,11 +108,11 @@ class JourneysMixin:
@dataclass(kw_only=True)
class JourneyMixin:
- journey_id: Optional[UUID] = None
- instance_id: Optional[UUID] = None
+ journey_id: UUID | None = None
+ instance_id: UUID | None = None
@dataclass(kw_only=True)
class EventBaseMixin:
event_id: UUID = field(default_factory=uuid4)
- created_timestamp: datetime = field(default_factory=partial(datetime.now, tz=timezone.utc))
+ created_timestamp: datetime = field(default_factory=partial(datetime.now, tz=UTC))
diff --git a/common/events/converters.py b/common/events/converters.py
index accaf8b..88a9429 100644
--- a/common/events/converters.py
+++ b/common/events/converters.py
@@ -1,5 +1,5 @@
from dataclasses import asdict, fields
-from typing import Optional, cast
+from typing import cast
from common.entities import ComponentType, RunStatus
from common.entities.event import ApiEventType
@@ -29,7 +29,7 @@ def _extract_common_attributes(self, event: EventV2) -> dict:
"payload_keys": event.event_payload.payload_keys,
}
- def _extract_batch_attributes(self, batch: Optional[v2.BatchPipelineData]) -> dict:
+ def _extract_batch_attributes(self, batch: v2.BatchPipelineData | None) -> dict:
data = {
"run_name": batch.run_name if batch else None,
"run_key": batch.run_key if batch else None,
@@ -40,7 +40,7 @@ def _extract_batch_attributes(self, batch: Optional[v2.BatchPipelineData]) -> di
data["component_tool"] = batch.details.tool
return data
- def _extract_dataset_attributes(self, dataset: Optional[v2.DatasetData]) -> dict:
+ def _extract_dataset_attributes(self, dataset: v2.DatasetData | None) -> dict:
data = {
"dataset_key": dataset.dataset_key if dataset else None,
"dataset_name": dataset.details.name if dataset and dataset.details else None,
@@ -49,7 +49,7 @@ def _extract_dataset_attributes(self, dataset: Optional[v2.DatasetData]) -> dict
data["component_tool"] = dataset.details.tool
return data
- def _extract_server_attributes(self, server: Optional[v2.ServerData]) -> dict:
+ def _extract_server_attributes(self, server: v2.ServerData | None) -> dict:
data = {
"server_key": server.server_key if server else None,
"server_name": server.details.name if server and server.details else None,
@@ -58,7 +58,7 @@ def _extract_server_attributes(self, server: Optional[v2.ServerData]) -> dict:
data["component_tool"] = server.details.tool
return data
- def _extract_stream_attributes(self, stream: Optional[v2.StreamData]) -> dict:
+ def _extract_stream_attributes(self, stream: v2.StreamData | None) -> dict:
data = {
"stream_key": stream.stream_key if stream else None,
"stream_name": stream.details.name if stream and stream.details else None,
@@ -69,10 +69,10 @@ def _extract_stream_attributes(self, stream: Optional[v2.StreamData]) -> dict:
def _extract_component_data(
self,
- batch: Optional[v2.BatchPipelineData],
- dataset: Optional[v2.DatasetData],
- server: Optional[v2.ServerData],
- stream: Optional[v2.StreamData],
+ batch: v2.BatchPipelineData | None,
+ dataset: v2.DatasetData | None,
+ server: v2.ServerData | None,
+ stream: v2.StreamData | None,
) -> dict:
data = {
**self._extract_batch_attributes(batch),
@@ -84,7 +84,7 @@ def _extract_component_data(
data["component_tool"] = None
return data
- def _extract_task_attributes(self, event: EventV2, batch: Optional[v2.BatchPipelineData]) -> dict:
+ def _extract_task_attributes(self, event: EventV2, batch: v2.BatchPipelineData | None) -> dict:
return {
"task_key": batch.task_key if batch else None,
"task_name": batch.task_name if batch else None,
@@ -116,8 +116,8 @@ def _extract_testgen_item(self, testgen: dict) -> v1.TestgenItem:
)
def _extract_test_outcome_item_integrations(
- self, integrations: Optional[dict]
- ) -> Optional[v1.TestOutcomeItemIntegrations]:
+ self, integrations: dict | None
+ ) -> v1.TestOutcomeItemIntegrations | None:
if integrations is None:
return None
return v1.TestOutcomeItemIntegrations(testgen=self._extract_testgen_item(integrations["testgen"]))
@@ -142,8 +142,8 @@ def _extract_testgen_table(self, tables: dict) -> v1.TestgenTable:
)
def _extract_testgen_table_group_config(
- self, table_group_configuration: Optional[dict]
- ) -> Optional[v1.TestgenTableGroupV1]:
+ self, table_group_configuration: dict | None
+ ) -> v1.TestgenTableGroupV1 | None:
if table_group_configuration is None:
return None
return v1.TestgenTableGroupV1(
@@ -167,7 +167,7 @@ def _extract_testgen_integration_componenet(self, integrations: dict) -> v1.Test
def _extract_component_integrations(
self, component: v2.TestGenComponentData
- ) -> Optional[v1.TestGenTestOutcomeIntegrationComponent]:
+ ) -> v1.TestGenTestOutcomeIntegrationComponent | None:
integrations = next(c for f in fields(component) if (c := getattr(component, f.name, None))).integrations
if integrations is None:
return None
@@ -300,7 +300,7 @@ def _extract_common_internal_attributes(self, event: Event) -> dict:
"version": event.version,
}
- def _extract_batch_pipeline_data(self, event: Event) -> Optional[v2.BatchPipelineData]:
+ def _extract_batch_pipeline_data(self, event: Event) -> v2.BatchPipelineData | None:
new_component_data = (
v2.NewComponentData(name=event.pipeline_name, tool=event.component_tool)
if event.pipeline_name or event.component_tool
@@ -318,9 +318,7 @@ def _extract_batch_pipeline_data(self, event: Event) -> Optional[v2.BatchPipelin
else:
return None
- def _extract_testgen_batch_pipeline_data(
- self, event: v1.TestOutcomesEvent
- ) -> Optional[v2.TestGenBatchPipelineData]:
+ def _extract_testgen_batch_pipeline_data(self, event: v1.TestOutcomesEvent) -> v2.TestGenBatchPipelineData | None:
if data := self._extract_batch_pipeline_data(event):
return v2.TestGenBatchPipelineData(
batch_key=data.batch_key,
@@ -333,7 +331,7 @@ def _extract_testgen_batch_pipeline_data(
)
return None
- def _extract_testgen_dataset_data(self, event: v1.TestOutcomesEvent) -> Optional[v2.TestGenDatasetData]:
+ def _extract_testgen_dataset_data(self, event: v1.TestOutcomesEvent) -> v2.TestGenDatasetData | None:
if data := self._extract_dataset_data(event):
return v2.TestGenDatasetData(
dataset_key=data.dataset_key,
@@ -342,7 +340,7 @@ def _extract_testgen_dataset_data(self, event: v1.TestOutcomesEvent) -> Optional
)
return None
- def _extract_testgen_stream_data(self, event: v1.TestOutcomesEvent) -> Optional[v2.TestGenStreamData]:
+ def _extract_testgen_stream_data(self, event: v1.TestOutcomesEvent) -> v2.TestGenStreamData | None:
if data := self._extract_stream_data(event):
return v2.TestGenStreamData(
stream_key=data.stream_key,
@@ -351,7 +349,7 @@ def _extract_testgen_stream_data(self, event: v1.TestOutcomesEvent) -> Optional[
)
return None
- def _extract_testgen_server_data(self, event: v1.TestOutcomesEvent) -> Optional[v2.TestGenServerData]:
+ def _extract_testgen_server_data(self, event: v1.TestOutcomesEvent) -> v2.TestGenServerData | None:
if data := self._extract_server_data(event):
return v2.TestGenServerData(
server_key=data.server_key,
@@ -360,7 +358,7 @@ def _extract_testgen_server_data(self, event: v1.TestOutcomesEvent) -> Optional[
)
return None
- def _extract_dataset_data(self, event: Event) -> Optional[v2.DatasetData]:
+ def _extract_dataset_data(self, event: Event) -> v2.DatasetData | None:
new_component_data = (
v2.NewComponentData(name=event.dataset_name, tool=event.component_tool)
if event.dataset_name or event.component_tool
@@ -371,7 +369,7 @@ def _extract_dataset_data(self, event: Event) -> Optional[v2.DatasetData]:
else:
return None
- def _extract_stream_data(self, event: Event) -> Optional[v2.StreamData]:
+ def _extract_stream_data(self, event: Event) -> v2.StreamData | None:
new_component_data = (
v2.NewComponentData(name=event.stream_name, tool=event.component_tool)
if event.stream_name or event.component_tool
@@ -385,7 +383,7 @@ def _extract_stream_data(self, event: Event) -> Optional[v2.StreamData]:
else:
return None
- def _extract_server_data(self, event: Event) -> Optional[v2.ServerData]:
+ def _extract_server_data(self, event: Event) -> v2.ServerData | None:
new_component_data = (
v2.NewComponentData(name=event.server_name, tool=event.component_tool)
if event.server_name or event.component_tool
@@ -438,8 +436,8 @@ def _extract_testgen_item(self, testgen: dict) -> v2.TestgenItem:
)
def _extract_test_outcome_item_integrations(
- self, integrations: Optional[dict]
- ) -> Optional[v2.TestOutcomeItemIntegrations]:
+ self, integrations: dict | None
+ ) -> v2.TestOutcomeItemIntegrations | None:
if integrations is None:
return None
return v2.TestOutcomeItemIntegrations(testgen=self._extract_testgen_item(integrations["testgen"]))
@@ -464,8 +462,8 @@ def _extract_testgen_table(self, tables: dict) -> v2.TestgenTable:
)
def _extract_testgen_table_group_config(
- self, table_group_configuration: Optional[dict]
- ) -> Optional[v2.TestgenTableGroupV1]:
+ self, table_group_configuration: dict | None
+ ) -> v2.TestgenTableGroupV1 | None:
if table_group_configuration is None:
return None
return v2.TestgenTableGroupV1(
@@ -484,7 +482,7 @@ def _extract_testgen_integrations(self, testgen: dict) -> v2.TestgenDataset:
def _extract_testgen_integration_componenet(
self, event: v1.TestOutcomesEvent
- ) -> Optional[v2.TestGenTestOutcomeIntegrations]:
+ ) -> v2.TestGenTestOutcomeIntegrations | None:
if i := event.component_integrations:
return v2.TestGenTestOutcomeIntegrations(
testgen=self._extract_testgen_integrations(asdict(i.integrations.testgen)),
diff --git a/common/events/internal/alert.py b/common/events/internal/alert.py
index a928702..889ead1 100644
--- a/common/events/internal/alert.py
+++ b/common/events/internal/alert.py
@@ -4,7 +4,6 @@
]
from dataclasses import dataclass
-from typing import Optional
from uuid import UUID
from common.decorators import cached_property
@@ -20,7 +19,7 @@
class AlertBase:
alert_id: UUID
level: AlertLevel
- description: Optional[str]
+ description: str | None
@dataclass(kw_only=True)
diff --git a/common/events/internal/scheduled_event.py b/common/events/internal/scheduled_event.py
index 0c246f0..2844898 100644
--- a/common/events/internal/scheduled_event.py
+++ b/common/events/internal/scheduled_event.py
@@ -2,7 +2,6 @@
from dataclasses import dataclass
from datetime import datetime
-from typing import Optional
from uuid import UUID
from common.events.base import ComponentMixin
@@ -19,7 +18,7 @@ class ScheduledEvent(ComponentMixin):
schedule_id: UUID
schedule_type: ScheduleType
schedule_timestamp: datetime
- schedule_margin: Optional[datetime] = None
+ schedule_margin: datetime | None = None
@property
def partition_identifier(self) -> str:
diff --git a/common/events/v1/dataset_operation_event.py b/common/events/v1/dataset_operation_event.py
index 8594752..f5e9ac1 100644
--- a/common/events/v1/dataset_operation_event.py
+++ b/common/events/v1/dataset_operation_event.py
@@ -3,7 +3,6 @@
__all__ = ["DatasetOperationSchema", "DatasetOperationApiSchema", "DatasetOperationEvent", "DatasetOperationType"]
from dataclasses import dataclass
-from typing import Optional
from marshmallow import Schema, ValidationError, validates_schema
from marshmallow.fields import Str
@@ -55,7 +54,7 @@ class DatasetOperationEvent(Event):
__api_schema__ = DatasetOperationApiSchema
operation: str
- path: Optional[str] = None
+ path: str | None = None
def accept(self, handler: EventHandlerBase) -> bool:
return handler.handle_dataset_operation(self)
diff --git a/common/events/v1/event.py b/common/events/v1/event.py
index 47de3a2..2aec201 100644
--- a/common/events/v1/event.py
+++ b/common/events/v1/event.py
@@ -4,8 +4,7 @@
import logging
from dataclasses import InitVar, dataclass, field
-from datetime import datetime, timezone
-from typing import Optional
+from datetime import datetime, UTC
from uuid import UUID, uuid4
from common.decorators import cached_property
@@ -72,35 +71,35 @@ class Event(EventInterface):
this to define your own Event to ensure your events have all the expected fields.
"""
- pipeline_key: Optional[str]
+ pipeline_key: str | None
source: str
event_id: UUID
event_timestamp: datetime
received_timestamp: datetime
metadata: dict[str, object]
event_type: str
- run_name: Optional[str]
- run_key: Optional[str]
- component_tool: Optional[str]
- project_id: Optional[UUID]
- run_id: Optional[UUID]
- pipeline_id: Optional[UUID]
- task_id: Optional[UUID]
- task_name: Optional[str]
- task_key: Optional[str]
- run_task_id: Optional[UUID]
- external_url: Optional[str]
- pipeline_name: Optional[str]
+ run_name: str | None
+ run_key: str | None
+ component_tool: str | None
+ project_id: UUID | None
+ run_id: UUID | None
+ pipeline_id: UUID | None
+ task_id: UUID | None
+ task_name: str | None
+ task_key: str | None
+ run_task_id: UUID | None
+ external_url: str | None
+ pipeline_name: str | None
instances: list[InstanceRef]
- dataset_id: Optional[UUID]
- dataset_key: Optional[str]
- dataset_name: Optional[str]
- server_id: Optional[UUID]
- server_key: Optional[str]
- server_name: Optional[str]
- stream_id: Optional[UUID]
- stream_key: Optional[str]
- stream_name: Optional[str]
+ dataset_id: UUID | None
+ dataset_key: str | None
+ dataset_name: str | None
+ server_id: UUID | None
+ server_key: str | None
+ server_name: str | None
+ stream_id: UUID | None
+ stream_key: str | None
+ stream_name: str | None
payload_keys: list[str]
version: EventVersion
@@ -131,7 +130,7 @@ def as_event_from_request(cls, request_body: dict) -> Event:
# At a glance, we could have these defaulted by the schema. However, the spec says that if timestamp is not
# defined, it _must_ be matched to the received time. Setting it in the schema would generate very tiny
# differences in time.
- current_time = str(datetime.now(timezone.utc))
+ current_time = str(datetime.now(UTC))
if "event_timestamp" not in event_body:
event_body["event_timestamp"] = current_time
event_body["received_timestamp"] = current_time
@@ -215,7 +214,7 @@ def component_journeys(self) -> list[Journey]:
@cached_property
def component_key_details(self) -> EventComponentDetails:
if not (key := next((attr for attr in EVENT_ATTRIBUTES.keys() if getattr(self, attr, None) is not None), None)):
- LOG.error(f"Event component key details cannot be parsed from the event information provided: " f"{self}")
+ LOG.error(f"Event component key details cannot be parsed from the event information provided: {self}")
raise ValueError("Event component key details cannot be parsed.")
return EVENT_ATTRIBUTES[key]
@@ -227,7 +226,7 @@ def component_key(self) -> str:
return key
@property
- def component_id(self) -> Optional[UUID]:
+ def component_id(self) -> UUID | None:
return getattr(self, self.component_key_details.component_id, None)
@component_id.setter
@@ -239,7 +238,7 @@ def component_id(self, value: UUID) -> None:
setattr(self, self.component_key_details.component_id, value)
@property
- def component_name(self) -> Optional[str]:
+ def component_name(self) -> str | None:
return getattr(self, self.component_key_details.component_name, None)
@property
diff --git a/common/events/v1/event_schemas.py b/common/events/v1/event_schemas.py
index 1c63efc..f8ad150 100644
--- a/common/events/v1/event_schemas.py
+++ b/common/events/v1/event_schemas.py
@@ -1,7 +1,7 @@
__all__ = ["EventSchemaInterface", "EventApiSchema", "EventSchema"]
import json
-from datetime import timezone
+from datetime import UTC
from typing import Any, Union
from marshmallow import Schema, ValidationError, post_dump, post_load, pre_load, validates_schema
@@ -176,7 +176,7 @@ class EventApiSchema(EventSchemaInterface):
)
event_timestamp = AwareDateTime(
format="iso",
- default_timezone=timezone.utc,
+ default_timezone=UTC,
metadata={
"description": (
"Optional. An ISO8601 timestamp that describes when the event occurred. If no timezone "
@@ -251,7 +251,7 @@ class EventSchema(EventApiSchema):
received_timestamp = AwareDateTime(
format="iso",
required=True,
- default_timezone=timezone.utc,
+ default_timezone=UTC,
metadata={"description": "An ISO timestamp that the Event Ingestion API applies when it receives the event."},
)
# This is the source of the message.
diff --git a/common/events/v1/test_outcomes_event.py b/common/events/v1/test_outcomes_event.py
index 37c64be..9bd06ea 100644
--- a/common/events/v1/test_outcomes_event.py
+++ b/common/events/v1/test_outcomes_event.py
@@ -18,11 +18,11 @@
]
from dataclasses import asdict, dataclass
-from datetime import datetime, timezone
+from datetime import datetime, UTC
from decimal import Decimal as std_decimal
from enum import Enum as std_Enum
from enum import IntEnum as std_IntEnum
-from typing import Any, Optional, Union
+from typing import Any, Union
from uuid import UUID as std_UUID
from marshmallow import Schema, ValidationError, post_load, validates_schema
@@ -100,7 +100,7 @@ class TestgenItem:
test_suite: str
version: int
test_parameters: list[TestgenItemTestParameters]
- columns: Optional[list[str]] = None
+ columns: list[str] | None = None
class TestgenItemSchema(Schema):
@@ -163,8 +163,8 @@ def to_testoutcome_item_integrations(self, data: dict, **_: Any) -> TestOutcomeI
@dataclass
class TestgenTable:
include_list: list[str]
- include_pattern: Optional[str] = None
- exclude_pattern: Optional[str] = None
+ include_pattern: str | None = None
+ exclude_pattern: str | None = None
class TestgenTableSchema(Schema):
@@ -217,9 +217,9 @@ def to_testgen_table(self, data: dict, **_: Any) -> TestgenTable:
class TestgenTableGroupV1:
group_id: std_UUID
project_code: str
- uses_sampling: Optional[bool] = None
- sample_percentage: Optional[str] = None
- sample_minimum_count: Optional[int] = None
+ uses_sampling: bool | None = None
+ sample_percentage: str | None = None
+ sample_minimum_count: int | None = None
class TestgenTableGroupV1Schema(Schema):
@@ -263,8 +263,8 @@ class TestgenDataset:
database_name: str
connection_name: str
tables: TestgenTable
- schema: Optional[str] = None
- table_group_configuration: Optional[TestgenTableGroupV1] = None
+ schema: str | None = None
+ table_group_configuration: TestgenTableGroupV1 | None = None
class TestgenDatasetSchema(Schema):
@@ -346,19 +346,19 @@ class TestOutcomeItem:
name: str
status: str
description: str = ""
- start_time: Optional[datetime] = None
- end_time: Optional[datetime] = None
- metadata: Optional[dict[str, Any]] = None
- metric_value: Optional[Decimal] = None
- metric_name: Optional[str] = None
- metric_description: Optional[str] = None
- min_threshold: Optional[Decimal] = None
- max_threshold: Optional[Decimal] = None
- integrations: Optional[TestOutcomeItemIntegrations] = None
- dimensions: Optional[list[str]] = None
- result: Optional[str] = None
- type: Optional[str] = None
- key: Optional[str] = None
+ start_time: datetime | None = None
+ end_time: datetime | None = None
+ metadata: dict[str, Any] | None = None
+ metric_value: Decimal | None = None
+ metric_name: str | None = None
+ metric_description: str | None = None
+ min_threshold: Decimal | None = None
+ max_threshold: Decimal | None = None
+ integrations: TestOutcomeItemIntegrations | None = None
+ dimensions: list[str] | None = None
+ result: str | None = None
+ type: str | None = None
+ key: str | None = None
# region Schemas
@@ -373,8 +373,7 @@ class TestOutcomeItemSchema(Schema):
enum=TestStatuses,
metadata={
"description": (
- "Required. The test status to be applied. Can set the status for both tests in runs and "
- "tests in tasks."
+ "Required. The test status to be applied. Can set the status for both tests in runs and tests in tasks."
)
},
)
@@ -384,13 +383,13 @@ class TestOutcomeItemSchema(Schema):
)
start_time = AwareDateTime(
format="iso",
- default_timezone=timezone.utc,
+ default_timezone=UTC,
allow_none=True,
metadata={"description": "An ISO timestamp of when the test execution started."},
)
end_time = AwareDateTime(
format="iso",
- default_timezone=timezone.utc,
+ default_timezone=UTC,
allow_none=True,
metadata={"description": "An ISO timestamp of when the test execution ended."},
)
@@ -530,7 +529,7 @@ class TestOutcomesEvent(Event):
"""Represents the single result of a test."""
test_outcomes: list[TestOutcomeItem]
- component_integrations: Optional[TestGenTestOutcomeIntegrationComponent] = None
+ component_integrations: TestGenTestOutcomeIntegrationComponent | None = None
__schema__ = TestOutcomesSchema
__api_schema__ = TestOutcomesApiSchema
diff --git a/common/events/v2/base.py b/common/events/v2/base.py
index e983983..4451195 100644
--- a/common/events/v2/base.py
+++ b/common/events/v2/base.py
@@ -7,8 +7,7 @@
from dataclasses import dataclass
-from datetime import datetime, timezone
-from typing import Optional
+from datetime import datetime, UTC
from marshmallow import Schema
from marshmallow.fields import UUID, AwareDateTime, Dict, List, Str, Url
@@ -24,9 +23,9 @@
@dataclass
class BasePayload:
- event_timestamp: Optional[datetime]
+ event_timestamp: datetime | None
metadata: dict[str, object]
- external_url: Optional[str]
+ external_url: str | None
payload_keys: list[str]
@@ -38,7 +37,7 @@ class BasePayloadSchema(Schema):
event_timestamp = AwareDateTime(
load_default=None,
format="iso",
- default_timezone=timezone.utc,
+ default_timezone=UTC,
metadata={
"description": (
"An ISO8601 timestamp that describes when the event occurred. "
diff --git a/common/events/v2/component_data.py b/common/events/v2/component_data.py
index 51f9f24..e0cbc70 100644
--- a/common/events/v2/component_data.py
+++ b/common/events/v2/component_data.py
@@ -15,7 +15,7 @@
]
from dataclasses import dataclass
-from typing import Any, Optional
+from typing import Any
from marshmallow import Schema, ValidationError, post_load, validates_schema
from marshmallow.fields import Nested, Str
@@ -28,44 +28,44 @@
@dataclass
class NewComponentData:
- name: Optional[str]
- tool: Optional[str]
+ name: str | None
+ tool: str | None
@dataclass
class BatchPipelineData:
batch_key: str
run_key: str
- run_name: Optional[str]
- task_key: Optional[str]
- task_name: Optional[str]
- details: Optional[NewComponentData]
+ run_name: str | None
+ task_key: str | None
+ task_name: str | None
+ details: NewComponentData | None
@dataclass
class DatasetData:
dataset_key: str
- details: Optional[NewComponentData]
+ details: NewComponentData | None
@dataclass
class ServerData:
server_key: str
- details: Optional[NewComponentData]
+ details: NewComponentData | None
@dataclass
class StreamData:
stream_key: str
- details: Optional[NewComponentData]
+ details: NewComponentData | None
@dataclass
class ComponentData:
- batch_pipeline: Optional[BatchPipelineData]
- stream: Optional[StreamData]
- dataset: Optional[DatasetData]
- server: Optional[ServerData]
+ batch_pipeline: BatchPipelineData | None
+ stream: StreamData | None
+ dataset: DatasetData | None
+ server: ServerData | None
class NewComponentDataSchema(Schema):
diff --git a/common/events/v2/dataset_operation.py b/common/events/v2/dataset_operation.py
index 22de9a6..9f9cb0e 100644
--- a/common/events/v2/dataset_operation.py
+++ b/common/events/v2/dataset_operation.py
@@ -7,7 +7,7 @@
from dataclasses import dataclass
from enum import Enum as std_Enum
-from typing import Any, Optional
+from typing import Any
from marshmallow import post_load
from marshmallow.fields import Enum, Nested, Str
@@ -28,7 +28,7 @@ class DatasetOperationType(std_Enum):
class DatasetOperation(BasePayload):
dataset_component: DatasetData
operation: DatasetOperationType
- path: Optional[str]
+ path: str | None
class DatasetOperationSchema(BasePayloadSchema):
diff --git a/common/events/v2/test_outcomes.py b/common/events/v2/test_outcomes.py
index 67c0fd8..b529781 100644
--- a/common/events/v2/test_outcomes.py
+++ b/common/events/v2/test_outcomes.py
@@ -13,10 +13,10 @@
from dataclasses import dataclass
from dataclasses import fields as dc_fields
-from datetime import datetime, timezone
+from datetime import datetime, UTC
from decimal import Decimal as std_Decimal
from enum import Enum as std_Enum
-from typing import Any, Optional
+from typing import Any
from marshmallow import Schema, ValidationError, post_load, validates_schema
from marshmallow.fields import AwareDateTime, Decimal, Dict, Enum, List, Nested, Str
@@ -56,23 +56,23 @@ class TestOutcomeItem:
status: TestStatus
description: str
metadata: dict[str, Any]
- start_time: Optional[datetime]
- end_time: Optional[datetime]
- metric_value: Optional[std_Decimal]
- metric_name: Optional[str]
- metric_description: Optional[str]
- metric_min_threshold: Optional[std_Decimal]
- metric_max_threshold: Optional[std_Decimal]
- integrations: Optional[TestOutcomeItemIntegrations]
- dimensions: Optional[list[str]]
- result: Optional[str]
- type: Optional[str]
- key: Optional[str]
+ start_time: datetime | None
+ end_time: datetime | None
+ metric_value: std_Decimal | None
+ metric_name: str | None
+ metric_description: str | None
+ metric_min_threshold: std_Decimal | None
+ metric_max_threshold: std_Decimal | None
+ integrations: TestOutcomeItemIntegrations | None
+ dimensions: list[str] | None
+ result: str | None
+ type: str | None
+ key: str | None
@dataclass
class TestGenIntegrations:
- integrations: Optional[TestGenTestOutcomeIntegrations]
+ integrations: TestGenTestOutcomeIntegrations | None
@dataclass
@@ -93,10 +93,10 @@ class TestGenStreamData(StreamData, TestGenIntegrations): ...
@dataclass
class TestGenComponentData:
- batch_pipeline: Optional[TestGenBatchPipelineData]
- stream: Optional[TestGenStreamData]
- dataset: Optional[TestGenDatasetData]
- server: Optional[TestGenServerData]
+ batch_pipeline: TestGenBatchPipelineData | None
+ stream: TestGenStreamData | None
+ dataset: TestGenDatasetData | None
+ server: TestGenServerData | None
@dataclass
@@ -125,8 +125,7 @@ class TestOutcomeItemSchema(Schema):
enum=TestStatus,
metadata={
"description": (
- "Required. The test status to be applied. Can set the status for both tests in runs and "
- "tests in tasks."
+ "Required. The test status to be applied. Can set the status for both tests in runs and tests in tasks."
)
},
)
@@ -144,13 +143,13 @@ class TestOutcomeItemSchema(Schema):
)
start_time = AwareDateTime(
format="iso",
- default_timezone=timezone.utc,
+ default_timezone=UTC,
load_default=None,
metadata={"description": "An ISO timestamp of when the test execution started."},
)
end_time = AwareDateTime(
format="iso",
- default_timezone=timezone.utc,
+ default_timezone=UTC,
load_default=None,
metadata={"description": "An ISO timestamp of when the test execution ended."},
)
diff --git a/common/events/v2/testgen.py b/common/events/v2/testgen.py
index 7d45c1a..56978c0 100644
--- a/common/events/v2/testgen.py
+++ b/common/events/v2/testgen.py
@@ -14,7 +14,7 @@
from dataclasses import asdict, dataclass
from decimal import Decimal as std_Decimal
from enum import IntEnum
-from typing import Any, Optional, Union
+from typing import Any, Union
from uuid import UUID as std_UUID
from marshmallow import Schema, ValidationError, post_load, validates_schema
@@ -60,7 +60,7 @@ class TestgenItem:
test_suite: str
version: TestgenIntegrationVersions
test_parameters: list[TestgenItemTestParameters]
- columns: Optional[list[str]]
+ columns: list[str] | None
@dataclass
@@ -71,17 +71,17 @@ class TestOutcomeItemIntegrations:
@dataclass
class TestgenTable:
include_list: list[str]
- include_pattern: Optional[str]
- exclude_pattern: Optional[str]
+ include_pattern: str | None
+ exclude_pattern: str | None
@dataclass
class TestgenTableGroupV1:
group_id: std_UUID
project_code: str
- uses_sampling: Optional[bool]
- sample_percentage: Optional[str]
- sample_minimum_count: Optional[int]
+ uses_sampling: bool | None
+ sample_percentage: str | None
+ sample_minimum_count: int | None
@dataclass
@@ -90,8 +90,8 @@ class TestgenDataset:
database_name: str
connection_name: str
tables: TestgenTable
- schema: Optional[str]
- table_group_configuration: Optional[TestgenTableGroupV1]
+ schema: str | None
+ table_group_configuration: TestgenTableGroupV1 | None
@dataclass
diff --git a/common/kafka/consumer.py b/common/kafka/consumer.py
index 402f872..b9372d3 100644
--- a/common/kafka/consumer.py
+++ b/common/kafka/consumer.py
@@ -3,7 +3,7 @@
import logging
import signal
from types import FrameType
-from typing import Any, Optional
+from typing import Any
from collections.abc import Iterator
from confluent_kafka import Consumer, Message
@@ -49,7 +49,7 @@ def init_handlers() -> None:
signal.signal(signal.SIGTERM, GracefulKiller.exit_gracefully)
@staticmethod
- def exit_gracefully(sig_num: int, frame: Optional[FrameType]) -> Any:
+ def exit_gracefully(sig_num: int, frame: FrameType | None) -> Any:
LOG.info(f"Signal {sig_num} received, attempting to exit gracefully. Use SIGKILL to terminate immediately.")
GracefulKiller.should_exit = True
@@ -152,9 +152,9 @@ def commit(self) -> Any:
except Exception as e:
raise ConsumerCommitError from e
- def poll(self) -> Optional[KafkaMessage]:
+ def poll(self) -> KafkaMessage | None:
try:
- msg: Optional[Message] = self.consumer.poll(CONSUMER_POLL_PERIOD_SECS)
+ msg: Message | None = self.consumer.poll(CONSUMER_POLL_PERIOD_SECS)
except DisconnectedConsumerError:
raise
except Exception as ex:
diff --git a/common/kafka/message.py b/common/kafka/message.py
index ed13d51..b6f000b 100644
--- a/common/kafka/message.py
+++ b/common/kafka/message.py
@@ -1,12 +1,12 @@
__all__ = ["KafkaMessage"]
from dataclasses import dataclass
-from typing import Generic, Optional, TypeVar
+from typing import TypeVar
T = TypeVar("T")
@dataclass(frozen=True, kw_only=True)
-class KafkaMessage(Generic[T]):
+class KafkaMessage[T]:
"""
A generic Kafka message
@@ -19,4 +19,4 @@ class KafkaMessage(Generic[T]):
partition: int
offset: int
headers: dict
- key: Optional[str] = None
+ key: str | None = None
diff --git a/common/kafka/producer.py b/common/kafka/producer.py
index e272f74..72e82cf 100644
--- a/common/kafka/producer.py
+++ b/common/kafka/producer.py
@@ -6,7 +6,7 @@
import uuid
from contextlib import contextmanager
from types import TracebackType
-from typing import Any, Optional
+from typing import Any
from collections.abc import Callable, Generator
from confluent_kafka import KafkaError, KafkaException, Message, Producer
@@ -88,7 +88,7 @@ def __enter__(self) -> KafkaProducer:
return self
def __exit__(
- self, exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], tb: Optional[TracebackType]
+ self, exc_type: type[BaseException] | None, exc_value: BaseException | None, tb: TracebackType | None
) -> None:
self.disconnect()
@@ -105,7 +105,7 @@ def is_topic_available(self, topic: Topic, timeout: int = KAFKA_OP_TIMEOUT_SECS)
metadata = self.producer.list_topics(topic=topic.name, timeout=timeout)
return len(metadata.topics[topic.name].partitions) > 0
- def produce(self, topic: Topic, event: Any, callback: Optional[Callable] = None) -> None:
+ def produce(self, topic: Topic, event: Any, callback: Callable | None = None) -> None:
def delivery_report(err: KafkaError, message: Message) -> None:
"""Called once for each message produced to indicate delivery result. Triggered by poll() or flush()."""
if err is not None:
@@ -149,7 +149,7 @@ class KafkaTransactionalProducer(KafkaProducer):
"""
- def __init__(self, config: dict, tx_consumer: Optional[KafkaTransactionalConsumer] = None) -> None:
+ def __init__(self, config: dict, tx_consumer: KafkaTransactionalConsumer | None = None) -> None:
if PRODUCER_TX_MANDATORY_SETTINGS.keys() & config.keys():
raise ProducerConfigurationError(
f"Local configuration cannot override any of {PRODUCER_TX_MANDATORY_SETTINGS.keys()}"
diff --git a/common/kafka/topic.py b/common/kafka/topic.py
index 1988c36..4bd1b69 100644
--- a/common/kafka/topic.py
+++ b/common/kafka/topic.py
@@ -9,7 +9,7 @@
import json
from dataclasses import dataclass
-from typing import Any, NamedTuple, Optional, Protocol
+from typing import Any, NamedTuple, Protocol
from confluent_kafka import Message
@@ -29,7 +29,7 @@ class ProduceMessageArgs(NamedTuple):
value: bytes
topic: str
headers: dict[str, str]
- key: Optional[str] = None
+ key: str | None = None
def as_dict(self) -> dict[str, Any]:
d = {
@@ -44,7 +44,7 @@ def as_dict(self) -> dict[str, Any]:
return d
-def _get_headers_as_dict(headers: Optional[list[tuple[str, bytes]]]) -> dict[str, str]:
+def _get_headers_as_dict(headers: list[tuple[str, bytes]] | None) -> dict[str, str]:
return {k: v.decode("utf-8") for k, v in headers or {}}
diff --git a/common/logging/json_logging.py b/common/logging/json_logging.py
index a912df0..1a4b6d7 100644
--- a/common/logging/json_logging.py
+++ b/common/logging/json_logging.py
@@ -1,12 +1,12 @@
__all__ = ["JsonFormatter"]
import logging
-from datetime import datetime, timezone
+from datetime import datetime, UTC
from json import dumps
from common.json_encoder import JsonExtendedEncoder
-UTC = timezone.utc
+UTC = UTC
class JsonFormatter(logging.Formatter):
diff --git a/common/messagepack.py b/common/messagepack.py
index 7005b8c..c77080e 100644
--- a/common/messagepack.py
+++ b/common/messagepack.py
@@ -9,7 +9,7 @@
from io import BytesIO
from pathlib import Path, PurePath, PurePosixPath, PureWindowsPath
from types import MappingProxyType
-from typing import Any, BinaryIO, Optional, TextIO, Union, cast
+from typing import Any, BinaryIO, TextIO, Union, cast
from collections.abc import Callable
from uuid import UUID
@@ -261,7 +261,7 @@ def decode_ext(code: int, data: bytes) -> object:
return ExtType(code, data)
-def dump(value: object, flo: FLO, hook: Optional[Callable] = None) -> None:
+def dump(value: object, flo: FLO, hook: Callable | None = None) -> None:
"""
Serialize as msgpack and write the result to a file-like-object.
@@ -281,7 +281,7 @@ def dump(value: object, flo: FLO, hook: Optional[Callable] = None) -> None:
)
-def dumps(value: object, hook: Optional[Callable] = None) -> bytes:
+def dumps(value: object, hook: Callable | None = None) -> bytes:
"""
Serialize object to msgpack and return resulting messagepack bytes.
@@ -295,7 +295,7 @@ def dumps(value: object, hook: Optional[Callable] = None) -> bytes:
return result
-def load(flo: FLO, object_hook: Optional[Callable] = None) -> Any:
+def load(flo: FLO, object_hook: Callable | None = None) -> Any:
"""
Deserialize a msgpack file-like-object.
@@ -317,7 +317,7 @@ def load(flo: FLO, object_hook: Optional[Callable] = None) -> Any:
return result
-def loads(stream: bytes, object_hook: Optional[Callable] = None) -> Any:
+def loads(stream: bytes, object_hook: Callable | None = None) -> Any:
"""
Deserialize msgpack bytes
diff --git a/common/model.py b/common/model.py
index d9009fa..9698289 100644
--- a/common/model.py
+++ b/common/model.py
@@ -7,7 +7,6 @@
from contextlib import suppress
from importlib import import_module
from types import ModuleType
-from typing import Optional
from peewee import Database, Model, SchemaManager
@@ -17,7 +16,7 @@
LOG = logging.getLogger(__name__)
-def walk(module: Optional[ModuleType] = None) -> dict[str, Model]:
+def walk(module: ModuleType | None = None) -> dict[str, Model]:
"""
Recursively scans a module for all PeeWee model classes. Defaults to `common.entities` but can scan any module.
diff --git a/common/peewee_extensions/fields.py b/common/peewee_extensions/fields.py
index e229af8..d26d203 100644
--- a/common/peewee_extensions/fields.py
+++ b/common/peewee_extensions/fields.py
@@ -3,11 +3,11 @@
import logging
import re
import socket
-from datetime import datetime, timezone
+from datetime import datetime, UTC
from enum import Enum
from json import dumps as json_dumps
from json import loads as json_loads
-from typing import Any, Optional, Union, cast
+from typing import Any, Union, cast
from re import Pattern
from unicodedata import normalize
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
@@ -28,7 +28,7 @@ def __init__(self, null: bool = False, defaults_to_now: bool = False, **kwargs:
if defaults_to_now:
# This is a convenient method to ensure newly created objects will have a timezone-aware value. It doesn't
# affect what is stored in the database.
- kwargs["default"] = lambda: datetime.utcnow().replace(tzinfo=timezone.utc)
+ kwargs["default"] = lambda: datetime.utcnow().replace(tzinfo=UTC)
else:
kwargs["default"] = None
@@ -36,9 +36,9 @@ def __init__(self, null: bool = False, defaults_to_now: bool = False, **kwargs:
# we put in.
super().__init__(null=null, utc=True, resolution=1_000_000, **kwargs)
- def python_value(self, value: Union[int, float]) -> Optional[datetime]:
+ def python_value(self, value: Union[int, float]) -> datetime | None:
if isinstance(ret_val := super().python_value(value), datetime):
- return ret_val.replace(tzinfo=timezone.utc)
+ return ret_val.replace(tzinfo=UTC)
else:
return None
@@ -90,7 +90,7 @@ def db_value(self, value: str) -> str:
return db_value
-def _enum_value_to_db_value(enum_class: type[Enum], value: Union[str, Enum, None]) -> Optional[str | int]:
+def _enum_value_to_db_value(enum_class: type[Enum], value: Union[str, Enum, None]) -> str | int | None:
"""Converts a value before sending it to the DB."""
if value is None:
return None
@@ -108,7 +108,7 @@ def _enum_value_to_db_value(enum_class: type[Enum], value: Union[str, Enum, None
return value
-def _db_value_to_enum_value(enum_class: type[Enum], value: str | int) -> Optional[Enum]:
+def _db_value_to_enum_value(enum_class: type[Enum], value: str | int) -> Enum | None:
if value:
try:
return enum_class(value)
@@ -125,12 +125,12 @@ def __init__(self, enum_class: type[Enum], **kwargs: Any) -> None:
self.enum_class = enum_class
super().__init__(**kwargs)
- def db_value(self, value: Union[str, Enum, None]) -> Optional[str]:
+ def db_value(self, value: Union[str, Enum, None]) -> str | None:
"""Converts a value before sending it to the DB."""
db_value: str = super().db_value(_enum_value_to_db_value(self.enum_class, value))
return db_value
- def python_value(self, value: str) -> Optional[Enum]:
+ def python_value(self, value: str) -> Enum | None:
return _db_value_to_enum_value(self.enum_class, super().python_value(value))
@@ -141,12 +141,12 @@ def __init__(self, enum_class: type[Enum], **kwargs: Any) -> None:
self.enum_class = enum_class
super().__init__(**kwargs)
- def db_value(self, value: Union[str, Enum, None]) -> Optional[str]:
+ def db_value(self, value: Union[str, Enum, None]) -> str | None:
"""Converts a value before sending it to the DB."""
db_value: str = super().db_value(_enum_value_to_db_value(self.enum_class, value))
return db_value
- def python_value(self, value: str) -> Optional[Enum]:
+ def python_value(self, value: str) -> Enum | None:
return _db_value_to_enum_value(self.enum_class, super().python_value(value))
@@ -174,7 +174,7 @@ class JSONStrListField(TextField):
def __init__(self, **kwargs: Any) -> None:
"""Ensure that `default` is always the list function or a function that returns a list."""
- if (default_func := kwargs.get("default", None)) not in (None, list):
+ if (default_func := kwargs.get("default", None)) and callable(default_func):
try:
_result = default_func()
except Exception as e:
@@ -187,7 +187,7 @@ def __init__(self, **kwargs: Any) -> None:
kwargs["default"] = list
super().__init__(**kwargs)
- def db_value(self, value: Optional[list[str]]) -> Optional[str]:
+ def db_value(self, value: list[str] | None) -> str | None:
"""Dump a list of strings as a JSON string. Keeps key order consistent."""
if value is not None:
if not isinstance(value, list):
@@ -199,7 +199,7 @@ def db_value(self, value: Optional[list[str]]) -> Optional[str]:
else:
return None
- def python_value(self, value: Optional[str]) -> Optional[list[str]]:
+ def python_value(self, value: str | None) -> list[str] | None:
"""Load the text retrieved from the JSON field into a list."""
if value is not None:
try:
@@ -223,7 +223,7 @@ class JSONDictListField(TextField):
def __init__(self, **kwargs: Any) -> None:
"""Ensure that `default` is always the list function or a function that returns a list."""
- if (default_func := kwargs.get("default", None)) not in (None, list):
+ if (default_func := kwargs.get("default", None)) and callable(default_func):
try:
_result = default_func()
except Exception as e:
@@ -236,7 +236,7 @@ def __init__(self, **kwargs: Any) -> None:
kwargs["default"] = list
super().__init__(**kwargs)
- def db_value(self, value: Optional[list[dict]]) -> Optional[str]:
+ def db_value(self, value: list[dict] | None) -> str | None:
"""Dump a list of strings as a JSON string. Keeps key order consistent."""
if value is not None:
if not isinstance(value, list):
@@ -248,7 +248,7 @@ def db_value(self, value: Optional[list[dict]]) -> Optional[str]:
else:
return None
- def python_value(self, value: Optional[str]) -> Optional[list[dict]]:
+ def python_value(self, value: str | None) -> list[dict] | None:
"""Load the text retrieved from the JSON field into a list."""
if value is not None:
try:
@@ -272,7 +272,7 @@ def __init__(self, schema: Schema, **kwargs: Any) -> None:
self.schema = schema
super().__init__(**kwargs)
- def db_value(self, value: Any) -> Optional[str]:
+ def db_value(self, value: Any) -> str | None:
if value is None:
json_data = "null"
else:
@@ -282,7 +282,7 @@ def db_value(self, value: Any) -> Optional[str]:
raise ValueError(f"The value in '{self.name}' can not be dumped by {self.schema}.") from e
return cast(str, super().db_value(json_data))
- def python_value(self, value: Optional[str]) -> Optional[Any]:
+ def python_value(self, value: str | None) -> Any | None:
if value is None or value == "null":
return None
else:
diff --git a/common/peewee_extensions/fixtures.py b/common/peewee_extensions/fixtures.py
index 18dc14e..26a440d 100644
--- a/common/peewee_extensions/fixtures.py
+++ b/common/peewee_extensions/fixtures.py
@@ -4,7 +4,7 @@
from functools import cache
from graphlib import CycleError, TopologicalSorter
from pathlib import Path
-from typing import Any, Optional
+from typing import Any
from uuid import UUID, uuid4
# TODO: When we move to Python 3.11+ we can switch to importing tomlib and we can remove the tomli requirement from
@@ -57,7 +57,7 @@ def generate_table_map() -> dict[str, Model]:
return {x._meta.table_name: x for x in model_map.values()}
-def dump_results(results: ModelSelect, *, requires_id: Optional[UUID] = None) -> str:
+def dump_results(results: ModelSelect, *, requires_id: UUID | None = None) -> str:
"""Given Peewee query results, generate a fixture dump."""
model_class = results.model
rows = []
diff --git a/common/predicate_engine/compilers/utils.py b/common/predicate_engine/compilers/utils.py
index 7819ed3..f7e9225 100644
--- a/common/predicate_engine/compilers/utils.py
+++ b/common/predicate_engine/compilers/utils.py
@@ -17,4 +17,4 @@ def _prefetch_query(query: SelectQuery, *, _queries: list) -> list[Model]:
result: list[Model] = query.prefetch(*_queries, prefetch_type=PREFETCH_TYPE.JOIN)
return result
- return partial(_prefetch_query, _queries=args)
+ return partial(_prefetch_query, _queries=list(args))
diff --git a/common/predicate_engine/query.py b/common/predicate_engine/query.py
index acd1fcd..a59ae50 100644
--- a/common/predicate_engine/query.py
+++ b/common/predicate_engine/query.py
@@ -6,10 +6,10 @@
from abc import ABC, abstractmethod
from collections.abc import Mapping
from copy import copy, deepcopy
-from datetime import datetime, timezone
+from datetime import datetime, UTC
from enum import Enum
from functools import partial, reduce
-from typing import Any, Final, Optional, Union
+from typing import Any, Final, Union
from collections.abc import Callable, Iterable
from ._operators import QueryOperator, get_operator, get_operators, split_operators
@@ -33,10 +33,10 @@ def _ensure_utc(dt: datetime) -> datetime:
'2006-11-06T10:10:10+00:00'
"""
if (tzinfo := dt.tzinfo) is None:
- return dt.replace(tzinfo=timezone.utc)
+ return dt.replace(tzinfo=UTC)
if tzinfo.utcoffset(dt) is None:
- return dt.replace(tzinfo=timezone.utc)
- return dt.astimezone(timezone.utc)
+ return dt.replace(tzinfo=UTC)
+ return dt.astimezone(UTC)
def getattr_recursive(lookup_obj: Any, attr_name: str, *args: Any) -> Any:
@@ -58,8 +58,8 @@ def _getattr(obj: Any, attr: str) -> Any:
class ConnectorType(Enum):
- OR: str = "OR"
- AND: str = "AND"
+ OR = "OR"
+ AND = "AND"
class _Encapsulate(ABC):
@@ -76,8 +76,8 @@ def __init__(
self,
value: object,
*,
- attr_name: Optional[str] = None,
- transform_funcs: Optional[Iterable[Callable[..., Iterable]]] = None,
+ attr_name: str | None = None,
+ transform_funcs: Iterable[Callable[..., Iterable]] | None = None,
) -> None:
self.wrapped_value = value
self.attr_name = attr_name
@@ -175,8 +175,8 @@ def __init__(
value: object,
*,
n: int,
- attr_name: Optional[str] = None,
- transform_funcs: Optional[Iterable[Callable[..., Iterable]]] = None,
+ attr_name: str | None = None,
+ transform_funcs: Iterable[Callable[..., Iterable]] | None = None,
) -> None:
self.count = n
super().__init__(value, attr_name=attr_name, transform_funcs=transform_funcs)
@@ -201,7 +201,7 @@ class ATLEAST(_Encapsulate):
"""
Encapsulate a value to indicate an ATLEAST predicate operation on an iterable.
- The match should be successful if ATLEAST the N first valures are matching.
+ The match should be successful if ATLEAST the N first values are matching.
"""
__slots__ = ("count",)
@@ -214,8 +214,8 @@ def __init__(
value: object,
*,
n: int,
- attr_name: Optional[str] = None,
- transform_funcs: Optional[Iterable[Callable[..., Iterable]]] = None,
+ attr_name: str | None = None,
+ transform_funcs: Iterable[Callable[..., Iterable]] | None = None,
) -> None:
self.count = n
super().__init__(value, attr_name=attr_name, transform_funcs=transform_funcs)
@@ -311,7 +311,7 @@ class R:
Encapsulate rules as objects that can then be combined using & and |
This is an implementation of a tree node for making expressions which can be used to construct rules of arbitrary
- complexity. It is loosely inspired by the Qobject implementation in Django but is object agnostic and not meant for
+ complexity. It is loosely inspired by the Q-object implementation in Django but is object agnostic and not meant for
an ORM.
"""
@@ -329,7 +329,7 @@ def __init__(self, **kwargs: object) -> None:
@classmethod
def _new_instance(
- cls, children: Optional[list] = None, conn_type: ConnectorType = ConnectorType.AND, negated: bool = False
+ cls, children: list | None = None, conn_type: ConnectorType = ConnectorType.AND, negated: bool = False
) -> R:
"""
Creates a new instance of this class.
diff --git a/common/schemas/fields/cron_expr_str.py b/common/schemas/fields/cron_expr_str.py
index e365bb1..586476b 100644
--- a/common/schemas/fields/cron_expr_str.py
+++ b/common/schemas/fields/cron_expr_str.py
@@ -1,6 +1,6 @@
__all__ = ["CronExpressionStr"]
-from typing import Any, Optional
+from typing import Any
from collections.abc import Mapping
from marshmallow import ValidationError
@@ -16,7 +16,7 @@ class CronExpressionStr(Str):
It validates against what ApScheduler's CronTrigger expects.
"""
- def _deserialize(self, value: Any, attr: Optional[str], data: Optional[Mapping[str, Any]], **kwargs: object) -> Any:
+ def _deserialize(self, value: Any, attr: str | None, data: Mapping[str, Any] | None, **kwargs: object) -> Any:
str_value = super(Str, self)._deserialize(value, attr, data, **kwargs)
if errors := validate_cron_expression(str_value):
raise ValidationError(" ".join(errors))
diff --git a/common/schemas/fields/enum_str.py b/common/schemas/fields/enum_str.py
index 87fa98f..7dc52ba 100644
--- a/common/schemas/fields/enum_str.py
+++ b/common/schemas/fields/enum_str.py
@@ -1,7 +1,7 @@
__all__ = ["EnumStr"]
from enum import Enum, EnumMeta
-from typing import Any, Optional, Union, cast
+from typing import Any, Union, cast
from collections.abc import Iterable
from marshmallow.utils import ensure_text_type
@@ -30,7 +30,7 @@ def __init__(self, enum: Union[EnumMeta, list], **kwargs: object) -> None:
super().__init__(validate=OneOf(allowed_values), **kwargs) # type: ignore[arg-type]
- def _serialize(self, value: Any, attr: Optional[str], obj: Any, **kwargs: object) -> Optional[str]:
+ def _serialize(self, value: Any, attr: str | None, obj: Any, **kwargs: object) -> str | None:
if value is None:
return None
if isinstance(value, Enum):
diff --git a/common/schemas/fields/normalized_str.py b/common/schemas/fields/normalized_str.py
index 4099162..48c3867 100644
--- a/common/schemas/fields/normalized_str.py
+++ b/common/schemas/fields/normalized_str.py
@@ -1,6 +1,6 @@
__all__ = ["NormalizedStr", "strip_upper_underscore"]
-from typing import Any, Optional
+from typing import Any
from collections.abc import Callable, Mapping
from marshmallow.fields import Str
@@ -23,13 +23,13 @@ def __init__(self, normalizer: Callable[[str], str] = str.upper, **kwargs: Any):
self.normalizer_func = normalizer
super().__init__(**kwargs)
- def _serialize(self, value: Any, attr: Optional[str], obj: Any, **kwargs: object) -> Optional[str]:
+ def _serialize(self, value: Any, attr: str | None, obj: Any, **kwargs: object) -> str | None:
if value is None:
return None
str_field = ensure_text_type(value)
return self.normalizer_func(str_field)
- def _deserialize(self, value: Any, attr: Optional[str], data: Optional[Mapping[str, Any]], **kwargs: object) -> Any:
+ def _deserialize(self, value: Any, attr: str | None, data: Mapping[str, Any] | None, **kwargs: object) -> Any:
if not isinstance(value, str | bytes):
raise self.make_error("invalid")
try:
diff --git a/common/schemas/fields/zoneinfo.py b/common/schemas/fields/zoneinfo.py
index e03b2da..1de173e 100644
--- a/common/schemas/fields/zoneinfo.py
+++ b/common/schemas/fields/zoneinfo.py
@@ -1,7 +1,7 @@
__all__ = ["ZoneInfo"]
import zoneinfo
-from typing import Any, Optional
+from typing import Any
from collections.abc import Mapping
from marshmallow.fields import Str
@@ -16,7 +16,7 @@ class ZoneInfo(Str):
default_error_messages = {"invalid_timezone": INVALID_TIMEZONE}
- def _deserialize(self, value: Any, attr: Optional[str], data: Optional[Mapping[str, Any]], **kwargs: object) -> Any:
+ def _deserialize(self, value: Any, attr: str | None, data: Mapping[str, Any] | None, **kwargs: object) -> Any:
str_value = super(Str, self)._deserialize(value, attr, data, **kwargs)
# Given ZoneInfo accepts a filesystem type as its constructor argument and we don't want to accept paths as
# values for ZoneInfo fields, we validate the input before trying to build the ZoneInfo object
diff --git a/common/schemas/filter_schemas.py b/common/schemas/filter_schemas.py
index fd25acd..d2a0d4e 100644
--- a/common/schemas/filter_schemas.py
+++ b/common/schemas/filter_schemas.py
@@ -1,5 +1,5 @@
-from datetime import datetime, timezone
-from typing import Any, Optional, Union
+from datetime import datetime, UTC
+from typing import Any, Union
from marshmallow import EXCLUDE, Schema, ValidationError, pre_load, validates_schema
from marshmallow.fields import UUID, AwareDateTime, Boolean, Enum, List, Str
@@ -11,7 +11,7 @@
class FiltersSchema(Schema):
def validate_time_range(
- self, range_begin: Optional[datetime], range_end: Optional[datetime], range_begin_name: str
+ self, range_begin: datetime | None, range_end: datetime | None, range_begin_name: str
) -> None:
if range_begin is None or range_end is None:
return None
@@ -43,7 +43,7 @@ class InstanceFiltersSchema(FiltersSchema):
start_range_begin = AwareDateTime(
allow_none=True,
format="iso",
- default_timezone=timezone.utc,
+ default_timezone=UTC,
metadata={
"description": "Optional. An ISO8601 datetime. If specified, The result will only include instances with a "
"start_time field equal or past the given datetime. May be specified with start_range_end "
@@ -53,7 +53,7 @@ class InstanceFiltersSchema(FiltersSchema):
start_range_end = AwareDateTime(
allow_none=True,
format="iso",
- default_timezone=timezone.utc,
+ default_timezone=UTC,
metadata={
"description": "Optional. An ISO8601 datetime. If specified, the result will only contain instances with a "
"start_time field before the given datetime. May be specified with start_range_begin to create "
@@ -63,7 +63,7 @@ class InstanceFiltersSchema(FiltersSchema):
end_range_begin = AwareDateTime(
allow_none=True,
format="iso",
- default_timezone=timezone.utc,
+ default_timezone=UTC,
metadata={
"description": "Optional. An ISO8601 datetime. If specified, The result will only include instances with an "
"end_time field equal or past the given datetime. May be specified with end_range_end to create a range."
@@ -72,7 +72,7 @@ class InstanceFiltersSchema(FiltersSchema):
end_range_end = AwareDateTime(
allow_none=True,
format="iso",
- default_timezone=timezone.utc,
+ default_timezone=UTC,
metadata={
"description": "Optional. An ISO8601 datetime. If specified, The result will only include instances with an "
"end_time field before the given datetime. May be specified with end_range_begin to create a range."
diff --git a/common/schemas/validators/regexp.py b/common/schemas/validators/regexp.py
index 87af20d..81631d3 100644
--- a/common/schemas/validators/regexp.py
+++ b/common/schemas/validators/regexp.py
@@ -1,7 +1,6 @@
__all__ = ["IsRegexp"]
import re
-from typing import Optional
from marshmallow import ValidationError
from marshmallow.validate import Validator
@@ -14,7 +13,7 @@ class IsRegexp(Validator):
message_invalid = "Invalid regular expression"
- def __init__(self, *, error: Optional[str] = None) -> None:
+ def __init__(self, *, error: str | None = None) -> None:
self.error: str = error or self.message_invalid
def __call__(self, value: str) -> str:
diff --git a/common/tests/integration/entities/test_alerts.py b/common/tests/integration/entities/test_alerts.py
index cdbda7d..d4afe3c 100644
--- a/common/tests/integration/entities/test_alerts.py
+++ b/common/tests/integration/entities/test_alerts.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
import pytest
@@ -94,7 +94,7 @@ def test_run_alert_expected_start_time_get(pipeline_run):
run_alert = RunAlert(run=pipeline_run, name="A", description="A", level=AlertLevel.ERROR, type="invalid")
assert run_alert.expected_start_time is None
run_alert.details["expected_start_time"] = 1123922544.0
- assert datetime(2005, 8, 13, 8, 42, 24, tzinfo=timezone.utc) == run_alert.expected_start_time
+ assert datetime(2005, 8, 13, 8, 42, 24, tzinfo=UTC) == run_alert.expected_start_time
@pytest.mark.integration
@@ -117,7 +117,7 @@ def test_run_alert_expected_end_time_get(pipeline_run):
run_alert = RunAlert(run=pipeline_run, name="A", description="A", level=AlertLevel.ERROR, type="invalid")
assert run_alert.expected_end_time is None # No details have been added yet
run_alert.details["expected_end_time"] = 1123922544.0
- assert datetime(2005, 8, 13, 8, 42, 24, tzinfo=timezone.utc) == run_alert.expected_end_time
+ assert datetime(2005, 8, 13, 8, 42, 24, tzinfo=UTC) == run_alert.expected_end_time
@pytest.mark.integration
@@ -134,8 +134,8 @@ def test_run_alert_expected_times_naive(pipeline_run):
run_alert.expected_start_time = datetime(2005, 8, 13, 8, 42, 24)
run_alert.expected_end_time = datetime(2005, 8, 13, 8, 55, 24)
- assert run_alert.expected_start_time == datetime(2005, 8, 13, 8, 42, 24, tzinfo=timezone.utc)
- assert run_alert.expected_end_time == datetime(2005, 8, 13, 8, 55, 24, tzinfo=timezone.utc)
+ assert run_alert.expected_start_time == datetime(2005, 8, 13, 8, 42, 24, tzinfo=UTC)
+ assert run_alert.expected_end_time == datetime(2005, 8, 13, 8, 55, 24, tzinfo=UTC)
@pytest.mark.integration
@@ -153,7 +153,7 @@ def test_instance_alert_expected_start_time_get(instance):
instance=instance, name="A", description="A", message="A", level=AlertLevel.WARNING, type="invalid"
)
instance_alert.details["expected_start_time"] = 1123922544.0
- assert datetime(2005, 8, 13, 8, 42, 24, tzinfo=timezone.utc) == instance_alert.expected_start_time
+ assert datetime(2005, 8, 13, 8, 42, 24, tzinfo=UTC) == instance_alert.expected_start_time
@pytest.mark.integration
@@ -171,7 +171,7 @@ def test_instance_alert_expected_end_time_get(instance):
instance=instance, name="A", description="A", message="A", level=AlertLevel.WARNING, type="invalid"
)
instance_alert.details["expected_end_time"] = 1123922544.0
- assert datetime(2005, 8, 13, 8, 42, 24, tzinfo=timezone.utc) == instance_alert.expected_end_time
+ assert datetime(2005, 8, 13, 8, 42, 24, tzinfo=UTC) == instance_alert.expected_end_time
@pytest.mark.integration
@@ -183,8 +183,8 @@ def test_instance_alert_expected_times_naive(instance):
instance_alert.expected_start_time = datetime(2005, 8, 13, 8, 42, 24)
instance_alert.expected_end_time = datetime(2005, 8, 13, 8, 55, 24)
- assert instance_alert.expected_start_time == datetime(2005, 8, 13, 8, 42, 24, tzinfo=timezone.utc)
- assert instance_alert.expected_end_time == datetime(2005, 8, 13, 8, 55, 24, tzinfo=timezone.utc)
+ assert instance_alert.expected_start_time == datetime(2005, 8, 13, 8, 42, 24, tzinfo=UTC)
+ assert instance_alert.expected_end_time == datetime(2005, 8, 13, 8, 55, 24, tzinfo=UTC)
@pytest.mark.integration
diff --git a/common/tests/integration/entities/test_runs.py b/common/tests/integration/entities/test_runs.py
index 8789b63..519a10b 100644
--- a/common/tests/integration/entities/test_runs.py
+++ b/common/tests/integration/entities/test_runs.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta, timezone, UTC
import pytest
from peewee import IntegrityError
@@ -29,7 +29,7 @@ def test_add_pipeline_run_listening(pipeline):
run.save()
# After adding non-RUNNING status, listening should be False
- run.end_time = datetime.utcnow().replace(tzinfo=timezone.utc) + timedelta(days=3)
+ run.end_time = datetime.utcnow().replace(tzinfo=UTC) + timedelta(days=3)
run.status = "COMPLETED"
run.save()
diff --git a/common/tests/integration/entity_services/conftest.py b/common/tests/integration/entity_services/conftest.py
index 342206f..d101b1b 100644
--- a/common/tests/integration/entity_services/conftest.py
+++ b/common/tests/integration/entity_services/conftest.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
import pytest
@@ -26,7 +26,7 @@ def pipeline_4(test_db, project):
@pytest.fixture
def current_time() -> datetime:
- yield datetime.now(timezone.utc)
+ yield datetime.now(UTC)
@pytest.fixture()
@@ -61,7 +61,7 @@ def test_outcomes_instance(test_db, run, pipeline, instance_instance_set):
@pytest.fixture
def test_outcomes_event(project, pipeline, run, event_data):
- timestamp = datetime.now(timezone.utc).isoformat()
+ timestamp = datetime.now(UTC).isoformat()
data = {
"event_type": TestOutcomesEvent.__name__,
"test_outcomes": [
diff --git a/common/tests/integration/entity_services/test_project_service.py b/common/tests/integration/entity_services/test_project_service.py
index f0efc22..aa038ac 100644
--- a/common/tests/integration/entity_services/test_project_service.py
+++ b/common/tests/integration/entity_services/test_project_service.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta, timezone, UTC
from typing import Optional
from uuid import uuid4
@@ -47,7 +47,7 @@ def local_test_db(test_db):
def _add_runs(
- pipeline, instance, number_of_runs: int, current_time: datetime, *, expected_start_time: Optional[datetime] = None
+ pipeline, instance, number_of_runs: int, current_time: datetime, *, expected_start_time: datetime | None = None
):
instance_set = InstanceSet.get_or_create([instance.id])
for key in range(1, number_of_runs + 1):
@@ -158,19 +158,19 @@ def test_get_runs_with_rules_coalesce_sort(pipeline, instance, patched_instance_
err_level = AlertLevel["ERROR"].value
late_type = RunAlertType["LATE_END"].value
- r1_expected_start = datetime(2023, 5, 25, 7, 44, 1, tzinfo=timezone.utc)
+ r1_expected_start = datetime(2023, 5, 25, 7, 44, 1, tzinfo=UTC)
r1 = Run.create(
key="coalesce-run-1",
pipeline=pipeline,
instance_set=instance_set,
expected_start_time=r1_expected_start,
- start_time=datetime(2023, 5, 25, 7, 44, 6, tzinfo=timezone.utc),
- end_time=datetime(2023, 5, 25, 7, 45, 6, tzinfo=timezone.utc),
+ start_time=datetime(2023, 5, 25, 7, 44, 6, tzinfo=UTC),
+ end_time=datetime(2023, 5, 25, 7, 45, 6, tzinfo=UTC),
status=RunStatus.COMPLETED.name,
)
RunAlert.create(name="CA1", description="CD1", level=err_level, type=late_type, run=r1)
- r2_expected_start = datetime(2023, 5, 25, 7, 45, 10, tzinfo=timezone.utc)
+ r2_expected_start = datetime(2023, 5, 25, 7, 45, 10, tzinfo=UTC)
r2 = Run.create(
pipeline=pipeline,
instance_set=instance_set,
@@ -181,14 +181,14 @@ def test_get_runs_with_rules_coalesce_sort(pipeline, instance, patched_instance_
)
RunAlert.create(name="CA2", description="CD2", level=err_level, type=late_type, run=r2)
- r3_expected_start = datetime(2023, 5, 25, 7, 46, 1, tzinfo=timezone.utc)
+ r3_expected_start = datetime(2023, 5, 25, 7, 46, 1, tzinfo=UTC)
r3 = Run.create(
key="coalesce-run-3",
pipeline=pipeline,
instance_set=instance_set,
expected_start_time=r3_expected_start,
- start_time=datetime(2023, 5, 25, 7, 46, 22, tzinfo=timezone.utc),
- end_time=datetime(2023, 5, 25, 7, 49, 12, tzinfo=timezone.utc),
+ start_time=datetime(2023, 5, 25, 7, 46, 22, tzinfo=UTC),
+ end_time=datetime(2023, 5, 25, 7, 49, 12, tzinfo=UTC),
status=RunStatus.COMPLETED_WITH_WARNINGS.name,
)
RunAlert.create(name="CA3", description="CD3", level=err_level, type=late_type, run=r3)
diff --git a/common/tests/integration/entity_services/test_upcoming_instance_services.py b/common/tests/integration/entity_services/test_upcoming_instance_services.py
index de26d7a..5fb0a3e 100644
--- a/common/tests/integration/entity_services/test_upcoming_instance_services.py
+++ b/common/tests/integration/entity_services/test_upcoming_instance_services.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta, timezone, UTC
from uuid import uuid4
import pytest
@@ -282,7 +282,7 @@ def test_get_upcoming_instances_with_rules_discard_matching_existing_instance(
instance_rule_end,
current_time,
):
- base_time = datetime(2023, 8, 21, 10, 0, 0, tzinfo=timezone.utc)
+ base_time = datetime(2023, 8, 21, 10, 0, 0, tzinfo=UTC)
instance.start_time = base_time
instance.save()
instance_rule_end.expression = "30 * * * *"
@@ -307,7 +307,7 @@ def test_get_upcoming_instances_with_rules_do_not_discard_upcoming_instance(
instance_rule_end,
current_time,
):
- base_time = datetime(2023, 8, 21, 10, 0, 0, tzinfo=timezone.utc)
+ base_time = datetime(2023, 8, 21, 10, 0, 0, tzinfo=UTC)
instance.start_time = base_time
instance.save()
instance_rule_start.expression = "5 * * * *"
diff --git a/common/tests/integration/test_apscheduler_extensions.py b/common/tests/integration/test_apscheduler_extensions.py
index 7b9d463..248ba33 100644
--- a/common/tests/integration/test_apscheduler_extensions.py
+++ b/common/tests/integration/test_apscheduler_extensions.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta, timezone, UTC
from itertools import count
import pytest
@@ -12,7 +12,7 @@
def calculate_fire_time_sequence(trigger, n=10):
prev_fire_time = None
- now = datetime.now(tz=getattr(trigger, "timezone", timezone.utc))
+ now = datetime.now(tz=getattr(trigger, "timezone", UTC))
for _ in range(n):
next_fire_time = trigger.get_next_fire_time(prev_fire_time, now)
yield next_fire_time
@@ -24,21 +24,21 @@ def calculate_fire_time_sequence(trigger, n=10):
@pytest.mark.parametrize(
"trigger",
(
- CronTrigger.from_crontab("*/2 * * * *", timezone=timezone.utc),
- CronTrigger.from_crontab("*/4 * * * *", timezone=timezone.utc),
+ CronTrigger.from_crontab("*/2 * * * *", timezone=UTC),
+ CronTrigger.from_crontab("*/4 * * * *", timezone=UTC),
CronTrigger(
year="*",
month="*",
day="*",
hour="*",
minute="*/4",
- timezone=timezone.utc,
- start_date=datetime.now(tz=timezone.utc) + timedelta(days=5),
+ timezone=UTC,
+ start_date=datetime.now(tz=UTC) + timedelta(days=5),
),
CronTrigger.from_crontab("*/4 * * * *", timezone=astimezone("Asia/Tokyo")),
- IntervalTrigger(minutes=1, timezone=timezone.utc),
- IntervalTrigger(minutes=6, timezone=timezone.utc),
- DateTrigger(timezone=timezone.utc),
+ IntervalTrigger(minutes=1, timezone=UTC),
+ IntervalTrigger(minutes=6, timezone=UTC),
+ DateTrigger(timezone=UTC),
),
ids=(
"cron_smaller_interval",
@@ -73,7 +73,7 @@ def test_delayed_trigger_3_min_delay(trigger):
),
)
def test_delayed_trigger(delay):
- cron_trigger = CronTrigger.from_crontab("*/2 * * * *", timezone=timezone.utc)
+ cron_trigger = CronTrigger.from_crontab("*/2 * * * *", timezone=UTC)
delayed_trigger = DelayedTrigger(cron_trigger, delay)
for idx, original, delayed in zip(
diff --git a/common/tests/unit/actions/test_webhook_action.py b/common/tests/unit/actions/test_webhook_action.py
index 6458cb1..7f5dd01 100644
--- a/common/tests/unit/actions/test_webhook_action.py
+++ b/common/tests/unit/actions/test_webhook_action.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
from unittest.mock import Mock, patch
import pytest
@@ -19,7 +19,7 @@ def session():
@pytest.fixture
def test_outcome_item_data():
- timestamp = datetime.now(timezone.utc).isoformat()
+ timestamp = datetime.now(UTC).isoformat()
return {
"name": "My_test_name",
"status": TestStatuses.PASSED.name,
diff --git a/common/tests/unit/entities/test_journey_dag.py b/common/tests/unit/entities/test_journey_dag.py
index 5edbd3d..d6abe82 100644
--- a/common/tests/unit/entities/test_journey_dag.py
+++ b/common/tests/unit/entities/test_journey_dag.py
@@ -10,7 +10,7 @@
@dataclass
class FakeEdge:
- left: Optional[str]
+ left: str | None
right: str
def __hash__(self):
diff --git a/common/tests/unit/entity_services/helpers/test_filter_rules.py b/common/tests/unit/entity_services/helpers/test_filter_rules.py
index c6547ea..fb51cba 100644
--- a/common/tests/unit/entity_services/helpers/test_filter_rules.py
+++ b/common/tests/unit/entity_services/helpers/test_filter_rules.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
from uuid import uuid4
import pytest
@@ -58,8 +58,8 @@ def test_from_params_start():
# We want to see that arrow is giving us the kind of timestamp we think; i.e., a UTC aware
# timestamp.
- assert filters.start_range_begin == datetime(year=2022, month=8, day=16, tzinfo=timezone.utc)
- assert filters.start_range_end == datetime(year=2022, month=8, day=17, tzinfo=timezone.utc)
+ assert filters.start_range_begin == datetime(year=2022, month=8, day=16, tzinfo=UTC)
+ assert filters.start_range_end == datetime(year=2022, month=8, day=17, tzinfo=UTC)
assert filters.end_range_begin is None
assert filters.end_range_end is None
assert bool(filters)
@@ -102,8 +102,8 @@ def test_run_filters_from_parameters_end():
[("page", "5"), ("count", "25"), ("end_range_begin", "2022-08-16"), ("end_range_end", "2022-08-17")]
)
filters = RunFilters.from_params(end)
- assert filters.end_range_begin == datetime(year=2022, month=8, day=16, tzinfo=timezone.utc)
- assert filters.end_range_end == datetime(year=2022, month=8, day=17, tzinfo=timezone.utc)
+ assert filters.end_range_begin == datetime(year=2022, month=8, day=16, tzinfo=UTC)
+ assert filters.end_range_end == datetime(year=2022, month=8, day=17, tzinfo=UTC)
assert filters.start_range_begin is None
assert filters.start_range_end is None
assert filters.pipeline_keys == []
diff --git a/common/tests/unit/events/v1/test_base_events.py b/common/tests/unit/events/v1/test_base_events.py
index 81d5433..7857abe 100644
--- a/common/tests/unit/events/v1/test_base_events.py
+++ b/common/tests/unit/events/v1/test_base_events.py
@@ -1,5 +1,5 @@
import uuid
-from datetime import timedelta, timezone
+from datetime import timedelta, timezone, UTC
import pytest
from marshmallow import ValidationError
@@ -127,8 +127,8 @@ def test_event_with_batch_pipeline_component_missing_run_key_error(valid_event_d
@pytest.mark.parametrize(
["timestamp", "tz"],
[
- ("2018-07-25T00:00:00Z", timezone.utc),
- ("2018-07-25T00:00:00", timezone.utc),
+ ("2018-07-25T00:00:00Z", UTC),
+ ("2018-07-25T00:00:00", UTC),
("2014-12-22T03:12:58.019077+06:00", timezone(timedelta(hours=6))),
],
ids=["ZuluTime", "Naive", "TZ offset"],
diff --git a/common/tests/unit/events/v1/test_testoutcomes.py b/common/tests/unit/events/v1/test_testoutcomes.py
index 30ae353..f862e35 100644
--- a/common/tests/unit/events/v1/test_testoutcomes.py
+++ b/common/tests/unit/events/v1/test_testoutcomes.py
@@ -38,9 +38,9 @@ def test_testoutcomes_schema_with_testgen_integration(test_outcomes_testgen_even
assert item_integration_event.test_suite == item_integration_data["test_suite"]
assert item_integration_event.version == item_integration_data["version"]
assert len(item_integration_event.test_parameters) == len(item_integration_data["test_parameters"])
- assert (
- type(item_integration_event.test_parameters[0].value) == Decimal
- ), "expected dataclass's value to be Decimal type"
+ assert type(item_integration_event.test_parameters[0].value) == Decimal, (
+ "expected dataclass's value to be Decimal type"
+ )
assert str(item_integration_event.test_parameters[0].value) == item_integration_data["test_parameters"][0]["value"]
assert item_integration_event.test_parameters[0].name == item_integration_data["test_parameters"][0]["name"]
assert item_integration_event.columns == item_integration_data["columns"]
diff --git a/common/tests/unit/events/v2/test_test_outcomes.py b/common/tests/unit/events/v2/test_test_outcomes.py
index c8bc638..3b23371 100644
--- a/common/tests/unit/events/v2/test_test_outcomes.py
+++ b/common/tests/unit/events/v2/test_test_outcomes.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
from decimal import Decimal
import pytest
@@ -172,8 +172,8 @@ def test_test_outcomes(
for actual, expected in zip(res.test_outcomes, test_outcome_items):
assert actual.name == expected.name
assert actual.description == expected.description
- assert actual.start_time == datetime.fromisoformat(expected.start_time).replace(tzinfo=timezone.utc)
- assert actual.end_time == datetime.fromisoformat(expected.end_time).replace(tzinfo=timezone.utc)
+ assert actual.start_time == datetime.fromisoformat(expected.start_time).replace(tzinfo=UTC)
+ assert actual.end_time == datetime.fromisoformat(expected.end_time).replace(tzinfo=UTC)
assert actual.metric_value == expected.metric_value
assert actual.metric_name == expected.metric_name
assert actual.metric_description == expected.metric_description
@@ -201,9 +201,9 @@ def test_testoutcomes_schema_with_testgen_integration(test_outcomes_testgen_data
assert item_integration_event.test_suite == item_integration_data["test_suite"]
assert item_integration_event.version == item_integration_data["version"]
assert len(item_integration_event.test_parameters) == len(item_integration_data["test_parameters"])
- assert (
- type(item_integration_event.test_parameters[0].value) == Decimal
- ), "expected dataclass's value to be Decimal type"
+ assert type(item_integration_event.test_parameters[0].value) == Decimal, (
+ "expected dataclass's value to be Decimal type"
+ )
assert str(item_integration_event.test_parameters[0].value) == item_integration_data["test_parameters"][0]["value"]
assert item_integration_event.test_parameters[0].name == item_integration_data["test_parameters"][0]["name"]
assert item_integration_event.columns == item_integration_data["columns"]
diff --git a/common/tests/unit/flask_ext/test_jwt_plugin.py b/common/tests/unit/flask_ext/test_jwt_plugin.py
index 47631c3..23ab982 100644
--- a/common/tests/unit/flask_ext/test_jwt_plugin.py
+++ b/common/tests/unit/flask_ext/test_jwt_plugin.py
@@ -1,7 +1,7 @@
import json
from base64 import b64decode, b64encode
from binascii import Error as B64DecodeError
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta, timezone, UTC
from unittest.mock import Mock, patch
from uuid import uuid4
@@ -50,7 +50,7 @@ def expired_token():
@pytest.fixture
def current_token():
data = TOKEN_DATA.copy()
- dt = datetime.now(timezone.utc) + timedelta(days=2)
+ dt = datetime.now(UTC) + timedelta(days=2)
data["exp"] = int(dt.replace(microsecond=0).timestamp())
return encode(data, key=JWT_KEY)
@@ -64,8 +64,8 @@ def user():
@pytest.mark.parametrize(
("ts_value", "dt_value"),
[
- (40, datetime(1970, 1, 1, 0, 0, 40, tzinfo=timezone.utc)),
- (435474000, datetime(1983, 10, 20, 5, 0, tzinfo=timezone.utc)),
+ (40, datetime(1970, 1, 1, 0, 0, 40, tzinfo=UTC)),
+ (435474000, datetime(1983, 10, 20, 5, 0, tzinfo=UTC)),
],
)
def test_get_expiration_int(ts_value, dt_value):
diff --git a/common/tests/unit/peewee_extensions/test_peewee_extensions.py b/common/tests/unit/peewee_extensions/test_peewee_extensions.py
index 77783c3..f3e76ad 100644
--- a/common/tests/unit/peewee_extensions/test_peewee_extensions.py
+++ b/common/tests/unit/peewee_extensions/test_peewee_extensions.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
from enum import Enum, IntEnum
from zoneinfo import ZoneInfo
@@ -82,7 +82,7 @@ def test_domain_field_lowercase():
@pytest.mark.unit
def test_timestamp_to_utc():
"""UTCTimestampField.python_value returns timezone aware values."""
- expected_dt = datetime.now(timezone.utc)
+ expected_dt = datetime.now(UTC)
f_inst = UTCTimestampField()
db_value = f_inst.db_value(expected_dt)
result = f_inst.python_value(db_value)
diff --git a/common/tests/unit/predicate_engine/assertions.py b/common/tests/unit/predicate_engine/assertions.py
index 625bf62..13d6f2f 100644
--- a/common/tests/unit/predicate_engine/assertions.py
+++ b/common/tests/unit/predicate_engine/assertions.py
@@ -42,7 +42,7 @@ def _to_str(value):
raise AssertionError(f"Rules differ: \n\t{str_a}\n\t!=\n\t{str_b}")
-def assertRuleMatches(a: R, b: Any, msg: Optional[str] = None):
+def assertRuleMatches(a: R, b: Any, msg: str | None = None):
"""Assert that an R object matches a given value."""
try:
result = a.matches(b)
@@ -65,7 +65,7 @@ def assertRuleMatches(a: R, b: Any, msg: Optional[str] = None):
)
-def assertRuleNotMatches(a: R, b: Any, msg: Optional[str] = None):
+def assertRuleNotMatches(a: R, b: Any, msg: str | None = None):
try:
result = a.matches(b)
except Exception:
diff --git a/common/tests/unit/predicate_engine/conftest.py b/common/tests/unit/predicate_engine/conftest.py
index 09234c5..f0d34f1 100644
--- a/common/tests/unit/predicate_engine/conftest.py
+++ b/common/tests/unit/predicate_engine/conftest.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
import pytest
@@ -33,7 +33,7 @@ def timestamp(self) -> datetime:
@property
def timestamp_dt(self) -> datetime:
- return datetime(1983, 10, 20, 10, 10, 10, tzinfo=timezone.utc)
+ return datetime(1983, 10, 20, 10, 10, 10, tzinfo=UTC)
@pytest.fixture(scope="session")
diff --git a/common/tests/unit/predicate_engine/test_predicate_engine.py b/common/tests/unit/predicate_engine/test_predicate_engine.py
index 99d111d..437769f 100644
--- a/common/tests/unit/predicate_engine/test_predicate_engine.py
+++ b/common/tests/unit/predicate_engine/test_predicate_engine.py
@@ -2,7 +2,7 @@
import sys
from collections.abc import MutableMapping
from copy import copy
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
from unicodedata import category
import pytest
@@ -193,7 +193,7 @@ def test_matching_invalid_data_types(rule, simple_entity):
@pytest.mark.parametrize(
"rule",
(
- R(timestamp__gte=datetime.now(timezone.utc)),
+ R(timestamp__gte=datetime.now(UTC)),
R(timestamp_dt__gte=datetime.now()),
),
)
diff --git a/common/tests/unit/test_apscheduler_extensions.py b/common/tests/unit/test_apscheduler_extensions.py
index 0c07f37..14f1e3c 100644
--- a/common/tests/unit/test_apscheduler_extensions.py
+++ b/common/tests/unit/test_apscheduler_extensions.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta, timezone, UTC
from unittest.mock import patch
from zoneinfo import ZoneInfo
@@ -15,7 +15,7 @@
@pytest.mark.unit
def test_delayed_trigger_negative():
- cron_trigger = CronTrigger.from_crontab("*/2 * * * *", timezone=timezone.utc)
+ cron_trigger = CronTrigger.from_crontab("*/2 * * * *", timezone=UTC)
delay = timedelta(hours=-1)
with pytest.raises(ValueError, match="positive"):
DelayedTrigger(cron_trigger, delay)
@@ -102,8 +102,8 @@ def test_fix_weekdays(weekday_expression, expected):
@pytest.mark.unit
def test_get_crontab_trigger_times_finite_range():
- start = datetime(2000, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
- end = datetime(2000, 1, 10, 0, 0, 0, tzinfo=timezone.utc)
+ start = datetime(2000, 1, 1, 0, 0, 0, tzinfo=UTC)
+ end = datetime(2000, 1, 10, 0, 0, 0, tzinfo=UTC)
for i, time in enumerate(get_crontab_trigger_times("0 10 * * *", ZoneInfo("UTC"), start, end)):
assert time == start + timedelta(days=i, hours=10)
assert i == 8
@@ -111,7 +111,7 @@ def test_get_crontab_trigger_times_finite_range():
@pytest.mark.unit
def test_get_crontab_trigger_times_infinite_range():
- start = datetime(2000, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
+ start = datetime(2000, 1, 1, 0, 0, 0, tzinfo=UTC)
gen = get_crontab_trigger_times("0 10 * * *", ZoneInfo("UTC"), start)
for i in range(5000):
assert next(gen) == start + timedelta(days=i, hours=10)
@@ -122,7 +122,7 @@ def test_get_crontab_trigger_times_infinite_range():
"start,end",
(
(datetime(2000, 1, 1, 0, 0, 0), None),
- (datetime(2000, 1, 1, 0, 0, 0, tzinfo=timezone.utc), datetime(2000, 1, 2, 0, 0, 0)),
+ (datetime(2000, 1, 1, 0, 0, 0, tzinfo=UTC), datetime(2000, 1, 2, 0, 0, 0)),
),
)
def test_get_crontab_trigger_times_invalid_range(start, end):
diff --git a/common/tests/unit/test_datetime_utils.py b/common/tests/unit/test_datetime_utils.py
index 3ceec35..060fbd8 100644
--- a/common/tests/unit/test_datetime_utils.py
+++ b/common/tests/unit/test_datetime_utils.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
from zoneinfo import ZoneInfo, available_timezones
import pytest
@@ -26,10 +26,10 @@ def test_datetime_iso8601():
@pytest.mark.parametrize(
"ts, dt",
(
- (1685701704.039912, datetime(2023, 6, 2, 10, 28, 24, 39912, tzinfo=timezone.utc)),
- (1123905724.000002, datetime(2005, 8, 13, 4, 2, 4, 2, tzinfo=timezone.utc)),
- (435489642.424242, datetime(1983, 10, 20, 9, 20, 42, 424242, tzinfo=timezone.utc)),
- (1162815179.06429, datetime(2006, 11, 6, 12, 12, 59, 64290, tzinfo=timezone.utc)),
+ (1685701704.039912, datetime(2023, 6, 2, 10, 28, 24, 39912, tzinfo=UTC)),
+ (1123905724.000002, datetime(2005, 8, 13, 4, 2, 4, 2, tzinfo=UTC)),
+ (435489642.424242, datetime(1983, 10, 20, 9, 20, 42, 424242, tzinfo=UTC)),
+ (1162815179.06429, datetime(2006, 11, 6, 12, 12, 59, 64290, tzinfo=UTC)),
),
)
def test_timestamp_and_datetime(ts, dt):
@@ -47,7 +47,7 @@ def test_tzinfo_added():
naive_dt = datetime(2006, 11, 6, 12, 12)
timestamp = datetime_to_timestamp(naive_dt)
- expected_dt = datetime(2006, 11, 6, 12, 12, tzinfo=timezone.utc)
+ expected_dt = datetime(2006, 11, 6, 12, 12, tzinfo=UTC)
actual_dt = timestamp_to_datetime(timestamp)
assert actual_dt != naive_dt
@@ -65,7 +65,7 @@ def test_tzinfo_coerced_to_utc():
tz_dt = datetime(2001, 1, 1, 0, 0, 0, tzinfo=tzinfo)
timestamp = datetime_to_timestamp(tz_dt)
- expected_dt = tz_dt.astimezone(timezone.utc)
+ expected_dt = tz_dt.astimezone(UTC)
actual_dt = timestamp_to_datetime(timestamp)
assert expected_dt == actual_dt
diff --git a/common/tests/unit/test_messagepack.py b/common/tests/unit/test_messagepack.py
index ce269d3..aab95bb 100644
--- a/common/tests/unit/test_messagepack.py
+++ b/common/tests/unit/test_messagepack.py
@@ -1,7 +1,7 @@
from array import array
from collections import OrderedDict
from dataclasses import dataclass
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
from decimal import Decimal
from io import BytesIO
from pathlib import Path, PurePath, PurePosixPath, PureWindowsPath
@@ -159,7 +159,7 @@ def test_dump_load_datetime():
@pytest.mark.unit
def test_dump_load_datetime_tzinfo():
"""Messagepack can dump/load datetime.datetime values and preserve tzinfo."""
- data = datetime.now(timezone.utc)
+ data = datetime.now(UTC)
out_value = loads(dumps(data))
assert data == out_value
diff --git a/deploy/search_view_plugin.py b/deploy/search_view_plugin.py
index 7cbd348..9559559 100644
--- a/deploy/search_view_plugin.py
+++ b/deploy/search_view_plugin.py
@@ -1,5 +1,5 @@
import logging
-from typing import Any, Optional
+from typing import Any
from apispec import BasePlugin
from apispec.yaml_utils import load_yaml_from_docstring
@@ -16,8 +16,8 @@ class SearchViewPlugin(BasePlugin):
def operation_helper(
self,
- path: Optional[str] = None,
- operations: Optional[dict] = None,
+ path: str | None = None,
+ operations: dict | None = None,
**kwargs: Any,
) -> None:
view_class = getattr(kwargs.get("view", None), "view_class", None)
diff --git a/deploy/subcomponent_plugin.py b/deploy/subcomponent_plugin.py
index 85edd1d..1bd4c64 100644
--- a/deploy/subcomponent_plugin.py
+++ b/deploy/subcomponent_plugin.py
@@ -1,5 +1,5 @@
import logging
-from typing import Any, Optional
+from typing import Any
from apispec import BasePlugin
@@ -15,8 +15,8 @@ class SubcomponentPlugin(BasePlugin):
def operation_helper(
self,
- path: Optional[str] = None,
- operations: Optional[dict] = None,
+ path: str | None = None,
+ operations: dict | None = None,
**kwargs: Any,
) -> None:
description_dict: dict[str, str] = {
@@ -58,7 +58,7 @@ def request_body_helper(self, method: str) -> dict:
}
return request_body
- def parameter_helper(self, parameter: Optional[dict] = None, **kwargs: Any) -> dict:
+ def parameter_helper(self, parameter: dict | None = None, **kwargs: Any) -> dict:
method = kwargs["method"]
parameter = {"in": "path", "schema": {"type": "string"}, "required": "true", "name": "component_id"}
if method == "post":
@@ -66,7 +66,7 @@ def parameter_helper(self, parameter: Optional[dict] = None, **kwargs: Any) -> d
parameter["description"] = f"The ID of the project that the {self.subcomponent_name} will be created under."
return parameter
- def response_helper(self, response: Optional[dict] = None, **kwargs: Any) -> dict:
+ def response_helper(self, response: dict | None = None, **kwargs: Any) -> dict:
method = kwargs["method"]
response_desc_dict: dict[str, dict[int, str]] = {
"get": {
diff --git a/event_api/config/defaults.py b/event_api/config/defaults.py
index 6e2dfce..81aba97 100644
--- a/event_api/config/defaults.py
+++ b/event_api/config/defaults.py
@@ -5,13 +5,12 @@
"""
import os
-from typing import Optional
# Flask specific settings: https://flask.palletsprojects.com/en/latest/config/#builtin-configuration-values
from common.entities import Service
-PROPAGATE_EXCEPTIONS: Optional[bool] = None
-SERVER_NAME: Optional[str] = os.environ.get("EVENTS_API_HOSTNAME") # Use flask defaults if none set
+PROPAGATE_EXCEPTIONS: bool | None = None
+SERVER_NAME: str | None = os.environ.get("EVENTS_API_HOSTNAME") # Use flask defaults if none set
USE_X_SENDFILE: bool = False # If we serve files enable this in production settings when webserver support configured
# Application settings
diff --git a/event_api/config/local.py b/event_api/config/local.py
index aaf5847..f063607 100644
--- a/event_api/config/local.py
+++ b/event_api/config/local.py
@@ -1,5 +1,3 @@
-from typing import Optional
-
# Flask specific settings: https://flask.palletsprojects.com/en/latest/config/#builtin-configuration-values
-PROPAGATE_EXCEPTIONS: Optional[bool] = True
+PROPAGATE_EXCEPTIONS: bool | None = True
SECRET_KEY: str = "NOT_VERY_SECRET"
diff --git a/event_api/config/minikube.py b/event_api/config/minikube.py
index e5ed0f8..6c72c79 100644
--- a/event_api/config/minikube.py
+++ b/event_api/config/minikube.py
@@ -1,5 +1,3 @@
-from typing import Optional
-
# Flask specific settings: https://flask.palletsprojects.com/en/latest/config/#builtin-configuration-values
-TESTING: Optional[bool] = True
+TESTING: bool | None = True
SECRET_KEY: str = "NOT_VERY_SECRET"
diff --git a/event_api/endpoints/v1/event_view.py b/event_api/endpoints/v1/event_view.py
index 825e2fe..ae257ec 100644
--- a/event_api/endpoints/v1/event_view.py
+++ b/event_api/endpoints/v1/event_view.py
@@ -1,7 +1,6 @@
import logging
-from datetime import datetime, timezone
+from datetime import datetime, UTC
from http import HTTPStatus
-from typing import Optional
from flask import Response, current_app, g, make_response, request
from marshmallow import ValidationError
@@ -34,14 +33,14 @@ class EventView(BaseView):
event_type: type[Event]
"""The class (not instance) that is used to deserialize the incoming request body"""
- def make_error(self, msg: str, e: Exception, error_code: Optional[int] = None) -> Response:
+ def make_error(self, msg: str, e: Exception, error_code: int | None = None) -> Response:
"""TODO: This should be turned into an ErrorHandler at the app level."""
return make_response(
{
"error": msg,
# TODO: Should this be exposed to the user?
"details": str(e),
- "timestamp": datetime.now(tz=timezone.utc),
+ "timestamp": datetime.now(tz=UTC),
},
error_code if error_code else 500,
)
diff --git a/event_api/tests/integration/v1_endpoints/conftest.py b/event_api/tests/integration/v1_endpoints/conftest.py
index 8b413e8..41adb91 100644
--- a/event_api/tests/integration/v1_endpoints/conftest.py
+++ b/event_api/tests/integration/v1_endpoints/conftest.py
@@ -1,7 +1,7 @@
import os
import shutil
from dataclasses import dataclass
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
from unittest.mock import MagicMock, patch
from uuid import UUID
@@ -29,7 +29,7 @@ class DatabaseCtx:
@pytest.fixture
def predictable_datetime():
- return datetime(2022, 5, 25, 19, 56, 52, 759419, tzinfo=timezone.utc)
+ return datetime(2022, 5, 25, 19, 56, 52, 759419, tzinfo=UTC)
@pytest.fixture
diff --git a/event_api/tests/integration/v2_endpoints/conftest.py b/event_api/tests/integration/v2_endpoints/conftest.py
index 462feff..e9d6e88 100644
--- a/event_api/tests/integration/v2_endpoints/conftest.py
+++ b/event_api/tests/integration/v2_endpoints/conftest.py
@@ -1,6 +1,6 @@
import os
import shutil
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
from unittest.mock import MagicMock, patch
import pytest
@@ -18,7 +18,7 @@
@pytest.fixture
def event_time():
- return datetime(2022, 5, 25, 19, 56, 52, 759419, tzinfo=timezone.utc)
+ return datetime(2022, 5, 25, 19, 56, 52, 759419, tzinfo=UTC)
@pytest.fixture
diff --git a/observability_api/config/defaults.py b/observability_api/config/defaults.py
index 5a8a923..05a39c0 100644
--- a/observability_api/config/defaults.py
+++ b/observability_api/config/defaults.py
@@ -6,13 +6,12 @@
import os
from datetime import timedelta
-from typing import Optional
# Flask specific settings: https://flask.palletsprojects.com/en/latest/config/#builtin-configuration-values
from common.entities import Service
-PROPAGATE_EXCEPTIONS: Optional[bool] = None
-SERVER_NAME: Optional[str] = os.environ.get("OBSERVABILITY_API_HOSTNAME") # Use flask defaults if none set
+PROPAGATE_EXCEPTIONS: bool | None = None
+SERVER_NAME: str | None = os.environ.get("OBSERVABILITY_API_HOSTNAME") # Use flask defaults if none set
USE_X_SENDFILE: bool = False # If we serve files enable this in production settings when webserver support configured
# Application settings
diff --git a/observability_api/config/local.py b/observability_api/config/local.py
index 56bf7e2..b58b69b 100644
--- a/observability_api/config/local.py
+++ b/observability_api/config/local.py
@@ -1,7 +1,5 @@
-from typing import Optional
-
# Flask specific settings: https://flask.palletsprojects.com/en/latest/config/#builtin-configuration-values
-PROPAGATE_EXCEPTIONS: Optional[bool] = True
+PROPAGATE_EXCEPTIONS: bool | None = True
SECRET_KEY: str = "NOT_VERY_SECRET"
# Application settings
diff --git a/observability_api/config/minikube.py b/observability_api/config/minikube.py
index bceae82..3b013a0 100644
--- a/observability_api/config/minikube.py
+++ b/observability_api/config/minikube.py
@@ -1,7 +1,5 @@
-from typing import Optional
-
# Flask specific settings: https://flask.palletsprojects.com/en/latest/config/#builtin-configuration-values
-TESTING: Optional[bool] = True
+TESTING: bool | None = True
SECRET_KEY: str = "NOT_VERY_SECRET"
# Application settings
diff --git a/observability_api/config/test.py b/observability_api/config/test.py
index 5dc2eb9..d9cedfc 100644
--- a/observability_api/config/test.py
+++ b/observability_api/config/test.py
@@ -1,7 +1,5 @@
-from typing import Optional
-
# Flask specific settings: https://flask.palletsprojects.com/en/latest/config/#builtin-configuration-values
-PROPAGATE_EXCEPTIONS: Optional[bool] = True
+PROPAGATE_EXCEPTIONS: bool | None = True
SECRET_KEY: str = "NOT_VERY_SECRET"
TESTING: bool = True
diff --git a/observability_api/endpoints/component_view.py b/observability_api/endpoints/component_view.py
index 859da15..091dc0a 100644
--- a/observability_api/endpoints/component_view.py
+++ b/observability_api/endpoints/component_view.py
@@ -2,7 +2,7 @@
import logging
from http import HTTPStatus
-from typing import Any, Optional
+from typing import Any
from uuid import UUID
from flask import Blueprint, Response, make_response
@@ -21,7 +21,7 @@ class ComponentByIdAbstractView(BaseEntityView):
route: str
entity: type[BaseEntity]
schema: type[ModelSchema]
- patch_schema: Optional[type[ModelSchema]] = None
+ patch_schema: type[ModelSchema] | None = None
def get(self, component_id: UUID) -> Response:
component = self.get_entity_or_fail(self.entity, self.entity.id == component_id)
diff --git a/observability_api/endpoints/v1/journeys.py b/observability_api/endpoints/v1/journeys.py
index 92ba43a..56e3902 100644
--- a/observability_api/endpoints/v1/journeys.py
+++ b/observability_api/endpoints/v1/journeys.py
@@ -3,7 +3,6 @@
import logging
from graphlib import CycleError
from http import HTTPStatus
-from typing import Optional
from uuid import UUID
from flask import Response, make_response, request
@@ -114,7 +113,7 @@ def get(self, project_id: UUID) -> Response:
"""
_ = self.get_entity_or_fail(Project, Project.id == project_id)
- component_id: Optional[str] = request.args.get("component_id", None)
+ component_id: str | None = request.args.get("component_id", None)
page: Page = ProjectService.get_journeys_with_rules(
str(project_id), ListRules.from_params(request.args), component_id=component_id
)
diff --git a/observability_api/endpoints/v1/project_settings.py b/observability_api/endpoints/v1/project_settings.py
index 702539c..5a28789 100644
--- a/observability_api/endpoints/v1/project_settings.py
+++ b/observability_api/endpoints/v1/project_settings.py
@@ -1,6 +1,6 @@
__all__ = ["ProjectAlertsSettings"]
-from typing import Optional, cast, Any
+from typing import cast, Any
from uuid import UUID
from flask import Response, make_response
@@ -23,7 +23,7 @@
class ProjectAlertsSettings(BaseEntityView):
PERMISSION_REQUIREMENTS: tuple[Permission, ...] = (PERM_USER, PERM_PROJECT)
- _project: Optional[Project]
+ _project: Project | None
def get_request_schema(self) -> Schema:
return cast(Schema, Schema.from_dict(self.get_fields(), name=f"{self.__class__.__name__}Schema")())
diff --git a/observability_api/schemas/event_schemas.py b/observability_api/schemas/event_schemas.py
index 7b9ac85..192594a 100644
--- a/observability_api/schemas/event_schemas.py
+++ b/observability_api/schemas/event_schemas.py
@@ -52,8 +52,7 @@ class EventResponseSchema(Schema):
required=True,
metadata={
"description": (
- "The IDs of the components related to the event. The first item in the list is the primary "
- "component. "
+ "The IDs of the components related to the event. The first item in the list is the primary component. "
)
},
)
diff --git a/observability_api/tests/integration/v1_endpoints/conftest.py b/observability_api/tests/integration/v1_endpoints/conftest.py
index e1c00d3..f67c09e 100644
--- a/observability_api/tests/integration/v1_endpoints/conftest.py
+++ b/observability_api/tests/integration/v1_endpoints/conftest.py
@@ -1,7 +1,7 @@
import os
import shutil
import uuid
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta, timezone, UTC
from decimal import Decimal
import pytest
@@ -275,8 +275,8 @@ def test_outcome(client, instance_instance_set, instance_runs, pipeline):
dimensions=["a", "b", "c"],
status=TestStatuses.WARNING.name,
run=instance_runs[0].id,
- start_time=datetime.now(tz=timezone.utc) - timedelta(minutes=15),
- end_time=datetime.now(tz=timezone.utc) - timedelta(minutes=5),
+ start_time=datetime.now(tz=UTC) - timedelta(minutes=15),
+ end_time=datetime.now(tz=UTC) - timedelta(minutes=5),
component=pipeline,
instance_set=instance_instance_set.instance_set,
external_url="https://example.com",
@@ -299,8 +299,8 @@ def test_outcomes(client, instance_instance_set, pipeline):
dimension=[f"a-{i}", f"b-{i}", f"c-{i}"],
description=f"Description{i}",
status=f"{TestStatuses.PASSED.name if i % 2 == 0 else TestStatuses.FAILED.name}",
- start_time=datetime.now(tz=timezone.utc) + timedelta(minutes=5 * i),
- end_time=datetime.now(tz=timezone.utc) + timedelta(minutes=15 * i),
+ start_time=datetime.now(tz=UTC) + timedelta(minutes=5 * i),
+ end_time=datetime.now(tz=UTC) + timedelta(minutes=15 * i),
component=pipeline,
instance_set=instance_instance_set.instance_set,
external_url="https://example.com",
diff --git a/observability_api/tests/integration/v1_endpoints/test_alerts.py b/observability_api/tests/integration/v1_endpoints/test_alerts.py
index 9f7d59f..094d097 100644
--- a/observability_api/tests/integration/v1_endpoints/test_alerts.py
+++ b/observability_api/tests/integration/v1_endpoints/test_alerts.py
@@ -1,5 +1,5 @@
import uuid
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta, timezone, UTC
from http import HTTPStatus
from uuid import uuid4
@@ -107,7 +107,7 @@ def test_list_project_alerts_search(client, g_user, project, instance_alert, ins
@pytest.mark.integration
def test_list_project_alerts_filters(client, g_user, project, instance_alert, instance_alert_components, run_alerts):
- yesterday = datetime.now(tz=timezone.utc) - timedelta(days=1)
+ yesterday = datetime.now(tz=UTC) - timedelta(days=1)
past_instance = Instance.create(journey=instance_alert.instance.journey, start_time=yesterday)
past_instance_alert = InstanceAlert.create(
id=uuid4(),
diff --git a/observability_api/tests/integration/v1_endpoints/test_instances.py b/observability_api/tests/integration/v1_endpoints/test_instances.py
index 1492897..82bcdd3 100644
--- a/observability_api/tests/integration/v1_endpoints/test_instances.py
+++ b/observability_api/tests/integration/v1_endpoints/test_instances.py
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta, timezone, UTC
from http import HTTPStatus
from typing import Optional
from uuid import UUID, uuid4
@@ -51,8 +51,8 @@ def test_outcome(client, instance_runs, pipeline):
description="Abc_Description",
status=TestStatuses.WARNING.name,
run=instance_runs[0].id,
- start_time=datetime.now(tz=timezone.utc) - timedelta(minutes=15),
- end_time=datetime.now(tz=timezone.utc) - timedelta(minutes=5),
+ start_time=datetime.now(tz=UTC) - timedelta(minutes=15),
+ end_time=datetime.now(tz=UTC) - timedelta(minutes=5),
component=pipeline,
)
yield test_outcome
@@ -66,8 +66,8 @@ def test_outcomes(client, instance_instance_set, pipeline):
name=f"DKTest{i}",
description=f"Description{i}",
status=f"{TestStatuses.PASSED.name if i % 2 == 0 else TestStatuses.FAILED.name}",
- start_time=datetime.now(tz=timezone.utc) + timedelta(minutes=5 * i),
- end_time=datetime.now(tz=timezone.utc) + timedelta(minutes=15 * i),
+ start_time=datetime.now(tz=UTC) + timedelta(minutes=5 * i),
+ end_time=datetime.now(tz=UTC) + timedelta(minutes=15 * i),
component=pipeline,
instance_set=instance_instance_set.instance_set,
)
@@ -230,8 +230,8 @@ def create_test_outcomes(instances: list[Instance], component: Pipeline | Datase
name=f"DKTest{i}",
description=f"Description{i}",
status=f"{TestStatuses.PASSED.name if i % 2 == 0 else TestStatuses.FAILED.name}",
- start_time=datetime.now(tz=timezone.utc) + timedelta(minutes=5 * i),
- end_time=datetime.now(tz=timezone.utc) + timedelta(minutes=15 * i),
+ start_time=datetime.now(tz=UTC) + timedelta(minutes=5 * i),
+ end_time=datetime.now(tz=UTC) + timedelta(minutes=15 * i),
component=component,
instance_set=instance_set,
)
@@ -244,7 +244,7 @@ def create_dataset_operations(instances: list[Instance], dataset: Dataset):
DatasetOperation.create(
dataset=dataset,
instance_set=instance_set,
- operation_time=datetime.now(tz=timezone.utc),
+ operation_time=datetime.now(tz=UTC),
operation=f"{DatasetOperationType.READ.name if i % 2 == 0 else DatasetOperationType.WRITE.name}",
path="/path/to/file",
)
@@ -252,7 +252,7 @@ def create_dataset_operations(instances: list[Instance], dataset: Dataset):
@pytest.fixture
def run_status_event(pipeline, project, journey, instances):
- ts = datetime.now(timezone.utc)
+ ts = datetime.now(UTC)
yield RunStatusEvent(
project_id=project.id,
event_id=uuid4(),
@@ -360,7 +360,7 @@ def test_search_instances(client, journey, instances, g_user):
response = client.post(
f"/observability/v1/projects/{journey.project.id}/instances/search",
- json={"params": {"start_range_end": datetime.now(timezone.utc).isoformat()}},
+ json={"params": {"start_range_end": datetime.now(UTC).isoformat()}},
)
assert response.status_code == HTTPStatus.OK, response.json
assert response.json["total"] == 6
@@ -472,7 +472,7 @@ class InstanceData:
instances: list[Instance] = field(default_factory=list)
-def create_instance_data(number: int, proj: Optional[Project] = None) -> InstanceData:
+def create_instance_data(number: int, proj: Project | None = None) -> InstanceData:
c = Company.create(name=f"TestCompany{number}")
org = Organization.create(name=f"Internal Org{number}", company=c)
if proj:
@@ -519,8 +519,8 @@ def test_list_instances_with_filters_param_results(client, journey, g_user, proj
instance.save()
args = [("journey_name", name) for name in ("Test_Journey1", "Test_Journey2")] + [
- ("start_range_begin", (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat()),
- ("start_range_end", datetime.now(timezone.utc).isoformat()),
+ ("start_range_begin", (datetime.now(UTC) - timedelta(hours=1)).isoformat()),
+ ("start_range_end", datetime.now(UTC).isoformat()),
]
query_string = MultiDict(args)
@@ -540,7 +540,7 @@ def test_list_instances_with_filters_param_results(client, journey, g_user, proj
assert r1.json["total"] == 3
r1 = client.get(
f"/observability/v1/projects/{instance_data1.project.id}/instances",
- query_string=MultiDict([("start_range_begin", datetime.now(timezone.utc).isoformat())]),
+ query_string=MultiDict([("start_range_begin", datetime.now(UTC).isoformat())]),
)
assert r1.status_code == HTTPStatus.OK, r1.json
assert r1.json["total"] == 0
diff --git a/observability_api/tests/integration/v1_endpoints/test_jwt_auth.py b/observability_api/tests/integration/v1_endpoints/test_jwt_auth.py
index 95cfd89..3041536 100644
--- a/observability_api/tests/integration/v1_endpoints/test_jwt_auth.py
+++ b/observability_api/tests/integration/v1_endpoints/test_jwt_auth.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta, timezone, UTC
from http import HTTPStatus
from unittest.mock import patch
from uuid import uuid4
@@ -61,7 +61,7 @@ def valid_token(token_user):
"company_id": str(token_user.primary_company_id),
"domain": "fakedomain.fake",
}
- dt = datetime.now(timezone.utc) + timedelta(days=2)
+ dt = datetime.now(UTC) + timedelta(days=2)
data["exp"] = int(dt.replace(microsecond=0).timestamp())
return JWTAuth.encode_token(data)
@@ -69,7 +69,7 @@ def valid_token(token_user):
@pytest.fixture
def invalid_token_bad_user(token_user):
data = {"user_id": str(uuid4()), "company_id": str(token_user.primary_company_id)}
- dt = datetime.now(timezone.utc) + timedelta(days=2)
+ dt = datetime.now(UTC) + timedelta(days=2)
data["exp"] = int(dt.replace(microsecond=0).timestamp())
return JWTAuth.encode_token(data)
@@ -140,15 +140,13 @@ def test_jwt_token_expiration(jwt_client, token_user):
token = JWTAuth.log_user_in(token_user)
claims = JWTAuth.decode_token(token)
- assert get_token_expiration(claims) < datetime.now(timezone.utc) + timedelta(seconds=20)
+ assert get_token_expiration(claims) < datetime.now(UTC) + timedelta(seconds=20)
@pytest.mark.integration
def test_jwt_token_expiration_explicit(jwt_client, token_user):
with patch("common.api.flask_ext.authentication.jwt_plugin.get_domain", return_value="fakedomain.fake"):
- token = JWTAuth.log_user_in(
- token_user, claims={"exp": (datetime.now(timezone.utc) + timedelta(seconds=10)).timestamp()}
- )
+ token = JWTAuth.log_user_in(token_user, claims={"exp": (datetime.now(UTC) + timedelta(seconds=10)).timestamp()})
claims = JWTAuth.decode_token(token)
- assert get_token_expiration(claims) < datetime.now(timezone.utc) + timedelta(seconds=10)
+ assert get_token_expiration(claims) < datetime.now(UTC) + timedelta(seconds=10)
diff --git a/observability_api/tests/integration/v1_endpoints/test_runs.py b/observability_api/tests/integration/v1_endpoints/test_runs.py
index 3902be2..50060e2 100644
--- a/observability_api/tests/integration/v1_endpoints/test_runs.py
+++ b/observability_api/tests/integration/v1_endpoints/test_runs.py
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta, timezone, UTC
from http import HTTPStatus
from itertools import chain
from typing import Optional
@@ -57,7 +57,7 @@ def uuid_value():
@pytest.fixture
def run_status_event(runs, pipeline, project, uuid_value):
- ts = str(datetime.now(timezone.utc))
+ ts = str(datetime.now(UTC))
yield RunStatusEvent(
**RunStatusSchema().load(
{
@@ -181,7 +181,7 @@ class RunData:
alerts: list[RunAlert] = field(default_factory=list)
-def create_run_data(number: int, proj: Optional[Project] = None, set_tool=False) -> RunData:
+def create_run_data(number: int, proj: Project | None = None, set_tool=False) -> RunData:
c = Company.create(name=f"TestCompany{number}")
org = Organization.create(name=f"Internal Org{number}", company=c)
if proj:
@@ -235,8 +235,8 @@ def test_list_runs_with_filters_param_results(client, g_user_2_admin, pipeline,
[("run_key", key) for key in ("1", "2")]
+ [("pipeline_key", key) for key in ("Test_Pipeline1", "Test_Pipeline2")]
+ [
- ("start_range_begin", (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat()),
- ("start_range_end", datetime.now(timezone.utc).isoformat()),
+ ("start_range_begin", (datetime.now(UTC) - timedelta(hours=1)).isoformat()),
+ ("start_range_end", datetime.now(UTC).isoformat()),
]
)
query_string = MultiDict(args)
@@ -526,8 +526,8 @@ def test_list_batch_pipeline_runs_with_filters_param_results(client, g_user_2_ad
_ = create_run_data(3, proj=rd_two.project)
args = [("run_key", key) for key in ("1", "2")] + [
- ("start_range_begin", (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat()),
- ("start_range_end", datetime.now(timezone.utc).isoformat()),
+ ("start_range_begin", (datetime.now(UTC) - timedelta(hours=1)).isoformat()),
+ ("start_range_end", datetime.now(UTC).isoformat()),
("pipeline_id", pipeline.id),
]
query_string = MultiDict(args)
@@ -539,7 +539,7 @@ def test_list_batch_pipeline_runs_with_filters_param_results(client, g_user_2_ad
f"/observability/v1/projects/{project.id}/runs",
query_string=MultiDict(
[
- ("start_range_begin", datetime.now(timezone.utc).isoformat()),
+ ("start_range_begin", datetime.now(UTC).isoformat()),
("pipeline_id", pipeline.id),
]
),
@@ -578,9 +578,9 @@ def test_list_runs_for_instance(client, g_user, instances, runs, project):
)
assert response.status_code == HTTPStatus.OK, response.json
response_body = response.json
- assert (
- response_body["total"] == 1 and len(response_body["entities"]) == response_body["total"]
- ), "should return one run for each instance"
+ assert response_body["total"] == 1 and len(response_body["entities"]) == response_body["total"], (
+ "should return one run for each instance"
+ )
assert len(response_body["entities"][0]["alerts"]) == 1
expected_run_ids = [str(r.id) for r in runs]
assert response_body["entities"][0]["id"] in expected_run_ids, "the returned ID isn't one of the expected runs"
@@ -619,9 +619,9 @@ def test_list_runs_for_instance_with_summaries(client, g_user, instances, runs,
(RunTaskStatus.FAILED, 2),
)
for status, expected in tasks_statuses:
- assert any(
- (status.name, expected) == (task["status"], task["count"]) for task in run["tasks_summary"]
- ), f"did not find expected {{\"status\": {status.name}, \"count\": {expected}}} in {run['tasks_summary']}"
+ assert any((status.name, expected) == (task["status"], task["count"]) for task in run["tasks_summary"]), (
+ f'did not find expected {{"status": {status.name}, "count": {expected}}} in {run["tasks_summary"]}'
+ )
assert len(run["alerts"]) == 1
assert run["alerts"][0]["level"] == AlertLevel["ERROR"].value
@@ -721,7 +721,7 @@ def test_get_instance_runs_status_filters(query_string, expected, client, g_user
key=key,
pipeline=pipeline,
instance_set=instance_set,
- status=f"{RunStatus.COMPLETED.name if int(key) % 2 == 0 else RunStatus.FAILED.name}",
+ status=f"{RunStatus.COMPLETED.name if int(key) % 2 == 0 else RunStatus.FAILED.name}",
)
base_query_string = {"instance_id": [instance.id], "status": query_string}
response = client.get(f"/observability/v1/projects/{project.id}/runs", query_string=base_query_string)
diff --git a/observability_api/tests/integration/v1_endpoints/test_service_account_keys.py b/observability_api/tests/integration/v1_endpoints/test_service_account_keys.py
index 1a7f3da..d93338f 100644
--- a/observability_api/tests/integration/v1_endpoints/test_service_account_keys.py
+++ b/observability_api/tests/integration/v1_endpoints/test_service_account_keys.py
@@ -1,5 +1,5 @@
import uuid
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
from http import HTTPStatus
import pytest
@@ -24,7 +24,7 @@ def sa_key(client, project):
@pytest.mark.integration
def test_create_sa_key_success(client, g_user, project, sa_key_data):
- today = datetime.now(timezone.utc)
+ today = datetime.now(UTC)
response = client.post(
f"/observability/v1/projects/{project.id}/service-account-key",
headers={"Content-Type": "application/json"},
@@ -46,7 +46,7 @@ def test_create_sa_key_success(client, g_user, project, sa_key_data):
@pytest.mark.integration
def test_create_sa_key_with_name_and_description(client, g_user, project, sa_key_data):
sa_key_data["description"] = "Whoa man, I'm just using this for auth"
- today = datetime.now(timezone.utc)
+ today = datetime.now(UTC)
response = client.post(
f"/observability/v1/projects/{project.id}/service-account-key",
headers={"Content-Type": "application/json"},
diff --git a/observability_api/tests/integration/v1_endpoints/test_upcoming_instances.py b/observability_api/tests/integration/v1_endpoints/test_upcoming_instances.py
index f96c724..c42bdac 100644
--- a/observability_api/tests/integration/v1_endpoints/test_upcoming_instances.py
+++ b/observability_api/tests/integration/v1_endpoints/test_upcoming_instances.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta, timezone, UTC
from http import HTTPStatus
import pytest
@@ -14,7 +14,7 @@ def test_list_project_upcoming_instances_instance_schedule(client, journey, jour
InstanceRule.create(journey=journey, action=InstanceRuleAction.END, expression="30,40 * * * *")
InstanceRule.create(journey=journey_2, action=InstanceRuleAction.START, expression="10,50 * * * *")
- start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=timezone.utc)
+ start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=UTC)
query = MultiDict(
[
("start_range", start_time.isoformat()),
@@ -70,7 +70,7 @@ def test_list_project_upcoming_instances_batch_schedule(
batch_end_schedule.schedule = "30 * * * *"
batch_end_schedule.save()
- start_time = datetime(1991, 2, 20, 10, 59, 00, tzinfo=timezone.utc)
+ start_time = datetime(1991, 2, 20, 10, 59, 00, tzinfo=UTC)
query = MultiDict(
[
("start_range", start_time.isoformat()),
@@ -94,7 +94,7 @@ def test_list_project_upcoming_instances_filters(client, journey, journey_2, ins
InstanceRule.create(journey=journey, action=InstanceRuleAction.START, expression="10 * * * *")
InstanceRule.create(journey=journey_2, action=InstanceRuleAction.START, expression="30 * * * *")
- start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=timezone.utc)
+ start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=UTC)
query = MultiDict(
[
("start_range", start_time.isoformat()),
@@ -148,7 +148,7 @@ def test_list_company_upcoming_instances_instance_schedule(client, organization,
InstanceRule.create(journey=journey, action=InstanceRuleAction.START, expression="10 * * * *")
InstanceRule.create(journey=journey_2, action=InstanceRuleAction.START, expression="30 * * * *")
- start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=timezone.utc)
+ start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=UTC)
query = MultiDict(
[
("start_range", start_time.isoformat()),
@@ -173,7 +173,7 @@ def test_list_company_upcoming_instances_filters(client, project, organization,
InstanceRule.create(journey=journey, action=InstanceRuleAction.START, expression="10 * * * *")
InstanceRule.create(journey=journey_2, action=InstanceRuleAction.START, expression="30 * * * *")
- start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=timezone.utc)
+ start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=UTC)
query = MultiDict(
[
("start_range", start_time.isoformat()),
@@ -210,7 +210,7 @@ def test_list_company_upcoming_instances_filters(client, project, organization,
@pytest.mark.integration
def test_list_project_upcoming_instances_sa_key_auth_ok(client, journey, journey_2, instance, g_project):
- start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=timezone.utc)
+ start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=UTC)
query = MultiDict(
[
("start_range", start_time.isoformat()),
@@ -223,7 +223,7 @@ def test_list_project_upcoming_instances_sa_key_auth_ok(client, journey, journey
@pytest.mark.integration
def test_list_company_upcoming_instances_sa_key_auth_forbidden(client, journey, journey_2, instance, g_project):
- start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=timezone.utc)
+ start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=UTC)
query = MultiDict(
[
("start_range", start_time.isoformat()),
diff --git a/pyproject.toml b/pyproject.toml
index 67b7517..5144753 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -47,14 +47,14 @@ dependencies = [
[project.optional-dependencies]
dev = [
- "ruff~=0.7.3",
+ "ruff~=0.12.0",
"invoke~=2.1.2",
"lxml~=4.9.1",
- "mypy~=1.5.0",
+ "mypy~=1.16.1",
"pre-commit~=2.20.0",
- "pytest-cov~=4.0.0",
- "pytest-xdist~=3.1.0",
- "pytest~=7.2.0",
+ "pytest-cov~=6.2.1",
+ "pytest-xdist~=3.7.0",
+ "pytest~=8.4.1",
"pyyaml~=6.0",
"types-PyYAML~=6.0.8",
"types-requests==2.28.11.15",
@@ -223,6 +223,10 @@ check_untyped_defs = false
module = "PIL.*"
ignore_missing_imports = true
+[[tool.mypy.overrides]]
+module = "IPython.*"
+ignore_missing_imports = true
+
[[tool.mypy.overrides]]
module = "invoke"
ignore_missing_imports = true
@@ -235,6 +239,10 @@ ignore_missing_imports = true
module = "msgpack.*"
ignore_missing_imports = true
+[[tool.mypy.overrides]]
+module = "observability_plugins.*"
+ignore_missing_imports = true
+
[[tool.mypy.overrides]]
module = "marshmallow_union.*"
ignore_missing_imports = true
@@ -243,8 +251,12 @@ ignore_missing_imports = true
module = "pybars"
ignore_missing_imports = true
+[[tool.mypy.overrides]]
+module = "yoyo.*"
+ignore_missing_imports = true
+
[tool.ruff]
-target-version = "py310"
+target-version = "py312"
line-length = 120
[tool.ruff.lint]
diff --git a/rules_engine/journey_rules.py b/rules_engine/journey_rules.py
index 8ffe6ed..d13b3ad 100644
--- a/rules_engine/journey_rules.py
+++ b/rules_engine/journey_rules.py
@@ -5,7 +5,7 @@
import logging
import time
from functools import lru_cache
-from typing import Any, Optional
+from typing import Any
from collections.abc import Callable
from uuid import UUID
@@ -33,19 +33,19 @@ class JourneyRule:
def __init__(
self,
r_obj: R,
- rule_entity: Optional[RuleEntity],
+ rule_entity: RuleEntity | None,
*triggers: Callable,
- journey_id: Optional[UUID] = None,
- component_id: Optional[UUID] = None,
+ journey_id: UUID | None = None,
+ component_id: UUID | None = None,
) -> None:
self.r_obj: R = r_obj
self.rule_entity = rule_entity
- self.triggers: tuple[Callable[[EVENT_TYPE, Optional[RuleEntity], Optional[UUID]], ActionResult], ...] = triggers
- self.journey_id: Optional[UUID] = journey_id
- self.component_id: Optional[UUID] = component_id
+ self.triggers: tuple[Callable[[EVENT_TYPE, RuleEntity | None, UUID | None], ActionResult], ...] = triggers
+ self.journey_id: UUID | None = journey_id
+ self.component_id: UUID | None = component_id
@staticmethod
- def _get_component_id(event: EVENT_TYPE) -> Optional[UUID]:
+ def _get_component_id(event: EVENT_TYPE) -> UUID | None:
"""Extract the component id from the given event."""
match event:
case Event():
@@ -81,7 +81,7 @@ def __str__(self) -> str:
return f"{self.__module__}.{self.__class__.__name__}: {self.r_obj}"
-def _execute_action(event: EVENT_TYPE, rule_entity: RuleEntity, journey_id: Optional[UUID]) -> Any:
+def _execute_action(event: EVENT_TYPE, rule_entity: RuleEntity, journey_id: UUID | None) -> Any:
action_entity = JourneyService.get_action_by_implementation(rule_entity.journey_id, rule_entity.action)
action = action_factory(rule_entity.action, rule_entity.action_args, action_entity)
action.execute(event, rule_entity, journey_id)
diff --git a/rules_engine/rule_data.py b/rules_engine/rule_data.py
index f9f3f5f..e652a94 100644
--- a/rules_engine/rule_data.py
+++ b/rules_engine/rule_data.py
@@ -1,6 +1,5 @@
__all__ = ["RuleData"]
-from typing import Optional
from uuid import UUID
from peewee import SelectQuery, fn
@@ -21,7 +20,7 @@ class DatabaseData:
def __init__(self, event: EVENT_TYPE) -> None:
self.event = event
- def _get_batch_pipeline_id(self) -> Optional[UUID]:
+ def _get_batch_pipeline_id(self) -> UUID | None:
match self.event:
case Event():
return self.event.pipeline_id
diff --git a/rules_engine/tests/integration/conftest.py b/rules_engine/tests/integration/conftest.py
index 789f275..39938a2 100644
--- a/rules_engine/tests/integration/conftest.py
+++ b/rules_engine/tests/integration/conftest.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
from unittest.mock import MagicMock, Mock, patch
from uuid import uuid4
@@ -31,8 +31,8 @@ def base_event_data():
"pipeline_name": None,
"project_id": str(uuid4()),
"source": EventSources.API.name,
- "event_timestamp": str(datetime.now(timezone.utc)),
- "received_timestamp": str(datetime.now(timezone.utc)),
+ "event_timestamp": str(datetime.now(UTC)),
+ "received_timestamp": str(datetime.now(UTC)),
"external_url": "https://example.com",
"metadata": {},
"run_id": None,
diff --git a/rules_engine/tests/unit/test_data_points.py b/rules_engine/tests/unit/test_data_points.py
index e189149..0b9fcb4 100644
--- a/rules_engine/tests/unit/test_data_points.py
+++ b/rules_engine/tests/unit/test_data_points.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
import pytest
@@ -206,7 +206,7 @@ def test_run_data_points(run_status_event, run, rule):
# DataPoints is using the correct source
assert run.status != run_status_event.status
assert dps.run.status == run.status
- assert dps.run.start_time == run.start_time.replace(tzinfo=timezone.utc).isoformat()
+ assert dps.run.start_time == run.start_time.replace(tzinfo=UTC).isoformat()
assert dps.run.start_time_formatted == datetime_formatted(run.start_time)
assert dps.run.end_time == "N/A"
assert dps.run.end_time_formatted == "N/A"
@@ -221,7 +221,7 @@ def test_run_data_points_with_end_time(run_status_event, run, rule):
run.save()
dps = DataPoints(run_status_event, rule)
# These are tested in a separate function because the run is cached in Event
- assert dps.run.end_time == run.end_time.replace(tzinfo=timezone.utc).isoformat()
+ assert dps.run.end_time == run.end_time.replace(tzinfo=UTC).isoformat()
assert dps.run.end_time_formatted == datetime_formatted(run.end_time)
@@ -260,9 +260,9 @@ def test_run_task_data_points_with_times(run_status_event, run_task, rule):
run_task.save()
dps = DataPoints(run_status_event, rule)
# These are tested in a separate function because the run task is cached in Event
- assert dps.run_task.start_time == run_task.start_time.replace(tzinfo=timezone.utc).isoformat()
+ assert dps.run_task.start_time == run_task.start_time.replace(tzinfo=UTC).isoformat()
assert dps.run_task.start_time_formatted == datetime_formatted(run_task.start_time)
- assert dps.run_task.end_time == run_task.end_time.replace(tzinfo=timezone.utc).isoformat()
+ assert dps.run_task.end_time == run_task.end_time.replace(tzinfo=UTC).isoformat()
assert dps.run_task.end_time_formatted == datetime_formatted(run_task.end_time)
diff --git a/run_manager/alerts.py b/run_manager/alerts.py
index ba707ed..aa4789a 100644
--- a/run_manager/alerts.py
+++ b/run_manager/alerts.py
@@ -1,5 +1,4 @@
import logging
-from typing import Optional
from collections.abc import Iterable
from uuid import UUID
@@ -82,8 +81,8 @@ def create_run_alert(alert_type: RunAlertType, run: Run, pipeline: Pipeline) ->
def create_instance_alert(
alert_type: InstanceAlertType,
instance: Instance,
- component: Optional[Component] = None,
- alert_components: Optional[Iterable[UUID]] = None,
+ component: Component | None = None,
+ alert_components: Iterable[UUID] | None = None,
) -> InstanceAlertEvent:
alert_level = INSTANCE_ALERT_LEVELS[alert_type]
alert_description = INSTANCE_ALERT_DESCRIPTIONS[alert_type].format(
diff --git a/run_manager/context.py b/run_manager/context.py
index 2da0c14..a881520 100644
--- a/run_manager/context.py
+++ b/run_manager/context.py
@@ -1,6 +1,5 @@
__all__ = ["RunManagerContext"]
from dataclasses import dataclass, field
-from typing import Optional
from uuid import UUID
from common.entities import Component, Instance, InstanceSet, Pipeline, Run, RunTask, Task
@@ -13,21 +12,21 @@ class RunManagerContext:
A context object to pass a state around when handling events in run manager
"""
- component: Optional[Component] = None
+ component: Component | None = None
# Keeping pipeline to keep "pre" event v1 code intact, i.e. avoid significant refactoring effort
- pipeline: Optional[Pipeline] = None
- run: Optional[Run] = None
- task: Optional[Task] = None
- run_task: Optional[RunTask] = None
+ pipeline: Pipeline | None = None
+ run: Run | None = None
+ task: Task | None = None
+ run_task: RunTask | None = None
instances: list[InstanceRef] = field(default_factory=list)
- instance_set: Optional[InstanceSet] = None
+ instance_set: InstanceSet | None = None
ended_instances: list[UUID | Instance] = field(default_factory=list)
"""List of instances that ended in this context"""
created_run: bool = False
"""Indicates if the run was created during this context"""
started_run: bool = False
"""Indicates if the run was started during this context"""
- prev_run_status: Optional[str] = None
+ prev_run_status: str | None = None
"""Previous run status before being processed by the run handler.
This is to check for unexpected run status changed"""
diff --git a/run_manager/event_handlers/component_identifier.py b/run_manager/event_handlers/component_identifier.py
index 61b98c1..30871ef 100644
--- a/run_manager/event_handlers/component_identifier.py
+++ b/run_manager/event_handlers/component_identifier.py
@@ -1,7 +1,7 @@
__all__ = ["ComponentIdentifier"]
import logging
-from typing import Optional, cast
+from typing import cast
from peewee import DoesNotExist
@@ -25,7 +25,7 @@
"""Map event component type to db model"""
-def _get_component(event: Event) -> Optional[Component]:
+def _get_component(event: Event) -> Component | None:
# v1 event can only have one (component type)_id
if component_id := event.component_id:
try:
@@ -56,7 +56,7 @@ def _get_component(event: Event) -> Optional[Component]:
return component
-def _create_component(event: Event) -> Optional[Component]:
+def _create_component(event: Event) -> Component | None:
component: Component = event.component_model.create(
key=event.component_key, name=event.component_name, tool=event.component_tool, project_id=event.project_id
)
diff --git a/run_manager/event_handlers/instance_handler.py b/run_manager/event_handlers/instance_handler.py
index 82cf862..1b98f36 100644
--- a/run_manager/event_handlers/instance_handler.py
+++ b/run_manager/event_handlers/instance_handler.py
@@ -2,7 +2,7 @@
import logging
from collections import defaultdict
from itertools import chain
-from typing import Any, Optional, cast
+from typing import Any, cast
from collections.abc import Callable, Mapping
from uuid import UUID
@@ -34,7 +34,7 @@
LOG = logging.getLogger(__name__)
-def _find_existing_instance(journey: Journey, with_run: bool, payload_key: Optional[str]) -> Optional[Instance]:
+def _find_existing_instance(journey: Journey, with_run: bool, payload_key: str | None) -> Instance | None:
if payload_key is None:
f: Callable[[Any], bool | Any] = lambda k: k.payload_key is None
else:
@@ -234,7 +234,7 @@ def default_instance_creation(self, event: Event) -> list[InstanceRef]:
identified_instances: list[InstanceRef] = []
with_run = event.run_id is not None
for journey in event.component_journeys:
- payload_keys: list[Optional[str]] = [None]
+ payload_keys: list[str | None] = [None]
if any(rule.action is InstanceRuleAction.END_PAYLOAD for rule in journey.instance_rules):
payload_keys.extend(event.payload_keys)
for payload_key in payload_keys:
diff --git a/run_manager/event_handlers/run_handler.py b/run_manager/event_handlers/run_handler.py
index 91c829b..3f53243 100644
--- a/run_manager/event_handlers/run_handler.py
+++ b/run_manager/event_handlers/run_handler.py
@@ -1,7 +1,7 @@
__all__ = ["RunHandler"]
import logging
-from typing import Optional, cast
+from typing import cast
from peewee import DoesNotExist
@@ -102,7 +102,7 @@ def _handle_event(self, event: Event) -> None:
f"for batch-pipeline {self.context.pipeline.id}"
)
- def _get_run(self, event: Event, pipeline: Pipeline) -> Optional[Run]:
+ def _get_run(self, event: Event, pipeline: Pipeline) -> Run | None:
"""
Get an existing run instance
diff --git a/run_manager/event_handlers/run_unexpected_status_change_handler.py b/run_manager/event_handlers/run_unexpected_status_change_handler.py
index 8ed8a17..2078391 100644
--- a/run_manager/event_handlers/run_unexpected_status_change_handler.py
+++ b/run_manager/event_handlers/run_unexpected_status_change_handler.py
@@ -1,5 +1,4 @@
import logging
-from typing import Optional
from common.entities import RunAlertType, RunStatus
from common.events import EventHandlerBase
@@ -19,7 +18,7 @@
"""Map new run status to alert type"""
-def _get_alert_type(context: RunManagerContext) -> Optional[RunAlertType]:
+def _get_alert_type(context: RunManagerContext) -> RunAlertType | None:
if context.run is None:
raise ValueError("The `run` attribute for the context object must be populated with a valid Run instance")
@@ -52,7 +51,7 @@ def handle_test_outcomes(self, event: TestOutcomesEvent) -> bool:
def handle_run_status(self, event: RunStatusEvent) -> bool:
if (pipeline := self.context.pipeline) is None or (run := self.context.run) is None:
raise ValueError(
- "The context object must be populated with a valid Pipeline and Run instance " "for RunStatusEvent"
+ "The context object must be populated with a valid Pipeline and Run instance for RunStatusEvent"
)
if (alert_type := _get_alert_type(self.context)) is not None:
diff --git a/run_manager/tests/integration/conftest.py b/run_manager/tests/integration/conftest.py
index e49b6b8..b58f912 100644
--- a/run_manager/tests/integration/conftest.py
+++ b/run_manager/tests/integration/conftest.py
@@ -1,6 +1,6 @@
import uuid
from dataclasses import replace
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta, timezone, UTC
from unittest.mock import MagicMock
import pytest
@@ -57,7 +57,7 @@ def compare_event_data(unidentified_event, identified_event, pipeline, run, task
@pytest.fixture
def timestamp_now():
- return datetime.now(tz=timezone.utc)
+ return datetime.now(tz=UTC)
@pytest.fixture
@@ -367,7 +367,7 @@ def pipeline_end_payload_rule(journey, pipeline):
@pytest.fixture
def timestamp_now():
- return datetime.now(tz=timezone.utc)
+ return datetime.now(tz=UTC)
@pytest.fixture
diff --git a/run_manager/tests/integration/event_handlers/test_out_of_sequence_instance_handler.py b/run_manager/tests/integration/event_handlers/test_out_of_sequence_instance_handler.py
index 340ead1..8e62e85 100644
--- a/run_manager/tests/integration/event_handlers/test_out_of_sequence_instance_handler.py
+++ b/run_manager/tests/integration/event_handlers/test_out_of_sequence_instance_handler.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
import pytest
@@ -47,8 +47,8 @@ def __init__(self, project, pipelines=None, name="out-of-sequence-test"):
project=project,
instance_set=self.instance_set.id,
status=RunStatus.COMPLETED.name,
- start_time=datetime.now(tz=timezone.utc),
- end_time=datetime.now(tz=timezone.utc),
+ start_time=datetime.now(tz=UTC),
+ end_time=datetime.now(tz=UTC),
)
)
JourneyDagEdge.create(
diff --git a/run_manager/tests/integration/test_run_handler.py b/run_manager/tests/integration/test_run_handler.py
index dda4f41..e7048f2 100644
--- a/run_manager/tests/integration/test_run_handler.py
+++ b/run_manager/tests/integration/test_run_handler.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
from uuid import UUID
import pytest
@@ -26,12 +26,12 @@ def test_run_handler_new_pending_metadata(run_status_event_pending, pipeline):
# Retrieve the run and make sure the expected_start_time has been updated
run = Run.get_by_id(handler.context.run.id)
- assert run.expected_start_time == datetime(2005, 3, 1, 1, 1, 1, tzinfo=timezone.utc)
+ assert run.expected_start_time == datetime(2005, 3, 1, 1, 1, 1, tzinfo=UTC)
@pytest.mark.integration
def test_run_handler_no_overwrite_expected_start_time(run_status_event_missing, pipeline):
- expected_start_time = datetime(2005, 3, 2, 2, 2, 2, tzinfo=timezone.utc)
+ expected_start_time = datetime(2005, 3, 2, 2, 2, 2, tzinfo=UTC)
run = Run.create(
id=UUID("dbed19e1-d0bb-4860-bbdf-9d768cb90764"),
pipeline=pipeline,
diff --git a/run_manager/tests/integration/test_run_manager_instance.py b/run_manager/tests/integration/test_run_manager_instance.py
index 13cf35c..e556ea0 100644
--- a/run_manager/tests/integration/test_run_manager_instance.py
+++ b/run_manager/tests/integration/test_run_manager_instance.py
@@ -1,5 +1,5 @@
import copy
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
from uuid import uuid4
import pytest
@@ -257,7 +257,7 @@ def test_run_manager_dont_modify_previously_closed_instance(
):
p1 = Pipeline.create(key="pipe1", project=project)
j1 = Journey.create(name="journey1", project=project)
- instance_time = datetime.utcnow().replace(tzinfo=timezone.utc)
+ instance_time = datetime.utcnow().replace(tzinfo=UTC)
i1 = Instance.create(journey=j1, start_time=instance_time, end_time=instance_time)
InstanceRule.create(journey=j1, action=InstanceRuleAction.START, batch_pipeline=p1)
InstanceRule.create(journey=j1, action=InstanceRuleAction.END, batch_pipeline=p1)
diff --git a/run_manager/tests/integration/test_run_manager_unordered_events.py b/run_manager/tests/integration/test_run_manager_unordered_events.py
index dd1753c..c53ab6e 100644
--- a/run_manager/tests/integration/test_run_manager_unordered_events.py
+++ b/run_manager/tests/integration/test_run_manager_unordered_events.py
@@ -1,5 +1,5 @@
import copy
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta, timezone, UTC
from uuid import uuid4
import pytest
@@ -15,7 +15,7 @@ def test_keep_run_end_state_and_update_start_state(pipeline, kafka_consumer, kaf
Keep the run end state (end time, status) when an old message is processed.
Update the start time when older message is received.
"""
- new_time = datetime.now(tz=timezone.utc)
+ new_time = datetime.now(tz=UTC)
old_time = new_time - timedelta(minutes=5)
old_start_status_message = run_status_message
@@ -60,7 +60,7 @@ def test_reopen_run_on_newer_status(
assert run.end_time is not None
# Re-open existing run
- run_status_message.payload.event_timestamp = datetime.utcnow().replace(tzinfo=timezone.utc)
+ run_status_message.payload.event_timestamp = datetime.utcnow().replace(tzinfo=UTC)
run_status_message.payload.event_id = uuid4()
kafka_consumer.__iter__.return_value = iter((run_status_message,))
run_manager.process_events()
@@ -77,7 +77,7 @@ def test_keep_task_end_state_and_update_start_state(kafka_consumer, kafka_produc
Keep the task end state (end time, status) when an old message is processed.
Update the start time when older message is received.
"""
- new_time = datetime.now(tz=timezone.utc)
+ new_time = datetime.now(tz=UTC)
old_time = new_time - timedelta(minutes=5)
old_status_message = task_status_message
old_status_message.payload.event_timestamp = old_time
@@ -122,7 +122,7 @@ def test_reopen_task_on_newer_status(pipeline, kafka_consumer, kafka_producer, r
assert run.end_time is not None
# Re-open existing run
- run_status_message.payload.event_timestamp = datetime.utcnow().replace(tzinfo=timezone.utc)
+ run_status_message.payload.event_timestamp = datetime.utcnow().replace(tzinfo=UTC)
run_status_message.payload.event_id = uuid4()
kafka_consumer.__iter__.return_value = iter((run_status_message,))
run_manager.process_events()
diff --git a/run_manager/tests/integration/test_scheduler_events.py b/run_manager/tests/integration/test_scheduler_events.py
index 891ad71..1d5652a 100644
--- a/run_manager/tests/integration/test_scheduler_events.py
+++ b/run_manager/tests/integration/test_scheduler_events.py
@@ -1,5 +1,5 @@
from collections import Counter
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta, timezone, UTC
from uuid import uuid4
import pytest
@@ -159,7 +159,7 @@ def test_scheduler_run_should_start_with_pending(
):
"""Check that pending runs are marked missing when start schedule occurs"""
instance = Instance.create(journey=journey)
- expected_start_time = datetime(2005, 3, 2, 2, 2, 2, tzinfo=timezone.utc)
+ expected_start_time = datetime(2005, 3, 2, 2, 2, 2, tzinfo=UTC)
run = Run.create(
status=RunStatus.PENDING.name,
start_time=None,
diff --git a/scheduler/agent_check.py b/scheduler/agent_check.py
index 5ba0b5c..55058d3 100644
--- a/scheduler/agent_check.py
+++ b/scheduler/agent_check.py
@@ -1,5 +1,5 @@
import logging
-from datetime import datetime, timezone, timedelta
+from datetime import datetime, timedelta, UTC
from apscheduler.triggers.interval import IntervalTrigger
@@ -14,7 +14,7 @@
def _get_agent_status(check_interval_seconds: int, latest_heartbeat: datetime) -> AgentStatus:
- lateness = (datetime.now(tz=timezone.utc) - latest_heartbeat).total_seconds()
+ lateness = (datetime.now(tz=UTC) - latest_heartbeat).total_seconds()
if lateness > check_interval_seconds * settings.AGENT_STATUS_CHECK_OFFLINE_FACTOR:
return AgentStatus.OFFLINE
elif lateness > check_interval_seconds * settings.AGENT_STATUS_CHECK_UNHEALTHY_FACTOR:
@@ -65,7 +65,7 @@ def _create_and_add_job(self, schedule: AgentCheckSchedule) -> None:
)
def _check_agents_are_online(self, project: Project) -> None:
- check_threshold = datetime.now(tz=timezone.utc) - timedelta(seconds=project.agent_check_interval)
+ check_threshold = datetime.now(tz=UTC) - timedelta(seconds=project.agent_check_interval)
for agent in Agent.select().where(
Agent.project == project.id,
Agent.latest_heartbeat < check_threshold,
diff --git a/scheduler/component_expectations.py b/scheduler/component_expectations.py
index 192ed50..c733296 100644
--- a/scheduler/component_expectations.py
+++ b/scheduler/component_expectations.py
@@ -1,6 +1,5 @@
import logging
from datetime import datetime, timedelta
-from typing import Optional
from uuid import UUID
from apscheduler.triggers.cron import CronTrigger
@@ -29,7 +28,7 @@ def _produce_event(
component_id: UUID,
schedule_type: ScheduleType,
is_margin: bool,
- margin: Optional[int] = None,
+ margin: int | None = None,
) -> None:
"""Create and forward corresponding scheduler event(s) to the run manager"""
if is_margin:
diff --git a/scheduler/schedule_source.py b/scheduler/schedule_source.py
index 7d3d22e..aac4b95 100644
--- a/scheduler/schedule_source.py
+++ b/scheduler/schedule_source.py
@@ -3,7 +3,7 @@
import logging
from collections import defaultdict
from datetime import datetime
-from typing import Any, Optional, Protocol, Generic, TypeVar
+from typing import Any, Protocol, TypeVar
from collections.abc import Callable
from apscheduler.executors.pool import ThreadPoolExecutor
@@ -41,7 +41,7 @@ def id(self) -> str: ...
ST = TypeVar("ST", bound=Schedule)
-class ScheduleSource(Generic[ST]):
+class ScheduleSource[ST: Schedule]:
"""Concentrates all features and configurations around a specific source of schedules."""
source_name: str
@@ -65,7 +65,7 @@ def jobstore_name(self) -> str:
def executor_name(self) -> str:
return self.source_name
- def add_job(self, func: Callable, job_id: str, trigger: BaseTrigger, kwargs: Optional[dict[str, Any]]) -> Job:
+ def add_job(self, func: Callable, job_id: str, trigger: BaseTrigger, kwargs: dict[str, Any] | None) -> Job:
return self.scheduler.add_job(
func,
id=job_id,
diff --git a/scheduler/tests/integration/test_agent_scheduler.py b/scheduler/tests/integration/test_agent_scheduler.py
index 0bc4d31..114ea84 100644
--- a/scheduler/tests/integration/test_agent_scheduler.py
+++ b/scheduler/tests/integration/test_agent_scheduler.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timezone, timedelta
+from datetime import datetime, timezone, timedelta, UTC
from unittest.mock import patch
import pytest
@@ -19,7 +19,7 @@ def agents(project):
tool="tool",
version="vTest",
status=AgentStatus.ONLINE,
- latest_heartbeat=datetime.now(tz=timezone.utc) - timedelta(seconds=elapsed_time),
+ latest_heartbeat=datetime.now(tz=UTC) - timedelta(seconds=elapsed_time),
)
for elapsed_time in (
25, # Below the checking threshold
diff --git a/scheduler/tests/integration/test_schedule_source.py b/scheduler/tests/integration/test_schedule_source.py
index 67d50ad..e713686 100644
--- a/scheduler/tests/integration/test_schedule_source.py
+++ b/scheduler/tests/integration/test_schedule_source.py
@@ -1,5 +1,5 @@
import threading
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
from unittest.mock import Mock, patch
import pytest
@@ -29,7 +29,7 @@ def get_next_fire_time(self, previous_fire_time, now):
if previous_fire_time:
return None
else:
- return datetime.now(tz=timezone.utc)
+ return datetime.now(tz=UTC)
class TestScheduleSource(ScheduleSource):
diff --git a/scheduler/tests/unit/conftest.py b/scheduler/tests/unit/conftest.py
index 1dae93d..05fab03 100644
--- a/scheduler/tests/unit/conftest.py
+++ b/scheduler/tests/unit/conftest.py
@@ -1,4 +1,4 @@
-from datetime import datetime, timezone
+from datetime import datetime, timezone, UTC
from unittest.mock import Mock, patch
from uuid import uuid4
@@ -42,7 +42,7 @@ def agent_source(scheduler, event_producer_mock):
@pytest.fixture
def job_kwargs():
return {
- "run_time": datetime.now(tz=timezone.utc),
+ "run_time": datetime.now(tz=UTC),
"schedule_type": ScheduleType.BATCH_END_TIME,
"schedule_id": str(uuid4()),
"component_id": str(uuid4()),
@@ -64,4 +64,4 @@ def schedule_data():
@pytest.fixture
def run_time():
- return datetime.now(tz=timezone.utc)
+ return datetime.now(tz=UTC)
diff --git a/scheduler/tests/unit/test_agent_scheduler.py b/scheduler/tests/unit/test_agent_scheduler.py
index 640d509..20aa3fa 100644
--- a/scheduler/tests/unit/test_agent_scheduler.py
+++ b/scheduler/tests/unit/test_agent_scheduler.py
@@ -1,5 +1,5 @@
import uuid
-from datetime import datetime, timezone, timedelta
+from datetime import datetime, timezone, timedelta, UTC
from unittest.mock import patch
import pytest
@@ -43,5 +43,5 @@ def test_add_job(agent_source):
],
)
def test_get_agent_status(elapsed_time, expected_status):
- latest_heartbeat = datetime.now(tz=timezone.utc) - timedelta(seconds=elapsed_time)
+ latest_heartbeat = datetime.now(tz=UTC) - timedelta(seconds=elapsed_time)
assert _get_agent_status(CHECK_INTERVAL, latest_heartbeat) == expected_status
diff --git a/scripts/invocations/deploy.py b/scripts/invocations/deploy.py
index 9b48a50..4ac5d37 100644
--- a/scripts/invocations/deploy.py
+++ b/scripts/invocations/deploy.py
@@ -315,7 +315,7 @@ def build(
if ui:
ctx.run(
- f"docker build . {args_str} " f"-t 'observability-ui:{tag}' -f ./deploy/docker/observability-ui.dockerfile",
+ f"docker build . {args_str} -t 'observability-ui:{tag}' -f ./deploy/docker/observability-ui.dockerfile",
env=env,
)
diff --git a/setup.cfg b/setup.cfg
deleted file mode 100644
index 120c85e..0000000
--- a/setup.cfg
+++ /dev/null
@@ -1,2 +0,0 @@
-# This file is here to support editable installs (pip install -e .)
-# https://github.com/pypa/setuptools/issues/2816
diff --git a/testlib/fixtures/entities.py b/testlib/fixtures/entities.py
index 4acff83..c05ec17 100644
--- a/testlib/fixtures/entities.py
+++ b/testlib/fixtures/entities.py
@@ -61,7 +61,7 @@
import base64
import uuid
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta, UTC
from decimal import Decimal
from unittest.mock import patch
from uuid import UUID
@@ -445,11 +445,11 @@ def batch_end_schedule(pipeline):
)
-ALERT_EXPECTED_START_DT: datetime = datetime(2005, 8, 13, 4, 2, 4, 2, tzinfo=timezone.utc)
+ALERT_EXPECTED_START_DT: datetime = datetime(2005, 8, 13, 4, 2, 4, 2, tzinfo=UTC)
"""The datetime of the expected_start_time for alert fixtures."""
-ALERT_EXPECTED_END_DT: datetime = datetime(2005, 8, 13, 8, 4, 8, 1, tzinfo=timezone.utc)
+ALERT_EXPECTED_END_DT: datetime = datetime(2005, 8, 13, 8, 4, 8, 1, tzinfo=UTC)
"""The datetime of the expected_end_time for alert fixtures."""
@@ -496,8 +496,8 @@ def test_outcome(component, run, task):
component=component,
description="Testy McTestface",
dimensions=["a", "b", "c"],
- start_time=datetime(2000, 3, 9, 12, 11, 10, tzinfo=timezone.utc),
- end_time=datetime(2000, 3, 9, 13, 12, 11, tzinfo=timezone.utc),
+ start_time=datetime(2000, 3, 9, 12, 11, 10, tzinfo=UTC),
+ end_time=datetime(2000, 3, 9, 13, 12, 11, tzinfo=UTC),
name="test-outcome-1",
external_url="https://fake.testy/do-not-go-here",
key="test-outcome-key-1",
@@ -563,10 +563,10 @@ def testgen_dataset_component(test_db, dataset):
yield dataset_component
-AGENT_LATEST_EVENT = datetime(2023, 10, 17, 12, 33, 19, 154295, tzinfo=timezone.utc)
+AGENT_LATEST_EVENT = datetime(2023, 10, 17, 12, 33, 19, 154295, tzinfo=UTC)
"""Default timestamp for latest event received by an agent."""
-AGENT_LATEST_HEARTBEAT = datetime(2023, 10, 17, 12, 42, 42, 424242, tzinfo=timezone.utc)
+AGENT_LATEST_HEARTBEAT = datetime(2023, 10, 17, 12, 42, 42, 424242, tzinfo=UTC)
"""Default timestamp for the lasttime an agent checked-in."""
@@ -585,7 +585,7 @@ def agent_1(test_db, project):
@pytest.fixture()
def agent_2(test_db, project):
- dt_1 = datetime.now(timezone.utc)
+ dt_1 = datetime.now(UTC)
dt_2 = dt_1 + timedelta(seconds=42)
return Agent.create(
project=project,
@@ -603,8 +603,8 @@ def event_entity(test_db, pipeline, task, run, run_task, instance_instance_set):
return EventEntity.create(
version=EventVersion.V2,
type=ApiEventType.BATCH_PIPELINE_STATUS,
- created_timestamp=datetime(2024, 1, 20, 10, 0, 0, tzinfo=timezone.utc),
- timestamp=datetime(2024, 1, 20, 9, 59, 0, tzinfo=timezone.utc),
+ created_timestamp=datetime(2024, 1, 20, 10, 0, 0, tzinfo=UTC),
+ timestamp=datetime(2024, 1, 20, 9, 59, 0, tzinfo=UTC),
project=pipeline.project_id,
component=pipeline,
task=task,
@@ -620,8 +620,8 @@ def event_entity_2(test_db, dataset):
return EventEntity.create(
version=EventVersion.V2,
type=ApiEventType.DATASET_OPERATION,
- created_timestamp=datetime(2024, 1, 20, 9, 55, 0, tzinfo=timezone.utc),
- timestamp=datetime(2024, 1, 20, 9, 55, 0, tzinfo=timezone.utc),
+ created_timestamp=datetime(2024, 1, 20, 9, 55, 0, tzinfo=UTC),
+ timestamp=datetime(2024, 1, 20, 9, 55, 0, tzinfo=UTC),
project=dataset.project_id,
component=dataset,
v2_payload={},
diff --git a/testlib/fixtures/v1_events.py b/testlib/fixtures/v1_events.py
index df9ab4e..060e8f8 100644
--- a/testlib/fixtures/v1_events.py
+++ b/testlib/fixtures/v1_events.py
@@ -32,7 +32,7 @@
]
-from datetime import datetime, timezone
+from datetime import datetime, UTC
from decimal import Decimal
from uuid import UUID
@@ -220,7 +220,7 @@ def FAILED_run_status_event_data(FAILED_run_status_event):
@pytest.fixture
def test_outcome_item_data(metadata_model) -> dict:
- timestamp = datetime.now(timezone.utc).isoformat()
+ timestamp = datetime.now(UTC).isoformat()
yield {
"name": "My_test_name",
"status": TestStatuses.PASSED.name,
diff --git a/testlib/fixtures/v2_events.py b/testlib/fixtures/v2_events.py
index 800fa87..371087f 100644
--- a/testlib/fixtures/v2_events.py
+++ b/testlib/fixtures/v2_events.py
@@ -19,7 +19,7 @@
"test_outcomes_testgen_event_v2",
]
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta, UTC
from decimal import Decimal
from uuid import UUID
@@ -84,7 +84,7 @@
TEST_OUTCOMES_EVENT_ID: UUID = UUID("83af84bc-318e-4dda-9d40-6c7c8bacd992")
"""ID for EventV2 LOG event."""
-EVENT_TIMESTAMP: datetime = datetime(2023, 5, 10, 1, 1, 1, tzinfo=timezone.utc)
+EVENT_TIMESTAMP: datetime = datetime(2023, 5, 10, 1, 1, 1, tzinfo=UTC)
"""Default timestamp for events."""
CREATED_TIMESTAMP: datetime = EVENT_TIMESTAMP + timedelta(minutes=3, seconds=1)
@@ -191,7 +191,7 @@ def run_alert() -> RunAlert:
@pytest.fixture
def test_outcome_item(metadata_model) -> TestOutcomeItem:
- timestamp = datetime.now(timezone.utc)
+ timestamp = datetime.now(UTC)
return TestOutcomeItem(
name="My_test_name",
status=TestStatus.PASSED,
diff --git a/testlib/peewee.py b/testlib/peewee.py
index 2240fca..ace6dca 100644
--- a/testlib/peewee.py
+++ b/testlib/peewee.py
@@ -1,9 +1,10 @@
import contextlib
from unittest.mock import Mock, patch
+from typing import Any
@contextlib.contextmanager
-def patch_select(target: str, **kwargs):
+def patch_select(target: str, **kwargs: Any) -> Any:
with patch(target=f"{target}.select") as select_mock:
select_mock.return_value = select_mock
for attr in ("join", "left_outer_join", "switch", "order_by", "where"):