-
Notifications
You must be signed in to change notification settings - Fork 17
feat: Allow aws-vault to safely be run in parallel #291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
dad4d29
091f156
3aeaf36
21896e7
a5670db
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,9 @@ package vault | |
|
|
||
| import ( | ||
| "context" | ||
| "fmt" | ||
| "log" | ||
| "os" | ||
| "time" | ||
|
|
||
| "github.com/aws/aws-sdk-go-v2/aws" | ||
|
|
@@ -21,23 +23,160 @@ type CachedSessionProvider struct { | |
| SessionProvider StsSessionProvider | ||
| Keyring *SessionKeyring | ||
| ExpiryWindow time.Duration | ||
| UseSessionLock bool | ||
| sessionLock SessionCacheLock | ||
| sessionLockWait time.Duration | ||
| sessionLockLog time.Duration | ||
| sessionNow func() time.Time | ||
| sessionSleep func(context.Context, time.Duration) error | ||
| sessionLogf func(string, ...any) | ||
| } | ||
|
|
||
| const ( | ||
| defaultSessionLockWaitDelay = 100 * time.Millisecond | ||
| defaultSessionLockLogEvery = 15 * time.Second | ||
| defaultSessionLockWarnAfter = 5 * time.Second | ||
| ) | ||
|
|
||
| func defaultSessionSleep(ctx context.Context, d time.Duration) error { | ||
| timer := time.NewTimer(d) | ||
| defer timer.Stop() | ||
|
|
||
| select { | ||
| case <-ctx.Done(): | ||
| return ctx.Err() | ||
| case <-timer.C: | ||
| return nil | ||
| } | ||
| } | ||
|
|
||
| func (p *CachedSessionProvider) ensureSessionDependencies() { | ||
| if p.sessionLock == nil { | ||
| p.sessionLock = NewDefaultSessionCacheLock(p.SessionKey.StringForMatching()) | ||
| } | ||
| if p.sessionLockWait == 0 { | ||
| p.sessionLockWait = defaultSessionLockWaitDelay | ||
| } | ||
| if p.sessionLockLog == 0 { | ||
| p.sessionLockLog = defaultSessionLockLogEvery | ||
| } | ||
| if p.sessionNow == nil { | ||
| p.sessionNow = time.Now | ||
| } | ||
| if p.sessionSleep == nil { | ||
| p.sessionSleep = defaultSessionSleep | ||
| } | ||
| if p.sessionLogf == nil { | ||
| p.sessionLogf = log.Printf | ||
| } | ||
| } | ||
|
|
||
| func (p *CachedSessionProvider) RetrieveStsCredentials(ctx context.Context) (*ststypes.Credentials, error) { | ||
| creds, err := p.Keyring.Get(p.SessionKey) | ||
| creds, cached, err := p.getCachedSession() | ||
| if err == nil && cached { | ||
| return creds, nil | ||
| } | ||
|
|
||
| if !p.UseSessionLock { | ||
| return p.getSessionWithoutLock(ctx) | ||
| } | ||
|
|
||
| p.ensureSessionDependencies() | ||
|
|
||
| return p.getSessionWithLock(ctx) | ||
| } | ||
|
|
||
| func (p *CachedSessionProvider) getCachedSession() (creds *ststypes.Credentials, cached bool, err error) { | ||
| creds, err = p.Keyring.Get(p.SessionKey) | ||
| if err != nil { | ||
| return nil, false, err | ||
| } | ||
| if time.Until(*creds.Expiration) < p.ExpiryWindow { | ||
| return nil, false, nil | ||
| } | ||
| log.Printf("Re-using cached credentials %s from %s, expires in %s", FormatKeyForDisplay(*creds.AccessKeyId), p.SessionKey.Type, time.Until(*creds.Expiration).String()) | ||
| return creds, true, nil | ||
| } | ||
|
|
||
| func (p *CachedSessionProvider) getSessionWithLock(ctx context.Context) (*ststypes.Credentials, error) { | ||
| waiter := newLockWaiter( | ||
| p.sessionLock, | ||
| "Waiting for session lock at %s\n", | ||
| "Waiting for session lock at %s", | ||
| p.sessionLockWait, | ||
| p.sessionLockLog, | ||
| defaultSessionLockWarnAfter, | ||
| p.sessionNow, | ||
| p.sessionSleep, | ||
| p.sessionLogf, | ||
| func(format string, args ...any) { | ||
| fmt.Fprintf(os.Stderr, format, args...) | ||
| }, | ||
| ) | ||
|
|
||
| if err != nil || time.Until(*creds.Expiration) < p.ExpiryWindow { | ||
| // lookup missed, we need to create a new one. | ||
| creds, err = p.SessionProvider.RetrieveStsCredentials(ctx) | ||
| for { | ||
| creds, cached, err := p.getCachedSession() | ||
| if err == nil && cached { | ||
| return creds, nil | ||
| } | ||
| if ctx.Err() != nil { | ||
| return nil, ctx.Err() | ||
| } | ||
|
|
||
| locked, err := p.sessionLock.TryLock() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here - the manual unlock-on-every-error-path pattern is fragile. Every new code path added in the future must remember to call The standard Go idiom is golocked, err := p.ssoTokenLock.TryLock()
if locked {
defer p.ssoTokenLock.Unlock()
// ... rest of logic
} |
||
| if err != nil { | ||
| return nil, err | ||
| } | ||
| err = p.Keyring.Set(p.SessionKey, creds) | ||
| if err != nil { | ||
| if locked { | ||
| creds, cached, err = p.getCachedSession() | ||
| if err == nil && cached { | ||
| unlockErr := p.sessionLock.Unlock() | ||
| if unlockErr != nil { | ||
| return nil, unlockErr | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If both the operation and the unlock fail, the unlock error is returned and the original (more useful) error is silently dropped. The original error should be wrapped or joined: fmt.Errorf("unlock: %w (original: %v)", unlockErr, err). This could be a debugging nightmare in production. The |
||
| } | ||
| return creds, nil | ||
| } | ||
|
|
||
| creds, err = p.SessionProvider.RetrieveStsCredentials(ctx) | ||
| if err != nil { | ||
| unlockErr := p.sessionLock.Unlock() | ||
| if unlockErr != nil { | ||
| return nil, unlockErr | ||
| } | ||
| return nil, err | ||
| } | ||
| if err = p.Keyring.Set(p.SessionKey, creds); err != nil { | ||
| unlockErr := p.sessionLock.Unlock() | ||
| if unlockErr != nil { | ||
| return nil, unlockErr | ||
| } | ||
| return nil, err | ||
| } | ||
|
|
||
| if err = p.sessionLock.Unlock(); err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| return creds, nil | ||
| } | ||
| if err = waiter.sleepAfterMiss(ctx); err != nil { | ||
| return nil, err | ||
| } | ||
| } else { | ||
| log.Printf("Re-using cached credentials %s from %s, expires in %s", FormatKeyForDisplay(*creds.AccessKeyId), p.SessionKey.Type, time.Until(*creds.Expiration).String()) | ||
| } | ||
| } | ||
|
|
||
| func (p *CachedSessionProvider) getSessionWithoutLock(ctx context.Context) (*ststypes.Credentials, error) { | ||
| if ctx.Err() != nil { | ||
| return nil, ctx.Err() | ||
| } | ||
|
|
||
| creds, err := p.SessionProvider.RetrieveStsCredentials(ctx) | ||
| if err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| if err = p.Keyring.Set(p.SessionKey, creds); err != nil { | ||
| return nil, err | ||
| } | ||
|
|
||
| return creds, nil | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please expand documentation of this feature -there's no explanation of when to use
--parallel-safe, what it does, what the trade-offs are (serialized keychain ops), or what backends it applies to. Given the flag is opt-in, users who need it won't know what it does. Could you please add a short section on it?