#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Implements the ``@task_group`` function decorator.
When the decorated function is called, a task group will be created to represent
a collection of closely related tasks on the same DAG that should be grouped
together when the DAG is displayed graphically.
"""
from __future__ import annotations
import functools
import inspect
import warnings
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, TypeVar, overload
import attr
from airflow.decorators.base import ExpandableFactory
from airflow.models.expandinput import (
DictOfListsExpandInput,
ListOfDictsExpandInput,
MappedArgument,
)
from airflow.models.xcom_arg import XComArg
from airflow.sdk.definitions.node import DAGNode
from airflow.typing_compat import ParamSpec
from airflow.utils.helpers import prevent_duplicates
from airflow.utils.task_group import MappedTaskGroup, TaskGroup
if TYPE_CHECKING:
from airflow.models.dag import DAG
from airflow.models.expandinput import (
OperatorExpandArgument,
OperatorExpandKwargsArgument,
)
[docs]FParams = ParamSpec("FParams")
[docs]FReturn = TypeVar("FReturn", None, DAGNode)
[docs]task_group_sig = inspect.signature(TaskGroup.__init__)
@attr.define()
class _TaskGroupFactory(ExpandableFactory, Generic[FParams, FReturn]):
function: Callable[FParams, FReturn] = attr.ib(validator=attr.validators.is_callable())
tg_kwargs: dict[str, Any] = attr.ib(factory=dict) # Parameters forwarded to TaskGroup.
partial_kwargs: dict[str, Any] = attr.ib(factory=dict) # Parameters forwarded to 'function'.
_task_group_created: bool = attr.ib(False, init=False)
tg_class: ClassVar[type[TaskGroup]] = TaskGroup
@tg_kwargs.validator
def _validate(self, _, kwargs):
task_group_sig.bind_partial(**kwargs)
def __attrs_post_init__(self):
self.tg_kwargs.setdefault("group_id", self.function.__name__)
def __del__(self):
if self.partial_kwargs and not self._task_group_created:
try:
group_id = repr(self.tg_kwargs["group_id"])
except KeyError:
group_id = f"at {hex(id(self))}"
warnings.warn(f"Partial task group {group_id} was never mapped!", stacklevel=1)
def __call__(self, *args: FParams.args, **kwargs: FParams.kwargs) -> DAGNode:
"""
Instantiate the task group.
This uses the wrapped function to create a task group. Depending on the
return type of the wrapped function, this either returns the last task
in the group, or the group itself, to support task chaining.
"""
return self._create_task_group(TaskGroup, *args, **kwargs)
def _create_task_group(self, tg_factory: Callable[..., TaskGroup], *args: Any, **kwargs: Any) -> DAGNode:
with tg_factory(add_suffix_on_collision=True, **self.tg_kwargs) as task_group:
if self.function.__doc__ and not task_group.tooltip:
task_group.tooltip = self.function.__doc__
# Invoke function to run Tasks inside the TaskGroup
retval = self.function(*args, **kwargs)
self._task_group_created = True
# If the task-creating function returns a task, forward the return value
# so dependencies bind to it. This is equivalent to
# with TaskGroup(...) as tg:
# t2 = task_2(task_1())
# start >> t2 >> end
if retval is not None:
return retval
# Otherwise return the task group as a whole, equivalent to
# with TaskGroup(...) as tg:
# task_1()
# task_2()
# start >> tg >> end
return task_group
def override(self, **kwargs: Any) -> _TaskGroupFactory[FParams, FReturn]:
# TODO: FIXME when mypy gets compatible with new attrs
return attr.evolve(self, tg_kwargs={**self.tg_kwargs, **kwargs}) # type: ignore[arg-type]
def partial(self, **kwargs: Any) -> _TaskGroupFactory[FParams, FReturn]:
self._validate_arg_names("partial", kwargs)
prevent_duplicates(self.partial_kwargs, kwargs, fail_reason="duplicate partial")
kwargs.update(self.partial_kwargs)
# TODO: FIXME when mypy gets compatible with new attrs
return attr.evolve(self, partial_kwargs=kwargs) # type: ignore[arg-type]
def expand(self, **kwargs: OperatorExpandArgument) -> DAGNode:
if not kwargs:
raise TypeError("no arguments to expand against")
self._validate_arg_names("expand", kwargs)
prevent_duplicates(self.partial_kwargs, kwargs, fail_reason="mapping already partial")
expand_input = DictOfListsExpandInput(kwargs)
return self._create_task_group(
functools.partial(MappedTaskGroup, expand_input=expand_input),
**self.partial_kwargs,
**{k: MappedArgument(input=expand_input, key=k) for k in kwargs},
)
def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument) -> DAGNode:
if isinstance(kwargs, Sequence):
for item in kwargs:
if not isinstance(item, (XComArg, Mapping)):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
elif not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
# It's impossible to build a dict of stubs as keyword arguments if the
# function uses * or ** wildcard arguments.
function_has_vararg = any(
v.kind == inspect.Parameter.VAR_POSITIONAL or v.kind == inspect.Parameter.VAR_KEYWORD
for v in self.function_signature.parameters.values()
)
if function_has_vararg:
raise TypeError("calling expand_kwargs() on task group function with * or ** is not supported")
# We can't be sure how each argument is used in the function (well
# technically we can with AST but let's not), so we have to create stubs
# for every argument, including those with default values.
map_kwargs = (k for k in self.function_signature.parameters if k not in self.partial_kwargs)
expand_input = ListOfDictsExpandInput(kwargs)
return self._create_task_group(
functools.partial(MappedTaskGroup, expand_input=expand_input),
**self.partial_kwargs,
**{k: MappedArgument(input=expand_input, key=k) for k in map_kwargs},
)
# This covers the @task_group() case. Annotations are copied from the TaskGroup
# class, only providing a default to 'group_id' (this is optional for the
# decorator and defaults to the decorated function's name). Please keep them in
# sync with TaskGroup when you can! Note that since this is an overload, these
# argument defaults aren't actually used at runtime--the real implementation
# does not use them, and simply rely on TaskGroup's defaults, so it's not
# disastrous if they go out of sync with TaskGroup.
@overload
[docs]def task_group(
group_id: str | None = None,
prefix_group_id: bool = True,
parent_group: TaskGroup | None = None,
dag: DAG | None = None,
default_args: dict[str, Any] | None = None,
tooltip: str = "",
ui_color: str = "CornflowerBlue",
ui_fgcolor: str = "#000",
add_suffix_on_collision: bool = False,
) -> Callable[[Callable[FParams, FReturn]], _TaskGroupFactory[FParams, FReturn]]: ...
# This covers the @task_group case (no parentheses).
@overload
def task_group(python_callable: Callable[FParams, FReturn]) -> _TaskGroupFactory[FParams, FReturn]: ...
def task_group(python_callable=None, **tg_kwargs):
"""
Python TaskGroup decorator.
This wraps a function into an Airflow TaskGroup. When used as the
``@task_group()`` form, all arguments are forwarded to the underlying
TaskGroup class. Can be used to parametrize TaskGroup.
:param python_callable: Function to decorate.
:param tg_kwargs: Keyword arguments for the TaskGroup object.
"""
if callable(python_callable) and not tg_kwargs:
return _TaskGroupFactory(function=python_callable, tg_kwargs=tg_kwargs)
return functools.partial(_TaskGroupFactory, tg_kwargs=tg_kwargs)