# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Hook for LlamaIndex integration with Airflow connections."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from airflow.providers.common.compat.sdk import (
AirflowOptionalProviderFeatureException,
BaseHook,
)
if TYPE_CHECKING:
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.llms.llm import LLM
[docs]
class LlamaIndexHook(BaseHook):
"""
Bridge an Airflow connection to LlamaIndex chat and embedding models.
The hook resolves credentials (API key, optional API base URL) from the
Airflow connection and returns native LlamaIndex objects ready to pass
to ``VectorStoreIndex(..., embed_model=...)``,
``load_index_from_storage(..., embed_model=...)``, or
``index.as_retriever(..., llm=...)``.
LlamaIndex does not ship a universal ``init_chat_model`` /
``init_embedding_model`` equivalent (each vendor is a separate package
under ``llama-index-llms-*`` / ``llama-index-embeddings-*`` with its own
constructor kwargs). The hook therefore covers the OpenAI-compatible
surface that matches LlamaIndex's own ``resolve_embed_model("default")``
behaviour. For other vendors (Cohere, Bedrock, Vertex, HuggingFace, ...)
instantiate the LlamaIndex class directly in your ``@task`` and pass it
to the operator's ``embed_model=`` / ``llm=`` parameter -- both
``LlamaIndexEmbeddingOperator`` and ``LlamaIndexRetrievalOperator`` accept a pre-built
``BaseEmbedding`` / ``LLM`` instance and bypass the hook in that case.
.. note::
The hook deliberately does **not** mutate LlamaIndex's global
``Settings`` singleton. Operators pass the resolved model directly
to LlamaIndex constructors so concurrent tasks in the same worker
don't race on shared state.
Connection fields:
* **password**: API key passed as ``api_key=``.
* **host**: Optional base URL passed as ``api_base=`` (custom endpoints,
Ollama, vLLM).
* **extra** JSON: ``{"embed_model": "text-embedding-3-small",
"llm_model": "gpt-4o"}`` -- default model identifiers stored on the
connection.
:param llm_conn_id: Airflow connection ID for the LLM provider. Falls
back to :attr:`default_conn_name` (``"llamaindex_default"``) when
not provided.
:param embed_conn_id: Optional separate Airflow connection ID for the
embedding provider. Falls back to ``llm_conn_id`` when not set.
:param embed_model: Embedding model name (e.g.
``"text-embedding-3-small"``). Overrides ``extra["embed_model"]``
on the connection.
:param llm_model: LLM model name (e.g. ``"gpt-4o"``). Overrides
``extra["llm_model"]`` on the connection. Required when calling
:meth:`get_llm`.
"""
[docs]
conn_name_attr = "llm_conn_id"
[docs]
default_conn_name = "llamaindex_default"
[docs]
conn_type = "llamaindex"
[docs]
hook_name = "LlamaIndex"
def __init__(
self,
llm_conn_id: str | None = None,
embed_conn_id: str | None = None,
embed_model: str | None = None,
llm_model: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
# Resolve at runtime so a future per-vendor subclass with its own
# ``default_conn_name`` is honoured.
[docs]
self.llm_conn_id = llm_conn_id if llm_conn_id is not None else self.default_conn_name
[docs]
self.embed_conn_id = embed_conn_id if embed_conn_id is not None else self.llm_conn_id
[docs]
self.embed_model = embed_model
[docs]
self.llm_model = llm_model
@staticmethod
[docs]
def get_ui_field_behaviour() -> dict[str, Any]:
"""Return custom field behaviour for the Airflow connection form."""
return {
"hidden_fields": ["schema", "port", "login"],
"relabeling": {"password": "API Key"},
"placeholders": {
"host": "https://api.openai.com/v1 (optional, for custom endpoints / Ollama)",
"extra": '{"embed_model": "text-embedding-3-small", "llm_model": "gpt-4o"}',
},
}
@staticmethod
def _resolve_model(
conn_extra: dict[str, Any],
*,
constructor_value: str | None,
extra_key: str,
kind: str,
) -> str:
"""Resolve a model identifier from the constructor arg or connection extra."""
model_id = constructor_value or conn_extra.get(extra_key)
if not model_id:
raise ValueError(
f"No {kind} model identifier set. Pass {extra_key}= to the hook "
f'constructor or set extra={{"{extra_key}": "model-name"}} on '
"the connection."
)
return model_id
@staticmethod
def _connection_kwargs(conn: Any) -> dict[str, Any]:
"""Return shared OpenAI-style kwargs (api_key, api_base) from the connection."""
kwargs: dict[str, Any] = {}
if conn.password:
kwargs["api_key"] = conn.password
if conn.host:
kwargs["api_base"] = conn.host
return kwargs
[docs]
def get_embedding_model(self) -> BaseEmbedding:
"""
Return a LlamaIndex embedding model configured from the Airflow connection.
Uses ``embed_conn_id`` (falls back to ``llm_conn_id``) for credentials.
Returns an ``OpenAIEmbedding`` instance; for other vendors,
instantiate the LlamaIndex class directly and pass it to the
operator's ``embed_model=`` parameter.
"""
# Lazy: llama-index is an optional extra; importing at module level
# would break common.ai for users who haven't installed ``[llamaindex]``.
try:
from llama_index.embeddings.openai import OpenAIEmbedding
except ImportError as e:
raise AirflowOptionalProviderFeatureException(e)
conn = self.get_connection(self.embed_conn_id)
model_id = self._resolve_model(
conn.extra_dejson,
constructor_value=self.embed_model,
extra_key="embed_model",
kind="embedding",
)
return OpenAIEmbedding(model=model_id, **self._connection_kwargs(conn))
[docs]
def get_llm(self) -> LLM:
"""
Return a LlamaIndex LLM configured from the Airflow connection.
Returns an ``OpenAI`` LLM instance; for other vendors, instantiate
the LlamaIndex class directly and pass it to the operator's ``llm=``
parameter.
"""
try:
from llama_index.llms.openai import OpenAI
except ImportError as e:
raise AirflowOptionalProviderFeatureException(e)
conn = self.get_connection(self.llm_conn_id)
model_id = self._resolve_model(
conn.extra_dejson,
constructor_value=self.llm_model,
extra_key="llm_model",
kind="llm",
)
return OpenAI(model=model_id, **self._connection_kwargs(conn))