From 1933492c75d574dfa7ee284615730f60a71816b6 Mon Sep 17 00:00:00 2001 From: Eliott Wantz Date: Sun, 8 Mar 2026 22:50:50 -0400 Subject: [PATCH] fix: harden ssh local tunnel binding against TOCTOU races Stop pre-scanning for an open local port before launching ssh. Instead, let ssh bind the forwarded port directly with `ExitOnForwardFailure=yes` and retry only when stderr shows a local bind failure. Also tighten tunnel readiness checks so the forwarded port must be reachable and owned by the launched process tree, preventing false positives if another local process grabs the port first. Add focused tests for retryable bind failures and process-tree ownership checks. --- TablePro/Core/SSH/SSHTunnelManager.swift | 423 +++++++++++------- .../Core/SSH/SSHTunnelManagerTests.swift | 49 ++ 2 files changed, 319 insertions(+), 153 deletions(-) create mode 100644 TableProTests/Core/SSH/SSHTunnelManagerTests.swift diff --git a/TablePro/Core/SSH/SSHTunnelManager.swift b/TablePro/Core/SSH/SSHTunnelManager.swift index 0dc539e2..df8a2716 100644 --- a/TablePro/Core/SSH/SSHTunnelManager.swift +++ b/TablePro/Core/SSH/SSHTunnelManager.swift @@ -45,6 +45,12 @@ struct SSHTunnel { let createdAt: Date } +private struct SSHTunnelLaunch { + let process: Process + let errorPipe: Pipe + let askpassScriptPath: String? +} + /// Manages SSH tunnels for database connections using system ssh command actor SSHTunnelManager { static let shared = SSHTunnelManager() @@ -127,150 +133,79 @@ actor SSHTunnelManager { try await closeTunnel(connectionId: connectionId) } - // Find available local port - let localPort = try await findAvailablePort() - - // Build SSH command - let process = Process() - process.executableURL = URL(fileURLWithPath: "/usr/bin/ssh") - - var arguments = [ - "-N", // Don't execute remote command - "-o", "StrictHostKeyChecking=no", - "-o", "UserKnownHostsFile=/dev/null", - "-o", "ServerAliveInterval=60", - "-o", "ServerAliveCountMax=3", - "-o", "ConnectTimeout=10", - "-L", "\(localPort):\(remoteHost):\(remotePort)", - "-p", String(sshPort), - ] - - // Add authentication - switch authMethod { - case .privateKey: - guard let keyPath = privateKeyPath, !keyPath.isEmpty else { - throw SSHTunnelError.tunnelCreationFailed("Private key path is required for key authentication") + for localPort in localPortCandidates() { + let launch: SSHTunnelLaunch + + do { + launch = try createTunnelLaunch( + localPort: localPort, + sshHost: sshHost, + sshPort: sshPort, + sshUsername: sshUsername, + authMethod: authMethod, + privateKeyPath: privateKeyPath, + keyPassphrase: keyPassphrase, + sshPassword: sshPassword, + agentSocketPath: agentSocketPath, + remoteHost: remoteHost, + remotePort: remotePort, + jumpHosts: jumpHosts + ) + } catch let error as SSHTunnelError { + throw error + } catch { + throw SSHTunnelError.tunnelCreationFailed(error.localizedDescription) } - let expandedPath = expandPath(keyPath) - - // Validate private key exists and is readable - let fileManager = FileManager.default - guard fileManager.fileExists(atPath: expandedPath) else { - throw SSHTunnelError.tunnelCreationFailed("Private key file not found at: \(expandedPath)") - } - guard fileManager.isReadableFile(atPath: expandedPath) else { - throw SSHTunnelError.tunnelCreationFailed("Private key file is not readable. Check permissions (should be 600): \(expandedPath)") + do { + try launch.process.run() + } catch { + removeAskpassScript(launch.askpassScriptPath) + throw SSHTunnelError.tunnelCreationFailed(error.localizedDescription) } - // Force public key authentication - arguments.append(contentsOf: ["-i", expandedPath]) - arguments.append(contentsOf: ["-o", "PubkeyAuthentication=yes"]) - arguments.append(contentsOf: ["-o", "PasswordAuthentication=no"]) - arguments.append(contentsOf: ["-o", "PreferredAuthentications=publickey"]) - - case .password: - // For password auth, we'll use SSH_ASKPASS with a helper script - // Note: This requires ssh to be run without a TTY (which -N provides) - arguments.append(contentsOf: ["-o", "PasswordAuthentication=yes"]) - arguments.append(contentsOf: ["-o", "PreferredAuthentications=password"]) - arguments.append(contentsOf: ["-o", "PubkeyAuthentication=no"]) - - case .sshAgent: - arguments.append(contentsOf: ["-o", "PubkeyAuthentication=yes"]) - arguments.append(contentsOf: ["-o", "PasswordAuthentication=no"]) - arguments.append(contentsOf: ["-o", "PreferredAuthentications=publickey"]) - } - - // Jump host identity files - for jumpHost in jumpHosts where jumpHost.authMethod == .privateKey && !jumpHost.privateKeyPath.isEmpty { - arguments.append(contentsOf: ["-i", expandPath(jumpHost.privateKeyPath)]) - } - - // ProxyJump chain - if !jumpHosts.isEmpty { - let jumpString = jumpHosts.map(\.proxyJumpString).joined(separator: ",") - arguments.append(contentsOf: ["-J", jumpString]) - } - - arguments.append("\(sshUsername)@\(sshHost)") - - process.arguments = arguments - - // Set up SSH_ASKPASS for passphrase or password - var askpassScriptPath: String? - - if authMethod == .privateKey, let passphrase = keyPassphrase { - askpassScriptPath = try createAskpassScript(password: passphrase) - } else if authMethod == .password, let password = sshPassword { - askpassScriptPath = try createAskpassScript(password: password) - } - - if let scriptPath = askpassScriptPath { - var environment = ProcessInfo.processInfo.environment - environment["SSH_ASKPASS"] = scriptPath - environment["SSH_ASKPASS_REQUIRE"] = "force" - environment["DISPLAY"] = ":0" // Required for SSH_ASKPASS to work - process.environment = environment - } - - if authMethod == .sshAgent, let socketPath = agentSocketPath, !socketPath.isEmpty { - var environment = process.environment ?? ProcessInfo.processInfo.environment - environment["SSH_AUTH_SOCK"] = expandPath(socketPath) - process.environment = environment - } - - // Capture stderr for error messages - let errorPipe = Pipe() - process.standardError = errorPipe - process.standardOutput = FileHandle.nullDevice + let tunnelReady = await waitForTunnelReady( + localPort: localPort, + process: launch.process, + timeoutSeconds: 15 + ) - // Start the process - do { - try process.run() - } catch { - removeAskpassScript(askpassScriptPath) - throw SSHTunnelError.tunnelCreationFailed(error.localizedDescription) - } + removeAskpassScript(launch.askpassScriptPath) - // Wait for tunnel to become ready by probing the local port - let tunnelReady = await waitForTunnelReady( - localPort: localPort, - process: process, - timeoutSeconds: 15 - ) + if !tunnelReady { + if !launch.process.isRunning { + let errorData = launch.errorPipe.fileHandleForReading.readDataToEndOfFile() + let errorMessage = String(data: errorData, encoding: .utf8) ?? "Unknown error" - removeAskpassScript(askpassScriptPath) + if Self.isLocalPortBindFailure(errorMessage) { + Self.logger.notice("SSH tunnel bind race on local port \(localPort), retrying with another port") + continue + } - if !tunnelReady { - // Process died or timed out — read stderr for diagnostics - if !process.isRunning { - let errorData = errorPipe.fileHandleForReading.readDataToEndOfFile() - let errorMessage = String(data: errorData, encoding: .utf8) ?? "Unknown error" + throw classifySSHError( + errorMessage: errorMessage, + authMethod: authMethod + ) + } - throw classifySSHError( - errorMessage: errorMessage, - authMethod: authMethod - ) + launch.process.terminate() + throw SSHTunnelError.connectionTimeout } - // Process still running but port never became reachable - process.terminate() - throw SSHTunnelError.connectionTimeout - } + let tunnel = SSHTunnel( + connectionId: connectionId, + localPort: localPort, + remoteHost: remoteHost, + remotePort: remotePort, + process: launch.process, + createdAt: Date() + ) + tunnels[connectionId] = tunnel - // Store the tunnel - let tunnel = SSHTunnel( - connectionId: connectionId, - localPort: localPort, - remoteHost: remoteHost, - remotePort: remotePort, - process: process, - createdAt: Date() - ) - tunnels[connectionId] = tunnel + return localPort + } - return localPort + throw SSHTunnelError.noAvailablePort } /// Close an SSH tunnel @@ -311,32 +246,118 @@ actor SSHTunnelManager { // MARK: - Private Helpers - private func findAvailablePort() async throws -> Int { - for port in portRangeStart...portRangeEnd { - if isPortAvailable(port) { - return port - } - } - throw SSHTunnelError.noAvailablePort + private func localPortCandidates() -> [Int] { + Array(portRangeStart...portRangeEnd).shuffled() } - private func isPortAvailable(_ port: Int) -> Bool { - let socketFD = socket(AF_INET, SOCK_STREAM, 0) - guard socketFD >= 0 else { return false } - defer { close(socketFD) } + private func createTunnelLaunch( + localPort: Int, + sshHost: String, + sshPort: Int, + sshUsername: String, + authMethod: SSHAuthMethod, + privateKeyPath: String?, + keyPassphrase: String?, + sshPassword: String?, + agentSocketPath: String?, + remoteHost: String, + remotePort: Int, + jumpHosts: [SSHJumpHost] + ) throws -> SSHTunnelLaunch { + let process = Process() + let errorPipe = Pipe() + var askpassScriptPath: String? - var addr = sockaddr_in() - addr.sin_family = sa_family_t(AF_INET) - addr.sin_port = in_port_t(port).bigEndian - addr.sin_addr.s_addr = inet_addr("127.0.0.1") + do { + process.executableURL = URL(fileURLWithPath: "/usr/bin/ssh") + + var arguments = [ + "-N", // Don't execute remote command + "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ServerAliveInterval=60", + "-o", "ServerAliveCountMax=3", + "-o", "ConnectTimeout=10", + "-o", "ExitOnForwardFailure=yes", + "-L", "127.0.0.1:\(localPort):\(remoteHost):\(remotePort)", + "-p", String(sshPort), + ] + + switch authMethod { + case .privateKey: + guard let keyPath = privateKeyPath, !keyPath.isEmpty else { + throw SSHTunnelError.tunnelCreationFailed("Private key path is required for key authentication") + } + + let expandedPath = expandPath(keyPath) + let fileManager = FileManager.default + guard fileManager.fileExists(atPath: expandedPath) else { + throw SSHTunnelError.tunnelCreationFailed("Private key file not found at: \(expandedPath)") + } + guard fileManager.isReadableFile(atPath: expandedPath) else { + throw SSHTunnelError.tunnelCreationFailed("Private key file is not readable. Check permissions (should be 600): \(expandedPath)") + } + + arguments.append(contentsOf: ["-i", expandedPath]) + arguments.append(contentsOf: ["-o", "PubkeyAuthentication=yes"]) + arguments.append(contentsOf: ["-o", "PasswordAuthentication=no"]) + arguments.append(contentsOf: ["-o", "PreferredAuthentications=publickey"]) + + case .password: + arguments.append(contentsOf: ["-o", "PasswordAuthentication=yes"]) + arguments.append(contentsOf: ["-o", "PreferredAuthentications=password"]) + arguments.append(contentsOf: ["-o", "PubkeyAuthentication=no"]) + + case .sshAgent: + arguments.append(contentsOf: ["-o", "PubkeyAuthentication=yes"]) + arguments.append(contentsOf: ["-o", "PasswordAuthentication=no"]) + arguments.append(contentsOf: ["-o", "PreferredAuthentications=publickey"]) + } - let result = withUnsafePointer(to: &addr) { - $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { - bind(socketFD, $0, socklen_t(MemoryLayout.size)) + for jumpHost in jumpHosts where jumpHost.authMethod == .privateKey && !jumpHost.privateKeyPath.isEmpty { + arguments.append(contentsOf: ["-i", expandPath(jumpHost.privateKeyPath)]) } - } - return result == 0 + if !jumpHosts.isEmpty { + let jumpString = jumpHosts.map(\.proxyJumpString).joined(separator: ",") + arguments.append(contentsOf: ["-J", jumpString]) + } + + arguments.append("\(sshUsername)@\(sshHost)") + process.arguments = arguments + + if authMethod == .privateKey, let passphrase = keyPassphrase { + askpassScriptPath = try createAskpassScript(password: passphrase) + } else if authMethod == .password, let password = sshPassword { + askpassScriptPath = try createAskpassScript(password: password) + } + + if let scriptPath = askpassScriptPath { + var environment = ProcessInfo.processInfo.environment + environment["SSH_ASKPASS"] = scriptPath + environment["SSH_ASKPASS_REQUIRE"] = "force" + environment["DISPLAY"] = ":0" + process.environment = environment + } + + if authMethod == .sshAgent, let socketPath = agentSocketPath, !socketPath.isEmpty { + var environment = process.environment ?? ProcessInfo.processInfo.environment + environment["SSH_AUTH_SOCK"] = expandPath(socketPath) + process.environment = environment + } + + process.standardError = errorPipe + process.standardOutput = FileHandle.nullDevice + + return SSHTunnelLaunch( + process: process, + errorPipe: errorPipe, + askpassScriptPath: askpassScriptPath + ) + } catch { + removeAskpassScript(askpassScriptPath) + throw error + } } private func expandPath(_ path: String) -> String { @@ -391,8 +412,8 @@ actor SSHTunnelManager { // If the SSH process died, bail out immediately guard process.isRunning else { return false } - // Try to connect to the local forwarded port - if isPortReachable(localPort) { + if isPortReachable(localPort), + isPortOwnedByProcessTree(localPort, rootProcessId: process.processIdentifier) { return true } @@ -422,6 +443,102 @@ actor SSHTunnelManager { return result == 0 } + private func isPortOwnedByProcessTree(_ port: Int, rootProcessId: Int32) -> Bool { + let listeningProcessIds = listeningProcessIds(for: port) + guard !listeningProcessIds.isEmpty else { return false } + + let processTreeIds = processTreeIds(rootProcessId: rootProcessId) + return !listeningProcessIds.isDisjoint(with: processTreeIds) + } + + private func listeningProcessIds(for port: Int) -> Set { + let output = runCommand( + executablePath: "/usr/sbin/lsof", + arguments: ["-nP", "-iTCP:\(port)", "-sTCP:LISTEN", "-t"] + ) + + return Set( + output + .split(whereSeparator: \.isNewline) + .compactMap { Int32($0) } + ) + } + + private func processTreeIds(rootProcessId: Int32) -> Set { + let parentProcessIds = currentParentProcessIds() + return Self.descendantProcessIds( + rootProcessId: rootProcessId, + parentProcessIds: parentProcessIds + ) + } + + private func currentParentProcessIds() -> [Int32: Int32] { + let output = runCommand( + executablePath: "/bin/ps", + arguments: ["-axo", "pid=,ppid="] + ) + + var parentProcessIds: [Int32: Int32] = [:] + for line in output.split(whereSeparator: \.isNewline) { + let parts = line.split(whereSeparator: \.isWhitespace) + guard parts.count == 2, + let pid = Int32(parts[0]), + let parentPid = Int32(parts[1]) else { + continue + } + parentProcessIds[pid] = parentPid + } + return parentProcessIds + } + + private func runCommand(executablePath: String, arguments: [String]) -> String { + let process = Process() + let outputPipe = Pipe() + + process.executableURL = URL(fileURLWithPath: executablePath) + process.arguments = arguments + process.standardOutput = outputPipe + process.standardError = FileHandle.nullDevice + + do { + try process.run() + process.waitUntilExit() + } catch { + return "" + } + + let data = outputPipe.fileHandleForReading.readDataToEndOfFile() + return String(data: data, encoding: .utf8) ?? "" + } + + internal static func descendantProcessIds( + rootProcessId: Int32, + parentProcessIds: [Int32: Int32] + ) -> Set { + var discovered: Set = [rootProcessId] + var queue: [Int32] = [rootProcessId] + + while let currentProcessId = queue.first { + queue.removeFirst() + + for (processId, parentProcessId) in parentProcessIds + where parentProcessId == currentProcessId && !discovered.contains(processId) { + discovered.insert(processId) + queue.append(processId) + } + } + + return discovered + } + + static func isLocalPortBindFailure(_ errorMessage: String) -> Bool { + let normalized = errorMessage.lowercased() + return normalized.contains("address already in use") + || normalized.contains("cannot listen to port") + || normalized.contains("could not request local forwarding") + || normalized.contains("port forwarding failed") + } + /// Classify an SSH stderr message into a specific error type private func classifySSHError( errorMessage: String, diff --git a/TableProTests/Core/SSH/SSHTunnelManagerTests.swift b/TableProTests/Core/SSH/SSHTunnelManagerTests.swift new file mode 100644 index 00000000..3b1f4dbf --- /dev/null +++ b/TableProTests/Core/SSH/SSHTunnelManagerTests.swift @@ -0,0 +1,49 @@ +// +// SSHTunnelManagerTests.swift +// TableProTests +// +// Tests for SSH tunnel port binding safeguards. +// + +@testable import TablePro +import Testing + +@Suite("SSHTunnelManager") +struct SSHTunnelManagerTests { + @Test("Ownership checks include child ssh processes") + func descendantProcessIdsIncludeChildren() { + let processTree = SSHTunnelManager.descendantProcessIds( + rootProcessId: 100, + parentProcessIds: [ + 101: 100, + 102: 101, + 200: 999, + ] + ) + + #expect(processTree == [100, 101, 102]) + } + + @Test("Local port bind failures are treated as retryable") + func localPortBindFailuresAreRetryable() { + let errorMessage = """ + bind [127.0.0.1]:60000: Address already in use + channel_setup_fwd_listener_tcpip: cannot listen to port: 60000 + Could not request local forwarding. + """ + + #expect(SSHTunnelManager.isLocalPortBindFailure(errorMessage)) + } + + @Test("Non-bind SSH failures are not retried as port races") + func nonBindFailuresAreNotRetried() { + #expect(SSHTunnelManager.isLocalPortBindFailure("Permission denied (publickey,password).") == false) + #expect(SSHTunnelManager.isLocalPortBindFailure("Connection timed out during banner exchange") == false) + } + + @Test("Generic forwarding failures are treated as retryable bind failures") + func genericForwardingFailuresAreRetryable() { + let errorMessage = "Error: port forwarding failed for listen port 60123" + #expect(SSHTunnelManager.isLocalPortBindFailure(errorMessage)) + } +}