Source code for airflow.providers.apache.kafka.triggers.await_message
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import asyncio
from functools import partial
from typing import Any, Sequence
from asgiref.sync import sync_to_async
from airflow.exceptions import AirflowException
from airflow.providers.apache.kafka.hooks.consume import KafkaConsumerHook
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.module_loading import import_string
[docs]class AwaitMessageTrigger(BaseTrigger):
"""
A trigger that waits for a message matching specific criteria to arrive in Kafka.
The behavior of the consumer of this trigger is as follows:
- poll the Kafka topics for a message, if no message returned, sleep
- process the message with provided callable and commit the message offset:
- if callable returns any data, raise a TriggerEvent with the return data
- else continue to next message
:param kafka_config_id: The connection object to use, defaults to "kafka_default"
:param topics: The topic (or topic regex) that should be searched for messages
:param apply_function: the location of the function to apply to messages for determination of matching
criteria. (In python dot notation as a string)
:param apply_function_args: A set of arguments to apply to the callable, defaults to None
:param apply_function_kwargs: A set of key word arguments to apply to the callable, defaults to None,
defaults to None
:param poll_timeout: How long the Kafka client should wait before returning from a poll request to
Kafka (seconds), defaults to 1
:param poll_interval: How long the trigger should sleep after reaching the end of the Kafka log
(seconds), defaults to 5
"""
def __init__(
self,
topics: Sequence[str],
apply_function: str,
kafka_config_id: str = "kafka_default",
apply_function_args: Sequence[Any] | None = None,
apply_function_kwargs: dict[Any, Any] | None = None,
poll_timeout: float = 1,
poll_interval: float = 5,
) -> None:
self.topics = topics
self.apply_function = apply_function
self.apply_function_args = apply_function_args or ()
self.apply_function_kwargs = apply_function_kwargs or {}
self.kafka_config_id = kafka_config_id
self.poll_timeout = poll_timeout
self.poll_interval = poll_interval
[docs] def serialize(self) -> tuple[str, dict[str, Any]]:
return (
"airflow.providers.apache.kafka.triggers.await_message.AwaitMessageTrigger",
{
"topics": self.topics,
"apply_function": self.apply_function,
"apply_function_args": self.apply_function_args,
"apply_function_kwargs": self.apply_function_kwargs,
"kafka_config_id": self.kafka_config_id,
"poll_timeout": self.poll_timeout,
"poll_interval": self.poll_interval,
},
)
[docs] async def run(self):
consumer_hook = KafkaConsumerHook(topics=self.topics, kafka_config_id=self.kafka_config_id)
async_get_consumer = sync_to_async(consumer_hook.get_consumer)
consumer = await async_get_consumer()
async_poll = sync_to_async(consumer.poll)
async_commit = sync_to_async(consumer.commit)
processing_call = import_string(self.apply_function)
processing_call = partial(processing_call, *self.apply_function_args, **self.apply_function_kwargs)
async_message_process = sync_to_async(processing_call)
while True:
message = await async_poll(self.poll_timeout)
if message is None:
continue
elif message.error():
raise AirflowException(f"Error: {message.error()}")
else:
rv = await async_message_process(message)
if rv:
await async_commit(message=message, asynchronous=False)
yield TriggerEvent(rv)
break
else:
await async_commit(message=message, asynchronous=False)
await asyncio.sleep(self.poll_interval)