#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import contextlib
import os
import re
import socket
import subprocess
import time
from collections.abc import Iterable, Mapping
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
import pandas as pd
import csv
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.security import utils
from airflow.utils.helpers import as_flattened_list
from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING
[docs]HIVE_QUEUE_PRIORITIES = ["VERY_HIGH", "HIGH", "NORMAL", "LOW", "VERY_LOW"]
[docs]def get_context_from_env_var() -> dict[Any, Any]:
"""
Extract context from env variable, (dag_id, task_id, etc) for use in BashOperator and PythonOperator.
:return: The context of interest.
"""
return {
format_map["default"]: os.environ.get(format_map["env_var_format"], "")
for format_map in AIRFLOW_VAR_NAME_FORMAT_MAPPING.values()
}
[docs]class HiveCliHook(BaseHook):
"""
Simple wrapper around the hive CLI.
It also supports the ``beeline``
a lighter CLI that runs JDBC and is replacing the heavier
traditional CLI. To enable ``beeline``, set the use_beeline param in the
extra field of your connection as in ``{ "use_beeline": true }``
Note that you can also set default hive CLI parameters by passing ``hive_cli_params``
space separated list of parameters to add to the hive command.
The extra connection parameter ``auth`` gets passed as in the ``jdbc``
connection string as is.
:param hive_cli_conn_id: Reference to the
:ref:`Hive CLI connection id <howto/connection:hive_cli>`.
:param mapred_queue: queue used by the Hadoop Scheduler (Capacity or Fair)
:param mapred_queue_priority: priority within the job queue.
Possible settings include: VERY_HIGH, HIGH, NORMAL, LOW, VERY_LOW
:param mapred_job_name: This name will appear in the jobtracker.
This can make monitoring easier.
:param hive_cli_params: Space separated list of hive command parameters to add to the
hive command.
:param proxy_user: Run HQL code as this user.
"""
[docs] conn_name_attr = "hive_cli_conn_id"
[docs] default_conn_name = "hive_cli_default"
[docs] hook_name = "Hive Client Wrapper"
def __init__(
self,
hive_cli_conn_id: str = default_conn_name,
mapred_queue: str | None = None,
mapred_queue_priority: str | None = None,
mapred_job_name: str | None = None,
hive_cli_params: str = "",
auth: str | None = None,
proxy_user: str | None = None,
) -> None:
super().__init__()
conn = self.get_connection(hive_cli_conn_id)
self.hive_cli_params: str = hive_cli_params
self.use_beeline: bool = conn.extra_dejson.get("use_beeline", False)
self.auth = auth
self.conn = conn
self.sub_process: Any = None
if mapred_queue_priority:
mapred_queue_priority = mapred_queue_priority.upper()
if mapred_queue_priority not in HIVE_QUEUE_PRIORITIES:
raise AirflowException(
f"Invalid Mapred Queue Priority. Valid values are: {', '.join(HIVE_QUEUE_PRIORITIES)}"
)
self.mapred_queue = mapred_queue or conf.get("hive", "default_hive_mapred_queue")
self.mapred_queue_priority = mapred_queue_priority
self.mapred_job_name = mapred_job_name
self.proxy_user = proxy_user
self.high_availability = self.conn.extra_dejson.get("high_availability", False)
@classmethod
@classmethod
[docs] def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Return custom UI field behaviour for Hive Client Wrapper connection."""
return {
"hidden_fields": ["extra"],
"relabeling": {},
}
def _get_proxy_user(self) -> str:
"""Set the proper proxy_user value in case the user overwrite the default."""
conn = self.conn
if self.proxy_user is not None:
return f"hive.server2.proxy.user={self.proxy_user}"
proxy_user_value: str = conn.extra_dejson.get("proxy_user", "")
if proxy_user_value != "":
return f"hive.server2.proxy.user={proxy_user_value}"
return ""
def _prepare_cli_cmd(self) -> list[Any]:
"""Create the command list from available information."""
conn = self.conn
hive_bin = "hive"
cmd_extra = []
if self.use_beeline:
hive_bin = "beeline"
self._validate_beeline_parameters(conn)
if self.high_availability:
jdbc_url = f"jdbc:hive2://{conn.host}/{conn.schema}"
self.log.info("High Availability selected, setting JDBC url as %s", jdbc_url)
else:
jdbc_url = f"jdbc:hive2://{conn.host}:{conn.port}/{conn.schema}"
self.log.info("High Availability not selected, setting JDBC url as %s", jdbc_url)
if conf.get("core", "security") == "kerberos":
template = conn.extra_dejson.get("principal", "hive/_HOST@EXAMPLE.COM")
if "_HOST" in template:
template = utils.replace_hostname_pattern(utils.get_components(template))
proxy_user = self._get_proxy_user()
if ";" in template:
raise RuntimeError("The principal should not contain the ';' character")
if ";" in proxy_user:
raise RuntimeError("The proxy_user should not contain the ';' character")
jdbc_url += f";principal={template};{proxy_user}"
if self.high_availability:
if not jdbc_url.endswith(";"):
jdbc_url += ";"
jdbc_url += "serviceDiscoveryMode=zooKeeper;ssl=true;zooKeeperNamespace=hiveserver2"
elif self.auth:
jdbc_url += ";auth=" + self.auth
jdbc_url = f'"{jdbc_url}"'
cmd_extra += ["-u", jdbc_url]
if conn.login:
cmd_extra += ["-n", conn.login]
if conn.password:
cmd_extra += ["-p", conn.password]
hive_params_list = self.hive_cli_params.split()
return [hive_bin, *cmd_extra, *hive_params_list]
def _validate_beeline_parameters(self, conn):
if self.high_availability:
if ";" in conn.schema:
raise ValueError(
f"The schema used in beeline command ({conn.schema}) should not contain ';' character)"
)
return
elif ":" in conn.host or "/" in conn.host or ";" in conn.host:
raise ValueError(
f"The host used in beeline command ({conn.host}) should not contain ':/;' characters)"
)
try:
int_port = int(conn.port)
if not 0 < int_port <= 65535:
raise ValueError(
f"The port used in beeline command ({conn.port}) should be in range 0-65535)"
)
except (ValueError, TypeError) as e:
raise ValueError(
f"The port used in beeline command ({conn.port}) should be a valid integer: {e})"
)
if ";" in conn.schema:
raise ValueError(
f"The schema used in beeline command ({conn.schema}) should not contain ';' character)"
)
@staticmethod
def _prepare_hiveconf(d: dict[Any, Any]) -> list[Any]:
"""
Prepare a list of hiveconf params from a dictionary of key value pairs.
:param d:
>>> hh = HiveCliHook()
>>> hive_conf = {"hive.exec.dynamic.partition": "true",
... "hive.exec.dynamic.partition.mode": "nonstrict"}
>>> hh._prepare_hiveconf(hive_conf)
["-hiveconf", "hive.exec.dynamic.partition=true",\
"-hiveconf", "hive.exec.dynamic.partition.mode=nonstrict"]
"""
if not d:
return []
return as_flattened_list(zip(["-hiveconf"] * len(d), [f"{k}={v}" for k, v in d.items()]))
[docs] def run_cli(
self,
hql: str,
schema: str | None = None,
verbose: bool = True,
hive_conf: dict[Any, Any] | None = None,
) -> Any:
"""
Run an hql statement using the hive cli.
If hive_conf is specified it should be a dict and the entries
will be set as key/value pairs in HiveConf.
:param hql: an hql (hive query language) statement to run with hive cli
:param schema: Name of hive schema (database) to use
:param verbose: Provides additional logging. Defaults to True.
:param hive_conf: if specified these key value pairs will be passed
to hive as ``-hiveconf "key"="value"``. Note that they will be
passed after the ``hive_cli_params`` and thus will override
whatever values are specified in the database.
>>> hh = HiveCliHook()
>>> result = hh.run_cli("USE airflow;")
>>> ("OK" in result)
True
"""
conn = self.conn
schema = schema or conn.schema
invalid_chars_list = re.findall(r"[^a-z0-9_]", schema)
if invalid_chars_list:
invalid_chars = "".join(invalid_chars_list)
raise RuntimeError(f"The schema `{schema}` contains invalid characters: {invalid_chars}")
if schema:
hql = f"USE {schema};\n{hql}"
with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir, NamedTemporaryFile(dir=tmp_dir) as f:
hql += "\n"
f.write(hql.encode("UTF-8"))
f.flush()
hive_cmd = self._prepare_cli_cmd()
env_context = get_context_from_env_var()
# Only extend the hive_conf if it is defined.
if hive_conf:
env_context.update(hive_conf)
hive_conf_params = self._prepare_hiveconf(env_context)
if self.mapred_queue:
hive_conf_params.extend(
[
"-hiveconf",
f"mapreduce.job.queuename={self.mapred_queue}",
"-hiveconf",
f"mapred.job.queue.name={self.mapred_queue}",
"-hiveconf",
f"tez.queue.name={self.mapred_queue}",
]
)
if self.mapred_queue_priority:
hive_conf_params.extend(["-hiveconf", f"mapreduce.job.priority={self.mapred_queue_priority}"])
if self.mapred_job_name:
hive_conf_params.extend(["-hiveconf", f"mapred.job.name={self.mapred_job_name}"])
hive_cmd.extend(hive_conf_params)
hive_cmd.extend(["-f", f.name])
if verbose:
self.log.info("%s", " ".join(hive_cmd))
sub_process: Any = subprocess.Popen(
hive_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=tmp_dir, close_fds=True
)
self.sub_process = sub_process
stdout = ""
for line in iter(sub_process.stdout.readline, b""):
line = line.decode()
stdout += line
if verbose:
self.log.info(line.strip())
sub_process.wait()
if sub_process.returncode:
raise AirflowException(stdout)
return stdout
[docs] def test_hql(self, hql: str) -> None:
"""Test an hql statement using the hive cli and EXPLAIN."""
create, insert, other = [], [], []
for query in hql.split(";"): # naive
query_original = query
query = query.lower().strip()
if query.startswith("create table"):
create.append(query_original)
elif query.startswith(("set ", "add jar ", "create temporary function")):
other.append(query_original)
elif query.startswith("insert"):
insert.append(query_original)
other_ = ";".join(other)
for query_set in [create, insert]:
for query in query_set:
query_preview = " ".join(query.split())[:50]
self.log.info("Testing HQL [%s (...)]", query_preview)
if query_set == insert:
query = other_ + "; explain " + query
else:
query = "explain " + query
try:
self.run_cli(query, verbose=False)
except AirflowException as e:
message = e.args[0].splitlines()[-2]
self.log.info(message)
error_loc = re.search(r"(\d+):(\d+)", message)
if error_loc:
lst = int(error_loc.group(1))
begin = max(lst - 2, 0)
end = min(lst + 3, len(query.splitlines()))
context = "\n".join(query.splitlines()[begin:end])
self.log.info("Context :\n %s", context)
else:
self.log.info("SUCCESS")
[docs] def load_df(
self,
df: pd.DataFrame,
table: str,
field_dict: dict[Any, Any] | None = None,
delimiter: str = ",",
encoding: str = "utf8",
pandas_kwargs: Any = None,
**kwargs: Any,
) -> None:
"""
Load a pandas DataFrame into hive.
Hive data types will be inferred if not passed but column names will
not be sanitized.
:param df: DataFrame to load into a Hive table
:param table: target Hive table, use dot notation to target a
specific database
:param field_dict: mapping from column name to hive data type.
Note that Python dict is ordered so it keeps columns' order.
:param delimiter: field delimiter in the file
:param encoding: str encoding to use when writing DataFrame to file
:param pandas_kwargs: passed to DataFrame.to_csv
:param kwargs: passed to self.load_file
"""
def _infer_field_types_from_df(df: pd.DataFrame) -> dict[Any, Any]:
dtype_kind_hive_type = {
"b": "BOOLEAN", # boolean
"i": "BIGINT", # signed integer
"u": "BIGINT", # unsigned integer
"f": "DOUBLE", # floating-point
"c": "STRING", # complex floating-point
"M": "TIMESTAMP", # datetime
"O": "STRING", # object
"S": "STRING", # (byte-)string
"U": "STRING", # Unicode
"V": "STRING", # void
}
order_type = {}
for col, dtype in df.dtypes.items():
order_type[col] = dtype_kind_hive_type[dtype.kind]
return order_type
if pandas_kwargs is None:
pandas_kwargs = {}
with (
TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir,
NamedTemporaryFile(dir=tmp_dir, mode="w") as f,
):
if field_dict is None:
field_dict = _infer_field_types_from_df(df)
df.to_csv(
path_or_buf=f,
sep=delimiter,
header=False,
index=False,
encoding=encoding,
date_format="%Y-%m-%d %H:%M:%S",
**pandas_kwargs,
)
f.flush()
return self.load_file(
filepath=f.name, table=table, delimiter=delimiter, field_dict=field_dict, **kwargs
)
[docs] def load_file(
self,
filepath: str,
table: str,
delimiter: str = ",",
field_dict: dict[Any, Any] | None = None,
create: bool = True,
overwrite: bool = True,
partition: dict[str, Any] | None = None,
recreate: bool = False,
tblproperties: dict[str, Any] | None = None,
) -> None:
"""
Load a local file into Hive.
Note that the table generated in Hive uses ``STORED AS textfile``
which isn't the most efficient serialization format. If a
large amount of data is loaded and/or if the tables gets
queried considerably, you may want to use this operator only to
stage the data into a temporary table before loading it into its
final destination using a ``HiveOperator``.
:param filepath: local filepath of the file to load
:param table: target Hive table, use dot notation to target a
specific database
:param delimiter: field delimiter in the file
:param field_dict: A dictionary of the fields name in the file
as keys and their Hive types as values.
Note that Python dict is ordered so it keeps columns' order.
:param create: whether to create the table if it doesn't exist
:param overwrite: whether to overwrite the data in table or partition
:param partition: target partition as a dict of partition columns
and values
:param recreate: whether to drop and recreate the table at every
execution
:param tblproperties: TBLPROPERTIES of the hive table being created
"""
hql = ""
if recreate:
hql += f"DROP TABLE IF EXISTS {table};\n"
if create or recreate:
if field_dict is None:
raise ValueError("Must provide a field dict when creating a table")
fields = ",\n ".join(f"`{k.strip('`')}` {v}" for k, v in field_dict.items())
hql += f"CREATE TABLE IF NOT EXISTS {table} (\n{fields})\n"
if partition:
pfields = ",\n ".join(p + " STRING" for p in partition)
hql += f"PARTITIONED BY ({pfields})\n"
hql += "ROW FORMAT DELIMITED\n"
hql += f"FIELDS TERMINATED BY '{delimiter}'\n"
hql += "STORED AS textfile\n"
if tblproperties is not None:
tprops = ", ".join(f"'{k}'='{v}'" for k, v in tblproperties.items())
hql += f"TBLPROPERTIES({tprops})\n"
hql += ";"
self.log.info(hql)
self.run_cli(hql)
hql = f"LOAD DATA LOCAL INPATH '{filepath}' "
if overwrite:
hql += "OVERWRITE "
hql += f"INTO TABLE {table} "
if partition:
pvals = ", ".join(f"{k}='{v}'" for k, v in partition.items())
hql += f"PARTITION ({pvals})"
# Add a newline character as a workaround for https://issues.apache.org/jira/browse/HIVE-10541,
hql += ";\n"
self.log.info(hql)
self.run_cli(hql)
[docs] def kill(self) -> None:
"""Kill Hive cli command."""
if hasattr(self, "sub_process"):
if self.sub_process.poll() is None:
print("Killing the Hive job")
self.sub_process.terminate()
time.sleep(60)
self.sub_process.kill()
[docs]class HiveServer2Hook(DbApiHook):
"""
Wrapper around the pyhive library.
Notes:
* the default auth_mechanism is PLAIN, to override it you
can specify it in the ``extra`` of your connection in the UI
* the default for run_set_variable_statements is true, if you
are using impala you may need to set it to false in the
``extra`` of your connection in the UI
:param hiveserver2_conn_id: Reference to the
:ref: `Hive Server2 thrift service connection id <howto/connection:hiveserver2>`.
:param schema: Hive database name.
"""
[docs] conn_name_attr = "hiveserver2_conn_id"
[docs] default_conn_name = "hiveserver2_default"
[docs] conn_type = "hiveserver2"
[docs] hook_name = "Hive Server 2 Thrift"
[docs] supports_autocommit = False
[docs] def get_conn(self, schema: str | None = None) -> Any:
"""Return a Hive connection object."""
username: str | None = None
password: str | None = None
db = self.get_connection(self.hiveserver2_conn_id) # type: ignore
auth_mechanism = db.extra_dejson.get("auth_mechanism", "NONE")
if auth_mechanism == "NONE" and db.login is None:
# we need to give a username
username = "airflow"
kerberos_service_name = None
if conf.get("core", "security") == "kerberos":
auth_mechanism = db.extra_dejson.get("auth_mechanism", "KERBEROS")
kerberos_service_name = db.extra_dejson.get("kerberos_service_name", "hive")
# pyhive uses GSSAPI instead of KERBEROS as a auth_mechanism identifier
if auth_mechanism == "GSSAPI":
self.log.warning(
"Detected deprecated 'GSSAPI' for auth_mechanism for %s. Please use 'KERBEROS' instead",
self.hiveserver2_conn_id, # type: ignore
)
auth_mechanism = "KERBEROS"
# Password should be set if and only if in LDAP or CUSTOM mode
if auth_mechanism in ("LDAP", "CUSTOM"):
password = db.password
from pyhive.hive import connect
return connect(
host=db.host,
port=db.port,
auth=auth_mechanism,
kerberos_service_name=kerberos_service_name,
username=db.login or username,
password=password,
database=schema or db.schema or "default",
)
def _get_results(
self,
sql: str | list[str],
schema: str = "default",
fetch_size: int | None = None,
hive_conf: Iterable | Mapping | None = None,
) -> Any:
from pyhive.exc import ProgrammingError
if isinstance(sql, str):
sql = [sql]
previous_description = None
with contextlib.closing(self.get_conn(schema)) as conn, contextlib.closing(conn.cursor()) as cur:
cur.arraysize = fetch_size or 1000
db = self.get_connection(self.hiveserver2_conn_id) # type: ignore
# Not all query services (e.g. impala) support the set command
if db.extra_dejson.get("run_set_variable_statements", True):
env_context = get_context_from_env_var()
if hive_conf:
env_context.update(hive_conf)
for k, v in env_context.items():
cur.execute(f"set {k}={v}")
for statement in sql:
cur.execute(statement)
# we only get results of statements that returns
lowered_statement = statement.lower().strip()
if lowered_statement.startswith(("select", "with", "show")) or (
lowered_statement.startswith("set") and "=" not in lowered_statement
):
description = cur.description
if previous_description and previous_description != description:
message = f"""The statements are producing different descriptions:
Current: {description!r}
Previous: {previous_description!r}"""
raise ValueError(message)
elif not previous_description:
previous_description = description
yield description
try:
# DB API 2 raises when no results are returned
# we're silencing here as some statements in the list
# may be `SET` or DDL
yield from cur
except ProgrammingError:
self.log.debug("get_results returned no records")
[docs] def get_results(
self,
sql: str | list[str],
schema: str = "default",
fetch_size: int | None = None,
hive_conf: Iterable | Mapping | None = None,
) -> dict[str, Any]:
"""
Get results of the provided hql in target schema.
:param sql: hql to be executed.
:param schema: target schema, default to 'default'.
:param fetch_size: max size of result to fetch.
:param hive_conf: hive_conf to execute alone with the hql.
:return: results of hql execution, dict with data (list of results) and header
"""
results_iter = self._get_results(sql, schema, fetch_size=fetch_size, hive_conf=hive_conf)
header = next(results_iter)
results = {"data": list(results_iter), "header": header}
return results
[docs] def to_csv(
self,
sql: str,
csv_filepath: str,
schema: str = "default",
delimiter: str = ",",
lineterminator: str = "\r\n",
output_header: bool = True,
fetch_size: int = 1000,
hive_conf: dict[Any, Any] | None = None,
) -> None:
"""
Execute hql in target schema and write results to a csv file.
:param sql: hql to be executed.
:param csv_filepath: filepath of csv to write results into.
:param schema: target schema, default to 'default'.
:param delimiter: delimiter of the csv file, default to ','.
:param lineterminator: lineterminator of the csv file.
:param output_header: header of the csv file, default to True.
:param fetch_size: number of result rows to write into the csv file, default to 1000.
:param hive_conf: hive_conf to execute alone with the hql.
"""
results_iter = self._get_results(sql, schema, fetch_size=fetch_size, hive_conf=hive_conf)
header = next(results_iter)
message = None
i = 0
with open(csv_filepath, "w", encoding="utf-8") as file:
writer = csv.writer(file, delimiter=delimiter, lineterminator=lineterminator)
try:
if output_header:
self.log.debug("Cursor description is %s", header)
writer.writerow([c[0] for c in header])
for i, row in enumerate(results_iter, 1):
writer.writerow(row)
if i % fetch_size == 0:
self.log.info("Written %s rows so far.", i)
except ValueError as exception:
message = str(exception)
if message:
# need to clean up the file first
os.remove(csv_filepath)
raise ValueError(message)
self.log.info("Done. Loaded a total of %s rows.", i)
[docs] def get_records(
self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] | None = None, **kwargs
) -> Any:
"""
Get a set of records from a Hive query; optionally pass a 'schema' kwarg to specify target schema.
:param sql: hql to be executed.
:param parameters: optional configuration passed to get_results
:return: result of hive execution
>>> hh = HiveServer2Hook()
>>> sql = "SELECT * FROM airflow.static_babynames LIMIT 100"
>>> len(hh.get_records(sql))
100
"""
schema = kwargs["schema"] if "schema" in kwargs else "default"
return self.get_results(sql, schema=schema, hive_conf=parameters)["data"]
[docs] def get_pandas_df( # type: ignore
self,
sql: str,
schema: str = "default",
hive_conf: dict[Any, Any] | None = None,
**kwargs,
) -> pd.DataFrame:
"""
Get a pandas dataframe from a Hive query.
:param sql: hql to be executed.
:param schema: target schema, default to 'default'.
:param hive_conf: hive_conf to execute alone with the hql.
:param kwargs: (optional) passed into pandas.DataFrame constructor
:return: result of hive execution
>>> hh = HiveServer2Hook()
>>> sql = "SELECT * FROM airflow.static_babynames LIMIT 100"
>>> df = hh.get_pandas_df(sql)
>>> len(df.index)
100
:return: pandas.DateFrame
"""
try:
import pandas as pd
except ImportError as e:
from airflow.exceptions import AirflowOptionalProviderFeatureException
raise AirflowOptionalProviderFeatureException(e)
res = self.get_results(sql, schema=schema, hive_conf=hive_conf)
df = pd.DataFrame(res["data"], columns=[c[0] for c in res["header"]], **kwargs)
return df