Airflow Summit 2025 is coming October 07-09. Register now for early bird ticket!

Source code for tests.system.amazon.aws.example_datasync

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING

import boto3

from airflow.providers.amazon.aws.operators.datasync import DataSyncOperator
from airflow.providers.amazon.aws.operators.s3 import S3CreateBucketOperator, S3DeleteBucketOperator

from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

if TYPE_CHECKING:
    from airflow.decorators import task
    from airflow.models.baseoperator import chain
    from airflow.models.dag import DAG
else:
    if AIRFLOW_V_3_0_PLUS:
        from airflow.sdk import DAG, chain, task
    else:
        # Airflow 2.10 compat
        from airflow.decorators import task
        from airflow.models.baseoperator import chain
        from airflow.models.dag import DAG
from airflow.utils.trigger_rule import TriggerRule

from system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder

[docs] DAG_ID = "example_datasync"
# Externally fetched variables:
[docs] ROLE_ARN_KEY = "ROLE_ARN"
[docs] sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build()
[docs] def get_s3_bucket_arn(bucket_name): return f"arn:aws:s3:::{bucket_name}"
[docs] def create_location(bucket_name, role_arn): client = boto3.client("datasync") response = client.create_location_s3( Subdirectory="test", S3BucketArn=get_s3_bucket_arn(bucket_name), S3Config={ "BucketAccessRoleArn": role_arn, }, ) return response["LocationArn"]
@task
[docs] def create_source_location(bucket_source, role_arn): return create_location(bucket_source, role_arn)
@task
[docs] def create_destination_location(bucket_destination, role_arn): return create_location(bucket_destination, role_arn)
@task
[docs] def create_task(**kwargs): client = boto3.client("datasync") response = client.create_task( SourceLocationArn=kwargs["ti"].xcom_pull("create_source_location"), DestinationLocationArn=kwargs["ti"].xcom_pull("create_destination_location"), ) return response["TaskArn"]
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs] def delete_task(task_arn): client = boto3.client("datasync") client.delete_task( TaskArn=task_arn, )
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs] def delete_task_created_by_operator(**kwargs): client = boto3.client("datasync") client.delete_task( TaskArn=kwargs["ti"].xcom_pull("create_and_execute_task")["TaskArn"], )
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs] def list_locations(bucket_source, bucket_destination): client = boto3.client("datasync") return client.list_locations( Filters=[ { "Name": "LocationUri", "Values": [ f"s3://{bucket_source}/test/", f"s3://{bucket_destination}/test/", f"s3://{bucket_source}/test_create/", f"s3://{bucket_destination}/test_create/", ], "Operator": "In", } ] )
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs] def delete_locations(locations): client = boto3.client("datasync") for location in locations["Locations"]: client.delete_location( LocationArn=location["LocationArn"], )
with DAG( DAG_ID, schedule="@once", start_date=datetime(2021, 1, 1), catchup=False, tags=["example"], ) as dag:
[docs] test_context = sys_test_context_task()
s3_bucket_source: str = f"{test_context[ENV_ID_KEY]}-datasync-bucket-source" s3_bucket_destination: str = f"{test_context[ENV_ID_KEY]}-datasync-bucket-destination" create_s3_bucket_source = S3CreateBucketOperator( task_id="create_s3_bucket_source", bucket_name=s3_bucket_source ) create_s3_bucket_destination = S3CreateBucketOperator( task_id="create_s3_bucket_destination", bucket_name=s3_bucket_destination ) source_location = create_source_location(s3_bucket_source, test_context[ROLE_ARN_KEY]) destination_location = create_destination_location(s3_bucket_destination, test_context[ROLE_ARN_KEY]) created_task_arn = create_task() # [START howto_operator_datasync_specific_task] # Execute a specific task execute_task_by_arn = DataSyncOperator( task_id="execute_task_by_arn", task_arn=created_task_arn, ) # [END howto_operator_datasync_specific_task] # DataSyncOperator waits by default, setting as False to test the Sensor below. execute_task_by_arn.wait_for_completion = False # [START howto_operator_datasync_search_task] # Search and execute a task execute_task_by_locations = DataSyncOperator( task_id="execute_task_by_locations", source_location_uri=f"s3://{s3_bucket_source}/test", destination_location_uri=f"s3://{s3_bucket_destination}/test", # Only transfer files from /test/subdir folder task_execution_kwargs={ "Includes": [{"FilterType": "SIMPLE_PATTERN", "Value": "/test/subdir"}], }, ) # [END howto_operator_datasync_search_task] # DataSyncOperator waits by default, setting as False to test the Sensor below. execute_task_by_locations.wait_for_completion = False # [START howto_operator_datasync_create_task] # Create a task (the task does not exist) create_and_execute_task = DataSyncOperator( task_id="create_and_execute_task", source_location_uri=f"s3://{s3_bucket_source}/test_create", destination_location_uri=f"s3://{s3_bucket_destination}/test_create", create_task_kwargs={"Name": "Created by Airflow"}, create_source_location_kwargs={ "Subdirectory": "test_create", "S3BucketArn": get_s3_bucket_arn(s3_bucket_source), "S3Config": { "BucketAccessRoleArn": test_context[ROLE_ARN_KEY], }, }, create_destination_location_kwargs={ "Subdirectory": "test_create", "S3BucketArn": get_s3_bucket_arn(s3_bucket_destination), "S3Config": { "BucketAccessRoleArn": test_context[ROLE_ARN_KEY], }, }, delete_task_after_execution=False, ) # [END howto_operator_datasync_create_task] # DataSyncOperator waits by default, setting as False to test the Sensor below. create_and_execute_task.wait_for_completion = False locations_task = list_locations(s3_bucket_source, s3_bucket_destination) delete_locations_task = delete_locations(locations_task) delete_s3_bucket_source = S3DeleteBucketOperator( task_id="delete_s3_bucket_source", bucket_name=s3_bucket_source, force_delete=True, trigger_rule=TriggerRule.ALL_DONE, ) delete_s3_bucket_destination = S3DeleteBucketOperator( task_id="delete_s3_bucket_destination", bucket_name=s3_bucket_destination, force_delete=True, trigger_rule=TriggerRule.ALL_DONE, ) chain( # TEST SETUP test_context, create_s3_bucket_source, create_s3_bucket_destination, source_location, destination_location, created_task_arn, # TEST BODY execute_task_by_arn, execute_task_by_locations, create_and_execute_task, # TEST TEARDOWN delete_task(created_task_arn), delete_task_created_by_operator(), locations_task, delete_locations_task, delete_s3_bucket_source, delete_s3_bucket_destination, ) from tests_common.test_utils.watcher import watcher # This test needs watcher in order to properly mark success/failure # when "tearDown" task with trigger rule is part of the DAG list(dag.tasks) >> watcher() from tests_common.test_utils.system_tests import get_test_run # noqa: E402 # Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
[docs] test_run = get_test_run(dag)

Was this entry helpful?