diff --git a/snowflake_test.go b/snowflake_test.go index e6848c2..d8cceae 100644 --- a/snowflake_test.go +++ b/snowflake_test.go @@ -28,23 +28,40 @@ import ( const ( envVarSnowflakeAccount = "SNOWFLAKE_ACCOUNT" envVarSnowflakeUser = "SNOWFLAKE_USER" - envVarSnowflakePassword = "SNOWFLAKE_PASSWORD" envVarSnowflakeDatabase = "SNOWFLAKE_DATABASE" envVarSnowflakeSchema = "SNOWFLAKE_SCHEMA" envVarSnowflakePrivateKey = "SNOWFLAKE_PRIVATE_KEY" envVarRunAccTests = "VAULT_ACC" + + defaultRSAKeyCreationStmts = ` +CREATE USER {{username}} RSA_PUBLIC_KEY='{{public_key}}'; +GRANT ROLE public TO USER {{username}}; +GRANT USAGE ON DATABASE %s TO USER {{username}};` + + defaultPasswordCreationStmts = ` +CREATE USER {{name}} PASSWORD = '{{password}}' DEFAULT_ROLE = public; +GRANT ROLE public TO USER {{name}}; +GRANT USAGE ON DATABASE %s TO USER {{username}};` + + defaultUsageCreationStmt = "GRANT USAGE ON DATABASE %s TO USER {{username}};" ) var runAcceptanceTests = os.Getenv(envVarRunAccTests) != "" -func connUrl(t *testing.T) string { - connURL, err := dsnString() +func connDetails(t *testing.T) (string, []byte, string) { + connURL, rawBase64PrivateKey, user, err := getKeyPairAuthParameters("") + if err != nil { + t.Fatalf("failed to retrieve connection URL: %s", err) + } + + // decode base64 encoded private key from environment + privateKey, err := base64.StdEncoding.DecodeString(rawBase64PrivateKey) if err != nil { - t.Fatalf("failed to retrieve connection DSN: %s", err) + t.Fatalf("failed to decode private key: %s", err) } - return connURL + return connURL, privateKey, user } // TestSnowflakeSQL_Initialize ensures initializing the Snowflake @@ -55,39 +72,6 @@ func TestSnowflakeSQL_Initialize(t *testing.T) { t.SkipNow() } - t.Run("userpass auth", func(t *testing.T) { - db := new() - defer dbtesting.AssertClose(t, db) - - connURL, err := dsnString() - if err != nil { - t.Fatalf("failed to retrieve connection DSN: %s", err) - } - - expectedConfig := map[string]interface{}{ - "connection_url": connURL, - dbplugin.SupportedCredentialTypesKey: []interface{}{ - dbplugin.CredentialTypePassword.String(), - dbplugin.CredentialTypeRSAPrivateKey.String(), - }, - } - req := dbplugin.InitializeRequest{ - Config: map[string]interface{}{ - "connection_url": connURL, - }, - VerifyConnection: true, - } - resp := dbtesting.AssertInitialize(t, db, req) - if !reflect.DeepEqual(resp.Config, expectedConfig) { - t.Fatalf("Actual: %#v\nExpected: %#v", resp.Config, expectedConfig) - } - - connProducer := db.snowflakeConnectionProducer - if !connProducer.Initialized { - t.Fatal("Database should be initialized") - } - }) - // the environment variable SNOWFLAKE_PRIVATE_KEY in CI // is a base64 encoded string. As such, this test expects the // input for the variable to be base64 encoded @@ -95,16 +79,7 @@ func TestSnowflakeSQL_Initialize(t *testing.T) { db := new() defer dbtesting.AssertClose(t, db) - connURL, rawBase64PrivateKey, user, err := getKeyPairAuthParameters("") - if err != nil { - t.Fatalf("failed to retrieve connection URL: %s", err) - } - - // decode base64 encoded private key from environment - privateKey, err := base64.StdEncoding.DecodeString(rawBase64PrivateKey) - if err != nil { - t.Fatalf("failed to decode private key: %s", err) - } + connURL, privateKey, user := connDetails(t) expectedConfig := map[string]interface{}{ "connection_url": connURL, @@ -187,6 +162,8 @@ func TestSnowflake_NewUser(t *testing.T) { t.SkipNow() } + dbName := getTestDatabase(t) + type testCase struct { creationStmts []string credentialType dbplugin.CredentialType @@ -204,9 +181,7 @@ func TestSnowflake_NewUser(t *testing.T) { "new user with password credential using name": { credentialType: dbplugin.CredentialTypePassword, creationStmts: []string{ - ` - CREATE USER {{name}} PASSWORD = '{{password}}' DEFAULT_ROLE = public; - GRANT ROLE public TO USER {{name}};`, + fmt.Sprintf(defaultPasswordCreationStmts, dbName), }, password: "y8fva_sdVA3rasf", }, @@ -215,22 +190,21 @@ func TestSnowflake_NewUser(t *testing.T) { creationStmts: []string{ "CREATE USER {{username}} PASSWORD = '{{password}}';", "GRANT ROLE public TO USER {{username}};", + fmt.Sprintf(defaultUsageCreationStmt, dbName), }, password: "secure_password", }, "new user with 2048 bit rsa_private_key credential": { credentialType: dbplugin.CredentialTypeRSAPrivateKey, creationStmts: []string{ - ` - CREATE USER {{username}} RSA_PUBLIC_KEY='{{public_key}}'; - GRANT ROLE public TO USER {{username}};`, + fmt.Sprintf(defaultRSAKeyCreationStmts, dbName), }, keyBits: 2048, }, "new user with 3072 bit rsa_private_key credential": { credentialType: dbplugin.CredentialTypeRSAPrivateKey, creationStmts: []string{ - "CREATE USER {{username}} RSA_PUBLIC_KEY='{{public_key}}';", + fmt.Sprintf(defaultRSAKeyCreationStmts, dbName), }, keyBits: 3072, }, @@ -239,6 +213,7 @@ func TestSnowflake_NewUser(t *testing.T) { creationStmts: []string{ "CREATE USER {{username}} RSA_PUBLIC_KEY='{{public_key}}';", "GRANT ROLE public TO USER {{username}};", + fmt.Sprintf(defaultUsageCreationStmt, dbName), }, keyBits: 4096, }, @@ -246,7 +221,7 @@ func TestSnowflake_NewUser(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { - connURL := connUrl(t) + connURL, privateKey, user := connDetails(t) db := new() defer dbtesting.AssertClose(t, db) @@ -254,6 +229,8 @@ func TestSnowflake_NewUser(t *testing.T) { initReq := dbplugin.InitializeRequest{ Config: map[string]interface{}{ "connection_url": connURL, + "username": user, + "private_key": privateKey, }, VerifyConnection: true, } @@ -284,7 +261,7 @@ func TestSnowflake_NewUser(t *testing.T) { } else if err != nil { t.Fatalf("failed to create user %s", err) } - defer attemptDropUser(connURL, createResp.Username) + defer attemptDropUser(connURL, user, createResp.Username, privateKey) assertPasswordCredentialsExist(t, connURL, createResp.Username, test.password) case dbplugin.CredentialTypeRSAPrivateKey: @@ -297,7 +274,7 @@ func TestSnowflake_NewUser(t *testing.T) { } else if err != nil { t.Fatalf("failed to create user %s", err) } - defer attemptDropUser(connURL, createResp.Username) + defer attemptDropUser(connURL, user, createResp.Username, privateKey) assertRSAKeyPairCredentialsExist(t, connURL, createResp.Username, priv) } }) @@ -309,7 +286,8 @@ func TestSnowflake_RenewUser(t *testing.T) { t.SkipNow() } - connURL := connUrl(t) + connURL, privateKey, user := connDetails(t) + dbName := getTestDatabase(t) db := new() defer dbtesting.AssertClose(t, db) @@ -317,6 +295,8 @@ func TestSnowflake_RenewUser(t *testing.T) { initReq := dbplugin.InitializeRequest{ Config: map[string]interface{}{ "connection_url": connURL, + "username": user, + "private_key": privateKey, }, VerifyConnection: true, } @@ -331,9 +311,7 @@ func TestSnowflake_RenewUser(t *testing.T) { }, Statements: dbplugin.Statements{ Commands: []string{ - ` - CREATE USER {{name}} PASSWORD = '{{password}}'; - GRANT ROLE public TO USER {{name}};`, + fmt.Sprintf(defaultPasswordCreationStmts, dbName), }, }, Password: password, @@ -341,7 +319,7 @@ func TestSnowflake_RenewUser(t *testing.T) { } createResp := dbtesting.AssertNewUser(t, db, createReq) - defer attemptDropUser(connURL, createResp.Username) + defer attemptDropUser(connURL, user, createResp.Username, privateKey) assertPasswordCredentialsExist(t, connURL, createResp.Username, password) @@ -365,7 +343,8 @@ func TestSnowflake_RevokeUser(t *testing.T) { t.SkipNow() } - connURL := connUrl(t) + connURL, privateKey, user := connDetails(t) + dbName := getTestDatabase(t) type testCase struct { deleteStatements []string @@ -395,6 +374,8 @@ func TestSnowflake_RevokeUser(t *testing.T) { initReq := dbplugin.InitializeRequest{ Config: map[string]interface{}{ "connection_url": connURL, + "username": user, + "private_key": privateKey, }, VerifyConnection: true, } @@ -409,9 +390,7 @@ func TestSnowflake_RevokeUser(t *testing.T) { }, Statements: dbplugin.Statements{ Commands: []string{ - ` - CREATE USER {{name}} PASSWORD = '{{password}}'; - GRANT ROLE public TO USER {{name}};`, + fmt.Sprintf(defaultPasswordCreationStmts, dbName), }, }, Password: password, @@ -439,7 +418,8 @@ func TestSnowflake_DefaultUsernameTemplate(t *testing.T) { t.SkipNow() } - connURL := connUrl(t) + connURL, privateKey, user := connDetails(t) + dbName := getTestDatabase(t) db := new() defer dbtesting.AssertClose(t, db) @@ -447,6 +427,8 @@ func TestSnowflake_DefaultUsernameTemplate(t *testing.T) { initReq := dbplugin.InitializeRequest{ Config: map[string]interface{}{ "connection_url": connURL, + "username": user, + "private_key": privateKey, }, VerifyConnection: true, } @@ -460,16 +442,14 @@ func TestSnowflake_DefaultUsernameTemplate(t *testing.T) { }, Statements: dbplugin.Statements{ Commands: []string{ - ` - CREATE USER {{name}} PASSWORD = '{{password}}'; - GRANT ROLE public TO USER {{name}};`, + fmt.Sprintf(defaultPasswordCreationStmts, dbName), }, }, Password: password, Expiration: time.Now().Add(time.Hour), } createResp := dbtesting.AssertNewUser(t, db, createReq) - defer attemptDropUser(connURL, createResp.Username) + defer attemptDropUser(connURL, user, createResp.Username, privateKey) if createResp.Username == "" { t.Fatalf("Missing username") @@ -485,7 +465,8 @@ func TestSnowflake_CustomUsernameTemplate(t *testing.T) { t.SkipNow() } - connURL := connUrl(t) + connURL, privateKey, user := connDetails(t) + dbName := getTestDatabase(t) db := new() defer dbtesting.AssertClose(t, db) @@ -493,6 +474,8 @@ func TestSnowflake_CustomUsernameTemplate(t *testing.T) { initReq := dbplugin.InitializeRequest{ Config: map[string]interface{}{ "connection_url": connURL, + "username": user, + "private_key": privateKey, "username_template": "{{.DisplayName}}_{{random 10}}", }, VerifyConnection: true, @@ -507,16 +490,14 @@ func TestSnowflake_CustomUsernameTemplate(t *testing.T) { }, Statements: dbplugin.Statements{ Commands: []string{ - ` - CREATE USER {{name}} PASSWORD = '{{password}}'; - GRANT ROLE public TO USER {{name}};`, + fmt.Sprintf(defaultPasswordCreationStmts, dbName), }, }, Password: password, Expiration: time.Now().Add(time.Hour), } createResp := dbtesting.AssertNewUser(t, db, createReq) - defer attemptDropUser(connURL, createResp.Username) + defer attemptDropUser(connURL, user, createResp.Username, privateKey) if createResp.Username == "" { t.Fatalf("Missing username") @@ -527,31 +508,6 @@ func TestSnowflake_CustomUsernameTemplate(t *testing.T) { require.Regexp(t, `^test_[a-zA-Z0-9]{10}$`, createResp.Username) } -func dsnString() (string, error) { - user := os.Getenv(envVarSnowflakeUser) - password := os.Getenv(envVarSnowflakePassword) - account := os.Getenv(envVarSnowflakeAccount) - - var err error - if user == "" { - err = multierror.Append(err, fmt.Errorf("SNOWFLAKE_USER not set")) - } - if password == "" { - err = multierror.Append(err, fmt.Errorf("SNOWFLAKE_PASSWORD not set")) - } - if account == "" { - err = multierror.Append(err, fmt.Errorf("SNOWFLAKE_ACCOUNT not set")) - } - - if err != nil { - return "", err - } - - dsnString := fmt.Sprintf("%s:%s@%s", user, password, account) - - return dsnString, nil -} - func getKeyPairAuthParameters(optionalQueryParams string) (connURL string, pKey string, user string, err error) { user = os.Getenv(envVarSnowflakeUser) pKey = os.Getenv(envVarSnowflakePrivateKey) @@ -572,7 +528,7 @@ func getKeyPairAuthParameters(optionalQueryParams string) (connURL string, pKey err = multierror.Append(err, fmt.Errorf("SNOWFLAKE_DATABASE not set")) } - connURL = fmt.Sprintf("%s.snowflakecomputing.com/%s", user, database) + connURL = fmt.Sprintf("%s.snowflakecomputing.com/%s", account, database) if optionalQueryParams != "" { connURL = fmt.Sprintf("%s?%s", connURL, optionalQueryParams) @@ -582,7 +538,10 @@ func getKeyPairAuthParameters(optionalQueryParams string) (connURL string, pKey } func verifyConnWithKeyPairCredential(connString, username string, private *rsa.PrivateKey) error { - conf, err := gosnowflake.ParseDSN(connString) + // empty password always fails here, so we set a placeholder so we can parse + // this is cleared out in the config below + url := fmt.Sprintf("%s:%s@%s", username, "empty", connString) + conf, err := gosnowflake.ParseDSN(url) if err != nil { return err } @@ -595,6 +554,7 @@ func verifyConnWithKeyPairCredential(connString, username string, private *rsa.P Schema: conf.Schema, User: username, PrivateKey: private, + Password: "", } dsn, err := gosnowflake.DSN(config) if err != nil { @@ -610,7 +570,8 @@ func verifyConnWithKeyPairCredential(connString, username string, private *rsa.P } func verifyConnWithPasswordCredential(connString, username, password string) error { - conf, err := gosnowflake.ParseDSN(connString) + connURL := fmt.Sprintf("%s:%s@%s", username, password, connString) + conf, err := gosnowflake.ParseDSN(connURL) if err != nil { return err } @@ -665,17 +626,18 @@ func assertRSAKeyPairCredentialsExist(t *testing.T, connString, username string, } } -func assertRSAKeyPairCredentialsDoNotExist(t *testing.T, connString, username string, private *rsa.PrivateKey) { - t.Helper() - err := verifyConnWithKeyPairCredential(connString, username, private) - if err == nil { - t.Fatalf("logged in when it shouldn't have been able to") +func getTestDatabase(t *testing.T) string { + database := os.Getenv(envVarSnowflakeDatabase) + if database == "" { + t.Fatalf("SNOWFLAKE_DATABASE not set") } + + return database } // Needed to not clutter the shared instance with testing artifacts -func attemptDropUser(connString, username string) { - db, err := sql.Open("snowflake", connString) +func attemptDropUser(connString, rootUser, username string, privateKey []byte) { + db, err := openSnowflake(connString, rootUser, privateKey) if err != nil { log.Printf("connection issue: %s", err) }