diff --git a/internal/installer/node/node.go b/internal/installer/node/node.go index 31c5953..0edf691 100644 --- a/internal/installer/node/node.go +++ b/internal/installer/node/node.go @@ -33,43 +33,59 @@ type NodeManager struct { // getAuthMethods constructs a slice of ssh.AuthMethod, prioritizing the SSH agent. func (n *NodeManager) getAuthMethods() ([]ssh.AuthMethod, error) { - var authMethods []ssh.AuthMethod + var signers []ssh.Signer + // 1. Get Agent Signers if authSocket := os.Getenv("SSH_AUTH_SOCK"); authSocket != "" { - // Connect to the SSH agent's socket - conn, err := net.Dial("unix", authSocket) - if err == nil { - // Create an Agent client and use it for authentication + if conn, err := net.Dial("unix", authSocket); err == nil { agentClient := agent.NewClient(conn) - authMethods = append(authMethods, ssh.PublicKeysCallback(agentClient.Signers)) - } else { - fmt.Printf("Could not connect to SSH Agent (%s): %v\n", authSocket, err) + if s, err := agentClient.Signers(); err == nil { + signers = append(signers, s...) + } } } + // 2. Add Private Key (File) if needed if n.KeyPath != "" { - // Use cached signer if available to avoid repeated passphrase prompts + shouldLoad := true + + // Use cached signer if available if n.cachedSigner != nil { - authMethods = append(authMethods, ssh.PublicKeys(n.cachedSigner)) - } else { - signer, err := n.loadPrivateKey() - if err != nil { - if len(authMethods) == 0 { - return nil, err + signers = append(signers, n.cachedSigner) + shouldLoad = false + } + + // Check if key is already in agent (requires .pub file) + if shouldLoad && len(signers) > 0 { + if pubBytes, err := n.FileIO.ReadFile(n.KeyPath + ".pub"); err == nil { + if targetPub, _, _, _, err := ssh.ParseAuthorizedKey(pubBytes); err == nil { + targetMarshaled := string(targetPub.Marshal()) + for _, s := range signers { + if string(s.PublicKey().Marshal()) == targetMarshaled { + shouldLoad = false + break + } + } } - fmt.Printf("Warning: %v\n", err) - } else { + } + } + + // Else load from file with passphrase prompt if needed + if shouldLoad { + if signer, err := n.loadPrivateKey(); err == nil { n.cachedSigner = signer - authMethods = append(authMethods, ssh.PublicKeys(signer)) + signers = append(signers, signer) + } else { + log.Printf("Warning: failed to load private key: %v\n", err) } } } - if len(authMethods) == 0 { + if len(signers) == 0 { return nil, fmt.Errorf("no valid authentication methods configured. Check SSH_AUTH_SOCK and private key path") } - return authMethods, nil + return []ssh.AuthMethod{ssh.PublicKeys(signers...)}, nil } // loadPrivateKey reads and parses the private key, prompting for passphrase if needed. @@ -89,9 +105,9 @@ func (n *NodeManager) loadPrivateKey() (ssh.Signer, error) { } // Key is encrypted, prompt for passphrase - fmt.Printf("Enter passphrase for key '%s': ", n.KeyPath) + log.Printf("Enter passphrase for key '%s': ", n.KeyPath) passphrase, err := term.ReadPassword(int(syscall.Stdin)) - fmt.Println() + log.Println() if err != nil { return nil, fmt.Errorf("failed to read passphrase: %v", err) } @@ -130,7 +146,7 @@ func (n *NodeManager) connectToJumpbox(ip, username string) (*ssh.Client, error) // Enable Agent Forwarding on the jumpbox connection if err := n.forwardAgent(jumpboxClient, nil); err != nil { - fmt.Printf(" Warning: Agent forwarding setup failed on jumpbox: %v\n", err) + log.Printf(" Warning: Agent forwarding setup failed on jumpbox: %v\n", err) } return jumpboxClient, nil @@ -183,7 +199,7 @@ func (n *NodeManager) RunSSHCommand(jumpboxIp string, ip string, username string err = n.forwardAgent(client, session) if err != nil { - fmt.Printf(" Warning: Agent forwarding setup failed on session: %v\n", err) + log.Printf(" Warning: Agent forwarding setup failed on session: %v\n", err) } session.Stdout = os.Stdout @@ -202,7 +218,6 @@ func (n *NodeManager) RunSSHCommand(jumpboxIp string, ip string, username string } func (n *NodeManager) GetClient(jumpboxIp string, ip string, username string) (*ssh.Client, error) { - authMethods, err := n.getAuthMethods() if err != nil { return nil, fmt.Errorf("failed to get authentication methods: %w", err)