Airflow Summit 2025 is coming October 07-09. Register now to secure your spot!

Source code for airflow.providers.postgres.dialects.postgres

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from collections.abc import Callable

from methodtools import lru_cache

from airflow.providers.common.sql.dialects.dialect import Dialect, T


[docs] class PostgresDialect(Dialect): """Postgres dialect implementation.""" @property
[docs] def name(self) -> str: return "postgresql"
@lru_cache(maxsize=None)
[docs] def get_primary_keys(self, table: str, schema: str | None = None) -> list[str] | None: """ Get the table's primary key. :param table: Name of the target table :param schema: Name of the target schema, public by default :return: Primary key columns list """ if schema is None: table, schema = self.extract_schema_from_table(table) pk_columns = [ row[0] for row in self.get_records( """ select kcu.column_name from information_schema.table_constraints tco join information_schema.key_column_usage kcu on kcu.constraint_name = tco.constraint_name and kcu.constraint_schema = tco.constraint_schema and kcu.constraint_name = tco.constraint_name where tco.constraint_type = 'PRIMARY KEY' and kcu.table_schema = %s and kcu.table_name = %s order by kcu.ordinal_position """, (self.unescape_word(schema), self.unescape_word(table)), ) ] return pk_columns or None
@staticmethod def _to_row(row): return { "name": row[0], "type": row[1], "nullable": row[2].casefold() == "yes", "default": row[3], "autoincrement": row[4].casefold() == "always", "identity": row[5].casefold() == "yes", } @lru_cache(maxsize=None)
[docs] def get_column_names( self, table: str, schema: str | None = None, predicate: Callable[[T], bool] = lambda column: True ) -> list[str] | None: if schema is None: table, schema = self.extract_schema_from_table(table) column_names = list( row["name"] for row in filter( predicate, map( self._to_row, self.get_records( """ select column_name, data_type, is_nullable, column_default, is_generated, is_identity from information_schema.columns where table_schema = %s and table_name = %s order by ordinal_position """, (self.unescape_word(schema), self.unescape_word(table)), ), ), ) ) self.log.debug("Column names for table '%s': %s", table, column_names) return column_names
[docs] def generate_replace_sql(self, table, values, target_fields, **kwargs) -> str: """ Generate the REPLACE SQL statement. :param table: Name of the target table :param values: The row to insert into the table :param target_fields: The names of the columns to fill in the table :param replace: Whether to replace instead of insert :param replace_index: the column or list of column names to act as index for the ON CONFLICT clause :return: The generated INSERT or REPLACE SQL statement """ if not target_fields: raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires column names") replace_index = kwargs.get("replace_index") or self.get_primary_keys(table) if not replace_index: raise ValueError("PostgreSQL ON CONFLICT upsert syntax requires an unique index") if isinstance(replace_index, str): replace_index = [replace_index] sql = self.generate_insert_sql(table, values, target_fields, **kwargs) on_conflict_str = f" ON CONFLICT ({', '.join(map(self.escape_word, replace_index))})" replace_target = [self.escape_word(f) for f in target_fields if f not in replace_index] if replace_target: replace_target_str = ", ".join(f"{col} = excluded.{col}" for col in replace_target) sql += f"{on_conflict_str} DO UPDATE SET {replace_target_str}" else: sql += f"{on_conflict_str} DO NOTHING" return sql

Was this entry helpful?