Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions USAGE.md
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ WARNING: Use of this option runs against security best practices. It is recommen
To configure the default flag values of `aws-vault` and its subcommands:
* `AWS_VAULT_BACKEND`: Secret backend to use (see the flag `--backend`)
* `AWS_VAULT_BIOMETRICS`: Use biometric authentication using TouchID, if supported (see the flag `--biometrics`)
* `AWS_VAULT_PARALLEL_SAFE`: Enable cross-process locking for keychain and cached credentials (see the flag `--parallel-safe`)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please expand documentation of this feature -there's no explanation of when to use --parallel-safe, what it does, what the trade-offs are (serialized keychain ops), or what backends it applies to. Given the flag is opt-in, users who need it won't know what it does. Could you please add a short section on it?

* `AWS_VAULT_KEYCHAIN_NAME`: Name of macOS keychain to use (see the flag `--keychain`)
* `AWS_VAULT_AUTO_LOGOUT`: Enable auto-logout when doing `login` (see the flag `--auto-logout`)
* `AWS_VAULT_PROMPT`: Prompt driver to use (see the flag `--prompt`)
Expand Down
11 changes: 10 additions & 1 deletion cli/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type ExecCommandInput struct {
NoSession bool
UseStdout bool
ShowHelpMessages bool
ParallelSafe bool
}

