Source code for airflow.providers.edge.worker_api.routes.jobs

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from ast import literal_eval
from typing import Annotated

from sqlalchemy import select, update

from airflow.providers.edge.models.edge_job import EdgeJobModel
from airflow.providers.edge.worker_api.auth import jwt_token_authorization_rest
from airflow.providers.edge.worker_api.datamodels import (
    EdgeJobFetched,
    WorkerApiDocs,
    WorkerQueuesBody,
)
from airflow.providers.edge.worker_api.routes._v2_compat import (
    AirflowRouter,
    Body,
    Depends,
    SessionDep,
    create_openapi_http_exception_doc,
    status,
)
from airflow.utils import timezone
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.state import TaskInstanceState

[docs]jobs_router = AirflowRouter(tags=["Jobs"], prefix="/jobs")
@jobs_router.post( "/fetch/{worker_name}", dependencies=[Depends(jwt_token_authorization_rest)], responses=create_openapi_http_exception_doc( [ status.HTTP_400_BAD_REQUEST, status.HTTP_403_FORBIDDEN, ] ), )
[docs]def fetch( worker_name: str, body: Annotated[ WorkerQueuesBody, Body( title="Log data chunks", description="The queues and capacity from which the worker can fetch jobs.", ), ], session: SessionDep, ) -> EdgeJobFetched | None: """Fetch a job to execute on the edge worker.""" query = ( select(EdgeJobModel) .where( EdgeJobModel.state == TaskInstanceState.QUEUED, EdgeJobModel.concurrency_slots <= body.free_concurrency, ) .order_by(EdgeJobModel.queued_dttm) ) if body.queues: query = query.where(EdgeJobModel.queue.in_(body.queues)) query = query.limit(1) query = with_row_locks(query, of=EdgeJobModel, session=session, skip_locked=True) job: EdgeJobModel = session.scalar(query) if not job: return None job.state = TaskInstanceState.RUNNING job.edge_worker = worker_name job.last_update = timezone.utcnow() session.commit() return EdgeJobFetched( dag_id=job.dag_id, task_id=job.task_id, run_id=job.run_id, map_index=job.map_index, try_number=job.try_number, command=literal_eval(job.command), concurrency_slots=job.concurrency_slots, )
@jobs_router.patch( "/state/{dag_id}/{task_id}/{run_id}/{try_number}/{map_index}/{state}", dependencies=[Depends(jwt_token_authorization_rest)], responses=create_openapi_http_exception_doc( [ status.HTTP_400_BAD_REQUEST, status.HTTP_403_FORBIDDEN, ] ), )
[docs]def state( dag_id: Annotated[str, WorkerApiDocs.dag_id], task_id: Annotated[str, WorkerApiDocs.task_id], run_id: Annotated[str, WorkerApiDocs.run_id], try_number: Annotated[int, WorkerApiDocs.try_number], map_index: Annotated[int, WorkerApiDocs.map_index], state: Annotated[TaskInstanceState, WorkerApiDocs.state], session: SessionDep, ) -> None: """Update the state of a job running on the edge worker.""" query = ( update(EdgeJobModel) .where( EdgeJobModel.dag_id == dag_id, EdgeJobModel.task_id == task_id, EdgeJobModel.run_id == run_id, EdgeJobModel.map_index == map_index, EdgeJobModel.try_number == try_number, ) .values(state=state, last_update=timezone.utcnow()) ) session.execute(query)

Was this entry helpful?