Source code for airflow.providers.amazon.aws.executors.batch.utils

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import datetime
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from airflow.providers.amazon.aws.executors.utils.base_config_keys import BaseConfigKeys
from airflow.utils.state import State

if TYPE_CHECKING:
    from airflow.models.taskinstance import TaskInstanceKey

[docs]CommandType = list[str]
[docs]ExecutorConfigType = dict[str, Any]
[docs]CONFIG_GROUP_NAME = "aws_batch_executor"
[docs]CONFIG_DEFAULTS = { "conn_id": "aws_default", "max_submit_job_attempts": "3", "check_health_on_startup": "True", }
@dataclass
[docs]class BatchQueuedJob: """Represents a Batch job that is queued. The job will be run in the next heartbeat."""
[docs] key: TaskInstanceKey
[docs] command: CommandType
[docs] queue: str
[docs] executor_config: ExecutorConfigType
[docs] attempt_number: int
[docs] next_attempt_time: datetime.datetime
@dataclass
[docs]class BatchJobInfo: """Contains information about a currently running Batch job."""
[docs] cmd: CommandType
[docs] queue: str
[docs] config: ExecutorConfigType
[docs]class BatchJob: """Data Transfer Object for an AWS Batch Job."""
[docs] STATE_MAPPINGS = { "SUBMITTED": State.QUEUED, "PENDING": State.QUEUED, "RUNNABLE": State.QUEUED, "STARTING": State.QUEUED, "RUNNING": State.RUNNING, "SUCCEEDED": State.SUCCESS, "FAILED": State.FAILED, }
def __init__(self, job_id: str, status: str, status_reason: str | None = None): self.job_id = job_id self.status = status self.status_reason = status_reason
[docs] def get_job_state(self) -> str: """Return the state of the job.""" return self.STATE_MAPPINGS.get(self.status, State.QUEUED)
[docs] def __repr__(self): """Return a visual representation of the Job status.""" return f"({self.job_id} -> {self.status}, {self.get_job_state()})"
[docs]class BatchJobCollection: """A collection to manage running Batch Jobs.""" def __init__(self): self.key_to_id: dict[TaskInstanceKey, str] = {} self.id_to_key: dict[str, TaskInstanceKey] = {} self.id_to_failure_counts: dict[str, int] = defaultdict(int) self.id_to_job_info: dict[str, BatchJobInfo] = {}
[docs] def add_job( self, job_id: str, airflow_task_key: TaskInstanceKey, airflow_cmd: CommandType, queue: str, exec_config: ExecutorConfigType, attempt_number: int, ): """Add a job to the collection.""" self.key_to_id[airflow_task_key] = job_id self.id_to_key[job_id] = airflow_task_key self.id_to_failure_counts[job_id] = attempt_number self.id_to_job_info[job_id] = BatchJobInfo(cmd=airflow_cmd, queue=queue, config=exec_config)
[docs] def pop_by_id(self, job_id: str) -> TaskInstanceKey: """Delete job from collection based off of Batch Job ID.""" task_key = self.id_to_key[job_id] del self.key_to_id[task_key] del self.id_to_key[job_id] del self.id_to_failure_counts[job_id] return task_key
[docs] def failure_count_by_id(self, job_id: str) -> int: """Get the number of times a job has failed given a Batch Job Id.""" return self.id_to_failure_counts[job_id]
[docs] def increment_failure_count(self, job_id: str): """Increment the failure counter given a Batch Job Id.""" self.id_to_failure_counts[job_id] += 1
[docs] def get_all_jobs(self) -> list[str]: """Get all AWS ARNs in collection.""" return list(self.id_to_key.keys())
[docs] def __len__(self): """Return the number of jobs in collection.""" return len(self.key_to_id)
[docs]class BatchSubmitJobKwargsConfigKeys(BaseConfigKeys): """Keys loaded into the config which are valid Batch submit_job kwargs."""
[docs] JOB_NAME = "job_name"
[docs] JOB_QUEUE = "job_queue"
[docs] JOB_DEFINITION = "job_definition"
[docs] EKS_PROPERTIES_OVERRIDE = "eks_properties_override"
[docs] NODE_OVERRIDE = "node_override"
[docs]class AllBatchConfigKeys(BatchSubmitJobKwargsConfigKeys): """All keys loaded into the config which are related to the Batch Executor."""
[docs] MAX_SUBMIT_JOB_ATTEMPTS = "max_submit_job_attempts"
[docs] AWS_CONN_ID = "conn_id"
[docs] SUBMIT_JOB_KWARGS = "submit_job_kwargs"
[docs] REGION_NAME = "region_name"
[docs] CHECK_HEALTH_ON_STARTUP = "check_health_on_startup"
[docs]class BatchExecutorException(Exception): """Thrown when something unexpected has occurred within the AWS Batch ecosystem."""

Was this entry helpful?