@task.snowpark

Use the @task.snowpark to run Snowpark Python code in a Snowflake database.

Warning

  • Snowpark does not support Python 3.12 yet.

  • Currently, this decorator does not support Snowpark pandas API because conflicting pandas version is used in Airflow. Consider using Snowpark pandas API with other Snowpark decorators or operators.

Prerequisite Tasks

To use this decorator, you must do a few things:

Using the Operator

Use the snowflake_conn_id argument to specify connection used. If not specified, snowflake_default will be used.

An example usage of the @task.snowpark is as follows:

tests/system/snowflake/example_snowpark_decorator.py[source]

    @task.snowpark
    def setup_data(session: Session):
        # The Snowpark session object is injected as an argument
        data = [
            (1, 0, 5, "Product 1", "prod-1", 1, 10),
            (2, 1, 5, "Product 1A", "prod-1-A", 1, 20),
            (3, 1, 5, "Product 1B", "prod-1-B", 1, 30),
            (4, 0, 10, "Product 2", "prod-2", 2, 40),
            (5, 4, 10, "Product 2A", "prod-2-A", 2, 50),
            (6, 4, 10, "Product 2B", "prod-2-B", 2, 60),
            (7, 0, 20, "Product 3", "prod-3", 3, 70),
            (8, 7, 20, "Product 3A", "prod-3-A", 3, 80),
            (9, 7, 20, "Product 3B", "prod-3-B", 3, 90),
            (10, 0, 50, "Product 4", "prod-4", 4, 100),
            (11, 10, 50, "Product 4A", "prod-4-A", 4, 100),
            (12, 10, 50, "Product 4B", "prod-4-B", 4, 100),
        ]
        columns = ["id", "parent_id", "category_id", "name", "serial_number", "key", "3rd"]
        df = session.create_dataframe(data, schema=columns)
        table_name = "sample_product_data"
        df.write.save_as_table(table_name, mode="overwrite")
        return table_name

    table_name = setup_data()  # type: ignore[call-arg, misc]

    @task.snowpark
    def check_num_rows(table_name: str):
        # Alternatively, retrieve the Snowpark session object using `get_active_session`
        from snowflake.snowpark.context import get_active_session

        session = get_active_session()
        df = session.table(table_name)
        assert df.count() == 12

    check_num_rows(table_name)

As the example demonstrates, there are two ways to use the Snowpark session object in your Python function:

  • Pass the Snowpark session object to the function as a keyword argument named session. The Snowpark session will be automatically injected into the function, allowing you to use it as you normally would.

  • Use get_active_session function from Snowpark to retrieve the Snowpark session object inside the function.

Note

Parameters that can be passed onto the decorators will be given priority over the parameters already given in the Airflow connection metadata (such as schema, role, database and so forth).

Was this entry helpful?