# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import logging
import os
import shutil
import subprocess
import uuid
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from paramiko import SSHClient
[docs]
class TPTConfig:
"""Configuration constants for TPT operations."""
[docs]
FILE_PERMISSIONS_READ_ONLY = 0o400
[docs]
TEMP_DIR_WINDOWS = "C:\\Windows\\Temp"
[docs]
def execute_remote_command(ssh_client: SSHClient, command: str) -> tuple[int, str, str]:
"""
Execute a command on remote host and properly manage SSH channels.
:param ssh_client: SSH client connection
:param command: Command to execute
:return: Tuple of (exit_status, stdout, stderr)
"""
stdin, stdout, stderr = ssh_client.exec_command(command)
try:
exit_status = stdout.channel.recv_exit_status()
stdout_data = stdout.read().decode().strip()
stderr_data = stderr.read().decode().strip()
return exit_status, stdout_data, stderr_data
finally:
stdin.close()
stdout.close()
stderr.close()
[docs]
def write_file(path: str, content: str) -> None:
with open(path, "w", encoding="utf-8") as f:
f.write(content)
[docs]
def secure_delete(file_path: str, logger: logging.Logger | None = None) -> None:
"""
Securely delete a file using shred if available, otherwise use os.remove.
:param file_path: Path to the file to be deleted
:param logger: Optional logger instance
"""
logger = logger or logging.getLogger(__name__)
if not os.path.exists(file_path):
return
try:
# Check if shred is available
if shutil.which("shred") is not None:
# Use shred to securely delete the file
subprocess.run(["shred", "--remove", file_path], check=True, timeout=TPTConfig.DEFAULT_TIMEOUT)
logger.info("Securely removed file using shred: %s", file_path)
else:
# Fall back to regular deletion
os.remove(file_path)
logger.info("Removed file: %s", file_path)
except (OSError, subprocess.CalledProcessError, subprocess.TimeoutExpired) as e:
logger.warning("Failed to remove file %s: %s", file_path, str(e))
[docs]
def remote_secure_delete(
ssh_client: SSHClient, remote_files: list[str], logger: logging.Logger | None = None
) -> None:
"""
Securely delete remote files via SSH. Attempts shred first, falls back to rm if shred is unavailable.
:param ssh_client: SSH client connection
:param remote_files: List of remote file paths to delete
:param logger: Optional logger instance
"""
logger = logger or logging.getLogger(__name__)
if not ssh_client or not remote_files:
return
try:
# Detect remote OS
remote_os = get_remote_os(ssh_client, logger)
windows_remote = remote_os == "windows"
# Check if shred is available on remote system (UNIX/Linux)
shred_available = False
if not windows_remote:
exit_status, output, _ = execute_remote_command(ssh_client, "command -v shred")
shred_available = exit_status == 0 and output.strip() != ""
for file_path in remote_files:
try:
if windows_remote:
# Windows remote host - use del command
replace_slash = file_path.replace("/", "\\")
execute_remote_command(
ssh_client, f'if exist "{replace_slash}" del /f /q "{replace_slash}"'
)
elif shred_available:
# UNIX/Linux with shred
execute_remote_command(ssh_client, f"shred --remove {file_path}")
else:
# UNIX/Linux without shred - overwrite then delete
execute_remote_command(
ssh_client,
f"if [ -f {file_path} ]; then "
f"dd if=/dev/zero of={file_path} bs=4096 count=$(($(stat -c '%s' {file_path})/4096+1)) 2>/dev/null; "
f"rm -f {file_path}; fi",
)
except Exception as e:
logger.warning("Failed to process remote file %s: %s", file_path, str(e))
logger.info("Processed remote files: %s", ", ".join(remote_files))
except Exception as e:
logger.warning("Failed to remove remote files: %s", str(e))
[docs]
def terminate_subprocess(sp: subprocess.Popen | None, logger: logging.Logger | None = None) -> None:
"""
Terminate a subprocess gracefully with proper error handling.
:param sp: Subprocess to terminate
:param logger: Optional logger instance
"""
logger = logger or logging.getLogger(__name__)
if not sp or sp.poll() is not None:
# Process is None or already terminated
return
logger.info("Terminating subprocess (PID: %s)", sp.pid)
try:
sp.terminate() # Attempt to terminate gracefully
sp.wait(timeout=TPTConfig.DEFAULT_TIMEOUT)
logger.info("Subprocess terminated gracefully")
except subprocess.TimeoutExpired:
logger.warning(
"Subprocess did not terminate gracefully within %d seconds, killing it", TPTConfig.DEFAULT_TIMEOUT
)
try:
sp.kill()
sp.wait(timeout=2) # Brief wait after kill
logger.info("Subprocess killed successfully")
except Exception as e:
logger.error("Error killing subprocess: %s", str(e))
except Exception as e:
logger.error("Error terminating subprocess: %s", str(e))
[docs]
def get_remote_os(ssh_client: SSHClient, logger: logging.Logger | None = None) -> str:
"""
Detect the operating system of the remote host via SSH.
:param ssh_client: SSH client connection
:param logger: Optional logger instance
:return: Operating system type as string ('windows' or 'unix')
"""
logger = logger or logging.getLogger(__name__)
if not ssh_client:
logger.warning("No SSH client provided for OS detection")
return "unix"
try:
# Check for Windows first
exit_status, stdout_data, stderr_data = execute_remote_command(ssh_client, "echo %OS%")
if "Windows" in stdout_data:
return "windows"
# All other systems are treated as Unix-like
return "unix"
except Exception as e:
logger.error("Error detecting remote OS: %s", str(e))
return "unix"
[docs]
def set_local_file_permissions(local_file_path: str, logger: logging.Logger | None = None) -> None:
"""
Set permissions for a local file to be read-only for the owner.
:param local_file_path: Path to the local file
:param logger: Optional logger instance
:raises FileNotFoundError: If the file does not exist
:raises OSError: If setting permissions fails
"""
logger = logger or logging.getLogger(__name__)
if not local_file_path:
logger.warning("No file path provided for permission setting")
return
if not os.path.exists(local_file_path):
raise FileNotFoundError(f"File does not exist: {local_file_path}")
try:
# Set file permission to read-only for the owner (400)
os.chmod(local_file_path, TPTConfig.FILE_PERMISSIONS_READ_ONLY)
logger.info("Set read-only permissions for file %s", local_file_path)
except (OSError, PermissionError) as e:
raise OSError(f"Error setting permissions for local file {local_file_path}: {e}") from e
def _set_windows_file_permissions(
ssh_client: SSHClient, remote_file_path: str, logger: logging.Logger
) -> None:
"""Set restrictive permissions on Windows remote file."""
command = f'icacls "{remote_file_path}" /inheritance:r /grant:r "%USERNAME%":R'
exit_status, stdout_data, stderr_data = execute_remote_command(ssh_client, command)
if exit_status != 0:
raise RuntimeError(
f"Failed to set restrictive permissions on Windows remote file {remote_file_path}. "
f"Exit status: {exit_status}, Error: {stderr_data if stderr_data else 'N/A'}"
)
logger.info("Set restrictive permissions (owner read-only) for Windows remote file %s", remote_file_path)
def _set_unix_file_permissions(ssh_client: SSHClient, remote_file_path: str, logger: logging.Logger) -> None:
"""Set read-only permissions on Unix/Linux remote file."""
command = f"chmod 400 {remote_file_path}"
exit_status, stdout_data, stderr_data = execute_remote_command(ssh_client, command)
if exit_status != 0:
raise RuntimeError(
f"Failed to set permissions (400) on remote file {remote_file_path}. "
f"Exit status: {exit_status}, Error: {stderr_data if stderr_data else 'N/A'}"
)
logger.info("Set read-only permissions for remote file %s", remote_file_path)
[docs]
def set_remote_file_permissions(
ssh_client: SSHClient, remote_file_path: str, logger: logging.Logger | None = None
) -> None:
"""
Set permissions for a remote file to be read-only for the owner.
:param ssh_client: SSH client connection
:param remote_file_path: Path to the remote file
:param logger: Optional logger instance
:raises RuntimeError: If permission setting fails
"""
logger = logger or logging.getLogger(__name__)
if not ssh_client or not remote_file_path:
logger.warning(
"Invalid parameters: ssh_client=%s, remote_file_path=%s", bool(ssh_client), remote_file_path
)
return
try:
# Detect remote OS once
remote_os = get_remote_os(ssh_client, logger)
if remote_os == "windows":
_set_windows_file_permissions(ssh_client, remote_file_path, logger)
else:
_set_unix_file_permissions(ssh_client, remote_file_path, logger)
except RuntimeError:
raise
except Exception as e:
raise RuntimeError(f"Error setting permissions for remote file {remote_file_path}: {e}") from e
[docs]
def get_remote_temp_directory(ssh_client: SSHClient, logger: logging.Logger | None = None) -> str:
"""
Get the remote temporary directory path based on the operating system.
:param ssh_client: SSH client connection
:param logger: Optional logger instance
:return: Path to the remote temporary directory
"""
logger = logger or logging.getLogger(__name__)
try:
# Detect OS once
remote_os = get_remote_os(ssh_client, logger)
if remote_os == "windows":
exit_status, temp_dir, stderr_data = execute_remote_command(ssh_client, "echo %TEMP%")
if exit_status == 0 and temp_dir and temp_dir != "%TEMP%":
return temp_dir
logger.warning("Could not get TEMP directory, using default: %s", TPTConfig.TEMP_DIR_WINDOWS)
return TPTConfig.TEMP_DIR_WINDOWS
# Unix/Linux - use /tmp
return TPTConfig.TEMP_DIR_UNIX
except Exception as e:
logger.warning("Error getting remote temp directory: %s", str(e))
return TPTConfig.TEMP_DIR_UNIX
[docs]
def verify_tpt_utility_installed(utility: str) -> None:
"""Verify if a TPT utility (e.g., tbuild) is installed and available in the system's PATH."""
if shutil.which(utility) is None:
raise FileNotFoundError(
f"TPT utility '{utility}' is not installed or not available in the system's PATH"
)
[docs]
def verify_tpt_utility_on_remote_host(
ssh_client: SSHClient, utility: str, logger: logging.Logger | None = None
) -> None:
"""
Verify if a TPT utility (tbuild) is installed on the remote host via SSH.
:param ssh_client: SSH client connection
:param utility: Name of the utility to verify
:param logger: Optional logger instance
:raises FileNotFoundError: If utility is not found on remote host
:raises RuntimeError: If verification fails unexpectedly
"""
logger = logger or logging.getLogger(__name__)
try:
# Detect remote OS once
remote_os = get_remote_os(ssh_client, logger)
if remote_os == "windows":
command = f"where {utility}"
else:
command = f"which {utility}"
exit_status, output, error = execute_remote_command(ssh_client, command)
if exit_status != 0 or not output:
raise FileNotFoundError(
f"TPT utility '{utility}' is not installed or not available in PATH on the remote host. "
f"Command: {command}, Exit status: {exit_status}, "
f"stderr: {error if error else 'N/A'}"
)
logger.info("TPT utility '%s' found at: %s", utility, output.split("\n")[0])
except (FileNotFoundError, RuntimeError):
raise
except Exception as e:
raise RuntimeError(f"Failed to verify TPT utility '{utility}' on remote host: {e}") from e
[docs]
def prepare_tpt_ddl_script(
sql: list[str],
error_list: list[int] | None,
source_conn: dict[str, Any],
job_name: str | None = None,
) -> str:
"""
Prepare a TPT script for executing DDL statements.
This method generates a TPT script that defines a DDL operator and applies the provided SQL statements.
It also supports specifying a list of error codes to handle during the operation.
:param sql: A list of DDL statements to execute.
:param error_list: A list of error codes to handle during the operation.
:param source_conn: Connection details for the source database.
:param job_name: The name of the TPT job. Defaults to unique name if None.
:return: A formatted TPT script as a string.
:raises ValueError: If the SQL statement list is empty.
"""
if not sql or not isinstance(sql, list):
raise ValueError("SQL statement list must be a non-empty list")
# Clean and escape each SQL statement:
sql_statements = [
stmt.strip().rstrip(";").replace("'", "''")
for stmt in sql
if stmt and isinstance(stmt, str) and stmt.strip()
]
if not sql_statements:
raise ValueError("No valid SQL statements found in the provided input")
# Format for TPT APPLY block, indenting after the first line
apply_sql = ",\n".join(
[f"('{stmt};')" if i == 0 else f" ('{stmt};')" for i, stmt in enumerate(sql_statements)]
)
if job_name is None:
job_name = f"airflow_tptddl_{uuid.uuid4().hex}"
# Format error list for inclusion in the TPT script
if not error_list:
error_list_stmt = "ErrorList = ['']"
else:
error_list_str = ", ".join([f"'{error}'" for error in error_list])
error_list_stmt = f"ErrorList = [{error_list_str}]"
host = source_conn["host"]
login = source_conn["login"]
password = source_conn["password"]
tpt_script = f"""
DEFINE JOB {job_name}
DESCRIPTION 'TPT DDL Operation'
(
APPLY
{apply_sql}
TO OPERATOR ( $DDL ()
ATTR
(
TdpId = '{host}',
UserName = '{login}',
UserPassword = '{password}',
{error_list_stmt}
)
);
);
"""
return tpt_script
[docs]
def decrypt_remote_file(
ssh_client: SSHClient,
remote_enc_file: str,
remote_dec_file: str,
password: str,
logger: logging.Logger | None = None,
) -> int:
"""
Decrypt a remote file using OpenSSL.
:param ssh_client: SSH client connection
:param remote_enc_file: Path to the encrypted file
:param remote_dec_file: Path for the decrypted file
:param password: Decryption password
:param logger: Optional logger instance
:return: Exit status of the decryption command
:raises RuntimeError: If decryption fails
"""
logger = logger or logging.getLogger(__name__)
# Detect remote OS
remote_os = get_remote_os(ssh_client, logger)
windows_remote = remote_os == "windows"
if windows_remote:
# Windows - use different quoting and potentially different OpenSSL parameters
password_escaped = password.replace('"', '""') # Escape double quotes for Windows
decrypt_cmd = (
f'openssl enc -d -aes-256-cbc -salt -pbkdf2 -pass pass:"{password_escaped}" '
f'-in "{remote_enc_file}" -out "{remote_dec_file}"'
)
else:
# Unix/Linux - use single quote escaping
password_escaped = password.replace("'", "'\\''") # Escape single quotes
decrypt_cmd = (
f"openssl enc -d -aes-256-cbc -salt -pbkdf2 -pass pass:'{password_escaped}' "
f"-in {remote_enc_file} -out {remote_dec_file}"
)
exit_status, stdout_data, stderr_data = execute_remote_command(ssh_client, decrypt_cmd)
if exit_status != 0:
raise RuntimeError(
f"Decryption failed with exit status {exit_status}. Error: {stderr_data if stderr_data else 'N/A'}"
)
logger.info("Successfully decrypted remote file %s to %s", remote_enc_file, remote_dec_file)
return exit_status
[docs]
def transfer_file_sftp(
ssh_client: SSHClient, local_path: str, remote_path: str, logger: logging.Logger | None = None
) -> None:
"""
Transfer a file from local to remote host using SFTP.
:param ssh_client: SSH client connection
:param local_path: Local file path
:param remote_path: Remote file path
:param logger: Optional logger instance
:raises FileNotFoundError: If local file does not exist
:raises RuntimeError: If file transfer fails
"""
logger = logger or logging.getLogger(__name__)
if not os.path.exists(local_path):
raise FileNotFoundError(f"Local file does not exist: {local_path}")
sftp = None
try:
sftp = ssh_client.open_sftp()
sftp.put(local_path, remote_path)
logger.info("Successfully transferred file from %s to %s", local_path, remote_path)
except Exception as e:
raise RuntimeError(f"Failed to transfer file from {local_path} to {remote_path}: {e}") from e
finally:
if sftp:
sftp.close()