diff --git a/internal/aws/credentials/chain_provider.go b/internal/aws/credentials/chain_provider.go index 942c7690fa..aca9ec77d9 100644 --- a/internal/aws/credentials/chain_provider.go +++ b/internal/aws/credentials/chain_provider.go @@ -11,6 +11,8 @@ package credentials import ( + "context" + "go.mongodb.org/mongo-driver/v2/internal/aws/awserr" ) @@ -45,10 +47,10 @@ func NewChainCredentials(providers []Provider) *Credentials { // // If a provider is found it will be cached and any calls to IsExpired() // will return the expired state of the cached provider. -func (c *ChainProvider) Retrieve() (Value, error) { +func (c *ChainProvider) Retrieve(ctx context.Context) (Value, error) { var errs = make([]error, 0, len(c.Providers)) for _, p := range c.Providers { - creds, err := p.Retrieve() + creds, err := p.Retrieve(ctx) if err == nil { c.curr = p return creds, nil diff --git a/internal/aws/credentials/chain_provider_test.go b/internal/aws/credentials/chain_provider_test.go index 3680b78105..4514a8f122 100644 --- a/internal/aws/credentials/chain_provider_test.go +++ b/internal/aws/credentials/chain_provider_test.go @@ -11,6 +11,7 @@ package credentials import ( + "context" "reflect" "testing" @@ -23,7 +24,7 @@ type secondStubProvider struct { err error } -func (s *secondStubProvider) Retrieve() (Value, error) { +func (s *secondStubProvider) Retrieve(_ context.Context) (Value, error) { s.expired = false s.creds.ProviderName = "secondStubProvider" return s.creds, s.err @@ -54,7 +55,7 @@ func TestChainProviderWithNames(t *testing.T) { }, } - creds, err := p.Retrieve() + creds, err := p.Retrieve(context.Background()) if err != nil { t.Errorf("Expect no error, got %v", err) } @@ -90,7 +91,7 @@ func TestChainProviderGet(t *testing.T) { }, } - creds, err := p.Retrieve() + creds, err := p.Retrieve(context.Background()) if err != nil { t.Errorf("Expect no error, got %v", err) } @@ -113,10 +114,12 @@ func TestChainProviderIsExpired(t *testing.T) { }, } + ctx := context.Background() + if !p.IsExpired() { t.Errorf("Expect expired to be true before any Retrieve") } - _, err := p.Retrieve() + _, err := p.Retrieve(ctx) if err != nil { t.Errorf("Expect no error, got %v", err) } @@ -129,7 +132,7 @@ func TestChainProviderIsExpired(t *testing.T) { t.Errorf("Expect return of expired provider") } - _, err = p.Retrieve() + _, err = p.Retrieve(ctx) if err != nil { t.Errorf("Expect no error, got %v", err) } @@ -146,7 +149,7 @@ func TestChainProviderWithNoProvider(t *testing.T) { if !p.IsExpired() { t.Errorf("Expect expired with no providers") } - _, err := p.Retrieve() + _, err := p.Retrieve(context.Background()) if err.Error() != "NoCredentialProviders: no valid providers in chain" { t.Errorf("Expect no providers error returned, got %v", err) } @@ -167,7 +170,7 @@ func TestChainProviderWithNoValidProvider(t *testing.T) { if !p.IsExpired() { t.Errorf("Expect expired with no providers") } - _, err := p.Retrieve() + _, err := p.Retrieve(context.Background()) expectErr := awserr.NewBatchError("NoCredentialProviders", "no valid providers in chain", errs) if e, a := expectErr, err; !reflect.DeepEqual(e, a) { diff --git a/internal/aws/credentials/credentials.go b/internal/aws/credentials/credentials.go index 919d0819b1..6d46bce6ec 100644 --- a/internal/aws/credentials/credentials.go +++ b/internal/aws/credentials/credentials.go @@ -52,20 +52,13 @@ func (v Value) HasKeys() bool { type Provider interface { // Retrieve returns nil if it successfully retrieved the value. // Error is returned if the value were not obtainable, or empty. - Retrieve() (Value, error) + Retrieve(context.Context) (Value, error) // IsExpired returns if the credentials are no longer valid, and need // to be retrieved. IsExpired() bool } -// ProviderWithContext is a Provider that can retrieve credentials with a Context -type ProviderWithContext interface { - Provider - - RetrieveWithContext(context.Context) (Value, error) -} - // A Credentials provides concurrency safe retrieval of AWS credentials Value. // // A Credentials is also used to fetch Azure credentials Value. @@ -143,13 +136,7 @@ func (c *Credentials) singleRetrieve(ctx context.Context) (interface{}, error) { return curCreds, nil } - var creds Value - var err error - if p, ok := c.provider.(ProviderWithContext); ok { - creds, err = p.RetrieveWithContext(ctx) - } else { - creds, err = c.provider.Retrieve() - } + creds, err := c.provider.Retrieve(ctx) if err == nil { c.creds = creds } diff --git a/internal/aws/credentials/credentials_test.go b/internal/aws/credentials/credentials_test.go index e10848c7ac..4847441264 100644 --- a/internal/aws/credentials/credentials_test.go +++ b/internal/aws/credentials/credentials_test.go @@ -33,7 +33,7 @@ type stubProvider struct { err error } -func (s *stubProvider) Retrieve() (Value, error) { +func (s *stubProvider) Retrieve(_ context.Context) (Value, error) { s.retrievedCount++ s.expired = false s.creds.ProviderName = "stubProvider" @@ -133,7 +133,7 @@ func (e *MockProvider) IsExpired() bool { return e.expiration.Before(curTime()) } -func (*MockProvider) Retrieve() (Value, error) { +func (*MockProvider) Retrieve(_ context.Context) (Value, error) { return Value{}, nil } @@ -162,9 +162,9 @@ type stubProviderConcurrent struct { done chan struct{} } -func (s *stubProviderConcurrent) Retrieve() (Value, error) { +func (s *stubProviderConcurrent) Retrieve(ctx context.Context) (Value, error) { <-s.done - return s.stubProvider.Retrieve() + return s.stubProvider.Retrieve(ctx) } func TestCredentialsGetConcurrent(t *testing.T) { diff --git a/internal/aws/types.go b/internal/aws/types.go index 52aecda76b..d4fc5cc2ec 100644 --- a/internal/aws/types.go +++ b/internal/aws/types.go @@ -12,8 +12,31 @@ package aws import ( "io" + "time" ) +// Credentials represents AWS credentials. +type Credentials struct { + AccessKeyID string + SecretAccessKey string + SessionToken string + Source string + CanExpire bool + Expires time.Time + AccountID string +} + +func (v Credentials) Expired() bool { + if v.CanExpire { + // Calling Round(0) on the current time will truncate the monotonic + // reading only. Ensures credential expiry time is always based on + // reported wall-clock time. + return !v.Expires.After(time.Now().Round(0)) + } + + return false +} + // ReadSeekCloser wraps a io.Reader returning a ReaderSeekerCloser. Allows the // SDK to accept an io.Reader that is not also an io.Seeker for unsigned // streaming payload API operations. diff --git a/internal/credproviders/assume_role_provider.go b/internal/credproviders/assume_role_provider.go index eec2247c70..36aa939336 100644 --- a/internal/credproviders/assume_role_provider.go +++ b/internal/credproviders/assume_role_provider.go @@ -57,8 +57,8 @@ func NewAssumeRoleProvider(httpClient *http.Client, expiryWindow time.Duration) } } -// RetrieveWithContext retrieves the keys from the AWS service. -func (a *AssumeRoleProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) { +// Retrieve retrieves the keys from the AWS service. +func (a *AssumeRoleProvider) Retrieve(ctx context.Context) (credentials.Value, error) { const defaultHTTPTimeout = 10 * time.Second v := credentials.Value{ProviderName: assumeRoleProviderName} @@ -137,11 +137,6 @@ func (a *AssumeRoleProvider) RetrieveWithContext(ctx context.Context) (credentia return v, nil } -// Retrieve retrieves the keys from the AWS service. -func (a *AssumeRoleProvider) Retrieve() (credentials.Value, error) { - return a.RetrieveWithContext(context.Background()) -} - // IsExpired returns true if the credentials are expired. func (a *AssumeRoleProvider) IsExpired() bool { return a.expiration.Before(time.Now()) diff --git a/internal/credproviders/aws_provider.go b/internal/credproviders/aws_provider.go new file mode 100644 index 0000000000..9e16b18068 --- /dev/null +++ b/internal/credproviders/aws_provider.go @@ -0,0 +1,47 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package credproviders + +import ( + "context" + + "go.mongodb.org/mongo-driver/v2/internal/aws" + "go.mongodb.org/mongo-driver/v2/internal/aws/credentials" +) + +const awsProviderName = "AwsProvider" + +// AwsProvider retrieves credentials from the given AWS credentials provider. +type AwsProvider struct { + credentials *aws.Credentials + Provider func(context.Context) (aws.Credentials, error) +} + +// Retrieve retrieves the keys from the given AWS credentials provider. +func (a *AwsProvider) Retrieve(ctx context.Context) (credentials.Value, error) { + var value credentials.Value + if a.credentials == nil { + creds, err := a.Provider(ctx) + if err != nil { + return value, err + } + a.credentials = &creds + } + value.AccessKeyID = a.credentials.AccessKeyID + value.SecretAccessKey = a.credentials.SecretAccessKey + value.SessionToken = a.credentials.SessionToken + value.ProviderName = awsProviderName + return value, nil +} + +// IsExpired returns true if the credentials have not been retrieved. +func (a *AwsProvider) IsExpired() bool { + if a.credentials == nil { + return true + } + return a.credentials.Expired() +} diff --git a/internal/credproviders/ec2_provider.go b/internal/credproviders/ec2_provider.go index df15a7f2c3..3b299f894c 100644 --- a/internal/credproviders/ec2_provider.go +++ b/internal/credproviders/ec2_provider.go @@ -146,8 +146,8 @@ func (e *EC2Provider) getCredentials(ctx context.Context, token string, role str return v, ec2Resp.Expiration, nil } -// RetrieveWithContext retrieves the keys from the AWS service. -func (e *EC2Provider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) { +// Retrieve retrieves the keys from the AWS service. +func (e *EC2Provider) Retrieve(ctx context.Context) (credentials.Value, error) { v := credentials.Value{ProviderName: ec2ProviderName} token, err := e.getToken(ctx) @@ -172,11 +172,6 @@ func (e *EC2Provider) RetrieveWithContext(ctx context.Context) (credentials.Valu return v, nil } -// Retrieve retrieves the keys from the AWS service. -func (e *EC2Provider) Retrieve() (credentials.Value, error) { - return e.RetrieveWithContext(context.Background()) -} - // IsExpired returns true if the credentials are expired. func (e *EC2Provider) IsExpired() bool { return e.expiration.Before(time.Now()) diff --git a/internal/credproviders/ecs_provider.go b/internal/credproviders/ecs_provider.go index 6816a3182d..a292fe2ee3 100644 --- a/internal/credproviders/ecs_provider.go +++ b/internal/credproviders/ecs_provider.go @@ -49,8 +49,8 @@ func NewECSProvider(httpClient *http.Client, expiryWindow time.Duration) *ECSPro } } -// RetrieveWithContext retrieves the keys from the AWS service. -func (e *ECSProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) { +// Retrieve retrieves the keys from the AWS service. +func (e *ECSProvider) Retrieve(ctx context.Context) (credentials.Value, error) { const defaultHTTPTimeout = 10 * time.Second v := credentials.Value{ProviderName: ecsProviderName} @@ -101,11 +101,6 @@ func (e *ECSProvider) RetrieveWithContext(ctx context.Context) (credentials.Valu return v, nil } -// Retrieve retrieves the keys from the AWS service. -func (e *ECSProvider) Retrieve() (credentials.Value, error) { - return e.RetrieveWithContext(context.Background()) -} - // IsExpired returns true if the credentials are expired. func (e *ECSProvider) IsExpired() bool { return e.expiration.Before(time.Now()) diff --git a/internal/credproviders/env_provider.go b/internal/credproviders/env_provider.go index cf6bb60b31..7c77c18b98 100644 --- a/internal/credproviders/env_provider.go +++ b/internal/credproviders/env_provider.go @@ -7,6 +7,7 @@ package credproviders import ( + "context" "os" "go.mongodb.org/mongo-driver/v2/internal/aws/credentials" @@ -46,7 +47,7 @@ func NewEnvProvider() *EnvProvider { } // Retrieve retrieves the keys from the environment. -func (e *EnvProvider) Retrieve() (credentials.Value, error) { +func (e *EnvProvider) Retrieve(_ context.Context) (credentials.Value, error) { e.retrieved = false v := credentials.Value{ diff --git a/internal/credproviders/imds_provider.go b/internal/credproviders/imds_provider.go index f3674c727b..f2a801a9e2 100644 --- a/internal/credproviders/imds_provider.go +++ b/internal/credproviders/imds_provider.go @@ -41,8 +41,8 @@ func NewAzureProvider(httpClient *http.Client, expiryWindow time.Duration) *Azur } } -// RetrieveWithContext retrieves the keys from the Azure service. -func (a *AzureProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) { +// Retrieve retrieves the keys from the Azure service. +func (a *AzureProvider) Retrieve(ctx context.Context) (credentials.Value, error) { v := credentials.Value{ProviderName: AzureProviderName} req, err := http.NewRequest(http.MethodGet, azureURI, nil) if err != nil { @@ -92,11 +92,6 @@ func (a *AzureProvider) RetrieveWithContext(ctx context.Context) (credentials.Va return v, err } -// Retrieve retrieves the keys from the Azure service. -func (a *AzureProvider) Retrieve() (credentials.Value, error) { - return a.RetrieveWithContext(context.Background()) -} - // IsExpired returns if the credentials have been retrieved. func (a *AzureProvider) IsExpired() bool { return a.expiration.Before(time.Now()) diff --git a/internal/credproviders/static_provider.go b/internal/credproviders/static_provider.go index bbb0e8033a..1b8ddd4d68 100644 --- a/internal/credproviders/static_provider.go +++ b/internal/credproviders/static_provider.go @@ -7,6 +7,7 @@ package credproviders import ( + "context" "errors" "go.mongodb.org/mongo-driver/v2/internal/aws/credentials" @@ -42,7 +43,7 @@ func verify(v credentials.Value) error { } // Retrieve returns the credentials or error if the credentials are invalid. -func (s *StaticProvider) Retrieve() (credentials.Value, error) { +func (s *StaticProvider) Retrieve(_ context.Context) (credentials.Value, error) { if !s.verified { s.err = verify(s.Value) s.Value.ProviderName = staticProviderName diff --git a/internal/integration/client_side_encryption_prose_test.go b/internal/integration/client_side_encryption_prose_test.go index 18d55c4cf5..12e49f65a5 100644 --- a/internal/integration/client_side_encryption_prose_test.go +++ b/internal/integration/client_side_encryption_prose_test.go @@ -3147,6 +3147,124 @@ func TestClientSideEncryptionProse(t *testing.T) { }) } +func TestCustomAwsCredentialsProse(t *testing.T) { + mt := mtest.New(t, mtest.NewOptions().CreateClient(false)) + + mt.Run("Case 1: ClientEncryption with credentialProviders and incorrect kmsProviders", func(mt *mtest.T) { + opts := options.Client().ApplyURI(mtest.ClusterURI()) + integtest.AddTestServerAPIVersion(opts) + keyVaultClient, err := mongo.Connect(opts) + assert.NoErrorf(mt, err, "error on Connect: %v", err) + + ceo := options.ClientEncryption(). + SetKeyVaultNamespace("keyvault.datakeys"). + SetKmsProviders(map[string]map[string]any{ + "aws": { + "accessKeyId": awsAccessKeyID, + "secretAccessKey": awsSecretAccessKey, + }, + }). + SetCredentialProviders(map[string]options.CredentialsProvider{ + "aws": func(ctx context.Context) (options.Credentials, error) { + return options.Credentials{}, nil + }, + }) + _, err = mongo.NewClientEncryption(keyVaultClient, ceo) + assert.ErrorContains(mt, err, "can only provide a custom AWS credential provider", + "unexpected error: %v", err) + }) + + mt.Run("Case 2: ClientEncryption with credentialProviders works", func(mt *mtest.T) { + opts := options.Client().ApplyURI(mtest.ClusterURI()) + integtest.AddTestServerAPIVersion(opts) + keyVaultClient, err := mongo.Connect(opts) + assert.NoErrorf(mt, err, "error on Connect: %v", err) + + var calledCount int + ceo := options.ClientEncryption(). + SetKeyVaultNamespace("keyvault.datakeys"). + SetKmsProviders(map[string]map[string]any{ + "aws": map[string]any{}, + }). + SetCredentialProviders(map[string]options.CredentialsProvider{ + "aws": func(_ context.Context) (options.Credentials, error) { + calledCount++ + return options.Credentials{ + AccessKeyID: awsAccessKeyID, + SecretAccessKey: awsSecretAccessKey, + }, nil + }, + }) + clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo) + assert.NoErrorf(mt, err, "error on NewClientEncryption: %v", err) + + dkOpts := options.DataKey().SetMasterKey(bson.D{ + {"region", "us-east-1"}, + {"key", "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0"}, + }) + _, err = clientEncryption.CreateDataKey(context.Background(), "aws", dkOpts) + assert.NoErrorf(mt, err, "unexpected error %v", err) + assert.Equal(mt, 1, calledCount, "expected credential provider to be called once") + }) + + mt.Run("Case 3: AutoEncryptionOpts with credentialProviders and incorrect kmsProviders", func(mt *mtest.T) { + aeo := options.AutoEncryption(). + SetKeyVaultNamespace("keyvault.datakeys"). + SetKmsProviders(map[string]map[string]any{ + "aws": { + "accessKeyId": awsAccessKeyID, + "secretAccessKey": awsSecretAccessKey, + }, + }). + SetCredentialProviders(map[string]options.CredentialsProvider{ + "aws": func(ctx context.Context) (options.Credentials, error) { + return options.Credentials{}, nil + }, + }) + co := options.Client().SetAutoEncryptionOptions(aeo).ApplyURI(mtest.ClusterURI()) + integtest.AddTestServerAPIVersion(co) + _, err := mongo.Connect(co) + assert.ErrorContainsf(mt, err, "can only provide a custom AWS credential provider", + "unexpected error: %v", err) + }) + + mt.Run("Case 4: ClientEncryption with credentialProviders and valid environment variables", func(mt *mtest.T) { + mt.Setenv("AWS_ACCESS_KEY_ID", os.Getenv("FLE_AWS_SECRET_ACCESS_KEY")) + mt.Setenv("AWS_SECRET_ACCESS_KEY", os.Getenv("FLE_AWS_ACCESS_KEY_ID")) + + opts := options.Client().ApplyURI(mtest.ClusterURI()) + integtest.AddTestServerAPIVersion(opts) + keyVaultClient, err := mongo.Connect(opts) + assert.NoErrorf(mt, err, "error on Connect: %v", err) + + var calledCount int + ceo := options.ClientEncryption(). + SetKeyVaultNamespace("keyvault.datakeys"). + SetKmsProviders(map[string]map[string]any{ + "aws": map[string]any{}, + }). + SetCredentialProviders(map[string]options.CredentialsProvider{ + "aws": func(ctx context.Context) (options.Credentials, error) { + calledCount++ + return options.Credentials{ + AccessKeyID: awsAccessKeyID, + SecretAccessKey: awsSecretAccessKey, + }, nil + }, + }) + clientEncryption, err := mongo.NewClientEncryption(keyVaultClient, ceo) + assert.NoErrorf(mt, err, "error on NewClientEncryption: %v", err) + + dkOpts := options.DataKey().SetMasterKey(bson.D{ + {"region", "us-east-1"}, + {"key", "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0"}, + }) + _, err = clientEncryption.CreateDataKey(context.Background(), "aws", dkOpts) + assert.NoErrorf(mt, err, "unexpected error %v", err) + assert.Equal(mt, 1, calledCount, "expected credential provider to be called once") + }) +} + func getWatcher(mt *mtest.T, streamType mongo.StreamType, cpt *cseProseTest) watcher { mt.Helper() diff --git a/mongo/client.go b/mongo/client.go index f0480a0c72..5351aecbd9 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -15,6 +15,9 @@ import ( "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/event" + "go.mongodb.org/mongo-driver/v2/internal/aws" + "go.mongodb.org/mongo-driver/v2/internal/aws/credentials" + "go.mongodb.org/mongo-driver/v2/internal/credproviders" "go.mongodb.org/mongo-driver/v2/internal/httputil" "go.mongodb.org/mongo-driver/v2/internal/logger" "go.mongodb.org/mongo-driver/v2/internal/mongoutil" @@ -595,6 +598,21 @@ func (c *Client) newMongoCrypt(opts *options.AutoEncryptionOptions) (*mongocrypt bypassAutoEncryption := opts.BypassAutoEncryption != nil && *opts.BypassAutoEncryption bypassQueryAnalysis := opts.BypassQueryAnalysis != nil && *opts.BypassQueryAnalysis + providers := make(map[string]credentials.Provider) + for k, fn := range opts.CredentialProviders { + if k == "aws" && fn != nil { + providers[k] = &credproviders.AwsProvider{ + Provider: func(ctx context.Context) (aws.Credentials, error) { + c, err := fn(ctx) + if err != nil { + return aws.Credentials{}, err + } + return aws.Credentials(c), nil + }, + } + } + } + mc, err := mongocrypt.NewMongoCrypt(mcopts.MongoCrypt(). SetKmsProviders(kmsProviders). SetLocalSchemaMap(cryptSchemaMap). @@ -603,7 +621,8 @@ func (c *Client) newMongoCrypt(opts *options.AutoEncryptionOptions) (*mongocrypt SetCryptSharedLibDisabled(cryptSharedLibDisabled || bypassAutoEncryption). SetCryptSharedLibOverridePath(cryptSharedLibPath). SetHTTPClient(opts.HTTPClient). - SetKeyExpiration(opts.KeyExpiration)) + SetKeyExpiration(opts.KeyExpiration). + SetCredentialProviders(providers)) if err != nil { return nil, err } diff --git a/mongo/client_encryption.go b/mongo/client_encryption.go index 32851ffffb..974793a2d1 100644 --- a/mongo/client_encryption.go +++ b/mongo/client_encryption.go @@ -14,6 +14,9 @@ import ( "strings" "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/internal/aws" + "go.mongodb.org/mongo-driver/v2/internal/aws/credentials" + "go.mongodb.org/mongo-driver/v2/internal/credproviders" "go.mongodb.org/mongo-driver/v2/internal/mongoutil" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" @@ -53,6 +56,21 @@ func NewClientEncryption(keyVaultClient *Client, opts ...options.Lister[options. return nil, fmt.Errorf("error creating KMS providers map: %w", err) } + providers := make(map[string]credentials.Provider) + for k, fn := range cea.CredentialProviders { + if k == "aws" && fn != nil { + providers[k] = &credproviders.AwsProvider{ + Provider: func(ctx context.Context) (aws.Credentials, error) { + c, err := fn(ctx) + if err != nil { + return aws.Credentials{}, err + } + return aws.Credentials(c), nil + }, + } + } + } + mc, err := mongocrypt.NewMongoCrypt(mcopts.MongoCrypt(). SetKmsProviders(kmsProviders). // Explicitly disable loading the crypt_shared library for the Crypt used for @@ -60,7 +78,8 @@ func NewClientEncryption(keyVaultClient *Client, opts ...options.Lister[options. // have the crypt_shared library installed if they're using ClientEncryption. SetCryptSharedLibDisabled(true). SetHTTPClient(cea.HTTPClient). - SetKeyExpiration(cea.KeyExpiration)) + SetKeyExpiration(cea.KeyExpiration). + SetCredentialProviders(providers)) if err != nil { return nil, err } diff --git a/mongo/client_examples_test.go b/mongo/client_examples_test.go index 971fc74f08..d8ee6d723b 100644 --- a/mongo/client_examples_test.go +++ b/mongo/client_examples_test.go @@ -267,10 +267,11 @@ func ExampleConnect_aWS() { // The order in which the driver searches for credentials is: // // 1. Credentials passed through the URI - // 2. Environment variables - // 3. ECS endpoint if and only if AWS_CONTAINER_CREDENTIALS_RELATIVE_URI is + // 2. Custom AWS credential provider + // 3. Environment variables + // 4. ECS endpoint if and only if AWS_CONTAINER_CREDENTIALS_RELATIVE_URI is // set - // 4. EC2 endpoint + // 5. EC2 endpoint // // The following examples set the appropriate credentials via the // ClientOptions.SetAuth method. All of these credentials can be specified @@ -352,6 +353,28 @@ func ExampleConnect_aWS() { panic(err) } _ = ecClient + + // Custom AWS credential provider + + // Applications can authenticate using a custom AWS credential provider as + // well. + credential := options.Credential{ + AuthMechanism: "MONGODB-AWS", + AwsCredentialsProvider: func(_ context.Context) ( + options.Credentials, error) { + return options.Credentials{ + AccessKeyID: accessKeyID, + SecretAccessKey: secretAccessKey, + SessionToken: sessionToken, + }, nil + }, + } + awsClient, err := mongo.Connect( + options.Client().SetAuth(credential)) + if err != nil { + panic(err) + } + _ = awsClient } func ExampleConnect_stableAPI() { diff --git a/mongo/options/autoencryptionoptions.go b/mongo/options/autoencryptionoptions.go index db3508fb4f..e98eeae162 100644 --- a/mongo/options/autoencryptionoptions.go +++ b/mongo/options/autoencryptionoptions.go @@ -42,6 +42,7 @@ type AutoEncryptionOptions struct { EncryptedFieldsMap map[string]any BypassQueryAnalysis *bool KeyExpiration *time.Duration + CredentialProviders map[string]CredentialsProvider } // AutoEncryption creates a new AutoEncryptionOptions configured with default values. @@ -174,3 +175,9 @@ func (a *AutoEncryptionOptions) SetKeyExpiration(expiration time.Duration) *Auto return a } + +// SetCredentialProviders specifies options for custom credential providers. +func (a *AutoEncryptionOptions) SetCredentialProviders(providers map[string]CredentialsProvider) *AutoEncryptionOptions { + a.CredentialProviders = providers + return a +} diff --git a/mongo/options/clientencryptionoptions.go b/mongo/options/clientencryptionoptions.go index a6c477a7a9..fc73d6daa2 100644 --- a/mongo/options/clientencryptionoptions.go +++ b/mongo/options/clientencryptionoptions.go @@ -19,11 +19,12 @@ import ( // // See corresponding setter methods for documentation. type ClientEncryptionOptions struct { - KeyVaultNamespace string - KmsProviders map[string]map[string]any - TLSConfig map[string]*tls.Config - HTTPClient *http.Client - KeyExpiration *time.Duration + KeyVaultNamespace string + KmsProviders map[string]map[string]any + TLSConfig map[string]*tls.Config + HTTPClient *http.Client + KeyExpiration *time.Duration + CredentialProviders map[string]CredentialsProvider } // ClientEncryptionOptionsBuilder contains options to configure client @@ -94,6 +95,15 @@ func (c *ClientEncryptionOptionsBuilder) SetKeyExpiration(expiration time.Durati return c } +// SetCredentialProviders specifies options for custom credential providers. +func (c *ClientEncryptionOptionsBuilder) SetCredentialProviders(providers map[string]CredentialsProvider) *ClientEncryptionOptionsBuilder { + c.Opts = append(c.Opts, func(opts *ClientEncryptionOptions) error { + opts.CredentialProviders = providers + return nil + }) + return c +} + // BuildTLSConfig specifies tls.Config options for each KMS provider to use to configure TLS on all connections created // to the KMS provider. The input map should contain a mapping from each KMS provider to a document containing the necessary // options, as follows: diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index adc880a5e9..336c508773 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -116,6 +116,7 @@ type Credential struct { PasswordSet bool OIDCMachineCallback OIDCCallback OIDCHumanCallback OIDCCallback + AwsCredentialsProvider CredentialsProvider } // OIDCCallback is the type for both Human and Machine Callback flows. @@ -144,6 +145,20 @@ type IDPInfo struct { RequestScopes []string } +// CredentialsProvider is the function type that returns AWS credentials. +type CredentialsProvider func(context.Context) (Credentials, error) + +// Credentials represents AWS credentials. +type Credentials struct { + AccessKeyID string + SecretAccessKey string + SessionToken string + Source string + CanExpire bool + Expires time.Time + AccountID string +} + // BSONOptions are optional BSON marshaling and unmarshaling behaviors. type BSONOptions struct { // UseJSONStructTags causes the driver to fall back to using the "json" diff --git a/x/mongo/driver/auth/mongodbaws.go b/x/mongo/driver/auth/mongodbaws.go index dd9661e1a9..11d78c0cfa 100644 --- a/x/mongo/driver/auth/mongodbaws.go +++ b/x/mongo/driver/auth/mongodbaws.go @@ -27,27 +27,35 @@ func newMongoDBAWSAuthenticator(cred *Cred, httpClient *http.Client) (Authentica if httpClient == nil { return nil, errors.New("httpClient must not be nil") } - return &MongoDBAWSAuthenticator{ - credentials: &credproviders.StaticProvider{ - Value: credentials.Value{ - AccessKeyID: cred.Username, - SecretAccessKey: cred.Password, - SessionToken: cred.Props["AWS_SESSION_TOKEN"], + authenticator := MongoDBAWSAuthenticator{ + providers: []credentials.Provider{ + &credproviders.StaticProvider{ + Value: credentials.Value{ + AccessKeyID: cred.Username, + SecretAccessKey: cred.Password, + SessionToken: cred.Props["AWS_SESSION_TOKEN"], + }, }, }, httpClient: httpClient, - }, nil + } + if cred.AwsCredentialsProvider != nil { + authenticator.providers = append(authenticator.providers, &credproviders.AwsProvider{ + Provider: cred.AwsCredentialsProvider, + }) + } + return &authenticator, nil } // MongoDBAWSAuthenticator uses AWS-IAM credentials over SASL to authenticate a connection. type MongoDBAWSAuthenticator struct { - credentials *credproviders.StaticProvider - httpClient *http.Client + providers []credentials.Provider + httpClient *http.Client } // Auth authenticates the connection. func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) error { - providers := creds.NewAWSCredentialProvider(a.httpClient, a.credentials) + providers := creds.NewAWSCredentialProvider(a.httpClient, a.providers...) adapter := &awsSaslAdapter{ conversation: &awsConversation{ credentials: providers.Cred, diff --git a/x/mongo/driver/auth/mongodbaws_test.go b/x/mongo/driver/auth/mongodbaws_test.go index ef72d7f29f..0c251f8f0f 100644 --- a/x/mongo/driver/auth/mongodbaws_test.go +++ b/x/mongo/driver/auth/mongodbaws_test.go @@ -7,10 +7,21 @@ package auth import ( + "context" "errors" + "net/http" "testing" + "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/internal/aws" + "go.mongodb.org/mongo-driver/v2/internal/require" + "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/drivertest" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/mnet" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/wiremessage" ) func TestGetRegion(t *testing.T) { @@ -46,3 +57,153 @@ func TestGetRegion(t *testing.T) { } } + +func TestAWSCustomCredentialProvider(t *testing.T) { + t.Setenv("AWS_ACCESS_KEY_ID", "AWS_ACCESS_KEY_ID") + t.Setenv("AWS_SECRET_ACCESS_KEY", "AWS_SECRET_ACCESS_KEY") + + var cnt int + for _, tc := range []struct { + name string + cred *Cred + cnt int + }{ + { + name: "provider with cred", + cred: &Cred{ + Username: "user", + Password: "pass", + Props: map[string]string{"AWS_SESSION_TOKEN": "token"}, + AwsCredentialsProvider: func(_ context.Context) (aws.Credentials, error) { + cnt++ + return aws.Credentials{}, nil + }, + }, + cnt: 0, + }, + { + name: "provider with empty cred", + cred: &Cred{ + AwsCredentialsProvider: func(_ context.Context) (aws.Credentials, error) { + cnt++ + return aws.Credentials{}, nil + }, + }, + cnt: 1, + }, + } { + cnt = 0 + t.Run(tc.name, func(t *testing.T) { + authenticator, err := newMongoDBAWSAuthenticator( + tc.cred, + &http.Client{}, + ) + require.NoErrorf(t, err, "unexpected error %v", err) + + resps := make(chan []byte, 1) + written := make(chan []byte, 1) + var readErr chan error + go func() { + for { + processWm(resps, written, readErr) + } + }() + + desc := description.Server{ + WireVersion: &description.VersionRange{ + Max: 6, + }, + } + c := &drivertest.ChannelConn{ + Written: written, + ReadResp: resps, + ReadErr: readErr, + Desc: desc, + } + + mnetconn := mnet.NewConnection(c) + + err = authenticator.Auth(context.Background(), &driver.AuthConfig{Connection: mnetconn}) + assert.NoErrorf(t, err, "expected no error but got %v", err) + assert.Equalf(t, tc.cnt, cnt, "expected provider to be called %v times but got %v", tc.cnt, cnt) + }) + } +} + +func processWm(resps, written chan []byte, errChan chan error) { + buf := <-written + buf, ok := extractPayload(buf) + if !ok { + errChan <- errors.New("could not extract payload from message") + } + var p struct { + Payload bson.Binary `bson:"payload"` + } + err := bson.Unmarshal(buf, &p) + if err != nil { + errChan <- err + } + if p.Payload.Subtype != 0x00 { + errChan <- errors.New("unexpected payload subtype") + } + var n struct { + Nonce bson.Binary `bson:"r"` + } + err = bson.Unmarshal(p.Payload.Data, &n) + if err != nil { + errChan <- err + } + if n.Nonce.Subtype != 0x00 { + errChan <- errors.New("unexpected nonce subtype") + } + nonce := make([]byte, 64) + copy(nonce, n.Nonce.Data) + + writeReplies(resps, + bsoncore.BuildDocumentFromElements(nil, + bsoncore.AppendInt32Element(nil, "ok", 1), + bsoncore.AppendInt32Element(nil, "conversationId", 1), + bsoncore.AppendBinaryElement(nil, "payload", 0x00, bsoncore.BuildDocumentFromElements(nil, + bsoncore.AppendBinaryElement(nil, "s", n.Nonce.Subtype, nonce), + bsoncore.AppendStringElement(nil, "h", "region"), + )), + bsoncore.AppendBooleanElement(nil, "done", true), + ), + ) +} + +func extractPayload(wm []byte) (bsoncore.Document, bool) { + _, _, _, opcode, wm, ok := wiremessage.ReadHeader(wm) + if !ok { + return nil, ok + } + if opcode != wiremessage.OpMsg { + return nil, false + } + var actualPayload bsoncore.Document + _, wm, ok = wiremessage.ReadMsgFlags(wm) + if !ok { + return nil, ok + } + for loop := true; loop; { + var stype wiremessage.SectionType + stype, wm, ok = wiremessage.ReadMsgSectionType(wm) + if !ok { + return nil, ok + } + switch stype { + case wiremessage.DocumentSequence: + _, _, wm, ok = wiremessage.ReadMsgSectionDocumentSequence(wm) + if !ok { + return nil, ok + } + case wiremessage.SingleDocument: + actualPayload, _, ok = wiremessage.ReadMsgSectionSingleDocument(wm) + if !ok { + return nil, ok + } + loop = false + } + } + return actualPayload, true +} diff --git a/x/mongo/driver/driver.go b/x/mongo/driver/driver.go index 88995263e5..62850903af 100644 --- a/x/mongo/driver/driver.go +++ b/x/mongo/driver/driver.go @@ -17,6 +17,7 @@ import ( "context" "time" + "go.mongodb.org/mongo-driver/v2/internal/aws" "go.mongodb.org/mongo-driver/v2/internal/csot" "go.mongodb.org/mongo-driver/v2/mongo/address" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" @@ -73,13 +74,14 @@ type Authenticator interface { // Cred is a user's credential. type Cred struct { - Source string - Username string - Password string - PasswordSet bool - Props map[string]string - OIDCMachineCallback OIDCCallback - OIDCHumanCallback OIDCCallback + Source string + Username string + Password string + PasswordSet bool + Props map[string]string + OIDCMachineCallback OIDCCallback + OIDCHumanCallback OIDCCallback + AwsCredentialsProvider func(context.Context) (aws.Credentials, error) } // Deployment is implemented by types that can select a server from a deployment. diff --git a/x/mongo/driver/mongocrypt/mongocrypt.go b/x/mongo/driver/mongocrypt/mongocrypt.go index 91b950c371..f9d49ccff5 100644 --- a/x/mongo/driver/mongocrypt/mongocrypt.go +++ b/x/mongo/driver/mongocrypt/mongocrypt.go @@ -24,6 +24,7 @@ import ( "unsafe" "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/internal/aws/credentials" "go.mongodb.org/mongo-driver/v2/internal/httputil" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/auth/creds" @@ -35,9 +36,10 @@ type kmsProvider interface { } type MongoCrypt struct { - wrapped *C.mongocrypt_t - kmsProviders map[string]kmsProvider - httpClient *http.Client + wrapped *C.mongocrypt_t + kmsProviders map[string]kmsProvider + httpClient *http.Client + credProviders map[string]credentials.Provider } // Version returns the version string for the loaded libmongocrypt, or an empty string @@ -63,8 +65,16 @@ func NewMongoCrypt(opts *options.MongoCryptOptions) (*MongoCrypt, error) { if needsKmsProvider(opts.KmsProviders, "gcp") { kmsProviders["gcp"] = creds.NewGCPCredentialProvider(httpClient) } + provider, ok := opts.CredentialProviders["aws"] if needsKmsProvider(opts.KmsProviders, "aws") { - kmsProviders["aws"] = creds.NewAWSCredentialProvider(httpClient) + var providers []credentials.Provider + if ok { + providers = append(providers, provider) + } + kmsProviders["aws"] = creds.NewAWSCredentialProvider(httpClient, providers...) + } else if ok { + return nil, fmt.Errorf("can only provide a custom AWS credential provider " + + "when the state machine is configured for automatic AWS credential fetching") } if needsKmsProvider(opts.KmsProviders, "azure") { kmsProviders["azure"] = creds.NewAzureCredentialProvider(httpClient) diff --git a/x/mongo/driver/mongocrypt/options/mongocrypt_options.go b/x/mongo/driver/mongocrypt/options/mongocrypt_options.go index 504065d4bf..e2003c9598 100644 --- a/x/mongo/driver/mongocrypt/options/mongocrypt_options.go +++ b/x/mongo/driver/mongocrypt/options/mongocrypt_options.go @@ -10,6 +10,7 @@ import ( "net/http" "time" + "go.mongodb.org/mongo-driver/v2/internal/aws/credentials" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" ) @@ -23,6 +24,7 @@ type MongoCryptOptions struct { CryptSharedLibOverridePath string HTTPClient *http.Client KeyExpiration *time.Duration + CredentialProviders map[string]credentials.Provider } // MongoCrypt creates a new MongoCryptOptions instance. @@ -79,3 +81,9 @@ func (mo *MongoCryptOptions) SetKeyExpiration(expiration *time.Duration) *MongoC mo.KeyExpiration = expiration return mo } + +// SetCredentialProviders sets the custom credential providers. +func (mo *MongoCryptOptions) SetCredentialProviders(providers map[string]credentials.Provider) *MongoCryptOptions { + mo.CredentialProviders = providers + return mo +} diff --git a/x/mongo/driver/topology/topology_options.go b/x/mongo/driver/topology/topology_options.go index 2ddc7434bd..a46d72e4ec 100644 --- a/x/mongo/driver/topology/topology_options.go +++ b/x/mongo/driver/topology/topology_options.go @@ -14,6 +14,7 @@ import ( "time" "go.mongodb.org/mongo-driver/v2/event" + "go.mongodb.org/mongo-driver/v2/internal/aws" "go.mongodb.org/mongo-driver/v2/internal/logger" "go.mongodb.org/mongo-driver/v2/internal/optionsutil" "go.mongodb.org/mongo-driver/v2/mongo/options" @@ -112,14 +113,23 @@ func ConvertCreds(cred *options.Credential) *driver.Cred { } } + var awsCredentialsProvider func(context.Context) (aws.Credentials, error) + if cred.AwsCredentialsProvider != nil { + awsCredentialsProvider = func(ctx context.Context) (aws.Credentials, error) { + creds, err := cred.AwsCredentialsProvider(ctx) + return aws.Credentials(creds), err + } + } + return &auth.Cred{ - Source: cred.AuthSource, - Username: cred.Username, - Password: cred.Password, - PasswordSet: cred.PasswordSet, - Props: cred.AuthMechanismProperties, - OIDCMachineCallback: oidcMachineCallback, - OIDCHumanCallback: oidcHumanCallback, + Source: cred.AuthSource, + Username: cred.Username, + Password: cred.Password, + PasswordSet: cred.PasswordSet, + Props: cred.AuthMechanismProperties, + OIDCMachineCallback: oidcMachineCallback, + OIDCHumanCallback: oidcHumanCallback, + AwsCredentialsProvider: awsCredentialsProvider, } }