Source code for airflow.providers.apache.livy.triggers.livy

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""This module contains the Apache Livy Trigger."""

from __future__ import annotations

import asyncio
from datetime import datetime, timedelta, timezone
from typing import Any, AsyncIterator

from airflow.providers.apache.livy.hooks.livy import BatchState, LivyAsyncHook
from airflow.triggers.base import BaseTrigger, TriggerEvent


[docs]class LivyTrigger(BaseTrigger): """ Check for the state of a previously submitted job with batch_id. :param batch_id: Batch job id :param spark_params: Spark parameters; for example, spark_params = {"file": "test/pi.py", "class_name": "org.apache.spark.examples.SparkPi", "args": ["/usr/lib/spark/bin/run-example", "SparkPi", "10"],"jars": "command-runner.jar", "driver_cores": 1, "executor_cores": 4,"num_executors": 1} :param livy_conn_id: reference to a pre-defined Livy Connection. :param polling_interval: time in seconds between polling for job completion. If poll_interval=0, in that case return the batch_id and if polling_interval > 0, poll the livy job for termination in the polling interval defined. :param extra_options: A dictionary of options, where key is string and value depends on the option that's being modified. :param extra_headers: A dictionary of headers passed to the HTTP request to livy. :param livy_hook_async: LivyAsyncHook object """ def __init__( self, batch_id: int | str, spark_params: dict[Any, Any], livy_conn_id: str = "livy_default", polling_interval: int = 0, extra_options: dict[str, Any] | None = None, extra_headers: dict[str, Any] | None = None, livy_hook_async: LivyAsyncHook | None = None, execution_timeout: timedelta | None = None, ): super().__init__() self._batch_id = batch_id self.spark_params = spark_params self._livy_conn_id = livy_conn_id self._polling_interval = polling_interval self._extra_options = extra_options self._extra_headers = extra_headers self._livy_hook_async = livy_hook_async self._execution_timeout = execution_timeout
[docs] def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize LivyTrigger arguments and classpath.""" return ( "airflow.providers.apache.livy.triggers.livy.LivyTrigger", { "batch_id": self._batch_id, "spark_params": self.spark_params, "livy_conn_id": self._livy_conn_id, "polling_interval": self._polling_interval, "extra_options": self._extra_options, "extra_headers": self._extra_headers, "livy_hook_async": self._livy_hook_async, "execution_timeout": self._execution_timeout, }, )
[docs] async def run(self) -> AsyncIterator[TriggerEvent]: """ Run the trigger. If ``_polling_interval > 0``, this pools Livy for batch termination asynchronously. Otherwise the success response is created immediately. """ try: if self._polling_interval > 0: response = await self.poll_for_termination(self._batch_id) yield TriggerEvent(response) yield TriggerEvent( { "status": "success", "batch_id": self._batch_id, "response": f"Batch {self._batch_id} succeeded", "log_lines": None, } ) except Exception as exc: yield TriggerEvent( { "status": "error", "batch_id": self._batch_id, "response": f"Batch {self._batch_id} did not succeed with {exc}", "log_lines": None, } )
[docs] async def poll_for_termination(self, batch_id: int | str) -> dict[str, Any]: """ Pool Livy for batch termination asynchronously. :param batch_id: id of the batch session to monitor. """ if self._execution_timeout is not None: timeout_datetime = datetime.now(timezone.utc) + self._execution_timeout else: timeout_datetime = None batch_execution_timed_out = False hook = self._get_async_hook() state = await hook.get_batch_state(batch_id) self.log.info("Batch with id %s is in state: %s", batch_id, state["batch_state"].value) while state["batch_state"] not in hook.TERMINAL_STATES: self.log.info("Batch with id %s is in state: %s", batch_id, state["batch_state"].value) batch_execution_timed_out = ( timeout_datetime is not None and datetime.now(timezone.utc) > timeout_datetime ) if batch_execution_timed_out: break self.log.info("Sleeping for %s seconds", self._polling_interval) await asyncio.sleep(self._polling_interval) state = await hook.get_batch_state(batch_id) log_lines = await hook.dump_batch_logs(batch_id) if batch_execution_timed_out: self.log.info( "Batch with id %s did not terminate, but it reached execution timeout.", batch_id, ) return { "status": "timeout", "batch_id": batch_id, "response": f"Batch {batch_id} timed out", "log_lines": log_lines, } self.log.info("Batch with id %s terminated with state: %s", batch_id, state["batch_state"].value) if state["batch_state"] != BatchState.SUCCESS: return { "status": "error", "batch_id": batch_id, "response": f"Batch {batch_id} did not succeed", "log_lines": log_lines, } return { "status": "success", "batch_id": batch_id, "response": f"Batch {batch_id} succeeded", "log_lines": log_lines, }
def _get_async_hook(self) -> LivyAsyncHook: if self._livy_hook_async is None or not isinstance(self._livy_hook_async, LivyAsyncHook): self._livy_hook_async = LivyAsyncHook( livy_conn_id=self._livy_conn_id, extra_headers=self._extra_headers, extra_options=self._extra_options, ) return self._livy_hook_async

Was this entry helpful?