diff --git a/cmd/connector/conn_tag_create.go b/cmd/connector/conn_tag_create.go index 8782c4ad..9440ff13 100644 --- a/cmd/connector/conn_tag_create.go +++ b/cmd/connector/conn_tag_create.go @@ -25,7 +25,7 @@ func newConnTagCreateCmd(client client.Client) *cobra.Command { tagName := cmd.Flags().Lookup("name").Value.String() versionStr := cmd.Flags().Lookup("version").Value.String() - version, err := strconv.Atoi(versionStr) + version, err := strconv.ParseUint(versionStr, 10, 32) if err != nil { return err } diff --git a/internal/va/va.go b/internal/va/va.go index bb1d0c73..8013ef9b 100644 --- a/internal/va/va.go +++ b/internal/va/va.go @@ -2,6 +2,7 @@ package va import ( + "bufio" "bytes" "errors" "fmt" @@ -9,6 +10,7 @@ import ( "net" "os" "path" + "strings" "sync" "github.com/charmbracelet/log" @@ -16,15 +18,96 @@ import ( "github.com/vbauerster/mpb/v8" "github.com/vbauerster/mpb/v8/decor" "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/knownhosts" ) -func RunVACmd(addr string, password string, cmd string) (string, error) { - config := &ssh.ClientConfig{ +const sshDirMode = 0700 + +// ensureKnownHostsFile creates ~/.ssh and an empty known_hosts file if missing. +func ensureKnownHostsFile(knownHostsPath string) error { + sshDir := path.Dir(knownHostsPath) + if err := os.MkdirAll(sshDir, sshDirMode); err != nil { + return fmt.Errorf("could not create %s: %w", sshDir, err) + } + f, err := os.OpenFile(knownHostsPath, os.O_CREATE|os.O_RDONLY, 0600) + if err != nil { + return fmt.Errorf("could not create %s: %w", knownHostsPath, err) + } + _ = f.Close() + return nil +} + +// promptYesNo reads a line from stdin and returns true for yes, false for no or invalid input. +func promptYesNo() bool { + scanner := bufio.NewScanner(os.Stdin) + if !scanner.Scan() { + return false + } + line := strings.TrimSpace(strings.ToLower(scanner.Text())) + return line == "yes" || line == "y" +} + +// newSSHClientConfig returns an ssh.ClientConfig with host key verification via ~/.ssh/known_hosts +// and interactive first-time host acceptance. +func newSSHClientConfig(password string) (*ssh.ClientConfig, error) { + userHome, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("could not determine user home directory: %w", err) + } + knownHostsPath := path.Join(userHome, ".ssh", "known_hosts") + if err := ensureKnownHostsFile(knownHostsPath); err != nil { + return nil, err + } + + baseCallback, err := knownhosts.New(knownHostsPath) + if err != nil { + return nil, fmt.Errorf("could not load SSH known_hosts from %s: %w", knownHostsPath, err) + } + + callback := func(hostname string, remote net.Addr, key ssh.PublicKey) error { + err := baseCallback(hostname, remote, key) + if err == nil { + return nil + } + var keyErr *knownhosts.KeyError + if errors.As(err, &keyErr) { + // Host was known but key differs (possible MITM): always fail + if len(keyErr.Want) > 0 { + return err + } + // Unknown host (first-time connection): prompt and optionally add to known_hosts + fingerprint := ssh.FingerprintSHA256(key) + addrStr := hostname + if remote != nil { + addrStr = remote.String() + } + fmt.Fprintf(os.Stderr, "The authenticity of host %q can't be established.\n%s key fingerprint is %s.\nAre you sure you want to continue connecting (yes/no)? ", addrStr, key.Type(), fingerprint) + if !promptYesNo() { + return err + } + line := knownhosts.Line([]string{knownhosts.Normalize(addrStr)}, key) + "\n" + f, appendErr := os.OpenFile(knownHostsPath, os.O_WRONLY|os.O_APPEND, 0600) + if appendErr != nil { + return fmt.Errorf("could not append to known_hosts: %w", appendErr) + } + _, _ = f.WriteString(line) + _ = f.Close() + return nil + } + return err + } + + return &ssh.ClientConfig{ User: "sailpoint", - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - Auth: []ssh.AuthMethod{ - ssh.Password(password), - }, + HostKeyCallback: callback, + Auth: []ssh.AuthMethod{ssh.Password(password)}, + }, nil +} + +func RunVACmd(addr string, password string, cmd string) (string, error) { + config, err := newSSHClientConfig(password) + if err != nil { + return "", err } // Connect client, dialErr := ssh.Dial("tcp", net.JoinHostPort(addr, "22"), config) @@ -39,7 +122,6 @@ func RunVACmd(addr string, password string, cmd string) (string, error) { } defer session.Close() - // import "bytes" var b bytes.Buffer // get output @@ -56,12 +138,9 @@ func RunVACmd(addr string, password string, cmd string) (string, error) { } func RunVACmdLive(addr string, password string, cmd string) error { - config := &ssh.ClientConfig{ - User: "sailpoint", - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - Auth: []ssh.AuthMethod{ - ssh.Password(password), - }, + config, err := newSSHClientConfig(password) + if err != nil { + return err } // Connect client, dialErr := ssh.Dial("tcp", net.JoinHostPort(addr, "22"), config) @@ -92,12 +171,9 @@ func RunVACmdLive(addr string, password string, cmd string) error { func CollectVAFiles(endpoint string, password string, output string, files []string, p *mpb.Progress) error { log.Info("Starting File Collection", "VA", endpoint) - config := &ssh.ClientConfig{ - User: "sailpoint", - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - Auth: []ssh.AuthMethod{ - ssh.Password(password), - }, + config, err := newSSHClientConfig(password) + if err != nil { + return err } outputFolder := path.Join(output, endpoint)