Skip to content

Commit bfebce3

Browse files
committed
API refactor, error handling refactor
1 parent 212706e commit bfebce3

File tree

3 files changed

+225
-102
lines changed

3 files changed

+225
-102
lines changed

README.md

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@ The library provides the following signature algorithms:
1414
- [ECDSA](https://pkg.go.dev/github.com/denpeshkov/httpsign/ecdsa)
1515
- [Ed25519](https://pkg.go.dev/github.com/denpeshkov/httpsign/ed25519)
1616

17-
The API is based on two interfaces: `Signer` and `Verifier`.
17+
The API is based on four interfaces: `Signer`, `Verifier` and `SignerSource`, `VerifierSource`.
18+
1819
`Signer` is essentially a wrapper around the signature algorithm's private key.
1920
Because the private key also contains the corresponding public key, `Signer` can be used for verification as well.
21+
`SignerSource` abstracts the retrieval of `Signer` based on the provided key ID.
2022

2123
`Verifier` uses the public key for verification. It is useful in situations where the user only has access to the public key and not the private key.
24+
`VerifierSource` abstracts the retrieval of `Verifier` based on the provided key ID.
2225

2326
The HMAC algorithm is an exception, as it uses the same shared secret key for both signing and verification.
2427
Therefore, the API provides a single structure, [`HMAC`](https://pkg.go.dev/github.com/denpeshkov/httpsign/hmac#HMAC), for both signing and verification.
@@ -28,65 +31,93 @@ Therefore, the API provides a single structure, [`HMAC`](https://pkg.go.dev/gith
2831
Here is an example using `HMAC-SHA-256` algorithm:
2932

3033
```go
31-
sharedKey := []byte("shared-secret")
34+
type staticSource struct{ h *hmac.HMAC }
35+
36+
func (s staticSource) Signer(ctx context.Context, kid string) (Signer, error) {
37+
return s.h, nil
38+
}
39+
func (s staticSource) Verifier(ctx context.Context, kid string) (Verifier, error) {
40+
return s.h, nil
41+
}
42+
43+
const (
44+
secret = "shared-secret"
45+
kid = "key-id"
46+
)
3247

33-
// Create the Signer using the shared secret key.
34-
sgn, err := hshmac.New(sharedKey, crypto.SHA256)
48+
// Create the signer using the shared secret key.
49+
sgn, err := hmac.New([]byte(secret), crypto.SHA256)
3550
if err != nil {
3651
log.Fatal(err)
3752
}
3853

54+
// Create the source given the signer.
55+
src := staticSource{sgn}
56+
3957
// Create the Transport.
40-
tr := httpsign.NewTransport(sgn)
58+
tr := NewTransport(src, kid)
4159

4260
// Create an HTTP client using our transport to sign outgoing requests.
4361
c := &http.Client{Transport: tr}
4462

4563
// Create the Middleware to verify incoming requests signatures.
46-
m := httpsign.NewMiddleware(sgn)
64+
m := Middleware(src, DefaultErrorHandler)
4765

4866
// Wrap the handler.
4967
var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
5068
fmt.Fprintln(w, "Hello")
5169
})
52-
handler = m.Handler(handler)
70+
handler = m(handler)
5371

5472
http.Handle("/api/foo", handler)
5573
```
5674

5775
Here is an example using `RSASSA-PKCS1-v1.5 SHA-256` algorithm:
5876

5977
```go
60-
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
78+
type staticSource struct{ h *rsa.PKCSSigner }
79+
80+
func (s staticSource) Signer(ctx context.Context, kid string) (Signer, error) {
81+
return s.h, nil
82+
}
83+
func (s staticSource) Verifier(ctx context.Context, kid string) (Verifier, error) {
84+
return s.h, nil
85+
}
86+
87+
const kid = "key-id"
88+
89+
privateKey, err := stdrsa.GenerateKey(rand.Reader, 2048)
6190
hash := crypto.SHA256
6291

63-
// Create the Signer using the private key.
64-
sgn, err := hsrsa.NewPKCSSigner(privateKey, hash)
92+
// Create the signer using the shared secret key.
93+
sgn, err := rsa.NewPKCSSigner(privateKey, hash)
6594
if err != nil {
6695
log.Fatal(err)
6796
}
6897

98+
// Create the source given the signer.
99+
src := staticSource{sgn}
100+
69101
// Create the Transport.
70-
tr := httpsign.NewTransport(sgn)
102+
tr := NewTransport(src, kid)
71103

72104
// Create an HTTP client using our transport to sign outgoing requests.
73105
c := &http.Client{Transport: tr}
74106

75107
// Create the Middleware to verify incoming requests signatures.
76-
m := httpsign.NewMiddleware(sgn)
108+
m := Middleware(src, DefaultErrorHandler)
77109

78110
// Alternatively, we can explicitly create a Verifier using the public key.
79111
vrf, err := hsrsa.NewPKCSVerifier(&privateKey.PublicKey, hash)
80112
if err != nil {
81113
log.Fatal(err)
82114
}
83-
m = httpsign.NewMiddleware(vrf)
84115

85116
// Wrap the handler.
86117
var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
87118
fmt.Fprintln(w, "Hello")
88119
})
89-
handler = m.Handler(handler)
120+
handler = m(handler)
90121

91122
http.Handle("/api/foo", handler)
92123
```

http.go

Lines changed: 101 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
package httpsign
33

44
import (
5+
"context"
56
"encoding/base64"
67
"errors"
78
"fmt"
@@ -13,33 +14,65 @@ import (
1314
)
1415

1516
const (
16-
signatureHeader = "X-Signature"
17-
timestampHeader = "X-Signature-Timestamp"
17+
SignatureHeader = "X-Signature"
18+
TimestampHeader = "X-Signature-Timestamp"
19+
KidHeader = "X-Key-ID"
1820
)
1921

2022
var (
21-
// ErrVerification represents a failure to verify a signature.
22-
ErrVerification = errors.New("signature verification error")
23+
// ErrRequestVerification represents a failure to verify an HTTP request.
24+
ErrRequestVerification = errors.New("failed to verify HTTP request")
2325
)
2426

25-
// Transport is an HTTP [http.RoundTripper] which signs outgoing HTTP requests.
27+
// Sign signs the HTTP request using the provided signer and timestamp.
28+
func Sign(s Signer, timestamp time.Time, r *http.Request) error {
29+
var (
30+
method = r.Method
31+
host = r.Host
32+
path = r.URL.EscapedPath()
33+
query = query{r.URL.Query()}.Encode()
34+
ts = timestamp.UTC().Format(time.RFC3339)
35+
)
36+
if path == "" {
37+
path = "/" // See https://www.rfc-editor.org/rfc/rfc9110#section-4.2.3
38+
}
39+
data := fmt.Sprintf("%s%s%s%s%s", method, host, path, query, ts)
40+
sig, err := s.Sign([]byte(data))
41+
if err != nil {
42+
return err
43+
}
44+
sige := base64.RawURLEncoding.EncodeToString(sig)
45+
r.Header.Add(TimestampHeader, ts)
46+
r.Header.Add(SignatureHeader, sige)
47+
return nil
48+
}
49+
50+
// SignerSource provides a [Signer] given a key ID.
51+
// It must be safe for concurrent use by multiple goroutines.
52+
type SignerSource interface {
53+
Signer(ctx context.Context, kid string) (Signer, error)
54+
}
55+
56+
// Transport is an [http.RoundTripper] which signs outgoing HTTP requests.
2657
type Transport struct {
2758
// Base is the base http.RoundTripper used to make HTTP requests.
2859
// By default, http.DefaultTransport is used.
2960
Base http.RoundTripper
3061

31-
signer Signer
62+
source SignerSource
63+
kid string
3264
}
3365

34-
// NewTransport returns a new [Transport] given a [Signer].
35-
func NewTransport(signer Signer) *Transport {
66+
// NewTransport returns a new [Transport] using a [SignerSource] with the provided key ID.
67+
func NewTransport(source SignerSource, kid string) *Transport {
3668
return &Transport{
3769
Base: http.DefaultTransport,
38-
signer: signer,
70+
source: source,
71+
kid: kid,
3972
}
4073
}
4174

42-
// RoundTrip implements the [http.RoundTripper] interface, signing the request using provided [Signer].
75+
// RoundTrip signs the request.
4376
func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) {
4477
bodyClosed := false
4578
if r.Body != nil {
@@ -51,103 +84,98 @@ func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) {
5184
}
5285

5386
r = r.Clone(r.Context()) // per RoundTripper contract.
54-
if err := t.sign(r); err != nil {
87+
r.Header.Set(KidHeader, t.kid)
88+
signer, err := t.source.Signer(r.Context(), t.kid)
89+
if err != nil {
90+
return nil, fmt.Errorf("obtain signer: %w", err)
91+
}
92+
if err := Sign(signer, time.Now(), r); err != nil {
5593
return nil, fmt.Errorf("sign request: %w", err)
5694
}
5795
bodyClosed = true // r.Body is closed by the base RoundTripper.
5896
return t.Base.RoundTrip(r)
5997
}
6098

61-
func (t *Transport) sign(r *http.Request) error {
99+
// Verify verifies the HTTP request using the provided verifier.
100+
func Verify(v Verifier, r *http.Request) error {
62101
var (
63102
method = r.Method
64103
host = r.Host
65104
path = r.URL.EscapedPath()
66105
query = query{r.URL.Query()}.Encode()
67-
timestamp = time.Now().UTC().Format(time.RFC3339)
106+
timestamp = r.Header.Get(TimestampHeader)
68107
)
69-
if path == "" {
70-
path = "/" // See https://www.rfc-editor.org/rfc/rfc9110#section-4.2.3
108+
if timestamp == "" {
109+
return fmt.Errorf("missing %s header", TimestampHeader)
110+
}
111+
sigRaw := r.Header.Get(SignatureHeader)
112+
if sigRaw == "" {
113+
return fmt.Errorf("missing %s header", SignatureHeader)
71114
}
72-
data := fmt.Sprintf("%s%s%s%s%s", method, host, path, query, timestamp)
73-
sig, err := t.signer.Sign([]byte(data))
115+
sig, err := base64.RawURLEncoding.DecodeString(sigRaw)
74116
if err != nil {
75-
return err
117+
return fmt.Errorf("malformed signature: %w", err)
118+
}
119+
msg := fmt.Sprintf("%s%s%s%s%s", method, host, path, query, timestamp)
120+
121+
valid, err := v.Verify([]byte(msg), sig)
122+
if err != nil {
123+
return fmt.Errorf("verification failure: %w", err)
124+
}
125+
if !valid {
126+
return errors.New("invalid signature")
76127
}
77-
esig := base64.RawURLEncoding.EncodeToString(sig)
78-
r.Header.Add(timestampHeader, timestamp)
79-
r.Header.Add(signatureHeader, esig)
80128
return nil
81129
}
82130

131+
// VerifierSource provides a [Verifier] given a key ID.
132+
// It must be safe for concurrent use by multiple goroutines.
133+
type VerifierSource interface {
134+
Verifier(ctx context.Context, kid string) (Verifier, error)
135+
}
136+
83137
// DefaultErrorHandler handles errors as follows:
84-
// - If the error is [ErrVerification], it sends a 401 Unauthorized response.
138+
// - If the error is [ErrRequestVerification], it sends a 401 Unauthorized response.
85139
// - For any other errors, it defaults to sending a 500 Internal Server Error response.
86140
func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
87141
if err == nil {
88142
return
89143
}
90144
switch {
91-
case errors.Is(err, ErrVerification):
145+
case errors.Is(err, ErrRequestVerification):
92146
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
93147
default:
94148
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
95149
}
96150
}
97151

98-
type Middleware struct {
99-
// ErrorHandler is used to handle errors that occur during signature verification.
100-
// If not provided, DefaultErrorHandler is used.
101-
ErrorHandler func(w http.ResponseWriter, r *http.Request, err error)
102-
103-
verifier Verifier
104-
}
105-
106-
// NewMiddleware returns a new [Middleware] given a [Verifier].
107-
func NewMiddleware(verifier Verifier) *Middleware {
108-
return &Middleware{
109-
ErrorHandler: DefaultErrorHandler,
110-
verifier: verifier,
152+
// Middleware creates middleware that verifies HTTP request signatures using a [Verifier]
153+
// from the provided source and handles errors with a custom handler.
154+
func Middleware(
155+
source VerifierSource,
156+
errHandler func(w http.ResponseWriter, r *http.Request, err error),
157+
) func(http.Handler) http.Handler {
158+
return func(h http.Handler) http.Handler {
159+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
160+
kid := r.Header.Get(KidHeader)
161+
if kid == "" {
162+
errHandler(w, r, fmt.Errorf("%w: missing %s header", ErrRequestVerification, KidHeader))
163+
return
164+
}
165+
v, err := source.Verifier(r.Context(), kid)
166+
if err != nil {
167+
errHandler(w, r, fmt.Errorf("obtain verifier: %w", err))
168+
return
169+
}
170+
if err := Verify(v, r); err != nil {
171+
errHandler(w, r, fmt.Errorf("%w: %w", ErrRequestVerification, err))
172+
return
173+
}
174+
h.ServeHTTP(w, r)
175+
})
111176
}
112177
}
113178

114-
// Handler returns a handler that serves requests with signature verification.
115-
func (m *Middleware) Handler(h http.Handler) http.Handler {
116-
return m.handler(func(w http.ResponseWriter, r *http.Request) error {
117-
var (
118-
method = r.Method
119-
host = r.Host
120-
path = r.URL.EscapedPath()
121-
query = query{r.URL.Query()}.Encode()
122-
timestamp = r.Header.Get(timestampHeader)
123-
)
124-
msg := fmt.Sprintf("%s%s%s%s%s", method, host, path, query, timestamp)
125-
126-
sig, err := base64.RawURLEncoding.DecodeString(r.Header.Get(signatureHeader))
127-
if err != nil {
128-
return fmt.Errorf("%w: %w", ErrVerification, err)
129-
}
130-
131-
valid, err := m.verifier.Verify([]byte(msg), sig)
132-
if err != nil {
133-
return err
134-
}
135-
if !valid {
136-
return ErrVerification
137-
}
138-
h.ServeHTTP(w, r)
139-
return nil
140-
})
141-
}
142-
143-
func (m *Middleware) handler(h func(w http.ResponseWriter, r *http.Request) error) http.Handler {
144-
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
145-
if err := h(w, r); err != nil {
146-
m.ErrorHandler(w, r, err)
147-
}
148-
})
149-
}
150-
151179
// query embeds [url.Values] overriding [url.Values.Encode] to sort by both key and value.
152180
type query struct{ url.Values }
153181

0 commit comments

Comments
 (0)