22package httpsign
33
44import (
5+ "context"
56 "encoding/base64"
67 "errors"
78 "fmt"
@@ -13,33 +14,65 @@ import (
1314)
1415
1516const (
16- signatureHeader = "X-Signature"
17- timestampHeader = "X-Signature-Timestamp"
17+ SignatureHeader = "X-Signature"
18+ TimestampHeader = "X-Signature-Timestamp"
19+ KidHeader = "X-Key-ID"
1820)
1921
2022var (
21- // ErrVerification represents a failure to verify a signature .
22- ErrVerification = errors .New ("signature verification error " )
23+ // ErrRequestVerification represents a failure to verify an HTTP request .
24+ ErrRequestVerification = errors .New ("failed to verify HTTP request " )
2325)
2426
25- // Transport is an HTTP [http.RoundTripper] which signs outgoing HTTP requests.
27+ // Sign signs the HTTP request using the provided signer and timestamp.
28+ func Sign (s Signer , timestamp time.Time , r * http.Request ) error {
29+ var (
30+ method = r .Method
31+ host = r .Host
32+ path = r .URL .EscapedPath ()
33+ query = query {r .URL .Query ()}.Encode ()
34+ ts = timestamp .UTC ().Format (time .RFC3339 )
35+ )
36+ if path == "" {
37+ path = "/" // See https://www.rfc-editor.org/rfc/rfc9110#section-4.2.3
38+ }
39+ data := fmt .Sprintf ("%s%s%s%s%s" , method , host , path , query , ts )
40+ sig , err := s .Sign ([]byte (data ))
41+ if err != nil {
42+ return err
43+ }
44+ sige := base64 .RawURLEncoding .EncodeToString (sig )
45+ r .Header .Add (TimestampHeader , ts )
46+ r .Header .Add (SignatureHeader , sige )
47+ return nil
48+ }
49+
50+ // SignerSource provides a [Signer] given a key ID.
51+ // It must be safe for concurrent use by multiple goroutines.
52+ type SignerSource interface {
53+ Signer (ctx context.Context , kid string ) (Signer , error )
54+ }
55+
56+ // Transport is an [http.RoundTripper] which signs outgoing HTTP requests.
2657type Transport struct {
2758 // Base is the base http.RoundTripper used to make HTTP requests.
2859 // By default, http.DefaultTransport is used.
2960 Base http.RoundTripper
3061
31- signer Signer
62+ source SignerSource
63+ kid string
3264}
3365
34- // NewTransport returns a new [Transport] given a [Signer] .
35- func NewTransport (signer Signer ) * Transport {
66+ // NewTransport returns a new [Transport] using a [SignerSource] with the provided key ID .
67+ func NewTransport (source SignerSource , kid string ) * Transport {
3668 return & Transport {
3769 Base : http .DefaultTransport ,
38- signer : signer ,
70+ source : source ,
71+ kid : kid ,
3972 }
4073}
4174
42- // RoundTrip implements the [http.RoundTripper] interface, signing the request using provided [Signer] .
75+ // RoundTrip signs the request.
4376func (t * Transport ) RoundTrip (r * http.Request ) (* http.Response , error ) {
4477 bodyClosed := false
4578 if r .Body != nil {
@@ -51,103 +84,98 @@ func (t *Transport) RoundTrip(r *http.Request) (*http.Response, error) {
5184 }
5285
5386 r = r .Clone (r .Context ()) // per RoundTripper contract.
54- if err := t .sign (r ); err != nil {
87+ r .Header .Set (KidHeader , t .kid )
88+ signer , err := t .source .Signer (r .Context (), t .kid )
89+ if err != nil {
90+ return nil , fmt .Errorf ("obtain signer: %w" , err )
91+ }
92+ if err := Sign (signer , time .Now (), r ); err != nil {
5593 return nil , fmt .Errorf ("sign request: %w" , err )
5694 }
5795 bodyClosed = true // r.Body is closed by the base RoundTripper.
5896 return t .Base .RoundTrip (r )
5997}
6098
61- func (t * Transport ) sign (r * http.Request ) error {
99+ // Verify verifies the HTTP request using the provided verifier.
100+ func Verify (v Verifier , r * http.Request ) error {
62101 var (
63102 method = r .Method
64103 host = r .Host
65104 path = r .URL .EscapedPath ()
66105 query = query {r .URL .Query ()}.Encode ()
67- timestamp = time . Now (). UTC (). Format ( time . RFC3339 )
106+ timestamp = r . Header . Get ( TimestampHeader )
68107 )
69- if path == "" {
70- path = "/" // See https://www.rfc-editor.org/rfc/rfc9110#section-4.2.3
108+ if timestamp == "" {
109+ return fmt .Errorf ("missing %s header" , TimestampHeader )
110+ }
111+ sigRaw := r .Header .Get (SignatureHeader )
112+ if sigRaw == "" {
113+ return fmt .Errorf ("missing %s header" , SignatureHeader )
71114 }
72- data := fmt .Sprintf ("%s%s%s%s%s" , method , host , path , query , timestamp )
73- sig , err := t .signer .Sign ([]byte (data ))
115+ sig , err := base64 .RawURLEncoding .DecodeString (sigRaw )
74116 if err != nil {
75- return err
117+ return fmt .Errorf ("malformed signature: %w" , err )
118+ }
119+ msg := fmt .Sprintf ("%s%s%s%s%s" , method , host , path , query , timestamp )
120+
121+ valid , err := v .Verify ([]byte (msg ), sig )
122+ if err != nil {
123+ return fmt .Errorf ("verification failure: %w" , err )
124+ }
125+ if ! valid {
126+ return errors .New ("invalid signature" )
76127 }
77- esig := base64 .RawURLEncoding .EncodeToString (sig )
78- r .Header .Add (timestampHeader , timestamp )
79- r .Header .Add (signatureHeader , esig )
80128 return nil
81129}
82130
131+ // VerifierSource provides a [Verifier] given a key ID.
132+ // It must be safe for concurrent use by multiple goroutines.
133+ type VerifierSource interface {
134+ Verifier (ctx context.Context , kid string ) (Verifier , error )
135+ }
136+
83137// DefaultErrorHandler handles errors as follows:
84- // - If the error is [ErrVerification ], it sends a 401 Unauthorized response.
138+ // - If the error is [ErrRequestVerification ], it sends a 401 Unauthorized response.
85139// - For any other errors, it defaults to sending a 500 Internal Server Error response.
86140func DefaultErrorHandler (w http.ResponseWriter , r * http.Request , err error ) {
87141 if err == nil {
88142 return
89143 }
90144 switch {
91- case errors .Is (err , ErrVerification ):
145+ case errors .Is (err , ErrRequestVerification ):
92146 http .Error (w , http .StatusText (http .StatusUnauthorized ), http .StatusUnauthorized )
93147 default :
94148 http .Error (w , http .StatusText (http .StatusInternalServerError ), http .StatusInternalServerError )
95149 }
96150}
97151
98- type Middleware struct {
99- // ErrorHandler is used to handle errors that occur during signature verification.
100- // If not provided, DefaultErrorHandler is used.
101- ErrorHandler func (w http.ResponseWriter , r * http.Request , err error )
102-
103- verifier Verifier
104- }
105-
106- // NewMiddleware returns a new [Middleware] given a [Verifier].
107- func NewMiddleware (verifier Verifier ) * Middleware {
108- return & Middleware {
109- ErrorHandler : DefaultErrorHandler ,
110- verifier : verifier ,
152+ // Middleware creates middleware that verifies HTTP request signatures using a [Verifier]
153+ // from the provided source and handles errors with a custom handler.
154+ func Middleware (
155+ source VerifierSource ,
156+ errHandler func (w http.ResponseWriter , r * http.Request , err error ),
157+ ) func (http.Handler ) http.Handler {
158+ return func (h http.Handler ) http.Handler {
159+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
160+ kid := r .Header .Get (KidHeader )
161+ if kid == "" {
162+ errHandler (w , r , fmt .Errorf ("%w: missing %s header" , ErrRequestVerification , KidHeader ))
163+ return
164+ }
165+ v , err := source .Verifier (r .Context (), kid )
166+ if err != nil {
167+ errHandler (w , r , fmt .Errorf ("obtain verifier: %w" , err ))
168+ return
169+ }
170+ if err := Verify (v , r ); err != nil {
171+ errHandler (w , r , fmt .Errorf ("%w: %w" , ErrRequestVerification , err ))
172+ return
173+ }
174+ h .ServeHTTP (w , r )
175+ })
111176 }
112177}
113178
114- // Handler returns a handler that serves requests with signature verification.
115- func (m * Middleware ) Handler (h http.Handler ) http.Handler {
116- return m .handler (func (w http.ResponseWriter , r * http.Request ) error {
117- var (
118- method = r .Method
119- host = r .Host
120- path = r .URL .EscapedPath ()
121- query = query {r .URL .Query ()}.Encode ()
122- timestamp = r .Header .Get (timestampHeader )
123- )
124- msg := fmt .Sprintf ("%s%s%s%s%s" , method , host , path , query , timestamp )
125-
126- sig , err := base64 .RawURLEncoding .DecodeString (r .Header .Get (signatureHeader ))
127- if err != nil {
128- return fmt .Errorf ("%w: %w" , ErrVerification , err )
129- }
130-
131- valid , err := m .verifier .Verify ([]byte (msg ), sig )
132- if err != nil {
133- return err
134- }
135- if ! valid {
136- return ErrVerification
137- }
138- h .ServeHTTP (w , r )
139- return nil
140- })
141- }
142-
143- func (m * Middleware ) handler (h func (w http.ResponseWriter , r * http.Request ) error ) http.Handler {
144- return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
145- if err := h (w , r ); err != nil {
146- m .ErrorHandler (w , r , err )
147- }
148- })
149- }
150-
151179// query embeds [url.Values] overriding [url.Values.Encode] to sort by both key and value.
152180type query struct { url.Values }
153181
0 commit comments