From 8750518fa8de86349ecec71d9118298ed7f7497e Mon Sep 17 00:00:00 2001 From: Feike Steenbergen Date: Tue, 21 May 2019 21:04:32 +0200 Subject: [PATCH] Support connection service file and service names The whole point of supporting this can best be said by directly quoting the PostgreSQL manual: > The connection service file allows libpq connection parameters to be > associated with a single service name. That service name can then be > specified by a libpq connection, and the associated settings will be > used. This allows connection parameters to be modified without > requiring a recompile of the libpq application. The service name can > also be specified using the PGSERVICE environment variable. source: https://www.postgresql.org/docs/current/libpq-pgservice.html Fixes #538 --- conn.go | 60 ++++++++++++++++++++++++++++++++++++++++++++++- conn_test.go | 66 ++++++++++++++++++++++++++++++++++++++++++++++++++++ connector.go | 19 +++++++++++++++ 3 files changed, 144 insertions(+), 1 deletion(-) diff --git a/conn.go b/conn.go index bab765bb6..9ffac0c0a 100644 --- a/conn.go +++ b/conn.go @@ -1876,7 +1876,9 @@ func parseEnviron(env []string) (out map[string]string) { accrue("user") case "PGPASSWORD": accrue("password") - case "PGSERVICE", "PGSERVICEFILE", "PGREALM": + case "PGSERVICE": + accrue("service") + case "PGREALM": unsupported() case "PGOPTIONS": accrue("options") @@ -1914,6 +1916,62 @@ func parseEnviron(env []string) (out map[string]string) { return out } +// parseServiceFile parses the options from a service file and adds them to the values. +// +// The parsing code is based on parseServiceInfo from libpq's fe-connect.c +func parseServiceFile(service string, o values) error { + filename := os.Getenv("PGSERVICEFILE") + if filename == "" { + // XXX this code doesn't work on Windows where the default filename is + // XXX %APPDATA%\postgresql\.pg_service.conf + // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470 + userHome := os.Getenv("HOME") + if userHome == "" { + user, err := user.Current() + if err != nil { + return err + } + userHome = user.HomeDir + } + filename = filepath.Join(userHome, ".pg_service.conf") + } + + file, err := os.Open(filename) + if err != nil { + return err + } + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + + // once we find the header of our section, we can start reading + if line == fmt.Sprintf("[%s]", service) { + for scanner.Scan() { + line = strings.TrimSpace(scanner.Text()) + // once we find the next section, we're done + if strings.HasPrefix(line, "[") { + return nil + } else if line != "" { + if err := parseOpts(line, o); err != nil { + return err + } + } + } + // EOF means we're done + return nil + } + } + + if err := scanner.Err(); err != nil { + return err + } + + // if we end up here, we didn't find the service that was explicitly provided + return fmt.Errorf(`definition of service "%s" not found`, service) +} + // isUTF8 returns whether name is a fuzzy variation of the string "UTF-8". func isUTF8(name string) bool { // Recognize all sorts of silly things as "UTF-8", like Postgres does diff --git a/conn_test.go b/conn_test.go index 0d25c9554..71007eab1 100644 --- a/conn_test.go +++ b/conn_test.go @@ -140,6 +140,72 @@ func TestOpenURL(t *testing.T) { testURL("postgresql://") } +func TestPgServiceFile(t *testing.T) { + if os.Getenv("PGSERVICEFILE") == "" { + if os.Getenv("TRAVIS") != "true" { + t.Skip("PGSERVICEFILE not set, skipping service connection file tests") + } + os.Setenv("PGSERVICEFILE", "/tmp/pqgotest_pgservice") + os.Remove(pgpassFile) + pgservice, err := os.OpenFile(os.Getenv("PGSERVICEFILE"), os.O_RDWR|os.O_CREATE, 0644) + if err != nil { + t.Fatalf("Unexpected error writing pg service file %#v", err) + } + _, err = pgservice.WriteString(` +[service1] +host=localhost + +[service2] +dbname=template2 + +[service3] +thistestshould=fail +`) + if err != nil { + t.Fatalf("Unexpected error writing pg service file %#v", err) + } + pgservice.Close() + } + + testAssert := func(conninfo string, expected string, reason string) { + conn, err := openTestConnConninfo(conninfo) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + txn, err := conn.Begin() + if err != nil { + if expected != "fail" { + t.Fatalf(reason, err) + } + return + } + rows, err := txn.Query("SELECT USER") + if err != nil { + txn.Rollback() + if expected != "fail" { + t.Fatalf(reason, err) + } + } else { + rows.Close() + if expected != "ok" { + t.Fatalf(reason, err) + } + } + txn.Rollback() + } + + testAssert("service=service1", "ok", "connect to defaults failed") + testAssert("service=service2", "fail", "connect to template2 failed") + testAssert("service=service3", "fail", "unrecognized parameter %#v") + + os.Setenv("PGSERVICEFILE", "IdoNotExist") + testAssert("service=pietje", "fail", "service file does not exist") + + os.Setenv("PGSERVICEFILE", "") +} + const pgpassFile = "/tmp/pqgotest_pgpass" func TestPgpass(t *testing.T) { diff --git a/connector.go b/connector.go index 2f8ced673..b62d63b1a 100644 --- a/connector.go +++ b/connector.go @@ -47,6 +47,7 @@ func NewConnector(dsn string) (*Connector, error) { // // * Very low precedence defaults applied in every situation // * Environment variables + // * Service name variables // * Explicitly passed connection information o["host"] = "localhost" o["port"] = "5432" @@ -68,6 +69,24 @@ func NewConnector(dsn string) (*Connector, error) { return nil, err } + // whenever a service is specified, we will need to parse the connection service file + // and override the defaults with the parameters specified for that service. + // See https://www.postgresql.org/docs/current/libpq-pgservice.html + if service, ok := o["service"]; ok { + if err := parseServiceFile(service, o); err != nil { + return nil, err + } + + // By overwriting the options with the service parameters we may have masked some + // explicitly passed connection information, e.g. "service=staging user=read_only". + // By repeating the parseOpts we overcome this issue. + if err := parseOpts(dsn, o); err != nil { + return nil, err + } + // "service" itself should not be passed down as a connection parameter + delete(o, "service") + } + // Use the "fallback" application name if necessary if fallback, ok := o["fallback_application_name"]; ok { if _, ok := o["application_name"]; !ok {