Source code for airflow.providers.edge3.executors.edge_executor

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from collections.abc import Sequence
from copy import deepcopy
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any

from sqlalchemy import delete, select

from airflow.configuration import conf
from airflow.executors import workloads
from airflow.executors.base_executor import BaseExecutor
from airflow.models.taskinstance import TaskInstance
from airflow.providers.common.compat.sdk import Stats, timezone
from airflow.providers.edge3.models.db import EdgeDBManager, check_db_manager_config
from airflow.providers.edge3.models.edge_job import EdgeJobModel
from airflow.providers.edge3.models.edge_logs import EdgeLogsModel
from airflow.providers.edge3.models.edge_worker import EdgeWorkerModel, EdgeWorkerState, reset_metrics
from airflow.utils.db import DBLocks, create_global_lock
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import TaskInstanceState

if TYPE_CHECKING:
    from sqlalchemy.orm import Session

    from airflow.cli.cli_config import GroupCommand
    from airflow.models.taskinstancekey import TaskInstanceKey

    # TODO: Airflow 2 type hints; remove when Airflow 2 support is removed
[docs] CommandType = Sequence[str]
# Task tuple to send to be executed TaskTuple = tuple[TaskInstanceKey, CommandType, str | None, Any | None]
[docs] PARALLELISM: int = conf.getint("core", "PARALLELISM")
[docs] DEFAULT_QUEUE: str = conf.get_mandatory_value("operators", "default_queue")
[docs] class EdgeExecutor(BaseExecutor): """Implementation of the EdgeExecutor to distribute work to Edge Workers via HTTP.""" def __init__(self, parallelism: int = PARALLELISM): super().__init__(parallelism=parallelism)
[docs] self.last_reported_state: dict[TaskInstanceKey, TaskInstanceState] = {}
@provide_session
[docs] def start(self, session: Session = NEW_SESSION): """If EdgeExecutor provider is loaded first time, ensure table exists.""" check_db_manager_config() edge_db_manager = EdgeDBManager(session) if edge_db_manager.check_migration(): return with create_global_lock(session=session, lock=DBLocks.MIGRATIONS): edge_db_manager.initdb()
def _process_tasks(self, task_tuples: list[TaskTuple]) -> None: """ Temporary overwrite of _process_tasks function. Idea is to not change the interface of the execute_async function in BaseExecutor as it will be changed in Airflow 3. Edge worker needs task_instance in execute_async but BaseExecutor deletes this out of the self.queued_tasks. Store queued_tasks in own var to be able to access this in execute_async function. """ self.edge_queued_tasks = deepcopy(self.queued_tasks) super()._process_tasks(task_tuples) # type: ignore[misc] @provide_session
[docs] def queue_workload( self, workload: workloads.All, session: Session = NEW_SESSION, ) -> None: """Put new workload to queue. Airflow 3 entry point to execute a task.""" if not isinstance(workload, workloads.ExecuteTask): raise TypeError(f"Don't know how to queue workload of type {type(workload).__name__}") task_instance = workload.ti key = task_instance.key # Check if job already exists with same dag_id, task_id, run_id, map_index, try_number existing_job = session.scalars( select(EdgeJobModel).where( EdgeJobModel.dag_id == key.dag_id, EdgeJobModel.task_id == key.task_id, EdgeJobModel.run_id == key.run_id, EdgeJobModel.map_index == key.map_index, EdgeJobModel.try_number == key.try_number, ) ).first() if existing_job: existing_job.state = TaskInstanceState.QUEUED existing_job.queue = task_instance.queue existing_job.concurrency_slots = task_instance.pool_slots existing_job.command = workload.model_dump_json() else: session.add( EdgeJobModel( dag_id=key.dag_id, task_id=key.task_id, run_id=key.run_id, map_index=key.map_index, try_number=key.try_number, state=TaskInstanceState.QUEUED, queue=task_instance.queue, concurrency_slots=task_instance.pool_slots, command=workload.model_dump_json(), ) )
def _check_worker_liveness(self, session: Session) -> bool: """Reset worker state if heartbeat timed out.""" changed = False heartbeat_interval: int = conf.getint("edge", "heartbeat_interval") lifeless_workers: Sequence[EdgeWorkerModel] = session.scalars( select(EdgeWorkerModel) .with_for_update(skip_locked=True) .where( EdgeWorkerModel.state.not_in( [ EdgeWorkerState.UNKNOWN, EdgeWorkerState.OFFLINE, EdgeWorkerState.OFFLINE_MAINTENANCE, ] ), EdgeWorkerModel.last_update < (timezone.utcnow() - timedelta(seconds=heartbeat_interval * 5)), ) ).all() for worker in lifeless_workers: changed = True # If the worker dies in maintenance mode we want to remember it, so it can start in maintenance mode worker.state = ( EdgeWorkerState.OFFLINE_MAINTENANCE if worker.state in ( EdgeWorkerState.MAINTENANCE_MODE, EdgeWorkerState.MAINTENANCE_PENDING, EdgeWorkerState.MAINTENANCE_REQUEST, ) else EdgeWorkerState.UNKNOWN ) reset_metrics(worker.worker_name) return changed def _update_orphaned_jobs(self, session: Session) -> bool: """Update status ob jobs when workers die and don't update anymore.""" heartbeat_interval: int = conf.getint("scheduler", "task_instance_heartbeat_timeout") lifeless_jobs: Sequence[EdgeJobModel] = session.scalars( select(EdgeJobModel) .with_for_update(skip_locked=True) .where( EdgeJobModel.state == TaskInstanceState.RUNNING, EdgeJobModel.last_update < (timezone.utcnow() - timedelta(seconds=heartbeat_interval)), ) ).all() for job in lifeless_jobs: ti = TaskInstance.get_task_instance( dag_id=job.dag_id, run_id=job.run_id, task_id=job.task_id, map_index=job.map_index, session=session, ) job.state = ti.state if ti and ti.state else TaskInstanceState.REMOVED if job.state != TaskInstanceState.RUNNING: # Edge worker does not backport emitted Airflow metrics, so export some metrics # Export metrics as failed as these jobs will be deleted in the future tags = { "dag_id": job.dag_id, "task_id": job.task_id, "queue": job.queue, "state": str(TaskInstanceState.FAILED), } Stats.incr( f"edge_worker.ti.finish.{job.queue}.{TaskInstanceState.FAILED}.{job.dag_id}.{job.task_id}", tags=tags, ) Stats.incr("edge_worker.ti.finish", tags=tags) return bool(lifeless_jobs) def _purge_jobs(self, session: Session) -> bool: """Clean finished jobs.""" purged_marker = False job_success_purge = conf.getint("edge", "job_success_purge") job_fail_purge = conf.getint("edge", "job_fail_purge") jobs: Sequence[EdgeJobModel] = session.scalars( select(EdgeJobModel) .with_for_update(skip_locked=True) .where( EdgeJobModel.state.in_( [ TaskInstanceState.RUNNING, TaskInstanceState.SUCCESS, TaskInstanceState.FAILED, TaskInstanceState.REMOVED, TaskInstanceState.RESTARTING, TaskInstanceState.UP_FOR_RETRY, ] ) ) ).all() # Sync DB with executor otherwise runs out of sync in multi scheduler deployment already_removed = self.running - set(job.key for job in jobs) self.running = self.running - already_removed for job in jobs: if job.key in self.running: if job.state == TaskInstanceState.RUNNING: if ( job.key not in self.last_reported_state or self.last_reported_state[job.key] != job.state ): self.running_state(job.key) self.last_reported_state[job.key] = job.state elif job.state == TaskInstanceState.SUCCESS: if job.key in self.last_reported_state: del self.last_reported_state[job.key] self.success(job.key) elif job.state in [ TaskInstanceState.FAILED, TaskInstanceState.RESTARTING, TaskInstanceState.UP_FOR_RETRY, ]: if job.key in self.last_reported_state: del self.last_reported_state[job.key] self.fail(job.key) else: self.last_reported_state[job.key] = TaskInstanceState(job.state) if ( job.state == TaskInstanceState.SUCCESS and job.last_update_t < (datetime.now() - timedelta(minutes=job_success_purge)).timestamp() ) or ( job.state in ( TaskInstanceState.FAILED, TaskInstanceState.REMOVED, TaskInstanceState.RESTARTING, TaskInstanceState.UP_FOR_RETRY, ) and job.last_update_t < (datetime.now() - timedelta(minutes=job_fail_purge)).timestamp() ): if job.key in self.last_reported_state: del self.last_reported_state[job.key] purged_marker = True session.delete(job) session.execute( delete(EdgeLogsModel).where( EdgeLogsModel.dag_id == job.dag_id, EdgeLogsModel.run_id == job.run_id, EdgeLogsModel.task_id == job.task_id, EdgeLogsModel.map_index == job.map_index, EdgeLogsModel.try_number == job.try_number, ) ) return purged_marker @provide_session
[docs] def sync(self, session: Session = NEW_SESSION) -> None: """Sync will get called periodically by the heartbeat method.""" with Stats.timer("edge_executor.sync.duration"): orphaned = self._update_orphaned_jobs(session) purged = self._purge_jobs(session) liveness = self._check_worker_liveness(session) if purged or liveness or orphaned: session.commit()
[docs] def end(self) -> None: """End the executor.""" self.log.info("Shutting down EdgeExecutor")
[docs] def terminate(self): """Terminate the executor is not doing anything."""
@provide_session
[docs] def revoke_task(self, *, ti: TaskInstance, session: Session = NEW_SESSION): """ Revoke a task instance from the executor. This method removes the task from the executor's internal state and deletes the corresponding EdgeJobModel record to prevent edge workers from picking it up. :param ti: Task instance to revoke :param session: Database session """ # Remove from executor's internal state self.running.discard(ti.key) self.queued_tasks.pop(ti.key, None) if ti.key in self.last_reported_state: del self.last_reported_state[ti.key] # Delete the job from the database to prevent edge workers from picking it up session.execute( delete(EdgeJobModel).where( EdgeJobModel.dag_id == ti.dag_id, EdgeJobModel.task_id == ti.task_id, EdgeJobModel.run_id == ti.run_id, EdgeJobModel.map_index == ti.map_index, EdgeJobModel.try_number == ti.try_number, ) ) self.log.info("Revoked task instance %s from EdgeExecutor", ti.key)
[docs] def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]: """ Try to adopt running task instances that have been abandoned by a SchedulerJob dying. Anything that is not adopted will be cleared by the scheduler (and then become eligible for re-scheduling) :return: any TaskInstances that were unable to be adopted """ # We handle all running tasks from the DB in sync, no adoption logic needed. return []
@staticmethod
[docs] def get_cli_commands() -> list[GroupCommand]: from airflow.providers.edge3.cli.definition import get_edge_cli_commands return get_edge_cli_commands()

Was this entry helpful?