From 7a9a3a0332fb37be999b640c91dc8d3669582a1a Mon Sep 17 00:00:00 2001 From: Martin Vladev Date: Thu, 19 Sep 2019 15:35:22 +0300 Subject: [PATCH 1/2] [WIP] Add Implicit Grant Flow support See https://tools.ietf.org/html/rfc6749#section-4.2 https://openid.net/specs/openid-connect-core-1_0.html#ImplicitFlowAuth --- example/implicit/main.go | 89 +++++++++ example/main.go | 5 +- go.sum | 6 - implicit.go | 93 ++++++++++ implicit/types.go | 51 +++++ implicit_test.go | 222 ++++++++++++++++++++++ internal/implicit/implicit.go | 288 +++++++++++++++++++++++++++++ internal/implicit/implicit_test.go | 71 +++++++ internal/shared.go | 32 ++++ internal/shared_test.go | 58 ++++++ oauth2cli_test.go | 51 ++--- server.go | 22 +-- 12 files changed, 936 insertions(+), 52 deletions(-) create mode 100644 example/implicit/main.go create mode 100644 implicit.go create mode 100644 implicit/types.go create mode 100644 implicit_test.go create mode 100644 internal/implicit/implicit.go create mode 100644 internal/implicit/implicit_test.go create mode 100644 internal/shared.go create mode 100644 internal/shared_test.go diff --git a/example/implicit/main.go b/example/implicit/main.go new file mode 100644 index 0000000..e8954fa --- /dev/null +++ b/example/implicit/main.go @@ -0,0 +1,89 @@ +package main + +import ( + "context" + "flag" + "log" + "os" + + "github.com/int128/oauth2cli" + "github.com/int128/oauth2cli/implicit" + "github.com/pkg/browser" + "golang.org/x/oauth2/google" + "golang.org/x/sync/errgroup" + "golang.org/x/xerrors" +) + +type cmdOptions struct { + clientID string + localServerCert string + localServerKey string +} + +func main() { + var o cmdOptions + flag.StringVar(&o.clientID, "client-id", "", "OAuth Client ID") + flag.StringVar(&o.localServerCert, "local-server-cert", "", "Path to a certificate file for the local server") + flag.StringVar(&o.localServerKey, "local-server-key", "", "Path to a key file for the local server") + flag.Parse() + + if o.clientID == "" { + log.Printf(`You need to set oauth2 credentials. +Open https://console.cloud.google.com/apis/credentials and create a client. +Then set the following options:`) + flag.PrintDefaults() + os.Exit(1) + return + } + + if o.localServerCert == "" || o.localServerKey == "" { + log.Printf("Certificate and key are required") + flag.PrintDefaults() + os.Exit(1) + return + } + + ctx := context.Background() + + ready := make(chan string, 1) + var eg errgroup.Group + eg.Go(func() error { + select { + case url, ok := <-ready: + if !ok { + return nil + } + log.Printf("Open %s", url) + if err := browser.OpenURL(url); err != nil { + log.Printf("could not open the browser: %s", err) + } + return nil + case err := <-ctx.Done(): + return xerrors.Errorf("context done while waiting for authorization: %w", err) + } + }) + eg.Go(func() error { + + defer close(ready) + token, nonce, err := oauth2cli.GeTokenIDTokenImplicitly(ctx, &implicit.ServerConfig{ + LocalServerPort: []int{8000}, + LocalServerReadyChan: ready, + LocalServerCertFile: o.localServerCert, + LocalServerKeyFile: o.localServerKey, + Config: implicit.Config{ + ClientID: o.clientID, + AuthURL: google.Endpoint.AuthURL, + RedirectURL: "https://localhost:8000/implicit", + Scopes: []string{"openid"}, + }, + }) + if err != nil { + return xerrors.Errorf("could not get a token: %w", err) + } + log.Printf("You got a valid token: %+v\nnonce: %q", *token, nonce) + return nil + }) + if err := eg.Wait(); err != nil { + log.Printf("error while authorization: %s", err) + } +} diff --git a/example/main.go b/example/main.go index eebbc70..5592160 100644 --- a/example/main.go +++ b/example/main.go @@ -3,14 +3,15 @@ package main import ( "context" "flag" + "log" + "os" + "github.com/int128/oauth2cli" "github.com/pkg/browser" "golang.org/x/oauth2" "golang.org/x/oauth2/google" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" - "log" - "os" ) type cmdOptions struct { diff --git a/go.sum b/go.sum index 612d623..a2001a4 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,6 @@ cloud.google.com/go v0.34.0 h1:eOI3/cP2VTU6uZLDYAoic+eyzzB9YyGmJ7eIjl8rOPg= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/int128/listener v0.0.0-20191003025026-314d86475e3b h1:7K9Vdxt6paQjHxlQpZ2UgSO68wcYUqVjxZc0rAYCRJg= -github.com/int128/listener v0.0.0-20191003025026-314d86475e3b/go.mod h1:sho0rrH7mNRRZH4hYOYx+xwRDGmtRndaUiu2z9iumes= github.com/int128/listener v1.0.0 h1:a9H3m4jbXgXpxJUK3fxWrh37Iic/UU/kYOGE0WtjbbI= github.com/int128/listener v1.0.0/go.mod h1:sho0rrH7mNRRZH4hYOYx+xwRDGmtRndaUiu2z9iumes= github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4 h1:49lOXmGaUpV9Fz3gd7TFZY106KVlPVa5jcYD1gaQf98= @@ -11,8 +9,6 @@ github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/oauth2 v0.0.0-20190319182350-c85d3e98c914 h1:jIOcLT9BZzyJ9ce+IwwZ+aF9yeCqzrR+NrD68a/SHKw= -golang.org/x/oauth2 v0.0.0-20190319182350-c85d3e98c914/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45 h1:SVwTIAaPC2U/AvvLNZ2a7OVsmBpC8L5BlwK1whH3hm0= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= @@ -20,8 +16,6 @@ golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e h1:vcxGaoTs7kV8m5Np9uUNQin4BrLOthgV7252N8V+FwY= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4= -golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= diff --git a/implicit.go b/implicit.go new file mode 100644 index 0000000..41c739b --- /dev/null +++ b/implicit.go @@ -0,0 +1,93 @@ +package oauth2cli + +import ( + "context" + + implict_types "github.com/int128/oauth2cli/implicit" + shared "github.com/int128/oauth2cli/internal" + implicit_int "github.com/int128/oauth2cli/internal/implicit" + "golang.org/x/oauth2" + "golang.org/x/xerrors" +) + +// GetTokenImplicitly performs Implicit Grant Flow and returns a token from the provider. +// See https://tools.ietf.org/html/rfc6749#section-4.2 +// +// This does the following steps: +// +// 1. Start a local server at the port. +// 2. Open a browser and navigate it to the local server. +// 3. Wait for the user authorization. +// 4. Receive a token via an authorization response (HTTP redirect). +// 5. Post the URL fragment via JavaScript to a local endpoint. +// 6. Return the token. +// +func GetTokenImplicitly(ctx context.Context, c *implict_types.ServerConfig) (token *oauth2.Token, err error) { + if c.LocalServerMiddleware == nil { + c.LocalServerMiddleware = shared.DefaultMiddleware + } + if c.LocalServerSuccessHTML == "" { + c.LocalServerSuccessHTML = DefaultLocalServerSuccessHTML + } + token, _, err = implicit_int.ReceiveTokenViaLocalServer(ctx, c, []string{"token"}) + if err != nil { + return token, xerrors.Errorf("error while receiving token: %w", err) + } + return token, err +} + +// GetIDTokenImplicitly performs Implicit Grant Flow and returns a id_token from the provider. +// See https://openid.net/specs/openid-connect-core-1_0.html#ImplicitFlowAuth +// +// This does the following steps: +// +// 1. Start a local server at the port. +// 2. Open a browser and navigate it to the local server. +// 3. Wait for the user authorization. +// 4. Receive a id_token via an authorization response (HTTP redirect). +// 5. Post the URL fragment via JavaScript to a local endpoint. +// 6. Return the id_token. +// +// Note: it's up to the consumer to validate the id_token with the nonce value. +// +func GetIDTokenImplicitly(ctx context.Context, c *implict_types.ServerConfig) (token *oauth2.Token, nonce string, err error) { + if c.LocalServerMiddleware == nil { + c.LocalServerMiddleware = shared.DefaultMiddleware + } + if c.LocalServerSuccessHTML == "" { + c.LocalServerSuccessHTML = DefaultLocalServerSuccessHTML + } + token, nonce, err = implicit_int.ReceiveTokenViaLocalServer(ctx, c, []string{"id_token"}) + if err != nil { + return token, nonce, xerrors.Errorf("error while receiving token: %w", err) + } + return token, nonce, err +} + +// GeTokenIDTokenImplicitly performs Implicit Grant Flow and returns a token and id_token from the provider. +// See https://openid.net/specs/openid-connect-core-1_0.html#ImplicitFlowAuth +// +// This does the following steps: +// +// 1. Start a local server at the port. +// 2. Open a browser and navigate it to the local server. +// 3. Wait for the user authorization. +// 4. Receive a id_token via an authorization response (HTTP redirect). +// 5. Post the URL fragment via JavaScript to a local endpoint. +// 6. Return the id_token. +// +// Note: it's up to the consumer to validate the id_token with the nonce value. +// +func GeTokenIDTokenImplicitly(ctx context.Context, c *implict_types.ServerConfig) (token *oauth2.Token, nonce string, err error) { + if c.LocalServerMiddleware == nil { + c.LocalServerMiddleware = shared.DefaultMiddleware + } + if c.LocalServerSuccessHTML == "" { + c.LocalServerSuccessHTML = DefaultLocalServerSuccessHTML + } + token, nonce, err = implicit_int.ReceiveTokenViaLocalServer(ctx, c, []string{"token", "id_token"}) + if err != nil { + return token, nonce, xerrors.Errorf("error while receiving token: %w", err) + } + return token, nonce, err +} diff --git a/implicit/types.go b/implicit/types.go new file mode 100644 index 0000000..5eac0c1 --- /dev/null +++ b/implicit/types.go @@ -0,0 +1,51 @@ +package implicit + +import ( + "net/http" +) + +// Config describes a typical OAuth2 Implicit flow, with both the +// client application information and the server's endpoint URLs. +type Config struct { + // ClientID is the application's ID. + ClientID string + + // AuthURL represents an OAuth 2.0 provider's authorization endpoint URL. + AuthURL string + + // RedirectURL is the URL to redirect users going through + // the OAuth flow, after the resource owner's URLs. + RedirectURL string + + // Scope specifies optional requested permissions. + Scopes []string +} + +// ServerConfig represents a config for GetToken. +type ServerConfig struct { + Config Config + + // Address which the local server binds to. + // Set to "0.0.0.0" to bind all interfaces. + // Default to "127.0.0.1". + LocalServerAddress string + // Candidates of a port which the local server binds to. + // If multiple ports are given, it will try the ports in order. + // If nil or an empty slice is given, it will allocate a free port. + LocalServerPort []int + // A PEM-encoded certificate, and possibly the complete certificate chain. + // When set, the server will serve TLS traffic using the specified + // certificates. It's recommended that the public key's SANs contain + // the loopback addresses - 'localhost', '127.0.0.1' and '::1' + LocalServerCertFile string + // A PEM-encoded private key for the certificate. + // This is required when LocalServerCertFile is set. + LocalServerKeyFile string + // Response HTML body on authorization completed. + // Default to DefaultLocalServerSuccessHTML. + LocalServerSuccessHTML string + // Middleware for the local server. Default to none. + LocalServerMiddleware func(h http.Handler) http.Handler + // A channel to send its URL when the local server is ready. Default to none. + LocalServerReadyChan chan<- string +} diff --git a/implicit_test.go b/implicit_test.go new file mode 100644 index 0000000..2d7b127 --- /dev/null +++ b/implicit_test.go @@ -0,0 +1,222 @@ +package oauth2cli_test + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "math/rand" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "github.com/int128/oauth2cli" + "github.com/int128/oauth2cli/implicit" + internal_implicit "github.com/int128/oauth2cli/internal/implicit" + "golang.org/x/sync/errgroup" + "golang.org/x/xerrors" +) + +func init() { + // for tests we want to have a stable time + internal_implicit.Now = func() time.Time { return time.Time{} } +} + +func TestGetTokenImplicitly(t *testing.T) { + cfg := &implicit.ServerConfig{ + Config: implicit.Config{ + ClientID: "YOUR_CLIENT_ID", + Scopes: []string{"email", "profile"}, + }, + LocalServerCertFile: "testdata/cert.pem", + LocalServerKeyFile: "testdata/cert-key.pem", + LocalServerMiddleware: loggingMiddleware(t), + } + t.Run("Success", func(t *testing.T) { successfulTokenImplicitTest(t, cfg) }) + // t.Run("ErrorAuthResponse", func(t *testing.T) { errorAuthResponseTest(t, cfg) }) +} + +type implicitAuthServerHandler struct { + t *testing.T + NewAuthResponse func(scope, state, nonce, redirectURI string) string +} + +func (h *implicitAuthServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if err := h.serveHTTP(w, r); err != nil { + h.t.Errorf("authServerHandler error: %s", err) + http.Error(w, err.Error(), 500) + } +} + +func (h *implicitAuthServerHandler) serveHTTP(w http.ResponseWriter, r *http.Request) error { + switch { + case r.Method == "GET" && r.URL.Path == "/auth": + q := r.URL.Query() + scope, nonce, state, redirectURI := q.Get("scope"), q.Get("nonce"), q.Get("state"), q.Get("redirect_uri") + + if scope == "" { + return xerrors.New("scope is missing") + } + if state == "" { + return xerrors.New("state is missing") + } + if redirectURI == "" { + return xerrors.New("redirect_uri is missing") + } + to := h.NewAuthResponse(scope, nonce, state, redirectURI) + http.Redirect(w, r, to, 302) + + default: + http.NotFound(w, r) + } + return nil +} + +func successfulTokenImplicitTest(t *testing.T, cfg *implicit.ServerConfig) { + + ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Hour) + defer cancel() + h := implicitAuthServerHandler{ + t: t, + NewAuthResponse: func(scope, nonce, state, redirectURI string) string { + if w := "email profile"; scope != w { + t.Errorf("scope wants %s but %s", w, scope) + return fmt.Sprintf("%s?error=invalid_scope", redirectURI) + } + if cfg.LocalServerCertFile != "" && !strings.HasPrefix(redirectURI, "https://") { + t.Errorf("redirect_uri must start with https:// when using TLS config %s", redirectURI) + return fmt.Sprintf("%s?error=invalid_redirect_uri", redirectURI) + } + return fmt.Sprintf("%s#access_token=ACCESS_TOKEN&state=%s&token_type=bearer&expires_in=3333", redirectURI, state) + }, + } + s := httptest.NewServer(&h) + defer s.Close() + openBrowserCh := make(chan string) + defer close(openBrowserCh) + + cfg.LocalServerReadyChan = openBrowserCh + + p := randomPort() + cfg.LocalServerPort = []int{p} + cfg.Config.AuthURL = s.URL + "/auth" + cfg.Config.RedirectURL = fmt.Sprintf("https://localhost:%d/oauth2/implicit/callback", p) + + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { + // Wait for the local server and open a browser request. + select { + case url := <-openBrowserCh: + resp, body, err := openBrowserRequest(ctx, url) + if err != nil { + return xerrors.Errorf("could not open browser request: %w", err) + } + if resp.StatusCode != 200 { + return xerrors.Errorf("status wants 200 but %d", resp.StatusCode) + } + if expected := fmt.Sprintf(internal_implicit.JSPoster, "/oauth2/implicit/callback"); body != expected { + return xerrors.Errorf("response body did not match, want:\n%s\nbut was:\n%s", expected, body) + } + resp, err = postFragment(ctx, cfg.Config.RedirectURL, resp) + if resp.StatusCode != 200 { + return xerrors.Errorf("status wants 200 but %d", resp.StatusCode) + } + return err + case <-ctx.Done(): + return xerrors.Errorf("context done while waiting for opening browser: %w", ctx.Err()) + } + }) + eg.Go(func() error { + // Start a local server and get a token. + token, err := oauth2cli.GetTokenImplicitly(ctx, cfg) + if err != nil { + return xerrors.Errorf("could not get a token: %w", err) + } + if token.AccessToken != "ACCESS_TOKEN" { + return xerrors.Errorf("AccessToken wants %q but %q", "ACCESS_TOKEN", token.AccessToken) + } + if token.Type() != "Bearer" { + return xerrors.Errorf("TokenType should be %q but %q", "Bearer", token.Type()) + } + if token.RefreshToken != "" { + return xerrors.Errorf("RefreshToken should not be set but it is %q", token.RefreshToken) + } + + expectedTime := (time.Time{}).Add(time.Second * 3333) + if token.Expiry != expectedTime { + return xerrors.Errorf("Expiry wants %v but was %v", expectedTime, token.Expiry) + } + return nil + }) + if err := eg.Wait(); err != nil { + t.Errorf("error: %+v", err) + } + +} + +// returns a random port between 1024 and 32767 +func randomPort() int { + return 1024 + rand.New(rand.NewSource(time.Now().UnixNano())).Intn(31744) +} + +func openBrowserRequestImplicitly(ctx context.Context, url string) (int, string, error) { + c, err := client() + if err != nil { + return 0, "", xerrors.Errorf("could not create client: %w", err) + } + resp, err := getWithContext(ctx, c, url) + + if err != nil { + return 0, "", xerrors.Errorf("could not send a request: %w", err) + } + l := resp.Request.Response.Header.Get("Location") + fmt.Printf("Location: %s", l) + + defer resp.Body.Close() + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return resp.StatusCode, "", xerrors.Errorf("could not read response body: %w", err) + } + return resp.StatusCode, string(b), nil +} + +func postFragment(ctx context.Context, postUrl string, r *http.Response) (*http.Response, error) { + c, err := client() + if err != nil { + return nil, xerrors.Errorf("could not create client: %w", err) + } + locationURL, err := url.Parse(r.Request.Response.Header.Get("Location")) + if err != nil { + return nil, xerrors.Errorf("could not paste location url: %w", err) + } + p, err := url.Parse(postUrl) + if err != nil { + return nil, xerrors.Errorf("could not parse postUrl: %w", err) + } + + p.RawQuery = locationURL.Fragment + + pr, err := http.NewRequestWithContext(ctx, "POST", p.String(), nil) + if err != nil { + return nil, xerrors.Errorf("could not create post request: %w", err) + } + return c.Do(pr) +} + +func client() (*http.Client, error) { + certPool := x509.NewCertPool() + data, err := ioutil.ReadFile("testdata/ca.pem") + if err != nil { + return nil, xerrors.Errorf("could not read certificate authority: %w", err) + } + if !certPool.AppendCertsFromPEM(data) { + return nil, fmt.Errorf("could not append certificate data") + } + + // we add our custom CA, otherwise the client will throw an invalid certificate error. + return &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{RootCAs: certPool}}}, nil +} diff --git a/internal/implicit/implicit.go b/internal/implicit/implicit.go new file mode 100644 index 0000000..36823ff --- /dev/null +++ b/internal/implicit/implicit.go @@ -0,0 +1,288 @@ +package implicit + +import ( + "bytes" + "context" + "fmt" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/int128/listener" + types "github.com/int128/oauth2cli/implicit" + shared "github.com/int128/oauth2cli/internal" + "golang.org/x/oauth2" + "golang.org/x/sync/errgroup" + "golang.org/x/xerrors" +) + +// Now returns the current time. Overritten in tests +var Now = time.Now + +type AuthorizationResponse struct { + token *oauth2.Token // non-empty if a valid token is received + nonce string // token and id_token should check it the claim "nonce" matches this value + err error // non-nil if an error is received or any error occurs +} + +type localServerHandler struct { + config *types.ServerConfig + // nonce is a token to protect the user from CSRF attacks. You must + // always provide a non-empty string and validate that it matches the + // the state query parameter on your redirect callback. + // See http://tools.ietf.org/html/rfc6749#section-10.12 for more info. + nonce string + state string + responseCh chan<- *AuthorizationResponse + redirectPath string + responseTypes []string +} + +// query get changed +func (h *localServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + fmt.Println(h.redirectPath) + switch { + case r.Method == "GET" && r.URL.Path == h.redirectPath && q.Get("error") != "": + h.responseCh <- h.handleErrorResponse(w, r) + case r.Method == "POST" && r.URL.Path == h.redirectPath: + h.responseCh <- h.handleTokenResponse(w, r) + case r.Method == "GET" && r.URL.Path == h.redirectPath: + h.handleRawTokenResponse(w, r) + case r.Method == "GET" && r.URL.Path == "/": + h.handleIndex(w, r) + default: + http.NotFound(w, r) + } +} + +// JSPoster posts the url fragment to the redirect path. +const JSPoster = ` +` + +func (h *localServerHandler) handleRawTokenResponse(w http.ResponseWriter, r *http.Request) { + w.Header().Add("Content-Type", "text/html") + if _, err := fmt.Fprintf(w, JSPoster, h.redirectPath); err != nil { + http.Error(w, "server error", http.StatusInternalServerError) + } +} +func (h *localServerHandler) handleIndex(w http.ResponseWriter, r *http.Request) { + url := h.redirectURL() + http.Redirect(w, r, url, 302) +} + +func (h *localServerHandler) handleTokenResponse(w http.ResponseWriter, r *http.Request) *AuthorizationResponse { + vals := r.URL.Query() + token := &oauth2.Token{ + AccessToken: vals.Get("access_token"), + TokenType: vals.Get("token_type"), + } + + if state := vals.Get("state"); state != h.state { + http.Error(w, "server error", 500) + return &AuthorizationResponse{err: xerrors.Errorf("state does not match, wants %q but got %q", h.state, state)} + } + + if h.hasTokenResponse() { + if token.AccessToken = vals.Get("access_token"); token.AccessToken == "" { + http.Error(w, "server error", 500) + return &AuthorizationResponse{err: xerrors.Errorf("access_token missing in authentication response when requesting token")} + } + if token.TokenType = vals.Get("token_type"); token.TokenType == "" { + http.Error(w, "server error", 500) + return &AuthorizationResponse{err: xerrors.Errorf("token_type missing in authentication response when requesting token")} + } + } + + if h.hasIDTokenResponse() { + idToken := vals.Get("id_token") + if idToken == "" { + http.Error(w, "server error", 500) + return &AuthorizationResponse{err: xerrors.Errorf("id_token missing in authentication response when requesting id_token")} + } + vals.Set("id_token", idToken) + } + + e := vals.Get("expires_in") + expires, _ := strconv.Atoi(e) + if expires != 0 { + token.Expiry = Now().Add(time.Duration(expires) * time.Second) + } + token = token.WithExtra(vals) + + w.Header().Add("Content-Type", "text/html") + if _, err := fmt.Fprintf(w, h.config.LocalServerSuccessHTML); err != nil { + http.Error(w, "server error", 500) + return &AuthorizationResponse{err: xerrors.Errorf("error while writing response body: %w", err)} + } + return &AuthorizationResponse{token: token, nonce: h.nonce} +} + +func (h *localServerHandler) handleErrorResponse(w http.ResponseWriter, r *http.Request) *AuthorizationResponse { + q := r.URL.Query() + errorCode, errorDescription := q.Get("error"), q.Get("error_description") + + http.Error(w, "authorization error", 500) + return &AuthorizationResponse{err: xerrors.Errorf("authorization error from server: %s %s", errorCode, errorDescription)} +} + +func (h *localServerHandler) hasResponse(seek string) bool { + for _, rr := range h.responseTypes { + if rr == seek { + return true + } + } + return false +} + +func (h *localServerHandler) hasIDTokenResponse() bool { + return h.hasResponse("id_token") +} + +func (h *localServerHandler) hasTokenResponse() bool { + return h.hasResponse("token") +} + +// URL returns a URL to OAuth 2.0 provider's consent page +// that asks for permissions for the required scopes explicitly. +func (h *localServerHandler) redirectURL() string { + var buf bytes.Buffer + + c := h.config.Config + buf.WriteString(c.AuthURL) + + v := url.Values{ + "response_type": {strings.Join(h.responseTypes, " ")}, + "client_id": {c.ClientID}, + } + if c.RedirectURL != "" { + v.Set("redirect_uri", c.RedirectURL) + } + if len(c.Scopes) > 0 { + v.Set("scope", strings.Join(c.Scopes, " ")) + } + + v.Set("state", h.state) + + if h.nonce != "" { + v.Set("nonce", h.nonce) + } + if strings.Contains(c.AuthURL, "?") { + buf.WriteByte('&') + } else { + buf.WriteByte('?') + } + buf.WriteString(v.Encode()) + return buf.String() +} + +func ReceiveTokenViaLocalServer(ctx context.Context, c *types.ServerConfig, responseTypes []string) (token *oauth2.Token, nonce string, err error) { + state, err := shared.NewOAuth2State() + if err != nil { + return nil, "", xerrors.Errorf("error while state parameter generation: %w", err) + } + nonce, err = shared.NewOAuth2State() + if err != nil { + return nil, "", xerrors.Errorf("error while nonce parameter generation: %w", err) + } + l, err := listener.New(shared.ExpandAddresses(c.LocalServerAddress, c.LocalServerPort)) + if err != nil { + return nil, "", xerrors.Errorf("error while starting a local server: %w", err) + } + defer l.Close() + + if c.LocalServerCertFile == "" || c.LocalServerKeyFile == "" { + return nil, "", xerrors.Errorf("LocalServerCertFile and LocalServerKeyFile must be set when using implicit flow") + } + var redirectPath = "implicit" + + l.URL.Scheme = "https" + + if c.Config.RedirectURL == "" { + l.URL.Path = "implicit" + c.Config.RedirectURL = l.URL.String() + } else { + rd, err := url.Parse(c.Config.RedirectURL) + if err != nil { + return nil, "", xerrors.Errorf("redirect URL must be a valid URL: %w", err) + } + if rd.Path == "" || len(rd.Path) == 1 { + return nil, "", xerrors.Errorf("redirect URL path must not be empty") + } + // rd.ResolveReference() + redirectPath = rd.Path + + if rd.Scheme != "https" { + return nil, "", xerrors.Errorf("redirect URL scheme must be https") + } + } + + respCh := make(chan *AuthorizationResponse) + server := http.Server{ + Handler: c.LocalServerMiddleware(&localServerHandler{ + config: c, + nonce: nonce, + state: state, + responseCh: respCh, + redirectPath: redirectPath, + responseTypes: responseTypes, + }), + } + var resp *AuthorizationResponse + var eg errgroup.Group + eg.Go(func() error { + for { + select { + case received, ok := <-respCh: + if !ok { + return nil // channel is closed (after the server is stopped) + } + if resp == nil { + resp = received // pick only the first response + } + if err := server.Shutdown(ctx); err != nil { + return xerrors.Errorf("could not shutdown the local server: %w", err) + } + case <-ctx.Done(): + if err := server.Shutdown(ctx); err != nil { + return xerrors.Errorf("could not shutdown the local server: %w", err) + } + return xerrors.Errorf("context done while waiting for authorization response: %w", ctx.Err()) + } + } + }) + eg.Go(func() error { + defer close(respCh) + if c.LocalServerCertFile != "" && c.LocalServerKeyFile != "" { + if err := server.ServeTLS(l, c.LocalServerCertFile, c.LocalServerKeyFile); err != nil && err != http.ErrServerClosed { + return xerrors.Errorf("could not start a local TLS server: %w", err) + } + } else { + if err := server.Serve(l); err != nil && err != http.ErrServerClosed { + return xerrors.Errorf("could not start a local server: %w", err) + } + } + return nil + }) + + if c.LocalServerReadyChan != nil { + c.LocalServerReadyChan <- l.URL.String() + } + + if err := eg.Wait(); err != nil { + return nil, "", xerrors.Errorf("error while authorization: %w", err) + } + if resp == nil { + return nil, "", xerrors.New("no authorization response") + } + ctx.Done() + return resp.token, resp.nonce, resp.err +} diff --git a/internal/implicit/implicit_test.go b/internal/implicit/implicit_test.go new file mode 100644 index 0000000..992f25d --- /dev/null +++ b/internal/implicit/implicit_test.go @@ -0,0 +1,71 @@ +package implicit + +import ( + "net/url" + "reflect" + "testing" + + types "github.com/int128/oauth2cli/implicit" +) + +func Test_localServerHandler_redirectURL(t *testing.T) { + type fields struct { + config *types.ServerConfig + nonce string + state string + responseTypes []string + } + tests := []struct { + name string + fields fields + vals url.Values + }{ + { + "with redirect field", + fields{ + config: &types.ServerConfig{Config: types.Config{ + ClientID: "foo-client", + RedirectURL: "https://localhost:8080/foo/bar", + AuthURL: "https://auth.local:334/oauth-bar", + Scopes: []string{"openid"}, + }}, + state: "some-state", + nonce: "some-nonce", + responseTypes: []string{"token", "id_token"}, + }, + url.Values{ + "client_id": {"foo-client"}, + "redirect_uri": {"https://localhost:8080/foo/bar"}, + "response_type": {"token id_token"}, + "state": {"some-state"}, + "nonce": {"some-nonce"}, + "scope": {"openid"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &localServerHandler{ + config: tt.fields.config, + nonce: tt.fields.nonce, + state: tt.fields.state, + responseTypes: tt.fields.responseTypes, + } + u, err := url.Parse(h.redirectURL()) + if err != nil { + t.Errorf("error when parsing url %v", err) + } + + exected, err := url.Parse("https://auth.local:334/oauth-bar") + if err != nil { + t.Errorf("error when parsing url %v", err) + } + + exected.RawQuery = tt.vals.Encode() + + if !reflect.DeepEqual(u, exected) { + t.Errorf("expected url\n%+v but was\n%+v", exected, u) + } + }) + } +} diff --git a/internal/shared.go b/internal/shared.go new file mode 100644 index 0000000..390584f --- /dev/null +++ b/internal/shared.go @@ -0,0 +1,32 @@ +package internal + +import ( + "crypto/rand" + "encoding/binary" + "fmt" + "net/http" + + "golang.org/x/xerrors" +) + +// ExpandAddresses returns a slice of addresses for every port +func ExpandAddresses(address string, ports []int) (addresses []string) { + for _, port := range ports { + addresses = append(addresses, fmt.Sprintf("%s:%d", address, port)) + } + return +} + +// NewOAuth2State retruns random state +func NewOAuth2State() (string, error) { + var n uint64 + if err := binary.Read(rand.Reader, binary.LittleEndian, &n); err != nil { + return "", xerrors.Errorf("error while reading random: %w", err) + } + return fmt.Sprintf("%x", n), nil +} + +// DefaultMiddleware returns h handler +func DefaultMiddleware(h http.Handler) http.Handler { + return h +} diff --git a/internal/shared_test.go b/internal/shared_test.go new file mode 100644 index 0000000..9e0583e --- /dev/null +++ b/internal/shared_test.go @@ -0,0 +1,58 @@ +package internal + +import ( + "net/http" + "reflect" + "testing" +) + +func TestDefaultMiddleware(t *testing.T) { + + t.Run("same handler is returned", func(t *testing.T) { + + if got := DefaultMiddleware(http.DefaultServeMux); !reflect.DeepEqual(got, http.DefaultServeMux) { + t.Errorf("DefaultMiddleware() = %v, want %v", got, http.DefaultServeMux) + } + }) +} + +func TestNewOAuth2State(t *testing.T) { + + t.Run("different results are returned", func(t *testing.T) { + + s1, err := NewOAuth2State() + if err != nil { + t.Errorf("unexpected error calling NewOAuth2State(): %v", err) + } + s2, err := NewOAuth2State() + if err != nil { + t.Errorf("unexpected error calling NewOAuth2State(): %v", err) + } + + if s1 == s2 { + t.Errorf("DefaultMiddleware() returned the same value on different invocations: %q", s1) + } + }) +} + +func TestExpandAddresses(t *testing.T) { + type args struct { + address string + ports []int + } + tests := []struct { + name string + args args + wantAddresses []string + }{ + {"one port", args{"0.0.0.0", []int{80}}, []string{"0.0.0.0:80"}}, + {"multiple ports port", args{"0.0.0.0", []int{80, 8080}}, []string{"0.0.0.0:80", "0.0.0.0:8080"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotAddresses := ExpandAddresses(tt.args.address, tt.args.ports); !reflect.DeepEqual(gotAddresses, tt.wantAddresses) { + t.Errorf("ExpandAddresses() = %v, want %v", gotAddresses, tt.wantAddresses) + } + }) + } +} diff --git a/oauth2cli_test.go b/oauth2cli_test.go index ee36bc3..e195ff7 100644 --- a/oauth2cli_test.go +++ b/oauth2cli_test.go @@ -2,8 +2,6 @@ package oauth2cli_test import ( "context" - "crypto/tls" - "crypto/x509" "fmt" "io/ioutil" "net/http" @@ -90,13 +88,13 @@ func successfulTest(t *testing.T, cfg oauth2cli.Config) { // Wait for the local server and open a browser request. select { case url := <-openBrowserCh: - status, body, err := openBrowserRequest(url) + resp, body, err := openBrowserRequest(ctx, url) if err != nil { return xerrors.Errorf("could not open browser request: %w", err) } t.Logf("got response body: %s", body) - if status != 200 { - t.Errorf("status wants 200 but %d", status) + if resp.StatusCode != 200 { + t.Errorf("status wants 200 but %d", resp.StatusCode) } if body != oauth2cli.DefaultLocalServerSuccessHTML { t.Errorf("response body did not match") @@ -154,13 +152,13 @@ func errorAuthResponseTest(t *testing.T, cfg oauth2cli.Config) { // Wait for the local server and open a browser request. select { case url := <-openBrowserCh: - status, body, err := openBrowserRequest(url) + resp, body, err := openBrowserRequest(ctx, url) if err != nil { return xerrors.Errorf("could not open browser request: %w", err) } t.Logf("got response body: %s", body) - if status != 500 { - t.Errorf("status wants 500 but %d", status) + if resp.StatusCode != 500 { + return xerrors.Errorf("status wants 500 but %d", resp.StatusCode) } return nil case <-ctx.Done(): @@ -209,16 +207,16 @@ func errorTokenResponseTest(t *testing.T, cfg oauth2cli.Config) { // Wait for the local server and open a browser request. select { case url := <-openBrowserCh: - status, body, err := openBrowserRequest(url) + resp, body, err := openBrowserRequest(ctx, url) if err != nil { return xerrors.Errorf("could not open browser request: %w", err) } t.Logf("got response body: %s", body) - if status != 200 { - t.Errorf("status wants 200 but %d", status) + if resp.StatusCode != 200 { + return xerrors.Errorf("status wants 200 but %d", resp.StatusCode) } if body != oauth2cli.DefaultLocalServerSuccessHTML { - t.Errorf("response body did not match") + return xerrors.Errorf("response body did not match") } return nil case <-ctx.Done(): @@ -249,28 +247,23 @@ func loggingMiddleware(t *testing.T) func(h http.Handler) http.Handler { } } -func openBrowserRequest(url string) (int, string, error) { - certPool := x509.NewCertPool() - data, err := ioutil.ReadFile("testdata/ca.pem") +func openBrowserRequest(ctx context.Context, url string) (*http.Response, string, error) { + c, err := client() if err != nil { - return 0, "", xerrors.Errorf("could not read certificate authority: %w", err) - } - if !certPool.AppendCertsFromPEM(data) { - return 0, "", fmt.Errorf("could not append certificate data") + return nil, "", xerrors.Errorf("could not create client: %w", err) } + resp, err := getWithContext(ctx, c, url) - // we add our custom CA, otherwise the client will throw an invalid certificate error. - client := &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{RootCAs: certPool}}} - resp, err := client.Get(url) if err != nil { - return 0, "", xerrors.Errorf("could not send a request: %w", err) + return nil, "", xerrors.Errorf("could not send a request: %w", err) } + defer resp.Body.Close() b, err := ioutil.ReadAll(resp.Body) if err != nil { - return resp.StatusCode, "", xerrors.Errorf("could not read response body: %w", err) + return resp, "", xerrors.Errorf("could not read response body: %w", err) } - return resp.StatusCode, string(b), nil + return resp, string(b), nil } type authServerHandler struct { @@ -335,3 +328,11 @@ func (h *authServerHandler) serveHTTP(w http.ResponseWriter, r *http.Request) er } return nil } + +func getWithContext(ctx context.Context, c *http.Client, url string) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, err + } + return c.Do(req) +} diff --git a/server.go b/server.go index 751ef1c..0c58799 100644 --- a/server.go +++ b/server.go @@ -2,22 +2,21 @@ package oauth2cli import ( "context" - "crypto/rand" - "encoding/binary" "fmt" "net/http" "github.com/int128/listener" + shared "github.com/int128/oauth2cli/internal" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" ) func receiveCodeViaLocalServer(ctx context.Context, c *Config) (string, error) { - state, err := newOAuth2State() + state, err := shared.NewOAuth2State() if err != nil { return "", xerrors.Errorf("error while state parameter generation: %w", err) } - l, err := listener.New(expandAddresses(c.LocalServerAddress, c.LocalServerPort)) + l, err := listener.New(shared.ExpandAddresses(c.LocalServerAddress, c.LocalServerPort)) if err != nil { return "", xerrors.Errorf("error while starting a local server: %w", err) } @@ -91,21 +90,6 @@ func receiveCodeViaLocalServer(ctx context.Context, c *Config) (string, error) { return resp.code, resp.err } -func expandAddresses(address string, ports []int) (addresses []string) { - for _, port := range ports { - addresses = append(addresses, fmt.Sprintf("%s:%d", address, port)) - } - return -} - -func newOAuth2State() (string, error) { - var n uint64 - if err := binary.Read(rand.Reader, binary.LittleEndian, &n); err != nil { - return "", xerrors.Errorf("error while reading random: %w", err) - } - return fmt.Sprintf("%x", n), nil -} - type authorizationResponse struct { code string // non-empty if a valid code is received err error // non-nil if an error is received or any error occurs From eb78ad7c4fbd31bcda1662076a759427a6f17f93 Mon Sep 17 00:00:00 2001 From: Martin Vladev Date: Thu, 17 Oct 2019 15:55:35 +0300 Subject: [PATCH 2/2] fix gocilint errors --- example/implicit/main.go | 7 ++++- implicit.go | 17 +++++++++--- implicit_test.go | 36 +++++-------------------- internal/implicit/implicit.go | 43 +++++++++++++++++++++++------- internal/implicit/implicit_test.go | 3 +++ internal/shared.go | 2 ++ internal/shared_test.go | 6 ++--- 7 files changed, 67 insertions(+), 47 deletions(-) diff --git a/example/implicit/main.go b/example/implicit/main.go index e8954fa..1790b4b 100644 --- a/example/implicit/main.go +++ b/example/implicit/main.go @@ -22,6 +22,7 @@ type cmdOptions struct { func main() { var o cmdOptions + flag.StringVar(&o.clientID, "client-id", "", "OAuth Client ID") flag.StringVar(&o.localServerCert, "local-server-cert", "", "Path to a certificate file for the local server") flag.StringVar(&o.localServerKey, "local-server-key", "", "Path to a key file for the local server") @@ -33,6 +34,7 @@ Open https://console.cloud.google.com/apis/credentials and create a client. Then set the following options:`) flag.PrintDefaults() os.Exit(1) + return } @@ -40,13 +42,16 @@ Then set the following options:`) log.Printf("Certificate and key are required") flag.PrintDefaults() os.Exit(1) + return } ctx := context.Background() ready := make(chan string, 1) + var eg errgroup.Group + eg.Go(func() error { select { case url, ok := <-ready: @@ -63,7 +68,6 @@ Then set the following options:`) } }) eg.Go(func() error { - defer close(ready) token, nonce, err := oauth2cli.GeTokenIDTokenImplicitly(ctx, &implicit.ServerConfig{ LocalServerPort: []int{8000}, @@ -83,6 +87,7 @@ Then set the following options:`) log.Printf("You got a valid token: %+v\nnonce: %q", *token, nonce) return nil }) + if err := eg.Wait(); err != nil { log.Printf("error while authorization: %s", err) } diff --git a/implicit.go b/implicit.go index 41c739b..0032f96 100644 --- a/implicit.go +++ b/implicit.go @@ -3,7 +3,7 @@ package oauth2cli import ( "context" - implict_types "github.com/int128/oauth2cli/implicit" + implicit_types "github.com/int128/oauth2cli/implicit" shared "github.com/int128/oauth2cli/internal" implicit_int "github.com/int128/oauth2cli/internal/implicit" "golang.org/x/oauth2" @@ -22,17 +22,20 @@ import ( // 5. Post the URL fragment via JavaScript to a local endpoint. // 6. Return the token. // -func GetTokenImplicitly(ctx context.Context, c *implict_types.ServerConfig) (token *oauth2.Token, err error) { +func GetTokenImplicitly(ctx context.Context, c *implicit_types.ServerConfig) (token *oauth2.Token, err error) { if c.LocalServerMiddleware == nil { c.LocalServerMiddleware = shared.DefaultMiddleware } + if c.LocalServerSuccessHTML == "" { c.LocalServerSuccessHTML = DefaultLocalServerSuccessHTML } + token, _, err = implicit_int.ReceiveTokenViaLocalServer(ctx, c, []string{"token"}) if err != nil { return token, xerrors.Errorf("error while receiving token: %w", err) } + return token, err } @@ -50,17 +53,20 @@ func GetTokenImplicitly(ctx context.Context, c *implict_types.ServerConfig) (tok // // Note: it's up to the consumer to validate the id_token with the nonce value. // -func GetIDTokenImplicitly(ctx context.Context, c *implict_types.ServerConfig) (token *oauth2.Token, nonce string, err error) { +func GetIDTokenImplicitly(ctx context.Context, c *implicit_types.ServerConfig) (token *oauth2.Token, nonce string, err error) { if c.LocalServerMiddleware == nil { c.LocalServerMiddleware = shared.DefaultMiddleware } + if c.LocalServerSuccessHTML == "" { c.LocalServerSuccessHTML = DefaultLocalServerSuccessHTML } + token, nonce, err = implicit_int.ReceiveTokenViaLocalServer(ctx, c, []string{"id_token"}) if err != nil { return token, nonce, xerrors.Errorf("error while receiving token: %w", err) } + return token, nonce, err } @@ -78,16 +84,19 @@ func GetIDTokenImplicitly(ctx context.Context, c *implict_types.ServerConfig) (t // // Note: it's up to the consumer to validate the id_token with the nonce value. // -func GeTokenIDTokenImplicitly(ctx context.Context, c *implict_types.ServerConfig) (token *oauth2.Token, nonce string, err error) { +func GeTokenIDTokenImplicitly(ctx context.Context, c *implicit_types.ServerConfig) (token *oauth2.Token, nonce string, err error) { if c.LocalServerMiddleware == nil { c.LocalServerMiddleware = shared.DefaultMiddleware } + if c.LocalServerSuccessHTML == "" { c.LocalServerSuccessHTML = DefaultLocalServerSuccessHTML } + token, nonce, err = implicit_int.ReceiveTokenViaLocalServer(ctx, c, []string{"token", "id_token"}) if err != nil { return token, nonce, xerrors.Errorf("error while receiving token: %w", err) } + return token, nonce, err } diff --git a/implicit_test.go b/implicit_test.go index 2d7b127..b1a5be1 100644 --- a/implicit_test.go +++ b/implicit_test.go @@ -36,8 +36,8 @@ func TestGetTokenImplicitly(t *testing.T) { LocalServerKeyFile: "testdata/cert-key.pem", LocalServerMiddleware: loggingMiddleware(t), } + t.Run("Success", func(t *testing.T) { successfulTokenImplicitTest(t, cfg) }) - // t.Run("ErrorAuthResponse", func(t *testing.T) { errorAuthResponseTest(t, cfg) }) } type implicitAuthServerHandler struct { @@ -68,8 +68,7 @@ func (h *implicitAuthServerHandler) serveHTTP(w http.ResponseWriter, r *http.Req return xerrors.New("redirect_uri is missing") } to := h.NewAuthResponse(scope, nonce, state, redirectURI) - http.Redirect(w, r, to, 302) - + http.Redirect(w, r, to, http.StatusFound) default: http.NotFound(w, r) } @@ -77,7 +76,6 @@ func (h *implicitAuthServerHandler) serveHTTP(w http.ResponseWriter, r *http.Req } func successfulTokenImplicitTest(t *testing.T, cfg *implicit.ServerConfig) { - ctx, cancel := context.WithTimeout(context.TODO(), 1*time.Hour) defer cancel() h := implicitAuthServerHandler{ @@ -95,7 +93,9 @@ func successfulTokenImplicitTest(t *testing.T, cfg *implicit.ServerConfig) { }, } s := httptest.NewServer(&h) + defer s.Close() + openBrowserCh := make(chan string) defer close(openBrowserCh) @@ -155,7 +155,6 @@ func successfulTokenImplicitTest(t *testing.T, cfg *implicit.ServerConfig) { if err := eg.Wait(); err != nil { t.Errorf("error: %+v", err) } - } // returns a random port between 1024 and 32767 @@ -163,28 +162,7 @@ func randomPort() int { return 1024 + rand.New(rand.NewSource(time.Now().UnixNano())).Intn(31744) } -func openBrowserRequestImplicitly(ctx context.Context, url string) (int, string, error) { - c, err := client() - if err != nil { - return 0, "", xerrors.Errorf("could not create client: %w", err) - } - resp, err := getWithContext(ctx, c, url) - - if err != nil { - return 0, "", xerrors.Errorf("could not send a request: %w", err) - } - l := resp.Request.Response.Header.Get("Location") - fmt.Printf("Location: %s", l) - - defer resp.Body.Close() - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - return resp.StatusCode, "", xerrors.Errorf("could not read response body: %w", err) - } - return resp.StatusCode, string(b), nil -} - -func postFragment(ctx context.Context, postUrl string, r *http.Response) (*http.Response, error) { +func postFragment(ctx context.Context, postURL string, r *http.Response) (*http.Response, error) { c, err := client() if err != nil { return nil, xerrors.Errorf("could not create client: %w", err) @@ -193,9 +171,9 @@ func postFragment(ctx context.Context, postUrl string, r *http.Response) (*http. if err != nil { return nil, xerrors.Errorf("could not paste location url: %w", err) } - p, err := url.Parse(postUrl) + p, err := url.Parse(postURL) if err != nil { - return nil, xerrors.Errorf("could not parse postUrl: %w", err) + return nil, xerrors.Errorf("could not parse postURL: %w", err) } p.RawQuery = locationURL.Fragment diff --git a/internal/implicit/implicit.go b/internal/implicit/implicit.go index 36823ff..83847e6 100644 --- a/internal/implicit/implicit.go +++ b/internal/implicit/implicit.go @@ -43,15 +43,15 @@ type localServerHandler struct { // query get changed func (h *localServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { q := r.URL.Query() - fmt.Println(h.redirectPath) + switch { - case r.Method == "GET" && r.URL.Path == h.redirectPath && q.Get("error") != "": + case r.Method == http.MethodGet && r.URL.Path == h.redirectPath && q.Get("error") != "": h.responseCh <- h.handleErrorResponse(w, r) - case r.Method == "POST" && r.URL.Path == h.redirectPath: + case r.Method == http.MethodPost && r.URL.Path == h.redirectPath: h.responseCh <- h.handleTokenResponse(w, r) - case r.Method == "GET" && r.URL.Path == h.redirectPath: + case r.Method == http.MethodGet && r.URL.Path == h.redirectPath: h.handleRawTokenResponse(w, r) - case r.Method == "GET" && r.URL.Path == "/": + case r.Method == http.MethodGet && r.URL.Path == "/": h.handleIndex(w, r) default: http.NotFound(w, r) @@ -68,15 +68,16 @@ fetch("%s?" + content, {method: "POST"}) ` -func (h *localServerHandler) handleRawTokenResponse(w http.ResponseWriter, r *http.Request) { +func (h *localServerHandler) handleRawTokenResponse(w http.ResponseWriter, _ *http.Request) { w.Header().Add("Content-Type", "text/html") + if _, err := fmt.Fprintf(w, JSPoster, h.redirectPath); err != nil { http.Error(w, "server error", http.StatusInternalServerError) } } func (h *localServerHandler) handleIndex(w http.ResponseWriter, r *http.Request) { url := h.redirectURL() - http.Redirect(w, r, url, 302) + http.Redirect(w, r, url, http.StatusFound) } func (h *localServerHandler) handleTokenResponse(w http.ResponseWriter, r *http.Request) *AuthorizationResponse { @@ -96,6 +97,7 @@ func (h *localServerHandler) handleTokenResponse(w http.ResponseWriter, r *http. http.Error(w, "server error", 500) return &AuthorizationResponse{err: xerrors.Errorf("access_token missing in authentication response when requesting token")} } + if token.TokenType = vals.Get("token_type"); token.TokenType == "" { http.Error(w, "server error", 500) return &AuthorizationResponse{err: xerrors.Errorf("token_type missing in authentication response when requesting token")} @@ -108,21 +110,26 @@ func (h *localServerHandler) handleTokenResponse(w http.ResponseWriter, r *http. http.Error(w, "server error", 500) return &AuthorizationResponse{err: xerrors.Errorf("id_token missing in authentication response when requesting id_token")} } + vals.Set("id_token", idToken) } e := vals.Get("expires_in") expires, _ := strconv.Atoi(e) + if expires != 0 { token.Expiry = Now().Add(time.Duration(expires) * time.Second) } + token = token.WithExtra(vals) w.Header().Add("Content-Type", "text/html") + if _, err := fmt.Fprintf(w, h.config.LocalServerSuccessHTML); err != nil { http.Error(w, "server error", 500) return &AuthorizationResponse{err: xerrors.Errorf("error while writing response body: %w", err)} } + return &AuthorizationResponse{token: token, nonce: h.nonce} } @@ -131,6 +138,7 @@ func (h *localServerHandler) handleErrorResponse(w http.ResponseWriter, r *http. errorCode, errorDescription := q.Get("error"), q.Get("error_description") http.Error(w, "authorization error", 500) + return &AuthorizationResponse{err: xerrors.Errorf("authorization error from server: %s %s", errorCode, errorDescription)} } @@ -140,6 +148,7 @@ func (h *localServerHandler) hasResponse(seek string) bool { return true } } + return false } @@ -163,9 +172,11 @@ func (h *localServerHandler) redirectURL() string { "response_type": {strings.Join(h.responseTypes, " ")}, "client_id": {c.ClientID}, } + if c.RedirectURL != "" { v.Set("redirect_uri", c.RedirectURL) } + if len(c.Scopes) > 0 { v.Set("scope", strings.Join(c.Scopes, " ")) } @@ -175,12 +186,15 @@ func (h *localServerHandler) redirectURL() string { if h.nonce != "" { v.Set("nonce", h.nonce) } + if strings.Contains(c.AuthURL, "?") { buf.WriteByte('&') } else { buf.WriteByte('?') } + buf.WriteString(v.Encode()) + return buf.String() } @@ -189,19 +203,23 @@ func ReceiveTokenViaLocalServer(ctx context.Context, c *types.ServerConfig, resp if err != nil { return nil, "", xerrors.Errorf("error while state parameter generation: %w", err) } + nonce, err = shared.NewOAuth2State() if err != nil { return nil, "", xerrors.Errorf("error while nonce parameter generation: %w", err) } + l, err := listener.New(shared.ExpandAddresses(c.LocalServerAddress, c.LocalServerPort)) if err != nil { return nil, "", xerrors.Errorf("error while starting a local server: %w", err) } + defer l.Close() if c.LocalServerCertFile == "" || c.LocalServerKeyFile == "" { return nil, "", xerrors.Errorf("LocalServerCertFile and LocalServerKeyFile must be set when using implicit flow") } + var redirectPath = "implicit" l.URL.Scheme = "https" @@ -236,8 +254,12 @@ func ReceiveTokenViaLocalServer(ctx context.Context, c *types.ServerConfig, resp responseTypes: responseTypes, }), } - var resp *AuthorizationResponse - var eg errgroup.Group + + var ( + resp *AuthorizationResponse + eg errgroup.Group + ) + eg.Go(func() error { for { select { @@ -280,9 +302,12 @@ func ReceiveTokenViaLocalServer(ctx context.Context, c *types.ServerConfig, resp if err := eg.Wait(); err != nil { return nil, "", xerrors.Errorf("error while authorization: %w", err) } + if resp == nil { return nil, "", xerrors.New("no authorization response") } + ctx.Done() + return resp.token, resp.nonce, resp.err } diff --git a/internal/implicit/implicit_test.go b/internal/implicit/implicit_test.go index 992f25d..eaca9ef 100644 --- a/internal/implicit/implicit_test.go +++ b/internal/implicit/implicit_test.go @@ -15,6 +15,7 @@ func Test_localServerHandler_redirectURL(t *testing.T) { state string responseTypes []string } + tests := []struct { name string fields fields @@ -43,7 +44,9 @@ func Test_localServerHandler_redirectURL(t *testing.T) { }, }, } + for _, tt := range tests { + tt := tt t.Run(tt.name, func(t *testing.T) { h := &localServerHandler{ config: tt.fields.config, diff --git a/internal/shared.go b/internal/shared.go index 390584f..1a0878e 100644 --- a/internal/shared.go +++ b/internal/shared.go @@ -14,6 +14,7 @@ func ExpandAddresses(address string, ports []int) (addresses []string) { for _, port := range ports { addresses = append(addresses, fmt.Sprintf("%s:%d", address, port)) } + return } @@ -23,6 +24,7 @@ func NewOAuth2State() (string, error) { if err := binary.Read(rand.Reader, binary.LittleEndian, &n); err != nil { return "", xerrors.Errorf("error while reading random: %w", err) } + return fmt.Sprintf("%x", n), nil } diff --git a/internal/shared_test.go b/internal/shared_test.go index 9e0583e..11b1300 100644 --- a/internal/shared_test.go +++ b/internal/shared_test.go @@ -7,9 +7,7 @@ import ( ) func TestDefaultMiddleware(t *testing.T) { - t.Run("same handler is returned", func(t *testing.T) { - if got := DefaultMiddleware(http.DefaultServeMux); !reflect.DeepEqual(got, http.DefaultServeMux) { t.Errorf("DefaultMiddleware() = %v, want %v", got, http.DefaultServeMux) } @@ -17,9 +15,7 @@ func TestDefaultMiddleware(t *testing.T) { } func TestNewOAuth2State(t *testing.T) { - t.Run("different results are returned", func(t *testing.T) { - s1, err := NewOAuth2State() if err != nil { t.Errorf("unexpected error calling NewOAuth2State(): %v", err) @@ -40,6 +36,7 @@ func TestExpandAddresses(t *testing.T) { address string ports []int } + tests := []struct { name string args args @@ -49,6 +46,7 @@ func TestExpandAddresses(t *testing.T) { {"multiple ports port", args{"0.0.0.0", []int{80, 8080}}, []string{"0.0.0.0:80", "0.0.0.0:8080"}}, } for _, tt := range tests { + tt := tt t.Run(tt.name, func(t *testing.T) { if gotAddresses := ExpandAddresses(tt.args.address, tt.args.ports); !reflect.DeepEqual(gotAddresses, tt.wantAddresses) { t.Errorf("ExpandAddresses() = %v, want %v", gotAddresses, tt.wantAddresses)