# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Operator for generating SQL queries from natural language using LLMs."""
from __future__ import annotations
from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any
try:
from airflow.providers.common.ai.utils.sql_validation import (
DEFAULT_ALLOWED_TYPES,
validate_sql as _validate_sql,
)
from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
except ImportError as e:
from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException
raise AirflowOptionalProviderFeatureException(e)
from airflow.providers.common.ai.operators.llm import LLMOperator
from airflow.providers.common.ai.utils.logging import log_run_summary
from airflow.providers.common.compat.sdk import BaseHook
if TYPE_CHECKING:
from sqlglot import exp
from airflow.providers.common.sql.config import DataSourceConfig
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.sdk import Context
# SQLAlchemy dialect_name → sqlglot dialect mapping for names that differ.
_SQLALCHEMY_TO_SQLGLOT_DIALECT: dict[str, str] = {
"postgresql": "postgres",
"mssql": "tsql",
}
[docs]
class LLMSQLQueryOperator(LLMOperator):
"""
Generate SQL queries from natural language using an LLM.
Inherits from :class:`~airflow.providers.common.ai.operators.llm.LLMOperator`
for LLM access and optionally uses a
:class:`~airflow.providers.common.sql.hooks.sql.DbApiHook`
for schema introspection. The operator generates SQL but does not execute it —
the generated SQL is returned as XCom and can be passed to
``SQLExecuteQueryOperator`` or used in downstream tasks.
When ``system_prompt`` is provided, it is appended to the built-in SQL safety
instructions — use it for domain-specific guidance (e.g. "prefer CTEs over
subqueries", "always use LEFT JOINs").
:param prompt: Natural language description of the desired query.
:param llm_conn_id: Connection ID for the LLM provider.
:param model_id: Model identifier (e.g. ``"openai:gpt-4o"``).
Overrides the model stored in the connection's extra field.
:param system_prompt: Additional instructions appended to the built-in SQL
safety prompt. Use for domain-specific guidance.
:param agent_params: Additional keyword arguments passed to the pydantic-ai
``Agent`` constructor (e.g. ``retries``, ``model_settings``).
:param db_conn_id: Connection ID for database schema introspection.
The connection must resolve to a ``DbApiHook``.
:param table_names: Tables to include in the LLM's schema context.
Used with ``db_conn_id`` for automatic introspection.
:param schema_context: Manual schema context string. When provided,
this is used instead of ``db_conn_id`` introspection.
:param validate_sql: Whether to validate generated SQL via AST parsing.
Default ``True`` (safe by default).
:param allowed_sql_types: SQL statement types to allow.
Default: ``(Select, Union, Intersect, Except)``.
:param dialect: SQL dialect for parsing (``postgres``, ``mysql``, etc.).
Auto-detected from the database hook if not set.
Human-in-the-Loop approval parameters are inherited from
:class:`~airflow.providers.common.ai.operators.llm.LLMOperator`
(``require_approval``, ``approval_timeout``, ``allow_modifications``).
When ``allow_modifications=True`` and the reviewer edits the SQL, the
modified query is re-validated against the same safety rules before being
returned.
"""
[docs]
template_fields: Sequence[str] = (
*LLMOperator.template_fields,
"db_conn_id",
"table_names",
"schema_context",
)
def __init__(
self,
*,
db_conn_id: str | None = None,
table_names: list[str] | None = None,
schema_context: str | None = None,
validate_sql: bool = True,
allowed_sql_types: tuple[type[exp.Expression], ...] = DEFAULT_ALLOWED_TYPES,
dialect: str | None = None,
datasource_config: DataSourceConfig | None = None,
**kwargs: Any,
) -> None:
kwargs.pop("output_type", None) # SQL operator always returns str
super().__init__(**kwargs)
[docs]
self.db_conn_id = db_conn_id
[docs]
self.table_names = table_names
[docs]
self.schema_context = schema_context
[docs]
self.validate_sql = validate_sql
[docs]
self.allowed_sql_types = allowed_sql_types
[docs]
self.datasource_config = datasource_config
@cached_property
[docs]
def db_hook(self) -> DbApiHook | None:
"""Return DbApiHook for the configured database connection, or None."""
if not self.db_conn_id:
return None
from airflow.providers.common.sql.hooks.sql import DbApiHook
connection = BaseHook.get_connection(self.db_conn_id)
hook = connection.get_hook()
if not isinstance(hook, DbApiHook):
raise ValueError(
f"Connection {self.db_conn_id!r} does not provide a DbApiHook. Got {type(hook).__name__}."
)
return hook
[docs]
def execute(self, context: Context) -> str:
schema_info = self._get_schema_context()
full_system_prompt = self._build_system_prompt(schema_info)
agent = self.llm_hook.create_agent(
output_type=str, instructions=full_system_prompt, **self.agent_params
)
result = agent.run_sync(self.prompt)
log_run_summary(self.log, result)
sql = self._strip_llm_output(result.output)
if self.validate_sql:
_validate_sql(sql, allowed_types=self.allowed_sql_types, dialect=self._resolved_dialect)
self.log.info("Generated SQL:\n%s", sql)
if self.require_approval:
self.defer_for_approval(context, sql) # type: ignore[misc]
return sql
[docs]
def execute_complete(self, context: Context, generated_output: str, event: dict[str, Any]) -> str:
"""Resume after human review, re-validating if the reviewer modified the SQL."""
output = super().execute_complete(context, generated_output, event)
if output != generated_output:
_validate_sql(output, allowed_types=self.allowed_sql_types, dialect=self._resolved_dialect)
return output
@staticmethod
def _strip_llm_output(raw: str) -> str:
"""Strip whitespace and markdown code fences from LLM output."""
text = raw.strip()
if text.startswith("```"):
lines = text.split("\n")
# Remove opening fence (```sql, ```, etc.) and closing fence
if len(lines) >= 2:
end = -1 if lines[-1].strip().startswith("```") else len(lines)
text = "\n".join(lines[1:end]).strip()
return text
def _get_schema_context(self) -> str:
"""Return schema context from manual override or database introspection."""
if self.schema_context:
return self.schema_context
if (self.db_hook and self.table_names) or self.datasource_config:
return self._introspect_schemas()
return ""
def _introspect_schemas(self) -> str:
"""Build schema context by introspecting tables via the database hook."""
parts: list[str] = []
for table in self.table_names or []:
columns = self.db_hook.get_table_schema(table) # type: ignore[union-attr]
if not columns:
self.log.warning("Table %r returned no columns — it may not exist.", table)
continue
col_info = ", ".join(f"{c['name']} {c['type']}" for c in columns)
parts.append(f"Table: {table}\nColumns: {col_info}")
if not parts and self.table_names:
raise ValueError(
f"None of the requested tables ({self.table_names}) returned schema information. "
"Check that the table names are correct and the database connection has access."
)
if self.datasource_config:
object_storage_schema = self._introspect_object_storage_schema()
parts.append(f"Table: {self.datasource_config.table_name}\nColumns: {object_storage_schema}")
return "\n\n".join(parts)
def _introspect_object_storage_schema(self):
"""Use DataFusion Engine to get the schema of object stores."""
engine = DataFusionEngine()
engine.register_datasource(self.datasource_config)
return engine.get_schema(self.datasource_config.table_name)
def _build_system_prompt(self, schema_info: str) -> str:
"""Construct the system prompt for the LLM."""
dialect_label = self._resolved_dialect or "SQL"
prompt = (
f"You are a {dialect_label} expert. "
"Generate a single SQL query based on the user's request.\n"
"Return ONLY the SQL query, no explanation or markdown.\n"
)
if schema_info:
prompt += f"\nAvailable schema:\n{schema_info}\n"
prompt += (
"\nRules:\n"
"- Generate only SELECT queries (including CTEs, JOINs, subqueries, UNION)\n"
"- Never generate data modification statements "
"(INSERT, UPDATE, DELETE, DROP, etc.)\n"
"- Use proper syntax for the specified dialect\n"
)
if self.system_prompt:
prompt += f"\nAdditional instructions:\n{self.system_prompt}\n"
return prompt
@cached_property
def _resolved_dialect(self) -> str | None:
"""
Resolve the SQL dialect from explicit parameter or database hook.
Normalizes SQLAlchemy dialect names to sqlglot equivalents
(e.g. ``postgresql`` → ``postgres``).
"""
raw = self.dialect
if not raw and self.db_hook and hasattr(self.db_hook, "dialect_name"):
raw = self.db_hook.dialect_name
if raw:
return _SQLALCHEMY_TO_SQLGLOT_DIALECT.get(raw, raw)
return None