Source code for airflow.providers.apache.spark.operators.spark_submit

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, cast

import requests
from tenacity import retry, stop_after_attempt, wait_fixed

from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
from airflow.providers.common.compat.openlineage.utils.spark import (
    inject_parent_job_information_into_spark_properties,
    inject_transport_information_into_spark_properties,
)
from airflow.providers.common.compat.sdk import BaseOperator, conf

try:
    from airflow.sdk.bases.resumablemixin import ResumableJobMixin
except ImportError:
    # Airflow 2 compat.
    # ResumableJobMixin does not exist in Airflow 2, so we need to add a stub to make it
    # behave as before
[docs] class ResumableJobMixin: # type: ignore[no-redef] """Airflow 2 stub — no task_state, always submits fresh."""
[docs] external_id_key: str = "remote_job_id"
[docs] def execute_resumable(self, context): external_id = self.submit_job(context) self.poll_until_complete(external_id, context) return self.get_job_result(external_id, context)
if TYPE_CHECKING: from pydantic import JsonValue from requests.auth import AuthBase from airflow.providers.common.compat.sdk import Context
[docs] class SparkSubmitOperator(ResumableJobMixin, BaseOperator): """ Wrap the spark-submit binary to kick off a spark-submit job; requires "spark-submit" binary in the PATH. .. seealso:: For more information on how to use this operator, take a look at the guide: :ref:`howto/operator:SparkSubmitOperator` :param application: The application that submitted as a job, either jar or py file. (templated) :param conf: Arbitrary Spark configuration properties (templated) :param conn_id: The :ref:`spark connection id <howto/connection:spark>` as configured in Airflow administration. When an invalid connection_id is supplied, it will default to yarn. :param files: Upload additional files to the executor running the job, separated by a comma. Files will be placed in the working directory of each executor. For example, serialized objects. (templated) :param py_files: Additional python files used by the job, can be .zip, .egg or .py. (templated) :param jars: Submit additional jars to upload and place them in driver and executor classpaths. (templated) :param driver_class_path: Additional, driver-specific, classpath settings. (templated) :param java_class: the main class of the Java application :param packages: Comma-separated list of maven coordinates of jars to include on the driver and executor classpaths. (templated) :param exclude_packages: Comma-separated list of maven coordinates of jars to exclude while resolving the dependencies provided in 'packages' (templated) :param repositories: Comma-separated list of additional remote repositories to search for the maven coordinates given with 'packages' :param total_executor_cores: (Standalone & Mesos only) Total cores for all executors (Default: all the available cores on the worker) :param executor_cores: (Standalone & YARN only) Number of cores per executor (Default: 2) :param executor_memory: Memory per executor (e.g. 1000M, 2G) (Default: 1G) :param driver_memory: Memory allocated to the driver (e.g. 1000M, 2G) (Default: 1G) :param keytab: Full path to the file that contains the keytab (templated) (will overwrite any keytab defined in the connection's extra JSON) :param principal: The name of the kerberos principal used for keytab (templated) (will overwrite any principal defined in the connection's extra JSON) :param proxy_user: User to impersonate when submitting the application (templated) :param name: Name of the job (default airflow-spark). (templated) :param num_executors: Number of executors to launch :param status_poll_interval: Seconds to wait between polls of driver status in cluster mode. Used both by the Spark standalone driver-status tracker and (when ``yarn_track_via_rm_api=True``) by the YARN ResourceManager REST API polling loop. The YARN ResourceManager REST API polling loop uses at least 10 seconds to avoid flooding the ResourceManager on long-running jobs (Default: 1). :param application_args: Arguments for the application being submitted (templated) :param env_vars: Environment variables for spark-submit. It supports yarn and k8s mode too. (templated) :param verbose: Whether to pass the verbose flag to spark-submit process for debugging :param spark_binary: The command to use for spark submit. Some distros may use spark2-submit or spark3-submit. (will overwrite any spark_binary defined in the connection's extra JSON) :param properties_file: Path to a file from which to load extra properties. If not specified, this will look for conf/spark-defaults.conf. :param yarn_queue: The name of the YARN queue to which the application is submitted. (will overwrite any yarn queue defined in the connection's extra JSON) :param deploy_mode: Whether to deploy your driver on the worker nodes (cluster) or locally as a client. (will overwrite any deployment mode defined in the connection's extra JSON) :param use_krb5ccache: if True, configure spark to use ticket cache instead of relying on keytab for Kerberos login :param post_submit_commands: Optional list of shell commands to run after the Spark job finishes. Useful for cleaning up sidecars such as Istio. Failures produce a warning but do not fail the task. :param yarn_track_via_rm_api: If True (when master is YARN and ``deploy_mode`` is ``cluster``), release the ``spark-submit`` JVM once the application has been submitted to YARN, then poll the YARN ResourceManager REST API (``GET /ws/v1/cluster/apps/{appId}``) until the application reaches a final state. The polling interval is controlled by ``status_poll_interval`` with a 10-second minimum. This frees the worker from holding the long-lived submit JVM. Requires the Spark connection's ``extra`` JSON to set ``yarn_resourcemanager_webapp_address`` (e.g. ``http://rm:8088``). Cluster-side driver logs should be used after the switch to polling. Defaults to ``False``. :param yarn_rm_auth: Optional ``requests.auth.AuthBase`` instance used for every call to the YARN ResourceManager REST API (status polling and kill). When omitted, Kerberos-enabled Spark connections with both ``keytab`` and ``principal`` configured use ``requests-kerberos`` automatically. Defaults to ``None`` (no auth for non-Kerberos connections). """ # Generic key used across all Spark deployment modes (standalone driver ID, # YARN application ID, K8s driver pod name).
[docs] external_id_key = "spark_job_id"
[docs] template_fields: Sequence[str] = ( "application", "conf", "files", "py_files", "jars", "driver_class_path", "packages", "exclude_packages", "keytab", "principal", "proxy_user", "name", "application_args", "env_vars", "post_submit_commands", "properties_file", )
def __init__( self, *, application: str = "", conf: dict[Any, Any] | None = None, conn_id: str = "spark_default", files: str | None = None, py_files: str | None = None, archives: str | None = None, driver_class_path: str | None = None, jars: str | None = None, java_class: str | None = None, packages: str | None = None, exclude_packages: str | None = None, repositories: str | None = None, total_executor_cores: int | None = None, executor_cores: int | None = None, executor_memory: str | None = None, driver_memory: str | None = None, keytab: str | None = None, principal: str | None = None, proxy_user: str | None = None, name: str = "arrow-spark", num_executors: int | None = None, status_poll_interval: int = 1, application_args: list[Any] | None = None, env_vars: dict[str, Any] | None = None, verbose: bool = False, spark_binary: str | None = None, properties_file: str | None = None, yarn_queue: str | None = None, deploy_mode: str | None = None, use_krb5ccache: bool = False, post_submit_commands: list[str] | None = None, reconnect_on_retry: bool = True, yarn_track_via_rm_api: bool = False, yarn_rm_auth: AuthBase | None = None, openlineage_inject_parent_job_info: bool = conf.getboolean( "openlineage", "spark_inject_parent_job_info", fallback=False ), openlineage_inject_transport_info: bool = conf.getboolean( "openlineage", "spark_inject_transport_info", fallback=False ), **kwargs: Any, ) -> None: super().__init__(**kwargs)
[docs] self.application = application
[docs] self.conf = conf
[docs] self.files = files
[docs] self.py_files = py_files
self._archives = archives
[docs] self.driver_class_path = driver_class_path
[docs] self.jars = jars
self._java_class = java_class
[docs] self.packages = packages
[docs] self.exclude_packages = exclude_packages
self._repositories = repositories self._total_executor_cores = total_executor_cores self._executor_cores = executor_cores self._executor_memory = executor_memory self._driver_memory = driver_memory
[docs] self.keytab = keytab
[docs] self.principal = principal
[docs] self.proxy_user = proxy_user
[docs] self.name = name
self._num_executors = num_executors self._status_poll_interval = status_poll_interval
[docs] self.application_args = application_args
[docs] self.env_vars = env_vars
self._verbose = verbose self._spark_binary = spark_binary
[docs] self.properties_file = properties_file
self._yarn_queue = yarn_queue self._deploy_mode = deploy_mode self._hook: SparkSubmitHook | None = None
[docs] self.post_submit_commands = post_submit_commands
self._post_submit_commands = list(post_submit_commands) if post_submit_commands else [] self._conn_id = conn_id self._use_krb5ccache = use_krb5ccache self._yarn_track_via_rm_api = yarn_track_via_rm_api self._yarn_rm_auth = yarn_rm_auth
[docs] self.reconnect_on_retry = reconnect_on_retry
self._openlineage_inject_parent_job_info = openlineage_inject_parent_job_info self._openlineage_inject_transport_info = openlineage_inject_transport_info
[docs] def execute(self, context: Context) -> None: """Call the SparkSubmitHook to run the provided spark job.""" self.conf = self.conf or {} if self._openlineage_inject_parent_job_info: self.log.debug("Injecting OpenLineage parent job information into Spark properties.") self.conf = inject_parent_job_information_into_spark_properties(self.conf, context) if self._openlineage_inject_transport_info: self.log.debug("Injecting OpenLineage transport information into Spark properties.") self.conf = inject_transport_information_into_spark_properties(self.conf, context) if self._hook is None: self._hook = self._get_hook() hook = self._hook if hook._should_track_driver_status: if self.reconnect_on_retry: return self.execute_resumable(context) # reconnect_on_retry=False: still submit-and-poll, just skip task_state persistence. driver_id = self.submit_job(context) self.poll_until_complete(driver_id, context) return self.get_job_result(driver_id, context) hook.submit(self.application)
[docs] def submit_job(self, context: Context) -> str: if self._hook is None: self._hook = self._get_hook() driver_id = self._hook.submit(self.application) if not driver_id: raise RuntimeError("spark-submit did not return a driver ID") self.log.info("Spark driver submitted: %s", driver_id) return driver_id
[docs] def get_job_status(self, external_id: JsonValue) -> str: # called from submit_job which always returns a str (Spark driver IDs are strings) external_id = cast("str", external_id) if self._hook is None: self._hook = self._get_hook() # The YARN and K8s branches below (and in is_job_active, is_job_succeeded, poll_until_complete) # are currently unreachable: execute_resumable is only called when _should_track_driver_status # is True, which requires spark:// + cluster mode. They are scaffolding for a follow-up PR # that extends ResumableJobMixin support to YARN and Kubernetes. if self._hook._is_yarn: # TODO: call YARN ResourceManager REST API # GET http://rm:8088/ws/v1/cluster/apps/{external_id} raise NotImplementedError("YARN job status not yet implemented") if self._hook._is_kubernetes: # TODO: call K8s pod status API raise NotImplementedError("K8s job status not yet implemented") scheme = self._hook._connection.get("rest_scheme", "http") rest_port = self._hook._connection.get("rest_port", 6066) # HA master URLs can look like spark://m1:7077,m2:7077 — try each host in order. # The master URL port (e.g. 7077) is the RPC port — not the REST API port. # Use rest-port connection extra to override spark.master.rest.port (default 6066). master_urls = self._hook._connection["master"].replace("spark://", "").split(",") last_exc: Exception = RuntimeError("No Spark masters to query") for m in master_urls: host = m.strip().split(":")[0] url = f"{scheme}://{host}:{rest_port}/v1/submissions/status/{external_id}" try: status = self._fetch_driver_status(url, external_id) return status except Exception as e: self.log.warning("Could not reach Spark master %s: %s", host, e) last_exc = e raise last_exc
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1), reraise=True) def _fetch_driver_status(self, url: str, external_id: str) -> str: response = requests.get(url, timeout=30) response.raise_for_status() # "success:false" means the master does not recognise the driver ID or is in recovery. # https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala data = response.json() if not data.get("success"): raise RuntimeError( f"Spark REST API returned failure for {external_id}: {data.get('message', 'unknown error')}" ) status = data["driverState"] self.log.info("Driver %s status: %s", external_id, status) return status
[docs] def is_job_active(self, status: str) -> bool: if self._hook is None: self._hook = self._get_hook() status = status.upper() if self._hook._is_yarn: # https://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/ResourceManagerRest.html return status in ("NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED", "RUNNING") if self._hook._is_kubernetes: return status in ("PENDING", "RUNNING") # RELAUNCHING: driver is being restarted after a failure, still alive. # UNKNOWN: master is in failure recovery, state is temporarily unavailable. # https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala return status in ("SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN")
[docs] def is_job_succeeded(self, status: str) -> bool: if self._hook is None: self._hook = self._get_hook() status = status.upper() if self._hook._is_kubernetes: return status == "SUCCEEDED" # standalone and YARN both use FINISHED return status == "FINISHED"
[docs] def poll_until_complete(self, external_id: JsonValue, context: Context) -> None: # called from submit_job which always returns a str (Spark driver IDs are strings) external_id = cast("str", external_id) if self._hook is None: self._hook = self._get_hook() if self._hook._is_yarn: # TODO: poll YARN ResourceManager until app reaches terminal state raise NotImplementedError("YARN poll not yet implemented") if self._hook._is_kubernetes: # TODO: poll K8s pod phase until terminal raise NotImplementedError("K8s poll not yet implemented") self.log.info("Polling driver %s until completion", external_id) self._hook._driver_id = external_id try: self._hook._start_driver_status_tracking() if self._hook._driver_status != "FINISHED": raise RuntimeError(f"Driver {external_id} exited with status {self._hook._driver_status}") finally: # post-submit commands must fire whether the job succeeded or failed. self._hook._run_post_submit_commands()
[docs] def get_job_result(self, external_id: JsonValue, context: Context) -> None: return None
[docs] def on_kill(self) -> None: if self._hook is None: self._hook = self._get_hook() self._hook.on_kill()
def _get_hook(self) -> SparkSubmitHook: return SparkSubmitHook( conf=self.conf, conn_id=self._conn_id, files=self.files, py_files=self.py_files, archives=self._archives, driver_class_path=self.driver_class_path, jars=self.jars, java_class=self._java_class, packages=self.packages, exclude_packages=self.exclude_packages, repositories=self._repositories, total_executor_cores=self._total_executor_cores, executor_cores=self._executor_cores, executor_memory=self._executor_memory, driver_memory=self._driver_memory, keytab=self.keytab, principal=self.principal, proxy_user=self.proxy_user, name=self.name, num_executors=self._num_executors, status_poll_interval=self._status_poll_interval, application_args=self.application_args, env_vars=self.env_vars, verbose=self._verbose, spark_binary=self._spark_binary, properties_file=self.properties_file, yarn_queue=self._yarn_queue, deploy_mode=self._deploy_mode, use_krb5ccache=self._use_krb5ccache, post_submit_commands=self.post_submit_commands, yarn_track_via_rm_api=self._yarn_track_via_rm_api, yarn_rm_auth=self._yarn_rm_auth, )

Was this entry helpful?