Source code for airflow.providers.edge3.cli.worker

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging
import os
import signal
import sys
import time
import traceback
from asyncio import Task, create_task, gather, get_running_loop, sleep
from collections.abc import Awaitable, Callable
from contextlib import suppress
from datetime import datetime
from functools import cached_property
from http import HTTPStatus
from multiprocessing import Process, Queue
from pathlib import Path
from typing import TYPE_CHECKING

import anyio
from aiofiles import open as aio_open
from aiohttp import ClientResponseError
from lockfile.pidlockfile import remove_existing_pidfile

from airflow import __version__ as airflow_version
from airflow.providers.common.compat.sdk import conf, timezone
from airflow.providers.edge3 import __version__ as edge_provider_version
from airflow.providers.edge3.cli.api_client import (
    jobs_fetch,
    jobs_set_state,
    logs_push,
    worker_register,
    worker_set_state,
)
from airflow.providers.edge3.cli.dataclasses import Job, MaintenanceMarker, WorkerStatus
from airflow.providers.edge3.cli.signalling import (
    SIG_STATUS,
    maintenance_marker_file_path,
    status_file_path,
    write_pid_to_pidfile,
)
from airflow.providers.edge3.models.edge_worker import (
    EdgeWorkerDuplicateException,
    EdgeWorkerState,
    EdgeWorkerVersionException,
)
from airflow.providers.edge3.version_compat import AIRFLOW_V_3_2_PLUS, AIRFLOW_V_3_3_PLUS
from airflow.utils.net import getfqdn
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
    from airflow.configuration import AirflowConfigParser
    from airflow.executors.workloads import ExecuteTask

