Skip to content

Commit 897d228

Browse files
committed
fix: Just run Airflow 3.0 tests like before
1 parent 83cfdff commit 897d228

File tree

2 files changed

+54
-19
lines changed

2 files changed

+54
-19
lines changed

airflow_dbt_python/utils/version.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,4 @@ def _get_base_airflow_version_tuple() -> tuple[int, int, int]:
4040

4141
AIRFLOW_V_3_0_PLUS = _get_base_airflow_version_tuple() >= (3, 0, 0)
4242
AIRFLOW_V_3_1_PLUS = _get_base_airflow_version_tuple() >= (3, 1, 0)
43+
AIRFLOW_V_3_0 = AIRFLOW_V_3_0_PLUS and not AIRFLOW_V_3_1_PLUS

tests/dags/test_dbt_dags.py

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
import pytest
1010
from airflow.models import DagBag, DagModel, DagRun, DagTag
1111
from airflow.models.dag import DagOwnerAttributes
12+
from airflow.models.dag_version import DagVersion
1213
from airflow.models.serialized_dag import SerializedDagModel
13-
from airflow.providers.common.compat.sdk import DAG, DagRunState, TaskInstanceState
14+
from airflow.providers.common.compat.sdk import DagRunState, TaskInstanceState
1415
from airflow.utils.session import create_session
1516
from airflow.utils.types import DagRunType
1617
from dbt.contracts.results import RunStatus, TestStatus
@@ -22,7 +23,17 @@
2223
DbtSourceFreshnessOperator,
2324
DbtTestOperator,
2425
)
25-
from airflow_dbt_python.utils.version import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS
26+
from airflow_dbt_python.utils.version import (
27+
AIRFLOW_V_3_0,
28+
AIRFLOW_V_3_0_PLUS,
29+
AIRFLOW_V_3_1_PLUS,
30+
)
31+
32+
if AIRFLOW_V_3_0:
33+
# For some reason Airflow 3.0 cannot use dag.test()
34+
from airflow import DAG
35+
else:
36+
from airflow.providers.common.compat.sdk import DAG
2637

