Source code for airflow.providers.amazon.aws.bundles.s3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import os
from pathlib import Path
import structlog
from airflow.dag_processing.bundles.base import BaseDagBundle
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
[docs]
class S3DagBundle(BaseDagBundle):
"""
S3 DAG bundle - exposes a directory in S3 as a DAG bundle.
This allows Airflow to load DAGs directly from an S3 bucket.
:param aws_conn_id: Airflow connection ID for AWS. Defaults to AwsBaseHook.default_conn_name.
:param bucket_name: The name of the S3 bucket containing the DAG files.
:param prefix: Optional subdirectory within the S3 bucket where the DAGs are stored.
If None, DAGs are assumed to be at the root of the bucket (Optional).
"""
[docs]
supports_versioning = False
def __init__(
self,
*,
aws_conn_id: str = AwsBaseHook.default_conn_name,
bucket_name: str,
prefix: str = "",
**kwargs,
) -> None:
super().__init__(**kwargs)
[docs]
self.aws_conn_id = aws_conn_id
[docs]
self.bucket_name = bucket_name
# Local path where S3 DAGs are downloaded
[docs]
self.s3_dags_dir: Path = self.base_dir
log = structlog.get_logger(__name__)
self._log = log.bind(
bundle_name=self.name,
version=self.version,
bucket_name=self.bucket_name,
prefix=self.prefix,
aws_conn_id=self.aws_conn_id,
)
self._s3_hook: S3Hook | None = None
def _initialize(self):
with self.lock():
if not self.s3_dags_dir.exists():
self._log.info("Creating local DAGs directory: %s", self.s3_dags_dir)
os.makedirs(self.s3_dags_dir)
if not self.s3_dags_dir.is_dir():
raise AirflowException(f"Local DAGs path: {self.s3_dags_dir} is not a directory.")
if not self.s3_hook.check_for_bucket(bucket_name=self.bucket_name):
raise AirflowException(f"S3 bucket '{self.bucket_name}' does not exist.")
if self.prefix:
# don't check when prefix is ""
if not self.s3_hook.check_for_prefix(
bucket_name=self.bucket_name, prefix=self.prefix, delimiter="/"
):
raise AirflowException(
f"S3 prefix 's3://{self.bucket_name}/{self.prefix}' does not exist."
)
self.refresh()
[docs]
def initialize(self) -> None:
self._initialize()
super().initialize()
@property
[docs]
def s3_hook(self):
if self._s3_hook is None:
try:
self._s3_hook: S3Hook = S3Hook(aws_conn_id=self.aws_conn_id) # Initialize S3 hook.
except AirflowException as e:
self._log.warning("Could not create S3Hook for connection %s: %s", self.aws_conn_id, e)
return self._s3_hook
[docs]
def __repr__(self):
return (
f"<S3DagBundle("
f"name={self.name!r}, "
f"bucket_name={self.bucket_name!r}, "
f"prefix={self.prefix!r}, "
f"version={self.version!r}"
f")>"
)
[docs]
def get_current_version(self) -> str | None:
"""Return the current version of the DAG bundle. Currently not supported."""
return None
@property
[docs]
def path(self) -> Path:
"""Return the local path to the DAG files."""
return self.s3_dags_dir # Path where DAGs are downloaded.
[docs]
def refresh(self) -> None:
"""Refresh the DAG bundle by re-downloading the DAGs from S3."""
if self.version:
raise AirflowException("Refreshing a specific version is not supported")
with self.lock():
self._log.debug(
"Downloading DAGs from s3://%s/%s to %s", self.bucket_name, self.prefix, self.s3_dags_dir
)
self.s3_hook.sync_to_local_dir(
bucket_name=self.bucket_name,
s3_prefix=self.prefix,
local_dir=self.s3_dags_dir,
delete_stale=True,
)
[docs]
def view_url(self, version: str | None = None) -> str | None:
"""Return a URL for viewing the DAGs in S3. Currently, versioning is not supported."""
if self.version:
raise AirflowException("S3 url with version is not supported")
# https://<bucket-name>.s3.<region>.amazonaws.com/<object-key>
url = f"https://{self.bucket_name}.s3"
if self.s3_hook.region_name:
url += f".{self.s3_hook.region_name}"
url += ".amazonaws.com"
if self.prefix:
url += f"/{self.prefix}"
return url