# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import json
import logging
import os
from datetime import datetime
from http import HTTPStatus
from pathlib import Path
from typing import TYPE_CHECKING, Any
from urllib.parse import quote, urljoin
import requests
from retryhttp import retry, wait_retry_after
from tenacity import before_log, wait_random_exponential
from airflow.configuration import conf
from airflow.providers.edge.models.edge_worker import EdgeWorkerVersionException
from airflow.providers.edge.worker_api.auth import jwt_signer
from airflow.providers.edge.worker_api.datamodels import (
EdgeJobFetched,
PushLogsBody,
WorkerQueuesBody,
WorkerSetStateReturn,
WorkerStateBody,
)
from airflow.utils.state import TaskInstanceState # noqa: TC001
if TYPE_CHECKING:
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.providers.edge.models.edge_worker import EdgeWorkerState
[docs]
logger = logging.getLogger(__name__)
# Hidden config options for Edge Worker how retries on HTTP requests should be handled
# Note: Given defaults make attempts after 1, 3, 7, 15, 31seconds, 1:03, 2:07, 3:37 and fails after 5:07min
# So far there is no other config facility in Task SDK we use ENV for the moment
# TODO: Consider these env variables jointly in task sdk together with task_sdk/src/airflow/sdk/api/client.py
[docs]
API_RETRIES = int(os.getenv("AIRFLOW__EDGE__API_RETRIES", os.getenv("AIRFLOW__WORKERS__API_RETRIES", 10)))
[docs]
API_RETRY_WAIT_MIN = float(
os.getenv("AIRFLOW__EDGE__API_RETRY_WAIT_MIN", os.getenv("AIRFLOW__WORKERS__API_RETRY_WAIT_MIN", 1.0))
)
[docs]
API_RETRY_WAIT_MAX = float(
os.getenv("AIRFLOW__EDGE__API_RETRY_WAIT_MAX", os.getenv("AIRFLOW__WORKERS__API_RETRY_WAIT_MAX", 90.0))
)
_default_wait = wait_random_exponential(min=API_RETRY_WAIT_MIN, max=API_RETRY_WAIT_MAX)
@retry(
reraise=True,
max_attempt_number=API_RETRIES,
wait_server_errors=_default_wait,
wait_network_errors=_default_wait,
wait_timeouts=_default_wait,
wait_rate_limited=wait_retry_after(fallback=_default_wait), # No infinite timeout on HTTP 429
before_sleep=before_log(logger, logging.WARNING),
)
def _make_generic_request(method: str, rest_path: str, data: str | None = None) -> Any:
signer = jwt_signer()
api_url = conf.get("edge", "api_url")
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": signer.generate_signed_token({"method": rest_path}),
}
api_endpoint = urljoin(api_url, rest_path)
response = requests.request(method, url=api_endpoint, data=data, headers=headers)
response.raise_for_status()
if response.status_code == HTTPStatus.NO_CONTENT:
return None
return json.loads(response.content)
[docs]
def worker_register(
hostname: str, state: EdgeWorkerState, queues: list[str] | None, sysinfo: dict
) -> datetime:
"""Register worker with the Edge API."""
try:
result = _make_generic_request(
"POST",
f"worker/{quote(hostname)}",
WorkerStateBody(state=state, jobs_active=0, queues=queues, sysinfo=sysinfo).model_dump_json(
exclude_unset=True
),
)
except requests.HTTPError as e:
if e.response.status_code == 400:
raise EdgeWorkerVersionException(str(e))
raise e
return datetime.fromisoformat(result)
[docs]
def worker_set_state(
hostname: str, state: EdgeWorkerState, jobs_active: int, queues: list[str] | None, sysinfo: dict
) -> WorkerSetStateReturn:
"""Update the state of the worker in the central site and thereby implicitly heartbeat."""
try:
result = _make_generic_request(
"PATCH",
f"worker/{quote(hostname)}",
WorkerStateBody(
state=state, jobs_active=jobs_active, queues=queues, sysinfo=sysinfo
).model_dump_json(exclude_unset=True),
)
except requests.HTTPError as e:
if e.response.status_code == 400:
raise EdgeWorkerVersionException(str(e))
raise e
return WorkerSetStateReturn(**result)
[docs]
def jobs_fetch(hostname: str, queues: list[str] | None, free_concurrency: int) -> EdgeJobFetched | None:
"""Fetch a job to execute on the edge worker."""
result = _make_generic_request(
"POST",
f"jobs/fetch/{quote(hostname)}",
WorkerQueuesBody(queues=queues, free_concurrency=free_concurrency).model_dump_json(
exclude_unset=True
),
)
if result:
return EdgeJobFetched(**result)
return None
[docs]
def jobs_set_state(key: TaskInstanceKey, state: TaskInstanceState) -> None:
"""Set the state of a job."""
_make_generic_request(
"PATCH",
f"jobs/state/{key.dag_id}/{key.task_id}/{key.run_id}/{key.try_number}/{key.map_index}/{state}",
)
[docs]
def logs_logfile_path(task: TaskInstanceKey) -> Path:
"""Elaborate the path and filename to expect from task execution."""
result = _make_generic_request(
"GET",
f"logs/logfile_path/{task.dag_id}/{task.task_id}/{task.run_id}/{task.try_number}/{task.map_index}",
)
base_log_folder = conf.get("logging", "base_log_folder", fallback="NOT AVAILABLE")
return Path(base_log_folder, result)
[docs]
def logs_push(
task: TaskInstanceKey,
log_chunk_time: datetime,
log_chunk_data: str,
) -> None:
"""Push an incremental log chunk from Edge Worker to central site."""
_make_generic_request(
"POST",
f"logs/push/{task.dag_id}/{task.task_id}/{task.run_id}/{task.try_number}/{task.map_index}",
PushLogsBody(log_chunk_time=log_chunk_time, log_chunk_data=log_chunk_data).model_dump_json(
exclude_unset=True
),
)