diff --git a/pgutils/connector.go b/pgutils/connector.go index 21dce91..1a0b773 100644 --- a/pgutils/connector.go +++ b/pgutils/connector.go @@ -5,8 +5,9 @@ import ( "errors" "fmt" "log" + "net" "net/url" - "time" + "strings" "database/sql" "database/sql/driver" @@ -20,109 +21,161 @@ import ( "github.com/lib/pq" ) -type baseConnectionStringProvider interface { - getBaseConnectionString(ctx context.Context) (string, error) -} +const defaultPostgresPort = "5432" + +var pqDriver = &pq.Driver{} -type PostgresqlConnector struct { - baseConnectionStringProvider - searchPath string +// ConnectionStringProvider returns a Postgres connection string for use by clients +// that need a DSN (e.g., pq.Listener) or to build a connector. +type ConnectionStringProvider interface { + ConnectionString(ctx context.Context) (string, error) } -func (conn *PostgresqlConnector) WithSearchPath(searchPath string) *PostgresqlConnector { - return &PostgresqlConnector{ - baseConnectionStringProvider: conn.baseConnectionStringProvider, - searchPath: searchPath, - } +type connectionStringProviderFunc func(context.Context) (string, error) + +func (f connectionStringProviderFunc) ConnectionString(ctx context.Context) (string, error) { + return f(ctx) } -func (conn *PostgresqlConnector) Connect(ctx context.Context) (driver.Conn, error) { - dsn, err := conn.GetConnectionString(ctx) +// NewConnectionStringProviderFromURLString parses rawURL and constructs a provider. +// +// Standard Postgres example: +// +// postgres://user:pass@host:5432/dbname?sslmode=require +// +// IAM example 1: +// +// postgres+rds-iam://user@host:5432/dbname +// +// IAM example 2 (cross-account): +// +// postgres+rds-iam://user@host:5432/dbname?assume_role_arn=...&assume_role_session_name=... +// +// For postgres+rds-iam, the provider generates a fresh IAM auth token on each ConnectionString(ctx) call. +func NewConnectionStringProviderFromURLString(ctx context.Context, rawURL string) (ConnectionStringProvider, error) { + u, err := url.Parse(rawURL) if err != nil { - return nil, fmt.Errorf("get connection string: %w", err) + return nil, fmt.Errorf("parsing URL: %w", err) } - pqConnector, err := pq.NewConnector(dsn) - if err != nil { - return nil, fmt.Errorf("create pq connector: %w", err) + + switch u.Scheme { + case "postgres", "postgresql": + return &staticConnectionStringProvider{connectionString: u.String()}, nil + case "postgres+rds-iam": + return newIAMConnectionStringProviderFromURL(ctx, u) + default: + return nil, fmt.Errorf("unsupported URL scheme: %q (expected postgres, postgresql, or postgres+rds-iam)", u.Scheme) } +} - return pqConnector.Connect(ctx) +// ToConnector wraps a ConnectionStringProvider as a driver.Connector. +// Each Connect(ctx) call asks the provider for a fresh DSN. +func ToConnector(provider ConnectionStringProvider) driver.Connector { + return &postgresqlConnector{connectionStringProvider: provider} } -func (conn *PostgresqlConnector) GetConnectionString(ctx context.Context) (string, error) { - dsn, err := conn.getBaseConnectionString(ctx) - if err != nil { - return "", fmt.Errorf("get base connection string: %w", err) +// WithSchemaSearchPath returns a ConnectionStringProvider that appends search_path +// to the DSN produced by the underlying provider. +func WithSchemaSearchPath(provider ConnectionStringProvider, searchPath string) ConnectionStringProvider { + return connectionStringProviderFunc(func(ctx context.Context) (string, error) { + dsn, err := provider.ConnectionString(ctx) + if err != nil { + return "", fmt.Errorf("ConnectionString failed: %w", err) + } + + dsnWithPath, err := addSearchPathToURL(dsn, searchPath) + if err != nil { + return "", fmt.Errorf("applying schema search path failed: %w", err) + } + + return dsnWithPath, nil + }) +} + +// ConnectDB opens a connection using the connector and verifies it with a ping +func ConnectDB(conn driver.Connector) (*sqlx.DB, error) { + sqlDB := sql.OpenDB(conn) + db := sqlx.NewDb(sqlDB, "postgres") + if err := db.Ping(); err != nil { + db.Close() + return nil, err } - if conn.searchPath == "" { - return dsn, nil + return db, nil +} + +// MustConnectDB is like ConnectDB but panics on error +func MustConnectDB(conn driver.Connector) *sqlx.DB { + db, err := ConnectDB(conn) + if err != nil { + panic(err) } + return db +} - // Add search path - u, err := url.Parse(dsn) +// addSearchPathToURL returns a copy of u with search_path set in the query string. +// It returns an error if search_path is already present. +func addSearchPathToURL(rawURL string, searchPath string) (string, error) { + u, err := url.Parse(rawURL) if err != nil { - return "", fmt.Errorf("parse DSN URL: %w", err) + return "", fmt.Errorf("url string failed to parse while adding search path: %w", err) + } + + if searchPath == "" { + return u.String(), nil } + q := u.Query() if v := q.Get("search_path"); v != "" { return "", fmt.Errorf("search_path already set to %q", v) } - q.Set("search_path", conn.searchPath) // url.Values will percent-encode commas as needed + q.Set("search_path", searchPath) u.RawQuery = q.Encode() return u.String(), nil } -func (c *PostgresqlConnector) Driver() driver.Driver { - return &pq.Driver{} +type postgresqlConnector struct { + connectionStringProvider ConnectionStringProvider } -type staticConnectionStringProvider struct { - connectionString string -} +func (c *postgresqlConnector) Connect(ctx context.Context) (driver.Conn, error) { + dsn, err := c.connectionStringProvider.ConnectionString(ctx) + if err != nil { + return nil, fmt.Errorf("getting connection string from provider: %w", err) + } + pqConnector, err := pq.NewConnector(dsn) + if err != nil { + return nil, fmt.Errorf("creating pq connector: %w", err) + } -func (p *staticConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) { - return p.connectionString, nil + return pqConnector.Connect(ctx) } -func NewPostgresqlConnectorFromConnectionString(connectionString string) *PostgresqlConnector { - return &PostgresqlConnector{ - baseConnectionStringProvider: &staticConnectionStringProvider{connectionString}, - } +func (c *postgresqlConnector) Driver() driver.Driver { + return pqDriver } -type IAMAuthConfig struct { - RDSEndpoint string - User string - Database string - - // Optional: cross-account role assumption. - // Set this to a role ARN in the RDS account (Account A) that has rds-db:connect. - AssumeRoleARN string - - // Optional: if your trust policy requires an external ID. - AssumeRoleExternalID string - - // Optional: override the default session name. - AssumeRoleSessionName string - - // Optional: override STS assume role duration. - // If zero, SDK default is used. - AssumeRoleDuration time.Duration +type staticConnectionStringProvider struct { + connectionString string } -type iamAuthConnectionStringProvider struct { - IAMAuthConfig +func (p *staticConnectionStringProvider) ConnectionString(ctx context.Context) (string, error) { + return p.connectionString, nil +} - region string - creds aws.CredentialsProvider +type rdsIAMConnectionStringProvider struct { + RDSEndpoint string + Region string + User string + Database string + CredentialsProvider aws.CredentialsProvider } -func (p *iamAuthConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) { - authToken, err := auth.BuildAuthToken(ctx, p.RDSEndpoint, p.region, p.User, p.creds) +func (p *rdsIAMConnectionStringProvider) ConnectionString(ctx context.Context) (string, error) { + authToken, err := auth.BuildAuthToken(ctx, p.RDSEndpoint, p.Region, p.User, p.CredentialsProvider) if err != nil { return "", fmt.Errorf("building auth token: %w", err) } - log.Printf("Signing RDS IAM token for \n Endpoint: %s \n User: %s \n Database: %s", p.RDSEndpoint, p.User, p.Database) + log.Printf("Signing RDS IAM token for Endpoint: %s User: %s Database: %s", p.RDSEndpoint, p.User, p.Database) dsnURL := &url.URL{ Scheme: "postgresql", @@ -134,9 +187,43 @@ func (p *iamAuthConnectionStringProvider) getBaseConnectionString(ctx context.Co return dsnURL.String(), nil } -func NewPostgresqlConnectorWithIAMAuth(ctx context.Context, cfg *IAMAuthConfig) (*PostgresqlConnector, error) { - if cfg.RDSEndpoint == "" || cfg.User == "" || cfg.Database == "" { - return nil, errors.New("RDS endpoint, user, and database are required") +func newIAMConnectionStringProviderFromURL(ctx context.Context, u *url.URL) (ConnectionStringProvider, error) { + user := "" + if u.User != nil { + user = u.User.Username() + if _, hasPw := u.User.Password(); hasPw { + return nil, errors.New("postgres+rds-iam URL must not include a password") + } + } + if user == "" { + return nil, errors.New("postgres+rds-iam URL missing username") + } + + host := u.Hostname() + if host == "" { + return nil, errors.New("postgres+rds-iam URL missing host") + } + + port := u.Port() + if port == "" { + port = defaultPostgresPort + } + + // Match libpq/psql defaulting: if dbname isn't specified, dbname defaults to username. + dbName := strings.TrimPrefix(u.Path, "/") + if dbName == "" { + dbName = user + } + + q := u.Query() + supportedParams := map[string]struct{}{ + "assume_role_arn": {}, + "assume_role_session_name": {}, + } + for k := range q { + if _, ok := supportedParams[k]; !ok { + return nil, fmt.Errorf("postgres+rds-iam URL has unsupported query parameter: %s", k) + } } awsCfg, err := awsconfig.LoadDefaultConfig(ctx) @@ -149,66 +236,25 @@ func NewPostgresqlConnectorWithIAMAuth(ctx context.Context, cfg *IAMAuthConfig) } creds := awsCfg.Credentials - - // Cross-account support: - // If AssumeRoleARN is set, assume a role in the RDS account (Account A) - // using the ECS task role creds from Account B as the source credentials. - if cfg.AssumeRoleARN != "" { - log.Printf("RDS IAM Assuming Role: %s for \n Endpoint: %s \n User: %s \n Database: %s", cfg.AssumeRoleARN, cfg.RDSEndpoint, cfg.User, cfg.Database) + assumeRoleARN := q.Get("assume_role_arn") + if assumeRoleARN != "" { stsClient := sts.NewFromConfig(awsCfg) - - sessionName := cfg.AssumeRoleSessionName + sessionName := q.Get("assume_role_session_name") if sessionName == "" { sessionName = "pgutils-rds-iam" } - - assumeProvider := stscreds.NewAssumeRoleProvider(stsClient, cfg.AssumeRoleARN, func(assumeRoleOpts *stscreds.AssumeRoleOptions) { - assumeRoleOpts.RoleSessionName = sessionName - - if cfg.AssumeRoleExternalID != "" { - assumeRoleOpts.ExternalID = aws.String(cfg.AssumeRoleExternalID) - } - - if cfg.AssumeRoleDuration != 0 { - assumeRoleOpts.Duration = cfg.AssumeRoleDuration - } + log.Printf("RDS IAM Assuming Role: %s with session name: %s for Host: %s User: %s Database: %s", assumeRoleARN, sessionName, host, user, dbName) + assumeProvider := stscreds.NewAssumeRoleProvider(stsClient, assumeRoleARN, func(opts *stscreds.AssumeRoleOptions) { + opts.RoleSessionName = sessionName }) - - // Cache to avoid calling STS too frequently. creds = aws.NewCredentialsCache(assumeProvider) } - return &PostgresqlConnector{ - baseConnectionStringProvider: &iamAuthConnectionStringProvider{ - IAMAuthConfig: *cfg, - region: awsCfg.Region, - creds: creds, - }, + return &rdsIAMConnectionStringProvider{ + Region: awsCfg.Region, + RDSEndpoint: net.JoinHostPort(host, port), + User: user, + Database: dbName, + CredentialsProvider: creds, }, nil } - -// Provides missing sqlx.OpenDB -func OpenDB(conn *PostgresqlConnector) *sqlx.DB { - sqlDB := sql.OpenDB(conn) - return sqlx.NewDb(sqlDB, "postgres") -} - -// ConnectDB opens a connection using the connector and verifies it with a ping -func ConnectDB(conn *PostgresqlConnector) (*sqlx.DB, error) { - db := OpenDB(conn) - if err := db.Ping(); err != nil { - db.Close() - return nil, err - } - return db, nil -} - -// MustConnectDB is like ConnectDB but panics on error -func MustConnectDB(conn *PostgresqlConnector) *sqlx.DB { - db, err := ConnectDB(conn) - if err != nil { - panic(err) - } - return db -} - diff --git a/pgutils/listener.go b/pgutils/listener.go index 958462c..d1a7d06 100644 --- a/pgutils/listener.go +++ b/pgutils/listener.go @@ -69,7 +69,7 @@ func listenerEventToString(t pq.ListenerEventType) string { // The callback is invoked from the listener goroutine; it MUST NOT block // for long periods. If you need to do heavy work, offload it to another // goroutine. -func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string, callback func(*pq.Notification), onClose func()) error { +func Listen(ctx context.Context, provider ConnectionStringProvider, pgChannelName string, callback func(*pq.Notification), onClose func()) error { if callback == nil { return fmt.Errorf("listener callback cannot be nil") } @@ -77,9 +77,9 @@ func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string reconnectEventCh := make(chan struct{}, 1) // We just need a single reconnect event to trigger, so buffer size of 1 makeListener := func() (*pq.Listener, error) { - url, err := conn.GetConnectionString(ctx) + url, err := provider.ConnectionString(ctx) if err != nil { - return nil, fmt.Errorf("get url: %w", err) + return nil, fmt.Errorf("error getting connection string from provider: %w", err) } cb := func(t pq.ListenerEventType, e error) { @@ -174,4 +174,3 @@ func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string return nil } - diff --git a/pgutils/pgutils.go b/pgutils/pgutils.go index 5cd1bb7..659cd96 100644 --- a/pgutils/pgutils.go +++ b/pgutils/pgutils.go @@ -3,6 +3,7 @@ package pgutils import ( "crypto/sha1" "fmt" + "net/url" "regexp" "strings" "testing" @@ -128,3 +129,14 @@ func getCanonicalFormat(s string) string { str = re.ReplaceAllString(str, "=") return str } + +func CensorDSNForLogs(dsn string) string { + u, err := url.Parse(dsn) + if err != nil { + return "" + } + if u.User != nil { + u.User = url.UserPassword(u.User.Username(), "xxx") + } + return u.String() +}