Source code for airflow.providers.teradata.utils.tpt_util

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging
import os
import shutil
import subprocess
import uuid
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from paramiko import SSHClient


[docs] class TPTConfig: """Configuration constants for TPT operations."""
[docs] DEFAULT_TIMEOUT = 5
[docs] FILE_PERMISSIONS_READ_ONLY = 0o400
[docs] TEMP_DIR_WINDOWS = "C:\\Windows\\Temp"
[docs] TEMP_DIR_UNIX = "/tmp"
[docs] def execute_remote_command(ssh_client: SSHClient, command: str) -> tuple[int, str, str]: """ Execute a command on remote host and properly manage SSH channels. :param ssh_client: SSH client connection :param command: Command to execute :return: Tuple of (exit_status, stdout, stderr) """ stdin, stdout, stderr = ssh_client.exec_command(command) try: exit_status = stdout.channel.recv_exit_status() stdout_data = stdout.read().decode().strip() stderr_data = stderr.read().decode().strip() return exit_status, stdout_data, stderr_data finally: stdin.close() stdout.close() stderr.close()
[docs] def write_file(path: str, content: str) -> None: with open(path, "w", encoding="utf-8") as f: f.write(content)
[docs] def secure_delete(file_path: str, logger: logging.Logger | None = None) -> None: """ Securely delete a file using shred if available, otherwise use os.remove. :param file_path: Path to the file to be deleted :param logger: Optional logger instance """ logger = logger or logging.getLogger(__name__) if not os.path.exists(file_path): return try: # Check if shred is available if shutil.which("shred") is not None: # Use shred to securely delete the file subprocess.run(["shred", "--remove", file_path], check=True, timeout=TPTConfig.DEFAULT_TIMEOUT) logger.info("Securely removed file using shred: %s", file_path) else: # Fall back to regular deletion os.remove(file_path) logger.info("Removed file: %s", file_path) except (OSError, subprocess.CalledProcessError, subprocess.TimeoutExpired) as e: logger.warning("Failed to remove file %s: %s", file_path, str(e))
[docs] def remote_secure_delete( ssh_client: SSHClient, remote_files: list[str], logger: logging.Logger | None = None ) -> None: """ Securely delete remote files via SSH. Attempts shred first, falls back to rm if shred is unavailable. :param ssh_client: SSH client connection :param remote_files: List of remote file paths to delete :param logger: Optional logger instance """ logger = logger or logging.getLogger(__name__) if not ssh_client or not remote_files: return try: # Detect remote OS remote_os = get_remote_os(ssh_client, logger) windows_remote = remote_os == "windows" # Check if shred is available on remote system (UNIX/Linux) shred_available = False if not windows_remote: exit_status, output, _ = execute_remote_command(ssh_client, "command -v shred") shred_available = exit_status == 0 and output.strip() != "" for file_path in remote_files: try: if windows_remote: # Windows remote host - use del command replace_slash = file_path.replace("/", "\\") execute_remote_command( ssh_client, f'if exist "{replace_slash}" del /f /q "{replace_slash}"' ) elif shred_available: # UNIX/Linux with shred execute_remote_command(ssh_client, f"shred --remove {file_path}") else: # UNIX/Linux without shred - overwrite then delete execute_remote_command( ssh_client, f"if [ -f {file_path} ]; then " f"dd if=/dev/zero of={file_path} bs=4096 count=$(($(stat -c '%s' {file_path})/4096+1)) 2>/dev/null; " f"rm -f {file_path}; fi", ) except Exception as e: logger.warning("Failed to process remote file %s: %s", file_path, str(e)) logger.info("Processed remote files: %s", ", ".join(remote_files)) except Exception as e: logger.warning("Failed to remove remote files: %s", str(e))
[docs] def terminate_subprocess(sp: subprocess.Popen | None, logger: logging.Logger | None = None) -> None: """ Terminate a subprocess gracefully with proper error handling. :param sp: Subprocess to terminate :param logger: Optional logger instance """ logger = logger or logging.getLogger(__name__) if not sp or sp.poll() is not None: # Process is None or already terminated return logger.info("Terminating subprocess (PID: %s)", sp.pid) try: sp.terminate() # Attempt to terminate gracefully sp.wait(timeout=TPTConfig.DEFAULT_TIMEOUT) logger.info("Subprocess terminated gracefully") except subprocess.TimeoutExpired: logger.warning( "Subprocess did not terminate gracefully within %d seconds, killing it", TPTConfig.DEFAULT_TIMEOUT ) try: sp.kill() sp.wait(timeout=2) # Brief wait after kill logger.info("Subprocess killed successfully") except Exception as e: logger.error("Error killing subprocess: %s", str(e)) except Exception as e: logger.error("Error terminating subprocess: %s", str(e))
[docs] def get_remote_os(ssh_client: SSHClient, logger: logging.Logger | None = None) -> str: """ Detect the operating system of the remote host via SSH. :param ssh_client: SSH client connection :param logger: Optional logger instance :return: Operating system type as string ('windows' or 'unix') """ logger = logger or logging.getLogger(__name__) if not ssh_client: logger.warning("No SSH client provided for OS detection") return "unix" try: # Check for Windows first exit_status, stdout_data, stderr_data = execute_remote_command(ssh_client, "echo %OS%") if "Windows" in stdout_data: return "windows" # All other systems are treated as Unix-like return "unix" except Exception as e: logger.error("Error detecting remote OS: %s", str(e)) return "unix"
[docs] def set_local_file_permissions(local_file_path: str, logger: logging.Logger | None = None) -> None: """ Set permissions for a local file to be read-only for the owner. :param local_file_path: Path to the local file :param logger: Optional logger instance :raises FileNotFoundError: If the file does not exist :raises OSError: If setting permissions fails """ logger = logger or logging.getLogger(__name__) if not local_file_path: logger.warning("No file path provided for permission setting") return if not os.path.exists(local_file_path): raise FileNotFoundError(f"File does not exist: {local_file_path}") try: # Set file permission to read-only for the owner (400) os.chmod(local_file_path, TPTConfig.FILE_PERMISSIONS_READ_ONLY) logger.info("Set read-only permissions for file %s", local_file_path) except (OSError, PermissionError) as e: raise OSError(f"Error setting permissions for local file {local_file_path}: {e}") from e
def _set_windows_file_permissions( ssh_client: SSHClient, remote_file_path: str, logger: logging.Logger ) -> None: """Set restrictive permissions on Windows remote file.""" command = f'icacls "{remote_file_path}" /inheritance:r /grant:r "%USERNAME%":R' exit_status, stdout_data, stderr_data = execute_remote_command(ssh_client, command) if exit_status != 0: raise RuntimeError( f"Failed to set restrictive permissions on Windows remote file {remote_file_path}. " f"Exit status: {exit_status}, Error: {stderr_data if stderr_data else 'N/A'}" ) logger.info("Set restrictive permissions (owner read-only) for Windows remote file %s", remote_file_path) def _set_unix_file_permissions(ssh_client: SSHClient, remote_file_path: str, logger: logging.Logger) -> None: """Set read-only permissions on Unix/Linux remote file.""" command = f"chmod 400 {remote_file_path}" exit_status, stdout_data, stderr_data = execute_remote_command(ssh_client, command) if exit_status != 0: raise RuntimeError( f"Failed to set permissions (400) on remote file {remote_file_path}. " f"Exit status: {exit_status}, Error: {stderr_data if stderr_data else 'N/A'}" ) logger.info("Set read-only permissions for remote file %s", remote_file_path)
[docs] def set_remote_file_permissions( ssh_client: SSHClient, remote_file_path: str, logger: logging.Logger | None = None ) -> None: """ Set permissions for a remote file to be read-only for the owner. :param ssh_client: SSH client connection :param remote_file_path: Path to the remote file :param logger: Optional logger instance :raises RuntimeError: If permission setting fails """ logger = logger or logging.getLogger(__name__) if not ssh_client or not remote_file_path: logger.warning( "Invalid parameters: ssh_client=%s, remote_file_path=%s", bool(ssh_client), remote_file_path ) return try: # Detect remote OS once remote_os = get_remote_os(ssh_client, logger) if remote_os == "windows": _set_windows_file_permissions(ssh_client, remote_file_path, logger) else: _set_unix_file_permissions(ssh_client, remote_file_path, logger) except RuntimeError: raise except Exception as e: raise RuntimeError(f"Error setting permissions for remote file {remote_file_path}: {e}") from e
[docs] def get_remote_temp_directory(ssh_client: SSHClient, logger: logging.Logger | None = None) -> str: """ Get the remote temporary directory path based on the operating system. :param ssh_client: SSH client connection :param logger: Optional logger instance :return: Path to the remote temporary directory """ logger = logger or logging.getLogger(__name__) try: # Detect OS once remote_os = get_remote_os(ssh_client, logger) if remote_os == "windows": exit_status, temp_dir, stderr_data = execute_remote_command(ssh_client, "echo %TEMP%") if exit_status == 0 and temp_dir and temp_dir != "%TEMP%": return temp_dir logger.warning("Could not get TEMP directory, using default: %s", TPTConfig.TEMP_DIR_WINDOWS) return TPTConfig.TEMP_DIR_WINDOWS # Unix/Linux - use /tmp return TPTConfig.TEMP_DIR_UNIX except Exception as e: logger.warning("Error getting remote temp directory: %s", str(e)) return TPTConfig.TEMP_DIR_UNIX
[docs] def verify_tpt_utility_installed(utility: str) -> None: """Verify if a TPT utility (e.g., tbuild) is installed and available in the system's PATH.""" if shutil.which(utility) is None: raise FileNotFoundError( f"TPT utility '{utility}' is not installed or not available in the system's PATH" )
[docs] def verify_tpt_utility_on_remote_host( ssh_client: SSHClient, utility: str, logger: logging.Logger | None = None ) -> None: """ Verify if a TPT utility (tbuild) is installed on the remote host via SSH. :param ssh_client: SSH client connection :param utility: Name of the utility to verify :param logger: Optional logger instance :raises FileNotFoundError: If utility is not found on remote host :raises RuntimeError: If verification fails unexpectedly """ logger = logger or logging.getLogger(__name__) try: # Detect remote OS once remote_os = get_remote_os(ssh_client, logger) if remote_os == "windows": command = f"where {utility}" else: command = f"which {utility}" exit_status, output, error = execute_remote_command(ssh_client, command) if exit_status != 0 or not output: raise FileNotFoundError( f"TPT utility '{utility}' is not installed or not available in PATH on the remote host. " f"Command: {command}, Exit status: {exit_status}, " f"stderr: {error if error else 'N/A'}" ) logger.info("TPT utility '%s' found at: %s", utility, output.split("\n")[0]) except (FileNotFoundError, RuntimeError): raise except Exception as e: raise RuntimeError(f"Failed to verify TPT utility '{utility}' on remote host: {e}") from e
[docs] def prepare_tpt_ddl_script( sql: list[str], error_list: list[int] | None, source_conn: dict[str, Any], job_name: str | None = None, ) -> str: """ Prepare a TPT script for executing DDL statements. This method generates a TPT script that defines a DDL operator and applies the provided SQL statements. It also supports specifying a list of error codes to handle during the operation. :param sql: A list of DDL statements to execute. :param error_list: A list of error codes to handle during the operation. :param source_conn: Connection details for the source database. :param job_name: The name of the TPT job. Defaults to unique name if None. :return: A formatted TPT script as a string. :raises ValueError: If the SQL statement list is empty. """ if not sql or not isinstance(sql, list): raise ValueError("SQL statement list must be a non-empty list") # Clean and escape each SQL statement: sql_statements = [ stmt.strip().rstrip(";").replace("'", "''") for stmt in sql if stmt and isinstance(stmt, str) and stmt.strip() ] if not sql_statements: raise ValueError("No valid SQL statements found in the provided input") # Format for TPT APPLY block, indenting after the first line apply_sql = ",\n".join( [f"('{stmt};')" if i == 0 else f" ('{stmt};')" for i, stmt in enumerate(sql_statements)] ) if job_name is None: job_name = f"airflow_tptddl_{uuid.uuid4().hex}" # Format error list for inclusion in the TPT script if not error_list: error_list_stmt = "ErrorList = ['']" else: error_list_str = ", ".join([f"'{error}'" for error in error_list]) error_list_stmt = f"ErrorList = [{error_list_str}]" host = source_conn["host"] login = source_conn["login"] password = source_conn["password"] tpt_script = f""" DEFINE JOB {job_name} DESCRIPTION 'TPT DDL Operation' ( APPLY {apply_sql} TO OPERATOR ( $DDL () ATTR ( TdpId = '{host}', UserName = '{login}', UserPassword = '{password}', {error_list_stmt} ) ); ); """ return tpt_script
[docs] def decrypt_remote_file( ssh_client: SSHClient, remote_enc_file: str, remote_dec_file: str, password: str, logger: logging.Logger | None = None, ) -> int: """ Decrypt a remote file using OpenSSL. :param ssh_client: SSH client connection :param remote_enc_file: Path to the encrypted file :param remote_dec_file: Path for the decrypted file :param password: Decryption password :param logger: Optional logger instance :return: Exit status of the decryption command :raises RuntimeError: If decryption fails """ logger = logger or logging.getLogger(__name__) # Detect remote OS remote_os = get_remote_os(ssh_client, logger) windows_remote = remote_os == "windows" if windows_remote: # Windows - use different quoting and potentially different OpenSSL parameters password_escaped = password.replace('"', '""') # Escape double quotes for Windows decrypt_cmd = ( f'openssl enc -d -aes-256-cbc -salt -pbkdf2 -pass pass:"{password_escaped}" ' f'-in "{remote_enc_file}" -out "{remote_dec_file}"' ) else: # Unix/Linux - use single quote escaping password_escaped = password.replace("'", "'\\''") # Escape single quotes decrypt_cmd = ( f"openssl enc -d -aes-256-cbc -salt -pbkdf2 -pass pass:'{password_escaped}' " f"-in {remote_enc_file} -out {remote_dec_file}" ) exit_status, stdout_data, stderr_data = execute_remote_command(ssh_client, decrypt_cmd) if exit_status != 0: raise RuntimeError( f"Decryption failed with exit status {exit_status}. Error: {stderr_data if stderr_data else 'N/A'}" ) logger.info("Successfully decrypted remote file %s to %s", remote_enc_file, remote_dec_file) return exit_status
[docs] def transfer_file_sftp( ssh_client: SSHClient, local_path: str, remote_path: str, logger: logging.Logger | None = None ) -> None: """ Transfer a file from local to remote host using SFTP. :param ssh_client: SSH client connection :param local_path: Local file path :param remote_path: Remote file path :param logger: Optional logger instance :raises FileNotFoundError: If local file does not exist :raises RuntimeError: If file transfer fails """ logger = logger or logging.getLogger(__name__) if not os.path.exists(local_path): raise FileNotFoundError(f"Local file does not exist: {local_path}") sftp = None try: sftp = ssh_client.open_sftp() sftp.put(local_path, remote_path) logger.info("Successfully transferred file from %s to %s", local_path, remote_path) except Exception as e: raise RuntimeError(f"Failed to transfer file from {local_path} to {remote_path}: {e}") from e finally: if sftp: sftp.close()

Was this entry helpful?