#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from collections.abc import Mapping
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any
from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s
from airflow.exceptions import AirflowException
from airflow.providers.cncf.kubernetes import pod_generator
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook, _load_body_to_dict
from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import add_unique_suffix
from airflow.providers.cncf.kubernetes.operators.custom_object_launcher import CustomObjectLauncher
from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator
from airflow.providers.cncf.kubernetes.pod_generator import MAX_LABEL_LEN, PodGenerator
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
from airflow.utils.helpers import prune_dict
if TYPE_CHECKING:
import jinja2
from airflow.utils.context import Context
[docs]class SparkKubernetesOperator(KubernetesPodOperator):
"""
Creates sparkApplication object in kubernetes cluster.
.. seealso::
For more detail about Spark Application Object have a look at the reference:
https://github.com/GoogleCloudPlatform/spark-on-k8s-operator/blob/v1beta2-1.3.3-3.1.1/docs/api-docs.md#sparkapplication
:param image: Docker image you wish to launch. Defaults to hub.docker.com,
:param code_path: path to the spark code in image,
:param namespace: kubernetes namespace to put sparkApplication
:param name: name of the pod in which the task will run, will be used (plus a random
suffix if random_name_suffix is True) to generate a pod id (DNS-1123 subdomain,
containing only [a-z0-9.-]).
:param application_file: filepath to kubernetes custom_resource_definition of sparkApplication
:param template_spec: kubernetes sparkApplication specification
:param get_logs: get the stdout of the container as logs of the tasks.
:param do_xcom_push: If True, the content of the file
/airflow/xcom/return.json in the container will also be pushed to an
XCom when the container completes.
:param success_run_history_limit: Number of past successful runs of the application to keep.
:param startup_timeout_seconds: timeout in seconds to startup the pod.
:param log_events_on_failure: Log the pod's events if a failure occurs
:param reattach_on_restart: if the scheduler dies while the pod is running, reattach and monitor
:param delete_on_termination: What to do when the pod reaches its final
state, or the execution is interrupted. If True (default), delete the
pod; if False, leave the pod.
:param kubernetes_conn_id: the connection to Kubernetes cluster
"""
[docs] template_fields = ["application_file", "namespace", "template_spec", "kubernetes_conn_id"]
[docs] template_fields_renderers = {"template_spec": "py"}
[docs] template_ext = ("yaml", "yml", "json")
[docs] BASE_CONTAINER_NAME = "spark-kubernetes-driver"
def __init__(
self,
*,
image: str | None = None,
code_path: str | None = None,
namespace: str = "default",
name: str | None = None,
application_file: str | None = None,
template_spec=None,
get_logs: bool = True,
do_xcom_push: bool = False,
success_run_history_limit: int = 1,
startup_timeout_seconds=600,
log_events_on_failure: bool = False,
reattach_on_restart: bool = True,
delete_on_termination: bool = True,
kubernetes_conn_id: str = "kubernetes_default",
**kwargs,
) -> None:
if kwargs.get("xcom_push") is not None:
raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead")
super().__init__(name=name, **kwargs)
self.image = image
self.code_path = code_path
self.application_file = application_file
self.template_spec = template_spec
self.kubernetes_conn_id = kubernetes_conn_id
self.startup_timeout_seconds = startup_timeout_seconds
self.reattach_on_restart = reattach_on_restart
self.delete_on_termination = delete_on_termination
self.do_xcom_push = do_xcom_push
self.namespace = namespace
self.get_logs = get_logs
self.log_events_on_failure = log_events_on_failure
self.success_run_history_limit = success_run_history_limit
if self.base_container_name != self.BASE_CONTAINER_NAME:
self.log.warning(
"base_container_name is not supported and will be overridden to %s", self.BASE_CONTAINER_NAME
)
self.base_container_name = self.BASE_CONTAINER_NAME
if self.get_logs and self.container_logs != self.BASE_CONTAINER_NAME:
self.log.warning(
"container_logs is not supported and will be overridden to %s", self.BASE_CONTAINER_NAME
)
self.container_logs = [self.BASE_CONTAINER_NAME]
def _render_nested_template_fields(
self,
content: Any,
context: Mapping[str, Any],
jinja_env: jinja2.Environment,
seen_oids: set,
) -> None:
if id(content) not in seen_oids and isinstance(content, k8s.V1EnvVar):
seen_oids.add(id(content))
self._do_render_template_fields(content, ("value", "name"), context, jinja_env, seen_oids)
return
super()._render_nested_template_fields(content, context, jinja_env, seen_oids)
[docs] def manage_template_specs(self):
if self.application_file:
try:
filepath = Path(self.application_file.rstrip()).resolve(strict=True)
except (FileNotFoundError, OSError, RuntimeError, ValueError):
application_file_body = self.application_file
else:
application_file_body = filepath.read_text()
template_body = _load_body_to_dict(application_file_body)
if not isinstance(template_body, dict):
msg = f"application_file body can't transformed into the dictionary:\n{application_file_body}"
raise TypeError(msg)
elif self.template_spec:
template_body = self.template_spec
else:
raise AirflowException("either application_file or template_spec should be passed")
if "spark" not in template_body:
template_body = {"spark": template_body}
return template_body
[docs] def create_job_name(self):
name = (
self.name or self.template_body.get("spark", {}).get("metadata", {}).get("name") or self.task_id
)
updated_name = add_unique_suffix(name=name, max_len=MAX_LABEL_LEN)
return self._set_name(updated_name)
@staticmethod
def _get_pod_identifying_label_string(labels) -> str:
filtered_labels = {label_id: label for label_id, label in labels.items() if label_id != "try_number"}
return ",".join([label_id + "=" + label for label_id, label in sorted(filtered_labels.items())])
@staticmethod
[docs] def create_labels_for_pod(context: dict | None = None, include_try_number: bool = True) -> dict:
"""
Generate labels for the pod to track the pod in case of Operator crash.
:param include_try_number: add try number to labels
:param context: task context provided by airflow DAG
:return: dict.
"""
if not context:
return {}
ti = context["ti"]
run_id = context["run_id"]
labels = {
"dag_id": ti.dag_id,
"task_id": ti.task_id,
"run_id": run_id,
"spark_kubernetes_operator": "True",
# 'execution_date': context['ts'],
# 'try_number': context['ti'].try_number,
}
# If running on Airflow 2.3+:
map_index = getattr(ti, "map_index", -1)
if map_index >= 0:
labels["map_index"] = map_index
if include_try_number:
labels.update(try_number=ti.try_number)
# In the case of sub dags this is just useful
# TODO: Remove this when the minimum version of Airflow is bumped to 3.0
if getattr(context["dag"], "is_subdag", False):
labels["parent_dag_id"] = context["dag"].parent_dag.dag_id
# Ensure that label is valid for Kube,
# and if not truncate/remove invalid chars and replace with short hash.
for label_id, label in labels.items():
safe_label = pod_generator.make_safe_label_value(str(label))
labels[label_id] = safe_label
return labels
@cached_property
[docs] def pod_manager(self) -> PodManager:
return PodManager(kube_client=self.client)
@staticmethod
def _try_numbers_match(context, pod) -> bool:
return pod.metadata.labels["try_number"] == context["ti"].try_number
@property
[docs] def template_body(self):
"""Templated body for CustomObjectLauncher."""
return self.manage_template_specs()
[docs] def find_spark_job(self, context):
labels = self.create_labels_for_pod(context, include_try_number=False)
label_selector = self._get_pod_identifying_label_string(labels) + ",spark-role=driver"
pod_list = self.client.list_namespaced_pod(self.namespace, label_selector=label_selector).items
pod = None
if len(pod_list) > 1: # and self.reattach_on_restart:
raise AirflowException(f"More than one pod running with labels: {label_selector}")
elif len(pod_list) == 1:
pod = pod_list[0]
self.log.info(
"Found matching driver pod %s with labels %s", pod.metadata.name, pod.metadata.labels
)
self.log.info("`try_number` of task_instance: %s", context["ti"].try_number)
self.log.info("`try_number` of pod: %s", pod.metadata.labels["try_number"])
return pod
[docs] def get_or_create_spark_crd(self, launcher: CustomObjectLauncher, context) -> k8s.V1Pod:
if self.reattach_on_restart:
driver_pod = self.find_spark_job(context)
if driver_pod:
return driver_pod
driver_pod, spark_obj_spec = launcher.start_spark_job(
image=self.image, code_path=self.code_path, startup_timeout=self.startup_timeout_seconds
)
return driver_pod
[docs] def process_pod_deletion(self, pod, *, reraise=True):
if pod is not None:
if self.delete_on_termination:
self.log.info("Deleting spark job: %s", pod.metadata.name.replace("-driver", ""))
self.launcher.delete_spark_job(pod.metadata.name.replace("-driver", ""))
else:
self.log.info("skipping deleting spark job: %s", pod.metadata.name)
@cached_property
[docs] def hook(self) -> KubernetesHook:
hook = KubernetesHook(
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster or self.template_body.get("kubernetes", {}).get("in_cluster", False),
config_file=self.config_file
or self.template_body.get("kubernetes", {}).get("kube_config_file", None),
cluster_context=self.cluster_context
or self.template_body.get("kubernetes", {}).get("cluster_context", None),
)
return hook
@cached_property
[docs] def client(self) -> CoreV1Api:
return self.hook.core_v1_client
@cached_property
[docs] def custom_obj_api(self) -> CustomObjectsApi:
return CustomObjectsApi()
[docs] def execute(self, context: Context):
self.name = self.create_job_name()
self.log.info("Creating sparkApplication.")
self.launcher = CustomObjectLauncher(
name=self.name,
namespace=self.namespace,
kube_client=self.client,
custom_obj_api=self.custom_obj_api,
template_body=self.template_body,
)
self.pod = self.get_or_create_spark_crd(self.launcher, context)
self.pod_request_obj = self.launcher.pod_spec
return super().execute(context=context)
[docs] def on_kill(self) -> None:
if self.launcher:
self.log.debug("Deleting spark job for task %s", self.task_id)
self.launcher.delete_spark_job()
[docs] def patch_already_checked(self, pod: k8s.V1Pod, *, reraise=True):
"""Add an "already checked" annotation to ensure we don't reattach on retries."""
pod.metadata.labels["already_checked"] = "True"
body = PodGenerator.serialize_pod(pod)
self.client.patch_namespaced_pod(pod.metadata.name, pod.metadata.namespace, body)
[docs] def dry_run(self) -> None:
"""Print out the spark job that would be created by this operator."""
print(prune_dict(self.launcher.body, mode="strict"))