Source code for airflow.providers.amazon.aws.hooks.redshift_data

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import time
from collections.abc import Iterable
from dataclasses import dataclass
from pprint import pformat
from typing import TYPE_CHECKING, Any
from uuid import UUID

from pendulum import duration

from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.utils import trim_none_values

if TYPE_CHECKING:
    from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient  # noqa: F401
    from mypy_boto3_redshift_data.type_defs import DescribeStatementResponseTypeDef

[docs]FINISHED_STATE = "FINISHED"
[docs]FAILED_STATE = "FAILED"
[docs]ABORTED_STATE = "ABORTED"
[docs]FAILURE_STATES = {FAILED_STATE, ABORTED_STATE}
[docs]RUNNING_STATES = {"PICKED", "STARTED", "SUBMITTED"}
@dataclass
[docs]class QueryExecutionOutput: """Describes the output of a query execution."""
[docs] statement_id: str
[docs] session_id: str | None
[docs]class RedshiftDataQueryFailedError(ValueError): """Raise an error that redshift data query failed."""
[docs]class RedshiftDataQueryAbortedError(ValueError): """Raise an error that redshift data query was aborted."""
[docs]class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]): """ Interact with Amazon Redshift Data API. Provide thin wrapper around :external+boto3:py:class:`boto3.client("redshift-data") <RedshiftDataAPIService.Client>`. Additional arguments (such as ``aws_conn_id``) may be specified and are passed down to the underlying AwsBaseHook. .. seealso:: - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` - `Amazon Redshift Data API \ <https://docs.aws.amazon.com/redshift-data/latest/APIReference/Welcome.html>`__ """ def __init__(self, *args, **kwargs) -> None: kwargs["client_type"] = "redshift-data" super().__init__(*args, **kwargs)
[docs] def execute_query( self, sql: str | list[str], database: str | None = None, cluster_identifier: str | None = None, db_user: str | None = None, parameters: Iterable | None = None, secret_arn: str | None = None, statement_name: str | None = None, with_event: bool = False, wait_for_completion: bool = True, poll_interval: int = 10, workgroup_name: str | None = None, session_id: str | None = None, session_keep_alive_seconds: int | None = None, ) -> QueryExecutionOutput: """ Execute a statement against Amazon Redshift. :param sql: the SQL statement or list of SQL statement to run :param database: the name of the database :param cluster_identifier: unique identifier of a cluster :param db_user: the database username :param parameters: the parameters for the SQL statement :param secret_arn: the name or ARN of the secret that enables db access :param statement_name: the name of the SQL statement :param with_event: whether to send an event to EventBridge :param wait_for_completion: whether to wait for a result :param poll_interval: how often in seconds to check the query status :param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with `cluster_identifier`. Specify this parameter to query Redshift Serverless. More info https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html :param session_id: the session identifier of the query :param session_keep_alive_seconds: duration in seconds to keep the session alive after the query finishes. The maximum time a session can keep alive is 24 hours :returns statement_id: str, the UUID of the statement """ kwargs: dict[str, Any] = { "ClusterIdentifier": cluster_identifier, "Database": database, "DbUser": db_user, "Parameters": parameters, "WithEvent": with_event, "SecretArn": secret_arn, "StatementName": statement_name, "WorkgroupName": workgroup_name, "SessionId": session_id, "SessionKeepAliveSeconds": session_keep_alive_seconds, } if sum(x is not None for x in (cluster_identifier, workgroup_name, session_id)) != 1: raise ValueError( "Exactly one of cluster_identifier, workgroup_name, or session_id must be provided" ) if session_id is not None: msg = "session_id must be a valid UUID4" try: if UUID(session_id).version != 4: raise ValueError(msg) except ValueError: raise ValueError(msg) if session_keep_alive_seconds is not None and ( session_keep_alive_seconds < 0 or duration(seconds=session_keep_alive_seconds).hours > 24 ): raise ValueError("Session keep alive duration must be between 0 and 86400 seconds.") if isinstance(sql, list): kwargs["Sqls"] = sql resp = self.conn.batch_execute_statement(**trim_none_values(kwargs)) else: kwargs["Sql"] = sql resp = self.conn.execute_statement(**trim_none_values(kwargs)) statement_id = resp["Id"] if wait_for_completion: self.wait_for_results(statement_id, poll_interval=poll_interval) return QueryExecutionOutput(statement_id=statement_id, session_id=resp.get("SessionId"))
[docs] def wait_for_results(self, statement_id: str, poll_interval: int) -> str: while True: self.log.info("Polling statement %s", statement_id) is_finished = self.check_query_is_finished(statement_id) if is_finished: return FINISHED_STATE time.sleep(poll_interval)
[docs] def check_query_is_finished(self, statement_id: str) -> bool: """Check whether query finished, raise exception is failed.""" resp = self.conn.describe_statement(Id=statement_id) return self.parse_statement_response(resp)
[docs] def parse_statement_response(self, resp: DescribeStatementResponseTypeDef) -> bool: """Parse the response of describe_statement.""" status = resp["Status"] if status == FINISHED_STATE: num_rows = resp.get("ResultRows") if num_rows is not None: self.log.info("Processed %s rows", num_rows) return True elif status in FAILURE_STATES: exception_cls = ( RedshiftDataQueryFailedError if status == FAILED_STATE else RedshiftDataQueryAbortedError ) raise exception_cls( f"Statement {resp['Id']} terminated with status {status}. " f"Response details: {pformat(resp)}" ) self.log.info("Query status: %s", status) return False
[docs] def get_table_primary_key( self, table: str, database: str, schema: str | None = "public", cluster_identifier: str | None = None, workgroup_name: str | None = None, db_user: str | None = None, secret_arn: str | None = None, statement_name: str | None = None, with_event: bool = False, wait_for_completion: bool = True, poll_interval: int = 10, ) -> list[str] | None: """ Return the table primary key. Copied from ``RedshiftSQLHook.get_table_primary_key()`` :param table: Name of the target table :param database: the name of the database :param schema: Name of the target schema, public by default :param cluster_identifier: unique identifier of a cluster :param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with `cluster_identifier`. Specify this parameter to query Redshift Serverless. More info https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html :param db_user: the database username :param secret_arn: the name or ARN of the secret that enables db access :param statement_name: the name of the SQL statement :param with_event: indicates whether to send an event to EventBridge :param wait_for_completion: indicates whether to wait for a result, if True wait, if False don't wait :param poll_interval: how often in seconds to check the query status :return: Primary key columns list """ sql = f""" select kcu.column_name from information_schema.table_constraints tco join information_schema.key_column_usage kcu on kcu.constraint_name = tco.constraint_name and kcu.constraint_schema = tco.constraint_schema and kcu.constraint_name = tco.constraint_name where tco.constraint_type = 'PRIMARY KEY' and kcu.table_schema = {schema} and kcu.table_name = {table} """ stmt_id = self.execute_query( sql=sql, database=database, cluster_identifier=cluster_identifier, workgroup_name=workgroup_name, db_user=db_user, secret_arn=secret_arn, statement_name=statement_name, with_event=with_event, wait_for_completion=wait_for_completion, poll_interval=poll_interval, ).statement_id pk_columns = [] token = "" while True: kwargs = {"Id": stmt_id} if token: kwargs["NextToken"] = token response = self.conn.get_statement_result(**kwargs) # we only select a single column (that is a string), # so safe to assume that there is only a single col in the record pk_columns += [y["stringValue"] for x in response["Records"] for y in x] if "NextToken" in response: token = response["NextToken"] else: break return pk_columns or None
[docs] async def is_still_running(self, statement_id: str) -> bool: """ Async function to check whether the query is still running. :param statement_id: the UUID of the statement """ async with self.async_conn as client: desc = await client.describe_statement(Id=statement_id) return desc["Status"] in RUNNING_STATES
[docs] async def check_query_is_finished_async(self, statement_id: str) -> bool: """ Async function to check statement is finished. It takes statement_id, makes async connection to redshift data to get the query status by statement_id and returns the query status. :param statement_id: the UUID of the statement """ async with self.async_conn as client: resp = await client.describe_statement(Id=statement_id) return self.parse_statement_response(resp)

Was this entry helpful?