#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import json
import os
from collections.abc import Iterable, Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar
from urllib.parse import quote_plus, urlencode
import trino
from deprecated import deprecated
from trino.exceptions import DatabaseError
from trino.transaction import IsolationLevel
from airflow.configuration import conf
from airflow.exceptions import (
    AirflowException,
    AirflowOptionalProviderFeatureException,
    AirflowProviderDeprecationWarning,
)
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.trino.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils.helpers import exactly_one
if AIRFLOW_V_3_0_PLUS:
    from airflow.sdk.execution_time.context import AIRFLOW_VAR_NAME_FORMAT_MAPPING, DEFAULT_FORMAT_PREFIX
else:
    from airflow.utils.operator_helpers import (  # type: ignore[no-redef, attr-defined]
        AIRFLOW_VAR_NAME_FORMAT_MAPPING,
        DEFAULT_FORMAT_PREFIX,
    )
if TYPE_CHECKING:
    from airflow.models import Connection
[docs]
def generate_trino_client_info() -> str:
    """Return json string with dag_id, task_id, logical_date and try_number."""
    context_var = {
        format_map["default"].replace(DEFAULT_FORMAT_PREFIX, ""): os.environ.get(
            format_map["env_var_format"], ""
        )
        for format_map in AIRFLOW_VAR_NAME_FORMAT_MAPPING.values()
    }
    date_key = "logical_date" if AIRFLOW_V_3_0_PLUS else "execution_date"
    task_info = {
        "dag_id": context_var["dag_id"],
        "task_id": context_var["task_id"],
        date_key: context_var[date_key],
        "try_number": context_var["try_number"],
        "dag_run_id": context_var["dag_run_id"],
        "dag_owner": context_var["dag_owner"],
    }
    return json.dumps(task_info, sort_keys=True) 
[docs]
class TrinoException(Exception):
    """Trino exception.""" 
def _boolify(value):
    if isinstance(value, bool):
        return value
    if isinstance(value, str):
        if value.lower() == "false":
            return False
        if value.lower() == "true":
            return True
    return value
[docs]
class TrinoHook(DbApiHook):
    """
    Interact with Trino through trino package.
    >>> ph = TrinoHook()
    >>> sql = "SELECT count(1) AS num FROM airflow.static_babynames"
    >>> ph.get_records(sql)
    [[340698]]
    """
[docs]
    conn_name_attr = "trino_conn_id" 
[docs]
    default_conn_name = "trino_default" 
    _test_connection_sql = "select 1"
    @classmethod
