# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import threading
from collections import namedtuple
from collections.abc import Iterable, Mapping, Sequence
from contextlib import closing
from copy import copy
from datetime import timedelta
from typing import (
TYPE_CHECKING,
Any,
Callable,
TypeVar,
cast,
overload,
)
from databricks import sql # type: ignore[attr-defined]
from databricks.sql.types import Row
from airflow.exceptions import AirflowException
from airflow.models.connection import Connection as AirflowConnection
from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results
from airflow.providers.databricks.exceptions import DatabricksSqlExecutionError, DatabricksSqlExecutionTimeout
from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook
if TYPE_CHECKING:
from databricks.sql.client import Connection
[docs]LIST_SQL_ENDPOINTS_ENDPOINT = ("GET", "api/2.0/sql/endpoints")
[docs]def create_timeout_thread(cur, execution_timeout: timedelta | None) -> threading.Timer | None:
if execution_timeout is not None:
seconds_to_timeout = execution_timeout.total_seconds()
t = threading.Timer(seconds_to_timeout, cur.connection.cancel)
else:
t = None
return t
[docs]class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
"""
Hook to interact with Databricks SQL.
:param databricks_conn_id: Reference to the
:ref:`Databricks connection <howto/connection:databricks>`.
:param http_path: Optional string specifying HTTP path of Databricks SQL Endpoint or cluster.
If not specified, it should be either specified in the Databricks connection's extra parameters,
or ``sql_endpoint_name`` must be specified.
:param sql_endpoint_name: Optional name of Databricks SQL Endpoint. If not specified, ``http_path``
must be provided as described above.
:param session_configuration: An optional dictionary of Spark session parameters. Defaults to None.
If not specified, it could be specified in the Databricks connection's extra parameters.
:param http_headers: An optional list of (k, v) pairs that will be set as HTTP headers
on every request
:param catalog: An optional initial catalog to use. Requires DBR version 9.0+
:param schema: An optional initial schema to use. Requires DBR version 9.0+
:param kwargs: Additional parameters internal to Databricks SQL Connector parameters
"""
[docs] hook_name = "Databricks SQL"
_test_connection_sql = "select 42"
def __init__(
self,
databricks_conn_id: str = BaseDatabricksHook.default_conn_name,
http_path: str | None = None,
sql_endpoint_name: str | None = None,
session_configuration: dict[str, str] | None = None,
http_headers: list[tuple[str, str]] | None = None,
catalog: str | None = None,
schema: str | None = None,
caller: str = "DatabricksSqlHook",
**kwargs,
) -> None:
super().__init__(databricks_conn_id, caller=caller)
self._sql_conn: Connection | None = None
self._token: str | None = None
self._http_path = http_path
self._sql_endpoint_name = sql_endpoint_name
self.supports_autocommit = True
self.session_config = session_configuration
self.http_headers = http_headers
self.catalog = catalog
self.schema = schema
self.additional_params = kwargs
def _get_extra_config(self) -> dict[str, Any | None]:
extra_params = copy(self.databricks_conn.extra_dejson)
for arg in ["http_path", "session_configuration", *self.extra_parameters]:
if arg in extra_params:
del extra_params[arg]
return extra_params
def _get_sql_endpoint_by_name(self, endpoint_name) -> dict[str, Any]:
result = self._do_api_call(LIST_SQL_ENDPOINTS_ENDPOINT)
if "endpoints" not in result:
raise AirflowException("Can't list Databricks SQL endpoints")
try:
endpoint = next(endpoint for endpoint in result["endpoints"] if endpoint["name"] == endpoint_name)
except StopIteration:
raise AirflowException(f"Can't find Databricks SQL endpoint with name '{endpoint_name}'")
else:
return endpoint
[docs] def get_conn(self) -> AirflowConnection:
"""Return a Databricks SQL connection object."""
if not self._http_path:
if self._sql_endpoint_name:
endpoint = self._get_sql_endpoint_by_name(self._sql_endpoint_name)
self._http_path = endpoint["odbc_params"]["path"]
elif "http_path" in self.databricks_conn.extra_dejson:
self._http_path = self.databricks_conn.extra_dejson["http_path"]
else:
raise AirflowException(
"http_path should be provided either explicitly, "
"or in extra parameter of Databricks connection, "
"or sql_endpoint_name should be specified"
)
prev_token = self._token
new_token = self._get_token(raise_error=True)
if not self._token or new_token != self._token:
self._token = new_token
if not self.session_config:
self.session_config = self.databricks_conn.extra_dejson.get("session_configuration")
if not self._sql_conn or prev_token != new_token:
if self._sql_conn: # close already existing connection
self._sql_conn.close()
self._sql_conn = sql.connect(
self.host,
self._http_path,
self._token,
schema=self.schema,
catalog=self.catalog,
session_configuration=self.session_config,
http_headers=self.http_headers,
_user_agent_entry=self.user_agent_value,
**self._get_extra_config(),
**self.additional_params,
)
if self._sql_conn is None:
raise AirflowException("SQL connection is not initialized")
return cast(AirflowConnection, self._sql_conn)
@overload # type: ignore[override]
[docs] def run(
self,
sql: str | Iterable[str],
autocommit: bool = ...,
parameters: Iterable | Mapping[str, Any] | None = ...,
handler: None = ...,
split_statements: bool = ...,
return_last: bool = ...,
execution_timeout: timedelta | None = None,
) -> None: ...
@overload
def run(
self,
sql: str | Iterable[str],
autocommit: bool = ...,
parameters: Iterable | Mapping[str, Any] | None = ...,
handler: Callable[[Any], T] = ...,
split_statements: bool = ...,
return_last: bool = ...,
execution_timeout: timedelta | None = None,
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None: ...
def run(
self,
sql: str | Iterable[str],
autocommit: bool = False,
parameters: Iterable | Mapping[str, Any] | None = None,
handler: Callable[[Any], T] | None = None,
split_statements: bool = True,
return_last: bool = True,
execution_timeout: timedelta | None = None,
) -> tuple | list[tuple] | list[list[tuple] | tuple] | None:
"""
Run a command or a list of commands.
Pass a list of SQL statements to the SQL parameter to get them to
execute sequentially.
:param sql: the sql statement to be executed (str) or a list of
sql statements to execute
:param autocommit: What to set the connection's autocommit setting to
before executing the query. Note that currently there is no commit functionality
in Databricks SQL so this flag has no effect.
:param parameters: The parameters to render the SQL query with.
:param handler: The result handler which is called with the result of each statement.
:param split_statements: Whether to split a single SQL string into statements and run separately
:param return_last: Whether to return result for only last statement or for all after split
:return: return only result of the LAST SQL expression if handler was provided unless return_last
is set to False.
:param execution_timeout: max time allowed for the execution of this task instance, if it goes beyond
it will raise and fail.
"""
self.descriptions = []
if isinstance(sql, str):
if split_statements:
sql_list = [self.strip_sql_string(s) for s in self.split_sql_string(sql)]
else:
sql_list = [self.strip_sql_string(sql)]
else:
sql_list = [self.strip_sql_string(s) for s in sql]
if sql_list:
self.log.debug("Executing following statements against Databricks DB: %s", sql_list)
else:
raise ValueError("List of SQL statements is empty")
conn = None
results = []
for sql_statement in sql_list:
# when using AAD tokens, it could expire if previous query run longer than token lifetime
conn = self.get_conn()
with closing(conn.cursor()) as cur:
self.set_autocommit(conn, autocommit)
with closing(conn.cursor()) as cur:
t = create_timeout_thread(cur, execution_timeout)
# TODO: adjust this to make testing easier
try:
self._run_command(cur, sql_statement, parameters) # type: ignore[attr-defined]
except Exception as e:
if t is None or t.is_alive():
raise DatabricksSqlExecutionError(
f"Error running SQL statement: {sql_statement}. {str(e)}"
)
raise DatabricksSqlExecutionTimeout(
f"Timeout threshold exceeded for SQL statement: {sql_statement} was cancelled."
)
finally:
if t is not None:
t.cancel()
if handler is not None:
raw_result = handler(cur)
result = self._make_common_data_structure(raw_result)
if return_single_query_results(sql, return_last, split_statements):
results = [result]
self.descriptions = [cur.description]
else:
results.append(result)
self.descriptions.append(cur.description)
if conn:
conn.close()
self._sql_conn = None
if handler is None:
return None
if return_single_query_results(sql, return_last, split_statements):
return results[-1]
else:
return results
def _make_common_data_structure(self, result: T | Sequence[T]) -> tuple[Any, ...] | list[tuple[Any, ...]]:
"""Transform the databricks Row objects into namedtuple."""
# Below ignored lines respect namedtuple docstring, but mypy do not support dynamically
# instantiated namedtuple, and will never do: https://github.com/python/mypy/issues/848
if isinstance(result, list):
rows: Sequence[Row] = result
if not rows:
return []
rows_fields = tuple(rows[0].__fields__)
rows_object = namedtuple("Row", rows_fields, rename=True) # type: ignore
return cast(list[tuple[Any, ...]], [rows_object(*row) for row in rows])
elif isinstance(result, Row):
row_fields = tuple(result.__fields__)
row_object = namedtuple("Row", row_fields, rename=True) # type: ignore
return cast(tuple[Any, ...], row_object(*result))
else:
raise TypeError(f"Expected Sequence[Row] or Row, but got {type(result)}")
[docs] def bulk_dump(self, table, tmp_file):
raise NotImplementedError()
[docs] def bulk_load(self, table, tmp_file):
raise NotImplementedError()