Source code for airflow.providers.alibaba.cloud.operators.analyticdb_spark

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import time
from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.alibaba.cloud.hooks.analyticdb_spark import AnalyticDBSparkHook, AppState

if TYPE_CHECKING:
    from airflow.utils.context import Context


[docs]class AnalyticDBSparkBaseOperator(BaseOperator): """Abstract base class that defines how users develop AnalyticDB Spark.""" def __init__( self, *, adb_spark_conn_id: str = "adb_spark_default", region: str | None = None, polling_interval: int = 0, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.app_id: str | None = None self.polling_interval = polling_interval self._adb_spark_conn_id = adb_spark_conn_id self._region = region @cached_property
[docs] def hook(self) -> AnalyticDBSparkHook: """Get valid hook.""" return AnalyticDBSparkHook(adb_spark_conn_id=self._adb_spark_conn_id, region=self._region)
[docs] def execute(self, context: Context) -> Any: ...
[docs] def monitor_application(self): self.log.info("Monitoring application with %s", self.app_id) if self.polling_interval > 0: self.poll_for_termination(self.app_id)
[docs] def poll_for_termination(self, app_id: str) -> None: """ Pool for spark application termination. :param app_id: id of the spark application to monitor """ state = self.hook.get_spark_state(app_id) while AppState(state) not in AnalyticDBSparkHook.TERMINAL_STATES: self.log.debug("Application with id %s is in state: %s", app_id, state) time.sleep(self.polling_interval) state = self.hook.get_spark_state(app_id) self.log.info("Application with id %s terminated with state: %s", app_id, state) self.log.info( "Web ui address is %s for application with id %s", self.hook.get_spark_web_ui_address(app_id), app_id, ) self.log.info(self.hook.get_spark_log(app_id)) if AppState(state) != AppState.COMPLETED: raise AirflowException(f"Application {app_id} did not succeed")
[docs] def on_kill(self) -> None: self.kill()
[docs] def kill(self) -> None: """Delete the specified application.""" if self.app_id is not None: self.hook.kill_spark_app(self.app_id)
[docs]class AnalyticDBSparkSQLOperator(AnalyticDBSparkBaseOperator): """ Submits a Spark SQL application to the underlying cluster; wraps the AnalyticDB Spark REST API. :param sql: The SQL query to execute. :param conf: Spark configuration properties. :param driver_resource_spec: The resource specifications of the Spark driver. :param executor_resource_spec: The resource specifications of each Spark executor. :param num_executors: number of executors to launch for this application. :param name: name of this application. :param cluster_id: The cluster ID of AnalyticDB MySQL 3.0 Data Lakehouse. :param rg_name: The name of resource group in AnalyticDB MySQL 3.0 Data Lakehouse cluster. """
[docs] template_fields: Sequence[str] = ("spark_params",)
[docs] template_fields_renderers = {"spark_params": "json"}
def __init__( self, *, sql: str, conf: dict[Any, Any] | None = None, driver_resource_spec: str | None = None, executor_resource_spec: str | None = None, num_executors: int | str | None = None, name: str | None = None, cluster_id: str, rg_name: str, **kwargs: Any, ) -> None: super().__init__(**kwargs) spark_params = { "sql": sql, "conf": conf, "driver_resource_spec": driver_resource_spec, "executor_resource_spec": executor_resource_spec, "num_executors": num_executors, "name": name, } self.spark_params = spark_params self._cluster_id = cluster_id self._rg_name = rg_name
[docs] def execute(self, context: Context) -> Any: submit_response = self.hook.submit_spark_sql( cluster_id=self._cluster_id, rg_name=self._rg_name, **self.spark_params ) self.app_id = submit_response.body.data.app_id self.monitor_application() return self.app_id
[docs]class AnalyticDBSparkBatchOperator(AnalyticDBSparkBaseOperator): """ Submits a Spark batch application to the underlying cluster; wraps the AnalyticDB Spark REST API. :param file: path of the file containing the application to execute. :param class_name: name of the application Java/Spark main class. :param args: application command line arguments. :param conf: Spark configuration properties. :param jars: jars to be used in this application. :param py_files: python files to be used in this application. :param files: files to be used in this application. :param driver_resource_spec: The resource specifications of the Spark driver. :param executor_resource_spec: The resource specifications of each Spark executor. :param num_executors: number of executors to launch for this application. :param archives: archives to be used in this application. :param name: name of this application. :param cluster_id: The cluster ID of AnalyticDB MySQL 3.0 Data Lakehouse. :param rg_name: The name of resource group in AnalyticDB MySQL 3.0 Data Lakehouse cluster. """
[docs] template_fields: Sequence[str] = ("spark_params",)
[docs] template_fields_renderers = {"spark_params": "json"}
def __init__( self, *, file: str, class_name: str | None = None, args: Sequence[str | int | float] | None = None, conf: dict[Any, Any] | None = None, jars: Sequence[str] | None = None, py_files: Sequence[str] | None = None, files: Sequence[str] | None = None, driver_resource_spec: str | None = None, executor_resource_spec: str | None = None, num_executors: int | str | None = None, archives: Sequence[str] | None = None, name: str | None = None, cluster_id: str, rg_name: str, **kwargs: Any, ) -> None: super().__init__(**kwargs) spark_params = { "file": file, "class_name": class_name, "args": args, "conf": conf, "jars": jars, "py_files": py_files, "files": files, "driver_resource_spec": driver_resource_spec, "executor_resource_spec": executor_resource_spec, "num_executors": num_executors, "archives": archives, "name": name, } self.spark_params = spark_params self._cluster_id = cluster_id self._rg_name = rg_name
[docs] def execute(self, context: Context) -> Any: submit_response = self.hook.submit_spark_app( cluster_id=self._cluster_id, rg_name=self._rg_name, **self.spark_params ) self.app_id = submit_response.body.data.app_id self.monitor_application() return self.app_id

Was this entry helpful?