[docs]
    def get_ui_field_behaviour(cls) -> dict[str, Any]:
        """Return custom field behaviour."""
        return {
            "hidden_fields": [],
            "relabeling": {},
            "placeholders": {
                "extra": json.dumps(
                    {
                        "auth": "authentication type",
                        "impersonate_as_owner": "allow impersonate as owner",
                        "jwt__token": "JWT token",
                        "jwt__file": "JWT file path",
                        "certs__client_cert_path": "Client certificate path",
                        "certs__client_key_path": "Client key path",
                        "kerberos__config": "Kerberos config",
                        "kerberos__service_name": "Kerberos service name",
                        "kerberos__mutual_authentication": "Kerberos mutual authentication",
                        "kerberos__force_preemptive": "Kerberos force preemptive",
                        "kerberos__hostname_override": "Kerberos hostname override",
                        "kerberos__sanitize_mutual_error_response": "Kerberos sanitize mutual error response",
                        "kerberos__principal": "Kerberos principal",
                        "kerberos__delegate": "Kerberos delegate",
                        "kerberos__ca_bundle": "Kerberos CA bundle",
                        "session_properties": "session properties",
                        "client_tags": "Trino client tags. Example ['sales','cluster1']",
                        "timezone": "Trino timezone",
                    },
                    indent=1,
                ),
                "login": "Effective user for connection",
            },
        } 
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._placeholder: str = "?"
[docs]
    def get_conn(self) -> Connection:
        """Return a connection object."""
        db = self.get_connection(self.get_conn_id())
        extra = db.extra_dejson
        auth = None
        user = db.login
        auth_methods = []
        if db.password:
            auth_methods.append("password")
        if extra.get("auth") == "jwt":
            auth_methods.append("jwt")
        if extra.get("auth") == "certs":
            auth_methods.append("certs")
        if extra.get("auth") == "kerberos":
            auth_methods.append("kerberos")
        if len(auth_methods) > 1:
            raise AirflowException(
                f"Multiple authentication methods specified: {', '.join(auth_methods)}. Only one is allowed."
            )
        if db.password:
            auth = trino.auth.BasicAuthentication(db.login, db.password)
        elif extra.get("auth") == "jwt":
            if not exactly_one(jwt_file := "jwt__file" in extra, jwt_token := "jwt__token" in extra):
                msg = (
                    "When auth set to 'jwt' then expected exactly one parameter 'jwt__file' or 'jwt__token'"
                    " in connection extra, but "
                )
                if jwt_file and jwt_token:
                    msg += "provided both."
                else:
                    msg += "none of them provided."
                raise ValueError(msg)
            if jwt_file:
                token = Path(extra["jwt__file"]).read_text()
            else:
                token = extra["jwt__token"]
            auth = trino.auth.JWTAuthentication(token=token)
        elif extra.get("auth") == "certs":
            auth = trino.auth.CertificateAuthentication(
                extra.get("certs__client_cert_path"),
                extra.get("certs__client_key_path"),
            )
        elif extra.get("auth") == "kerberos":
            auth = trino.auth.KerberosAuthentication(
                config=extra.get("kerberos__config", os.environ.get("KRB5_CONFIG")),
                service_name=extra.get("kerberos__service_name"),
                mutual_authentication=_boolify(extra.get("kerberos__mutual_authentication", False)),
                force_preemptive=_boolify(extra.get("kerberos__force_preemptive", False)),
                hostname_override=extra.get("kerberos__hostname_override"),
                sanitize_mutual_error_response=_boolify(
                    extra.get("kerberos__sanitize_mutual_error_response", True)
                ),
                principal=extra.get("kerberos__principal", conf.get("kerberos", "principal")),
                delegate=_boolify(extra.get("kerberos__delegate", False)),
                ca_bundle=extra.get("kerberos__ca_bundle"),
            )
        if _boolify(extra.get("impersonate_as_owner", False)):
            user = os.getenv("AIRFLOW_CTX_DAG_OWNER", None)
            if user is None:
                user = db.login
        http_headers = {"X-Trino-Client-Info": generate_trino_client_info()}
        trino_conn = trino.dbapi.connect(
            host=db.host,
            port=db.port,
            user=user,
            source=extra.get("source", "airflow"),
            http_scheme=extra.get("protocol", "http"),
            http_headers=http_headers,
            catalog=extra.get("catalog", "hive"),
            schema=db.schema,
            auth=auth,
            isolation_level=self.get_isolation_level(),
            verify=_boolify(extra.get("verify", True)),
            session_properties=extra.get("session_properties") or None,
            client_tags=extra.get("client_tags") or None,
            timezone=extra.get("timezone") or None,
            extra_credential=extra.get("extra_credential") or None,
            roles=extra.get("roles") or None,
        )
        return trino_conn 
[docs]
    def get_isolation_level(self) -> Any:
        """Return an isolation level."""
        db = self.get_connection(self.get_conn_id())
        isolation_level = db.extra_dejson.get("isolation_level", "AUTOCOMMIT").upper()
        return getattr(IsolationLevel, isolation_level, IsolationLevel.AUTOCOMMIT) 
[docs]
    def get_records(
        self,
        sql: str | list[str] = "",
        parameters: Iterable | Mapping[str, Any] | None = None,
    ) -> Any:
        if not isinstance(sql, str):
            raise ValueError(f"The sql in Trino Hook must be a string and is {sql}!")
        try:
            return super().get_records(self.strip_sql_string(sql), parameters)
        except DatabaseError as e:
            raise TrinoException(e) 
