Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 40 additions & 25 deletions internal/installer/node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,43 +33,59 @@ type NodeManager struct {

// getAuthMethods constructs a slice of ssh.AuthMethod, prioritizing the SSH agent.
func (n *NodeManager) getAuthMethods() ([]ssh.AuthMethod, error) {
var authMethods []ssh.AuthMethod
var signers []ssh.Signer

// 1. Get Agent Signers
if authSocket := os.Getenv("SSH_AUTH_SOCK"); authSocket != "" {
// Connect to the SSH agent's socket
conn, err := net.Dial("unix", authSocket)
if err == nil {
// Create an Agent client and use it for authentication
if conn, err := net.Dial("unix", authSocket); err == nil {
agentClient := agent.NewClient(conn)
authMethods = append(authMethods, ssh.PublicKeysCallback(agentClient.Signers))
} else {
fmt.Printf("Could not connect to SSH Agent (%s): %v\n", authSocket, err)
if s, err := agentClient.Signers(); err == nil {
signers = append(signers, s...)
}
}
}

// 2. Add Private Key (File) if needed
if n.KeyPath != "" {
// Use cached signer if available to avoid repeated passphrase prompts
shouldLoad := true

// Use cached signer if available
if n.cachedSigner != nil {
authMethods = append(authMethods, ssh.PublicKeys(n.cachedSigner))
} else {
signer, err := n.loadPrivateKey()
if err != nil {
if len(authMethods) == 0 {
return nil, err
signers = append(signers, n.cachedSigner)
shouldLoad = false
}

// Check if key is already in agent (requires .pub file)
if shouldLoad && len(signers) > 0 {
if pubBytes, err := n.FileIO.ReadFile(n.KeyPath + ".pub"); err == nil {
if targetPub, _, _, _, err := ssh.ParseAuthorizedKey(pubBytes); err == nil {
targetMarshaled := string(targetPub.Marshal())
for _, s := range signers {
if string(s.PublicKey().Marshal()) == targetMarshaled {
shouldLoad = false
break
}
}
}
fmt.Printf("Warning: %v\n", err)
} else {
}
}

// Else load from file with passphrase prompt if needed
if shouldLoad {
if signer, err := n.loadPrivateKey(); err == nil {
n.cachedSigner = signer
authMethods = append(authMethods, ssh.PublicKeys(signer))
signers = append(signers, signer)
} else {
log.Printf("Warning: failed to load private key: %v\n", err)
}
}
}

if len(authMethods) == 0 {
if len(signers) == 0 {
return nil, fmt.Errorf("no valid authentication methods configured. Check SSH_AUTH_SOCK and private key path")
}

return authMethods, nil
return []ssh.AuthMethod{ssh.PublicKeys(signers...)}, nil
}

// loadPrivateKey reads and parses the private key, prompting for passphrase if needed.
Expand All @@ -89,9 +105,9 @@ func (n *NodeManager) loadPrivateKey() (ssh.Signer, error) {
}

// Key is encrypted, prompt for passphrase
fmt.Printf("Enter passphrase for key '%s': ", n.KeyPath)
log.Printf("Enter passphrase for key '%s': ", n.KeyPath)
passphrase, err := term.ReadPassword(int(syscall.Stdin))
fmt.Println()
log.Println()
if err != nil {
return nil, fmt.Errorf("failed to read passphrase: %v", err)
}
Expand Down Expand Up @@ -130,7 +146,7 @@ func (n *NodeManager) connectToJumpbox(ip, username string) (*ssh.Client, error)

// Enable Agent Forwarding on the jumpbox connection
if err := n.forwardAgent(jumpboxClient, nil); err != nil {
fmt.Printf(" Warning: Agent forwarding setup failed on jumpbox: %v\n", err)
log.Printf(" Warning: Agent forwarding setup failed on jumpbox: %v\n", err)
}

return jumpboxClient, nil
Expand Down Expand Up @@ -183,7 +199,7 @@ func (n *NodeManager) RunSSHCommand(jumpboxIp string, ip string, username string
err = n.forwardAgent(client, session)

if err != nil {
fmt.Printf(" Warning: Agent forwarding setup failed on session: %v\n", err)
log.Printf(" Warning: Agent forwarding setup failed on session: %v\n", err)
}

session.Stdout = os.Stdout
Expand All @@ -202,7 +218,6 @@ func (n *NodeManager) RunSSHCommand(jumpboxIp string, ip string, username string
}

func (n *NodeManager) GetClient(jumpboxIp string, ip string, username string) (*ssh.Client, error) {

authMethods, err := n.getAuthMethods()
if err != nil {
return nil, fmt.Errorf("failed to get authentication methods: %w", err)
Expand Down