Source code for tests.system.google.cloud.dataprep.example_dataprep

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Example Airflow DAG that shows how to use Google Dataprep.

This DAG relies on the following OS environment variables

* SYSTEM_TESTS_DATAPREP_TOKEN - Dataprep API access token.
  For generating it please use instruction
  https://docs.trifacta.com/display/DP/Manage+API+Access+Tokens#:~:text=Enable%20individual%20access-,Generate%20New%20Token,-Via%20UI.
"""

from __future__ import annotations

import logging
import os
from datetime import datetime

from providers.google.tests.system.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID

from airflow import models
from airflow.decorators import task
from airflow.models import Connection
from airflow.models.baseoperator import chain
from airflow.providers.google.cloud.hooks.dataprep import GoogleDataprepHook
from airflow.providers.google.cloud.operators.dataprep import (
    DataprepCopyFlowOperator,
    DataprepDeleteFlowOperator,
    DataprepGetJobGroupOperator,
    DataprepGetJobsForJobGroupOperator,
    DataprepRunFlowOperator,
    DataprepRunJobGroupOperator,
)
from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator
from airflow.providers.google.cloud.sensors.dataprep import DataprepJobGroupIsFinishedSensor
from airflow.settings import Session
from airflow.utils.trigger_rule import TriggerRule

[docs] ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
[docs] DAG_ID = "dataprep"
[docs] CONNECTION_ID = f"connection_{DAG_ID}_{ENV_ID}".replace("-", "_")
[docs] DATAPREP_TOKEN = os.environ.get("SYSTEM_TESTS_DATAPREP_TOKEN", "")
[docs] GCP_PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID
[docs] GCS_BUCKET_NAME = f"dataprep-bucket-{DAG_ID}-{ENV_ID}"
[docs] GCS_BUCKET_PATH = f"gs://{GCS_BUCKET_NAME}/task_results/"
[docs] DATASET_URI = "gs://airflow-system-tests-resources/dataprep/dataset-00000.parquet"
[docs] DATASET_NAME = f"dataset_{DAG_ID}_{ENV_ID}".replace("-", "_")
[docs] DATASET_WRANGLED_NAME = f"wrangled_{DATASET_NAME}"
[docs] DATASET_WRANGLED_ID = "{{ task_instance.xcom_pull('create_wrangled_dataset')['id'] }}"
[docs] FLOW_ID = "{{ task_instance.xcom_pull('create_flow')['id'] }}"
[docs] FLOW_COPY_ID = "{{ task_instance.xcom_pull('copy_flow')['id'] }}"
[docs] RECIPE_NAME = DATASET_WRANGLED_NAME
[docs] WRITE_SETTINGS = { "writesettings": [ { "path": GCS_BUCKET_PATH + f"adhoc_{RECIPE_NAME}.csv", "action": "create", "format": "csv", }, ], }
[docs] log = logging.getLogger(__name__)
with models.DAG( DAG_ID, schedule="@once", start_date=datetime(2021, 1, 1), # Override to match your needs catchup=False, tags=["example", "dataprep"], render_template_as_native_obj=True, ) as dag:
[docs] create_bucket_task = GCSCreateBucketOperator( task_id="create_bucket", bucket_name=GCS_BUCKET_NAME, project_id=GCP_PROJECT_ID, )
@task def create_connection(connection_id: str) -> None: connection = Connection( conn_id=connection_id, description="Example Dataprep connection", conn_type="dataprep", extra={"token": DATAPREP_TOKEN}, ) session = Session() log.info("Removing connection %s if it exists", connection_id) query = session.query(Connection).filter(Connection.conn_id == connection_id) query.delete() session.add(connection) session.commit() log.info("Connection created: '%s'", connection_id) create_connection_task = create_connection(connection_id=CONNECTION_ID) @task def create_imported_dataset(): hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID) response = hook.create_imported_dataset( body_request={ "uri": DATASET_URI, "name": DATASET_NAME, } ) return response create_imported_dataset_task = create_imported_dataset() @task def create_flow(): hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID) response = hook.create_flow( body_request={ "name": f"test_flow_{DAG_ID}_{ENV_ID}", "description": "Test flow", } ) return response create_flow_task = create_flow() @task def create_wrangled_dataset(flow, imported_dataset): hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID) response = hook.create_wrangled_dataset( body_request={ "importedDataset": {"id": imported_dataset["id"]}, "flow": {"id": flow["id"]}, "name": DATASET_WRANGLED_NAME, } ) return response create_wrangled_dataset_task = create_wrangled_dataset(create_flow_task, create_imported_dataset_task) @task def create_output(wrangled_dataset): hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID) response = hook.create_output_object( body_request={ "execution": "dataflow", "profiler": False, "flowNodeId": wrangled_dataset["id"], } ) return response create_output_task = create_output(create_wrangled_dataset_task) @task def create_write_settings(output): hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID) response = hook.create_write_settings( body_request={ "path": GCS_BUCKET_PATH + f"adhoc_{RECIPE_NAME}.csv", "action": "create", "format": "csv", "outputObjectId": output["id"], } ) return response create_write_settings_task = create_write_settings(create_output_task) # [START how_to_dataprep_copy_flow_operator] copy_task = DataprepCopyFlowOperator( task_id="copy_flow", dataprep_conn_id=CONNECTION_ID, project_id=GCP_PROJECT_ID, flow_id=FLOW_ID, name=f"copy_{DATASET_NAME}", ) # [END how_to_dataprep_copy_flow_operator] # [START how_to_dataprep_run_job_group_operator] run_job_group_task = DataprepRunJobGroupOperator( task_id="run_job_group", dataprep_conn_id=CONNECTION_ID, project_id=GCP_PROJECT_ID, body_request={ "wrangledDataset": {"id": DATASET_WRANGLED_ID}, "overrides": WRITE_SETTINGS, }, ) # [END how_to_dataprep_run_job_group_operator] # [START how_to_dataprep_dataprep_run_flow_operator] run_flow_task = DataprepRunFlowOperator( task_id="run_flow", dataprep_conn_id=CONNECTION_ID, project_id=GCP_PROJECT_ID, flow_id=FLOW_COPY_ID, body_request={}, ) # [END how_to_dataprep_dataprep_run_flow_operator] # [START how_to_dataprep_get_job_group_operator] get_job_group_task = DataprepGetJobGroupOperator( task_id="get_job_group", dataprep_conn_id=CONNECTION_ID, project_id=GCP_PROJECT_ID, job_group_id="{{ task_instance.xcom_pull('run_flow')['data'][0]['id'] }}", embed="", include_deleted=False, ) # [END how_to_dataprep_get_job_group_operator] # [START how_to_dataprep_get_jobs_for_job_group_operator] get_jobs_for_job_group_task = DataprepGetJobsForJobGroupOperator( task_id="get_jobs_for_job_group", dataprep_conn_id=CONNECTION_ID, job_group_id="{{ task_instance.xcom_pull('run_flow')['data'][0]['id'] }}", ) # [END how_to_dataprep_get_jobs_for_job_group_operator] # [START how_to_dataprep_job_group_finished_sensor] check_flow_status_sensor = DataprepJobGroupIsFinishedSensor( task_id="check_flow_status", dataprep_conn_id=CONNECTION_ID, job_group_id="{{ task_instance.xcom_pull('run_flow')['data'][0]['id'] }}", ) # [END how_to_dataprep_job_group_finished_sensor] # [START how_to_dataprep_job_group_finished_sensor] check_job_group_status_sensor = DataprepJobGroupIsFinishedSensor( task_id="check_job_group_status", dataprep_conn_id=CONNECTION_ID, job_group_id="{{ task_instance.xcom_pull('run_job_group')['id'] }}", ) # [END how_to_dataprep_job_group_finished_sensor] # [START how_to_dataprep_delete_flow_operator] delete_flow_task = DataprepDeleteFlowOperator( task_id="delete_flow", dataprep_conn_id=CONNECTION_ID, flow_id="{{ task_instance.xcom_pull('copy_flow')['id'] }}", ) # [END how_to_dataprep_delete_flow_operator] delete_flow_task.trigger_rule = TriggerRule.ALL_DONE delete_flow_task_original = DataprepDeleteFlowOperator( task_id="delete_flow_original", dataprep_conn_id=CONNECTION_ID, flow_id="{{ task_instance.xcom_pull('create_flow')['id'] }}", trigger_rule=TriggerRule.ALL_DONE, ) @task(trigger_rule=TriggerRule.ALL_DONE) def delete_dataset(dataset): hook = GoogleDataprepHook(dataprep_conn_id=CONNECTION_ID) hook.delete_imported_dataset(dataset_id=dataset["id"]) delete_dataset_task = delete_dataset(create_imported_dataset_task) delete_bucket_task = GCSDeleteBucketOperator( task_id="delete_bucket", bucket_name=GCS_BUCKET_NAME, trigger_rule=TriggerRule.ALL_DONE, ) @task(task_id="delete_connection") def delete_connection(connection_id: str) -> None: session = Session() log.info("Removing connection %s", connection_id) query = session.query(Connection).filter(Connection.conn_id == connection_id) query.delete() session.commit() delete_connection_task = delete_connection(connection_id=CONNECTION_ID) chain( # TEST SETUP create_bucket_task, create_connection_task, [create_imported_dataset_task, create_flow_task], create_wrangled_dataset_task, create_output_task, create_write_settings_task, # TEST BODY copy_task, [run_job_group_task, run_flow_task], [get_job_group_task, get_jobs_for_job_group_task], [check_flow_status_sensor, check_job_group_status_sensor], # TEST TEARDOWN delete_dataset_task, [delete_flow_task, delete_flow_task_original], [delete_bucket_task, delete_connection_task], ) from tests_common.test_utils.watcher import watcher # This test needs watcher in order to properly mark success/failure # when "tearDown" task with trigger rule is part of the DAG list(dag.tasks) >> watcher() from tests_common.test_utils.system_tests import get_test_run # noqa: E402 # Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
[docs] test_run = get_test_run(dag)

Was this entry helpful?