func (input ExecCommandInput) validate() error {
Expand Down Expand Up @@ -121,6 +122,7 @@ func ConfigureExecCommand(app *kingpin.Application, a *AwsVault) {
StringsVar(&input.Args)

cmd.Action(func(c *kingpin.ParseContext) (err error) {
input.ParallelSafe = a.ParallelSafe
input.Config.MfaPromptMethod = a.PromptDriver(hasBackgroundServer(input))
input.Config.NonChainedGetSessionTokenDuration = input.SessionDuration
input.Config.AssumeRoleDuration = input.SessionDuration
Expand Down Expand Up @@ -155,6 +157,7 @@ func ConfigureExecCommand(app *kingpin.Application, a *AwsVault) {
Config: input.Config,
SessionDuration: input.SessionDuration,
NoSession: input.NoSession,
ParallelSafe: input.ParallelSafe,
}

err = ExportCommand(exportCommandInput, f, keyring)
Expand Down Expand Up @@ -185,7 +188,13 @@ func ExecCommand(input ExecCommandInput, f *vault.ConfigFile, keyring keyring.Ke
return 0, fmt.Errorf("Error loading config: %w", err)
}

credsProvider, err := vault.NewTempCredentialsProvider(config, &vault.CredentialKeyring{Keyring: keyring}, input.NoSession, false)
credsProvider, err := vault.NewTempCredentialsProviderWithOptions(
config,
&vault.CredentialKeyring{Keyring: keyring},
input.NoSession,
false,
vault.TempCredentialsOptions{ParallelSafe: input.ParallelSafe},
)
if err != nil {
return 0, fmt.Errorf("Error getting temporary credentials: %w", err)
}
Expand Down
10 changes: 9 additions & 1 deletion cli/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type ExportCommandInput struct {
SessionDuration time.Duration
NoSession bool
UseStdout bool
ParallelSafe bool
}

var (
Expand Down Expand Up @@ -66,6 +67,7 @@ func ConfigureExportCommand(app *kingpin.Application, a *AwsVault) {
StringVar(&input.ProfileName)

cmd.Action(func(c *kingpin.ParseContext) (err error) {
input.ParallelSafe = a.ParallelSafe
input.Config.MfaPromptMethod = a.PromptDriver(false)
input.Config.NonChainedGetSessionTokenDuration = input.SessionDuration
input.Config.AssumeRoleDuration = input.SessionDuration
Expand Down Expand Up @@ -108,7 +110,13 @@ func ExportCommand(input ExportCommandInput, f *vault.ConfigFile, keyring keyrin
}

ckr := &vault.CredentialKeyring{Keyring: keyring}
credsProvider, err := vault.NewTempCredentialsProvider(config, ckr, input.NoSession, false)
credsProvider, err := vault.NewTempCredentialsProviderWithOptions(
config,
ckr,
input.NoSession,
false,
vault.TempCredentialsOptions{ParallelSafe: input.ParallelSafe},
)
if err != nil {
return fmt.Errorf("Error getting temporary credentials: %w", err)
}
Expand Down
12 changes: 12 additions & 0 deletions cli/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type AwsVault struct {
KeyringConfig keyring.Config
KeyringBackend string
promptDriver string
ParallelSafe bool

keyringImpl keyring.Keyring
awsConfigFile *vault.ConfigFile
Expand Down Expand Up @@ -77,6 +78,13 @@ func (a *AwsVault) Keyring() (keyring.Keyring, error) {
if err != nil {
return nil, err
}
if a.KeyringBackend == string(keyring.KeychainBackend) && a.ParallelSafe {
lockKey := a.KeyringConfig.KeychainName
if lockKey == "" {
lockKey = "aws-vault"
}
a.keyringImpl = vault.NewKeychainLockedKeyring(a.keyringImpl, lockKey)
}
}

return a.keyringImpl, nil
Expand Down Expand Up @@ -201,6 +209,10 @@ func ConfigureGlobals(app *kingpin.Application) *AwsVault {
Envar("AWS_VAULT_BIOMETRICS").
BoolVar(&a.UseBiometrics)

app.Flag("parallel-safe", "Enable cross-process locking for keychain and cached credentials").
Envar("AWS_VAULT_PARALLEL_SAFE").
BoolVar(&a.ParallelSafe)

app.PreAction(func(c *kingpin.ParseContext) error {
if !a.Debug {
log.SetOutput(io.Discard)
Expand Down
16 changes: 12 additions & 4 deletions cli/rotate.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ import (
)

type RotateCommandInput struct {
NoSession bool
ProfileName string
Config vault.ProfileConfig
NoSession bool
ProfileName string
Config vault.ProfileConfig
ParallelSafe bool
}

func ConfigureRotateCommand(app *kingpin.Application, a *AwsVault) {
Expand All @@ -34,6 +35,7 @@ func ConfigureRotateCommand(app *kingpin.Application, a *AwsVault) {
StringVar(&input.ProfileName)

cmd.Action(func(c *kingpin.ParseContext) (err error) {
input.ParallelSafe = a.ParallelSafe
input.Config.MfaPromptMethod = a.PromptDriver(false)

f, err := a.AwsConfigFile()
Expand Down Expand Up @@ -97,7 +99,13 @@ func RotateCommand(input RotateCommandInput, f *vault.ConfigFile, keyring keyrin
credsProvider = vault.NewMasterCredentialsProvider(ckr, config.ProfileName)
} else {
// Can't always disable sessions completely, might need to use session for MFA-Protected API Access
credsProvider, err = vault.NewTempCredentialsProvider(config, ckr, input.NoSession, true)
credsProvider, err = vault.NewTempCredentialsProviderWithOptions(
config,
ckr,
input.NoSession,
true,
vault.TempCredentialsOptions{ParallelSafe: input.ParallelSafe},
)
if err != nil {
return fmt.Errorf("Error getting temporary credentials: %w", err)
}
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ require (
github.com/byteness/keyring v1.7.2
github.com/charmbracelet/huh v0.8.0
github.com/charmbracelet/lipgloss v1.1.0
github.com/gofrs/flock v0.8.1
github.com/google/go-cmp v0.7.0
github.com/mattn/go-isatty v0.0.20
github.com/mattn/go-tty v0.0.7
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ github.com/extism/go-sdk v1.7.1 h1:lWJos6uY+tRFdlIHR+SJjwFDApY7OypS/2nMhiVQ9Sw=
github.com/extism/go-sdk v1.7.1/go.mod h1:IT+Xdg5AZM9hVtpFUA+uZCJMge/hbvshl8bwzLtFyKA=
github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw=
github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU=
github.com/godbus/dbus/v5 v5.2.2 h1:TUR3TgtSVDmjiXOgAAyaZbYmIeP3DPkld3jgKGV8mXQ=
github.com/godbus/dbus/v5 v5.2.2/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
Expand Down
155 changes: 147 additions & 8 deletions vault/cachedsessionprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package vault

import (
"context"
"fmt"
"log"
"os"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
Expand All @@ -21,23 +23,160 @@ type CachedSessionProvider struct {
SessionProvider StsSessionProvider
Keyring *SessionKeyring
ExpiryWindow time.Duration
UseSessionLock bool
sessionLock SessionCacheLock
sessionLockWait time.Duration
sessionLockLog time.Duration
sessionNow func() time.Time
sessionSleep func(context.Context, time.Duration) error
sessionLogf func(string, ...any)
}

const (
defaultSessionLockWaitDelay = 100 * time.Millisecond
defaultSessionLockLogEvery = 15 * time.Second
defaultSessionLockWarnAfter = 5 * time.Second
)

func defaultSessionSleep(ctx context.Context, d time.Duration) error {
timer := time.NewTimer(d)
defer timer.Stop()

select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}

func (p *CachedSessionProvider) ensureSessionDependencies() {
if p.sessionLock == nil {
p.sessionLock = NewDefaultSessionCacheLock(p.SessionKey.StringForMatching())
}
if p.sessionLockWait == 0 {
p.sessionLockWait = defaultSessionLockWaitDelay
}
if p.sessionLockLog == 0 {
p.sessionLockLog = defaultSessionLockLogEvery
}
if p.sessionNow == nil {
p.sessionNow = time.Now
}
if p.sessionSleep == nil {
p.sessionSleep = defaultSessionSleep
}
if p.sessionLogf == nil {
p.sessionLogf = log.Printf
}
}

func (p *CachedSessionProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) {
creds, err := p.Keyring.Get(p.SessionKey)
creds, cached, err := p.getCachedSession()
if err == nil && cached {
return creds, nil
}

if !p.UseSessionLock {
return p.getSessionWithoutLock(ctx)
}

p.ensureSessionDependencies()

return p.getSessionWithLock(ctx)
}

func (p *CachedSessionProvider) getCachedSession() (creds *ststypes.Credentials, cached bool, err error) {
creds, err = p.Keyring.Get(p.SessionKey)
if err != nil {
return nil, false, err
}
if time.Until(*creds.Expiration) < p.ExpiryWindow {
return nil, false, nil
}
log.Printf("Re-using cached credentials %s from %s, expires in %s", FormatKeyForDisplay(*creds.AccessKeyId), p.SessionKey.Type, time.Until(*creds.Expiration).String())
return creds, true, nil
}

func (p *CachedSessionProvider) getSessionWithLock(ctx context.Context) (*ststypes.Credentials, error) {
waiter := newLockWaiter(
p.sessionLock,
"Waiting for session lock at %s\n",
"Waiting for session lock at %s",
p.sessionLockWait,
p.sessionLockLog,
defaultSessionLockWarnAfter,
p.sessionNow,
p.sessionSleep,
p.sessionLogf,
func(format string, args ...any) {
fmt.Fprintf(os.Stderr, format, args...)
},
)

if err != nil || time.Until(*creds.Expiration) < p.ExpiryWindow {
// lookup missed, we need to create a new one.
creds, err = p.SessionProvider.RetrieveStsCredentials(ctx)
for {
creds, cached, err := p.getCachedSession()
if err == nil && cached {
return creds, nil
}
if ctx.Err() != nil {
return nil, ctx.Err()
}

locked, err := p.sessionLock.TryLock()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here - the manual unlock-on-every-error-path pattern is fragile. Every new code path added in the future must remember to call Unlock(). A panic (e.g. from a nil pointer in newOIDCTokenFn) will leave the lock file held until process death.

The standard Go idiom is defer:

golocked, err := p.ssoTokenLock.TryLock()
if locked {
    defer p.ssoTokenLock.Unlock()
    // ... rest of logic
}

if err != nil {
return nil, err
}
err = p.Keyring.Set(p.SessionKey, creds)
if err != nil {
if locked {
creds, cached, err = p.getCachedSession()
if err == nil && cached {
unlockErr := p.sessionLock.Unlock()
if unlockErr != nil {
return nil, unlockErr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If both the operation and the unlock fail, the unlock error is returned and the original (more useful) error is silently dropped. The original error should be wrapped or joined: fmt.Errorf("unlock: %w (original: %v)", unlockErr, err). This could be a debugging nightmare in production.

The withLock in locked_keyring.go has the same pattern but correctly returns unlockErr only after function succeeds, so it's fine there.

}
return creds, nil
}

creds, err = p.SessionProvider.RetrieveStsCredentials(ctx)
if err != nil {
unlockErr := p.sessionLock.Unlock()
if unlockErr != nil {
return nil, unlockErr
}
return nil, err
}
if err = p.Keyring.Set(p.SessionKey, creds); err != nil {
unlockErr := p.sessionLock.Unlock()
if unlockErr != nil {
return nil, unlockErr
}
return nil, err
}

if err = p.sessionLock.Unlock(); err != nil {
return nil, err
}

return creds, nil
}
if err = waiter.sleepAfterMiss(ctx); err != nil {
return nil, err
}
} else {
log.Printf("Re-using cached credentials %s from %s, expires in %s", FormatKeyForDisplay(*creds.AccessKeyId), p.SessionKey.Type, time.Until(*creds.Expiration).String())
}
}

func (p *CachedSessionProvider) getSessionWithoutLock(ctx context.Context) (*ststypes.Credentials, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}

creds, err := p.SessionProvider.RetrieveStsCredentials(ctx)
if err != nil {
return nil, err
}

if err = p.Keyring.Set(p.SessionKey, creds); err != nil {
return nil, err
}

return creds, nil
Expand Down
Loading