Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/connector/conn_tag_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func newConnTagCreateCmd(client client.Client) *cobra.Command {
tagName := cmd.Flags().Lookup("name").Value.String()
versionStr := cmd.Flags().Lookup("version").Value.String()

version, err := strconv.Atoi(versionStr)
version, err := strconv.ParseUint(versionStr, 10, 32)
if err != nil {
return err
}
Expand Down
114 changes: 95 additions & 19 deletions internal/va/va.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,112 @@
package va

import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"net"
"os"
"path"
"strings"
"sync"

"github.com/charmbracelet/log"
"github.com/pkg/sftp"
"github.com/vbauerster/mpb/v8"
"github.com/vbauerster/mpb/v8/decor"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
)

func RunVACmd(addr string, password string, cmd string) (string, error) {
config := &ssh.ClientConfig{
const sshDirMode = 0700

// ensureKnownHostsFile creates ~/.ssh and an empty known_hosts file if missing.
func ensureKnownHostsFile(knownHostsPath string) error {
sshDir := path.Dir(knownHostsPath)
if err := os.MkdirAll(sshDir, sshDirMode); err != nil {
return fmt.Errorf("could not create %s: %w", sshDir, err)
}
f, err := os.OpenFile(knownHostsPath, os.O_CREATE|os.O_RDONLY, 0600)
if err != nil {
return fmt.Errorf("could not create %s: %w", knownHostsPath, err)
}
_ = f.Close()
return nil
}

// promptYesNo reads a line from stdin and returns true for yes, false for no or invalid input.
func promptYesNo() bool {
scanner := bufio.NewScanner(os.Stdin)
if !scanner.Scan() {
return false
}
line := strings.TrimSpace(strings.ToLower(scanner.Text()))
return line == "yes" || line == "y"
}

// newSSHClientConfig returns an ssh.ClientConfig with host key verification via ~/.ssh/known_hosts
// and interactive first-time host acceptance.
func newSSHClientConfig(password string) (*ssh.ClientConfig, error) {
userHome, err := os.UserHomeDir()
if err != nil {
return nil, fmt.Errorf("could not determine user home directory: %w", err)
}
knownHostsPath := path.Join(userHome, ".ssh", "known_hosts")
if err := ensureKnownHostsFile(knownHostsPath); err != nil {
return nil, err
}

baseCallback, err := knownhosts.New(knownHostsPath)
if err != nil {
return nil, fmt.Errorf("could not load SSH known_hosts from %s: %w", knownHostsPath, err)
}

callback := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
err := baseCallback(hostname, remote, key)
if err == nil {
return nil
}
var keyErr *knownhosts.KeyError
if errors.As(err, &keyErr) {
// Host was known but key differs (possible MITM): always fail
if len(keyErr.Want) > 0 {
return err
}
// Unknown host (first-time connection): prompt and optionally add to known_hosts
fingerprint := ssh.FingerprintSHA256(key)
addrStr := hostname
if remote != nil {
addrStr = remote.String()
}
fmt.Fprintf(os.Stderr, "The authenticity of host %q can't be established.\n%s key fingerprint is %s.\nAre you sure you want to continue connecting (yes/no)? ", addrStr, key.Type(), fingerprint)
if !promptYesNo() {
return err
}
line := knownhosts.Line([]string{knownhosts.Normalize(addrStr)}, key) + "\n"
f, appendErr := os.OpenFile(knownHostsPath, os.O_WRONLY|os.O_APPEND, 0600)
if appendErr != nil {
return fmt.Errorf("could not append to known_hosts: %w", appendErr)
}
_, _ = f.WriteString(line)
_ = f.Close()
return nil
}
return err
}

return &ssh.ClientConfig{
User: "sailpoint",
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Auth: []ssh.AuthMethod{
ssh.Password(password),
},
HostKeyCallback: callback,
Auth: []ssh.AuthMethod{ssh.Password(password)},
}, nil
}

func RunVACmd(addr string, password string, cmd string) (string, error) {
config, err := newSSHClientConfig(password)
if err != nil {
return "", err
}
// Connect
client, dialErr := ssh.Dial("tcp", net.JoinHostPort(addr, "22"), config)
Expand All @@ -39,7 +122,6 @@ func RunVACmd(addr string, password string, cmd string) (string, error) {
}
defer session.Close()

// import "bytes"
var b bytes.Buffer

// get output
Expand All @@ -56,12 +138,9 @@ func RunVACmd(addr string, password string, cmd string) (string, error) {
}

func RunVACmdLive(addr string, password string, cmd string) error {
config := &ssh.ClientConfig{
User: "sailpoint",
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Auth: []ssh.AuthMethod{
ssh.Password(password),
},
config, err := newSSHClientConfig(password)
if err != nil {
return err
}
// Connect
client, dialErr := ssh.Dial("tcp", net.JoinHostPort(addr, "22"), config)
Expand Down Expand Up @@ -92,12 +171,9 @@ func RunVACmdLive(addr string, password string, cmd string) error {
func CollectVAFiles(endpoint string, password string, output string, files []string, p *mpb.Progress) error {
log.Info("Starting File Collection", "VA", endpoint)

config := &ssh.ClientConfig{
User: "sailpoint",
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Auth: []ssh.AuthMethod{
ssh.Password(password),
},
config, err := newSSHClientConfig(password)
if err != nil {
return err
}

outputFolder := path.Join(output, endpoint)
Expand Down
Loading