# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from datetime import datetime
from typing import (
    Annotated,
    Any,
)
from pydantic import BaseModel, Field
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.providers.edge3.models.edge_worker import EdgeWorkerState  # noqa: TCH001
from airflow.providers.edge3.worker_api.routes._v2_compat import ExecuteTask, Path
[docs]
class WorkerApiDocs:
    """Documentation collection for the worker API."""
[docs]
    dag_id = Path(title="Dag ID", description="Identifier of the DAG to which the task belongs.") 
[docs]
    task_id = Path(title="Task ID", description="Task name in the DAG.") 
[docs]
    run_id = Path(title="Run ID", description="Run ID of the DAG execution.") 
[docs]
    try_number = Path(title="Try Number", description="The number of attempt to execute this task.") 
[docs]
    map_index = Path(
        title="Map Index",
        description="For dynamically mapped tasks the mapping number, -1 if the task is not mapped.",
    ) 
[docs]
    state = Path(title="Task State", description="State of the assigned task under execution.") 
 
[docs]
class JsonRpcRequestBase(BaseModel):
    """Base JSON RPC request model to define just the method."""
[docs]
    method: Annotated[
        str,
        Field(description="Fully qualified python module method name that is called via JSON RPC."),
    ] 
 
[docs]
class JsonRpcRequest(JsonRpcRequestBase):
    """JSON RPC request model."""
[docs]
    jsonrpc: Annotated[str, Field(description="JSON RPC Version", examples=["2.0"])] 
[docs]
    params: Annotated[
        dict[str, Any] | None,
        Field(description="Dictionary of parameters passed to the method."),
    ] 
 
[docs]
class EdgeJobBase(BaseModel):
    """Basic attributes of a job on the edge worker."""
[docs]
    dag_id: Annotated[
        str, Field(title="Dag ID", description="Identifier of the DAG to which the task belongs.")
    ] 
[docs]
    task_id: Annotated[str, Field(title="Task ID", description="Task name in the DAG.")] 
[docs]
    run_id: Annotated[str, Field(title="Run ID", description="Run ID of the DAG execution.")] 
[docs]
    map_index: Annotated[
        int,
        Field(
            title="Map Index",
            description="For dynamically mapped tasks the mapping number, -1 if the task is not mapped.",
        ),
    ] 
[docs]
    try_number: Annotated[
        int, Field(title="Try Number", description="The number of attempt to execute this task.")
    ] 
    @property
[docs]
    def key(self) -> TaskInstanceKey:
        return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index) 
 
[docs]
class EdgeJobFetched(EdgeJobBase):
    """Job that is to be executed on the edge worker."""
[docs]
    command: Annotated[
        ExecuteTask,
        Field(
            title="Command",
            description="Command line to use to execute the job in Airflow 2. Task definition in Airflow 3",
        ),
    ] 
[docs]
    concurrency_slots: Annotated[int, Field(description="Number of concurrency slots the job requires.")] 
 
[docs]
class WorkerQueuesBase(BaseModel):
    """Queues that a worker supports to run jobs on."""
[docs]
    queues: Annotated[
        list[str] | None,
        Field(
            None,
            description="List of queues the worker is pulling jobs from. If not provided, worker pulls from all queues.",
        ),
    ] 
 
[docs]
class WorkerQueuesBody(WorkerQueuesBase):
    """Queues that a worker supports to run jobs on."""
[docs]
    free_concurrency: Annotated[int, Field(description="Number of free concurrency slots on the worker.")] 
 
[docs]
class WorkerStateBody(WorkerQueuesBase):
    """Details of the worker state sent to the scheduler."""
[docs]
    state: Annotated[EdgeWorkerState, Field(description="State of the worker from the view of the worker.")] 
[docs]
    jobs_active: Annotated[int, Field(description="Number of active jobs the worker is running.")] = 0 
[docs]
    queues: Annotated[
        list[str] | None,
        Field(
            description="List of queues the worker is pulling jobs from. If not provided, worker pulls from all queues."
        ),
    ] = None 
[docs]
    sysinfo: Annotated[
        dict[str, str | int],
        Field(
            description="System information of the worker.",
            examples=[
                {
                    "concurrency": 4,
                    "free_concurrency": 3,
                    "airflow_version": "2.0.0",
                    "edge_provider_version": "1.0.0",
                }
            ],
        ),
    ] 
[docs]
    maintenance_comments: Annotated[
        str | None,
        Field(description="Comments about the maintenance state of the worker."),
    ] = None 
 
[docs]
class WorkerQueueUpdateBody(BaseModel):
    """Changed queues for the worker."""
[docs]
    new_queues: Annotated[
        list[str] | None,
        Field(description="Additional queues to be added to worker."),
    ] 
[docs]
    remove_queues: Annotated[
        list[str] | None,
        Field(description="Queues to remove from worker."),
    ] 
 
[docs]
class PushLogsBody(BaseModel):
    """Incremental new log content from worker."""
[docs]
    log_chunk_time: Annotated[datetime, Field(description="Time of the log chunk at point of sending.")] 
[docs]
    log_chunk_data: Annotated[str, Field(description="Log chunk data as incremental log text.")] 
 
[docs]
class WorkerRegistrationReturn(BaseModel):
    """The return class for the worker registration."""
[docs]
    last_update: Annotated[datetime, Field(description="Time of the last update of the worker.")] 
 
[docs]
class WorkerSetStateReturn(BaseModel):
    """The return class for the worker set state."""
[docs]
    state: Annotated[EdgeWorkerState, Field(description="State of the worker from the view of the server.")] 
[docs]
    queues: Annotated[
        list[str] | None,
        Field(
            description="List of queues the worker is pulling jobs from. If not provided, worker pulls from all queues."
        ),
    ] 
[docs]
    maintenance_comments: Annotated[
        str | None,
        Field(description="Comments about the maintenance state of the worker."),
    ] = None