[docs]
    def get_first(
        self, sql: str | list[str] = "", parameters: Iterable | Mapping[str, Any] | None = None
    ) -> Any:
        if not isinstance(sql, str):
            raise ValueError(f"The sql in Trino Hook must be a string and is {sql}!")
        try:
            return super().get_first(self.strip_sql_string(sql), parameters)
        except DatabaseError as e:
            raise TrinoException(e) 
    def _get_pandas_df(self, sql: str = "", parameters=None, **kwargs):
        try:
            import pandas as pd
        except ImportError:
            raise AirflowOptionalProviderFeatureException(
                "Pandas is not installed. Please install it with `pip install pandas`."
            )
        cursor = self.get_cursor()
        try:
            cursor.execute(self.strip_sql_string(sql), parameters)
            data = cursor.fetchall()
        except DatabaseError as e:
            raise TrinoException(e)
        column_descriptions = cursor.description
        if data:
            df = pd.DataFrame(data, **kwargs)
            df.rename(columns={n: c[0] for n, c in zip(df.columns, column_descriptions)}, inplace=True)
        else:
            df = pd.DataFrame(**kwargs)
        return df
    def _get_polars_df(self, sql: str = "", parameters=None, **kwargs):
        try:
            import polars as pl
        except ImportError:
            raise AirflowOptionalProviderFeatureException(
                "Polars is not installed. Please install it with `pip install polars`."
            )
        cursor = self.get_cursor()
        try:
            cursor.execute(self.strip_sql_string(sql), parameters)
            data = cursor.fetchall()
        except DatabaseError as e:
            raise TrinoException(e)
        column_descriptions = cursor.description
        if data:
            df = pl.DataFrame(
                data,
                schema=[c[0] for c in column_descriptions],
                orient="row",
                **kwargs,
            )
        else:
            df = pl.DataFrame(**kwargs)
        return df
    @deprecated(
        reason="Replaced by function `get_df`.",
        category=AirflowProviderDeprecationWarning,
        action="ignore",
    )
[docs]
    def get_pandas_df(self, sql: str = "", parameters=None, **kwargs):
        return self._get_pandas_df(sql, parameters, **kwargs) 
[docs]
    def insert_rows(
        self,
        table: str,
        rows: Iterable[tuple],
        target_fields: Iterable[str] | None = None,
        commit_every: int = 0,
        replace: bool = False,
        **kwargs,
    ) -> None:
        """
        Insert a set of tuples into a table in a generic way.
        :param table: Name of the target table
        :param rows: The rows to insert into the table
        :param target_fields: The names of the columns to fill in the table
        :param commit_every: The maximum number of rows to insert in one
            transaction. Set to 0 to insert all rows in one transaction.
        :param replace: Whether to replace instead of insert
        """
        if self.get_isolation_level() == IsolationLevel.AUTOCOMMIT:
            self.log.info(
                "Transactions are not enable in trino connection. "
                "Please use the isolation_level property to enable it. "
                "Falling back to insert all rows in one transaction."
            )
            commit_every = 0
        super().insert_rows(table, rows, target_fields, commit_every, replace) 
    @staticmethod
    def _serialize_cell(cell: Any, conn: Connection | None = None) -> Any:
        """
        Trino will adapt all execute() args internally, hence we return cell without any conversion.
        :param cell: The cell to insert into the table
        :param conn: The database connection
        :return: The cell
        """
        return cell
[docs]
    def get_openlineage_database_info(self, connection):
        """Return Trino specific information for OpenLineage."""
        from airflow.providers.openlineage.sqlparser import DatabaseInfo
        return DatabaseInfo(
            scheme="trino",
            authority=DbApiHook.get_openlineage_authority_part(
                connection, default_port=trino.constants.DEFAULT_PORT
            ),
            information_schema_columns=[
                "table_schema",
                "table_name",
                "column_name",
                "ordinal_position",
                "data_type",
                "table_catalog",
            ],
            database=connection.extra_dejson.get("catalog", "hive"),
            is_information_schema_cross_db=True,
        ) 
[docs]
    def get_openlineage_database_dialect(self, _):
        """Return Trino dialect."""
        return "trino" 
[docs]
    def get_openlineage_default_schema(self):
        """Return Trino default schema."""
        return trino.constants.DEFAULT_SCHEMA 
[docs]
    def get_uri(self) -> str:
        """Return the Trino URI for the connection."""
        conn = self.connection
        uri = "trino://"
        auth_part = ""
        if conn.login:
            auth_part = quote_plus(conn.login)
            if conn.password:
                auth_part = f"{auth_part}:{quote_plus(conn.password)}"
            auth_part = f"{auth_part}@"
        host_part = conn.host or "localhost"
        if conn.port:
            host_part = f"{host_part}:{conn.port}"
        schema_part = ""
        if conn.schema:
            schema_part = f"/{quote_plus(conn.schema)}"
            extra_schema = conn.extra_dejson.get("schema")
            if extra_schema:
                schema_part = f"{schema_part}/{quote_plus(extra_schema)}"
        uri = f"{uri}{auth_part}{host_part}{schema_part}"
        extra = conn.extra_dejson.copy()
        if "schema" in extra:
            extra.pop("schema")
        query_params = {k: str(v) for k, v in extra.items() if v is not None}
        if query_params:
            uri = f"{uri}?{urlencode(query_params)}"
        return uri