From c0467c9abd07e4014267c8a7b39927184a28fd76 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 11 Feb 2026 11:45:34 +0100 Subject: [PATCH 1/2] refactor: Added deferrable support to WinRMOperator --- providers/microsoft/winrm/provider.yaml | 5 + .../microsoft/winrm/get_provider_info.py | 6 + .../providers/microsoft/winrm/hooks/winrm.py | 209 ++++++++++++------ .../microsoft/winrm/operators/winrm.py | 141 ++++++++++-- .../microsoft/winrm/triggers/__init__.py | 17 ++ .../microsoft/winrm/triggers/winrm.py | 147 ++++++++++++ .../unit/microsoft/winrm/hooks/test_winrm.py | 7 +- .../microsoft/winrm/operators/test_winrm.py | 58 ++++- .../unit/microsoft/winrm/triggers/__init__.py | 17 ++ .../microsoft/winrm/triggers/test_winrm.py | 85 +++++++ 10 files changed, 589 insertions(+), 103 deletions(-) create mode 100644 providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/triggers/__init__.py create mode 100644 providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/triggers/winrm.py create mode 100644 providers/microsoft/winrm/tests/unit/microsoft/winrm/triggers/__init__.py create mode 100644 providers/microsoft/winrm/tests/unit/microsoft/winrm/triggers/test_winrm.py diff --git a/providers/microsoft/winrm/provider.yaml b/providers/microsoft/winrm/provider.yaml index d8da1111e3544..8475d877b1e89 100644 --- a/providers/microsoft/winrm/provider.yaml +++ b/providers/microsoft/winrm/provider.yaml @@ -83,6 +83,11 @@ hooks: python-modules: - airflow.providers.microsoft.winrm.hooks.winrm +triggers: + - integration-name: Windows Remote Management (WinRM) + python-modules: + - airflow.providers.microsoft.winrm.triggers.winrm + connection-types: - hook-class-name: airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook connection-type: winrm diff --git a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/get_provider_info.py b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/get_provider_info.py index 84cddb048c1fa..8c9d6c56f2f7e 100644 --- a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/get_provider_info.py +++ b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/get_provider_info.py @@ -47,6 +47,12 @@ def get_provider_info(): "python-modules": ["airflow.providers.microsoft.winrm.hooks.winrm"], } ], + "triggers": [ + { + "integration-name": "Windows Remote Management (WinRM)", + "python-modules": ["airflow.providers.microsoft.winrm.triggers.winrm"], + } + ], "connection-types": [ { "hook-class-name": "airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook", diff --git a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py index 1178b3c91f1a4..2a16e427fe258 100644 --- a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py +++ b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/hooks/winrm.py @@ -19,16 +19,20 @@ from __future__ import annotations +import logging from base64 import b64encode from contextlib import suppress +from typing import TYPE_CHECKING, Any, Literal, cast from winrm.exceptions import WinRMOperationTimeoutError from winrm.protocol import Protocol +from airflow.providers.common.compat.connection import get_async_connection from airflow.providers.common.compat.sdk import AirflowException, BaseHook from airflow.utils.platform import getuser -# TODO: FIXME please - I have too complex implementation +if TYPE_CHECKING: + from airflow.providers.common.compat.sdk import Connection class WinRMHook(BaseHook): @@ -122,13 +126,10 @@ def __init__( self.credssp_disable_tlsv1_2 = credssp_disable_tlsv1_2 self.send_cbt = send_cbt - self.winrm_protocol = None - - def get_conn(self): - self.log.debug("Creating WinRM client for conn_id: %s", self.ssh_conn_id) - if self.ssh_conn_id is not None: - conn = self.get_connection(self.ssh_conn_id) + self.winrm_protocol: Protocol | None = None + def create_protocol(self, conn: Connection | None) -> Protocol: + if conn: if self.username is None: self.username = conn.login if self.password is None: @@ -192,36 +193,57 @@ def get_conn(self): self.endpoint = f"http://{self.remote_host}:{self.remote_port}/wsman" try: - if self.password and self.password.strip(): - self.winrm_protocol = Protocol( - endpoint=self.endpoint, - transport=self.transport, - username=self.username, - password=self.password, - service=self.service, - keytab=self.keytab, - ca_trust_path=self.ca_trust_path, - cert_pem=self.cert_pem, - cert_key_pem=self.cert_key_pem, - server_cert_validation=self.server_cert_validation, - kerberos_delegation=self.kerberos_delegation, - read_timeout_sec=self.read_timeout_sec, - operation_timeout_sec=self.operation_timeout_sec, - kerberos_hostname_override=self.kerberos_hostname_override, - message_encryption=self.message_encryption, - credssp_disable_tlsv1_2=self.credssp_disable_tlsv1_2, - send_cbt=self.send_cbt, + winrm_protocol = Protocol( + endpoint=self.endpoint, + transport=cast( + "Literal['auto', 'basic', 'certificate', 'ntlm', 'kerberos', 'credssp', 'plaintext', 'ssl']", + self.transport, + ), + username=self.username, + password=self.password, + service=self.service, + keytab=cast("Any", self.keytab), + ca_trust_path=cast("str | Literal['legacy_requests']", self.ca_trust_path), + cert_pem=self.cert_pem, + cert_key_pem=self.cert_key_pem, + server_cert_validation=cast( + "Literal['validate', 'ignore'] | None", self.server_cert_validation + ), + kerberos_delegation=self.kerberos_delegation, + read_timeout_sec=self.read_timeout_sec, + operation_timeout_sec=self.operation_timeout_sec, + kerberos_hostname_override=self.kerberos_hostname_override, + message_encryption=cast("Literal['auto', 'always', 'never']", self.message_encryption), + credssp_disable_tlsv1_2=self.credssp_disable_tlsv1_2, + send_cbt=self.send_cbt, + ) + + if not hasattr(winrm_protocol, "get_command_output_raw"): + # since pywinrm>=0.5 get_command_output_raw replace _raw_get_command_output + winrm_protocol.get_command_output_raw = ( # type: ignore[method-assign] + winrm_protocol._raw_get_command_output ) + self.log.info("Establishing WinRM connection to host: %s", self.remote_host) + + return winrm_protocol except Exception as error: error_msg = f"Error creating connection to host: {self.remote_host}, error: {error}" self.log.error(error_msg) raise AirflowException(error_msg) - if not hasattr(self.winrm_protocol, "get_command_output_raw"): - # since pywinrm>=0.5 get_command_output_raw replace _raw_get_command_output - self.winrm_protocol.get_command_output_raw = self.winrm_protocol._raw_get_command_output + def get_conn(self) -> Protocol: + if self.winrm_protocol is None: + self.winrm_protocol = self.create_protocol( + self.get_connection(self.ssh_conn_id) if self.ssh_conn_id else None + ) + return self.winrm_protocol + async def get_async_conn(self) -> Protocol: + if self.winrm_protocol is None: + self.winrm_protocol = self.create_protocol( + await get_async_connection(self.ssh_conn_id) if self.ssh_conn_id else None + ) return self.winrm_protocol def run( @@ -231,7 +253,7 @@ def run( output_encoding: str = "utf-8", return_output: bool = True, working_directory: str | None = None, - ) -> tuple[int, list[bytes], list[bytes]]: + ) -> tuple[int | None, list[bytes], list[bytes]]: """ Run a command. @@ -243,55 +265,104 @@ def run( :param working_directory: specify working directory. :return: returns a tuple containing return_code, stdout and stderr in order. """ - winrm_client = self.get_conn() - self.log.info("Establishing WinRM connection to host: %s", self.remote_host) - try: - shell_id = winrm_client.open_shell(working_directory=working_directory) - except Exception as error: - error_msg = f"Error connecting to host: {self.remote_host}, error: {error}" - self.log.error(error_msg) - raise AirflowException(error_msg) + conn = self.get_conn() + shell_id, command_id = self._run_command( + conn=conn, + command=command, + ps_path=ps_path, + working_directory=working_directory, + ) try: - if ps_path is not None: - self.log.info("Running command as powershell script: '%s'...", command) - encoded_ps = b64encode(command.encode("utf_16_le")).decode("ascii") - command_id = winrm_client.run_command(shell_id, f"{ps_path} -encodedcommand {encoded_ps}") - else: - self.log.info("Running command: '%s'...", command) - command_id = winrm_client.run_command(shell_id, command) - - # See: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py + command_done = False stdout_buffer = [] stderr_buffer = [] - command_done = False + return_code: int | None = None + + # See: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py while not command_done: - # this is an expected error when waiting for a long-running process, just silently retry - with suppress(WinRMOperationTimeoutError): - ( - stdout, - stderr, - return_code, - command_done, - ) = winrm_client.get_command_output_raw(shell_id, command_id) - - # Only buffer stdout if we need to so that we minimize memory usage. - if return_output: - stdout_buffer.append(stdout) - stderr_buffer.append(stderr) - - for line in stdout.decode(output_encoding).splitlines(): - self.log.info(line) - for line in stderr.decode(output_encoding).splitlines(): - self.log.warning(line) - - winrm_client.cleanup_command(shell_id, command_id) + ( + stdout, + stderr, + return_code, + command_done, + ) = self.get_command_output(conn, shell_id, command_id, output_encoding) + + # Only buffer stdout if we need to so that we minimize memory usage. + if return_output: + stdout_buffer.append(stdout) + stderr_buffer.append(stderr) return return_code, stdout_buffer, stderr_buffer except Exception as e: raise AirflowException(f"WinRM operator error: {e}") finally: - winrm_client.close_shell(shell_id) + conn.cleanup_command(shell_id, command_id) + conn.close_shell(shell_id) + + def run_command( + self, + command: str, + ps_path: str | None = None, + working_directory: str | None = None, + ) -> tuple[str, str]: + return self._run_command(self.get_conn(), command, ps_path, working_directory) + + def _run_command( + self, + conn: Protocol, + command: str, + ps_path: str | None = None, + working_directory: str | None = None, + ) -> tuple[str, str]: + if not command: + raise AirflowException("No command specified so nothing to execute here.") + + try: + shell_id = conn.open_shell(working_directory=working_directory) + + if ps_path is not None: + self.log.info("Running command as powershell script: '%s'...", command) + encoded_ps = b64encode(command.encode("utf_16_le")).decode("ascii") + command_id = conn.run_command(shell_id, f"{ps_path} -encodedcommand {encoded_ps}") + else: + self.log.info("Running command: '%s'...", command) + command_id = conn.run_command(shell_id, command) + except Exception as error: + error_msg = f"Error connecting to host: {self.remote_host}, error: {error}" + self.log.error(error_msg) + raise AirflowException(error_msg) + + return shell_id, command_id + + def get_command_output( + self, conn: Protocol, shell_id: str, command_id: str, output_encoding: str = "utf-8" + ) -> tuple[bytes, bytes, int | None, bool]: + with suppress(WinRMOperationTimeoutError): + ( + stdout, + stderr, + return_code, + command_done, + ) = conn.get_command_output_raw(shell_id, command_id) + + self.log.debug("return_code: %s", return_code) + self.log.debug("command_done: %s", command_done) + self.log_output(stdout, output_encoding=output_encoding) + self.log_output(stderr, level=logging.WARNING, output_encoding=output_encoding) + + return stdout, stderr, return_code, command_done + return b"", b"", None, False + + def log_output( + self, + output: bytes | None, + level: int = logging.INFO, + output_encoding: str = "utf-8", + ) -> None: + if output: + for line in output.decode(output_encoding).splitlines(): + self.log.log(level, line) def test_connection(self): try: diff --git a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py index a360df4e1febc..c71e7eb6fcaee 100644 --- a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py +++ b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/operators/winrm.py @@ -17,13 +17,19 @@ # under the License. from __future__ import annotations +import base64 import logging +import warnings from base64 import b64encode from collections.abc import Sequence -from typing import TYPE_CHECKING +from contextlib import suppress +from datetime import timedelta +from typing import TYPE_CHECKING, Any +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.common.compat.sdk import AirflowException, BaseOperator, conf from airflow.providers.microsoft.winrm.hooks.winrm import WinRMHook +from airflow.providers.microsoft.winrm.triggers.winrm import WinRMCommandOutputTrigger if TYPE_CHECKING: from airflow.sdk import Context @@ -46,9 +52,14 @@ class WinRMOperator(BaseOperator): :param ps_path: path to powershell, `powershell` for v5.1- and `pwsh` for v6+. If specified, it will execute the command as powershell script. :param output_encoding: the encoding used to decode stout and stderr - :param timeout: timeout for executing the command. + :param max_output_chunks: Maximum number of stdout/stderr chunks to keep in a rolling buffer to prevent + excessive memory usage for long-running commands in deferrable mode, defaults to 100. + :param timeout: timeout for executing the command, defaults to 10. + :param poll_interval: How often, in seconds, the trigger should poll the output command of the launched command, + defaults to 1. :param expected_return_code: expected return code value(s) of command. :param working_directory: specify working directory. + :param deferrable: Run operator in the deferrable mode """ template_fields: Sequence[str] = ( @@ -66,9 +77,12 @@ def __init__( command: str | None = None, ps_path: str | None = None, output_encoding: str = "utf-8", - timeout: int = 10, + max_output_chunks: int = 100, + timeout: int | timedelta | None = None, + poll_interval: int | timedelta | None = None, expected_return_code: int | list[int] | range = 0, working_directory: str | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -78,37 +92,91 @@ def __init__( self.command = command self.ps_path = ps_path self.output_encoding = output_encoding - self.timeout = timeout + self.max_output_chunks = max_output_chunks + if timeout is not None: + warnings.warn( + "timeout is deprecated and will be removed. Please use execution_timeout instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + self.execution_timeout = ( + timedelta(seconds=timeout) if not isinstance(timeout, timedelta) else timeout + ) + self.poll_interval = ( + poll_interval.total_seconds() + if isinstance(poll_interval, timedelta) + else poll_interval + if poll_interval is not None + else 1.0 + ) self.expected_return_code = expected_return_code self.working_directory = working_directory + self.deferrable = deferrable - def execute(self, context: Context) -> list | str: - if self.ssh_conn_id and not self.winrm_hook: - self.log.info("Hook not found, creating...") - self.winrm_hook = WinRMHook(ssh_conn_id=self.ssh_conn_id) - + @property + def hook(self) -> WinRMHook: if not self.winrm_hook: - raise AirflowException("Cannot operate without winrm_hook or ssh_conn_id.") + if self.ssh_conn_id: + self.log.info("Hook not found, creating...") + self.winrm_hook = WinRMHook(ssh_conn_id=self.ssh_conn_id, remote_host=self.remote_host) + else: + raise AirflowException("Cannot operate without winrm_hook.") - if self.remote_host is not None: - self.winrm_hook.remote_host = self.remote_host + return self.winrm_hook + def execute(self, context: Context) -> list | str: if not self.command: raise AirflowException("No command specified so nothing to execute here.") - return_code, stdout_buffer, stderr_buffer = self.winrm_hook.run( + if self.deferrable: + if not self.hook.ssh_conn_id: + raise AirflowException("Cannot operate in deferrable mode without ssh_conn_id.") + + shell_id, command_id = self.hook.run_command( + command=self.command, + ps_path=self.ps_path, + working_directory=self.working_directory, + ) + return self.defer( + trigger=WinRMCommandOutputTrigger( + ssh_conn_id=self.hook.ssh_conn_id, + shell_id=shell_id, + command_id=command_id, + output_encoding=self.output_encoding, + return_output=self.do_xcom_push, + max_output_chunks=self.max_output_chunks, + poll_interval=self.poll_interval, + ), + method_name=self.execute_complete.__name__, + timeout=self.execution_timeout, + ) + + return_code, stdout_buffer, stderr_buffer = self.hook.run( command=self.command, ps_path=self.ps_path, output_encoding=self.output_encoding, return_output=self.do_xcom_push, working_directory=self.working_directory, ) + return self.evaluate_result(return_code, stdout_buffer, stderr_buffer) + + def validate_return_code(self, return_code: int | None) -> bool: + if return_code is not None: + if isinstance(self.expected_return_code, int): + return return_code == self.expected_return_code + if isinstance(self.expected_return_code, list) or isinstance(self.expected_return_code, range): + return return_code in self.expected_return_code + return False - success = False - if isinstance(self.expected_return_code, int): - success = return_code == self.expected_return_code - elif isinstance(self.expected_return_code, list) or isinstance(self.expected_return_code, range): - success = return_code in self.expected_return_code + def evaluate_result( + self, + return_code: int | None, + stdout_buffer: list[bytes], + stderr_buffer: list[bytes], + ) -> Any: + success = self.validate_return_code(return_code) + + self.log.debug("success: %s", success) if success: # returning output if do_xcom_push is set @@ -122,3 +190,40 @@ def execute(self, context: Context) -> list | str: stderr_output = b"".join(stderr_buffer).decode(self.output_encoding) error_msg = f"Error running cmd: {self.command}, return code: {return_code}, error: {stderr_output}" raise AirflowException(error_msg) + + def _decode(self, output: str) -> bytes: + decoded_output = base64.standard_b64decode(output) + self.hook.log_output(decoded_output, output_encoding=self.output_encoding) + return decoded_output + + def execute_complete( + self, + context: Context, + event: dict[Any, Any], + ) -> Any: + """ + Execute callback when WinRMCommandOutputTrigger finishes execution. + + This method gets executed automatically when WinRMCommandOutputTrigger completes its execution. + """ + status = event.get("status") + + if status == "error": + raise AirflowException(f"Trigger failed: {event.get('message')}") + + return_code = event.get("return_code") + + self.log.info("%s completed with %s", self.task_id, status) + + stdout = [self._decode(output) for output in event.get("stdout", [])] + stderr = [self._decode(output) for output in event.get("stderr", [])] + + try: + return self.evaluate_result(return_code, stdout, stderr) + finally: + conn = self.hook.get_conn() + if conn: + with suppress(Exception): + conn.cleanup_command(event["shell_id"], event["command_id"]) + with suppress(Exception): + conn.close_shell(event["shell_id"]) diff --git a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/triggers/__init__.py b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/triggers/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/triggers/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/triggers/winrm.py b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/triggers/winrm.py new file mode 100644 index 0000000000000..1ba9bfd549689 --- /dev/null +++ b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/triggers/winrm.py @@ -0,0 +1,147 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hook for winrm remote execution.""" + +from __future__ import annotations + +import asyncio +import base64 +from collections import deque +from collections.abc import AsyncIterator +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from airflow.providers.microsoft.winrm.hooks.winrm import WinRMHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + +if TYPE_CHECKING: + from winrm import Protocol + + +class WinRMCommandOutputTrigger(BaseTrigger): + """ + A trigger that polls the command output executed by the WinRMHook. + + This trigger avoids blocking a worker when using the WinRMOperator in deferred mode. + + The behavior of this trigger is as follows: + - poll the command output from the shell launched by WinRM, + - if command not done then sleep and retry, + - when command done then return the output. + + :param ssh_conn_id: connection id from airflow Connections from where + all the required parameters can be fetched like username and password, + though priority is given to the params passed during init. + :param shell_id: The shell id on the remote machine. + :param command_id: The command id executed on the remote machine. + :param output_encoding: the encoding used to decode stout and stderr, defaults to utf-8. + :param return_output: Whether to accumulate and return the stdout or not, defaults to True. + :param poll_interval: How often, in seconds, the trigger should poll the output command of the launched command, + defaults to 1. + :param max_output_chunks: Maximum number of stdout/stderr chunks to keep in a rolling buffer to prevent + excessive memory usage for long-running commands, defaults to 100. + """ + + def __init__( + self, + ssh_conn_id: str, + shell_id: str, + command_id: str, + output_encoding: str = "utf-8", + return_output: bool = True, + poll_interval: float = 1, + max_output_chunks: int = 100, + ) -> None: + super().__init__() + self.ssh_conn_id = ssh_conn_id + self.shell_id = shell_id + self.command_id = command_id + self.output_encoding = output_encoding + self.return_output = return_output + self.poll_interval = poll_interval + self._stdout: deque[str] = deque(maxlen=max_output_chunks) + self._stderr: deque[str] = deque(maxlen=max_output_chunks) + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize WinRMCommandOutputTrigger arguments and classpath.""" + return ( + f"{self.__class__.__module__}.{self.__class__.__name__}", + { + "ssh_conn_id": self.ssh_conn_id, + "shell_id": self.shell_id, + "command_id": self.command_id, + "output_encoding": self.output_encoding, + "return_output": self.return_output, + "poll_interval": self.poll_interval, + "max_output_chunks": self._stdout.maxlen, + }, + ) + + @cached_property + def hook(self) -> WinRMHook: + return WinRMHook(ssh_conn_id=self.ssh_conn_id) + + async def get_command_output(self, conn: Protocol) -> tuple[bytes, bytes, int | None, bool]: + from asgiref.sync import sync_to_async + + return await sync_to_async(self.hook.get_command_output)(conn, self.shell_id, self.command_id) + + async def run(self) -> AsyncIterator[TriggerEvent]: + command_done: bool = False + + try: + conn = await self.hook.get_async_conn() + while not command_done: + ( + stdout, + stderr, + return_code, + command_done, + ) = await self.get_command_output(conn) + + if self.return_output and stdout: + self._stdout.append(base64.standard_b64encode(stdout).decode(self.output_encoding)) + if stderr: + self._stderr.append(base64.standard_b64encode(stderr).decode(self.output_encoding)) + + if not command_done: + await asyncio.sleep(self.poll_interval) + continue + + yield TriggerEvent( + { + "status": "success", + "shell_id": self.shell_id, + "command_id": self.command_id, + "return_code": return_code, + "stdout": list(self._stdout), + "stderr": list(self._stderr), + } + ) + return + except Exception as e: + self.log.exception("An error occurred: %s", e) + yield TriggerEvent( + { + "status": "error", + "shell_id": self.shell_id, + "command_id": self.command_id, + "message": str(e), + } + ) + return diff --git a/providers/microsoft/winrm/tests/unit/microsoft/winrm/hooks/test_winrm.py b/providers/microsoft/winrm/tests/unit/microsoft/winrm/hooks/test_winrm.py index cf348a35a8e36..9f39d937e03fd 100644 --- a/providers/microsoft/winrm/tests/unit/microsoft/winrm/hooks/test_winrm.py +++ b/providers/microsoft/winrm/tests/unit/microsoft/winrm/hooks/test_winrm.py @@ -21,14 +21,9 @@ import pytest -from airflow.providers.common.compat.sdk import AirflowException +from airflow.providers.common.compat.sdk import AirflowException, Connection from airflow.providers.microsoft.winrm.hooks.winrm import WinRMHook -try: - from airflow.sdk import Connection # type: ignore -except ImportError: - from airflow.models import Connection # type: ignore - class TestWinRMHook: def test_get_conn_missing_remote_host(self): diff --git a/providers/microsoft/winrm/tests/unit/microsoft/winrm/operators/test_winrm.py b/providers/microsoft/winrm/tests/unit/microsoft/winrm/operators/test_winrm.py index 5a28926e22848..3544c499ebc45 100644 --- a/providers/microsoft/winrm/tests/unit/microsoft/winrm/operators/test_winrm.py +++ b/providers/microsoft/winrm/tests/unit/microsoft/winrm/operators/test_winrm.py @@ -19,23 +19,28 @@ from base64 import b64encode from unittest import mock +from unittest.mock import AsyncMock, MagicMock import pytest from airflow.providers.common.compat.sdk import AirflowException +from airflow.providers.microsoft.winrm.hooks.winrm import WinRMHook from airflow.providers.microsoft.winrm.operators.winrm import WinRMOperator +from airflow.triggers.base import TriggerEvent + +from tests_common.test_utils.operators.run_deferrable import execute_operator class TestWinRMOperator: def test_no_winrm_hook_no_ssh_conn_id(self): - op = WinRMOperator(task_id="test_task_id", winrm_hook=None, ssh_conn_id=None) - exception_msg = "Cannot operate without winrm_hook or ssh_conn_id." + op = WinRMOperator(task_id="test_task_id", winrm_hook=None, ssh_conn_id=None, command="not_empty") + exception_msg = "Cannot operate without winrm_hook." with pytest.raises(AirflowException, match=exception_msg): op.execute(None) - @mock.patch("airflow.providers.microsoft.winrm.operators.winrm.WinRMHook") - def test_no_command(self, mock_hook): - op = WinRMOperator(task_id="test_task_id", winrm_hook=mock_hook, command=None) + def test_no_command(self): + winrm_hook = WinRMHook(transport="ntml", remote_host="localhost", password="secret") + op = WinRMOperator(task_id="test_task_id", winrm_hook=winrm_hook, command=None) exception_msg = "No command specified so nothing to execute here." with pytest.raises(AirflowException, match=exception_msg): op.execute(None) @@ -84,11 +89,7 @@ def test_expected_return_code_command(self, mock_hook, expected_return_code, rea expected_return_code=expected_return_code, ) - should_task_succeed = False - if isinstance(expected_return_code, int): - should_task_succeed = real_return_code == expected_return_code - elif isinstance(expected_return_code, list) or isinstance(expected_return_code, range): - should_task_succeed = real_return_code in expected_return_code + should_task_succeed = op.validate_return_code(real_return_code) if should_task_succeed: execute_result = op.execute(None) @@ -104,3 +105,40 @@ def test_expected_return_code_command(self, mock_hook, expected_return_code, rea exception_msg = f"Error running cmd: {command}, return code: {real_return_code}, error: KO" with pytest.raises(AirflowException, match=exception_msg): op.execute(None) + + @mock.patch("airflow.providers.microsoft.winrm.operators.winrm.WinRMHook") + @mock.patch("airflow.providers.microsoft.winrm.triggers.winrm.WinRMHook") + def test_execute_deferrable_success(self, mock_operator_hook, mock_trigger_hook): + mock_hook_instance = MagicMock(spec=WinRMHook) + mock_hook_instance.ssh_conn_id = "winrm_default" + mock_operator_hook.return_value = mock_hook_instance + mock_trigger_hook.return_value = mock_hook_instance + mock_conn = MagicMock() + mock_hook_instance.get_async_conn = AsyncMock(return_value=mock_conn) + mock_hook_instance.get_conn.return_value = mock_conn + mock_hook_instance.get_command_output.return_value = (b"hello", b"", 0, True) + mock_hook_instance.run_command.return_value = ( + "043E496C-A9E5-4284-AFCC-78A90E2BCB65", + "E4C36903-E59F-43AB-9374-ABA87509F46D", + ) + + operator = WinRMOperator( + task_id="test_task", + winrm_hook=mock_hook_instance, + command="dir", + deferrable=True, + ) + + result, events = execute_operator(operator) + + assert len(events) == 1 + assert isinstance(events[0], TriggerEvent) + assert events[0].payload == { + "command_id": "E4C36903-E59F-43AB-9374-ABA87509F46D", + "return_code": 0, + "shell_id": "043E496C-A9E5-4284-AFCC-78A90E2BCB65", + "status": "success", + "stderr": [], + "stdout": ["aGVsbG8="], + } + assert result == "aGVsbG8=" diff --git a/providers/microsoft/winrm/tests/unit/microsoft/winrm/triggers/__init__.py b/providers/microsoft/winrm/tests/unit/microsoft/winrm/triggers/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/providers/microsoft/winrm/tests/unit/microsoft/winrm/triggers/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/microsoft/winrm/tests/unit/microsoft/winrm/triggers/test_winrm.py b/providers/microsoft/winrm/tests/unit/microsoft/winrm/triggers/test_winrm.py new file mode 100644 index 0000000000000..511725e686714 --- /dev/null +++ b/providers/microsoft/winrm/tests/unit/microsoft/winrm/triggers/test_winrm.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +import base64 +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from airflow.providers.microsoft.winrm.triggers.winrm import WinRMCommandOutputTrigger +from airflow.triggers.base import TriggerEvent + + +class TestWinRMCommandOutputTrigger: + def test_serialize(self): + trigger = WinRMCommandOutputTrigger( + ssh_conn_id="ssh_conn_id", + shell_id="043E496C-A9E5-4284-AFCC-78A90E2BCB65", + command_id="E4C36903-E59F-43AB-9374-ABA87509F46D", + output_encoding="utf-8", + return_output=True, + poll_interval=10, + max_output_chunks=100, + ) + + actual = trigger.serialize() + + assert isinstance(actual, tuple) + assert actual[0] == f"{WinRMCommandOutputTrigger.__module__}.{WinRMCommandOutputTrigger.__name__}" + assert actual[1] == { + "ssh_conn_id": "ssh_conn_id", + "shell_id": "043E496C-A9E5-4284-AFCC-78A90E2BCB65", + "command_id": "E4C36903-E59F-43AB-9374-ABA87509F46D", + "output_encoding": "utf-8", + "return_output": True, + "poll_interval": 10, + "max_output_chunks": 100, + } + + @pytest.mark.asyncio + @patch("airflow.providers.microsoft.winrm.triggers.winrm.WinRMHook") + async def test_run(self, mock_hook): + trigger = WinRMCommandOutputTrigger( + ssh_conn_id="ssh_conn_id", + shell_id="043E496C-A9E5-4284-AFCC-78B5717DF4D73", + command_id="78CE100B-04FD-4EE2-8DAF-0751795661BB", + poll_interval=1, + ) + + mock_hook_instance = MagicMock() + mock_hook.return_value = mock_hook_instance + mock_conn = MagicMock() + mock_hook_instance.get_async_conn = AsyncMock(return_value=mock_conn) + mock_hook_instance.get_conn.return_value = mock_conn + mock_hook_instance.get_command_output.return_value = (b"hello", b"", 0, True) + + with patch.object(asyncio, "sleep", return_value=None): + events = [event async for event in trigger.run()] + + assert len(events) == 1 + event = events[0] + assert isinstance(event, TriggerEvent) + + payload = event.payload + assert payload["status"] == "success" + assert payload["shell_id"] == "043E496C-A9E5-4284-AFCC-78B5717DF4D73" + assert payload["command_id"] == "78CE100B-04FD-4EE2-8DAF-0751795661BB" + assert payload["return_code"] == 0 + assert base64.b64decode(payload["stdout"][0]) == b"hello" + assert not payload["stderr"] From 104c95a3950214a346ba2e4b22c816795ea692be Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 11 Feb 2026 13:43:51 +0100 Subject: [PATCH 2/2] Update providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/triggers/winrm.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../src/airflow/providers/microsoft/winrm/triggers/winrm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/triggers/winrm.py b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/triggers/winrm.py index 1ba9bfd549689..61bb37b7d32dd 100644 --- a/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/triggers/winrm.py +++ b/providers/microsoft/winrm/src/airflow/providers/microsoft/winrm/triggers/winrm.py @@ -15,7 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Hook for winrm remote execution.""" +"""Trigger for winrm remote execution.""" from __future__ import annotations