Source code for tests.system.amazon.aws.example_sagemaker_condition

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
System test for SageMakerConditionOperator.

This operator evaluates conditions against XCom values passed from upstream tasks.

The DAG simulates an ML accuracy-gate workflow:

1. ``produce_metrics`` pushes a dict of metrics to XCom.
2. ``check_accuracy`` uses SageMakerConditionOperator to branch:
   - accuracy >= 0.9 AND loss < 0.1 -> ``deploy_model``
   - otherwise -> ``retrain_model``
3. Only the correct branch task runs; the other is skipped.
"""

from __future__ import annotations

from datetime import datetime

from airflow.providers.amazon.aws.operators.sagemaker import SageMakerConditionOperator

from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

if AIRFLOW_V_3_0_PLUS:
    from airflow.sdk import DAG, chain, task
else:
    from airflow.decorators import task  # type: ignore[attr-defined,no-redef]
    from airflow.models.baseoperator import chain  # type: ignore[attr-defined,no-redef]
    from airflow.models.dag import DAG  # type: ignore[attr-defined,no-redef,assignment]

from system.amazon.aws.utils import SystemTestContextBuilder

[docs] DAG_ID = "example_sagemaker_condition"
[docs] sys_test_context_task = SystemTestContextBuilder().build()
with DAG( DAG_ID, schedule="@once", start_date=datetime(2021, 1, 1), catchup=False, ) as dag:
[docs] test_context = sys_test_context_task()
# TEST SETUP: push simulated ML metrics to XCom @task def produce_metrics(): """Simulate an ML training job that returns accuracy and loss metrics.""" return {"accuracy": 0.95, "loss": 0.04} metrics = produce_metrics() # TEST BODY # [START howto_operator_sagemaker_condition] check_accuracy = SageMakerConditionOperator( task_id="check_accuracy", conditions=[ { "type": "GreaterThanOrEqualTo", "left_value": "{{ ti.xcom_pull(task_ids='produce_metrics')['accuracy'] }}", "right_value": 0.9, }, { "type": "LessThan", "left_value": "{{ ti.xcom_pull(task_ids='produce_metrics')['loss'] }}", "right_value": 0.1, }, ], if_task_ids=["deploy_model"], else_task_ids=["retrain_model"], ) # [END howto_operator_sagemaker_condition] @task def deploy_model(): """Placeholder: model meets quality bar, proceed to deployment.""" return "deployed" @task def retrain_model(): """Placeholder: model does not meet quality bar, retrain.""" return "retrained" # Scenario 2: condition evaluates to False -> else branch @task def produce_bad_metrics(): """Simulate a training job with poor accuracy.""" return {"accuracy": 0.5, "loss": 0.8} bad_metrics = produce_bad_metrics() # [START howto_operator_sagemaker_condition_flat] check_bad_accuracy = SageMakerConditionOperator( task_id="check_bad_accuracy", condition_type="GreaterThanOrEqualTo", left_value="{{ ti.xcom_pull(task_ids='produce_bad_metrics')['accuracy'] }}", right_value=0.9, if_task_ids=["should_not_run"], else_task_ids=["should_run"], ) # [END howto_operator_sagemaker_condition_flat] @task def should_not_run(): """This task should be skipped because accuracy < 0.9.""" return "error: should not have run" @task def should_run(): """This task should execute because accuracy < 0.9 -> else branch.""" return "correctly routed to else branch" # Scenario 3: Or condition + Not condition # [START howto_operator_sagemaker_condition_not_or] check_logical = SageMakerConditionOperator( task_id="check_logical", conditions=[ { "type": "Or", "conditions": [ {"type": "Equals", "left_value": 1, "right_value": 2}, {"type": "Equals", "left_value": 3, "right_value": 3}, ], }, { "type": "Not", "condition": {"type": "Equals", "left_value": "a", "right_value": "b"}, }, ], if_task_ids=["logical_pass"], else_task_ids=["logical_fail"], ) # [END howto_operator_sagemaker_condition_not_or] @task def logical_pass(): """Or(1==2, 3==3) -> True AND Not(a==b) -> True -> if branch.""" return "logical conditions passed" @task def logical_fail(): return "error: logical conditions should have passed" test_context >> [metrics, bad_metrics, check_logical] chain(metrics, check_accuracy, [deploy_model(), retrain_model()]) chain(bad_metrics, check_bad_accuracy, [should_not_run(), should_run()]) chain(check_logical, [logical_pass(), logical_fail()]) from tests_common.test_utils.watcher import watcher # This test needs watcher in order to properly mark success/failure # when "tearDown" task with trigger rule is part of the DAG list(dag.tasks) >> watcher() from tests_common.test_utils.system_tests import get_test_run # noqa: E402 # Needed to run the example DAG with pytest (see: contributing-docs/testing/system_tests.rst)
[docs] test_run = get_test_run(dag)

Was this entry helpful?