Source code for tests.system.amazon.aws.example_sagemaker_unified_studio

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from datetime import datetime

import pytest

from airflow.decorators import task
from airflow.models.baseoperator import chain
from airflow.models.dag import DAG
from airflow.providers.amazon.aws.operators.sagemaker_unified_studio import (
    SageMakerNotebookOperator,
)
from system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder

from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS

"""
Prerequisites: The account which runs this test must manually have the following:
1. An IAM IDC organization set up in the testing region with a user initialized
2. A SageMaker Unified Studio Domain (with default VPC and roles)
3. A project within the SageMaker Unified Studio Domain
4. A notebook (test_notebook.ipynb) placed in the project's s3 path

This test will emulate a DAG run in the shared MWAA environment inside a SageMaker Unified Studio Project.
The setup tasks will set up the project and configure the test runner to emulate an MWAA instance.
Then, the SageMakerNotebookOperator will run a test notebook. This should spin up a SageMaker training job, run the notebook, and exit successfully.
"""

[docs] pytestmark = pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Test requires Airflow 2.10+")
[docs] DAG_ID = "example_sagemaker_unified_studio"
# Externally fetched variables:
[docs] DOMAIN_ID_KEY = "DOMAIN_ID"
[docs] PROJECT_ID_KEY = "PROJECT_ID"
[docs] ENVIRONMENT_ID_KEY = "ENVIRONMENT_ID"
[docs] S3_PATH_KEY = "S3_PATH"
[docs] REGION_NAME_KEY = "REGION_NAME"
[docs] sys_test_context_task = ( SystemTestContextBuilder() .add_variable(DOMAIN_ID_KEY) .add_variable(PROJECT_ID_KEY) .add_variable(ENVIRONMENT_ID_KEY) .add_variable(S3_PATH_KEY) .add_variable(REGION_NAME_KEY) .build() )
[docs] def get_mwaa_environment_params( domain_id: str, project_id: str, environment_id: str, s3_path: str, region_name: str, ): AIRFLOW_PREFIX = "AIRFLOW__WORKFLOWS__" parameters = {} parameters[f"{AIRFLOW_PREFIX}DATAZONE_DOMAIN_ID"] = domain_id parameters[f"{AIRFLOW_PREFIX}DATAZONE_PROJECT_ID"] = project_id parameters[f"{AIRFLOW_PREFIX}DATAZONE_ENVIRONMENT_ID"] = environment_id parameters[f"{AIRFLOW_PREFIX}DATAZONE_SCOPE_NAME"] = "dev" parameters[f"{AIRFLOW_PREFIX}DATAZONE_STAGE"] = "prod" parameters[f"{AIRFLOW_PREFIX}DATAZONE_ENDPOINT"] = f"https://datazone.{region_name}.api.aws" parameters[f"{AIRFLOW_PREFIX}PROJECT_S3_PATH"] = s3_path parameters[f"{AIRFLOW_PREFIX}DATAZONE_DOMAIN_REGION"] = region_name return parameters
@task
[docs] def mock_mwaa_environment(parameters: dict): """ Sets several environment variables in the container to emulate an MWAA environment provisioned within SageMaker Unified Studio. When running in the ECSExecutor, this is a no-op. """ import os for key, value in parameters.items(): os.environ[key] = value
with DAG( DAG_ID, schedule="@once", start_date=datetime(2021, 1, 1), tags=["example"], catchup=False, ) as dag:
[docs] test_context = sys_test_context_task()
test_env_id = test_context[ENV_ID_KEY] domain_id = test_context[DOMAIN_ID_KEY] project_id = test_context[PROJECT_ID_KEY] environment_id = test_context[ENVIRONMENT_ID_KEY] s3_path = test_context[S3_PATH_KEY] region_name = test_context[REGION_NAME_KEY] mock_mwaa_environment_params = get_mwaa_environment_params( domain_id, project_id, environment_id, s3_path, region_name, ) setup_mwaa_environment = mock_mwaa_environment(mock_mwaa_environment_params) # [START howto_operator_sagemaker_unified_studio_notebook] notebook_path = "test_notebook.ipynb" # This should be the path to your .ipynb, .sqlnb, or .vetl file in your project. run_notebook = SageMakerNotebookOperator( task_id="run-notebook", input_config={"input_path": notebook_path, "input_params": {}}, output_config={"output_formats": ["NOTEBOOK"]}, # optional compute={ "instance_type": "ml.m5.large", "volume_size_in_gb": 30, }, # optional termination_condition={"max_runtime_in_seconds": 600}, # optional tags={}, # optional wait_for_completion=True, # optional waiter_delay=5, # optional deferrable=False, # optional executor_config={ # optional "overrides": {"containerOverrides": {"environment": mock_mwaa_environment_params}} }, ) # [END howto_operator_sagemaker_unified_studio_notebook] chain( # TEST SETUP test_context, setup_mwaa_environment, # TEST BODY run_notebook, ) from tests_common.test_utils.watcher import watcher # This test needs watcher in order to properly mark success/failure # when "tearDown" task with trigger rule is part of the DAG list(dag.tasks) >> watcher() from tests_common.test_utils.system_tests import get_test_run # noqa: E402 # Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
[docs] test_run = get_test_run(dag)

Was this entry helpful?