99import pytest
1010from airflow .models import DagBag , DagModel , DagRun , DagTag
1111from airflow .models .dag import DagOwnerAttributes
12+ from airflow .models .dag_version import DagVersion
1213from airflow .models .serialized_dag import SerializedDagModel
13- from airflow .providers .common .compat .sdk import DAG , DagRunState , TaskInstanceState
14+ from airflow .providers .common .compat .sdk import DagRunState , TaskInstanceState
1415from airflow .utils .session import create_session
1516from airflow .utils .types import DagRunType
1617from dbt .contracts .results import RunStatus , TestStatus
2223 DbtSourceFreshnessOperator ,
2324 DbtTestOperator ,
2425)
25- from airflow_dbt_python .utils .version import AIRFLOW_V_3_0_PLUS , AIRFLOW_V_3_1_PLUS
26+ from airflow_dbt_python .utils .version import (
27+ AIRFLOW_V_3_0 ,
28+ AIRFLOW_V_3_0_PLUS ,
29+ AIRFLOW_V_3_1_PLUS ,
30+ )
31+
32+ if AIRFLOW_V_3_0 :
33+ # For some reason Airflow 3.0 cannot use dag.test()
34+ from airflow import DAG
35+ else :
36+ from airflow .providers .common .compat .sdk import DAG
2637
2738DATA_INTERVAL_START = pendulum .datetime (2022 , 1 , 1 , tz = "UTC" )
2839DATA_INTERVAL_END = DATA_INTERVAL_START + dt .timedelta (hours = 1 )
@@ -33,6 +44,7 @@ def sync_dag_to_db(
3344 bundle_name : str = "testing" ,
3445):
3546 """Sync dags into the database."""
47+ from airflow .models .dagbundle import DagBundleModel
3648 from airflow .models .serialized_dag import SerializedDagModel
3749 from airflow .serialization .serialized_objects import (
3850 LazyDeserializedDAG ,
@@ -41,11 +53,8 @@ def sync_dag_to_db(
4153 from airflow .utils .session import create_session
4254
4355 with create_session () as session :
44- if AIRFLOW_V_3_1_PLUS :
45- from airflow .models .dagbundle import DagBundleModel
46-
47- session .merge (DagBundleModel (name = bundle_name ))
48- session .flush ()
56+ session .merge (DagBundleModel (name = bundle_name ))
57+ session .flush ()
4958
5059 def _write_dag (dag : DAG ) -> SerializedDAG :
5160 if not SerializedDagModel .has_dag (dag .dag_id ):
@@ -68,9 +77,26 @@ def _create_dagrun(
6877 start_date : dt .datetime ,
6978 run_type : DagRunType ,
7079) -> DagRun :
71- if AIRFLOW_V_3_0_PLUS :
80+ if AIRFLOW_V_3_1_PLUS :
7281 return parent_dag .test ()
7382
83+ elif AIRFLOW_V_3_0_PLUS :
84+ from airflow .utils .types import DagRunTriggeredByType # type: ignore
85+
86+ return parent_dag .create_dagrun ( # type: ignore
87+ run_id = f"{ parent_dag .dag_id } -{ logical_date .isoformat ()} -RUN" ,
88+ state = state ,
89+ logical_date = logical_date ,
90+ data_interval = data_interval ,
91+ start_date = start_date ,
92+ conf = {},
93+ backfill_id = None ,
94+ creating_job_id = None ,
95+ run_type = run_type ,
96+ run_after = dt .datetime (1970 , 1 , 1 , 0 , 0 , 0 , tzinfo = dt .timezone .utc ),
97+ triggered_by = DagRunTriggeredByType .TIMETABLE ,
98+ )
99+
74100 else :
75101 return parent_dag .create_dagrun ( # type: ignore
76102 state = state ,
@@ -102,9 +128,7 @@ def test_dags_loaded(dagbag):
102128@pytest .fixture
103129def testing_dag_bundle ():
104130 """Create a DAG bundle for tests."""
105- from airflow_dbt_python .utils .version import AIRFLOW_V_3_1_PLUS
106-
107- if AIRFLOW_V_3_1_PLUS :
131+ if AIRFLOW_V_3_0_PLUS :
108132 from airflow .models .dagbundle import DagBundleModel
109133
110134 with create_session () as session :
@@ -125,13 +149,14 @@ def _clear_db():
125149
126150 session .query (DagFavorite ).delete ()
127151
152+ session .query (DagVersion ).delete ()
128153 session .query (DagTag ).delete ()
129154 session .query (DagOwnerAttributes ).delete ()
130155 session .query (DagRun ).delete ()
131156 session .query (DagModel ).delete ()
132157 session .query (SerializedDagModel ).delete ()
133158
134- if AIRFLOW_V_3_1_PLUS :
159+ if AIRFLOW_V_3_0_PLUS :
135160 from airflow .models .dagbundle import DagBundleModel
136161
137162 session .query (DagBundleModel ).delete ()
@@ -303,7 +328,7 @@ def prepare_dbt_project_dir() -> str:
303328
304329 d = generate_dag ()
305330
306- if AIRFLOW_V_3_0_PLUS :
331+ if AIRFLOW_V_3_1_PLUS :
307332 sync_dag_to_db (d )
308333
309334 return d
@@ -315,8 +340,8 @@ def test_dbt_operators_in_taskflow_dag(
315340 profiles_file ,
316341):
317342 """Assert DAG contains correct dbt operators when running."""
318- if AIRFLOW_V_3_0_PLUS :
319- dag = taskflow_dag
343+ if AIRFLOW_V_3_0 :
344+ dag = DAG . from_sdk_dag ( taskflow_dag ) # type: ignore
320345 else :
321346 dag = taskflow_dag
322347
@@ -445,7 +470,7 @@ def target_connection_dag(
445470
446471 dbt_seed >> dbt_run >> dbt_test
447472
448- if AIRFLOW_V_3_0_PLUS :
473+ if AIRFLOW_V_3_1_PLUS :
449474 sync_dag_to_db (dag )
450475 return dag
451476
@@ -510,6 +535,9 @@ def test_example_basic_dag(
510535 """Test the example basic DAG."""
511536 dag = dagbag .get_dag (dag_id = "example_basic_dbt" )
512537
538+ if AIRFLOW_V_3_0 :
539+ dag = DAG .from_sdk_dag (dag ) # type: ignore
540+
513541 assert dag is not None
514542 assert len (dag .tasks ) == 1
515543
@@ -525,7 +553,7 @@ def test_example_basic_dag(
525553 dbt_run .target = "test"
526554 dbt_run .profile = "default"
527555
528- if AIRFLOW_V_3_0_PLUS :
556+ if AIRFLOW_V_3_1_PLUS :
529557 sync_dag_to_db (dag )
530558
531559 dagrun = _create_dagrun (
@@ -578,9 +606,12 @@ def test_example_dbt_project_in_github_dag(
578606 assert dag is not None
579607 assert len (dag .tasks ) == 3
580608
581- if AIRFLOW_V_3_0_PLUS :
609+ if AIRFLOW_V_3_1_PLUS :
582610 sync_dag_to_db (dag )
583611
612+ if AIRFLOW_V_3_0 :
613+ dag = DAG .from_sdk_dag (dag ) # type: ignore
614+
584615 dagrun = _create_dagrun (
585616 dag ,
586617 state = DagRunState .RUNNING ,
@@ -632,9 +663,12 @@ def test_example_complete_dbt_workflow_dag(
632663 assert dag is not None
633664 assert len (dag .tasks ) == 5
634665
635- if AIRFLOW_V_3_0_PLUS :
666+ if AIRFLOW_V_3_1_PLUS :
636667 sync_dag_to_db (dag )
637668
669+ if AIRFLOW_V_3_0 :
670+ dag = DAG .from_sdk_dag (dag ) # type: ignore
671+
638672 dagrun = _create_dagrun (
639673 dag ,
640674 state = DagRunState .RUNNING ,
0 commit comments