Source code for airflow.providers.common.ai.utils.sql_validation

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
SQL safety validation for LLM-generated queries.

Uses an allowlist approach: only explicitly permitted statement types pass.
This is safer than a denylist because new/unexpected statement types
(INSERT, UPDATE, MERGE, TRUNCATE, COPY, etc.) are blocked by default.
"""

from __future__ import annotations

import sqlglot
from sqlglot import exp
from sqlglot.errors import ErrorLevel

# Allowlist: only these top-level statement types pass validation by default.
# - Select: plain queries and CTE-wrapped queries (WITH ... AS ... SELECT is parsed
#   as Select with a `with` clause property — still a Select node at the top level)
# - Union/Intersect/Except: set operations on SELECT results
[docs] DEFAULT_ALLOWED_TYPES: tuple[type[exp.Expression], ...] = ( exp.Select, exp.Union, exp.Intersect, exp.Except, )
[docs] class SQLSafetyError(Exception): """Generated SQL failed safety validation."""
[docs] def validate_sql( sql: str, *, allowed_types: tuple[type[exp.Expression], ...] | None = None, dialect: str | None = None, allow_multiple_statements: bool = False, ) -> list[exp.Expression]: """ Parse SQL and verify all statements are in the allowed types list. By default, only a single SELECT-family statement is allowed. Multi-statement SQL (separated by semicolons) is rejected unless ``allow_multiple_statements=True``, because multi-statement inputs can hide dangerous operations after a benign SELECT. Returns parsed statements on success, raises :class:`SQLSafetyError` on violation. :param sql: SQL string to validate. :param allowed_types: Tuple of sqlglot expression types to permit. Defaults to ``(Select, Union, Intersect, Except)``. :param dialect: SQL dialect for parsing (``postgres``, ``mysql``, etc.). :param allow_multiple_statements: Whether to allow multiple semicolon-separated statements. Default ``False``. :return: List of parsed sqlglot Expression objects. :raises SQLSafetyError: If the SQL is empty, contains disallowed statement types, or has multiple statements when not permitted. """ if not sql or not sql.strip(): raise SQLSafetyError("Empty SQL input.") types = allowed_types or DEFAULT_ALLOWED_TYPES try: statements = sqlglot.parse(sql, dialect=dialect, error_level=ErrorLevel.RAISE) except sqlglot.errors.ParseError as e: raise SQLSafetyError(f"SQL parse error: {e}") from e # sqlglot.parse can return [None] for empty input parsed: list[exp.Expression] = [s for s in statements if s is not None] if not parsed: raise SQLSafetyError("Empty SQL input.") if not allow_multiple_statements and len(parsed) > 1: raise SQLSafetyError( f"Multiple statements detected ({len(parsed)}). Only single statements are allowed by default." ) for stmt in parsed: if not isinstance(stmt, types): allowed_names = ", ".join(t.__name__ for t in types) raise SQLSafetyError( f"Statement type '{type(stmt).__name__}' is not allowed. Allowed types: {allowed_names}" ) return parsed

Was this entry helpful?