Source code for airflow.providers.common.ai.hooks.pydantic_ai

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any, TypeVar, overload

from pydantic_ai import Agent
from pydantic_ai.models import infer_model
from pydantic_ai.providers import infer_provider, infer_provider_class

from airflow.providers.common.compat.sdk import BaseHook

[docs] OutputT = TypeVar("OutputT")
if TYPE_CHECKING: from pydantic_ai.models import KnownModelName, Model
[docs] class PydanticAIHook(BaseHook): """ Hook for LLM access via pydantic-ai. Covers providers that use a standard ``api_key`` + optional ``base_url`` (OpenAI, Anthropic, Groq, Mistral, DeepSeek, Ollama, vLLM, …). For cloud providers with non-standard auth use the dedicated subclasses: :class:`PydanticAIAzureHook`, :class:`PydanticAIBedrockHook`, :class:`PydanticAIVertexHook`. Connection fields: - **password**: API key - **host**: Base URL (optional, e.g. ``https://api.openai.com/v1``) - **extra** JSON: ``{"model": "openai:gpt-5.3"}`` :param llm_conn_id: Airflow connection ID for the LLM provider. :param model_id: Model identifier in ``provider:model`` format (e.g. ``"openai:gpt-5.3"``). Overrides the model stored in the connection's extra field. """
[docs] conn_name_attr = "llm_conn_id"
[docs] default_conn_name = "pydanticai_default"
[docs] conn_type = "pydanticai"
[docs] hook_name = "Pydantic AI"
def __init__( self, llm_conn_id: str | None = None, model_id: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) # Resolve at runtime so each subclass uses its own default_conn_name. # A bare `llm_conn_id: str = default_conn_name` would bind the *base* # class value for all subclasses because Python evaluates default # argument values at class-definition time.
[docs] self.llm_conn_id = llm_conn_id if llm_conn_id is not None else self.default_conn_name
[docs] self.model_id = model_id
self._model: Model | None = None @staticmethod
[docs] def get_ui_field_behaviour() -> dict[str, Any]: """Return custom field behaviour for the Airflow connection form.""" return { "hidden_fields": ["schema", "port", "login"], "relabeling": {"password": "API Key"}, "placeholders": { "host": "https://api.openai.com/v1 (optional, for custom endpoints / Ollama)", "extra": '{"model": "openai:gpt-5.3"}', }, }
# ------------------------------------------------------------------ # Core connection / agent API # ------------------------------------------------------------------ def _get_provider_kwargs( self, api_key: str | None, base_url: str | None, extra: dict[str, Any], ) -> dict[str, Any]: """ Return the kwargs to pass to the provider constructor. Subclasses override this method to map their connection fields to the parameters expected by their specific provider class. The base implementation handles the common ``api_key`` / ``base_url`` pattern used by OpenAI, Anthropic, Groq, Mistral, Ollama, and most other providers. :param api_key: Value of ``conn.password``. :param base_url: Value of ``conn.host``. :param extra: Deserialized ``conn.extra`` JSON. :return: Kwargs forwarded to ``provider_cls(**kwargs)``. Empty dict signals that no explicit credentials are available and the hook should fall back to environment-variable–based auth. """ kwargs: dict[str, Any] = {} if api_key: kwargs["api_key"] = api_key if base_url: kwargs["base_url"] = base_url return kwargs
[docs] def get_conn(self) -> Model: """ Return a configured pydantic-ai ``Model``. Resolution order: 1. **Explicit credentials** — when :meth:`_get_provider_kwargs` returns a non-empty dict the provider class is instantiated with those kwargs and wrapped in a ``provider_factory``. 2. **Default resolution** — delegates to pydantic-ai ``infer_model`` which reads standard env vars (``OPENAI_API_KEY``, ``AWS_PROFILE``, …). The resolved model is cached for the lifetime of this hook instance. """ if self._model is not None: return self._model conn = self.get_connection(self.llm_conn_id) extra: dict[str, Any] = conn.extra_dejson model_name: str | KnownModelName = self.model_id or extra.get("model", "") if not model_name: raise ValueError( "No model specified. Set model_id on the hook or the Model field on the connection." ) api_key: str | None = conn.password or None base_url: str | None = conn.host or None provider_kwargs = self._get_provider_kwargs(api_key, base_url, extra) if provider_kwargs: _kwargs = provider_kwargs # capture for closure self.log.info( "Using explicit credentials for provider with model '%s': %s", model_name, list(provider_kwargs), ) def _provider_factory(pname: str) -> Any: try: return infer_provider_class(pname)(**_kwargs) except TypeError: self.log.warning( "Provider '%s' rejected kwargs %s; falling back to env-var auth", pname, list(_kwargs), ) return infer_provider(pname) self._model = infer_model(model_name, provider_factory=_provider_factory) return self._model self._model = infer_model(model_name) return self._model
@overload
[docs] def create_agent( self, output_type: type[OutputT], *, instructions: str, **agent_kwargs ) -> Agent[None, OutputT]: ...
@overload def create_agent(self, *, instructions: str, **agent_kwargs) -> Agent[None, str]: ... def create_agent( self, output_type: type[Any] = str, *, instructions: str, **agent_kwargs ) -> Agent[None, Any]: """ Create a pydantic-ai Agent configured with this hook's model. :param output_type: The expected output type from the agent (default: ``str``). :param instructions: System-level instructions for the agent. :param agent_kwargs: Additional keyword arguments passed to the Agent constructor. """ return Agent(self.get_conn(), output_type=output_type, instructions=instructions, **agent_kwargs)
[docs] def test_connection(self) -> tuple[bool, str]: """ Test connection by resolving the model. Validates that the model string is valid and the provider class can be instantiated with the supplied credentials. Does NOT make an LLM API call — that would be expensive and fail for reasons unrelated to connectivity (quotas, billing, rate limits). """ try: self.get_conn() return True, "Model resolved successfully." except Exception as e: return False, str(e)
[docs] class PydanticAIAzureHook(PydanticAIHook): """ Hook for Azure OpenAI via pydantic-ai. Connection fields: - **password**: Azure API key - **host**: Azure endpoint (e.g. ``https://<resource>.openai.azure.com``) - **extra** JSON:: {"model": "azure:gpt-4o", "api_version": "2024-07-01-preview"} :param llm_conn_id: Airflow connection ID. :param model_id: Model identifier, e.g. ``"azure:gpt-4o"``. """
[docs] conn_type = "pydanticai-azure"
[docs] default_conn_name = "pydanticai_azure_default"
[docs] hook_name = "Pydantic AI (Azure OpenAI)"
@staticmethod
[docs] def get_ui_field_behaviour() -> dict[str, Any]: """Return custom field behaviour for the Airflow connection form.""" return { "hidden_fields": ["schema", "port", "login"], "relabeling": {"password": "API Key", "host": "Azure Endpoint"}, "placeholders": { "host": "https://<resource>.openai.azure.com", "extra": '{"model": "azure:gpt-4o", "api_version": "2024-07-01-preview"}', }, }
def _get_provider_kwargs( self, api_key: str | None, base_url: str | None, extra: dict[str, Any], ) -> dict[str, Any]: kwargs: dict[str, Any] = {} if api_key: kwargs["api_key"] = api_key if base_url: kwargs["azure_endpoint"] = base_url if extra.get("api_version"): kwargs["api_version"] = extra["api_version"] return kwargs
[docs] class PydanticAIBedrockHook(PydanticAIHook): """ Hook for AWS Bedrock via pydantic-ai. Credentials are resolved in order: 1. IAM keys from ``extra`` (``aws_access_key_id`` + ``aws_secret_access_key``, optionally ``aws_session_token``). 2. Bearer token in ``extra`` (``api_key``, maps to env ``AWS_BEARER_TOKEN_BEDROCK``). 3. Environment-variable / instance-role chain (``AWS_PROFILE``, IAM role, …) when no explicit keys are provided. Connection fields: - **extra** JSON:: { "model": "bedrock:us.anthropic.claude-opus-4-5", "region_name": "us-east-1", "aws_access_key_id": "AKIA...", "aws_secret_access_key": "...", "aws_session_token": "...", "profile_name": "my-aws-profile", "api_key": "bearer-token", "base_url": "https://custom-bedrock-endpoint", "aws_read_timeout": 60.0, "aws_connect_timeout": 10.0 } Leave ``aws_access_key_id`` / ``aws_secret_access_key`` and ``api_key`` empty to use the default AWS credential chain. :param llm_conn_id: Airflow connection ID. :param model_id: Model identifier, e.g. ``"bedrock:us.anthropic.claude-opus-4-5"``. """
[docs] conn_type = "pydanticai-bedrock"
[docs] default_conn_name = "pydanticai_bedrock_default"
[docs] hook_name = "Pydantic AI (AWS Bedrock)"
@staticmethod
[docs] def get_ui_field_behaviour() -> dict[str, Any]: """Return custom field behaviour for the Airflow connection form.""" return { "hidden_fields": ["schema", "port", "login", "host", "password"], "relabeling": {}, "placeholders": { "extra": ( '{"model": "bedrock:us.anthropic.claude-opus-4-5", ' '"region_name": "us-east-1"}' " — leave aws_access_key_id empty for IAM role / env-var auth" ), }, }
def _get_provider_kwargs( self, api_key: str | None, base_url: str | None, extra: dict[str, Any], ) -> dict[str, Any]: """ Return kwargs for ``BedrockProvider``. .. note:: The ``api_key`` and ``base_url`` parameters (sourced from ``conn.password`` and ``conn.host``) are intentionally ignored here. Bedrock connections hide those fields in the UI; all config is stored in ``extra`` instead. The ``api_key`` and ``base_url`` keys below refer to *extra* fields, not the method parameters. """ _str_keys = ( "aws_access_key_id", "aws_secret_access_key", "aws_session_token", "region_name", "profile_name", # Bearer-token auth (alternative to IAM key/secret). # Maps to AWS_BEARER_TOKEN_BEDROCK env var. "api_key", # Custom Bedrock runtime endpoint. "base_url", ) kwargs: dict[str, Any] = {k: extra[k] for k in _str_keys if extra.get(k)} # BedrockProvider expects float for timeout values; JSON integers must be coerced. for _timeout_key in ("aws_read_timeout", "aws_connect_timeout"): if extra.get(_timeout_key): kwargs[_timeout_key] = float(extra[_timeout_key]) return kwargs
[docs] class PydanticAIVertexHook(PydanticAIHook): """ Hook for Google Vertex AI (or Generative Language API) via pydantic-ai. Credentials are resolved in order: 1. ``service_account_info`` (JSON object) in ``extra`` — loaded into a ``google.auth.credentials.Credentials`` object and passed as ``credentials`` to ``GoogleProvider``. 2. ``api_key`` in ``extra`` — for Generative Language API (non-Vertex) or Vertex API-key auth. 3. Application Default Credentials (``GOOGLE_APPLICATION_CREDENTIALS``, ``gcloud auth application-default login``, Workload Identity, …) when no explicit credentials are provided. Connection fields: - **extra** JSON:: { "model": "google-vertex:gemini-2.0-flash", "project": "my-gcp-project", "location": "us-central1", "service_account_info": {...}, "vertexai": true, } Use ``"service_account_info"`` to embed the service-account JSON directly (as an object, not a string path). Set ``"vertexai": true`` to force Vertex AI mode when only ``api_key`` is provided. Omit ``vertexai`` for the Generative Language API (GLA). :param llm_conn_id: Airflow connection ID. :param model_id: Model identifier, e.g. ``"google-vertex:gemini-2.0-flash"``. """
[docs] conn_type = "pydanticai-vertex"
[docs] default_conn_name = "pydanticai_vertex_default"
[docs] hook_name = "Pydantic AI (Google Vertex AI)"
@staticmethod
[docs] def get_ui_field_behaviour() -> dict[str, Any]: """Return custom field behaviour for the Airflow connection form.""" return { "hidden_fields": ["schema", "port", "login", "host", "password"], "relabeling": {}, "placeholders": { "extra": ( '{"model": "google-vertex:gemini-2.0-flash", ' '"project": "my-project", "location": "us-central1", "vertexai": true}' " — add service_account_info (object) for SA auth;" " omit both to use Application Default Credentials" ), }, }
def _get_provider_kwargs( self, api_key: str | None, base_url: str | None, extra: dict[str, Any], ) -> dict[str, Any]: sa_info = extra.get("service_account_info") kwargs: dict[str, Any] = {} # Direct GoogleProvider scalar kwargs. for _key in ("api_key", "project", "location", "base_url"): if extra.get(_key): kwargs[_key] = extra[_key] # Optional vertexai bool flag (force Vertex AI mode for API-key auth). _vertexai = extra.get("vertexai") if _vertexai is not None: kwargs["vertexai"] = bool(_vertexai) # Service-account credentials — loaded lazily to avoid importing # google-auth on non-Vertex code paths (optional heavy dependency). if sa_info: from google.oauth2 import service_account # lazy: optional dep kwargs["credentials"] = service_account.Credentials.from_service_account_info( sa_info, scopes=["https://www.googleapis.com/auth/cloud-platform"], ) return kwargs

Was this entry helpful?