Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions providers/microsoft/winrm/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ hooks:
python-modules:
- airflow.providers.microsoft.winrm.hooks.winrm

triggers:
- integration-name: Windows Remote Management (WinRM)
python-modules:
- airflow.providers.microsoft.winrm.triggers.winrm

connection-types:
- hook-class-name: airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook
connection-type: winrm
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def get_provider_info():
"python-modules": ["airflow.providers.microsoft.winrm.hooks.winrm"],
}
],
"triggers": [
{
"integration-name": "Windows Remote Management (WinRM)",
"python-modules": ["airflow.providers.microsoft.winrm.triggers.winrm"],
}
],
"connection-types": [
{
"hook-class-name": "airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,20 @@

from __future__ import annotations

import logging
from base64 import b64encode
from contextlib import suppress
from typing import TYPE_CHECKING, Any, Literal, cast

from winrm.exceptions import WinRMOperationTimeoutError
from winrm.protocol import Protocol

from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException, BaseHook
from airflow.utils.platform import getuser

# TODO: FIXME please - I have too complex implementation
if TYPE_CHECKING:
from airflow.providers.common.compat.sdk import Connection


class WinRMHook(BaseHook):
Expand Down Expand Up @@ -122,13 +126,10 @@ def __init__(
self.credssp_disable_tlsv1_2 = credssp_disable_tlsv1_2
self.send_cbt = send_cbt

self.winrm_protocol = None

def get_conn(self):
self.log.debug("Creating WinRM client for conn_id: %s", self.ssh_conn_id)
if self.ssh_conn_id is not None:
conn = self.get_connection(self.ssh_conn_id)
self.winrm_protocol: Protocol | None = None

def create_protocol(self, conn: Connection | None) -> Protocol:
if conn:
if self.username is None:
self.username = conn.login
if self.password is None:
Expand Down Expand Up @@ -192,36 +193,57 @@ def get_conn(self):
self.endpoint = f"http://{self.remote_host}:{self.remote_port}/wsman"

try:
if self.password and self.password.strip():
self.winrm_protocol = Protocol(
endpoint=self.endpoint,
transport=self.transport,
username=self.username,
password=self.password,
service=self.service,
keytab=self.keytab,
ca_trust_path=self.ca_trust_path,
cert_pem=self.cert_pem,
cert_key_pem=self.cert_key_pem,
server_cert_validation=self.server_cert_validation,
kerberos_delegation=self.kerberos_delegation,
read_timeout_sec=self.read_timeout_sec,
operation_timeout_sec=self.operation_timeout_sec,
kerberos_hostname_override=self.kerberos_hostname_override,
message_encryption=self.message_encryption,
credssp_disable_tlsv1_2=self.credssp_disable_tlsv1_2,
send_cbt=self.send_cbt,
winrm_protocol = Protocol(
endpoint=self.endpoint,
transport=cast(
"Literal['auto', 'basic', 'certificate', 'ntlm', 'kerberos', 'credssp', 'plaintext', 'ssl']",
self.transport,
),
username=self.username,
password=self.password,
service=self.service,
keytab=cast("Any", self.keytab),
ca_trust_path=cast("str | Literal['legacy_requests']", self.ca_trust_path),
cert_pem=self.cert_pem,
cert_key_pem=self.cert_key_pem,
server_cert_validation=cast(
"Literal['validate', 'ignore'] | None", self.server_cert_validation
),
kerberos_delegation=self.kerberos_delegation,
read_timeout_sec=self.read_timeout_sec,
operation_timeout_sec=self.operation_timeout_sec,
kerberos_hostname_override=self.kerberos_hostname_override,
message_encryption=cast("Literal['auto', 'always', 'never']", self.message_encryption),
credssp_disable_tlsv1_2=self.credssp_disable_tlsv1_2,
send_cbt=self.send_cbt,
)

if not hasattr(winrm_protocol, "get_command_output_raw"):
# since pywinrm>=0.5 get_command_output_raw replace _raw_get_command_output
winrm_protocol.get_command_output_raw = ( # type: ignore[method-assign]
winrm_protocol._raw_get_command_output
)

self.log.info("Establishing WinRM connection to host: %s", self.remote_host)

return winrm_protocol
except Exception as error:
error_msg = f"Error creating connection to host: {self.remote_host}, error: {error}"
self.log.error(error_msg)
raise AirflowException(error_msg)

if not hasattr(self.winrm_protocol, "get_command_output_raw"):
# since pywinrm>=0.5 get_command_output_raw replace _raw_get_command_output
self.winrm_protocol.get_command_output_raw = self.winrm_protocol._raw_get_command_output
def get_conn(self) -> Protocol:
if self.winrm_protocol is None:
self.winrm_protocol = self.create_protocol(
self.get_connection(self.ssh_conn_id) if self.ssh_conn_id else None
)
return self.winrm_protocol

async def get_async_conn(self) -> Protocol:
if self.winrm_protocol is None:
self.winrm_protocol = self.create_protocol(
await get_async_connection(self.ssh_conn_id) if self.ssh_conn_id else None
)
return self.winrm_protocol

def run(
Expand All @@ -231,7 +253,7 @@ def run(
output_encoding: str = "utf-8",
return_output: bool = True,
working_directory: str | None = None,
) -> tuple[int, list[bytes], list[bytes]]:
) -> tuple[int | None, list[bytes], list[bytes]]:
"""
Run a command.

Expand All @@ -243,55 +265,104 @@ def run(
:param working_directory: specify working directory.
:return: returns a tuple containing return_code, stdout and stderr in order.
"""
winrm_client = self.get_conn()
self.log.info("Establishing WinRM connection to host: %s", self.remote_host)
try:
shell_id = winrm_client.open_shell(working_directory=working_directory)
except Exception as error:
error_msg = f"Error connecting to host: {self.remote_host}, error: {error}"
self.log.error(error_msg)
raise AirflowException(error_msg)
conn = self.get_conn()
shell_id, command_id = self._run_command(
conn=conn,
command=command,
ps_path=ps_path,
working_directory=working_directory,
)

try:
if ps_path is not None:
self.log.info("Running command as powershell script: '%s'...", command)
encoded_ps = b64encode(command.encode("utf_16_le")).decode("ascii")
command_id = winrm_client.run_command(shell_id, f"{ps_path} -encodedcommand {encoded_ps}")
else:
self.log.info("Running command: '%s'...", command)
command_id = winrm_client.run_command(shell_id, command)

# See: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
command_done = False
stdout_buffer = []
stderr_buffer = []
command_done = False
return_code: int | None = None

# See: https://github.com/diyan/pywinrm/blob/master/winrm/protocol.py
while not command_done:
# this is an expected error when waiting for a long-running process, just silently retry
with suppress(WinRMOperationTimeoutError):
(
stdout,
stderr,
return_code,
command_done,
) = winrm_client.get_command_output_raw(shell_id, command_id)

# Only buffer stdout if we need to so that we minimize memory usage.
if return_output:
stdout_buffer.append(stdout)
stderr_buffer.append(stderr)

for line in stdout.decode(output_encoding).splitlines():
self.log.info(line)
for line in stderr.decode(output_encoding).splitlines():
self.log.warning(line)

winrm_client.cleanup_command(shell_id, command_id)
(
stdout,
stderr,
return_code,
command_done,
) = self.get_command_output(conn, shell_id, command_id, output_encoding)

# Only buffer stdout if we need to so that we minimize memory usage.
if return_output:
stdout_buffer.append(stdout)
stderr_buffer.append(stderr)

return return_code, stdout_buffer, stderr_buffer
except Exception as e:
raise AirflowException(f"WinRM operator error: {e}")
finally:
winrm_client.close_shell(shell_id)
conn.cleanup_command(shell_id, command_id)
conn.close_shell(shell_id)

def run_command(
self,
command: str,
ps_path: str | None = None,
working_directory: str | None = None,
) -> tuple[str, str]:
return self._run_command(self.get_conn(), command, ps_path, working_directory)

def _run_command(
self,
conn: Protocol,
command: str,
ps_path: str | None = None,
working_directory: str | None = None,
) -> tuple[str, str]:
if not command:
raise AirflowException("No command specified so nothing to execute here.")

try:
shell_id = conn.open_shell(working_directory=working_directory)

if ps_path is not None:
self.log.info("Running command as powershell script: '%s'...", command)
encoded_ps = b64encode(command.encode("utf_16_le")).decode("ascii")
command_id = conn.run_command(shell_id, f"{ps_path} -encodedcommand {encoded_ps}")
else:
self.log.info("Running command: '%s'...", command)
command_id = conn.run_command(shell_id, command)
except Exception as error:
error_msg = f"Error connecting to host: {self.remote_host}, error: {error}"
self.log.error(error_msg)
raise AirflowException(error_msg)

return shell_id, command_id

def get_command_output(
self, conn: Protocol, shell_id: str, command_id: str, output_encoding: str = "utf-8"
) -> tuple[bytes, bytes, int | None, bool]:
with suppress(WinRMOperationTimeoutError):
(
stdout,
stderr,
return_code,
command_done,
) = conn.get_command_output_raw(shell_id, command_id)

self.log.debug("return_code: %s", return_code)
self.log.debug("command_done: %s", command_done)
self.log_output(stdout, output_encoding=output_encoding)
self.log_output(stderr, level=logging.WARNING, output_encoding=output_encoding)

return stdout, stderr, return_code, command_done
return b"", b"", None, False

def log_output(
self,
output: bytes | None,
level: int = logging.INFO,
output_encoding: str = "utf-8",
) -> None:
if output:
for line in output.decode(output_encoding).splitlines():
self.log.log(level, line)

def test_connection(self):
try:
Expand Down
Loading
Loading