# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Generic adapter that exposes Airflow Hook methods as pydantic-ai tools."""
from __future__ import annotations
import inspect
import json
import re
import types
from typing import TYPE_CHECKING, Any, Union, get_args, get_origin, get_type_hints
from pydantic_ai.tools import ToolDefinition
from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
from pydantic_core import SchemaValidator, core_schema
if TYPE_CHECKING:
from collections.abc import Callable
from pydantic_ai._run_context import RunContext
from airflow.providers.common.compat.sdk import BaseHook
# Single shared validator — accepts any JSON-decoded dict from the LLM.
_PASSTHROUGH_VALIDATOR = SchemaValidator(core_schema.any_schema())
# Maps Python types to JSON Schema fragments.
_TYPE_MAP: dict[type, dict[str, Any]] = {
str: {"type": "string"},
int: {"type": "integer"},
float: {"type": "number"},
bool: {"type": "boolean"},
list: {"type": "array"},
dict: {"type": "object"},
bytes: {"type": "string"},
}
# ---------------------------------------------------------------------------
# Private introspection helpers
# ---------------------------------------------------------------------------
def _python_type_to_json_schema(annotation: Any) -> dict[str, Any]:
"""Convert a Python type annotation to a JSON Schema fragment."""
if annotation is inspect.Parameter.empty or annotation is Any:
return {"type": "string"}
origin = get_origin(annotation)
args = get_args(annotation)
# Optional[X] is Union[X, None] — handle both types.UnionType (3.10+) and typing.Union
if origin is types.UnionType or origin is Union:
non_none = [a for a in args if a is not type(None)]
if len(non_none) == 1:
return _python_type_to_json_schema(non_none[0])
return {"type": "string"}
# list[X]
if origin is list:
items = _python_type_to_json_schema(args[0]) if args else {"type": "string"}
return {"type": "array", "items": items}
# dict[K, V]
if origin is dict:
return {"type": "object"}
# Always return a fresh copy — callers may mutate the dict (e.g. adding "description").
schema = _TYPE_MAP.get(annotation)
return dict(schema) if schema else {"type": "string"}
def _build_json_schema_from_signature(method: Callable[..., Any]) -> dict[str, Any]:
"""Build a JSON Schema ``object`` from a method's signature and type hints."""
sig = inspect.signature(method)
try:
hints = get_type_hints(method)
except Exception:
hints = {}
properties: dict[str, Any] = {}
required: list[str] = []
for name, param in sig.parameters.items():
if name in ("self", "cls"):
continue
# Skip **kwargs and *args
if param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD):
continue
annotation = hints.get(name, param.annotation)
prop = _python_type_to_json_schema(annotation)
properties[name] = prop
if param.default is inspect.Parameter.empty:
required.append(name)
schema: dict[str, Any] = {"type": "object", "properties": properties}
if required:
schema["required"] = required
return schema
def _extract_description(method: Callable[..., Any]) -> str:
"""Return the first paragraph of a method's docstring."""
doc = inspect.getdoc(method)
if not doc:
return method.__name__.replace("_", " ").capitalize()
# First paragraph = everything up to the first blank line.
lines: list[str] = []
for line in doc.splitlines():
if not line.strip():
if lines:
break
continue
lines.append(line.strip())
return " ".join(lines) if lines else method.__name__.replace("_", " ").capitalize()
# Matches Sphinx-style `:param name:` and Google-style `name:` under an ``Args:`` block.
_SPHINX_PARAM_RE = re.compile(r":param\s+(\w+):\s*(.+?)(?=\n\s*:|$)", re.DOTALL)
_GOOGLE_ARGS_RE = re.compile(r"^\s{2,}(\w+)\s*(?:\(.+?\))?:\s*(.+)", re.MULTILINE)
def _parse_param_docs(docstring: str) -> dict[str, str]:
"""Parse parameter descriptions from Sphinx or Google-style docstrings."""
params: dict[str, str] = {}
# Try Sphinx style first.
for match in _SPHINX_PARAM_RE.finditer(docstring):
name = match.group(1)
desc = " ".join(match.group(2).split())
params[name] = desc
if params:
return params
# Fall back to Google style (``Args:`` section).
in_args = False
for line in docstring.splitlines():
stripped = line.strip()
if stripped.lower().startswith("args:"):
in_args = True
continue
if in_args:
if stripped and not stripped[0].isspace() and ":" not in stripped:
break
m = _GOOGLE_ARGS_RE.match(line)
if m:
params[m.group(1)] = " ".join(m.group(2).split())
return params
def _serialize_for_llm(value: Any) -> str:
"""Convert a Python return value to a string suitable for an LLM."""
if value is None:
return "null"
if isinstance(value, str):
return value
try:
return json.dumps(value, default=str)
except (TypeError, ValueError):
return str(value)