#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""This module contains a Google Cloud Spanner Hook."""
from __future__ import annotations
from collections import OrderedDict
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, NamedTuple
from google.api_core.exceptions import AlreadyExists, GoogleAPICallError
from google.cloud.spanner_v1.client import Client
from sqlalchemy import create_engine
from airflow.exceptions import AirflowException
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook, get_field
from airflow.providers.openlineage.sqlparser import DatabaseInfo
if TYPE_CHECKING:
from google.cloud.spanner_v1.database import Database
from google.cloud.spanner_v1.instance import Instance
from google.cloud.spanner_v1.transaction import Transaction
from google.longrunning.operations_grpc_pb2 import Operation
from airflow.models.connection import Connection
[docs]
class SpannerConnectionParams(NamedTuple):
"""Information about Google Spanner connection parameters."""
[docs]
instance_id: str | None
[docs]
database_id: str | None
[docs]
class SpannerHook(GoogleBaseHook, DbApiHook):
"""
Hook for Google Cloud Spanner APIs.
All the methods in the hook where project_id is used must be called with
keyword arguments rather than positional.
"""
[docs]
conn_name_attr = "gcp_conn_id"
[docs]
default_conn_name = "google_cloud_spanner_default"
[docs]
conn_type = "gcpspanner"
[docs]
hook_name = "Google Cloud Spanner"
def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:
super().__init__(
gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain,
**kwargs,
)
self._client: Client | None = None
def _get_client(self, project_id: str) -> Client:
"""
Provide a client for interacting with the Cloud Spanner API.
:param project_id: The ID of the Google Cloud project.
:return: Client
"""
if not self._client:
self._client = Client(
project=project_id, credentials=self.get_credentials(), client_info=CLIENT_INFO
)
return self._client
def _get_conn_params(self) -> SpannerConnectionParams:
"""Extract spanner database connection parameters."""
extras = self.get_connection(self.gcp_conn_id).extra_dejson
project_id = get_field(extras, "project_id") or self.project_id
instance_id = get_field(extras, "instance_id")
database_id = get_field(extras, "database_id")
return SpannerConnectionParams(project_id, instance_id, database_id)
[docs]
def get_uri(self) -> str:
"""Override DbApiHook get_uri method for get_sqlalchemy_engine()."""
project_id, instance_id, database_id = self._get_conn_params()
if not all([instance_id, database_id]):
raise AirflowException("The instance_id or database_id were not specified")
return f"spanner+spanner:///projects/{project_id}/instances/{instance_id}/databases/{database_id}"
[docs]
def get_sqlalchemy_engine(self, engine_kwargs=None):
"""
Get an sqlalchemy_engine object.
:param engine_kwargs: Kwargs used in :func:`~sqlalchemy.create_engine`.
:return: the created engine.
"""
if engine_kwargs is None:
engine_kwargs = {}
project_id, _, _ = self._get_conn_params()
spanner_client = self._get_client(project_id=project_id)
return create_engine(self.get_uri(), connect_args={"client": spanner_client}, **engine_kwargs)
@GoogleBaseHook.fallback_to_default_project_id
[docs]
def get_instance(
self,
instance_id: str,
project_id: str,
) -> Instance | None:
"""
Get information about a particular instance.
:param project_id: Optional, The ID of the Google Cloud project that owns the Cloud Spanner
database. If set to None or missing, the default project_id from the Google Cloud connection
is used.
:param instance_id: The ID of the Cloud Spanner instance.
:return: Spanner instance
"""
instance = self._get_client(project_id=project_id).instance(instance_id=instance_id)
if not instance.exists():
return None
return instance
def _apply_to_instance(
self,
project_id: str,
instance_id: str,
configuration_name: str,
node_count: int,
display_name: str,
func: Callable[[Instance], Operation],
) -> None:
"""
Invoke a method on a given instance by applying a specified Callable.
:param project_id: The ID of the Google Cloud project that owns the Cloud Spanner database.
:param instance_id: The ID of the instance.
:param configuration_name: Name of the instance configuration defining how the
instance will be created. Required for instances which do not yet exist.
:param node_count: (Optional) Number of nodes allocated to the instance.
:param display_name: (Optional) The display name for the instance in the Cloud
Console UI. (Must be between 4 and 30 characters.) If this value is not set
in the constructor, will fall back to the instance ID.
:param func: Method of the instance to be called.
"""
instance = self._get_client(project_id=project_id).instance(
instance_id=instance_id,
configuration_name=configuration_name,
node_count=node_count,
display_name=display_name,
)
try:
operation: Operation = func(instance)
except GoogleAPICallError as e:
self.log.error("An error occurred: %s. Exiting.", e.message)
raise e
if operation:
result = operation.result()
self.log.info(result)
@GoogleBaseHook.fallback_to_default_project_id
[docs]
def create_instance(
self,
instance_id: str,
configuration_name: str,
node_count: int,
display_name: str,
project_id: str,
) -> None:
"""
Create a new Cloud Spanner instance.
:param instance_id: The ID of the Cloud Spanner instance.
:param configuration_name: The name of the instance configuration defining how the
instance will be created. Possible configuration values can be retrieved via
https://cloud.google.com/spanner/docs/reference/rest/v1/projects.instanceConfigs/list
:param node_count: (Optional) The number of nodes allocated to the Cloud Spanner
instance.
:param display_name: (Optional) The display name for the instance in the Google Cloud Console.
Must be between 4 and 30 characters. If this value is not passed, the name falls back
to the instance ID.
:param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner
database. If set to None or missing, the default project_id from the Google Cloud connection
is used.
:return: None
"""
self._apply_to_instance(
project_id, instance_id, configuration_name, node_count, display_name, lambda x: x.create()
)
@GoogleBaseHook.fallback_to_default_project_id
[docs]
def update_instance(
self,
instance_id: str,
configuration_name: str,
node_count: int,
display_name: str,
project_id: str,
) -> None:
"""
Update an existing Cloud Spanner instance.
:param instance_id: The ID of the Cloud Spanner instance.
:param configuration_name: The name of the instance configuration defining how the
instance will be created. Possible configuration values can be retrieved via
https://cloud.google.com/spanner/docs/reference/rest/v1/projects.instanceConfigs/list
:param node_count: (Optional) The number of nodes allocated to the Cloud Spanner
instance.
:param display_name: (Optional) The display name for the instance in the Google Cloud
Console. Must be between 4 and 30 characters. If this value is not set in
the constructor, the name falls back to the instance ID.
:param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner
database. If set to None or missing, the default project_id from the Google Cloud connection
is used.
:return: None
"""
self._apply_to_instance(
project_id, instance_id, configuration_name, node_count, display_name, lambda x: x.update()
)
@GoogleBaseHook.fallback_to_default_project_id
[docs]
def delete_instance(self, instance_id: str, project_id: str) -> None:
"""
Delete an existing Cloud Spanner instance.
:param instance_id: The ID of the Cloud Spanner instance.
:param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner
database. If set to None or missing, the default project_id from the Google Cloud connection
is used.
:return: None
"""
instance = self._get_client(project_id=project_id).instance(instance_id)
try:
instance.delete()
return
except GoogleAPICallError as e:
self.log.error("An error occurred: %s. Exiting.", e.message)
raise e
@GoogleBaseHook.fallback_to_default_project_id
[docs]
def get_database(
self,
instance_id: str,
database_id: str,
project_id: str,
) -> Database | None:
"""
Retrieve a database in Cloud Spanner; return None if the database does not exist in the instance.
:param instance_id: The ID of the Cloud Spanner instance.
:param database_id: The ID of the database in Cloud Spanner.
:param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner
database. If set to None or missing, the default project_id from the Google Cloud connection
is used.
:return: Database object or None if database does not exist
"""
instance = self._get_client(project_id=project_id).instance(instance_id=instance_id)
if not instance.exists():
raise AirflowException(f"The instance {instance_id} does not exist in project {project_id} !")
database = instance.database(database_id=database_id)
if not database.exists():
return None
return database
@GoogleBaseHook.fallback_to_default_project_id
[docs]
def create_database(
self,
instance_id: str,
database_id: str,
ddl_statements: list[str],
project_id: str,
) -> None:
"""
Create a new database in Cloud Spanner.
:param instance_id: The ID of the Cloud Spanner instance.
:param database_id: The ID of the database to create in Cloud Spanner.
:param ddl_statements: The string list containing DDL for the new database.
:param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner
database. If set to None or missing, the default project_id from the Google Cloud connection
is used.
:return: None
"""
instance = self._get_client(project_id=project_id).instance(instance_id=instance_id)
if not instance.exists():
raise AirflowException(f"The instance {instance_id} does not exist in project {project_id} !")
database = instance.database(database_id=database_id, ddl_statements=ddl_statements)
try:
operation: Operation = database.create()
except GoogleAPICallError as e:
self.log.error("An error occurred: %s. Exiting.", e.message)
raise e
if operation:
result = operation.result()
self.log.info(result)
@GoogleBaseHook.fallback_to_default_project_id
[docs]
def update_database(
self,
instance_id: str,
database_id: str,
ddl_statements: list[str],
project_id: str,
operation_id: str | None = None,
) -> None:
"""
Update DDL of a database in Cloud Spanner.
:param instance_id: The ID of the Cloud Spanner instance.
:param database_id: The ID of the database in Cloud Spanner.
:param ddl_statements: The string list containing DDL for the new database.
:param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner
database. If set to None or missing, the default project_id from the Google Cloud connection
is used.
:param operation_id: (Optional) The unique per database operation ID that can be
specified to implement idempotency check.
:return: None
"""
instance = self._get_client(project_id=project_id).instance(instance_id=instance_id)
if not instance.exists():
raise AirflowException(f"The instance {instance_id} does not exist in project {project_id} !")
database = instance.database(database_id=database_id)
try:
operation = database.update_ddl(ddl_statements=ddl_statements, operation_id=operation_id)
if operation:
result = operation.result()
self.log.info(result)
return
except AlreadyExists as e:
if e.code == 409 and operation_id in e.message:
self.log.info(
"Replayed update_ddl message - the operation id %s was already done before.",
operation_id,
)
return
except GoogleAPICallError as e:
self.log.error("An error occurred: %s. Exiting.", e.message)
raise e
@GoogleBaseHook.fallback_to_default_project_id
[docs]
def delete_database(self, instance_id: str, database_id, project_id: str) -> bool:
"""
Drop a database in Cloud Spanner.
:param instance_id: The ID of the Cloud Spanner instance.
:param database_id: The ID of the database in Cloud Spanner.
:param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner
database. If set to None or missing, the default project_id from the Google Cloud connection
is used.
:return: True if everything succeeded
"""
instance = self._get_client(project_id=project_id).instance(instance_id=instance_id)
if not instance.exists():
raise AirflowException(f"The instance {instance_id} does not exist in project {project_id} !")
database = instance.database(database_id=database_id)
if not database.exists():
self.log.info(
"The database %s is already deleted from instance %s. Exiting.", database_id, instance_id
)
return False
try:
database.drop()
except GoogleAPICallError as e:
self.log.error("An error occurred: %s. Exiting.", e.message)
raise e
return True
@GoogleBaseHook.fallback_to_default_project_id
[docs]
def execute_dml(
self,
instance_id: str,
database_id: str,
queries: list[str],
project_id: str,
) -> list[int]:
"""
Execute an arbitrary DML query (INSERT, UPDATE, DELETE).
:param instance_id: The ID of the Cloud Spanner instance.
:param database_id: The ID of the database in Cloud Spanner.
:param queries: The queries to execute.
:param project_id: Optional, the ID of the Google Cloud project that owns the Cloud Spanner
database. If set to None or missing, the default project_id from the Google Cloud connection
is used.
:return: list of numbers of affected rows by DML query
"""
db = (
self._get_client(project_id=project_id)
.instance(instance_id=instance_id)
.database(database_id=database_id)
)
def _tx_runner(tx: Transaction) -> dict[str, int]:
return self._execute_sql_in_transaction(tx, queries)
result = db.run_in_transaction(_tx_runner)
result_rows_count_per_query = []
for i, (sql, rc) in enumerate(result.items(), start=1):
if not sql.startswith("SELECT"):
preview = sql if len(sql) <= 300 else sql[:300] + "…"
self.log.info("[DML %d/%d] affected rows=%d | %s", i, len(result), rc, preview)
result_rows_count_per_query.append(rc)
return result_rows_count_per_query
@staticmethod
def _execute_sql_in_transaction(transaction: Transaction, queries: list[str]) -> dict[str, int]:
counts: OrderedDict[str, int] = OrderedDict()
for sql in queries:
rc = transaction.execute_update(sql)
counts[sql] = rc
return counts
def _get_openlineage_authority_part(self, connection: Connection) -> str | None:
"""Build Spanner-specific authority part for OpenLineage. Returns {project}/{instance}."""
extras = connection.extra_dejson
project_id = extras.get("project_id")
instance_id = extras.get("instance_id")
if not project_id or not instance_id:
return None
return f"{project_id}/{instance_id}"
[docs]
def get_openlineage_database_dialect(self, connection: Connection) -> str:
"""Return database dialect for OpenLineage."""
return "spanner"
[docs]
def get_openlineage_database_info(self, connection: Connection) -> DatabaseInfo:
"""Return Spanner specific information for OpenLineage."""
extras = connection.extra_dejson
database_id = extras.get("database_id")
return DatabaseInfo(
scheme=self.get_openlineage_database_dialect(connection),
authority=self._get_openlineage_authority_part(connection),
database=database_id,
information_schema_columns=[
"table_schema",
"table_name",
"column_name",
"ordinal_position",
"spanner_type",
],
)
[docs]
def get_openlineage_default_schema(self) -> str | None:
"""
Spanner expose 'public' or '' schema depending on dialect(Postgres vs GoogleSQL).
SQLAlchemy dialect for Spanner does not expose default schema, so we return None
to follow the same approach.
"""
return None