diff --git a/buf.go b/buf.go index 93d73b67..0d69f5ad 100644 --- a/buf.go +++ b/buf.go @@ -39,6 +39,20 @@ func (b *readBuf) string() string { return string(s) } +func (b *readBuf) strings() []string { + ss := []string{} + for (*b)[0] != 0 { + i := bytes.IndexByte(*b, 0) + if i < 0 { + errorf("invalid message format; expected string terminator") + } + s := (*b)[:i] + *b = (*b)[i+1:] + ss = append(ss, string(s)) + } + return ss +} + func (b *readBuf) next(n int) (v []byte) { v = (*b)[:n] *b = (*b)[n:] diff --git a/conn.go b/conn.go index 50f6a66d..54530c66 100644 --- a/conn.go +++ b/conn.go @@ -178,6 +178,9 @@ type conn struct { // GSSAPI context gss GSS + + // channel binding data used for SCRAM-SHA-256-PLUS + tlsServerEndPoint []byte } type syncErr struct { @@ -1204,7 +1207,20 @@ func (cn *conn) ssl(o values) error { } } - cn.c, err = upgrade(cn.c) + conn, err := upgrade(cn.c) + if err != nil { + return err + } + + if o["channel_binding"] != "disable" { + cb, err := tlsServerEndPoint(conn) + if err != nil { + return err + } + cn.tlsServerEndPoint = cb + } + + cn.c = conn return err } @@ -1282,8 +1298,16 @@ func (cn *conn) startup(o values) { func (cn *conn) auth(r *readBuf, o values) { switch code := r.int32(); code { case 0: + if o["channel_binding"] == "required" { + errorf("SCRAM-SHA-256 protocol error: channel binding required") + } + // OK case 3: + if o["channel_binding"] == "required" { + errorf("SCRAM-SHA-256 protocol error: channel binding required") + } + w := cn.writeBuf('p') w.string(o["password"]) cn.send(w) @@ -1297,6 +1321,10 @@ func (cn *conn) auth(r *readBuf, o values) { errorf("unexpected authentication response: %q", t) } case 5: + if o["channel_binding"] == "required" { + errorf("SCRAM-SHA-256 protocol error: channel binding required") + } + s := string(r.next(4)) w := cn.writeBuf('p') w.string("md5" + md5s(md5s(o["password"]+o["user"])+s)) @@ -1311,6 +1339,10 @@ func (cn *conn) auth(r *readBuf, o values) { errorf("unexpected authentication response: %q", t) } case 7: // GSSAPI, startup + if o["channel_binding"] == "required" { + errorf("SCRAM-SHA-256 protocol error: channel binding required") + } + if newGss == nil { errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)") } @@ -1346,6 +1378,9 @@ func (cn *conn) auth(r *readBuf, o values) { cn.gss = cli case 8: // GSSAPI continue + if o["channel_binding"] == "required" { + errorf("SCRAM-SHA-256 protocol error: channel binding required") + } if cn.gss == nil { errorf("GSSAPI protocol error") @@ -1364,7 +1399,41 @@ func (cn *conn) auth(r *readBuf, o values) { // from the server.. case 10: + supported := r.strings() + + scramSha256 := false + scramSha256Plus := false + for _, s := range supported { + switch s { + case "SCRAM-SHA-256": + scramSha256 = true + case "SCRAM-SHA-256-PLUS": + scramSha256Plus = true + } + } + sc := scram.NewClient(sha256.New, o["user"], o["password"]) + + // channel binding is supported by the client + if cn.tlsServerEndPoint != nil { + sc.WithTlsServerEndPoint(cn.tlsServerEndPoint) + } + + var selected string + // SCRAM-SHA-256-PLUS always takes preference. + if cn.tlsServerEndPoint != nil && scramSha256Plus { + sc.UseChannelBinding() + selected = "SCRAM-SHA-256-PLUS" + } else if scramSha256 { + selected = "SCRAM-SHA-256" + } else { + errorf("SCRAM-SHA-256 protocol error") + } + + if o["channel_binding"] == "required" && selected != "SCRAM-SHA-256-PLUS" { + errorf("SCRAM-SHA-256 protocol error: channel binding required") + } + sc.Step(nil) if sc.Err() != nil { errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) @@ -1372,7 +1441,7 @@ func (cn *conn) auth(r *readBuf, o values) { scOut := sc.Out() w := cn.writeBuf('p') - w.string("SCRAM-SHA-256") + w.string(selected) w.int32(len(scOut)) w.bytes(scOut) cn.send(w) diff --git a/scram/scram.go b/scram/scram.go index 0e1ef656..26e4e4d7 100644 --- a/scram/scram.go +++ b/scram/scram.go @@ -65,6 +65,10 @@ type Client struct { serverNonce []byte saltedPass []byte authMsg bytes.Buffer + + // channel binding data used for SCRAM-SHA-256-PLUS + tlsServerEndPoint []byte + channelBinding bool } // NewClient returns a new SCRAM-* client with the provided hash algorithm. @@ -83,6 +87,16 @@ func NewClient(newHash func() hash.Hash, user, pass string) *Client { return c } +// Out returns the data to be sent to the server in the current step. +func (c *Client) WithTlsServerEndPoint(cb []byte) { + c.tlsServerEndPoint = cb +} + +// Out returns the data to be sent to the server in the current step. +func (c *Client) UseChannelBinding() { + c.channelBinding = true +} + // Out returns the data to be sent to the server in the current step. func (c *Client) Out() []byte { if c.out.Len() == 0 { @@ -140,7 +154,17 @@ func (c *Client) step1(in []byte) error { c.authMsg.WriteString(",r=") c.authMsg.Write(c.clientNonce) - c.out.WriteString("n,,") + if c.tlsServerEndPoint != nil && c.channelBinding { + // we support channel binding, and so does the server + c.out.WriteString("p=tls-server-end-point,,") + } else if c.tlsServerEndPoint != nil { + // we support channel binding, but the server doesn't. + c.out.WriteString("y,,") + } else { + // we do not support channel binding. + c.out.WriteString("n,,") + } + c.out.Write(c.authMsg.Bytes()) return nil } @@ -182,11 +206,34 @@ func (c *Client) step2(in []byte) error { } c.saltPassword(salt, iterCount) - c.authMsg.WriteString(",c=biws,r=") - c.authMsg.Write(c.serverNonce) + // channel binding: + c.authMsg.WriteString(",c=") + c.out.WriteString("c=") + + var mode string + if c.tlsServerEndPoint != nil && c.channelBinding { + // we support channel binding, and so does the server + data := []byte("p=tls-server-end-point,,") + data = append(data, c.tlsServerEndPoint...) + + mode = base64.StdEncoding.EncodeToString(data) + } else if c.tlsServerEndPoint != nil { + // we support channel binding, but the server doesn't. + mode = "eSws" + } else { + // we do not support channel binding. + mode = "biws" + } + c.authMsg.WriteString(mode) + c.out.WriteString(mode) - c.out.WriteString("c=biws,r=") + // server nonce + c.authMsg.WriteString(",r=") + c.out.WriteString(",r=") + + c.authMsg.Write(c.serverNonce) c.out.Write(c.serverNonce) + c.out.WriteString(",p=") c.out.Write(c.clientProof()) return nil diff --git a/ssl.go b/ssl.go index 2bedb953..54409cdb 100644 --- a/ssl.go +++ b/ssl.go @@ -1,6 +1,7 @@ package pq import ( + "crypto" "crypto/tls" "crypto/x509" "errors" @@ -14,7 +15,7 @@ import ( // ssl generates a function to upgrade a net.Conn based on the "sslmode" and // related settings. The function is nil when no upgrade should take place. -func ssl(o values) (func(net.Conn) (net.Conn, error), error) { +func ssl(o values) (func(net.Conn) (*tls.Conn, error), error) { verifyCaOnly := false tlsConf := tls.Config{} switch mode := o["sslmode"]; mode { @@ -78,7 +79,7 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { // also initiates renegotiations and cannot be reconfigured. tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient - return func(conn net.Conn) (net.Conn, error) { + return func(conn net.Conn) (*tls.Conn, error) { client := tls.Client(conn, &tlsConf) if verifyCaOnly { err := sslVerifyCertificateAuthority(client, &tlsConf) @@ -220,3 +221,28 @@ func sslnegotiation(o values) bool { } return true } + +func tlsServerEndPoint(conn *tls.Conn) ([]byte, error) { + err := conn.Handshake() + if err != nil { + return nil, err + } + + cert := conn.ConnectionState().PeerCertificates[0] + + // choose the channel binding hash type + // Use the same hash type used for the certificate signature, except for MD5 and SHA-1 which + // use SHA256 + hashType := crypto.SHA256 + switch cert.SignatureAlgorithm { + case x509.SHA384WithRSA, x509.ECDSAWithSHA384, x509.SHA384WithRSAPSS: + hashType = crypto.SHA384 + case x509.SHA512WithRSA, x509.ECDSAWithSHA512, x509.SHA512WithRSAPSS: + hashType = crypto.SHA512 + } + + hasher := hashType.New() + _, _ = hasher.Write(cert.Raw) + data := hasher.Sum(nil) + return data, nil +}