Source code for airflow.providers.amazon.aws.triggers.emr

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import asyncio
import sys
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING

from asgiref.sync import sync_to_async

from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.triggers.base import TriggerEvent
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
    from sqlalchemy.orm.session import Session

    from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook

if not AIRFLOW_V_3_0_PLUS:
    from airflow.models.taskinstance import TaskInstance
    from airflow.utils.session import provide_session


[docs] class EmrAddStepsTrigger(AwsBaseWaiterTrigger): """ Poll for the status of EMR steps until they reach terminal state. :param job_flow_id: job_flow_id which contains the steps to check the state of :param step_ids: steps to check the state of :param waiter_delay: polling period in seconds to check for the status :param waiter_max_attempts: The maximum number of attempts to be made :param aws_conn_id: Reference to AWS connection id """ def __init__( self, job_flow_id: str, step_ids: list[str], waiter_delay: int, waiter_max_attempts: int, aws_conn_id: str | None = "aws_default", ): super().__init__( serialized_fields={"job_flow_id": job_flow_id, "step_ids": step_ids}, waiter_name="steps_wait_for_terminal", waiter_args={"ClusterId": job_flow_id, "StepIds": step_ids}, failure_message=f"Error while waiting for steps {step_ids} to complete", status_message=f"Step ids: {step_ids}, Steps are still in non-terminal state", status_queries=[ "Steps[].Status.State", "Steps[].Status.FailureDetails", ], return_value=step_ids, waiter_delay=waiter_delay, waiter_max_attempts=waiter_max_attempts, aws_conn_id=aws_conn_id, )
[docs] def hook(self) -> AwsGenericHook: return EmrHook(aws_conn_id=self.aws_conn_id)
[docs] class EmrCreateJobFlowTrigger(AwsBaseWaiterTrigger): """ Asynchronously poll the boto3 API and wait for the JobFlow to finish executing. :param job_flow_id: The id of the job flow to wait for. :param waiter_delay: The amount of time in seconds to wait between attempts. :param waiter_max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. """ def __init__( self, job_flow_id: str, aws_conn_id: str | None = None, waiter_delay: int = 30, waiter_max_attempts: int = 60, waiter_name: str = "job_flow_waiting", ): super().__init__( serialized_fields={"job_flow_id": job_flow_id}, waiter_name=waiter_name, waiter_args={"ClusterId": job_flow_id}, failure_message="JobFlow creation failed", status_message="JobFlow creation in progress", status_queries=[ "Cluster.Status.State", "Cluster.Status.StateChangeReason", "Cluster.Status.ErrorDetails", ], return_key="job_flow_id", return_value=job_flow_id, waiter_delay=waiter_delay, waiter_max_attempts=waiter_max_attempts, aws_conn_id=aws_conn_id, )
[docs] def hook(self) -> AwsGenericHook: return EmrHook(aws_conn_id=self.aws_conn_id)
[docs] class EmrTerminateJobFlowTrigger(AwsBaseWaiterTrigger): """ Asynchronously poll the boto3 API and wait for the JobFlow to finish terminating. :param job_flow_id: ID of the EMR Job Flow to terminate :param waiter_delay: The amount of time in seconds to wait between attempts. :param waiter_max_attempts: The maximum number of attempts to be made. :param aws_conn_id: The Airflow connection used for AWS credentials. """ def __init__( self, job_flow_id: str, aws_conn_id: str | None = None, waiter_delay: int = 30, waiter_max_attempts: int = 60, ): super().__init__( serialized_fields={"job_flow_id": job_flow_id}, waiter_name="job_flow_terminated", waiter_args={"ClusterId": job_flow_id}, failure_message="JobFlow termination failed", status_message="JobFlow termination in progress", status_queries=[ "Cluster.Status.State", "Cluster.Status.StateChangeReason", "Cluster.Status.ErrorDetails", ], return_value=None, waiter_delay=waiter_delay, waiter_max_attempts=waiter_max_attempts, aws_conn_id=aws_conn_id, )
[docs] def hook(self) -> AwsGenericHook: return EmrHook(aws_conn_id=self.aws_conn_id)
[docs] class EmrContainerTrigger(AwsBaseWaiterTrigger): """ Poll for the status of EMR container until reaches terminal state. :param virtual_cluster_id: Reference Emr cluster id :param job_id: job_id to check the state :param aws_conn_id: Reference to AWS connection id :param waiter_delay: polling period in seconds to check for the status :param waiter_max_attempts: The maximum number of attempts to be made. Defaults to an infinite wait. """ def __init__( self, virtual_cluster_id: str, job_id: str, aws_conn_id: str | None = "aws_default", waiter_delay: int = 30, waiter_max_attempts: int = sys.maxsize, ): super().__init__( serialized_fields={"virtual_cluster_id": virtual_cluster_id, "job_id": job_id}, waiter_name="container_job_complete", waiter_args={"id": job_id, "virtualClusterId": virtual_cluster_id}, failure_message="Job failed", status_message="Job in progress", status_queries=["jobRun.state", "jobRun.failureReason"], return_key="job_id", return_value=job_id, waiter_delay=waiter_delay, waiter_max_attempts=waiter_max_attempts, aws_conn_id=aws_conn_id, )
[docs] def hook(self) -> AwsGenericHook: return EmrContainerHook(aws_conn_id=self.aws_conn_id)
[docs] class EmrStepSensorTrigger(AwsBaseWaiterTrigger): """ Poll for the status of EMR container until reaches terminal state. :param job_flow_id: job_flow_id which contains the step check the state of :param step_id: step to check the state of :param waiter_delay: polling period in seconds to check for the status :param waiter_max_attempts: The maximum number of attempts to be made :param aws_conn_id: Reference to AWS connection id """ def __init__( self, job_flow_id: str, step_id: str, waiter_delay: int = 30, waiter_max_attempts: int = 60, aws_conn_id: str | None = "aws_default", ): super().__init__( serialized_fields={"job_flow_id": job_flow_id, "step_id": step_id}, waiter_name="step_wait_for_terminal", waiter_args={"ClusterId": job_flow_id, "StepId": step_id}, failure_message=f"Error while waiting for step {step_id} to complete", status_message=f"Step id: {step_id}, Step is still in non-terminal state", status_queries=[ "Step.Status.State", "Step.Status.FailureDetails", "Step.Status.StateChangeReason", ], return_value=None, waiter_delay=waiter_delay, waiter_max_attempts=waiter_max_attempts, aws_conn_id=aws_conn_id, )
[docs] def hook(self) -> AwsGenericHook: return EmrHook(aws_conn_id=self.aws_conn_id)
[docs] class EmrServerlessCreateApplicationTrigger(AwsBaseWaiterTrigger): """ Poll an Emr Serverless application and wait for it to be created. :param application_id: The ID of the application being polled. :waiter_delay: polling period in seconds to check for the status :param waiter_max_attempts: The maximum number of attempts to be made :param aws_conn_id: Reference to AWS connection id """ def __init__( self, application_id: str, waiter_delay: int = 30, waiter_max_attempts: int = 60, aws_conn_id: str | None = "aws_default", ) -> None: super().__init__( serialized_fields={"application_id": application_id}, waiter_name="serverless_app_created", waiter_args={"applicationId": application_id}, failure_message="Application creation failed", status_message="Application status is", status_queries=["application.state", "application.stateDetails"], return_key="application_id", return_value=application_id, waiter_delay=waiter_delay, waiter_max_attempts=waiter_max_attempts, aws_conn_id=aws_conn_id, )
[docs] def hook(self) -> AwsGenericHook: return EmrServerlessHook(self.aws_conn_id)
[docs] class EmrServerlessStartApplicationTrigger(AwsBaseWaiterTrigger): """ Poll an Emr Serverless application and wait for it to be started. :param application_id: The ID of the application being polled. :waiter_delay: polling period in seconds to check for the status :param waiter_max_attempts: The maximum number of attempts to be made :param aws_conn_id: Reference to AWS connection id """ def __init__( self, application_id: str, waiter_delay: int = 30, waiter_max_attempts: int = 60, aws_conn_id: str | None = "aws_default", ) -> None: super().__init__( serialized_fields={"application_id": application_id}, waiter_name="serverless_app_started", waiter_args={"applicationId": application_id}, failure_message="Application failed to start", status_message="Application status is", status_queries=["application.state", "application.stateDetails"], return_key="application_id", return_value=application_id, waiter_delay=waiter_delay, waiter_max_attempts=waiter_max_attempts, aws_conn_id=aws_conn_id, )
[docs] def hook(self) -> AwsGenericHook: return EmrServerlessHook(self.aws_conn_id)
[docs] class EmrServerlessStopApplicationTrigger(AwsBaseWaiterTrigger): """ Poll an Emr Serverless application and wait for it to be stopped. :param application_id: The ID of the application being polled. :waiter_delay: polling period in seconds to check for the status :param waiter_max_attempts: The maximum number of attempts to be made :param aws_conn_id: Reference to AWS connection id. """ def __init__( self, application_id: str, waiter_delay: int = 30, waiter_max_attempts: int = 60, aws_conn_id: str | None = "aws_default", ) -> None: super().__init__( serialized_fields={"application_id": application_id}, waiter_name="serverless_app_stopped", waiter_args={"applicationId": application_id}, failure_message="Application failed to start", status_message="Application status is", status_queries=["application.state", "application.stateDetails"], return_key="application_id", return_value=application_id, waiter_delay=waiter_delay, waiter_max_attempts=waiter_max_attempts, aws_conn_id=aws_conn_id, )
[docs] def hook(self) -> AwsGenericHook: return EmrServerlessHook(self.aws_conn_id)
[docs] class EmrServerlessStartJobTrigger(AwsBaseWaiterTrigger): """ Poll an Emr Serverless job run and wait for it to be completed. :param application_id: The ID of the application the job in being run on. :param job_id: The ID of the job run. :param waiter_delay: polling period in seconds to check for the status :param waiter_max_attempts: The maximum number of attempts to be made :param aws_conn_id: Reference to AWS connection id :param cancel_on_kill: Flag to indicate whether to cancel the job when the task is killed. """ def __init__( self, application_id: str, job_id: str | None, waiter_delay: int = 30, waiter_max_attempts: int = 60, aws_conn_id: str | None = "aws_default", cancel_on_kill: bool = True, ) -> None: super().__init__( serialized_fields={ "application_id": application_id, "job_id": job_id, "cancel_on_kill": cancel_on_kill, }, waiter_name="serverless_job_completed", waiter_args={"applicationId": application_id, "jobRunId": job_id}, failure_message="Serverless Job failed", status_message="Serverless Job status is", status_queries=["jobRun.state", "jobRun.stateDetails"], return_key="job_details", return_value={"application_id": application_id, "job_id": job_id}, waiter_delay=waiter_delay, waiter_max_attempts=waiter_max_attempts, aws_conn_id=aws_conn_id, )
[docs] self.application_id = application_id
[docs] self.job_id = job_id
[docs] self.cancel_on_kill = cancel_on_kill
[docs] def hook(self) -> AwsGenericHook: return EmrServerlessHook(self.aws_conn_id)
if not AIRFLOW_V_3_0_PLUS: @provide_session
[docs] def get_task_instance(self, session: Session) -> TaskInstance: """Get the task instance for the current trigger (Airflow 2.x compatibility).""" from sqlalchemy import select query = select(TaskInstance).where( TaskInstance.dag_id == self.task_instance.dag_id, TaskInstance.task_id == self.task_instance.task_id, TaskInstance.run_id == self.task_instance.run_id, TaskInstance.map_index == self.task_instance.map_index, ) task_instance = session.scalars(query).one_or_none() if task_instance is None: raise ValueError( f"TaskInstance with dag_id: {self.task_instance.dag_id}, " f"task_id: {self.task_instance.task_id}, " f"run_id: {self.task_instance.run_id} and " f"map_index: {self.task_instance.map_index} is not found" ) return task_instance
[docs] async def get_task_state(self): """Get the current state of the task instance (Airflow 3.x).""" from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)( dag_id=self.task_instance.dag_id, task_ids=[self.task_instance.task_id], run_ids=[self.task_instance.run_id], map_index=self.task_instance.map_index, ) try: task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id] except Exception: raise ValueError( f"TaskInstance with dag_id: {self.task_instance.dag_id}, " f"task_id: {self.task_instance.task_id}, " f"run_id: {self.task_instance.run_id} and " f"map_index: {self.task_instance.map_index} is not found" ) return task_state
[docs] async def safe_to_cancel(self) -> bool: """ Whether it is safe to cancel the EMR Serverless job. Returns True if task is NOT DEFERRED (user-initiated cancellation). Returns False if task is DEFERRED (triggerer restart - don't cancel job). """ if AIRFLOW_V_3_0_PLUS: task_state = await self.get_task_state() else: task_instance = self.get_task_instance() # type: ignore[call-arg] task_state = task_instance.state return task_state != TaskInstanceState.DEFERRED
[docs] async def run(self) -> AsyncIterator[TriggerEvent]: """ Run the trigger and wait for the job to complete. If the task is cancelled while waiting, attempt to cancel the EMR Serverless job if cancel_on_kill is enabled and it's safe to do so. """ hook = self.hook() try: async with await hook.get_async_conn() as client: waiter = hook.get_waiter( self.waiter_name, deferrable=True, client=client, config_overrides=self.waiter_config_overrides, ) await async_wait( waiter, self.waiter_delay, self.attempts, self.waiter_args, self.failure_message, self.status_message, self.status_queries, ) yield TriggerEvent({"status": "success", self.return_key: self.return_value}) except asyncio.CancelledError: if self.job_id and self.cancel_on_kill and await self.safe_to_cancel(): self.log.info( "Task was cancelled. Cancelling EMR Serverless job. Application ID: %s, Job ID: %s", self.application_id, self.job_id, ) hook.conn.cancel_job_run(applicationId=self.application_id, jobRunId=self.job_id) self.log.info("EMR Serverless job %s cancelled successfully.", self.job_id) else: self.log.info( "Trigger may have shutdown or cancel_on_kill is disabled. " "Skipping job cancellation. Application ID: %s, Job ID: %s", self.application_id, self.job_id, ) raise except Exception as e: yield TriggerEvent({"status": "failure", "message": str(e)})
[docs] class EmrServerlessDeleteApplicationTrigger(AwsBaseWaiterTrigger): """ Poll an Emr Serverless application and wait for it to be deleted. :param application_id: The ID of the application being polled. :waiter_delay: polling period in seconds to check for the status :param waiter_max_attempts: The maximum number of attempts to be made :param aws_conn_id: Reference to AWS connection id """ def __init__( self, application_id: str, waiter_delay: int = 30, waiter_max_attempts: int = 60, aws_conn_id: str | None = "aws_default", ) -> None: super().__init__( serialized_fields={"application_id": application_id}, waiter_name="serverless_app_terminated", waiter_args={"applicationId": application_id}, failure_message="Application failed to start", status_message="Application status is", status_queries=["application.state", "application.stateDetails"], return_key="application_id", return_value=application_id, waiter_delay=waiter_delay, waiter_max_attempts=waiter_max_attempts, aws_conn_id=aws_conn_id, )
[docs] def hook(self) -> AwsGenericHook: return EmrServerlessHook(self.aws_conn_id)
[docs] class EmrServerlessCancelJobsTrigger(AwsBaseWaiterTrigger): """ Trigger for canceling a list of jobs in an EMR Serverless application. :param application_id: EMR Serverless application ID :param aws_conn_id: Reference to AWS connection id :param waiter_delay: Delay in seconds between each attempt to check the status :param waiter_max_attempts: Maximum number of attempts to check the status """ def __init__( self, application_id: str, aws_conn_id: str | None, waiter_delay: int, waiter_max_attempts: int, ) -> None: states = list(EmrServerlessHook.JOB_INTERMEDIATE_STATES.union({"CANCELLING"})) super().__init__( serialized_fields={"application_id": application_id}, waiter_name="no_job_running", waiter_args={"applicationId": application_id, "states": states}, failure_message="Error while waiting for jobs to cancel", status_message="Currently running jobs", status_queries=["jobRuns[*].applicationId", "jobRuns[*].state"], return_key="application_id", return_value=application_id, waiter_delay=waiter_delay, waiter_max_attempts=waiter_max_attempts, aws_conn_id=aws_conn_id, )
[docs] def hook(self) -> AwsGenericHook: return EmrServerlessHook(self.aws_conn_id)
@property
[docs] def hook_instance(self) -> AwsGenericHook: """This property is added for backward compatibility.""" return self.hook()

Was this entry helpful?