# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains the Apache Livy Trigger."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
from datetime import datetime, timedelta, timezone
from typing import Any
from airflow.providers.apache.livy.hooks.livy import BatchState, LivyAsyncHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
[docs]class LivyTrigger(BaseTrigger):
"""
Check for the state of a previously submitted job with batch_id.
:param batch_id: Batch job id
:param spark_params: Spark parameters; for example,
spark_params = {"file": "test/pi.py", "class_name": "org.apache.spark.examples.SparkPi",
"args": ["/usr/lib/spark/bin/run-example", "SparkPi", "10"],"jars": "command-runner.jar",
"driver_cores": 1, "executor_cores": 4,"num_executors": 1}
:param livy_conn_id: reference to a pre-defined Livy Connection.
:param polling_interval: time in seconds between polling for job completion. If poll_interval=0, in that
case return the batch_id and if polling_interval > 0, poll the livy job for termination in the
polling interval defined.
:param extra_options: A dictionary of options, where key is string and value
depends on the option that's being modified.
:param extra_headers: A dictionary of headers passed to the HTTP request to livy.
:param livy_hook_async: LivyAsyncHook object
"""
def __init__(
self,
batch_id: int | str,
spark_params: dict[Any, Any],
livy_conn_id: str = "livy_default",
polling_interval: int = 0,
extra_options: dict[str, Any] | None = None,
extra_headers: dict[str, Any] | None = None,
livy_hook_async: LivyAsyncHook | None = None,
execution_timeout: timedelta | None = None,
):
super().__init__()
self._batch_id = batch_id
self.spark_params = spark_params
self._livy_conn_id = livy_conn_id
self._polling_interval = polling_interval
self._extra_options = extra_options
self._extra_headers = extra_headers
self._livy_hook_async = livy_hook_async
self._execution_timeout = execution_timeout
[docs] def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize LivyTrigger arguments and classpath."""
return (
"airflow.providers.apache.livy.triggers.livy.LivyTrigger",
{
"batch_id": self._batch_id,
"spark_params": self.spark_params,
"livy_conn_id": self._livy_conn_id,
"polling_interval": self._polling_interval,
"extra_options": self._extra_options,
"extra_headers": self._extra_headers,
"livy_hook_async": self._livy_hook_async,
"execution_timeout": self._execution_timeout,
},
)
[docs] async def run(self) -> AsyncIterator[TriggerEvent]:
"""
Run the trigger.
If ``_polling_interval > 0``, this pools Livy for batch termination
asynchronously. Otherwise the success response is created immediately.
"""
try:
if self._polling_interval > 0:
response = await self.poll_for_termination(self._batch_id)
yield TriggerEvent(response)
yield TriggerEvent(
{
"status": "success",
"batch_id": self._batch_id,
"response": f"Batch {self._batch_id} succeeded",
"log_lines": None,
}
)
except Exception as exc:
yield TriggerEvent(
{
"status": "error",
"batch_id": self._batch_id,
"response": f"Batch {self._batch_id} did not succeed with {exc}",
"log_lines": None,
}
)
[docs] async def poll_for_termination(self, batch_id: int | str) -> dict[str, Any]:
"""
Pool Livy for batch termination asynchronously.
:param batch_id: id of the batch session to monitor.
"""
if self._execution_timeout is not None:
timeout_datetime = datetime.now(timezone.utc) + self._execution_timeout
else:
timeout_datetime = None
batch_execution_timed_out = False
hook = self._get_async_hook()
state = await hook.get_batch_state(batch_id)
self.log.info("Batch with id %s is in state: %s", batch_id, state["batch_state"].value)
while state["batch_state"] not in hook.TERMINAL_STATES:
self.log.info("Batch with id %s is in state: %s", batch_id, state["batch_state"].value)
batch_execution_timed_out = (
timeout_datetime is not None and datetime.now(timezone.utc) > timeout_datetime
)
if batch_execution_timed_out:
break
self.log.info("Sleeping for %s seconds", self._polling_interval)
await asyncio.sleep(self._polling_interval)
state = await hook.get_batch_state(batch_id)
log_lines = await hook.dump_batch_logs(batch_id)
if batch_execution_timed_out:
self.log.info(
"Batch with id %s did not terminate, but it reached execution timeout.",
batch_id,
)
return {
"status": "timeout",
"batch_id": batch_id,
"response": f"Batch {batch_id} timed out",
"log_lines": log_lines,
}
self.log.info("Batch with id %s terminated with state: %s", batch_id, state["batch_state"].value)
if state["batch_state"] != BatchState.SUCCESS:
return {
"status": "error",
"batch_id": batch_id,
"response": f"Batch {batch_id} did not succeed",
"log_lines": log_lines,
}
return {
"status": "success",
"batch_id": batch_id,
"response": f"Batch {batch_id} succeeded",
"log_lines": log_lines,
}
def _get_async_hook(self) -> LivyAsyncHook:
if self._livy_hook_async is None or not isinstance(self._livy_hook_async, LivyAsyncHook):
self._livy_hook_async = LivyAsyncHook(
livy_conn_id=self._livy_conn_id,
extra_headers=self._extra_headers,
extra_options=self._extra_options,
)
return self._livy_hook_async