Source code for airflow.providers.microsoft.azure.operators.container_instances

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import re
import time
from collections import namedtuple
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any

from azure.mgmt.containerinstance.models import (
    Container,
    ContainerGroup,
    ContainerGroupDiagnostics,
    ContainerGroupSubnetId,
    ContainerPort,
    DnsConfiguration,
    EnvironmentVariable,
    IpAddress,
    ResourceRequests,
    ResourceRequirements,
    Volume as _AzureVolume,
    VolumeMount,
)
from msrestazure.azure_exceptions import CloudError

from airflow.exceptions import AirflowException, AirflowTaskTimeout
from airflow.models import BaseOperator
from airflow.providers.microsoft.azure.hooks.container_instance import AzureContainerInstanceHook
from airflow.providers.microsoft.azure.hooks.container_registry import AzureContainerRegistryHook
from airflow.providers.microsoft.azure.hooks.container_volume import AzureContainerVolumeHook

if TYPE_CHECKING:
    from airflow.utils.context import Context

[docs]Volume = namedtuple( "Volume", ["conn_id", "account_name", "share_name", "mount_path", "read_only"], )
[docs]DEFAULT_ENVIRONMENT_VARIABLES: dict[str, str] = {}
[docs]DEFAULT_SECURED_VARIABLES: Sequence[str] = []
[docs]DEFAULT_VOLUMES: Sequence[Volume] = []
[docs]DEFAULT_MEMORY_IN_GB = 2.0
[docs]DEFAULT_CPU = 1.0
[docs]class AzureContainerInstancesOperator(BaseOperator): """ Start a container on Azure Container Instances. :param ci_conn_id: connection id of a service principal which will be used to start the container instance :param registry_conn_id: connection id of a user which can login to a private docker registry. For Azure use :ref:`Azure connection id<howto/connection:azure>` :param resource_group: name of the resource group wherein this container instance should be started :param name: name of this container instance. Please note this name has to be unique in order to run containers in parallel. :param image: the docker image to be used :param region: the region wherein this container instance should be started :param environment_variables: key,value pairs containing environment variables which will be passed to the running container :param secured_variables: names of environmental variables that should not be exposed outside the container (typically passwords). :param volumes: list of ``Volume`` tuples to be mounted to the container. Currently only Azure Fileshares are supported. :param memory_in_gb: the amount of memory to allocate to this container :param cpu: the number of cpus to allocate to this container :param gpu: GPU Resource for the container. :param command: the command to run inside the container :param container_timeout: max time allowed for the execution of the container instance. :param tags: azure tags as dict of str:str :param xcom_all: Control if logs are pushed to XCOM similarly to how DockerOperator does. Possible values include: 'None', 'True', 'False'. Defaults to 'None', meaning no logs are pushed to XCOM which is the historical behaviour. 'True' means push all logs to XCOM which may run the risk of hitting XCOM size limits. 'False' means push only the last line of the logs to XCOM. However, the logs are pushed into XCOM under "logs", not return_value to avoid breaking the existing behaviour. :param os_type: The operating system type required by the containers in the container group. Possible values include: 'Windows', 'Linux' :param restart_policy: Restart policy for all containers within the container group. Possible values include: 'Always', 'OnFailure', 'Never' :param ip_address: The IP address type of the container group. :param subnet_ids: The subnet resource IDs for a container group :param dns_config: The DNS configuration for a container group. :param diagnostics: Container group diagnostic information (Log Analytics). :param priority: Container group priority, Possible values include: 'Regular', 'Spot' **Example**:: AzureContainerInstancesOperator( ci_conn_id="azure_service_principal", registry_conn_id="azure_registry_user", resource_group="my-resource-group", name="my-container-name-{{ ds }}", image="myprivateregistry.azurecr.io/my_container:latest", region="westeurope", environment_variables={ "MODEL_PATH": "my_value", "POSTGRES_LOGIN": "{{ macros.connection('postgres_default').login }}", "POSTGRES_PASSWORD": "{{ macros.connection('postgres_default').password }}", "JOB_GUID": "{{ ti.xcom_pull(task_ids='task1', key='guid') }}", }, secured_variables=["POSTGRES_PASSWORD"], volumes=[ ("azure_container_instance_conn_id", "my_storage_container", "my_fileshare", "/input-data", True), ], memory_in_gb=14.0, cpu=4.0, gpu=GpuResource(count=1, sku="K80"), subnet_ids=[ { "id": "/subscriptions/00000000-0000-0000-0000-00000000000/resourceGroups/my_rg/providers/Microsoft.Network/virtualNetworks/my_vnet/subnets/my_subnet" } ], dns_config={"name_servers": ["10.0.0.10", "10.0.0.11"]}, diagnostics={ "log_analytics": { "workspaceId": "workspaceid", "workspaceKey": "workspaceKey", } }, priority="Regular", command=["/bin/echo", "world"], task_id="start_container", ) """
[docs] template_fields: Sequence[str] = ("name", "image", "command", "environment_variables", "volumes")
[docs] template_fields_renderers = {"command": "bash", "environment_variables": "json"}
def __init__( self, *, ci_conn_id: str, resource_group: str, name: str, image: str, region: str, registry_conn_id: str | None = None, environment_variables: dict | None = None, secured_variables: str | None = None, volumes: list | None = None, memory_in_gb: Any | None = None, cpu: Any | None = None, gpu: Any | None = None, command: list[str] | None = None, remove_on_error: bool = True, fail_if_exists: bool = True, tags: dict[str, str] | None = None, xcom_all: bool | None = None, os_type: str = "Linux", restart_policy: str = "Never", ip_address: IpAddress | None = None, ports: list[ContainerPort] | None = None, subnet_ids: list[ContainerGroupSubnetId] | None = None, dns_config: DnsConfiguration | None = None, diagnostics: ContainerGroupDiagnostics | None = None, priority: str | None = "Regular", **kwargs, ) -> None: super().__init__(**kwargs) self.ci_conn_id = ci_conn_id self.resource_group = resource_group self.name = name self.image = image self.region = region self.registry_conn_id = registry_conn_id self.environment_variables = environment_variables or DEFAULT_ENVIRONMENT_VARIABLES self.secured_variables = secured_variables or DEFAULT_SECURED_VARIABLES self.volumes = volumes or DEFAULT_VOLUMES self.memory_in_gb = memory_in_gb or DEFAULT_MEMORY_IN_GB self.cpu = cpu or DEFAULT_CPU self.gpu = gpu self.command = command self.remove_on_error = remove_on_error self.fail_if_exists = fail_if_exists self._ci_hook: Any = None self.tags = tags self.xcom_all = xcom_all self.os_type = os_type if self.os_type not in ["Linux", "Windows"]: raise AirflowException( "Invalid value for the os_type argument. " "Please set 'Linux' or 'Windows' as the os_type. " f"Found `{self.os_type}`." ) self.restart_policy = restart_policy if self.restart_policy not in ["Always", "OnFailure", "Never"]: raise AirflowException( "Invalid value for the restart_policy argument. " "Please set one of 'Always', 'OnFailure','Never' as the restart_policy. " f"Found `{self.restart_policy}`" ) self.ip_address = ip_address self.ports = ports self.subnet_ids = subnet_ids self.dns_config = dns_config self.diagnostics = diagnostics self.priority = priority if self.priority not in ["Regular", "Spot"]: raise AirflowException( "Invalid value for the priority argument. " "Please set 'Regular' or 'Spot' as the priority. " f"Found `{self.priority}`." )
[docs] def execute(self, context: Context) -> int: # Check name again in case it was templated. self._check_name(self.name) self._ci_hook = AzureContainerInstanceHook(azure_conn_id=self.ci_conn_id) if self.fail_if_exists: self.log.info("Testing if container group already exists") if self._ci_hook.exists(self.resource_group, self.name): raise AirflowException("Container group exists") if self.registry_conn_id: registry_hook = AzureContainerRegistryHook(self.registry_conn_id) image_registry_credentials: list | None = [ registry_hook.connection, ] else: image_registry_credentials = None environment_variables = [] for key, value in self.environment_variables.items(): if key in self.secured_variables: e = EnvironmentVariable(name=key, secure_value=value) else: e = EnvironmentVariable(name=key, value=value) environment_variables.append(e) volumes: list[_AzureVolume] = [] volume_mounts: list[VolumeMount | VolumeMount] = [] for conn_id, account_name, share_name, mount_path, read_only in self.volumes: hook = AzureContainerVolumeHook(conn_id) mount_name = f"mount-{len(volumes)}" volumes.append(hook.get_file_volume(mount_name, share_name, account_name, read_only)) volume_mounts.append(VolumeMount(name=mount_name, mount_path=mount_path, read_only=read_only)) exit_code = 1 try: self.log.info("Starting container group with %.1f cpu %.1f mem", self.cpu, self.memory_in_gb) if self.gpu: self.log.info("GPU count: %.1f, GPU SKU: %s", self.gpu.count, self.gpu.sku) resources = ResourceRequirements( requests=ResourceRequests(memory_in_gb=self.memory_in_gb, cpu=self.cpu, gpu=self.gpu) ) if self.ip_address and not self.ports: self.ports = [ContainerPort(port=80)] self.log.info("Default port set. Container will listen on port 80") container = Container( name=self.name, image=self.image, resources=resources, command=self.command, environment_variables=environment_variables, volume_mounts=volume_mounts, ports=self.ports, ) container_group = ContainerGroup( location=self.region, containers=[ container, ], image_registry_credentials=image_registry_credentials, volumes=volumes, restart_policy=self.restart_policy, os_type=self.os_type, tags=self.tags, ip_address=self.ip_address, subnet_ids=self.subnet_ids, dns_config=self.dns_config, diagnostics=self.diagnostics, priority=self.priority, ) self._ci_hook.create_or_update(self.resource_group, self.name, container_group) self.log.info("Container group started %s/%s", self.resource_group, self.name) exit_code = self._monitor_logging(self.resource_group, self.name) if self.xcom_all is not None: logs = self._ci_hook.get_logs(self.resource_group, self.name) if logs is None: context["ti"].xcom_push(key="logs", value=[]) else: if self.xcom_all: context["ti"].xcom_push(key="logs", value=logs) else: # slice off the last entry in the list logs and return it as a list context["ti"].xcom_push(key="logs", value=logs[-1:]) self.log.info("Container had exit code: %s", exit_code) if exit_code != 0: raise AirflowException(f"Container had a non-zero exit code, {exit_code}") return exit_code except CloudError: self.log.exception("Could not start container group") raise AirflowException("Could not start container group") finally: if exit_code == 0 or self.remove_on_error: self.on_kill()
[docs] def on_kill(self) -> None: self.log.info("Deleting container group") try: self._ci_hook.delete(self.resource_group, self.name) except Exception: self.log.exception("Could not delete container group")
def _monitor_logging(self, resource_group: str, name: str) -> int: last_state = None last_message_logged = None last_line_logged = None while True: try: cg_state = self._ci_hook.get_state(resource_group, name) instance_view = cg_state.containers[0].instance_view # If there is no instance view, we show the provisioning state if instance_view is not None: c_state = instance_view.current_state state, exit_code, detail_status = ( c_state.state, c_state.exit_code, c_state.detail_status, ) else: state = cg_state.provisioning_state exit_code = 0 detail_status = "Provisioning" if instance_view is not None and instance_view.events is not None: messages = [event.message for event in instance_view.events] last_message_logged = self._log_last(messages, last_message_logged) if state != last_state: self.log.info("Container group state changed to %s", state) last_state = state if state in ["Running", "Terminated", "Succeeded"]: try: logs = self._ci_hook.get_logs(resource_group, name) if logs and logs[0] is None: self.log.error("Container log is broken, marking as failed.") return 1 last_line_logged = self._log_last(logs, last_line_logged) except CloudError: self.log.exception( "Exception while getting logs from container instance, retrying..." ) if state == "Terminated": self.log.info("Container exited with detail_status %s", detail_status) return exit_code if state == "Failed": self.log.error("Azure provision failure") return 1 except AirflowTaskTimeout: raise except CloudError as err: if "ResourceNotFound" in str(err): self.log.warning( "ResourceNotFound, container is probably removed " "by another process " "(make sure that the name is unique)." ) return 1 else: self.log.exception("Exception while getting container groups") except Exception: self.log.exception("Exception while getting container groups") time.sleep(1) def _log_last(self, logs: list | None, last_line_logged: Any) -> Any | None: if logs: # determine the last line which was logged before last_line_index = 0 for i in range(len(logs) - 1, -1, -1): if logs[i] == last_line_logged: # this line is the same, hence print from i+1 last_line_index = i + 1 break # log all new ones for line in logs[last_line_index:]: self.log.info(line.rstrip()) return logs[-1] return None @staticmethod def _check_name(name: str) -> str: regex_check = re.match("[a-z0-9]([-a-z0-9]*[a-z0-9])?", name) if regex_check is None or regex_check.group() != name: raise AirflowException('ACI name must match regex [a-z0-9]([-a-z0-9]*[a-z0-9])? (like "my-name")') if len(name) > 63: raise AirflowException("ACI name cannot be longer than 63 characters") return name

Was this entry helpful?