# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import json
from datetime import datetime
from typing import Annotated
from sqlalchemy import select
from airflow.providers.edge.models.edge_worker import EdgeWorkerModel, set_metrics
from airflow.providers.edge.worker_api.auth import jwt_token_authorization_rest
from airflow.providers.edge.worker_api.datamodels import (
WorkerQueueUpdateBody, # noqa: TC001
WorkerStateBody, # noqa: TC001
)
from airflow.providers.edge.worker_api.routes._v2_compat import (
AirflowRouter,
Body,
Depends,
HTTPException,
Path,
SessionDep,
create_openapi_http_exception_doc,
status,
)
from airflow.stats import Stats
from airflow.utils import timezone
[docs]worker_router = AirflowRouter(
tags=["Worker"],
prefix="/worker",
responses=create_openapi_http_exception_doc(
[
status.HTTP_400_BAD_REQUEST,
status.HTTP_403_FORBIDDEN,
]
),
)
def _assert_version(sysinfo: dict[str, str | int]) -> None:
"""Check if the Edge Worker version matches the central API site."""
from airflow import __version__ as airflow_version
from airflow.providers.edge import __version__ as edge_provider_version
# Note: In future, more stable versions we might be more liberate, for the
# moment we require exact version match for Edge Worker and core version
if "airflow_version" in sysinfo:
airflow_on_worker = sysinfo["airflow_version"]
if airflow_on_worker != airflow_version:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
f"Edge Worker runs on Airflow {airflow_on_worker} "
f"and the core runs on {airflow_version}. Rejecting access due to difference.",
)
else:
raise HTTPException(
status.HTTP_400_BAD_REQUEST, "Edge Worker does not specify the version it is running on."
)
if "edge_provider_version" in sysinfo:
provider_on_worker = sysinfo["edge_provider_version"]
if provider_on_worker != edge_provider_version:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
f"Edge Worker runs on Edge Provider {provider_on_worker} "
f"and the core runs on {edge_provider_version}. Rejecting access due to difference.",
)
else:
raise HTTPException(
status.HTTP_400_BAD_REQUEST, "Edge Worker does not specify the provider version it is running on."
)
_worker_name_doc = Path(title="Worker Name", description="Hostname or instance name of the worker")
_worker_state_doc = Body(
title="Worker State",
description="State of the worker with details",
examples=[
{
"state": "running",
"jobs_active": 3,
"queues": ["large_node", "wisconsin_site"],
"sysinfo": {
"concurrency": 4,
"airflow_version": "2.10.0",
"edge_provider_version": "1.0.0",
},
}
],
)
_worker_queue_doc = Body(
title="Changes in worker queues",
description="Changes to be applied to current queues of worker",
examples=[{"new_queues": ["new_queue"], "remove_queues": ["old_queue"]}],
)
@worker_router.post("/{worker_name}", dependencies=[Depends(jwt_token_authorization_rest)])
[docs]def register(
worker_name: Annotated[str, _worker_name_doc],
body: Annotated[WorkerStateBody, _worker_state_doc],
session: SessionDep,
) -> datetime:
"""Register a new worker to the backend."""
_assert_version(body.sysinfo)
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
if not worker:
worker = EdgeWorkerModel(worker_name=worker_name, state=body.state, queues=body.queues)
worker.state = body.state
worker.queues = body.queues
worker.sysinfo = json.dumps(body.sysinfo)
worker.last_update = timezone.utcnow()
session.add(worker)
return worker.last_update
@worker_router.patch("/{worker_name}", dependencies=[Depends(jwt_token_authorization_rest)])
[docs]def set_state(
worker_name: Annotated[str, _worker_name_doc],
body: Annotated[WorkerStateBody, _worker_state_doc],
session: SessionDep,
) -> list[str] | None:
"""Set state of worker and returns the current assigned queues."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
worker.state = body.state
worker.jobs_active = body.jobs_active
worker.sysinfo = json.dumps(body.sysinfo)
worker.last_update = timezone.utcnow()
session.commit()
Stats.incr(f"edge_worker.heartbeat_count.{worker_name}", 1, 1)
Stats.incr("edge_worker.heartbeat_count", 1, 1, tags={"worker_name": worker_name})
set_metrics(
worker_name=worker_name,
state=body.state,
jobs_active=body.jobs_active,
concurrency=int(body.sysinfo.get("concurrency", -1)),
free_concurrency=int(body.sysinfo["free_concurrency"]),
queues=worker.queues,
)
_assert_version(body.sysinfo) # Exception only after worker state is in the DB
return worker.queues
@worker_router.patch(
"/queues/{worker_name}",
dependencies=[Depends(jwt_token_authorization_rest)],
)
[docs]def update_queues(
worker_name: Annotated[str, _worker_name_doc],
body: Annotated[WorkerQueueUpdateBody, _worker_queue_doc],
session: SessionDep,
) -> None:
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
if body.new_queues:
worker.add_queues(body.new_queues)
if body.remove_queues:
worker.remove_queues(body.remove_queues)
session.add(worker)