From 9b3f7235d053b3031c087b7fb0d3f74355a97faf Mon Sep 17 00:00:00 2001 From: Brant Watson Date: Thu, 19 Jun 2025 11:28:35 -0500 Subject: [PATCH] Package updates & linter/typing fixes - Removed empty config.cfg hack (it was an ancient workaround for a problem that no longer exists) - Updated ruff formatter config to format as python3.12 (the version of python used by the project) - Fixed some outdated readme text - Upgraded mypy version - Fixed type hint complaints from mypy - Updated pytest & pytest plugin versions (fixes several noisy deprecation warnings) --- README.md | 6 +- agent_api/config/defaults.py | 5 +- agent_api/config/local.py | 4 +- agent_api/config/minikube.py | 4 +- agent_api/endpoints/v1/heartbeat.py | 8 +- .../v1_endpoints/test_heartbeat.py | 12 +-- cli/base.py | 4 +- cli/entry_points/database_schema.py | 4 +- cli/entry_points/gen_events.py | 4 +- cli/entry_points/graph_schema.py | 3 +- cli/lib.py | 3 +- common/actions/action.py | 12 +-- common/actions/action_factory.py | 3 +- common/actions/data_points.py | 84 +++++++++---------- common/actions/send_email_action.py | 5 +- common/actions/webhook_action.py | 6 +- common/api/base_view.py | 10 +-- common/api/flask_ext/authentication/common.py | 5 +- .../flask_ext/authentication/jwt_plugin.py | 12 +-- common/api/flask_ext/base_extension.py | 3 +- common/api/flask_ext/config.py | 3 +- common/api/flask_ext/cors.py | 5 +- common/api/request_parsing.py | 8 +- common/api/search_view.py | 4 +- common/apscheduler_extensions.py | 7 +- common/auth/keys/service_key.py | 10 +-- common/datetime_utils.py | 6 +- common/decorators.py | 8 +- common/entities/alert.py | 9 +- common/entities/authentication.py | 6 +- common/entities/base_entity.py | 4 +- common/entities/company.py | 3 +- common/entities/component_meta.py | 4 +- common/entities/dataset_operation.py | 4 +- common/entities/upcoming_instance.py | 5 +- common/entity_services/component_service.py | 3 +- .../entity_services/helpers/filter_rules.py | 30 ++++--- common/entity_services/helpers/list_rules.py | 6 +- common/entity_services/instance_service.py | 11 ++- common/entity_services/journey_service.py | 3 +- common/entity_services/pipeline_service.py | 3 +- common/entity_services/project_service.py | 8 +- .../entity_services/test_outcome_service.py | 9 +- .../upcoming_instance_service.py | 12 +-- common/entity_services/user_service.py | 6 +- common/events/base.py | 23 +++-- common/events/converters.py | 58 +++++++------ common/events/internal/alert.py | 3 +- common/events/internal/scheduled_event.py | 3 +- common/events/v1/dataset_operation_event.py | 3 +- common/events/v1/event.py | 55 ++++++------ common/events/v1/event_schemas.py | 6 +- common/events/v1/test_outcomes_event.py | 55 ++++++------ common/events/v2/base.py | 9 +- common/events/v2/component_data.py | 28 +++---- common/events/v2/dataset_operation.py | 4 +- common/events/v2/test_outcomes.py | 45 +++++----- common/events/v2/testgen.py | 18 ++-- common/kafka/consumer.py | 8 +- common/kafka/message.py | 6 +- common/kafka/producer.py | 8 +- common/kafka/topic.py | 6 +- common/logging/json_logging.py | 4 +- common/messagepack.py | 10 +-- common/model.py | 3 +- common/peewee_extensions/fields.py | 38 ++++----- common/peewee_extensions/fixtures.py | 4 +- common/predicate_engine/compilers/utils.py | 2 +- common/predicate_engine/query.py | 32 +++---- common/schemas/fields/cron_expr_str.py | 4 +- common/schemas/fields/enum_str.py | 4 +- common/schemas/fields/normalized_str.py | 6 +- common/schemas/fields/zoneinfo.py | 4 +- common/schemas/filter_schemas.py | 14 ++-- common/schemas/validators/regexp.py | 3 +- .../tests/integration/entities/test_alerts.py | 18 ++-- .../tests/integration/entities/test_runs.py | 4 +- .../integration/entity_services/conftest.py | 6 +- .../entity_services/test_project_service.py | 18 ++-- .../test_upcoming_instance_services.py | 6 +- .../test_apscheduler_extensions.py | 20 ++--- .../tests/unit/actions/test_webhook_action.py | 4 +- .../tests/unit/entities/test_journey_dag.py | 2 +- .../helpers/test_filter_rules.py | 10 +-- .../tests/unit/events/v1/test_base_events.py | 6 +- .../tests/unit/events/v1/test_testoutcomes.py | 6 +- .../unit/events/v2/test_test_outcomes.py | 12 +-- .../tests/unit/flask_ext/test_jwt_plugin.py | 8 +- .../test_peewee_extensions.py | 4 +- .../tests/unit/predicate_engine/assertions.py | 4 +- .../tests/unit/predicate_engine/conftest.py | 4 +- .../predicate_engine/test_predicate_engine.py | 4 +- .../tests/unit/test_apscheduler_extensions.py | 12 +-- common/tests/unit/test_datetime_utils.py | 14 ++-- common/tests/unit/test_messagepack.py | 4 +- deploy/search_view_plugin.py | 6 +- deploy/subcomponent_plugin.py | 10 +-- event_api/config/defaults.py | 5 +- event_api/config/local.py | 4 +- event_api/config/minikube.py | 4 +- event_api/endpoints/v1/event_view.py | 7 +- .../integration/v1_endpoints/conftest.py | 4 +- .../integration/v2_endpoints/conftest.py | 4 +- observability_api/config/defaults.py | 5 +- observability_api/config/local.py | 4 +- observability_api/config/minikube.py | 4 +- observability_api/config/test.py | 4 +- observability_api/endpoints/component_view.py | 4 +- observability_api/endpoints/v1/journeys.py | 3 +- .../endpoints/v1/project_settings.py | 4 +- observability_api/schemas/event_schemas.py | 3 +- .../integration/v1_endpoints/conftest.py | 10 +-- .../integration/v1_endpoints/test_alerts.py | 4 +- .../v1_endpoints/test_instances.py | 28 +++---- .../integration/v1_endpoints/test_jwt_auth.py | 14 ++-- .../integration/v1_endpoints/test_runs.py | 30 +++---- .../v1_endpoints/test_service_account_keys.py | 6 +- .../v1_endpoints/test_upcoming_instances.py | 16 ++-- pyproject.toml | 24 ++++-- rules_engine/journey_rules.py | 18 ++-- rules_engine/rule_data.py | 3 +- rules_engine/tests/integration/conftest.py | 6 +- rules_engine/tests/unit/test_data_points.py | 10 +-- run_manager/alerts.py | 5 +- run_manager/context.py | 15 ++-- .../event_handlers/component_identifier.py | 6 +- .../event_handlers/instance_handler.py | 6 +- run_manager/event_handlers/run_handler.py | 4 +- .../run_unexpected_status_change_handler.py | 5 +- run_manager/tests/integration/conftest.py | 6 +- .../test_out_of_sequence_instance_handler.py | 6 +- .../tests/integration/test_run_handler.py | 6 +- .../integration/test_run_manager_instance.py | 4 +- .../test_run_manager_unordered_events.py | 10 +-- .../integration/test_scheduler_events.py | 4 +- scheduler/agent_check.py | 6 +- scheduler/component_expectations.py | 3 +- scheduler/schedule_source.py | 6 +- .../tests/integration/test_agent_scheduler.py | 4 +- .../tests/integration/test_schedule_source.py | 4 +- scheduler/tests/unit/conftest.py | 6 +- scheduler/tests/unit/test_agent_scheduler.py | 4 +- scripts/invocations/deploy.py | 2 +- setup.cfg | 2 - testlib/fixtures/entities.py | 24 +++--- testlib/fixtures/v1_events.py | 4 +- testlib/fixtures/v2_events.py | 6 +- testlib/peewee.py | 3 +- 148 files changed, 675 insertions(+), 726 deletions(-) delete mode 100644 setup.cfg diff --git a/README.md b/README.md index 1d63c26..360006b 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ *

DataOps Observability is part of DataKitchen's Open Source Data Observability. DataOps Observability monitors every data journey from data source to customer value, from any team development environment into production, across every tool, team, environment, and customer so that problems are detected, localized, and understood immediately.

