# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import json
import logging
import os
from datetime import datetime
from http import HTTPStatus
from pathlib import Path
from typing import TYPE_CHECKING, Any
from urllib.parse import quote, urljoin
import requests
from retryhttp import retry, wait_retry_after
from tenacity import before_log, wait_random_exponential
from airflow.configuration import conf
from airflow.providers.edge3.models.edge_worker import EdgeWorkerVersionException
from airflow.providers.edge3.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.providers.edge3.worker_api.datamodels import (
    EdgeJobFetched,
    PushLogsBody,
    WorkerQueuesBody,
    WorkerRegistrationReturn,
    WorkerSetStateReturn,
    WorkerStateBody,
)
from airflow.utils.state import TaskInstanceState  # noqa: TC001
if TYPE_CHECKING:
    from airflow.models.taskinstancekey import TaskInstanceKey
    from airflow.providers.edge3.models.edge_worker import EdgeWorkerState
[docs]
logger = logging.getLogger(__name__) 
# Hidden config options for Edge Worker how retries on HTTP requests should be handled
# Note: Given defaults make attempts after 1, 3, 7, 15, 31seconds, 1:03, 2:07, 3:37 and fails after 5:07min
# So far there is no other config facility in Task SDK we use ENV for the moment
# TODO: Consider these env variables jointly in task sdk together with task_sdk/src/airflow/sdk/api/client.py
[docs]
API_RETRIES = int(
    os.getenv("AIRFLOW__EDGE__API_RETRIES", os.getenv("AIRFLOW__WORKERS__API_RETRIES", str(10)))
) 
[docs]
API_RETRY_WAIT_MIN = float(
    os.getenv(
        "AIRFLOW__EDGE__API_RETRY_WAIT_MIN", os.getenv("AIRFLOW__WORKERS__API_RETRY_WAIT_MIN", str(1.0))
    )
) 
[docs]
API_RETRY_WAIT_MAX = float(
    os.getenv(
        "AIRFLOW__EDGE__API_RETRY_WAIT_MAX", os.getenv("AIRFLOW__WORKERS__API_RETRY_WAIT_MAX", str(90.0))
    )
) 
_default_wait = wait_random_exponential(min=API_RETRY_WAIT_MIN, max=API_RETRY_WAIT_MAX)
@retry(
    reraise=True,
    max_attempt_number=API_RETRIES,
    wait_server_errors=_default_wait,
    wait_network_errors=_default_wait,
    wait_timeouts=_default_wait,
    wait_rate_limited=wait_retry_after(fallback=_default_wait),  # No infinite timeout on HTTP 429
    before_sleep=before_log(logger, logging.WARNING),
)
def _make_generic_request(method: str, rest_path: str, data: str | None = None) -> Any:
    if AIRFLOW_V_3_0_PLUS:
        from functools import cache
        from airflow.api_fastapi.auth.tokens import JWTGenerator
        @cache
        def jwt_generator() -> JWTGenerator:
            return JWTGenerator(
                secret_key=conf.get("api_auth", "jwt_secret"),
                valid_for=conf.getint("api_auth", "jwt_leeway", fallback=30),
                audience="api",
            )
        generator = jwt_generator()
        authorization = generator.generate({"method": rest_path})
    else:
        # Airflow 2.10 compatibility
        from airflow.providers.edge3.worker_api.auth import jwt_signer
        signer = jwt_signer()
        authorization = signer.generate_signed_token({"method": rest_path})
    api_url = conf.get("edge", "api_url")
    headers = {
        "Content-Type": "application/json",
        "Accept": "application/json",
        "Authorization": authorization,
    }
    api_endpoint = urljoin(api_url, rest_path)
    response = requests.request(method, url=api_endpoint, data=data, headers=headers)
    response.raise_for_status()
    if response.status_code == HTTPStatus.NO_CONTENT:
        return None
    return json.loads(response.content)
[docs]
def worker_register(
    hostname: str, state: EdgeWorkerState, queues: list[str] | None, sysinfo: dict
) -> WorkerRegistrationReturn:
    """Register worker with the Edge API."""
    try:
        result = _make_generic_request(
            "POST",
            f"worker/{quote(hostname)}",
            WorkerStateBody(state=state, jobs_active=0, queues=queues, sysinfo=sysinfo).model_dump_json(
                exclude_unset=True
            ),
        )
    except requests.HTTPError as e:
        if e.response.status_code == 400:
            raise EdgeWorkerVersionException(str(e))
        raise e
    return WorkerRegistrationReturn(**result) 
[docs]
def worker_set_state(
    hostname: str,
    state: EdgeWorkerState,
    jobs_active: int,
    queues: list[str] | None,
    sysinfo: dict,
    maintenance_comments: str | None = None,
) -> WorkerSetStateReturn:
    """Update the state of the worker in the central site and thereby implicitly heartbeat."""
    try:
        result = _make_generic_request(
            "PATCH",
            f"worker/{quote(hostname)}",
            WorkerStateBody(
                state=state,
                jobs_active=jobs_active,
                queues=queues,
                sysinfo=sysinfo,
                maintenance_comments=maintenance_comments,
            ).model_dump_json(exclude_unset=True),
        )
    except requests.HTTPError as e:
        if e.response.status_code == 400:
            raise EdgeWorkerVersionException(str(e))
        raise e
    return WorkerSetStateReturn(**result) 
[docs]
def jobs_fetch(hostname: str, queues: list[str] | None, free_concurrency: int) -> EdgeJobFetched | None:
    """Fetch a job to execute on the edge worker."""
    result = _make_generic_request(
        "POST",
        f"jobs/fetch/{quote(hostname)}",
        WorkerQueuesBody(queues=queues, free_concurrency=free_concurrency).model_dump_json(
            exclude_unset=True
        ),
    )
    if result:
        return EdgeJobFetched(**result)
    return None 
[docs]
def jobs_set_state(key: TaskInstanceKey, state: TaskInstanceState) -> None:
    """Set the state of a job."""
    _make_generic_request(
        "PATCH",
        f"jobs/state/{key.dag_id}/{key.task_id}/{key.run_id}/{key.try_number}/{key.map_index}/{state}",
    ) 
[docs]
def logs_logfile_path(task: TaskInstanceKey) -> Path:
    """Elaborate the path and filename to expect from task execution."""
    result = _make_generic_request(
        "GET",
        f"logs/logfile_path/{task.dag_id}/{task.task_id}/{task.run_id}/{task.try_number}/{task.map_index}",
    )
    base_log_folder = conf.get("logging", "base_log_folder", fallback="NOT AVAILABLE")
    return Path(base_log_folder, result) 
[docs]
def logs_push(
    task: TaskInstanceKey,
    log_chunk_time: datetime,
    log_chunk_data: str,
) -> None:
    """Push an incremental log chunk from Edge Worker to central site."""
    _make_generic_request(
        "POST",
        f"logs/push/{task.dag_id}/{task.task_id}/{task.run_id}/{task.try_number}/{task.map_index}",
        PushLogsBody(log_chunk_time=log_chunk_time, log_chunk_data=log_chunk_data).model_dump_json(
            exclude_unset=True
        ),
    )