Source code for airflow.example_dags.example_asset_partition

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

from airflow.sdk import (
    DAG,
    AllowedKeyMapper,
    Asset,
    CronPartitionTimetable,
    DayWindow,
    FanOutMapper,
    FixedKeyMapper,
    IdentityMapper,
    MonthWindow,
    PartitionAtRuntime,
    PartitionedAssetTimetable,
    ProductMapper,
    RollupMapper,
    SegmentWindow,
    StartOfDayMapper,
    StartOfHourMapper,
    StartOfMonthMapper,
    StartOfWeekMapper,
    StartOfYearMapper,
    WeekWindow,
    asset,
    task,
)

[docs] team_a_player_stats = Asset(uri="file://incoming/player-stats/team_a.csv", name="team_a_player_stats")
[docs] combined_player_stats = Asset(uri="file://curated/player-stats/combined.csv", name="combined_player_stats")
with DAG( dag_id="ingest_team_a_player_stats", schedule=CronPartitionTimetable("0 * * * *", timezone="UTC"), tags=["example", "player-stats", "ingestion"], ): """Produce hourly partitioned stats for Team A.""" @task(outlets=[team_a_player_stats])
[docs] def ingest_team_a_stats(): """Materialize Team A player statistics for the current hourly partition.""" pass
ingest_team_a_stats() @asset( uri="file://incoming/player-stats/team_b.csv", schedule=CronPartitionTimetable("15 * * * *", timezone="UTC"), tags=["player-stats", "ingestion"], )
[docs] def team_b_player_stats(): """Produce hourly partitioned stats for Team B.""" pass
@asset( uri="file://incoming/player-stats/team_c.csv", schedule=CronPartitionTimetable("30 * * * *", timezone="UTC"), tags=["player-stats", "ingestion"], )
[docs] def team_c_player_stats(): """Produce hourly partitioned stats for Team C.""" pass
with DAG( dag_id="clean_and_combine_player_stats", schedule=PartitionedAssetTimetable( assets=team_a_player_stats & team_b_player_stats & team_c_player_stats, default_partition_mapper=StartOfHourMapper(), ), catchup=False, tags=["example", "player-stats", "cleanup"], ): """ Combine hourly partitions from Team A, B and C into a single curated dataset. This Dag demonstrates multi-asset partition alignment using StartOfHourMapper. """ @task(outlets=[combined_player_stats])
[docs] def combine_player_stats(dag_run=None): """Merge the aligned hourly partitions into a combined dataset.""" if TYPE_CHECKING: assert dag_run print(dag_run.partition_key)
combine_player_stats() @asset( uri="file://analytics/player-stats/computed-player-odds.csv", # Fallback to IdentityMapper if no partition_mapper is specified. # If we want to other temporal mapper (e.g., StartOfHourMapper) here, # make sure the input_format is changed since the partition_key is now in "%Y-%m-%dT%H" format # instead of a valid timestamp schedule=PartitionedAssetTimetable(assets=combined_player_stats), tags=["player-stats", "odds"], )
[docs] def compute_player_odds(): """ Compute player odds from the combined hourly statistics. This asset is partition-aware and triggered by the combined stats asset. """ pass
with DAG( dag_id="player_odds_quality_check_wont_ever_to_trigger", schedule=PartitionedAssetTimetable( assets=(combined_player_stats & team_a_player_stats & Asset.ref(name="team_b_player_stats")), partition_mapper_config={ combined_player_stats: StartOfYearMapper(), # incompatible on purpose team_a_player_stats: StartOfHourMapper(), Asset.ref(name="team_b_player_stats"): StartOfHourMapper(), }, ), catchup=False, tags=["example", "player-stats", "odds"], ): """ Demonstrate a partition mapper mismatch scenario. The configured partition mapper transforms partition keys into formats that never matches ("%Y" v.s. "%Y-%m-%dT%H), so the Dag will never trigger. """ @task
[docs] def check_partition_alignment(): pass
check_partition_alignment()
[docs] regional_sales = Asset(uri="file://incoming/sales/regional.csv", name="regional_sales")
with DAG( dag_id="ingest_regional_sales", schedule=CronPartitionTimetable("0 * * * *", timezone="UTC"), tags=["example", "sales", "ingestion"], ): """Produce hourly regional sales data with composite partition keys.""" @task(outlets=[regional_sales])
[docs] def ingest_sales(): """Ingest regional sales data partitioned by region and time.""" pass
ingest_sales() with DAG( dag_id="aggregate_regional_sales", schedule=PartitionedAssetTimetable( assets=regional_sales, default_partition_mapper=ProductMapper(IdentityMapper(), StartOfDayMapper()), ), catchup=False, tags=["example", "sales", "aggregation"], ): """ Aggregate regional sales using ProductMapper. The ProductMapper splits the composite key "region|timestamp" and applies IdentityMapper to the region segment and StartOfDayMapper to the timestamp segment, aligning hourly partitions to daily granularity per region. """ @task
[docs] def aggregate_sales(dag_run=None): """Aggregate sales data for the matched region-day partition.""" if TYPE_CHECKING: assert dag_run print(dag_run.partition_key)
aggregate_sales()
[docs] region_raw_stats = Asset(uri="file://incoming/player-stats/by-region.csv", name="region_raw_stats")
with DAG( dag_id="ingest_region_stats", schedule=None, tags=["example", "player-stats", "regional"], ): """ Ingest player statistics per region. Externally triggered with partition_key set to a region code (``us``, ``eu``, ``apac``). """ @task(outlets=[region_raw_stats])
[docs] def ingest_region(): """Materialize player statistics for a single region partition.""" pass
ingest_region() @asset( uri="file://analytics/player-stats/regional-breakdown.csv", schedule=PartitionedAssetTimetable( assets=region_raw_stats, default_partition_mapper=AllowedKeyMapper(["us", "eu", "apac"]), ), tags=["player-stats", "regional"], )
[docs] def regional_stats_breakdown(): """ Aggregate regional player statistics. This asset demonstrates AllowedKeyMapper, which validates that upstream partition keys belong to a fixed set of allowed values (``us``, ``eu``, ``apac``) rather than time-based partitions. """ pass
@asset( uri="file://incoming/player-stats/live-region.csv", schedule=PartitionAtRuntime(), tags=["player-stats", "runtime"], )
[docs] def live_region_player_stats(self, outlet_events): """ Produce a single region partition whose key is decided at runtime. This asset demonstrates PartitionAtRuntime, which records the partition key on the emitted event with ``add_partitions`` while the task runs rather than from a timetable. """ outlet_events[self].add_partitions("us")
with DAG( dag_id="summarize_live_region_stats", schedule=PartitionedAssetTimetable(assets=Asset.ref(name="live_region_player_stats")), catchup=False, tags=["example", "player-stats", "runtime"], ): """ Summarize the live region statistics for each runtime-emitted partition. Triggered once per partition key recorded upstream at runtime. """ @task
[docs] def summarize_live_region(dag_run=None): """Summarize stats for the matched runtime partition.""" if TYPE_CHECKING: assert dag_run print(dag_run.partition_key)
summarize_live_region() @asset( uri="file://incoming/player-stats/multi-region.csv", schedule=PartitionAtRuntime(), tags=["player-stats", "runtime"], )
[docs] def multi_region_player_stats(self, outlet_events): """ Produce several region partitions from a single run. This asset demonstrates runtime fan-out, where each key emits its own asset event and duplicate keys collapse to a single event. """ outlet_events[self].add_partitions(["us", "eu", "apac"])
[docs] daily_sales = Asset(uri="s3://sales/daily", name="daily_sales")
[docs] daily_costs = Asset(uri="s3://costs/daily", name="daily_costs")
# --- Chained rollup: hourly → daily → monthly -------------------------------- # The hourly source asset already exists above (``team_a_player_stats``). # Each rollup Dag publishes its own asset so the next level can consume it.
[docs] daily_team_a = Asset(uri="s3://team-a/daily", name="daily_team_a")
[docs] monthly_team_a = Asset(uri="s3://team-a/monthly", name="monthly_team_a")
with DAG( dag_id="daily_team_a_rollup", schedule=PartitionedAssetTimetable( assets=team_a_player_stats, default_partition_mapper=RollupMapper( upstream_mapper=StartOfDayMapper(), window=DayWindow(), ), ), catchup=False, tags=["example", "player-stats", "rollup"], ): """ First rollup level: 24 hourly partitions of ``team_a_player_stats`` → one daily summary. ``StartOfDayMapper`` normalizes each upstream hourly timestamp (``%Y-%m-%dT%H:%M:%S``) to its day-start (``%Y-%m-%d``); ``DayWindow`` declares the downstream run needs all 24 hourly partitions before firing. Publishes ``daily_team_a`` so the monthly rollup below can consume it. """ @task(outlets=[daily_team_a])
[docs] def summarise_team_a_day(dag_run=None): """Produce the full-day rollup once every hour has arrived.""" if TYPE_CHECKING: assert dag_run print(f"All 24 hourly partitions received. Day: {dag_run.partition_key}")
summarise_team_a_day() with DAG( dag_id="monthly_team_a_rollup", schedule=PartitionedAssetTimetable( assets=daily_team_a, # The upstream (``daily_team_a``) emits day-formatted partition keys # (``%Y-%m-%d``), so the upstream mapper here must accept that format. default_partition_mapper=RollupMapper( upstream_mapper=StartOfMonthMapper(input_format="%Y-%m-%d"), window=MonthWindow(), ), ), catchup=False, tags=["example", "player-stats", "rollup"], ): """ Chained rollup: every day of ``daily_team_a`` (itself a rollup) → one monthly summary. Demonstrates how a rollup output can feed another rollup. ``StartOfMonthMapper`` is configured with ``input_format="%Y-%m-%d"`` so it can parse the day keys emitted by ``daily_team_a_rollup``; ``MonthWindow`` waits for every day of the calendar month (28–31 depending on the month). The partition key is the month identifier, e.g. ``2024-01``. """ @task(outlets=[monthly_team_a])
[docs] def summarise_team_a_month(dag_run=None): """Produce the full-month rollup once every day has arrived.""" if TYPE_CHECKING: assert dag_run print(f"All daily partitions received. Month: {dag_run.partition_key}")
summarise_team_a_month() # --- Fan-out: one weekly upstream → seven daily downstream Dag runs ----------
[docs] weekly_model_artifact = Asset(uri="file://artifacts/models/weekly.bin", name="weekly_model_artifact")
with DAG( dag_id="train_weekly_model", schedule=CronPartitionTimetable("0 0 * * 1", timezone="UTC"), catchup=False, tags=["example", "model", "training"], ): """Train a weekly model artifact every Monday at 00:00 UTC.""" @task(outlets=[weekly_model_artifact])
[docs] def train_model(): """Materialize the model artifact for the current weekly partition.""" pass
train_model() with DAG( dag_id="daily_inference", schedule=PartitionedAssetTimetable( assets=weekly_model_artifact, # FanOutMapper composes upstream_mapper + window + (optional) downstream_mapper. # WeekWindow.to_upstream() yields seven daily datetimes inside one week, # and the default downstream_mapper for WeekWindow is StartOfDayMapper, so # a weekly upstream key fans out to seven ``%Y-%m-%d`` downstream keys. default_partition_mapper=FanOutMapper( upstream_mapper=StartOfWeekMapper(), window=WeekWindow(), ), ), catchup=False, tags=["example", "model", "inference"], ): """Run daily inference, fanning the weekly model artifact out to one Dag run per day.""" @task
[docs] def run_inference(dag_run=None): """Run inference for one daily partition derived from the weekly model.""" if TYPE_CHECKING: assert dag_run print(dag_run.partition_key)
run_inference() # --- Segment (categorical) rollup ------------------------------------------- # ``multi_region_player_stats`` (defined above) emits one partition per region # (``us``, ``eu``, ``apac``) from a single run. The Dag below holds a downstream # run until every declared region key has arrived. with DAG( dag_id="segment_region_stats_rollup", schedule=PartitionedAssetTimetable( assets=Asset.ref(name="multi_region_player_stats"), default_partition_mapper=RollupMapper( upstream_mapper=FixedKeyMapper("all_regions"), window=SegmentWindow(["us", "eu", "apac"]), ), ), catchup=False, tags=["example", "player-stats", "rollup", "segment"], ): """ Categorical rollup: hold until all three region partitions arrive. ``RollupMapper(upstream_mapper=FixedKeyMapper("all_regions"), window=SegmentWindow([...]))`` declares the fixed set of region keys required for one downstream run and collapses every region key onto a single ``all_regions`` partition, so the three region events accumulate into one downstream run. The run is held until ``us``, ``eu``, and ``apac`` have all arrived from ``multi_region_player_stats``; partial arrivals remain pending in the next-run-assets view so operators can track progress. """ @task
[docs] def aggregate_all_regions(dag_run=None): """Produce the cross-region summary once every region partition has arrived.""" if TYPE_CHECKING: assert dag_run print(f"All region partitions received. Partition: {dag_run.partition_key}")
aggregate_all_regions()

Was this entry helpful?