* -[![DatKitchen Open Source Data Observability](https://datakitchen.io/wp-content/uploads/2024/04/both-products.png)](https://datakitchen.storylane.io/share/g01ss0plyamz) +[![DataKitchen Open Source Data Observability](https://datakitchen.io/wp-content/uploads/2024/04/both-products.png)](https://datakitchen.storylane.io/share/g01ss0plyamz) [Interactive Product Tour](https://datakitchen.storylane.io/share/g01ss0plyamz) ## Developer Setup @@ -100,9 +100,7 @@ We enforce the use of certain linting tools. To not get caught by the build-syst The following hooks are enabled in pre-commit: -- `black`: The black formatter is enforced on the project. We use a basic configuration. Ideally this should solve any and all -formatting questions we might encounter. -- `isort`: the isort import-sorter is enforced on the project. We use it with the `black` profile. +- `ruff`: Handles code formatting, import sorting, and linting To enable pre-commit from within your virtual environment, simply run: diff --git a/agent_api/config/defaults.py b/agent_api/config/defaults.py index d02c934..cb6cd3e 100644 --- a/agent_api/config/defaults.py +++ b/agent_api/config/defaults.py @@ -5,13 +5,12 @@ """ import os -from typing import Optional # Flask specific settings: https://flask.palletsprojects.com/en/latest/config/#builtin-configuration-values from common.entities import Service -PROPAGATE_EXCEPTIONS: Optional[bool] = None -SERVER_NAME: Optional[str] = os.environ.get("AGENT_API_HOSTNAME") # Use flask defaults if none set +PROPAGATE_EXCEPTIONS: bool | None = None +SERVER_NAME: str | None = os.environ.get("AGENT_API_HOSTNAME") # Use flask defaults if none set USE_X_SENDFILE: bool = False # If we serve files enable this in production settings when webserver support configured # Application settings diff --git a/agent_api/config/local.py b/agent_api/config/local.py index aaf5847..f063607 100644 --- a/agent_api/config/local.py +++ b/agent_api/config/local.py @@ -1,5 +1,3 @@ -from typing import Optional - # Flask specific settings: https://flask.palletsprojects.com/en/latest/config/#builtin-configuration-values -PROPAGATE_EXCEPTIONS: Optional[bool] = True +PROPAGATE_EXCEPTIONS: bool | None = True SECRET_KEY: str = "NOT_VERY_SECRET" diff --git a/agent_api/config/minikube.py b/agent_api/config/minikube.py index e5ed0f8..6c72c79 100644 --- a/agent_api/config/minikube.py +++ b/agent_api/config/minikube.py @@ -1,5 +1,3 @@ -from typing import Optional - # Flask specific settings: https://flask.palletsprojects.com/en/latest/config/#builtin-configuration-values -TESTING: Optional[bool] = True +TESTING: bool | None = True SECRET_KEY: str = "NOT_VERY_SECRET" diff --git a/agent_api/endpoints/v1/heartbeat.py b/agent_api/endpoints/v1/heartbeat.py index 451bd56..d01886e 100644 --- a/agent_api/endpoints/v1/heartbeat.py +++ b/agent_api/endpoints/v1/heartbeat.py @@ -1,7 +1,7 @@ import logging -from datetime import datetime, timezone +from datetime import datetime, UTC from http import HTTPStatus -from typing import Optional, Union, cast +from typing import Union, cast from uuid import UUID from flask import Response, g, make_response @@ -23,7 +23,7 @@ def _update_or_create( version: str, project_id: Union[str, UUID], latest_heartbeat: datetime, - latest_event_timestamp: Optional[datetime], + latest_event_timestamp: datetime | None, ) -> None: try: agent = Agent.select().where(Agent.key == key, Agent.tool == tool, Agent.project_id == project_id).get() @@ -57,7 +57,7 @@ class Heartbeat(BaseView): def post(self) -> Response: data = self.parse_body(schema=HeartbeatSchema()) - data["latest_heartbeat"] = datetime.now(tz=timezone.utc) + data["latest_heartbeat"] = datetime.now(tz=UTC) data["project_id"] = g.project.id _update_or_create(**data) return make_response("", HTTPStatus.NO_CONTENT) diff --git a/agent_api/tests/integration/v1_endpoints/test_heartbeat.py b/agent_api/tests/integration/v1_endpoints/test_heartbeat.py index ca2cfa5..0c6d08d 100644 --- a/agent_api/tests/integration/v1_endpoints/test_heartbeat.py +++ b/agent_api/tests/integration/v1_endpoints/test_heartbeat.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC from http import HTTPStatus import pytest @@ -9,7 +9,7 @@ @pytest.mark.integration def test_agent_heartbeat(client, database_ctx, headers): - last_event_timestamp = datetime(2023, 10, 20, 4, 42, 42, tzinfo=timezone.utc) + last_event_timestamp = datetime(2023, 10, 20, 4, 42, 42, tzinfo=UTC) data = { "key": "test-key", "tool": "test-tool", @@ -35,7 +35,7 @@ def test_agent_heartbeat_no_event_timestamp(client, database_ctx, headers): @pytest.mark.integration def test_agent_heartbeat_update(client, database_ctx, headers): - last_event_timestamp = datetime(2023, 10, 20, 4, 42, 42, tzinfo=timezone.utc) + last_event_timestamp = datetime(2023, 10, 20, 4, 42, 42, tzinfo=UTC) data = { "key": "test-key", "tool": "test-tool", @@ -47,7 +47,7 @@ def test_agent_heartbeat_update(client, database_ctx, headers): assert HTTPStatus.NO_CONTENT == response_1.status_code, response_1.json # The latest_event_timestamp should be older than "now" - now = datetime.now(timezone.utc) + now = datetime.now(UTC) agent_1 = Agent.select().get() assert agent_1.latest_heartbeat < now assert agent_1.status == AgentStatus.ONLINE @@ -62,7 +62,7 @@ def test_agent_heartbeat_update(client, database_ctx, headers): @pytest.mark.integration def test_agent_heartbeat_existing_update(client, database_ctx, headers): - last_event_timestamp = datetime(2023, 10, 20, 4, 42, 42, tzinfo=timezone.utc) + last_event_timestamp = datetime(2023, 10, 20, 4, 42, 42, tzinfo=UTC) data_1 = { "key": "test-key", "tool": "test-tool", @@ -79,7 +79,7 @@ def test_agent_heartbeat_existing_update(client, database_ctx, headers): data_2 = data_1.copy() data_2["version"] = "12.0.3" - data_2["latest_event_timestamp"] = datetime(2023, 10, 20, 4, 44, 44, tzinfo=timezone.utc).isoformat() + data_2["latest_event_timestamp"] = datetime(2023, 10, 20, 4, 44, 44, tzinfo=UTC).isoformat() response_2 = client.post("/agent/v1/heartbeat", json=data_2, headers=headers) assert HTTPStatus.NO_CONTENT == response_2.status_code, response_2.json diff --git a/cli/base.py b/cli/base.py index 7d28942..889a187 100644 --- a/cli/base.py +++ b/cli/base.py @@ -4,7 +4,7 @@ from argparse import ArgumentParser from logging.config import dictConfig from pathlib import Path -from typing import Any, Optional +from typing import Any from collections.abc import Callable from log_color import ColorFormatter, ColorStripper @@ -80,7 +80,7 @@ def __init__(self, **kwargs: Any) -> None: LOG.info("#g<\u2714> Established #c<%s> connection to #c<%s>", DB.obj.__class__.__name__, DB.obj.database) -def logging_init(*, level: str, logfile: Optional[str] = None) -> None: +def logging_init(*, level: str, logfile: str | None = None) -> None: """Given the log level and an optional logging file location, configure all logging.""" # Don't bother with a file handler if we're not logging to a file handlers = ["console", "filehandler"] if logfile else ["console"] diff --git a/cli/entry_points/database_schema.py b/cli/entry_points/database_schema.py index fa2376f..7a2c02a 100644 --- a/cli/entry_points/database_schema.py +++ b/cli/entry_points/database_schema.py @@ -1,6 +1,6 @@ import re from argparse import ArgumentParser -from typing import Any, Optional +from typing import Any from re import Pattern from collections.abc import Iterable @@ -19,7 +19,7 @@ class MysqlPrintDatabase(MySQLDatabase): def __init__(self) -> None: super().__init__("") - def execute_sql(self, sql: str, params: Optional[Iterable[Any]] = None, commit: Optional[bool] = None) -> None: + def execute_sql(self, sql: str, params: Iterable[Any] | None = None, commit: bool | None = None) -> None: if params: raise Exception(f"Params are not expected to be needed to run DDL SQL, but found {params}") if match := self._create_table_re.match(sql): diff --git a/cli/entry_points/gen_events.py b/cli/entry_points/gen_events.py index 955b54a..85f0d20 100644 --- a/cli/entry_points/gen_events.py +++ b/cli/entry_points/gen_events.py @@ -7,7 +7,7 @@ import time from argparse import Action, ArgumentParser, Namespace from datetime import datetime -from typing import Any, Optional, Union +from typing import Any, Union from collections.abc import Sequence from requests_extensions import get_session @@ -90,7 +90,7 @@ def __call__( parser: ArgumentParser, namespace: Namespace, values: Union[str, Sequence[Any], None], - option_string: Optional[str] = None, + option_string: str | None = None, ) -> None: event_data = {} remove_fields = [] diff --git a/cli/entry_points/graph_schema.py b/cli/entry_points/graph_schema.py index c433fe9..e44d6c6 100644 --- a/cli/entry_points/graph_schema.py +++ b/cli/entry_points/graph_schema.py @@ -2,6 +2,7 @@ import sys from argparse import ArgumentParser from pathlib import Path +from typing import Any from jinja2 import Environment, FileSystemLoader from peewee import Field, ForeignKeyField, ManyToManyField, Model @@ -54,7 +55,7 @@ def subcmd_entry_point(self) -> None: dot_parts = [head.render({})] # Initial context/config - model_context = [] + model_context: list[dict[str, Any]] = [] LOG.info("#m") for name, model in model_map.items(): diff --git a/cli/lib.py b/cli/lib.py index 80ed768..fd79dcb 100644 --- a/cli/lib.py +++ b/cli/lib.py @@ -2,7 +2,6 @@ import re import textwrap from argparse import ArgumentParser, ArgumentTypeError -from typing import Optional from uuid import UUID from log_color.colors import ColorStr @@ -21,7 +20,7 @@ def uuid_type(arg: str) -> UUID: def slice_type(arg: str) -> slice: """Convert an argument to a slice; for simplicity, disallow negative slice values and steps.""" - def _int_or_none(val: str) -> Optional[int]: + def _int_or_none(val: str) -> int | None: if not val: return None else: diff --git a/common/actions/action.py b/common/actions/action.py index 7eea0a7..18e062b 100644 --- a/common/actions/action.py +++ b/common/actions/action.py @@ -8,7 +8,7 @@ ] import logging -from typing import Any, NamedTuple, Optional +from typing import Any, NamedTuple from uuid import UUID from common.entities import Action, Rule @@ -35,15 +35,15 @@ class InvalidActionTemplate(ActionException): class ActionResult(NamedTuple): result: bool - response: Optional[dict] - exception: Optional[Exception] + response: dict | None + exception: Exception | None class BaseAction: required_arguments: set = set() requires_action_template: bool = False - def __init__(self, action_template: Optional[Action], override_arguments: dict) -> None: + def __init__(self, action_template: Action | None, override_arguments: dict) -> None: if self.requires_action_template and not action_template: raise ActionTemplateRequired(f"'{self.__class__.__name__}' requires an action template to be set") @@ -70,7 +70,7 @@ def _validate_args(self) -> None: if missing_args: raise ValueError(f"Required arguments {missing_args} missing for {self.__class__.__name__}") - def _run(self, event: EVENT_TYPE, rule: Rule, journey_id: Optional[UUID]) -> ActionResult: + def _run(self, event: EVENT_TYPE, rule: Rule, journey_id: UUID | None) -> ActionResult: raise NotImplementedError("Base Action cannot be executed") def _store_action_result(self, action_result: ActionResult) -> None: @@ -88,7 +88,7 @@ def _store_action_result(self, action_result: ActionResult) -> None: exc_info=action_result.exception, ) - def execute(self, event: EVENT_TYPE, rule: Rule, journey_id: Optional[UUID]) -> bool: + def execute(self, event: EVENT_TYPE, rule: Rule, journey_id: UUID | None) -> bool: action_result = self._run(event, rule, journey_id) self._store_action_result(action_result) return action_result.result diff --git a/common/actions/action_factory.py b/common/actions/action_factory.py index 9aabcc1..b37a1c1 100644 --- a/common/actions/action_factory.py +++ b/common/actions/action_factory.py @@ -1,6 +1,5 @@ __all__ = ["ACTION_CLASS_MAP", "action_factory"] -from typing import Optional from common.entities import Action @@ -11,7 +10,7 @@ ACTION_CLASS_MAP: dict[str, type[BaseAction]] = {"CALL_WEBHOOK": WebhookAction, "SEND_EMAIL": SendEmailAction} -def action_factory(implementation: str, action_args: dict, template: Optional[Action]) -> BaseAction: +def action_factory(implementation: str, action_args: dict, template: Action | None) -> BaseAction: try: action_class = ACTION_CLASS_MAP[implementation] except KeyError as ke: diff --git a/common/actions/data_points.py b/common/actions/data_points.py index 6d5e4ac..289b013 100644 --- a/common/actions/data_points.py +++ b/common/actions/data_points.py @@ -1,7 +1,7 @@ import logging from collections.abc import Mapping from datetime import datetime -from typing import Any, Optional, cast +from typing import Any, cast from collections.abc import Callable, Iterator from uuid import UUID @@ -102,7 +102,7 @@ def __init__(self, event: EVENT_TYPE): "name": self._name, } - def _id(self) -> Optional[UUID]: + def _id(self) -> UUID | None: return self.event.project_id def _name(self) -> str: @@ -120,16 +120,16 @@ def __init__(self, event: Event): "type": self._type, } - def _id(self) -> Optional[UUID]: + def _id(self) -> UUID | None: return self.event.component_id - def _key(self) -> Optional[str]: + def _key(self) -> str | None: return self.event.component_key - def _name(self) -> Optional[str]: + def _name(self) -> str | None: return self.event.component.display_name - def _type(self) -> Optional[str]: + def _type(self) -> str | None: return self.event.component_type.name @@ -142,13 +142,13 @@ def __init__(self, event: Event): "name": self._name, } - def _id(self) -> Optional[UUID]: + def _id(self) -> UUID | None: return self.event.pipeline_id - def _key(self) -> Optional[str]: + def _key(self) -> str | None: return self.event.pipeline_key - def _name(self) -> Optional[str]: + def _name(self) -> str | None: return cast(str, self.event.pipeline.display_name) @@ -170,13 +170,13 @@ def __init__(self, event: Event): "expected_end_time_formatted": self._expected_end_time_formatted, } - def _id(self) -> Optional[UUID]: + def _id(self) -> UUID | None: return self.event.run_id - def _key(self) -> Optional[str]: + def _key(self) -> str | None: return self.event.run_key - def _name(self) -> Optional[str]: + def _name(self) -> str | None: return self.event.run_name def _status(self) -> str: @@ -195,7 +195,7 @@ def _start_time_formatted(self) -> str: return datetime_formatted(start_time) return "N/A" - def _expected_start_dt(self) -> Optional[datetime]: + def _expected_start_dt(self) -> datetime | None: try: run = getattr(self.event, "run", None) except DoesNotExist: @@ -206,11 +206,11 @@ def _expected_start_dt(self) -> Optional[datetime]: return cast(datetime, expected_start_time) return None - def _expected_start_time(self) -> Optional[str]: + def _expected_start_time(self) -> str | None: val = self._expected_start_dt() return datetime_iso8601(val) if val else None - def _expected_start_time_formatted(self) -> Optional[str]: + def _expected_start_time_formatted(self) -> str | None: val = self._expected_start_dt() return datetime_formatted(val) if val else None @@ -226,7 +226,7 @@ def _end_time_formatted(self) -> str: return datetime_formatted(end_time) return "N/A" - def _expected_end_dt(self) -> Optional[datetime]: + def _expected_end_dt(self) -> datetime | None: try: run = getattr(self.event, "run", None) except DoesNotExist: @@ -237,11 +237,11 @@ def _expected_end_dt(self) -> Optional[datetime]: return cast(datetime, expected_end_time) return None - def _expected_end_time(self) -> Optional[str]: + def _expected_end_time(self) -> str | None: val = self._expected_end_dt() return datetime_iso8601(val) if val else None - def _expected_end_time_formatted(self) -> Optional[str]: + def _expected_end_time_formatted(self) -> str | None: val = self._expected_end_dt() return datetime_formatted(val) if val else None @@ -255,10 +255,10 @@ def __init__(self, event: Event): "name": self._name, } - def _id(self) -> Optional[UUID]: + def _id(self) -> UUID | None: return self.event.task_id - def _key(self) -> Optional[str]: + def _key(self) -> str | None: ret: str = getattr(self.event, "task_key") return ret @@ -278,7 +278,7 @@ def __init__(self, event: Event) -> None: "end_time_formatted": self._end_time_formatted, } - def _id(self) -> Optional[UUID]: + def _id(self) -> UUID | None: return self.event.run_task_id def _status(self) -> str: @@ -408,7 +408,7 @@ def _alert_type(self) -> str: _type: str = self.event.type.value return _type - def _expected_start_dt(self) -> Optional[datetime]: + def _expected_start_dt(self) -> datetime | None: try: alert = getattr(self.event, "alert", None) except DoesNotExist: @@ -419,7 +419,7 @@ def _expected_start_dt(self) -> Optional[datetime]: return cast(datetime, expected_start_time) return None - def _expected_end_dt(self) -> Optional[datetime]: + def _expected_end_dt(self) -> datetime | None: try: alert = getattr(self.event, "alert", None) except DoesNotExist: @@ -430,11 +430,11 @@ def _expected_end_dt(self) -> Optional[datetime]: return cast(datetime, expected_end_time) return None - def _expected_start_time_formatted(self) -> Optional[str]: + def _expected_start_time_formatted(self) -> str | None: val = self._expected_start_dt() return datetime_formatted(val) if val else None - def _expected_end_time_formatted(self) -> Optional[str]: + def _expected_end_time_formatted(self) -> str | None: val = self._expected_end_dt() return datetime_formatted(val) if val else None @@ -448,16 +448,16 @@ def __init__(self, event: RunAlert) -> None: "name": self._name, } - def _id(self) -> Optional[UUID]: - id: Optional[UUID] = self.event.batch_pipeline_id + def _id(self) -> UUID | None: + id: UUID | None = self.event.batch_pipeline_id return id - def _key(self) -> Optional[str]: - key: Optional[str] = self.event.batch_pipeline.key + def _key(self) -> str | None: + key: str | None = self.event.batch_pipeline.key return key - def _name(self) -> Optional[str]: - name: Optional[str] = self.event.batch_pipeline.display_name + def _name(self) -> str | None: + name: str | None = self.event.batch_pipeline.display_name return name @@ -470,16 +470,16 @@ def __init__(self, event: RunAlert) -> None: "name": self._name, } - def _id(self) -> Optional[UUID]: - id: Optional[UUID] = self.event.run.id + def _id(self) -> UUID | None: + id: UUID | None = self.event.run.id return id - def _key(self) -> Optional[str]: - key: Optional[str] = self.event.run.key + def _key(self) -> str | None: + key: str | None = self.event.run.key return key - def _name(self) -> Optional[str]: - name: Optional[str] = self.event.run.name + def _name(self) -> str | None: + name: str | None = self.event.run.name return name @@ -493,26 +493,26 @@ def __init__(self, rule: Rule) -> None: "run_state_trigger_successive": self._run_state_trigger_successive, } - def _run_state_matches(self) -> Optional[str]: + def _run_state_matches(self) -> str | None: try: - matches: Optional[str] = self.rule.rule_data["conditions"][0]["run_state"]["matches"] + matches: str | None = self.rule.rule_data["conditions"][0]["run_state"]["matches"] return matches except Exception: return None - def _run_state_count(self) -> Optional[str]: + def _run_state_count(self) -> str | None: try: return str(self.rule.rule_data["conditions"][0]["run_state"]["count"]) except Exception: return None - def _run_state_group_run_name(self) -> Optional[str]: + def _run_state_group_run_name(self) -> str | None: try: return str(self.rule.rule_data["conditions"][0]["run_state"]["group_run_name"]) except Exception: return None - def _run_state_trigger_successive(self) -> Optional[str]: + def _run_state_trigger_successive(self) -> str | None: try: return str(self.rule.rule_data["conditions"][0]["run_state"]["trigger_successive"]) except Exception: diff --git a/common/actions/send_email_action.py b/common/actions/send_email_action.py index 3906829..7845801 100644 --- a/common/actions/send_email_action.py +++ b/common/actions/send_email_action.py @@ -1,7 +1,6 @@ __all__ = ["SendEmailAction"] import logging from dataclasses import asdict -from typing import Optional from uuid import UUID from peewee import DoesNotExist @@ -22,7 +21,7 @@ class SendEmailAction(BaseAction): required_arguments = {"recipients", "template"} requires_action_template = True - def _run(self, event: EVENT_TYPE, rule: Rule, journey_id: Optional[UUID]) -> ActionResult: + def _run(self, event: EVENT_TYPE, rule: Rule, journey_id: UUID | None) -> ActionResult: try: context = self._get_data_points(event, rule, journey_id) except Exception as e: @@ -41,7 +40,7 @@ def _run(self, event: EVENT_TYPE, rule: Rule, journey_id: Optional[UUID]) -> Act return ActionResult(True, response, None) def _get_data_points( - self, event: EVENT_TYPE, rule: Rule, journey_id: Optional[UUID] + self, event: EVENT_TYPE, rule: Rule, journey_id: UUID | None ) -> dict | AgentStatusChangeDataPoints: """ Get the data points to be used in the email template diff --git a/common/actions/webhook_action.py b/common/actions/webhook_action.py index 207eee8..59dc93a 100644 --- a/common/actions/webhook_action.py +++ b/common/actions/webhook_action.py @@ -1,7 +1,7 @@ __all__ = ["WebhookAction"] import logging -from typing import Any, Optional, Union +from typing import Any, Union from collections.abc import Mapping from uuid import UUID @@ -44,7 +44,7 @@ def format_data(data: Union[None, list, dict, str], data_points: Mapping) -> Any class WebhookAction(BaseAction): required_arguments = {"url", "method"} - def _run(self, event: EVENT_TYPE, rule: Rule, _: Optional[UUID]) -> ActionResult: + def _run(self, event: EVENT_TYPE, rule: Rule, _: UUID | None) -> ActionResult: data_points: Mapping match event: case RunAlert() | InstanceAlert(): @@ -67,7 +67,7 @@ def _run(self, event: EVENT_TYPE, rule: Rule, _: Optional[UUID]) -> ActionResult return ActionResult(False, None, e) return ActionResult(True, {"status_code": response.status_code}, None) - def _parse_headers(self, data_points: Mapping) -> Optional[dict[str, str]]: + def _parse_headers(self, data_points: Mapping) -> dict[str, str] | None: if headers := self.arguments.get("headers"): return {h["key"]: format_data(h["value"], data_points) for h in headers} else: diff --git a/common/api/base_view.py b/common/api/base_view.py index ac6f9d7..33beae2 100644 --- a/common/api/base_view.py +++ b/common/api/base_view.py @@ -4,7 +4,7 @@ import logging from dataclasses import dataclass from functools import cached_property -from typing import Any, Optional +from typing import Any from flask import g, request from flask.typing import ResponseReturnValue @@ -22,7 +22,7 @@ @dataclass class Permission: entity_attribute: str - role: Optional[str] = None + role: str | None = None methods: tuple[str, ...] = ("GET", "PUT", "POST", "PATCH", "DELETE") def __call__(self, *methods: str) -> Permission: @@ -48,7 +48,7 @@ class BaseView(MethodView): """ @property - def user(self) -> Optional[User]: + def user(self) -> User | None: """Return the currently authenticated user.""" return getattr(g, "user", None) @@ -61,12 +61,12 @@ def user_roles(self) -> list[str]: return [] @property - def claims(self) -> Optional[User]: + def claims(self) -> User | None: """Return the currently authenticated token.""" return getattr(g, "claims", None) @property - def project(self) -> Optional[Project]: + def project(self) -> Project | None: return getattr(g, "project", None) @property diff --git a/common/api/flask_ext/authentication/common.py b/common/api/flask_ext/authentication/common.py index 249a0a9..6f7125b 100644 --- a/common/api/flask_ext/authentication/common.py +++ b/common/api/flask_ext/authentication/common.py @@ -1,7 +1,6 @@ __all__ = ["get_domain", "BaseAuthPlugin", "validate_authentication"] import logging import re -from typing import Optional from urllib.parse import urlparse from flask import current_app, g, request @@ -16,10 +15,10 @@ class BaseAuthPlugin(BaseExtension): header_name: str = NotImplemented - header_prefix: Optional[str] = None + header_prefix: str | None = None @classmethod - def get_header_data(cls) -> Optional[str]: + def get_header_data(cls) -> str | None: auth_data = request.headers.get(cls.header_name, None) if auth_data and cls.header_prefix: if match := re.match(rf"^{cls.header_prefix}\s+(.*)\s*$", auth_data): diff --git a/common/api/flask_ext/authentication/jwt_plugin.py b/common/api/flask_ext/authentication/jwt_plugin.py index 965ba9a..0bcd41f 100644 --- a/common/api/flask_ext/authentication/jwt_plugin.py +++ b/common/api/flask_ext/authentication/jwt_plugin.py @@ -1,7 +1,7 @@ __all__ = ["JWTAuth"] import logging -from datetime import datetime, timedelta, timezone -from typing import Optional, cast +from datetime import datetime, timedelta, UTC +from typing import cast from collections.abc import Callable from flask import current_app, g, request @@ -24,7 +24,7 @@ def get_token_expiration(claims: JWT_CLAIMS) -> datetime: except KeyError as ke: raise ValueError("Token claims missing 'exp' key") from ke try: - return datetime.fromtimestamp(cast(float | int, exp_timestamp), tz=timezone.utc) + return datetime.fromtimestamp(cast(float | int, exp_timestamp), tz=UTC) except Exception as e: raise ValueError(f"Unable to parse expiration from '{claims['exp']}'") from e @@ -63,7 +63,7 @@ def pre_request_auth(cls) -> None: except Exception as e: raise Unauthorized("Invalid authentication token") from e - if get_token_expiration(claims) < datetime.now(timezone.utc): + if get_token_expiration(claims) < datetime.now(UTC): LOG.error("JWT token expired") raise Unauthorized("Invalid authentication token") @@ -95,7 +95,7 @@ def decode_token(cls, token: str) -> JWT_CLAIMS: return decoded_token @classmethod - def log_user_in(cls, user: User, logout_callback: Optional[str] = None, claims: Optional[JWT_CLAIMS] = None) -> str: + def log_user_in(cls, user: User, logout_callback: str | None = None, claims: JWT_CLAIMS | None = None) -> str: claims = claims or {} if logout_callback: @@ -105,7 +105,7 @@ def log_user_in(cls, user: User, logout_callback: Optional[str] = None, claims: raise ValueError(f"Logout callback '{logout_callback}' is not registered.") if "exp" not in claims: - claims["exp"] = (datetime.now(timezone.utc) + cls.default_jwt_expiration).timestamp() + claims["exp"] = (datetime.now(UTC) + cls.default_jwt_expiration).timestamp() claims["user_id"] = str(user.id) claims["company_id"] = str(user.primary_company_id) diff --git a/common/api/flask_ext/base_extension.py b/common/api/flask_ext/base_extension.py index bfd403f..653cf74 100644 --- a/common/api/flask_ext/base_extension.py +++ b/common/api/flask_ext/base_extension.py @@ -1,12 +1,11 @@ __all__ = ["BaseExtension"] -from typing import Optional from flask import Flask from flask.typing import AfterRequestCallable, AppOrBlueprintKey, BeforeRequestCallable class BaseExtension: - def __init__(self, app: Optional[Flask] = None) -> None: + def __init__(self, app: Flask | None = None) -> None: if app is not None: self.app = app self.init_app() diff --git a/common/api/flask_ext/config.py b/common/api/flask_ext/config.py index 3e914d9..3e9070f 100644 --- a/common/api/flask_ext/config.py +++ b/common/api/flask_ext/config.py @@ -1,6 +1,5 @@ __all__ = ["Config"] import os -from typing import Optional from flask import Flask @@ -16,7 +15,7 @@ class Config: then the configuration will load "foo.bar.production" """ - def __init__(self, app: Optional[Flask] = None, config_module: str = ""): + def __init__(self, app: Flask | None = None, config_module: str = ""): if not config_module: raise ValueError("You must provide a 'config_module' to the Config extension") self.app = app diff --git a/common/api/flask_ext/cors.py b/common/api/flask_ext/cors.py index f6cb681..79e7d95 100644 --- a/common/api/flask_ext/cors.py +++ b/common/api/flask_ext/cors.py @@ -1,6 +1,5 @@ __all__ = ["CORS"] from http import HTTPStatus -from typing import Optional from flask import Flask, Response, make_response, request from werkzeug.exceptions import NotFound @@ -26,13 +25,13 @@ class CORS(BaseExtension): - def __init__(self, app: Optional[Flask] = None, allowed_methods: Optional[list[str]] = None): + def __init__(self, app: Flask | None = None, allowed_methods: list[str] | None = None): allowed_methods = allowed_methods or [] self.allowed_methods = ", ".join(allowed_methods + ["OPTIONS"]).upper() super().__init__(app) @staticmethod - def make_preflight_response() -> Optional[Response]: + def make_preflight_response() -> Response | None: if request.method == "OPTIONS": # When request.endpoint isn't populated it means that the URL didn't match any registered view. For this # case we abort and issue a 404 diff --git a/common/api/request_parsing.py b/common/api/request_parsing.py index 73f91e1..c794ad5 100644 --- a/common/api/request_parsing.py +++ b/common/api/request_parsing.py @@ -1,7 +1,7 @@ __all__ = ["get_bool_param", "no_body_allowed", "str_to_bool", "get_origin_domain"] from functools import wraps -from typing import Any, Optional +from typing import Any from collections.abc import Callable, Iterable from urllib.parse import urlparse @@ -33,10 +33,10 @@ def str_to_bool(value: str, param_name: str) -> bool: elif case_insensitive_value == "false": return False else: - raise ValidationError({param_name: ("Expected 'true' or 'false'. Instead received " f"'{value}'.")}) + raise ValidationError({param_name: (f"Expected 'true' or 'false'. Instead received '{value}'.")}) -def no_body_allowed(func: Optional[Callable] = None, /, methods: Iterable[str] = SAFE_HTTP_METHODS) -> Callable: +def no_body_allowed(func: Callable | None = None, /, methods: Iterable[str] = SAFE_HTTP_METHODS) -> Callable: """ Decorator to be used on MethodView functions if the function does not allow a request body to be passed. @@ -58,7 +58,7 @@ def _wrapper(*args: list, **kwargs: dict) -> Any: return decorator -def get_origin_domain() -> Optional[str]: +def get_origin_domain() -> str | None: if (source_url := request.headers.get("Origin")) is not None: try: return urlparse(source_url).netloc or None diff --git a/common/api/search_view.py b/common/api/search_view.py index 44fad79..ad68568 100644 --- a/common/api/search_view.py +++ b/common/api/search_view.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from flask import Blueprint, Response, request from flask.typing import RouteCallable @@ -12,7 +12,7 @@ class SearchView(BaseView): - args_from_post: Optional[MultiDict] = None + args_from_post: MultiDict | None = None request_body_schema: type[Schema] def post(self, *args: Any, **kwargs: Any) -> Response: diff --git a/common/apscheduler_extensions.py b/common/apscheduler_extensions.py index 18fa404..4624eec 100644 --- a/common/apscheduler_extensions.py +++ b/common/apscheduler_extensions.py @@ -3,7 +3,6 @@ import logging import re from datetime import datetime, timedelta -from typing import Optional from collections.abc import Generator from zoneinfo import ZoneInfo @@ -36,10 +35,10 @@ def __init__(self, trigger: BaseTrigger, delay: timedelta): self.trigger = trigger self.delay = delay - def get_next_fire_time(self, previous_fire_time: Optional[datetime], now: datetime) -> Optional[datetime]: + def get_next_fire_time(self, previous_fire_time: datetime | None, now: datetime) -> datetime | None: if previous_fire_time: previous_fire_time -= self.delay - next_fire_time: Optional[datetime] = self.trigger.get_next_fire_time(previous_fire_time, now) + next_fire_time: datetime | None = self.trigger.get_next_fire_time(previous_fire_time, now) return next_fire_time + self.delay if next_fire_time else None def __str__(self) -> str: @@ -109,7 +108,7 @@ def fix_weekdays(expression: str) -> str: def get_crontab_trigger_times( - crontab: str, timezone: ZoneInfo, start_range: datetime, end_range: Optional[datetime] = None + crontab: str, timezone: ZoneInfo, start_range: datetime, end_range: datetime | None = None ) -> Generator[datetime, None, None]: """ Generate the crontab trigger times for the given time range. diff --git a/common/auth/keys/service_key.py b/common/auth/keys/service_key.py index 8ed0a53..1fbedb3 100644 --- a/common/auth/keys/service_key.py +++ b/common/auth/keys/service_key.py @@ -1,8 +1,8 @@ import logging from base64 import b64encode from dataclasses import dataclass -from datetime import datetime, timedelta, timezone -from typing import NamedTuple, Optional +from datetime import datetime, timedelta, UTC +from typing import NamedTuple from uuid import uuid4 from peewee import DoesNotExist @@ -31,15 +31,15 @@ def generate_key( *, project: Project, allowed_services: list[str], - name: Optional[str] = None, - description: Optional[str] = None, + name: str | None = None, + description: str | None = None, expiration_days: int = DEFAULT_EXPIRY_DAYS, ) -> KeyPair: """Generate a new Service Account key for the given service name.""" passphrase = generate_passphrase() salt = str(uuid4()) passphrase_hash = hash_value(value=passphrase, salt=salt) - expiry = datetime.now(timezone.utc) + timedelta(days=expiration_days) + expiry = datetime.now(UTC) + timedelta(days=expiration_days) digest = create_digest(iterations=HASH_ITERATIONS, salt=salt, passphrase_hash=passphrase_hash) # Give the key a pretty unique name if none provided (Name only has to be unique per-project so this should safe) diff --git a/common/datetime_utils.py b/common/datetime_utils.py index 8677449..6c473e0 100644 --- a/common/datetime_utils.py +++ b/common/datetime_utils.py @@ -1,12 +1,12 @@ __all__ = ["datetime_formatted", "datetime_iso8601", "to_utc_aware"] -from datetime import datetime, timezone +from datetime import datetime, UTC # Although datetimes in Events are tz aware they only contains the raw offset # after being marshmallow serialized. Since the tz name is desired astimezone # is used for strftime to return the name. def to_utc_aware(dt: datetime) -> datetime: - return dt.replace(tzinfo=timezone.utc) if dt.tzinfo is None else dt.astimezone(timezone.utc) + return dt.replace(tzinfo=UTC) if dt.tzinfo is None else dt.astimezone(UTC) def datetime_formatted(dt: datetime) -> str: @@ -28,5 +28,5 @@ def datetime_to_timestamp(dt: datetime) -> float: def timestamp_to_datetime(timestamp: float) -> datetime: """Convert a timestamp to a datetime object in UTC time.""" - dt = datetime.fromtimestamp(timestamp, tz=timezone.utc) + dt = datetime.fromtimestamp(timestamp, tz=UTC) return dt diff --git a/common/decorators.py b/common/decorators.py index 075ae84..485f4bd 100644 --- a/common/decorators.py +++ b/common/decorators.py @@ -1,12 +1,12 @@ from __future__ import annotations -from typing import Any, Generic, Optional, TypeVar, cast +from typing import Any, TypeVar, cast from collections.abc import Callable PropertyType = TypeVar("PropertyType") -class cached_property(Generic[PropertyType]): +class cached_property[PropertyType]: """ A `property` decorator that caches the value on the instance. @@ -47,7 +47,7 @@ class cached_property(Generic[PropertyType]): """ - _name: Optional[str] = None + _name: str | None = None def __init__(self, f: Callable[[Any], PropertyType]) -> None: self.func = f @@ -59,7 +59,7 @@ def __set_name__(self, owner: type[object], name: str) -> None: elif name != self._name: raise TypeError(f"Cannot assign the same instance to two names ({self._name} and {name}).") - def __get__(self, inst: object, cls: Optional[Any] = None) -> PropertyType: + def __get__(self, inst: object, cls: Any | None = None) -> PropertyType: """ Retrieve the value from instance, stashing the result in inst.__dict__ diff --git a/common/entities/alert.py b/common/entities/alert.py index 17b7371..70a6ebf 100644 --- a/common/entities/alert.py +++ b/common/entities/alert.py @@ -2,7 +2,6 @@ from datetime import datetime from enum import Enum -from typing import Optional from peewee import CharField, CompositeKey, ForeignKeyField from playhouse.mysql_ext import JSONField @@ -63,7 +62,7 @@ class AlertBase(BaseEntity, AuditUpdateTimeEntityMixin): level = EnumStrField(AlertLevel, null=False, max_length=50) @property - def expected_start_time(self) -> Optional[datetime]: + def expected_start_time(self) -> datetime | None: """If the alert has expected_start_time in it's details dict, return it as a datetime object.""" timestamp = self.details.get("expected_start_time", None) if timestamp: @@ -76,7 +75,7 @@ def expected_start_time(self) -> Optional[datetime]: return None @expected_start_time.setter - def expected_start_time(self, dt_obj: Optional[datetime]) -> None: + def expected_start_time(self, dt_obj: datetime | None) -> None: """Set the expected_start_time value (converts to timestamp in details dict).""" if dt_obj is None: self.details.pop("expected_start_time", None) @@ -85,7 +84,7 @@ def expected_start_time(self, dt_obj: Optional[datetime]) -> None: self.details["expected_start_time"] = timestamp @property - def expected_end_time(self) -> Optional[datetime]: + def expected_end_time(self) -> datetime | None: """If the alert has expected_end_time in it's details dict, return it as a datetime object.""" timestamp = self.details.get("expected_end_time", None) if timestamp: @@ -98,7 +97,7 @@ def expected_end_time(self) -> Optional[datetime]: return None @expected_end_time.setter - def expected_end_time(self, dt_obj: Optional[datetime]) -> None: + def expected_end_time(self, dt_obj: datetime | None) -> None: """Set the expected_end_time value (converts to timestamp in details dict).""" if dt_obj is None: self.details.pop("expected_end_time", None) diff --git a/common/entities/authentication.py b/common/entities/authentication.py index 125cb1d..60f0776 100644 --- a/common/entities/authentication.py +++ b/common/entities/authentication.py @@ -1,6 +1,6 @@ __all__ = ["ApiKey", "Service", "ServiceAccountKey"] import logging -from datetime import datetime, timezone +from datetime import datetime, UTC from enum import Enum from uuid import uuid4 @@ -24,7 +24,7 @@ class ApiKey(Model): user = ForeignKeyField(User, backref="api_keys", on_delete="CASCADE", null=False, index=True) def is_expired(self) -> bool: - if datetime.now(timezone.utc) > self.expiry: + if datetime.now(UTC) > self.expiry: return True else: return False @@ -59,7 +59,7 @@ def is_expired(self) -> bool: # If no expiration date is set, the key's duration is unlimited if not self.expiry: return False - if datetime.now(timezone.utc) > self.expiry: + if datetime.now(UTC) > self.expiry: return True else: return False diff --git a/common/entities/base_entity.py b/common/entities/base_entity.py index 7ac8e3f..f0cd4ba 100644 --- a/common/entities/base_entity.py +++ b/common/entities/base_entity.py @@ -1,6 +1,6 @@ __all__ = ["ActivableEntityMixin", "AuditEntityMixin", "AuditUpdateTimeEntityMixin", "BaseEntity", "BaseModel", "DB"] -from datetime import datetime, timezone +from datetime import datetime, UTC from typing import Any from uuid import uuid4 @@ -53,5 +53,5 @@ class AuditUpdateTimeEntityMixin(Model): @classmethod def update(cls, *args: Any, **kwargs: Any) -> Any: - kwargs["updated_on"] = datetime.utcnow().replace(tzinfo=timezone.utc) + kwargs["updated_on"] = datetime.utcnow().replace(tzinfo=UTC) return super().update(*args, **kwargs) diff --git a/common/entities/company.py b/common/entities/company.py index 6c84f63..20d4cd1 100644 --- a/common/entities/company.py +++ b/common/entities/company.py @@ -1,6 +1,5 @@ __all__ = ["Company"] -from typing import Optional from peewee import CharField, ForeignKeyField @@ -13,5 +12,5 @@ class Company(BaseEntity, AuditEntityMixin): name = CharField(unique=True, null=False) @property - def parent(self) -> Optional[ForeignKeyField]: + def parent(self) -> ForeignKeyField | None: return None diff --git a/common/entities/component_meta.py b/common/entities/component_meta.py index 606f5cc..2bd2daf 100644 --- a/common/entities/component_meta.py +++ b/common/entities/component_meta.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Union +from typing import Any, Union from peewee import Field, ForeignKeyField, ModelBase, ModelSelect, ModelUpdate @@ -144,7 +144,7 @@ def create(cls: type[BaseModel], **data: object) -> BaseModel: component.save(force_insert=True) return component - def save(self: BaseModel, force_insert: bool = False, only: Optional[object] = None) -> Union[bool, int]: + def save(self: BaseModel, force_insert: bool = False, only: object | None = None) -> Union[bool, int]: ret = 0 if not force_insert and self.component: ret += self.component.save(only=only) or 0 diff --git a/common/entities/dataset_operation.py b/common/entities/dataset_operation.py index c80553e..3d0c201 100644 --- a/common/entities/dataset_operation.py +++ b/common/entities/dataset_operation.py @@ -13,8 +13,8 @@ class DatasetOperationType(Enum): - READ: str = "READ" - WRITE: str = "WRITE" + READ = "READ" + WRITE = "WRITE" class DatasetOperation(BaseEntity): diff --git a/common/entities/upcoming_instance.py b/common/entities/upcoming_instance.py index 949c486..362ad86 100644 --- a/common/entities/upcoming_instance.py +++ b/common/entities/upcoming_instance.py @@ -1,7 +1,6 @@ __all__ = ["UpcomingInstance"] from dataclasses import dataclass from datetime import datetime -from typing import Optional from common.entities.journey import Journey @@ -9,5 +8,5 @@ @dataclass class UpcomingInstance: journey: Journey - expected_start_time: Optional[datetime] = None - expected_end_time: Optional[datetime] = None + expected_start_time: datetime | None = None + expected_end_time: datetime | None = None diff --git a/common/entity_services/component_service.py b/common/entity_services/component_service.py index a772926..413479b 100644 --- a/common/entity_services/component_service.py +++ b/common/entity_services/component_service.py @@ -2,7 +2,6 @@ from datetime import datetime from itertools import cycle -from typing import Optional from collections.abc import Generator from peewee import Select @@ -31,7 +30,7 @@ def select_journeys(component: Component) -> Select: @classmethod def get_or_create_active_instances( - cls, component: Component, start_time: Optional[datetime] = None + cls, component: Component, start_time: datetime | None = None ) -> Generator[tuple[bool, Instance], None, None]: """ Retrieves active Instances for a given component. Create active Instances when a Journey does not have one. diff --git a/common/entity_services/helpers/filter_rules.py b/common/entity_services/helpers/filter_rules.py index 290924c..391256d 100644 --- a/common/entity_services/helpers/filter_rules.py +++ b/common/entity_services/helpers/filter_rules.py @@ -13,7 +13,7 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import Optional, TypeVar +from typing import TypeVar from collections.abc import Callable from uuid import UUID @@ -72,7 +72,7 @@ class ParamConfig: func: Callable -def _date_or_none(params: MultiDict, field_name: str) -> Optional[datetime]: +def _date_or_none(params: MultiDict, field_name: str) -> datetime | None: if date := params.get(field_name): try: return arrow.get(date).datetime @@ -81,7 +81,7 @@ def _date_or_none(params: MultiDict, field_name: str) -> Optional[datetime]: return None -def _str_to_bool(params: MultiDict, field_name: str) -> Optional[bool]: +def _str_to_bool(params: MultiDict, field_name: str) -> bool | None: if (value := params.get(field_name)) is None: return None return str_to_bool(value, field_name) @@ -130,28 +130,28 @@ class Filters: Extend by specifying the wanted attributes and how to unpack them in from_params. """ - active: Optional[bool] = None + active: bool | None = None component_ids: list[str] = field(default_factory=list) component_types: list[str] = field(default_factory=list) - date_range_end: Optional[datetime] = None - date_range_start: Optional[datetime] = None - end_range: Optional[datetime] = None - end_range_begin: Optional[datetime] = None - end_range_end: Optional[datetime] = None + date_range_end: datetime | None = None + date_range_start: datetime | None = None + end_range: datetime | None = None + end_range_begin: datetime | None = None + end_range_end: datetime | None = None event_ids: list[str] = field(default_factory=list) event_types: list[str] = field(default_factory=list) instance_ids: list[str] = field(default_factory=list) journey_ids: list[str] = field(default_factory=list) journey_names: list[str] = field(default_factory=list) - key: Optional[str] = None + key: str | None = None levels: list[str] = field(default_factory=list) pipeline_keys: list[str] = field(default_factory=list) project_ids: list[str] = field(default_factory=list) run_ids: list[str] = field(default_factory=list) run_keys: list[str] = field(default_factory=list) - start_range: Optional[datetime] = None - start_range_begin: Optional[datetime] = None - start_range_end: Optional[datetime] = None + start_range: datetime | None = None + start_range_begin: datetime | None = None + start_range_end: datetime | None = None statuses: list[str] = field(default_factory=list) task_ids: list[str] = field(default_factory=list) tools: list[str] = field(default_factory=list) @@ -166,9 +166,7 @@ def __bool__(self) -> bool: return False @staticmethod - def validate_time_range( - range_begin: Optional[datetime], range_end: Optional[datetime], range_begin_name: str - ) -> None: + def validate_time_range(range_begin: datetime | None, range_end: datetime | None, range_begin_name: str) -> None: if range_begin is None or range_end is None: return None if range_begin >= range_end: diff --git a/common/entity_services/helpers/list_rules.py b/common/entity_services/helpers/list_rules.py index 8352ea8..56e7869 100644 --- a/common/entity_services/helpers/list_rules.py +++ b/common/entity_services/helpers/list_rules.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from enum import Enum as std_Enum from enum import auto -from typing import Generic, Optional, TypeVar +from typing import TypeVar from collections.abc import Generator from marshmallow import EXCLUDE, Schema @@ -36,7 +36,7 @@ class Meta: @dataclass -class Page(Generic[T]): +class Page[T]: """ Useful for returning results from the service layer to get paginated results but also receive the total objects without pagination. @@ -83,7 +83,7 @@ class ListRules: page: int = DEFAULT_PAGE count: int = DEFAULT_COUNT sort: SortOrder = SortOrder.ASC - search: Optional[str] = None + search: str | None = None @classmethod def from_params_without_search(cls, params: MultiDict) -> ListRules: diff --git a/common/entity_services/instance_service.py b/common/entity_services/instance_service.py index 62139ec..feb07b0 100644 --- a/common/entity_services/instance_service.py +++ b/common/entity_services/instance_service.py @@ -2,7 +2,6 @@ from collections import Counter, defaultdict from datetime import datetime -from typing import Optional from collections.abc import Iterable from uuid import UUID @@ -163,11 +162,11 @@ def run_alerts_query(instances: Iterable[Instance]) -> ModelSelect: def get_instance_run_counts( instance: UUID | Instance, *, - include_run_statuses: Optional[Iterable[str]] = None, - exclude_run_statuses: Optional[Iterable[str]] = None, - journey: Optional[UUID] = None, - pipelines: Optional[Iterable[UUID]] = None, - end_before: Optional[datetime] = None, + include_run_statuses: Iterable[str] | None = None, + exclude_run_statuses: Iterable[str] | None = None, + journey: UUID | None = None, + pipelines: Iterable[UUID] | None = None, + end_before: datetime | None = None, ) -> dict[UUID, int]: """ Return a dict of pipelines with the corresponding run count per pipeline. diff --git a/common/entity_services/journey_service.py b/common/entity_services/journey_service.py index d6881ac..6be3e28 100644 --- a/common/entity_services/journey_service.py +++ b/common/entity_services/journey_service.py @@ -1,7 +1,6 @@ __all__ = ["JourneyService"] import logging -from typing import Optional from uuid import UUID from common.entities import Action, Company, Component, Journey, JourneyDagEdge, Organization, Project, Rule @@ -19,7 +18,7 @@ def get_rules_with_rules(journey_id: UUID, list_rules: ListRules) -> Page[Rule]: return Page[Rule].get_paginated_results(query, Rule.created_on, list_rules) @staticmethod - def get_action_by_implementation(journey_id: UUID, action_impl: str) -> Optional[Action]: + def get_action_by_implementation(journey_id: UUID, action_impl: str) -> Action | None: """ Fetches an Action entity given a Journey ID and the action implementation. diff --git a/common/entity_services/pipeline_service.py b/common/entity_services/pipeline_service.py index 0081ac4..551c4a4 100644 --- a/common/entity_services/pipeline_service.py +++ b/common/entity_services/pipeline_service.py @@ -1,6 +1,5 @@ __all__ = ["PipelineService"] import logging -from typing import Optional from common.entities import Pipeline @@ -9,6 +8,6 @@ class PipelineService: @staticmethod - def get_by_key_and_project(pipeline_key: Optional[str], project_id: str) -> Pipeline: + def get_by_key_and_project(pipeline_key: str | None, project_id: str) -> Pipeline: pipeline: Pipeline = Pipeline.get(Pipeline.key == pipeline_key, Pipeline.project == project_id) return pipeline diff --git a/common/entity_services/project_service.py b/common/entity_services/project_service.py index 6086e5a..0c7b3bb 100644 --- a/common/entity_services/project_service.py +++ b/common/entity_services/project_service.py @@ -2,7 +2,7 @@ from functools import reduce from operator import or_ -from typing import Any, Optional +from typing import Any from peewee import PREFETCH_TYPE, Value, fn, prefetch, DoesNotExist @@ -107,7 +107,7 @@ def get_components_with_rules(project_id: str, rules: ListRules, filters: Compon @staticmethod def get_runs_with_rules( - project_id: Optional[str], pipeline_ids: list[str], rules: ListRules, filters: RunFilters + project_id: str | None, pipeline_ids: list[str], rules: ListRules, filters: RunFilters ) -> Page[Run]: start_dt = fn.COALESCE(Run.start_time, Run.expected_start_time) query = Run.select(Run, start_dt.alias("start_dt")).distinct() @@ -154,7 +154,7 @@ def get_runs_with_rules( @staticmethod def get_instances_with_rules( - rules: ListRules, filters: Filters, project_ids: list[str], company_id: Optional[str] = None + rules: ListRules, filters: Filters, project_ids: list[str], company_id: str | None = None ) -> Page[Instance]: memberships = [Journey.project.in_(project_ids)] if project_ids else [] if company_id: @@ -211,7 +211,7 @@ def get_instances_with_rules( return Page[Instance](results=results, total=query.count()) @staticmethod - def get_journeys_with_rules(project_id: str, rules: ListRules, component_id: Optional[str] = None) -> Page[Journey]: + def get_journeys_with_rules(project_id: str, rules: ListRules, component_id: str | None = None) -> Page[Journey]: base_query = Journey.project == project_id if rules.search is not None: base_query &= Journey.name ** f"%{rules.search}%" diff --git a/common/entity_services/test_outcome_service.py b/common/entity_services/test_outcome_service.py index 9639889..b2835d7 100644 --- a/common/entity_services/test_outcome_service.py +++ b/common/entity_services/test_outcome_service.py @@ -1,6 +1,5 @@ __all__ = ["TestOutcomeService"] -from typing import Optional from uuid import UUID from peewee import SqliteDatabase @@ -17,9 +16,9 @@ def insert_from_event( *, event: TestOutcomesEvent, component_id: UUID, - instance_set_id: Optional[UUID] = None, - run_id: Optional[UUID] = None, - task_id: Optional[UUID] = None, + instance_set_id: UUID | None = None, + run_id: UUID | None = None, + task_id: UUID | None = None, ) -> None: test_outcomes = [] test_outcome_integrations = [] @@ -65,7 +64,7 @@ def insert_from_event( ) # Using the recursive lookup avoids having to check for None on optional values up the whole chain - testgen_dataset: Optional[TestgenDataset] = getattr_recursive( + testgen_dataset: TestgenDataset | None = getattr_recursive( event, "component_integrations__integrations__testgen", None ) diff --git a/common/entity_services/upcoming_instance_service.py b/common/entity_services/upcoming_instance_service.py index b02f671..64f44c9 100644 --- a/common/entity_services/upcoming_instance_service.py +++ b/common/entity_services/upcoming_instance_service.py @@ -4,7 +4,7 @@ from datetime import datetime from heapq import merge from operator import itemgetter -from typing import Optional, cast +from typing import cast from collections.abc import Generator from uuid import UUID from zoneinfo import ZoneInfo @@ -65,7 +65,7 @@ def _collect_journey_schedules( def _get_instance_times( schedules: JourneySchedules, start_time: datetime, - end_time: Optional[datetime], + end_time: datetime | None, ) -> Generator[tuple[datetime, bool], None, None]: """ Generate a sequence of expected instance start and end times from the given schedules @@ -96,8 +96,8 @@ class UpcomingInstanceService: def get_upcoming_instances_with_rules( rules: ListRules, filters: UpcomingInstanceFilters, - project_id: Optional[UUID] = None, - company_id: Optional[UUID] = None, + project_id: UUID | None = None, + company_id: UUID | None = None, ) -> list[UpcomingInstance]: assert filters.start_range is not None memberships = [] @@ -159,8 +159,8 @@ def get_upcoming_instances_with_rules( def get_upcoming_instances( journey: Journey, start_time: datetime, - end_time: Optional[datetime] = None, - schedules: Optional[JourneySchedules] = None, + end_time: datetime | None = None, + schedules: JourneySchedules | None = None, ) -> Generator[UpcomingInstance, None, None]: """ Get upcoming instances for the given journey diff --git a/common/entity_services/user_service.py b/common/entity_services/user_service.py index 1f6a072..4e093be 100644 --- a/common/entity_services/user_service.py +++ b/common/entity_services/user_service.py @@ -1,14 +1,10 @@ -from typing import Optional - from common.entities import User from common.entity_services.helpers import ListRules, Page class UserService: @staticmethod - def list_with_rules( - rules: ListRules, company_id: Optional[str] = None, name_filter: Optional[str] = None - ) -> Page[User]: + def list_with_rules(rules: ListRules, company_id: str | None = None, name_filter: str | None = None) -> Page[User]: query = User.select() if company_id: query = query.where(User.primary_company_id == company_id) diff --git a/common/events/base.py b/common/events/base.py index 64a496e..4d1f0d8 100644 --- a/common/events/base.py +++ b/common/events/base.py @@ -12,10 +12,9 @@ ] from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import datetime, UTC from enum import Enum as std_Enum from functools import partial -from typing import Optional from uuid import UUID from uuid import UUID as std_UUID from uuid import uuid4 @@ -44,14 +43,14 @@ def partition_identifier(self) -> str: @dataclass(kw_only=True) class ComponentMixin: - component_id: Optional[UUID] = None - component_type: Optional[ComponentType] = None + component_id: UUID | None = None + component_type: ComponentType | None = None @dataclass(kw_only=True) class BatchPipelineMixin: - batch_pipeline_id: Optional[UUID] = None - run_id: Optional[UUID] = None + batch_pipeline_id: UUID | None = None + run_id: UUID | None = None @cached_property def batch_pipeline(self) -> Pipeline: @@ -66,13 +65,13 @@ def run(self) -> Run: @dataclass(kw_only=True) class RunMixin: - run_id: Optional[UUID] = None + run_id: UUID | None = None @dataclass(kw_only=True) class TaskMixin: - task_id: Optional[UUID] = None - run_task_id: Optional[UUID] = None + task_id: UUID | None = None + run_task_id: UUID | None = None @cached_property def task(self) -> Task: @@ -109,11 +108,11 @@ class JourneysMixin: @dataclass(kw_only=True) class JourneyMixin: - journey_id: Optional[UUID] = None - instance_id: Optional[UUID] = None + journey_id: UUID | None = None + instance_id: UUID | None = None @dataclass(kw_only=True) class EventBaseMixin: event_id: UUID = field(default_factory=uuid4) - created_timestamp: datetime = field(default_factory=partial(datetime.now, tz=timezone.utc)) + created_timestamp: datetime = field(default_factory=partial(datetime.now, tz=UTC)) diff --git a/common/events/converters.py b/common/events/converters.py index accaf8b..88a9429 100644 --- a/common/events/converters.py +++ b/common/events/converters.py @@ -1,5 +1,5 @@ from dataclasses import asdict, fields -from typing import Optional, cast +from typing import cast from common.entities import ComponentType, RunStatus from common.entities.event import ApiEventType @@ -29,7 +29,7 @@ def _extract_common_attributes(self, event: EventV2) -> dict: "payload_keys": event.event_payload.payload_keys, } - def _extract_batch_attributes(self, batch: Optional[v2.BatchPipelineData]) -> dict: + def _extract_batch_attributes(self, batch: v2.BatchPipelineData | None) -> dict: data = { "run_name": batch.run_name if batch else None, "run_key": batch.run_key if batch else None, @@ -40,7 +40,7 @@ def _extract_batch_attributes(self, batch: Optional[v2.BatchPipelineData]) -> di data["component_tool"] = batch.details.tool return data - def _extract_dataset_attributes(self, dataset: Optional[v2.DatasetData]) -> dict: + def _extract_dataset_attributes(self, dataset: v2.DatasetData | None) -> dict: data = { "dataset_key": dataset.dataset_key if dataset else None, "dataset_name": dataset.details.name if dataset and dataset.details else None, @@ -49,7 +49,7 @@ def _extract_dataset_attributes(self, dataset: Optional[v2.DatasetData]) -> dict data["component_tool"] = dataset.details.tool return data - def _extract_server_attributes(self, server: Optional[v2.ServerData]) -> dict: + def _extract_server_attributes(self, server: v2.ServerData | None) -> dict: data = { "server_key": server.server_key if server else None, "server_name": server.details.name if server and server.details else None, @@ -58,7 +58,7 @@ def _extract_server_attributes(self, server: Optional[v2.ServerData]) -> dict: data["component_tool"] = server.details.tool return data - def _extract_stream_attributes(self, stream: Optional[v2.StreamData]) -> dict: + def _extract_stream_attributes(self, stream: v2.StreamData | None) -> dict: data = { "stream_key": stream.stream_key if stream else None, "stream_name": stream.details.name if stream and stream.details else None, @@ -69,10 +69,10 @@ def _extract_stream_attributes(self, stream: Optional[v2.StreamData]) -> dict: def _extract_component_data( self, - batch: Optional[v2.BatchPipelineData], - dataset: Optional[v2.DatasetData], - server: Optional[v2.ServerData], - stream: Optional[v2.StreamData], + batch: v2.BatchPipelineData | None, + dataset: v2.DatasetData | None, + server: v2.ServerData | None, + stream: v2.StreamData | None, ) -> dict: data = { **self._extract_batch_attributes(batch), @@ -84,7 +84,7 @@ def _extract_component_data( data["component_tool"] = None return data - def _extract_task_attributes(self, event: EventV2, batch: Optional[v2.BatchPipelineData]) -> dict: + def _extract_task_attributes(self, event: EventV2, batch: v2.BatchPipelineData | None) -> dict: return { "task_key": batch.task_key if batch else None, "task_name": batch.task_name if batch else None, @@ -116,8 +116,8 @@ def _extract_testgen_item(self, testgen: dict) -> v1.TestgenItem: ) def _extract_test_outcome_item_integrations( - self, integrations: Optional[dict] - ) -> Optional[v1.TestOutcomeItemIntegrations]: + self, integrations: dict | None + ) -> v1.TestOutcomeItemIntegrations | None: if integrations is None: return None return v1.TestOutcomeItemIntegrations(testgen=self._extract_testgen_item(integrations["testgen"])) @@ -142,8 +142,8 @@ def _extract_testgen_table(self, tables: dict) -> v1.TestgenTable: ) def _extract_testgen_table_group_config( - self, table_group_configuration: Optional[dict] - ) -> Optional[v1.TestgenTableGroupV1]: + self, table_group_configuration: dict | None + ) -> v1.TestgenTableGroupV1 | None: if table_group_configuration is None: return None return v1.TestgenTableGroupV1( @@ -167,7 +167,7 @@ def _extract_testgen_integration_componenet(self, integrations: dict) -> v1.Test def _extract_component_integrations( self, component: v2.TestGenComponentData - ) -> Optional[v1.TestGenTestOutcomeIntegrationComponent]: + ) -> v1.TestGenTestOutcomeIntegrationComponent | None: integrations = next(c for f in fields(component) if (c := getattr(component, f.name, None))).integrations if integrations is None: return None @@ -300,7 +300,7 @@ def _extract_common_internal_attributes(self, event: Event) -> dict: "version": event.version, } - def _extract_batch_pipeline_data(self, event: Event) -> Optional[v2.BatchPipelineData]: + def _extract_batch_pipeline_data(self, event: Event) -> v2.BatchPipelineData | None: new_component_data = ( v2.NewComponentData(name=event.pipeline_name, tool=event.component_tool) if event.pipeline_name or event.component_tool @@ -318,9 +318,7 @@ def _extract_batch_pipeline_data(self, event: Event) -> Optional[v2.BatchPipelin else: return None - def _extract_testgen_batch_pipeline_data( - self, event: v1.TestOutcomesEvent - ) -> Optional[v2.TestGenBatchPipelineData]: + def _extract_testgen_batch_pipeline_data(self, event: v1.TestOutcomesEvent) -> v2.TestGenBatchPipelineData | None: if data := self._extract_batch_pipeline_data(event): return v2.TestGenBatchPipelineData( batch_key=data.batch_key, @@ -333,7 +331,7 @@ def _extract_testgen_batch_pipeline_data( ) return None - def _extract_testgen_dataset_data(self, event: v1.TestOutcomesEvent) -> Optional[v2.TestGenDatasetData]: + def _extract_testgen_dataset_data(self, event: v1.TestOutcomesEvent) -> v2.TestGenDatasetData | None: if data := self._extract_dataset_data(event): return v2.TestGenDatasetData( dataset_key=data.dataset_key, @@ -342,7 +340,7 @@ def _extract_testgen_dataset_data(self, event: v1.TestOutcomesEvent) -> Optional ) return None - def _extract_testgen_stream_data(self, event: v1.TestOutcomesEvent) -> Optional[v2.TestGenStreamData]: + def _extract_testgen_stream_data(self, event: v1.TestOutcomesEvent) -> v2.TestGenStreamData | None: if data := self._extract_stream_data(event): return v2.TestGenStreamData( stream_key=data.stream_key, @@ -351,7 +349,7 @@ def _extract_testgen_stream_data(self, event: v1.TestOutcomesEvent) -> Optional[ ) return None - def _extract_testgen_server_data(self, event: v1.TestOutcomesEvent) -> Optional[v2.TestGenServerData]: + def _extract_testgen_server_data(self, event: v1.TestOutcomesEvent) -> v2.TestGenServerData | None: if data := self._extract_server_data(event): return v2.TestGenServerData( server_key=data.server_key, @@ -360,7 +358,7 @@ def _extract_testgen_server_data(self, event: v1.TestOutcomesEvent) -> Optional[ ) return None - def _extract_dataset_data(self, event: Event) -> Optional[v2.DatasetData]: + def _extract_dataset_data(self, event: Event) -> v2.DatasetData | None: new_component_data = ( v2.NewComponentData(name=event.dataset_name, tool=event.component_tool) if event.dataset_name or event.component_tool @@ -371,7 +369,7 @@ def _extract_dataset_data(self, event: Event) -> Optional[v2.DatasetData]: else: return None - def _extract_stream_data(self, event: Event) -> Optional[v2.StreamData]: + def _extract_stream_data(self, event: Event) -> v2.StreamData | None: new_component_data = ( v2.NewComponentData(name=event.stream_name, tool=event.component_tool) if event.stream_name or event.component_tool @@ -385,7 +383,7 @@ def _extract_stream_data(self, event: Event) -> Optional[v2.StreamData]: else: return None - def _extract_server_data(self, event: Event) -> Optional[v2.ServerData]: + def _extract_server_data(self, event: Event) -> v2.ServerData | None: new_component_data = ( v2.NewComponentData(name=event.server_name, tool=event.component_tool) if event.server_name or event.component_tool @@ -438,8 +436,8 @@ def _extract_testgen_item(self, testgen: dict) -> v2.TestgenItem: ) def _extract_test_outcome_item_integrations( - self, integrations: Optional[dict] - ) -> Optional[v2.TestOutcomeItemIntegrations]: + self, integrations: dict | None + ) -> v2.TestOutcomeItemIntegrations | None: if integrations is None: return None return v2.TestOutcomeItemIntegrations(testgen=self._extract_testgen_item(integrations["testgen"])) @@ -464,8 +462,8 @@ def _extract_testgen_table(self, tables: dict) -> v2.TestgenTable: ) def _extract_testgen_table_group_config( - self, table_group_configuration: Optional[dict] - ) -> Optional[v2.TestgenTableGroupV1]: + self, table_group_configuration: dict | None + ) -> v2.TestgenTableGroupV1 | None: if table_group_configuration is None: return None return v2.TestgenTableGroupV1( @@ -484,7 +482,7 @@ def _extract_testgen_integrations(self, testgen: dict) -> v2.TestgenDataset: def _extract_testgen_integration_componenet( self, event: v1.TestOutcomesEvent - ) -> Optional[v2.TestGenTestOutcomeIntegrations]: + ) -> v2.TestGenTestOutcomeIntegrations | None: if i := event.component_integrations: return v2.TestGenTestOutcomeIntegrations( testgen=self._extract_testgen_integrations(asdict(i.integrations.testgen)), diff --git a/common/events/internal/alert.py b/common/events/internal/alert.py index a928702..889ead1 100644 --- a/common/events/internal/alert.py +++ b/common/events/internal/alert.py @@ -4,7 +4,6 @@ ] from dataclasses import dataclass -from typing import Optional from uuid import UUID from common.decorators import cached_property @@ -20,7 +19,7 @@ class AlertBase: alert_id: UUID level: AlertLevel - description: Optional[str] + description: str | None @dataclass(kw_only=True) diff --git a/common/events/internal/scheduled_event.py b/common/events/internal/scheduled_event.py index 0c246f0..2844898 100644 --- a/common/events/internal/scheduled_event.py +++ b/common/events/internal/scheduled_event.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from datetime import datetime -from typing import Optional from uuid import UUID from common.events.base import ComponentMixin @@ -19,7 +18,7 @@ class ScheduledEvent(ComponentMixin): schedule_id: UUID schedule_type: ScheduleType schedule_timestamp: datetime - schedule_margin: Optional[datetime] = None + schedule_margin: datetime | None = None @property def partition_identifier(self) -> str: diff --git a/common/events/v1/dataset_operation_event.py b/common/events/v1/dataset_operation_event.py index 8594752..f5e9ac1 100644 --- a/common/events/v1/dataset_operation_event.py +++ b/common/events/v1/dataset_operation_event.py @@ -3,7 +3,6 @@ __all__ = ["DatasetOperationSchema", "DatasetOperationApiSchema", "DatasetOperationEvent", "DatasetOperationType"] from dataclasses import dataclass -from typing import Optional from marshmallow import Schema, ValidationError, validates_schema from marshmallow.fields import Str @@ -55,7 +54,7 @@ class DatasetOperationEvent(Event): __api_schema__ = DatasetOperationApiSchema operation: str - path: Optional[str] = None + path: str | None = None def accept(self, handler: EventHandlerBase) -> bool: return handler.handle_dataset_operation(self) diff --git a/common/events/v1/event.py b/common/events/v1/event.py index 47de3a2..2aec201 100644 --- a/common/events/v1/event.py +++ b/common/events/v1/event.py @@ -4,8 +4,7 @@ import logging from dataclasses import InitVar, dataclass, field -from datetime import datetime, timezone -from typing import Optional +from datetime import datetime, UTC from uuid import UUID, uuid4 from common.decorators import cached_property @@ -72,35 +71,35 @@ class Event(EventInterface): this to define your own Event to ensure your events have all the expected fields. """ - pipeline_key: Optional[str] + pipeline_key: str | None source: str event_id: UUID event_timestamp: datetime received_timestamp: datetime metadata: dict[str, object] event_type: str - run_name: Optional[str] - run_key: Optional[str] - component_tool: Optional[str] - project_id: Optional[UUID] - run_id: Optional[UUID] - pipeline_id: Optional[UUID] - task_id: Optional[UUID] - task_name: Optional[str] - task_key: Optional[str] - run_task_id: Optional[UUID] - external_url: Optional[str] - pipeline_name: Optional[str] + run_name: str | None + run_key: str | None + component_tool: str | None + project_id: UUID | None + run_id: UUID | None + pipeline_id: UUID | None + task_id: UUID | None + task_name: str | None + task_key: str | None + run_task_id: UUID | None + external_url: str | None + pipeline_name: str | None instances: list[InstanceRef] - dataset_id: Optional[UUID] - dataset_key: Optional[str] - dataset_name: Optional[str] - server_id: Optional[UUID] - server_key: Optional[str] - server_name: Optional[str] - stream_id: Optional[UUID] - stream_key: Optional[str] - stream_name: Optional[str] + dataset_id: UUID | None + dataset_key: str | None + dataset_name: str | None + server_id: UUID | None + server_key: str | None + server_name: str | None + stream_id: UUID | None + stream_key: str | None + stream_name: str | None payload_keys: list[str] version: EventVersion @@ -131,7 +130,7 @@ def as_event_from_request(cls, request_body: dict) -> Event: # At a glance, we could have these defaulted by the schema. However, the spec says that if timestamp is not # defined, it _must_ be matched to the received time. Setting it in the schema would generate very tiny # differences in time. - current_time = str(datetime.now(timezone.utc)) + current_time = str(datetime.now(UTC)) if "event_timestamp" not in event_body: event_body["event_timestamp"] = current_time event_body["received_timestamp"] = current_time @@ -215,7 +214,7 @@ def component_journeys(self) -> list[Journey]: @cached_property def component_key_details(self) -> EventComponentDetails: if not (key := next((attr for attr in EVENT_ATTRIBUTES.keys() if getattr(self, attr, None) is not None), None)): - LOG.error(f"Event component key details cannot be parsed from the event information provided: " f"{self}") + LOG.error(f"Event component key details cannot be parsed from the event information provided: {self}") raise ValueError("Event component key details cannot be parsed.") return EVENT_ATTRIBUTES[key] @@ -227,7 +226,7 @@ def component_key(self) -> str: return key @property - def component_id(self) -> Optional[UUID]: + def component_id(self) -> UUID | None: return getattr(self, self.component_key_details.component_id, None) @component_id.setter @@ -239,7 +238,7 @@ def component_id(self, value: UUID) -> None: setattr(self, self.component_key_details.component_id, value) @property - def component_name(self) -> Optional[str]: + def component_name(self) -> str | None: return getattr(self, self.component_key_details.component_name, None) @property diff --git a/common/events/v1/event_schemas.py b/common/events/v1/event_schemas.py index 1c63efc..f8ad150 100644 --- a/common/events/v1/event_schemas.py +++ b/common/events/v1/event_schemas.py @@ -1,7 +1,7 @@ __all__ = ["EventSchemaInterface", "EventApiSchema", "EventSchema"] import json -from datetime import timezone +from datetime import UTC from typing import Any, Union from marshmallow import Schema, ValidationError, post_dump, post_load, pre_load, validates_schema @@ -176,7 +176,7 @@ class EventApiSchema(EventSchemaInterface): ) event_timestamp = AwareDateTime( format="iso", - default_timezone=timezone.utc, + default_timezone=UTC, metadata={ "description": ( "Optional. An ISO8601 timestamp that describes when the event occurred. If no timezone " @@ -251,7 +251,7 @@ class EventSchema(EventApiSchema): received_timestamp = AwareDateTime( format="iso", required=True, - default_timezone=timezone.utc, + default_timezone=UTC, metadata={"description": "An ISO timestamp that the Event Ingestion API applies when it receives the event."}, ) # This is the source of the message. diff --git a/common/events/v1/test_outcomes_event.py b/common/events/v1/test_outcomes_event.py index 37c64be..9bd06ea 100644 --- a/common/events/v1/test_outcomes_event.py +++ b/common/events/v1/test_outcomes_event.py @@ -18,11 +18,11 @@ ] from dataclasses import asdict, dataclass -from datetime import datetime, timezone +from datetime import datetime, UTC from decimal import Decimal as std_decimal from enum import Enum as std_Enum from enum import IntEnum as std_IntEnum -from typing import Any, Optional, Union +from typing import Any, Union from uuid import UUID as std_UUID from marshmallow import Schema, ValidationError, post_load, validates_schema @@ -100,7 +100,7 @@ class TestgenItem: test_suite: str version: int test_parameters: list[TestgenItemTestParameters] - columns: Optional[list[str]] = None + columns: list[str] | None = None class TestgenItemSchema(Schema): @@ -163,8 +163,8 @@ def to_testoutcome_item_integrations(self, data: dict, **_: Any) -> TestOutcomeI @dataclass class TestgenTable: include_list: list[str] - include_pattern: Optional[str] = None - exclude_pattern: Optional[str] = None + include_pattern: str | None = None + exclude_pattern: str | None = None class TestgenTableSchema(Schema): @@ -217,9 +217,9 @@ def to_testgen_table(self, data: dict, **_: Any) -> TestgenTable: class TestgenTableGroupV1: group_id: std_UUID project_code: str - uses_sampling: Optional[bool] = None - sample_percentage: Optional[str] = None - sample_minimum_count: Optional[int] = None + uses_sampling: bool | None = None + sample_percentage: str | None = None + sample_minimum_count: int | None = None class TestgenTableGroupV1Schema(Schema): @@ -263,8 +263,8 @@ class TestgenDataset: database_name: str connection_name: str tables: TestgenTable - schema: Optional[str] = None - table_group_configuration: Optional[TestgenTableGroupV1] = None + schema: str | None = None + table_group_configuration: TestgenTableGroupV1 | None = None class TestgenDatasetSchema(Schema): @@ -346,19 +346,19 @@ class TestOutcomeItem: name: str status: str description: str = "" - start_time: Optional[datetime] = None - end_time: Optional[datetime] = None - metadata: Optional[dict[str, Any]] = None - metric_value: Optional[Decimal] = None - metric_name: Optional[str] = None - metric_description: Optional[str] = None - min_threshold: Optional[Decimal] = None - max_threshold: Optional[Decimal] = None - integrations: Optional[TestOutcomeItemIntegrations] = None - dimensions: Optional[list[str]] = None - result: Optional[str] = None - type: Optional[str] = None - key: Optional[str] = None + start_time: datetime | None = None + end_time: datetime | None = None + metadata: dict[str, Any] | None = None + metric_value: Decimal | None = None + metric_name: str | None = None + metric_description: str | None = None + min_threshold: Decimal | None = None + max_threshold: Decimal | None = None + integrations: TestOutcomeItemIntegrations | None = None + dimensions: list[str] | None = None + result: str | None = None + type: str | None = None + key: str | None = None # region Schemas @@ -373,8 +373,7 @@ class TestOutcomeItemSchema(Schema): enum=TestStatuses, metadata={ "description": ( - "Required. The test status to be applied. Can set the status for both tests in runs and " - "tests in tasks." + "Required. The test status to be applied. Can set the status for both tests in runs and tests in tasks." ) }, ) @@ -384,13 +383,13 @@ class TestOutcomeItemSchema(Schema): ) start_time = AwareDateTime( format="iso", - default_timezone=timezone.utc, + default_timezone=UTC, allow_none=True, metadata={"description": "An ISO timestamp of when the test execution started."}, ) end_time = AwareDateTime( format="iso", - default_timezone=timezone.utc, + default_timezone=UTC, allow_none=True, metadata={"description": "An ISO timestamp of when the test execution ended."}, ) @@ -530,7 +529,7 @@ class TestOutcomesEvent(Event): """Represents the single result of a test.""" test_outcomes: list[TestOutcomeItem] - component_integrations: Optional[TestGenTestOutcomeIntegrationComponent] = None + component_integrations: TestGenTestOutcomeIntegrationComponent | None = None __schema__ = TestOutcomesSchema __api_schema__ = TestOutcomesApiSchema diff --git a/common/events/v2/base.py b/common/events/v2/base.py index e983983..4451195 100644 --- a/common/events/v2/base.py +++ b/common/events/v2/base.py @@ -7,8 +7,7 @@ from dataclasses import dataclass -from datetime import datetime, timezone -from typing import Optional +from datetime import datetime, UTC from marshmallow import Schema from marshmallow.fields import UUID, AwareDateTime, Dict, List, Str, Url @@ -24,9 +23,9 @@ @dataclass class BasePayload: - event_timestamp: Optional[datetime] + event_timestamp: datetime | None metadata: dict[str, object] - external_url: Optional[str] + external_url: str | None payload_keys: list[str] @@ -38,7 +37,7 @@ class BasePayloadSchema(Schema): event_timestamp = AwareDateTime( load_default=None, format="iso", - default_timezone=timezone.utc, + default_timezone=UTC, metadata={ "description": ( "An ISO8601 timestamp that describes when the event occurred. " diff --git a/common/events/v2/component_data.py b/common/events/v2/component_data.py index 51f9f24..e0cbc70 100644 --- a/common/events/v2/component_data.py +++ b/common/events/v2/component_data.py @@ -15,7 +15,7 @@ ] from dataclasses import dataclass -from typing import Any, Optional +from typing import Any from marshmallow import Schema, ValidationError, post_load, validates_schema from marshmallow.fields import Nested, Str @@ -28,44 +28,44 @@ @dataclass class NewComponentData: - name: Optional[str] - tool: Optional[str] + name: str | None + tool: str | None @dataclass class BatchPipelineData: batch_key: str run_key: str - run_name: Optional[str] - task_key: Optional[str] - task_name: Optional[str] - details: Optional[NewComponentData] + run_name: str | None + task_key: str | None + task_name: str | None + details: NewComponentData | None @dataclass class DatasetData: dataset_key: str - details: Optional[NewComponentData] + details: NewComponentData | None @dataclass class ServerData: server_key: str - details: Optional[NewComponentData] + details: NewComponentData | None @dataclass class StreamData: stream_key: str - details: Optional[NewComponentData] + details: NewComponentData | None @dataclass class ComponentData: - batch_pipeline: Optional[BatchPipelineData] - stream: Optional[StreamData] - dataset: Optional[DatasetData] - server: Optional[ServerData] + batch_pipeline: BatchPipelineData | None + stream: StreamData | None + dataset: DatasetData | None + server: ServerData | None class NewComponentDataSchema(Schema): diff --git a/common/events/v2/dataset_operation.py b/common/events/v2/dataset_operation.py index 22de9a6..9f9cb0e 100644 --- a/common/events/v2/dataset_operation.py +++ b/common/events/v2/dataset_operation.py @@ -7,7 +7,7 @@ from dataclasses import dataclass from enum import Enum as std_Enum -from typing import Any, Optional +from typing import Any from marshmallow import post_load from marshmallow.fields import Enum, Nested, Str @@ -28,7 +28,7 @@ class DatasetOperationType(std_Enum): class DatasetOperation(BasePayload): dataset_component: DatasetData operation: DatasetOperationType - path: Optional[str] + path: str | None class DatasetOperationSchema(BasePayloadSchema): diff --git a/common/events/v2/test_outcomes.py b/common/events/v2/test_outcomes.py index 67c0fd8..b529781 100644 --- a/common/events/v2/test_outcomes.py +++ b/common/events/v2/test_outcomes.py @@ -13,10 +13,10 @@ from dataclasses import dataclass from dataclasses import fields as dc_fields -from datetime import datetime, timezone +from datetime import datetime, UTC from decimal import Decimal as std_Decimal from enum import Enum as std_Enum -from typing import Any, Optional +from typing import Any from marshmallow import Schema, ValidationError, post_load, validates_schema from marshmallow.fields import AwareDateTime, Decimal, Dict, Enum, List, Nested, Str @@ -56,23 +56,23 @@ class TestOutcomeItem: status: TestStatus description: str metadata: dict[str, Any] - start_time: Optional[datetime] - end_time: Optional[datetime] - metric_value: Optional[std_Decimal] - metric_name: Optional[str] - metric_description: Optional[str] - metric_min_threshold: Optional[std_Decimal] - metric_max_threshold: Optional[std_Decimal] - integrations: Optional[TestOutcomeItemIntegrations] - dimensions: Optional[list[str]] - result: Optional[str] - type: Optional[str] - key: Optional[str] + start_time: datetime | None + end_time: datetime | None + metric_value: std_Decimal | None + metric_name: str | None + metric_description: str | None + metric_min_threshold: std_Decimal | None + metric_max_threshold: std_Decimal | None + integrations: TestOutcomeItemIntegrations | None + dimensions: list[str] | None + result: str | None + type: str | None + key: str | None @dataclass class TestGenIntegrations: - integrations: Optional[TestGenTestOutcomeIntegrations] + integrations: TestGenTestOutcomeIntegrations | None @dataclass @@ -93,10 +93,10 @@ class TestGenStreamData(StreamData, TestGenIntegrations): ... @dataclass class TestGenComponentData: - batch_pipeline: Optional[TestGenBatchPipelineData] - stream: Optional[TestGenStreamData] - dataset: Optional[TestGenDatasetData] - server: Optional[TestGenServerData] + batch_pipeline: TestGenBatchPipelineData | None + stream: TestGenStreamData | None + dataset: TestGenDatasetData | None + server: TestGenServerData | None @dataclass @@ -125,8 +125,7 @@ class TestOutcomeItemSchema(Schema): enum=TestStatus, metadata={ "description": ( - "Required. The test status to be applied. Can set the status for both tests in runs and " - "tests in tasks." + "Required. The test status to be applied. Can set the status for both tests in runs and tests in tasks." ) }, ) @@ -144,13 +143,13 @@ class TestOutcomeItemSchema(Schema): ) start_time = AwareDateTime( format="iso", - default_timezone=timezone.utc, + default_timezone=UTC, load_default=None, metadata={"description": "An ISO timestamp of when the test execution started."}, ) end_time = AwareDateTime( format="iso", - default_timezone=timezone.utc, + default_timezone=UTC, load_default=None, metadata={"description": "An ISO timestamp of when the test execution ended."}, ) diff --git a/common/events/v2/testgen.py b/common/events/v2/testgen.py index 7d45c1a..56978c0 100644 --- a/common/events/v2/testgen.py +++ b/common/events/v2/testgen.py @@ -14,7 +14,7 @@ from dataclasses import asdict, dataclass from decimal import Decimal as std_Decimal from enum import IntEnum -from typing import Any, Optional, Union +from typing import Any, Union from uuid import UUID as std_UUID from marshmallow import Schema, ValidationError, post_load, validates_schema @@ -60,7 +60,7 @@ class TestgenItem: test_suite: str version: TestgenIntegrationVersions test_parameters: list[TestgenItemTestParameters] - columns: Optional[list[str]] + columns: list[str] | None @dataclass @@ -71,17 +71,17 @@ class TestOutcomeItemIntegrations: @dataclass class TestgenTable: include_list: list[str] - include_pattern: Optional[str] - exclude_pattern: Optional[str] + include_pattern: str | None + exclude_pattern: str | None @dataclass class TestgenTableGroupV1: group_id: std_UUID project_code: str - uses_sampling: Optional[bool] - sample_percentage: Optional[str] - sample_minimum_count: Optional[int] + uses_sampling: bool | None + sample_percentage: str | None + sample_minimum_count: int | None @dataclass @@ -90,8 +90,8 @@ class TestgenDataset: database_name: str connection_name: str tables: TestgenTable - schema: Optional[str] - table_group_configuration: Optional[TestgenTableGroupV1] + schema: str | None + table_group_configuration: TestgenTableGroupV1 | None @dataclass diff --git a/common/kafka/consumer.py b/common/kafka/consumer.py index 402f872..b9372d3 100644 --- a/common/kafka/consumer.py +++ b/common/kafka/consumer.py @@ -3,7 +3,7 @@ import logging import signal from types import FrameType -from typing import Any, Optional +from typing import Any from collections.abc import Iterator from confluent_kafka import Consumer, Message @@ -49,7 +49,7 @@ def init_handlers() -> None: signal.signal(signal.SIGTERM, GracefulKiller.exit_gracefully) @staticmethod - def exit_gracefully(sig_num: int, frame: Optional[FrameType]) -> Any: + def exit_gracefully(sig_num: int, frame: FrameType | None) -> Any: LOG.info(f"Signal {sig_num} received, attempting to exit gracefully. Use SIGKILL to terminate immediately.") GracefulKiller.should_exit = True @@ -152,9 +152,9 @@ def commit(self) -> Any: except Exception as e: raise ConsumerCommitError from e - def poll(self) -> Optional[KafkaMessage]: + def poll(self) -> KafkaMessage | None: try: - msg: Optional[Message] = self.consumer.poll(CONSUMER_POLL_PERIOD_SECS) + msg: Message | None = self.consumer.poll(CONSUMER_POLL_PERIOD_SECS) except DisconnectedConsumerError: raise except Exception as ex: diff --git a/common/kafka/message.py b/common/kafka/message.py index ed13d51..b6f000b 100644 --- a/common/kafka/message.py +++ b/common/kafka/message.py @@ -1,12 +1,12 @@ __all__ = ["KafkaMessage"] from dataclasses import dataclass -from typing import Generic, Optional, TypeVar +from typing import TypeVar T = TypeVar("T") @dataclass(frozen=True, kw_only=True) -class KafkaMessage(Generic[T]): +class KafkaMessage[T]: """ A generic Kafka message @@ -19,4 +19,4 @@ class KafkaMessage(Generic[T]): partition: int offset: int headers: dict - key: Optional[str] = None + key: str | None = None diff --git a/common/kafka/producer.py b/common/kafka/producer.py index e272f74..72e82cf 100644 --- a/common/kafka/producer.py +++ b/common/kafka/producer.py @@ -6,7 +6,7 @@ import uuid from contextlib import contextmanager from types import TracebackType -from typing import Any, Optional +from typing import Any from collections.abc import Callable, Generator from confluent_kafka import KafkaError, KafkaException, Message, Producer @@ -88,7 +88,7 @@ def __enter__(self) -> KafkaProducer: return self def __exit__( - self, exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], tb: Optional[TracebackType] + self, exc_type: type[BaseException] | None, exc_value: BaseException | None, tb: TracebackType | None ) -> None: self.disconnect() @@ -105,7 +105,7 @@ def is_topic_available(self, topic: Topic, timeout: int = KAFKA_OP_TIMEOUT_SECS) metadata = self.producer.list_topics(topic=topic.name, timeout=timeout) return len(metadata.topics[topic.name].partitions) > 0 - def produce(self, topic: Topic, event: Any, callback: Optional[Callable] = None) -> None: + def produce(self, topic: Topic, event: Any, callback: Callable | None = None) -> None: def delivery_report(err: KafkaError, message: Message) -> None: """Called once for each message produced to indicate delivery result. Triggered by poll() or flush().""" if err is not None: @@ -149,7 +149,7 @@ class KafkaTransactionalProducer(KafkaProducer): """ - def __init__(self, config: dict, tx_consumer: Optional[KafkaTransactionalConsumer] = None) -> None: + def __init__(self, config: dict, tx_consumer: KafkaTransactionalConsumer | None = None) -> None: if PRODUCER_TX_MANDATORY_SETTINGS.keys() & config.keys(): raise ProducerConfigurationError( f"Local configuration cannot override any of {PRODUCER_TX_MANDATORY_SETTINGS.keys()}" diff --git a/common/kafka/topic.py b/common/kafka/topic.py index 1988c36..4bd1b69 100644 --- a/common/kafka/topic.py +++ b/common/kafka/topic.py @@ -9,7 +9,7 @@ import json from dataclasses import dataclass -from typing import Any, NamedTuple, Optional, Protocol +from typing import Any, NamedTuple, Protocol from confluent_kafka import Message @@ -29,7 +29,7 @@ class ProduceMessageArgs(NamedTuple): value: bytes topic: str headers: dict[str, str] - key: Optional[str] = None + key: str | None = None def as_dict(self) -> dict[str, Any]: d = { @@ -44,7 +44,7 @@ def as_dict(self) -> dict[str, Any]: return d -def _get_headers_as_dict(headers: Optional[list[tuple[str, bytes]]]) -> dict[str, str]: +def _get_headers_as_dict(headers: list[tuple[str, bytes]] | None) -> dict[str, str]: return {k: v.decode("utf-8") for k, v in headers or {}} diff --git a/common/logging/json_logging.py b/common/logging/json_logging.py index a912df0..1a4b6d7 100644 --- a/common/logging/json_logging.py +++ b/common/logging/json_logging.py @@ -1,12 +1,12 @@ __all__ = ["JsonFormatter"] import logging -from datetime import datetime, timezone +from datetime import datetime, UTC from json import dumps from common.json_encoder import JsonExtendedEncoder -UTC = timezone.utc +UTC = UTC class JsonFormatter(logging.Formatter): diff --git a/common/messagepack.py b/common/messagepack.py index 7005b8c..c77080e 100644 --- a/common/messagepack.py +++ b/common/messagepack.py @@ -9,7 +9,7 @@ from io import BytesIO from pathlib import Path, PurePath, PurePosixPath, PureWindowsPath from types import MappingProxyType -from typing import Any, BinaryIO, Optional, TextIO, Union, cast +from typing import Any, BinaryIO, TextIO, Union, cast from collections.abc import Callable from uuid import UUID @@ -261,7 +261,7 @@ def decode_ext(code: int, data: bytes) -> object: return ExtType(code, data) -def dump(value: object, flo: FLO, hook: Optional[Callable] = None) -> None: +def dump(value: object, flo: FLO, hook: Callable | None = None) -> None: """ Serialize as msgpack and write the result to a file-like-object. @@ -281,7 +281,7 @@ def dump(value: object, flo: FLO, hook: Optional[Callable] = None) -> None: ) -def dumps(value: object, hook: Optional[Callable] = None) -> bytes: +def dumps(value: object, hook: Callable | None = None) -> bytes: """ Serialize object to msgpack and return resulting messagepack bytes. @@ -295,7 +295,7 @@ def dumps(value: object, hook: Optional[Callable] = None) -> bytes: return result -def load(flo: FLO, object_hook: Optional[Callable] = None) -> Any: +def load(flo: FLO, object_hook: Callable | None = None) -> Any: """ Deserialize a msgpack file-like-object. @@ -317,7 +317,7 @@ def load(flo: FLO, object_hook: Optional[Callable] = None) -> Any: return result -def loads(stream: bytes, object_hook: Optional[Callable] = None) -> Any: +def loads(stream: bytes, object_hook: Callable | None = None) -> Any: """ Deserialize msgpack bytes diff --git a/common/model.py b/common/model.py index d9009fa..9698289 100644 --- a/common/model.py +++ b/common/model.py @@ -7,7 +7,6 @@ from contextlib import suppress from importlib import import_module from types import ModuleType -from typing import Optional from peewee import Database, Model, SchemaManager @@ -17,7 +16,7 @@ LOG = logging.getLogger(__name__) -def walk(module: Optional[ModuleType] = None) -> dict[str, Model]: +def walk(module: ModuleType | None = None) -> dict[str, Model]: """ Recursively scans a module for all PeeWee model classes. Defaults to `common.entities` but can scan any module. diff --git a/common/peewee_extensions/fields.py b/common/peewee_extensions/fields.py index e229af8..d26d203 100644 --- a/common/peewee_extensions/fields.py +++ b/common/peewee_extensions/fields.py @@ -3,11 +3,11 @@ import logging import re import socket -from datetime import datetime, timezone +from datetime import datetime, UTC from enum import Enum from json import dumps as json_dumps from json import loads as json_loads -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from re import Pattern from unicodedata import normalize from zoneinfo import ZoneInfo, ZoneInfoNotFoundError @@ -28,7 +28,7 @@ def __init__(self, null: bool = False, defaults_to_now: bool = False, **kwargs: if defaults_to_now: # This is a convenient method to ensure newly created objects will have a timezone-aware value. It doesn't # affect what is stored in the database. - kwargs["default"] = lambda: datetime.utcnow().replace(tzinfo=timezone.utc) + kwargs["default"] = lambda: datetime.utcnow().replace(tzinfo=UTC) else: kwargs["default"] = None @@ -36,9 +36,9 @@ def __init__(self, null: bool = False, defaults_to_now: bool = False, **kwargs: # we put in. super().__init__(null=null, utc=True, resolution=1_000_000, **kwargs) - def python_value(self, value: Union[int, float]) -> Optional[datetime]: + def python_value(self, value: Union[int, float]) -> datetime | None: if isinstance(ret_val := super().python_value(value), datetime): - return ret_val.replace(tzinfo=timezone.utc) + return ret_val.replace(tzinfo=UTC) else: return None @@ -90,7 +90,7 @@ def db_value(self, value: str) -> str: return db_value -def _enum_value_to_db_value(enum_class: type[Enum], value: Union[str, Enum, None]) -> Optional[str | int]: +def _enum_value_to_db_value(enum_class: type[Enum], value: Union[str, Enum, None]) -> str | int | None: """Converts a value before sending it to the DB.""" if value is None: return None @@ -108,7 +108,7 @@ def _enum_value_to_db_value(enum_class: type[Enum], value: Union[str, Enum, None return value -def _db_value_to_enum_value(enum_class: type[Enum], value: str | int) -> Optional[Enum]: +def _db_value_to_enum_value(enum_class: type[Enum], value: str | int) -> Enum | None: if value: try: return enum_class(value) @@ -125,12 +125,12 @@ def __init__(self, enum_class: type[Enum], **kwargs: Any) -> None: self.enum_class = enum_class super().__init__(**kwargs) - def db_value(self, value: Union[str, Enum, None]) -> Optional[str]: + def db_value(self, value: Union[str, Enum, None]) -> str | None: """Converts a value before sending it to the DB.""" db_value: str = super().db_value(_enum_value_to_db_value(self.enum_class, value)) return db_value - def python_value(self, value: str) -> Optional[Enum]: + def python_value(self, value: str) -> Enum | None: return _db_value_to_enum_value(self.enum_class, super().python_value(value)) @@ -141,12 +141,12 @@ def __init__(self, enum_class: type[Enum], **kwargs: Any) -> None: self.enum_class = enum_class super().__init__(**kwargs) - def db_value(self, value: Union[str, Enum, None]) -> Optional[str]: + def db_value(self, value: Union[str, Enum, None]) -> str | None: """Converts a value before sending it to the DB.""" db_value: str = super().db_value(_enum_value_to_db_value(self.enum_class, value)) return db_value - def python_value(self, value: str) -> Optional[Enum]: + def python_value(self, value: str) -> Enum | None: return _db_value_to_enum_value(self.enum_class, super().python_value(value)) @@ -174,7 +174,7 @@ class JSONStrListField(TextField): def __init__(self, **kwargs: Any) -> None: """Ensure that `default` is always the list function or a function that returns a list.""" - if (default_func := kwargs.get("default", None)) not in (None, list): + if (default_func := kwargs.get("default", None)) and callable(default_func): try: _result = default_func() except Exception as e: @@ -187,7 +187,7 @@ def __init__(self, **kwargs: Any) -> None: kwargs["default"] = list super().__init__(**kwargs) - def db_value(self, value: Optional[list[str]]) -> Optional[str]: + def db_value(self, value: list[str] | None) -> str | None: """Dump a list of strings as a JSON string. Keeps key order consistent.""" if value is not None: if not isinstance(value, list): @@ -199,7 +199,7 @@ def db_value(self, value: Optional[list[str]]) -> Optional[str]: else: return None - def python_value(self, value: Optional[str]) -> Optional[list[str]]: + def python_value(self, value: str | None) -> list[str] | None: """Load the text retrieved from the JSON field into a list.""" if value is not None: try: @@ -223,7 +223,7 @@ class JSONDictListField(TextField): def __init__(self, **kwargs: Any) -> None: """Ensure that `default` is always the list function or a function that returns a list.""" - if (default_func := kwargs.get("default", None)) not in (None, list): + if (default_func := kwargs.get("default", None)) and callable(default_func): try: _result = default_func() except Exception as e: @@ -236,7 +236,7 @@ def __init__(self, **kwargs: Any) -> None: kwargs["default"] = list super().__init__(**kwargs) - def db_value(self, value: Optional[list[dict]]) -> Optional[str]: + def db_value(self, value: list[dict] | None) -> str | None: """Dump a list of strings as a JSON string. Keeps key order consistent.""" if value is not None: if not isinstance(value, list): @@ -248,7 +248,7 @@ def db_value(self, value: Optional[list[dict]]) -> Optional[str]: else: return None - def python_value(self, value: Optional[str]) -> Optional[list[dict]]: + def python_value(self, value: str | None) -> list[dict] | None: """Load the text retrieved from the JSON field into a list.""" if value is not None: try: @@ -272,7 +272,7 @@ def __init__(self, schema: Schema, **kwargs: Any) -> None: self.schema = schema super().__init__(**kwargs) - def db_value(self, value: Any) -> Optional[str]: + def db_value(self, value: Any) -> str | None: if value is None: json_data = "null" else: @@ -282,7 +282,7 @@ def db_value(self, value: Any) -> Optional[str]: raise ValueError(f"The value in '{self.name}' can not be dumped by {self.schema}.") from e return cast(str, super().db_value(json_data)) - def python_value(self, value: Optional[str]) -> Optional[Any]: + def python_value(self, value: str | None) -> Any | None: if value is None or value == "null": return None else: diff --git a/common/peewee_extensions/fixtures.py b/common/peewee_extensions/fixtures.py index 18dc14e..26a440d 100644 --- a/common/peewee_extensions/fixtures.py +++ b/common/peewee_extensions/fixtures.py @@ -4,7 +4,7 @@ from functools import cache from graphlib import CycleError, TopologicalSorter from pathlib import Path -from typing import Any, Optional +from typing import Any from uuid import UUID, uuid4 # TODO: When we move to Python 3.11+ we can switch to importing tomlib and we can remove the tomli requirement from @@ -57,7 +57,7 @@ def generate_table_map() -> dict[str, Model]: return {x._meta.table_name: x for x in model_map.values()} -def dump_results(results: ModelSelect, *, requires_id: Optional[UUID] = None) -> str: +def dump_results(results: ModelSelect, *, requires_id: UUID | None = None) -> str: """Given Peewee query results, generate a fixture dump.""" model_class = results.model rows = [] diff --git a/common/predicate_engine/compilers/utils.py b/common/predicate_engine/compilers/utils.py index 7819ed3..f7e9225 100644 --- a/common/predicate_engine/compilers/utils.py +++ b/common/predicate_engine/compilers/utils.py @@ -17,4 +17,4 @@ def _prefetch_query(query: SelectQuery, *, _queries: list) -> list[Model]: result: list[Model] = query.prefetch(*_queries, prefetch_type=PREFETCH_TYPE.JOIN) return result - return partial(_prefetch_query, _queries=args) + return partial(_prefetch_query, _queries=list(args)) diff --git a/common/predicate_engine/query.py b/common/predicate_engine/query.py index acd1fcd..a59ae50 100644 --- a/common/predicate_engine/query.py +++ b/common/predicate_engine/query.py @@ -6,10 +6,10 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from copy import copy, deepcopy -from datetime import datetime, timezone +from datetime import datetime, UTC from enum import Enum from functools import partial, reduce -from typing import Any, Final, Optional, Union +from typing import Any, Final, Union from collections.abc import Callable, Iterable from ._operators import QueryOperator, get_operator, get_operators, split_operators @@ -33,10 +33,10 @@ def _ensure_utc(dt: datetime) -> datetime: '2006-11-06T10:10:10+00:00' """ if (tzinfo := dt.tzinfo) is None: - return dt.replace(tzinfo=timezone.utc) + return dt.replace(tzinfo=UTC) if tzinfo.utcoffset(dt) is None: - return dt.replace(tzinfo=timezone.utc) - return dt.astimezone(timezone.utc) + return dt.replace(tzinfo=UTC) + return dt.astimezone(UTC) def getattr_recursive(lookup_obj: Any, attr_name: str, *args: Any) -> Any: @@ -58,8 +58,8 @@ def _getattr(obj: Any, attr: str) -> Any: class ConnectorType(Enum): - OR: str = "OR" - AND: str = "AND" + OR = "OR" + AND = "AND" class _Encapsulate(ABC): @@ -76,8 +76,8 @@ def __init__( self, value: object, *, - attr_name: Optional[str] = None, - transform_funcs: Optional[Iterable[Callable[..., Iterable]]] = None, + attr_name: str | None = None, + transform_funcs: Iterable[Callable[..., Iterable]] | None = None, ) -> None: self.wrapped_value = value self.attr_name = attr_name @@ -175,8 +175,8 @@ def __init__( value: object, *, n: int, - attr_name: Optional[str] = None, - transform_funcs: Optional[Iterable[Callable[..., Iterable]]] = None, + attr_name: str | None = None, + transform_funcs: Iterable[Callable[..., Iterable]] | None = None, ) -> None: self.count = n super().__init__(value, attr_name=attr_name, transform_funcs=transform_funcs) @@ -201,7 +201,7 @@ class ATLEAST(_Encapsulate): """ Encapsulate a value to indicate an ATLEAST predicate operation on an iterable. - The match should be successful if ATLEAST the N first valures are matching. + The match should be successful if ATLEAST the N first values are matching. """ __slots__ = ("count",) @@ -214,8 +214,8 @@ def __init__( value: object, *, n: int, - attr_name: Optional[str] = None, - transform_funcs: Optional[Iterable[Callable[..., Iterable]]] = None, + attr_name: str | None = None, + transform_funcs: Iterable[Callable[..., Iterable]] | None = None, ) -> None: self.count = n super().__init__(value, attr_name=attr_name, transform_funcs=transform_funcs) @@ -311,7 +311,7 @@ class R: Encapsulate rules as objects that can then be combined using & and | This is an implementation of a tree node for making expressions which can be used to construct rules of arbitrary - complexity. It is loosely inspired by the Qobject implementation in Django but is object agnostic and not meant for + complexity. It is loosely inspired by the Q-object implementation in Django but is object agnostic and not meant for an ORM. """ @@ -329,7 +329,7 @@ def __init__(self, **kwargs: object) -> None: @classmethod def _new_instance( - cls, children: Optional[list] = None, conn_type: ConnectorType = ConnectorType.AND, negated: bool = False + cls, children: list | None = None, conn_type: ConnectorType = ConnectorType.AND, negated: bool = False ) -> R: """ Creates a new instance of this class. diff --git a/common/schemas/fields/cron_expr_str.py b/common/schemas/fields/cron_expr_str.py index e365bb1..586476b 100644 --- a/common/schemas/fields/cron_expr_str.py +++ b/common/schemas/fields/cron_expr_str.py @@ -1,6 +1,6 @@ __all__ = ["CronExpressionStr"] -from typing import Any, Optional +from typing import Any from collections.abc import Mapping from marshmallow import ValidationError @@ -16,7 +16,7 @@ class CronExpressionStr(Str): It validates against what ApScheduler's CronTrigger expects. """ - def _deserialize(self, value: Any, attr: Optional[str], data: Optional[Mapping[str, Any]], **kwargs: object) -> Any: + def _deserialize(self, value: Any, attr: str | None, data: Mapping[str, Any] | None, **kwargs: object) -> Any: str_value = super(Str, self)._deserialize(value, attr, data, **kwargs) if errors := validate_cron_expression(str_value): raise ValidationError(" ".join(errors)) diff --git a/common/schemas/fields/enum_str.py b/common/schemas/fields/enum_str.py index 87fa98f..7dc52ba 100644 --- a/common/schemas/fields/enum_str.py +++ b/common/schemas/fields/enum_str.py @@ -1,7 +1,7 @@ __all__ = ["EnumStr"] from enum import Enum, EnumMeta -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from collections.abc import Iterable from marshmallow.utils import ensure_text_type @@ -30,7 +30,7 @@ def __init__(self, enum: Union[EnumMeta, list], **kwargs: object) -> None: super().__init__(validate=OneOf(allowed_values), **kwargs) # type: ignore[arg-type] - def _serialize(self, value: Any, attr: Optional[str], obj: Any, **kwargs: object) -> Optional[str]: + def _serialize(self, value: Any, attr: str | None, obj: Any, **kwargs: object) -> str | None: if value is None: return None if isinstance(value, Enum): diff --git a/common/schemas/fields/normalized_str.py b/common/schemas/fields/normalized_str.py index 4099162..48c3867 100644 --- a/common/schemas/fields/normalized_str.py +++ b/common/schemas/fields/normalized_str.py @@ -1,6 +1,6 @@ __all__ = ["NormalizedStr", "strip_upper_underscore"] -from typing import Any, Optional +from typing import Any from collections.abc import Callable, Mapping from marshmallow.fields import Str @@ -23,13 +23,13 @@ def __init__(self, normalizer: Callable[[str], str] = str.upper, **kwargs: Any): self.normalizer_func = normalizer super().__init__(**kwargs) - def _serialize(self, value: Any, attr: Optional[str], obj: Any, **kwargs: object) -> Optional[str]: + def _serialize(self, value: Any, attr: str | None, obj: Any, **kwargs: object) -> str | None: if value is None: return None str_field = ensure_text_type(value) return self.normalizer_func(str_field) - def _deserialize(self, value: Any, attr: Optional[str], data: Optional[Mapping[str, Any]], **kwargs: object) -> Any: + def _deserialize(self, value: Any, attr: str | None, data: Mapping[str, Any] | None, **kwargs: object) -> Any: if not isinstance(value, str | bytes): raise self.make_error("invalid") try: diff --git a/common/schemas/fields/zoneinfo.py b/common/schemas/fields/zoneinfo.py index e03b2da..1de173e 100644 --- a/common/schemas/fields/zoneinfo.py +++ b/common/schemas/fields/zoneinfo.py @@ -1,7 +1,7 @@ __all__ = ["ZoneInfo"] import zoneinfo -from typing import Any, Optional +from typing import Any from collections.abc import Mapping from marshmallow.fields import Str @@ -16,7 +16,7 @@ class ZoneInfo(Str): default_error_messages = {"invalid_timezone": INVALID_TIMEZONE} - def _deserialize(self, value: Any, attr: Optional[str], data: Optional[Mapping[str, Any]], **kwargs: object) -> Any: + def _deserialize(self, value: Any, attr: str | None, data: Mapping[str, Any] | None, **kwargs: object) -> Any: str_value = super(Str, self)._deserialize(value, attr, data, **kwargs) # Given ZoneInfo accepts a filesystem type as its constructor argument and we don't want to accept paths as # values for ZoneInfo fields, we validate the input before trying to build the ZoneInfo object diff --git a/common/schemas/filter_schemas.py b/common/schemas/filter_schemas.py index fd25acd..d2a0d4e 100644 --- a/common/schemas/filter_schemas.py +++ b/common/schemas/filter_schemas.py @@ -1,5 +1,5 @@ -from datetime import datetime, timezone -from typing import Any, Optional, Union +from datetime import datetime, UTC +from typing import Any, Union from marshmallow import EXCLUDE, Schema, ValidationError, pre_load, validates_schema from marshmallow.fields import UUID, AwareDateTime, Boolean, Enum, List, Str @@ -11,7 +11,7 @@ class FiltersSchema(Schema): def validate_time_range( - self, range_begin: Optional[datetime], range_end: Optional[datetime], range_begin_name: str + self, range_begin: datetime | None, range_end: datetime | None, range_begin_name: str ) -> None: if range_begin is None or range_end is None: return None @@ -43,7 +43,7 @@ class InstanceFiltersSchema(FiltersSchema): start_range_begin = AwareDateTime( allow_none=True, format="iso", - default_timezone=timezone.utc, + default_timezone=UTC, metadata={ "description": "Optional. An ISO8601 datetime. If specified, The result will only include instances with a " "start_time field equal or past the given datetime. May be specified with start_range_end " @@ -53,7 +53,7 @@ class InstanceFiltersSchema(FiltersSchema): start_range_end = AwareDateTime( allow_none=True, format="iso", - default_timezone=timezone.utc, + default_timezone=UTC, metadata={ "description": "Optional. An ISO8601 datetime. If specified, the result will only contain instances with a " "start_time field before the given datetime. May be specified with start_range_begin to create " @@ -63,7 +63,7 @@ class InstanceFiltersSchema(FiltersSchema): end_range_begin = AwareDateTime( allow_none=True, format="iso", - default_timezone=timezone.utc, + default_timezone=UTC, metadata={ "description": "Optional. An ISO8601 datetime. If specified, The result will only include instances with an " "end_time field equal or past the given datetime. May be specified with end_range_end to create a range." @@ -72,7 +72,7 @@ class InstanceFiltersSchema(FiltersSchema): end_range_end = AwareDateTime( allow_none=True, format="iso", - default_timezone=timezone.utc, + default_timezone=UTC, metadata={ "description": "Optional. An ISO8601 datetime. If specified, The result will only include instances with an " "end_time field before the given datetime. May be specified with end_range_begin to create a range." diff --git a/common/schemas/validators/regexp.py b/common/schemas/validators/regexp.py index 87af20d..81631d3 100644 --- a/common/schemas/validators/regexp.py +++ b/common/schemas/validators/regexp.py @@ -1,7 +1,6 @@ __all__ = ["IsRegexp"] import re -from typing import Optional from marshmallow import ValidationError from marshmallow.validate import Validator @@ -14,7 +13,7 @@ class IsRegexp(Validator): message_invalid = "Invalid regular expression" - def __init__(self, *, error: Optional[str] = None) -> None: + def __init__(self, *, error: str | None = None) -> None: self.error: str = error or self.message_invalid def __call__(self, value: str) -> str: diff --git a/common/tests/integration/entities/test_alerts.py b/common/tests/integration/entities/test_alerts.py index cdbda7d..d4afe3c 100644 --- a/common/tests/integration/entities/test_alerts.py +++ b/common/tests/integration/entities/test_alerts.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC import pytest @@ -94,7 +94,7 @@ def test_run_alert_expected_start_time_get(pipeline_run): run_alert = RunAlert(run=pipeline_run, name="A", description="A", level=AlertLevel.ERROR, type="invalid") assert run_alert.expected_start_time is None run_alert.details["expected_start_time"] = 1123922544.0 - assert datetime(2005, 8, 13, 8, 42, 24, tzinfo=timezone.utc) == run_alert.expected_start_time + assert datetime(2005, 8, 13, 8, 42, 24, tzinfo=UTC) == run_alert.expected_start_time @pytest.mark.integration @@ -117,7 +117,7 @@ def test_run_alert_expected_end_time_get(pipeline_run): run_alert = RunAlert(run=pipeline_run, name="A", description="A", level=AlertLevel.ERROR, type="invalid") assert run_alert.expected_end_time is None # No details have been added yet run_alert.details["expected_end_time"] = 1123922544.0 - assert datetime(2005, 8, 13, 8, 42, 24, tzinfo=timezone.utc) == run_alert.expected_end_time + assert datetime(2005, 8, 13, 8, 42, 24, tzinfo=UTC) == run_alert.expected_end_time @pytest.mark.integration @@ -134,8 +134,8 @@ def test_run_alert_expected_times_naive(pipeline_run): run_alert.expected_start_time = datetime(2005, 8, 13, 8, 42, 24) run_alert.expected_end_time = datetime(2005, 8, 13, 8, 55, 24) - assert run_alert.expected_start_time == datetime(2005, 8, 13, 8, 42, 24, tzinfo=timezone.utc) - assert run_alert.expected_end_time == datetime(2005, 8, 13, 8, 55, 24, tzinfo=timezone.utc) + assert run_alert.expected_start_time == datetime(2005, 8, 13, 8, 42, 24, tzinfo=UTC) + assert run_alert.expected_end_time == datetime(2005, 8, 13, 8, 55, 24, tzinfo=UTC) @pytest.mark.integration @@ -153,7 +153,7 @@ def test_instance_alert_expected_start_time_get(instance): instance=instance, name="A", description="A", message="A", level=AlertLevel.WARNING, type="invalid" ) instance_alert.details["expected_start_time"] = 1123922544.0 - assert datetime(2005, 8, 13, 8, 42, 24, tzinfo=timezone.utc) == instance_alert.expected_start_time + assert datetime(2005, 8, 13, 8, 42, 24, tzinfo=UTC) == instance_alert.expected_start_time @pytest.mark.integration @@ -171,7 +171,7 @@ def test_instance_alert_expected_end_time_get(instance): instance=instance, name="A", description="A", message="A", level=AlertLevel.WARNING, type="invalid" ) instance_alert.details["expected_end_time"] = 1123922544.0 - assert datetime(2005, 8, 13, 8, 42, 24, tzinfo=timezone.utc) == instance_alert.expected_end_time + assert datetime(2005, 8, 13, 8, 42, 24, tzinfo=UTC) == instance_alert.expected_end_time @pytest.mark.integration @@ -183,8 +183,8 @@ def test_instance_alert_expected_times_naive(instance): instance_alert.expected_start_time = datetime(2005, 8, 13, 8, 42, 24) instance_alert.expected_end_time = datetime(2005, 8, 13, 8, 55, 24) - assert instance_alert.expected_start_time == datetime(2005, 8, 13, 8, 42, 24, tzinfo=timezone.utc) - assert instance_alert.expected_end_time == datetime(2005, 8, 13, 8, 55, 24, tzinfo=timezone.utc) + assert instance_alert.expected_start_time == datetime(2005, 8, 13, 8, 42, 24, tzinfo=UTC) + assert instance_alert.expected_end_time == datetime(2005, 8, 13, 8, 55, 24, tzinfo=UTC) @pytest.mark.integration diff --git a/common/tests/integration/entities/test_runs.py b/common/tests/integration/entities/test_runs.py index 8789b63..519a10b 100644 --- a/common/tests/integration/entities/test_runs.py +++ b/common/tests/integration/entities/test_runs.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta, timezone, UTC import pytest from peewee import IntegrityError @@ -29,7 +29,7 @@ def test_add_pipeline_run_listening(pipeline): run.save() # After adding non-RUNNING status, listening should be False - run.end_time = datetime.utcnow().replace(tzinfo=timezone.utc) + timedelta(days=3) + run.end_time = datetime.utcnow().replace(tzinfo=UTC) + timedelta(days=3) run.status = "COMPLETED" run.save() diff --git a/common/tests/integration/entity_services/conftest.py b/common/tests/integration/entity_services/conftest.py index 342206f..d101b1b 100644 --- a/common/tests/integration/entity_services/conftest.py +++ b/common/tests/integration/entity_services/conftest.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC import pytest @@ -26,7 +26,7 @@ def pipeline_4(test_db, project): @pytest.fixture def current_time() -> datetime: - yield datetime.now(timezone.utc) + yield datetime.now(UTC) @pytest.fixture() @@ -61,7 +61,7 @@ def test_outcomes_instance(test_db, run, pipeline, instance_instance_set): @pytest.fixture def test_outcomes_event(project, pipeline, run, event_data): - timestamp = datetime.now(timezone.utc).isoformat() + timestamp = datetime.now(UTC).isoformat() data = { "event_type": TestOutcomesEvent.__name__, "test_outcomes": [ diff --git a/common/tests/integration/entity_services/test_project_service.py b/common/tests/integration/entity_services/test_project_service.py index f0efc22..aa038ac 100644 --- a/common/tests/integration/entity_services/test_project_service.py +++ b/common/tests/integration/entity_services/test_project_service.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta, timezone, UTC from typing import Optional from uuid import uuid4 @@ -47,7 +47,7 @@ def local_test_db(test_db): def _add_runs( - pipeline, instance, number_of_runs: int, current_time: datetime, *, expected_start_time: Optional[datetime] = None + pipeline, instance, number_of_runs: int, current_time: datetime, *, expected_start_time: datetime | None = None ): instance_set = InstanceSet.get_or_create([instance.id]) for key in range(1, number_of_runs + 1): @@ -158,19 +158,19 @@ def test_get_runs_with_rules_coalesce_sort(pipeline, instance, patched_instance_ err_level = AlertLevel["ERROR"].value late_type = RunAlertType["LATE_END"].value - r1_expected_start = datetime(2023, 5, 25, 7, 44, 1, tzinfo=timezone.utc) + r1_expected_start = datetime(2023, 5, 25, 7, 44, 1, tzinfo=UTC) r1 = Run.create( key="coalesce-run-1", pipeline=pipeline, instance_set=instance_set, expected_start_time=r1_expected_start, - start_time=datetime(2023, 5, 25, 7, 44, 6, tzinfo=timezone.utc), - end_time=datetime(2023, 5, 25, 7, 45, 6, tzinfo=timezone.utc), + start_time=datetime(2023, 5, 25, 7, 44, 6, tzinfo=UTC), + end_time=datetime(2023, 5, 25, 7, 45, 6, tzinfo=UTC), status=RunStatus.COMPLETED.name, ) RunAlert.create(name="CA1", description="CD1", level=err_level, type=late_type, run=r1) - r2_expected_start = datetime(2023, 5, 25, 7, 45, 10, tzinfo=timezone.utc) + r2_expected_start = datetime(2023, 5, 25, 7, 45, 10, tzinfo=UTC) r2 = Run.create( pipeline=pipeline, instance_set=instance_set, @@ -181,14 +181,14 @@ def test_get_runs_with_rules_coalesce_sort(pipeline, instance, patched_instance_ ) RunAlert.create(name="CA2", description="CD2", level=err_level, type=late_type, run=r2) - r3_expected_start = datetime(2023, 5, 25, 7, 46, 1, tzinfo=timezone.utc) + r3_expected_start = datetime(2023, 5, 25, 7, 46, 1, tzinfo=UTC) r3 = Run.create( key="coalesce-run-3", pipeline=pipeline, instance_set=instance_set, expected_start_time=r3_expected_start, - start_time=datetime(2023, 5, 25, 7, 46, 22, tzinfo=timezone.utc), - end_time=datetime(2023, 5, 25, 7, 49, 12, tzinfo=timezone.utc), + start_time=datetime(2023, 5, 25, 7, 46, 22, tzinfo=UTC), + end_time=datetime(2023, 5, 25, 7, 49, 12, tzinfo=UTC), status=RunStatus.COMPLETED_WITH_WARNINGS.name, ) RunAlert.create(name="CA3", description="CD3", level=err_level, type=late_type, run=r3) diff --git a/common/tests/integration/entity_services/test_upcoming_instance_services.py b/common/tests/integration/entity_services/test_upcoming_instance_services.py index de26d7a..5fb0a3e 100644 --- a/common/tests/integration/entity_services/test_upcoming_instance_services.py +++ b/common/tests/integration/entity_services/test_upcoming_instance_services.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta, timezone, UTC from uuid import uuid4 import pytest @@ -282,7 +282,7 @@ def test_get_upcoming_instances_with_rules_discard_matching_existing_instance( instance_rule_end, current_time, ): - base_time = datetime(2023, 8, 21, 10, 0, 0, tzinfo=timezone.utc) + base_time = datetime(2023, 8, 21, 10, 0, 0, tzinfo=UTC) instance.start_time = base_time instance.save() instance_rule_end.expression = "30 * * * *" @@ -307,7 +307,7 @@ def test_get_upcoming_instances_with_rules_do_not_discard_upcoming_instance( instance_rule_end, current_time, ): - base_time = datetime(2023, 8, 21, 10, 0, 0, tzinfo=timezone.utc) + base_time = datetime(2023, 8, 21, 10, 0, 0, tzinfo=UTC) instance.start_time = base_time instance.save() instance_rule_start.expression = "5 * * * *" diff --git a/common/tests/integration/test_apscheduler_extensions.py b/common/tests/integration/test_apscheduler_extensions.py index 7b9d463..248ba33 100644 --- a/common/tests/integration/test_apscheduler_extensions.py +++ b/common/tests/integration/test_apscheduler_extensions.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta, timezone, UTC from itertools import count import pytest @@ -12,7 +12,7 @@ def calculate_fire_time_sequence(trigger, n=10): prev_fire_time = None - now = datetime.now(tz=getattr(trigger, "timezone", timezone.utc)) + now = datetime.now(tz=getattr(trigger, "timezone", UTC)) for _ in range(n): next_fire_time = trigger.get_next_fire_time(prev_fire_time, now) yield next_fire_time @@ -24,21 +24,21 @@ def calculate_fire_time_sequence(trigger, n=10): @pytest.mark.parametrize( "trigger", ( - CronTrigger.from_crontab("*/2 * * * *", timezone=timezone.utc), - CronTrigger.from_crontab("*/4 * * * *", timezone=timezone.utc), + CronTrigger.from_crontab("*/2 * * * *", timezone=UTC), + CronTrigger.from_crontab("*/4 * * * *", timezone=UTC), CronTrigger( year="*", month="*", day="*", hour="*", minute="*/4", - timezone=timezone.utc, - start_date=datetime.now(tz=timezone.utc) + timedelta(days=5), + timezone=UTC, + start_date=datetime.now(tz=UTC) + timedelta(days=5), ), CronTrigger.from_crontab("*/4 * * * *", timezone=astimezone("Asia/Tokyo")), - IntervalTrigger(minutes=1, timezone=timezone.utc), - IntervalTrigger(minutes=6, timezone=timezone.utc), - DateTrigger(timezone=timezone.utc), + IntervalTrigger(minutes=1, timezone=UTC), + IntervalTrigger(minutes=6, timezone=UTC), + DateTrigger(timezone=UTC), ), ids=( "cron_smaller_interval", @@ -73,7 +73,7 @@ def test_delayed_trigger_3_min_delay(trigger): ), ) def test_delayed_trigger(delay): - cron_trigger = CronTrigger.from_crontab("*/2 * * * *", timezone=timezone.utc) + cron_trigger = CronTrigger.from_crontab("*/2 * * * *", timezone=UTC) delayed_trigger = DelayedTrigger(cron_trigger, delay) for idx, original, delayed in zip( diff --git a/common/tests/unit/actions/test_webhook_action.py b/common/tests/unit/actions/test_webhook_action.py index 6458cb1..7f5dd01 100644 --- a/common/tests/unit/actions/test_webhook_action.py +++ b/common/tests/unit/actions/test_webhook_action.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC from unittest.mock import Mock, patch import pytest @@ -19,7 +19,7 @@ def session(): @pytest.fixture def test_outcome_item_data(): - timestamp = datetime.now(timezone.utc).isoformat() + timestamp = datetime.now(UTC).isoformat() return { "name": "My_test_name", "status": TestStatuses.PASSED.name, diff --git a/common/tests/unit/entities/test_journey_dag.py b/common/tests/unit/entities/test_journey_dag.py index 5edbd3d..d6abe82 100644 --- a/common/tests/unit/entities/test_journey_dag.py +++ b/common/tests/unit/entities/test_journey_dag.py @@ -10,7 +10,7 @@ @dataclass class FakeEdge: - left: Optional[str] + left: str | None right: str def __hash__(self): diff --git a/common/tests/unit/entity_services/helpers/test_filter_rules.py b/common/tests/unit/entity_services/helpers/test_filter_rules.py index c6547ea..fb51cba 100644 --- a/common/tests/unit/entity_services/helpers/test_filter_rules.py +++ b/common/tests/unit/entity_services/helpers/test_filter_rules.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC from uuid import uuid4 import pytest @@ -58,8 +58,8 @@ def test_from_params_start(): # We want to see that arrow is giving us the kind of timestamp we think; i.e., a UTC aware # timestamp. - assert filters.start_range_begin == datetime(year=2022, month=8, day=16, tzinfo=timezone.utc) - assert filters.start_range_end == datetime(year=2022, month=8, day=17, tzinfo=timezone.utc) + assert filters.start_range_begin == datetime(year=2022, month=8, day=16, tzinfo=UTC) + assert filters.start_range_end == datetime(year=2022, month=8, day=17, tzinfo=UTC) assert filters.end_range_begin is None assert filters.end_range_end is None assert bool(filters) @@ -102,8 +102,8 @@ def test_run_filters_from_parameters_end(): [("page", "5"), ("count", "25"), ("end_range_begin", "2022-08-16"), ("end_range_end", "2022-08-17")] ) filters = RunFilters.from_params(end) - assert filters.end_range_begin == datetime(year=2022, month=8, day=16, tzinfo=timezone.utc) - assert filters.end_range_end == datetime(year=2022, month=8, day=17, tzinfo=timezone.utc) + assert filters.end_range_begin == datetime(year=2022, month=8, day=16, tzinfo=UTC) + assert filters.end_range_end == datetime(year=2022, month=8, day=17, tzinfo=UTC) assert filters.start_range_begin is None assert filters.start_range_end is None assert filters.pipeline_keys == [] diff --git a/common/tests/unit/events/v1/test_base_events.py b/common/tests/unit/events/v1/test_base_events.py index 81d5433..7857abe 100644 --- a/common/tests/unit/events/v1/test_base_events.py +++ b/common/tests/unit/events/v1/test_base_events.py @@ -1,5 +1,5 @@ import uuid -from datetime import timedelta, timezone +from datetime import timedelta, timezone, UTC import pytest from marshmallow import ValidationError @@ -127,8 +127,8 @@ def test_event_with_batch_pipeline_component_missing_run_key_error(valid_event_d @pytest.mark.parametrize( ["timestamp", "tz"], [ - ("2018-07-25T00:00:00Z", timezone.utc), - ("2018-07-25T00:00:00", timezone.utc), + ("2018-07-25T00:00:00Z", UTC), + ("2018-07-25T00:00:00", UTC), ("2014-12-22T03:12:58.019077+06:00", timezone(timedelta(hours=6))), ], ids=["ZuluTime", "Naive", "TZ offset"], diff --git a/common/tests/unit/events/v1/test_testoutcomes.py b/common/tests/unit/events/v1/test_testoutcomes.py index 30ae353..f862e35 100644 --- a/common/tests/unit/events/v1/test_testoutcomes.py +++ b/common/tests/unit/events/v1/test_testoutcomes.py @@ -38,9 +38,9 @@ def test_testoutcomes_schema_with_testgen_integration(test_outcomes_testgen_even assert item_integration_event.test_suite == item_integration_data["test_suite"] assert item_integration_event.version == item_integration_data["version"] assert len(item_integration_event.test_parameters) == len(item_integration_data["test_parameters"]) - assert ( - type(item_integration_event.test_parameters[0].value) == Decimal - ), "expected dataclass's value to be Decimal type" + assert type(item_integration_event.test_parameters[0].value) == Decimal, ( + "expected dataclass's value to be Decimal type" + ) assert str(item_integration_event.test_parameters[0].value) == item_integration_data["test_parameters"][0]["value"] assert item_integration_event.test_parameters[0].name == item_integration_data["test_parameters"][0]["name"] assert item_integration_event.columns == item_integration_data["columns"] diff --git a/common/tests/unit/events/v2/test_test_outcomes.py b/common/tests/unit/events/v2/test_test_outcomes.py index c8bc638..3b23371 100644 --- a/common/tests/unit/events/v2/test_test_outcomes.py +++ b/common/tests/unit/events/v2/test_test_outcomes.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC from decimal import Decimal import pytest @@ -172,8 +172,8 @@ def test_test_outcomes( for actual, expected in zip(res.test_outcomes, test_outcome_items): assert actual.name == expected.name assert actual.description == expected.description - assert actual.start_time == datetime.fromisoformat(expected.start_time).replace(tzinfo=timezone.utc) - assert actual.end_time == datetime.fromisoformat(expected.end_time).replace(tzinfo=timezone.utc) + assert actual.start_time == datetime.fromisoformat(expected.start_time).replace(tzinfo=UTC) + assert actual.end_time == datetime.fromisoformat(expected.end_time).replace(tzinfo=UTC) assert actual.metric_value == expected.metric_value assert actual.metric_name == expected.metric_name assert actual.metric_description == expected.metric_description @@ -201,9 +201,9 @@ def test_testoutcomes_schema_with_testgen_integration(test_outcomes_testgen_data assert item_integration_event.test_suite == item_integration_data["test_suite"] assert item_integration_event.version == item_integration_data["version"] assert len(item_integration_event.test_parameters) == len(item_integration_data["test_parameters"]) - assert ( - type(item_integration_event.test_parameters[0].value) == Decimal - ), "expected dataclass's value to be Decimal type" + assert type(item_integration_event.test_parameters[0].value) == Decimal, ( + "expected dataclass's value to be Decimal type" + ) assert str(item_integration_event.test_parameters[0].value) == item_integration_data["test_parameters"][0]["value"] assert item_integration_event.test_parameters[0].name == item_integration_data["test_parameters"][0]["name"] assert item_integration_event.columns == item_integration_data["columns"] diff --git a/common/tests/unit/flask_ext/test_jwt_plugin.py b/common/tests/unit/flask_ext/test_jwt_plugin.py index 47631c3..23ab982 100644 --- a/common/tests/unit/flask_ext/test_jwt_plugin.py +++ b/common/tests/unit/flask_ext/test_jwt_plugin.py @@ -1,7 +1,7 @@ import json from base64 import b64decode, b64encode from binascii import Error as B64DecodeError -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta, timezone, UTC from unittest.mock import Mock, patch from uuid import uuid4 @@ -50,7 +50,7 @@ def expired_token(): @pytest.fixture def current_token(): data = TOKEN_DATA.copy() - dt = datetime.now(timezone.utc) + timedelta(days=2) + dt = datetime.now(UTC) + timedelta(days=2) data["exp"] = int(dt.replace(microsecond=0).timestamp()) return encode(data, key=JWT_KEY) @@ -64,8 +64,8 @@ def user(): @pytest.mark.parametrize( ("ts_value", "dt_value"), [ - (40, datetime(1970, 1, 1, 0, 0, 40, tzinfo=timezone.utc)), - (435474000, datetime(1983, 10, 20, 5, 0, tzinfo=timezone.utc)), + (40, datetime(1970, 1, 1, 0, 0, 40, tzinfo=UTC)), + (435474000, datetime(1983, 10, 20, 5, 0, tzinfo=UTC)), ], ) def test_get_expiration_int(ts_value, dt_value): diff --git a/common/tests/unit/peewee_extensions/test_peewee_extensions.py b/common/tests/unit/peewee_extensions/test_peewee_extensions.py index 77783c3..f3e76ad 100644 --- a/common/tests/unit/peewee_extensions/test_peewee_extensions.py +++ b/common/tests/unit/peewee_extensions/test_peewee_extensions.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC from enum import Enum, IntEnum from zoneinfo import ZoneInfo @@ -82,7 +82,7 @@ def test_domain_field_lowercase(): @pytest.mark.unit def test_timestamp_to_utc(): """UTCTimestampField.python_value returns timezone aware values.""" - expected_dt = datetime.now(timezone.utc) + expected_dt = datetime.now(UTC) f_inst = UTCTimestampField() db_value = f_inst.db_value(expected_dt) result = f_inst.python_value(db_value) diff --git a/common/tests/unit/predicate_engine/assertions.py b/common/tests/unit/predicate_engine/assertions.py index 625bf62..13d6f2f 100644 --- a/common/tests/unit/predicate_engine/assertions.py +++ b/common/tests/unit/predicate_engine/assertions.py @@ -42,7 +42,7 @@ def _to_str(value): raise AssertionError(f"Rules differ: \n\t{str_a}\n\t!=\n\t{str_b}") -def assertRuleMatches(a: R, b: Any, msg: Optional[str] = None): +def assertRuleMatches(a: R, b: Any, msg: str | None = None): """Assert that an R object matches a given value.""" try: result = a.matches(b) @@ -65,7 +65,7 @@ def assertRuleMatches(a: R, b: Any, msg: Optional[str] = None): ) -def assertRuleNotMatches(a: R, b: Any, msg: Optional[str] = None): +def assertRuleNotMatches(a: R, b: Any, msg: str | None = None): try: result = a.matches(b) except Exception: diff --git a/common/tests/unit/predicate_engine/conftest.py b/common/tests/unit/predicate_engine/conftest.py index 09234c5..f0d34f1 100644 --- a/common/tests/unit/predicate_engine/conftest.py +++ b/common/tests/unit/predicate_engine/conftest.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC import pytest @@ -33,7 +33,7 @@ def timestamp(self) -> datetime: @property def timestamp_dt(self) -> datetime: - return datetime(1983, 10, 20, 10, 10, 10, tzinfo=timezone.utc) + return datetime(1983, 10, 20, 10, 10, 10, tzinfo=UTC) @pytest.fixture(scope="session") diff --git a/common/tests/unit/predicate_engine/test_predicate_engine.py b/common/tests/unit/predicate_engine/test_predicate_engine.py index 99d111d..437769f 100644 --- a/common/tests/unit/predicate_engine/test_predicate_engine.py +++ b/common/tests/unit/predicate_engine/test_predicate_engine.py @@ -2,7 +2,7 @@ import sys from collections.abc import MutableMapping from copy import copy -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC from unicodedata import category import pytest @@ -193,7 +193,7 @@ def test_matching_invalid_data_types(rule, simple_entity): @pytest.mark.parametrize( "rule", ( - R(timestamp__gte=datetime.now(timezone.utc)), + R(timestamp__gte=datetime.now(UTC)), R(timestamp_dt__gte=datetime.now()), ), ) diff --git a/common/tests/unit/test_apscheduler_extensions.py b/common/tests/unit/test_apscheduler_extensions.py index 0c07f37..14f1e3c 100644 --- a/common/tests/unit/test_apscheduler_extensions.py +++ b/common/tests/unit/test_apscheduler_extensions.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta, timezone, UTC from unittest.mock import patch from zoneinfo import ZoneInfo @@ -15,7 +15,7 @@ @pytest.mark.unit def test_delayed_trigger_negative(): - cron_trigger = CronTrigger.from_crontab("*/2 * * * *", timezone=timezone.utc) + cron_trigger = CronTrigger.from_crontab("*/2 * * * *", timezone=UTC) delay = timedelta(hours=-1) with pytest.raises(ValueError, match="positive"): DelayedTrigger(cron_trigger, delay) @@ -102,8 +102,8 @@ def test_fix_weekdays(weekday_expression, expected): @pytest.mark.unit def test_get_crontab_trigger_times_finite_range(): - start = datetime(2000, 1, 1, 0, 0, 0, tzinfo=timezone.utc) - end = datetime(2000, 1, 10, 0, 0, 0, tzinfo=timezone.utc) + start = datetime(2000, 1, 1, 0, 0, 0, tzinfo=UTC) + end = datetime(2000, 1, 10, 0, 0, 0, tzinfo=UTC) for i, time in enumerate(get_crontab_trigger_times("0 10 * * *", ZoneInfo("UTC"), start, end)): assert time == start + timedelta(days=i, hours=10) assert i == 8 @@ -111,7 +111,7 @@ def test_get_crontab_trigger_times_finite_range(): @pytest.mark.unit def test_get_crontab_trigger_times_infinite_range(): - start = datetime(2000, 1, 1, 0, 0, 0, tzinfo=timezone.utc) + start = datetime(2000, 1, 1, 0, 0, 0, tzinfo=UTC) gen = get_crontab_trigger_times("0 10 * * *", ZoneInfo("UTC"), start) for i in range(5000): assert next(gen) == start + timedelta(days=i, hours=10) @@ -122,7 +122,7 @@ def test_get_crontab_trigger_times_infinite_range(): "start,end", ( (datetime(2000, 1, 1, 0, 0, 0), None), - (datetime(2000, 1, 1, 0, 0, 0, tzinfo=timezone.utc), datetime(2000, 1, 2, 0, 0, 0)), + (datetime(2000, 1, 1, 0, 0, 0, tzinfo=UTC), datetime(2000, 1, 2, 0, 0, 0)), ), ) def test_get_crontab_trigger_times_invalid_range(start, end): diff --git a/common/tests/unit/test_datetime_utils.py b/common/tests/unit/test_datetime_utils.py index 3ceec35..060fbd8 100644 --- a/common/tests/unit/test_datetime_utils.py +++ b/common/tests/unit/test_datetime_utils.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC from zoneinfo import ZoneInfo, available_timezones import pytest @@ -26,10 +26,10 @@ def test_datetime_iso8601(): @pytest.mark.parametrize( "ts, dt", ( - (1685701704.039912, datetime(2023, 6, 2, 10, 28, 24, 39912, tzinfo=timezone.utc)), - (1123905724.000002, datetime(2005, 8, 13, 4, 2, 4, 2, tzinfo=timezone.utc)), - (435489642.424242, datetime(1983, 10, 20, 9, 20, 42, 424242, tzinfo=timezone.utc)), - (1162815179.06429, datetime(2006, 11, 6, 12, 12, 59, 64290, tzinfo=timezone.utc)), + (1685701704.039912, datetime(2023, 6, 2, 10, 28, 24, 39912, tzinfo=UTC)), + (1123905724.000002, datetime(2005, 8, 13, 4, 2, 4, 2, tzinfo=UTC)), + (435489642.424242, datetime(1983, 10, 20, 9, 20, 42, 424242, tzinfo=UTC)), + (1162815179.06429, datetime(2006, 11, 6, 12, 12, 59, 64290, tzinfo=UTC)), ), ) def test_timestamp_and_datetime(ts, dt): @@ -47,7 +47,7 @@ def test_tzinfo_added(): naive_dt = datetime(2006, 11, 6, 12, 12) timestamp = datetime_to_timestamp(naive_dt) - expected_dt = datetime(2006, 11, 6, 12, 12, tzinfo=timezone.utc) + expected_dt = datetime(2006, 11, 6, 12, 12, tzinfo=UTC) actual_dt = timestamp_to_datetime(timestamp) assert actual_dt != naive_dt @@ -65,7 +65,7 @@ def test_tzinfo_coerced_to_utc(): tz_dt = datetime(2001, 1, 1, 0, 0, 0, tzinfo=tzinfo) timestamp = datetime_to_timestamp(tz_dt) - expected_dt = tz_dt.astimezone(timezone.utc) + expected_dt = tz_dt.astimezone(UTC) actual_dt = timestamp_to_datetime(timestamp) assert expected_dt == actual_dt diff --git a/common/tests/unit/test_messagepack.py b/common/tests/unit/test_messagepack.py index ce269d3..aab95bb 100644 --- a/common/tests/unit/test_messagepack.py +++ b/common/tests/unit/test_messagepack.py @@ -1,7 +1,7 @@ from array import array from collections import OrderedDict from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC from decimal import Decimal from io import BytesIO from pathlib import Path, PurePath, PurePosixPath, PureWindowsPath @@ -159,7 +159,7 @@ def test_dump_load_datetime(): @pytest.mark.unit def test_dump_load_datetime_tzinfo(): """Messagepack can dump/load datetime.datetime values and preserve tzinfo.""" - data = datetime.now(timezone.utc) + data = datetime.now(UTC) out_value = loads(dumps(data)) assert data == out_value diff --git a/deploy/search_view_plugin.py b/deploy/search_view_plugin.py index 7cbd348..9559559 100644 --- a/deploy/search_view_plugin.py +++ b/deploy/search_view_plugin.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Optional +from typing import Any from apispec import BasePlugin from apispec.yaml_utils import load_yaml_from_docstring @@ -16,8 +16,8 @@ class SearchViewPlugin(BasePlugin): def operation_helper( self, - path: Optional[str] = None, - operations: Optional[dict] = None, + path: str | None = None, + operations: dict | None = None, **kwargs: Any, ) -> None: view_class = getattr(kwargs.get("view", None), "view_class", None) diff --git a/deploy/subcomponent_plugin.py b/deploy/subcomponent_plugin.py index 85edd1d..1bd4c64 100644 --- a/deploy/subcomponent_plugin.py +++ b/deploy/subcomponent_plugin.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Optional +from typing import Any from apispec import BasePlugin @@ -15,8 +15,8 @@ class SubcomponentPlugin(BasePlugin): def operation_helper( self, - path: Optional[str] = None, - operations: Optional[dict] = None, + path: str | None = None, + operations: dict | None = None, **kwargs: Any, ) -> None: description_dict: dict[str, str] = { @@ -58,7 +58,7 @@ def request_body_helper(self, method: str) -> dict: } return request_body - def parameter_helper(self, parameter: Optional[dict] = None, **kwargs: Any) -> dict: + def parameter_helper(self, parameter: dict | None = None, **kwargs: Any) -> dict: method = kwargs["method"] parameter = {"in": "path", "schema": {"type": "string"}, "required": "true", "name": "component_id"} if method == "post": @@ -66,7 +66,7 @@ def parameter_helper(self, parameter: Optional[dict] = None, **kwargs: Any) -> d parameter["description"] = f"The ID of the project that the {self.subcomponent_name} will be created under." return parameter - def response_helper(self, response: Optional[dict] = None, **kwargs: Any) -> dict: + def response_helper(self, response: dict | None = None, **kwargs: Any) -> dict: method = kwargs["method"] response_desc_dict: dict[str, dict[int, str]] = { "get": { diff --git a/event_api/config/defaults.py b/event_api/config/defaults.py index 6e2dfce..81aba97 100644 --- a/event_api/config/defaults.py +++ b/event_api/config/defaults.py @@ -5,13 +5,12 @@ """ import os -from typing import Optional # Flask specific settings: https://flask.palletsprojects.com/en/latest/config/#builtin-configuration-values from common.entities import Service -PROPAGATE_EXCEPTIONS: Optional[bool] = None -SERVER_NAME: Optional[str] = os.environ.get("EVENTS_API_HOSTNAME") # Use flask defaults if none set +PROPAGATE_EXCEPTIONS: bool | None = None +SERVER_NAME: str | None = os.environ.get("EVENTS_API_HOSTNAME") # Use flask defaults if none set USE_X_SENDFILE: bool = False # If we serve files enable this in production settings when webserver support configured # Application settings diff --git a/event_api/config/local.py b/event_api/config/local.py index aaf5847..f063607 100644 --- a/event_api/config/local.py +++ b/event_api/config/local.py @@ -1,5 +1,3 @@ -from typing import Optional - # Flask specific settings: https://flask.palletsprojects.com/en/latest/config/#builtin-configuration-values -PROPAGATE_EXCEPTIONS: Optional[bool] = True +PROPAGATE_EXCEPTIONS: bool | None = True SECRET_KEY: str = "NOT_VERY_SECRET" diff --git a/event_api/config/minikube.py b/event_api/config/minikube.py index e5ed0f8..6c72c79 100644 --- a/event_api/config/minikube.py +++ b/event_api/config/minikube.py @@ -1,5 +1,3 @@ -from typing import Optional - # Flask specific settings: https://flask.palletsprojects.com/en/latest/config/#builtin-configuration-values -TESTING: Optional[bool] = True +TESTING: bool | None = True SECRET_KEY: str = "NOT_VERY_SECRET" diff --git a/event_api/endpoints/v1/event_view.py b/event_api/endpoints/v1/event_view.py index 825e2fe..ae257ec 100644 --- a/event_api/endpoints/v1/event_view.py +++ b/event_api/endpoints/v1/event_view.py @@ -1,7 +1,6 @@ import logging -from datetime import datetime, timezone +from datetime import datetime, UTC from http import HTTPStatus -from typing import Optional from flask import Response, current_app, g, make_response, request from marshmallow import ValidationError @@ -34,14 +33,14 @@ class EventView(BaseView): event_type: type[Event] """The class (not instance) that is used to deserialize the incoming request body""" - def make_error(self, msg: str, e: Exception, error_code: Optional[int] = None) -> Response: + def make_error(self, msg: str, e: Exception, error_code: int | None = None) -> Response: """TODO: This should be turned into an ErrorHandler at the app level.""" return make_response( { "error": msg, # TODO: Should this be exposed to the user? "details": str(e), - "timestamp": datetime.now(tz=timezone.utc), + "timestamp": datetime.now(tz=UTC), }, error_code if error_code else 500, ) diff --git a/event_api/tests/integration/v1_endpoints/conftest.py b/event_api/tests/integration/v1_endpoints/conftest.py index 8b413e8..41adb91 100644 --- a/event_api/tests/integration/v1_endpoints/conftest.py +++ b/event_api/tests/integration/v1_endpoints/conftest.py @@ -1,7 +1,7 @@ import os import shutil from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC from unittest.mock import MagicMock, patch from uuid import UUID @@ -29,7 +29,7 @@ class DatabaseCtx: @pytest.fixture def predictable_datetime(): - return datetime(2022, 5, 25, 19, 56, 52, 759419, tzinfo=timezone.utc) + return datetime(2022, 5, 25, 19, 56, 52, 759419, tzinfo=UTC) @pytest.fixture diff --git a/event_api/tests/integration/v2_endpoints/conftest.py b/event_api/tests/integration/v2_endpoints/conftest.py index 462feff..e9d6e88 100644 --- a/event_api/tests/integration/v2_endpoints/conftest.py +++ b/event_api/tests/integration/v2_endpoints/conftest.py @@ -1,6 +1,6 @@ import os import shutil -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC from unittest.mock import MagicMock, patch import pytest @@ -18,7 +18,7 @@ @pytest.fixture def event_time(): - return datetime(2022, 5, 25, 19, 56, 52, 759419, tzinfo=timezone.utc) + return datetime(2022, 5, 25, 19, 56, 52, 759419, tzinfo=UTC) @pytest.fixture diff --git a/observability_api/config/defaults.py b/observability_api/config/defaults.py index 5a8a923..05a39c0 100644 --- a/observability_api/config/defaults.py +++ b/observability_api/config/defaults.py @@ -6,13 +6,12 @@ import os from datetime import timedelta -from typing import Optional # Flask specific settings: https://flask.palletsprojects.com/en/latest/config/#builtin-configuration-values from common.entities import Service -PROPAGATE_EXCEPTIONS: Optional[bool] = None -SERVER_NAME: Optional[str] = os.environ.get("OBSERVABILITY_API_HOSTNAME") # Use flask defaults if none set +PROPAGATE_EXCEPTIONS: bool | None = None +SERVER_NAME: str | None = os.environ.get("OBSERVABILITY_API_HOSTNAME") # Use flask defaults if none set USE_X_SENDFILE: bool = False # If we serve files enable this in production settings when webserver support configured # Application settings diff --git a/observability_api/config/local.py b/observability_api/config/local.py index 56bf7e2..b58b69b 100644 --- a/observability_api/config/local.py +++ b/observability_api/config/local.py @@ -1,7 +1,5 @@ -from typing import Optional - # Flask specific settings: https://flask.palletsprojects.com/en/latest/config/#builtin-configuration-values -PROPAGATE_EXCEPTIONS: Optional[bool] = True +PROPAGATE_EXCEPTIONS: bool | None = True SECRET_KEY: str = "NOT_VERY_SECRET" # Application settings diff --git a/observability_api/config/minikube.py b/observability_api/config/minikube.py index bceae82..3b013a0 100644 --- a/observability_api/config/minikube.py +++ b/observability_api/config/minikube.py @@ -1,7 +1,5 @@ -from typing import Optional - # Flask specific settings: https://flask.palletsprojects.com/en/latest/config/#builtin-configuration-values -TESTING: Optional[bool] = True +TESTING: bool | None = True SECRET_KEY: str = "NOT_VERY_SECRET" # Application settings diff --git a/observability_api/config/test.py b/observability_api/config/test.py index 5dc2eb9..d9cedfc 100644 --- a/observability_api/config/test.py +++ b/observability_api/config/test.py @@ -1,7 +1,5 @@ -from typing import Optional - # Flask specific settings: https://flask.palletsprojects.com/en/latest/config/#builtin-configuration-values -PROPAGATE_EXCEPTIONS: Optional[bool] = True +PROPAGATE_EXCEPTIONS: bool | None = True SECRET_KEY: str = "NOT_VERY_SECRET" TESTING: bool = True diff --git a/observability_api/endpoints/component_view.py b/observability_api/endpoints/component_view.py index 859da15..091dc0a 100644 --- a/observability_api/endpoints/component_view.py +++ b/observability_api/endpoints/component_view.py @@ -2,7 +2,7 @@ import logging from http import HTTPStatus -from typing import Any, Optional +from typing import Any from uuid import UUID from flask import Blueprint, Response, make_response @@ -21,7 +21,7 @@ class ComponentByIdAbstractView(BaseEntityView): route: str entity: type[BaseEntity] schema: type[ModelSchema] - patch_schema: Optional[type[ModelSchema]] = None + patch_schema: type[ModelSchema] | None = None def get(self, component_id: UUID) -> Response: component = self.get_entity_or_fail(self.entity, self.entity.id == component_id) diff --git a/observability_api/endpoints/v1/journeys.py b/observability_api/endpoints/v1/journeys.py index 92ba43a..56e3902 100644 --- a/observability_api/endpoints/v1/journeys.py +++ b/observability_api/endpoints/v1/journeys.py @@ -3,7 +3,6 @@ import logging from graphlib import CycleError from http import HTTPStatus -from typing import Optional from uuid import UUID from flask import Response, make_response, request @@ -114,7 +113,7 @@ def get(self, project_id: UUID) -> Response: """ _ = self.get_entity_or_fail(Project, Project.id == project_id) - component_id: Optional[str] = request.args.get("component_id", None) + component_id: str | None = request.args.get("component_id", None) page: Page = ProjectService.get_journeys_with_rules( str(project_id), ListRules.from_params(request.args), component_id=component_id ) diff --git a/observability_api/endpoints/v1/project_settings.py b/observability_api/endpoints/v1/project_settings.py index 702539c..5a28789 100644 --- a/observability_api/endpoints/v1/project_settings.py +++ b/observability_api/endpoints/v1/project_settings.py @@ -1,6 +1,6 @@ __all__ = ["ProjectAlertsSettings"] -from typing import Optional, cast, Any +from typing import cast, Any from uuid import UUID from flask import Response, make_response @@ -23,7 +23,7 @@ class ProjectAlertsSettings(BaseEntityView): PERMISSION_REQUIREMENTS: tuple[Permission, ...] = (PERM_USER, PERM_PROJECT) - _project: Optional[Project] + _project: Project | None def get_request_schema(self) -> Schema: return cast(Schema, Schema.from_dict(self.get_fields(), name=f"{self.__class__.__name__}Schema")()) diff --git a/observability_api/schemas/event_schemas.py b/observability_api/schemas/event_schemas.py index 7b9ac85..192594a 100644 --- a/observability_api/schemas/event_schemas.py +++ b/observability_api/schemas/event_schemas.py @@ -52,8 +52,7 @@ class EventResponseSchema(Schema): required=True, metadata={ "description": ( - "The IDs of the components related to the event. The first item in the list is the primary " - "component. " + "The IDs of the components related to the event. The first item in the list is the primary component. " ) }, ) diff --git a/observability_api/tests/integration/v1_endpoints/conftest.py b/observability_api/tests/integration/v1_endpoints/conftest.py index e1c00d3..f67c09e 100644 --- a/observability_api/tests/integration/v1_endpoints/conftest.py +++ b/observability_api/tests/integration/v1_endpoints/conftest.py @@ -1,7 +1,7 @@ import os import shutil import uuid -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta, timezone, UTC from decimal import Decimal import pytest @@ -275,8 +275,8 @@ def test_outcome(client, instance_instance_set, instance_runs, pipeline): dimensions=["a", "b", "c"], status=TestStatuses.WARNING.name, run=instance_runs[0].id, - start_time=datetime.now(tz=timezone.utc) - timedelta(minutes=15), - end_time=datetime.now(tz=timezone.utc) - timedelta(minutes=5), + start_time=datetime.now(tz=UTC) - timedelta(minutes=15), + end_time=datetime.now(tz=UTC) - timedelta(minutes=5), component=pipeline, instance_set=instance_instance_set.instance_set, external_url="https://example.com", @@ -299,8 +299,8 @@ def test_outcomes(client, instance_instance_set, pipeline): dimension=[f"a-{i}", f"b-{i}", f"c-{i}"], description=f"Description{i}", status=f"{TestStatuses.PASSED.name if i % 2 == 0 else TestStatuses.FAILED.name}", - start_time=datetime.now(tz=timezone.utc) + timedelta(minutes=5 * i), - end_time=datetime.now(tz=timezone.utc) + timedelta(minutes=15 * i), + start_time=datetime.now(tz=UTC) + timedelta(minutes=5 * i), + end_time=datetime.now(tz=UTC) + timedelta(minutes=15 * i), component=pipeline, instance_set=instance_instance_set.instance_set, external_url="https://example.com", diff --git a/observability_api/tests/integration/v1_endpoints/test_alerts.py b/observability_api/tests/integration/v1_endpoints/test_alerts.py index 9f7d59f..094d097 100644 --- a/observability_api/tests/integration/v1_endpoints/test_alerts.py +++ b/observability_api/tests/integration/v1_endpoints/test_alerts.py @@ -1,5 +1,5 @@ import uuid -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta, timezone, UTC from http import HTTPStatus from uuid import uuid4 @@ -107,7 +107,7 @@ def test_list_project_alerts_search(client, g_user, project, instance_alert, ins @pytest.mark.integration def test_list_project_alerts_filters(client, g_user, project, instance_alert, instance_alert_components, run_alerts): - yesterday = datetime.now(tz=timezone.utc) - timedelta(days=1) + yesterday = datetime.now(tz=UTC) - timedelta(days=1) past_instance = Instance.create(journey=instance_alert.instance.journey, start_time=yesterday) past_instance_alert = InstanceAlert.create( id=uuid4(), diff --git a/observability_api/tests/integration/v1_endpoints/test_instances.py b/observability_api/tests/integration/v1_endpoints/test_instances.py index 1492897..82bcdd3 100644 --- a/observability_api/tests/integration/v1_endpoints/test_instances.py +++ b/observability_api/tests/integration/v1_endpoints/test_instances.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta, timezone, UTC from http import HTTPStatus from typing import Optional from uuid import UUID, uuid4 @@ -51,8 +51,8 @@ def test_outcome(client, instance_runs, pipeline): description="Abc_Description", status=TestStatuses.WARNING.name, run=instance_runs[0].id, - start_time=datetime.now(tz=timezone.utc) - timedelta(minutes=15), - end_time=datetime.now(tz=timezone.utc) - timedelta(minutes=5), + start_time=datetime.now(tz=UTC) - timedelta(minutes=15), + end_time=datetime.now(tz=UTC) - timedelta(minutes=5), component=pipeline, ) yield test_outcome @@ -66,8 +66,8 @@ def test_outcomes(client, instance_instance_set, pipeline): name=f"DKTest{i}", description=f"Description{i}", status=f"{TestStatuses.PASSED.name if i % 2 == 0 else TestStatuses.FAILED.name}", - start_time=datetime.now(tz=timezone.utc) + timedelta(minutes=5 * i), - end_time=datetime.now(tz=timezone.utc) + timedelta(minutes=15 * i), + start_time=datetime.now(tz=UTC) + timedelta(minutes=5 * i), + end_time=datetime.now(tz=UTC) + timedelta(minutes=15 * i), component=pipeline, instance_set=instance_instance_set.instance_set, ) @@ -230,8 +230,8 @@ def create_test_outcomes(instances: list[Instance], component: Pipeline | Datase name=f"DKTest{i}", description=f"Description{i}", status=f"{TestStatuses.PASSED.name if i % 2 == 0 else TestStatuses.FAILED.name}", - start_time=datetime.now(tz=timezone.utc) + timedelta(minutes=5 * i), - end_time=datetime.now(tz=timezone.utc) + timedelta(minutes=15 * i), + start_time=datetime.now(tz=UTC) + timedelta(minutes=5 * i), + end_time=datetime.now(tz=UTC) + timedelta(minutes=15 * i), component=component, instance_set=instance_set, ) @@ -244,7 +244,7 @@ def create_dataset_operations(instances: list[Instance], dataset: Dataset): DatasetOperation.create( dataset=dataset, instance_set=instance_set, - operation_time=datetime.now(tz=timezone.utc), + operation_time=datetime.now(tz=UTC), operation=f"{DatasetOperationType.READ.name if i % 2 == 0 else DatasetOperationType.WRITE.name}", path="/path/to/file", ) @@ -252,7 +252,7 @@ def create_dataset_operations(instances: list[Instance], dataset: Dataset): @pytest.fixture def run_status_event(pipeline, project, journey, instances): - ts = datetime.now(timezone.utc) + ts = datetime.now(UTC) yield RunStatusEvent( project_id=project.id, event_id=uuid4(), @@ -360,7 +360,7 @@ def test_search_instances(client, journey, instances, g_user): response = client.post( f"/observability/v1/projects/{journey.project.id}/instances/search", - json={"params": {"start_range_end": datetime.now(timezone.utc).isoformat()}}, + json={"params": {"start_range_end": datetime.now(UTC).isoformat()}}, ) assert response.status_code == HTTPStatus.OK, response.json assert response.json["total"] == 6 @@ -472,7 +472,7 @@ class InstanceData: instances: list[Instance] = field(default_factory=list) -def create_instance_data(number: int, proj: Optional[Project] = None) -> InstanceData: +def create_instance_data(number: int, proj: Project | None = None) -> InstanceData: c = Company.create(name=f"TestCompany{number}") org = Organization.create(name=f"Internal Org{number}", company=c) if proj: @@ -519,8 +519,8 @@ def test_list_instances_with_filters_param_results(client, journey, g_user, proj instance.save() args = [("journey_name", name) for name in ("Test_Journey1", "Test_Journey2")] + [ - ("start_range_begin", (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat()), - ("start_range_end", datetime.now(timezone.utc).isoformat()), + ("start_range_begin", (datetime.now(UTC) - timedelta(hours=1)).isoformat()), + ("start_range_end", datetime.now(UTC).isoformat()), ] query_string = MultiDict(args) @@ -540,7 +540,7 @@ def test_list_instances_with_filters_param_results(client, journey, g_user, proj assert r1.json["total"] == 3 r1 = client.get( f"/observability/v1/projects/{instance_data1.project.id}/instances", - query_string=MultiDict([("start_range_begin", datetime.now(timezone.utc).isoformat())]), + query_string=MultiDict([("start_range_begin", datetime.now(UTC).isoformat())]), ) assert r1.status_code == HTTPStatus.OK, r1.json assert r1.json["total"] == 0 diff --git a/observability_api/tests/integration/v1_endpoints/test_jwt_auth.py b/observability_api/tests/integration/v1_endpoints/test_jwt_auth.py index 95cfd89..3041536 100644 --- a/observability_api/tests/integration/v1_endpoints/test_jwt_auth.py +++ b/observability_api/tests/integration/v1_endpoints/test_jwt_auth.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta, timezone, UTC from http import HTTPStatus from unittest.mock import patch from uuid import uuid4 @@ -61,7 +61,7 @@ def valid_token(token_user): "company_id": str(token_user.primary_company_id), "domain": "fakedomain.fake", } - dt = datetime.now(timezone.utc) + timedelta(days=2) + dt = datetime.now(UTC) + timedelta(days=2) data["exp"] = int(dt.replace(microsecond=0).timestamp()) return JWTAuth.encode_token(data) @@ -69,7 +69,7 @@ def valid_token(token_user): @pytest.fixture def invalid_token_bad_user(token_user): data = {"user_id": str(uuid4()), "company_id": str(token_user.primary_company_id)} - dt = datetime.now(timezone.utc) + timedelta(days=2) + dt = datetime.now(UTC) + timedelta(days=2) data["exp"] = int(dt.replace(microsecond=0).timestamp()) return JWTAuth.encode_token(data) @@ -140,15 +140,13 @@ def test_jwt_token_expiration(jwt_client, token_user): token = JWTAuth.log_user_in(token_user) claims = JWTAuth.decode_token(token) - assert get_token_expiration(claims) < datetime.now(timezone.utc) + timedelta(seconds=20) + assert get_token_expiration(claims) < datetime.now(UTC) + timedelta(seconds=20) @pytest.mark.integration def test_jwt_token_expiration_explicit(jwt_client, token_user): with patch("common.api.flask_ext.authentication.jwt_plugin.get_domain", return_value="fakedomain.fake"): - token = JWTAuth.log_user_in( - token_user, claims={"exp": (datetime.now(timezone.utc) + timedelta(seconds=10)).timestamp()} - ) + token = JWTAuth.log_user_in(token_user, claims={"exp": (datetime.now(UTC) + timedelta(seconds=10)).timestamp()}) claims = JWTAuth.decode_token(token) - assert get_token_expiration(claims) < datetime.now(timezone.utc) + timedelta(seconds=10) + assert get_token_expiration(claims) < datetime.now(UTC) + timedelta(seconds=10) diff --git a/observability_api/tests/integration/v1_endpoints/test_runs.py b/observability_api/tests/integration/v1_endpoints/test_runs.py index 3902be2..50060e2 100644 --- a/observability_api/tests/integration/v1_endpoints/test_runs.py +++ b/observability_api/tests/integration/v1_endpoints/test_runs.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta, timezone, UTC from http import HTTPStatus from itertools import chain from typing import Optional @@ -57,7 +57,7 @@ def uuid_value(): @pytest.fixture def run_status_event(runs, pipeline, project, uuid_value): - ts = str(datetime.now(timezone.utc)) + ts = str(datetime.now(UTC)) yield RunStatusEvent( **RunStatusSchema().load( { @@ -181,7 +181,7 @@ class RunData: alerts: list[RunAlert] = field(default_factory=list) -def create_run_data(number: int, proj: Optional[Project] = None, set_tool=False) -> RunData: +def create_run_data(number: int, proj: Project | None = None, set_tool=False) -> RunData: c = Company.create(name=f"TestCompany{number}") org = Organization.create(name=f"Internal Org{number}", company=c) if proj: @@ -235,8 +235,8 @@ def test_list_runs_with_filters_param_results(client, g_user_2_admin, pipeline, [("run_key", key) for key in ("1", "2")] + [("pipeline_key", key) for key in ("Test_Pipeline1", "Test_Pipeline2")] + [ - ("start_range_begin", (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat()), - ("start_range_end", datetime.now(timezone.utc).isoformat()), + ("start_range_begin", (datetime.now(UTC) - timedelta(hours=1)).isoformat()), + ("start_range_end", datetime.now(UTC).isoformat()), ] ) query_string = MultiDict(args) @@ -526,8 +526,8 @@ def test_list_batch_pipeline_runs_with_filters_param_results(client, g_user_2_ad _ = create_run_data(3, proj=rd_two.project) args = [("run_key", key) for key in ("1", "2")] + [ - ("start_range_begin", (datetime.now(timezone.utc) - timedelta(hours=1)).isoformat()), - ("start_range_end", datetime.now(timezone.utc).isoformat()), + ("start_range_begin", (datetime.now(UTC) - timedelta(hours=1)).isoformat()), + ("start_range_end", datetime.now(UTC).isoformat()), ("pipeline_id", pipeline.id), ] query_string = MultiDict(args) @@ -539,7 +539,7 @@ def test_list_batch_pipeline_runs_with_filters_param_results(client, g_user_2_ad f"/observability/v1/projects/{project.id}/runs", query_string=MultiDict( [ - ("start_range_begin", datetime.now(timezone.utc).isoformat()), + ("start_range_begin", datetime.now(UTC).isoformat()), ("pipeline_id", pipeline.id), ] ), @@ -578,9 +578,9 @@ def test_list_runs_for_instance(client, g_user, instances, runs, project): ) assert response.status_code == HTTPStatus.OK, response.json response_body = response.json - assert ( - response_body["total"] == 1 and len(response_body["entities"]) == response_body["total"] - ), "should return one run for each instance" + assert response_body["total"] == 1 and len(response_body["entities"]) == response_body["total"], ( + "should return one run for each instance" + ) assert len(response_body["entities"][0]["alerts"]) == 1 expected_run_ids = [str(r.id) for r in runs] assert response_body["entities"][0]["id"] in expected_run_ids, "the returned ID isn't one of the expected runs" @@ -619,9 +619,9 @@ def test_list_runs_for_instance_with_summaries(client, g_user, instances, runs, (RunTaskStatus.FAILED, 2), ) for status, expected in tasks_statuses: - assert any( - (status.name, expected) == (task["status"], task["count"]) for task in run["tasks_summary"] - ), f"did not find expected {{\"status\": {status.name}, \"count\": {expected}}} in {run['tasks_summary']}" + assert any((status.name, expected) == (task["status"], task["count"]) for task in run["tasks_summary"]), ( + f'did not find expected {{"status": {status.name}, "count": {expected}}} in {run["tasks_summary"]}' + ) assert len(run["alerts"]) == 1 assert run["alerts"][0]["level"] == AlertLevel["ERROR"].value @@ -721,7 +721,7 @@ def test_get_instance_runs_status_filters(query_string, expected, client, g_user key=key, pipeline=pipeline, instance_set=instance_set, - status=f"{RunStatus.COMPLETED.name if int(key) % 2 == 0 else RunStatus.FAILED.name}", + status=f"{RunStatus.COMPLETED.name if int(key) % 2 == 0 else RunStatus.FAILED.name}", ) base_query_string = {"instance_id": [instance.id], "status": query_string} response = client.get(f"/observability/v1/projects/{project.id}/runs", query_string=base_query_string) diff --git a/observability_api/tests/integration/v1_endpoints/test_service_account_keys.py b/observability_api/tests/integration/v1_endpoints/test_service_account_keys.py index 1a7f3da..d93338f 100644 --- a/observability_api/tests/integration/v1_endpoints/test_service_account_keys.py +++ b/observability_api/tests/integration/v1_endpoints/test_service_account_keys.py @@ -1,5 +1,5 @@ import uuid -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC from http import HTTPStatus import pytest @@ -24,7 +24,7 @@ def sa_key(client, project): @pytest.mark.integration def test_create_sa_key_success(client, g_user, project, sa_key_data): - today = datetime.now(timezone.utc) + today = datetime.now(UTC) response = client.post( f"/observability/v1/projects/{project.id}/service-account-key", headers={"Content-Type": "application/json"}, @@ -46,7 +46,7 @@ def test_create_sa_key_success(client, g_user, project, sa_key_data): @pytest.mark.integration def test_create_sa_key_with_name_and_description(client, g_user, project, sa_key_data): sa_key_data["description"] = "Whoa man, I'm just using this for auth" - today = datetime.now(timezone.utc) + today = datetime.now(UTC) response = client.post( f"/observability/v1/projects/{project.id}/service-account-key", headers={"Content-Type": "application/json"}, diff --git a/observability_api/tests/integration/v1_endpoints/test_upcoming_instances.py b/observability_api/tests/integration/v1_endpoints/test_upcoming_instances.py index f96c724..c42bdac 100644 --- a/observability_api/tests/integration/v1_endpoints/test_upcoming_instances.py +++ b/observability_api/tests/integration/v1_endpoints/test_upcoming_instances.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta, timezone, UTC from http import HTTPStatus import pytest @@ -14,7 +14,7 @@ def test_list_project_upcoming_instances_instance_schedule(client, journey, jour InstanceRule.create(journey=journey, action=InstanceRuleAction.END, expression="30,40 * * * *") InstanceRule.create(journey=journey_2, action=InstanceRuleAction.START, expression="10,50 * * * *") - start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=timezone.utc) + start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=UTC) query = MultiDict( [ ("start_range", start_time.isoformat()), @@ -70,7 +70,7 @@ def test_list_project_upcoming_instances_batch_schedule( batch_end_schedule.schedule = "30 * * * *" batch_end_schedule.save() - start_time = datetime(1991, 2, 20, 10, 59, 00, tzinfo=timezone.utc) + start_time = datetime(1991, 2, 20, 10, 59, 00, tzinfo=UTC) query = MultiDict( [ ("start_range", start_time.isoformat()), @@ -94,7 +94,7 @@ def test_list_project_upcoming_instances_filters(client, journey, journey_2, ins InstanceRule.create(journey=journey, action=InstanceRuleAction.START, expression="10 * * * *") InstanceRule.create(journey=journey_2, action=InstanceRuleAction.START, expression="30 * * * *") - start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=timezone.utc) + start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=UTC) query = MultiDict( [ ("start_range", start_time.isoformat()), @@ -148,7 +148,7 @@ def test_list_company_upcoming_instances_instance_schedule(client, organization, InstanceRule.create(journey=journey, action=InstanceRuleAction.START, expression="10 * * * *") InstanceRule.create(journey=journey_2, action=InstanceRuleAction.START, expression="30 * * * *") - start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=timezone.utc) + start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=UTC) query = MultiDict( [ ("start_range", start_time.isoformat()), @@ -173,7 +173,7 @@ def test_list_company_upcoming_instances_filters(client, project, organization, InstanceRule.create(journey=journey, action=InstanceRuleAction.START, expression="10 * * * *") InstanceRule.create(journey=journey_2, action=InstanceRuleAction.START, expression="30 * * * *") - start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=timezone.utc) + start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=UTC) query = MultiDict( [ ("start_range", start_time.isoformat()), @@ -210,7 +210,7 @@ def test_list_company_upcoming_instances_filters(client, project, organization, @pytest.mark.integration def test_list_project_upcoming_instances_sa_key_auth_ok(client, journey, journey_2, instance, g_project): - start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=timezone.utc) + start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=UTC) query = MultiDict( [ ("start_range", start_time.isoformat()), @@ -223,7 +223,7 @@ def test_list_project_upcoming_instances_sa_key_auth_ok(client, journey, journey @pytest.mark.integration def test_list_company_upcoming_instances_sa_key_auth_forbidden(client, journey, journey_2, instance, g_project): - start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=timezone.utc) + start_time = datetime(1991, 2, 20, 10, 00, 00, tzinfo=UTC) query = MultiDict( [ ("start_range", start_time.isoformat()), diff --git a/pyproject.toml b/pyproject.toml index 67b7517..5144753 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,14 +47,14 @@ dependencies = [ [project.optional-dependencies] dev = [ - "ruff~=0.7.3", + "ruff~=0.12.0", "invoke~=2.1.2", "lxml~=4.9.1", - "mypy~=1.5.0", + "mypy~=1.16.1", "pre-commit~=2.20.0", - "pytest-cov~=4.0.0", - "pytest-xdist~=3.1.0", - "pytest~=7.2.0", + "pytest-cov~=6.2.1", + "pytest-xdist~=3.7.0", + "pytest~=8.4.1", "pyyaml~=6.0", "types-PyYAML~=6.0.8", "types-requests==2.28.11.15", @@ -223,6 +223,10 @@ check_untyped_defs = false module = "PIL.*" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "IPython.*" +ignore_missing_imports = true + [[tool.mypy.overrides]] module = "invoke" ignore_missing_imports = true @@ -235,6 +239,10 @@ ignore_missing_imports = true module = "msgpack.*" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "observability_plugins.*" +ignore_missing_imports = true + [[tool.mypy.overrides]] module = "marshmallow_union.*" ignore_missing_imports = true @@ -243,8 +251,12 @@ ignore_missing_imports = true module = "pybars" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "yoyo.*" +ignore_missing_imports = true + [tool.ruff] -target-version = "py310" +target-version = "py312" line-length = 120 [tool.ruff.lint] diff --git a/rules_engine/journey_rules.py b/rules_engine/journey_rules.py index 8ffe6ed..d13b3ad 100644 --- a/rules_engine/journey_rules.py +++ b/rules_engine/journey_rules.py @@ -5,7 +5,7 @@ import logging import time from functools import lru_cache -from typing import Any, Optional +from typing import Any from collections.abc import Callable from uuid import UUID @@ -33,19 +33,19 @@ class JourneyRule: def __init__( self, r_obj: R, - rule_entity: Optional[RuleEntity], + rule_entity: RuleEntity | None, *triggers: Callable, - journey_id: Optional[UUID] = None, - component_id: Optional[UUID] = None, + journey_id: UUID | None = None, + component_id: UUID | None = None, ) -> None: self.r_obj: R = r_obj self.rule_entity = rule_entity - self.triggers: tuple[Callable[[EVENT_TYPE, Optional[RuleEntity], Optional[UUID]], ActionResult], ...] = triggers - self.journey_id: Optional[UUID] = journey_id - self.component_id: Optional[UUID] = component_id + self.triggers: tuple[Callable[[EVENT_TYPE, RuleEntity | None, UUID | None], ActionResult], ...] = triggers + self.journey_id: UUID | None = journey_id + self.component_id: UUID | None = component_id @staticmethod - def _get_component_id(event: EVENT_TYPE) -> Optional[UUID]: + def _get_component_id(event: EVENT_TYPE) -> UUID | None: """Extract the component id from the given event.""" match event: case Event(): @@ -81,7 +81,7 @@ def __str__(self) -> str: return f"{self.__module__}.{self.__class__.__name__}: {self.r_obj}" -def _execute_action(event: EVENT_TYPE, rule_entity: RuleEntity, journey_id: Optional[UUID]) -> Any: +def _execute_action(event: EVENT_TYPE, rule_entity: RuleEntity, journey_id: UUID | None) -> Any: action_entity = JourneyService.get_action_by_implementation(rule_entity.journey_id, rule_entity.action) action = action_factory(rule_entity.action, rule_entity.action_args, action_entity) action.execute(event, rule_entity, journey_id) diff --git a/rules_engine/rule_data.py b/rules_engine/rule_data.py index f9f3f5f..e652a94 100644 --- a/rules_engine/rule_data.py +++ b/rules_engine/rule_data.py @@ -1,6 +1,5 @@ __all__ = ["RuleData"] -from typing import Optional from uuid import UUID from peewee import SelectQuery, fn @@ -21,7 +20,7 @@ class DatabaseData: def __init__(self, event: EVENT_TYPE) -> None: self.event = event - def _get_batch_pipeline_id(self) -> Optional[UUID]: + def _get_batch_pipeline_id(self) -> UUID | None: match self.event: case Event(): return self.event.pipeline_id diff --git a/rules_engine/tests/integration/conftest.py b/rules_engine/tests/integration/conftest.py index 789f275..39938a2 100644 --- a/rules_engine/tests/integration/conftest.py +++ b/rules_engine/tests/integration/conftest.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC from unittest.mock import MagicMock, Mock, patch from uuid import uuid4 @@ -31,8 +31,8 @@ def base_event_data(): "pipeline_name": None, "project_id": str(uuid4()), "source": EventSources.API.name, - "event_timestamp": str(datetime.now(timezone.utc)), - "received_timestamp": str(datetime.now(timezone.utc)), + "event_timestamp": str(datetime.now(UTC)), + "received_timestamp": str(datetime.now(UTC)), "external_url": "https://example.com", "metadata": {}, "run_id": None, diff --git a/rules_engine/tests/unit/test_data_points.py b/rules_engine/tests/unit/test_data_points.py index e189149..0b9fcb4 100644 --- a/rules_engine/tests/unit/test_data_points.py +++ b/rules_engine/tests/unit/test_data_points.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC import pytest @@ -206,7 +206,7 @@ def test_run_data_points(run_status_event, run, rule): # DataPoints is using the correct source assert run.status != run_status_event.status assert dps.run.status == run.status - assert dps.run.start_time == run.start_time.replace(tzinfo=timezone.utc).isoformat() + assert dps.run.start_time == run.start_time.replace(tzinfo=UTC).isoformat() assert dps.run.start_time_formatted == datetime_formatted(run.start_time) assert dps.run.end_time == "N/A" assert dps.run.end_time_formatted == "N/A" @@ -221,7 +221,7 @@ def test_run_data_points_with_end_time(run_status_event, run, rule): run.save() dps = DataPoints(run_status_event, rule) # These are tested in a separate function because the run is cached in Event - assert dps.run.end_time == run.end_time.replace(tzinfo=timezone.utc).isoformat() + assert dps.run.end_time == run.end_time.replace(tzinfo=UTC).isoformat() assert dps.run.end_time_formatted == datetime_formatted(run.end_time) @@ -260,9 +260,9 @@ def test_run_task_data_points_with_times(run_status_event, run_task, rule): run_task.save() dps = DataPoints(run_status_event, rule) # These are tested in a separate function because the run task is cached in Event - assert dps.run_task.start_time == run_task.start_time.replace(tzinfo=timezone.utc).isoformat() + assert dps.run_task.start_time == run_task.start_time.replace(tzinfo=UTC).isoformat() assert dps.run_task.start_time_formatted == datetime_formatted(run_task.start_time) - assert dps.run_task.end_time == run_task.end_time.replace(tzinfo=timezone.utc).isoformat() + assert dps.run_task.end_time == run_task.end_time.replace(tzinfo=UTC).isoformat() assert dps.run_task.end_time_formatted == datetime_formatted(run_task.end_time) diff --git a/run_manager/alerts.py b/run_manager/alerts.py index ba707ed..aa4789a 100644 --- a/run_manager/alerts.py +++ b/run_manager/alerts.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from collections.abc import Iterable from uuid import UUID @@ -82,8 +81,8 @@ def create_run_alert(alert_type: RunAlertType, run: Run, pipeline: Pipeline) -> def create_instance_alert( alert_type: InstanceAlertType, instance: Instance, - component: Optional[Component] = None, - alert_components: Optional[Iterable[UUID]] = None, + component: Component | None = None, + alert_components: Iterable[UUID] | None = None, ) -> InstanceAlertEvent: alert_level = INSTANCE_ALERT_LEVELS[alert_type] alert_description = INSTANCE_ALERT_DESCRIPTIONS[alert_type].format( diff --git a/run_manager/context.py b/run_manager/context.py index 2da0c14..a881520 100644 --- a/run_manager/context.py +++ b/run_manager/context.py @@ -1,6 +1,5 @@ __all__ = ["RunManagerContext"] from dataclasses import dataclass, field -from typing import Optional from uuid import UUID from common.entities import Component, Instance, InstanceSet, Pipeline, Run, RunTask, Task @@ -13,21 +12,21 @@ class RunManagerContext: A context object to pass a state around when handling events in run manager """ - component: Optional[Component] = None + component: Component | None = None # Keeping pipeline to keep "pre" event v1 code intact, i.e. avoid significant refactoring effort - pipeline: Optional[Pipeline] = None - run: Optional[Run] = None - task: Optional[Task] = None - run_task: Optional[RunTask] = None + pipeline: Pipeline | None = None + run: Run | None = None + task: Task | None = None + run_task: RunTask | None = None instances: list[InstanceRef] = field(default_factory=list) - instance_set: Optional[InstanceSet] = None + instance_set: InstanceSet | None = None ended_instances: list[UUID | Instance] = field(default_factory=list) """List of instances that ended in this context""" created_run: bool = False """Indicates if the run was created during this context""" started_run: bool = False """Indicates if the run was started during this context""" - prev_run_status: Optional[str] = None + prev_run_status: str | None = None """Previous run status before being processed by the run handler. This is to check for unexpected run status changed""" diff --git a/run_manager/event_handlers/component_identifier.py b/run_manager/event_handlers/component_identifier.py index 61b98c1..30871ef 100644 --- a/run_manager/event_handlers/component_identifier.py +++ b/run_manager/event_handlers/component_identifier.py @@ -1,7 +1,7 @@ __all__ = ["ComponentIdentifier"] import logging -from typing import Optional, cast +from typing import cast from peewee import DoesNotExist @@ -25,7 +25,7 @@ """Map event component type to db model""" -def _get_component(event: Event) -> Optional[Component]: +def _get_component(event: Event) -> Component | None: # v1 event can only have one (component type)_id if component_id := event.component_id: try: @@ -56,7 +56,7 @@ def _get_component(event: Event) -> Optional[Component]: return component -def _create_component(event: Event) -> Optional[Component]: +def _create_component(event: Event) -> Component | None: component: Component = event.component_model.create( key=event.component_key, name=event.component_name, tool=event.component_tool, project_id=event.project_id ) diff --git a/run_manager/event_handlers/instance_handler.py b/run_manager/event_handlers/instance_handler.py index 82cf862..1b98f36 100644 --- a/run_manager/event_handlers/instance_handler.py +++ b/run_manager/event_handlers/instance_handler.py @@ -2,7 +2,7 @@ import logging from collections import defaultdict from itertools import chain -from typing import Any, Optional, cast +from typing import Any, cast from collections.abc import Callable, Mapping from uuid import UUID @@ -34,7 +34,7 @@ LOG = logging.getLogger(__name__) -def _find_existing_instance(journey: Journey, with_run: bool, payload_key: Optional[str]) -> Optional[Instance]: +def _find_existing_instance(journey: Journey, with_run: bool, payload_key: str | None) -> Instance | None: if payload_key is None: f: Callable[[Any], bool | Any] = lambda k: k.payload_key is None else: @@ -234,7 +234,7 @@ def default_instance_creation(self, event: Event) -> list[InstanceRef]: identified_instances: list[InstanceRef] = [] with_run = event.run_id is not None for journey in event.component_journeys: - payload_keys: list[Optional[str]] = [None] + payload_keys: list[str | None] = [None] if any(rule.action is InstanceRuleAction.END_PAYLOAD for rule in journey.instance_rules): payload_keys.extend(event.payload_keys) for payload_key in payload_keys: diff --git a/run_manager/event_handlers/run_handler.py b/run_manager/event_handlers/run_handler.py index 91c829b..3f53243 100644 --- a/run_manager/event_handlers/run_handler.py +++ b/run_manager/event_handlers/run_handler.py @@ -1,7 +1,7 @@ __all__ = ["RunHandler"] import logging -from typing import Optional, cast +from typing import cast from peewee import DoesNotExist @@ -102,7 +102,7 @@ def _handle_event(self, event: Event) -> None: f"for batch-pipeline {self.context.pipeline.id}" ) - def _get_run(self, event: Event, pipeline: Pipeline) -> Optional[Run]: + def _get_run(self, event: Event, pipeline: Pipeline) -> Run | None: """ Get an existing run instance diff --git a/run_manager/event_handlers/run_unexpected_status_change_handler.py b/run_manager/event_handlers/run_unexpected_status_change_handler.py index 8ed8a17..2078391 100644 --- a/run_manager/event_handlers/run_unexpected_status_change_handler.py +++ b/run_manager/event_handlers/run_unexpected_status_change_handler.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from common.entities import RunAlertType, RunStatus from common.events import EventHandlerBase @@ -19,7 +18,7 @@ """Map new run status to alert type""" -def _get_alert_type(context: RunManagerContext) -> Optional[RunAlertType]: +def _get_alert_type(context: RunManagerContext) -> RunAlertType | None: if context.run is None: raise ValueError("The `run` attribute for the context object must be populated with a valid Run instance") @@ -52,7 +51,7 @@ def handle_test_outcomes(self, event: TestOutcomesEvent) -> bool: def handle_run_status(self, event: RunStatusEvent) -> bool: if (pipeline := self.context.pipeline) is None or (run := self.context.run) is None: raise ValueError( - "The context object must be populated with a valid Pipeline and Run instance " "for RunStatusEvent" + "The context object must be populated with a valid Pipeline and Run instance for RunStatusEvent" ) if (alert_type := _get_alert_type(self.context)) is not None: diff --git a/run_manager/tests/integration/conftest.py b/run_manager/tests/integration/conftest.py index e49b6b8..b58f912 100644 --- a/run_manager/tests/integration/conftest.py +++ b/run_manager/tests/integration/conftest.py @@ -1,6 +1,6 @@ import uuid from dataclasses import replace -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta, timezone, UTC from unittest.mock import MagicMock import pytest @@ -57,7 +57,7 @@ def compare_event_data(unidentified_event, identified_event, pipeline, run, task @pytest.fixture def timestamp_now(): - return datetime.now(tz=timezone.utc) + return datetime.now(tz=UTC) @pytest.fixture @@ -367,7 +367,7 @@ def pipeline_end_payload_rule(journey, pipeline): @pytest.fixture def timestamp_now(): - return datetime.now(tz=timezone.utc) + return datetime.now(tz=UTC) @pytest.fixture diff --git a/run_manager/tests/integration/event_handlers/test_out_of_sequence_instance_handler.py b/run_manager/tests/integration/event_handlers/test_out_of_sequence_instance_handler.py index 340ead1..8e62e85 100644 --- a/run_manager/tests/integration/event_handlers/test_out_of_sequence_instance_handler.py +++ b/run_manager/tests/integration/event_handlers/test_out_of_sequence_instance_handler.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC import pytest @@ -47,8 +47,8 @@ def __init__(self, project, pipelines=None, name="out-of-sequence-test"): project=project, instance_set=self.instance_set.id, status=RunStatus.COMPLETED.name, - start_time=datetime.now(tz=timezone.utc), - end_time=datetime.now(tz=timezone.utc), + start_time=datetime.now(tz=UTC), + end_time=datetime.now(tz=UTC), ) ) JourneyDagEdge.create( diff --git a/run_manager/tests/integration/test_run_handler.py b/run_manager/tests/integration/test_run_handler.py index dda4f41..e7048f2 100644 --- a/run_manager/tests/integration/test_run_handler.py +++ b/run_manager/tests/integration/test_run_handler.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC from uuid import UUID import pytest @@ -26,12 +26,12 @@ def test_run_handler_new_pending_metadata(run_status_event_pending, pipeline): # Retrieve the run and make sure the expected_start_time has been updated run = Run.get_by_id(handler.context.run.id) - assert run.expected_start_time == datetime(2005, 3, 1, 1, 1, 1, tzinfo=timezone.utc) + assert run.expected_start_time == datetime(2005, 3, 1, 1, 1, 1, tzinfo=UTC) @pytest.mark.integration def test_run_handler_no_overwrite_expected_start_time(run_status_event_missing, pipeline): - expected_start_time = datetime(2005, 3, 2, 2, 2, 2, tzinfo=timezone.utc) + expected_start_time = datetime(2005, 3, 2, 2, 2, 2, tzinfo=UTC) run = Run.create( id=UUID("dbed19e1-d0bb-4860-bbdf-9d768cb90764"), pipeline=pipeline, diff --git a/run_manager/tests/integration/test_run_manager_instance.py b/run_manager/tests/integration/test_run_manager_instance.py index 13cf35c..e556ea0 100644 --- a/run_manager/tests/integration/test_run_manager_instance.py +++ b/run_manager/tests/integration/test_run_manager_instance.py @@ -1,5 +1,5 @@ import copy -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC from uuid import uuid4 import pytest @@ -257,7 +257,7 @@ def test_run_manager_dont_modify_previously_closed_instance( ): p1 = Pipeline.create(key="pipe1", project=project) j1 = Journey.create(name="journey1", project=project) - instance_time = datetime.utcnow().replace(tzinfo=timezone.utc) + instance_time = datetime.utcnow().replace(tzinfo=UTC) i1 = Instance.create(journey=j1, start_time=instance_time, end_time=instance_time) InstanceRule.create(journey=j1, action=InstanceRuleAction.START, batch_pipeline=p1) InstanceRule.create(journey=j1, action=InstanceRuleAction.END, batch_pipeline=p1) diff --git a/run_manager/tests/integration/test_run_manager_unordered_events.py b/run_manager/tests/integration/test_run_manager_unordered_events.py index dd1753c..c53ab6e 100644 --- a/run_manager/tests/integration/test_run_manager_unordered_events.py +++ b/run_manager/tests/integration/test_run_manager_unordered_events.py @@ -1,5 +1,5 @@ import copy -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta, timezone, UTC from uuid import uuid4 import pytest @@ -15,7 +15,7 @@ def test_keep_run_end_state_and_update_start_state(pipeline, kafka_consumer, kaf Keep the run end state (end time, status) when an old message is processed. Update the start time when older message is received. """ - new_time = datetime.now(tz=timezone.utc) + new_time = datetime.now(tz=UTC) old_time = new_time - timedelta(minutes=5) old_start_status_message = run_status_message @@ -60,7 +60,7 @@ def test_reopen_run_on_newer_status( assert run.end_time is not None # Re-open existing run - run_status_message.payload.event_timestamp = datetime.utcnow().replace(tzinfo=timezone.utc) + run_status_message.payload.event_timestamp = datetime.utcnow().replace(tzinfo=UTC) run_status_message.payload.event_id = uuid4() kafka_consumer.__iter__.return_value = iter((run_status_message,)) run_manager.process_events() @@ -77,7 +77,7 @@ def test_keep_task_end_state_and_update_start_state(kafka_consumer, kafka_produc Keep the task end state (end time, status) when an old message is processed. Update the start time when older message is received. """ - new_time = datetime.now(tz=timezone.utc) + new_time = datetime.now(tz=UTC) old_time = new_time - timedelta(minutes=5) old_status_message = task_status_message old_status_message.payload.event_timestamp = old_time @@ -122,7 +122,7 @@ def test_reopen_task_on_newer_status(pipeline, kafka_consumer, kafka_producer, r assert run.end_time is not None # Re-open existing run - run_status_message.payload.event_timestamp = datetime.utcnow().replace(tzinfo=timezone.utc) + run_status_message.payload.event_timestamp = datetime.utcnow().replace(tzinfo=UTC) run_status_message.payload.event_id = uuid4() kafka_consumer.__iter__.return_value = iter((run_status_message,)) run_manager.process_events() diff --git a/run_manager/tests/integration/test_scheduler_events.py b/run_manager/tests/integration/test_scheduler_events.py index 891ad71..1d5652a 100644 --- a/run_manager/tests/integration/test_scheduler_events.py +++ b/run_manager/tests/integration/test_scheduler_events.py @@ -1,5 +1,5 @@ from collections import Counter -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta, timezone, UTC from uuid import uuid4 import pytest @@ -159,7 +159,7 @@ def test_scheduler_run_should_start_with_pending( ): """Check that pending runs are marked missing when start schedule occurs""" instance = Instance.create(journey=journey) - expected_start_time = datetime(2005, 3, 2, 2, 2, 2, tzinfo=timezone.utc) + expected_start_time = datetime(2005, 3, 2, 2, 2, 2, tzinfo=UTC) run = Run.create( status=RunStatus.PENDING.name, start_time=None, diff --git a/scheduler/agent_check.py b/scheduler/agent_check.py index 5ba0b5c..55058d3 100644 --- a/scheduler/agent_check.py +++ b/scheduler/agent_check.py @@ -1,5 +1,5 @@ import logging -from datetime import datetime, timezone, timedelta +from datetime import datetime, timedelta, UTC from apscheduler.triggers.interval import IntervalTrigger @@ -14,7 +14,7 @@ def _get_agent_status(check_interval_seconds: int, latest_heartbeat: datetime) -> AgentStatus: - lateness = (datetime.now(tz=timezone.utc) - latest_heartbeat).total_seconds() + lateness = (datetime.now(tz=UTC) - latest_heartbeat).total_seconds() if lateness > check_interval_seconds * settings.AGENT_STATUS_CHECK_OFFLINE_FACTOR: return AgentStatus.OFFLINE elif lateness > check_interval_seconds * settings.AGENT_STATUS_CHECK_UNHEALTHY_FACTOR: @@ -65,7 +65,7 @@ def _create_and_add_job(self, schedule: AgentCheckSchedule) -> None: ) def _check_agents_are_online(self, project: Project) -> None: - check_threshold = datetime.now(tz=timezone.utc) - timedelta(seconds=project.agent_check_interval) + check_threshold = datetime.now(tz=UTC) - timedelta(seconds=project.agent_check_interval) for agent in Agent.select().where( Agent.project == project.id, Agent.latest_heartbeat < check_threshold, diff --git a/scheduler/component_expectations.py b/scheduler/component_expectations.py index 192ed50..c733296 100644 --- a/scheduler/component_expectations.py +++ b/scheduler/component_expectations.py @@ -1,6 +1,5 @@ import logging from datetime import datetime, timedelta -from typing import Optional from uuid import UUID from apscheduler.triggers.cron import CronTrigger @@ -29,7 +28,7 @@ def _produce_event( component_id: UUID, schedule_type: ScheduleType, is_margin: bool, - margin: Optional[int] = None, + margin: int | None = None, ) -> None: """Create and forward corresponding scheduler event(s) to the run manager""" if is_margin: diff --git a/scheduler/schedule_source.py b/scheduler/schedule_source.py index 7d3d22e..aac4b95 100644 --- a/scheduler/schedule_source.py +++ b/scheduler/schedule_source.py @@ -3,7 +3,7 @@ import logging from collections import defaultdict from datetime import datetime -from typing import Any, Optional, Protocol, Generic, TypeVar +from typing import Any, Protocol, TypeVar from collections.abc import Callable from apscheduler.executors.pool import ThreadPoolExecutor @@ -41,7 +41,7 @@ def id(self) -> str: ... ST = TypeVar("ST", bound=Schedule) -class ScheduleSource(Generic[ST]): +class ScheduleSource[ST: Schedule]: """Concentrates all features and configurations around a specific source of schedules.""" source_name: str @@ -65,7 +65,7 @@ def jobstore_name(self) -> str: def executor_name(self) -> str: return self.source_name - def add_job(self, func: Callable, job_id: str, trigger: BaseTrigger, kwargs: Optional[dict[str, Any]]) -> Job: + def add_job(self, func: Callable, job_id: str, trigger: BaseTrigger, kwargs: dict[str, Any] | None) -> Job: return self.scheduler.add_job( func, id=job_id, diff --git a/scheduler/tests/integration/test_agent_scheduler.py b/scheduler/tests/integration/test_agent_scheduler.py index 0bc4d31..114ea84 100644 --- a/scheduler/tests/integration/test_agent_scheduler.py +++ b/scheduler/tests/integration/test_agent_scheduler.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone, timedelta +from datetime import datetime, timezone, timedelta, UTC from unittest.mock import patch import pytest @@ -19,7 +19,7 @@ def agents(project): tool="tool", version="vTest", status=AgentStatus.ONLINE, - latest_heartbeat=datetime.now(tz=timezone.utc) - timedelta(seconds=elapsed_time), + latest_heartbeat=datetime.now(tz=UTC) - timedelta(seconds=elapsed_time), ) for elapsed_time in ( 25, # Below the checking threshold diff --git a/scheduler/tests/integration/test_schedule_source.py b/scheduler/tests/integration/test_schedule_source.py index 67d50ad..e713686 100644 --- a/scheduler/tests/integration/test_schedule_source.py +++ b/scheduler/tests/integration/test_schedule_source.py @@ -1,5 +1,5 @@ import threading -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC from unittest.mock import Mock, patch import pytest @@ -29,7 +29,7 @@ def get_next_fire_time(self, previous_fire_time, now): if previous_fire_time: return None else: - return datetime.now(tz=timezone.utc) + return datetime.now(tz=UTC) class TestScheduleSource(ScheduleSource): diff --git a/scheduler/tests/unit/conftest.py b/scheduler/tests/unit/conftest.py index 1dae93d..05fab03 100644 --- a/scheduler/tests/unit/conftest.py +++ b/scheduler/tests/unit/conftest.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import datetime, timezone, UTC from unittest.mock import Mock, patch from uuid import uuid4 @@ -42,7 +42,7 @@ def agent_source(scheduler, event_producer_mock): @pytest.fixture def job_kwargs(): return { - "run_time": datetime.now(tz=timezone.utc), + "run_time": datetime.now(tz=UTC), "schedule_type": ScheduleType.BATCH_END_TIME, "schedule_id": str(uuid4()), "component_id": str(uuid4()), @@ -64,4 +64,4 @@ def schedule_data(): @pytest.fixture def run_time(): - return datetime.now(tz=timezone.utc) + return datetime.now(tz=UTC) diff --git a/scheduler/tests/unit/test_agent_scheduler.py b/scheduler/tests/unit/test_agent_scheduler.py index 640d509..20aa3fa 100644 --- a/scheduler/tests/unit/test_agent_scheduler.py +++ b/scheduler/tests/unit/test_agent_scheduler.py @@ -1,5 +1,5 @@ import uuid -from datetime import datetime, timezone, timedelta +from datetime import datetime, timezone, timedelta, UTC from unittest.mock import patch import pytest @@ -43,5 +43,5 @@ def test_add_job(agent_source): ], ) def test_get_agent_status(elapsed_time, expected_status): - latest_heartbeat = datetime.now(tz=timezone.utc) - timedelta(seconds=elapsed_time) + latest_heartbeat = datetime.now(tz=UTC) - timedelta(seconds=elapsed_time) assert _get_agent_status(CHECK_INTERVAL, latest_heartbeat) == expected_status diff --git a/scripts/invocations/deploy.py b/scripts/invocations/deploy.py index 9b48a50..4ac5d37 100644 --- a/scripts/invocations/deploy.py +++ b/scripts/invocations/deploy.py @@ -315,7 +315,7 @@ def build( if ui: ctx.run( - f"docker build . {args_str} " f"-t 'observability-ui:{tag}' -f ./deploy/docker/observability-ui.dockerfile", + f"docker build . {args_str} -t 'observability-ui:{tag}' -f ./deploy/docker/observability-ui.dockerfile", env=env, ) diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 120c85e..0000000 --- a/setup.cfg +++ /dev/null @@ -1,2 +0,0 @@ -# This file is here to support editable installs (pip install -e .) -# https://github.com/pypa/setuptools/issues/2816 diff --git a/testlib/fixtures/entities.py b/testlib/fixtures/entities.py index 4acff83..c05ec17 100644 --- a/testlib/fixtures/entities.py +++ b/testlib/fixtures/entities.py @@ -61,7 +61,7 @@ import base64 import uuid -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta, UTC from decimal import Decimal from unittest.mock import patch from uuid import UUID @@ -445,11 +445,11 @@ def batch_end_schedule(pipeline): ) -ALERT_EXPECTED_START_DT: datetime = datetime(2005, 8, 13, 4, 2, 4, 2, tzinfo=timezone.utc) +ALERT_EXPECTED_START_DT: datetime = datetime(2005, 8, 13, 4, 2, 4, 2, tzinfo=UTC) """The datetime of the expected_start_time for alert fixtures.""" -ALERT_EXPECTED_END_DT: datetime = datetime(2005, 8, 13, 8, 4, 8, 1, tzinfo=timezone.utc) +ALERT_EXPECTED_END_DT: datetime = datetime(2005, 8, 13, 8, 4, 8, 1, tzinfo=UTC) """The datetime of the expected_end_time for alert fixtures.""" @@ -496,8 +496,8 @@ def test_outcome(component, run, task): component=component, description="Testy McTestface", dimensions=["a", "b", "c"], - start_time=datetime(2000, 3, 9, 12, 11, 10, tzinfo=timezone.utc), - end_time=datetime(2000, 3, 9, 13, 12, 11, tzinfo=timezone.utc), + start_time=datetime(2000, 3, 9, 12, 11, 10, tzinfo=UTC), + end_time=datetime(2000, 3, 9, 13, 12, 11, tzinfo=UTC), name="test-outcome-1", external_url="https://fake.testy/do-not-go-here", key="test-outcome-key-1", @@ -563,10 +563,10 @@ def testgen_dataset_component(test_db, dataset): yield dataset_component -AGENT_LATEST_EVENT = datetime(2023, 10, 17, 12, 33, 19, 154295, tzinfo=timezone.utc) +AGENT_LATEST_EVENT = datetime(2023, 10, 17, 12, 33, 19, 154295, tzinfo=UTC) """Default timestamp for latest event received by an agent.""" -AGENT_LATEST_HEARTBEAT = datetime(2023, 10, 17, 12, 42, 42, 424242, tzinfo=timezone.utc) +AGENT_LATEST_HEARTBEAT = datetime(2023, 10, 17, 12, 42, 42, 424242, tzinfo=UTC) """Default timestamp for the lasttime an agent checked-in.""" @@ -585,7 +585,7 @@ def agent_1(test_db, project): @pytest.fixture() def agent_2(test_db, project): - dt_1 = datetime.now(timezone.utc) + dt_1 = datetime.now(UTC) dt_2 = dt_1 + timedelta(seconds=42) return Agent.create( project=project, @@ -603,8 +603,8 @@ def event_entity(test_db, pipeline, task, run, run_task, instance_instance_set): return EventEntity.create( version=EventVersion.V2, type=ApiEventType.BATCH_PIPELINE_STATUS, - created_timestamp=datetime(2024, 1, 20, 10, 0, 0, tzinfo=timezone.utc), - timestamp=datetime(2024, 1, 20, 9, 59, 0, tzinfo=timezone.utc), + created_timestamp=datetime(2024, 1, 20, 10, 0, 0, tzinfo=UTC), + timestamp=datetime(2024, 1, 20, 9, 59, 0, tzinfo=UTC), project=pipeline.project_id, component=pipeline, task=task, @@ -620,8 +620,8 @@ def event_entity_2(test_db, dataset): return EventEntity.create( version=EventVersion.V2, type=ApiEventType.DATASET_OPERATION, - created_timestamp=datetime(2024, 1, 20, 9, 55, 0, tzinfo=timezone.utc), - timestamp=datetime(2024, 1, 20, 9, 55, 0, tzinfo=timezone.utc), + created_timestamp=datetime(2024, 1, 20, 9, 55, 0, tzinfo=UTC), + timestamp=datetime(2024, 1, 20, 9, 55, 0, tzinfo=UTC), project=dataset.project_id, component=dataset, v2_payload={}, diff --git a/testlib/fixtures/v1_events.py b/testlib/fixtures/v1_events.py index df9ab4e..060e8f8 100644 --- a/testlib/fixtures/v1_events.py +++ b/testlib/fixtures/v1_events.py @@ -32,7 +32,7 @@ ] -from datetime import datetime, timezone +from datetime import datetime, UTC from decimal import Decimal from uuid import UUID @@ -220,7 +220,7 @@ def FAILED_run_status_event_data(FAILED_run_status_event): @pytest.fixture def test_outcome_item_data(metadata_model) -> dict: - timestamp = datetime.now(timezone.utc).isoformat() + timestamp = datetime.now(UTC).isoformat() yield { "name": "My_test_name", "status": TestStatuses.PASSED.name, diff --git a/testlib/fixtures/v2_events.py b/testlib/fixtures/v2_events.py index 800fa87..371087f 100644 --- a/testlib/fixtures/v2_events.py +++ b/testlib/fixtures/v2_events.py @@ -19,7 +19,7 @@ "test_outcomes_testgen_event_v2", ] -from datetime import datetime, timedelta, timezone +from datetime import datetime, timedelta, UTC from decimal import Decimal from uuid import UUID @@ -84,7 +84,7 @@ TEST_OUTCOMES_EVENT_ID: UUID = UUID("83af84bc-318e-4dda-9d40-6c7c8bacd992") """ID for EventV2 LOG event.""" -EVENT_TIMESTAMP: datetime = datetime(2023, 5, 10, 1, 1, 1, tzinfo=timezone.utc) +EVENT_TIMESTAMP: datetime = datetime(2023, 5, 10, 1, 1, 1, tzinfo=UTC) """Default timestamp for events.""" CREATED_TIMESTAMP: datetime = EVENT_TIMESTAMP + timedelta(minutes=3, seconds=1) @@ -191,7 +191,7 @@ def run_alert() -> RunAlert: @pytest.fixture def test_outcome_item(metadata_model) -> TestOutcomeItem: - timestamp = datetime.now(timezone.utc) + timestamp = datetime.now(UTC) return TestOutcomeItem( name="My_test_name", status=TestStatus.PASSED, diff --git a/testlib/peewee.py b/testlib/peewee.py index 2240fca..ace6dca 100644 --- a/testlib/peewee.py +++ b/testlib/peewee.py @@ -1,9 +1,10 @@ import contextlib from unittest.mock import Mock, patch +from typing import Any @contextlib.contextmanager -def patch_select(target: str, **kwargs): +def patch_select(target: str, **kwargs: Any) -> Any: with patch(target=f"{target}.select") as select_mock: select_mock.return_value = select_mock for attr in ("join", "left_outer_join", "switch", "order_by", "where"):