diff --git a/internal/cmd/drive.go b/internal/cmd/drive.go index 34c61578..137b96b7 100644 --- a/internal/cmd/drive.go +++ b/internal/cmd/drive.go @@ -8,6 +8,7 @@ import ( "net/http" "os" "path/filepath" + "regexp" "strings" "google.golang.org/api/drive/v3" @@ -21,6 +22,14 @@ import ( var newDriveService = googleapi.NewDrive +var ( + driveSearchFieldComparisonPattern = regexp.MustCompile(`(?i)\b(?:mimeType|name|fullText|trashed|starred|modifiedTime|createdTime|viewedByMeTime|visibility)\b\s*(?:!=|<=|>=|=|<|>)`) + driveSearchContainsPattern = regexp.MustCompile(`(?i)\b(?:name|fullText)\b\s+contains\s+'`) + driveSearchMembershipPattern = regexp.MustCompile(`(?i)'[^']+'\s+in\s+(?:parents|owners|writers|readers)`) + driveSearchHasPattern = regexp.MustCompile(`(?i)\b(?:properties|appProperties)\b\s+has\s+\{`) + driveTrashedPattern = regexp.MustCompile(`(?i)\btrashed\b`) +) + const ( driveMimeGoogleDoc = "application/vnd.google-apps.document" driveMimeGoogleSheet = "application/vnd.google-apps.spreadsheet" @@ -956,15 +965,38 @@ func buildDriveListQuery(folderID string, userQuery string) string { } else { q = parent } - if !strings.Contains(q, "trashed") { + if !hasDriveTrashedPredicate(q) { q += " and trashed = false" } return q } func buildDriveSearchQuery(text string) string { - q := fmt.Sprintf("fullText contains '%s'", escapeDriveQueryString(text)) - return q + " and trashed = false" + q := strings.TrimSpace(text) + if q == "" { + return "trashed = false" + } + if !looksLikeDriveFilterQuery(q) { + return fmt.Sprintf("fullText contains '%s' and trashed = false", escapeDriveQueryString(q)) + } + if !hasDriveTrashedPredicate(q) { + q += " and trashed = false" + } + return q +} + +func looksLikeDriveFilterQuery(q string) bool { + if strings.EqualFold(q, "sharedWithMe") { + return true + } + return driveSearchFieldComparisonPattern.MatchString(q) || + driveSearchContainsPattern.MatchString(q) || + driveSearchMembershipPattern.MatchString(q) || + driveSearchHasPattern.MatchString(q) +} + +func hasDriveTrashedPredicate(q string) bool { + return driveTrashedPattern.MatchString(q) } func escapeDriveQueryString(s string) string { diff --git a/internal/cmd/drive_search_more_test.go b/internal/cmd/drive_search_more_test.go index a1ca8950..44be0b72 100644 --- a/internal/cmd/drive_search_more_test.go +++ b/internal/cmd/drive_search_more_test.go @@ -162,3 +162,62 @@ func TestDriveSearchCmd_NoResultsAndEmptyQuery(t *testing.T) { t.Fatalf("expected empty query error") } } + +func TestDriveSearchCmd_PassesThroughDriveFilterQueries(t *testing.T) { + origNew := newDriveService + t.Cleanup(func() { newDriveService = origNew }) + + const query = "mimeType = 'application/vnd.google-apps.document'" + const wantQ = query + " and trashed = false" + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.NotFound(w, r) + return + } + if errMsg := driveAllDrivesQueryError(r); errMsg != "" { + http.Error(w, errMsg, http.StatusBadRequest) + return + } + if got := r.URL.Query().Get("q"); got != wantQ { + http.Error(w, "unexpected query: "+got, http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "files": []map[string]any{ + { + "id": "f1", + "name": "Doc", + "mimeType": "application/vnd.google-apps.document", + "modifiedTime": "2025-12-12T14:37:47Z", + }, + }, + }) + })) + t.Cleanup(srv.Close) + + svc, err := drive.NewService(context.Background(), + option.WithoutAuthentication(), + option.WithHTTPClient(srv.Client()), + option.WithEndpoint(srv.URL+"/"), + ) + if err != nil { + t.Fatalf("NewService: %v", err) + } + newDriveService = func(context.Context, string) (*drive.Service, error) { return svc, nil } + + flags := &RootFlags{Account: "a@b.com"} + var errBuf bytes.Buffer + u, uiErr := ui.New(ui.Options{Stdout: io.Discard, Stderr: &errBuf, Color: "never"}) + if uiErr != nil { + t.Fatalf("ui.New: %v", uiErr) + } + ctx := ui.WithUI(context.Background(), u) + _ = captureStdout(t, func() { + cmd := &DriveSearchCmd{} + if execErr := runKong(t, cmd, []string{query}, ctx, flags); execErr != nil { + t.Fatalf("execute: %v", execErr) + } + }) +} diff --git a/internal/cmd/drive_test.go b/internal/cmd/drive_test.go index a61077e2..1871d862 100644 --- a/internal/cmd/drive_test.go +++ b/internal/cmd/drive_test.go @@ -30,6 +30,30 @@ func TestBuildDriveSearchQuery(t *testing.T) { if got != "fullText contains 'hello world' and trashed = false" { t.Fatalf("unexpected: %q", got) } + + t.Run("passes through filter query", func(t *testing.T) { + got := buildDriveSearchQuery("mimeType = 'application/vnd.google-apps.document'") + want := "mimeType = 'application/vnd.google-apps.document' and trashed = false" + if got != want { + t.Fatalf("unexpected: %q", got) + } + }) + + t.Run("plain text containing trashed still appends trashed=false", func(t *testing.T) { + got := buildDriveSearchQuery("trashed") + want := "fullText contains 'trashed' and trashed = false" + if got != want { + t.Fatalf("unexpected: %q", got) + } + }) + + t.Run("does not add trashed when already present", func(t *testing.T) { + got := buildDriveSearchQuery("mimeType != 'application/vnd.google-apps.folder' and TrAsHeD = true") + want := "mimeType != 'application/vnd.google-apps.folder' and TrAsHeD = true" + if got != want { + t.Fatalf("unexpected: %q", got) + } + }) } func TestEscapeDriveQueryString(t *testing.T) { @@ -50,3 +74,64 @@ func TestFormatDriveSize(t *testing.T) { t.Fatalf("unexpected: %q", got) } } + +func TestLooksLikeDriveFilterQuery(t *testing.T) { + tests := []struct { + name string + query string + want bool + }{ + // --- Should return true (filter queries) --- + + // Field comparisons + {name: "mimeType equals", query: "mimeType = 'application/vnd.google-apps.document'", want: true}, + {name: "name not equals", query: "name != 'untitled'", want: true}, + {name: "modifiedTime greater than", query: "modifiedTime > '2024-01-01'", want: true}, + {name: "trashed equals", query: "trashed = true", want: true}, + {name: "starred equals", query: "starred = false", want: true}, + {name: "createdTime less than", query: "createdTime < '2023-06-01'", want: true}, + {name: "viewedByMeTime gte", query: "viewedByMeTime >= '2024-01-01'", want: true}, + {name: "visibility equals", query: "visibility = 'anyoneWithLink'", want: true}, + + // Contains + {name: "name contains", query: "name contains 'report'", want: true}, + {name: "fullText contains", query: "fullText contains 'budget'", want: true}, + + // Membership (in) + {name: "in parents", query: "'folder123' in parents", want: true}, + {name: "in owners", query: "'user@example.com' in owners", want: true}, + {name: "in writers", query: "'user@example.com' in writers", want: true}, + {name: "in readers", query: "'reader@example.com' in readers", want: true}, + + // Has property + {name: "properties has", query: "properties has { key='department' and value='finance' }", want: true}, + {name: "appProperties has", query: "appProperties has { key='project' and value='alpha' }", want: true}, + + // sharedWithMe (case-insensitive) + {name: "sharedWithMe exact", query: "sharedWithMe", want: true}, + {name: "sharedWithMe uppercase", query: "SHAREDWITHME", want: true}, + {name: "sharedWithMe mixed case", query: "SharedWithMe", want: true}, + + // Compound queries + {name: "compound mimeType and name contains", query: "mimeType = 'application/pdf' and name contains 'report'", want: true}, + {name: "compound trashed and starred", query: "trashed = false and starred = true", want: true}, + + // --- Should return false (natural language / plain text) --- + {name: "plain text meeting notes", query: "meeting notes", want: false}, + {name: "plain text find my documents", query: "find my documents", want: false}, + {name: "plain text trashed files", query: "trashed files", want: false}, + {name: "plain text hello world", query: "hello world", want: false}, + {name: "plain text important", query: "important", want: false}, + {name: "empty string", query: "", want: false}, + {name: "whitespace only", query: " ", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := looksLikeDriveFilterQuery(tt.query) + if got != tt.want { + t.Errorf("looksLikeDriveFilterQuery(%q) = %v, want %v", tt.query, got, tt.want) + } + }) + } +} diff --git a/internal/cmd/gmail_mime.go b/internal/cmd/gmail_mime.go index af6a783b..f2d56535 100644 --- a/internal/cmd/gmail_mime.go +++ b/internal/cmd/gmail_mime.go @@ -65,21 +65,21 @@ func buildRFC822(opts mailOptions, cfg *rfc822Config) ([]byte, error) { } } - writeHeader(&b, "From", opts.From) + writeHeader(&b, "From", formatAddressHeader(opts.From)) if len(opts.To) > 0 { - writeHeader(&b, "To", strings.Join(opts.To, ", ")) + writeHeader(&b, "To", formatAddressHeaders(opts.To)) } if len(opts.Cc) > 0 { - writeHeader(&b, "Cc", strings.Join(opts.Cc, ", ")) + writeHeader(&b, "Cc", formatAddressHeaders(opts.Cc)) } if len(opts.Bcc) > 0 { - writeHeader(&b, "Bcc", strings.Join(opts.Bcc, ", ")) + writeHeader(&b, "Bcc", formatAddressHeaders(opts.Bcc)) } if strings.TrimSpace(opts.ReplyTo) != "" { if err := validateHeaderValue(opts.ReplyTo); err != nil { return nil, fmt.Errorf("invalid Reply-To: %w", err) } - writeHeader(&b, "Reply-To", strings.TrimSpace(opts.ReplyTo)) + writeHeader(&b, "Reply-To", formatAddressHeader(opts.ReplyTo)) } if err := validateHeaderValue(opts.Subject); err != nil { return nil, fmt.Errorf("invalid Subject: %w", err) @@ -217,6 +217,33 @@ func writeHeader(b *bytes.Buffer, name, value string) { b.WriteString("\r\n") } +func formatAddressHeader(value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return trimmed + } + addr, err := mail.ParseAddress(trimmed) + if err != nil { + return trimmed + } + if strings.TrimSpace(addr.Name) == "" { + return addr.Address + } + return addr.String() +} + +func formatAddressHeaders(values []string) string { + formatted := make([]string, 0, len(values)) + for _, value := range values { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + continue + } + formatted = append(formatted, formatAddressHeader(trimmed)) + } + return strings.Join(formatted, ", ") +} + func wrapBase64(b []byte) string { s := base64.StdEncoding.EncodeToString(b) const width = 76 diff --git a/internal/cmd/gmail_mime_test.go b/internal/cmd/gmail_mime_test.go index 7cd00ff4..8961703f 100644 --- a/internal/cmd/gmail_mime_test.go +++ b/internal/cmd/gmail_mime_test.go @@ -140,6 +140,47 @@ func TestBuildRFC822UTF8Subject(t *testing.T) { } } +func TestBuildRFC822UTF8FromDisplayName(t *testing.T) { + raw, err := buildRFC822(mailOptions{ + From: "Sérgio Bastos • Importrust ", + To: []string{"c@d.com"}, + Subject: "Hi", + Body: "Hello", + }, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + s := string(raw) + if !strings.Contains(s, "From: =?utf-8?") { + t.Fatalf("expected encoded-word From header: %q", s) + } + if !strings.Contains(s, "") { + t.Fatalf("expected alias email in From header: %q", s) + } + if strings.Contains(s, "From: Sérgio Bastos • Importrust ") { + t.Fatalf("expected From header to be RFC 2047 encoded: %q", s) + } +} + +func TestBuildRFC822PlainFromAddressStaysUnwrapped(t *testing.T) { + raw, err := buildRFC822(mailOptions{ + From: "a@b.com", + To: []string{"c@d.com"}, + Subject: "Hi", + Body: "Hello", + }, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + s := string(raw) + if !strings.Contains(s, "From: a@b.com\r\n") { + t.Fatalf("expected plain From address, got: %q", s) + } + if strings.Contains(s, "From: \r\n") { + t.Fatalf("unexpected wrapped From address: %q", s) + } +} + func TestBuildRFC822ReplyToHeader(t *testing.T) { raw, err := buildRFC822(mailOptions{ From: "a@b.com", @@ -255,3 +296,48 @@ func TestRandomMessageID(t *testing.T) { t.Fatalf("unexpected: %q", id) } } + +func TestFormatAddressHeaderUnparseable(t *testing.T) { + input := "not an email at all" + got := formatAddressHeader(input) + if got != input { + t.Fatalf("expected unparseable input returned unchanged, got: %q", got) + } +} + +func TestFormatAddressHeadersMixed(t *testing.T) { + input := []string{"Alice ", "c@d.com", "Sérgio Bastos "} + got := formatAddressHeaders(input) + + // Should contain all three addresses comma-separated. + parts := strings.SplitN(got, ", ", 3) + if len(parts) != 3 { + t.Fatalf("expected 3 comma-separated parts, got %d: %q", len(parts), got) + } + + // First part: display name "Alice" with address a@b.com. + if !strings.Contains(parts[0], "Alice") || !strings.Contains(parts[0], "a@b.com") { + t.Fatalf("unexpected first part: %q", parts[0]) + } + + // Second part: plain address, no angle brackets. + if parts[1] != "c@d.com" { + t.Fatalf("expected plain address c@d.com, got: %q", parts[1]) + } + + // Third part: non-ASCII name must be RFC 2047 encoded. + if !strings.Contains(parts[2], "=?utf-8?") { + t.Fatalf("expected RFC 2047 encoded name in third part, got: %q", parts[2]) + } + if !strings.Contains(parts[2], "s@b.com") { + t.Fatalf("expected address s@b.com in third part, got: %q", parts[2]) + } +} + +func TestFormatAddressHeadersFiltersEmpty(t *testing.T) { + got := formatAddressHeaders([]string{"a@b.com", "", "b@c.com"}) + expected := "a@b.com, b@c.com" + if got != expected { + t.Fatalf("expected %q, got %q", expected, got) + } +} diff --git a/internal/cmd/gmail_send_batches_test.go b/internal/cmd/gmail_send_batches_test.go index 97e6e379..c70656ed 100644 --- a/internal/cmd/gmail_send_batches_test.go +++ b/internal/cmd/gmail_send_batches_test.go @@ -52,6 +52,7 @@ func TestSendGmailBatches_WithTracking(t *testing.T) { Enabled: true, WorkerURL: "https://example.com", TrackingKey: mustTrackingKey(t), + AdminKey: "test-admin-key", } batches := buildSendBatches( diff --git a/internal/cmd/gmail_track.go b/internal/cmd/gmail_track.go index dd2e69ef..655a4710 100644 --- a/internal/cmd/gmail_track.go +++ b/internal/cmd/gmail_track.go @@ -5,4 +5,5 @@ type GmailTrackCmd struct { Setup GmailTrackSetupCmd `cmd:"" help:"Set up email tracking (deploy Cloudflare Worker)"` Opens GmailTrackOpensCmd `cmd:"" help:"Query email opens"` Status GmailTrackStatusCmd `cmd:"" help:"Show tracking configuration status"` + Key GmailTrackKeyCmd `cmd:"" help:"Manage tracking key rotation"` } diff --git a/internal/cmd/gmail_track_cmd_test.go b/internal/cmd/gmail_track_cmd_test.go index 42f55a40..65964322 100644 --- a/internal/cmd/gmail_track_cmd_test.go +++ b/internal/cmd/gmail_track_cmd_test.go @@ -49,6 +49,51 @@ func TestGmailTrackSetupAndStatus(t *testing.T) { } } +func TestGmailTrackKeyRotate(t *testing.T) { + setupTrackingEnv(t) + + if err := tracking.SaveSecrets("a@b.com", "legacy-key", "admin-key"); err != nil { + t.Fatalf("SaveSecrets: %v", err) + } + + cfg := &tracking.Config{ + Enabled: true, + WorkerURL: "https://example.com", + WorkerName: "gog-email-tracker-test", + DatabaseName: "gog-email-tracker-test", + TrackingKeyVersions: []int{1}, + TrackingCurrentKeyVersion: 1, + SecretsInKeyring: true, + } + if err := tracking.SaveConfig("a@b.com", cfg); err != nil { + t.Fatalf("SaveConfig: %v", err) + } + + out := captureStdout(t, func() { + _ = captureStderr(t, func() { + if err := Execute([]string{"--account", "a@b.com", "--no-input", "gmail", "track", "key", "rotate", "--no-deploy"}); err != nil { + t.Fatalf("Execute: %v", err) + } + }) + }) + + if !strings.Contains(out, "tracking_key_rotated\ttrue") { + t.Fatalf("unexpected rotate output: %q", out) + } + if !strings.Contains(out, "current_version\t2") { + t.Fatalf("expected key rotation version, got %q", out) + } + + reloaded, err := tracking.LoadConfig("a@b.com") + if err != nil { + t.Fatalf("LoadConfig: %v", err) + } + + if reloaded.TrackingCurrentKeyVersion != 2 { + t.Fatalf("expected rotated version 2, got %d", reloaded.TrackingCurrentKeyVersion) + } +} + func TestGmailTrackStatus_NotConfigured(t *testing.T) { setupTrackingEnv(t) @@ -70,6 +115,9 @@ func TestGmailTrackOpens(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { case strings.Contains(r.URL.Path, "/q/"): + if r.Header.Get("Authorization") != "Bearer adminkey" { + t.Fatalf("unexpected /q/ auth: %q", r.Header.Get("Authorization")) + } w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]any{ "tracking_id": "tid", diff --git a/internal/cmd/gmail_track_key.go b/internal/cmd/gmail_track_key.go new file mode 100644 index 00000000..2099adbd --- /dev/null +++ b/internal/cmd/gmail_track_key.go @@ -0,0 +1,147 @@ +package cmd + +import ( + "context" + "fmt" + "sort" + "strings" + + "github.com/steipete/gogcli/internal/tracking" + "github.com/steipete/gogcli/internal/ui" +) + +type GmailTrackKeyCmd struct { + Rotate GmailTrackKeyRotateCmd `cmd:"" help:"Rotate tracking encryption key"` +} + +type GmailTrackKeyRotateCmd struct { + NoDeploy bool `name:"no-deploy" help:"Update local config only and skip worker deploy"` +} + +func (c *GmailTrackKeyRotateCmd) Run(ctx context.Context, flags *RootFlags) error { + u := ui.FromContext(ctx) + + account, cfg, err := loadTrackingConfigForAccount(flags) + if err != nil { + return err + } + + if strings.TrimSpace(cfg.WorkerName) == "" || strings.TrimSpace(cfg.WorkerURL) == "" { + return fmt.Errorf("tracking not configured; run 'gog gmail track setup' first") + } + + if strings.TrimSpace(cfg.AdminKey) == "" { + return fmt.Errorf("tracking admin key not configured; run 'gog gmail track setup' again") + } + + if !cfg.IsConfigured() { + return fmt.Errorf("tracking not configured; run 'gog gmail track setup' first") + } + + currentVersion := cfg.TrackingCurrentKeyVersion + if currentVersion <= 0 { + currentVersion = 1 + } + + knownVersions := append([]int{}, cfg.TrackingKeyVersions...) + if len(knownVersions) == 0 { + knownVersions = []int{currentVersion} + } + + trackingKeys, detectedCurrentVersion, err := tracking.LoadTrackingKeys(account, knownVersions, currentVersion) + if err != nil { + return fmt.Errorf("load tracking keys: %w", err) + } + if detectedCurrentVersion <= 0 { + return fmt.Errorf("invalid tracking key version: %d", detectedCurrentVersion) + } + + for version, key := range trackingKeys { + if strings.TrimSpace(key) == "" { + return fmt.Errorf("missing tracking key for version %d", version) + } + } + + nextVersion := detectedCurrentVersion + for version := range trackingKeys { + if version > nextVersion { + nextVersion = version + } + } + nextVersion++ + if nextVersion <= 0 || nextVersion > 255 { + return fmt.Errorf("invalid tracking key version: %d", nextVersion) + } + + nextKey, err := tracking.GenerateKey() + if err != nil { + return fmt.Errorf("generate tracking key: %w", err) + } + + trackingKeys[nextVersion] = nextKey + + // Keep sorted list of configured versions for local state and future rotation. + versions := make([]int, 0, len(trackingKeys)) + for version := range trackingKeys { + versions = append(versions, version) + } + sort.Ints(versions) + + request := map[string]any{ + "account": account, + "worker_name": cfg.WorkerName, + "database_name": cfg.DatabaseName, + "tracking_current_version": nextVersion, + "tracking_key_versions": versions, + "deploy": !c.NoDeploy, + } + if err := dryRunExit(ctx, flags, "gmail.track.key.rotate", request); err != nil { + return err + } + + if !c.NoDeploy { + dbName := strings.TrimSpace(cfg.DatabaseName) + if dbName == "" { + dbName = strings.TrimSpace(cfg.WorkerName) + } + dbID, deployErr := tracking.DeployWorker(ctx, u.Err(), tracking.DeployOptions{ + WorkerDir: "internal/tracking/worker", + WorkerName: cfg.WorkerName, + DatabaseName: dbName, + TrackingKeys: trackingKeys, + TrackingCurrentVersion: nextVersion, + AdminKey: cfg.AdminKey, + }) + if deployErr != nil { + return deployErr + } + cfg.DatabaseID = dbID + } + + if err := tracking.SaveTrackingKeys(account, trackingKeys, nextVersion, cfg.AdminKey); err != nil { + return fmt.Errorf("save tracking keys: %w", err) + } + + cfg.TrackingKeyVersions = versions + cfg.TrackingCurrentKeyVersion = nextVersion + cfg.TrackingKey = "" + cfg.SecretsInKeyring = true + + if err := tracking.SaveConfig(account, cfg); err != nil { + return fmt.Errorf("save tracking config: %w", err) + } + + if c.NoDeploy { + u.Out().Printf("tracking_key_rotated\t%t", true) + u.Out().Printf("current_version\t%d", nextVersion) + u.Out().Printf("tracking_keys\t%v", versions) + u.Err().Println("No deploy selected; rotate was stored locally only.") + return nil + } + + u.Out().Printf("tracking_key_rotated\t%t", true) + u.Out().Printf("current_version\t%d", nextVersion) + u.Out().Printf("tracking_keys\t%v", versions) + + return nil +} diff --git a/internal/cmd/gmail_track_opens.go b/internal/cmd/gmail_track_opens.go index 6922fc1b..23102391 100644 --- a/internal/cmd/gmail_track_opens.go +++ b/internal/cmd/gmail_track_opens.go @@ -51,6 +51,7 @@ func (c *GmailTrackOpensCmd) queryByTrackingID(ctx context.Context, cfg *trackin if err != nil { return fmt.Errorf("build request: %w", err) } + req.Header.Set("Authorization", "Bearer "+cfg.AdminKey) resp, err := http.DefaultClient.Do(req) if err != nil { diff --git a/internal/cmd/gmail_track_setup.go b/internal/cmd/gmail_track_setup.go index 7ea69a9b..f68b4395 100644 --- a/internal/cmd/gmail_track_setup.go +++ b/internal/cmd/gmail_track_setup.go @@ -134,6 +134,8 @@ func (c *GmailTrackSetupCmd) Run(ctx context.Context, flags *RootFlags) error { cfg.SecretsInKeyring = true cfg.TrackingKey = "" cfg.AdminKey = "" + cfg.TrackingCurrentKeyVersion = 1 + cfg.TrackingKeyVersions = []int{1} if c.Deploy { dbID, deployErr := tracking.DeployWorker(ctx, u.Err(), tracking.DeployOptions{ diff --git a/internal/cmd/gmail_track_status.go b/internal/cmd/gmail_track_status.go index d415538d..d4c0687c 100644 --- a/internal/cmd/gmail_track_status.go +++ b/internal/cmd/gmail_track_status.go @@ -40,6 +40,12 @@ func (c *GmailTrackStatusCmd) Run(ctx context.Context, flags *RootFlags) error { u.Out().Printf("database_id\t%s", cfg.DatabaseID) } u.Out().Printf("admin_configured\t%t", strings.TrimSpace(cfg.AdminKey) != "") + if cfg.TrackingCurrentKeyVersion > 0 { + u.Out().Printf("tracking_current_key_version\t%d", cfg.TrackingCurrentKeyVersion) + } + if len(cfg.TrackingKeyVersions) > 0 { + u.Out().Printf("tracking_key_versions\t%v", cfg.TrackingKeyVersions) + } return nil } diff --git a/internal/tracking/config.go b/internal/tracking/config.go index 13f90daf..022913dd 100644 --- a/internal/tracking/config.go +++ b/internal/tracking/config.go @@ -25,6 +25,8 @@ type Config struct { DatabaseID string `json:"database_id,omitempty"` SecretsInKeyring bool `json:"secrets_in_keyring,omitempty"` TrackingKey string `json:"tracking_key,omitempty"` + TrackingKeyVersions []int `json:"tracking_key_versions,omitempty"` + TrackingCurrentKeyVersion int `json:"tracking_current_key_version,omitempty"` AdminKey string `json:"admin_key,omitempty"` } @@ -173,28 +175,96 @@ func SaveConfig(account string, cfg *Config) error { // IsConfigured returns true if tracking is set up. func (c *Config) IsConfigured() bool { - return c.Enabled && c.WorkerURL != "" && c.TrackingKey != "" + hasKey := strings.TrimSpace(c.TrackingKey) != "" || + (len(c.TrackingKeyVersions) > 0 && c.TrackingCurrentKeyVersion > 0) + return c.Enabled && c.WorkerURL != "" && c.AdminKey != "" && hasKey } func hydrateConfig(account string, cfg *Config) (*Config, error) { + cfg.TrackingCurrentKeyVersion = normalizedTrackingCurrentKeyVersion(cfg.TrackingCurrentKeyVersion) + + if len(cfg.TrackingKeyVersions) == 0 { + cfg.TrackingKeyVersions = append(cfg.TrackingKeyVersions, cfg.TrackingCurrentKeyVersion) + cfg.TrackingKeyVersions = dedupeSortedInts(cfg.TrackingKeyVersions) + } + if strings.TrimSpace(cfg.TrackingKey) == "" || strings.TrimSpace(cfg.AdminKey) == "" || cfg.SecretsInKeyring { - trackingKey, adminKey, secretErr := LoadSecrets(account) + trackingKeys, trackingCurrentVersion, secretErr := LoadTrackingKeys(account, cfg.TrackingKeyVersions, cfg.TrackingCurrentKeyVersion) + if secretErr != nil { + return nil, secretErr + } + secretTrackingKey, secretAdminKey, secretErr := LoadSecrets(account) if secretErr != nil { return nil, secretErr } if strings.TrimSpace(cfg.TrackingKey) == "" { - cfg.TrackingKey = trackingKey + cfg.TrackingKey = trackingKeys[trackingCurrentVersion] + if cfg.TrackingKey == "" { + cfg.TrackingKey = secretTrackingKey + } } if strings.TrimSpace(cfg.AdminKey) == "" { - cfg.AdminKey = adminKey + cfg.AdminKey = secretAdminKey } + + cfg.TrackingCurrentKeyVersion = normalizedTrackingCurrentKeyVersion(trackingCurrentVersion) + + if len(cfg.TrackingKeyVersions) == 0 && trackingCurrentVersion > 0 { + cfg.TrackingKeyVersions = []int{trackingCurrentVersion} + } + + cfg.TrackingKeyVersions = dedupeSortedInts(cfg.TrackingKeyVersions) + } + + cfg.TrackingCurrentKeyVersion = normalizedTrackingCurrentKeyVersion(cfg.TrackingCurrentKeyVersion) + + if len(cfg.TrackingKeyVersions) == 0 { + cfg.TrackingKeyVersions = []int{cfg.TrackingCurrentKeyVersion} } return cfg, nil } +func normalizedTrackingCurrentKeyVersion(version int) int { + if version > 0 { + return version + } + + return 1 +} + +func dedupeSortedInts(values []int) []int { + if len(values) == 0 { + return nil + } + + seen := map[int]struct{}{} + for _, value := range values { + seen[value] = struct{}{} + } + + delete(seen, 0) + delete(seen, -1) + + deduped := make([]int, 0, len(seen)) + for value := range seen { + deduped = append(deduped, value) + } + + // simple in-place sort + for i := 1; i < len(deduped); i++ { + j := i + for j > 0 && deduped[j-1] > deduped[j] { + deduped[j-1], deduped[j] = deduped[j], deduped[j-1] + j-- + } + } + + return deduped +} + func normalizeAccount(account string) string { return strings.ToLower(strings.TrimSpace(account)) } diff --git a/internal/tracking/crypto.go b/internal/tracking/crypto.go index 4e51635d..4f57a61d 100644 --- a/internal/tracking/crypto.go +++ b/internal/tracking/crypto.go @@ -8,6 +8,9 @@ import ( "encoding/json" "errors" "fmt" + "sort" + "strconv" + "strings" ) var errCiphertextTooShort = errors.New("ciphertext too short") @@ -20,8 +23,25 @@ type PixelPayload struct { SentAt int64 `json:"t"` } -// Encrypt encrypts a PixelPayload into a URL-safe base64 blob using AES-GCM +const ( + defaultTrackingKeyVersion = "1" +) + +// Encrypt encrypts a PixelPayload into a URL-safe base64 blob using AES-GCM. func Encrypt(payload *PixelPayload, keyBase64 string) (string, error) { + return EncryptWithVersion(payload, keyBase64, defaultTrackingKeyVersion) +} + +// EncryptWithVersion encrypts with an explicit 1-byte key version prefix. +func EncryptWithVersion(payload *PixelPayload, keyBase64 string, keyVersion string) (string, error) { + version, err := strconv.Atoi(strings.TrimSpace(keyVersion)) + if err != nil { + return "", fmt.Errorf("invalid key version: %w", err) + } + if version < 1 || version > 255 { + return "", fmt.Errorf("invalid key version: %d", version) + } + key, err := base64.StdEncoding.DecodeString(keyBase64) if err != nil { return "", fmt.Errorf("decode key: %w", err) @@ -47,24 +67,121 @@ func Encrypt(payload *PixelPayload, keyBase64 string) (string, error) { return "", fmt.Errorf("nonce: %w", err) } - ciphertext := aead.Seal(nonce, nonce, plaintext, nil) + payloadWithVersion := make([]byte, 0, 1+len(nonce)+aead.Overhead()+len(plaintext)) + payloadWithVersion = append(payloadWithVersion, byte(version)) + payloadWithVersion = append(payloadWithVersion, nonce...) + ciphertext := aead.Seal(payloadWithVersion, nonce, plaintext, nil) // URL-safe base64 encode return base64.RawURLEncoding.EncodeToString(ciphertext), nil } -// Decrypt decrypts a URL-safe base64 blob using AES-GCM +// Decrypt decrypts a URL-safe base64 blob using AES-GCM. func Decrypt(blob string, keyBase64 string) (*PixelPayload, error) { - key, err := base64.StdEncoding.DecodeString(keyBase64) - if err != nil { - return nil, fmt.Errorf("decode key: %w", err) - } + return DecryptWithVersions(blob, map[string]string{ + defaultTrackingKeyVersion: keyBase64, + }) +} +// DecryptWithVersions decrypts a blob by trying the available key versions. +func DecryptWithVersions(blob string, keysByVersion map[string]string) (*PixelPayload, error) { ciphertext, err := base64.RawURLEncoding.DecodeString(blob) if err != nil { return nil, fmt.Errorf("decode blob: %w", err) } + if len(ciphertext) == 0 { + return nil, errCiphertextTooShort + } + + orders, err := decryptionVersionOrder(ciphertext, keysByVersion) + if err != nil { + return nil, err + } + + for _, order := range orders { + key, ok := keysByVersion[order] + if !ok || key == "" { + continue + } + + plaintext, decryptErr := decryptWithOffset(ciphertext, key, 1) + if decryptErr == nil { + var payload PixelPayload + unmarshalErr := json.Unmarshal(plaintext, &payload) + if unmarshalErr == nil { + return &payload, nil + } + if err == nil { + err = unmarshalErr + } + continue + } + if err == nil { + err = decryptErr + } + } + + for _, order := range keyVersions(keysByVersion) { + key, ok := keysByVersion[order] + if !ok || key == "" { + continue + } + + plaintext, decryptErr := decryptWithOffset(ciphertext, key, 0) + if decryptErr == nil { + var payload PixelPayload + unmarshalErr := json.Unmarshal(plaintext, &payload) + if unmarshalErr == nil { + return &payload, nil + } + if err == nil { + err = unmarshalErr + } + continue + } + if err == nil { + err = decryptErr + } + } + + if err == nil { + return nil, errCiphertextTooShort + } + + return nil, fmt.Errorf("decrypt: %w", err) +} + +func decryptionVersionOrder(ciphertext []byte, keysByVersion map[string]string) ([]string, error) { + versions := keyVersions(keysByVersion) + if len(versions) == 0 { + return nil, errors.New("no tracking keys configured") + } + + prefix := int(ciphertext[0]) + prefixVersion := strconv.Itoa(prefix) + for i, version := range versions { + if version == prefixVersion { + result := make([]string, 0, len(versions)) + result = append(result, prefixVersion) + for j, v := range versions { + if j != i { + result = append(result, v) + } + } + return result, nil + } + } + + return versions, nil +} + +func decryptWithOffset(blob []byte, keyBase64 string, nonceOffset int) ([]byte, error) { + key, err := base64.StdEncoding.DecodeString(keyBase64) + if err != nil { + return nil, fmt.Errorf("decode key: %w", err) + } + block, err := aes.NewCipher(key) if err != nil { return nil, fmt.Errorf("new cipher: %w", err) @@ -75,24 +192,40 @@ func Decrypt(blob string, keyBase64 string) (*PixelPayload, error) { return nil, fmt.Errorf("new gcm: %w", err) } - if len(ciphertext) < aead.NonceSize() { + if len(blob) <= nonceOffset { return nil, errCiphertextTooShort } - nonce := ciphertext[:aead.NonceSize()] - ciphertext = ciphertext[aead.NonceSize():] + if len(blob) < nonceOffset+aead.NonceSize() { + return nil, errCiphertextTooShort + } - plaintext, err := aead.Open(nil, nonce, ciphertext, nil) + nonce := blob[nonceOffset : nonceOffset+aead.NonceSize()] + cipherPayload := blob[nonceOffset+aead.NonceSize():] + + plaintext, err := aead.Open(nil, nonce, cipherPayload, nil) if err != nil { return nil, fmt.Errorf("decrypt: %w", err) } - var payload PixelPayload - if err := json.Unmarshal(plaintext, &payload); err != nil { - return nil, fmt.Errorf("unmarshal payload: %w", err) + return plaintext, nil +} + +func keyVersions(keysByVersion map[string]string) []string { + versions := make([]string, 0, len(keysByVersion)) + for version := range keysByVersion { + if _, err := strconv.Atoi(version); err == nil { + versions = append(versions, version) + } } - return &payload, nil + sort.Slice(versions, func(i, j int) bool { + iv, _ := strconv.Atoi(versions[i]) + jv, _ := strconv.Atoi(versions[j]) + return iv < jv + }) + + return versions } // GenerateKey generates a new 256-bit AES key as base64 @@ -104,3 +237,4 @@ func GenerateKey() (string, error) { return base64.StdEncoding.EncodeToString(key), nil } + diff --git a/internal/tracking/crypto_test.go b/internal/tracking/crypto_test.go index 434f9d3e..de6bb42e 100644 --- a/internal/tracking/crypto_test.go +++ b/internal/tracking/crypto_test.go @@ -1,6 +1,7 @@ package tracking import ( + "encoding/base64" "testing" "time" ) @@ -61,6 +62,81 @@ func TestEncryptProducesURLSafeOutput(t *testing.T) { } } +func TestEncryptWithVersionAndDecryptWithVersions(t *testing.T) { + key1, err := GenerateKey() + if err != nil { + t.Fatalf("GenerateKey failed: %v", err) + } + + key2, err := GenerateKey() + if err != nil { + t.Fatalf("GenerateKey failed: %v", err) + } + + payload := &PixelPayload{ + Recipient: "test@example.com", + SubjectHash: "abc123", + SentAt: time.Now().Unix(), + } + + encrypted, err := EncryptWithVersion(payload, key1, "1") + if err != nil { + t.Fatalf("EncryptWithVersion failed: %v", err) + } + + decrypted, err := DecryptWithVersions(encrypted, map[string]string{ + "2": key2, + "1": key1, + }) + if err != nil { + t.Fatalf("DecryptWithVersions failed: %v", err) + } + + if decrypted.Recipient != payload.Recipient { + t.Errorf("Recipient mismatch: got %q, want %q", decrypted.Recipient, payload.Recipient) + } +} + +func TestDecryptLegacyBlobWithVersions(t *testing.T) { + key, err := GenerateKey() + if err != nil { + t.Fatalf("GenerateKey failed: %v", err) + } + + payload := &PixelPayload{ + Recipient: "test@example.com", + SubjectHash: "abc123", + SentAt: time.Now().Unix(), + } + + versioned, err := EncryptWithVersion(payload, key, "1") + if err != nil { + t.Fatalf("EncryptWithVersion failed: %v", err) + } + + raw, err := base64.RawURLEncoding.DecodeString(versioned) + if err != nil { + t.Fatalf("decode versioned blob: %v", err) + } + + if len(raw) == 0 { + t.Fatalf("unexpected empty blob") + } + + legacy := base64.RawURLEncoding.EncodeToString(raw[1:]) + + decrypted, err := DecryptWithVersions(legacy, map[string]string{ + "1": key, + }) + if err != nil { + t.Fatalf("DecryptWithVersions legacy: %v", err) + } + + if decrypted.SubjectHash != payload.SubjectHash { + t.Errorf("SubjectHash mismatch: got %q, want %q", decrypted.SubjectHash, payload.SubjectHash) + } +} + func TestDecryptWithWrongKeyFails(t *testing.T) { key1, _ := GenerateKey() key2, _ := GenerateKey() diff --git a/internal/tracking/deploy.go b/internal/tracking/deploy.go index 92ead3bd..f5fef5f6 100644 --- a/internal/tracking/deploy.go +++ b/internal/tracking/deploy.go @@ -21,7 +21,9 @@ type DeployOptions struct { WorkerName string DatabaseName string TrackingKey string + TrackingKeys map[int]string AdminKey string + TrackingCurrentVersion int } var ( @@ -84,10 +86,28 @@ func DeployWorker(ctx context.Context, logger DeployLogger, opts DeployOptions) return "", runErr } - if runErr := runWranglerCommand(ctx, workerDir, strings.NewReader(opts.TrackingKey+"\n"), "secret", "put", "TRACKING_KEY", "--name", opts.WorkerName); runErr != nil { + trackingKeys, currentVersion, trackingErr := resolveTrackingDeploymentSecrets(opts) + if trackingErr != nil { + return "", trackingErr + } + + for _, version := range trackingKeys { + secretName := fmt.Sprintf("TRACKING_KEY_V%d", version.version) + if runErr := runWranglerCommand(ctx, workerDir, strings.NewReader(version.key+"\n"), "secret", "put", secretName, "--name", opts.WorkerName); runErr != nil { + return "", runErr + } + } + + if runErr := runWranglerCommand(ctx, workerDir, strings.NewReader(fmt.Sprintf("%d\n", currentVersion)), "secret", "put", "TRACKING_KEY_CURRENT_VERSION", "--name", opts.WorkerName); runErr != nil { return "", runErr } + if legacyKey, ok := trackingKeyByVersion(trackingKeys, currentVersion); ok { + if runErr := runWranglerCommand(ctx, workerDir, strings.NewReader(legacyKey+"\n"), "secret", "put", "TRACKING_KEY", "--name", opts.WorkerName); runErr != nil { + return "", runErr + } + } + if runErr := runWranglerCommand(ctx, workerDir, strings.NewReader(opts.AdminKey+"\n"), "secret", "put", "ADMIN_KEY", "--name", opts.WorkerName); runErr != nil { return "", runErr } @@ -109,6 +129,85 @@ func DeployWorker(ctx context.Context, logger DeployLogger, opts DeployOptions) return dbID, nil } +type trackingDeploymentSecret struct { + version int + key string +} + +func resolveTrackingDeploymentSecrets(opts DeployOptions) ([]trackingDeploymentSecret, int, error) { + if len(opts.TrackingKeys) > 0 { + versions := sortedTrackingVersions(opts.TrackingKeys) + if len(versions) == 0 { + return nil, 0, fmt.Errorf("tracking key map is empty") + } + + currentVersion := normalizedTrackingCurrentVersion(opts.TrackingCurrentVersion, versions) + secrets := make([]trackingDeploymentSecret, 0, len(versions)) + for _, version := range versions { + key := strings.TrimSpace(opts.TrackingKeys[version]) + if key == "" { + continue + } + secrets = append(secrets, trackingDeploymentSecret{version: version, key: key}) + } + if len(secrets) == 0 { + return nil, 0, fmt.Errorf("tracking key map is empty") + } + + return secrets, currentVersion, nil + } + + if strings.TrimSpace(opts.TrackingKey) == "" { + return nil, 0, fmt.Errorf("missing tracking key") + } + + return []trackingDeploymentSecret{ + { + version: 1, + key: opts.TrackingKey, + }, + }, 1, nil +} + +func normalizedTrackingCurrentVersion(current int, versions []int) int { + if current > 0 { + for _, version := range versions { + if version == current { + return current + } + } + return versions[len(versions)-1] + } + + return versions[len(versions)-1] +} + +func sortedTrackingVersions(keys map[int]string) []int { + versions := make([]int, 0, len(keys)) + for version := range keys { + versions = append(versions, version) + } + // simple ascending order + for i := 1; i < len(versions); i++ { + j := i + for j > 0 && versions[j-1] > versions[j] { + versions[j-1], versions[j] = versions[j], versions[j-1] + j-- + } + } + + return versions +} + +func trackingKeyByVersion(secrets []trackingDeploymentSecret, current int) (string, bool) { + for _, s := range secrets { + if s.version == current { + return s.key, true + } + } + return "", false +} + func ensureD1Database(ctx context.Context, workerDir, dbName string) (string, error) { out, err := runWranglerCommandOutput(ctx, workerDir, nil, "d1", "create", dbName) if err != nil { diff --git a/internal/tracking/pixel.go b/internal/tracking/pixel.go index afd3dd25..d27b5e6b 100644 --- a/internal/tracking/pixel.go +++ b/internal/tracking/pixel.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "errors" "fmt" + "strconv" "time" ) @@ -23,7 +24,12 @@ func GeneratePixelURL(cfg *Config, recipient, subject string) (string, string, e SentAt: time.Now().Unix(), } - blob, err := Encrypt(payload, cfg.TrackingKey) + trackingKeyVersion := cfg.TrackingCurrentKeyVersion + if trackingKeyVersion == 0 { + trackingKeyVersion = 1 + } + + blob, err := EncryptWithVersion(payload, cfg.TrackingKey, strconv.Itoa(trackingKeyVersion)) if err != nil { return "", "", fmt.Errorf("encrypt payload: %w", err) } diff --git a/internal/tracking/pixel_test.go b/internal/tracking/pixel_test.go index b261f5e4..89a3c9fd 100644 --- a/internal/tracking/pixel_test.go +++ b/internal/tracking/pixel_test.go @@ -11,6 +11,7 @@ func TestGeneratePixelURL(t *testing.T) { Enabled: true, WorkerURL: "https://test.workers.dev", TrackingKey: key, + AdminKey: "test-admin-key", } pixelURL, blob, err := GeneratePixelURL(cfg, "test@example.com", "Hello World") diff --git a/internal/tracking/secrets.go b/internal/tracking/secrets.go index 96abea5f..b2f17d31 100644 --- a/internal/tracking/secrets.go +++ b/internal/tracking/secrets.go @@ -3,7 +3,9 @@ package tracking import ( "errors" "fmt" + "sort" "strings" + "strconv" "github.com/99designs/keyring" @@ -18,25 +20,127 @@ var ( const ( legacyTrackingKeySecretKey = "tracking/tracking_key" legacyAdminKeySecretKey = "tracking/admin_key" - trackingKeySecretSuffix = "tracking_key" - adminKeySecretSuffix = "admin_key" + trackingKeySecretSuffix = "tracking_key" + adminKeySecretSuffix = "admin_key" + trackingKeysCurrentVersionSuffix = "tracking_key_current_version" + trackingKeyVersionSecretPrefix = "tracking_key_v" ) -func SaveSecrets(account, trackingKey, adminKey string) error { +func LoadSecrets(account string) (trackingKey, adminKey string, err error) { account = normalizeAccount(account) if account == "" { - return errMissingAccount + return "", "", errMissingAccount + } + + keys, currentVersion, err := LoadTrackingKeys(account, nil, 0) + if err != nil { + return "", "", fmt.Errorf("read tracking keys: %w", err) + } + if currentVersion > 0 { + if key := strings.TrimSpace(keys[currentVersion]); key != "" { + trackingKey = key + } } if trackingKey == "" { - return errMissingTrackingKey + trackingKey, err = readSecretWithFallback(scopedSecretKey(account, trackingKeySecretSuffix), legacyTrackingKeySecretKey) + if err != nil { + return "", "", fmt.Errorf("read tracking key: %w", err) + } + } + + adminKey, err = readSecretWithFallback(scopedSecretKey(account, adminKeySecretSuffix), legacyAdminKeySecretKey) + if err != nil { + return "", "", fmt.Errorf("read admin key: %w", err) + } + + return trackingKey, adminKey, nil +} + +func LoadTrackingKeys(account string, versions []int, currentVersion int) (map[int]string, int, error) { + account = normalizeAccount(account) + if account == "" { + return nil, 0, errMissingAccount + } + + if currentVersion == 0 { + currentVersion = currentTrackingKeyVersion(account) + } + + if currentVersion == 0 { + currentVersion = 1 + } + + keys := map[int]string{} + versionsToLoad := append([]int{}, versions...) + if len(versionsToLoad) == 0 { + versionsToLoad = []int{currentVersion} + } + + for _, version := range versionsToLoad { + if version <= 0 { + continue + } + + raw, keyErr := secrets.GetSecret(scopedSecretKey(account, keyVersionSecret(version))) + if keyErr != nil { + if !errors.Is(keyErr, keyring.ErrKeyNotFound) { + return nil, 0, fmt.Errorf("read tracking key v%d: %w", version, keyErr) + } + continue + } + key := strings.TrimSpace(string(raw)) + if key == "" { + continue + } + keys[version] = key + } + + if len(keys) == 0 { + legacyKey, legacyErr := readSecretWithFallback(scopedSecretKey(account, trackingKeySecretSuffix), legacyTrackingKeySecretKey) + if legacyErr != nil { + return nil, 0, fmt.Errorf("read tracking key: %w", legacyErr) + } + keys[1] = legacyKey + currentVersion = 1 } + return keys, currentVersion, nil +} + +func SaveTrackingKeys(account string, trackingKeys map[int]string, currentVersion int, adminKey string) error { + account = normalizeAccount(account) + if account == "" { + return errMissingAccount + } + + currentKey := strings.TrimSpace(trackingKeys[currentVersion]) + if currentVersion <= 0 || currentVersion > 255 { + return fmt.Errorf("invalid tracking key version: %d", currentVersion) + } + if currentKey == "" { + return errMissingTrackingKey + } if adminKey == "" { return errMissingAdminKey } - if err := secrets.SetSecret(scopedSecretKey(account, trackingKeySecretSuffix), []byte(trackingKey)); err != nil { + for _, version := range sortedKeyVersions(trackingKeys) { + key := strings.TrimSpace(trackingKeys[version]) + if key == "" { + continue + } + if err := secrets.SetSecret(scopedSecretKey(account, keyVersionSecret(version)), []byte(key)); err != nil { + return fmt.Errorf("store tracking key v%d: %w", version, err) + } + } + + if err := secrets.SetSecret(scopedSecretKey(account, trackingKeysCurrentVersionSuffix), []byte(strconv.Itoa(currentVersion))); err != nil { + return fmt.Errorf("store tracking current key version: %w", err) + } + + // Keep legacy key name available for compatibility with older worker deployments. + if err := secrets.SetSecret(scopedSecretKey(account, trackingKeySecretSuffix), []byte(currentKey)); err != nil { return fmt.Errorf("store tracking key: %w", err) } @@ -47,23 +151,44 @@ func SaveSecrets(account, trackingKey, adminKey string) error { return nil } -func LoadSecrets(account string) (trackingKey, adminKey string, err error) { +func SaveSecrets(account, trackingKey, adminKey string) error { account = normalizeAccount(account) if account == "" { - return "", "", errMissingAccount + return errMissingAccount } - trackingKey, err = readSecretWithFallback(scopedSecretKey(account, trackingKeySecretSuffix), legacyTrackingKeySecretKey) + return SaveTrackingKeys(account, map[int]string{1: trackingKey}, 1, adminKey) +} + +func currentTrackingKeyVersion(account string) int { + raw, err := secrets.GetSecret(scopedSecretKey(account, trackingKeysCurrentVersionSuffix)) if err != nil { - return "", "", fmt.Errorf("read tracking key: %w", err) + return 0 } - adminKey, err = readSecretWithFallback(scopedSecretKey(account, adminKeySecretSuffix), legacyAdminKeySecretKey) - if err != nil { - return "", "", fmt.Errorf("read admin key: %w", err) + parsed, parseErr := strconv.Atoi(strings.TrimSpace(string(raw))) + if parseErr != nil { + return 0 } - return trackingKey, adminKey, nil + if parsed <= 0 || parsed > 255 { + return 0 + } + + return parsed +} + +func keyVersionSecret(version int) string { + return trackingKeyVersionSecretPrefix + strconv.Itoa(version) +} + +func sortedKeyVersions(keys map[int]string) []int { + versions := make([]int, 0, len(keys)) + for version := range keys { + versions = append(versions, version) + } + sort.Ints(versions) + return versions } func readSecretWithFallback(primary, legacy string) (string, error) { diff --git a/internal/tracking/secrets_test.go b/internal/tracking/secrets_test.go index f083a0d5..18777ab2 100644 --- a/internal/tracking/secrets_test.go +++ b/internal/tracking/secrets_test.go @@ -59,3 +59,31 @@ func TestScopedSecretKey(t *testing.T) { t.Fatalf("unexpected scoped key: %q", got) } } + +func TestSaveAndLoadTrackingKeys(t *testing.T) { + setupTrackingKeyringEnv(t) + + account := "a@b.com" + if err := SaveTrackingKeys(account, map[int]string{ + 1: "old-key", + 2: "new-key", + }, 2, "admin-key"); err != nil { + t.Fatalf("SaveTrackingKeys: %v", err) + } + + keys, currentVersion, err := LoadTrackingKeys(account, []int{1, 2}, 2) + if err != nil { + t.Fatalf("LoadTrackingKeys: %v", err) + } + + if currentVersion != 2 { + t.Fatalf("unexpected current version: %d", currentVersion) + } + + if got := keys[1]; got != "old-key" { + t.Fatalf("missing key v1: %q", got) + } + if got := keys[2]; got != "new-key" { + t.Fatalf("missing key v2: %q", got) + } +} diff --git a/internal/tracking/worker/src/bot.test.ts b/internal/tracking/worker/src/bot.test.ts index ffe697df..3828be66 100644 --- a/internal/tracking/worker/src/bot.test.ts +++ b/internal/tracking/worker/src/bot.test.ts @@ -1,34 +1,65 @@ import { describe, it, expect } from 'vitest'; import { detectBot } from './bot'; +function normalHeaders() { + return new Headers({ + Accept: 'image/gif', + Referer: 'https://mail.google.com', + }); +} + describe('detectBot', () => { it('treats GoogleImageProxy as real human', () => { - const result = detectBot('GoogleImageProxy', '66.249.88.1', null); + const result = detectBot('GoogleImageProxy', '66.249.88.1', null, normalHeaders()); expect(result.isBot).toBe(false); expect(result.botType).toBe('gmail_proxy'); }); it('detects Apple Mail Privacy Protection', () => { - const result = detectBot('Mozilla/5.0', '17.253.144.10', null); + const result = detectBot('Mozilla/5.0', '17.253.144.10', null, normalHeaders()); expect(result.isBot).toBe(true); expect(result.botType).toBe('apple_mpp'); }); it('detects Outlook prefetch', () => { - const result = detectBot('Microsoft Outlook 16.0', '1.2.3.4', null); + const result = detectBot('Microsoft Outlook 16.0', '1.2.3.4', null, normalHeaders()); expect(result.isBot).toBe(true); expect(result.botType).toBe('outlook_prefetch'); }); it('detects rapid opens as prefetch', () => { - const result = detectBot('Mozilla/5.0', '1.2.3.4', 500); + const result = detectBot('Mozilla/5.0', '1.2.3.4', 500, normalHeaders()); expect(result.isBot).toBe(true); expect(result.botType).toBe('prefetch'); }); it('treats normal opens as human', () => { - const result = detectBot('Mozilla/5.0 Chrome', '1.2.3.4', 5000); + const result = detectBot('Mozilla/5.0 Chrome', '1.2.3.4', 5000, normalHeaders()); expect(result.isBot).toBe(false); expect(result.botType).toBeNull(); }); + + it('treats bot-like requests with missing request headers as bots', () => { + const result = detectBot('Mozilla/5.0', '1.2.3.4', 5000, new Headers()); + expect(result.isBot).toBe(true); + expect(result.botType).toBe('missing_headers'); + }); + + it('treats Cloudflare managed bots as bots', () => { + const result = detectBot('Mozilla/5.0', '1.2.3.4', 5000, normalHeaders(), { + verifiedBot: true, + score: 99, + }); + expect(result.isBot).toBe(true); + expect(result.botType).toBe('bot_managed'); + }); + + it('treats low bot score as bot', () => { + const result = detectBot('Mozilla/5.0', '1.2.3.4', 5000, normalHeaders(), { + verifiedBot: false, + score: 10, + }); + expect(result.isBot).toBe(true); + expect(result.botType).toBe('low_bot_score'); + }); }); diff --git a/internal/tracking/worker/src/bot.ts b/internal/tracking/worker/src/bot.ts index be37a7fc..75ada222 100644 --- a/internal/tracking/worker/src/bot.ts +++ b/internal/tracking/worker/src/bot.ts @@ -3,6 +3,11 @@ export interface BotDetectionResult { botType: string | null; } +type BotManagementInfo = { + verifiedBot?: boolean; + score?: number; +}; + // Apple Private Relay IP ranges (simplified - real impl would use full list) const APPLE_RELAY_PREFIXES = [ '17.', // Apple corporate @@ -12,18 +17,41 @@ const APPLE_RELAY_PREFIXES = [ export function detectBot( userAgent: string, ip: string, - timeSinceDeliveryMs: number | null + timeSinceDeliveryMs: number | null, + headers: Headers = new Headers(), + botManagement?: BotManagementInfo ): BotDetectionResult { + const hasAcceptHeader = headers.get('Accept') !== null; + const hasRefererHeader = headers.get('Referer') !== null || headers.get('referer') !== null; + + // Known bot signals from Cloudflare + if (botManagement?.verifiedBot) { + return { isBot: true, botType: 'bot_managed' }; + } + + if (botManagement?.score !== undefined && botManagement.score < 20) { + return { isBot: true, botType: 'low_bot_score' }; + } + // Gmail Image Proxy = real human (Gmail proxies on their behalf) if (userAgent.includes('GoogleImageProxy')) { return { isBot: false, botType: 'gmail_proxy' }; } + if (!hasAcceptHeader && !hasRefererHeader) { + return { isBot: true, botType: 'missing_headers' }; + } + // Apple Mail Privacy Protection if (APPLE_RELAY_PREFIXES.some(prefix => ip.startsWith(prefix))) { return { isBot: true, botType: 'apple_mpp' }; } + // Missing or suspicious User-Agent + if (!userAgent || userAgent === 'unknown' || userAgent.trim().length < 8) { + return { isBot: true, botType: 'invalid_user_agent' }; + } + // Outlook prefetch if (userAgent.includes('Outlook-iOS') || userAgent.includes('Microsoft Outlook') || diff --git a/internal/tracking/worker/src/crypto.test.ts b/internal/tracking/worker/src/crypto.test.ts index 4780faed..c5c8a7ca 100644 --- a/internal/tracking/worker/src/crypto.test.ts +++ b/internal/tracking/worker/src/crypto.test.ts @@ -28,4 +28,30 @@ describe('crypto', () => { await expect(decrypt('invalid', key)).rejects.toThrow(); }); + + it('decrypts versioned payload with the matching version key', async () => { + const key = await importKey(testKey); + const payload = { r: 'test@example.com', s: 'abc123', t: 1704067200 }; + + const legacyBlob = await encrypt(payload, key); + const raw = Uint8Array.from(atob(legacyBlob), c => c.charCodeAt(0)); + const versionedRaw = Uint8Array.from([2, ...raw]); + const versionedBlob = btoa(String.fromCharCode(...versionedRaw)) + .replace(/\+/g, '-') + .replace(/\//g, '_') + .replace(/=+$/, ''); + + const decrypted = await decrypt(versionedBlob, { 2: testKey }); + expect(decrypted).toEqual(payload); + }); + + it('falls back to legacy format without a version prefix', async () => { + const key = await importKey(testKey); + const payload = { r: 'test@example.com', s: 'abc123', t: 1704067200 }; + + const blob = await encrypt(payload, key); + + const decrypted = await decrypt(blob, { 1: testKey }); + expect(decrypted).toEqual(payload); + }); }); diff --git a/internal/tracking/worker/src/crypto.ts b/internal/tracking/worker/src/crypto.ts index 6d02e861..f741a7e1 100644 --- a/internal/tracking/worker/src/crypto.ts +++ b/internal/tracking/worker/src/crypto.ts @@ -2,6 +2,9 @@ import type { PixelPayload } from './types'; const ALGORITHM = 'AES-GCM'; const IV_LENGTH = 12; +type TrackingKeys = Record; + +type TrackingKeysInput = TrackingKeys | string; export async function importKey(base64Key: string): Promise { const keyBytes = Uint8Array.from(atob(base64Key), c => c.charCodeAt(0)); @@ -14,23 +17,44 @@ export async function importKey(base64Key: string): Promise { ); } -export async function decrypt(blob: string, key: CryptoKey): Promise { - // URL-safe base64 decode - const base64 = blob.replace(/-/g, '+').replace(/_/g, '/'); - const padded = base64 + '='.repeat((4 - base64.length % 4) % 4); - const combined = Uint8Array.from(atob(padded), c => c.charCodeAt(0)); +export async function decrypt(blob: string, keysByVersion: TrackingKeysInput): Promise { + const keyMap: TrackingKeys = normalizeKeys(keysByVersion); + const combined = decodeBlob(blob); + const versions = sortedVersions(keyMap); + if (versions.length === 0) { + throw new Error('missing tracking keys'); + } - const iv = combined.slice(0, IV_LENGTH); - const ciphertext = combined.slice(IV_LENGTH); + const candidateVersion = combined.length > 0 ? combined[0] : -1; + const orderedVersions = orderedDecryptionVersions(candidateVersion, versions); + const firstAttemptOffset = candidateVersion >= 1 && candidateVersion <= 255 ? 1 : 0; + const fallbackOffset = firstAttemptOffset === 1 ? 0 : 1; - const decrypted = await crypto.subtle.decrypt( - { name: ALGORITHM, iv }, - key, - ciphertext - ); + for (const version of orderedVersions) { + const key = keyMap[version]; + const payload = await tryDecrypt(combined, key, firstAttemptOffset); + if (!payload) { + continue; + } + const parsed = parsePayload(payload); + if (parsed) { + return parsed; + } + } - const text = new TextDecoder().decode(decrypted); - return JSON.parse(text) as PixelPayload; + for (const version of orderedVersions) { + const key = keyMap[version]; + const payload = await tryDecrypt(combined, key, fallbackOffset); + if (!payload) { + continue; + } + const parsed = parsePayload(payload); + if (parsed) { + return parsed; + } + } + + throw new Error('decrypt failed'); } export async function encrypt(payload: PixelPayload, key: CryptoKey): Promise { @@ -51,3 +75,104 @@ export async function encrypt(payload: PixelPayload, key: CryptoKey): Promise 255) { + continue; + } + if (typeof key !== 'string') { + continue; + } + const trimmed = key.trim(); + if (trimmed === '') { + continue; + } + result[numericVersion] = trimmed; + } + + return result; +} + +function decodeBlob(blob: string): Uint8Array { + const base64 = blob.replace(/-/g, '+').replace(/_/g, '/'); + const padded = base64 + '='.repeat((4 - base64.length % 4) % 4); + return Uint8Array.from(atob(padded), c => c.charCodeAt(0)); +} + +function sortedVersions(keysByVersion: Record): number[] { + const versions = Object.keys(keysByVersion) + .map(v => parseInt(v, 10)) + .filter(v => Number.isFinite(v) && v > 0); + + versions.sort((a, b) => a - b); + return versions; +} + +async function tryDecrypt( + combined: Uint8Array, + base64Key: string | undefined, + nonceOffset: number +): Promise { + if (!base64Key) { + return null; + } + try { + const key = Uint8Array.from(atob(base64Key), c => c.charCodeAt(0)); + const importedKey = await crypto.subtle.importKey( + 'raw', + key, + { name: ALGORITHM }, + false, + ['decrypt'] + ); + + if (combined.length < nonceOffset + IV_LENGTH) { + return null; + } + + const iv = combined.slice(nonceOffset, nonceOffset + IV_LENGTH); + const ciphertext = combined.slice(nonceOffset + IV_LENGTH); + + return await crypto.subtle.decrypt( + { name: ALGORITHM, iv }, + importedKey, + ciphertext + ); + } catch { + return null; + } +} + +function orderedDecryptionVersions(candidateVersion: number, versions: number[]): number[] { + if (candidateVersion < 1 || candidateVersion > 255) { + return versions; + } + + const ordered = [...versions]; + const candidateIndex = ordered.indexOf(candidateVersion); + if (candidateIndex < 0) { + return ordered; + } + + return [ordered[candidateIndex], ...ordered.slice(0, candidateIndex), ...ordered.slice(candidateIndex + 1)]; +} + +function parsePayload(payload: ArrayBuffer): PixelPayload | null { + try { + const text = new TextDecoder().decode(payload); + return JSON.parse(text) as PixelPayload; + } catch { + return null; + } +} diff --git a/internal/tracking/worker/src/index.ts b/internal/tracking/worker/src/index.ts index faad5142..a765b270 100644 --- a/internal/tracking/worker/src/index.ts +++ b/internal/tracking/worker/src/index.ts @@ -1,8 +1,44 @@ import type { Env, PixelPayload } from './types'; -import { importKey, decrypt } from './crypto'; +import { decrypt } from './crypto'; import { detectBot } from './bot'; import { pixelResponse } from './pixel'; +const RATE_LIMIT_WINDOW_SECONDS = 60 * 60; +const RATE_LIMIT_MAX_REQUESTS = 100; +const DEDUPE_WINDOW_SQL = '-1 hour'; + +function trackingKeysFromEnv(env: Env): Record { + const keys: Record = {}; + + for (const [name, value] of Object.entries(env)) { + if (typeof value !== 'string') { + continue; + } + + if (!name.startsWith('TRACKING_KEY_V')) { + continue; + } + + const versionText = name.substring('TRACKING_KEY_V'.length); + const version = Number.parseInt(versionText, 10); + if (!Number.isFinite(version) || version < 1 || version > 255) { + continue; + } + if (value.trim() === '') { + continue; + } + + keys[version] = value.trim(); + } + + const legacyKey = typeof env.TRACKING_KEY === 'string' ? env.TRACKING_KEY.trim() : ''; + if (Object.keys(keys).length === 0 && legacyKey !== '') { + keys[1] = legacyKey; + } + + return keys; +} + export default { async fetch(request: Request, env: Env): Promise { const url = new URL(request.url); @@ -41,11 +77,11 @@ async function handlePixel(request: Request, env: Env, path: string): Promise { + const key = `rate:${ip}`; + + try { + const raw = await rateStore.get(key); + const current = raw !== null ? parseInt(raw, 10) : 0; + const next = Number.isFinite(current) && current >= 0 ? current + 1 : 1; + + await rateStore.put(key, String(next), { expirationTtl: RATE_LIMIT_WINDOW_SECONDS }); + return next > RATE_LIMIT_MAX_REQUESTS; + } catch (error) { + console.error('Rate limit check failed:', error); + return false; + } +} + +async function hasRecentOpen(db: D1Database, trackingId: string, ip: string): Promise { + try { + const existing = await db.prepare(` + SELECT 1 FROM opens + WHERE tracking_id = ? AND ip = ? AND opened_at > datetime('now', ?) + LIMIT 1 + `).bind(trackingId, ip, DEDUPE_WINDOW_SQL).first(); + + return existing !== null; + } catch (error) { + console.error('Failed to check duplicate open:', error); + return false; + } +} + async function handleQuery(request: Request, env: Env, path: string): Promise { + // Require admin authentication to prevent leaking IP/location data + const authHeader = request.headers.get('Authorization'); + if (!authHeader || authHeader !== `Bearer ${env.ADMIN_KEY}`) { + return new Response('Unauthorized', { status: 401 }); + } + const blob = path.slice(3); // Remove '/q/' - const key = await importKey(env.TRACKING_KEY); + const trackingKeys = trackingKeysFromEnv(env); let payload: PixelPayload; try { - payload = await decrypt(blob, key); + payload = await decrypt(blob, trackingKeys); } catch { return new Response('Invalid tracking ID', { status: 400 }); } diff --git a/internal/tracking/worker/src/types.ts b/internal/tracking/worker/src/types.ts index f31bdb2d..ebbec931 100644 --- a/internal/tracking/worker/src/types.ts +++ b/internal/tracking/worker/src/types.ts @@ -1,7 +1,13 @@ export interface Env { DB: D1Database; - TRACKING_KEY: string; + TRACKING_KEY?: string; + TRACKING_KEY_CURRENT_VERSION?: string; + TRACKING_KEY_V1?: string; + TRACKING_KEY_V2?: string; + TRACKING_KEY_V3?: string; ADMIN_KEY: string; + RATE_KV: KVNamespace; + [key: string]: string | D1Database | KVNamespace | undefined; } export interface PixelPayload { diff --git a/internal/tracking/worker/wrangler.toml b/internal/tracking/worker/wrangler.toml index 11910c70..5d1bf302 100644 --- a/internal/tracking/worker/wrangler.toml +++ b/internal/tracking/worker/wrangler.toml @@ -6,3 +6,8 @@ compatibility_date = "2024-12-01" binding = "DB" database_name = "gog-email-tracker" database_id = "placeholder-will-be-replaced" + +[[kv_namespaces]] +binding = "RATE_KV" +id = "placeholder-will-be-replaced" +preview_id = "placeholder-preview-will-be-replaced"