Source code for airflow.providers.common.ai.utils.sql_validation

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
SQL safety validation for LLM-generated queries.

Uses an allowlist approach: only explicitly permitted statement types pass.
This is safer than a denylist because new/unexpected statement types
(INSERT, UPDATE, MERGE, TRUNCATE, COPY, etc.) are blocked by default.
"""

from __future__ import annotations

import sqlglot
from sqlglot import exp
from sqlglot.errors import ErrorLevel

# Allowlist: only these top-level statement types pass validation by default.
# - Select: plain queries and CTE-wrapped queries (WITH ... AS ... SELECT is parsed
#   as Select with a `with` clause property — still a Select node at the top level)
# - Union/Intersect/Except: set operations on SELECT results
[docs] DEFAULT_ALLOWED_TYPES: tuple[type[exp.Expr], ...] = ( exp.Select, exp.Union, exp.Intersect, exp.Except, )
# Denylist: expression types that mutate data or schema when found anywhere in the AST. # This catches data-modifying CTEs (e.g. WITH del AS (DELETE …) SELECT …), # SELECT INTO, and other constructs that bypass top-level type checks. # Note: exp.Command is sqlglot's fallback for any syntax it doesn't recognize. # Including it makes the denylist fail-closed (safer), but may block legitimate # vendor-specific SQL that sqlglot can't parse. Callers who need such syntax can # provide custom allowed_types to bypass the deep scan entirely. _DATA_MODIFYING_NODES: tuple[type[exp.Expr], ...] = ( exp.Insert, exp.Update, exp.Delete, exp.Merge, exp.Into, exp.Command, )
[docs] class SQLSafetyError(Exception): """Generated SQL failed safety validation."""
[docs] def validate_sql( sql: str, *, allowed_types: tuple[type[exp.Expr], ...] | None = None, dialect: str | None = None, allow_multiple_statements: bool = False, ) -> list[exp.Expr]: """ Parse SQL and verify all statements are in the allowed types list. By default, only a single SELECT-family statement is allowed. Multi-statement SQL (separated by semicolons) is rejected unless ``allow_multiple_statements=True``, because multi-statement inputs can hide dangerous operations after a benign SELECT. Returns parsed statements on success, raises :class:`SQLSafetyError` on violation. :param sql: SQL string to validate. :param allowed_types: Tuple of sqlglot expression types to permit. Defaults to ``(Select, Union, Intersect, Except)``. :param dialect: SQL dialect for parsing (``postgres``, ``mysql``, etc.). :param allow_multiple_statements: Whether to allow multiple semicolon-separated statements. Default ``False``. :return: List of parsed sqlglot Expr objects. :raises SQLSafetyError: If the SQL is empty, contains disallowed statement types, or has multiple statements when not permitted. """ if not sql or not sql.strip(): raise SQLSafetyError("Empty SQL input.") types = allowed_types or DEFAULT_ALLOWED_TYPES try: statements = sqlglot.parse(sql, dialect=dialect, error_level=ErrorLevel.RAISE) except sqlglot.errors.ParseError as e: raise SQLSafetyError(f"SQL parse error: {e}") from e # sqlglot.parse can return [None] for empty input parsed: list[exp.Expr] = [s for s in statements if s is not None] if not parsed: raise SQLSafetyError("Empty SQL input.") if not allow_multiple_statements and len(parsed) > 1: raise SQLSafetyError( f"Multiple statements detected ({len(parsed)}). Only single statements are allowed by default." ) for stmt in parsed: if not isinstance(stmt, types): allowed_names = ", ".join(t.__name__ for t in types) raise SQLSafetyError( f"Statement type '{type(stmt).__name__}' is not allowed. Allowed types: {allowed_names}" ) # Deep scan: reject data-modifying nodes hidden inside otherwise-allowed statements # (e.g. data-modifying CTEs, SELECT INTO). Only applies when using the default # read-only allowlist — callers who provide custom allowed_types have explicitly # opted into non-read-only operations. if types is DEFAULT_ALLOWED_TYPES: _check_for_data_modifying_nodes(parsed) return parsed
def _check_for_data_modifying_nodes(statements: list[exp.Expr]) -> None: """ Walk the full AST of each statement and reject data-modifying expressions. This catches bypass vectors like: - Data-modifying CTEs: ``WITH d AS (DELETE FROM t RETURNING *) SELECT * FROM d`` - SELECT INTO: ``SELECT * INTO new_table FROM t`` - INSERT/UPDATE/DELETE hidden inside subqueries or CTEs :raises SQLSafetyError: If any data-modifying node is found in the AST. """ for stmt in statements: for node in stmt.walk(): if isinstance(node, _DATA_MODIFYING_NODES): raise SQLSafetyError( f"Data-modifying operation '{type(node).__name__}' found inside statement. " f"Only pure read operations are allowed in read-only mode." )

Was this entry helpful?