diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d2b246..f0eef71 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added +- `rootly login` — browser-based OAuth2 authentication with PKCE (no API key needed) +- `rootly logout` — clear stored OAuth tokens +- OAuth2 auto-refresh transport using `golang.org/x/oauth2` +- Auto-append `/api` for localhost endpoints (no need to pass `--api-host=localhost:22166/api`) +- `http://` scheme auto-detection for localhost/127.0.0.1 endpoints + ### Changed +- OAuth tokens stored in `~/.rootly-cli/config.yaml` under `oauth` key (single config file) +- API client uses OAuth Bearer tokens when available, falls back to API key +- Auth-exempt commands use `Annotations["skipAuth"]` instead of hardcoded name list - Switch `oncall who` and `oncall shifts` to unified `/v1/oncalls` endpoint with richer data (escalation policy, level, user email) - Add new filter flags: `--schedule-id`, `--service-id`, `--escalation-policy-id`, `--user-id`, `--time-zone`, `--earliest` - Table output now includes Escalation Policy, Level, and Email columns @@ -15,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Fix `oncall schedules` 404 error (use correct `/v1/schedules` endpoint) +- Windows test compatibility (`USERPROFILE` alongside `HOME`) ### Removed - Removed legacy `/v1/shifts` endpoint usage and associated `Shift`/`ShiftsResult` types @@ -54,7 +65,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Homebrew tap distribution - GitHub Actions CI (lint, test, build) and release workflows -[Unreleased]: https://github.com/rootlyhq/rootly-cli/compare/v0.1.4...HEAD +[Unreleased]: https://github.com/rootlyhq/rootly-cli/compare/v0.1.5...HEAD +[0.1.5]: https://github.com/rootlyhq/rootly-cli/compare/v0.1.4...v0.1.5 [0.1.4]: https://github.com/rootlyhq/rootly-cli/compare/v0.1.3...v0.1.4 [0.1.3]: https://github.com/rootlyhq/rootly-cli/compare/v0.1.2...v0.1.3 [0.1.2]: https://github.com/rootlyhq/rootly-cli/releases/tag/v0.1.2 diff --git a/go.mod b/go.mod index 4d2cc6c..6c95a06 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/rootlyhq/rootly-go v0.8.0 github.com/spf13/cobra v1.10.2 github.com/spf13/viper v1.21.0 + golang.org/x/oauth2 v0.28.0 gopkg.in/yaml.v3 v3.0.1 ) diff --git a/go.sum b/go.sum index fc77e22..2e99bbf 100644 --- a/go.sum +++ b/go.sum @@ -92,6 +92,8 @@ github.com/woodsbury/decimal128 v1.4.0 h1:xJATj7lLu4f2oObouMt2tgGiElE5gO6mSWUjQs github.com/woodsbury/decimal128 v1.4.0/go.mod h1:BP46FUrVjVhdTbKT+XuQh2xfQaGki9LMIRJSFuh6THU= go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc= +golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= diff --git a/internal/api/client.go b/internal/api/client.go index 46be3f9..1d6d838 100644 --- a/internal/api/client.go +++ b/internal/api/client.go @@ -14,6 +14,7 @@ import ( rootly "github.com/rootlyhq/rootly-go" "github.com/rootlyhq/rootly-cli/internal/config" + "github.com/rootlyhq/rootly-cli/internal/oauth" ) // Version is set by the main package to include in User-Agent @@ -733,30 +734,60 @@ func parseIncidentDetailRelationships(incident *Incident, a *incidentDetailAttri // NewClient creates a stateless API client for CLI usage. func NewClient(cfg *config.Config) (*Client, error) { endpoint := cfg.Endpoint - if endpoint != "" && !strings.HasPrefix(endpoint, "http://") && !strings.HasPrefix(endpoint, "https://") { - endpoint = "https://" + endpoint + if endpoint != "" { + endpoint = ensureScheme(endpoint) } - // Build HTTP client with optional debug transport - httpClient := http.DefaultClient - if cfg.Debug { - transport := http.DefaultTransport - if httpClient.Transport != nil { - transport = httpClient.Transport + // Determine auth: use OAuth tokens only when no API key is set + useOAuth := false + if cfg.APIKey == "" { + if tokens, err := oauth.LoadTokens(); err == nil && tokens.AccessToken != "" { + useOAuth = true } - httpClient = &http.Client{ - Transport: &debugTransport{transport: transport}, + } + + // Build base transport + var transport http.RoundTripper + transport = http.DefaultTransport + + if cfg.Debug { + transport = &debugTransport{transport: transport} + } + + var httpClient *http.Client + if useOAuth { + authBaseURL := oauth.DeriveAuthBaseURL(cfg.Endpoint) + clientID := oauth.LoadCachedClientID() + oauthCfg := oauth.NewConfig(authBaseURL, clientID) + var err error + httpClient, err = oauth.NewHTTPClient(oauthCfg, transport, "rootly-cli/"+Version) + if err != nil { + return nil, fmt.Errorf("failed to create OAuth client: %w", err) } + } else { + httpClient = &http.Client{Transport: transport} } - client, err := rootly.NewClientWithResponses(endpoint, - rootly.WithHTTPClient(httpClient), - rootly.WithRequestEditorFn(func(ctx context.Context, req *http.Request) error { + var reqEditorFn rootly.RequestEditorFn + if useOAuth { + // OAuth transport handles Authorization header + reqEditorFn = func(ctx context.Context, req *http.Request) error { + req.Header.Set("Content-Type", "application/vnd.api+json") + req.Header.Set("User-Agent", "rootly-cli/"+Version) + return nil + } + } else { + reqEditorFn = func(ctx context.Context, req *http.Request) error { req.Header.Set("Authorization", "Bearer "+cfg.APIKey) req.Header.Set("Content-Type", "application/vnd.api+json") req.Header.Set("User-Agent", "rootly-cli/"+Version) return nil - }), + } + } + + client, err := rootly.NewClientWithResponses(endpoint, + rootly.WithHTTPClient(httpClient), + rootly.WithRequestEditorFn(reqEditorFn), ) if err != nil { return nil, fmt.Errorf("failed to create rootly client: %w", err) @@ -764,12 +795,30 @@ func NewClient(cfg *config.Config) (*Client, error) { return &Client{ client: client, - endpoint: cfg.Endpoint, + endpoint: endpoint, apiKey: cfg.APIKey, httpClient: httpClient, }, nil } +// ensureScheme adds a scheme if missing, using http for localhost/127.0.0.1. +// For localhost without an explicit path, it also appends /api since the +// Rails monolith serves the API under /api/v1 rather than /v1. +func ensureScheme(endpoint string) string { + if strings.HasPrefix(endpoint, "http://") || strings.HasPrefix(endpoint, "https://") { + return endpoint + } + if strings.HasPrefix(endpoint, "localhost") || strings.HasPrefix(endpoint, "127.0.0.1") { + result := "http://" + endpoint + // Auto-append /api for localhost if no path is present + if !strings.Contains(endpoint, "/") { + result += "/api" + } + return result + } + return "https://" + endpoint +} + func (c *Client) ValidateAPIKey(ctx context.Context) error { // Use /v1/users/me endpoint to validate the API key resp, err := c.client.GetCurrentUserWithResponse(ctx) @@ -865,9 +914,6 @@ func (c *Client) ListIncidentsCLI(ctx context.Context, page, pageSize int, sort // Build URL with query parameters baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } url := fmt.Sprintf("%s/v1/incidents?page[number]=%d&page[size]=%d", baseURL, page, pageSize) if sort != "" { @@ -967,9 +1013,7 @@ func (c *Client) ListIncidentsCLI(ctx context.Context, page, pageSize int, sort func (c *Client) GetIncidentByID(ctx context.Context, id string) (*Incident, error) { // Build URL baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } + url := fmt.Sprintf("%s/v1/incidents/%s?include=roles,causes,incident_types,functionalities,services,environments,groups,user", baseURL, id) req, err := http.NewRequestWithContext(ctx, "GET", url, http.NoBody) if err != nil { @@ -1042,9 +1086,7 @@ func (c *Client) CreateIncident(ctx context.Context, title string, opts map[stri // Build URL baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } + url := fmt.Sprintf("%s/v1/incidents", baseURL) req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(bodyBytes))) @@ -1120,9 +1162,7 @@ func (c *Client) UpdateIncident(ctx context.Context, id string, opts map[string] // Build URL baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } + url := fmt.Sprintf("%s/v1/incidents/%s", baseURL, id) req, err := http.NewRequestWithContext(ctx, "PUT", url, strings.NewReader(string(bodyBytes))) @@ -1173,9 +1213,7 @@ func (c *Client) UpdateIncident(ctx context.Context, id string, opts map[string] func (c *Client) DeleteIncident(ctx context.Context, id string) error { // Build URL baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } + url := fmt.Sprintf("%s/v1/incidents/%s", baseURL, id) req, err := http.NewRequestWithContext(ctx, "DELETE", url, http.NoBody) @@ -1292,9 +1330,6 @@ func (c *Client) ListAlertsCLI(ctx context.Context, page, pageSize int, sort str // Build URL with query parameters baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } url := fmt.Sprintf("%s/v1/alerts?page[number]=%d&page[size]=%d", baseURL, page, pageSize) if sort != "" { @@ -1394,9 +1429,7 @@ func (c *Client) ListAlertsCLI(ctx context.Context, page, pageSize int, sort str func (c *Client) GetAlertByID(ctx context.Context, id string) (*Alert, error) { // Build URL baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } + url := fmt.Sprintf("%s/v1/alerts/%s?include=services,environments,groups,responders,alert_urgency", baseURL, id) req, err := http.NewRequestWithContext(ctx, "GET", url, http.NoBody) if err != nil { @@ -1632,9 +1665,7 @@ func (c *Client) CreateAlertCLI(ctx context.Context, summary string, opts map[st // Build URL baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } + url := fmt.Sprintf("%s/v1/alerts", baseURL) req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(bodyBytes))) @@ -1713,9 +1744,7 @@ func (c *Client) UpdateAlertCLI(ctx context.Context, id string, opts map[string] // Build URL baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } + url := fmt.Sprintf("%s/v1/alerts/%s", baseURL, id) req, err := http.NewRequestWithContext(ctx, "PUT", url, strings.NewReader(string(bodyBytes))) @@ -1766,9 +1795,7 @@ func (c *Client) UpdateAlertCLI(ctx context.Context, id string, opts map[string] func (c *Client) AcknowledgeAlertCLI(ctx context.Context, id string) error { // Build URL baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } + url := fmt.Sprintf("%s/v1/alerts/%s/acknowledge", baseURL, id) req, err := http.NewRequestWithContext(ctx, "POST", url, http.NoBody) @@ -1804,9 +1831,7 @@ func (c *Client) AcknowledgeAlertCLI(ctx context.Context, id string) error { func (c *Client) ResolveAlertCLI(ctx context.Context, id, resolutionMessage string, resolveIncidents bool) error { // Build URL baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } + url := fmt.Sprintf("%s/v1/alerts/%s/resolve", baseURL, id) var reqBody io.Reader = http.NoBody @@ -1876,9 +1901,6 @@ func (c *Client) ListServicesCLI(ctx context.Context, page, pageSize int, sort s // Build URL with query parameters baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } url := fmt.Sprintf("%s/v1/services?page[number]=%d&page[size]=%d", baseURL, page, pageSize) if sort != "" { @@ -1981,9 +2003,7 @@ func (c *Client) ListServicesCLI(ctx context.Context, page, pageSize int, sort s func (c *Client) GetServiceByID(ctx context.Context, id string) (*Service, error) { // Build URL baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } + url := fmt.Sprintf("%s/v1/services/%s?include=owner_group", baseURL, id) req, err := http.NewRequestWithContext(ctx, "GET", url, http.NoBody) if err != nil { @@ -2107,9 +2127,7 @@ func (c *Client) CreateService(ctx context.Context, name string, opts map[string // Build URL baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } + url := fmt.Sprintf("%s/v1/services", baseURL) req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(bodyBytes))) @@ -2206,9 +2224,7 @@ func (c *Client) UpdateService(ctx context.Context, id string, opts map[string]s // Build URL baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } + url := fmt.Sprintf("%s/v1/services/%s", baseURL, id) req, err := http.NewRequestWithContext(ctx, "PUT", url, strings.NewReader(string(bodyBytes))) @@ -2283,9 +2299,7 @@ func (c *Client) UpdateService(ctx context.Context, id string, opts map[string]s func (c *Client) DeleteService(ctx context.Context, id string) error { // Build URL baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } + url := fmt.Sprintf("%s/v1/services/%s", baseURL, id) req, err := http.NewRequestWithContext(ctx, "DELETE", url, http.NoBody) @@ -2329,9 +2343,6 @@ func (c *Client) ListTeamsCLI(ctx context.Context, page, pageSize int, sort stri // Build URL with query parameters baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } url := fmt.Sprintf("%s/v1/teams?page[number]=%d&page[size]=%d", baseURL, page, pageSize) if sort != "" { @@ -2436,9 +2447,7 @@ func (c *Client) ListTeamsCLI(ctx context.Context, page, pageSize int, sort stri func (c *Client) GetTeamByID(ctx context.Context, id string) (*Team, error) { // Build URL baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } + url := fmt.Sprintf("%s/v1/teams/%s?include=users", baseURL, id) req, err := http.NewRequestWithContext(ctx, "GET", url, http.NoBody) if err != nil { @@ -2563,9 +2572,7 @@ func (c *Client) CreateTeam(ctx context.Context, name string, opts map[string]st // Build URL baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } + url := fmt.Sprintf("%s/v1/teams", baseURL) req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(bodyBytes))) @@ -2655,9 +2662,7 @@ func (c *Client) UpdateTeam(ctx context.Context, id string, opts map[string]stri // Build URL baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } + url := fmt.Sprintf("%s/v1/teams/%s", baseURL, id) req, err := http.NewRequestWithContext(ctx, "PATCH", url, strings.NewReader(string(bodyBytes))) @@ -2736,9 +2741,7 @@ func (c *Client) UpdateTeam(ctx context.Context, id string, opts map[string]stri func (c *Client) DeleteTeam(ctx context.Context, id string) error { // Build URL baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } + url := fmt.Sprintf("%s/v1/teams/%s", baseURL, id) req, err := http.NewRequestWithContext(ctx, "DELETE", url, http.NoBody) @@ -2838,9 +2841,6 @@ func (c *Client) ListSchedulesCLI(ctx context.Context, page, pageSize int, filte // Build URL with query parameters baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } url := fmt.Sprintf("%s/v1/schedules?page[number]=%d&page[size]=%d", baseURL, page, pageSize) @@ -2929,9 +2929,6 @@ func (c *Client) ListSchedulesCLI(ctx context.Context, page, pageSize int, filte // ListOnCallsCLI lists on-call entries using the unified /v1/oncalls endpoint. func (c *Client) ListOnCallsCLI(ctx context.Context, params OnCallsParams) (*OnCallsResult, error) { baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } url := fmt.Sprintf("%s/v1/oncalls?", baseURL) @@ -3126,9 +3123,7 @@ func (c *Client) CreatePulseCLI(ctx context.Context, summary string, opts PulseO // Build URL baseURL := c.endpoint - if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { - baseURL = "https://" + baseURL - } + url := fmt.Sprintf("%s/v1/pulses", baseURL) req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(bodyBytes))) diff --git a/internal/cmd/alerts/alerts.go b/internal/cmd/alerts/alerts.go index 885703a..faa1e3d 100644 --- a/internal/cmd/alerts/alerts.go +++ b/internal/cmd/alerts/alerts.go @@ -8,6 +8,7 @@ import ( "github.com/rootlyhq/rootly-cli/internal/api" "github.com/rootlyhq/rootly-cli/internal/config" + "github.com/rootlyhq/rootly-cli/internal/oauth" ) // AlertsCmd is the parent command for all alert operations @@ -40,7 +41,9 @@ var AlertsCmd = &cobra.Command{ func getAPIClient() (*api.Client, error) { token := viper.GetString("api_key") if token == "" { - return nil, fmt.Errorf("API key required: set ROOTLY_API_KEY or add api_key to ~/.rootly-cli/config.yaml") + if !oauth.HasTokens() { + return nil, fmt.Errorf("authentication required: run 'rootly login' or set ROOTLY_API_KEY") + } } endpoint := viper.GetString("api_host") if endpoint == "" { diff --git a/internal/cmd/alerts/cmd_test.go b/internal/cmd/alerts/cmd_test.go index 31e0ea5..011f5aa 100644 --- a/internal/cmd/alerts/cmd_test.go +++ b/internal/cmd/alerts/cmd_test.go @@ -155,6 +155,9 @@ func TestRunListJSON(t *testing.T) { func TestRunListNoToken(t *testing.T) { viper.Reset() defer viper.Reset() + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) cmd := newTestCmd() cmd.Flags().Int("page", 1, "") @@ -167,8 +170,8 @@ func TestRunListNoToken(t *testing.T) { if err == nil { t.Fatal("expected error when no API token") } - if !strings.Contains(err.Error(), "API key required") { - t.Errorf("expected 'API key required' error, got: %v", err) + if !strings.Contains(err.Error(), "authentication required") { + t.Errorf("expected 'authentication required' error, got: %v", err) } } @@ -360,6 +363,9 @@ func TestRunAckSuccess(t *testing.T) { func TestRunAckNoToken(t *testing.T) { viper.Reset() defer viper.Reset() + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) cmd := newTestCmd() @@ -367,8 +373,8 @@ func TestRunAckNoToken(t *testing.T) { if err == nil { t.Fatal("expected error when no API token") } - if !strings.Contains(err.Error(), "API key required") { - t.Errorf("expected 'API key required' error, got: %v", err) + if !strings.Contains(err.Error(), "authentication required") { + t.Errorf("expected 'authentication required' error, got: %v", err) } } @@ -464,6 +470,9 @@ func TestRunResolveWithIncidents(t *testing.T) { func TestRunResolveNoToken(t *testing.T) { viper.Reset() defer viper.Reset() + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) cmd := newTestCmd() cmd.Flags().String("message", "", "") @@ -473,8 +482,8 @@ func TestRunResolveNoToken(t *testing.T) { if err == nil { t.Fatal("expected error when no API token") } - if !strings.Contains(err.Error(), "API key required") { - t.Errorf("expected 'API key required' error, got: %v", err) + if !strings.Contains(err.Error(), "authentication required") { + t.Errorf("expected 'authentication required' error, got: %v", err) } } diff --git a/internal/cmd/auth/cmd_test.go b/internal/cmd/auth/cmd_test.go new file mode 100644 index 0000000..3f748e3 --- /dev/null +++ b/internal/cmd/auth/cmd_test.go @@ -0,0 +1,55 @@ +package auth + +import ( + "bytes" + "os" + "testing" + "time" + + "github.com/rootlyhq/rootly-cli/internal/oauth" +) + +func TestLogoutCmd(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + // Save tokens first + tokens := &oauth.TokenData{ + AccessToken: "test", + RefreshToken: "test", + ExpiresAt: time.Now().Add(time.Hour), + } + oauth.SaveTokens(tokens) + + // Run logout + var buf bytes.Buffer + LogoutCmd.SetOut(&buf) + LogoutCmd.SetErr(&buf) + LogoutCmd.SetArgs([]string{}) + + if err := LogoutCmd.Execute(); err != nil { + t.Fatalf("logout failed: %v", err) + } + + // Verify tokens cleared + if _, err := oauth.LoadTokens(); !os.IsNotExist(err) { + t.Errorf("expected tokens to be cleared, got err: %v", err) + } +} + +func TestLogoutCmd_NoTokens(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + var buf bytes.Buffer + LogoutCmd.SetOut(&buf) + LogoutCmd.SetErr(&buf) + LogoutCmd.SetArgs([]string{}) + + // Should not error when no tokens exist + if err := LogoutCmd.Execute(); err != nil { + t.Fatalf("logout with no tokens should not fail: %v", err) + } +} diff --git a/internal/cmd/auth/login.go b/internal/cmd/auth/login.go new file mode 100644 index 0000000..1e61000 --- /dev/null +++ b/internal/cmd/auth/login.go @@ -0,0 +1,230 @@ +package auth + +import ( + "context" + "errors" + "fmt" + "html" + "net" + "net/http" + "os/exec" + "runtime" + "time" + + "github.com/spf13/cobra" + "github.com/spf13/viper" + "golang.org/x/oauth2" + + "github.com/rootlyhq/rootly-cli/internal/config" + xoauth "github.com/rootlyhq/rootly-cli/internal/oauth" +) + +var LoginCmd = &cobra.Command{ + Use: "login", + Short: "Authenticate with Rootly via browser-based OAuth2", + Long: `Opens your browser to authenticate with Rootly using OAuth2 Authorization Code + PKCE. + +No configuration is needed — just run "rootly login" and follow the browser prompts. +By default, connects to api.rootly.com. Use --api-host to target a different environment.`, + Example: ` # Login to Rootly (production) + rootly login + + # Login to a local dev server + rootly login --api-host=localhost:22166`, + Annotations: map[string]string{"skipAuth": "true"}, + RunE: runLogin, +} + +func init() { + LoginCmd.Flags().String("client-id", "", "Override OAuth2 client ID (for debugging)") + _ = LoginCmd.Flags().MarkHidden("client-id") +} + +// resolveClientID returns a client ID, registering dynamically if needed. +func resolveClientID(ctx context.Context, authBaseURL string, cmd *cobra.Command) (string, error) { + // Allow explicit override for debugging + if clientID, _ := cmd.Flags().GetString("client-id"); clientID != "" { + return clientID, nil + } + + // Check cached client_id + if clientID := xoauth.LoadCachedClientID(); clientID != "" { + return clientID, nil + } + + // Register dynamically + return registerAndCache(ctx, authBaseURL, cmd) +} + +// registerAndCache calls POST /oauth/register and saves the client_id. +func registerAndCache(ctx context.Context, authBaseURL string, cmd *cobra.Command) (string, error) { + _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Registering OAuth client...\n") + clientID, err := xoauth.RegisterClient(ctx, authBaseURL) + if err != nil { + return "", err + } + if err := xoauth.SaveClientID(clientID); err != nil { + return "", fmt.Errorf("failed to cache client ID: %w", err) + } + return clientID, nil +} + +// httpClientWithTimeout is used for pre-flight checks (HEAD) and registration. +var httpClientWithTimeout = &http.Client{Timeout: 10 * time.Second} + +func runLogin(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + apiHost := viper.GetString("api_host") + if apiHost == "" { + apiHost = config.DefaultEndpoint + } + + authBaseURL := xoauth.DeriveAuthBaseURL(apiHost) + + clientID, err := resolveClientID(ctx, authBaseURL, cmd) + if err != nil { + return err + } + + tok, err := doOAuthFlow(ctx, cmd, authBaseURL, clientID) + if err != nil { + // If authorize returned 404, the cached client_id may be stale — re-register once + if isAuthorize404(err) { + _, _ = fmt.Fprintf(cmd.OutOrStderr(), "OAuth client not found, re-registering...\n") + _ = xoauth.ClearClientID() + clientID, regErr := registerAndCache(ctx, authBaseURL, cmd) + if regErr != nil { + return regErr + } + tok, err = doOAuthFlow(ctx, cmd, authBaseURL, clientID) + if err != nil { + return err + } + } else { + return err + } + } + + if err := xoauth.SaveOAuth2Token(tok); err != nil { + return fmt.Errorf("failed to save tokens: %w", err) + } + + _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Login successful! Tokens saved.\n") + return nil +} + +// errAuthorize404 is returned when the authorize endpoint returns 404. +var errAuthorize404 = errors.New("authorization endpoint returned 404") + +func isAuthorize404(err error) bool { + return errors.Is(err, errAuthorize404) +} + +func doOAuthFlow(ctx context.Context, cmd *cobra.Command, authBaseURL, clientID string) (*oauth2.Token, error) { + cfg := xoauth.NewConfig(authBaseURL, clientID) + + verifier := oauth2.GenerateVerifier() + + state, err := xoauth.GenerateState() + if err != nil { + return nil, fmt.Errorf("failed to generate state: %w", err) + } + + // Build authorization URL with PKCE (S256 challenge derived from verifier) + authURL := cfg.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier)) + + // Pre-check: HEAD the authorize URL to detect 404 (stale client_id) + // Done before starting the callback server to avoid binding a port unnecessarily. + headReq, headReqErr := http.NewRequestWithContext(ctx, http.MethodHead, authURL, http.NoBody) + if headReqErr == nil { + if headResp, headErr := httpClientWithTimeout.Do(headReq); headErr == nil { + _ = headResp.Body.Close() + if headResp.StatusCode == http.StatusNotFound { + return nil, errAuthorize404 + } + } + } + + // Channel to receive the authorization code + codeCh := make(chan string, 1) + errCh := make(chan error, 1) + + // Start local callback server + mux := http.NewServeMux() + mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("state") != state { + errCh <- fmt.Errorf("state mismatch") + http.Error(w, "State mismatch", http.StatusBadRequest) + return + } + if errMsg := r.URL.Query().Get("error"); errMsg != "" { + desc := r.URL.Query().Get("error_description") + errCh <- fmt.Errorf("authorization error: %s — %s", errMsg, desc) + _, _ = fmt.Fprintf(w, "

