# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from functools import cached_property
from typing import TYPE_CHECKING, Any
import requests
from pyiceberg.catalog import load_catalog
from airflow.providers.common.compat.sdk import BaseHook
if TYPE_CHECKING:
from pyiceberg.catalog import Catalog
from pyiceberg.table import Table
[docs]
TOKENS_ENDPOINT = "oauth/tokens"
[docs]
class IcebergHook(BaseHook):
"""
Hook for Apache Iceberg REST catalogs.
Provides catalog-level operations (list namespaces, list tables, load schemas)
using pyiceberg, plus OAuth2 token generation for external query engines.
:param iceberg_conn_id: The :ref:`Iceberg connection id<howto/connection:iceberg>`
which refers to the information to connect to the Iceberg catalog.
"""
[docs]
conn_name_attr = "iceberg_conn_id"
[docs]
default_conn_name = "iceberg_default"
@classmethod
[docs]
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Return custom UI field behaviour for Iceberg connection."""
return {
"hidden_fields": ["schema", "port"],
"relabeling": {
"host": "Catalog URI",
"login": "Client ID",
"password": "Client Secret",
},
"placeholders": {
"host": "https://your-catalog.example.com/ws/v1",
"login": "client_id (OAuth2 credentials)",
"password": "client_secret (OAuth2 credentials)",
"extra": '{"warehouse": "s3://my-warehouse/", "s3.region": "us-east-1"}',
},
}
def __init__(self, iceberg_conn_id: str = default_conn_name) -> None:
super().__init__()
[docs]
self.conn_id = iceberg_conn_id
@cached_property
[docs]
def catalog(self) -> Catalog:
"""Return a pyiceberg Catalog instance for the configured connection."""
conn = self.get_connection(self.conn_id)
# Start with extra so connection fields take precedence
extra = conn.extra_dejson or {}
catalog_properties: dict[str, str] = {**extra}
host = conn.host.rstrip("/") if conn.host else None
if host:
catalog_properties["uri"] = host
if "type" not in catalog_properties:
catalog_properties["type"] = "rest"
# credential is REST-catalog-specific; other catalogs (Glue, BigQuery)
# use their own auth fields passed through extra.
if catalog_properties["type"] == "rest":
if conn.login and conn.password:
catalog_properties["credential"] = f"{conn.login}:{conn.password}"
elif conn.login or conn.password:
self.log.warning(
"Only one of Client ID / Client Secret is set. "
"Both are required for OAuth2 credential authentication."
)
return load_catalog(self.conn_id, **catalog_properties)
[docs]
def get_conn(self) -> Catalog:
"""Return the pyiceberg Catalog."""
return self.catalog
[docs]
def test_connection(self) -> tuple[bool, str]:
"""Test the Iceberg connection by listing namespaces."""
try:
namespaces = self.catalog.list_namespaces()
return True, f"Connected. Found {len(namespaces)} namespace(s)."
except Exception as e:
return False, str(e)
# ---- Token methods (backward compatibility) ----
[docs]
def get_token(self) -> str:
"""
Obtain a short-lived OAuth2 access token.
This preserves the legacy behavior of the pre-2.0 ``get_conn()`` method.
Use this when you need a raw token for external engines (Spark, Trino, Flink).
"""
conn = self.get_connection(self.conn_id)
base_url = conn.host.rstrip("/") if conn.host else ""
data = {
"client_id": conn.login,
"client_secret": conn.password,
"grant_type": "client_credentials",
}
response = requests.post(f"{base_url}/{TOKENS_ENDPOINT}", data=data, timeout=30)
response.raise_for_status()
return response.json()["access_token"]
[docs]
def get_token_macro(self) -> str:
"""Return a Jinja2 macro that resolves to a fresh token at render time."""
return f"{{{{ conn.{self.conn_id}.get_hook().get_token() }}}}"
# ---- Namespace operations ----
[docs]
def list_namespaces(self) -> list[str]:
"""Return all namespace names in the catalog."""
return [".".join(ns) for ns in self.catalog.list_namespaces()]
# ---- Table operations ----
[docs]
def list_tables(self, namespace: str) -> list[str]:
"""
Return all table names in the given namespace.
:param namespace: Namespace (database/schema) to list tables from.
:return: List of fully-qualified table names ("namespace.table").
"""
return [".".join(ident) for ident in self.catalog.list_tables(namespace)]
[docs]
def load_table(self, table_name: str) -> Table:
"""
Load an Iceberg table object.
:param table_name: Fully-qualified table name ("namespace.table").
:return: pyiceberg Table instance.
"""
if "." not in table_name:
raise ValueError(f"Expected fully-qualified table name (namespace.table), got: {table_name!r}")
return self.catalog.load_table(table_name)
[docs]
def table_exists(self, table_name: str) -> bool:
"""Check whether a table exists in the catalog."""
return self.catalog.table_exists(table_name)
# ---- Schema introspection ----
[docs]
def get_table_schema(self, table_name: str, **kwargs: Any) -> list[dict[str, str]]:
"""
Return column names and types for an Iceberg table.
Compatible with the ``DbApiHook.get_table_schema()`` contract so that
LLM operators can use this hook interchangeably for schema context.
:param table_name: Fully-qualified table name ("namespace.table").
:return: List of dicts with ``name`` and ``type`` keys.
Example return value::
[
{"name": "id", "type": "long"},
{"name": "name", "type": "string"},
{"name": "created_at", "type": "timestamptz"},
]
"""
table = self.load_table(table_name)
return [
{
"name": field.name,
"type": str(field.field_type),
}
for field in table.schema().fields
]
[docs]
def get_partition_spec(self, table_name: str) -> list[dict[str, str]]:
"""
Return the partition spec for an Iceberg table.
:param table_name: Fully-qualified table name.
:return: List of dicts with ``field`` and ``transform`` keys.
Example::
[
{"field": "event_date", "transform": "day"},
{"field": "region", "transform": "identity"},
]
"""
table = self.load_table(table_name)
spec = table.spec()
schema = table.schema()
result = []
for partition_field in spec.fields:
source_field = schema.find_field(partition_field.source_id)
result.append(
{
"field": source_field.name,
"transform": str(partition_field.transform),
}
)
return result
[docs]
def get_table_properties(self, table_name: str) -> dict[str, str]:
"""
Return table properties (format version, write config, etc.).
:param table_name: Fully-qualified table name.
"""
table = self.load_table(table_name)
return dict(table.properties)
[docs]
def get_snapshots(self, table_name: str, limit: int = 10) -> list[dict[str, Any]]:
"""
Return recent snapshots for an Iceberg table.
:param table_name: Fully-qualified table name.
:param limit: Maximum number of snapshots to return (most recent first).
:return: List of dicts with snapshot metadata.
"""
table = self.load_table(table_name)
arrow_table = table.inspect.snapshots()
num_rows = len(arrow_table)
if num_rows <= limit:
rows = arrow_table.to_pylist()
else:
rows = arrow_table.slice(offset=num_rows - limit, length=limit).to_pylist()
rows.reverse()
return rows