Source code for tests.system.amazon.aws.example_sagemaker_unified_studio_notebook

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from datetime import datetime

from airflow.providers.amazon.aws.operators.sagemaker_unified_studio_notebook import (
    SageMakerUnifiedStudioNotebookOperator,
)
from airflow.providers.amazon.aws.sensors.sagemaker_unified_studio_notebook import (
    SageMakerUnifiedStudioNotebookSensor,
)

from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

if AIRFLOW_V_3_0_PLUS:
    from airflow.sdk import DAG, chain
else:
    from airflow.models.baseoperator import chain  # type: ignore[attr-defined,no-redef]
    from airflow.models.dag import DAG  # type: ignore[attr-defined,no-redef,assignment]

from system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder

"""
Prerequisites: The account which runs this test must have the following:
1. A SageMaker Unified Studio Domain (with default VPC and roles)
2. A project within the SageMaker Unified Studio Domain
3. Two notebook assets registered in the project:
   - NOTEBOOK_ID: A notebook that produces output variables (e.g., name, age)
   - NOTEBOOK_B_ID: A notebook that accepts parameters and uses them

This test calls the DataZone StartNotebookRun / GetNotebookRun APIs directly
via boto3 using standard IAM credentials. No MWAA environment emulation is performed.
"""

[docs] DAG_ID = "example_sagemaker_unified_studio_notebook"
# Externally fetched variables:
[docs] DOMAIN_ID_KEY = "DOMAIN_ID"
[docs] PROJECT_ID_KEY = "PROJECT_ID"
[docs] NOTEBOOK_ID_KEY = "NOTEBOOK_ID"
[docs] NOTEBOOK_B_ID_KEY = "NOTEBOOK_B_ID"
[docs] sys_test_context_task = ( SystemTestContextBuilder() .add_variable(DOMAIN_ID_KEY) .add_variable(PROJECT_ID_KEY) .add_variable(NOTEBOOK_ID_KEY) .add_variable(NOTEBOOK_B_ID_KEY) .build() )
with DAG( DAG_ID, schedule="@once", start_date=datetime(2021, 1, 1), catchup=False, ) as dag:
[docs] test_context = sys_test_context_task()
test_env_id = test_context[ENV_ID_KEY] domain_id = test_context[DOMAIN_ID_KEY] project_id = test_context[PROJECT_ID_KEY] notebook_id = test_context[NOTEBOOK_ID_KEY] notebook_b_id = test_context[NOTEBOOK_B_ID_KEY] # [START howto_operator_sagemaker_unified_studio_notebook] import time client_token = f"idempotency-token-{int(time.time())}" run_notebook = SageMakerUnifiedStudioNotebookOperator( task_id="notebook-task", notebook_identifier=notebook_id, # This should be the notebook asset identifier from within the SageMaker Unified Studio domain domain_identifier=domain_id, owning_project_identifier=project_id, client_token=client_token, # optional notebook_parameters={ "param1": "value1", "param2": "value2", }, # optional compute_configuration={"instanceType": "sc.m5.large"}, # optional timeout_configuration={"runTimeoutInMinutes": 1440}, # optional wait_for_completion=True, # optional waiter_delay=30, # optional deferrable=False, # optional ) # [END howto_operator_sagemaker_unified_studio_notebook] # [START howto_operator_sagemaker_unified_studio_notebook_deferrable] run_notebook_deferrable = SageMakerUnifiedStudioNotebookOperator( task_id="notebook-deferrable-task", notebook_identifier=notebook_id, domain_identifier=domain_id, owning_project_identifier=project_id, deferrable=True, # optional waiter_delay=10, # optional ) # [END howto_operator_sagemaker_unified_studio_notebook_deferrable] # [START howto_sensor_sagemaker_unified_studio_notebook] run_sensor = SageMakerUnifiedStudioNotebookSensor( task_id="notebook-sensor-task", domain_identifier=domain_id, owning_project_identifier=project_id, notebook_identifier=notebook_id, notebook_run_id=run_notebook.output["notebook_run_id"], ) # [END howto_sensor_sagemaker_unified_studio_notebook] # [START howto_operator_sagemaker_unified_studio_notebook_pass_outputs] # Notebook A produces outputs (e.g., name, age) that are pushed to xcom. # Notebook B consumes those outputs via Jinja templating in notebook_parameters. run_notebook_a = SageMakerUnifiedStudioNotebookOperator( task_id="notebook-a-task", notebook_identifier=notebook_id, domain_identifier=domain_id, owning_project_identifier=project_id, wait_for_completion=True, ) run_notebook_b = SageMakerUnifiedStudioNotebookOperator( task_id="notebook-b-task", notebook_identifier=notebook_b_id, domain_identifier=domain_id, owning_project_identifier=project_id, notebook_parameters={ "employee_name": "{{ task_instance.xcom_pull(task_ids='notebook-a-task', key='NOTEBOOK_OUTPUT.name') }}", "employee_age": "{{ task_instance.xcom_pull(task_ids='notebook-a-task', key='NOTEBOOK_OUTPUT.age') }}", }, wait_for_completion=True, ) # [END howto_operator_sagemaker_unified_studio_notebook_pass_outputs] chain( test_context, run_notebook, run_notebook_deferrable, run_sensor, run_notebook_a, run_notebook_b, ) from tests_common.test_utils.watcher import watcher # This test needs watcher in order to properly mark success/failure # when "tearDown" task with trigger rule is part of the DAG list(dag.tasks) >> watcher() from tests_common.test_utils.system_tests import get_test_run # noqa: E402
[docs] test_run = get_test_run(dag)

Was this entry helpful?