Authorization Failed

%s

You can close this window.

", html.EscapeString(desc)) + return + } + code := r.URL.Query().Get("code") + if code == "" { + errCh <- fmt.Errorf("no code in callback") + http.Error(w, "Missing code", http.StatusBadRequest) + return + } + codeCh <- code + _, _ = fmt.Fprint(w, "

Login Successful!

You can close this window and return to the terminal.

") + }) + + lc := net.ListenConfig{} + listener, err := lc.Listen(ctx, "tcp", "localhost:"+xoauth.CallbackPort) + if err != nil { + return nil, fmt.Errorf("failed to start callback server on port %s: %w", xoauth.CallbackPort, err) + } + + server := &http.Server{Handler: mux} + go func() { _ = server.Serve(listener) }() + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = server.Shutdown(shutdownCtx) + }() + + _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Opening browser for authentication...\n") + _, _ = fmt.Fprintf(cmd.OutOrStderr(), "If the browser doesn't open, visit:\n%s\n\n", authURL) + + if err := openBrowser(ctx, authURL); err != nil { + _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Failed to open browser: %v\n", err) + } + + _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Waiting for authorization...\n") + + // Wait for callback + var code string + select { + case code = <-codeCh: + case err := <-errCh: + return nil, err + case <-time.After(5 * time.Minute): + return nil, fmt.Errorf("login timed out after 5 minutes") + } + + // Exchange code for tokens + tok, err := xoauth.ExchangeCode(ctx, cfg, code, verifier) + if err != nil { + return nil, err + } + + return tok, nil +} + +func openBrowser(ctx context.Context, url string) error { + switch runtime.GOOS { + case "darwin": + return exec.CommandContext(ctx, "open", url).Start() + case "linux": + return exec.CommandContext(ctx, "xdg-open", url).Start() + case "windows": + return exec.CommandContext(ctx, "rundll32", "url.dll,FileProtocolHandler", url).Start() + default: + return fmt.Errorf("unsupported platform") + } +} diff --git a/internal/cmd/auth/logout.go b/internal/cmd/auth/logout.go new file mode 100644 index 0000000..368d110 --- /dev/null +++ b/internal/cmd/auth/logout.go @@ -0,0 +1,22 @@ +package auth + +import ( + "fmt" + + "github.com/spf13/cobra" + + "github.com/rootlyhq/rootly-cli/internal/oauth" +) + +var LogoutCmd = &cobra.Command{ + Use: "logout", + Short: "Clear stored OAuth2 tokens", + Annotations: map[string]string{"skipAuth": "true"}, + RunE: func(cmd *cobra.Command, args []string) error { + if err := oauth.ClearTokens(); err != nil { + return fmt.Errorf("failed to clear tokens: %w", err) + } + _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Logged out. OAuth tokens cleared.\n") + return nil + }, +} diff --git a/internal/cmd/auth_register.go b/internal/cmd/auth_register.go new file mode 100644 index 0000000..74d0b5f --- /dev/null +++ b/internal/cmd/auth_register.go @@ -0,0 +1,8 @@ +package cmd + +import "github.com/rootlyhq/rootly-cli/internal/cmd/auth" + +func init() { + rootCmd.AddCommand(auth.LoginCmd) + rootCmd.AddCommand(auth.LogoutCmd) +} diff --git a/internal/cmd/completion.go b/internal/cmd/completion.go index 6b4949f..9b80b2c 100644 --- a/internal/cmd/completion.go +++ b/internal/cmd/completion.go @@ -28,6 +28,7 @@ PowerShell: DisableFlagsInUseLine: true, ValidArgs: []string{"bash", "zsh", "fish", "powershell"}, Args: cobra.MatchAll(cobra.ExactArgs(1), cobra.OnlyValidArgs), + Annotations: map[string]string{"skipAuth": "true"}, RunE: func(cmd *cobra.Command, args []string) error { switch args[0] { case "bash": diff --git a/internal/cmd/incidents/cmd_test.go b/internal/cmd/incidents/cmd_test.go index 3f2e3a1..b1edbc7 100644 --- a/internal/cmd/incidents/cmd_test.go +++ b/internal/cmd/incidents/cmd_test.go @@ -383,6 +383,9 @@ func TestRunUpdateNoFlags(t *testing.T) { func TestRunListNoToken(t *testing.T) { viper.Reset() defer viper.Reset() + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) cmd := newTestCmd() cmd.Flags().Int("page", 1, "") @@ -395,8 +398,8 @@ func TestRunListNoToken(t *testing.T) { if err == nil { t.Fatal("expected error when no API token") } - if !strings.Contains(err.Error(), "API key required") { - t.Errorf("expected 'API key required' error, got: %v", err) + if !strings.Contains(err.Error(), "authentication required") { + t.Errorf("expected 'authentication required' error, got: %v", err) } } @@ -484,6 +487,9 @@ func TestRunDeleteNoConfirmNonInteractive(t *testing.T) { func TestRunDeleteNoToken(t *testing.T) { viper.Reset() defer viper.Reset() + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) cmd := newTestCmd() cmd.Flags().BoolP("yes", "y", false, "") @@ -493,8 +499,8 @@ func TestRunDeleteNoToken(t *testing.T) { if err == nil { t.Fatal("expected error when no API token") } - if !strings.Contains(err.Error(), "API key required") { - t.Errorf("expected 'API key required' error, got: %v", err) + if !strings.Contains(err.Error(), "authentication required") { + t.Errorf("expected 'authentication required' error, got: %v", err) } } diff --git a/internal/cmd/incidents/incidents.go b/internal/cmd/incidents/incidents.go index 87bc892..b6ecf23 100644 --- a/internal/cmd/incidents/incidents.go +++ b/internal/cmd/incidents/incidents.go @@ -8,6 +8,7 @@ import ( "github.com/rootlyhq/rootly-cli/internal/api" "github.com/rootlyhq/rootly-cli/internal/config" + "github.com/rootlyhq/rootly-cli/internal/oauth" ) // IncidentsCmd is the parent command for all incident operations @@ -37,7 +38,9 @@ var IncidentsCmd = &cobra.Command{ func getAPIClient() (*api.Client, error) { token := viper.GetString("api_key") if token == "" { - return nil, fmt.Errorf("API key required: set ROOTLY_API_KEY or add api_key to ~/.rootly-cli/config.yaml") + if !oauth.HasTokens() { + return nil, fmt.Errorf("authentication required: run 'rootly login' or set ROOTLY_API_KEY") + } } endpoint := viper.GetString("api_host") if endpoint == "" { diff --git a/internal/cmd/oncall/cmd_test.go b/internal/cmd/oncall/cmd_test.go index 6167c8e..6a1a6fd 100644 --- a/internal/cmd/oncall/cmd_test.go +++ b/internal/cmd/oncall/cmd_test.go @@ -202,6 +202,9 @@ func TestRunListPagination(t *testing.T) { func TestRunListNoToken(t *testing.T) { viper.Reset() defer viper.Reset() + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) cmd := newTestCmd() cmd.Flags().Int("page", 1, "") @@ -211,8 +214,8 @@ func TestRunListNoToken(t *testing.T) { if err == nil { t.Fatal("expected error when no API token") } - if !strings.Contains(err.Error(), "API key required") { - t.Errorf("expected 'API key required' error, got: %v", err) + if !strings.Contains(err.Error(), "authentication required") { + t.Errorf("expected 'authentication required' error, got: %v", err) } } @@ -564,6 +567,9 @@ func TestRunWhoWithScheduleFilter(t *testing.T) { func TestRunShiftsNoToken(t *testing.T) { viper.Reset() defer viper.Reset() + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) cmd := newTestCmd() cmd.Flags().Int("days", 7, "") @@ -578,14 +584,17 @@ func TestRunShiftsNoToken(t *testing.T) { if err == nil { t.Fatal("expected error when no API token") } - if !strings.Contains(err.Error(), "API key required") { - t.Errorf("expected 'API key required' error, got: %v", err) + if !strings.Contains(err.Error(), "authentication required") { + t.Errorf("expected 'authentication required' error, got: %v", err) } } func TestRunWhoNoToken(t *testing.T) { viper.Reset() defer viper.Reset() + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) cmd := newTestCmd() cmd.Flags().String("schedule-id", "", "") diff --git a/internal/cmd/oncall/oncall.go b/internal/cmd/oncall/oncall.go index 418d885..3587346 100644 --- a/internal/cmd/oncall/oncall.go +++ b/internal/cmd/oncall/oncall.go @@ -8,6 +8,7 @@ import ( "github.com/rootlyhq/rootly-cli/internal/api" "github.com/rootlyhq/rootly-cli/internal/config" + "github.com/rootlyhq/rootly-cli/internal/oauth" ) // OncallCmd is the parent command for all on-call operations @@ -40,7 +41,9 @@ Note: Schedules are managed in the Rootly UI. This command provides read-only ac func getAPIClient() (*api.Client, error) { token := viper.GetString("api_key") if token == "" { - return nil, fmt.Errorf("API key required: set ROOTLY_API_KEY or add api_key to ~/.rootly-cli/config.yaml") + if !oauth.HasTokens() { + return nil, fmt.Errorf("authentication required: run 'rootly login' or set ROOTLY_API_KEY") + } } endpoint := viper.GetString("api_host") if endpoint == "" { diff --git a/internal/cmd/pulse/cmd_test.go b/internal/cmd/pulse/cmd_test.go index 308d89d..f381b88 100644 --- a/internal/cmd/pulse/cmd_test.go +++ b/internal/cmd/pulse/cmd_test.go @@ -192,6 +192,9 @@ func TestRunCreateNoSummary(t *testing.T) { func TestRunCreateNoToken(t *testing.T) { viper.Reset() defer viper.Reset() + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) cmd := newTestCmd() cmd.Flags().StringP("labels", "l", "", "") @@ -205,8 +208,8 @@ func TestRunCreateNoToken(t *testing.T) { if err == nil { t.Fatal("expected error when no API key") } - if !strings.Contains(err.Error(), "API key required") { - t.Errorf("expected 'API key required' error, got: %v", err) + if !strings.Contains(err.Error(), "authentication required") { + t.Errorf("expected 'authentication required' error, got: %v", err) } } @@ -498,6 +501,9 @@ func TestRunRunNoArgs(t *testing.T) { func TestRunRunNoAPIKey(t *testing.T) { viper.Reset() defer viper.Reset() + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) cmd := newTestCmd() cmd.Flags().String("summary", "", "") @@ -511,8 +517,8 @@ func TestRunRunNoAPIKey(t *testing.T) { if err == nil { t.Fatal("expected error when no API key") } - if !strings.Contains(err.Error(), "API key required") { - t.Errorf("expected 'API key required' error, got: %v", err) + if !strings.Contains(err.Error(), "authentication required") { + t.Errorf("expected 'authentication required' error, got: %v", err) } } diff --git a/internal/cmd/pulse/pulse.go b/internal/cmd/pulse/pulse.go index 6a1dfd9..2ad104f 100644 --- a/internal/cmd/pulse/pulse.go +++ b/internal/cmd/pulse/pulse.go @@ -8,6 +8,7 @@ import ( "github.com/rootlyhq/rootly-cli/internal/api" "github.com/rootlyhq/rootly-cli/internal/config" + "github.com/rootlyhq/rootly-cli/internal/oauth" ) // PulseCmd is the parent command for pulse operations @@ -34,7 +35,9 @@ var PulseCmd = &cobra.Command{ func getAPIClient() (*api.Client, error) { token := viper.GetString("api_key") if token == "" { - return nil, fmt.Errorf("API key required: set ROOTLY_API_KEY or add api_key to ~/.rootly-cli/config.yaml") + if !oauth.HasTokens() { + return nil, fmt.Errorf("authentication required: run 'rootly login' or set ROOTLY_API_KEY") + } } endpoint := viper.GetString("api_host") if endpoint == "" { diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 14a91b2..41fba8b 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -43,10 +43,21 @@ Or use a config file at ~/.rootly-cli/config.yaml: SilenceUsage: true, SilenceErrors: true, PersistentPreRunE: func(cmd *cobra.Command, args []string) error { - // Bind flags to viper + // Bind flags to viper (both local and inherited persistent flags) if err := viper.BindPFlags(cmd.Flags()); err != nil { return err } + if err := viper.BindPFlags(cmd.InheritedFlags()); err != nil { + return err + } + + // Explicitly bind hyphenated flags to underscore viper keys + if f := cmd.Flag("api-host"); f != nil && f.Changed { + viper.Set("api_host", f.Value.String()) + } + if f := cmd.Flag("api-key"); f != nil && f.Changed { + viper.Set("api_key", f.Value.String()) + } // Configure config file viper.SetConfigName("config") @@ -72,9 +83,8 @@ Or use a config file at ~/.rootly-cli/config.yaml: // Config file not found is OK - we'll use flags/env vars } - // Skip auth validation for commands that don't need it - cmdName := cmd.Name() - if cmdName == "version" || cmdName == "completion" || cmdName == "help" { + // Skip auth validation for commands that opt out + if cmd.Annotations["skipAuth"] == "true" { return nil } diff --git a/internal/cmd/services/cmd_test.go b/internal/cmd/services/cmd_test.go index d856f14..d2fe6ed 100644 --- a/internal/cmd/services/cmd_test.go +++ b/internal/cmd/services/cmd_test.go @@ -152,6 +152,9 @@ func TestRunListJSON(t *testing.T) { func TestRunListNoToken(t *testing.T) { viper.Reset() defer viper.Reset() + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) cmd := newTestCmd() cmd.Flags().Int("page", 1, "") @@ -164,8 +167,8 @@ func TestRunListNoToken(t *testing.T) { if err == nil { t.Fatal("expected error when no API token") } - if !strings.Contains(err.Error(), "API key required") { - t.Errorf("expected 'API key required' error, got: %v", err) + if !strings.Contains(err.Error(), "authentication required") { + t.Errorf("expected 'authentication required' error, got: %v", err) } } @@ -377,6 +380,9 @@ func TestRunDeleteNoConfirmNonInteractive(t *testing.T) { func TestRunDeleteNoToken(t *testing.T) { viper.Reset() defer viper.Reset() + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) cmd := newTestCmd() cmd.Flags().BoolP("yes", "y", false, "") @@ -386,8 +392,8 @@ func TestRunDeleteNoToken(t *testing.T) { if err == nil { t.Fatal("expected error when no API token") } - if !strings.Contains(err.Error(), "API key required") { - t.Errorf("expected 'API key required' error, got: %v", err) + if !strings.Contains(err.Error(), "authentication required") { + t.Errorf("expected 'authentication required' error, got: %v", err) } } diff --git a/internal/cmd/services/services.go b/internal/cmd/services/services.go index 4558e91..fbc7b6e 100644 --- a/internal/cmd/services/services.go +++ b/internal/cmd/services/services.go @@ -8,6 +8,7 @@ import ( "github.com/rootlyhq/rootly-cli/internal/api" "github.com/rootlyhq/rootly-cli/internal/config" + "github.com/rootlyhq/rootly-cli/internal/oauth" ) // ServicesCmd is the parent command for all service operations @@ -37,7 +38,9 @@ var ServicesCmd = &cobra.Command{ func getAPIClient() (*api.Client, error) { token := viper.GetString("api_key") if token == "" { - return nil, fmt.Errorf("API key required: set ROOTLY_API_KEY or add api_key to ~/.rootly-cli/config.yaml") + if !oauth.HasTokens() { + return nil, fmt.Errorf("authentication required: run 'rootly login' or set ROOTLY_API_KEY") + } } endpoint := viper.GetString("api_host") if endpoint == "" { diff --git a/internal/cmd/teams/cmd_test.go b/internal/cmd/teams/cmd_test.go index 56f5b0e..0440a2b 100644 --- a/internal/cmd/teams/cmd_test.go +++ b/internal/cmd/teams/cmd_test.go @@ -150,6 +150,9 @@ func TestRunListJSON(t *testing.T) { func TestRunListNoToken(t *testing.T) { viper.Reset() defer viper.Reset() + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) cmd := newTestCmd() cmd.Flags().Int("page", 1, "") @@ -161,8 +164,8 @@ func TestRunListNoToken(t *testing.T) { if err == nil { t.Fatal("expected error when no API token") } - if !strings.Contains(err.Error(), "API key required") { - t.Errorf("expected 'API key required' error, got: %v", err) + if !strings.Contains(err.Error(), "authentication required") { + t.Errorf("expected 'authentication required' error, got: %v", err) } } @@ -374,6 +377,9 @@ func TestRunDeleteNoConfirmNonInteractive(t *testing.T) { func TestRunDeleteNoToken(t *testing.T) { viper.Reset() defer viper.Reset() + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) cmd := newTestCmd() cmd.Flags().BoolP("yes", "y", false, "") @@ -383,8 +389,8 @@ func TestRunDeleteNoToken(t *testing.T) { if err == nil { t.Fatal("expected error when no API token") } - if !strings.Contains(err.Error(), "API key required") { - t.Errorf("expected 'API key required' error, got: %v", err) + if !strings.Contains(err.Error(), "authentication required") { + t.Errorf("expected 'authentication required' error, got: %v", err) } } diff --git a/internal/cmd/teams/teams.go b/internal/cmd/teams/teams.go index 5ea99fe..5956753 100644 --- a/internal/cmd/teams/teams.go +++ b/internal/cmd/teams/teams.go @@ -8,6 +8,7 @@ import ( "github.com/rootlyhq/rootly-cli/internal/api" "github.com/rootlyhq/rootly-cli/internal/config" + "github.com/rootlyhq/rootly-cli/internal/oauth" ) // TeamsCmd is the parent command for all team operations @@ -37,7 +38,9 @@ var TeamsCmd = &cobra.Command{ func getAPIClient() (*api.Client, error) { token := viper.GetString("api_key") if token == "" { - return nil, fmt.Errorf("API key required: set ROOTLY_API_KEY or add api_key to ~/.rootly-cli/config.yaml") + if !oauth.HasTokens() { + return nil, fmt.Errorf("authentication required: run 'rootly login' or set ROOTLY_API_KEY") + } } endpoint := viper.GetString("api_host") if endpoint == "" { diff --git a/internal/cmd/version.go b/internal/cmd/version.go index 8326279..ee5f56c 100644 --- a/internal/cmd/version.go +++ b/internal/cmd/version.go @@ -9,9 +9,10 @@ import ( ) var versionCmd = &cobra.Command{ - Use: "version", - Short: "Print version information", - Long: `Print the version, commit hash, and build date of the rootly CLI.`, + Use: "version", + Short: "Print version information", + Long: `Print the version, commit hash, and build date of the rootly CLI.`, + Annotations: map[string]string{"skipAuth": "true"}, RunE: func(cmd *cobra.Command, args []string) error { format := viper.GetString("format") diff --git a/internal/config/config.go b/internal/config/config.go index 303e9ab..a8bddac 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -18,12 +18,22 @@ const ( ) type Config struct { - APIKey string `yaml:"api_key"` - Endpoint string `yaml:"api_host"` - Timezone string `yaml:"timezone"` // TUI-specific - Language string `yaml:"language"` // TUI-specific - Layout string `yaml:"layout"` // TUI-specific - Debug bool `yaml:"-"` // Runtime only, not persisted + APIKey string `yaml:"api_key"` + Endpoint string `yaml:"api_host"` + Timezone string `yaml:"timezone"` // TUI-specific + Language string `yaml:"language"` // TUI-specific + Layout string `yaml:"layout"` // TUI-specific + OAuth *OAuthData `yaml:"oauth,omitempty"` + ClientID string `yaml:"client_id,omitempty"` // Dynamically registered OAuth client ID + Debug bool `yaml:"-"` // Runtime only, not persisted +} + +// OAuthData holds OAuth2 token data within the config file. +type OAuthData struct { + AccessToken string `yaml:"access_token"` + RefreshToken string `yaml:"refresh_token"` + ExpiresAt time.Time `yaml:"expires_at"` + TokenType string `yaml:"token_type"` } const DefaultTimezone = "UTC" diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go new file mode 100644 index 0000000..d4a89a3 --- /dev/null +++ b/internal/oauth/oauth.go @@ -0,0 +1,178 @@ +package oauth + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "strings" + + "golang.org/x/oauth2" + + "github.com/rootlyhq/rootly-cli/internal/config" +) + +const ( + CallbackPort = "19797" + RedirectURI = "http://localhost:" + CallbackPort + "/callback" +) + +// NewConfig creates an oauth2.Config for the given auth base URL and client ID. +func NewConfig(authBaseURL, clientID string) *oauth2.Config { + return &oauth2.Config{ + ClientID: clientID, + RedirectURL: RedirectURI, + Scopes: []string{"openid", "profile", "email", "all"}, + Endpoint: oauth2.Endpoint{ + AuthURL: authBaseURL + "/oauth/authorize", + TokenURL: authBaseURL + "/oauth/token", + AuthStyle: oauth2.AuthStyleInParams, + }, + } +} + +// registrationRequest is the payload for POST /oauth/register. +type registrationRequest struct { + ClientName string `json:"client_name"` + RedirectURIs []string `json:"redirect_uris"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` + GrantTypes []string `json:"grant_types"` + ResponseTypes []string `json:"response_types"` +} + +// registrationResponse is the response from POST /oauth/register. +type registrationResponse struct { + ClientID string `json:"client_id"` +} + +// RegisterClient dynamically registers an OAuth client and returns the client_id. +func RegisterClient(ctx context.Context, authBaseURL string) (string, error) { + reqBody := registrationRequest{ + ClientName: "Rootly CLI", + RedirectURIs: []string{RedirectURI}, + TokenEndpointAuthMethod: "none", + GrantTypes: []string{"authorization_code"}, + ResponseTypes: []string{"code"}, + } + + body, err := json.Marshal(reqBody) + if err != nil { + return "", fmt.Errorf("failed to marshal registration request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, authBaseURL+"/oauth/register", bytes.NewReader(body)) + if err != nil { + return "", fmt.Errorf("failed to create registration request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("failed to register OAuth client: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated { + return "", fmt.Errorf("could not register OAuth client (status %d)", resp.StatusCode) + } + + var regResp registrationResponse + if err := json.NewDecoder(resp.Body).Decode(®Resp); err != nil { + return "", fmt.Errorf("failed to parse registration response: %w", err) + } + + if regResp.ClientID == "" { + return "", fmt.Errorf("registration response missing client_id") + } + + return regResp.ClientID, nil +} + +// LoadCachedClientID reads the cached OAuth client_id from config. +func LoadCachedClientID() string { + cfg, err := config.Load() + if err != nil { + return "" + } + return cfg.ClientID +} + +// SaveClientID persists the OAuth client_id to config, preserving other fields. +func SaveClientID(clientID string) error { + cfg, err := config.Load() + if err != nil { + cfg = &config.Config{} + } + cfg.ClientID = clientID + return config.Save(cfg) +} + +// ClearClientID removes the cached client_id from config. +func ClearClientID() error { + cfg, err := config.Load() + if err != nil { + return nil // No config file means nothing to clear + } + cfg.ClientID = "" + return config.Save(cfg) +} + +// GenerateState creates a cryptographically random state parameter. +func GenerateState() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// ExchangeCode exchanges an authorization code for tokens using PKCE. +func ExchangeCode(ctx context.Context, cfg *oauth2.Config, code, codeVerifier string) (*oauth2.Token, error) { + return cfg.Exchange(ctx, code, oauth2.VerifierOption(codeVerifier)) +} + +// DeriveAuthBaseURL builds the OAuth base URL from the API host. +// For api.rootly.com it returns https://rootly.com. +// For localhost it returns http://localhost:. +func DeriveAuthBaseURL(apiHost string) string { + // Strip scheme to normalize, then re-apply appropriate scheme + scheme := "" + host := apiHost + if strings.HasPrefix(apiHost, "http://") { + scheme = "http://" + host = apiHost[7:] + } else if strings.HasPrefix(apiHost, "https://") { + scheme = "https://" + host = apiHost[8:] + } + + // Strip /api suffix (used for localhost API endpoints, not OAuth) + host = strings.TrimSuffix(host, "/api") + + if strings.HasPrefix(host, "localhost") || strings.HasPrefix(host, "127.0.0.1") { + if scheme == "" { + scheme = "http://" + } + return scheme + host + } + if strings.HasPrefix(host, "api.") { + return "https://" + host[4:] + } + if scheme == "" { + scheme = "https://" + } + return scheme + host +} + +// TokenSourceFromStored creates a token source that auto-refreshes using stored tokens. +func TokenSourceFromStored(cfg *oauth2.Config) (oauth2.TokenSource, error) { + stored, err := LoadTokens() + if err != nil { + return nil, err + } + tok := ToOAuth2Token(stored) + return cfg.TokenSource(context.Background(), tok), nil +} diff --git a/internal/oauth/oauth_test.go b/internal/oauth/oauth_test.go new file mode 100644 index 0000000..f52dd86 --- /dev/null +++ b/internal/oauth/oauth_test.go @@ -0,0 +1,152 @@ +package oauth + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "testing" +) + +func TestNewConfig(t *testing.T) { + cfg := NewConfig("https://rootly.com", "test-client-id") + + if cfg.ClientID != "test-client-id" { + t.Errorf("ClientID = %q, want %q", cfg.ClientID, "test-client-id") + } + if cfg.RedirectURL != "http://localhost:19797/callback" { + t.Errorf("RedirectURL = %q", cfg.RedirectURL) + } + if cfg.Endpoint.AuthURL != "https://rootly.com/oauth/authorize" { + t.Errorf("AuthURL = %q", cfg.Endpoint.AuthURL) + } + if cfg.Endpoint.TokenURL != "https://rootly.com/oauth/token" { + t.Errorf("TokenURL = %q", cfg.Endpoint.TokenURL) + } + if len(cfg.Scopes) != 4 { + t.Errorf("Scopes = %v", cfg.Scopes) + } +} + +func TestNewConfig_Localhost(t *testing.T) { + cfg := NewConfig("http://localhost:22166", "my-client") + + if cfg.Endpoint.AuthURL != "http://localhost:22166/oauth/authorize" { + t.Errorf("AuthURL = %q", cfg.Endpoint.AuthURL) + } + if cfg.Endpoint.TokenURL != "http://localhost:22166/oauth/token" { + t.Errorf("TokenURL = %q", cfg.Endpoint.TokenURL) + } +} + +func TestDeriveAuthBaseURL(t *testing.T) { + tests := []struct { + input string + want string + }{ + {"api.rootly.com", "https://rootly.com"}, + {"api.staging.rootly.com", "https://staging.rootly.com"}, + {"https://api.rootly.com", "https://rootly.com"}, + {"https://api.staging.rootly.com", "https://staging.rootly.com"}, + {"localhost:22166", "http://localhost:22166"}, + {"localhost:22166/api", "http://localhost:22166"}, + {"http://localhost:22166/api", "http://localhost:22166"}, + {"127.0.0.1:3000", "http://127.0.0.1:3000"}, + {"http://localhost:22166", "http://localhost:22166"}, + {"https://custom.example.com", "https://custom.example.com"}, + {"custom.example.com", "https://custom.example.com"}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := DeriveAuthBaseURL(tt.input) + if got != tt.want { + t.Errorf("DeriveAuthBaseURL(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestRegisterClient(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/oauth/register" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + if r.Method != http.MethodPost { + t.Errorf("unexpected method: %s", r.Method) + } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("unexpected content-type: %s", r.Header.Get("Content-Type")) + } + + var req registrationRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("failed to decode request: %v", err) + } + if req.ClientName != "Rootly CLI" { + t.Errorf("ClientName = %q", req.ClientName) + } + + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(registrationResponse{ClientID: "dynamic-id-123"}) + })) + defer srv.Close() + + clientID, err := RegisterClient(context.Background(), srv.URL) + if err != nil { + t.Fatalf("RegisterClient() error: %v", err) + } + if clientID != "dynamic-id-123" { + t.Errorf("clientID = %q, want %q", clientID, "dynamic-id-123") + } +} + +func TestRegisterClient_NonCreated(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + _, err := RegisterClient(context.Background(), srv.URL) + if err == nil { + t.Fatal("expected error for non-201 status") + } +} + +func TestLoadSaveClearClientID(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + // Initially empty + if id := LoadCachedClientID(); id != "" { + t.Errorf("expected empty, got %q", id) + } + + // Save + if err := SaveClientID("cached-id"); err != nil { + t.Fatalf("SaveClientID: %v", err) + } + if id := LoadCachedClientID(); id != "cached-id" { + t.Errorf("got %q, want %q", id, "cached-id") + } + + // Clear + if err := ClearClientID(); err != nil { + t.Fatalf("ClearClientID: %v", err) + } + + // Verify cleared — load raw config to check + data, err := os.ReadFile(configPathForTest(tmpDir)) + if err != nil { + // File may not exist after clear, that's ok + return + } + if id := LoadCachedClientID(); id != "" { + t.Errorf("expected empty after clear, got %q (raw: %s)", id, string(data)) + } +} + +func configPathForTest(home string) string { + return home + "/.rootly-cli/config.yaml" +} diff --git a/internal/oauth/tokens.go b/internal/oauth/tokens.go new file mode 100644 index 0000000..483cc4c --- /dev/null +++ b/internal/oauth/tokens.go @@ -0,0 +1,131 @@ +package oauth + +import ( + "os" + "time" + + "golang.org/x/oauth2" + "gopkg.in/yaml.v3" + + "github.com/rootlyhq/rootly-cli/internal/config" +) + +// TokenData is an alias for config.OAuthData used within the oauth package. +type TokenData = config.OAuthData + +// LoadTokens reads OAuth tokens from ~/.rootly-cli/config.yaml. +func LoadTokens() (*TokenData, error) { + data, err := os.ReadFile(config.Path()) + if err != nil { + return nil, err + } + var cfg config.Config + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, err + } + if cfg.OAuth == nil || cfg.OAuth.AccessToken == "" { + return nil, os.ErrNotExist + } + return cfg.OAuth, nil +} + +// SaveTokens writes OAuth tokens into ~/.rootly-cli/config.yaml, +// preserving all other config fields. +func SaveTokens(t *TokenData) error { + path := config.Path() + + // Read existing config to preserve other fields + var cfg config.Config + if data, err := os.ReadFile(path); err == nil { + _ = yaml.Unmarshal(data, &cfg) + } + + cfg.OAuth = t + + if err := os.MkdirAll(config.Dir(), 0700); err != nil { + return err + } + data, err := yaml.Marshal(&cfg) + if err != nil { + return err + } + return os.WriteFile(path, data, 0600) +} + +// SaveOAuth2Token converts and persists an oauth2.Token. +func SaveOAuth2Token(tok *oauth2.Token) error { + return SaveTokens(TokenDataFromOAuth2(tok)) +} + +// HasTokens returns true if OAuth tokens exist in the config file (cheap check). +func HasTokens() bool { + data, err := os.ReadFile(config.Path()) + if err != nil { + return false + } + // Quick check without full unmarshal + var cfg struct { + OAuth *struct { + AccessToken string `yaml:"access_token"` + } `yaml:"oauth"` + } + if err := yaml.Unmarshal(data, &cfg); err != nil { + return false + } + return cfg.OAuth != nil && cfg.OAuth.AccessToken != "" +} + +// ClearTokens removes OAuth tokens from ~/.rootly-cli/config.yaml, +// preserving all other config fields. +func ClearTokens() error { + path := config.Path() + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + var cfg config.Config + if err := yaml.Unmarshal(data, &cfg); err != nil { + return err + } + + if cfg.OAuth == nil { + return nil + } + + cfg.OAuth = nil + + out, err := yaml.Marshal(&cfg) + if err != nil { + return err + } + return os.WriteFile(path, out, 0600) +} + +// IsExpired returns true if the token is expired or within 30s of expiring. +func IsExpired(t *TokenData) bool { + return time.Now().After(t.ExpiresAt.Add(-30 * time.Second)) +} + +// ToOAuth2Token converts stored token data to an oauth2.Token. +func ToOAuth2Token(t *TokenData) *oauth2.Token { + return &oauth2.Token{ + AccessToken: t.AccessToken, + RefreshToken: t.RefreshToken, + TokenType: t.TokenType, + Expiry: t.ExpiresAt, + } +} + +// TokenDataFromOAuth2 converts an oauth2.Token to TokenData for storage. +func TokenDataFromOAuth2(tok *oauth2.Token) *TokenData { + return &TokenData{ + AccessToken: tok.AccessToken, + RefreshToken: tok.RefreshToken, + ExpiresAt: tok.Expiry, + TokenType: tok.TokenType, + } +} diff --git a/internal/oauth/tokens_test.go b/internal/oauth/tokens_test.go new file mode 100644 index 0000000..83a21fd --- /dev/null +++ b/internal/oauth/tokens_test.go @@ -0,0 +1,184 @@ +package oauth + +import ( + "os" + "path/filepath" + "runtime" + "strings" + "testing" + "time" +) + +// setTestHome sets HOME (and USERPROFILE on Windows) so os.UserHomeDir() returns tmpDir. +func setTestHome(t *testing.T, tmpDir string) { + t.Helper() + t.Setenv("HOME", tmpDir) + if runtime.GOOS == "windows" { + t.Setenv("USERPROFILE", tmpDir) + } +} + +func TestIsExpired(t *testing.T) { + tests := []struct { + name string + expires time.Time + want bool + }{ + {"future", time.Now().Add(10 * time.Minute), false}, + {"past", time.Now().Add(-10 * time.Minute), true}, + {"within 30s buffer", time.Now().Add(20 * time.Second), true}, + {"just outside buffer", time.Now().Add(60 * time.Second), false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + td := &TokenData{ExpiresAt: tt.expires} + if got := IsExpired(td); got != tt.want { + t.Errorf("IsExpired() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSaveAndLoadTokens(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + tokens := &TokenData{ + AccessToken: "test-access", + RefreshToken: "test-refresh", + ExpiresAt: time.Now().Add(1 * time.Hour).Truncate(time.Second), + TokenType: "Bearer", + } + + if err := SaveTokens(tokens); err != nil { + t.Fatalf("SaveTokens: %v", err) + } + + // Verify file exists + path := filepath.Join(tmpDir, ".rootly-cli", "config.yaml") + info, err := os.Stat(path) + if err != nil { + t.Fatalf("stat config file: %v", err) + } + // File permission check (skip on Windows where permissions work differently) + if runtime.GOOS != "windows" { + if perm := info.Mode().Perm(); perm != 0600 { + t.Errorf("file permissions = %o, want 0600", perm) + } + } + + loaded, err := LoadTokens() + if err != nil { + t.Fatalf("LoadTokens: %v", err) + } + if loaded.AccessToken != tokens.AccessToken { + t.Errorf("AccessToken = %q, want %q", loaded.AccessToken, tokens.AccessToken) + } + if loaded.RefreshToken != tokens.RefreshToken { + t.Errorf("RefreshToken = %q, want %q", loaded.RefreshToken, tokens.RefreshToken) + } +} + +func TestSaveTokens_PreservesExistingConfig(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + // Write a config with an API key first + dir := filepath.Join(tmpDir, ".rootly-cli") + os.MkdirAll(dir, 0700) + os.WriteFile(filepath.Join(dir, "config.yaml"), []byte("api_key: my-key\napi_host: custom.rootly.com\n"), 0600) + + // Save tokens + tokens := &TokenData{AccessToken: "tok", RefreshToken: "ref", ExpiresAt: time.Now().Add(time.Hour)} + if err := SaveTokens(tokens); err != nil { + t.Fatalf("SaveTokens: %v", err) + } + + // Verify existing fields preserved + data, _ := os.ReadFile(filepath.Join(dir, "config.yaml")) + content := string(data) + if !strings.Contains(content, "api_key: my-key") { + t.Errorf("api_key not preserved in config:\n%s", content) + } + if !strings.Contains(content, "api_host: custom.rootly.com") { + t.Errorf("api_host not preserved in config:\n%s", content) + } + if !strings.Contains(content, "access_token: tok") { + t.Errorf("oauth tokens not written:\n%s", content) + } +} + +func TestClearTokens(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + tokens := &TokenData{AccessToken: "x", RefreshToken: "y", ExpiresAt: time.Now().Add(time.Hour)} + _ = SaveTokens(tokens) + + if err := ClearTokens(); err != nil { + t.Fatalf("ClearTokens: %v", err) + } + + _, err := LoadTokens() + if err == nil { + t.Error("expected error after clearing tokens") + } +} + +func TestClearTokens_PreservesConfig(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + // Write config with API key + tokens + dir := filepath.Join(tmpDir, ".rootly-cli") + os.MkdirAll(dir, 0700) + os.WriteFile(filepath.Join(dir, "config.yaml"), []byte("api_key: my-key\noauth:\n access_token: tok\n refresh_token: ref\n"), 0600) + + if err := ClearTokens(); err != nil { + t.Fatalf("ClearTokens: %v", err) + } + + // API key should still be there + data, _ := os.ReadFile(filepath.Join(dir, "config.yaml")) + if !strings.Contains(string(data), "api_key: my-key") { + t.Errorf("api_key not preserved after clear:\n%s", string(data)) + } +} + +func TestClearTokens_NoFile(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + if err := ClearTokens(); err != nil { + t.Fatalf("ClearTokens on missing file: %v", err) + } +} + +func TestHasTokens(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + if HasTokens() { + t.Error("HasTokens should be false with no config") + } + + SaveTokens(&TokenData{AccessToken: "x", RefreshToken: "y", ExpiresAt: time.Now().Add(time.Hour)}) + + if !HasTokens() { + t.Error("HasTokens should be true after saving tokens") + } +} + +func TestLoadTokens_NoOAuthSection(t *testing.T) { + tmpDir := t.TempDir() + setTestHome(t, tmpDir) + + dir := filepath.Join(tmpDir, ".rootly-cli") + os.MkdirAll(dir, 0700) + os.WriteFile(filepath.Join(dir, "config.yaml"), []byte("api_key: my-key\n"), 0600) + + _, err := LoadTokens() + if err == nil { + t.Error("expected error when no oauth section exists") + } +} diff --git a/internal/oauth/transport.go b/internal/oauth/transport.go new file mode 100644 index 0000000..82a34d6 --- /dev/null +++ b/internal/oauth/transport.go @@ -0,0 +1,74 @@ +package oauth + +import ( + "fmt" + "net/http" + "strings" + + "golang.org/x/oauth2" +) + +// NewHTTPClient creates an http.Client that uses stored OAuth tokens with auto-refresh. +// The userAgent is set on all requests. The base transport is used for underlying HTTP calls. +func NewHTTPClient(cfg *oauth2.Config, base http.RoundTripper, userAgent string) (*http.Client, error) { + ts, err := TokenSourceFromStored(cfg) + if err != nil { + return nil, err + } + + // Wrap the token source to save refreshed tokens + ts = &persistingTokenSource{ + base: ts, + } + + transport := &oauth2.Transport{ + Source: ts, + Base: base, + } + + // Wrap with user-agent transport + return &http.Client{ + Transport: &userAgentTransport{ + base: transport, + userAgent: userAgent, + }, + }, nil +} + +// persistingTokenSource wraps a TokenSource and saves refreshed tokens to disk. +type persistingTokenSource struct { + base oauth2.TokenSource + lastAccessToken string +} + +func (p *persistingTokenSource) Token() (*oauth2.Token, error) { + tok, err := p.base.Token() + if err != nil { + // Surface a user-friendly message when refresh fails + errMsg := err.Error() + if strings.Contains(errMsg, "token") || strings.Contains(errMsg, "401") || strings.Contains(errMsg, "invalid_grant") { + return nil, fmt.Errorf("session expired — run 'rootly login' to re-authenticate: %w", err) + } + return nil, err + } + // Only save when the token was actually refreshed (new access token) + if tok.AccessToken != p.lastAccessToken { + p.lastAccessToken = tok.AccessToken + _ = SaveOAuth2Token(tok) + } + return tok, nil +} + +// userAgentTransport sets User-Agent on all requests. +type userAgentTransport struct { + base http.RoundTripper + userAgent string +} + +func (t *userAgentTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req2 := req.Clone(req.Context()) + if t.userAgent != "" { + req2.Header.Set("User-Agent", t.userAgent) + } + return t.base.RoundTrip(req2) +} diff --git a/internal/oauth/transport_test.go b/internal/oauth/transport_test.go new file mode 100644 index 0000000..b411a3c --- /dev/null +++ b/internal/oauth/transport_test.go @@ -0,0 +1,192 @@ +package oauth + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "golang.org/x/oauth2" +) + +func TestNewHTTPClient_SetsAuthAndUserAgent(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + tokens := &TokenData{ + AccessToken: "my-token", + RefreshToken: "my-refresh", + ExpiresAt: time.Now().Add(1 * time.Hour), + TokenType: "Bearer", + } + if err := SaveTokens(tokens); err != nil { + t.Fatal(err) + } + + var gotAuth, gotUA string + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + gotUA = r.Header.Get("User-Agent") + w.WriteHeader(200) + })) + defer backend.Close() + + cfg := &oauth2.Config{ + ClientID: "test", + Endpoint: oauth2.Endpoint{ + TokenURL: backend.URL + "/oauth/token", + AuthStyle: oauth2.AuthStyleInParams, + }, + } + + client, err := NewHTTPClient(cfg, http.DefaultTransport, "rootly-cli/test") + if err != nil { + t.Fatalf("NewHTTPClient: %v", err) + } + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, backend.URL+"/test", http.NoBody) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + resp.Body.Close() + + if gotAuth != "Bearer my-token" { + t.Errorf("Authorization = %q, want %q", gotAuth, "Bearer my-token") + } + if gotUA != "rootly-cli/test" { + t.Errorf("User-Agent = %q, want %q", gotUA, "rootly-cli/test") + } +} + +func TestNewHTTPClient_NoTokens(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + cfg := &oauth2.Config{ + ClientID: "test", + Endpoint: oauth2.Endpoint{TokenURL: "http://localhost/token"}, + } + + _, err := NewHTTPClient(cfg, http.DefaultTransport, "") + if err == nil { + t.Error("expected error when no tokens exist") + } +} + +func TestNewHTTPClient_RefreshesExpiredToken(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + // Token server for refresh + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": "fresh-token", + "refresh_token": "fresh-refresh", + "expires_in": 3600, + "token_type": "Bearer", + }) + })) + defer tokenServer.Close() + + // Save expired tokens + tokens := &TokenData{ + AccessToken: "expired-token", + RefreshToken: "valid-refresh", + ExpiresAt: time.Now().Add(-1 * time.Hour), + TokenType: "Bearer", + } + SaveTokens(tokens) + + cfg := &oauth2.Config{ + ClientID: "test", + Endpoint: oauth2.Endpoint{ + TokenURL: tokenServer.URL, + AuthStyle: oauth2.AuthStyleInParams, + }, + } + + client, err := NewHTTPClient(cfg, http.DefaultTransport, "") + if err != nil { + t.Fatalf("NewHTTPClient: %v", err) + } + + var gotAuth string + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.WriteHeader(200) + })) + defer backend.Close() + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, backend.URL+"/test", http.NoBody) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + resp.Body.Close() + + if gotAuth != "Bearer fresh-token" { + t.Errorf("Authorization = %q, want %q", gotAuth, "Bearer fresh-token") + } + + // Verify refreshed tokens were persisted + saved, _ := LoadTokens() + if saved.AccessToken != "fresh-token" { + t.Errorf("saved AccessToken = %q", saved.AccessToken) + } +} + +func TestNewHTTPClient_RefreshFailsSuggestsLogin(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + // Token server rejects refresh with invalid_grant + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error":"invalid_grant","error_description":"refresh token revoked"}`)) + })) + defer tokenServer.Close() + + // Save expired tokens so refresh is triggered + tokens := &TokenData{ + AccessToken: "expired-token", + RefreshToken: "revoked-refresh", + ExpiresAt: time.Now().Add(-1 * time.Hour), + TokenType: "Bearer", + } + SaveTokens(tokens) + + cfg := &oauth2.Config{ + ClientID: "test", + Endpoint: oauth2.Endpoint{ + TokenURL: tokenServer.URL, + AuthStyle: oauth2.AuthStyleInParams, + }, + } + + client, err := NewHTTPClient(cfg, http.DefaultTransport, "") + if err != nil { + t.Fatalf("NewHTTPClient: %v", err) + } + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, tokenServer.URL+"/test", http.NoBody) + resp, err := client.Do(req) + if resp != nil { + resp.Body.Close() + } + if err == nil { + t.Fatal("expected error when refresh token is revoked") + } + if !strings.Contains(err.Error(), "rootly login") { + t.Errorf("error should suggest 'rootly login', got: %v", err) + } +}