# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import json
import logging
from typing import TYPE_CHECKING, Any
from urllib.parse import urljoin
import requests
from fastapi import FastAPI
from airflow.api_fastapi.app import AUTH_MANAGER_FASTAPI_APP_PREFIX
from airflow.api_fastapi.auth.managers.base_auth_manager import BaseAuthManager, T
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.keycloak.auth_manager.resources import KeycloakResource
from airflow.providers.keycloak.auth_manager.user import KeycloakAuthManagerUser
if TYPE_CHECKING:
from airflow.api_fastapi.auth.managers.base_auth_manager import ResourceMethod
from airflow.api_fastapi.auth.managers.models.resource_details import (
AccessView,
AssetAliasDetails,
AssetDetails,
BackfillDetails,
ConfigurationDetails,
ConnectionDetails,
DagAccessEntity,
DagDetails,
PoolDetails,
VariableDetails,
)
from airflow.api_fastapi.common.types import MenuItem
[docs]
log = logging.getLogger(__name__)
[docs]
class KeycloakAuthManager(BaseAuthManager[KeycloakAuthManagerUser]):
"""
Keycloak auth manager.
Leverages Keycloak to perform authentication and authorization in Airflow.
"""
[docs]
def deserialize_user(self, token: dict[str, Any]) -> KeycloakAuthManagerUser:
return KeycloakAuthManagerUser(
user_id=token.pop("user_id"),
name=token.pop("name"),
access_token=token.pop("access_token"),
refresh_token=token.pop("refresh_token"),
)
[docs]
def serialize_user(self, user: KeycloakAuthManagerUser) -> dict[str, Any]:
return {
"user_id": user.get_id(),
"name": user.get_name(),
"access_token": user.access_token,
"refresh_token": user.refresh_token,
}
[docs]
def get_url_login(self, **kwargs) -> str:
base_url = conf.get("api", "base_url", fallback="/")
return urljoin(base_url, f"{AUTH_MANAGER_FASTAPI_APP_PREFIX}/login")
[docs]
def is_authorized_configuration(
self,
*,
method: ResourceMethod,
user: KeycloakAuthManagerUser,
details: ConfigurationDetails | None = None,
) -> bool:
config_section = details.section if details else None
return self._is_authorized(
method=method, resource_type=KeycloakResource.CONFIGURATION, user=user
) or (
config_section is not None
and self._is_authorized(
method=method,
resource_type=KeycloakResource.CONFIGURATION,
user=user,
resource_id=config_section,
)
)
[docs]
def is_authorized_connection(
self,
*,
method: ResourceMethod,
user: KeycloakAuthManagerUser,
details: ConnectionDetails | None = None,
) -> bool:
connection_id = details.conn_id if details else None
return self._is_authorized(method=method, resource_type=KeycloakResource.CONNECTION, user=user) or (
connection_id is not None
and self._is_authorized(
method=method, resource_type=KeycloakResource.CONNECTION, user=user, resource_id=connection_id
)
)
[docs]
def is_authorized_dag(
self,
*,
method: ResourceMethod,
user: T,
access_entity: DagAccessEntity | None = None,
details: DagDetails | None = None,
) -> bool:
return True
[docs]
def is_authorized_backfill(
self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, details: BackfillDetails | None = None
) -> bool:
backfill_id = str(details.id) if details else None
return self._is_authorized(method=method, resource_type=KeycloakResource.BACKFILL, user=user) or (
backfill_id is not None
and self._is_authorized(
method=method, resource_type=KeycloakResource.BACKFILL, user=user, resource_id=backfill_id
)
)
[docs]
def is_authorized_asset(
self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, details: AssetDetails | None = None
) -> bool:
asset_id = details.id if details else None
return self._is_authorized(method=method, resource_type=KeycloakResource.ASSET, user=user) or (
asset_id is not None
and self._is_authorized(
method=method, resource_type=KeycloakResource.ASSET, user=user, resource_id=asset_id
)
)
[docs]
def is_authorized_asset_alias(
self,
*,
method: ResourceMethod,
user: KeycloakAuthManagerUser,
details: AssetAliasDetails | None = None,
) -> bool:
asset_alias_id = details.id if details else None
return self._is_authorized(method=method, resource_type=KeycloakResource.ASSET_ALIAS, user=user) or (
asset_alias_id is not None
and self._is_authorized(
method=method,
resource_type=KeycloakResource.ASSET_ALIAS,
user=user,
resource_id=asset_alias_id,
)
)
[docs]
def is_authorized_variable(
self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, details: VariableDetails | None = None
) -> bool:
variable_key = details.key if details else None
return self._is_authorized(method=method, resource_type=KeycloakResource.VARIABLE, user=user) or (
variable_key is not None
and self._is_authorized(
method=method, resource_type=KeycloakResource.VARIABLE, user=user, resource_id=variable_key
)
)
[docs]
def is_authorized_pool(
self, *, method: ResourceMethod, user: KeycloakAuthManagerUser, details: PoolDetails | None = None
) -> bool:
pool_name = details.name if details else None
return self._is_authorized(method=method, resource_type=KeycloakResource.POOL, user=user) or (
pool_name is not None
and self._is_authorized(
method=method, resource_type=KeycloakResource.POOL, user=user, resource_id=pool_name
)
)
[docs]
def is_authorized_view(self, *, access_view: AccessView, user: KeycloakAuthManagerUser) -> bool:
return self._is_authorized(
method="GET", resource_type=KeycloakResource.VIEW, user=user
) or self._is_authorized(
method="GET",
resource_type=KeycloakResource.VIEW,
user=user,
resource_id=access_view.value,
)
[docs]
def is_authorized_custom_view(
self, *, method: ResourceMethod | str, resource_name: str, user: KeycloakAuthManagerUser
) -> bool:
return self._is_authorized(
method=method, resource_type=KeycloakResource.CUSTOM, user=user
) or self._is_authorized(
method=method, resource_type=KeycloakResource.CUSTOM, user=user, resource_id=resource_name
)
[docs]
def get_fastapi_app(self) -> FastAPI | None:
from airflow.providers.keycloak.auth_manager.routes.login import login_router
app = FastAPI(
title="Keycloak auth manager sub application",
description=(
"This is the Keycloak auth manager fastapi sub application. This API is only available if the "
"auth manager used in the Airflow environment is Keycloak auth manager. "
"This sub application provides login routes."
),
)
app.include_router(login_router)
return app
def _is_authorized(
self,
*,
method: ResourceMethod | str,
resource_type: KeycloakResource,
user: KeycloakAuthManagerUser,
resource_id: str | None = None,
) -> bool:
client_id = conf.get("keycloak_auth_manager", "client_id")
realm = conf.get("keycloak_auth_manager", "realm")
server_url = conf.get("keycloak_auth_manager", "server_url")
permission = (
f"{resource_type.value}:{resource_id}#{method}"
if resource_id
else f"{resource_type.value}#{method}"
)
resp = requests.post(
self._get_token_url(server_url, realm),
data=self._get_payload(client_id, permission),
headers=self._get_headers(user.access_token),
)
if resp.status_code == 200:
return True
if resp.status_code == 403:
return False
if resp.status_code == 400:
error = json.loads(resp.text)
if error.get("error") == "invalid_resource":
log.debug(error["error_description"])
return False
raise AirflowException(
f"Request not recognized by Keycloak. {error.get('error')}. {error.get('error_description')}"
)
raise AirflowException(f"Unexpected error: {resp.status_code} - {resp.text}")
@staticmethod
def _get_token_url(server_url, realm):
return f"{server_url}/realms/{realm}/protocol/openid-connect/token"
@staticmethod
def _get_payload(client_id, permission):
return {
"grant_type": "urn:ietf:params:oauth:grant-type:uma-ticket",
"audience": client_id,
"permission": permission,
}
@staticmethod
def _get_headers(access_token):
return {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/x-www-form-urlencoded",
}