Source code for airflow.providers.edge.worker_api.auth
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
from functools import cache
from uuid import uuid4
from itsdangerous import BadSignature
from jwt import (
ExpiredSignatureError,
ImmatureSignatureError,
InvalidAudienceError,
InvalidIssuedAtError,
InvalidSignatureError,
)
from airflow.configuration import conf
from airflow.providers.edge.worker_api.datamodels import JsonRpcRequestBase # noqa: TCH001
from airflow.providers.edge.worker_api.routes._v2_compat import (
Header,
HTTPException,
Request,
status,
)
from airflow.utils.jwt_signer import JWTSigner
[docs]log = logging.getLogger(__name__)
@cache
[docs]def jwt_signer() -> JWTSigner:
clock_grace = conf.getint("core", "internal_api_clock_grace", fallback=30)
return JWTSigner(
secret_key=conf.get("core", "internal_api_secret_key"),
expiration_time_in_seconds=clock_grace,
leeway_in_seconds=clock_grace,
audience="api",
)
def _forbidden_response(message: str):
"""Log the error and return the response anonymized."""
error_id = uuid4()
log.exception("%s error_id=%s", message, error_id)
raise HTTPException(
status.HTTP_403_FORBIDDEN,
f"Forbidden. The server side traceback may be identified with error_id={error_id}",
)
[docs]def jwt_token_authorization(method: str, authorization: str):
"""Check if the JWT token is correct."""
try:
# worker sends method without api_url
api_url = conf.get("edge", "api_url")
base_url = conf.get("webserver", "base_url")
url_prefix = api_url.replace(base_url, "").replace("/rpcapi", "/")
pure_method = method.replace(url_prefix, "")
payload = jwt_signer().verify_token(authorization)
signed_method = payload.get("method")
if not signed_method or signed_method != pure_method:
_forbidden_response(
"Invalid method in token authorization. "
f"signed method='{signed_method}' "
f"called method='{pure_method}'",
)
except BadSignature:
_forbidden_response("Bad Signature. Please use only the tokens provided by the API.")
except InvalidAudienceError:
_forbidden_response("Invalid audience for the request")
except InvalidSignatureError:
_forbidden_response("The signature of the request was wrong")
except ImmatureSignatureError:
_forbidden_response("The signature of the request was sent from the future")
except ExpiredSignatureError:
_forbidden_response(
"The signature of the request has expired. Make sure that all components "
"in your system have synchronized clocks.",
)
except InvalidIssuedAtError:
_forbidden_response(
"The request was issues in the future. Make sure that all components "
"in your system have synchronized clocks.",
)
except Exception:
_forbidden_response("Unable to authenticate API via token.")
[docs]def jwt_token_authorization_rpc(
body: JsonRpcRequestBase, authorization: str = Header(description="JWT Authorization Token")
):
"""Check if the JWT token is correct for JSON PRC requests."""
jwt_token_authorization(body.method, authorization)
[docs]def jwt_token_authorization_rest(
request: Request, authorization: str = Header(description="JWT Authorization Token")
):
"""Check if the JWT token is correct for REST API requests."""
jwt_token_authorization(request.url.path, authorization)