Source code for airflow.providers.common.ai.toolsets.datafusion

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Curated SQL toolset wrapping DataFusionEngine for agentic object-store workflows."""

from __future__ import annotations

import json
import logging
from typing import TYPE_CHECKING, Any

try:
    from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError, validate_sql as _validate_sql
    from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
    from airflow.providers.common.sql.datafusion.exceptions import QueryExecutionException
except ImportError as e:
    from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException

    raise AirflowOptionalProviderFeatureException(e)

from pydantic_ai.tools import ToolDefinition
from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
from pydantic_core import SchemaValidator, core_schema

if TYPE_CHECKING:
    from pydantic_ai._run_context import RunContext

    from airflow.providers.common.sql.config import DataSourceConfig

[docs] log = logging.getLogger(__name__)
_PASSTHROUGH_VALIDATOR = SchemaValidator(core_schema.any_schema()) # JSON Schemas for the three DataFusion tools. _LIST_TABLES_SCHEMA: dict[str, Any] = { "type": "object", "properties": {}, } _GET_SCHEMA_SCHEMA: dict[str, Any] = { "type": "object", "properties": { "table_name": {"type": "string", "description": "Name of the table to inspect."}, }, "required": ["table_name"], } _QUERY_SCHEMA: dict[str, Any] = { "type": "object", "properties": { "sql": {"type": "string", "description": "SQL query to execute."}, }, "required": ["sql"], }
[docs] class DataFusionToolset(AbstractToolset[Any]): """ Curated toolset that gives an LLM agent SQL access to object-storage data via Apache DataFusion. Provides three tools — ``list_tables``, ``get_schema``, and ``query`` — backed by :class:`~airflow.providers.common.sql.datafusion.engine.DataFusionEngine`. Each :class:`~airflow.providers.common.sql.config.DataSourceConfig` entry registers a table backed by Parquet, CSV, Avro, or Iceberg data on S3 or local storage. Multiple configs can be registered so that SQL queries can join across tables. Requires the ``datafusion`` extra of ``apache-airflow-providers-common-sql``. :param datasource_configs: One or more DataFusion data-source configurations. :param allow_writes: Allow data-modifying SQL (CREATE TABLE, CREATE VIEW, INSERT INTO, etc.). Default ``False`` — only SELECT-family statements are permitted. :param max_rows: Maximum number of rows returned from the ``query`` tool. Default ``50``. """ def __init__( self, datasource_configs: list[DataSourceConfig], *, allow_writes: bool = False, max_rows: int = 50, ) -> None: if not datasource_configs: raise ValueError("datasource_configs must contain at least one DataSourceConfig") self._datasource_configs = datasource_configs self._allow_writes = allow_writes self._max_rows = max_rows self._engine: DataFusionEngine | None = None @property
[docs] def id(self) -> str: suffix = "_".join(config.table_name.replace("-", "_") for config in self._datasource_configs) return f"sql_datafusion_{suffix}"
def _get_engine(self) -> DataFusionEngine: """Lazily create and configure a DataFusionEngine from *datasource_configs*.""" if self._engine is None: engine = DataFusionEngine() for config in self._datasource_configs: engine.register_datasource(config) self._engine = engine return self._engine
[docs] async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: tools: dict[str, ToolsetTool[Any]] = {} for name, description, schema in ( ("list_tables", "List available table names.", _LIST_TABLES_SCHEMA), ("get_schema", "Get column names and types for a table.", _GET_SCHEMA_SCHEMA), ("query", "Execute a SQL query and return rows as JSON.", _QUERY_SCHEMA), ): tool_def = ToolDefinition( name=name, description=description, parameters_json_schema=schema, sequential=True, ) tools[name] = ToolsetTool( toolset=self, tool_def=tool_def, max_retries=1, args_validator=_PASSTHROUGH_VALIDATOR, ) return tools
[docs] async def call_tool( self, name: str, tool_args: dict[str, Any], ctx: RunContext[Any], tool: ToolsetTool[Any], ) -> Any: if name == "list_tables": return self._list_tables() if name == "get_schema": return self._get_schema(tool_args["table_name"]) if name == "query": return self._query(tool_args["sql"]) raise ValueError(f"Unknown tool: {name!r}")
def _list_tables(self) -> str: try: engine = self._get_engine() tables: list[str] = list(engine.session_context.catalog().schema().table_names()) return json.dumps(tables) except Exception as ex: log.warning("list_tables failed: %s", ex) return json.dumps({"error": str(ex)}) def _get_schema(self, table_name: str) -> str: engine = self._get_engine() # session_context lookup is required here instead of engine.registered_tables, # because registered_tables only tracks tables registered via datasource config. # When allow_writes is enabled, the agent may create temporary in-memory tables # that would not be captured there. if not engine.session_context.table_exist(table_name): return json.dumps({"error": f"Table {table_name!r} is not available"}) # Intentionally using session_context instead of engine.get_schema() — # the latter returns a pre-formatted string intended for other operators, # not a JSON-compatible format. # TODO: refactor engine.get_schema() to return JSON and update this accordingly table = engine.session_context.table(table_name) columns = [{"name": f.name, "type": str(f.type)} for f in table.schema()] return json.dumps(columns) def _query(self, sql: str) -> str: try: if not self._allow_writes: _validate_sql(sql) engine = self._get_engine() pydict = engine.execute_query(sql) col_names = list(pydict.keys()) num_rows = len(next(iter(pydict.values()), [])) result: list[dict[str, Any]] = [ {col: pydict[col][i] for col in col_names} for i in range(min(num_rows, self._max_rows)) ] truncated = num_rows > self._max_rows output: dict[str, Any] = {"rows": result, "count": num_rows} if truncated: output["truncated"] = True output["max_rows"] = self._max_rows return json.dumps(output, default=str) except SQLSafetyError as ex: log.warning("query failed SQL safety validation: %s", ex) raise except QueryExecutionException as ex: return json.dumps({"error": str(ex), "query": sql})

Was this entry helpful?