Source code for airflow.providers.amazon.aws.utils.task_log_fetcher

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import json
import logging
import re
import time
from collections.abc import Generator
from datetime import datetime, timedelta, timezone
from threading import Event, Thread
from typing import TYPE_CHECKING

from botocore.exceptions import ClientError, ConnectionClosedError

from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook

if TYPE_CHECKING:
    from airflow.sdk.types import Logger

_LOG_LEVEL_PATTERN = re.compile(
    r"^\s*(?:\[)?(DEBUG|INFO|WARNING|WARN|ERROR|CRITICAL|FATAL)(?:\])?\s*[:\-]?\s*",
    re.IGNORECASE,
)
_LOG_LEVEL_MAP: dict[str, int] = {
    "DEBUG": logging.DEBUG,
    "INFO": logging.INFO,
    "WARNING": logging.WARNING,
    "WARN": logging.WARNING,
    "ERROR": logging.ERROR,
    "CRITICAL": logging.CRITICAL,
    "FATAL": logging.CRITICAL,
}


def _parse_log_level(message: str) -> int:
    """
    Detect the Python logging level from a CloudWatch log message.

    Supports two formats:
    1. JSON-structured logs with a ``levelname`` or ``level`` field.
    2. Plain-text prefixes such as ``ERROR:``, ``[WARNING]``, ``CRITICAL -``, etc.

    Returns ``logging.INFO`` when no known level is found (backwards-compatible).
    """
    stripped = message.strip()
    if stripped.startswith("{"):
        try:
            parsed = json.loads(stripped)
            level_str = parsed.get("levelname") or parsed.get("level") or ""
            level = _LOG_LEVEL_MAP.get(level_str.upper(), -1)
            if level != -1:
                return level
        except (json.JSONDecodeError, AttributeError):
            pass
    match = _LOG_LEVEL_PATTERN.match(message)
    if match:
        return _LOG_LEVEL_MAP.get(match.group(1).upper(), logging.INFO)
    return logging.INFO


[docs] class AwsTaskLogFetcher(Thread): """Fetch Cloudwatch log events with specific interval and forward them at the detected log level.""" def __init__( self, *, log_group: str, log_stream_name: str, fetch_interval: timedelta, logger: Logger, aws_conn_id: str | None = "aws_default", region_name: str | None = None, ): super().__init__() self._event = Event()
[docs] self.fetch_interval = fetch_interval
[docs] self.logger = logger
[docs] self.log_group = log_group
[docs] self.log_stream_name = log_stream_name
[docs] self.hook = AwsLogsHook(aws_conn_id=aws_conn_id, region_name=region_name)
[docs] def run(self) -> None: continuation_token = AwsLogsHook.ContinuationToken() while not self.is_stopped(): time.sleep(self.fetch_interval.total_seconds()) log_events = self._get_log_events(continuation_token) prev_timestamp_event = None for log_event in log_events: current_timestamp_event = datetime.fromtimestamp( log_event["timestamp"] / 1000.0, tz=timezone.utc ) if current_timestamp_event == prev_timestamp_event: # When multiple events have the same timestamp, somehow, only one event is logged # As a consequence, some logs are missed in the log group (in case they have the same # timestamp) # When a slight delay is added before logging the event, that solves the issue # See https://github.com/apache/airflow/issues/40875 time.sleep(0.001) level = _parse_log_level(log_event["message"]) self.logger.log(level, self.event_to_str(log_event)) prev_timestamp_event = current_timestamp_event
def _get_log_events(self, skip_token: AwsLogsHook.ContinuationToken | None = None) -> Generator: if skip_token is None: skip_token = AwsLogsHook.ContinuationToken() try: yield from self.hook.get_log_events( self.log_group, self.log_stream_name, continuation_token=skip_token ) except ClientError as error: if error.response["Error"]["Code"] != "ResourceNotFoundException": self.logger.warning("Error on retrieving Cloudwatch log events: %s", error) else: self.logger.info( "Cannot find log stream yet, it can take a couple of seconds to show up. " "If this error persists, check that the log group and stream are correct: " "group: %s\tstream: %s", self.log_group, self.log_stream_name, ) yield from () except ConnectionClosedError as error: self.logger.warning("ConnectionClosedError on retrieving Cloudwatch log events: %s", error) yield from () @staticmethod
[docs] def event_to_str(event: dict) -> str: event_dt = datetime.fromtimestamp(event["timestamp"] / 1000.0, tz=timezone.utc) formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3] message = event["message"] return f"[{formatted_event_dt}] {message}"
[docs] def get_last_log_messages(self, number_messages) -> list: """ Get the last logs messages in one single request. NOTE: some restrictions apply: - if logs are too old, the response will be empty - the max number of messages we can retrieve is constrained by cloudwatch limits (10,000). """ response = self.hook.conn.get_log_events( logGroupName=self.log_group, logStreamName=self.log_stream_name, startFromHead=False, limit=number_messages, ) return [log["message"] for log in response["events"]]
[docs] def get_last_log_message(self) -> str | None: try: return self.get_last_log_messages(1)[0] except IndexError: return None
[docs] def is_stopped(self) -> bool: return self._event.is_set()
[docs] def stop(self): self._event.set()

Was this entry helpful?