diff --git a/TablePro/Core/SSH/SSHTunnelManager.swift b/TablePro/Core/SSH/SSHTunnelManager.swift index 0dc539e2..df8a2716 100644 --- a/TablePro/Core/SSH/SSHTunnelManager.swift +++ b/TablePro/Core/SSH/SSHTunnelManager.swift @@ -45,6 +45,12 @@ struct SSHTunnel { let createdAt: Date } +private struct SSHTunnelLaunch { + let process: Process + let errorPipe: Pipe + let askpassScriptPath: String? +} + /// Manages SSH tunnels for database connections using system ssh command actor SSHTunnelManager { static let shared = SSHTunnelManager() @@ -127,150 +133,79 @@ actor SSHTunnelManager { try await closeTunnel(connectionId: connectionId) } - // Find available local port - let localPort = try await findAvailablePort() - - // Build SSH command - let process = Process() - process.executableURL = URL(fileURLWithPath: "/usr/bin/ssh") - - var arguments = [ - "-N", // Don't execute remote command - "-o", "StrictHostKeyChecking=no", - "-o", "UserKnownHostsFile=/dev/null", - "-o", "ServerAliveInterval=60", - "-o", "ServerAliveCountMax=3", - "-o", "ConnectTimeout=10", - "-L", "\(localPort):\(remoteHost):\(remotePort)", - "-p", String(sshPort), - ] - - // Add authentication - switch authMethod { - case .privateKey: - guard let keyPath = privateKeyPath, !keyPath.isEmpty else { - throw SSHTunnelError.tunnelCreationFailed("Private key path is required for key authentication") + for localPort in localPortCandidates() { + let launch: SSHTunnelLaunch + + do { + launch = try createTunnelLaunch( + localPort: localPort, + sshHost: sshHost, + sshPort: sshPort, + sshUsername: sshUsername, + authMethod: authMethod, + privateKeyPath: privateKeyPath, + keyPassphrase: keyPassphrase, + sshPassword: sshPassword, + agentSocketPath: agentSocketPath, + remoteHost: remoteHost, + remotePort: remotePort, + jumpHosts: jumpHosts + ) + } catch let error as SSHTunnelError { + throw error + } catch { + throw SSHTunnelError.tunnelCreationFailed(error.localizedDescription) } - let expandedPath = expandPath(keyPath) - - // Validate private key exists and is readable - let fileManager = FileManager.default - guard fileManager.fileExists(atPath: expandedPath) else { - throw SSHTunnelError.tunnelCreationFailed("Private key file not found at: \(expandedPath)") - } - guard fileManager.isReadableFile(atPath: expandedPath) else { - throw SSHTunnelError.tunnelCreationFailed("Private key file is not readable. Check permissions (should be 600): \(expandedPath)") + do { + try launch.process.run() + } catch { + removeAskpassScript(launch.askpassScriptPath) + throw SSHTunnelError.tunnelCreationFailed(error.localizedDescription) } - // Force public key authentication - arguments.append(contentsOf: ["-i", expandedPath]) - arguments.append(contentsOf: ["-o", "PubkeyAuthentication=yes"]) - arguments.append(contentsOf: ["-o", "PasswordAuthentication=no"]) - arguments.append(contentsOf: ["-o", "PreferredAuthentications=publickey"]) - - case .password: - // For password auth, we'll use SSH_ASKPASS with a helper script - // Note: This requires ssh to be run without a TTY (which -N provides) - arguments.append(contentsOf: ["-o", "PasswordAuthentication=yes"]) - arguments.append(contentsOf: ["-o", "PreferredAuthentications=password"]) - arguments.append(contentsOf: ["-o", "PubkeyAuthentication=no"]) - - case .sshAgent: - arguments.append(contentsOf: ["-o", "PubkeyAuthentication=yes"]) - arguments.append(contentsOf: ["-o", "PasswordAuthentication=no"]) - arguments.append(contentsOf: ["-o", "PreferredAuthentications=publickey"]) - } - - // Jump host identity files - for jumpHost in jumpHosts where jumpHost.authMethod == .privateKey && !jumpHost.privateKeyPath.isEmpty { - arguments.append(contentsOf: ["-i", expandPath(jumpHost.privateKeyPath)]) - } - - // ProxyJump chain - if !jumpHosts.isEmpty { - let jumpString = jumpHosts.map(\.proxyJumpString).joined(separator: ",") - arguments.append(contentsOf: ["-J", jumpString]) - } - - arguments.append("\(sshUsername)@\(sshHost)") - - process.arguments = arguments - - // Set up SSH_ASKPASS for passphrase or password - var askpassScriptPath: String? - - if authMethod == .privateKey, let passphrase = keyPassphrase { - askpassScriptPath = try createAskpassScript(password: passphrase) - } else if authMethod == .password, let password = sshPassword { - askpassScriptPath = try createAskpassScript(password: password) - } - - if let scriptPath = askpassScriptPath { - var environment = ProcessInfo.processInfo.environment - environment["SSH_ASKPASS"] = scriptPath - environment["SSH_ASKPASS_REQUIRE"] = "force" - environment["DISPLAY"] = ":0" // Required for SSH_ASKPASS to work - process.environment = environment - } - - if authMethod == .sshAgent, let socketPath = agentSocketPath, !socketPath.isEmpty { - var environment = process.environment ?? ProcessInfo.processInfo.environment - environment["SSH_AUTH_SOCK"] = expandPath(socketPath) - process.environment = environment - } - - // Capture stderr for error messages - let errorPipe = Pipe() - process.standardError = errorPipe - process.standardOutput = FileHandle.nullDevice + let tunnelReady = await waitForTunnelReady( + localPort: localPort, + process: launch.process, + timeoutSeconds: 15 + ) - // Start the process - do { - try process.run() - } catch { - removeAskpassScript(askpassScriptPath) - throw SSHTunnelError.tunnelCreationFailed(error.localizedDescription) - } + removeAskpassScript(launch.askpassScriptPath) - // Wait for tunnel to become ready by probing the local port - let tunnelReady = await waitForTunnelReady( - localPort: localPort, - process: process, - timeoutSeconds: 15 - ) + if !tunnelReady { + if !launch.process.isRunning { + let errorData = launch.errorPipe.fileHandleForReading.readDataToEndOfFile() + let errorMessage = String(data: errorData, encoding: .utf8) ?? "Unknown error" - removeAskpassScript(askpassScriptPath) + if Self.isLocalPortBindFailure(errorMessage) { + Self.logger.notice("SSH tunnel bind race on local port \(localPort), retrying with another port") + continue + } - if !tunnelReady { - // Process died or timed out — read stderr for diagnostics - if !process.isRunning { - let errorData = errorPipe.fileHandleForReading.readDataToEndOfFile() - let errorMessage = String(data: errorData, encoding: .utf8) ?? "Unknown error" + throw classifySSHError( + errorMessage: errorMessage, + authMethod: authMethod + ) + } - throw classifySSHError( - errorMessage: errorMessage, - authMethod: authMethod - ) + launch.process.terminate() + throw SSHTunnelError.connectionTimeout } - // Process still running but port never became reachable - process.terminate() - throw SSHTunnelError.connectionTimeout - } + let tunnel = SSHTunnel( + connectionId: connectionId, + localPort: localPort, + remoteHost: remoteHost, + remotePort: remotePort, + process: launch.process, + createdAt: Date() + ) + tunnels[connectionId] = tunnel - // Store the tunnel - let tunnel = SSHTunnel( - connectionId: connectionId, - localPort: localPort, - remoteHost: remoteHost, - remotePort: remotePort, - process: process, - createdAt: Date() - ) - tunnels[connectionId] = tunnel + return localPort + } - return localPort + throw SSHTunnelError.noAvailablePort } /// Close an SSH tunnel @@ -311,32 +246,118 @@ actor SSHTunnelManager { // MARK: - Private Helpers - private func findAvailablePort() async throws -> Int { - for port in portRangeStart...portRangeEnd { - if isPortAvailable(port) { - return port - } - } - throw SSHTunnelError.noAvailablePort + private func localPortCandidates() -> [Int] { + Array(portRangeStart...portRangeEnd).shuffled() } - private func isPortAvailable(_ port: Int) -> Bool { - let socketFD = socket(AF_INET, SOCK_STREAM, 0) - guard socketFD >= 0 else { return false } - defer { close(socketFD) } + private func createTunnelLaunch( + localPort: Int, + sshHost: String, + sshPort: Int, + sshUsername: String, + authMethod: SSHAuthMethod, + privateKeyPath: String?, + keyPassphrase: String?, + sshPassword: String?, + agentSocketPath: String?, + remoteHost: String, + remotePort: Int, + jumpHosts: [SSHJumpHost] + ) throws -> SSHTunnelLaunch { + let process = Process() + let errorPipe = Pipe() + var askpassScriptPath: String? - var addr = sockaddr_in() - addr.sin_family = sa_family_t(AF_INET) - addr.sin_port = in_port_t(port).bigEndian - addr.sin_addr.s_addr = inet_addr("127.0.0.1") + do { + process.executableURL = URL(fileURLWithPath: "/usr/bin/ssh") + + var arguments = [ + "-N", // Don't execute remote command + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ServerAliveInterval=60", + "-o", "ServerAliveCountMax=3", + "-o", "ConnectTimeout=10", + "-o", "ExitOnForwardFailure=yes", + "-L", "127.0.0.1:\(localPort):\(remoteHost):\(remotePort)", + "-p", String(sshPort), + ] + + switch authMethod { + case .privateKey: + guard let keyPath = privateKeyPath, !keyPath.isEmpty else { + throw SSHTunnelError.tunnelCreationFailed("Private key path is required for key authentication") + } + + let expandedPath = expandPath(keyPath) + let fileManager = FileManager.default + guard fileManager.fileExists(atPath: expandedPath) else { + throw SSHTunnelError.tunnelCreationFailed("Private key file not found at: \(expandedPath)") + } + guard fileManager.isReadableFile(atPath: expandedPath) else { + throw SSHTunnelError.tunnelCreationFailed("Private key file is not readable. Check permissions (should be 600): \(expandedPath)") + } + + arguments.append(contentsOf: ["-i", expandedPath]) + arguments.append(contentsOf: ["-o", "PubkeyAuthentication=yes"]) + arguments.append(contentsOf: ["-o", "PasswordAuthentication=no"]) + arguments.append(contentsOf: ["-o", "PreferredAuthentications=publickey"]) + + case .password: + arguments.append(contentsOf: ["-o", "PasswordAuthentication=yes"]) + arguments.append(contentsOf: ["-o", "PreferredAuthentications=password"]) + arguments.append(contentsOf: ["-o", "PubkeyAuthentication=no"]) + + case .sshAgent: + arguments.append(contentsOf: ["-o", "PubkeyAuthentication=yes"]) + arguments.append(contentsOf: ["-o", "PasswordAuthentication=no"]) + arguments.append(contentsOf: ["-o", "PreferredAuthentications=publickey"]) + } - let result = withUnsafePointer(to: &addr) { - $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { - bind(socketFD, $0, socklen_t(MemoryLayout.size)) + for jumpHost in jumpHosts where jumpHost.authMethod == .privateKey && !jumpHost.privateKeyPath.isEmpty { + arguments.append(contentsOf: ["-i", expandPath(jumpHost.privateKeyPath)]) } - } - return result == 0 + if !jumpHosts.isEmpty { + let jumpString = jumpHosts.map(\.proxyJumpString).joined(separator: ",") + arguments.append(contentsOf: ["-J", jumpString]) + } + + arguments.append("\(sshUsername)@\(sshHost)") + process.arguments = arguments + + if authMethod == .privateKey, let passphrase = keyPassphrase { + askpassScriptPath = try createAskpassScript(password: passphrase) + } else if authMethod == .password, let password = sshPassword { + askpassScriptPath = try createAskpassScript(password: password) + } + + if let scriptPath = askpassScriptPath { + var environment = ProcessInfo.processInfo.environment + environment["SSH_ASKPASS"] = scriptPath + environment["SSH_ASKPASS_REQUIRE"] = "force" + environment["DISPLAY"] = ":0" + process.environment = environment + } + + if authMethod == .sshAgent, let socketPath = agentSocketPath, !socketPath.isEmpty { + var environment = process.environment ?? ProcessInfo.processInfo.environment + environment["SSH_AUTH_SOCK"] = expandPath(socketPath) + process.environment = environment + } + + process.standardError = errorPipe + process.standardOutput = FileHandle.nullDevice + + return SSHTunnelLaunch( + process: process, + errorPipe: errorPipe, + askpassScriptPath: askpassScriptPath + ) + } catch { + removeAskpassScript(askpassScriptPath) + throw error + } } private func expandPath(_ path: String) -> String { @@ -391,8 +412,8 @@ actor SSHTunnelManager { // If the SSH process died, bail out immediately guard process.isRunning else { return false } - // Try to connect to the local forwarded port - if isPortReachable(localPort) { + if isPortReachable(localPort), + isPortOwnedByProcessTree(localPort, rootProcessId: process.processIdentifier) { return true } @@ -422,6 +443,102 @@ actor SSHTunnelManager { return result == 0 } + private func isPortOwnedByProcessTree(_ port: Int, rootProcessId: Int32) -> Bool { + let listeningProcessIds = listeningProcessIds(for: port) + guard !listeningProcessIds.isEmpty else { return false } + + let processTreeIds = processTreeIds(rootProcessId: rootProcessId) + return !listeningProcessIds.isDisjoint(with: processTreeIds) + } + + private func listeningProcessIds(for port: Int) -> Set { + let output = runCommand( + executablePath: "/usr/sbin/lsof", + arguments: ["-nP", "-iTCP:\(port)", "-sTCP:LISTEN", "-t"] + ) + + return Set( + output + .split(whereSeparator: \.isNewline) + .compactMap { Int32($0) } + ) + } + + private func processTreeIds(rootProcessId: Int32) -> Set { + let parentProcessIds = currentParentProcessIds() + return Self.descendantProcessIds( + rootProcessId: rootProcessId, + parentProcessIds: parentProcessIds + ) + } + + private func currentParentProcessIds() -> [Int32: Int32] { + let output = runCommand( + executablePath: "/bin/ps", + arguments: ["-axo", "pid=,ppid="] + ) + + var parentProcessIds: [Int32: Int32] = [:] + for line in output.split(whereSeparator: \.isNewline) { + let parts = line.split(whereSeparator: \.isWhitespace) + guard parts.count == 2, + let pid = Int32(parts[0]), + let parentPid = Int32(parts[1]) else { + continue + } + parentProcessIds[pid] = parentPid + } + return parentProcessIds + } + + private func runCommand(executablePath: String, arguments: [String]) -> String { + let process = Process() + let outputPipe = Pipe() + + process.executableURL = URL(fileURLWithPath: executablePath) + process.arguments = arguments + process.standardOutput = outputPipe + process.standardError = FileHandle.nullDevice + + do { + try process.run() + process.waitUntilExit() + } catch { + return "" + } + + let data = outputPipe.fileHandleForReading.readDataToEndOfFile() + return String(data: data, encoding: .utf8) ?? "" + } + + internal static func descendantProcessIds( + rootProcessId: Int32, + parentProcessIds: [Int32: Int32] + ) -> Set { + var discovered: Set = [rootProcessId] + var queue: [Int32] = [rootProcessId] + + while let currentProcessId = queue.first { + queue.removeFirst() + + for (processId, parentProcessId) in parentProcessIds + where parentProcessId == currentProcessId && !discovered.contains(processId) { + discovered.insert(processId) + queue.append(processId) + } + } + + return discovered + } + + static func isLocalPortBindFailure(_ errorMessage: String) -> Bool { + let normalized = errorMessage.lowercased() + return normalized.contains("address already in use") + || normalized.contains("cannot listen to port") + || normalized.contains("could not request local forwarding") + || normalized.contains("port forwarding failed") + } + /// Classify an SSH stderr message into a specific error type private func classifySSHError( errorMessage: String, diff --git a/TableProTests/Core/SSH/SSHTunnelManagerTests.swift b/TableProTests/Core/SSH/SSHTunnelManagerTests.swift new file mode 100644 index 00000000..3b1f4dbf --- /dev/null +++ b/TableProTests/Core/SSH/SSHTunnelManagerTests.swift @@ -0,0 +1,49 @@ +// +// SSHTunnelManagerTests.swift +// TableProTests +// +// Tests for SSH tunnel port binding safeguards. +// + +@testable import TablePro +import Testing + +@Suite("SSHTunnelManager") +struct SSHTunnelManagerTests { + @Test("Ownership checks include child ssh processes") + func descendantProcessIdsIncludeChildren() { + let processTree = SSHTunnelManager.descendantProcessIds( + rootProcessId: 100, + parentProcessIds: [ + 101: 100, + 102: 101, + 200: 999, + ] + ) + + #expect(processTree == [100, 101, 102]) + } + + @Test("Local port bind failures are treated as retryable") + func localPortBindFailuresAreRetryable() { + let errorMessage = """ + bind [127.0.0.1]:60000: Address already in use + channel_setup_fwd_listener_tcpip: cannot listen to port: 60000 + Could not request local forwarding. + """ + + #expect(SSHTunnelManager.isLocalPortBindFailure(errorMessage)) + } + + @Test("Non-bind SSH failures are not retried as port races") + func nonBindFailuresAreNotRetried() { + #expect(SSHTunnelManager.isLocalPortBindFailure("Permission denied (publickey,password).") == false) + #expect(SSHTunnelManager.isLocalPortBindFailure("Connection timed out during banner exchange") == false) + } + + @Test("Generic forwarding failures are treated as retryable bind failures") + func genericForwardingFailuresAreRetryable() { + let errorMessage = "Error: port forwarding failed for listen port 60123" + #expect(SSHTunnelManager.isLocalPortBindFailure(errorMessage)) + } +}