# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import asyncio
import typing
from typing import Any
from asgiref.sync import sync_to_async
from sqlalchemy import func
from airflow.models import DagRun
from airflow.providers.standard.utils.sensor_helper import _get_count
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.session import NEW_SESSION, provide_session
if typing.TYPE_CHECKING:
from datetime import datetime
from sqlalchemy.orm import Session
from airflow.utils.state import DagRunState
[docs]class WorkflowTrigger(BaseTrigger):
"""
A trigger to monitor tasks, task group and dag execution in Apache Airflow.
:param external_dag_id: The ID of the external DAG.
:param logical_dates: A list of logical dates for the external DAG.
:param external_task_ids: A collection of external task IDs to wait for.
:param external_task_group_id: The ID of the external task group to wait for.
:param failed_states: States considered as failed for external tasks.
:param skipped_states: States considered as skipped for external tasks.
:param allowed_states: States considered as successful for external tasks.
:param poke_interval: The interval (in seconds) for poking the external tasks.
:param soft_fail: If True, the trigger will not fail the entire DAG on external task failure.
"""
def __init__(
self,
external_dag_id: str,
logical_dates: list[datetime] | None = None,
execution_dates: list[datetime] | None = None,
external_task_ids: typing.Collection[str] | None = None,
external_task_group_id: str | None = None,
failed_states: typing.Iterable[str] | None = None,
skipped_states: typing.Iterable[str] | None = None,
allowed_states: typing.Iterable[str] | None = None,
poke_interval: float = 2.0,
soft_fail: bool = False,
**kwargs,
):
self.external_dag_id = external_dag_id
self.external_task_ids = external_task_ids
self.external_task_group_id = external_task_group_id
self.failed_states = failed_states
self.skipped_states = skipped_states
self.allowed_states = allowed_states
self.logical_dates = logical_dates
self.poke_interval = poke_interval
self.soft_fail = soft_fail
self.execution_dates = execution_dates
super().__init__(**kwargs)
[docs] def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize the trigger param and module path."""
_dates = (
{"logical_dates": self.logical_dates}
if AIRFLOW_V_3_0_PLUS
else {"execution_dates": self.execution_dates}
)
return (
"airflow.providers.standard.triggers.external_task.WorkflowTrigger",
{
"external_dag_id": self.external_dag_id,
"external_task_ids": self.external_task_ids,
"external_task_group_id": self.external_task_group_id,
"failed_states": self.failed_states,
"skipped_states": self.skipped_states,
"allowed_states": self.allowed_states,
**_dates,
"poke_interval": self.poke_interval,
"soft_fail": self.soft_fail,
},
)
[docs] async def run(self) -> typing.AsyncIterator[TriggerEvent]:
"""Check periodically tasks, task group or dag status."""
while True:
if self.failed_states:
failed_count = await self._get_count(self.failed_states)
if failed_count > 0:
yield TriggerEvent({"status": "failed"})
return
else:
yield TriggerEvent({"status": "success"})
return
if self.skipped_states:
skipped_count = await self._get_count(self.skipped_states)
if skipped_count > 0:
yield TriggerEvent({"status": "skipped"})
return
allowed_count = await self._get_count(self.allowed_states)
_dates = self.logical_dates if AIRFLOW_V_3_0_PLUS else self.execution_dates
if allowed_count == len(_dates): # type: ignore[arg-type]
yield TriggerEvent({"status": "success"})
return
self.log.info("Sleeping for %s seconds", self.poke_interval)
await asyncio.sleep(self.poke_interval)
@sync_to_async
def _get_count(self, states: typing.Iterable[str] | None) -> int:
"""
Get the count of records against dttm filter and states. Async wrapper for _get_count.
:param states: task or dag states
:return The count of records.
"""
return _get_count(
dttm_filter=self.logical_dates if AIRFLOW_V_3_0_PLUS else self.execution_dates,
external_task_ids=self.external_task_ids,
external_task_group_id=self.external_task_group_id,
external_dag_id=self.external_dag_id,
states=states,
)
[docs]class DagStateTrigger(BaseTrigger):
"""
Waits asynchronously for a DAG to complete for a specific logical date.
:param dag_id: The dag_id that contains the task you want to wait for
:param states: allowed states, default is ``['success']``
:param logical_dates: The logical date at which DAG run.
:param poll_interval: The time interval in seconds to check the state.
The default value is 5.0 sec.
"""
def __init__(
self,
dag_id: str,
states: list[DagRunState],
logical_dates: list[datetime] | None = None,
execution_dates: list[datetime] | None = None,
poll_interval: float = 5.0,
):
super().__init__()
self.dag_id = dag_id
self.states = states
self.logical_dates = logical_dates
self.execution_dates = execution_dates
self.poll_interval = poll_interval
[docs] def serialize(self) -> tuple[str, dict[str, typing.Any]]:
"""Serialize DagStateTrigger arguments and classpath."""
_dates = (
{"logical_dates": self.logical_dates}
if AIRFLOW_V_3_0_PLUS
else {"execution_dates": self.execution_dates}
)
return (
"airflow.providers.standard.triggers.external_task.DagStateTrigger",
{
"dag_id": self.dag_id,
"states": self.states,
**_dates,
"poll_interval": self.poll_interval,
},
)
[docs] async def run(self) -> typing.AsyncIterator[TriggerEvent]:
"""Check periodically if the dag run exists, and has hit one of the states yet, or not."""
while True:
# mypy confuses typing here
num_dags = await self.count_dags() # type: ignore[call-arg]
_dates = self.logical_dates if AIRFLOW_V_3_0_PLUS else self.execution_dates
if num_dags == len(_dates): # type: ignore[arg-type]
yield TriggerEvent(self.serialize())
return
await asyncio.sleep(self.poll_interval)
@sync_to_async
@provide_session
[docs] def count_dags(self, *, session: Session = NEW_SESSION) -> int | None:
"""Count how many dag runs in the database match our criteria."""
_dag_run_date_condition = (
DagRun.logical_date.in_(self.logical_dates)
if AIRFLOW_V_3_0_PLUS
else DagRun.execution_date.in_(self.execution_dates)
)
count = (
session.query(func.count("*")) # .count() is inefficient
.filter(
DagRun.dag_id == self.dag_id,
DagRun.state.in_(self.states),
_dag_run_date_condition,
)
.scalar()
)
return typing.cast(int, count)