Skip to content
288 changes: 167 additions & 121 deletions pgutils/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ import (
"errors"
"fmt"
"log"
"net"
"net/url"
"time"
"strings"

"database/sql"
"database/sql/driver"
Expand All @@ -20,109 +21,161 @@ import (
"github.com/lib/pq"
)

type baseConnectionStringProvider interface {
getBaseConnectionString(ctx context.Context) (string, error)
}
const defaultPostgresPort = "5432"

var pqDriver = &pq.Driver{}

type PostgresqlConnector struct {
baseConnectionStringProvider
searchPath string
// ConnectionStringProvider returns a Postgres connection string for use by clients
// that need a DSN (e.g., pq.Listener) or to build a connector.
type ConnectionStringProvider interface {
ConnectionString(ctx context.Context) (string, error)
}

func (conn *PostgresqlConnector) WithSearchPath(searchPath string) *PostgresqlConnector {
return &PostgresqlConnector{
baseConnectionStringProvider: conn.baseConnectionStringProvider,
searchPath: searchPath,
}
type connectionStringProviderFunc func(context.Context) (string, error)

func (f connectionStringProviderFunc) ConnectionString(ctx context.Context) (string, error) {
return f(ctx)
}

func (conn *PostgresqlConnector) Connect(ctx context.Context) (driver.Conn, error) {
dsn, err := conn.GetConnectionString(ctx)
// NewConnectionStringProviderFromURLString parses rawURL and constructs a provider.
//
// Standard Postgres example:
//
// postgres://user:pass@host:5432/dbname?sslmode=require
//
// IAM example 1:
//
// postgres+rds-iam://user@host:5432/dbname
//
// IAM example 2 (cross-account):
//
// postgres+rds-iam://user@host:5432/dbname?assume_role_arn=...&assume_role_session_name=...
//
// For postgres+rds-iam, the provider generates a fresh IAM auth token on each ConnectionString(ctx) call.
func NewConnectionStringProviderFromURLString(ctx context.Context, rawURL string) (ConnectionStringProvider, error) {
u, err := url.Parse(rawURL)
if err != nil {
return nil, fmt.Errorf("get connection string: %w", err)
return nil, fmt.Errorf("parsing URL: %w", err)
}
pqConnector, err := pq.NewConnector(dsn)
if err != nil {
return nil, fmt.Errorf("create pq connector: %w", err)

switch u.Scheme {
case "postgres", "postgresql":
return &staticConnectionStringProvider{connectionString: u.String()}, nil
case "postgres+rds-iam":
return newIAMConnectionStringProviderFromURL(ctx, u)
default:
return nil, fmt.Errorf("unsupported URL scheme: %q (expected postgres, postgresql, or postgres+rds-iam)", u.Scheme)
}
}

return pqConnector.Connect(ctx)
// ToConnector wraps a ConnectionStringProvider as a driver.Connector.
// Each Connect(ctx) call asks the provider for a fresh DSN.
func ToConnector(provider ConnectionStringProvider) driver.Connector {
return &postgresqlConnector{connectionStringProvider: provider}
}

func (conn *PostgresqlConnector) GetConnectionString(ctx context.Context) (string, error) {
dsn, err := conn.getBaseConnectionString(ctx)
if err != nil {
return "", fmt.Errorf("get base connection string: %w", err)
// WithSchemaSearchPath returns a ConnectionStringProvider that appends search_path
// to the DSN produced by the underlying provider.
func WithSchemaSearchPath(provider ConnectionStringProvider, searchPath string) ConnectionStringProvider {
return connectionStringProviderFunc(func(ctx context.Context) (string, error) {
dsn, err := provider.ConnectionString(ctx)
if err != nil {
return "", fmt.Errorf("ConnectionString failed: %w", err)
}

dsnWithPath, err := addSearchPathToURL(dsn, searchPath)
if err != nil {
return "", fmt.Errorf("applying schema search path failed: %w", err)
}

return dsnWithPath, nil
})
}

// ConnectDB opens a connection using the connector and verifies it with a ping
func ConnectDB(conn driver.Connector) (*sqlx.DB, error) {
sqlDB := sql.OpenDB(conn)
db := sqlx.NewDb(sqlDB, "postgres")
if err := db.Ping(); err != nil {
db.Close()
return nil, err
}
if conn.searchPath == "" {
return dsn, nil
return db, nil
}

// MustConnectDB is like ConnectDB but panics on error
func MustConnectDB(conn driver.Connector) *sqlx.DB {
db, err := ConnectDB(conn)
if err != nil {
panic(err)
}
return db
}

// Add search path
u, err := url.Parse(dsn)
// addSearchPathToURL returns a copy of u with search_path set in the query string.
// It returns an error if search_path is already present.
func addSearchPathToURL(rawURL string, searchPath string) (string, error) {
u, err := url.Parse(rawURL)
if err != nil {
return "", fmt.Errorf("parse DSN URL: %w", err)
return "", fmt.Errorf("url string failed to parse while adding search path: %w", err)
}

if searchPath == "" {
return u.String(), nil
}

q := u.Query()
if v := q.Get("search_path"); v != "" {
return "", fmt.Errorf("search_path already set to %q", v)
}
q.Set("search_path", conn.searchPath) // url.Values will percent-encode commas as needed
q.Set("search_path", searchPath)
u.RawQuery = q.Encode()
return u.String(), nil
}

func (c *PostgresqlConnector) Driver() driver.Driver {
return &pq.Driver{}
type postgresqlConnector struct {
connectionStringProvider ConnectionStringProvider
}

type staticConnectionStringProvider struct {
connectionString string
}
func (c *postgresqlConnector) Connect(ctx context.Context) (driver.Conn, error) {
dsn, err := c.connectionStringProvider.ConnectionString(ctx)
if err != nil {
return nil, fmt.Errorf("getting connection string from provider: %w", err)
}
pqConnector, err := pq.NewConnector(dsn)
if err != nil {
return nil, fmt.Errorf("creating pq connector: %w", err)
}

func (p *staticConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) {
return p.connectionString, nil
return pqConnector.Connect(ctx)
}

func NewPostgresqlConnectorFromConnectionString(connectionString string) *PostgresqlConnector {
return &PostgresqlConnector{
baseConnectionStringProvider: &staticConnectionStringProvider{connectionString},
}
func (c *postgresqlConnector) Driver() driver.Driver {
return pqDriver
}

type IAMAuthConfig struct {
RDSEndpoint string
User string
Database string

// Optional: cross-account role assumption.
// Set this to a role ARN in the RDS account (Account A) that has rds-db:connect.
AssumeRoleARN string

// Optional: if your trust policy requires an external ID.
AssumeRoleExternalID string

// Optional: override the default session name.
AssumeRoleSessionName string

// Optional: override STS assume role duration.
// If zero, SDK default is used.
AssumeRoleDuration time.Duration
type staticConnectionStringProvider struct {
connectionString string
}

type iamAuthConnectionStringProvider struct {
IAMAuthConfig
func (p *staticConnectionStringProvider) ConnectionString(ctx context.Context) (string, error) {
return p.connectionString, nil
}

region string
creds aws.CredentialsProvider
type rdsIAMConnectionStringProvider struct {
RDSEndpoint string
Region string
User string
Database string
CredentialsProvider aws.CredentialsProvider
}

func (p *iamAuthConnectionStringProvider) getBaseConnectionString(ctx context.Context) (string, error) {
authToken, err := auth.BuildAuthToken(ctx, p.RDSEndpoint, p.region, p.User, p.creds)
func (p *rdsIAMConnectionStringProvider) ConnectionString(ctx context.Context) (string, error) {
authToken, err := auth.BuildAuthToken(ctx, p.RDSEndpoint, p.Region, p.User, p.CredentialsProvider)
if err != nil {
return "", fmt.Errorf("building auth token: %w", err)
}
log.Printf("Signing RDS IAM token for \n Endpoint: %s \n User: %s \n Database: %s", p.RDSEndpoint, p.User, p.Database)
log.Printf("Signing RDS IAM token for Endpoint: %s User: %s Database: %s", p.RDSEndpoint, p.User, p.Database)

dsnURL := &url.URL{
Scheme: "postgresql",
Expand All @@ -134,9 +187,43 @@ func (p *iamAuthConnectionStringProvider) getBaseConnectionString(ctx context.Co
return dsnURL.String(), nil
}

func NewPostgresqlConnectorWithIAMAuth(ctx context.Context, cfg *IAMAuthConfig) (*PostgresqlConnector, error) {
if cfg.RDSEndpoint == "" || cfg.User == "" || cfg.Database == "" {
return nil, errors.New("RDS endpoint, user, and database are required")
func newIAMConnectionStringProviderFromURL(ctx context.Context, u *url.URL) (ConnectionStringProvider, error) {
user := ""
if u.User != nil {
user = u.User.Username()
if _, hasPw := u.User.Password(); hasPw {
return nil, errors.New("postgres+rds-iam URL must not include a password")
}
}
if user == "" {
return nil, errors.New("postgres+rds-iam URL missing username")
}

host := u.Hostname()
if host == "" {
return nil, errors.New("postgres+rds-iam URL missing host")
}

port := u.Port()
if port == "" {
port = defaultPostgresPort
}

// Match libpq/psql defaulting: if dbname isn't specified, dbname defaults to username.
dbName := strings.TrimPrefix(u.Path, "/")
if dbName == "" {
dbName = user
}

q := u.Query()
supportedParams := map[string]struct{}{
"assume_role_arn": {},
"assume_role_session_name": {},
}
for k := range q {
if _, ok := supportedParams[k]; !ok {
return nil, fmt.Errorf("postgres+rds-iam URL has unsupported query parameter: %s", k)
}
}

awsCfg, err := awsconfig.LoadDefaultConfig(ctx)
Expand All @@ -149,66 +236,25 @@ func NewPostgresqlConnectorWithIAMAuth(ctx context.Context, cfg *IAMAuthConfig)
}

creds := awsCfg.Credentials

// Cross-account support:
// If AssumeRoleARN is set, assume a role in the RDS account (Account A)
// using the ECS task role creds from Account B as the source credentials.
if cfg.AssumeRoleARN != "" {
log.Printf("RDS IAM Assuming Role: %s for \n Endpoint: %s \n User: %s \n Database: %s", cfg.AssumeRoleARN, cfg.RDSEndpoint, cfg.User, cfg.Database)
assumeRoleARN := q.Get("assume_role_arn")
if assumeRoleARN != "" {
stsClient := sts.NewFromConfig(awsCfg)

sessionName := cfg.AssumeRoleSessionName
sessionName := q.Get("assume_role_session_name")
if sessionName == "" {
sessionName = "pgutils-rds-iam"
}

assumeProvider := stscreds.NewAssumeRoleProvider(stsClient, cfg.AssumeRoleARN, func(assumeRoleOpts *stscreds.AssumeRoleOptions) {
assumeRoleOpts.RoleSessionName = sessionName

if cfg.AssumeRoleExternalID != "" {
assumeRoleOpts.ExternalID = aws.String(cfg.AssumeRoleExternalID)
}

if cfg.AssumeRoleDuration != 0 {
assumeRoleOpts.Duration = cfg.AssumeRoleDuration
}
log.Printf("RDS IAM Assuming Role: %s with session name: %s for Host: %s User: %s Database: %s", assumeRoleARN, sessionName, host, user, dbName)
assumeProvider := stscreds.NewAssumeRoleProvider(stsClient, assumeRoleARN, func(opts *stscreds.AssumeRoleOptions) {
opts.RoleSessionName = sessionName
})

// Cache to avoid calling STS too frequently.
creds = aws.NewCredentialsCache(assumeProvider)
}

return &PostgresqlConnector{
baseConnectionStringProvider: &iamAuthConnectionStringProvider{
IAMAuthConfig: *cfg,
region: awsCfg.Region,
creds: creds,
},
return &rdsIAMConnectionStringProvider{
Region: awsCfg.Region,
RDSEndpoint: net.JoinHostPort(host, port),
User: user,
Database: dbName,
CredentialsProvider: creds,
}, nil
}

// Provides missing sqlx.OpenDB
func OpenDB(conn *PostgresqlConnector) *sqlx.DB {
sqlDB := sql.OpenDB(conn)
return sqlx.NewDb(sqlDB, "postgres")
}

// ConnectDB opens a connection using the connector and verifies it with a ping
func ConnectDB(conn *PostgresqlConnector) (*sqlx.DB, error) {
db := OpenDB(conn)
if err := db.Ping(); err != nil {
db.Close()
return nil, err
}
return db, nil
}

// MustConnectDB is like ConnectDB but panics on error
func MustConnectDB(conn *PostgresqlConnector) *sqlx.DB {
db, err := ConnectDB(conn)
if err != nil {
panic(err)
}
return db
}

7 changes: 3 additions & 4 deletions pgutils/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,17 @@ func listenerEventToString(t pq.ListenerEventType) string {
// The callback is invoked from the listener goroutine; it MUST NOT block
// for long periods. If you need to do heavy work, offload it to another
// goroutine.
func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string, callback func(*pq.Notification), onClose func()) error {
func Listen(ctx context.Context, provider ConnectionStringProvider, pgChannelName string, callback func(*pq.Notification), onClose func()) error {
if callback == nil {
return fmt.Errorf("listener callback cannot be nil")
}

reconnectEventCh := make(chan struct{}, 1) // We just need a single reconnect event to trigger, so buffer size of 1

makeListener := func() (*pq.Listener, error) {
url, err := conn.GetConnectionString(ctx)
url, err := provider.ConnectionString(ctx)
if err != nil {
return nil, fmt.Errorf("get url: %w", err)
return nil, fmt.Errorf("error getting connection string from provider: %w", err)
}

cb := func(t pq.ListenerEventType, e error) {
Expand Down Expand Up @@ -174,4 +174,3 @@ func Listen(ctx context.Context, conn *PostgresqlConnector, pgChannelName string

return nil
}

Loading