Source code for tests.system.amazon.aws.example_emr_eks

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import json
import logging
import subprocess
import time
from datetime import datetime

import boto3
from tenacity import retry, retry_if_exception_type, stop_after_delay, wait_exponential

from airflow.providers.amazon.aws.hooks.eks import ClusterStates, NodegroupStates
from airflow.providers.amazon.aws.operators.eks import EksCreateClusterOperator, EksDeleteClusterOperator
from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator, EmrEksCreateClusterOperator
from airflow.providers.amazon.aws.operators.s3 import (
    S3CreateBucketOperator,
    S3CreateObjectOperator,
    S3DeleteBucketOperator,
)
from airflow.providers.amazon.aws.sensors.eks import EksClusterStateSensor, EksNodegroupStateSensor
from airflow.providers.amazon.aws.sensors.emr import EmrContainerSensor

from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

if AIRFLOW_V_3_0_PLUS:
    from airflow.sdk import DAG, chain, task
else:
    # Airflow 2.10 compat
    from airflow.decorators import task  # type: ignore[attr-defined,no-redef]
    from airflow.models.baseoperator import chain  # type: ignore[attr-defined,no-redef]
    from airflow.models.dag import DAG  # type: ignore[attr-defined,no-redef,assignment]
try:
    from airflow.sdk import TriggerRule
except ImportError:
    # Compatibility for Airflow < 3.1
    from airflow.utils.trigger_rule import TriggerRule  # type: ignore[no-redef,attr-defined]

from system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder
from system.amazon.aws.utils.k8s import get_describe_pod_operator

