Source code for airflow.providers.edge.worker_api.auth

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging
from functools import cache
from uuid import uuid4

from itsdangerous import BadSignature
from jwt import (
    ExpiredSignatureError,
    ImmatureSignatureError,
    InvalidAudienceError,
    InvalidIssuedAtError,
    InvalidSignatureError,
)

from airflow.configuration import conf
from airflow.providers.edge.worker_api.datamodels import JsonRpcRequestBase  # noqa: TCH001
from airflow.providers.edge.worker_api.routes._v2_compat import (
    Header,
    HTTPException,
    Request,
    status,
)
from airflow.utils.jwt_signer import JWTSigner

[docs]log = logging.getLogger(__name__)
@cache
[docs]def jwt_signer() -> JWTSigner: clock_grace = conf.getint("core", "internal_api_clock_grace", fallback=30) return JWTSigner( secret_key=conf.get("core", "internal_api_secret_key"), expiration_time_in_seconds=clock_grace, leeway_in_seconds=clock_grace, audience="api", )
def _forbidden_response(message: str): """Log the error and return the response anonymized.""" error_id = uuid4() log.exception("%s error_id=%s", message, error_id) raise HTTPException( status.HTTP_403_FORBIDDEN, f"Forbidden. The server side traceback may be identified with error_id={error_id}", )
[docs]def jwt_token_authorization(method: str, authorization: str): """Check if the JWT token is correct.""" try: # worker sends method without api_url api_url = conf.get("edge", "api_url") base_url = conf.get("webserver", "base_url") url_prefix = api_url.replace(base_url, "").replace("/rpcapi", "/") pure_method = method.replace(url_prefix, "") payload = jwt_signer().verify_token(authorization) signed_method = payload.get("method") if not signed_method or signed_method != pure_method: _forbidden_response( "Invalid method in token authorization. " f"signed method='{signed_method}' " f"called method='{pure_method}'", ) except BadSignature: _forbidden_response("Bad Signature. Please use only the tokens provided by the API.") except InvalidAudienceError: _forbidden_response("Invalid audience for the request") except InvalidSignatureError: _forbidden_response("The signature of the request was wrong") except ImmatureSignatureError: _forbidden_response("The signature of the request was sent from the future") except ExpiredSignatureError: _forbidden_response( "The signature of the request has expired. Make sure that all components " "in your system have synchronized clocks.", ) except InvalidIssuedAtError: _forbidden_response( "The request was issues in the future. Make sure that all components " "in your system have synchronized clocks.", ) except Exception: _forbidden_response("Unable to authenticate API via token.")
[docs]def jwt_token_authorization_rpc( body: JsonRpcRequestBase, authorization: str = Header(description="JWT Authorization Token") ): """Check if the JWT token is correct for JSON PRC requests.""" jwt_token_authorization(body.method, authorization)
[docs]def jwt_token_authorization_rest( request: Request, authorization: str = Header(description="JWT Authorization Token") ): """Check if the JWT token is correct for REST API requests.""" jwt_token_authorization(request.url.path, authorization)

Was this entry helpful?