from __future__ import annotations

import datetime
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from import BaseConfigKeys
from airflow.utils.state import State

    from airflow.models.taskinstance import TaskInstanceKey

[docs] CommandType = list[str]
[docs] ExecutorConfigType = dict[str, Any]
[docs] CONFIG_GROUP_NAME = "aws_batch_executor"
[docs] CONFIG_DEFAULTS = { "conn_id": "aws_default", "max_submit_job_attempts": "3", "check_health_on_startup": "True", }
[docs] class BatchQueuedJob: """Represents a Batch job that is queued. The job will be run in the next heartbeat."""
[docs] key: TaskInstanceKey
[docs] command: CommandType
[docs] queue: str
[docs] executor_config: ExecutorConfigType
[docs] attempt_number: int
[docs] next_attempt_time: datetime.datetime
[docs] class BatchJobInfo: """Contains information about a currently running Batch job."""
[docs] cmd: CommandType
[docs] queue: str
[docs] config: ExecutorConfigType
[docs] class BatchJob: """Data Transfer Object for an AWS Batch Job."""
def __init__(self, job_id: str, status: str, status_reason: str | None = None):
[docs] self.job_id = job_id
[docs] self.status = status
[docs] self.status_reason = status_reason
[docs] def get_job_state(self) -> str: """Return the state of the job.""" return self.STATE_MAPPINGS.get(self.status, State.QUEUED)
[docs] def __repr__(self): """Return a visual representation of the Job status.""" return f"({self.job_id} -> {self.status}, {self.get_job_state()})"
[docs] class BatchJobCollection: """A collection to manage running Batch Jobs.""" def __init__(self):
[docs] self.key_to_id: dict[TaskInstanceKey, str] = {}
[docs] self.id_to_key: dict[str, TaskInstanceKey] = {}
[docs] self.id_to_failure_counts: dict[str, int] = defaultdict(int)
[docs] self.id_to_job_info: dict[str, BatchJobInfo] = {}
[docs] def add_job( self, job_id: str, airflow_task_key: TaskInstanceKey, airflow_cmd: CommandType, queue: str, exec_config: ExecutorConfigType, attempt_number: int, ): """Add a job to the collection.""" self.key_to_id[airflow_task_key] = job_id self.id_to_key[job_id] = airflow_task_key self.id_to_failure_counts[job_id] = attempt_number self.id_to_job_info[job_id] = BatchJobInfo(cmd=airflow_cmd, queue=queue, config=exec_config)
[docs] def pop_by_id(self, job_id: str) -> TaskInstanceKey: """Delete job from collection based off of Batch Job ID.""" task_key = self.id_to_key[job_id] del self.key_to_id[task_key] del self.id_to_key[job_id] del self.id_to_failure_counts[job_id] return task_key
[docs] def failure_count_by_id(self, job_id: str) -> int: """Get the number of times a job has failed given a Batch Job Id.""" return self.id_to_failure_counts[job_id]
[docs] def increment_failure_count(self, job_id: str): """Increment the failure counter given a Batch Job Id.""" self.id_to_failure_counts[job_id] += 1
[docs] def get_all_jobs(self) -> list[str]: """Get all AWS ARNs in collection.""" return list(self.id_to_key.keys())
[docs] def __len__(self): """Return the number of jobs in collection.""" return len(self.key_to_id)
[docs] class BatchSubmitJobKwargsConfigKeys(BaseConfigKeys): """Keys loaded into the config which are valid Batch submit_job kwargs."""
[docs] JOB_NAME = "job_name"
[docs] JOB_QUEUE = "job_queue"
[docs] JOB_DEFINITION = "job_definition"
[docs] EKS_PROPERTIES_OVERRIDE = "eks_properties_override"
[docs] NODE_OVERRIDE = "node_override"
[docs] class AllBatchConfigKeys(BatchSubmitJobKwargsConfigKeys): """All keys loaded into the config which are related to the Batch Executor."""
[docs] MAX_SUBMIT_JOB_ATTEMPTS = "max_submit_job_attempts"
[docs] AWS_CONN_ID = "conn_id"
[docs] SUBMIT_JOB_KWARGS = "submit_job_kwargs"
[docs] REGION_NAME = "region_name"
[docs] CHECK_HEALTH_ON_STARTUP = "check_health_on_startup"
[docs] class BatchExecutorException(Exception): """Thrown when something unexpected has occurred within the AWS Batch ecosystem."""