[docs] DAG_ID = "example_emr_eks"
# Externally fetched variables
[docs] ROLE_ARN_KEY = "ROLE_ARN"
[docs] JOB_ROLE_ARN_KEY = "JOB_ROLE_ARN"
[docs] JOB_ROLE_NAME_KEY = "JOB_ROLE_NAME"
[docs] SUBNETS_KEY = "SUBNETS"
[docs] sys_test_context_task = ( SystemTestContextBuilder() .add_variable(ROLE_ARN_KEY) .add_variable(JOB_ROLE_ARN_KEY) .add_variable(JOB_ROLE_NAME_KEY) .add_variable(SUBNETS_KEY, split_string=True) .build() )
[docs] S3_FILE_NAME = "pi.py"
[docs] S3_FILE_CONTENT = """ k = 1 s = 0 for i in range(1000000): if i % 2 == 0: s += 4/k else: s -= 4/k k += 2 print(s) """
@task
[docs] def create_launch_template(template_name: str): # This launch template enables IMDSv2. boto3.client("ec2").create_launch_template( LaunchTemplateName=template_name, LaunchTemplateData={ "MetadataOptions": {"HttpEndpoint": "enabled", "HttpTokens": "required"}, }, )
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs] def delete_launch_template(template_name: str): boto3.client("ec2").delete_launch_template(LaunchTemplateName=template_name)
@task
[docs] def run_eksctl_commands(cluster_name, ns): # Install eksctl and enable access for EMR on EKS # See https://docs.aws.amazon.com/emr/latest/EMR-on-EKS-DevelopmentGuide/setting-up-cluster-access.html file = "https://github.com/weaveworks/eksctl/releases/latest/download/eksctl_$(uname -s)_amd64.tar.gz" commands = f""" curl --silent --location "{file}" | tar xz -C /tmp && /tmp/eksctl create iamidentitymapping --cluster {cluster_name} --namespace {ns} --service-name "emr-containers" && /tmp/eksctl utils associate-iam-oidc-provider --cluster {cluster_name} --approve """ build = subprocess.Popen( commands, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) _, err = build.communicate() if build.returncode != 0: raise RuntimeError(err)
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs] def delete_iam_oidc_identity_provider(cluster_name): oidc_provider_issuer_url = boto3.client("eks").describe_cluster( name=cluster_name, )["cluster"]["identity"]["oidc"]["issuer"] oidc_provider_issuer_endpoint = oidc_provider_issuer_url.replace("https://", "") account_id = boto3.client("sts").get_caller_identity()["Account"] boto3.client("iam").delete_open_id_connect_provider( OpenIDConnectProviderArn=f"arn:aws:iam::{account_id}:oidc-provider/{oidc_provider_issuer_endpoint}" )
@task
[docs] def update_trust_policy_execution_role(cluster_name, cluster_namespace, role_name): # Remove any already existing trusted entities added with "update-role-trust-policy" # Prevent getting an error "Cannot exceed quota for ACLSizePerRole" client = boto3.client("iam") role_trust_policy = client.get_role(RoleName=role_name)["Role"]["AssumeRolePolicyDocument"] # We assume if the action is sts:AssumeRoleWithWebIdentity, the statement had been added with # "update-role-trust-policy". Removing it to not exceed the quota role_trust_policy["Statement"] = [ statement for statement in role_trust_policy["Statement"] if statement["Action"] != "sts:AssumeRoleWithWebIdentity" ] client.update_assume_role_policy( RoleName=role_name, PolicyDocument=json.dumps(role_trust_policy), ) # See https://docs.aws.amazon.com/emr/latest/EMR-on-EKS-DevelopmentGuide/setting-up-trust-policy.html # The action "update-role-trust-policy" is not available in boto3, thus we need to do it using AWS CLI commands = ( f"aws emr-containers update-role-trust-policy --cluster-name {cluster_name} " f"--namespace {cluster_namespace} --role-name {role_name}" ) build = subprocess.Popen( commands, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) _, err = build.communicate() if build.returncode != 0: raise RuntimeError(err)
@task
[docs] def wait_for_trust_policy_propagation(cluster_name, role_name): """Validate that the IAM trust policy has propagated by checking the role's trust policy contains the expected OIDC provider. Uses exponential backoff retries (up to 5 minutes) instead of a fixed sleep, which avoids both wasting time when propagation is fast and failing when it's slow. """ log = logging.getLogger(__name__) # Determine the expected OIDC provider ARN from the EKS cluster eks_client = boto3.client("eks") oidc_issuer_url = eks_client.describe_cluster(name=cluster_name)["cluster"]["identity"]["oidc"]["issuer"] oidc_issuer_endpoint = oidc_issuer_url.replace("https://", "") account_id = boto3.client("sts").get_caller_identity()["Account"] expected_oidc_provider_arn = f"arn:aws:iam::{account_id}:oidc-provider/{oidc_issuer_endpoint}" @retry( retry=retry_if_exception_type(RuntimeError), wait=wait_exponential(multiplier=1, min=5, max=30), stop=stop_after_delay(300), reraise=True, ) def _validate_trust_policy(): iam_client = boto3.client("iam") # Verify the trust policy document contains the expected OIDC provider role = iam_client.get_role(RoleName=role_name)["Role"] trust_policy = role["AssumeRolePolicyDocument"] has_oidc_statement = False for statement in trust_policy.get("Statement", []): if statement.get("Action") != "sts:AssumeRoleWithWebIdentity": continue principal = statement.get("Principal", {}) federated = principal.get("Federated", "") if oidc_issuer_endpoint in federated: has_oidc_statement = True break if not has_oidc_statement: log.info( "Trust policy does not yet contain OIDC provider %s, retrying...", expected_oidc_provider_arn, ) raise RuntimeError( f"Trust policy for role {role_name} does not yet contain " f"the expected OIDC provider: {expected_oidc_provider_arn}" ) log.info("Trust policy document confirmed for role %s", role_name) _validate_trust_policy() # Brief buffer after IAM confirms the trust policy document — cross-service # caches (EKS/EMR) may still serve the old policy for a few seconds. time.sleep(15) log.info("Trust policy validation complete, proceeding.")
@task(trigger_rule=TriggerRule.ALL_DONE)
[docs] def delete_virtual_cluster(virtual_cluster_id): boto3.client("emr-containers").delete_virtual_cluster( id=virtual_cluster_id, )
with DAG( dag_id=DAG_ID, schedule="@once", start_date=datetime(2021, 1, 1), catchup=False, ) as dag:
[docs] test_context = sys_test_context_task()
env_id = test_context[ENV_ID_KEY] role_arn = test_context[ROLE_ARN_KEY] subnets = test_context[SUBNETS_KEY] job_role_arn = test_context[JOB_ROLE_ARN_KEY] job_role_name = test_context[JOB_ROLE_NAME_KEY] s3_bucket_name = f"{env_id}-bucket" eks_cluster_name = f"{env_id}-cluster" virtual_cluster_name = f"{env_id}-virtual-cluster" nodegroup_name = f"{env_id}-nodegroup" eks_namespace = "default" launch_template_name = f"{env_id}-launch-template" # [START howto_operator_emr_eks_config] job_driver_arg = { "sparkSubmitJobDriver": { "entryPoint": f"s3://{s3_bucket_name}/{S3_FILE_NAME}", "sparkSubmitParameters": "--conf spark.executors.instances=2 --conf spark.executors.memory=2G " "--conf spark.executor.cores=2 --conf spark.driver.cores=1", } } configuration_overrides_arg = { "monitoringConfiguration": { "cloudWatchMonitoringConfiguration": { "logGroupName": "/emr-eks-jobs", "logStreamNamePrefix": "airflow", } }, } # [END howto_operator_emr_eks_config] create_bucket = S3CreateBucketOperator( task_id="create_bucket", bucket_name=s3_bucket_name, ) upload_s3_file = S3CreateObjectOperator( task_id="upload_s3_file", s3_bucket=s3_bucket_name, s3_key=S3_FILE_NAME, data=S3_FILE_CONTENT, ) create_cluster_and_nodegroup = EksCreateClusterOperator( task_id="create_cluster_and_nodegroup", cluster_name=eks_cluster_name, nodegroup_name=nodegroup_name, cluster_role_arn=role_arn, # Opting to use the same ARN for the cluster and the nodegroup here, # but a different ARN could be configured and passed if desired. nodegroup_role_arn=role_arn, resources_vpc_config={"subnetIds": subnets}, # The launch template enforces IMDSv2 and is required for internal # compliance when running these system tests on AWS infrastructure. create_nodegroup_kwargs={"launchTemplate": {"name": launch_template_name}}, ) await_create_nodegroup = EksNodegroupStateSensor( task_id="await_create_nodegroup", cluster_name=eks_cluster_name, nodegroup_name=nodegroup_name, target_state=NodegroupStates.ACTIVE, poke_interval=10, ) # [START howto_operator_emr_eks_create_cluster] create_emr_eks_cluster = EmrEksCreateClusterOperator( task_id="create_emr_eks_cluster", virtual_cluster_name=virtual_cluster_name, eks_cluster_name=eks_cluster_name, eks_namespace=eks_namespace, ) # [END howto_operator_emr_eks_create_cluster] # [START howto_operator_emr_container] job_starter = EmrContainerOperator( task_id="start_job", virtual_cluster_id=str(create_emr_eks_cluster.output), execution_role_arn=job_role_arn, release_label="emr-7.0.0-latest", job_driver=job_driver_arg, configuration_overrides=configuration_overrides_arg, name="pi.py", ) # [END howto_operator_emr_container] job_starter.wait_for_completion = False job_starter.job_retry_max_attempts = 5 # [START howto_sensor_emr_container] job_waiter = EmrContainerSensor( task_id="job_waiter", virtual_cluster_id=str(create_emr_eks_cluster.output), job_id=str(job_starter.output), ) # [END howto_sensor_emr_container] # Describe pods only on failure to help diagnose EMR on EKS job issues. describe_pod = get_describe_pod_operator(cluster_name=eks_cluster_name, namespace=eks_namespace) describe_pod.trigger_rule = TriggerRule.ONE_FAILED delete_eks_cluster = EksDeleteClusterOperator( task_id="delete_eks_cluster", cluster_name=eks_cluster_name, force_delete_compute=True, trigger_rule=TriggerRule.ALL_DONE, ) await_delete_eks_cluster = EksClusterStateSensor( task_id="await_delete_eks_cluster", cluster_name=eks_cluster_name, target_state=ClusterStates.NONEXISTENT, trigger_rule=TriggerRule.ALL_DONE, poke_interval=10, ) delete_bucket = S3DeleteBucketOperator( task_id="delete_bucket", bucket_name=s3_bucket_name, force_delete=True, trigger_rule=TriggerRule.ALL_DONE, ) chain( # TEST SETUP test_context, create_bucket, upload_s3_file, create_launch_template(launch_template_name), create_cluster_and_nodegroup, await_create_nodegroup, run_eksctl_commands(eks_cluster_name, eks_namespace), update_trust_policy_execution_role(eks_cluster_name, eks_namespace, job_role_name), wait_for_trust_policy_propagation(eks_cluster_name, job_role_name), # TEST BODY create_emr_eks_cluster, job_starter, job_waiter, describe_pod, # TEST TEARDOWN delete_iam_oidc_identity_provider(eks_cluster_name), delete_virtual_cluster(str(create_emr_eks_cluster.output)), delete_eks_cluster, await_delete_eks_cluster, delete_launch_template(launch_template_name), delete_bucket, ) from tests_common.test_utils.watcher import watcher # This test needs watcher in order to properly mark success/failure # when "tearDown" task with trigger rule is part of the DAG list(dag.tasks) >> watcher() from tests_common.test_utils.system_tests import get_test_run # noqa: E402 # Needed to run the example DAG with pytest (see: contributing-docs/testing/system_tests.rst)
[docs] test_run = get_test_run(dag)

Was this entry helpful?