Source code for airflow.providers.apache.hive.hooks.hive

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import contextlib
import os
import re
import socket
import subprocess
import time
from collections.abc import Iterable, Mapping
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    import pandas as pd

import csv

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.security import utils
from airflow.utils.helpers import as_flattened_list
from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING

[docs]HIVE_QUEUE_PRIORITIES = ["VERY_HIGH", "HIGH", "NORMAL", "LOW", "VERY_LOW"]
[docs]def get_context_from_env_var() -> dict[Any, Any]: """ Extract context from env variable, (dag_id, task_id, etc) for use in BashOperator and PythonOperator. :return: The context of interest. """ return { format_map["default"]: os.environ.get(format_map["env_var_format"], "") for format_map in AIRFLOW_VAR_NAME_FORMAT_MAPPING.values() }
[docs]class HiveCliHook(BaseHook): """ Simple wrapper around the hive CLI. It also supports the ``beeline`` a lighter CLI that runs JDBC and is replacing the heavier traditional CLI. To enable ``beeline``, set the use_beeline param in the extra field of your connection as in ``{ "use_beeline": true }`` Note that you can also set default hive CLI parameters by passing ``hive_cli_params`` space separated list of parameters to add to the hive command. The extra connection parameter ``auth`` gets passed as in the ``jdbc`` connection string as is. :param hive_cli_conn_id: Reference to the :ref:`Hive CLI connection id <howto/connection:hive_cli>`. :param mapred_queue: queue used by the Hadoop Scheduler (Capacity or Fair) :param mapred_queue_priority: priority within the job queue. Possible settings include: VERY_HIGH, HIGH, NORMAL, LOW, VERY_LOW :param mapred_job_name: This name will appear in the jobtracker. This can make monitoring easier. :param hive_cli_params: Space separated list of hive command parameters to add to the hive command. :param proxy_user: Run HQL code as this user. """
[docs] conn_name_attr = "hive_cli_conn_id"
[docs] default_conn_name = "hive_cli_default"
[docs] conn_type = "hive_cli"
[docs] hook_name = "Hive Client Wrapper"
def __init__( self, hive_cli_conn_id: str = default_conn_name, mapred_queue: str | None = None, mapred_queue_priority: str | None = None, mapred_job_name: str | None = None, hive_cli_params: str = "", auth: str | None = None, proxy_user: str | None = None, ) -> None: super().__init__() conn = self.get_connection(hive_cli_conn_id) self.hive_cli_params: str = hive_cli_params self.use_beeline: bool = conn.extra_dejson.get("use_beeline", False) self.auth = auth self.conn = conn self.sub_process: Any = None if mapred_queue_priority: mapred_queue_priority = mapred_queue_priority.upper() if mapred_queue_priority not in HIVE_QUEUE_PRIORITIES: raise AirflowException( f"Invalid Mapred Queue Priority. Valid values are: {', '.join(HIVE_QUEUE_PRIORITIES)}" ) self.mapred_queue = mapred_queue or conf.get("hive", "default_hive_mapred_queue") self.mapred_queue_priority = mapred_queue_priority self.mapred_job_name = mapred_job_name self.proxy_user = proxy_user self.high_availability = self.conn.extra_dejson.get("high_availability", False) @classmethod
[docs] def get_connection_form_widgets(cls) -> dict[str, Any]: """Return connection widgets to add to Hive Client Wrapper connection form.""" from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from flask_babel import lazy_gettext from wtforms import BooleanField, StringField return { "use_beeline": BooleanField(lazy_gettext("Use Beeline"), default=True), "proxy_user": StringField(lazy_gettext("Proxy User"), widget=BS3TextFieldWidget(), default=""), "principal": StringField( lazy_gettext("Principal"), widget=BS3TextFieldWidget(), default="hive/_HOST@EXAMPLE.COM" ), "high_availability": BooleanField(lazy_gettext("High Availability mode"), default=False), }
@classmethod
[docs] def get_ui_field_behaviour(cls) -> dict[str, Any]: """Return custom UI field behaviour for Hive Client Wrapper connection.""" return { "hidden_fields": ["extra"], "relabeling": {}, }
def _get_proxy_user(self) -> str: """Set the proper proxy_user value in case the user overwrite the default.""" conn = self.conn if self.proxy_user is not None: return f"hive.server2.proxy.user={self.proxy_user}" proxy_user_value: str = conn.extra_dejson.get("proxy_user", "") if proxy_user_value != "": return f"hive.server2.proxy.user={proxy_user_value}" return "" def _prepare_cli_cmd(self) -> list[Any]: """Create the command list from available information.""" conn = self.conn hive_bin = "hive" cmd_extra = [] if self.use_beeline: hive_bin = "beeline" self._validate_beeline_parameters(conn) if self.high_availability: jdbc_url = f"jdbc:hive2://{conn.host}/{conn.schema}" self.log.info("High Availability selected, setting JDBC url as %s", jdbc_url) else: jdbc_url = f"jdbc:hive2://{conn.host}:{conn.port}/{conn.schema}" self.log.info("High Availability not selected, setting JDBC url as %s", jdbc_url) if conf.get("core", "security") == "kerberos": template = conn.extra_dejson.get("principal", "hive/_HOST@EXAMPLE.COM") if "_HOST" in template: template = utils.replace_hostname_pattern(utils.get_components(template)) proxy_user = self._get_proxy_user() if ";" in template: raise RuntimeError("The principal should not contain the ';' character") if ";" in proxy_user: raise RuntimeError("The proxy_user should not contain the ';' character") jdbc_url += f";principal={template};{proxy_user}" if self.high_availability: if not jdbc_url.endswith(";"): jdbc_url += ";" jdbc_url += "serviceDiscoveryMode=zooKeeper;ssl=true;zooKeeperNamespace=hiveserver2" elif self.auth: jdbc_url += ";auth=" + self.auth jdbc_url = f'"{jdbc_url}"' cmd_extra += ["-u", jdbc_url] if conn.login: cmd_extra += ["-n", conn.login] if conn.password: cmd_extra += ["-p", conn.password] hive_params_list = self.hive_cli_params.split() return [hive_bin, *cmd_extra, *hive_params_list] def _validate_beeline_parameters(self, conn): if self.high_availability: if ";" in conn.schema: raise ValueError( f"The schema used in beeline command ({conn.schema}) should not contain ';' character)" ) return elif ":" in conn.host or "/" in conn.host or ";" in conn.host: raise ValueError( f"The host used in beeline command ({conn.host}) should not contain ':/;' characters)" ) try: int_port = int(conn.port) if not 0 < int_port <= 65535: raise ValueError( f"The port used in beeline command ({conn.port}) should be in range 0-65535)" ) except (ValueError, TypeError) as e: raise ValueError( f"The port used in beeline command ({conn.port}) should be a valid integer: {e})" ) if ";" in conn.schema: raise ValueError( f"The schema used in beeline command ({conn.schema}) should not contain ';' character)" ) @staticmethod def _prepare_hiveconf(d: dict[Any, Any]) -> list[Any]: """ Prepare a list of hiveconf params from a dictionary of key value pairs. :param d: >>> hh = HiveCliHook() >>> hive_conf = {"hive.exec.dynamic.partition": "true", ... "hive.exec.dynamic.partition.mode": "nonstrict"} >>> hh._prepare_hiveconf(hive_conf) ["-hiveconf", "hive.exec.dynamic.partition=true",\ "-hiveconf", "hive.exec.dynamic.partition.mode=nonstrict"] """ if not d: return [] return as_flattened_list(zip(["-hiveconf"] * len(d), [f"{k}={v}" for k, v in d.items()]))
[docs] def run_cli( self, hql: str, schema: str | None = None, verbose: bool = True, hive_conf: dict[Any, Any] | None = None, ) -> Any: """ Run an hql statement using the hive cli. If hive_conf is specified it should be a dict and the entries will be set as key/value pairs in HiveConf. :param hql: an hql (hive query language) statement to run with hive cli :param schema: Name of hive schema (database) to use :param verbose: Provides additional logging. Defaults to True. :param hive_conf: if specified these key value pairs will be passed to hive as ``-hiveconf "key"="value"``. Note that they will be passed after the ``hive_cli_params`` and thus will override whatever values are specified in the database. >>> hh = HiveCliHook() >>> result = hh.run_cli("USE airflow;") >>> ("OK" in result) True """ conn = self.conn schema = schema or conn.schema invalid_chars_list = re.findall(r"[^a-z0-9_]", schema) if invalid_chars_list: invalid_chars = "".join(invalid_chars_list) raise RuntimeError(f"The schema `{schema}` contains invalid characters: {invalid_chars}") if schema: hql = f"USE {schema};\n{hql}" with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir, NamedTemporaryFile(dir=tmp_dir) as f: hql += "\n" f.write(hql.encode("UTF-8")) f.flush() hive_cmd = self._prepare_cli_cmd() env_context = get_context_from_env_var() # Only extend the hive_conf if it is defined. if hive_conf: env_context.update(hive_conf) hive_conf_params = self._prepare_hiveconf(env_context) if self.mapred_queue: hive_conf_params.extend( [ "-hiveconf", f"mapreduce.job.queuename={self.mapred_queue}", "-hiveconf", f"mapred.job.queue.name={self.mapred_queue}", "-hiveconf", f"tez.queue.name={self.mapred_queue}", ] ) if self.mapred_queue_priority: hive_conf_params.extend(["-hiveconf", f"mapreduce.job.priority={self.mapred_queue_priority}"]) if self.mapred_job_name: hive_conf_params.extend(["-hiveconf", f"mapred.job.name={self.mapred_job_name}"]) hive_cmd.extend(hive_conf_params) hive_cmd.extend(["-f", f.name]) if verbose: self.log.info("%s", " ".join(hive_cmd)) sub_process: Any = subprocess.Popen( hive_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=tmp_dir, close_fds=True ) self.sub_process = sub_process stdout = "" for line in iter(sub_process.stdout.readline, b""): line = line.decode() stdout += line if verbose: self.log.info(line.strip()) sub_process.wait() if sub_process.returncode: raise AirflowException(stdout) return stdout
[docs] def test_hql(self, hql: str) -> None: """Test an hql statement using the hive cli and EXPLAIN.""" create, insert, other = [], [], [] for query in hql.split(";"): # naive query_original = query query = query.lower().strip() if query.startswith("create table"): create.append(query_original) elif query.startswith(("set ", "add jar ", "create temporary function")): other.append(query_original) elif query.startswith("insert"): insert.append(query_original) other_ = ";".join(other) for query_set in [create, insert]: for query in query_set: query_preview = " ".join(query.split())[:50] self.log.info("Testing HQL [%s (...)]", query_preview) if query_set == insert: query = other_ + "; explain " + query else: query = "explain " + query try: self.run_cli(query, verbose=False) except AirflowException as e: message = e.args[0].splitlines()[-2] self.log.info(message) error_loc = re.search(r"(\d+):(\d+)", message) if error_loc: lst = int(error_loc.group(1)) begin = max(lst - 2, 0) end = min(lst + 3, len(query.splitlines())) context = "\n".join(query.splitlines()[begin:end]) self.log.info("Context :\n %s", context) else: self.log.info("SUCCESS")
[docs] def load_df( self, df: pd.DataFrame, table: str, field_dict: dict[Any, Any] | None = None, delimiter: str = ",", encoding: str = "utf8", pandas_kwargs: Any = None, **kwargs: Any, ) -> None: """ Load a pandas DataFrame into hive. Hive data types will be inferred if not passed but column names will not be sanitized. :param df: DataFrame to load into a Hive table :param table: target Hive table, use dot notation to target a specific database :param field_dict: mapping from column name to hive data type. Note that Python dict is ordered so it keeps columns' order. :param delimiter: field delimiter in the file :param encoding: str encoding to use when writing DataFrame to file :param pandas_kwargs: passed to DataFrame.to_csv :param kwargs: passed to self.load_file """ def _infer_field_types_from_df(df: pd.DataFrame) -> dict[Any, Any]: dtype_kind_hive_type = { "b": "BOOLEAN", # boolean "i": "BIGINT", # signed integer "u": "BIGINT", # unsigned integer "f": "DOUBLE", # floating-point "c": "STRING", # complex floating-point "M": "TIMESTAMP", # datetime "O": "STRING", # object "S": "STRING", # (byte-)string "U": "STRING", # Unicode "V": "STRING", # void } order_type = {} for col, dtype in df.dtypes.items(): order_type[col] = dtype_kind_hive_type[dtype.kind] return order_type if pandas_kwargs is None: pandas_kwargs = {} with ( TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir, NamedTemporaryFile(dir=tmp_dir, mode="w") as f, ): if field_dict is None: field_dict = _infer_field_types_from_df(df) df.to_csv( path_or_buf=f, sep=delimiter, header=False, index=False, encoding=encoding, date_format="%Y-%m-%d %H:%M:%S", **pandas_kwargs, ) f.flush() return self.load_file( filepath=f.name, table=table, delimiter=delimiter, field_dict=field_dict, **kwargs )
[docs] def load_file( self, filepath: str, table: str, delimiter: str = ",", field_dict: dict[Any, Any] | None = None, create: bool = True, overwrite: bool = True, partition: dict[str, Any] | None = None, recreate: bool = False, tblproperties: dict[str, Any] | None = None, ) -> None: """ Load a local file into Hive. Note that the table generated in Hive uses ``STORED AS textfile`` which isn't the most efficient serialization format. If a large amount of data is loaded and/or if the tables gets queried considerably, you may want to use this operator only to stage the data into a temporary table before loading it into its final destination using a ``HiveOperator``. :param filepath: local filepath of the file to load :param table: target Hive table, use dot notation to target a specific database :param delimiter: field delimiter in the file :param field_dict: A dictionary of the fields name in the file as keys and their Hive types as values. Note that Python dict is ordered so it keeps columns' order. :param create: whether to create the table if it doesn't exist :param overwrite: whether to overwrite the data in table or partition :param partition: target partition as a dict of partition columns and values :param recreate: whether to drop and recreate the table at every execution :param tblproperties: TBLPROPERTIES of the hive table being created """ hql = "" if recreate: hql += f"DROP TABLE IF EXISTS {table};\n" if create or recreate: if field_dict is None: raise ValueError("Must provide a field dict when creating a table") fields = ",\n ".join(f"`{k.strip('`')}` {v}" for k, v in field_dict.items()) hql += f"CREATE TABLE IF NOT EXISTS {table} (\n{fields})\n" if partition: pfields = ",\n ".join(p + " STRING" for p in partition) hql += f"PARTITIONED BY ({pfields})\n" hql += "ROW FORMAT DELIMITED\n" hql += f"FIELDS TERMINATED BY '{delimiter}'\n" hql += "STORED AS textfile\n" if tblproperties is not None: tprops = ", ".join(f"'{k}'='{v}'" for k, v in tblproperties.items()) hql += f"TBLPROPERTIES({tprops})\n" hql += ";" self.log.info(hql) self.run_cli(hql) hql = f"LOAD DATA LOCAL INPATH '{filepath}' " if overwrite: hql += "OVERWRITE " hql += f"INTO TABLE {table} " if partition: pvals = ", ".join(f"{k}='{v}'" for k, v in partition.items()) hql += f"PARTITION ({pvals})" # Add a newline character as a workaround for https://issues.apache.org/jira/browse/HIVE-10541, hql += ";\n" self.log.info(hql) self.run_cli(hql)
[docs] def kill(self) -> None: """Kill Hive cli command.""" if hasattr(self, "sub_process"): if self.sub_process.poll() is None: print("Killing the Hive job") self.sub_process.terminate() time.sleep(60) self.sub_process.kill()
[docs]class HiveMetastoreHook(BaseHook): """ Wrapper to interact with the Hive Metastore. :param metastore_conn_id: reference to the :ref: `metastore thrift service connection id <howto/connection:hive_metastore>`. """ # java short max val
[docs] MAX_PART_COUNT = 32767
[docs] conn_name_attr = "metastore_conn_id"
[docs] default_conn_name = "metastore_default"
[docs] conn_type = "hive_metastore"
[docs] hook_name = "Hive Metastore Thrift"
def __init__(self, metastore_conn_id: str = default_conn_name) -> None: super().__init__() self.conn = self.get_connection(metastore_conn_id) self.metastore = self.get_metastore_client()
[docs] def __getstate__(self) -> dict[str, Any]: # This is for pickling to work despite the thrift hive client not # being picklable """Serialize object and omit non-serializable attributes.""" state = dict(self.__dict__) del state["metastore"] return state
[docs] def __setstate__(self, d: dict[str, Any]) -> None: """Deserialize object and restore non-serializable attributes.""" self.__dict__.update(d) self.__dict__["metastore"] = self.get_metastore_client()
[docs] def get_metastore_client(self) -> Any: """Return a Hive thrift client.""" import hmsclient from thrift.protocol import TBinaryProtocol from thrift.transport import TSocket, TTransport host = self._find_valid_host() conn = self.conn if not host: raise AirflowException("Failed to locate the valid server.") auth_mechanism = conn.extra_dejson.get("auth_mechanism", "NOSASL") if conf.get("core", "security") == "kerberos": auth_mechanism = conn.extra_dejson.get("auth_mechanism", "GSSAPI") kerberos_service_name = conn.extra_dejson.get("kerberos_service_name", "hive") conn_socket = TSocket.TSocket(host, conn.port) if conf.get("core", "security") == "kerberos" and auth_mechanism == "GSSAPI": try: import saslwrapper as sasl except ImportError: import sasl def sasl_factory() -> sasl.Client: sasl_client = sasl.Client() sasl_client.setAttr("host", host) sasl_client.setAttr("service", kerberos_service_name) sasl_client.init() return sasl_client from thrift_sasl import TSaslClientTransport transport = TSaslClientTransport(sasl_factory, "GSSAPI", conn_socket) else: transport = TTransport.TBufferedTransport(conn_socket) protocol = TBinaryProtocol.TBinaryProtocol(transport) return hmsclient.HMSClient(iprot=protocol)
def _find_valid_host(self) -> Any: conn = self.conn hosts = conn.host.split(",") for host in hosts: host_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.log.info("Trying to connect to %s:%s", host, conn.port) if host_socket.connect_ex((host, conn.port)) == 0: self.log.info("Connected to %s:%s", host, conn.port) host_socket.close() return host else: self.log.error("Could not connect to %s:%s", host, conn.port) return None
[docs] def get_conn(self) -> Any: return self.metastore
[docs] def check_for_partition(self, schema: str, table: str, partition: str) -> bool: """ Check whether a partition exists. :param schema: Name of hive schema (database) @table belongs to :param table: Name of hive table @partition belongs to :param partition: Expression that matches the partitions to check for (e.g. `a = 'b' AND c = 'd'`) >>> hh = HiveMetastoreHook() >>> t = "static_babynames_partitioned" >>> hh.check_for_partition("airflow", t, "ds='2015-01-01'") True """ with self.metastore as client: partitions = client.get_partitions_by_filter( schema, table, partition, HiveMetastoreHook.MAX_PART_COUNT ) return bool(partitions)
[docs] def check_for_named_partition(self, schema: str, table: str, partition_name: str) -> Any: """ Check whether a partition with a given name exists. :param schema: Name of hive schema (database) @table belongs to :param table: Name of hive table @partition belongs to :param partition_name: Name of the partitions to check for (eg `a=b/c=d`) >>> hh = HiveMetastoreHook() >>> t = "static_babynames_partitioned" >>> hh.check_for_named_partition("airflow", t, "ds=2015-01-01") True >>> hh.check_for_named_partition("airflow", t, "ds=xxx") False """ with self.metastore as client: return client.check_for_named_partition(schema, table, partition_name)
[docs] def get_table(self, table_name: str, db: str = "default") -> Any: """ Get a metastore table object. >>> hh = HiveMetastoreHook() >>> t = hh.get_table(db="airflow", table_name="static_babynames") >>> t.tableName 'static_babynames' >>> [col.name for col in t.sd.cols] ['state', 'year', 'name', 'gender', 'num'] """ if db == "default" and "." in table_name: db, table_name = table_name.split(".")[:2] with self.metastore as client: return client.get_table(dbname=db, tbl_name=table_name)
[docs] def get_tables(self, db: str, pattern: str = "*") -> Any: """Get a metastore table object.""" with self.metastore as client: tables = client.get_tables(db_name=db, pattern=pattern) return client.get_table_objects_by_name(db, tables)
[docs] def get_databases(self, pattern: str = "*") -> Any: """Get a metastore table object.""" with self.metastore as client: return client.get_databases(pattern)
[docs] def get_partitions(self, schema: str, table_name: str, partition_filter: str | None = None) -> list[Any]: """ Return a list of all partitions in a table. Works only for tables with less than 32767 (java short max val). For subpartitioned table, the number might easily exceed this. >>> hh = HiveMetastoreHook() >>> t = "static_babynames_partitioned" >>> parts = hh.get_partitions(schema="airflow", table_name=t) >>> len(parts) 1 >>> parts [{'ds': '2015-01-01'}] """ with self.metastore as client: table = client.get_table(dbname=schema, tbl_name=table_name) if table.partitionKeys: if partition_filter: parts = client.get_partitions_by_filter( db_name=schema, tbl_name=table_name, filter=partition_filter, max_parts=HiveMetastoreHook.MAX_PART_COUNT, ) else: parts = client.get_partitions( db_name=schema, tbl_name=table_name, max_parts=HiveMetastoreHook.MAX_PART_COUNT ) pnames = [p.name for p in table.partitionKeys] return [dict(zip(pnames, p.values)) for p in parts] else: raise AirflowException("The table isn't partitioned")
@staticmethod def _get_max_partition_from_part_specs( part_specs: list[Any], partition_key: str | None, filter_map: dict[str, Any] | None ) -> Any: """ Get max partition of partitions with partition_key from part specs. key:value pair in filter_map will be used to filter out partitions. :param part_specs: list of partition specs. :param partition_key: partition key name. :param filter_map: partition_key:partition_value map used for partition filtering, e.g. {'key1': 'value1', 'key2': 'value2'}. Only partitions matching all partition_key:partition_value pairs will be considered as candidates of max partition. :return: Max partition or None if part_specs is empty. """ if not part_specs: return None # Assuming all specs have the same keys. if partition_key not in part_specs[0].keys(): raise AirflowException(f"Provided partition_key {partition_key} is not in part_specs.") if filter_map and not set(filter_map).issubset(part_specs[0]): raise AirflowException( f"Keys in provided filter_map {', '.join(filter_map.keys())} " f"are not subset of part_spec keys: {', '.join(part_specs[0].keys())}" ) return max( ( p_dict[partition_key] for p_dict in part_specs if filter_map is None or all(item in p_dict.items() for item in filter_map.items()) ), default=None, )
[docs] def max_partition( self, schema: str, table_name: str, field: str | None = None, filter_map: dict[Any, Any] | None = None, ) -> Any: """ Return the maximum value for all partitions with given field in a table. If only one partition key exist in the table, the key will be used as field. filter_map should be a partition_key:partition_value map and will be used to filter out partitions. :param schema: schema name. :param table_name: table name. :param field: partition key to get max partition from. :param filter_map: partition_key:partition_value map used for partition filtering. >>> hh = HiveMetastoreHook() >>> filter_map = {'ds': '2015-01-01'} >>> t = 'static_babynames_partitioned' >>> hh.max_partition(schema='airflow',\ ... table_name=t, field='ds', filter_map=filter_map) '2015-01-01' """ with self.metastore as client: table = client.get_table(dbname=schema, tbl_name=table_name) key_name_set = {key.name for key in table.partitionKeys} if len(table.partitionKeys) == 1: field = table.partitionKeys[0].name elif not field: raise AirflowException("Please specify the field you want the max value for.") elif field not in key_name_set: raise AirflowException("Provided field is not a partition key.") if filter_map and not set(filter_map.keys()).issubset(key_name_set): raise AirflowException("Provided filter_map contains keys that are not partition key.") part_names = client.get_partition_names( schema, table_name, max_parts=HiveMetastoreHook.MAX_PART_COUNT ) part_specs = [client.partition_name_to_spec(part_name) for part_name in part_names] return HiveMetastoreHook._get_max_partition_from_part_specs(part_specs, field, filter_map)
[docs] def table_exists(self, table_name: str, db: str = "default") -> bool: """ Check if table exists. >>> hh = HiveMetastoreHook() >>> hh.table_exists(db="airflow", table_name="static_babynames") True >>> hh.table_exists(db="airflow", table_name="does_not_exist") False """ try: self.get_table(table_name, db) return True except Exception: return False
[docs] def drop_partitions(self, table_name, part_vals, delete_data=False, db="default"): """ Drop partitions from the given table matching the part_vals input. :param table_name: table name. :param part_vals: list of partition specs. :param delete_data: Setting to control if underlying data have to deleted in addition to dropping partitions. :param db: Name of hive schema (database) @table belongs to >>> hh = HiveMetastoreHook() >>> hh.drop_partitions(db='airflow', table_name='static_babynames', part_vals="['2020-05-01']") True """ if self.table_exists(table_name, db): with self.metastore as client: self.log.info( "Dropping partition of table %s.%s matching the spec: %s", db, table_name, part_vals ) return client.drop_partition(db, table_name, part_vals, delete_data) else: self.log.info("Table %s.%s does not exist!", db, table_name) return False
[docs]class HiveServer2Hook(DbApiHook): """ Wrapper around the pyhive library. Notes: * the default auth_mechanism is PLAIN, to override it you can specify it in the ``extra`` of your connection in the UI * the default for run_set_variable_statements is true, if you are using impala you may need to set it to false in the ``extra`` of your connection in the UI :param hiveserver2_conn_id: Reference to the :ref: `Hive Server2 thrift service connection id <howto/connection:hiveserver2>`. :param schema: Hive database name. """
[docs] conn_name_attr = "hiveserver2_conn_id"
[docs] default_conn_name = "hiveserver2_default"
[docs] conn_type = "hiveserver2"
[docs] hook_name = "Hive Server 2 Thrift"
[docs] supports_autocommit = False
[docs] def get_conn(self, schema: str | None = None) -> Any: """Return a Hive connection object.""" username: str | None = None password: str | None = None db = self.get_connection(self.hiveserver2_conn_id) # type: ignore auth_mechanism = db.extra_dejson.get("auth_mechanism", "NONE") if auth_mechanism == "NONE" and db.login is None: # we need to give a username username = "airflow" kerberos_service_name = None if conf.get("core", "security") == "kerberos": auth_mechanism = db.extra_dejson.get("auth_mechanism", "KERBEROS") kerberos_service_name = db.extra_dejson.get("kerberos_service_name", "hive") # pyhive uses GSSAPI instead of KERBEROS as a auth_mechanism identifier if auth_mechanism == "GSSAPI": self.log.warning( "Detected deprecated 'GSSAPI' for auth_mechanism for %s. Please use 'KERBEROS' instead", self.hiveserver2_conn_id, # type: ignore ) auth_mechanism = "KERBEROS" # Password should be set if and only if in LDAP or CUSTOM mode if auth_mechanism in ("LDAP", "CUSTOM"): password = db.password from pyhive.hive import connect return connect( host=db.host, port=db.port, auth=auth_mechanism, kerberos_service_name=kerberos_service_name, username=db.login or username, password=password, database=schema or db.schema or "default", )
def _get_results( self, sql: str | list[str], schema: str = "default", fetch_size: int | None = None, hive_conf: Iterable | Mapping | None = None, ) -> Any: from pyhive.exc import ProgrammingError if isinstance(sql, str): sql = [sql] previous_description = None with contextlib.closing(self.get_conn(schema)) as conn, contextlib.closing(conn.cursor()) as cur: cur.arraysize = fetch_size or 1000 db = self.get_connection(self.hiveserver2_conn_id) # type: ignore # Not all query services (e.g. impala) support the set command if db.extra_dejson.get("run_set_variable_statements", True): env_context = get_context_from_env_var() if hive_conf: env_context.update(hive_conf) for k, v in env_context.items(): cur.execute(f"set {k}={v}") for statement in sql: cur.execute(statement) # we only get results of statements that returns lowered_statement = statement.lower().strip() if lowered_statement.startswith(("select", "with", "show")) or ( lowered_statement.startswith("set") and "=" not in lowered_statement ): description = cur.description if previous_description and previous_description != description: message = f"""The statements are producing different descriptions: Current: {description!r} Previous: {previous_description!r}""" raise ValueError(message) elif not previous_description: previous_description = description yield description try: # DB API 2 raises when no results are returned # we're silencing here as some statements in the list # may be `SET` or DDL yield from cur except ProgrammingError: self.log.debug("get_results returned no records")
[docs] def get_results( self, sql: str | list[str], schema: str = "default", fetch_size: int | None = None, hive_conf: Iterable | Mapping | None = None, ) -> dict[str, Any]: """ Get results of the provided hql in target schema. :param sql: hql to be executed. :param schema: target schema, default to 'default'. :param fetch_size: max size of result to fetch. :param hive_conf: hive_conf to execute alone with the hql. :return: results of hql execution, dict with data (list of results) and header """ results_iter = self._get_results(sql, schema, fetch_size=fetch_size, hive_conf=hive_conf) header = next(results_iter) results = {"data": list(results_iter), "header": header} return results
[docs] def to_csv( self, sql: str, csv_filepath: str, schema: str = "default", delimiter: str = ",", lineterminator: str = "\r\n", output_header: bool = True, fetch_size: int = 1000, hive_conf: dict[Any, Any] | None = None, ) -> None: """ Execute hql in target schema and write results to a csv file. :param sql: hql to be executed. :param csv_filepath: filepath of csv to write results into. :param schema: target schema, default to 'default'. :param delimiter: delimiter of the csv file, default to ','. :param lineterminator: lineterminator of the csv file. :param output_header: header of the csv file, default to True. :param fetch_size: number of result rows to write into the csv file, default to 1000. :param hive_conf: hive_conf to execute alone with the hql. """ results_iter = self._get_results(sql, schema, fetch_size=fetch_size, hive_conf=hive_conf) header = next(results_iter) message = None i = 0 with open(csv_filepath, "w", encoding="utf-8") as file: writer = csv.writer(file, delimiter=delimiter, lineterminator=lineterminator) try: if output_header: self.log.debug("Cursor description is %s", header) writer.writerow([c[0] for c in header]) for i, row in enumerate(results_iter, 1): writer.writerow(row) if i % fetch_size == 0: self.log.info("Written %s rows so far.", i) except ValueError as exception: message = str(exception) if message: # need to clean up the file first os.remove(csv_filepath) raise ValueError(message) self.log.info("Done. Loaded a total of %s rows.", i)
[docs] def get_records( self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] | None = None, **kwargs ) -> Any: """ Get a set of records from a Hive query; optionally pass a 'schema' kwarg to specify target schema. :param sql: hql to be executed. :param parameters: optional configuration passed to get_results :return: result of hive execution >>> hh = HiveServer2Hook() >>> sql = "SELECT * FROM airflow.static_babynames LIMIT 100" >>> len(hh.get_records(sql)) 100 """ schema = kwargs["schema"] if "schema" in kwargs else "default" return self.get_results(sql, schema=schema, hive_conf=parameters)["data"]
[docs] def get_pandas_df( # type: ignore self, sql: str, schema: str = "default", hive_conf: dict[Any, Any] | None = None, **kwargs, ) -> pd.DataFrame: """ Get a pandas dataframe from a Hive query. :param sql: hql to be executed. :param schema: target schema, default to 'default'. :param hive_conf: hive_conf to execute alone with the hql. :param kwargs: (optional) passed into pandas.DataFrame constructor :return: result of hive execution >>> hh = HiveServer2Hook() >>> sql = "SELECT * FROM airflow.static_babynames LIMIT 100" >>> df = hh.get_pandas_df(sql) >>> len(df.index) 100 :return: pandas.DateFrame """ try: import pandas as pd except ImportError as e: from airflow.exceptions import AirflowOptionalProviderFeatureException raise AirflowOptionalProviderFeatureException(e) res = self.get_results(sql, schema=schema, hive_conf=hive_conf) df = pd.DataFrame(res["data"], columns=[c[0] for c in res["header"]], **kwargs) return df

Was this entry helpful?