2738
DATA_INTERVAL_START = pendulum.datetime(2022, 1, 1, tz="UTC")
2839
DATA_INTERVAL_END = DATA_INTERVAL_START + dt.timedelta(hours=1)
@@ -33,6 +44,7 @@ def sync_dag_to_db(
3344
bundle_name: str = "testing",
3445
):
3546
"""Sync dags into the database."""
47+
from airflow.models.dagbundle import DagBundleModel
3648
from airflow.models.serialized_dag import SerializedDagModel
3749
from airflow.serialization.serialized_objects import (
3850
LazyDeserializedDAG,
@@ -41,11 +53,8 @@ def sync_dag_to_db(
4153
from airflow.utils.session import create_session
4254

4355
with create_session() as session:
44-
if AIRFLOW_V_3_1_PLUS:
45-
from airflow.models.dagbundle import DagBundleModel
46-
47-
session.merge(DagBundleModel(name=bundle_name))
48-
session.flush()
56+
session.merge(DagBundleModel(name=bundle_name))
57+
session.flush()
4958

5059
def _write_dag(dag: DAG) -> SerializedDAG:
5160
if not SerializedDagModel.has_dag(dag.dag_id):
@@ -68,9 +77,26 @@ def _create_dagrun(
6877
start_date: dt.datetime,
6978
run_type: DagRunType,
7079
) -> DagRun:
71-
if AIRFLOW_V_3_0_PLUS:
80+
if AIRFLOW_V_3_1_PLUS:
7281
return parent_dag.test()
7382

83+
elif AIRFLOW_V_3_0_PLUS:
84+
from airflow.utils.types import DagRunTriggeredByType # type: ignore
85+
86+
return parent_dag.create_dagrun( # type: ignore
87+
run_id=f"{parent_dag.dag_id}-{logical_date.isoformat()}-RUN",
88+
state=state,
89+
logical_date=logical_date,
90+
data_interval=data_interval,
91+
start_date=start_date,
92+
conf={},
93+
backfill_id=None,
94+
creating_job_id=None,
95+
run_type=run_type,
96+
run_after=dt.datetime(1970, 1, 1, 0, 0, 0, tzinfo=dt.timezone.utc),
97+
triggered_by=DagRunTriggeredByType.TIMETABLE,
98+
)
99+
74100
else:
75101
return parent_dag.create_dagrun( # type: ignore
76102
state=state,
@@ -102,9 +128,7 @@ def test_dags_loaded(dagbag):
102128
@pytest.fixture
103129
def testing_dag_bundle():
104130
"""Create a DAG bundle for tests."""
105-
from airflow_dbt_python.utils.version import AIRFLOW_V_3_1_PLUS
106-
107-
if AIRFLOW_V_3_1_PLUS:
131+
if AIRFLOW_V_3_0_PLUS:
108132
from airflow.models.dagbundle import DagBundleModel
109133

110134
with create_session() as session:
@@ -125,13 +149,14 @@ def _clear_db():
125149

126150
session.query(DagFavorite).delete()
127151

152+
session.query(DagVersion).delete()
128153
session.query(DagTag).delete()
129154
session.query(DagOwnerAttributes).delete()
130155
session.query(DagRun).delete()
131156
session.query(DagModel).delete()
132157
session.query(SerializedDagModel).delete()
133158

134-
if AIRFLOW_V_3_1_PLUS:
159+
if AIRFLOW_V_3_0_PLUS:
135160
from airflow.models.dagbundle import DagBundleModel
136161

137162
session.query(DagBundleModel).delete()
@@ -303,7 +328,7 @@ def prepare_dbt_project_dir() -> str:
303328

304329
d = generate_dag()
305330

306-
if AIRFLOW_V_3_0_PLUS:
331+
if AIRFLOW_V_3_1_PLUS:
307332
sync_dag_to_db(d)
308333

309334
return d
@@ -315,8 +340,8 @@ def test_dbt_operators_in_taskflow_dag(
315340
profiles_file,
316341
):
317342
"""Assert DAG contains correct dbt operators when running."""
318-
if AIRFLOW_V_3_0_PLUS:
319-
dag = taskflow_dag
343+
if AIRFLOW_V_3_0:
344+
dag = DAG.from_sdk_dag(taskflow_dag) # type: ignore
320345
else:
321346
dag = taskflow_dag
322347

@@ -445,7 +470,7 @@ def target_connection_dag(
445470

446471
dbt_seed >> dbt_run >> dbt_test
447472

448-
if AIRFLOW_V_3_0_PLUS:
473+
if AIRFLOW_V_3_1_PLUS:
449474
sync_dag_to_db(dag)
450475
return dag
451476

@@ -510,6 +535,9 @@ def test_example_basic_dag(
510535
"""Test the example basic DAG."""
511536
dag = dagbag.get_dag(dag_id="example_basic_dbt")
512537

538+
if AIRFLOW_V_3_0:
539+
dag = DAG.from_sdk_dag(dag) # type: ignore
540+
513541
assert dag is not None
514542
assert len(dag.tasks) == 1
515543

@@ -525,7 +553,7 @@ def test_example_basic_dag(
525553
dbt_run.target = "test"
526554
dbt_run.profile = "default"
527555

528-
if AIRFLOW_V_3_0_PLUS:
556+
if AIRFLOW_V_3_1_PLUS:
529557
sync_dag_to_db(dag)
530558

531559
dagrun = _create_dagrun(
@@ -578,9 +606,12 @@ def test_example_dbt_project_in_github_dag(
578606
assert dag is not None
579607
assert len(dag.tasks) == 3
580608

581-
if AIRFLOW_V_3_0_PLUS:
609+
if AIRFLOW_V_3_1_PLUS:
582610
sync_dag_to_db(dag)
583611

612+
if AIRFLOW_V_3_0:
613+
dag = DAG.from_sdk_dag(dag) # type: ignore
614+
584615
dagrun = _create_dagrun(
585616
dag,
586617
state=DagRunState.RUNNING,
@@ -632,9 +663,12 @@ def test_example_complete_dbt_workflow_dag(
632663
assert dag is not None
633664
assert len(dag.tasks) == 5
634665

635-
if AIRFLOW_V_3_0_PLUS:
666+
if AIRFLOW_V_3_1_PLUS:
636667
sync_dag_to_db(dag)
637668

669+
if AIRFLOW_V_3_0:
670+
dag = DAG.from_sdk_dag(dag) # type: ignore
671+
638672
dagrun = _create_dagrun(
639673
dag,
640674
state=DagRunState.RUNNING,

0 commit comments

Comments
 (0)