Source code for airflow.providers.microsoft.azure.hooks.data_factory

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Spelling exceptions.

.. spelling:word-list::
    CreateRunResponse
    DatasetResource
    LinkedServiceResource
    LROPoller
    PipelineResource
    PipelineRun
    TriggerResource
    datafactory
    DataFlow
    DataFlowResource
    mgmt
"""

from __future__ import annotations

import inspect
import time
from functools import wraps
from typing import IO, TYPE_CHECKING, Any, Callable, TypeVar, Union, cast

from asgiref.sync import sync_to_async
from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.identity.aio import (
    ClientSecretCredential as AsyncClientSecretCredential,
    DefaultAzureCredential as AsyncDefaultAzureCredential,
)
from azure.mgmt.datafactory import DataFactoryManagementClient
from azure.mgmt.datafactory.aio import DataFactoryManagementClient as AsyncDataFactoryManagementClient

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import (
    add_managed_identity_connection_widgets,
    get_async_default_azure_credential,
    get_sync_default_azure_credential,
)

if TYPE_CHECKING:
    from azure.core.polling import LROPoller
    from azure.mgmt.datafactory.models import (
        CreateRunResponse,
        DataFlowResource,
        DatasetResource,
        Factory,
        LinkedServiceResource,
        PipelineResource,
        PipelineRun,
        TriggerResource,
    )

[docs]Credentials = Union[ClientSecretCredential, DefaultAzureCredential]
[docs]AsyncCredentials = Union[AsyncClientSecretCredential, AsyncDefaultAzureCredential]
[docs]T = TypeVar("T", bound=Any)
[docs]def provide_targeted_factory(func: Callable) -> Callable: """ Provide the targeted factory to the decorated function in case it isn't specified. If ``resource_group_name`` or ``factory_name`` is not provided it defaults to the value specified in the connection extras. """ signature = inspect.signature(func) @wraps(func) def wrapper(*args, **kwargs) -> Callable: bound_args = signature.bind(*args, **kwargs) def bind_argument(arg, default_key): # Check if arg was not included in the function signature or, if it is, the value is not provided. if arg not in bound_args.arguments or bound_args.arguments[arg] is None: self = args[0] conn = self.get_connection(self.conn_id) extras = conn.extra_dejson default_value = extras.get(default_key) or extras.get( f"extra__azure_data_factory__{default_key}" ) if not default_value: raise AirflowException("Could not determine the targeted data factory.") bound_args.arguments[arg] = default_value bind_argument("resource_group_name", "resource_group_name") bind_argument("factory_name", "factory_name") return func(*bound_args.args, **bound_args.kwargs) return wrapper
[docs]class AzureDataFactoryPipelineRunStatus: """Azure Data Factory pipeline operation statuses."""
[docs] QUEUED = "Queued"
[docs] IN_PROGRESS = "InProgress"
[docs] SUCCEEDED = "Succeeded"
[docs] FAILED = "Failed"
[docs] CANCELING = "Canceling"
[docs] CANCELLED = "Cancelled"
[docs] TERMINAL_STATUSES = {CANCELLED, FAILED, SUCCEEDED}
[docs] INTERMEDIATE_STATES = {QUEUED, IN_PROGRESS, CANCELING}
[docs] FAILURE_STATES = {FAILED, CANCELLED}
[docs]class AzureDataFactoryPipelineRunException(AirflowException): """An exception that indicates a pipeline run failed to complete."""
[docs]def get_field(extras: dict, field_name: str, strict: bool = False): """Get field from extra, first checking short name, then for backcompat we check for prefixed name.""" backcompat_prefix = "extra__azure_data_factory__" if field_name.startswith("extra__"): raise ValueError( f"Got prefixed name {field_name}; please remove the '{backcompat_prefix}' prefix " "when using this method." ) if field_name in extras: return extras[field_name] or None prefixed_name = f"{backcompat_prefix}{field_name}" if prefixed_name in extras: return extras[prefixed_name] or None if strict: raise KeyError(f"Field {field_name} not found in extras")
[docs]class AzureDataFactoryHook(BaseHook): """ A hook to interact with Azure Data Factory. :param azure_data_factory_conn_id: The :ref:`Azure Data Factory connection id<howto/connection:adf>`. """
[docs] conn_type: str = "azure_data_factory"
[docs] conn_name_attr: str = "azure_data_factory_conn_id"
[docs] default_conn_name: str = "azure_data_factory_default"
[docs] hook_name: str = "Azure Data Factory"
@classmethod @add_managed_identity_connection_widgets
[docs] def get_connection_form_widgets(cls) -> dict[str, Any]: """Return connection widgets to add to connection form.""" from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from flask_babel import lazy_gettext from wtforms import StringField return { "tenantId": StringField(lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()), "subscriptionId": StringField(lazy_gettext("Subscription ID"), widget=BS3TextFieldWidget()), "resource_group_name": StringField( lazy_gettext("Resource Group Name"), widget=BS3TextFieldWidget() ), "factory_name": StringField(lazy_gettext("Factory Name"), widget=BS3TextFieldWidget()), }
@classmethod
[docs] def get_ui_field_behaviour(cls) -> dict[str, Any]: """Return custom field behaviour.""" return { "hidden_fields": ["schema", "port", "host", "extra"], "relabeling": { "login": "Client ID", "password": "Secret", }, }
def __init__(self, azure_data_factory_conn_id: str = default_conn_name): self._conn: DataFactoryManagementClient | None = None self.conn_id = azure_data_factory_conn_id super().__init__()
[docs] def get_conn(self) -> DataFactoryManagementClient: if self._conn is not None: return self._conn conn = self.get_connection(self.conn_id) extras = conn.extra_dejson tenant = get_field(extras, "tenantId") try: subscription_id = get_field(extras, "subscriptionId", strict=True) except KeyError: raise ValueError("A Subscription ID is required to connect to Azure Data Factory.") credential: Credentials if conn.login is not None and conn.password is not None: if not tenant: raise ValueError("A Tenant ID is required when authenticating with Client ID and Secret.") credential = ClientSecretCredential( client_id=conn.login, client_secret=conn.password, tenant_id=tenant ) else: managed_identity_client_id = get_field(extras, "managed_identity_client_id") workload_identity_tenant_id = get_field(extras, "workload_identity_tenant_id") credential = get_sync_default_azure_credential( managed_identity_client_id=managed_identity_client_id, workload_identity_tenant_id=workload_identity_tenant_id, ) self._conn = self._create_client(credential, subscription_id) return self._conn
[docs] def refresh_conn(self) -> DataFactoryManagementClient: self._conn = None return self.get_conn()
@provide_targeted_factory
[docs] def get_factory(self, resource_group_name: str, factory_name: str, **config: Any) -> Factory | None: """ Get the factory. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. :return: The factory. """ return self.get_conn().factories.get(resource_group_name, factory_name, **config)
def _factory_exists(self, resource_group_name, factory_name) -> bool: """Return whether or not the factory already exists.""" factories = { factory.name for factory in self.get_conn().factories.list_by_resource_group(resource_group_name) } return factory_name in factories @staticmethod def _create_client(credential: Credentials, subscription_id: str): return DataFactoryManagementClient( credential=credential, subscription_id=subscription_id, ) @provide_targeted_factory
[docs] def update_factory( self, factory: Factory, resource_group_name: str, factory_name: str, if_match: str | None = None, **config: Any, ) -> Factory: """ Update the factory. :param factory: The factory resource definition. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param if_match: ETag of the factory entity. Should only be specified for update, for which it should match existing entity or can be * for unconditional update. Default value is None. :param config: Extra parameters for the ADF client. :raise AirflowException: If the factory does not exist. :return: The factory. """ if not self._factory_exists(resource_group_name, factory_name): raise AirflowException(f"Factory {factory!r} does not exist.") return self.get_conn().factories.create_or_update( resource_group_name, factory_name, factory, if_match, **config )
@provide_targeted_factory
[docs] def create_factory( self, factory: Factory, resource_group_name: str, factory_name: str, **config: Any, ) -> Factory: """ Create the factory. :param factory: The factory resource definition. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. :raise AirflowException: If the factory already exists. :return: The factory. """ if self._factory_exists(resource_group_name, factory_name): raise AirflowException(f"Factory {factory!r} already exists.") return self.get_conn().factories.create_or_update( resource_group_name, factory_name, factory, **config )
@provide_targeted_factory
[docs] def delete_factory(self, resource_group_name: str, factory_name: str, **config: Any) -> None: """ Delete the factory. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. """ self.get_conn().factories.delete(resource_group_name, factory_name, **config)
@provide_targeted_factory
[docs] def get_linked_service( self, linked_service_name: str, resource_group_name: str, factory_name: str, if_none_match: str | None = None, **config: Any, ) -> LinkedServiceResource | None: """ Get the linked service. :param linked_service_name: The linked service name. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param if_none_match: ETag of the linked service entity. Should only be specified for get. If the ETag matches the existing entity tag, or if * was provided, then no content will be returned. Default value is None. :param config: Extra parameters for the ADF client. :return: The linked service. """ return self.get_conn().linked_services.get( resource_group_name, factory_name, linked_service_name, if_none_match, **config )
def _linked_service_exists(self, resource_group_name, factory_name, linked_service_name) -> bool: """Return whether or not the linked service already exists.""" linked_services = { linked_service.name for linked_service in self.get_conn().linked_services.list_by_factory( resource_group_name, factory_name ) } return linked_service_name in linked_services @provide_targeted_factory
[docs] def update_linked_service( self, linked_service_name: str, linked_service: LinkedServiceResource, resource_group_name: str, factory_name: str, **config: Any, ) -> LinkedServiceResource: """ Update the linked service. :param linked_service_name: The linked service name. :param linked_service: The linked service resource definition. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. :raise AirflowException: If the linked service does not exist. :return: The linked service. """ if not self._linked_service_exists(resource_group_name, factory_name, linked_service_name): raise AirflowException(f"Linked service {linked_service_name!r} does not exist.") return self.get_conn().linked_services.create_or_update( resource_group_name, factory_name, linked_service_name, linked_service, **config )
@provide_targeted_factory
[docs] def create_linked_service( self, linked_service_name: str, linked_service: LinkedServiceResource, resource_group_name: str, factory_name: str, **config: Any, ) -> LinkedServiceResource: """ Create the linked service. :param linked_service_name: The linked service name. :param linked_service: The linked service resource definition. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. :raise AirflowException: If the linked service already exists. :return: The linked service. """ if self._linked_service_exists(resource_group_name, factory_name, linked_service_name): raise AirflowException(f"Linked service {linked_service_name!r} already exists.") return self.get_conn().linked_services.create_or_update( resource_group_name, factory_name, linked_service_name, linked_service, **config )
@provide_targeted_factory
[docs] def delete_linked_service( self, linked_service_name: str, resource_group_name: str, factory_name: str, **config: Any, ) -> None: """ Delete the linked service. :param linked_service_name: The linked service name. :param resource_group_name: The linked service name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. """ self.get_conn().linked_services.delete( resource_group_name, factory_name, linked_service_name, **config )
@provide_targeted_factory
[docs] def get_dataset( self, dataset_name: str, resource_group_name: str, factory_name: str, **config: Any, ) -> DatasetResource | None: """ Get the dataset. :param dataset_name: The dataset name. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. :return: The dataset. """ return self.get_conn().datasets.get(resource_group_name, factory_name, dataset_name, **config)
def _dataset_exists(self, resource_group_name, factory_name, dataset_name) -> bool: """Return whether or not the dataset already exists.""" datasets = { dataset.name for dataset in self.get_conn().datasets.list_by_factory(resource_group_name, factory_name) } return dataset_name in datasets @provide_targeted_factory
[docs] def update_dataset( self, dataset_name: str, dataset: DatasetResource, resource_group_name: str, factory_name: str, **config: Any, ) -> DatasetResource: """ Update the dataset. :param dataset_name: The dataset name. :param dataset: The dataset resource definition. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. :raise AirflowException: If the dataset does not exist. :return: The dataset. """ if not self._dataset_exists(resource_group_name, factory_name, dataset_name): raise AirflowException(f"Dataset {dataset_name!r} does not exist.") return self.get_conn().datasets.create_or_update( resource_group_name, factory_name, dataset_name, dataset, **config )
@provide_targeted_factory
[docs] def create_dataset( self, dataset_name: str, dataset: DatasetResource, resource_group_name: str, factory_name: str, **config: Any, ) -> DatasetResource: """ Create the dataset. :param dataset_name: The dataset name. :param dataset: The dataset resource definition. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. :raise AirflowException: If the dataset already exists. :return: The dataset. """ if self._dataset_exists(resource_group_name, factory_name, dataset_name): raise AirflowException(f"Dataset {dataset_name!r} already exists.") return self.get_conn().datasets.create_or_update( resource_group_name, factory_name, dataset_name, dataset, **config )
@provide_targeted_factory
[docs] def delete_dataset( self, dataset_name: str, resource_group_name: str, factory_name: str, **config: Any, ) -> None: """ Delete the dataset. :param dataset_name: The dataset name. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. """ self.get_conn().datasets.delete(resource_group_name, factory_name, dataset_name, **config)
@provide_targeted_factory
[docs] def get_dataflow( self, dataflow_name: str, resource_group_name: str, factory_name: str, if_none_match: str | None = None, **config: Any, ) -> DataFlowResource: """ Get the dataflow. :param dataflow_name: The dataflow name. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param if_none_match: ETag of the data flow entity. Should only be specified for get. If the ETag matches the existing entity tag, or if * was provided, then no content will be returned. Default value is None. :param config: Extra parameters for the ADF client. :return: The DataFlowResource. """ return self.get_conn().data_flows.get( resource_group_name, factory_name, dataflow_name, if_none_match, **config )
def _dataflow_exists( self, dataflow_name: str, resource_group_name: str, factory_name: str, ) -> bool: """Return whether the dataflow already exists.""" dataflows = { dataflow.name for dataflow in self.get_conn().data_flows.list_by_factory(resource_group_name, factory_name) } return dataflow_name in dataflows @provide_targeted_factory
[docs] def update_dataflow( self, dataflow_name: str, dataflow: DataFlowResource | IO, resource_group_name: str, factory_name: str, if_match: str | None = None, **config: Any, ) -> DataFlowResource: """ Update the dataflow. :param dataflow_name: The dataflow name. :param dataflow: The dataflow resource definition. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param if_match: ETag of the data flow entity. Should only be specified for update, for which it should match existing entity or can be * for unconditional update. Default value is None. :param config: Extra parameters for the ADF client. :raise AirflowException: If the dataset does not exist. :return: DataFlowResource. """ if not self._dataflow_exists( dataflow_name, resource_group_name, factory_name, ): raise AirflowException(f"Dataflow {dataflow_name!r} does not exist.") return self.get_conn().data_flows.create_or_update( resource_group_name, factory_name, dataflow_name, dataflow, if_match, **config )
@provide_targeted_factory
[docs] def create_dataflow( self, dataflow_name: str, dataflow: DataFlowResource, resource_group_name: str, factory_name: str, if_match: str | None = None, **config: Any, ) -> DataFlowResource: """ Create the dataflow. :param dataflow_name: The dataflow name. :param dataflow: The dataflow resource definition. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param if_match: ETag of the factory entity. Should only be specified for update, for which it should match existing entity or can be * for unconditional update. Default value is None. :param config: Extra parameters for the ADF client. :raise AirflowException: If the dataset already exists. :return: The dataset. """ if self._dataflow_exists(dataflow_name, resource_group_name, factory_name): raise AirflowException(f"Dataflow {dataflow_name!r} already exists.") return self.get_conn().data_flows.create_or_update( resource_group_name, factory_name, dataflow_name, dataflow, if_match, **config )
@provide_targeted_factory
[docs] def delete_dataflow( self, dataflow_name: str, resource_group_name: str, factory_name: str, **config: Any, ) -> None: """ Delete the dataflow. :param dataflow_name: The dataflow name. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. """ self.get_conn().data_flows.delete(resource_group_name, factory_name, dataflow_name, **config)
@provide_targeted_factory
[docs] def get_pipeline( self, pipeline_name: str, resource_group_name: str, factory_name: str, **config: Any, ) -> PipelineResource | None: """ Get the pipeline. :param pipeline_name: The pipeline name. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. :return: The pipeline. """ return self.get_conn().pipelines.get(resource_group_name, factory_name, pipeline_name, **config)
def _pipeline_exists(self, resource_group_name, factory_name, pipeline_name) -> bool: """Return whether or not the pipeline already exists.""" pipelines = { pipeline.name for pipeline in self.get_conn().pipelines.list_by_factory(resource_group_name, factory_name) } return pipeline_name in pipelines @provide_targeted_factory
[docs] def update_pipeline( self, pipeline_name: str, pipeline: PipelineResource, resource_group_name: str, factory_name: str, **config: Any, ) -> PipelineResource: """ Update the pipeline. :param pipeline_name: The pipeline name. :param pipeline: The pipeline resource definition. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. :raise AirflowException: If the pipeline does not exist. :return: The pipeline. """ if not self._pipeline_exists(resource_group_name, factory_name, pipeline_name): raise AirflowException(f"Pipeline {pipeline_name!r} does not exist.") return self.get_conn().pipelines.create_or_update( resource_group_name, factory_name, pipeline_name, pipeline, **config )
@provide_targeted_factory
[docs] def create_pipeline( self, pipeline_name: str, pipeline: PipelineResource, resource_group_name: str, factory_name: str, **config: Any, ) -> PipelineResource: """ Create the pipeline. :param pipeline_name: The pipeline name. :param pipeline: The pipeline resource definition. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. :raise AirflowException: If the pipeline already exists. :return: The pipeline. """ if self._pipeline_exists(resource_group_name, factory_name, pipeline_name): raise AirflowException(f"Pipeline {pipeline_name!r} already exists.") return self.get_conn().pipelines.create_or_update( resource_group_name, factory_name, pipeline_name, pipeline, **config )
@provide_targeted_factory
[docs] def delete_pipeline( self, pipeline_name: str, resource_group_name: str, factory_name: str, **config: Any, ) -> None: """ Delete the pipeline. :param pipeline_name: The pipeline name. :param resource_group_name: The pipeline name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. """ self.get_conn().pipelines.delete(resource_group_name, factory_name, pipeline_name, **config)
@provide_targeted_factory
[docs] def run_pipeline( self, pipeline_name: str, resource_group_name: str, factory_name: str, **config: Any, ) -> CreateRunResponse: """ Run a pipeline. :param pipeline_name: The pipeline name. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. :return: The pipeline run. """ return self.get_conn().pipelines.create_run( resource_group_name, factory_name, pipeline_name, **config )
@provide_targeted_factory
[docs] def get_pipeline_run( self, run_id: str, resource_group_name: str, factory_name: str, **config: Any, ) -> PipelineRun: """ Get the pipeline run. :param run_id: The pipeline run identifier. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. :return: The pipeline run. """ return self.get_conn().pipeline_runs.get(resource_group_name, factory_name, run_id, **config)
[docs] def get_pipeline_run_status( self, run_id: str, resource_group_name: str, factory_name: str, ) -> str: """ Get a pipeline run's current status. :param run_id: The pipeline run identifier. :param resource_group_name: The resource group name. :param factory_name: The factory name. :return: The status of the pipeline run. """ self.log.info("Getting the status of run ID %s.", run_id) pipeline_run_status = self.get_pipeline_run(run_id, resource_group_name, factory_name).status self.log.info("Current status of pipeline run %s: %s", run_id, pipeline_run_status) return pipeline_run_status
[docs] def wait_for_pipeline_run_status( self, run_id: str, expected_statuses: str | set[str], resource_group_name: str, factory_name: str, check_interval: int = 60, timeout: int = 60 * 60 * 24 * 7, ) -> bool: """ Wait for a pipeline run to match an expected status. :param run_id: The pipeline run identifier. :param expected_statuses: The desired status(es) to check against a pipeline run's current status. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param check_interval: Time in seconds to check on a pipeline run's status. :param timeout: Time in seconds to wait for a pipeline to reach a terminal status or the expected status. :return: Boolean indicating if the pipeline run has reached the ``expected_status``. """ pipeline_run_status = self.get_pipeline_run_status(run_id, resource_group_name, factory_name) start_time = time.monotonic() while ( pipeline_run_status not in AzureDataFactoryPipelineRunStatus.TERMINAL_STATUSES and pipeline_run_status not in expected_statuses ): # Check if the pipeline-run duration has exceeded the ``timeout`` configured. if start_time + timeout < time.monotonic(): raise AzureDataFactoryPipelineRunException( f"Pipeline run {run_id} has not reached a terminal status after {timeout} seconds." ) # Wait to check the status of the pipeline run based on the ``check_interval`` configured. time.sleep(check_interval) pipeline_run_status = self.get_pipeline_run_status(run_id, resource_group_name, factory_name) return pipeline_run_status in expected_statuses
@provide_targeted_factory
[docs] def cancel_pipeline_run( self, run_id: str, resource_group_name: str, factory_name: str, **config: Any, ) -> None: """ Cancel the pipeline run. :param run_id: The pipeline run identifier. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. """ self.get_conn().pipeline_runs.cancel(resource_group_name, factory_name, run_id, **config)
@provide_targeted_factory
[docs] def get_trigger( self, trigger_name: str, resource_group_name: str, factory_name: str, **config: Any, ) -> TriggerResource | None: """ Get the trigger. :param trigger_name: The trigger name. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. :return: The trigger. """ return self.get_conn().triggers.get(resource_group_name, factory_name, trigger_name, **config)
def _trigger_exists(self, resource_group_name, factory_name, trigger_name) -> bool: """Return whether or not the trigger already exists.""" triggers = { trigger.name for trigger in self.get_conn().triggers.list_by_factory(resource_group_name, factory_name) } return trigger_name in triggers @provide_targeted_factory
[docs] def update_trigger( self, trigger_name: str, trigger: TriggerResource, resource_group_name: str, factory_name: str, if_match: str | None = None, **config: Any, ) -> TriggerResource: """ Update the trigger. :param trigger_name: The trigger name. :param trigger: The trigger resource definition. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param if_match: ETag of the trigger entity. Should only be specified for update, for which it should match existing entity or can be * for unconditional update. Default value is None. :param config: Extra parameters for the ADF client. :raise AirflowException: If the trigger does not exist. :return: The trigger. """ if not self._trigger_exists(resource_group_name, factory_name, trigger_name): raise AirflowException(f"Trigger {trigger_name!r} does not exist.") return self.get_conn().triggers.create_or_update( resource_group_name, factory_name, trigger_name, trigger, if_match, **config )
@provide_targeted_factory
[docs] def create_trigger( self, trigger_name: str, trigger: TriggerResource, resource_group_name: str, factory_name: str, **config: Any, ) -> TriggerResource: """ Create the trigger. :param trigger_name: The trigger name. :param trigger: The trigger resource definition. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. :raise AirflowException: If the trigger already exists. :return: The trigger. """ if self._trigger_exists(resource_group_name, factory_name, trigger_name): raise AirflowException(f"Trigger {trigger_name!r} already exists.") return self.get_conn().triggers.create_or_update( resource_group_name, factory_name, trigger_name, trigger, **config )
@provide_targeted_factory
[docs] def delete_trigger( self, trigger_name: str, resource_group_name: str, factory_name: str, **config: Any, ) -> None: """ Delete the trigger. :param trigger_name: The trigger name. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. """ self.get_conn().triggers.delete(resource_group_name, factory_name, trigger_name, **config)
@provide_targeted_factory
[docs] def start_trigger( self, trigger_name: str, resource_group_name: str, factory_name: str, **config: Any, ) -> LROPoller: """ Start the trigger. :param trigger_name: The trigger name. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. :return: An Azure operation poller. """ return self.get_conn().triggers.begin_start(resource_group_name, factory_name, trigger_name, **config)
@provide_targeted_factory
[docs] def stop_trigger( self, trigger_name: str, resource_group_name: str, factory_name: str, **config: Any, ) -> LROPoller: """ Stop the trigger. :param trigger_name: The trigger name. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. :return: An Azure operation poller. """ return self.get_conn().triggers.begin_stop(resource_group_name, factory_name, trigger_name, **config)
@provide_targeted_factory
[docs] def rerun_trigger( self, trigger_name: str, run_id: str, resource_group_name: str, factory_name: str, **config: Any, ) -> None: """ Rerun the trigger. :param trigger_name: The trigger name. :param run_id: The trigger run identifier. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. """ return self.get_conn().trigger_runs.rerun( resource_group_name, factory_name, trigger_name, run_id, **config )
@provide_targeted_factory
[docs] def cancel_trigger( self, trigger_name: str, run_id: str, resource_group_name: str, factory_name: str, **config: Any, ) -> None: """ Cancel the trigger. :param trigger_name: The trigger name. :param run_id: The trigger run identifier. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. """ self.get_conn().trigger_runs.cancel(resource_group_name, factory_name, trigger_name, run_id, **config)
[docs] def test_connection(self) -> tuple[bool, str]: """Test a configured Azure Data Factory connection.""" success = (True, "Successfully connected to Azure Data Factory.") try: # Attempt to list existing factories under the configured subscription and retrieve the first in # the returned iterator. The Azure Data Factory API does allow for creation of a # DataFactoryManagementClient with incorrect values but then will fail properly once items are # retrieved using the client. We need to _actually_ try to retrieve an object to properly test the # connection. self.get_conn().factories.list() return success except StopIteration: # If the iterator returned is empty it should still be considered a successful connection since # it's possible to create a Data Factory via the ``AzureDataFactoryHook`` and none could # legitimately exist yet. return success except Exception as e: return False, str(e)
[docs]def provide_targeted_factory_async(func: T) -> T: """ Provide the targeted factory to the async decorated function in case it isn't specified. If ``resource_group_name`` or ``factory_name`` is not provided it defaults to the value specified in the connection extras. """ signature = inspect.signature(func) @wraps(func) async def wrapper(*args: Any, **kwargs: Any) -> Any: bound_args = signature.bind(*args, **kwargs) async def bind_argument(arg: Any, default_key: str) -> None: # Check if arg was not included in the function signature or, if it is, the value is not provided. if arg not in bound_args.arguments or bound_args.arguments[arg] is None: self = args[0] conn = await sync_to_async(self.get_connection)(self.conn_id) extras = conn.extra_dejson default_value = extras.get(default_key) or extras.get( f"extra__azure_data_factory__{default_key}" ) if not default_value: raise AirflowException("Could not determine the targeted data factory.") bound_args.arguments[arg] = default_value await bind_argument("resource_group_name", "resource_group_name") await bind_argument("factory_name", "factory_name") return await func(*bound_args.args, **bound_args.kwargs) return cast(T, wrapper)
[docs]class AzureDataFactoryAsyncHook(AzureDataFactoryHook): """ An Async Hook that connects to Azure DataFactory to perform pipeline operations. :param azure_data_factory_conn_id: The :ref:`Azure Data Factory connection id<howto/connection:adf>`. """
[docs] default_conn_name: str = "azure_data_factory_default"
def __init__(self, azure_data_factory_conn_id: str = default_conn_name): self._async_conn: AsyncDataFactoryManagementClient | None = None self.conn_id = azure_data_factory_conn_id super().__init__(azure_data_factory_conn_id=azure_data_factory_conn_id)
[docs] async def get_async_conn(self) -> AsyncDataFactoryManagementClient: """Get async connection and connect to azure data factory.""" if self._async_conn is not None: return self._async_conn conn = await sync_to_async(self.get_connection)(self.conn_id) extras = conn.extra_dejson tenant = get_field(extras, "tenantId") try: subscription_id = get_field(extras, "subscriptionId", strict=True) except KeyError: raise ValueError("A Subscription ID is required to connect to Azure Data Factory.") credential: AsyncCredentials if conn.login is not None and conn.password is not None: if not tenant: raise ValueError("A Tenant ID is required when authenticating with Client ID and Secret.") credential = AsyncClientSecretCredential( client_id=conn.login, client_secret=conn.password, tenant_id=tenant ) else: managed_identity_client_id = get_field(extras, "managed_identity_client_id") workload_identity_tenant_id = get_field(extras, "workload_identity_tenant_id") credential = get_async_default_azure_credential( managed_identity_client_id=managed_identity_client_id, workload_identity_tenant_id=workload_identity_tenant_id, ) self._async_conn = AsyncDataFactoryManagementClient( credential=credential, subscription_id=subscription_id, ) return self._async_conn
[docs] async def refresh_conn(self) -> AsyncDataFactoryManagementClient: # type: ignore[override] self._conn = None return await self.get_async_conn()
@provide_targeted_factory_async
[docs] async def get_pipeline_run( self, run_id: str, resource_group_name: str, factory_name: str, **config: Any, ) -> PipelineRun: """ Connect to Azure Data Factory asynchronously to get the pipeline run details by run id. :param run_id: The pipeline run identifier. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. """ client = await self.get_async_conn() pipeline_run = await client.pipeline_runs.get(resource_group_name, factory_name, run_id) return pipeline_run
[docs] async def get_adf_pipeline_run_status( self, run_id: str, resource_group_name: str, factory_name: str ) -> str: """ Connect to Azure Data Factory asynchronously and get the pipeline status by run_id. :param run_id: The pipeline run identifier. :param resource_group_name: The resource group name. :param factory_name: The factory name. """ pipeline_run = await self.get_pipeline_run(run_id, resource_group_name, factory_name) status: str = cast(str, pipeline_run.status) return status
@provide_targeted_factory_async
[docs] async def cancel_pipeline_run( self, run_id: str, resource_group_name: str, factory_name: str, **config: Any, ) -> None: """ Cancel the pipeline run. :param run_id: The pipeline run identifier. :param resource_group_name: The resource group name. :param factory_name: The factory name. :param config: Extra parameters for the ADF client. """ client = await self.get_async_conn() try: await client.pipeline_runs.cancel(resource_group_name, factory_name, run_id) except Exception as e: raise AirflowException(e)

Was this entry helpful?