[docs] logger = logging.getLogger(__name__)
if sys.platform == "darwin":
[docs] setproctitle = lambda title: logger.debug("Mac OS detected, skipping setproctitle")
else: from setproctitle import setproctitle def _edge_hostname() -> str: """Get the hostname of the edge worker that should be reported by tasks.""" return os.environ.get("HOSTNAME", getfqdn()) def _reset_parent_signal_state() -> None: """ Detach a forked child from the parent's asyncio signal plumbing. The parent installs asyncio signal handlers for SIGTERM/SIGINT/SIG_STATUS via ``loop.add_signal_handler``, which internally uses ``signal.set_wakeup_fd`` on one end of a shared socketpair. On Linux ``fork()`` duplicates that fd into the child; signals delivered to the child then write bytes into the socketpair the parent is reading from, re-firing the parent's handlers. Reset the inherited state before the child runs any supervised code. """ signal.set_wakeup_fd(-1) signal.signal(signal.SIGTERM, signal.SIG_DFL) signal.signal(signal.SIGINT, signal.SIG_DFL) with suppress(ValueError, OSError): signal.signal(SIG_STATUS, signal.SIG_DFL)
[docs] class EdgeWorker: """Runner instance which executes the Edge Worker."""
[docs] jobs: list[Job] = []
"""List of jobs that the worker is running currently."""
[docs] drain: bool = False
"""Flag if job processing should be completed and no new jobs fetched for a graceful stop/shutdown."""
[docs] drain_started_at: float | None = None
"""``time.monotonic()`` timestamp of when drain was first requested."""
[docs] drain_timed_out: bool = False
[docs] drain_kill_deadline: float | None = None
[docs] maintenance_mode: bool = False
"""Flag if job processing should be completed and no new jobs fetched for maintenance mode. """
[docs] maintenance_comments: str | None = None
"""Comments for maintenance mode."""
[docs] versions_match: bool = True
"""Whether the worker and the server have matching versions of Airflow and the Edge Provider."""
[docs] background_tasks: set[Task] = set()
def __init__( self, pid_file_path: str, hostname: str, queues: list[str] | None, concurrency: int, daemon: bool = False, team_name: str | None = None, ):
[docs] self.pid_file_path = pid_file_path
[docs] self.hostname = hostname
[docs] self.queues = queues
[docs] self.concurrency = concurrency
[docs] self.daemon = daemon
[docs] self.team_name = team_name
[docs] self.worker_start_time: datetime = datetime.now()
if TYPE_CHECKING: self.conf: ExecutorConf | AirflowConfigParser if AIRFLOW_V_3_2_PLUS: from airflow.executors.base_executor import ExecutorConf self.conf = ExecutorConf(team_name) else: self.conf = conf
[docs] self.job_poll_interval = self.conf.getint("edge", "job_poll_interval")
[docs] self.hb_interval = self.conf.getint("edge", "heartbeat_interval")
[docs] self.drain_timeout_sec = self.conf.getint("edge", "drain_timeout_sec")
[docs] self.drain_kill_grace_sec = self.conf.getint("edge", "drain_kill_grace_sec")
[docs] self.base_log_folder: str = ( self.conf.get("logging", "base_log_folder", fallback="NOT AVAILABLE") or "" )
[docs] self.push_logs = self.conf.getboolean("edge", "push_logs")
[docs] self.push_log_chunk_size = self.conf.getint("edge", "push_log_chunk_size")
[docs] self.automatic_maintenance_on = ( self.conf.get("edge", "automatic_maintenance_on", fallback="Off") or "Off" ).lower()
[docs] self.automatic_maintenance_exit = ( self.conf.get("edge", "automatic_maintenance_exit", fallback="Off") or "Off" ).lower()
[docs] self.extended_sysinfo: Callable[[], Awaitable[dict[str, str | int | float | datetime]]] | None = None
extended_sysinfo_func_path = self.conf.get("edge", "extended_system_info_function", fallback=None) if extended_sysinfo_func_path: module_path, func_name = extended_sysinfo_func_path.rsplit(".", 1) try: module = __import__(module_path, fromlist=[func_name]) self.extended_sysinfo = getattr(module, func_name) logger.info("Using extended sysinfo function: %s", extended_sysinfo_func_path) except Exception: logger.exception( "Failed to import extended sysinfo function %s, skipping it.", extended_sysinfo_func_path, ) @cached_property def _execution_api_server_url(self) -> str | None: """Get the execution api server url from config or environment.""" execution_api_server_url = self.conf.get("core", "execution_api_server_url", fallback="") if not execution_api_server_url: # Derive execution api url from edge api url as fallback api_url = self.conf.get("edge", "api_url") execution_api_server_url = ( api_url.replace("edge_worker/v1/rpcapi", "execution") if api_url is not None else None ) logger.info("Using execution api server url: %s", execution_api_server_url) return execution_api_server_url @property
[docs] def free_concurrency(self) -> int: """Calculate the free concurrency of the worker.""" used_concurrency = sum(job.edge_job.concurrency_slots for job in self.jobs) return self.concurrency - used_concurrency
[docs] def signal_status(self): marker_path = Path(maintenance_marker_file_path(None)) if marker_path.exists(): request = MaintenanceMarker.from_json(marker_path.read_text()) logger.info("Requested to set maintenance mode to %s", request.maintenance) self.maintenance_mode = request.maintenance == "on" if self.maintenance_mode and request.comments: logger.info("Comments: %s", request.comments) self.maintenance_comments = request.comments marker_path.unlink() # send heartbeat immediately to update state task = get_running_loop().create_task(self.heartbeat(self.maintenance_comments)) self.background_tasks.add(task) task.add_done_callback(self.background_tasks.discard) else: logger.info("Request to get status of Edge Worker received.") status_path = Path(status_file_path(None)) status_path.write_text( WorkerStatus( job_count=len(self.jobs), jobs=[job.edge_job.key for job in self.jobs], state=self._get_state(), maintenance=self.maintenance_mode, maintenance_comments=self.maintenance_comments, drain=self.drain, ).json )
[docs] def signal_drain(self): if self._start_draining(): logger.info( "Request to shut down Edge Worker received, waiting for jobs to complete. %s", self._drain_policy_description(), )
[docs] def shutdown_handler(self): if not self._start_draining(): return logger.info("SIGTERM received. Sending SIGTERM to all jobs and quit") try: loop = get_running_loop() except RuntimeError: pass else: task = loop.create_task( self._push_drain_notice_to_all_jobs( "Edge worker received external SIGTERM; terminating task supervisor." ) ) self.background_tasks.add(task) task.add_done_callback(self.background_tasks.discard) self._terminate_jobs(signal.SIGTERM)
def _start_draining(self) -> bool: """Mark drain start. Returns ``True`` on the first call, ``False`` on subsequent calls.""" if self.drain: return False self.drain = True self.drain_started_at = time.monotonic() return True def _drain_policy_description(self) -> str: """One-line description of the configured drain-timeout policy for logging.""" if self.drain_timeout_sec <= 0: return "Drain timeout disabled; will wait indefinitely." return ( f"Drain timeout: {self.drain_timeout_sec}s (then SIGTERM), " f"kill grace: {self.drain_kill_grace_sec}s (then SIGKILL)." ) def _terminate_jobs(self, sig: int) -> None: """Send ``sig`` to every running job process. Safe to call repeatedly.""" for job in self.jobs: if job.process.pid: with suppress(ProcessLookupError, PermissionError): os.kill(job.process.pid, sig) async def _push_drain_notice_to_all_jobs(self, message: str) -> None: """Best-effort push of ``message`` into each running job's task log stream.""" async def push_one(job: Job) -> None: try: await logs_push( task=job.edge_job.key, log_chunk_time=timezone.utcnow(), log_chunk_data=f"{message}\n", ) except Exception: logger.exception("Failed to push drain notice to task log for %s", job.edge_job.identifier) await gather(*(push_one(job) for job in list(self.jobs))) async def _enforce_drain_timeout(self) -> bool: """ Apply drain-timeout policy when configured. Two-phase escalation: once ``drain_timeout_sec`` elapses, SIGTERM remaining jobs; after ``drain_kill_grace_sec`` more, SIGKILL and return ``True`` so the loop exits. Returns ``False`` otherwise (not configured, not draining, deadline not hit, or waiting out grace). """ if self.drain_timeout_sec <= 0: return False if not self.drain or not self.jobs or self.drain_started_at is None: return False now = time.monotonic() if now - self.drain_started_at < self.drain_timeout_sec: return False if not self.drain_timed_out: self.drain_timed_out = True self.drain_kill_deadline = now + self.drain_kill_grace_sec logger.warning( "Drain timeout of %ds exceeded with %d job(s) still running. Sending SIGTERM.", self.drain_timeout_sec, len(self.jobs), ) await self._push_drain_notice_to_all_jobs( f"Edge worker drain timeout of {self.drain_timeout_sec}s expired; " f"sending SIGTERM to task supervisor. " f"Will escalate to SIGKILL after {self.drain_kill_grace_sec}s grace." ) self._terminate_jobs(signal.SIGTERM) return False if self.drain_kill_deadline is not None and now >= self.drain_kill_deadline: logger.warning( "Drain kill grace of %ds exceeded with %d job(s) still running. Sending SIGKILL and exiting.", self.drain_kill_grace_sec, len(self.jobs), ) await self._push_drain_notice_to_all_jobs( f"Edge worker drain kill-grace of {self.drain_kill_grace_sec}s expired; " f"sending SIGKILL and exiting worker." ) self._terminate_jobs(signal.SIGKILL) return True return False def _adjust_maintenance_mode_based_on_sysinfo( self, sysinfo: dict[str, str | int | float | datetime] ) -> None: """Adjust maintenance mode based on sysinfo status and config.""" status: int = sysinfo.get("status") # type: ignore if not self.maintenance_mode and ( (status >= logging.WARNING and self.automatic_maintenance_on == "warning") or (status >= logging.ERROR and self.automatic_maintenance_on == "error") ): logger.info( "Entering maintenance mode due to status %s in sysinfo.", logging.getLevelName(status) ) self.maintenance_mode = True self.maintenance_comments = f"[{datetime.now().strftime('%Y-%m-%d %H:%M')}] - Automatic maintenance mode entered due to status {logging.getLevelName(status)} in sysinfo." elif ( self.maintenance_mode and self.maintenance_comments and "] - Automatic maintenance mode entered due to status " in self.maintenance_comments and ( (status < logging.WARNING and self.automatic_maintenance_exit == "info") or (status < logging.ERROR and self.automatic_maintenance_exit == "warning") ) ): logger.info("Exiting maintenance mode due to status %s in sysinfo.", logging.getLevelName(status)) self.maintenance_mode = False self.maintenance_comments = None async def _get_sysinfo(self) -> dict[str, str | int | float | datetime]: """Produce the sysinfo from worker to post to central site.""" sysinfo: dict[str, str | int | float | datetime] = { **( { "status": logging.INFO, } if self.versions_match else { "status": logging.WARNING, "status_text": "Healthy but version mismatch", "version_mismatch_description": "The version between the Edge Worker and the " "Airflow Core is not matching for either the edge or airflow package version. " "Please check if the Edge Provider version is compatible with your Airflow " "version. The worker will still operate but you might miss some features or " "have issues. Please consider upgrading the Edge Provider to a compatible " "version.", } ), "airflow_version": airflow_version, "edge_provider_version": edge_provider_version, "python_version": sys.version, "worker_start_time": self.worker_start_time, "concurrency": self.concurrency, "free_concurrency": self.free_concurrency, } if self.extended_sysinfo: try: sysinfo.update(await self.extended_sysinfo()) except Exception: logger.exception("Failed to get extended sysinfo, skipping it.") # After grabbing status, check if we need to enter/exit maintenance mode self._adjust_maintenance_mode_based_on_sysinfo(sysinfo) return sysinfo def _get_state(self) -> EdgeWorkerState: """State of the Edge Worker.""" if self.jobs: if self.drain: return EdgeWorkerState.TERMINATING if self.maintenance_mode: return EdgeWorkerState.MAINTENANCE_PENDING return EdgeWorkerState.RUNNING if self.drain: if self.maintenance_mode: return EdgeWorkerState.OFFLINE_MAINTENANCE return EdgeWorkerState.OFFLINE if self.maintenance_mode: return EdgeWorkerState.MAINTENANCE_MODE return EdgeWorkerState.IDLE def _run_job_via_supervisor(self, workload: ExecuteTask, results_queue: Queue) -> int: _reset_parent_signal_state() # Ignore ctrl-c in this process -- we don't want to kill _this_ one. we let tasks run to completion os.setpgrp() logger.info("Worker starting up pid=%d", os.getpid()) try: if AIRFLOW_V_3_3_PLUS: from airflow.executors.base_executor import BaseExecutor BaseExecutor.run_workload( workload=workload, server=self._execution_api_server_url, ) else: from airflow.sdk.execution_time.supervisor import supervise ti = workload.ti setproctitle( "airflow edge supervisor: " f"dag_id={ti.dag_id} task_id={ti.task_id} run_id={ti.run_id} map_index={ti.map_index} " f"try_number={ti.try_number}" ) supervise( # This is the "wrong" ti type, but it duck types the same. TODO: Create a protocol for this. # Same like in airflow/executors/local_executor.py:_execute_workload() ti=ti, # type: ignore[arg-type] dag_rel_path=workload.dag_rel_path, bundle_info=workload.bundle_info, token=workload.token, server=self._execution_api_server_url, log_path=workload.log_path, ) results_queue.put("OK") return 0 except Exception as e: logger.exception("Task execution failed") results_queue.put(e) return 1 def _launch_job(self, workload: ExecuteTask) -> tuple[Process, Queue[Exception]]: # Improvement: Use frozen GC to prevent child process from copying unnecessary memory # See _spawn_workers_with_gc_freeze() in airflow-core/src/airflow/executors/local_executor.py results_queue: Queue[Exception] = Queue() process = Process( target=self._run_job_via_supervisor, kwargs={"workload": workload, "results_queue": results_queue}, ) process.start() return process, results_queue async def _push_logs_in_chunks(self, job: Job): aio_logfile = anyio.Path(job.logfile) if self.push_logs and await aio_logfile.exists() and (await aio_logfile.stat()).st_size > job.logsize: async with aio_open(job.logfile, mode="rb") as logf: await logf.seek(job.logsize, os.SEEK_SET) read_data = await logf.read() job.logsize += len(read_data) # backslashreplace to keep not decoded characters and not raising exception # replace null with question mark to fix issue during DB push log_data = read_data.decode(errors="backslashreplace").replace("\x00", "\ufffd") while True: chunk_data = log_data[: self.push_log_chunk_size] log_data = log_data[self.push_log_chunk_size :] if not chunk_data: break await logs_push( task=job.edge_job.key, log_chunk_time=timezone.utcnow(), log_chunk_data=chunk_data, )
[docs] async def start(self): """Start the execution in a loop until terminated.""" try: sysinfo = await self._get_sysinfo() register_result = await worker_register( self.hostname, EdgeWorkerState.MAINTENANCE_MODE if self.maintenance_mode else EdgeWorkerState.STARTING, self.queues, sysinfo, self.team_name, ) self.versions_match = register_result.versions_match except EdgeWorkerVersionException as e: logger.info("Version mismatch of Edge worker and Core. Shutting down worker.") raise SystemExit(str(e)) except EdgeWorkerDuplicateException as e: logger.error(str(e)) raise SystemExit(str(e)) except ClientResponseError as e: # Note: Method not allowed is raised by FastAPI if the API is not enabled (not 404) if e.status in {HTTPStatus.NOT_FOUND, HTTPStatus.METHOD_NOT_ALLOWED}: raise SystemExit( "Error: API endpoint is not ready, please set [edge] api_enabled=True. Or check if the URL is correct to your deployment." ) raise SystemExit(str(e)) if not self.daemon: write_pid_to_pidfile(self.pid_file_path) loop = get_running_loop() loop.add_signal_handler(signal.SIGINT, self.signal_drain) loop.add_signal_handler(SIG_STATUS, self.signal_status) loop.add_signal_handler(signal.SIGTERM, self.shutdown_handler) setproctitle(f"airflow edge worker: {self.hostname}") os.environ["HOSTNAME"] = self.hostname os.environ["AIRFLOW__CORE__HOSTNAME_CALLABLE"] = f"{_edge_hostname.__module__}._edge_hostname" try: await self.loop() logger.info("Quitting worker, signal being offline.") try: sysinfo = await self._get_sysinfo() sysinfo["status"] = logging.NOTSET if "status_text" in sysinfo: del sysinfo["status_text"] # Remove old status text if exists await worker_set_state( self.hostname, EdgeWorkerState.OFFLINE_MAINTENANCE if self.maintenance_mode else EdgeWorkerState.OFFLINE, 0, self.queues, sysinfo, team_name=self.team_name, ) except EdgeWorkerVersionException: logger.info("Version mismatch of Edge worker and Core. Quitting worker anyway.") finally: if not self.daemon: remove_existing_pidfile(self.pid_file_path)
[docs] async def loop(self): """Run a loop of scheduling and monitoring tasks.""" last_hb = datetime.now() worker_state_changed = True # force heartbeat at start previous_jobs = 0 while not self.drain or self.jobs: if ( self.drain or datetime.now().timestamp() - last_hb.timestamp() > self.hb_interval or worker_state_changed # send heartbeat immediately if the state is different in db or previous_jobs != len(self.jobs) # when number of jobs changes ): worker_state_changed = await self.heartbeat() last_hb = datetime.now() previous_jobs = len(self.jobs) if self.maintenance_mode: logger.info("in maintenance mode%s", f", {len(self.jobs)} draining jobs" if self.jobs else "") elif not self.drain and self.free_concurrency > 0: task = create_task(self.fetch_and_run_job()) self.background_tasks.add(task) task.add_done_callback(self.background_tasks.discard) else: logger.info("%i %s running", len(self.jobs), "job is" if len(self.jobs) == 1 else "jobs are") if await self._enforce_drain_timeout(): break await self.interruptible_sleep()
[docs] async def fetch_and_run_job(self) -> None: """Fetch, start and monitor a new job.""" logger.debug("Attempting to fetch a new job...") edge_job = await jobs_fetch(self.hostname, self.queues, self.free_concurrency, self.team_name) if not edge_job: logger.info( "No new job to process%s", f", {len(self.jobs)} still running" if self.jobs else "", ) return logger.info("Received job: %s", edge_job.identifier) workload: ExecuteTask = edge_job.command process, results_queue = self._launch_job(workload) if TYPE_CHECKING: assert workload.log_path # We need to assume this is defined in here logfile = Path(self.base_log_folder, workload.log_path) job = Job(edge_job, process, logfile) self.jobs.append(job) await jobs_set_state(edge_job.key, TaskInstanceState.RUNNING) # As we got one job, directly fetch another one if possible if self.free_concurrency > 0: task = create_task(self.fetch_and_run_job()) self.background_tasks.add(task) task.add_done_callback(self.background_tasks.discard) while job.is_running and results_queue.empty(): await self._push_logs_in_chunks(job) for _ in range(0, self.job_poll_interval * 10): await sleep(0.1) if not job.is_running: break await self._push_logs_in_chunks(job) supervisor_msg = ( "(Unknown error, no exception details available)" if results_queue.empty() else results_queue.get() ) # Ensure that supervisor really ended after we grabbed results from queue while True: if not job.is_running: break await sleep(0.1) self.jobs.remove(job) if job.is_success: logger.info("Job completed: %s", job.edge_job.identifier) await jobs_set_state(job.edge_job.key, TaskInstanceState.SUCCESS) else: if isinstance(supervisor_msg, Exception): supervisor_msg = "\n".join(traceback.format_exception(supervisor_msg)) logger.error("Job failed: %s with:\n%s", job.edge_job.identifier, supervisor_msg) # Push it upwards to logs for better diagnostic as well await logs_push( task=job.edge_job.key, log_chunk_time=timezone.utcnow(), log_chunk_data=f"Error executing job:\n{supervisor_msg}", ) await jobs_set_state(job.edge_job.key, TaskInstanceState.FAILED)
[docs] async def heartbeat(self, new_maintenance_comments: str | None = None) -> bool: """Report liveness state of worker to central site with stats.""" sysinfo = await self._get_sysinfo() state = self._get_state() worker_state_changed: bool = False try: worker_info = await worker_set_state( self.hostname, state, len(self.jobs), self.queues, sysinfo, new_maintenance_comments or self.maintenance_comments, team_name=self.team_name, ) self.versions_match = worker_info.versions_match self.queues = worker_info.queues if worker_info.concurrency is not None and worker_info.concurrency != self.concurrency: logger.info( "Concurrency updated from %d to %d by remote request.", self.concurrency, worker_info.concurrency, ) self.concurrency = worker_info.concurrency if worker_info.state == EdgeWorkerState.MAINTENANCE_REQUEST: logger.info("Maintenance mode requested!") self.maintenance_mode = True elif ( worker_info.state in [EdgeWorkerState.IDLE, EdgeWorkerState.RUNNING] and self.maintenance_mode ): logger.info("Exit Maintenance mode requested!") self.maintenance_mode = False if self.maintenance_mode: self.maintenance_comments = worker_info.maintenance_comments else: self.maintenance_comments = None if worker_info.state == EdgeWorkerState.SHUTDOWN_REQUEST and self._start_draining(): logger.info("Shutdown requested!") worker_state_changed = worker_info.state != state except EdgeWorkerVersionException: logger.info("Version mismatch of Edge worker and Core. Shutting down worker.") self._start_draining() return worker_state_changed
[docs] async def interruptible_sleep(self): """Sleeps but stops sleeping if drain is made or some job completed.""" drain_before_sleep = self.drain jobcount_before_sleep = len(self.jobs) for _ in range(0, self.job_poll_interval * 10): await sleep(0.1) if drain_before_sleep != self.drain or len(self.jobs) < jobcount_before_sleep: return

Was this entry helpful?