# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
from airflow.sdk import (
DAG,
AllowedKeyMapper,
Asset,
CronPartitionTimetable,
DayWindow,
FanOutMapper,
FixedKeyMapper,
IdentityMapper,
MonthWindow,
PartitionAtRuntime,
PartitionedAssetTimetable,
ProductMapper,
RollupMapper,
SegmentWindow,
StartOfDayMapper,
StartOfHourMapper,
StartOfMonthMapper,
StartOfWeekMapper,
StartOfYearMapper,
WeekWindow,
asset,
task,
)
[docs]
team_a_player_stats = Asset(uri="file://incoming/player-stats/team_a.csv", name="team_a_player_stats")
[docs]
combined_player_stats = Asset(uri="file://curated/player-stats/combined.csv", name="combined_player_stats")
with DAG(
dag_id="ingest_team_a_player_stats",
schedule=CronPartitionTimetable("0 * * * *", timezone="UTC"),
tags=["example", "player-stats", "ingestion"],
):
"""Produce hourly partitioned stats for Team A."""
@task(outlets=[team_a_player_stats])
[docs]
def ingest_team_a_stats():
"""Materialize Team A player statistics for the current hourly partition."""
pass
ingest_team_a_stats()
@asset(
uri="file://incoming/player-stats/team_b.csv",
schedule=CronPartitionTimetable("15 * * * *", timezone="UTC"),
tags=["player-stats", "ingestion"],
)
[docs]
def team_b_player_stats():
"""Produce hourly partitioned stats for Team B."""
pass
@asset(
uri="file://incoming/player-stats/team_c.csv",
schedule=CronPartitionTimetable("30 * * * *", timezone="UTC"),
tags=["player-stats", "ingestion"],
)
[docs]
def team_c_player_stats():
"""Produce hourly partitioned stats for Team C."""
pass
with DAG(
dag_id="clean_and_combine_player_stats",
schedule=PartitionedAssetTimetable(
assets=team_a_player_stats & team_b_player_stats & team_c_player_stats,
default_partition_mapper=StartOfHourMapper(),
),
catchup=False,
tags=["example", "player-stats", "cleanup"],
):
"""
Combine hourly partitions from Team A, B and C into a single curated dataset.
This Dag demonstrates multi-asset partition alignment using StartOfHourMapper.
"""
@task(outlets=[combined_player_stats])
[docs]
def combine_player_stats(dag_run=None):
"""Merge the aligned hourly partitions into a combined dataset."""
if TYPE_CHECKING:
assert dag_run
print(dag_run.partition_key)
combine_player_stats()
@asset(
uri="file://analytics/player-stats/computed-player-odds.csv",
# Fallback to IdentityMapper if no partition_mapper is specified.
# If we want to other temporal mapper (e.g., StartOfHourMapper) here,
# make sure the input_format is changed since the partition_key is now in "%Y-%m-%dT%H" format
# instead of a valid timestamp
schedule=PartitionedAssetTimetable(assets=combined_player_stats),
tags=["player-stats", "odds"],
)
[docs]
def compute_player_odds():
"""
Compute player odds from the combined hourly statistics.
This asset is partition-aware and triggered by the combined stats asset.
"""
pass
with DAG(
dag_id="player_odds_quality_check_wont_ever_to_trigger",
schedule=PartitionedAssetTimetable(
assets=(combined_player_stats & team_a_player_stats & Asset.ref(name="team_b_player_stats")),
partition_mapper_config={
combined_player_stats: StartOfYearMapper(), # incompatible on purpose
team_a_player_stats: StartOfHourMapper(),
Asset.ref(name="team_b_player_stats"): StartOfHourMapper(),
},
),
catchup=False,
tags=["example", "player-stats", "odds"],
):
"""
Demonstrate a partition mapper mismatch scenario.
The configured partition mapper transforms partition keys into formats
that never matches ("%Y" v.s. "%Y-%m-%dT%H), so the Dag will never trigger.
"""
@task
[docs]
def check_partition_alignment():
pass
check_partition_alignment()
[docs]
regional_sales = Asset(uri="file://incoming/sales/regional.csv", name="regional_sales")
with DAG(
dag_id="ingest_regional_sales",
schedule=CronPartitionTimetable("0 * * * *", timezone="UTC"),
tags=["example", "sales", "ingestion"],
):
"""Produce hourly regional sales data with composite partition keys."""
@task(outlets=[regional_sales])
[docs]
def ingest_sales():
"""Ingest regional sales data partitioned by region and time."""
pass
ingest_sales()
with DAG(
dag_id="aggregate_regional_sales",
schedule=PartitionedAssetTimetable(
assets=regional_sales,
default_partition_mapper=ProductMapper(IdentityMapper(), StartOfDayMapper()),
),
catchup=False,
tags=["example", "sales", "aggregation"],
):
"""
Aggregate regional sales using ProductMapper.
The ProductMapper splits the composite key "region|timestamp" and applies
IdentityMapper to the region segment and StartOfDayMapper to the timestamp segment,
aligning hourly partitions to daily granularity per region.
"""
@task
[docs]
def aggregate_sales(dag_run=None):
"""Aggregate sales data for the matched region-day partition."""
if TYPE_CHECKING:
assert dag_run
print(dag_run.partition_key)
aggregate_sales()
[docs]
region_raw_stats = Asset(uri="file://incoming/player-stats/by-region.csv", name="region_raw_stats")
with DAG(
dag_id="ingest_region_stats",
schedule=None,
tags=["example", "player-stats", "regional"],
):
"""
Ingest player statistics per region.
Externally triggered with partition_key set to a region code (``us``, ``eu``, ``apac``).
"""
@task(outlets=[region_raw_stats])
[docs]
def ingest_region():
"""Materialize player statistics for a single region partition."""
pass
ingest_region()
@asset(
uri="file://analytics/player-stats/regional-breakdown.csv",
schedule=PartitionedAssetTimetable(
assets=region_raw_stats,
default_partition_mapper=AllowedKeyMapper(["us", "eu", "apac"]),
),
tags=["player-stats", "regional"],
)
[docs]
def regional_stats_breakdown():
"""
Aggregate regional player statistics.
This asset demonstrates AllowedKeyMapper, which validates that upstream partition
keys belong to a fixed set of allowed values (``us``, ``eu``, ``apac``) rather than time-based partitions.
"""
pass
@asset(
uri="file://incoming/player-stats/live-region.csv",
schedule=PartitionAtRuntime(),
tags=["player-stats", "runtime"],
)
[docs]
def live_region_player_stats(self, outlet_events):
"""
Produce a single region partition whose key is decided at runtime.
This asset demonstrates PartitionAtRuntime, which records the partition key on the
emitted event with ``add_partitions`` while the task runs rather than from a timetable.
"""
outlet_events[self].add_partitions("us")
with DAG(
dag_id="summarize_live_region_stats",
schedule=PartitionedAssetTimetable(assets=Asset.ref(name="live_region_player_stats")),
catchup=False,
tags=["example", "player-stats", "runtime"],
):
"""
Summarize the live region statistics for each runtime-emitted partition.
Triggered once per partition key recorded upstream at runtime.
"""
@task
[docs]
def summarize_live_region(dag_run=None):
"""Summarize stats for the matched runtime partition."""
if TYPE_CHECKING:
assert dag_run
print(dag_run.partition_key)
summarize_live_region()
@asset(
uri="file://incoming/player-stats/multi-region.csv",
schedule=PartitionAtRuntime(),
tags=["player-stats", "runtime"],
)
[docs]
def multi_region_player_stats(self, outlet_events):
"""
Produce several region partitions from a single run.
This asset demonstrates runtime fan-out, where each key emits its own asset event
and duplicate keys collapse to a single event.
"""
outlet_events[self].add_partitions(["us", "eu", "apac"])
[docs]
daily_sales = Asset(uri="s3://sales/daily", name="daily_sales")
[docs]
daily_costs = Asset(uri="s3://costs/daily", name="daily_costs")
# --- Chained rollup: hourly → daily → monthly --------------------------------
# The hourly source asset already exists above (``team_a_player_stats``).
# Each rollup Dag publishes its own asset so the next level can consume it.
[docs]
daily_team_a = Asset(uri="s3://team-a/daily", name="daily_team_a")
[docs]
monthly_team_a = Asset(uri="s3://team-a/monthly", name="monthly_team_a")
with DAG(
dag_id="daily_team_a_rollup",
schedule=PartitionedAssetTimetable(
assets=team_a_player_stats,
default_partition_mapper=RollupMapper(
upstream_mapper=StartOfDayMapper(),
window=DayWindow(),
),
),
catchup=False,
tags=["example", "player-stats", "rollup"],
):
"""
First rollup level: 24 hourly partitions of ``team_a_player_stats`` → one daily summary.
``StartOfDayMapper`` normalizes each upstream hourly timestamp (``%Y-%m-%dT%H:%M:%S``)
to its day-start (``%Y-%m-%d``); ``DayWindow`` declares the downstream run needs
all 24 hourly partitions before firing. Publishes ``daily_team_a`` so the
monthly rollup below can consume it.
"""
@task(outlets=[daily_team_a])
[docs]
def summarise_team_a_day(dag_run=None):
"""Produce the full-day rollup once every hour has arrived."""
if TYPE_CHECKING:
assert dag_run
print(f"All 24 hourly partitions received. Day: {dag_run.partition_key}")
summarise_team_a_day()
with DAG(
dag_id="monthly_team_a_rollup",
schedule=PartitionedAssetTimetable(
assets=daily_team_a,
# The upstream (``daily_team_a``) emits day-formatted partition keys
# (``%Y-%m-%d``), so the upstream mapper here must accept that format.
default_partition_mapper=RollupMapper(
upstream_mapper=StartOfMonthMapper(input_format="%Y-%m-%d"),
window=MonthWindow(),
),
),
catchup=False,
tags=["example", "player-stats", "rollup"],
):
"""
Chained rollup: every day of ``daily_team_a`` (itself a rollup) → one monthly summary.
Demonstrates how a rollup output can feed another rollup. ``StartOfMonthMapper``
is configured with ``input_format="%Y-%m-%d"`` so it can parse the day keys
emitted by ``daily_team_a_rollup``; ``MonthWindow`` waits for every day of the
calendar month (28–31 depending on the month). The partition key is the month
identifier, e.g. ``2024-01``.
"""
@task(outlets=[monthly_team_a])
[docs]
def summarise_team_a_month(dag_run=None):
"""Produce the full-month rollup once every day has arrived."""
if TYPE_CHECKING:
assert dag_run
print(f"All daily partitions received. Month: {dag_run.partition_key}")
summarise_team_a_month()
# --- Fan-out: one weekly upstream → seven daily downstream Dag runs ----------
[docs]
weekly_model_artifact = Asset(uri="file://artifacts/models/weekly.bin", name="weekly_model_artifact")
with DAG(
dag_id="train_weekly_model",
schedule=CronPartitionTimetable("0 0 * * 1", timezone="UTC"),
catchup=False,
tags=["example", "model", "training"],
):
"""Train a weekly model artifact every Monday at 00:00 UTC."""
@task(outlets=[weekly_model_artifact])
[docs]
def train_model():
"""Materialize the model artifact for the current weekly partition."""
pass
train_model()
with DAG(
dag_id="daily_inference",
schedule=PartitionedAssetTimetable(
assets=weekly_model_artifact,
# FanOutMapper composes upstream_mapper + window + (optional) downstream_mapper.
# WeekWindow.to_upstream() yields seven daily datetimes inside one week,
# and the default downstream_mapper for WeekWindow is StartOfDayMapper, so
# a weekly upstream key fans out to seven ``%Y-%m-%d`` downstream keys.
default_partition_mapper=FanOutMapper(
upstream_mapper=StartOfWeekMapper(),
window=WeekWindow(),
),
),
catchup=False,
tags=["example", "model", "inference"],
):
"""Run daily inference, fanning the weekly model artifact out to one Dag run per day."""
@task
[docs]
def run_inference(dag_run=None):
"""Run inference for one daily partition derived from the weekly model."""
if TYPE_CHECKING:
assert dag_run
print(dag_run.partition_key)
run_inference()
# --- Segment (categorical) rollup -------------------------------------------
# ``multi_region_player_stats`` (defined above) emits one partition per region
# (``us``, ``eu``, ``apac``) from a single run. The Dag below holds a downstream
# run until every declared region key has arrived.
with DAG(
dag_id="segment_region_stats_rollup",
schedule=PartitionedAssetTimetable(
assets=Asset.ref(name="multi_region_player_stats"),
default_partition_mapper=RollupMapper(
upstream_mapper=FixedKeyMapper("all_regions"),
window=SegmentWindow(["us", "eu", "apac"]),
),
),
catchup=False,
tags=["example", "player-stats", "rollup", "segment"],
):
"""
Categorical rollup: hold until all three region partitions arrive.
``RollupMapper(upstream_mapper=FixedKeyMapper("all_regions"), window=SegmentWindow([...]))``
declares the fixed set of region keys required for one downstream run and collapses every
region key onto a single ``all_regions`` partition, so the three region events accumulate
into one downstream run. The run is held until ``us``, ``eu``, and ``apac`` have all
arrived from ``multi_region_player_stats``; partial arrivals remain pending in the
next-run-assets view so operators can track progress.
"""
@task
[docs]
def aggregate_all_regions(dag_run=None):
"""Produce the cross-region summary once every region partition has arrived."""
if TYPE_CHECKING:
assert dag_run
print(f"All region partitions received. Partition: {dag_run.partition_key}")
aggregate_all_regions()