Source code for airflow.providers.microsoft.winrm.triggers.winrm

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Trigger for winrm remote execution."""

from __future__ import annotations

import asyncio
import base64
from collections import deque
from collections.abc import AsyncIterator
from functools import cached_property
from typing import TYPE_CHECKING, Any

from airflow.providers.microsoft.winrm.hooks.winrm import WinRMHook
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
    from winrm import Protocol


[docs] class WinRMCommandOutputTrigger(BaseTrigger): """ A trigger that polls the command output executed by the WinRMHook. This trigger avoids blocking a worker when using the WinRMOperator in deferred mode. The behavior of this trigger is as follows: - poll the command output from the shell launched by WinRM, - if command not done then sleep and retry, - when command done then return the output. :param ssh_conn_id: connection id from airflow Connections from where all the required parameters can be fetched like username and password, though priority is given to the params passed during init. :param shell_id: The shell id on the remote machine. :param command_id: The command id executed on the remote machine. :param output_encoding: the encoding used to decode stout and stderr, defaults to utf-8. :param return_output: Whether to accumulate and return the stdout or not, defaults to True. :param poll_interval: How often, in seconds, the trigger should poll the output command of the launched command, defaults to 1. :param max_output_chunks: Maximum number of stdout/stderr chunks to keep in a rolling buffer to prevent excessive memory usage for long-running commands, defaults to 100. """ def __init__( self, ssh_conn_id: str, shell_id: str, command_id: str, output_encoding: str = "utf-8", return_output: bool = True, poll_interval: float = 1, max_output_chunks: int = 100, ) -> None: super().__init__()
[docs] self.ssh_conn_id = ssh_conn_id
[docs] self.shell_id = shell_id
[docs] self.command_id = command_id
[docs] self.output_encoding = output_encoding
[docs] self.return_output = return_output
[docs] self.poll_interval = poll_interval
self._stdout: deque[str] = deque(maxlen=max_output_chunks) self._stderr: deque[str] = deque(maxlen=max_output_chunks)
[docs] def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize WinRMCommandOutputTrigger arguments and classpath.""" return ( f"{self.__class__.__module__}.{self.__class__.__name__}", { "ssh_conn_id": self.ssh_conn_id, "shell_id": self.shell_id, "command_id": self.command_id, "output_encoding": self.output_encoding, "return_output": self.return_output, "poll_interval": self.poll_interval, "max_output_chunks": self._stdout.maxlen, }, )
@cached_property
[docs] def hook(self) -> WinRMHook: return WinRMHook(ssh_conn_id=self.ssh_conn_id)
[docs] async def get_command_output(self, conn: Protocol) -> tuple[bytes, bytes, int | None, bool]: from asgiref.sync import sync_to_async return await sync_to_async(self.hook.get_command_output)(conn, self.shell_id, self.command_id)
[docs] async def run(self) -> AsyncIterator[TriggerEvent]: command_done: bool = False try: conn = await self.hook.get_async_conn() while not command_done: ( stdout, stderr, return_code, command_done, ) = await self.get_command_output(conn) if self.return_output and stdout: self._stdout.append(base64.standard_b64encode(stdout).decode(self.output_encoding)) if stderr: self._stderr.append(base64.standard_b64encode(stderr).decode(self.output_encoding)) if not command_done: await asyncio.sleep(self.poll_interval) continue yield TriggerEvent( { "status": "success", "shell_id": self.shell_id, "command_id": self.command_id, "return_code": return_code, "stdout": list(self._stdout), "stderr": list(self._stderr), } ) return except Exception as e: self.log.exception("An error occurred: %s", e) yield TriggerEvent( { "status": "error", "shell_id": self.shell_id, "command_id": self.command_id, "message": str(e), } ) return

Was this entry helpful?