#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, cast
import requests
from tenacity import retry, stop_after_attempt, wait_fixed
from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
from airflow.providers.common.compat.openlineage.utils.spark import (
inject_parent_job_information_into_spark_properties,
inject_transport_information_into_spark_properties,
)
from airflow.providers.common.compat.sdk import BaseOperator, conf
try:
from airflow.sdk.bases.resumablemixin import ResumableJobMixin
except ImportError:
# Airflow 2 compat.
# ResumableJobMixin does not exist in Airflow 2, so we need to add a stub to make it
# behave as before
[docs]
class ResumableJobMixin: # type: ignore[no-redef]
"""Airflow 2 stub — no task_state, always submits fresh."""
[docs]
external_id_key: str = "remote_job_id"
[docs]
def execute_resumable(self, context):
external_id = self.submit_job(context)
self.poll_until_complete(external_id, context)
return self.get_job_result(external_id, context)
if TYPE_CHECKING:
from pydantic import JsonValue
from requests.auth import AuthBase
from airflow.providers.common.compat.sdk import Context
[docs]
class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
"""
Wrap the spark-submit binary to kick off a spark-submit job; requires "spark-submit" binary in the PATH.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:SparkSubmitOperator`
:param application: The application that submitted as a job, either jar or py file. (templated)
:param conf: Arbitrary Spark configuration properties (templated)
:param conn_id: The :ref:`spark connection id <howto/connection:spark>` as configured
in Airflow administration. When an invalid connection_id is supplied, it will default to yarn.
:param files: Upload additional files to the executor running the job, separated by a
comma. Files will be placed in the working directory of each executor.
For example, serialized objects. (templated)
:param py_files: Additional python files used by the job, can be .zip, .egg or .py. (templated)
:param jars: Submit additional jars to upload and place them in driver and executor classpaths. (templated)
:param driver_class_path: Additional, driver-specific, classpath settings. (templated)
:param java_class: the main class of the Java application
:param packages: Comma-separated list of maven coordinates of jars to include on the
driver and executor classpaths. (templated)
:param exclude_packages: Comma-separated list of maven coordinates of jars to exclude
while resolving the dependencies provided in 'packages' (templated)
:param repositories: Comma-separated list of additional remote repositories to search
for the maven coordinates given with 'packages'
:param total_executor_cores: (Standalone & Mesos only) Total cores for all executors
(Default: all the available cores on the worker)
:param executor_cores: (Standalone & YARN only) Number of cores per executor (Default: 2)
:param executor_memory: Memory per executor (e.g. 1000M, 2G) (Default: 1G)
:param driver_memory: Memory allocated to the driver (e.g. 1000M, 2G) (Default: 1G)
:param keytab: Full path to the file that contains the keytab (templated)
(will overwrite any keytab defined in the connection's extra JSON)
:param principal: The name of the kerberos principal used for keytab (templated)
(will overwrite any principal defined in the connection's extra JSON)
:param proxy_user: User to impersonate when submitting the application (templated)
:param name: Name of the job (default airflow-spark). (templated)
:param num_executors: Number of executors to launch
:param status_poll_interval: Seconds to wait between polls of driver status in cluster
mode. Used both by the Spark standalone driver-status tracker and (when
``yarn_track_via_rm_api=True``) by the YARN ResourceManager REST API
polling loop. The YARN ResourceManager REST API polling loop uses at
least 10 seconds to avoid flooding the ResourceManager on long-running
jobs (Default: 1).
:param application_args: Arguments for the application being submitted (templated)
:param env_vars: Environment variables for spark-submit. It supports yarn and k8s mode too. (templated)
:param verbose: Whether to pass the verbose flag to spark-submit process for debugging
:param spark_binary: The command to use for spark submit.
Some distros may use spark2-submit or spark3-submit.
(will overwrite any spark_binary defined in the connection's extra JSON)
:param properties_file: Path to a file from which to load extra properties. If not
specified, this will look for conf/spark-defaults.conf.
:param yarn_queue: The name of the YARN queue to which the application is submitted.
(will overwrite any yarn queue defined in the connection's extra JSON)
:param deploy_mode: Whether to deploy your driver on the worker nodes (cluster) or locally as a client.
(will overwrite any deployment mode defined in the connection's extra JSON)
:param use_krb5ccache: if True, configure spark to use ticket cache instead of relying
on keytab for Kerberos login
:param post_submit_commands: Optional list of shell commands to run after the Spark job finishes.
Useful for cleaning up sidecars such as Istio. Failures produce a warning but do not fail the task.
:param yarn_track_via_rm_api: If True (when master is YARN and ``deploy_mode``
is ``cluster``), release the ``spark-submit`` JVM once the application has
been submitted to YARN, then poll the YARN ResourceManager REST API
(``GET /ws/v1/cluster/apps/{appId}``) until the application reaches a
final state. The polling interval is controlled by ``status_poll_interval``
with a 10-second minimum. This frees the worker from holding the
long-lived submit JVM. Requires the Spark connection's ``extra``
JSON to set ``yarn_resourcemanager_webapp_address`` (e.g. ``http://rm:8088``).
Cluster-side driver logs should be used after the switch to polling.
Defaults to ``False``.
:param yarn_rm_auth: Optional ``requests.auth.AuthBase`` instance used for every
call to the YARN ResourceManager REST API (status polling and kill). When
omitted, Kerberos-enabled Spark connections with both ``keytab`` and
``principal`` configured use ``requests-kerberos`` automatically.
Defaults to ``None`` (no auth for non-Kerberos connections).
"""
# Generic key used across all Spark deployment modes (standalone driver ID,
# YARN application ID, K8s driver pod name).
[docs]
external_id_key = "spark_job_id"
[docs]
template_fields: Sequence[str] = (
"application",
"conf",
"files",
"py_files",
"jars",
"driver_class_path",
"packages",
"exclude_packages",
"keytab",
"principal",
"proxy_user",
"name",
"application_args",
"env_vars",
"post_submit_commands",
"properties_file",
)
def __init__(
self,
*,
application: str = "",
conf: dict[Any, Any] | None = None,
conn_id: str = "spark_default",
files: str | None = None,
py_files: str | None = None,
archives: str | None = None,
driver_class_path: str | None = None,
jars: str | None = None,
java_class: str | None = None,
packages: str | None = None,
exclude_packages: str | None = None,
repositories: str | None = None,
total_executor_cores: int | None = None,
executor_cores: int | None = None,
executor_memory: str | None = None,
driver_memory: str | None = None,
keytab: str | None = None,
principal: str | None = None,
proxy_user: str | None = None,
name: str = "arrow-spark",
num_executors: int | None = None,
status_poll_interval: int = 1,
application_args: list[Any] | None = None,
env_vars: dict[str, Any] | None = None,
verbose: bool = False,
spark_binary: str | None = None,
properties_file: str | None = None,
yarn_queue: str | None = None,
deploy_mode: str | None = None,
use_krb5ccache: bool = False,
post_submit_commands: list[str] | None = None,
reconnect_on_retry: bool = True,
yarn_track_via_rm_api: bool = False,
yarn_rm_auth: AuthBase | None = None,
openlineage_inject_parent_job_info: bool = conf.getboolean(
"openlineage", "spark_inject_parent_job_info", fallback=False
),
openlineage_inject_transport_info: bool = conf.getboolean(
"openlineage", "spark_inject_transport_info", fallback=False
),
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
[docs]
self.application = application
[docs]
self.py_files = py_files
self._archives = archives
[docs]
self.driver_class_path = driver_class_path
self._java_class = java_class
[docs]
self.packages = packages
[docs]
self.exclude_packages = exclude_packages
self._repositories = repositories
self._total_executor_cores = total_executor_cores
self._executor_cores = executor_cores
self._executor_memory = executor_memory
self._driver_memory = driver_memory
[docs]
self.principal = principal
[docs]
self.proxy_user = proxy_user
self._num_executors = num_executors
self._status_poll_interval = status_poll_interval
[docs]
self.application_args = application_args
[docs]
self.env_vars = env_vars
self._verbose = verbose
self._spark_binary = spark_binary
[docs]
self.properties_file = properties_file
self._yarn_queue = yarn_queue
self._deploy_mode = deploy_mode
self._hook: SparkSubmitHook | None = None
[docs]
self.post_submit_commands = post_submit_commands
self._post_submit_commands = list(post_submit_commands) if post_submit_commands else []
self._conn_id = conn_id
self._use_krb5ccache = use_krb5ccache
self._yarn_track_via_rm_api = yarn_track_via_rm_api
self._yarn_rm_auth = yarn_rm_auth
[docs]
self.reconnect_on_retry = reconnect_on_retry
self._openlineage_inject_parent_job_info = openlineage_inject_parent_job_info
self._openlineage_inject_transport_info = openlineage_inject_transport_info
[docs]
def execute(self, context: Context) -> None:
"""Call the SparkSubmitHook to run the provided spark job."""
self.conf = self.conf or {}
if self._openlineage_inject_parent_job_info:
self.log.debug("Injecting OpenLineage parent job information into Spark properties.")
self.conf = inject_parent_job_information_into_spark_properties(self.conf, context)
if self._openlineage_inject_transport_info:
self.log.debug("Injecting OpenLineage transport information into Spark properties.")
self.conf = inject_transport_information_into_spark_properties(self.conf, context)
if self._hook is None:
self._hook = self._get_hook()
hook = self._hook
if hook._should_track_driver_status:
if self.reconnect_on_retry:
return self.execute_resumable(context)
# reconnect_on_retry=False: still submit-and-poll, just skip task_state persistence.
driver_id = self.submit_job(context)
self.poll_until_complete(driver_id, context)
return self.get_job_result(driver_id, context)
hook.submit(self.application)
[docs]
def submit_job(self, context: Context) -> str:
if self._hook is None:
self._hook = self._get_hook()
driver_id = self._hook.submit(self.application)
if not driver_id:
raise RuntimeError("spark-submit did not return a driver ID")
self.log.info("Spark driver submitted: %s", driver_id)
return driver_id
[docs]
def get_job_status(self, external_id: JsonValue) -> str:
# called from submit_job which always returns a str (Spark driver IDs are strings)
external_id = cast("str", external_id)
if self._hook is None:
self._hook = self._get_hook()
# The YARN and K8s branches below (and in is_job_active, is_job_succeeded, poll_until_complete)
# are currently unreachable: execute_resumable is only called when _should_track_driver_status
# is True, which requires spark:// + cluster mode. They are scaffolding for a follow-up PR
# that extends ResumableJobMixin support to YARN and Kubernetes.
if self._hook._is_yarn:
# TODO: call YARN ResourceManager REST API
# GET http://rm:8088/ws/v1/cluster/apps/{external_id}
raise NotImplementedError("YARN job status not yet implemented")
if self._hook._is_kubernetes:
# TODO: call K8s pod status API
raise NotImplementedError("K8s job status not yet implemented")
scheme = self._hook._connection.get("rest_scheme", "http")
rest_port = self._hook._connection.get("rest_port", 6066)
# HA master URLs can look like spark://m1:7077,m2:7077 — try each host in order.
# The master URL port (e.g. 7077) is the RPC port — not the REST API port.
# Use rest-port connection extra to override spark.master.rest.port (default 6066).
master_urls = self._hook._connection["master"].replace("spark://", "").split(",")
last_exc: Exception = RuntimeError("No Spark masters to query")
for m in master_urls:
host = m.strip().split(":")[0]
url = f"{scheme}://{host}:{rest_port}/v1/submissions/status/{external_id}"
try:
status = self._fetch_driver_status(url, external_id)
return status
except Exception as e:
self.log.warning("Could not reach Spark master %s: %s", host, e)
last_exc = e
raise last_exc
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1), reraise=True)
def _fetch_driver_status(self, url: str, external_id: str) -> str:
response = requests.get(url, timeout=30)
response.raise_for_status()
# "success:false" means the master does not recognise the driver ID or is in recovery.
# https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala
data = response.json()
if not data.get("success"):
raise RuntimeError(
f"Spark REST API returned failure for {external_id}: {data.get('message', 'unknown error')}"
)
status = data["driverState"]
self.log.info("Driver %s status: %s", external_id, status)
return status
[docs]
def is_job_active(self, status: str) -> bool:
if self._hook is None:
self._hook = self._get_hook()
status = status.upper()
if self._hook._is_yarn:
# https://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/ResourceManagerRest.html
return status in ("NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED", "RUNNING")
if self._hook._is_kubernetes:
return status in ("PENDING", "RUNNING")
# RELAUNCHING: driver is being restarted after a failure, still alive.
# UNKNOWN: master is in failure recovery, state is temporarily unavailable.
# https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala
return status in ("SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN")
[docs]
def is_job_succeeded(self, status: str) -> bool:
if self._hook is None:
self._hook = self._get_hook()
status = status.upper()
if self._hook._is_kubernetes:
return status == "SUCCEEDED"
# standalone and YARN both use FINISHED
return status == "FINISHED"
[docs]
def poll_until_complete(self, external_id: JsonValue, context: Context) -> None:
# called from submit_job which always returns a str (Spark driver IDs are strings)
external_id = cast("str", external_id)
if self._hook is None:
self._hook = self._get_hook()
if self._hook._is_yarn:
# TODO: poll YARN ResourceManager until app reaches terminal state
raise NotImplementedError("YARN poll not yet implemented")
if self._hook._is_kubernetes:
# TODO: poll K8s pod phase until terminal
raise NotImplementedError("K8s poll not yet implemented")
self.log.info("Polling driver %s until completion", external_id)
self._hook._driver_id = external_id
try:
self._hook._start_driver_status_tracking()
if self._hook._driver_status != "FINISHED":
raise RuntimeError(f"Driver {external_id} exited with status {self._hook._driver_status}")
finally:
# post-submit commands must fire whether the job succeeded or failed.
self._hook._run_post_submit_commands()
[docs]
def get_job_result(self, external_id: JsonValue, context: Context) -> None:
return None
[docs]
def on_kill(self) -> None:
if self._hook is None:
self._hook = self._get_hook()
self._hook.on_kill()
def _get_hook(self) -> SparkSubmitHook:
return SparkSubmitHook(
conf=self.conf,
conn_id=self._conn_id,
files=self.files,
py_files=self.py_files,
archives=self._archives,
driver_class_path=self.driver_class_path,
jars=self.jars,
java_class=self._java_class,
packages=self.packages,
exclude_packages=self.exclude_packages,
repositories=self._repositories,
total_executor_cores=self._total_executor_cores,
executor_cores=self._executor_cores,
executor_memory=self._executor_memory,
driver_memory=self._driver_memory,
keytab=self.keytab,
principal=self.principal,
proxy_user=self.proxy_user,
name=self.name,
num_executors=self._num_executors,
status_poll_interval=self._status_poll_interval,
application_args=self.application_args,
env_vars=self.env_vars,
verbose=self._verbose,
spark_binary=self._spark_binary,
properties_file=self.properties_file,
yarn_queue=self._yarn_queue,
deploy_mode=self._deploy_mode,
use_krb5ccache=self._use_krb5ccache,
post_submit_commands=self.post_submit_commands,
yarn_track_via_rm_api=self._yarn_track_via_rm_api,
yarn_rm_auth=self._yarn_rm_auth,
)