Source code for airflow.providers.common.ai.operators.agent

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Operator for running pydantic-ai agents with tools and multi-turn reasoning."""

from __future__ import annotations

import json
from collections.abc import Sequence
from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any

from pydantic import BaseModel

from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook
from airflow.providers.common.ai.mixins.hitl_review import HITLReviewMixin
from airflow.providers.common.ai.utils.logging import log_run_summary, wrap_toolsets_for_logging
from airflow.providers.common.compat.sdk import (
    AirflowOptionalProviderFeatureException,
    BaseOperator,
    BaseOperatorLink,
    conf,
)
from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_1_PLUS

if TYPE_CHECKING:
    from pydantic_ai import Agent
    from pydantic_ai.toolsets.abstract import AbstractToolset

    from airflow.providers.common.ai.durable.step_counter import DurableStepCounter
    from airflow.providers.common.ai.durable.storage import DurableStorage
    from airflow.providers.common.compat.sdk import TaskInstanceKey
    from airflow.sdk import Context






[docs] class AgentOperator(BaseOperator, HITLReviewMixin): """ Run a pydantic-ai Agent with tools and multi-turn reasoning. Provide ``llm_conn_id`` and optional ``toolsets`` to let the operator build and run the agent. The agent reasons about the prompt, calls tools in a multi-turn loop, and returns a final answer. :param prompt: The prompt to send to the agent. :param llm_conn_id: Connection ID for the LLM provider. :param model_id: Model identifier (e.g. ``"openai:gpt-5"``). Overrides the model stored in the connection's extra field. :param system_prompt: System-level instructions for the agent. :param output_type: Expected output type. Default ``str``. Set to a Pydantic ``BaseModel`` subclass for structured output. :param toolsets: List of pydantic-ai toolsets the agent can use (e.g. ``SQLToolset``, ``HookToolset``). :param enable_tool_logging: When ``True`` (default), wraps each toolset in a ``LoggingToolset`` that logs tool calls with timing at INFO level and arguments at DEBUG level. Set to ``False`` to disable. :param agent_params: Additional keyword arguments passed to the pydantic-ai ``Agent`` constructor (e.g. ``retries``, ``model_settings``). :param durable: When ``True``, enables step-level caching of model responses and tool results for durable execution. On retry, cached steps are replayed instead of re-executing. Default ``False``. Requires ``[common.ai] durable_cache_path`` to be set. **HITL Review parameters** (requires the ``hitl_review`` plugin): :param enable_hitl_review: When ``True``, the operator enters an iterative review loop after the first generation. A human reviewer can approve, reject, or request changes via the plugin's REST API at ``/hitl-review`` or through the **HITL Review** extra link on the task instance. Default ``False``. :param max_hitl_iterations: Maximum outputs shown to the reviewer (1 = initial output). When the reviewer requests changes at iteration >= this limit, the task fails with ``HITLMaxIterationsError`` without calling the LLM. E.g. 5 allows changes at iterations 1–4. Default ``5``. :param hitl_timeout: Maximum wall-clock time to wait for all review rounds combined. ``None`` means no timeout (the operator blocks until a terminal action). :param hitl_poll_interval: Seconds between XCom polls while waiting for a human response. Default ``10``. """
[docs] template_fields: Sequence[str] = ( "prompt", "llm_conn_id", "model_id", "system_prompt", "agent_params", )
def __init__( self, *, prompt: str, llm_conn_id: str, model_id: str | None = None, system_prompt: str = "", output_type: type = str, toolsets: list[AbstractToolset] | None = None, enable_tool_logging: bool = True, agent_params: dict[str, Any] | None = None, durable: bool = False, # Agent feedback parameters enable_hitl_review: bool = False, max_hitl_iterations: int = 5, hitl_timeout: timedelta | None = None, hitl_poll_interval: float = 10.0, **kwargs: Any, ) -> None: super().__init__(**kwargs)
[docs] self.prompt = prompt
[docs] self.llm_conn_id = llm_conn_id
[docs] self.model_id = model_id
[docs] self.system_prompt = system_prompt
[docs] self.output_type = output_type
[docs] self.toolsets = toolsets
[docs] self.enable_tool_logging = enable_tool_logging
[docs] self.agent_params = agent_params or {}
[docs] self.durable = durable
if durable and enable_hitl_review: raise ValueError("durable=True and enable_hitl_review=True cannot be used together.")
[docs] self.enable_hitl_review = enable_hitl_review
[docs] self.max_hitl_iterations = max_hitl_iterations
[docs] self.hitl_timeout = hitl_timeout
[docs] self.hitl_poll_interval = hitl_poll_interval
if self.enable_hitl_review and not AIRFLOW_V_3_1_PLUS: raise AirflowOptionalProviderFeatureException( "Human in the loop functionality needs Airflow 3.1+." ) @cached_property
[docs] def llm_hook(self) -> PydanticAIHook: """Return PydanticAIHook for the configured LLM connection.""" hook_params = { "model_id": self.model_id, } return PydanticAIHook.get_hook(self.llm_conn_id, hook_params=hook_params)
def _build_agent(self) -> Agent[None, Any]: """Build and return a pydantic-ai Agent from the operator's config.""" extra_kwargs = dict(self.agent_params) if self.toolsets: toolsets = self.toolsets if self.durable and self._durable_storage is not None and self._durable_counter is not None: toolsets = self._build_durable_toolsets( toolsets, self._durable_storage, self._durable_counter ) if self.enable_tool_logging: toolsets = wrap_toolsets_for_logging(toolsets, self.log) extra_kwargs["toolsets"] = toolsets return self.llm_hook.create_agent( output_type=self.output_type, instructions=self.system_prompt, **extra_kwargs, ) def _build_durable_toolsets( self, toolsets: list[AbstractToolset], storage: DurableStorage, counter: DurableStepCounter ) -> list[AbstractToolset]: """Wrap each toolset with CachingToolset for durable execution.""" from airflow.providers.common.ai.durable.caching_toolset import CachingToolset return [CachingToolset(wrapped=ts, storage=storage, counter=counter) for ts in toolsets]
[docs] def execute(self, context: Context) -> Any: self._durable_storage = None self._durable_counter = None if self.durable: from airflow.providers.common.ai.durable.step_counter import DurableStepCounter from airflow.providers.common.ai.durable.storage import DurableStorage ti = context["task_instance"] self._durable_storage = DurableStorage( dag_id=ti.dag_id, task_id=ti.task_id, run_id=ti.run_id, map_index=ti.map_index if ti.map_index is not None else -1, ) self._durable_counter = DurableStepCounter() agent = self._build_agent() storage = self._durable_storage counter = self._durable_counter if self.durable and storage is not None and counter is not None: from pydantic_ai.models import infer_model from airflow.providers.common.ai.durable.caching_model import CachingModel if agent.model is None: raise ValueError("Agent model must be set when durable=True") resolved_model = infer_model(agent.model) caching_model = CachingModel(resolved_model, storage=storage, counter=counter) with agent.override(model=caching_model): result = agent.run_sync(self.prompt) else: result = agent.run_sync(self.prompt) log_run_summary(self.log, result) if self._durable_counter is not None: c = self._durable_counter replayed = c.replayed_model + c.replayed_tool cached = c.cached_model + c.cached_tool if replayed: self.log.info( "Durable: replayed %d cached steps (%d model, %d tool), " "executed %d new steps (%d model, %d tool)", replayed, c.replayed_model, c.replayed_tool, cached, c.cached_model, c.cached_tool, ) if self._durable_storage is not None: self._durable_storage.cleanup() output = result.output if self.enable_hitl_review: result_str = self.run_hitl_review( # type: ignore[misc] context, output, message_history=result.all_messages(), ) # Deserialize back to dict try: return json.loads(result_str) except (ValueError, TypeError): return result_str if isinstance(output, BaseModel): return output.model_dump() return output
[docs] def regenerate_with_feedback(self, *, feedback: str, message_history: Any) -> tuple[str, Any]: """Re-run the agent with *feedback* appended to the conversation history.""" agent = self._build_agent() messages = message_history or [] result = agent.run_sync(feedback, message_history=messages) log_run_summary(self.log, result) output = result.output if isinstance(output, BaseModel): output = output.model_dump_json() return str(output), result.all_messages()

Was this entry helpful?