Source code for airflow.providers.standard.operators.trigger_dagrun

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import datetime
import inspect
import json
import time
from collections.abc import Sequence
from json import JSONDecodeError
from typing import TYPE_CHECKING, Any, cast, overload

from sqlalchemy import select
from sqlalchemy.orm.exc import NoResultFound

from airflow.api.common.trigger_dag import trigger_dag
from airflow.exceptions import DagNotFound, DagRunAlreadyExists
from airflow.models.dag import DagModel
from airflow.models.dagrun import DagRun
from airflow.models.serialized_dag import SerializedDagModel
from airflow.providers.common.compat.sdk import (
    AirflowException,
    AirflowSkipException,
    BaseOperatorLink,
    XCom,
    conf,
    timezone,
)
from airflow.providers.standard.triggers.external_task import DagStateTrigger
from airflow.providers.standard.utils.openlineage import safe_inject_openlineage_properties_into_dagrun_conf
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, BaseOperator, is_arg_set
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType

try:
    from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
except ImportError:
    from airflow.utils.types import NOTSET, ArgNotSet  # type: ignore[attr-defined,no-redef]

[docs] XCOM_LOGICAL_DATE_ISO = "trigger_logical_date_iso"
[docs] XCOM_RUN_ID = "trigger_run_id"
if TYPE_CHECKING: from sqlalchemy.orm.session import Session from airflow.providers.common.compat.sdk import Context, TaskInstanceKey
[docs] class DagIsPaused(AirflowException): """Raise when a dag is paused and something tries to run it.""" def __init__(self, dag_id: str) -> None: super().__init__(dag_id)
[docs] self.dag_id = dag_id
[docs] def __str__(self) -> str: return f"Dag {self.dag_id} is paused"
[docs] class TriggerDagRunOperator(BaseOperator): """ Triggers a DAG run for a specified DAG ID. Note that if database isolation mode is enabled, not all features are supported. :param trigger_dag_id: The ``dag_id`` of the DAG to trigger (templated). :param trigger_run_id: The run ID to use for the triggered DAG run (templated). If not provided, a run ID will be automatically generated. :param conf: Configuration for the DAG run (templated). :param logical_date: Logical date for the triggered DAG (templated). :param run_after: The date before which the triggered DAG should not run. :param reset_dag_run: Whether clear existing DAG run if already exists. This is useful when backfill or rerun an existing DAG run. This only resets (not recreates) the DAG run. DAG run conf is immutable and will not be reset on rerun of an existing DAG run. When reset_dag_run=False and dag run exists, DagRunAlreadyExists will be raised. When reset_dag_run=True and dag run exists, existing DAG run will be cleared to rerun. :param wait_for_completion: Whether or not wait for DAG run completion. (default: False) :param poke_interval: Poke interval to check DAG run status when wait_for_completion=True. (default: 60) :param allowed_states: Optional list of allowed DAG run states of the triggered DAG. This is useful when setting ``wait_for_completion`` to True. Must be a valid DagRunState. Default is ``[DagRunState.SUCCESS]``. :param failed_states: Optional list of failed or disallowed DAG run states of the triggered DAG. This is useful when setting ``wait_for_completion`` to True. Must be a valid DagRunState. Default is ``[DagRunState.FAILED]``. :param skip_when_already_exists: Set to true to mark the task as SKIPPED if a DAG run of the triggered DAG for the same logical date already exists. :param fail_when_dag_is_paused: If the dag to trigger is paused, DagIsPaused will be raised. :param deferrable: If waiting for completion, whether to defer the task until done, default is ``False``. :param openlineage_inject_parent_info: whether to include OpenLineage metadata about the parent task in the triggered DAG run's conf, enabling improved lineage tracking. The metadata is only injected if OpenLineage is enabled and running. This option does not modify any other part of the conf, and existing OpenLineage-related settings in the conf will not be overwritten. The injection process is safeguarded against exceptions - if any error occurs during metadata injection, it is gracefully handled and the conf remains unchanged - so it's safe to use. Default is ``True`` """
[docs] template_fields: Sequence[str] = ( "trigger_dag_id", "trigger_run_id", "logical_date", "conf", "wait_for_completion", "skip_when_already_exists", )
[docs] attributes_not_supported_in_airflow_2 = { # `run_after` uses NOTSET here so we can detect whether the user # explicitly provided it and warn in Airflow 2. "run_after": NOTSET, "note": None, }
[docs] template_fields_renderers = {"conf": "py"}
[docs] ui_color = "#ffefeb"
def __init__( self, *, trigger_dag_id: str, trigger_run_id: str | None = None, conf: dict | None = None, logical_date: str | datetime.datetime | None | ArgNotSet = NOTSET, run_after: str | datetime.datetime | None | ArgNotSet = NOTSET, reset_dag_run: bool = False, wait_for_completion: bool = False, poke_interval: int = 60, allowed_states: list[str | DagRunState] | None = None, failed_states: list[str | DagRunState] | None = None, skip_when_already_exists: bool = False, fail_when_dag_is_paused: bool = False, note: str | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), openlineage_inject_parent_info: bool = True, **kwargs, ) -> None: super().__init__(**kwargs)
[docs] self.trigger_dag_id = trigger_dag_id
[docs] self.trigger_run_id = trigger_run_id
[docs] self.conf = conf
[docs] self.reset_dag_run = reset_dag_run
[docs] self.wait_for_completion = wait_for_completion
[docs] self.poke_interval = poke_interval
if allowed_states: self.allowed_states = [DagRunState(s) for s in allowed_states] else: self.allowed_states = [DagRunState.SUCCESS] if failed_states is not None: self.failed_states = [DagRunState(s) for s in failed_states] else: self.failed_states = [DagRunState.FAILED]
[docs] self.skip_when_already_exists = skip_when_already_exists
[docs] self.fail_when_dag_is_paused = fail_when_dag_is_paused
[docs] self.openlineage_inject_parent_info = openlineage_inject_parent_info
[docs] self.note = note
[docs] self.deferrable = deferrable
logical_date = _validate_datetime_param("logical_date", logical_date) run_after = _validate_datetime_param("run_after", run_after)
[docs] self.logical_date = logical_date
[docs] self.run_after = run_after
if fail_when_dag_is_paused and AIRFLOW_V_3_0_PLUS: raise NotImplementedError("Setting `fail_when_dag_is_paused` not yet supported for Airflow 3.x")
[docs] def execute(self, context: Context): if self.logical_date is NOTSET: if self.run_after is not NOTSET: parsed_logical_date = None else: # If no logical_date is provided we will set utcnow() parsed_logical_date = timezone.utcnow() else: logical_date = cast("str | datetime.datetime | None", self.logical_date) parsed_logical_date = _parse_datetime_param(logical_date) if self.run_after is NOTSET: parsed_run_after = parsed_logical_date else: run_after = cast("str | datetime.datetime | None", self.run_after) parsed_run_after = _parse_datetime_param(run_after) try: if self.conf and isinstance(self.conf, str): self.conf = json.loads(self.conf) json.dumps(self.conf) except (TypeError, JSONDecodeError): raise ValueError("conf parameter should be JSON Serializable %s", self.conf) if self.openlineage_inject_parent_info: self.log.debug("Checking if OpenLineage information can be safely injected into dagrun conf.") self.conf = safe_inject_openlineage_properties_into_dagrun_conf( dr_conf=self.conf, ti=context.get("ti") ) if self.trigger_run_id: run_id = str(self.trigger_run_id) else: if AIRFLOW_V_3_0_PLUS: run_id = DagRun.generate_run_id( run_type=DagRunType.MANUAL, logical_date=parsed_logical_date, run_after=parsed_run_after or timezone.utcnow(), ) else: run_id = DagRun.generate_run_id(DagRunType.MANUAL, parsed_logical_date or timezone.utcnow()) # type: ignore[misc,call-arg] # Save run_id as task attribute - to be used by listeners self.trigger_run_id = run_id if self.fail_when_dag_is_paused: dag_model = DagModel.get_current(self.trigger_dag_id) if not dag_model: raise ValueError(f"Dag {self.trigger_dag_id} is not found") if dag_model.is_paused: # TODO: enable this when dag state endpoint available from task sdk # if AIRFLOW_V_3_0_PLUS: # raise DagIsPaused(dag_id=self.trigger_dag_id) raise AirflowException(f"Dag {self.trigger_dag_id} is paused") if AIRFLOW_V_3_0_PLUS: self._trigger_dag_af_3( context=context, run_id=self.trigger_run_id, parsed_logical_date=parsed_logical_date, parsed_run_after=parsed_run_after if self.run_after is not NOTSET else None, ) else: self._trigger_dag_af_2( context=context, run_id=self.trigger_run_id, parsed_logical_date=parsed_logical_date )
def _trigger_dag_af_3(self, context, run_id, parsed_logical_date, parsed_run_after=None): from airflow.providers.common.compat.sdk import DagRunTriggerException kwargs_accepted = dict( trigger_dag_id=self.trigger_dag_id, dag_run_id=run_id, conf=self.conf, logical_date=parsed_logical_date, reset_dag_run=self.reset_dag_run, skip_when_already_exists=self.skip_when_already_exists, wait_for_completion=self.wait_for_completion, allowed_states=self.allowed_states, failed_states=self.failed_states, poke_interval=self.poke_interval, deferrable=self.deferrable, ) parameters = inspect.signature(DagRunTriggerException.__init__).parameters if self.note and "note" in parameters: kwargs_accepted["note"] = self.note if parsed_run_after and "run_after" in parameters: kwargs_accepted["run_after"] = parsed_run_after raise DagRunTriggerException(**kwargs_accepted) def _trigger_dag_af_2(self, context, run_id, parsed_logical_date): try: unsupported_parameters = [] for attr, default_value in self.attributes_not_supported_in_airflow_2.items(): value = getattr(self, attr, default_value) if value is not default_value: unsupported_parameters.append(attr) if unsupported_parameters: self.log.warning( "The following parameters are not supported in Airflow 2.x and will be ignored: %s", ", ".join(unsupported_parameters), ) dag_run = trigger_dag( dag_id=self.trigger_dag_id, run_id=run_id, conf=self.conf, execution_date=parsed_logical_date, replace_microseconds=False, ) except DagRunAlreadyExists as e: if self.reset_dag_run: dag_run = e.dag_run self.log.info("Clearing %s on %s", self.trigger_dag_id, dag_run.run_id) # Get target dag object and call clear() dag_model = DagModel.get_current(self.trigger_dag_id) if dag_model is None: raise DagNotFound(f"Dag id {self.trigger_dag_id} not found in DagModel") # Note: here execution fails on database isolation mode. Needs structural changes for AIP-72 dag = SerializedDagModel.get_dag(self.trigger_dag_id) dag.clear(start_date=dag_run.logical_date, end_date=dag_run.logical_date) else: if self.skip_when_already_exists: raise AirflowSkipException( "Skipping due to skip_when_already_exists is set to True and DagRunAlreadyExists" ) raise e if dag_run is None: raise RuntimeError("The dag_run should be set here!") # Store the run id from the dag run (either created or found above) to # be used when creating the extra link on the webserver. ti = context["task_instance"] ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id) if self.wait_for_completion: # Kick off the deferral process if self.deferrable: self.defer( trigger=DagStateTrigger( dag_id=self.trigger_dag_id, states=self.allowed_states + self.failed_states, execution_dates=[dag_run.logical_date], run_ids=[run_id], poll_interval=self.poke_interval, ), method_name="execute_complete", ) # wait for dag to complete while True: self.log.info( "Waiting for %s on %s to become allowed state %s ...", self.trigger_dag_id, run_id, self.allowed_states, ) time.sleep(self.poke_interval) # Note: here execution fails on database isolation mode. Needs structural changes for AIP-72 dag_run.refresh_from_db() state = dag_run.state if state in self.failed_states: raise AirflowException(f"{self.trigger_dag_id} failed with failed states {state}") if state in self.allowed_states: self.log.info("%s finished with allowed state %s", self.trigger_dag_id, state) return
[docs] def execute_complete(self, context: Context, event: tuple[str, dict[str, Any]]): """ Handle task completion after returning from a deferral. Args: context: The Airflow context dictionary. event: A tuple containing the class path of the trigger and the trigger event data. """ # Example event tuple content: # ( # "airflow.providers.standard.triggers.external_task.DagStateTrigger", # { # 'dag_id': 'some_dag', # 'states': ['success', 'failed'], # 'poll_interval': 15, # 'run_ids': ['manual__2025-11-19T17:49:20.907083+00:00'], # 'execution_dates': [ # DateTime(2025, 11, 19, 17, 49, 20, 907083, tzinfo=Timezone('UTC')) # ] # } # ) _, event_data = event run_ids = event_data["run_ids"] # Re-set as attribute after coming back from deferral - to be used by listeners. # Just a safety check on length, we should always have single run_id here. self.trigger_run_id = run_ids[0] if len(run_ids) == 1 else None if AIRFLOW_V_3_0_PLUS: self._trigger_dag_run_af_3_execute_complete(event_data=event_data) else: self._trigger_dag_run_af_2_execute_complete(event_data=event_data)
def _trigger_dag_run_af_3_execute_complete(self, event_data: dict[str, Any]): failed_run_id_conditions = [] for run_id in event_data["run_ids"]: state = event_data.get(run_id) if state in self.failed_states: failed_run_id_conditions.append(run_id) continue if state in self.allowed_states: self.log.info( "%s finished with allowed state %s for run_id %s", self.trigger_dag_id, state, run_id, ) if failed_run_id_conditions: raise AirflowException( f"{self.trigger_dag_id} failed with failed states {self.failed_states} for run_ids" f" {failed_run_id_conditions}" ) if not AIRFLOW_V_3_0_PLUS: from airflow.utils.session import NEW_SESSION, provide_session # type: ignore[misc] @provide_session def _trigger_dag_run_af_2_execute_complete( self, event_data: dict[str, Any], session: Session = NEW_SESSION ): # This logical_date is parsed from the return trigger event provided_logical_date = event_data["execution_dates"][0] try: # Note: here execution fails on database isolation mode. Needs structural changes for AIP-72 dag_run = session.execute( select(DagRun).where( DagRun.dag_id == self.trigger_dag_id, DagRun.execution_date == provided_logical_date ) ).scalar_one() except NoResultFound: raise AirflowException( f"No DAG run found for DAG {self.trigger_dag_id} and logical date {self.logical_date}" ) state = dag_run.state if state in self.failed_states: raise AirflowException(f"{self.trigger_dag_id} failed with failed state {state}") if state in self.allowed_states: self.log.info("%s finished with allowed state %s", self.trigger_dag_id, state) return raise AirflowException( f"{self.trigger_dag_id} return {state} which is not in {self.failed_states}" f" or {self.allowed_states}" )
@overload def _validate_datetime_param(name: str, value: ArgNotSet) -> ArgNotSet: ... @overload def _validate_datetime_param(name: str, value: None) -> None: ... @overload def _validate_datetime_param(name: str, value: str) -> str: ... @overload def _validate_datetime_param(name: str, value: datetime.datetime) -> datetime.datetime: ... def _validate_datetime_param( name: str, value: str | datetime.datetime | None | ArgNotSet, ) -> str | datetime.datetime | None | ArgNotSet: if not is_arg_set(value): return NOTSET if value is None or isinstance(value, (str, datetime.datetime)): return value raise TypeError( f"Expected str, datetime.datetime, or None for parameter '{name}'. Got {type(value).__name__}" ) @overload def _parse_datetime_param(value: None) -> None: ... @overload def _parse_datetime_param(value: datetime.datetime) -> datetime.datetime: ... @overload def _parse_datetime_param(value: str) -> datetime.datetime: ... def _parse_datetime_param( value: str | datetime.datetime | None, ) -> datetime.datetime | None: if value is None or isinstance(value, datetime.datetime): return value return timezone.parse(value)

Was this entry helpful?