Airflow Summit 2025 is coming October 07-09. Register now for early bird ticket!

Source code for airflow.providers.microsoft.psrp.hooks.psrp

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from collections.abc import Callable, Generator
from contextlib import contextmanager
from copy import copy
from logging import DEBUG, ERROR, INFO, WARNING
from typing import TYPE_CHECKING, Any, cast
from weakref import WeakKeyDictionary

from pypsrp.host import PSHost
from pypsrp.messages import MessageType
from pypsrp.powershell import PowerShell, PSInvocationState, RunspacePool
from pypsrp.wsman import WSMan

from airflow.exceptions import AirflowException
from airflow.providers.microsoft.psrp.version_compat import BaseHook

[docs] INFORMATIONAL_RECORD_LEVEL_MAP = { MessageType.DEBUG_RECORD: DEBUG, MessageType.ERROR_RECORD: ERROR, MessageType.VERBOSE_RECORD: INFO, MessageType.WARNING_RECORD: WARNING, }
[docs] OutputCallback = Callable[[str], None]
[docs] class PsrpHook(BaseHook): """ Hook for PowerShell Remoting Protocol execution. When used as a context manager, the runspace pool is reused between shell sessions. :param psrp_conn_id: Required. The name of the PSRP connection. :param logging_level: Logging level for message streams which are received during remote execution. The default is to include all messages in the task log. :param operation_timeout: Override the default WSMan timeout when polling the pipeline. :param runspace_options: Optional dictionary which is passed when creating the runspace pool. See :py:class:`~pypsrp.powershell.RunspacePool` for a description of the available options. :param wsman_options: Optional dictionary which is passed when creating the `WSMan` client. See :py:class:`~pypsrp.wsman.WSMan` for a description of the available options. :param on_output_callback: Optional callback function to be called whenever an output response item is received during job status polling. :param exchange_keys: If true (default), automatically initiate a session key exchange when the hook is used as a context manager. :param host: Optional PowerShell host instance. If this is not set, the default implementation will be used. You can provide an alternative `configuration_name` using either `runspace_options` or by setting this key as the extra fields of your connection. """
[docs] conn_name_attr = "psrp_conn_id"
[docs] default_conn_name = "psrp_default"
[docs] conn_type = "psrp"
[docs] hook_name = "PowerShell Remoting Protocol"
_conn: RunspacePool | None = None _wsman_ref: WeakKeyDictionary[RunspacePool, WSMan] = WeakKeyDictionary() def __init__( self, psrp_conn_id: str, logging_level: int = DEBUG, operation_timeout: int | None = None, runspace_options: dict[str, Any] | None = None, wsman_options: dict[str, Any] | None = None, on_output_callback: OutputCallback | None = None, exchange_keys: bool = True, host: PSHost | None = None, ):
[docs] self.conn_id = psrp_conn_id
self._logging_level = logging_level self._operation_timeout = operation_timeout self._runspace_options = runspace_options or {} self._wsman_options = wsman_options or {} self._on_output_callback = on_output_callback self._exchange_keys = exchange_keys self._host = host or PSHost(None, None, False, type(self).__name__, None, None, "1.0")
[docs] def __enter__(self): conn = self.get_conn() self._wsman_ref[conn].__enter__() conn.__enter__() if self._exchange_keys: conn.exchange_keys() self._conn = conn return self
[docs] def __exit__(self, exc_type, exc_value, traceback): try: self._conn.__exit__(exc_type, exc_value, traceback) self._wsman_ref[self._conn].__exit__(exc_type, exc_value, traceback) finally: del self._conn
[docs] def get_conn(self) -> RunspacePool: """ Return a runspace pool. The returned object must be used as a context manager. """ conn = self.get_connection(self.conn_id) self.log.info("Establishing WinRM connection %s to host: %s", self.conn_id, conn.host) extra = conn.extra_dejson.copy() def apply_extra(d, keys): d = d.copy() for key in keys: value = extra.pop(key, None) if value is not None: d[key] = value return d wsman_options = apply_extra( self._wsman_options, ( "auth", "cert_validation", "connection_timeout", "locale", "read_timeout", "reconnection_retries", "reconnection_backoff", "ssl", ), ) conn.host = cast("str", conn.host) wsman = WSMan(conn.host, username=conn.login, password=conn.password, **wsman_options) runspace_options = apply_extra(self._runspace_options, ("configuration_name",)) if extra: raise AirflowException(f"Unexpected extra configuration keys: {', '.join(sorted(extra))}") pool = RunspacePool(wsman, host=self._host, **runspace_options) self._wsman_ref[pool] = wsman return pool
@contextmanager
[docs] def invoke(self) -> Generator[PowerShell, None, None]: """ Yield a PowerShell object to which commands can be added. Upon exit, the commands will be invoked. """ logger = copy(self.log) logger.setLevel(self._logging_level) local_context = self._conn is None if local_context: self.__enter__() try: if TYPE_CHECKING: assert self._conn is not None ps = PowerShell(self._conn) yield ps ps.begin_invoke() streams = [ ps.output, ps.streams.debug, ps.streams.error, ps.streams.information, ps.streams.progress, ps.streams.verbose, ps.streams.warning, ] offsets = [0 for _ in streams] # We're using polling to make sure output and streams are # handled while the process is running. while ps.state == PSInvocationState.RUNNING: ps.poll_invoke(timeout=self._operation_timeout) for i, stream in enumerate(streams): offset = offsets[i] while len(stream) > offset: record = stream[offset] # Records received on the output stream during job # status polling are handled via an optional callback, # while the other streams are simply logged. if stream is ps.output: if self._on_output_callback is not None: self._on_output_callback(record) else: self._log_record(logger.log, record) offset += 1 offsets[i] = offset # For good measure, we'll make sure the process has # stopped running in any case. ps.end_invoke() self.log.info("Invocation state: %s", str(PSInvocationState(ps.state))) if ps.streams.error: raise AirflowException("Process had one or more errors") finally: if local_context: self.__exit__(None, None, None)
[docs] def invoke_cmdlet( self, name: str, use_local_scope: bool | None = None, arguments: list[str] | None = None, parameters: dict[str, str] | None = None, ) -> PowerShell: """Invoke a PowerShell cmdlet and return session.""" with self.invoke() as ps: ps.add_cmdlet(name, use_local_scope=use_local_scope) for argument in arguments or (): ps.add_argument(argument) if parameters: ps.add_parameters(parameters) return ps
[docs] def invoke_powershell(self, script: str) -> PowerShell: """Invoke a PowerShell script and return session.""" with self.invoke() as ps: ps.add_script(script) return ps
def _log_record(self, log, record): message_type = record.MESSAGE_TYPE if message_type == MessageType.ERROR_RECORD: log(INFO, "%s: %s", record.reason, record) if record.script_stacktrace: for trace in record.script_stacktrace.splitlines(): log(INFO, trace) level = INFORMATIONAL_RECORD_LEVEL_MAP.get(message_type) if level is not None: try: message = str(record.message) except BaseException as exc: # See https://github.com/jborean93/pypsrp/pull/130 message = str(exc) # Sometimes a message will have a trailing \r\n sequence such as # the tracing output of the Set-PSDebug cmdlet. message = message.rstrip() if record.command_name is None: log(level, "%s", message) else: log(level, "%s: %s", record.command_name, message) elif message_type == MessageType.INFORMATION_RECORD: log(INFO, "%s (%s): %s", record.computer, record.user, record.message_data) elif message_type == MessageType.PROGRESS_RECORD: log(INFO, "Progress: %s (%s)", record.activity, record.description) else: log(WARNING, "Unsupported message type: %s", message_type)
[docs] def test_connection(self): """Test PSRP Connection.""" with PsrpHook(psrp_conn_id=self.conn_id): pass

Was this entry helpful?