diff --git a/ssl.go b/ssl.go index 36b61ba4..64d00f7b 100644 --- a/ssl.go +++ b/ssl.go @@ -8,14 +8,39 @@ import ( "os" "os/user" "path/filepath" + "runtime" "strings" ) +var testUser *user.User // for replacing user.Current() in tests + // ssl generates a function to upgrade a net.Conn based on the "sslmode" and // related settings. The function is nil when no upgrade should take place. func ssl(o values) (func(net.Conn) (net.Conn, error), error) { + var usr *user.User + // usr.Current() might fail when cross-compiling. We have to ignore the + // error and continue without home directory defaults, since we wouldn't + // know from where to load certificates. + if testUser != nil { + usr = new(user.User) + *usr = *testUser + } else { + usr, _ = user.Current() + } + verifyCaOnly := false tlsConf := tls.Config{} + + if usr != nil && o["sslmode"] != "disable" && o["sslrootcert"] == "" { + // https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-SSLROOTCERT + // https://www.postgresql.org/docs/current/libpq-ssl.html#LIBQ-SSL-CERTIFICATES + if runtime.GOOS == "windows" { + o["sslrootcert"] = filepath.Join(usr.HomeDir, "AppData", "Roaming", "postgresql", "root.crt") + } else { + o["sslrootcert"] = filepath.Join(usr.HomeDir, ".postgresql", "root.crt") + } + } + switch mode := o["sslmode"]; mode { // "require" is the default. case "", "require": @@ -61,7 +86,7 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { tlsConf.ServerName = o["host"] } - err := sslClientCertificates(&tlsConf, o) + err := sslClientCertificates(&tlsConf, o, usr) if err != nil { return nil, err } @@ -93,7 +118,7 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { // "sslkey" settings, or if they aren't set, from the .postgresql directory // in the user's home directory. The configured files must exist and have // the correct permissions. -func sslClientCertificates(tlsConf *tls.Config, o values) error { +func sslClientCertificates(tlsConf *tls.Config, o values, user *user.User) error { sslinline := o["sslinline"] if sslinline == "true" { cert, err := tls.X509KeyPair([]byte(o["sslcert"]), []byte(o["sslkey"])) @@ -104,11 +129,6 @@ func sslClientCertificates(tlsConf *tls.Config, o values) error { return nil } - // user.Current() might fail when cross-compiling. We have to ignore the - // error and continue without home directory defaults, since we wouldn't - // know from where to load them. - user, _ := user.Current() - // In libpq, the client certificate is only loaded if the setting is not blank. // // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1036-L1037 diff --git a/ssl_test.go b/ssl_test.go index 4c631b81..3261dd06 100644 --- a/ssl_test.go +++ b/ssl_test.go @@ -13,6 +13,7 @@ import ( "io" "net" "os" + "os/user" "path/filepath" "strings" "testing" @@ -76,31 +77,16 @@ func TestSSLConnection(t *testing.T) { rows.Close() } -// Test sslmode=verify-full +// Test sslmode=verify-full sslrootcert=rootCertPath func TestSSLVerifyFull(t *testing.T) { maybeSkipSSLTests(t) // Environment sanity check: should fail without SSL checkSSLSetup(t, "sslmode=disable user=pqgossltest") - // Not OK according to the system CA - _, err := openSSLConn(t, "host=postgres sslmode=verify-full user=pqgossltest") - if err == nil { - t.Fatal("expected error") - } - { - var x509err x509.UnknownAuthorityError - if !errors.As(err, &x509err) { - var x509err x509.HostnameError - if !errors.As(err, &x509err) { - t.Fatalf("expected x509.UnknownAuthorityError or x509.HostnameError, got %#+v", err) - } - } - } - rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") rootCert := "sslrootcert=" + rootCertPath + " " // No match on Common Name - _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-full user=pqgossltest") + _, err := openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-full user=pqgossltest") if err == nil { t.Fatal("expected error") } @@ -117,6 +103,20 @@ func TestSSLVerifyFull(t *testing.T) { } } +// Test sslmode=verify-full +func TestSSLVerifyFullWithDefaultRootCert(t *testing.T) { + maybeSkipSSLTests(t) + // Environment sanity check: should fail without SSL + checkSSLSetup(t, "sslmode=disable user=pqgossltest") + + setupHomeWithRootCRT(t) + + _, err := openSSLConn(t, "host=postgres sslmode=verify-full user=pqgossltest") + if err != nil { + t.Fatal(err) + } +} + // Test sslmode=require sslrootcert=rootCertPath func TestSSLRequireWithRootCert(t *testing.T) { maybeSkipSSLTests(t) @@ -162,30 +162,12 @@ func TestSSLRequireWithRootCert(t *testing.T) { } } -// Test sslmode=verify-ca +// Test sslmode=verify-ca sslrootcert=rootCertPath func TestSSLVerifyCA(t *testing.T) { maybeSkipSSLTests(t) // Environment sanity check: should fail without SSL checkSSLSetup(t, "sslmode=disable user=pqgossltest") - // Not OK according to the system CA - { - _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest") - var x509err x509.UnknownAuthorityError - if !errors.As(err, &x509err) { - t.Fatalf("expected %T, got %#+v", x509.UnknownAuthorityError{}, err) - } - } - - // Still not OK according to the system CA; empty sslrootcert is treated as unspecified. - { - _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest sslrootcert=''") - var x509err x509.UnknownAuthorityError - if !errors.As(err, &x509err) { - t.Fatalf("expected %T, got %#+v", x509.UnknownAuthorityError{}, err) - } - } - rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") rootCert := "sslrootcert=" + rootCertPath + " " // No match on Common Name, but that's OK @@ -198,6 +180,23 @@ func TestSSLVerifyCA(t *testing.T) { } } +func TestSSLVerifyCAWithDefaultRootCert(t *testing.T) { + maybeSkipSSLTests(t) + // Environment sanity check: should fail without SSL + checkSSLSetup(t, "sslmode=disable user=pqgossltest") + + setupHomeWithRootCRT(t) + + // No match on Common Name, but that's OK + if _, err := openSSLConn(t, "host=127.0.0.1 sslmode=verify-ca user=pqgossltest"); err != nil { + t.Fatal(err) + } + // Everything OK + if _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest"); err != nil { + t.Fatal(err) + } +} + // Authenticate over SSL using client certificates func TestSSLClientCertificates(t *testing.T) { maybeSkipSSLTests(t) @@ -377,6 +376,38 @@ func TestSNISupport(t *testing.T) { } } +func setupHomeWithRootCRT(t *testing.T) { + t.Helper() + + homeDir, err := os.MkdirTemp("", "lib-pg-ssl-test-*") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { os.RemoveAll(homeDir) }) + + err = os.MkdirAll(filepath.Join(homeDir, ".postgresql"), 0700) + if err != nil { + t.Fatal(err) + } + + b, err := os.ReadFile("certs/root.crt") + if err != nil { + t.Fatal(err) + } + + err = os.WriteFile(filepath.Join(homeDir, ".postgresql", "root.crt"), b, 0600) + if err != nil { + t.Fatal(err) + } + + testUser = &user.User{ + // no leading slash to we can be sure that $HOME/.postgresql/root.crt + // does not exist + HomeDir: homeDir, + } + t.Cleanup(func() { testUser = nil }) +} + // Make a postgres mock server to test TLS SNI // // Accepts postgres StartupMessage and handles TLS clientHello, then closes a connection.