Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1876,7 +1876,9 @@ func parseEnviron(env []string) (out map[string]string) {
accrue("user")
case "PGPASSWORD":
accrue("password")
case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
case "PGSERVICE":
accrue("service")
case "PGREALM":
unsupported()
case "PGOPTIONS":
accrue("options")
Expand Down Expand Up @@ -1914,6 +1916,62 @@ func parseEnviron(env []string) (out map[string]string) {
return out
}

// parseServiceFile parses the options from a service file and adds them to the values.
//
// The parsing code is based on parseServiceInfo from libpq's fe-connect.c
func parseServiceFile(service string, o values) error {
filename := os.Getenv("PGSERVICEFILE")
if filename == "" {
// XXX this code doesn't work on Windows where the default filename is
// XXX %APPDATA%\postgresql\.pg_service.conf
// Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470
userHome := os.Getenv("HOME")
if userHome == "" {
user, err := user.Current()
if err != nil {
return err
}
userHome = user.HomeDir
}
filename = filepath.Join(userHome, ".pg_service.conf")
}

file, err := os.Open(filename)
if err != nil {
return err
}
defer file.Close()

scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())

// once we find the header of our section, we can start reading
if line == fmt.Sprintf("[%s]", service) {
for scanner.Scan() {
line = strings.TrimSpace(scanner.Text())
// once we find the next section, we're done
if strings.HasPrefix(line, "[") {
return nil
} else if line != "" {
if err := parseOpts(line, o); err != nil {
return err
}
}
}
// EOF means we're done
return nil
}
}

if err := scanner.Err(); err != nil {
return err
}

// if we end up here, we didn't find the service that was explicitly provided
return fmt.Errorf(`definition of service "%s" not found`, service)
}

// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
func isUTF8(name string) bool {
// Recognize all sorts of silly things as "UTF-8", like Postgres does
Expand Down
66 changes: 66 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,72 @@ func TestOpenURL(t *testing.T) {
testURL("postgresql://")
}

func TestPgServiceFile(t *testing.T) {
if os.Getenv("PGSERVICEFILE") == "" {
if os.Getenv("TRAVIS") != "true" {
t.Skip("PGSERVICEFILE not set, skipping service connection file tests")
}
os.Setenv("PGSERVICEFILE", "/tmp/pqgotest_pgservice")
os.Remove(pgpassFile)
pgservice, err := os.OpenFile(os.Getenv("PGSERVICEFILE"), os.O_RDWR|os.O_CREATE, 0644)
if err != nil {
t.Fatalf("Unexpected error writing pg service file %#v", err)
}
_, err = pgservice.WriteString(`
[service1]
host=localhost

[service2]
dbname=template2

[service3]
thistestshould=fail
`)
if err != nil {
t.Fatalf("Unexpected error writing pg service file %#v", err)
}
pgservice.Close()
}

testAssert := func(conninfo string, expected string, reason string) {
conn, err := openTestConnConninfo(conninfo)
if err != nil {
t.Fatal(err)
}
defer conn.Close()

txn, err := conn.Begin()
if err != nil {
if expected != "fail" {
t.Fatalf(reason, err)
}
return
}
rows, err := txn.Query("SELECT USER")
if err != nil {
txn.Rollback()
if expected != "fail" {
t.Fatalf(reason, err)
}
} else {
rows.Close()
if expected != "ok" {
t.Fatalf(reason, err)
}
}
txn.Rollback()
}

testAssert("service=service1", "ok", "connect to defaults failed")
testAssert("service=service2", "fail", "connect to template2 failed")
testAssert("service=service3", "fail", "unrecognized parameter %#v")

os.Setenv("PGSERVICEFILE", "IdoNotExist")
testAssert("service=pietje", "fail", "service file does not exist")

os.Setenv("PGSERVICEFILE", "")
}

const pgpassFile = "/tmp/pqgotest_pgpass"

func TestPgpass(t *testing.T) {
Expand Down
19 changes: 19 additions & 0 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ func NewConnector(dsn string) (*Connector, error) {
//
// * Very low precedence defaults applied in every situation
// * Environment variables
// * Service name variables
// * Explicitly passed connection information
o["host"] = "localhost"
o["port"] = "5432"
Expand All @@ -68,6 +69,24 @@ func NewConnector(dsn string) (*Connector, error) {
return nil, err
}

// whenever a service is specified, we will need to parse the connection service file
// and override the defaults with the parameters specified for that service.
// See https://www.postgresql.org/docs/current/libpq-pgservice.html
if service, ok := o["service"]; ok {
if err := parseServiceFile(service, o); err != nil {
return nil, err
}

// By overwriting the options with the service parameters we may have masked some
// explicitly passed connection information, e.g. "service=staging user=read_only".
// By repeating the parseOpts we overcome this issue.
if err := parseOpts(dsn, o); err != nil {
return nil, err
}
// "service" itself should not be passed down as a connection parameter
delete(o, "service")
}

// Use the "fallback" application name if necessary
if fallback, ok := o["fallback_application_name"]; ok {
if _, ok := o["application_name"]; !ok {
Expand Down