diff --git a/internal/auth/backupcodes/generator_boundary_test.go b/internal/auth/backupcodes/generator_boundary_test.go new file mode 100644 index 0000000..84d31cd --- /dev/null +++ b/internal/auth/backupcodes/generator_boundary_test.go @@ -0,0 +1,636 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (c) 2024 UnderNET + +package backupcodes + +import ( + "context" + "encoding/json" + "strings" + "sync" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/undernetirc/cservice-api/db/mocks" + "github.com/undernetirc/cservice-api/internal/auth/password" + "github.com/undernetirc/cservice-api/models" +) + +// TestValidateBackupCodeFormat_Boundary tests adversarial and boundary inputs that +// the builder's tests did not cover. +func TestValidateBackupCodeFormat_Boundary(t *testing.T) { + adversarialCases := []struct { + name string + code string + wantErr bool + }{ + // Unicode look-alikes + {"unicode hyphen (em-dash)", "abcde\u2014" + "12345", true}, + {"unicode hyphen (en-dash)", "abcde\u2013" + "12345", true}, + {"unicode full-width hyphen", "abcde\uff0d12345", true}, + {"unicode digits look-alikes", "abcde-\uff11\uff12\uff13\uff14\uff15", true}, + {"null bytes in code", "abc\x00e-12345", true}, + {"null byte as hyphen", "abcde\x0012345", true}, + {"very long string (1000 chars)", strings.Repeat("a", 1000), true}, + {"only a hyphen", "-", true}, + {"two hyphens, correct length otherwise", "abcd--2345", true}, + {"correct length but wrong format (all hyphens)", "-----12345", true}, + {"newline in code", "abcde\n12345", true}, + {"tab in code", "abcde\t12345", true}, + {"carriage return", "abcde\r12345", true}, + {"all hyphens", "----------", true}, + {"hyphen at start", "-abcde1234", true}, + {"hyphen at end", "abcde1234-", true}, + {"correct format with trailing newline", "abcde-12345\n", true}, + {"correct format with leading space", " abcde-12345", true}, + {"correct format with trailing space", "abcde-12345 ", true}, + // Non-ASCII characters that might slip through + {"chinese characters", "测试a-12345", true}, + {"emoji in code", "abc😀-12345", true}, + // Valid boundary cases + {"all zeros (valid)", "00000-00000", false}, + {"all nines (valid)", "99999-99999", false}, + {"mixed case max alphanumeric", "ZZZZZ-zzzzz", false}, + } + + for _, tc := range adversarialCases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateBackupCodeFormat(tc.code) + if tc.wantErr { + assert.Error(t, err, "expected error for input: %q", tc.code) + } else { + assert.NoError(t, err, "expected no error for input: %q", tc.code) + } + }) + } +} + +// TestNormalizeBackupCode_Boundary tests edge cases in the normalization function. +func TestNormalizeBackupCode_Boundary(t *testing.T) { + cases := []struct { + name string + input string + expected string + }{ + // Empty and minimal inputs + {"empty string", "", ""}, + {"single space", " ", ""}, + {"all spaces", " ", ""}, + {"spaces only 10 chars", " ", ""}, + // After space removal, exactly 10 chars → hyphen inserted + {"10-char string after removing 5 spaces", "a b c d e 1 2 3 4 5", "abcde-12345"}, + // Unicode: unicode space (non-breaking) is NOT removed (only ASCII space) + {"non-breaking space (not removed)", "abcde\u00a012345", "abcde\u00a012345"}, + // Hyphen already present + {"already has hyphen at correct position", "abcde-12345", "abcde-12345"}, + // Multiple hyphens: no hyphen added because strings.Contains returns true + {"multiple hyphens stay as-is", "ab-c-de12345", "ab-c-de12345"}, + // Only 9 chars after space removal → no hyphen added + {"9 chars (no hyphen added)", "abcde1234", "abcde1234"}, + // 11 chars after space removal → no hyphen added + {"11 chars (no hyphen added)", "abcde123456", "abcde123456"}, + // Null bytes (treated as characters, not spaces) + {"null bytes stay in output", "abcde\x0012345", "abcde\x0012345"}, + // SQL injection-like input + {"SQL injection stays as-is (no spaces)", "';DROP-TABLE", "';DROP-TABLE"}, + {"SQL injection with spaces gets compacted", "'; DROP TABLE ", "';DROPTABLE"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + result := NormalizeBackupCode(tc.input) + assert.Equal(t, tc.expected, result) + }) + } +} + +// TestConsumeBackupCode_Boundary tests adversarial inputs to ConsumeBackupCode. +func TestConsumeBackupCode_Boundary(t *testing.T) { + ctx := context.Background() + userID := int32(42) + updatedBy := "admin" + + // Use low-cost bcrypt for test speed + hasher := password.NewBcryptHasher(&password.BcryptConfig{Cost: 4}) + validCode := "abcde-12345" + hash, err := hasher.GenerateHash(validCode) + require.NoError(t, err) + + storedCodes := []BackupCode{{Hash: hash}} + + t.Run("empty string input does not consume any code", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + metadataJSON := buildMetadataJSON(t, storedCodes, 1) + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{BackupCodes: metadataJSON}, nil).Once() + + consumed, err := gen.ConsumeBackupCode(ctx, userID, "", updatedBy) + + require.NoError(t, err) + assert.False(t, consumed, "empty string should not match any backup code") + mockDB.AssertExpectations(t) + }) + + t.Run("whitespace-only input does not consume any code", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + metadataJSON := buildMetadataJSON(t, storedCodes, 1) + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{BackupCodes: metadataJSON}, nil).Once() + + consumed, err := gen.ConsumeBackupCode(ctx, userID, " ", updatedBy) + + require.NoError(t, err) + assert.False(t, consumed, "whitespace-only input should not match any backup code") + mockDB.AssertExpectations(t) + }) + + t.Run("very long input (1000 chars) does not match valid code", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + metadataJSON := buildMetadataJSON(t, storedCodes, 1) + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{BackupCodes: metadataJSON}, nil).Once() + + longInput := strings.Repeat("a", 1000) + consumed, err := gen.ConsumeBackupCode(ctx, userID, longInput, updatedBy) + + require.NoError(t, err) + assert.False(t, consumed, "1000-char input should not match any 11-char backup code") + mockDB.AssertExpectations(t) + }) + + t.Run("unicode input does not match valid code", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + metadataJSON := buildMetadataJSON(t, storedCodes, 1) + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{BackupCodes: metadataJSON}, nil).Once() + + consumed, err := gen.ConsumeBackupCode(ctx, userID, "测试-12345", updatedBy) + + require.NoError(t, err) + assert.False(t, consumed, "unicode input should not match any backup code") + mockDB.AssertExpectations(t) + }) + + t.Run("SQL injection-like input does not match valid code", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + metadataJSON := buildMetadataJSON(t, storedCodes, 1) + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{BackupCodes: metadataJSON}, nil).Once() + + assert.NotPanics(t, func() { + consumed, err := gen.ConsumeBackupCode(ctx, userID, "'; DROP TABLE users; --", updatedBy) + require.NoError(t, err) + assert.False(t, consumed, "SQL injection input should not match any backup code") + }) + mockDB.AssertExpectations(t) + }) + + t.Run("null bytes in input do not match valid code", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + metadataJSON := buildMetadataJSON(t, storedCodes, 1) + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{BackupCodes: metadataJSON}, nil).Once() + + consumed, err := gen.ConsumeBackupCode(ctx, userID, "abcde\x00-12345", updatedBy) + + require.NoError(t, err) + assert.False(t, consumed, "input with null bytes should not match any backup code") + mockDB.AssertExpectations(t) + }) + + t.Run("input that normalizes to correct code is accepted", func(t *testing.T) { + // "a b c d e 1 2 3 4 5" → after normalization: "abcde12345" (10 chars) → "abcde-12345" + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + metadataJSON := buildMetadataJSON(t, storedCodes, 1) + // ConsumeBackupCode calls GetBackupCodes (GetUserBackupCodes) and UpdateBackupCodes (GetUserBackupCodes + UpdateUserBackupCodes) + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{BackupCodes: metadataJSON}, nil).Twice() + mockDB.On("UpdateUserBackupCodes", mock.Anything, mock.Anything).Return(nil).Once() + + // Input with spaces that NormalizeBackupCode converts to "abcde-12345" + normalizedInput := "a b c d e 1 2 3 4 5" + consumed, err := gen.ConsumeBackupCode(ctx, userID, normalizedInput, updatedBy) + + require.NoError(t, err) + assert.True(t, consumed, "normalized input should match the stored backup code") + mockDB.AssertExpectations(t) + }) +} + +// TestGetBackupCodes_CorruptMetadata tests behavior when the metadata stored in the +// database contains malformed JSON. +func TestGetBackupCodes_CorruptMetadata(t *testing.T) { + ctx := context.Background() + userID := int32(42) + + t.Run("malformed outer metadata JSON returns error", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{ + BackupCodes: []byte(`{this is not valid json`), + }, nil).Once() + + codes, err := gen.GetBackupCodes(ctx, userID) + + require.Error(t, err, "malformed outer metadata should return an error") + assert.Nil(t, codes) + assert.Contains(t, err.Error(), "failed to unmarshal backup codes metadata") + mockDB.AssertExpectations(t) + }) + + t.Run("valid outer metadata but malformed inner backup_codes JSON returns error", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + // Valid outer metadata, but backup_codes field contains invalid JSON + invalidInnerMetadata := Metadata{ + BackupCodes: `this is not valid json`, + GeneratedAt: "2025-06-22T10:30:00Z", + CodesRemaining: 5, + } + invalidJSON, err := json.Marshal(invalidInnerMetadata) + require.NoError(t, err) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{ + BackupCodes: invalidJSON, + }, nil).Once() + + codes, err := gen.GetBackupCodes(ctx, userID) + + require.Error(t, err, "malformed inner backup_codes JSON should return an error") + assert.Nil(t, codes) + assert.Contains(t, err.Error(), "failed to unmarshal backup codes") + mockDB.AssertExpectations(t) + }) + + t.Run("empty backup_codes field in metadata returns empty slice", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + // Valid outer metadata, backup_codes is an empty JSON array + emptyCodesMetadata := Metadata{ + BackupCodes: "[]", + GeneratedAt: "2025-06-22T10:30:00Z", + CodesRemaining: 0, + } + metadataJSON, err := json.Marshal(emptyCodesMetadata) + require.NoError(t, err) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{ + BackupCodes: metadataJSON, + }, nil).Once() + + codes, err := gen.GetBackupCodes(ctx, userID) + + assert.NoError(t, err) + assert.NotNil(t, codes, "empty array should return a non-nil empty slice") + assert.Len(t, codes, 0) + mockDB.AssertExpectations(t) + }) + + t.Run("backup_codes field is empty string returns error (invalid JSON)", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + emptyStringMetadata := Metadata{ + BackupCodes: "", + GeneratedAt: "2025-06-22T10:30:00Z", + CodesRemaining: 0, + } + metadataJSON, err := json.Marshal(emptyStringMetadata) + require.NoError(t, err) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{ + BackupCodes: metadataJSON, + }, nil).Once() + + codes, err := gen.GetBackupCodes(ctx, userID) + + // Empty string is not valid JSON → unmarshal error + assert.Error(t, err, "empty string in backup_codes should cause a JSON unmarshal error") + assert.Nil(t, codes) + mockDB.AssertExpectations(t) + }) +} + +// TestGetBackupCodesCount_Boundary tests the GetBackupCodesCount function which +// was not covered by the builder's tests. +func TestGetBackupCodesCount_Boundary(t *testing.T) { + ctx := context.Background() + userID := int32(42) + + t.Run("returns count from metadata", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + codes := []BackupCode{{Hash: "h1"}, {Hash: "h2"}, {Hash: "h3"}} + metadataJSON := buildMetadataJSON(t, codes, 3) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{BackupCodes: metadataJSON}, nil).Once() + + count, err := gen.GetBackupCodesCount(ctx, userID) + + require.NoError(t, err) + assert.Equal(t, 3, count) + mockDB.AssertExpectations(t) + }) + + t.Run("returns 0 when no backup codes stored", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{BackupCodes: nil}, nil).Once() + + count, err := gen.GetBackupCodesCount(ctx, userID) + + require.NoError(t, err) + assert.Equal(t, 0, count) + mockDB.AssertExpectations(t) + }) + + t.Run("returns error on DB failure", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{}, assert.AnError).Once() + + count, err := gen.GetBackupCodesCount(ctx, userID) + + require.Error(t, err) + assert.Equal(t, 0, count) + assert.Contains(t, err.Error(), "failed to get backup codes metadata") + mockDB.AssertExpectations(t) + }) + + t.Run("returns error on malformed metadata JSON", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{BackupCodes: []byte(`{invalid json`)}, nil).Once() + + count, err := gen.GetBackupCodesCount(ctx, userID) + + require.Error(t, err) + assert.Equal(t, 0, count) + assert.Contains(t, err.Error(), "failed to unmarshal backup codes metadata") + mockDB.AssertExpectations(t) + }) + + t.Run("count of 0 remaining codes is returned correctly", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + // CodesRemaining is explicitly 0 (all codes used) + metadataJSON := buildMetadataJSON(t, []BackupCode{}, 0) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{BackupCodes: metadataJSON}, nil).Once() + + count, err := gen.GetBackupCodesCount(ctx, userID) + + require.NoError(t, err) + assert.Equal(t, 0, count) + mockDB.AssertExpectations(t) + }) +} + +// TestGetBackupCodesGeneratedAt_Boundary tests the GetBackupCodesGeneratedAt function +// which was not covered by the builder's tests. +func TestGetBackupCodesGeneratedAt_Boundary(t *testing.T) { + ctx := context.Background() + userID := int32(42) + + t.Run("returns generated_at timestamp from metadata", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + metadataJSON := buildMetadataJSON(t, []BackupCode{{Hash: "h"}}, 1) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{BackupCodes: metadataJSON}, nil).Once() + + generatedAt, err := gen.GetBackupCodesGeneratedAt(ctx, userID) + + require.NoError(t, err) + assert.Equal(t, "2025-06-22T10:30:00Z", generatedAt) + mockDB.AssertExpectations(t) + }) + + t.Run("returns empty string when no backup codes stored", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{BackupCodes: nil}, nil).Once() + + generatedAt, err := gen.GetBackupCodesGeneratedAt(ctx, userID) + + require.NoError(t, err) + assert.Empty(t, generatedAt) + mockDB.AssertExpectations(t) + }) + + t.Run("returns error on DB failure", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{}, assert.AnError).Once() + + generatedAt, err := gen.GetBackupCodesGeneratedAt(ctx, userID) + + require.Error(t, err) + assert.Empty(t, generatedAt) + assert.Contains(t, err.Error(), "failed to get backup codes metadata") + mockDB.AssertExpectations(t) + }) + + t.Run("returns error on malformed metadata JSON", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{BackupCodes: []byte(`{invalid`)}, nil).Once() + + generatedAt, err := gen.GetBackupCodesGeneratedAt(ctx, userID) + + require.Error(t, err) + assert.Empty(t, generatedAt) + mockDB.AssertExpectations(t) + }) +} + +// TestGetBackupCodesReadStatus_Boundary tests the GetBackupCodesReadStatus function +// which was not covered by the builder's tests. +func TestGetBackupCodesReadStatus_Boundary(t *testing.T) { + ctx := context.Background() + userID := int32(42) + + t.Run("returns true when backup codes have been read", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{ + BackupCodesRead: pgtype.Bool{Bool: true, Valid: true}, + }, nil).Once() + + read, err := gen.GetBackupCodesReadStatus(ctx, userID) + + require.NoError(t, err) + assert.True(t, read) + mockDB.AssertExpectations(t) + }) + + t.Run("returns false when backup codes have not been read", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{ + BackupCodesRead: pgtype.Bool{Bool: false, Valid: true}, + }, nil).Once() + + read, err := gen.GetBackupCodesReadStatus(ctx, userID) + + require.NoError(t, err) + assert.False(t, read) + mockDB.AssertExpectations(t) + }) + + t.Run("returns error on DB failure", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{}, assert.AnError).Once() + + read, err := gen.GetBackupCodesReadStatus(ctx, userID) + + require.Error(t, err) + assert.False(t, read) + assert.Contains(t, err.Error(), "failed to get backup codes read status") + mockDB.AssertExpectations(t) + }) +} + +// TestConsumeBackupCode_ConcurrentSafety tests the goroutine safety of ConsumeBackupCode. +// Note: the logical TOCTOU (time-of-check-time-of-use) race — where two concurrent +// requests both read the same backup codes, find a match, and both "consume" it — +// is a semantic race that requires DB-level transactions to prevent. +// This test verifies there are no Go-level data races (safe for `go test -race`). +func TestConsumeBackupCode_ConcurrentSafety(t *testing.T) { + ctx := context.Background() + userID := int32(42) + updatedBy := "admin" + + hasher := password.NewBcryptHasher(&password.BcryptConfig{Cost: 4}) + code1, err := hasher.GenerateHash("abcde-11111") + require.NoError(t, err) + code2, err := hasher.GenerateHash("abcde-22222") + require.NoError(t, err) + + const numGoroutines = 5 + + t.Run("concurrent consumption of different codes is goroutine-safe", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + storedCodes := []BackupCode{{Hash: code1}, {Hash: code2}} + metadataJSON := buildMetadataJSON(t, storedCodes, 2) + + // Each goroutine reads then updates — mock returns full set each time + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{BackupCodes: metadataJSON}, nil).Maybe() + mockDB.On("UpdateUserBackupCodes", mock.Anything, mock.Anything). + Return(nil).Maybe() + + var wg sync.WaitGroup + for range numGoroutines { + wg.Add(1) + go func() { + defer wg.Done() + assert.NotPanics(t, func() { + _, _ = gen.ConsumeBackupCode(ctx, userID, "abcde-11111", updatedBy) + }) + }() + } + wg.Wait() + }) +} + +// TestConsumeBackupCode_TOCTOURaceDocumentation documents the TOCTOU semantic race +// in ConsumeBackupCode. Two concurrent callers both reading the same codes list and +// both returning consumed=true for the same code reveals a design-level race condition. +// +// This is NOT a Go data race (no concurrent memory access) — it is a logical race +// that must be fixed at the database level with transactions or advisory locks. +func TestConsumeBackupCode_TOCTOURaceDocumentation(t *testing.T) { + ctx := context.Background() + userID := int32(42) + updatedBy := "admin" + + hasher := password.NewBcryptHasher(&password.BcryptConfig{Cost: 4}) + codeHash, err := hasher.GenerateHash("abcde-12345") + require.NoError(t, err) + + t.Run("two sequential calls to consume same code: first succeeds, second fails", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + storedCodes := []BackupCode{{Hash: codeHash}} + metadataJSON := buildMetadataJSON(t, storedCodes, 1) + emptyMetadataJSON := buildMetadataJSON(t, []BackupCode{}, 0) + + // First consumption: reads full list, writes empty list + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{BackupCodes: metadataJSON}, nil).Once() + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{BackupCodes: metadataJSON}, nil).Once() + mockDB.On("UpdateUserBackupCodes", mock.Anything, mock.MatchedBy(func(p models.UpdateUserBackupCodesParams) bool { + var m Metadata + _ = json.Unmarshal(p.BackupCodes, &m) + return m.CodesRemaining == 0 + })).Return(nil).Once() + + // Second consumption: reads empty list + mockDB.On("GetUserBackupCodes", mock.Anything, userID). + Return(models.GetUserBackupCodesRow{BackupCodes: emptyMetadataJSON}, nil).Once() + + firstConsumed, firstErr := gen.ConsumeBackupCode(ctx, userID, "abcde-12345", updatedBy) + require.NoError(t, firstErr) + assert.True(t, firstConsumed, "first consumption should succeed") + + secondConsumed, secondErr := gen.ConsumeBackupCode(ctx, userID, "abcde-12345", updatedBy) + require.NoError(t, secondErr) + assert.False(t, secondConsumed, "second consumption should fail — code already used") + + mockDB.AssertExpectations(t) + }) +} diff --git a/internal/auth/backupcodes/generator_test.go b/internal/auth/backupcodes/generator_test.go index 9209036..49abe89 100644 --- a/internal/auth/backupcodes/generator_test.go +++ b/internal/auth/backupcodes/generator_test.go @@ -4,12 +4,17 @@ package backupcodes import ( + "context" "encoding/json" "strings" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/undernetirc/cservice-api/db/mocks" + "github.com/undernetirc/cservice-api/internal/auth/password" + "github.com/undernetirc/cservice-api/models" ) func TestGenerateBackupCodes(t *testing.T) { @@ -280,3 +285,362 @@ func TestBackupCodesMetadata(t *testing.T) { assert.Equal(t, 5, metadata.CodesRemaining) }) } + +// buildMetadataJSON is a test helper that builds the nested metadata JSON +// structure matching what the database returns. +func buildMetadataJSON(t *testing.T, codes []BackupCode, codesRemaining int) []byte { + t.Helper() + codesJSON, err := json.Marshal(codes) + require.NoError(t, err) + + metadata := Metadata{ + BackupCodes: string(codesJSON), + GeneratedAt: "2025-06-22T10:30:00Z", + CodesRemaining: codesRemaining, + } + metadataJSON, err := json.Marshal(metadata) + require.NoError(t, err) + return metadataJSON +} + +func TestNewBackupCodeGenerator(t *testing.T) { + t.Run("creates generator with db", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + require.NotNil(t, gen) + assert.Equal(t, mockDB, gen.db) + }) + + t.Run("creates generator with nil db", func(t *testing.T) { + gen := NewBackupCodeGenerator(nil) + + require.NotNil(t, gen) + assert.Nil(t, gen.db) + }) +} + +func TestGenerateAndStoreBackupCodes(t *testing.T) { + ctx := context.Background() + userID := int32(42) + updatedBy := "admin" + + t.Run("successful generation and storage", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + mockDB.On("UpdateUserBackupCodes", mock.Anything, mock.MatchedBy(func(arg models.UpdateUserBackupCodesParams) bool { + return arg.ID == userID && len(arg.BackupCodes) > 0 && arg.LastUpdatedBy.String == updatedBy + })).Return(nil).Once() + + codes, err := gen.GenerateAndStoreBackupCodes(ctx, userID, updatedBy) + + require.NoError(t, err) + assert.Len(t, codes, BackupCodeCount) + for _, code := range codes { + assert.NoError(t, ValidateBackupCodeFormat(code)) + } + mockDB.AssertExpectations(t) + }) +} + +func TestGenerateAndStoreBackupCodes_DBError(t *testing.T) { + ctx := context.Background() + userID := int32(42) + updatedBy := "admin" + + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + mockDB.On("UpdateUserBackupCodes", mock.Anything, mock.Anything).Return(assert.AnError).Once() + + codes, err := gen.GenerateAndStoreBackupCodes(ctx, userID, updatedBy) + + require.Error(t, err) + assert.Nil(t, codes) + assert.Contains(t, err.Error(), "failed to store backup codes") + mockDB.AssertExpectations(t) +} + +func TestGetBackupCodes(t *testing.T) { + ctx := context.Background() + userID := int32(42) + + t.Run("successful retrieval", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + storedCodes := []BackupCode{ + {Hash: "$2a$04$abcdefghijklmnopqrstuuABCDEFGHIJKLMNOPQRSTU12345678"}, + {Hash: "$2a$04$zyxwvutsrqponmlkjihgfeZYXWVUTSRQPONMLKJIHGFE87654321"}, + } + metadataJSON := buildMetadataJSON(t, storedCodes, 2) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID).Return(models.GetUserBackupCodesRow{ + BackupCodes: metadataJSON, + }, nil).Once() + + codes, err := gen.GetBackupCodes(ctx, userID) + + require.NoError(t, err) + require.Len(t, codes, 2) + assert.Equal(t, storedCodes[0].Hash, codes[0].Hash) + assert.Equal(t, storedCodes[1].Hash, codes[1].Hash) + mockDB.AssertExpectations(t) + }) + + t.Run("empty backup codes", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID).Return(models.GetUserBackupCodesRow{ + BackupCodes: nil, + }, nil).Once() + + codes, err := gen.GetBackupCodes(ctx, userID) + + assert.NoError(t, err) + assert.Nil(t, codes) + mockDB.AssertExpectations(t) + }) +} + +func TestGetBackupCodes_NotFound(t *testing.T) { + ctx := context.Background() + userID := int32(999) + + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID).Return(models.GetUserBackupCodesRow{ + BackupCodes: []byte{}, + }, nil).Once() + + codes, err := gen.GetBackupCodes(ctx, userID) + + assert.NoError(t, err) + assert.Nil(t, codes) + mockDB.AssertExpectations(t) +} + +func TestGetBackupCodes_DBError(t *testing.T) { + ctx := context.Background() + userID := int32(42) + + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID).Return(models.GetUserBackupCodesRow{}, assert.AnError).Once() + + codes, err := gen.GetBackupCodes(ctx, userID) + + require.Error(t, err) + assert.Nil(t, codes) + assert.Contains(t, err.Error(), "failed to retrieve backup codes") + mockDB.AssertExpectations(t) +} + +func TestUpdateBackupCodes(t *testing.T) { + ctx := context.Background() + userID := int32(42) + updatedBy := "admin" + + t.Run("successful update", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + existingCodes := []BackupCode{{Hash: "existinghash"}} + existingMetadata := buildMetadataJSON(t, existingCodes, 1) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID).Return(models.GetUserBackupCodesRow{ + BackupCodes: existingMetadata, + }, nil).Once() + + mockDB.On("UpdateUserBackupCodes", mock.Anything, mock.MatchedBy(func(arg models.UpdateUserBackupCodesParams) bool { + return arg.ID == userID && arg.LastUpdatedBy.String == updatedBy + })).Return(nil).Once() + + newCodes := []BackupCode{{Hash: "newhash1"}, {Hash: "newhash2"}} + err := gen.UpdateBackupCodes(ctx, userID, newCodes, updatedBy) + + require.NoError(t, err) + mockDB.AssertExpectations(t) + }) + + t.Run("preserves generated_at from existing metadata", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + existingCodes := []BackupCode{{Hash: "hash"}} + existingMetadata := buildMetadataJSON(t, existingCodes, 1) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID).Return(models.GetUserBackupCodesRow{ + BackupCodes: existingMetadata, + }, nil).Once() + + mockDB.On("UpdateUserBackupCodes", mock.Anything, mock.MatchedBy(func(arg models.UpdateUserBackupCodesParams) bool { + var metadata Metadata + if err := json.Unmarshal(arg.BackupCodes, &metadata); err != nil { + return false + } + return metadata.GeneratedAt == "2025-06-22T10:30:00Z" + })).Return(nil).Once() + + err := gen.UpdateBackupCodes(ctx, userID, []BackupCode{{Hash: "newhash"}}, updatedBy) + + require.NoError(t, err) + mockDB.AssertExpectations(t) + }) +} + +func TestUpdateBackupCodes_DBError(t *testing.T) { + ctx := context.Background() + userID := int32(42) + updatedBy := "admin" + + t.Run("error on get current metadata", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID).Return(models.GetUserBackupCodesRow{}, assert.AnError).Once() + + err := gen.UpdateBackupCodes(ctx, userID, []BackupCode{{Hash: "h"}}, updatedBy) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to get current backup codes metadata") + mockDB.AssertExpectations(t) + }) + + t.Run("error on update", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + existingMetadata := buildMetadataJSON(t, []BackupCode{{Hash: "h"}}, 1) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID).Return(models.GetUserBackupCodesRow{ + BackupCodes: existingMetadata, + }, nil).Once() + mockDB.On("UpdateUserBackupCodes", mock.Anything, mock.Anything).Return(assert.AnError).Once() + + err := gen.UpdateBackupCodes(ctx, userID, []BackupCode{{Hash: "h"}}, updatedBy) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to update backup codes") + mockDB.AssertExpectations(t) + }) +} + +func TestConsumeBackupCode(t *testing.T) { + ctx := context.Background() + userID := int32(42) + updatedBy := "admin" + plainCode := "abcde-12345" + + // Use low-cost bcrypt for test speed + hasher := password.NewBcryptHasher(&password.BcryptConfig{Cost: 4}) + hash, err := hasher.GenerateHash(plainCode) + require.NoError(t, err) + + t.Run("successful consumption", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + storedCodes := []BackupCode{ + {Hash: hash}, + {Hash: "$2a$04$otherhashvaluexxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"}, + } + metadataJSON := buildMetadataJSON(t, storedCodes, 2) + + // First call: GetBackupCodes reads from DB + mockDB.On("GetUserBackupCodes", mock.Anything, userID).Return(models.GetUserBackupCodesRow{ + BackupCodes: metadataJSON, + }, nil).Once() + + // Second call: UpdateBackupCodes reads current metadata + mockDB.On("GetUserBackupCodes", mock.Anything, userID).Return(models.GetUserBackupCodesRow{ + BackupCodes: metadataJSON, + }, nil).Once() + + // UpdateBackupCodes writes the reduced set + mockDB.On("UpdateUserBackupCodes", mock.Anything, mock.MatchedBy(func(arg models.UpdateUserBackupCodesParams) bool { + var metadata Metadata + if err := json.Unmarshal(arg.BackupCodes, &metadata); err != nil { + return false + } + return metadata.CodesRemaining == 1 + })).Return(nil).Once() + + consumed, err := gen.ConsumeBackupCode(ctx, userID, plainCode, updatedBy) + + require.NoError(t, err) + assert.True(t, consumed) + mockDB.AssertExpectations(t) + }) +} + +func TestConsumeBackupCode_Invalid(t *testing.T) { + ctx := context.Background() + userID := int32(42) + updatedBy := "admin" + + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + // Use low-cost bcrypt hash for a different code + hasher := password.NewBcryptHasher(&password.BcryptConfig{Cost: 4}) + hash, err := hasher.GenerateHash("zzzzz-99999") + require.NoError(t, err) + + storedCodes := []BackupCode{{Hash: hash}} + metadataJSON := buildMetadataJSON(t, storedCodes, 1) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID).Return(models.GetUserBackupCodesRow{ + BackupCodes: metadataJSON, + }, nil).Once() + + consumed, err := gen.ConsumeBackupCode(ctx, userID, "wrong-codes", updatedBy) + + require.NoError(t, err) + assert.False(t, consumed) + mockDB.AssertExpectations(t) +} + +func TestConsumeBackupCode_AlreadyUsed(t *testing.T) { + ctx := context.Background() + userID := int32(42) + updatedBy := "admin" + + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + // No codes remaining — simulates all codes already consumed + metadataJSON := buildMetadataJSON(t, []BackupCode{}, 0) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID).Return(models.GetUserBackupCodesRow{ + BackupCodes: metadataJSON, + }, nil).Once() + + consumed, err := gen.ConsumeBackupCode(ctx, userID, "abcde-12345", updatedBy) + + require.NoError(t, err) + assert.False(t, consumed) + mockDB.AssertExpectations(t) +} + +func TestConsumeBackupCode_DBError(t *testing.T) { + ctx := context.Background() + userID := int32(42) + updatedBy := "admin" + + mockDB := mocks.NewServiceInterface(t) + gen := NewBackupCodeGenerator(mockDB) + + mockDB.On("GetUserBackupCodes", mock.Anything, userID).Return(models.GetUserBackupCodesRow{}, assert.AnError).Once() + + consumed, err := gen.ConsumeBackupCode(ctx, userID, "abcde-12345", updatedBy) + + require.Error(t, err) + assert.False(t, consumed) + assert.Contains(t, err.Error(), "failed to get backup codes") + mockDB.AssertExpectations(t) +} diff --git a/internal/auth/reset/cleanup_test.go b/internal/auth/reset/cleanup_test.go new file mode 100644 index 0000000..718665f --- /dev/null +++ b/internal/auth/reset/cleanup_test.go @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (c) 2023 UnderNET + +package reset + +import ( + "context" + "errors" + "io" + "log/slog" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/undernetirc/cservice-api/db/mocks" + "github.com/undernetirc/cservice-api/models" +) + +func discardLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +func TestNewCleanupService(t *testing.T) { + t.Run("with_logger", func(t *testing.T) { + tm := NewTokenManager(nil, nil) + logger := discardLogger() + cs := NewCleanupService(tm, time.Hour, logger) + + require.NotNil(t, cs) + assert.Equal(t, tm, cs.tokenManager) + assert.Equal(t, time.Hour, cs.interval) + assert.Equal(t, logger, cs.logger) + assert.NotNil(t, cs.stopCh) + assert.NotNil(t, cs.doneCh) + }) + + t.Run("nil_logger_uses_default", func(t *testing.T) { + tm := NewTokenManager(nil, nil) + cs := NewCleanupService(tm, time.Hour, nil) + + require.NotNil(t, cs) + assert.NotNil(t, cs.logger) + }) +} + +func TestCleanupService_RunOnce(t *testing.T) { + t.Run("successful_cleanup", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + cs := NewCleanupService(tm, time.Hour, discardLogger()) + + statsBefore := models.GetPasswordResetTokenStatsRow{ + TotalTokens: 10, + UsedTokens: 2, + ExpiredTokens: 3, + ActiveTokens: 5, + } + statsAfter := models.GetPasswordResetTokenStatsRow{ + TotalTokens: 7, + UsedTokens: 2, + ExpiredTokens: 0, + ActiveTokens: 5, + } + + // GetTokenStats called twice: before and after cleanup + db.On("GetPasswordResetTokenStats", mock.Anything, mock.Anything). + Return(statsBefore, nil).Once() + db.On("CleanupExpiredPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Once() + db.On("DeleteExpiredPasswordResetTokens", mock.Anything, mock.Anything). + Return(nil).Once() + db.On("GetPasswordResetTokenStats", mock.Anything, mock.Anything). + Return(statsAfter, nil).Once() + + err := cs.RunOnce(context.Background()) + assert.NoError(t, err) + }) +} + +func TestCleanupService_RunOnce_DBError(t *testing.T) { + t.Run("stats_before_error", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + cs := NewCleanupService(tm, time.Hour, discardLogger()) + + db.On("GetPasswordResetTokenStats", mock.Anything, mock.Anything). + Return(models.GetPasswordResetTokenStatsRow{}, errors.New("db down")).Once() + + // RunOnce returns nil — errors are logged, not returned + err := cs.RunOnce(context.Background()) + assert.NoError(t, err) + }) + + t.Run("cleanup_error", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + cs := NewCleanupService(tm, time.Hour, discardLogger()) + + db.On("GetPasswordResetTokenStats", mock.Anything, mock.Anything). + Return(models.GetPasswordResetTokenStatsRow{TotalTokens: 5}, nil).Once() + db.On("CleanupExpiredPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything). + Return(errors.New("cleanup failed")).Once() + + err := cs.RunOnce(context.Background()) + assert.NoError(t, err) + }) +} + +func TestCleanupService_StartStop(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + cs := NewCleanupService(tm, 50*time.Millisecond, discardLogger()) + + // The initial performCleanup will call GetTokenStats — mock it to allow the goroutine to proceed + db.On("GetPasswordResetTokenStats", mock.Anything, mock.Anything). + Return(models.GetPasswordResetTokenStatsRow{}, nil).Maybe() + db.On("CleanupExpiredPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + db.On("DeleteExpiredPasswordResetTokens", mock.Anything, mock.Anything). + Return(nil).Maybe() + + ctx := context.Background() + cs.Start(ctx) + + // Give it time to run the initial cleanup and at least one tick + time.Sleep(100 * time.Millisecond) + + // Stop should return (doneCh closes) + done := make(chan struct{}) + go func() { + cs.Stop() + close(done) + }() + + select { + case <-done: + // Success — Stop returned + case <-time.After(2 * time.Second): + t.Fatal("Stop did not return in time — goroutine may be stuck") + } +} + +func TestPerformCleanup(t *testing.T) { + t.Run("success", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + cs := NewCleanupService(tm, time.Hour, discardLogger()) + + statsBefore := models.GetPasswordResetTokenStatsRow{ + TotalTokens: 20, ExpiredTokens: 5, ActiveTokens: 10, UsedTokens: 5, + } + statsAfter := models.GetPasswordResetTokenStatsRow{ + TotalTokens: 15, ExpiredTokens: 0, ActiveTokens: 10, UsedTokens: 5, + } + + db.On("GetPasswordResetTokenStats", mock.Anything, mock.Anything). + Return(statsBefore, nil).Once() + db.On("CleanupExpiredPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Once() + db.On("DeleteExpiredPasswordResetTokens", mock.Anything, mock.Anything). + Return(nil).Once() + db.On("GetPasswordResetTokenStats", mock.Anything, mock.Anything). + Return(statsAfter, nil).Once() + + cs.performCleanup(context.Background()) + db.AssertExpectations(t) + }) + + t.Run("stats_before_error_returns_early", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + cs := NewCleanupService(tm, time.Hour, discardLogger()) + + db.On("GetPasswordResetTokenStats", mock.Anything, mock.Anything). + Return(models.GetPasswordResetTokenStatsRow{}, errors.New("stats failed")).Once() + + cs.performCleanup(context.Background()) + // CleanupExpiredPasswordResetTokens should NOT be called + db.AssertNotCalled(t, "CleanupExpiredPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything) + }) + + t.Run("cleanup_error_returns_early", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + cs := NewCleanupService(tm, time.Hour, discardLogger()) + + db.On("GetPasswordResetTokenStats", mock.Anything, mock.Anything). + Return(models.GetPasswordResetTokenStatsRow{TotalTokens: 10}, nil).Once() + db.On("CleanupExpiredPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything). + Return(errors.New("cleanup failed")).Once() + + cs.performCleanup(context.Background()) + // DeleteExpiredPasswordResetTokens should NOT be called + db.AssertNotCalled(t, "DeleteExpiredPasswordResetTokens", mock.Anything, mock.Anything) + }) + + t.Run("stats_after_error", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + cs := NewCleanupService(tm, time.Hour, discardLogger()) + + db.On("GetPasswordResetTokenStats", mock.Anything, mock.Anything). + Return(models.GetPasswordResetTokenStatsRow{TotalTokens: 10}, nil).Once() + db.On("CleanupExpiredPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Once() + db.On("DeleteExpiredPasswordResetTokens", mock.Anything, mock.Anything). + Return(nil).Once() + db.On("GetPasswordResetTokenStats", mock.Anything, mock.Anything). + Return(models.GetPasswordResetTokenStatsRow{}, errors.New("stats after failed")).Once() + + cs.performCleanup(context.Background()) + db.AssertExpectations(t) + }) +} diff --git a/internal/auth/reset/manager_boundary_test.go b/internal/auth/reset/manager_boundary_test.go new file mode 100644 index 0000000..e5ac694 --- /dev/null +++ b/internal/auth/reset/manager_boundary_test.go @@ -0,0 +1,528 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (c) 2023 UnderNET + +package reset + +import ( + "context" + "errors" + "math" + "strings" + "sync" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/undernetirc/cservice-api/db/mocks" + "github.com/undernetirc/cservice-api/models" +) + +// TestValidateToken_SecurityBoundary tests adversarial inputs to ValidateToken. +// The DB call uses parameterized queries, so SQL injection is not a runtime risk, +// but the token strings should be handled gracefully regardless of content. +func TestValidateToken_SecurityBoundary(t *testing.T) { + ctx := context.Background() + + adversarialTokens := []struct { + name string + token string + }{ + {"empty token", ""}, + {"whitespace-only token", " \t\n "}, + {"very long token (10000 chars)", strings.Repeat("x", 10000)}, + {"SQL injection attempt", "'; DROP TABLE password_reset_tokens; --"}, + {"unicode characters", "测试-reset-token-值"}, + {"null bytes", "token\x00withNull\x00bytes"}, + {"newline injection", "token\r\nX-Injected: evil"}, + {"path traversal", "../../../../etc/passwd"}, + {"only special chars", "!@#$%^&*()_+{}|:<>?"}, + } + + for _, tc := range adversarialTokens { + t.Run(tc.name, func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + // DB is expected to return an error (no matching token) + db.On("ValidatePasswordResetToken", mock.Anything, tc.token, mock.Anything). + Return(models.PasswordResetToken{}, errors.New("no rows in result set")).Once() + + assert.NotPanics(t, func() { + result, err := tm.ValidateToken(ctx, tc.token) + assert.Nil(t, result) + assert.Error(t, err, "adversarial token should be rejected") + assert.Contains(t, err.Error(), "invalid or expired token") + }) + db.AssertExpectations(t) + }) + } +} + +// TestGetTokenTimeRemaining_Boundary tests edge cases in the time remaining calculation. +func TestGetTokenTimeRemaining_Boundary(t *testing.T) { + tm := NewTokenManager(nil, nil) + now := time.Now().Unix() + + t.Run("token expiring in exactly 1 second has positive remaining time", func(t *testing.T) { + token := &models.PasswordResetToken{ + ExpiresAt: int32(now + 1), + } + remaining := tm.GetTokenTimeRemaining(token) + assert.Equal(t, time.Second, remaining) + }) + + t.Run("token expiring exactly now returns 0", func(t *testing.T) { + token := &models.PasswordResetToken{ + ExpiresAt: int32(now), + } + remaining := tm.GetTokenTimeRemaining(token) + assert.Equal(t, time.Duration(0), remaining) + }) + + t.Run("token with ExpiresAt=0 returns 0 (treated as epoch, always expired)", func(t *testing.T) { + token := &models.PasswordResetToken{ + ExpiresAt: 0, + } + remaining := tm.GetTokenTimeRemaining(token) + // int64(0) - now is very negative → returns 0 + assert.Equal(t, time.Duration(0), remaining, + "ExpiresAt=0 (epoch) is far in the past, should return 0") + }) + + t.Run("token with int32 max ExpiresAt has very large remaining time", func(t *testing.T) { + // int32 max = 2147483647 (year 2038) + token := &models.PasswordResetToken{ + ExpiresAt: math.MaxInt32, + } + remaining := tm.GetTokenTimeRemaining(token) + // Should be many years in the future (positive) + assert.True(t, remaining > 0, "max int32 ExpiresAt should have large positive remaining time") + // Should be roughly 12+ years from now (2026 → 2038) + assert.True(t, remaining > 10*365*24*time.Hour, + "max int32 expiry should be at least 10 years from now (year 2038)") + }) + + t.Run("token expiring 1 hour from now has correct remaining", func(t *testing.T) { + token := &models.PasswordResetToken{ + ExpiresAt: int32(now + 3600), + } + remaining := tm.GetTokenTimeRemaining(token) + assert.True(t, remaining > 59*time.Minute) + assert.True(t, remaining <= time.Hour) + }) + + t.Run("token negative remaining (far past) returns 0", func(t *testing.T) { + token := &models.PasswordResetToken{ + ExpiresAt: int32(now - 86400), // 1 day ago + } + remaining := tm.GetTokenTimeRemaining(token) + assert.Equal(t, time.Duration(0), remaining) + }) +} + +// TestCreateToken_Boundary tests edge cases in CreateToken. +func TestCreateToken_Boundary(t *testing.T) { + ctx := context.Background() + + t.Run("userID=0 is accepted by CreateToken (boundary at zero)", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + db.On("GetActivePasswordResetTokensByUserID", mock.Anything, mock.Anything, mock.Anything). + Return([]models.PasswordResetToken{}, nil).Once() + db.On("CreatePasswordResetToken", mock.Anything, mock.Anything). + Return(models.PasswordResetToken{ + ID: 1, + Token: "generated-token", + UserID: pgtype.Int4{Int32: 0, Valid: true}, + }, nil).Once() + + result, err := tm.CreateToken(ctx, 0) + // The function does not validate userID — zero is passed through. + // This documents the boundary: userID=0 is accepted without validation. + require.NoError(t, err) + assert.NotNil(t, result) + }) + + t.Run("exactly MaxTokensPerUser-1 active tokens does NOT trigger invalidation", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + // MaxTokensPerUser defaults to 3; 2 active tokens should NOT trigger invalidation + existingTokens := []models.PasswordResetToken{ + {ID: 1}, {ID: 2}, + } + db.On("GetActivePasswordResetTokensByUserID", mock.Anything, mock.Anything, mock.Anything). + Return(existingTokens, nil).Once() + db.On("CreatePasswordResetToken", mock.Anything, mock.Anything). + Return(models.PasswordResetToken{ID: 3, Token: "new-token"}, nil).Once() + + result, err := tm.CreateToken(ctx, 100) + require.NoError(t, err) + assert.NotNil(t, result) + + // Verify InvalidateUserPasswordResetTokens was NOT called + db.AssertNotCalled(t, "InvalidateUserPasswordResetTokens", + mock.Anything, mock.Anything, mock.Anything) + db.AssertExpectations(t) + }) + + t.Run("exactly MaxTokensPerUser active tokens triggers invalidation", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + // MaxTokensPerUser defaults to 3; 3 active tokens should trigger invalidation + existingTokens := []models.PasswordResetToken{ + {ID: 1}, {ID: 2}, {ID: 3}, + } + db.On("GetActivePasswordResetTokensByUserID", mock.Anything, mock.Anything, mock.Anything). + Return(existingTokens, nil).Once() + db.On("InvalidateUserPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Once() + db.On("CreatePasswordResetToken", mock.Anything, mock.Anything). + Return(models.PasswordResetToken{ID: 4, Token: "new-token"}, nil).Once() + + result, err := tm.CreateToken(ctx, 100) + require.NoError(t, err) + assert.NotNil(t, result) + db.AssertExpectations(t) + }) + + t.Run("MaxTokensPerUser+1 active tokens also triggers invalidation", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + // 4 active tokens (exceeds max of 3) should also trigger invalidation + existingTokens := []models.PasswordResetToken{ + {ID: 1}, {ID: 2}, {ID: 3}, {ID: 4}, + } + db.On("GetActivePasswordResetTokensByUserID", mock.Anything, mock.Anything, mock.Anything). + Return(existingTokens, nil).Once() + db.On("InvalidateUserPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Once() + db.On("CreatePasswordResetToken", mock.Anything, mock.Anything). + Return(models.PasswordResetToken{ID: 5, Token: "new-token"}, nil).Once() + + result, err := tm.CreateToken(ctx, 100) + require.NoError(t, err) + assert.NotNil(t, result) + db.AssertExpectations(t) + }) + + t.Run("MaxTokensPerUser=1 boundary: single active token triggers invalidation", func(t *testing.T) { + db := mocks.NewQuerier(t) + config := &Config{ + TokenLength: 32, + TokenLifetime: time.Hour, + CleanupInterval: 24 * time.Hour, + MaxTokensPerUser: 1, // Very restrictive: any existing token triggers invalidation + } + tm := NewTokenManager(db, config) + + db.On("GetActivePasswordResetTokensByUserID", mock.Anything, mock.Anything, mock.Anything). + Return([]models.PasswordResetToken{{ID: 1}}, nil).Once() + db.On("InvalidateUserPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Once() + db.On("CreatePasswordResetToken", mock.Anything, mock.Anything). + Return(models.PasswordResetToken{ID: 2, Token: "new-token"}, nil).Once() + + result, err := tm.CreateToken(ctx, 100) + require.NoError(t, err) + assert.NotNil(t, result) + db.AssertExpectations(t) + }) +} + +// TestUseToken_SecurityBoundary tests the UseToken function with security-sensitive scenarios. +func TestUseToken_SecurityBoundary(t *testing.T) { + ctx := context.Background() + + t.Run("empty token string is rejected", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + db.On("ValidatePasswordResetToken", mock.Anything, "", mock.Anything). + Return(models.PasswordResetToken{}, errors.New("no rows")).Once() + + err := tm.UseToken(ctx, "") + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid or expired token") + db.AssertExpectations(t) + }) + + t.Run("very long token string is handled without panic", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + longToken := strings.Repeat("x", 10000) + db.On("ValidatePasswordResetToken", mock.Anything, longToken, mock.Anything). + Return(models.PasswordResetToken{}, errors.New("no rows")).Once() + + assert.NotPanics(t, func() { + err := tm.UseToken(ctx, longToken) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid or expired token") + }) + db.AssertExpectations(t) + }) + + t.Run("token reuse: UseToken after mark fails validates again and fails", func(t *testing.T) { + // Simulates a client that calls UseToken twice with the same token. + // Second call must fail because ValidatePasswordResetToken queries DB for active tokens. + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + // First UseToken call: validate succeeds, mark succeeds + db.On("ValidatePasswordResetToken", mock.Anything, "one-time-token", mock.Anything). + Return(models.PasswordResetToken{ID: 1, Token: "one-time-token"}, nil).Once() + db.On("MarkPasswordResetTokenAsUsed", mock.Anything, mock.Anything). + Return(nil).Once() + + // Second UseToken call: validate fails because token is now used + db.On("ValidatePasswordResetToken", mock.Anything, "one-time-token", mock.Anything). + Return(models.PasswordResetToken{}, errors.New("no rows — token already used")).Once() + + // First use should succeed + err := tm.UseToken(ctx, "one-time-token") + assert.NoError(t, err, "first use should succeed") + + // Second use should fail (DB would reject it) + err = tm.UseToken(ctx, "one-time-token") + assert.Error(t, err, "second use of same token should fail") + assert.Contains(t, err.Error(), "invalid or expired token") + + db.AssertExpectations(t) + }) + + t.Run("SQL injection in token is handled safely (parameterized query)", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + injectionToken := "'; UPDATE users SET password='hacked'; --" + db.On("ValidatePasswordResetToken", mock.Anything, injectionToken, mock.Anything). + Return(models.PasswordResetToken{}, errors.New("no rows")).Once() + + assert.NotPanics(t, func() { + err := tm.UseToken(ctx, injectionToken) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid or expired token") + }) + db.AssertExpectations(t) + }) +} + +// TestInvalidateUserTokens_Boundary tests edge cases in InvalidateUserTokens. +func TestInvalidateUserTokens_Boundary(t *testing.T) { + ctx := context.Background() + + t.Run("userID=0 is accepted (no validation of zero user ID)", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + db.On("InvalidateUserPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Once() + + err := tm.InvalidateUserTokens(ctx, 0) + assert.NoError(t, err, "userID=0 should be accepted by InvalidateUserTokens") + db.AssertExpectations(t) + }) + + t.Run("userID=max int32 is accepted", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + db.On("InvalidateUserPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Once() + + err := tm.InvalidateUserTokens(ctx, math.MaxInt32) + assert.NoError(t, err) + db.AssertExpectations(t) + }) +} + +// TestCleanupExpiredTokens_Boundary tests edge cases in CleanupExpiredTokens. +func TestCleanupExpiredTokens_Boundary(t *testing.T) { + ctx := context.Background() + + t.Run("cleanup with context already cancelled returns error from DB", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + cancelledCtx, cancel := context.WithCancel(ctx) + cancel() // Cancel immediately + + // The DB call will receive a cancelled context; mock returns an error + db.On("CleanupExpiredPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything). + Return(errors.New("context canceled")).Once() + + err := tm.CleanupExpiredTokens(cancelledCtx) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to mark expired tokens as deleted") + db.AssertExpectations(t) + }) +} + +// TestTokenManager_ConcurrentAccess tests that TokenManager methods are safe +// to call concurrently (no data races — run with `go test -race`). +func TestTokenManager_ConcurrentAccess(t *testing.T) { + ctx := context.Background() + + t.Run("concurrent ValidateToken calls are goroutine-safe", func(t *testing.T) { + const numGoroutines = 20 + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + // All concurrent calls return "not found" + db.On("ValidatePasswordResetToken", mock.Anything, mock.Anything, mock.Anything). + Return(models.PasswordResetToken{}, errors.New("no rows")).Times(numGoroutines) + + var wg sync.WaitGroup + errs := make([]error, numGoroutines) + + for i := range numGoroutines { + wg.Add(1) + go func(idx int) { + defer wg.Done() + _, errs[idx] = tm.ValidateToken(ctx, "some-token") + }(i) + } + wg.Wait() + + for i, err := range errs { + assert.Error(t, err, "goroutine %d: ValidateToken should return error", i) + } + db.AssertExpectations(t) + }) + + t.Run("concurrent GetTokenTimeRemaining calls are goroutine-safe", func(t *testing.T) { + tm := NewTokenManager(nil, nil) + now := time.Now().Unix() + + token := &models.PasswordResetToken{ + ExpiresAt: int32(now + 3600), + } + + const numGoroutines = 50 + var wg sync.WaitGroup + results := make([]time.Duration, numGoroutines) + + for i := range numGoroutines { + wg.Add(1) + go func(idx int) { + defer wg.Done() + results[idx] = tm.GetTokenTimeRemaining(token) + }(i) + } + wg.Wait() + + for i, d := range results { + assert.True(t, d > 0, "goroutine %d: time remaining should be positive", i) + } + }) +} + +// TestCreateToken_GeneratesNonEmptyToken verifies that generated tokens are non-empty. +// The builder's test mocked the DB return value but did not verify the generated token length. +func TestCreateToken_GeneratesNonEmptyToken(t *testing.T) { + ctx := context.Background() + db := mocks.NewQuerier(t) + + config := &Config{ + TokenLength: 32, + TokenLifetime: time.Hour, + CleanupInterval: 24 * time.Hour, + MaxTokensPerUser: 3, + } + tm := NewTokenManager(db, config) + + var capturedToken string + db.On("GetActivePasswordResetTokensByUserID", mock.Anything, mock.Anything, mock.Anything). + Return([]models.PasswordResetToken{}, nil).Once() + db.On("CreatePasswordResetToken", mock.Anything, mock.MatchedBy(func(p models.CreatePasswordResetTokenParams) bool { + capturedToken = p.Token + return p.Token != "" + })).Return(models.PasswordResetToken{ID: 1, Token: "some-token"}, nil).Once() + + result, err := tm.CreateToken(ctx, 100) + require.NoError(t, err) + assert.NotNil(t, result) + + // Verify the token was non-empty and has expected length (alphanumeric, 32 chars) + assert.NotEmpty(t, capturedToken, "generated token must not be empty") + assert.Len(t, capturedToken, 32, "generated token should be exactly 32 characters long") + + // Token should be alphanumeric only + for _, ch := range capturedToken { + assert.True(t, + (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9'), + "token character %c should be alphanumeric", ch) + } + + db.AssertExpectations(t) +} + +// TestCreateToken_CustomTokenLength verifies that the TokenLength config is respected. +func TestCreateToken_CustomTokenLength(t *testing.T) { + ctx := context.Background() + + tokenLengths := []int{8, 16, 32, 64, 128} + + for _, length := range tokenLengths { + t.Run("token length "+string(rune('0'+length/10))+string(rune('0'+length%10)), func(t *testing.T) { + db := mocks.NewQuerier(t) + config := &Config{ + TokenLength: length, + TokenLifetime: time.Hour, + CleanupInterval: 24 * time.Hour, + MaxTokensPerUser: 3, + } + tm := NewTokenManager(db, config) + + var capturedLength int + db.On("GetActivePasswordResetTokensByUserID", mock.Anything, mock.Anything, mock.Anything). + Return([]models.PasswordResetToken{}, nil).Once() + db.On("CreatePasswordResetToken", mock.Anything, mock.MatchedBy(func(p models.CreatePasswordResetTokenParams) bool { + capturedLength = len(p.Token) + return true + })).Return(models.PasswordResetToken{ID: 1, Token: "tok"}, nil).Once() + + _, err := tm.CreateToken(ctx, 100) + require.NoError(t, err) + assert.Equal(t, length, capturedLength, + "token length should match configured TokenLength=%d", length) + db.AssertExpectations(t) + }) + } +} + +// TestValidateToken_ExpirationAtBoundary tests ValidateToken near the expiration boundary. +// ValidateToken passes the current time to the DB query for filtering. +// This test documents that the expiration logic is DB-side. +func TestValidateToken_ExpirationAtBoundary(t *testing.T) { + ctx := context.Background() + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + // DB returns an active token (expires in future) + now := time.Now().Unix() + activeToken := models.PasswordResetToken{ + ID: 1, + Token: "boundary-token", + UserID: pgtype.Int4{Int32: 100, Valid: true}, + ExpiresAt: int32(now + 1), // 1 second in future + } + + db.On("ValidatePasswordResetToken", mock.Anything, "boundary-token", mock.Anything). + Return(activeToken, nil).Once() + + result, err := tm.ValidateToken(ctx, "boundary-token") + require.NoError(t, err) + assert.NotNil(t, result) + // Verify the token's time remaining is positive + remaining := tm.GetTokenTimeRemaining(result) + assert.True(t, remaining > 0, "token expiring in 1 second should have positive remaining time") + db.AssertExpectations(t) +} diff --git a/internal/auth/reset/manager_test.go b/internal/auth/reset/manager_test.go index fac18bc..cbdda27 100644 --- a/internal/auth/reset/manager_test.go +++ b/internal/auth/reset/manager_test.go @@ -4,10 +4,16 @@ package reset import ( + "context" + "errors" "testing" "time" + "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/undernetirc/cservice-api/db/mocks" "github.com/undernetirc/cservice-api/models" ) @@ -120,3 +126,302 @@ func TestTokenManagerIntegration(t *testing.T) { remaining := tm.GetTokenTimeRemaining(token) assert.True(t, remaining > 0) } + +func TestCreateToken(t *testing.T) { + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + db.On("GetActivePasswordResetTokensByUserID", mock.Anything, mock.Anything, mock.Anything). + Return([]models.PasswordResetToken{}, nil).Once() + db.On("CreatePasswordResetToken", mock.Anything, mock.Anything). + Return(models.PasswordResetToken{ + ID: 1, + Token: "generated-token", + UserID: pgtype.Int4{Int32: 100, Valid: true}, + }, nil).Once() + + result, err := tm.CreateToken(ctx, 100) + require.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, int32(1), result.ID) + }) + + t.Run("success_invalidates_old_when_max_reached", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + existingTokens := []models.PasswordResetToken{ + {ID: 1}, {ID: 2}, {ID: 3}, + } + db.On("GetActivePasswordResetTokensByUserID", mock.Anything, mock.Anything, mock.Anything). + Return(existingTokens, nil).Once() + db.On("InvalidateUserPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Once() + db.On("CreatePasswordResetToken", mock.Anything, mock.Anything). + Return(models.PasswordResetToken{ID: 4, Token: "new-token"}, nil).Once() + + result, err := tm.CreateToken(ctx, 100) + require.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, int32(4), result.ID) + }) + + t.Run("db_error_on_get_active", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + db.On("GetActivePasswordResetTokensByUserID", mock.Anything, mock.Anything, mock.Anything). + Return(nil, errors.New("connection refused")).Once() + + result, err := tm.CreateToken(ctx, 100) + assert.Nil(t, result) + assert.ErrorContains(t, err, "failed to check active tokens") + }) +} + +func TestCreateToken_DBError(t *testing.T) { + ctx := context.Background() + + t.Run("db_error_on_create", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + db.On("GetActivePasswordResetTokensByUserID", mock.Anything, mock.Anything, mock.Anything). + Return([]models.PasswordResetToken{}, nil).Once() + db.On("CreatePasswordResetToken", mock.Anything, mock.Anything). + Return(models.PasswordResetToken{}, errors.New("insert failed")).Once() + + result, err := tm.CreateToken(ctx, 100) + assert.Nil(t, result) + assert.ErrorContains(t, err, "failed to create reset token") + }) + + t.Run("db_error_on_invalidate", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + existingTokens := []models.PasswordResetToken{ + {ID: 1}, {ID: 2}, {ID: 3}, + } + db.On("GetActivePasswordResetTokensByUserID", mock.Anything, mock.Anything, mock.Anything). + Return(existingTokens, nil).Once() + db.On("InvalidateUserPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything). + Return(errors.New("update failed")).Once() + + result, err := tm.CreateToken(ctx, 100) + assert.Nil(t, result) + assert.ErrorContains(t, err, "failed to invalidate old tokens") + }) +} + +func TestValidateToken(t *testing.T) { + ctx := context.Background() + + t.Run("valid_token", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + expected := models.PasswordResetToken{ + ID: 1, + Token: "valid-token", + UserID: pgtype.Int4{Int32: 100, Valid: true}, + } + db.On("ValidatePasswordResetToken", mock.Anything, "valid-token", mock.Anything). + Return(expected, nil).Once() + + result, err := tm.ValidateToken(ctx, "valid-token") + require.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, int32(1), result.ID) + assert.Equal(t, "valid-token", result.Token) + }) +} + +func TestValidateToken_Expired(t *testing.T) { + ctx := context.Background() + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + db.On("ValidatePasswordResetToken", mock.Anything, "expired-token", mock.Anything). + Return(models.PasswordResetToken{}, errors.New("no rows")).Once() + + result, err := tm.ValidateToken(ctx, "expired-token") + assert.Nil(t, result) + assert.ErrorContains(t, err, "invalid or expired token") +} + +func TestValidateToken_Invalid(t *testing.T) { + ctx := context.Background() + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + db.On("ValidatePasswordResetToken", mock.Anything, "nonexistent-token", mock.Anything). + Return(models.PasswordResetToken{}, errors.New("no rows in result set")).Once() + + result, err := tm.ValidateToken(ctx, "nonexistent-token") + assert.Nil(t, result) + assert.ErrorContains(t, err, "invalid or expired token") +} + +func TestUseToken(t *testing.T) { + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + db.On("ValidatePasswordResetToken", mock.Anything, "use-me", mock.Anything). + Return(models.PasswordResetToken{ID: 1, Token: "use-me"}, nil).Once() + db.On("MarkPasswordResetTokenAsUsed", mock.Anything, mock.Anything). + Return(nil).Once() + + err := tm.UseToken(ctx, "use-me") + assert.NoError(t, err) + }) + + t.Run("db_error_on_mark", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + db.On("ValidatePasswordResetToken", mock.Anything, "mark-fail", mock.Anything). + Return(models.PasswordResetToken{ID: 1, Token: "mark-fail"}, nil).Once() + db.On("MarkPasswordResetTokenAsUsed", mock.Anything, mock.Anything). + Return(errors.New("update failed")).Once() + + err := tm.UseToken(ctx, "mark-fail") + assert.ErrorContains(t, err, "failed to mark token as used") + }) +} + +func TestUseToken_AlreadyUsed(t *testing.T) { + ctx := context.Background() + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + db.On("ValidatePasswordResetToken", mock.Anything, "already-used", mock.Anything). + Return(models.PasswordResetToken{}, errors.New("no rows")).Once() + + err := tm.UseToken(ctx, "already-used") + assert.ErrorContains(t, err, "invalid or expired token") +} + +func TestInvalidateUserTokens(t *testing.T) { + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + db.On("InvalidateUserPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Once() + + err := tm.InvalidateUserTokens(ctx, 100) + assert.NoError(t, err) + }) + + t.Run("db_error", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + db.On("InvalidateUserPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything). + Return(errors.New("connection lost")).Once() + + err := tm.InvalidateUserTokens(ctx, 100) + assert.ErrorContains(t, err, "failed to invalidate user tokens") + }) +} + +func TestCleanupExpiredTokens(t *testing.T) { + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + db.On("CleanupExpiredPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Once() + db.On("DeleteExpiredPasswordResetTokens", mock.Anything, mock.Anything). + Return(nil).Once() + + err := tm.CleanupExpiredTokens(ctx) + assert.NoError(t, err) + }) + + t.Run("nothing_to_clean", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + db.On("CleanupExpiredPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Once() + db.On("DeleteExpiredPasswordResetTokens", mock.Anything, mock.Anything). + Return(nil).Once() + + err := tm.CleanupExpiredTokens(ctx) + assert.NoError(t, err) + }) + + t.Run("db_error_on_cleanup", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + db.On("CleanupExpiredPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything). + Return(errors.New("update failed")).Once() + + err := tm.CleanupExpiredTokens(ctx) + assert.ErrorContains(t, err, "failed to mark expired tokens as deleted") + }) + + t.Run("db_error_on_delete", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + db.On("CleanupExpiredPasswordResetTokens", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Once() + db.On("DeleteExpiredPasswordResetTokens", mock.Anything, mock.Anything). + Return(errors.New("delete failed")).Once() + + err := tm.CleanupExpiredTokens(ctx) + assert.ErrorContains(t, err, "failed to delete expired tokens") + }) +} + +func TestGetTokenStats(t *testing.T) { + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + expected := models.GetPasswordResetTokenStatsRow{ + TotalTokens: 50, + UsedTokens: 20, + ExpiredTokens: 10, + ActiveTokens: 20, + } + db.On("GetPasswordResetTokenStats", mock.Anything, mock.Anything). + Return(expected, nil).Once() + + stats, err := tm.GetTokenStats(ctx) + require.NoError(t, err) + assert.NotNil(t, stats) + assert.Equal(t, int64(50), stats.TotalTokens) + assert.Equal(t, int64(20), stats.UsedTokens) + assert.Equal(t, int64(10), stats.ExpiredTokens) + assert.Equal(t, int64(20), stats.ActiveTokens) + }) + + t.Run("db_error", func(t *testing.T) { + db := mocks.NewQuerier(t) + tm := NewTokenManager(db, nil) + + db.On("GetPasswordResetTokenStats", mock.Anything, mock.Anything). + Return(models.GetPasswordResetTokenStatsRow{}, errors.New("query failed")).Once() + + stats, err := tm.GetTokenStats(ctx) + assert.Nil(t, stats) + assert.ErrorContains(t, err, "failed to get token stats") + }) +} diff --git a/internal/cron/service_test.go b/internal/cron/service_test.go new file mode 100644 index 0000000..8526992 --- /dev/null +++ b/internal/cron/service_test.go @@ -0,0 +1,405 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (c) 2025 UnderNET + +package cron + +import ( + "io" + "log/slog" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/undernetirc/cservice-api/internal/config" +) + +func createServiceTestLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +func TestNewService(t *testing.T) { + t.Run("creates service when enabled", func(t *testing.T) { + logger := createServiceTestLogger() + svc, err := NewService(ServiceConfig{ + Enabled: true, + PasswordResetCleanupCron: "*/5 * * * *", + TimeZone: "UTC", + }, logger) + + require.NoError(t, err) + require.NotNil(t, svc) + assert.True(t, svc.IsEnabled()) + assert.NotNil(t, svc.scheduler) + }) + + t.Run("creates disabled service when not enabled", func(t *testing.T) { + logger := createServiceTestLogger() + svc, err := NewService(ServiceConfig{ + Enabled: false, + }, logger) + + require.NoError(t, err) + require.NotNil(t, svc) + assert.False(t, svc.IsEnabled()) + assert.Nil(t, svc.scheduler) + }) + + t.Run("uses default logger when nil", func(t *testing.T) { + svc, err := NewService(ServiceConfig{ + Enabled: true, + PasswordResetCleanupCron: "*/5 * * * *", + TimeZone: "UTC", + }, nil) + + require.NoError(t, err) + require.NotNil(t, svc) + assert.NotNil(t, svc.logger) + }) + + t.Run("returns error for invalid timezone", func(t *testing.T) { + logger := createServiceTestLogger() + svc, err := NewService(ServiceConfig{ + Enabled: true, + PasswordResetCleanupCron: "*/5 * * * *", + TimeZone: "Invalid/Timezone", + }, logger) + + assert.Error(t, err) + assert.Nil(t, svc) + assert.Contains(t, err.Error(), "failed to create cron scheduler") + }) + + t.Run("uses default logger when disabled and nil logger", func(t *testing.T) { + svc, err := NewService(ServiceConfig{ + Enabled: false, + }, nil) + + require.NoError(t, err) + require.NotNil(t, svc) + assert.NotNil(t, svc.logger) + assert.False(t, svc.IsEnabled()) + }) +} + +func TestService_Start(t *testing.T) { + t.Run("starts enabled service", func(t *testing.T) { + logger := createServiceTestLogger() + svc, err := NewService(ServiceConfig{ + Enabled: true, + PasswordResetCleanupCron: "*/5 * * * *", + TimeZone: "UTC", + }, logger) + require.NoError(t, err) + + err = svc.Start() + assert.NoError(t, err) + + // Clean up + svc.Stop() + }) + + t.Run("start on disabled service returns nil", func(t *testing.T) { + logger := createServiceTestLogger() + svc, err := NewService(ServiceConfig{ + Enabled: false, + }, logger) + require.NoError(t, err) + + err = svc.Start() + assert.NoError(t, err) + }) +} + +func TestService_Stop(t *testing.T) { + t.Run("stops enabled service", func(t *testing.T) { + logger := createServiceTestLogger() + svc, err := NewService(ServiceConfig{ + Enabled: true, + PasswordResetCleanupCron: "*/5 * * * *", + TimeZone: "UTC", + }, logger) + require.NoError(t, err) + + err = svc.Start() + require.NoError(t, err) + + // Should not panic + svc.Stop() + }) + + t.Run("stop on disabled service does not panic", func(t *testing.T) { + logger := createServiceTestLogger() + svc, err := NewService(ServiceConfig{ + Enabled: false, + }, logger) + require.NoError(t, err) + + // Should not panic + svc.Stop() + }) +} + +func TestService_SetupPasswordResetCleanup(t *testing.T) { + t.Run("skips setup when service is disabled", func(t *testing.T) { + logger := createServiceTestLogger() + svc, err := NewService(ServiceConfig{ + Enabled: false, + }, logger) + require.NoError(t, err) + + cfg := ServiceConfig{ + PasswordResetCleanupCron: "*/5 * * * *", + } + + err = svc.SetupPasswordResetCleanup(nil, cfg) + assert.NoError(t, err) + }) + + t.Run("sets up cleanup job when enabled with valid config", func(t *testing.T) { + logger := createServiceTestLogger() + svc, err := NewService(ServiceConfig{ + Enabled: true, + PasswordResetCleanupCron: "*/5 * * * *", + TimeZone: "UTC", + }, logger) + require.NoError(t, err) + + // Set required password reset config values via viper + config.ServicePasswordResetTokenLength.Set(32) + config.ServicePasswordResetTokenLifetimeMinutes.Set(60) + config.ServicePasswordResetCleanupIntervalHours.Set(24) + config.ServicePasswordResetMaxTokensPerUser.Set(3) + + cfg := ServiceConfig{ + PasswordResetCleanupCron: "*/5 * * * *", + } + + err = svc.SetupPasswordResetCleanup(nil, cfg) + assert.NoError(t, err) + + // Verify job was registered + entries := svc.GetJobEntries() + assert.Len(t, entries, 1) + }) + + t.Run("returns error for invalid password reset config", func(t *testing.T) { + logger := createServiceTestLogger() + svc, err := NewService(ServiceConfig{ + Enabled: true, + PasswordResetCleanupCron: "*/5 * * * *", + TimeZone: "UTC", + }, logger) + require.NoError(t, err) + + // Set invalid token length to trigger config load error + config.ServicePasswordResetTokenLength.Set(1) // too short, min is 16 + defer config.ServicePasswordResetTokenLength.Set(32) + + cfg := ServiceConfig{ + PasswordResetCleanupCron: "*/5 * * * *", + } + + err = svc.SetupPasswordResetCleanup(nil, cfg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to load password reset config") + }) + + t.Run("returns error for invalid cron expression", func(t *testing.T) { + logger := createServiceTestLogger() + svc, err := NewService(ServiceConfig{ + Enabled: true, + PasswordResetCleanupCron: "*/5 * * * *", + TimeZone: "UTC", + }, logger) + require.NoError(t, err) + + // Set valid password reset config + config.ServicePasswordResetTokenLength.Set(32) + config.ServicePasswordResetTokenLifetimeMinutes.Set(60) + config.ServicePasswordResetCleanupIntervalHours.Set(24) + config.ServicePasswordResetMaxTokensPerUser.Set(3) + + cfg := ServiceConfig{ + PasswordResetCleanupCron: "invalid-cron", + } + + err = svc.SetupPasswordResetCleanup(nil, cfg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to schedule password reset cleanup job") + }) +} + +func TestService_AddCustomJob(t *testing.T) { + t.Run("adds job successfully", func(t *testing.T) { + logger := createServiceTestLogger() + svc, err := NewService(ServiceConfig{ + Enabled: true, + PasswordResetCleanupCron: "*/5 * * * *", + TimeZone: "UTC", + }, logger) + require.NoError(t, err) + + err = svc.AddCustomJob("*/5 * * * *", "test-job", func() {}) + assert.NoError(t, err) + + entries := svc.GetJobEntries() + assert.Len(t, entries, 1) + }) + + t.Run("returns error when service is disabled", func(t *testing.T) { + logger := createServiceTestLogger() + svc, err := NewService(ServiceConfig{ + Enabled: false, + }, logger) + require.NoError(t, err) + + err = svc.AddCustomJob("*/5 * * * *", "test-job", func() {}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "cron service is disabled") + }) + + t.Run("adds multiple jobs", func(t *testing.T) { + logger := createServiceTestLogger() + svc, err := NewService(ServiceConfig{ + Enabled: true, + PasswordResetCleanupCron: "*/5 * * * *", + TimeZone: "UTC", + }, logger) + require.NoError(t, err) + + err = svc.AddCustomJob("*/5 * * * *", "job-1", func() {}) + require.NoError(t, err) + + err = svc.AddCustomJob("0 * * * *", "job-2", func() {}) + require.NoError(t, err) + + entries := svc.GetJobEntries() + assert.Len(t, entries, 2) + }) +} + +func TestService_AddCustomJob_InvalidSchedule(t *testing.T) { + logger := createServiceTestLogger() + svc, err := NewService(ServiceConfig{ + Enabled: true, + PasswordResetCleanupCron: "*/5 * * * *", + TimeZone: "UTC", + }, logger) + require.NoError(t, err) + + tests := []struct { + name string + cronExpr string + }{ + { + name: "empty expression", + cronExpr: "", + }, + { + name: "too few fields", + cronExpr: "0 0 *", + }, + { + name: "invalid field value", + cronExpr: "60 * * * *", + }, + { + name: "garbage input", + cronExpr: "not-a-cron", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := svc.AddCustomJob(tt.cronExpr, "bad-job", func() {}) + assert.Error(t, err) + }) + } +} + +func TestService_GetJobEntries(t *testing.T) { + t.Run("returns nil when service is disabled", func(t *testing.T) { + logger := createServiceTestLogger() + svc, err := NewService(ServiceConfig{ + Enabled: false, + }, logger) + require.NoError(t, err) + + entries := svc.GetJobEntries() + assert.Nil(t, entries) + }) + + t.Run("returns empty slice when no jobs", func(t *testing.T) { + logger := createServiceTestLogger() + svc, err := NewService(ServiceConfig{ + Enabled: true, + PasswordResetCleanupCron: "*/5 * * * *", + TimeZone: "UTC", + }, logger) + require.NoError(t, err) + + entries := svc.GetJobEntries() + assert.Len(t, entries, 0) + }) + + t.Run("returns job info with valid fields", func(t *testing.T) { + logger := createServiceTestLogger() + svc, err := NewService(ServiceConfig{ + Enabled: true, + PasswordResetCleanupCron: "*/5 * * * *", + TimeZone: "UTC", + }, logger) + require.NoError(t, err) + + err = svc.AddCustomJob("*/5 * * * *", "test-job", func() {}) + require.NoError(t, err) + + svc.Start() + defer svc.Stop() + + entries := svc.GetJobEntries() + require.Len(t, entries, 1) + assert.NotEmpty(t, entries[0].Schedule) + assert.False(t, entries[0].Next.IsZero()) + }) +} + +func TestService_IsEnabled(t *testing.T) { + t.Run("returns true when enabled", func(t *testing.T) { + logger := createServiceTestLogger() + svc, err := NewService(ServiceConfig{ + Enabled: true, + PasswordResetCleanupCron: "*/5 * * * *", + TimeZone: "UTC", + }, logger) + require.NoError(t, err) + assert.True(t, svc.IsEnabled()) + }) + + t.Run("returns false when disabled", func(t *testing.T) { + logger := createServiceTestLogger() + svc, err := NewService(ServiceConfig{ + Enabled: false, + }, logger) + require.NoError(t, err) + assert.False(t, svc.IsEnabled()) + }) +} + +func TestLoadServiceConfigFromViper(t *testing.T) { + // Set config values + config.ServiceCronEnabled.Set(true) + config.ServiceCronPasswordResetCleanup.Set("0 */2 * * *") + config.ServiceCronTimeZone.Set("America/New_York") + defer func() { + config.ServiceCronEnabled.Set(false) + config.ServiceCronPasswordResetCleanup.Set("0 0 * * *") + config.ServiceCronTimeZone.Set("UTC") + }() + + cfg := LoadServiceConfigFromViper() + + assert.True(t, cfg.Enabled) + assert.Equal(t, "0 */2 * * *", cfg.PasswordResetCleanupCron) + assert.Equal(t, "America/New_York", cfg.TimeZone) +} diff --git a/internal/errors/manager_change_errors_test.go b/internal/errors/manager_change_errors_test.go new file mode 100644 index 0000000..9388b3b --- /dev/null +++ b/internal/errors/manager_change_errors_test.go @@ -0,0 +1,195 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (c) 2024 UnderNET + +package errors + +import ( + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewManagerChangeErrorHandler(t *testing.T) { + handler := NewManagerChangeErrorHandler() + require.NotNil(t, handler) +} + +func TestManagerChangeErrorHandler_HandleBusinessRuleError(t *testing.T) { + handler := NewManagerChangeErrorHandler() + + tests := []struct { + name string + error error + expectedStatus int + expectedCode string + }{ + { + name: "forbidden error", + error: &MockValidationError{ + code: ErrCodeForbidden, + message: "User is not channel owner", + details: map[string]any{"error": "not owner"}, + }, + expectedStatus: http.StatusForbidden, + expectedCode: ErrCodeForbidden, + }, + { + name: "not found error", + error: &MockValidationError{ + code: ErrCodeNotFound, + message: "Channel not found or not registered", + details: nil, + }, + expectedStatus: http.StatusNotFound, + expectedCode: ErrCodeNotFound, + }, + { + name: "conflict error", + error: &MockValidationError{ + code: ErrCodeConflict, + message: "Channel already has a pending manager change request", + details: nil, + }, + expectedStatus: http.StatusConflict, + expectedCode: ErrCodeConflict, + }, + { + name: "bad request error", + error: &MockValidationError{ + code: ErrCodeBadRequest, + message: "Cooldown period active", + details: nil, + }, + expectedStatus: http.StatusBadRequest, + expectedCode: ErrCodeBadRequest, + }, + { + name: "database error", + error: &MockValidationError{ + code: ErrCodeDatabase, + message: "Failed to check pending requests", + details: map[string]any{"error": "connection refused"}, + }, + expectedStatus: http.StatusInternalServerError, + expectedCode: ErrCodeDatabase, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/manager-change", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + err := handler.HandleBusinessRuleError(c, tt.error) + require.NoError(t, err) + + assert.Equal(t, tt.expectedStatus, rec.Code) + + var response ErrorResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &response)) + + assert.Equal(t, "error", response.Status) + assert.Equal(t, tt.expectedCode, response.Error.Code) + assert.NotEmpty(t, response.Error.Message) + }) + } + + t.Run("non-validation error fallback", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/manager-change", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + regularError := errors.New("unexpected error") + + captureLogOutput(t, func() { + err := handler.HandleBusinessRuleError(c, regularError) + require.NoError(t, err) + }) + + assert.Equal(t, http.StatusInternalServerError, rec.Code) + + var response ErrorResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &response)) + + assert.Equal(t, "error", response.Status) + assert.Equal(t, ErrCodeInternal, response.Error.Code) + }) +} + +func TestManagerChangeErrorHandler_MapBusinessRuleErrorToHTTPStatus(t *testing.T) { + handler := NewManagerChangeErrorHandler() + + tests := []struct { + errorCode string + expectedStatus int + }{ + {ErrCodeForbidden, http.StatusForbidden}, + {ErrCodeNotFound, http.StatusNotFound}, + {ErrCodeConflict, http.StatusConflict}, + {ErrCodeBadRequest, http.StatusBadRequest}, + {ErrCodeDatabase, http.StatusInternalServerError}, + {"UNKNOWN_ERROR", http.StatusInternalServerError}, + } + + for _, tt := range tests { + t.Run(tt.errorCode, func(t *testing.T) { + status := handler.mapBusinessRuleErrorToHTTPStatus(tt.errorCode) + assert.Equal(t, tt.expectedStatus, status) + }) + } +} + +func TestManagerChangeErrorHandler_GetErrorCategory(t *testing.T) { + handler := NewManagerChangeErrorHandler() + + tests := []struct { + errorCode string + expectedCategory string + }{ + {ErrCodeForbidden, "authorization"}, + {ErrCodeNotFound, "not_found"}, + {ErrCodeConflict, "business_rule"}, + {ErrCodeBadRequest, "validation"}, + {ErrCodeDatabase, "database"}, + {"UNKNOWN_ERROR", "unknown"}, + } + + for _, tt := range tests { + t.Run(tt.errorCode, func(t *testing.T) { + category := handler.GetErrorCategory(tt.errorCode) + assert.Equal(t, tt.expectedCategory, category) + }) + } +} + +func TestManagerChangeErrorHandler_IsRetryableError(t *testing.T) { + handler := NewManagerChangeErrorHandler() + + tests := []struct { + errorCode string + retryable bool + }{ + {ErrCodeDatabase, true}, + {ErrCodeForbidden, false}, + {ErrCodeNotFound, false}, + {ErrCodeConflict, false}, + {ErrCodeBadRequest, false}, + {"UNKNOWN_ERROR", false}, + } + + for _, tt := range tests { + t.Run(tt.errorCode, func(t *testing.T) { + retryable := handler.IsRetryableError(tt.errorCode) + assert.Equal(t, tt.retryable, retryable) + }) + } +} diff --git a/internal/helper/channel_validation_test.go b/internal/helper/channel_validation_test.go index 4769573..466ad71 100644 --- a/internal/helper/channel_validation_test.go +++ b/internal/helper/channel_validation_test.go @@ -596,3 +596,432 @@ func TestValidateUserNoregStatus(t *testing.T) { }) } } + +func TestValidateChannelRegistrationRequest(t *testing.T) { + ctx := context.Background() + userID := int32(123) + + t.Run("full validation pass", func(t *testing.T) { + validator, mockDB := setupChannelValidationTest() + + req := &ChannelRegistrationRequest{ + ChannelName: "#goodchan", + Description: "A nice channel", + Supporters: []string{"supporter1", "supporter2"}, + } + + currentUser := models.GetUserRow{ID: userID, Username: "testuser"} + mockDB.On("GetUser", mock.Anything, models.GetUserParams{ID: userID}).Return(currentUser, nil) + mockDB.On("GetSupportersByUsernames", mock.Anything, []string{"supporter1", "supporter2"}, mock.AnythingOfType("int32")).Return([]models.GetSupportersByUsernamesRow{ + {Username: "supporter1", IsOldEnough: true, HasFraudFlag: false, Email: pgtype.Text{String: "s1@example.com", Valid: true}}, + {Username: "supporter2", IsOldEnough: true, HasFraudFlag: false, Email: pgtype.Text{String: "s2@example.com", Valid: true}}, + }, nil) + mockDB.On("CheckMultipleSupportersNoregStatus", mock.Anything, []string{"supporter1", "supporter2"}).Return([]models.CheckMultipleSupportersNoregStatusRow{ + {Username: "supporter1", IsNoreg: false}, + {Username: "supporter2", IsNoreg: false}, + }, nil) + mockDB.On("CheckMultipleSupportersConcurrentSupports", mock.Anything, []string{"supporter1", "supporter2"}, mock.AnythingOfType("int32")).Return([]models.CheckMultipleSupportersConcurrentSupportsRow{ + {Username: "supporter1", ExceedsLimit: false}, + {Username: "supporter2", ExceedsLimit: false}, + }, nil) + + err := validator.ValidateChannelRegistrationRequest(ctx, req, userID) + assert.NoError(t, err) + mockDB.AssertExpectations(t) + }) + + t.Run("basic struct validation failure", func(t *testing.T) { + validator, _ := setupChannelValidationTest() + + req := &ChannelRegistrationRequest{ + ChannelName: "", // required field missing + Description: "A nice channel", + Supporters: []string{"supporter1", "supporter2"}, + } + + err := validator.ValidateChannelRegistrationRequest(ctx, req, userID) + assert.Error(t, err) + validationErr, ok := err.(*ValidationError) + assert.True(t, ok) + assert.Equal(t, apierrors.ErrCodeValidation, validationErr.Code) + }) + + t.Run("channel name validation failure", func(t *testing.T) { + validator, _ := setupChannelValidationTest() + + req := &ChannelRegistrationRequest{ + ChannelName: "#bad channel", // contains space + Description: "A nice channel", + Supporters: []string{"supporter1", "supporter2"}, + } + + err := validator.ValidateChannelRegistrationRequest(ctx, req, userID) + assert.Error(t, err) + validationErr, ok := err.(*ValidationError) + assert.True(t, ok) + assert.Equal(t, apierrors.ErrCodeInvalidChannelName, validationErr.Code) + }) + + t.Run("description validation failure", func(t *testing.T) { + validator, _ := setupChannelValidationTest() + + req := &ChannelRegistrationRequest{ + ChannelName: "#goodchan", + Description: "", + Supporters: []string{"supporter1", "supporter2"}, + } + + err := validator.ValidateChannelRegistrationRequest(ctx, req, userID) + assert.Error(t, err) + validationErr, ok := err.(*ValidationError) + assert.True(t, ok) + assert.Equal(t, apierrors.ErrCodeInvalidDescription, validationErr.Code) + }) + + t.Run("supporter validation failure", func(t *testing.T) { + validator, _ := setupChannelValidationTest() + + req := &ChannelRegistrationRequest{ + ChannelName: "#goodchan", + Description: "A nice channel", + Supporters: []string{"supporter1"}, // insufficient, need 2 + } + + err := validator.ValidateChannelRegistrationRequest(ctx, req, userID) + assert.Error(t, err) + validationErr, ok := err.(*ValidationError) + assert.True(t, ok) + assert.Equal(t, apierrors.ErrCodeInsufficientSupporters, validationErr.Code) + }) +} + +func TestValidateChannelRegistrationWithAdminBypass(t *testing.T) { + ctx := context.Background() + userID := int32(123) + adminLevel := int32(800) + + t.Run("admin goes through full validation - no bypass for basic validation", func(t *testing.T) { + validator, mockDB := setupChannelValidationTest() + + req := &ChannelRegistrationRequest{ + ChannelName: "#goodchan", + Description: "A nice channel", + Supporters: []string{"supporter1", "supporter2"}, + } + + currentUser := models.GetUserRow{ID: userID, Username: "testuser"} + mockDB.On("GetUser", mock.Anything, models.GetUserParams{ID: userID}).Return(currentUser, nil) + mockDB.On("GetSupportersByUsernames", mock.Anything, []string{"supporter1", "supporter2"}, mock.AnythingOfType("int32")).Return([]models.GetSupportersByUsernamesRow{ + {Username: "supporter1", IsOldEnough: true, HasFraudFlag: false, Email: pgtype.Text{String: "s1@example.com", Valid: true}}, + {Username: "supporter2", IsOldEnough: true, HasFraudFlag: false, Email: pgtype.Text{String: "s2@example.com", Valid: true}}, + }, nil) + mockDB.On("CheckMultipleSupportersNoregStatus", mock.Anything, []string{"supporter1", "supporter2"}).Return([]models.CheckMultipleSupportersNoregStatusRow{ + {Username: "supporter1", IsNoreg: false}, + {Username: "supporter2", IsNoreg: false}, + }, nil) + mockDB.On("CheckMultipleSupportersConcurrentSupports", mock.Anything, []string{"supporter1", "supporter2"}, mock.AnythingOfType("int32")).Return([]models.CheckMultipleSupportersConcurrentSupportsRow{ + {Username: "supporter1", ExceedsLimit: false}, + {Username: "supporter2", ExceedsLimit: false}, + }, nil) + + bypasses, err := validator.ValidateChannelRegistrationWithAdminBypass(ctx, req, userID, adminLevel) + assert.NoError(t, err) + assert.Empty(t, bypasses) + mockDB.AssertExpectations(t) + }) + + t.Run("non-admin fails basic validation", func(t *testing.T) { + validator, _ := setupChannelValidationTest() + + req := &ChannelRegistrationRequest{ + ChannelName: "", // invalid + Description: "A nice channel", + Supporters: []string{"supporter1", "supporter2"}, + } + + bypasses, err := validator.ValidateChannelRegistrationWithAdminBypass(ctx, req, userID, int32(0)) + assert.Error(t, err) + assert.Nil(t, bypasses) + validationErr, ok := err.(*ValidationError) + assert.True(t, ok) + assert.Equal(t, apierrors.ErrCodeValidation, validationErr.Code) + }) + + t.Run("admin also fails basic struct validation - no bypass", func(t *testing.T) { + validator, _ := setupChannelValidationTest() + + req := &ChannelRegistrationRequest{ + ChannelName: "", // required field missing + Description: "A nice channel", + Supporters: []string{"supporter1", "supporter2"}, + } + + bypasses, err := validator.ValidateChannelRegistrationWithAdminBypass(ctx, req, userID, adminLevel) + assert.Error(t, err) + assert.Nil(t, bypasses) + }) +} + +func TestValidateUserNoregStatusWithAdminBypass(t *testing.T) { + ctx := context.Background() + userID := int32(123) + currentUser := models.GetUserRow{ID: userID, Username: "testuser"} + + t.Run("admin cannot bypass noreg - user is restricted", func(t *testing.T) { + validator, mockDB := setupChannelValidationTest() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{ID: userID}).Return(currentUser, nil) + mockDB.On("CheckUserNoregStatus", mock.Anything, "testuser").Return(true, nil) + + bypasses, err := validator.ValidateUserNoregStatusWithAdminBypass(ctx, userID, int32(1000)) + assert.Error(t, err) + assert.Nil(t, bypasses) + validationErr, ok := err.(*ValidationError) + assert.True(t, ok) + assert.Equal(t, apierrors.ErrCodeUserRestricted, validationErr.Code) + mockDB.AssertExpectations(t) + }) + + t.Run("user not restricted - passes", func(t *testing.T) { + validator, mockDB := setupChannelValidationTest() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{ID: userID}).Return(currentUser, nil) + mockDB.On("CheckUserNoregStatus", mock.Anything, "testuser").Return(false, nil) + + bypasses, err := validator.ValidateUserNoregStatusWithAdminBypass(ctx, userID, int32(0)) + assert.NoError(t, err) + assert.Nil(t, bypasses) + mockDB.AssertExpectations(t) + }) +} + +func TestValidateUserChannelLimitsWithAdminBypass(t *testing.T) { + ctx := context.Background() + userID := int32(123) + + t.Run("admin bypasses multiple channel restriction", func(t *testing.T) { + validator, mockDB := setupChannelValidationTest() + config.ServiceChannelRegAllowMultiple.Set(false) + + mockDB.On("GetUserChannels", mock.Anything, userID).Return([]models.GetUserChannelsRow{ + {Name: "#existing", ChannelID: 1, UserID: userID}, + }, nil) + + bypasses, err := validator.ValidateUserChannelLimitsWithAdminBypass(ctx, userID, int32(1)) + assert.NoError(t, err) + assert.Len(t, bypasses, 1) + assert.Equal(t, "MULTIPLE_CHANNEL_BYPASS", bypasses[0].BypassType) + mockDB.AssertExpectations(t) + }) + + t.Run("non-admin blocked by multiple channel restriction", func(t *testing.T) { + validator, mockDB := setupChannelValidationTest() + config.ServiceChannelRegAllowMultiple.Set(false) + + mockDB.On("GetUserChannels", mock.Anything, userID).Return([]models.GetUserChannelsRow{ + {Name: "#existing", ChannelID: 1, UserID: userID}, + }, nil) + + bypasses, err := validator.ValidateUserChannelLimitsWithAdminBypass(ctx, userID, int32(0)) + assert.Error(t, err) + assert.Nil(t, bypasses) + validationErr, ok := err.(*ValidationError) + assert.True(t, ok) + assert.Equal(t, apierrors.ErrCodeChannelLimitExceeded, validationErr.Code) + mockDB.AssertExpectations(t) + }) + + t.Run("admin bypasses general channel limit", func(t *testing.T) { + validator, mockDB := setupChannelValidationTest() + config.ServiceChannelRegAllowMultiple.Set(true) + + // Simulate ValidateUserChannelLimits returning an error (at limit) + mockDB.On("GetUserChannelCount", mock.Anything, userID).Return(int64(1), nil) + mockDB.On("GetUserChannelLimit", mock.Anything, mock.AnythingOfType("models.GetUserChannelLimitParams")).Return(int32(1), nil) + + bypasses, err := validator.ValidateUserChannelLimitsWithAdminBypass(ctx, userID, int32(1)) + assert.NoError(t, err) + assert.Len(t, bypasses, 1) + assert.Equal(t, "CHANNEL_LIMIT_BYPASS", bypasses[0].BypassType) + mockDB.AssertExpectations(t) + }) + + t.Run("non-admin blocked by general channel limit", func(t *testing.T) { + validator, mockDB := setupChannelValidationTest() + config.ServiceChannelRegAllowMultiple.Set(true) + + mockDB.On("GetUserChannelCount", mock.Anything, userID).Return(int64(1), nil) + mockDB.On("GetUserChannelLimit", mock.Anything, mock.AnythingOfType("models.GetUserChannelLimitParams")).Return(int32(1), nil) + + bypasses, err := validator.ValidateUserChannelLimitsWithAdminBypass(ctx, userID, int32(0)) + assert.Error(t, err) + assert.Nil(t, bypasses) + validationErr, ok := err.(*ValidationError) + assert.True(t, ok) + assert.Equal(t, apierrors.ErrCodeChannelLimitReached, validationErr.Code) + mockDB.AssertExpectations(t) + }) + + t.Run("no existing channels - passes without bypass", func(t *testing.T) { + validator, mockDB := setupChannelValidationTest() + config.ServiceChannelRegAllowMultiple.Set(false) + + mockDB.On("GetUserChannels", mock.Anything, userID).Return([]models.GetUserChannelsRow{}, nil) + mockDB.On("GetUserChannelCount", mock.Anything, userID).Return(int64(0), nil) + mockDB.On("GetUserChannelLimit", mock.Anything, mock.AnythingOfType("models.GetUserChannelLimitParams")).Return(int32(1), nil) + + bypasses, err := validator.ValidateUserChannelLimitsWithAdminBypass(ctx, userID, int32(0)) + assert.NoError(t, err) + assert.Empty(t, bypasses) + mockDB.AssertExpectations(t) + }) + + t.Run("database error on GetUserChannels", func(t *testing.T) { + validator, mockDB := setupChannelValidationTest() + config.ServiceChannelRegAllowMultiple.Set(false) + + mockDB.On("GetUserChannels", mock.Anything, userID).Return([]models.GetUserChannelsRow(nil), assert.AnError) + + bypasses, err := validator.ValidateUserChannelLimitsWithAdminBypass(ctx, userID, int32(0)) + assert.Error(t, err) + assert.Nil(t, bypasses) + validationErr, ok := err.(*ValidationError) + assert.True(t, ok) + assert.Equal(t, apierrors.ErrCodeDatabaseError, validationErr.Code) + mockDB.AssertExpectations(t) + }) +} + +func TestValidatePendingRegistrationsWithAdminBypass(t *testing.T) { + ctx := context.Background() + userID := int32(123) + + t.Run("admin bypasses pending registration restriction", func(t *testing.T) { + validator, mockDB := setupChannelValidationTest() + + mockDB.On("GetUserPendingRegistrations", mock.Anything, pgtype.Int4{Int32: userID, Valid: true}).Return(int64(1), nil) + + bypasses, err := validator.ValidatePendingRegistrationsWithAdminBypass(ctx, userID, int32(800)) + assert.NoError(t, err) + assert.Len(t, bypasses, 1) + assert.Equal(t, "PENDING_REGISTRATION_BYPASS", bypasses[0].BypassType) + mockDB.AssertExpectations(t) + }) + + t.Run("lower admin cannot bypass pending registration", func(t *testing.T) { + validator, mockDB := setupChannelValidationTest() + + mockDB.On("GetUserPendingRegistrations", mock.Anything, pgtype.Int4{Int32: userID, Valid: true}).Return(int64(1), nil) + + bypasses, err := validator.ValidatePendingRegistrationsWithAdminBypass(ctx, userID, int32(799)) + assert.Error(t, err) + assert.Nil(t, bypasses) + validationErr, ok := err.(*ValidationError) + assert.True(t, ok) + assert.Equal(t, apierrors.ErrCodePendingExists, validationErr.Code) + mockDB.AssertExpectations(t) + }) + + t.Run("non-admin blocked by pending registration", func(t *testing.T) { + validator, mockDB := setupChannelValidationTest() + + mockDB.On("GetUserPendingRegistrations", mock.Anything, pgtype.Int4{Int32: userID, Valid: true}).Return(int64(2), nil) + + bypasses, err := validator.ValidatePendingRegistrationsWithAdminBypass(ctx, userID, int32(0)) + assert.Error(t, err) + assert.Nil(t, bypasses) + validationErr, ok := err.(*ValidationError) + assert.True(t, ok) + assert.Equal(t, apierrors.ErrCodePendingExists, validationErr.Code) + mockDB.AssertExpectations(t) + }) + + t.Run("no pending registrations - passes", func(t *testing.T) { + validator, mockDB := setupChannelValidationTest() + + mockDB.On("GetUserPendingRegistrations", mock.Anything, pgtype.Int4{Int32: userID, Valid: true}).Return(int64(0), nil) + + bypasses, err := validator.ValidatePendingRegistrationsWithAdminBypass(ctx, userID, int32(0)) + assert.NoError(t, err) + assert.Empty(t, bypasses) + mockDB.AssertExpectations(t) + }) + + t.Run("database error", func(t *testing.T) { + validator, mockDB := setupChannelValidationTest() + + mockDB.On("GetUserPendingRegistrations", mock.Anything, pgtype.Int4{Int32: userID, Valid: true}).Return(int64(0), assert.AnError) + + bypasses, err := validator.ValidatePendingRegistrationsWithAdminBypass(ctx, userID, int32(0)) + assert.Error(t, err) + assert.Nil(t, bypasses) + validationErr, ok := err.(*ValidationError) + assert.True(t, ok) + assert.Equal(t, apierrors.ErrCodeDatabaseError, validationErr.Code) + mockDB.AssertExpectations(t) + }) +} + +func TestValidateChannelNameAvailabilityWithAdminBypass(t *testing.T) { + ctx := context.Background() + channelName := "#testchannel" + + t.Run("admin cannot bypass - name taken", func(t *testing.T) { + validator, mockDB := setupChannelValidationTest() + + mockDB.On("CheckChannelNameExists", mock.Anything, channelName).Return(models.CheckChannelNameExistsRow{ID: 1, Name: channelName}, nil) + + bypasses, err := validator.ValidateChannelNameAvailabilityWithAdminBypass(ctx, channelName, int32(1000)) + assert.Error(t, err) + assert.Nil(t, bypasses) + validationErr, ok := err.(*ValidationError) + assert.True(t, ok) + assert.Equal(t, apierrors.ErrCodeChannelAlreadyExists, validationErr.Code) + mockDB.AssertExpectations(t) + }) + + t.Run("name available - passes", func(t *testing.T) { + validator, mockDB := setupChannelValidationTest() + + mockDB.On("CheckChannelNameExists", mock.Anything, channelName).Return(models.CheckChannelNameExistsRow{}, assert.AnError) + + bypasses, err := validator.ValidateChannelNameAvailabilityWithAdminBypass(ctx, channelName, int32(0)) + assert.NoError(t, err) + assert.Nil(t, bypasses) + mockDB.AssertExpectations(t) + }) +} + +func TestValidateUserIRCActivityWithAdminBypass(t *testing.T) { + ctx := context.Background() + userID := int32(123) + + t.Run("admin cannot bypass - insufficient activity", func(t *testing.T) { + validator, mockDB := setupChannelValidationTest() + + oldTime := pgtype.Int4{Int32: int32(time.Now().Add(-8 * 24 * time.Hour).Unix()), Valid: true} + user := models.GetUserRow{ID: userID, Username: "testuser", LastSeen: oldTime} + mockDB.On("GetUser", mock.Anything, models.GetUserParams{ID: userID}).Return(user, nil) + + bypasses, err := validator.ValidateUserIRCActivityWithAdminBypass(ctx, userID, int32(1000)) + assert.Error(t, err) + assert.Nil(t, bypasses) + validationErr, ok := err.(*ValidationError) + assert.True(t, ok) + assert.Equal(t, apierrors.ErrCodeInactiveUser, validationErr.Code) + mockDB.AssertExpectations(t) + }) + + t.Run("recently active - passes", func(t *testing.T) { + validator, mockDB := setupChannelValidationTest() + + recentTime := pgtype.Int4{Int32: int32(time.Now().Add(-1 * time.Hour).Unix()), Valid: true} + user := models.GetUserRow{ID: userID, Username: "testuser", LastSeen: recentTime} + mockDB.On("GetUser", mock.Anything, models.GetUserParams{ID: userID}).Return(user, nil) + + bypasses, err := validator.ValidateUserIRCActivityWithAdminBypass(ctx, userID, int32(0)) + assert.NoError(t, err) + assert.Nil(t, bypasses) + mockDB.AssertExpectations(t) + }) +} diff --git a/internal/helper/email_validation_test.go b/internal/helper/email_validation_test.go new file mode 100644 index 0000000..462da68 --- /dev/null +++ b/internal/helper/email_validation_test.go @@ -0,0 +1,461 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (c) 2023 UnderNET + +package helper + +import ( + "context" + "errors" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/undernetirc/cservice-api/db/mocks" + "github.com/undernetirc/cservice-api/internal/config" + "github.com/undernetirc/cservice-api/models" +) + +func setupEmailValidationTest() (*EmailLockValidator, *mocks.Querier) { + mockDB := &mocks.Querier{} + validator := NewEmailLockValidator(mockDB) + + // Set default config values for email lock + config.ServiceChannelRegLockedEmailDomains.Set([]string{"locked.com", "blocked.org"}) + config.ServiceChannelRegLockedEmailPatterns.Set([]string{"spam", "throwaway"}) + + return validator, mockDB +} + +func TestNewEmailLockValidator(t *testing.T) { + mockDB := &mocks.Querier{} + validator := NewEmailLockValidator(mockDB) + + require.NotNil(t, validator) + assert.Equal(t, mockDB, validator.db) +} + +func TestIsEmailLocked(t *testing.T) { + validator, _ := setupEmailValidationTest() + ctx := context.Background() + + tests := []struct { + name string + email string + expected bool + }{ + { + name: "empty email returns false", + email: "", + expected: false, + }, + { + name: "unlocked email domain", + email: "user@example.com", + expected: false, + }, + { + name: "locked email domain", + email: "user@locked.com", + expected: true, + }, + { + name: "locked email domain - blocked.org", + email: "user@blocked.org", + expected: true, + }, + { + name: "email matching locked pattern", + email: "spamuser@example.com", + expected: true, + }, + { + name: "email matching throwaway pattern", + email: "throwaway123@example.com", + expected: true, + }, + { + name: "normal email not matching any pattern", + email: "legitimate@example.com", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + locked, err := validator.IsEmailLocked(ctx, tt.email) + assert.NoError(t, err) + assert.Equal(t, tt.expected, locked) + }) + } +} + +func TestIsEmailLocked_CaseSensitivity(t *testing.T) { + validator, _ := setupEmailValidationTest() + ctx := context.Background() + + tests := []struct { + name string + email string + expected bool + }{ + { + name: "uppercase locked domain", + email: "USER@LOCKED.COM", + expected: true, + }, + { + name: "mixed case locked domain", + email: "User@Locked.Com", + expected: true, + }, + { + name: "uppercase locked pattern", + email: "SPAMUSER@example.com", + expected: true, + }, + { + name: "mixed case locked pattern", + email: "ThrowAway@example.com", + expected: true, + }, + { + name: "email with leading/trailing spaces", + email: " user@locked.com ", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + locked, err := validator.IsEmailLocked(ctx, tt.email) + assert.NoError(t, err) + assert.Equal(t, tt.expected, locked) + }) + } +} + +func TestValidateUserEmailNotLocked(t *testing.T) { + ctx := context.Background() + userID := int32(100) + + t.Run("user with unlocked email", func(t *testing.T) { + validator, mockDB := setupEmailValidationTest() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{ + ID: userID, + }).Return(models.GetUserRow{ + ID: userID, + Username: "testuser", + Email: pgtype.Text{String: "user@example.com", Valid: true}, + }, nil).Once() + + err := validator.ValidateUserEmailNotLocked(ctx, userID) + assert.NoError(t, err) + mockDB.AssertExpectations(t) + }) + + t.Run("user with locked email", func(t *testing.T) { + validator, mockDB := setupEmailValidationTest() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{ + ID: userID, + }).Return(models.GetUserRow{ + ID: userID, + Username: "testuser", + Email: pgtype.Text{String: "user@locked.com", Valid: true}, + }, nil).Once() + + err := validator.ValidateUserEmailNotLocked(ctx, userID) + assert.Error(t, err) + var validationErr *ValidationError + require.ErrorAs(t, err, &validationErr) + assert.Equal(t, "EMAIL_LOCKED", validationErr.Code) + mockDB.AssertExpectations(t) + }) + + t.Run("user with invalid email", func(t *testing.T) { + validator, mockDB := setupEmailValidationTest() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{ + ID: userID, + }).Return(models.GetUserRow{ + ID: userID, + Username: "testuser", + Email: pgtype.Text{Valid: false}, + }, nil).Once() + + err := validator.ValidateUserEmailNotLocked(ctx, userID) + assert.Error(t, err) + var validationErr *ValidationError + require.ErrorAs(t, err, &validationErr) + assert.Equal(t, "INVALID_EMAIL", validationErr.Code) + mockDB.AssertExpectations(t) + }) +} + +func TestValidateUserEmailNotLocked_DBError(t *testing.T) { + validator, mockDB := setupEmailValidationTest() + ctx := context.Background() + userID := int32(100) + + dbErr := errors.New("connection refused") + mockDB.On("GetUser", mock.Anything, models.GetUserParams{ + ID: userID, + }).Return(models.GetUserRow{}, dbErr).Once() + + err := validator.ValidateUserEmailNotLocked(ctx, userID) + assert.Error(t, err) + var validationErr *ValidationError + require.ErrorAs(t, err, &validationErr) + assert.Equal(t, "DATABASE_ERROR", validationErr.Code) + mockDB.AssertExpectations(t) +} + +func TestValidateSupporterEmailNotLocked(t *testing.T) { + ctx := context.Background() + supporterUsername := "supporter1" + + t.Run("supporter with unlocked email", func(t *testing.T) { + validator, mockDB := setupEmailValidationTest() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{ + Username: supporterUsername, + }).Return(models.GetUserRow{ + ID: 200, + Username: supporterUsername, + Email: pgtype.Text{String: "supporter@example.com", Valid: true}, + }, nil).Once() + + err := validator.ValidateSupporterEmailNotLocked(ctx, supporterUsername) + assert.NoError(t, err) + mockDB.AssertExpectations(t) + }) + + t.Run("supporter with locked email", func(t *testing.T) { + validator, mockDB := setupEmailValidationTest() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{ + Username: supporterUsername, + }).Return(models.GetUserRow{ + ID: 200, + Username: supporterUsername, + Email: pgtype.Text{String: "supporter@blocked.org", Valid: true}, + }, nil).Once() + + err := validator.ValidateSupporterEmailNotLocked(ctx, supporterUsername) + assert.Error(t, err) + var validationErr *ValidationError + require.ErrorAs(t, err, &validationErr) + assert.Equal(t, "SUPPORTER_EMAIL_LOCKED", validationErr.Code) + mockDB.AssertExpectations(t) + }) + + t.Run("supporter with invalid email", func(t *testing.T) { + validator, mockDB := setupEmailValidationTest() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{ + Username: supporterUsername, + }).Return(models.GetUserRow{ + ID: 200, + Username: supporterUsername, + Email: pgtype.Text{Valid: false}, + }, nil).Once() + + err := validator.ValidateSupporterEmailNotLocked(ctx, supporterUsername) + assert.Error(t, err) + var validationErr *ValidationError + require.ErrorAs(t, err, &validationErr) + assert.Equal(t, "INVALID_EMAIL", validationErr.Code) + mockDB.AssertExpectations(t) + }) + + t.Run("database error", func(t *testing.T) { + validator, mockDB := setupEmailValidationTest() + + dbErr := errors.New("timeout") + mockDB.On("GetUser", mock.Anything, models.GetUserParams{ + Username: supporterUsername, + }).Return(models.GetUserRow{}, dbErr).Once() + + err := validator.ValidateSupporterEmailNotLocked(ctx, supporterUsername) + assert.Error(t, err) + var validationErr *ValidationError + require.ErrorAs(t, err, &validationErr) + assert.Equal(t, "DATABASE_ERROR", validationErr.Code) + mockDB.AssertExpectations(t) + }) +} + +func TestValidateUserEmailLock(t *testing.T) { + ctx := context.Background() + userID := int32(100) + + tests := []struct { + name string + user models.GetUserRow + dbErr error + wantErr bool + errSubstr string + }{ + { + name: "user with unlocked email", + user: models.GetUserRow{ + ID: userID, + Username: "testuser", + Email: pgtype.Text{String: "user@example.com", Valid: true}, + }, + wantErr: false, + }, + { + name: "user with locked email domain", + user: models.GetUserRow{ + ID: userID, + Username: "testuser", + Email: pgtype.Text{String: "user@locked.com", Valid: true}, + }, + wantErr: true, + errSubstr: "email domain/pattern is locked", + }, + { + name: "user with locked email pattern", + user: models.GetUserRow{ + ID: userID, + Username: "testuser", + Email: pgtype.Text{String: "spamuser@example.com", Valid: true}, + }, + wantErr: true, + errSubstr: "email domain/pattern is locked", + }, + { + name: "user with no email", + user: models.GetUserRow{ + ID: userID, + Username: "testuser", + Email: pgtype.Text{Valid: false}, + }, + wantErr: true, + errSubstr: "user has no email address", + }, + { + name: "database error", + dbErr: errors.New("connection refused"), + wantErr: true, + errSubstr: "failed to get user email", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator, mockDB := setupEmailValidationTest() + + if tt.dbErr != nil { + mockDB.On("GetUser", mock.Anything, models.GetUserParams{ + ID: userID, + }).Return(models.GetUserRow{}, tt.dbErr).Once() + } else { + mockDB.On("GetUser", mock.Anything, models.GetUserParams{ + ID: userID, + }).Return(tt.user, nil).Once() + } + + err := validator.ValidateUserEmailLock(ctx, userID) + + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errSubstr) + } else { + assert.NoError(t, err) + } + + mockDB.AssertExpectations(t) + }) + } +} + +func TestValidateSupporterEmailLock(t *testing.T) { + ctx := context.Background() + supporterUsername := "supporter1" + + tests := []struct { + name string + user models.GetUserRow + dbErr error + wantErr bool + errSubstr string + }{ + { + name: "supporter with unlocked email", + user: models.GetUserRow{ + ID: 200, + Username: supporterUsername, + Email: pgtype.Text{String: "supporter@example.com", Valid: true}, + }, + wantErr: false, + }, + { + name: "supporter with locked email domain", + user: models.GetUserRow{ + ID: 200, + Username: supporterUsername, + Email: pgtype.Text{String: "supporter@blocked.org", Valid: true}, + }, + wantErr: true, + errSubstr: "supporter supporter1 email domain/pattern is locked", + }, + { + name: "supporter with locked email pattern", + user: models.GetUserRow{ + ID: 200, + Username: supporterUsername, + Email: pgtype.Text{String: "throwaway123@example.com", Valid: true}, + }, + wantErr: true, + errSubstr: "supporter supporter1 email domain/pattern is locked", + }, + { + name: "supporter with no email", + user: models.GetUserRow{ + ID: 200, + Username: supporterUsername, + Email: pgtype.Text{Valid: false}, + }, + wantErr: true, + errSubstr: "supporter supporter1 has no email address", + }, + { + name: "database error", + dbErr: errors.New("connection refused"), + wantErr: true, + errSubstr: "failed to get supporter email", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + validator, mockDB := setupEmailValidationTest() + + if tt.dbErr != nil { + mockDB.On("GetUser", mock.Anything, models.GetUserParams{ + Username: supporterUsername, + }).Return(models.GetUserRow{}, tt.dbErr).Once() + } else { + mockDB.On("GetUser", mock.Anything, models.GetUserParams{ + Username: supporterUsername, + }).Return(tt.user, nil).Once() + } + + err := validator.ValidateSupporterEmailLock(ctx, supporterUsername) + + if tt.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errSubstr) + } else { + assert.NoError(t, err) + } + + mockDB.AssertExpectations(t) + }) + } +} diff --git a/internal/helper/manager_change_validation_test.go b/internal/helper/manager_change_validation_test.go new file mode 100644 index 0000000..b15faad --- /dev/null +++ b/internal/helper/manager_change_validation_test.go @@ -0,0 +1,1147 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (c) 2023 UnderNET + +package helper + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + mocks "github.com/undernetirc/cservice-api/db/mocks" + apierrors "github.com/undernetirc/cservice-api/internal/errors" + "github.com/undernetirc/cservice-api/models" +) + +func TestNewManagerChangeValidator(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + require.NotNil(t, validator) +} + +// setupAllMocksForValid sets up mocks for a fully valid manager change scenario +func setupAllMocksForValid(mockDB *mocks.ServiceInterface, channelID, userID, newManagerID int32) { + oldRegisteredTs := int32(time.Now().Add(-365 * 24 * time.Hour).Unix()) // 1 year ago + oldSignupTs := int32(time.Now().Add(-365 * 24 * time.Hour).Unix()) // 1 year ago + + mockDB.On("CheckUserChannelOwnership", mock.Anything, userID, channelID). + Return(models.CheckUserChannelOwnershipRow{ + Name: "#testchannel", + ID: channelID, + RegisteredTs: pgtype.Int4{Int32: oldRegisteredTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckChannelExistsAndRegistered", mock.Anything, channelID). + Return(models.CheckChannelExistsAndRegisteredRow{ + ID: channelID, + Name: "#testchannel", + RegisteredTs: pgtype.Int4{Int32: oldRegisteredTs, Valid: true}, + }, nil).Once() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{Username: "newmanager"}). + Return(models.GetUserRow{ + ID: newManagerID, + Username: "newmanager", + SignupTs: pgtype.Int4{Int32: oldSignupTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckNewManagerChannelAccess", mock.Anything, channelID, newManagerID). + Return(models.CheckNewManagerChannelAccessRow{ + Username: "newmanager", + ID: newManagerID, + SignupTs: pgtype.Int4{Int32: oldSignupTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckExistingPendingRequests", mock.Anything, channelID). + Return([]models.CheckExistingPendingRequestsRow{}, nil).Once() + + mockDB.On("CheckUserOwnsOtherChannels", mock.Anything, newManagerID). + Return(false, nil).Once() + + mockDB.On("CheckChannelSingleManager", mock.Anything, channelID). + Return(int64(1), nil).Once() + + mockDB.On("CheckUserCooldownStatus", mock.Anything, userID). + Return(models.CheckUserCooldownStatusRow{ + PostForms: 0, + Verificationdata: pgtype.Text{String: "verified", Valid: true}, + Email: pgtype.Text{String: "test@example.com", Valid: true}, + }, nil).Once() +} + +func TestValidateManagerChangeBusinessRules_Valid(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + channelID := int32(1) + userID := int32(100) + newManagerID := int32(200) + + setupAllMocksForValid(mockDB, channelID, userID, newManagerID) + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + channelID, + userID, + "newmanager", + "permanent", + ) + + assert.NoError(t, err) + mockDB.AssertExpectations(t) +} + +func TestValidateManagerChangeBusinessRules_InsufficientPerms(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + mockDB.On("CheckUserChannelOwnership", mock.Anything, int32(100), int32(1)). + Return(models.CheckUserChannelOwnershipRow{}, errors.New("no rows")).Once() + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + int32(1), + int32(100), + "newmanager", + "permanent", + ) + + require.Error(t, err) + var valErr *ValidationError + require.True(t, errors.As(err, &valErr)) + assert.Equal(t, apierrors.ErrCodeForbidden, valErr.Code) + assert.Equal(t, "User is not channel owner", valErr.Message) + mockDB.AssertExpectations(t) +} + +func TestValidateManagerChangeBusinessRules_ChannelNotFound(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + mockDB.On("CheckUserChannelOwnership", mock.Anything, int32(100), int32(1)). + Return(models.CheckUserChannelOwnershipRow{ + Name: "#testchannel", + ID: 1, + }, nil).Once() + + mockDB.On("CheckChannelExistsAndRegistered", mock.Anything, int32(1)). + Return(models.CheckChannelExistsAndRegisteredRow{}, errors.New("no rows")).Once() + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + int32(1), + int32(100), + "newmanager", + "permanent", + ) + + require.Error(t, err) + var valErr *ValidationError + require.True(t, errors.As(err, &valErr)) + assert.Equal(t, apierrors.ErrCodeNotFound, valErr.Code) + assert.Contains(t, valErr.Message, "Channel not found") + mockDB.AssertExpectations(t) +} + +func TestValidateManagerChangeBusinessRules_ChannelTooNew(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + recentTs := int32(time.Now().Add(-30 * 24 * time.Hour).Unix()) // 30 days ago + + mockDB.On("CheckUserChannelOwnership", mock.Anything, int32(100), int32(1)). + Return(models.CheckUserChannelOwnershipRow{ + Name: "#testchannel", + ID: 1, + RegisteredTs: pgtype.Int4{Int32: recentTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckChannelExistsAndRegistered", mock.Anything, int32(1)). + Return(models.CheckChannelExistsAndRegisteredRow{ + ID: 1, + Name: "#testchannel", + RegisteredTs: pgtype.Int4{Int32: recentTs, Valid: true}, + }, nil).Once() + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + int32(1), + int32(100), + "newmanager", + "permanent", + ) + + require.Error(t, err) + var valErr *ValidationError + require.True(t, errors.As(err, &valErr)) + assert.Equal(t, apierrors.ErrCodeForbidden, valErr.Code) + assert.Contains(t, valErr.Message, "90 days old") + mockDB.AssertExpectations(t) +} + +func TestValidateManagerChangeBusinessRules_UserNotFound(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + oldTs := int32(time.Now().Add(-365 * 24 * time.Hour).Unix()) + + mockDB.On("CheckUserChannelOwnership", mock.Anything, int32(100), int32(1)). + Return(models.CheckUserChannelOwnershipRow{ + Name: "#testchannel", + ID: 1, + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckChannelExistsAndRegistered", mock.Anything, int32(1)). + Return(models.CheckChannelExistsAndRegisteredRow{ + ID: 1, + Name: "#testchannel", + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{Username: "unknownuser"}). + Return(models.GetUserRow{}, errors.New("no rows")).Once() + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + int32(1), + int32(100), + "unknownuser", + "permanent", + ) + + require.Error(t, err) + var valErr *ValidationError + require.True(t, errors.As(err, &valErr)) + assert.Equal(t, apierrors.ErrCodeNotFound, valErr.Code) + assert.Contains(t, valErr.Message, "New manager username not found") + mockDB.AssertExpectations(t) +} + +func TestValidateManagerChangeBusinessRules_NewManagerNoAccess(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + oldTs := int32(time.Now().Add(-365 * 24 * time.Hour).Unix()) + + mockDB.On("CheckUserChannelOwnership", mock.Anything, int32(100), int32(1)). + Return(models.CheckUserChannelOwnershipRow{ + Name: "#testchannel", + ID: 1, + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckChannelExistsAndRegistered", mock.Anything, int32(1)). + Return(models.CheckChannelExistsAndRegisteredRow{ + ID: 1, + Name: "#testchannel", + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{Username: "newmanager"}). + Return(models.GetUserRow{ + ID: 200, + Username: "newmanager", + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckNewManagerChannelAccess", mock.Anything, int32(1), int32(200)). + Return(models.CheckNewManagerChannelAccessRow{}, errors.New("no rows")).Once() + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + int32(1), + int32(100), + "newmanager", + "permanent", + ) + + require.Error(t, err) + var valErr *ValidationError + require.True(t, errors.As(err, &valErr)) + assert.Equal(t, apierrors.ErrCodeForbidden, valErr.Code) + assert.Contains(t, valErr.Message, "level 499 access") + mockDB.AssertExpectations(t) +} + +func TestValidateManagerChangeBusinessRules_PendingExists(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + oldTs := int32(time.Now().Add(-365 * 24 * time.Hour).Unix()) + + mockDB.On("CheckUserChannelOwnership", mock.Anything, int32(100), int32(1)). + Return(models.CheckUserChannelOwnershipRow{ + Name: "#testchannel", + ID: 1, + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckChannelExistsAndRegistered", mock.Anything, int32(1)). + Return(models.CheckChannelExistsAndRegisteredRow{ + ID: 1, + Name: "#testchannel", + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{Username: "newmanager"}). + Return(models.GetUserRow{ + ID: 200, + Username: "newmanager", + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckNewManagerChannelAccess", mock.Anything, int32(1), int32(200)). + Return(models.CheckNewManagerChannelAccessRow{ + Username: "newmanager", + ID: 200, + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckExistingPendingRequests", mock.Anything, int32(1)). + Return([]models.CheckExistingPendingRequestsRow{ + {ChannelID: 1}, + }, nil).Once() + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + int32(1), + int32(100), + "newmanager", + "permanent", + ) + + require.Error(t, err) + var valErr *ValidationError + require.True(t, errors.As(err, &valErr)) + assert.Equal(t, apierrors.ErrCodeConflict, valErr.Code) + assert.Contains(t, valErr.Message, "pending manager change") + mockDB.AssertExpectations(t) +} + +func TestValidateManagerChangeBusinessRules_SelfAssignment(t *testing.T) { + // Self-assignment is tested through the new manager account age and ownership checks. + // The validation doesn't explicitly check for self-assignment since a user with level 500 + // can't also have level 499 on the same channel. We test that error path via + // CheckNewManagerChannelAccess returning an error when the same user is both owner and target. + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + oldTs := int32(time.Now().Add(-365 * 24 * time.Hour).Unix()) + + mockDB.On("CheckUserChannelOwnership", mock.Anything, int32(100), int32(1)). + Return(models.CheckUserChannelOwnershipRow{ + Name: "#testchannel", + ID: 1, + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckChannelExistsAndRegistered", mock.Anything, int32(1)). + Return(models.CheckChannelExistsAndRegisteredRow{ + ID: 1, + Name: "#testchannel", + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + // User looks up themselves + mockDB.On("GetUser", mock.Anything, models.GetUserParams{Username: "selfuser"}). + Return(models.GetUserRow{ + ID: 100, + Username: "selfuser", + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + // Self-assignment: user 100 trying to assign to user 100 - they can't have level 499 + mockDB.On("CheckNewManagerChannelAccess", mock.Anything, int32(1), int32(100)). + Return(models.CheckNewManagerChannelAccessRow{}, errors.New("user does not have level 499")).Once() + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + int32(1), + int32(100), + "selfuser", + "permanent", + ) + + require.Error(t, err) + var valErr *ValidationError + require.True(t, errors.As(err, &valErr)) + assert.Equal(t, apierrors.ErrCodeForbidden, valErr.Code) + mockDB.AssertExpectations(t) +} + +func TestValidateManagerChangeBusinessRules_NewManagerAccountTooNew_Permanent(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + oldTs := int32(time.Now().Add(-365 * 24 * time.Hour).Unix()) + recentSignupTs := int32(time.Now().Add(-30 * 24 * time.Hour).Unix()) // 30 days ago + + mockDB.On("CheckUserChannelOwnership", mock.Anything, int32(100), int32(1)). + Return(models.CheckUserChannelOwnershipRow{ + Name: "#testchannel", + ID: 1, + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckChannelExistsAndRegistered", mock.Anything, int32(1)). + Return(models.CheckChannelExistsAndRegisteredRow{ + ID: 1, + Name: "#testchannel", + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{Username: "newmanager"}). + Return(models.GetUserRow{ + ID: 200, + Username: "newmanager", + SignupTs: pgtype.Int4{Int32: recentSignupTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckNewManagerChannelAccess", mock.Anything, int32(1), int32(200)). + Return(models.CheckNewManagerChannelAccessRow{ + Username: "newmanager", + ID: 200, + SignupTs: pgtype.Int4{Int32: recentSignupTs, Valid: true}, + }, nil).Once() + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + int32(1), + int32(100), + "newmanager", + "permanent", + ) + + require.Error(t, err) + var valErr *ValidationError + require.True(t, errors.As(err, &valErr)) + assert.Equal(t, apierrors.ErrCodeForbidden, valErr.Code) + assert.Contains(t, valErr.Message, "90 days old for permanent") + mockDB.AssertExpectations(t) +} + +func TestValidateManagerChangeBusinessRules_NewManagerAccountTooNew_Temporary(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + oldTs := int32(time.Now().Add(-365 * 24 * time.Hour).Unix()) + veryRecentSignupTs := int32(time.Now().Add(-10 * 24 * time.Hour).Unix()) // 10 days ago + + mockDB.On("CheckUserChannelOwnership", mock.Anything, int32(100), int32(1)). + Return(models.CheckUserChannelOwnershipRow{ + Name: "#testchannel", + ID: 1, + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckChannelExistsAndRegistered", mock.Anything, int32(1)). + Return(models.CheckChannelExistsAndRegisteredRow{ + ID: 1, + Name: "#testchannel", + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{Username: "newmanager"}). + Return(models.GetUserRow{ + ID: 200, + Username: "newmanager", + SignupTs: pgtype.Int4{Int32: veryRecentSignupTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckNewManagerChannelAccess", mock.Anything, int32(1), int32(200)). + Return(models.CheckNewManagerChannelAccessRow{ + Username: "newmanager", + ID: 200, + SignupTs: pgtype.Int4{Int32: veryRecentSignupTs, Valid: true}, + }, nil).Once() + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + int32(1), + int32(100), + "newmanager", + "temporary", + ) + + require.Error(t, err) + var valErr *ValidationError + require.True(t, errors.As(err, &valErr)) + assert.Equal(t, apierrors.ErrCodeForbidden, valErr.Code) + assert.Contains(t, valErr.Message, "30 days old for temporary") + mockDB.AssertExpectations(t) +} + +func TestValidateManagerChangeBusinessRules_OwnsOtherChannels(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + oldTs := int32(time.Now().Add(-365 * 24 * time.Hour).Unix()) + + mockDB.On("CheckUserChannelOwnership", mock.Anything, int32(100), int32(1)). + Return(models.CheckUserChannelOwnershipRow{ + Name: "#testchannel", + ID: 1, + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckChannelExistsAndRegistered", mock.Anything, int32(1)). + Return(models.CheckChannelExistsAndRegisteredRow{ + ID: 1, + Name: "#testchannel", + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{Username: "newmanager"}). + Return(models.GetUserRow{ + ID: 200, + Username: "newmanager", + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckNewManagerChannelAccess", mock.Anything, int32(1), int32(200)). + Return(models.CheckNewManagerChannelAccessRow{ + Username: "newmanager", + ID: 200, + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckExistingPendingRequests", mock.Anything, int32(1)). + Return([]models.CheckExistingPendingRequestsRow{}, nil).Once() + + mockDB.On("CheckUserOwnsOtherChannels", mock.Anything, int32(200)). + Return(true, nil).Once() + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + int32(1), + int32(100), + "newmanager", + "permanent", + ) + + require.Error(t, err) + var valErr *ValidationError + require.True(t, errors.As(err, &valErr)) + assert.Equal(t, apierrors.ErrCodeForbidden, valErr.Code) + assert.Contains(t, valErr.Message, "already owns other channels") + mockDB.AssertExpectations(t) +} + +func TestValidateManagerChangeBusinessRules_MultipleManagers(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + oldTs := int32(time.Now().Add(-365 * 24 * time.Hour).Unix()) + + mockDB.On("CheckUserChannelOwnership", mock.Anything, int32(100), int32(1)). + Return(models.CheckUserChannelOwnershipRow{ + Name: "#testchannel", + ID: 1, + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckChannelExistsAndRegistered", mock.Anything, int32(1)). + Return(models.CheckChannelExistsAndRegisteredRow{ + ID: 1, + Name: "#testchannel", + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{Username: "newmanager"}). + Return(models.GetUserRow{ + ID: 200, + Username: "newmanager", + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckNewManagerChannelAccess", mock.Anything, int32(1), int32(200)). + Return(models.CheckNewManagerChannelAccessRow{ + Username: "newmanager", + ID: 200, + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckExistingPendingRequests", mock.Anything, int32(1)). + Return([]models.CheckExistingPendingRequestsRow{}, nil).Once() + + mockDB.On("CheckUserOwnsOtherChannels", mock.Anything, int32(200)). + Return(false, nil).Once() + + mockDB.On("CheckChannelSingleManager", mock.Anything, int32(1)). + Return(int64(3), nil).Once() + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + int32(1), + int32(100), + "newmanager", + "permanent", + ) + + require.Error(t, err) + var valErr *ValidationError + require.True(t, errors.As(err, &valErr)) + assert.Equal(t, apierrors.ErrCodeForbidden, valErr.Code) + assert.Contains(t, valErr.Message, "multiple managers") + mockDB.AssertExpectations(t) +} + +func TestValidateManagerChangeBusinessRules_NoVerificationData(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + oldTs := int32(time.Now().Add(-365 * 24 * time.Hour).Unix()) + + mockDB.On("CheckUserChannelOwnership", mock.Anything, int32(100), int32(1)). + Return(models.CheckUserChannelOwnershipRow{ + Name: "#testchannel", + ID: 1, + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckChannelExistsAndRegistered", mock.Anything, int32(1)). + Return(models.CheckChannelExistsAndRegisteredRow{ + ID: 1, + Name: "#testchannel", + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{Username: "newmanager"}). + Return(models.GetUserRow{ + ID: 200, + Username: "newmanager", + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckNewManagerChannelAccess", mock.Anything, int32(1), int32(200)). + Return(models.CheckNewManagerChannelAccessRow{ + Username: "newmanager", + ID: 200, + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckExistingPendingRequests", mock.Anything, int32(1)). + Return([]models.CheckExistingPendingRequestsRow{}, nil).Once() + + mockDB.On("CheckUserOwnsOtherChannels", mock.Anything, int32(200)). + Return(false, nil).Once() + + mockDB.On("CheckChannelSingleManager", mock.Anything, int32(1)). + Return(int64(1), nil).Once() + + mockDB.On("CheckUserCooldownStatus", mock.Anything, int32(100)). + Return(models.CheckUserCooldownStatusRow{ + PostForms: 0, + Verificationdata: pgtype.Text{String: "", Valid: false}, + Email: pgtype.Text{String: "test@example.com", Valid: true}, + }, nil).Once() + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + int32(1), + int32(100), + "newmanager", + "permanent", + ) + + require.Error(t, err) + var valErr *ValidationError + require.True(t, errors.As(err, &valErr)) + assert.Equal(t, apierrors.ErrCodeForbidden, valErr.Code) + assert.Contains(t, valErr.Message, "verification information") + mockDB.AssertExpectations(t) +} + +func TestValidateManagerChangeBusinessRules_NoEmail(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + oldTs := int32(time.Now().Add(-365 * 24 * time.Hour).Unix()) + + mockDB.On("CheckUserChannelOwnership", mock.Anything, int32(100), int32(1)). + Return(models.CheckUserChannelOwnershipRow{ + Name: "#testchannel", + ID: 1, + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckChannelExistsAndRegistered", mock.Anything, int32(1)). + Return(models.CheckChannelExistsAndRegisteredRow{ + ID: 1, + Name: "#testchannel", + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{Username: "newmanager"}). + Return(models.GetUserRow{ + ID: 200, + Username: "newmanager", + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckNewManagerChannelAccess", mock.Anything, int32(1), int32(200)). + Return(models.CheckNewManagerChannelAccessRow{ + Username: "newmanager", + ID: 200, + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckExistingPendingRequests", mock.Anything, int32(1)). + Return([]models.CheckExistingPendingRequestsRow{}, nil).Once() + + mockDB.On("CheckUserOwnsOtherChannels", mock.Anything, int32(200)). + Return(false, nil).Once() + + mockDB.On("CheckChannelSingleManager", mock.Anything, int32(1)). + Return(int64(1), nil).Once() + + mockDB.On("CheckUserCooldownStatus", mock.Anything, int32(100)). + Return(models.CheckUserCooldownStatusRow{ + PostForms: 0, + Verificationdata: pgtype.Text{String: "verified", Valid: true}, + Email: pgtype.Text{String: "", Valid: false}, + }, nil).Once() + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + int32(1), + int32(100), + "newmanager", + "permanent", + ) + + require.Error(t, err) + var valErr *ValidationError + require.True(t, errors.As(err, &valErr)) + assert.Equal(t, apierrors.ErrCodeForbidden, valErr.Code) + assert.Contains(t, valErr.Message, "email set") + mockDB.AssertExpectations(t) +} + +func TestValidateManagerChangeBusinessRules_CooldownActive(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + oldTs := int32(time.Now().Add(-365 * 24 * time.Hour).Unix()) + futureTs := int32(time.Now().Add(24 * time.Hour).Unix()) // 1 day from now + + mockDB.On("CheckUserChannelOwnership", mock.Anything, int32(100), int32(1)). + Return(models.CheckUserChannelOwnershipRow{ + Name: "#testchannel", + ID: 1, + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckChannelExistsAndRegistered", mock.Anything, int32(1)). + Return(models.CheckChannelExistsAndRegisteredRow{ + ID: 1, + Name: "#testchannel", + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{Username: "newmanager"}). + Return(models.GetUserRow{ + ID: 200, + Username: "newmanager", + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckNewManagerChannelAccess", mock.Anything, int32(1), int32(200)). + Return(models.CheckNewManagerChannelAccessRow{ + Username: "newmanager", + ID: 200, + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckExistingPendingRequests", mock.Anything, int32(1)). + Return([]models.CheckExistingPendingRequestsRow{}, nil).Once() + + mockDB.On("CheckUserOwnsOtherChannels", mock.Anything, int32(200)). + Return(false, nil).Once() + + mockDB.On("CheckChannelSingleManager", mock.Anything, int32(1)). + Return(int64(1), nil).Once() + + mockDB.On("CheckUserCooldownStatus", mock.Anything, int32(100)). + Return(models.CheckUserCooldownStatusRow{ + PostForms: futureTs, + Verificationdata: pgtype.Text{String: "verified", Valid: true}, + Email: pgtype.Text{String: "test@example.com", Valid: true}, + }, nil).Once() + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + int32(1), + int32(100), + "newmanager", + "permanent", + ) + + require.Error(t, err) + var valErr *ValidationError + require.True(t, errors.As(err, &valErr)) + assert.Equal(t, apierrors.ErrCodeBadRequest, valErr.Code) + assert.Contains(t, valErr.Message, "form request after") + mockDB.AssertExpectations(t) +} + +func TestValidateManagerChangeBusinessRules_AccountLocked(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + oldTs := int32(time.Now().Add(-365 * 24 * time.Hour).Unix()) + + mockDB.On("CheckUserChannelOwnership", mock.Anything, int32(100), int32(1)). + Return(models.CheckUserChannelOwnershipRow{ + Name: "#testchannel", + ID: 1, + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckChannelExistsAndRegistered", mock.Anything, int32(1)). + Return(models.CheckChannelExistsAndRegisteredRow{ + ID: 1, + Name: "#testchannel", + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{Username: "newmanager"}). + Return(models.GetUserRow{ + ID: 200, + Username: "newmanager", + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckNewManagerChannelAccess", mock.Anything, int32(1), int32(200)). + Return(models.CheckNewManagerChannelAccessRow{ + Username: "newmanager", + ID: 200, + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckExistingPendingRequests", mock.Anything, int32(1)). + Return([]models.CheckExistingPendingRequestsRow{}, nil).Once() + + mockDB.On("CheckUserOwnsOtherChannels", mock.Anything, int32(200)). + Return(false, nil).Once() + + mockDB.On("CheckChannelSingleManager", mock.Anything, int32(1)). + Return(int64(1), nil).Once() + + mockDB.On("CheckUserCooldownStatus", mock.Anything, int32(100)). + Return(models.CheckUserCooldownStatusRow{ + PostForms: 666, + Verificationdata: pgtype.Text{String: "verified", Valid: true}, + Email: pgtype.Text{String: "test@example.com", Valid: true}, + }, nil).Once() + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + int32(1), + int32(100), + "newmanager", + "permanent", + ) + + require.Error(t, err) + var valErr *ValidationError + require.True(t, errors.As(err, &valErr)) + assert.Equal(t, apierrors.ErrCodeForbidden, valErr.Code) + assert.Contains(t, valErr.Message, "locked from submitting forms") + mockDB.AssertExpectations(t) +} + +func TestValidateManagerChangeBusinessRules_DBError(t *testing.T) { + t.Run("pending requests db error", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + oldTs := int32(time.Now().Add(-365 * 24 * time.Hour).Unix()) + + mockDB.On("CheckUserChannelOwnership", mock.Anything, int32(100), int32(1)). + Return(models.CheckUserChannelOwnershipRow{ + Name: "#testchannel", + ID: 1, + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckChannelExistsAndRegistered", mock.Anything, int32(1)). + Return(models.CheckChannelExistsAndRegisteredRow{ + ID: 1, + Name: "#testchannel", + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{Username: "newmanager"}). + Return(models.GetUserRow{ + ID: 200, + Username: "newmanager", + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckNewManagerChannelAccess", mock.Anything, int32(1), int32(200)). + Return(models.CheckNewManagerChannelAccessRow{ + Username: "newmanager", + ID: 200, + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckExistingPendingRequests", mock.Anything, int32(1)). + Return([]models.CheckExistingPendingRequestsRow(nil), errors.New("connection refused")).Once() + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + int32(1), + int32(100), + "newmanager", + "permanent", + ) + + require.Error(t, err) + var valErr *ValidationError + require.True(t, errors.As(err, &valErr)) + assert.Equal(t, apierrors.ErrCodeDatabase, valErr.Code) + mockDB.AssertExpectations(t) + }) + + t.Run("owns other channels db error", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + oldTs := int32(time.Now().Add(-365 * 24 * time.Hour).Unix()) + + mockDB.On("CheckUserChannelOwnership", mock.Anything, int32(100), int32(1)). + Return(models.CheckUserChannelOwnershipRow{ + Name: "#testchannel", + ID: 1, + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckChannelExistsAndRegistered", mock.Anything, int32(1)). + Return(models.CheckChannelExistsAndRegisteredRow{ + ID: 1, + Name: "#testchannel", + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{Username: "newmanager"}). + Return(models.GetUserRow{ + ID: 200, + Username: "newmanager", + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckNewManagerChannelAccess", mock.Anything, int32(1), int32(200)). + Return(models.CheckNewManagerChannelAccessRow{ + Username: "newmanager", + ID: 200, + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckExistingPendingRequests", mock.Anything, int32(1)). + Return([]models.CheckExistingPendingRequestsRow{}, nil).Once() + + mockDB.On("CheckUserOwnsOtherChannels", mock.Anything, int32(200)). + Return(false, errors.New("database timeout")).Once() + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + int32(1), + int32(100), + "newmanager", + "permanent", + ) + + require.Error(t, err) + var valErr *ValidationError + require.True(t, errors.As(err, &valErr)) + assert.Equal(t, apierrors.ErrCodeDatabase, valErr.Code) + mockDB.AssertExpectations(t) + }) + + t.Run("single manager db error", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + oldTs := int32(time.Now().Add(-365 * 24 * time.Hour).Unix()) + + mockDB.On("CheckUserChannelOwnership", mock.Anything, int32(100), int32(1)). + Return(models.CheckUserChannelOwnershipRow{ + Name: "#testchannel", + ID: 1, + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckChannelExistsAndRegistered", mock.Anything, int32(1)). + Return(models.CheckChannelExistsAndRegisteredRow{ + ID: 1, + Name: "#testchannel", + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{Username: "newmanager"}). + Return(models.GetUserRow{ + ID: 200, + Username: "newmanager", + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckNewManagerChannelAccess", mock.Anything, int32(1), int32(200)). + Return(models.CheckNewManagerChannelAccessRow{ + Username: "newmanager", + ID: 200, + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckExistingPendingRequests", mock.Anything, int32(1)). + Return([]models.CheckExistingPendingRequestsRow{}, nil).Once() + + mockDB.On("CheckUserOwnsOtherChannels", mock.Anything, int32(200)). + Return(false, nil).Once() + + mockDB.On("CheckChannelSingleManager", mock.Anything, int32(1)). + Return(int64(0), errors.New("database error")).Once() + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + int32(1), + int32(100), + "newmanager", + "permanent", + ) + + require.Error(t, err) + var valErr *ValidationError + require.True(t, errors.As(err, &valErr)) + assert.Equal(t, apierrors.ErrCodeDatabase, valErr.Code) + mockDB.AssertExpectations(t) + }) + + t.Run("cooldown status db error", func(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + oldTs := int32(time.Now().Add(-365 * 24 * time.Hour).Unix()) + + mockDB.On("CheckUserChannelOwnership", mock.Anything, int32(100), int32(1)). + Return(models.CheckUserChannelOwnershipRow{ + Name: "#testchannel", + ID: 1, + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckChannelExistsAndRegistered", mock.Anything, int32(1)). + Return(models.CheckChannelExistsAndRegisteredRow{ + ID: 1, + Name: "#testchannel", + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{Username: "newmanager"}). + Return(models.GetUserRow{ + ID: 200, + Username: "newmanager", + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckNewManagerChannelAccess", mock.Anything, int32(1), int32(200)). + Return(models.CheckNewManagerChannelAccessRow{ + Username: "newmanager", + ID: 200, + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckExistingPendingRequests", mock.Anything, int32(1)). + Return([]models.CheckExistingPendingRequestsRow{}, nil).Once() + + mockDB.On("CheckUserOwnsOtherChannels", mock.Anything, int32(200)). + Return(false, nil).Once() + + mockDB.On("CheckChannelSingleManager", mock.Anything, int32(1)). + Return(int64(1), nil).Once() + + mockDB.On("CheckUserCooldownStatus", mock.Anything, int32(100)). + Return(models.CheckUserCooldownStatusRow{}, errors.New("database error")).Once() + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + int32(1), + int32(100), + "newmanager", + "permanent", + ) + + require.Error(t, err) + var valErr *ValidationError + require.True(t, errors.As(err, &valErr)) + assert.Equal(t, apierrors.ErrCodeDatabase, valErr.Code) + mockDB.AssertExpectations(t) + }) +} + +func TestValidateManagerChangeBusinessRules_TemporaryValid(t *testing.T) { + mockDB := mocks.NewServiceInterface(t) + validator := NewManagerChangeValidator(mockDB) + + oldTs := int32(time.Now().Add(-365 * 24 * time.Hour).Unix()) + + mockDB.On("CheckUserChannelOwnership", mock.Anything, int32(100), int32(1)). + Return(models.CheckUserChannelOwnershipRow{ + Name: "#testchannel", + ID: 1, + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckChannelExistsAndRegistered", mock.Anything, int32(1)). + Return(models.CheckChannelExistsAndRegisteredRow{ + ID: 1, + Name: "#testchannel", + RegisteredTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("GetUser", mock.Anything, models.GetUserParams{Username: "newmanager"}). + Return(models.GetUserRow{ + ID: 200, + Username: "newmanager", + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckNewManagerChannelAccess", mock.Anything, int32(1), int32(200)). + Return(models.CheckNewManagerChannelAccessRow{ + Username: "newmanager", + ID: 200, + SignupTs: pgtype.Int4{Int32: oldTs, Valid: true}, + }, nil).Once() + + mockDB.On("CheckExistingPendingRequests", mock.Anything, int32(1)). + Return([]models.CheckExistingPendingRequestsRow{}, nil).Once() + + // CheckUserOwnsOtherChannels is NOT called for temporary changes + mockDB.On("CheckChannelSingleManager", mock.Anything, int32(1)). + Return(int64(1), nil).Once() + + mockDB.On("CheckUserCooldownStatus", mock.Anything, int32(100)). + Return(models.CheckUserCooldownStatusRow{ + PostForms: 0, + Verificationdata: pgtype.Text{String: "verified", Valid: true}, + Email: pgtype.Text{String: "test@example.com", Valid: true}, + }, nil).Once() + + err := validator.ValidateManagerChangeBusinessRules( + context.Background(), + int32(1), + int32(100), + "newmanager", + "temporary", + ) + + assert.NoError(t, err) + mockDB.AssertExpectations(t) +} diff --git a/middlewares/business_metrics_test.go b/middlewares/business_metrics_test.go new file mode 100644 index 0000000..5a8a6dd --- /dev/null +++ b/middlewares/business_metrics_test.go @@ -0,0 +1,840 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (c) 2025 UnderNET + +package middlewares + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/metric/noop" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" + "go.opentelemetry.io/otel/sdk/metric/metricdata" + "go.opentelemetry.io/otel/sdk/resource" + + "github.com/undernetirc/cservice-api/internal/metrics" +) + +func createTestBusinessMetrics(t *testing.T) *metrics.BusinessMetrics { + t.Helper() + meter := noop.NewMeterProvider().Meter("test") + bm, err := metrics.NewBusinessMetrics(metrics.BusinessMetricsConfig{ + Meter: meter, + ServiceName: "test-service", + }) + require.NoError(t, err) + return bm +} + +func TestBusinessMetricsMiddleware(t *testing.T) { + t.Run("nil business metrics returns no-op middleware", func(t *testing.T) { + middleware := BusinessMetricsMiddleware(BusinessMetricsConfig{ + BusinessMetrics: nil, + }) + assert.NotNil(t, middleware) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := middleware(func(c echo.Context) error { + return c.String(http.StatusOK, "ok") + }) + + err := handler(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("valid business metrics wraps handler", func(t *testing.T) { + bm := createTestBusinessMetrics(t) + middleware := BusinessMetricsMiddleware(BusinessMetricsConfig{ + BusinessMetrics: bm, + }) + assert.NotNil(t, middleware) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handlerCalled := false + handler := middleware(func(c echo.Context) error { + handlerCalled = true + return c.String(http.StatusOK, "ok") + }) + + err := handler(c) + assert.NoError(t, err) + assert.True(t, handlerCalled) + assert.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("skipper skips metrics recording", func(t *testing.T) { + bm := createTestBusinessMetrics(t) + middleware := BusinessMetricsMiddleware(BusinessMetricsConfig{ + BusinessMetrics: bm, + Skipper: func(_ echo.Context) bool { + return true + }, + }) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := middleware(func(c echo.Context) error { + return c.String(http.StatusOK, "ok") + }) + + err := handler(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("handler error is propagated", func(t *testing.T) { + bm := createTestBusinessMetrics(t) + middleware := BusinessMetricsMiddleware(BusinessMetricsConfig{ + BusinessMetrics: bm, + }) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + expectedErr := echo.NewHTTPError(http.StatusInternalServerError, "handler error") + handler := middleware(func(_ echo.Context) error { + return expectedErr + }) + + err := handler(c) + assert.Equal(t, expectedErr, err) + }) +} + +func TestBusinessMetricsMiddleware_RecordsOnCompletion(t *testing.T) { + reader := sdkmetric.NewManualReader() + provider := sdkmetric.NewMeterProvider( + sdkmetric.WithResource(resource.Empty()), + sdkmetric.WithReader(reader), + ) + meter := provider.Meter("test") + + bm, err := metrics.NewBusinessMetrics(metrics.BusinessMetricsConfig{ + Meter: meter, + ServiceName: "test-service", + }) + require.NoError(t, err) + + middleware := BusinessMetricsMiddleware(BusinessMetricsConfig{ + BusinessMetrics: bm, + }) + + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/api/v1/register", strings.NewReader(`{"username":"newuser","email":"newuser@example.com"}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := middleware(func(c echo.Context) error { + return c.String(http.StatusOK, "registered") + }) + + err = handler(c) + assert.NoError(t, err) + + ctx := context.Background() + rm := &metricdata.ResourceMetrics{} + err = reader.Collect(ctx, rm) + require.NoError(t, err) + assert.NotEmpty(t, rm.ScopeMetrics, "Expected metrics to be recorded after request completion") +} + +func TestRecordBusinessMetrics(t *testing.T) { + bm := createTestBusinessMetrics(t) + ctx := context.Background() + duration := 100 * time.Millisecond + + t.Run("registration endpoint POST", func(_ *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/api/v1/register", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + body := []byte(`{"username":"testuser","email":"test@example.com"}`) + + recordBusinessMetrics(ctx, bm, c, body, http.StatusOK, duration) + }) + + t.Run("activation endpoint POST", func(_ *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/api/v1/activate", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + body := []byte(`{"username":"testuser"}`) + + recordBusinessMetrics(ctx, bm, c, body, http.StatusOK, duration) + }) + + t.Run("login endpoint POST", func(_ *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/api/v1/login", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.Set("user_id", int32(42)) + + recordBusinessMetrics(ctx, bm, c, nil, http.StatusOK, duration) + }) + + t.Run("logout endpoint POST", func(_ *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/api/v1/logout", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.Set("user_id", int32(42)) + + recordBusinessMetrics(ctx, bm, c, nil, http.StatusOK, duration) + }) + + t.Run("channel search GET", func(_ *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/api/v1/channels/search?q=test", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.Set("user_id", int32(10)) + + recordBusinessMetrics(ctx, bm, c, nil, http.StatusOK, duration) + }) + + t.Run("channel settings GET", func(_ *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/api/v1/channels/5/settings", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.Set("user_id", int32(10)) + c.SetParamNames("id") + c.SetParamValues("5") + + recordBusinessMetrics(ctx, bm, c, nil, http.StatusOK, duration) + }) + + t.Run("channel settings PUT", func(_ *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPut, "/api/v1/channels/5/settings", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.Set("user_id", int32(10)) + c.SetParamNames("id") + c.SetParamValues("5") + body := []byte(`{"description":"new description","url":"https://example.com"}`) + + recordBusinessMetrics(ctx, bm, c, body, http.StatusOK, duration) + }) + + t.Run("channel members POST", func(_ *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/api/v1/channels/5/members", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.Set("user_id", int32(10)) + c.SetParamNames("id") + c.SetParamValues("5") + body := []byte(`{"user_id":20,"access_level":200}`) + + recordBusinessMetrics(ctx, bm, c, body, http.StatusOK, duration) + }) + + t.Run("channel members DELETE", func(_ *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodDelete, "/api/v1/channels/5/members/20", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.Set("user_id", int32(10)) + c.SetParamNames("id", "user_id") + c.SetParamValues("5", "20") + + recordBusinessMetrics(ctx, bm, c, nil, http.StatusOK, duration) + }) +} + +func TestRecordBusinessMetrics_VariousEndpoints(t *testing.T) { + bm := createTestBusinessMetrics(t) + ctx := context.Background() + duration := 50 * time.Millisecond + + tests := []struct { + name string + path string + method string + status int + body string + }{ + { + name: "general GET endpoint", + path: "/api/v1/users/me", + method: http.MethodGet, + status: http.StatusOK, + }, + { + name: "failed registration", + path: "/api/v1/register", + method: http.MethodPost, + status: http.StatusConflict, + body: `{"username":"existing","email":"existing@example.com"}`, + }, + { + name: "failed activation - invalid token", + path: "/api/v1/activate", + method: http.MethodPost, + status: http.StatusBadRequest, + body: `{"username":"testuser"}`, + }, + { + name: "failed activation - expired token", + path: "/api/v1/activate", + method: http.MethodPost, + status: http.StatusGone, + body: `{"username":"testuser"}`, + }, + { + name: "failed activation - not found", + path: "/api/v1/activate", + method: http.MethodPost, + status: http.StatusNotFound, + body: `{"username":"testuser"}`, + }, + { + name: "channel search with query param", + path: "/api/v1/channels/search?query=mychannel", + method: http.MethodGet, + status: http.StatusOK, + }, + { + name: "server error on channel settings", + path: "/api/v1/channels/5/settings", + method: http.MethodPut, + status: http.StatusInternalServerError, + body: `{"description":"test"}`, + }, + { + name: "failed member add", + path: "/api/v1/channels/5/members", + method: http.MethodPost, + status: http.StatusForbidden, + body: `{"user_id":20}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(_ *testing.T) { + e := echo.New() + req := httptest.NewRequest(tt.method, tt.path, nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + recordBusinessMetrics(ctx, bm, c, []byte(tt.body), tt.status, duration) + // Test passes if no panic occurs + }) + } +} + +func TestRecordBusinessMetrics_ErrorHandling(t *testing.T) { + bm := createTestBusinessMetrics(t) + ctx := context.Background() + duration := 10 * time.Millisecond + + t.Run("nil request body does not panic", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/api/v1/register", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + assert.NotPanics(t, func() { + recordBusinessMetrics(ctx, bm, c, nil, http.StatusOK, duration) + }) + }) + + t.Run("invalid JSON body does not panic", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/api/v1/register", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + assert.NotPanics(t, func() { + recordBusinessMetrics(ctx, bm, c, []byte("invalid json{{{"), http.StatusOK, duration) + }) + }) + + t.Run("empty body on activation does not panic", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/api/v1/activate", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + assert.NotPanics(t, func() { + recordBusinessMetrics(ctx, bm, c, []byte{}, http.StatusOK, duration) + }) + }) + + t.Run("channel settings update with empty body", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPut, "/api/v1/channels/5/settings", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + assert.NotPanics(t, func() { + recordBusinessMetrics(ctx, bm, c, []byte{}, http.StatusOK, duration) + }) + }) + + t.Run("channel member add with malformed body", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/api/v1/channels/5/members", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + assert.NotPanics(t, func() { + recordBusinessMetrics(ctx, bm, c, []byte("not json"), http.StatusOK, duration) + }) + }) + + t.Run("500 error does not break metrics", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodPost, "/api/v1/register", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + assert.NotPanics(t, func() { + recordBusinessMetrics(ctx, bm, c, []byte(`{"username":"u"}`), http.StatusInternalServerError, duration) + }) + }) +} + +func TestBusinessMetrics_ExtractorFunctions(t *testing.T) { + t.Run("extractUserID", func(t *testing.T) { + tests := []struct { + name string + setup func(echo.Context) + expected int32 + }{ + { + name: "no user_id in context", + setup: func(_ echo.Context) {}, + expected: 0, + }, + { + name: "int32 user_id", + setup: func(c echo.Context) { + c.Set("user_id", int32(42)) + }, + expected: 42, + }, + { + name: "int user_id", + setup: func(c echo.Context) { + c.Set("user_id", 99) + }, + expected: 99, + }, + { + name: "string user_id", + setup: func(c echo.Context) { + c.Set("user_id", "123") + }, + expected: 123, + }, + { + name: "invalid string user_id", + setup: func(c echo.Context) { + c.Set("user_id", "not-a-number") + }, + expected: 0, + }, + { + name: "unsupported type", + setup: func(c echo.Context) { + c.Set("user_id", float64(5.5)) + }, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + tt.setup(c) + + result := extractUserID(c) + assert.Equal(t, tt.expected, result) + }) + } + }) + + t.Run("extractChannelID", func(t *testing.T) { + tests := []struct { + name string + paramName string + paramValue string + expected int32 + }{ + { + name: "valid id param", + paramName: "id", + paramValue: "42", + expected: 42, + }, + { + name: "valid channel_id param", + paramName: "channel_id", + paramValue: "99", + expected: 99, + }, + { + name: "no param", + expected: 0, + }, + { + name: "invalid param", + paramName: "id", + paramValue: "abc", + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + if tt.paramName != "" { + c.SetParamNames(tt.paramName) + c.SetParamValues(tt.paramValue) + } + + result := extractChannelID(c) + assert.Equal(t, tt.expected, result) + }) + } + }) + + t.Run("extractAccessLevel", func(t *testing.T) { + tests := []struct { + name string + setup func(echo.Context) + expected int + }{ + { + name: "default access level", + setup: func(_ echo.Context) {}, + expected: 100, + }, + { + name: "custom access level", + setup: func(c echo.Context) { + c.Set("access_level", 500) + }, + expected: 500, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + tt.setup(c) + + result := extractAccessLevel(c) + assert.Equal(t, tt.expected, result) + }) + } + }) + + t.Run("extractRegistrationInfo", func(t *testing.T) { + tests := []struct { + name string + body string + expectUser string + expectEmail string + }{ + { + name: "valid registration body", + body: `{"username":"newuser","email":"new@example.com"}`, + expectUser: "newuser", + expectEmail: "new@example.com", + }, + { + name: "empty body", + body: "", + expectUser: "", + expectEmail: "", + }, + { + name: "invalid JSON", + body: "not json", + expectUser: "", + expectEmail: "", + }, + { + name: "missing fields", + body: `{"other":"value"}`, + expectUser: "", + expectEmail: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + username, email := extractRegistrationInfo([]byte(tt.body)) + assert.Equal(t, tt.expectUser, username) + assert.Equal(t, tt.expectEmail, email) + }) + } + }) + + t.Run("extractUsernameFromActivation", func(t *testing.T) { + tests := []struct { + name string + body string + expected string + }{ + { + name: "valid body", + body: `{"username":"activateuser"}`, + expected: "activateuser", + }, + { + name: "empty body", + body: "", + expected: "", + }, + { + name: "invalid JSON", + body: "not json", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractUsernameFromActivation([]byte(tt.body)) + assert.Equal(t, tt.expected, result) + }) + } + }) + + t.Run("extractResultCount", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + rec := httptest.NewRecorder() + + c := e.NewContext(req, rec) + assert.Equal(t, 0, extractResultCount(c)) + + c.Set("result_count", 25) + assert.Equal(t, 25, extractResultCount(c)) + }) + + t.Run("extractUpdatedFields", func(t *testing.T) { + tests := []struct { + name string + body string + hasKeys bool + }{ + { + name: "valid body with fields", + body: `{"description":"test","url":"https://example.com"}`, + hasKeys: true, + }, + { + name: "empty JSON", + body: `{}`, + hasKeys: false, + }, + { + name: "invalid JSON", + body: "not json", + hasKeys: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fields := extractUpdatedFields([]byte(tt.body)) + if tt.hasKeys { + assert.NotEmpty(t, fields) + } else { + assert.Empty(t, fields) + } + }) + } + }) + + t.Run("extractTargetUserID", func(t *testing.T) { + tests := []struct { + name string + body string + expected int32 + }{ + { + name: "numeric user_id", + body: `{"user_id":42}`, + expected: 42, + }, + { + name: "string user_id", + body: `{"user_id":"99"}`, + expected: 99, + }, + { + name: "missing user_id", + body: `{"other":"value"}`, + expected: 0, + }, + { + name: "invalid JSON", + body: "not json", + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractTargetUserID([]byte(tt.body)) + assert.Equal(t, tt.expected, result) + }) + } + }) + + t.Run("extractTargetUserIDFromPath", func(t *testing.T) { + e := echo.New() + req := httptest.NewRequest(http.MethodDelete, "/channels/5/members/20", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + c.SetParamNames("id", "user_id") + c.SetParamValues("5", "20") + + assert.Equal(t, int32(20), extractTargetUserIDFromPath(c)) + + // No param + c2 := e.NewContext(httptest.NewRequest(http.MethodDelete, "/test", nil), httptest.NewRecorder()) + assert.Equal(t, int32(0), extractTargetUserIDFromPath(c2)) + }) + + t.Run("extractMemberAccessLevel", func(t *testing.T) { + tests := []struct { + name string + body string + expected int + }{ + { + name: "valid access_level", + body: `{"access_level":200}`, + expected: 200, + }, + { + name: "missing access_level", + body: `{"user_id":42}`, + expected: 100, + }, + { + name: "invalid JSON", + body: "not json", + expected: 100, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractMemberAccessLevel([]byte(tt.body)) + assert.Equal(t, tt.expected, result) + }) + } + }) + + t.Run("getRegistrationReason", func(t *testing.T) { + assert.Equal(t, "success", getRegistrationReason(200, true)) + assert.Equal(t, "invalid_data", getRegistrationReason(400, false)) + assert.Equal(t, "username_or_email_exists", getRegistrationReason(409, false)) + assert.Equal(t, "validation_failed", getRegistrationReason(422, false)) + assert.Equal(t, "server_error", getRegistrationReason(500, false)) + }) + + t.Run("getActivationReason", func(t *testing.T) { + assert.Equal(t, "success", getActivationReason(200, true)) + assert.Equal(t, "invalid_token", getActivationReason(400, false)) + assert.Equal(t, "token_not_found", getActivationReason(404, false)) + assert.Equal(t, "token_expired", getActivationReason(410, false)) + assert.Equal(t, "server_error", getActivationReason(500, false)) + }) + + t.Run("getOperationType", func(t *testing.T) { + tests := []struct { + path string + method string + expected string + }{ + {"/api/v1/register", "POST", "user_registration"}, + {"/api/v1/activate", "POST", "user_activation"}, + {"/api/v1/login", "POST", "user_login"}, + {"/api/v1/logout", "POST", "user_logout"}, + {"/api/v1/channels/search", "GET", "channel_search"}, + {"/api/v1/channels/5/settings", "GET", "channel_settings_view"}, + {"/api/v1/channels/5/settings", "PUT", "channel_settings_update"}, + {"/api/v1/channels/5/members", "POST", "channel_member_add"}, + {"/api/v1/channels/5/members/20", "DELETE", "channel_member_remove"}, + {"/api/v1/users/me", "GET", "general_api"}, + } + + for _, tt := range tests { + t.Run(tt.path+"_"+tt.method, func(t *testing.T) { + result := getOperationType(tt.path, tt.method) + assert.Equal(t, tt.expected, result) + }) + } + }) + + t.Run("getFeatureName", func(t *testing.T) { + tests := []struct { + path string + method string + expected string + }{ + {"/api/v1/register", "POST", "user_registration"}, + {"/api/v1/activate", "POST", "user_activation"}, + {"/api/v1/channels/search", "GET", "channel_search"}, + {"/api/v1/channels/5/settings", "PUT", "channel_settings"}, + {"/api/v1/channels/5/members", "POST", "channel_members"}, + {"/api/v1/users/me", "GET", ""}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + result := getFeatureName(tt.path, tt.method) + assert.Equal(t, tt.expected, result) + }) + } + }) + + t.Run("calculateErrorRate", func(t *testing.T) { + assert.Equal(t, 0.0, calculateErrorRate(200)) + assert.Equal(t, 0.0, calculateErrorRate(301)) + assert.Equal(t, 100.0, calculateErrorRate(400)) + assert.Equal(t, 100.0, calculateErrorRate(404)) + assert.Equal(t, 100.0, calculateErrorRate(500)) + }) +} diff --git a/middlewares/combined_auth_security_test.go b/middlewares/combined_auth_security_test.go new file mode 100644 index 0000000..76494b4 --- /dev/null +++ b/middlewares/combined_auth_security_test.go @@ -0,0 +1,680 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (c) 2025 UnderNET + +package middlewares + +import ( + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/undernetirc/cservice-api/db/mocks" + "github.com/undernetirc/cservice-api/models" +) + +// TestCombinedAuth_SecurityBoundary_EmptyAuthHeader verifies that an empty Authorization +// header (header present with empty value) is treated the same as absent. +// Attack vector: sending "Authorization: " to skip JWT validation while avoiding +// the no-header check. +func TestCombinedAuth_SecurityBoundary_EmptyAuthHeader(t *testing.T) { + t.Run("empty Authorization header with JWT-only returns 401", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + cfg := newTestCombinedAuthConfig(mockService, true, false, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", "") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + assert.Equal(t, echo.ErrUnauthorized, err) + }) + + t.Run("empty Authorization header with both methods falls through to 401", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + cfg := newTestCombinedAuthConfig(mockService, true, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", "") + // No X-API-Key header either + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + assert.Equal(t, echo.ErrUnauthorized, err) + }) +} + +// TestCombinedAuth_SecurityBoundary_BearerEdgeCases tests malformed Bearer token formats. +func TestCombinedAuth_SecurityBoundary_BearerEdgeCases(t *testing.T) { + t.Run("Bearer with no token returns 401", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + cfg := newTestCombinedAuthConfig(mockService, true, false, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", "Bearer ") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + assert.Equal(t, echo.ErrUnauthorized, err) + }) + + t.Run("Basic auth scheme is rejected by JWT-only config", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + cfg := newTestCombinedAuthConfig(mockService, true, false, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + // Basic auth with valid-looking base64 credentials + req.Header.Set("Authorization", "Basic dXNlcjpwYXNzd29yZA==") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + assert.Equal(t, echo.ErrUnauthorized, err, "Basic scheme should be rejected by JWT-only config") + }) + + t.Run("Authorization header with SQL injection pattern is rejected", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + cfg := newTestCombinedAuthConfig(mockService, true, false, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + // Attempt SQL injection via Authorization header + req.Header.Set("Authorization", "Bearer '; DROP TABLE users; --") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + assert.NotPanics(t, func() { + err := handler(c) + assert.Equal(t, echo.ErrUnauthorized, err) + }) + }) +} + +// TestCombinedAuth_SecurityBoundary_NoAuthMethodsEnabled tests the degenerate case +// where both AllowJWT and AllowAPIKey are false. +func TestCombinedAuth_SecurityBoundary_NoAuthMethodsEnabled(t *testing.T) { + t.Run("no auth methods with Required=true always denies regardless of headers", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + cfg := newTestCombinedAuthConfig(mockService, false, false, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", "Bearer some-token") + req.Header.Set("X-API-Key", "some-api-key") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + assert.Equal(t, echo.ErrUnauthorized, err) + }) + + t.Run("no auth methods with Required=false passes through", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + cfg := newTestCombinedAuthConfig(mockService, false, false, false) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("Authorization", "Bearer some-token") + req.Header.Set("X-API-Key", "some-api-key") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + }) +} + +// TestCombinedAuth_SecurityBoundary_DeletedKey tests whether a soft-deleted API key +// is correctly rejected. +// +// POTENTIAL BUG: The current authenticateAPIKey implementation does not check the +// Deleted field on the returned ApiKey row. If GetAPIKeyByHash returns a deleted key +// (because the SQL query does not filter it), the key will be authenticated. +// This test documents the expected behavior (rejection). +// +// NOTE: The SQL query in GetAPIKeyByHash filters out deleted keys at the DB level, +// so in practice deleted keys are never returned. This test uses mocks that bypass +// the SQL filter, so the middleware-level gap is documented but not exploitable. +func TestCombinedAuth_SecurityBoundary_DeletedKey(t *testing.T) { + t.Run("soft-deleted API key behavior documented", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + + deletedKey := newValidAPIKeyRow(nil, nil) + deletedKey.Deleted = pgtype.Int2{Int16: 1, Valid: true} + + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(deletedKey, nil).Once() + // The middleware spawns a background goroutine for UpdateAPIKeyLastUsed + mockService.On("UpdateAPIKeyLastUsed", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", validAPIKey()) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + // KNOWN GAP: authenticateAPIKey does not check the Deleted flag at the Go level. + // The SQL query filters deleted keys, so this is defense-in-depth only. + // If GetAPIKeyByHash ever changes to return deleted keys, add: + // if apiKey.Deleted.Valid && apiKey.Deleted.Int16 != 0 { return false, nil, nil } + if err == nil { + t.Log("SECURITY NOTE: authenticateAPIKey does not check Deleted flag; relies on SQL filter") + } + // Allow background goroutine to complete + time.Sleep(10 * time.Millisecond) + }) +} + +// TestCombinedAuth_SecurityBoundary_ExpirationBoundary tests edge cases in expiration logic. +// The current check is: if int64(ExpiresAt) < now → expired. +// This means ExpiresAt == now is NOT expired (off-by-one at the boundary second). +func TestCombinedAuth_SecurityBoundary_ExpirationBoundary(t *testing.T) { + t.Run("key expiring exactly at current second is treated as valid (boundary condition)", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + + exactlyNow := int32(time.Now().Unix()) + apiKeyRow := newValidAPIKeyRow(nil, nil) + apiKeyRow.ExpiresAt = pgtype.Int4{Int32: exactlyNow, Valid: true} + + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(apiKeyRow, nil).Once() + mockService.On("UpdateAPIKeyLastUsed", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", validAPIKey()) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + // Document actual boundary behavior: ExpiresAt == now uses strict "<" so this is NOT expired. + // This may be intentional or a bug depending on design intent. + // The test logs the behavior rather than asserting a specific outcome. + t.Logf("Key at exact expiry second: err=%v, statusCode=%d (nil err = accepted, ErrUnauthorized = rejected)", + err, rec.Code) + }) + + t.Run("key expiring 1 second in the future is valid", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + + oneSecondFuture := int32(time.Now().Unix() + 1) + apiKeyRow := newValidAPIKeyRow(nil, nil) + apiKeyRow.ExpiresAt = pgtype.Int4{Int32: oneSecondFuture, Valid: true} + + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(apiKeyRow, nil).Once() + mockService.On("UpdateAPIKeyLastUsed", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", validAPIKey()) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + assert.NoError(t, err, "key expiring in 1 second should still be valid") + assert.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("key expired 1 second ago is rejected", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + + oneSecondAgo := int32(time.Now().Unix() - 1) + apiKeyRow := newValidAPIKeyRow(nil, nil) + apiKeyRow.ExpiresAt = pgtype.Int4{Int32: oneSecondAgo, Valid: true} + + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(apiKeyRow, nil).Once() + + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", validAPIKey()) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + assert.Equal(t, echo.ErrUnauthorized, err, "key expired 1 second ago must be rejected") + }) + + t.Run("key with ExpiresAt=0 and Valid=true is treated as no expiry", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + + // ExpiresAt.Int32 = 0, Valid = true: the condition `apiKey.ExpiresAt.Int32 > 0` is false, + // so expiry check is skipped. This is tested to document the behavior. + apiKeyRow := newValidAPIKeyRow(nil, nil) + apiKeyRow.ExpiresAt = pgtype.Int4{Int32: 0, Valid: true} + + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(apiKeyRow, nil).Once() + mockService.On("UpdateAPIKeyLastUsed", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", validAPIKey()) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + assert.NoError(t, err, "ExpiresAt=0 with Valid=true should bypass expiry check and accept the key") + }) +} + +// TestCombinedAuth_SecurityBoundary_VeryLongAPIKey tests that an oversized API key +// does not cause panics or resource exhaustion. SHA-256 hashing handles arbitrary input. +func TestCombinedAuth_SecurityBoundary_VeryLongAPIKey(t *testing.T) { + t.Run("1MB API key is hashed and fails gracefully", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(models.ApiKey{}, errors.New("not found")).Once() + + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + veryLongKey := strings.Repeat("a", 1024*1024) + req.Header.Set("X-API-Key", veryLongKey) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + assert.NotPanics(t, func() { + err := handler(c) + assert.Equal(t, echo.ErrUnauthorized, err) + }) + }) +} + +// TestCombinedAuth_SecurityBoundary_AdversarialAPIKeys tests API keys with adversarial +// content. All inputs are SHA-256 hashed before DB lookup, so injection is not possible, +// but the system must not panic or behave unexpectedly. +func TestCombinedAuth_SecurityBoundary_AdversarialAPIKeys(t *testing.T) { + adversarialKeys := []struct { + name string + key string + }{ + {"null bytes in key", "key\x00withNull\x00bytes"}, + {"SQL injection attempt", "'; DROP TABLE api_keys; --"}, + {"unicode characters", "测试-api-key-值"}, + {"newline header injection attempt", "key\r\nX-Injected: evil"}, + {"tab characters", "key\twith\ttabs"}, + {"only whitespace", " \t\n "}, + {"zero-width unicode", "key\u200b\u200czero\u200dwidth"}, + {"path traversal attempt", "../../../../etc/passwd"}, + {"null key", "\x00\x00\x00\x00"}, + } + + for _, tc := range adversarialKeys { + t.Run(tc.name, func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(models.ApiKey{}, errors.New("not found")).Maybe() + + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", tc.key) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + assert.NotPanics(t, func() { + err := handler(c) + // All adversarial keys must be rejected + assert.Equal(t, echo.ErrUnauthorized, err, + "adversarial API key %q must not authenticate", tc.name) + }) + }) + } +} + +// TestCombinedAuth_SecurityBoundary_MalformedIPRestrictions tests behavior when +// the IP restriction data stored in the database is malformed. +func TestCombinedAuth_SecurityBoundary_MalformedIPRestrictions(t *testing.T) { + t.Run("malformed IP restrictions JSON in database causes auth failure", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + apiKeyRow := newValidAPIKeyRow(nil, []byte(`this is not valid json`)) + + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(apiKeyRow, nil).Once() + + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", validAPIKey()) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + assert.Equal(t, echo.ErrUnauthorized, err) + }) + + t.Run("invalid IP address in X-Real-IP header with IP-restricted key causes auth failure", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + ipRestrictionsJSON, _ := json.Marshal([]string{"192.168.1.0/24"}) + apiKeyRow := newValidAPIKeyRow(nil, ipRestrictionsJSON) + + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(apiKeyRow, nil).Once() + + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", validAPIKey()) + req.Header.Set("X-Real-IP", "not-a-valid-ip-address") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + // Invalid client IP should result in auth failure + assert.Equal(t, echo.ErrUnauthorized, err) + }) + + t.Run("malformed CIDR in database IP restrictions causes auth failure", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + // Valid JSON array, but contents are not valid CIDR notation + malformedCIDRJSON, _ := json.Marshal([]string{"not-a-cidr", "256.256.256.256/99"}) + apiKeyRow := newValidAPIKeyRow(nil, malformedCIDRJSON) + + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(apiKeyRow, nil).Once() + + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", validAPIKey()) + req.Header.Set("X-Real-IP", "192.168.1.1") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + assert.NotPanics(t, func() { + err := handler(c) + // Malformed CIDR in DB should cause auth to fail (error path) + assert.Equal(t, echo.ErrUnauthorized, err) + }) + }) +} + +// TestCombinedAuth_SecurityBoundary_ConcurrentRequests verifies that the middleware +// is safe to use concurrently (no data races, no panics under load). +func TestCombinedAuth_SecurityBoundary_ConcurrentRequests(t *testing.T) { + t.Run("50 concurrent API key validations are goroutine-safe", func(t *testing.T) { + const numGoroutines = 50 + + mockService := mocks.NewServiceInterface(t) + scopesJSON, _ := json.Marshal([]string{"read"}) + apiKeyRow := newValidAPIKeyRow(scopesJSON, nil) + + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(apiKeyRow, nil).Times(numGoroutines) + mockService.On("UpdateAPIKeyLastUsed", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + var wg sync.WaitGroup + authErrors := make([]error, numGoroutines) + + for i := range numGoroutines { + wg.Add(1) + go func(idx int) { + defer wg.Done() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", validAPIKey()) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + authErrors[idx] = handler(c) + }(i) + } + + wg.Wait() + + for i, err := range authErrors { + assert.NoError(t, err, "goroutine %d should authenticate successfully", i) + } + }) + + t.Run("concurrent requests with mixed valid/invalid keys handle correctly", func(t *testing.T) { + const numValid = 25 + const numInvalid = 25 + + mockService := mocks.NewServiceInterface(t) + scopesJSON, _ := json.Marshal([]string{"read"}) + validRow := newValidAPIKeyRow(scopesJSON, nil) + + // Valid key lookups succeed + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(validRow, nil).Times(numValid) + mockService.On("UpdateAPIKeyLastUsed", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + var wg sync.WaitGroup + results := make([]error, numValid+numInvalid) + + for i := range numValid { + wg.Add(1) + go func(idx int) { + defer wg.Done() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", validAPIKey()) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + results[idx] = handler(c) + }(i) + } + + // Invalid key goroutines use a separate mock setup + mockService2 := mocks.NewServiceInterface(t) + mockService2.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(models.ApiKey{}, errors.New("not found")).Times(numInvalid) + cfg2 := newTestCombinedAuthConfig(mockService2, false, true, true) + + for i := range numInvalid { + wg.Add(1) + go func(idx int) { + defer wg.Done() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", "invalid-key-"+strings.Repeat("x", idx)) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + handler := CombinedAuth(cfg2)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + results[numValid+idx] = handler(c) + }(i) + } + + wg.Wait() + + // Valid requests should succeed + for i := range numValid { + assert.NoError(t, results[i], "valid request %d should succeed", i) + } + // Invalid requests should fail + for i := range numInvalid { + assert.Equal(t, echo.ErrUnauthorized, results[numValid+i], + "invalid request %d should return 401", i) + } + }) +} + +// TestAuthenticateAPIKey_SecurityBoundary_AdversarialInputs directly tests the +// authenticateAPIKey function with adversarial inputs. +func TestAuthenticateAPIKey_SecurityBoundary_AdversarialInputs(t *testing.T) { + t.Run("empty key string is hashed and returns not found", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(models.ApiKey{}, errors.New("not found")).Once() + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + // Empty key string — SHA-256("") is still a valid hash + authenticated, keyCtx, err := authenticateAPIKey(c, mockService, "") + + assert.False(t, authenticated) + assert.Nil(t, keyCtx) + assert.Error(t, err, "empty key should fail DB lookup") + }) + + t.Run("key with only whitespace is hashed and rejected", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(models.ApiKey{}, errors.New("not found")).Once() + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + authenticated, keyCtx, err := authenticateAPIKey(c, mockService, " \t\n ") + + assert.False(t, authenticated) + assert.Nil(t, keyCtx) + assert.Error(t, err) + }) + + t.Run("key with deleted flag behavior documented", func(t *testing.T) { + // Documents the current behavior: authenticateAPIKey does NOT check Deleted flag. + // The SQL query filters deleted keys at the DB level, so this is defense-in-depth. + mockService := mocks.NewServiceInterface(t) + + scopesJSON, _ := json.Marshal([]string{"read"}) + deletedKey := newValidAPIKeyRow(scopesJSON, nil) + deletedKey.Deleted = pgtype.Int2{Int16: 1, Valid: true} + + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(deletedKey, nil).Once() + mockService.On("UpdateAPIKeyLastUsed", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + authenticated, keyCtx, err := authenticateAPIKey(c, mockService, validAPIKey()) + + // KNOWN GAP: authenticateAPIKey does not check Deleted flag at Go level. + // SQL query filters deleted keys, making this defense-in-depth only. + t.Logf("Deleted key: authenticated=%v, ctx=%v, err=%v", authenticated, keyCtx, err) + if authenticated { + t.Log("SECURITY NOTE: authenticateAPIKey does not check Deleted flag; relies on SQL filter") + } + // Allow background goroutine to complete + time.Sleep(10 * time.Millisecond) + }) +} diff --git a/middlewares/combined_auth_test.go b/middlewares/combined_auth_test.go new file mode 100644 index 0000000..3ca6470 --- /dev/null +++ b/middlewares/combined_auth_test.go @@ -0,0 +1,546 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: Copyright (c) 2025 UnderNET + +package middlewares + +import ( + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + echojwt "github.com/labstack/echo-jwt/v4" + "github.com/labstack/echo/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/undernetirc/cservice-api/db/mocks" + "github.com/undernetirc/cservice-api/internal/helper" + "github.com/undernetirc/cservice-api/models" +) + +func TestDefaultCombinedAuthConfig(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + cfg := DefaultCombinedAuthConfig(mockService) + + assert.True(t, cfg.AllowJWT, "AllowJWT should default to true") + assert.True(t, cfg.AllowAPIKey, "AllowAPIKey should default to true") + assert.True(t, cfg.Required, "Required should default to true") + assert.Equal(t, mockService, cfg.Service, "Service should be set") +} + +// newTestCombinedAuthConfig creates a CombinedAuthConfig for testing with API key support only. +// JWT is disabled to avoid needing real RSA keys in unit tests. +func newTestCombinedAuthConfig(service models.ServiceInterface, allowJWT, allowAPIKey, required bool) CombinedAuthConfig { + return CombinedAuthConfig{ + AllowJWT: allowJWT, + AllowAPIKey: allowAPIKey, + Required: required, + JWTConfig: echojwt.Config{ + SigningKey: []byte("test-secret-key-for-unit-tests"), + }, + Service: service, + } +} + +// validAPIKey returns a test API key and its SHA-256 hash for use in tests. +func validAPIKey() string { + return "test-api-key-12345" +} + +// newValidAPIKeyRow returns a models.ApiKey representing a valid, non-expired key. +func newValidAPIKeyRow(scopes []byte, ipRestrictions []byte) models.ApiKey { + return models.ApiKey{ + ID: 1, + Name: "test-key", + KeyHash: "hashed", + Scopes: scopes, + CreatedBy: 1, + CreatedAt: int32(time.Now().Unix()), + ExpiresAt: pgtype.Int4{Int32: 0, Valid: false}, + Deleted: pgtype.Int2{Int16: 0, Valid: true}, + IpRestrictions: ipRestrictions, + } +} + +func TestCombinedAuth_JWTOnly(t *testing.T) { + t.Run("no auth header with required=true returns 401", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + cfg := newTestCombinedAuthConfig(mockService, true, false, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + assert.Equal(t, echo.ErrUnauthorized, err) + }) + + t.Run("no auth header with required=false continues", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + cfg := newTestCombinedAuthConfig(mockService, true, false, false) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + }) +} + +func TestCombinedAuth_APIKeyOnly(t *testing.T) { + t.Run("valid API key authenticates successfully", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + + scopesJSON, _ := json.Marshal([]string{"channels:read", "users:write"}) + apiKeyRow := newValidAPIKeyRow(scopesJSON, nil) + + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(apiKeyRow, nil).Once() + mockService.On("UpdateAPIKeyLastUsed", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", validAPIKey()) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + var capturedUser any + handler := CombinedAuth(cfg)(func(c echo.Context) error { + capturedUser = c.Get("user") + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + + apiKeyCtx, ok := capturedUser.(*helper.APIKeyContext) + require.True(t, ok, "user context should be *helper.APIKeyContext") + assert.Equal(t, int32(1), apiKeyCtx.ID) + assert.Equal(t, "test-key", apiKeyCtx.Name) + assert.Equal(t, []string{"channels:read", "users:write"}, apiKeyCtx.Scopes) + assert.True(t, apiKeyCtx.IsAPIKey) + }) + + t.Run("no API key with required=true returns 401", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + assert.Equal(t, echo.ErrUnauthorized, err) + }) +} + +func TestCombinedAuth_BothPresent(t *testing.T) { + t.Run("JWT fails, falls back to valid API key", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + + scopesJSON, _ := json.Marshal([]string{"read"}) + apiKeyRow := newValidAPIKeyRow(scopesJSON, nil) + + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(apiKeyRow, nil).Once() + mockService.On("UpdateAPIKeyLastUsed", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + + cfg := newTestCombinedAuthConfig(mockService, true, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + // Set both headers - JWT will fail (no valid signing key configured), API key should succeed + req.Header.Set("Authorization", "Bearer invalid-jwt-token") + req.Header.Set("X-API-Key", validAPIKey()) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + var capturedUser any + handler := CombinedAuth(cfg)(func(c echo.Context) error { + capturedUser = c.Get("user") + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + require.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + + apiKeyCtx, ok := capturedUser.(*helper.APIKeyContext) + require.True(t, ok, "should fall back to API key auth") + assert.Equal(t, int32(1), apiKeyCtx.ID) + }) +} + +func TestCombinedAuth_InvalidAPIKey(t *testing.T) { + t.Run("key not found in database returns 401", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(models.ApiKey{}, errors.New("no rows")).Once() + + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", "nonexistent-key") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + assert.Equal(t, echo.ErrUnauthorized, err) + }) +} + +func TestCombinedAuth_ExpiredKey(t *testing.T) { + t.Run("expired API key returns 401", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + + pastTime := int32(time.Now().Add(-1 * time.Hour).Unix()) + apiKeyRow := newValidAPIKeyRow(nil, nil) + apiKeyRow.ExpiresAt = pgtype.Int4{Int32: pastTime, Valid: true} + + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(apiKeyRow, nil).Once() + + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", validAPIKey()) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + assert.Equal(t, echo.ErrUnauthorized, err) + }) +} + +func TestCombinedAuth_ScopeValidation(t *testing.T) { + t.Run("scopes are correctly parsed and set on context", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + + scopes := []string{"channels:read", "users:write", "admin:manage"} + scopesJSON, _ := json.Marshal(scopes) + apiKeyRow := newValidAPIKeyRow(scopesJSON, nil) + + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(apiKeyRow, nil).Once() + mockService.On("UpdateAPIKeyLastUsed", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", validAPIKey()) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + var capturedUser any + handler := CombinedAuth(cfg)(func(c echo.Context) error { + capturedUser = c.Get("user") + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + require.NoError(t, err) + + apiKeyCtx, ok := capturedUser.(*helper.APIKeyContext) + require.True(t, ok) + assert.Equal(t, scopes, apiKeyCtx.Scopes) + }) + + t.Run("empty scopes results in nil scopes slice", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + + apiKeyRow := newValidAPIKeyRow(nil, nil) + + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(apiKeyRow, nil).Once() + mockService.On("UpdateAPIKeyLastUsed", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", validAPIKey()) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + var capturedUser any + handler := CombinedAuth(cfg)(func(c echo.Context) error { + capturedUser = c.Get("user") + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + require.NoError(t, err) + + apiKeyCtx, ok := capturedUser.(*helper.APIKeyContext) + require.True(t, ok) + assert.Nil(t, apiKeyCtx.Scopes) + }) +} + +func TestCombinedAuth_IPRestriction(t *testing.T) { + t.Run("allowed IP passes", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + + ipRestrictionsJSON, _ := json.Marshal([]string{"192.168.1.0/24"}) + scopesJSON, _ := json.Marshal([]string{"read"}) + apiKeyRow := newValidAPIKeyRow(scopesJSON, ipRestrictionsJSON) + + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(apiKeyRow, nil).Once() + mockService.On("UpdateAPIKeyLastUsed", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", validAPIKey()) + req.Header.Set("X-Real-IP", "192.168.1.50") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + }) + + t.Run("denied IP returns 401", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + + ipRestrictionsJSON, _ := json.Marshal([]string{"192.168.1.0/24"}) + apiKeyRow := newValidAPIKeyRow(nil, ipRestrictionsJSON) + + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(apiKeyRow, nil).Once() + + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", validAPIKey()) + req.Header.Set("X-Real-IP", "10.0.0.1") + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + assert.Equal(t, echo.ErrUnauthorized, err) + }) +} + +func TestAuthenticateAPIKey(t *testing.T) { + tests := []struct { + name string + setupMock func(m *mocks.ServiceInterface) + clientIP string + wantAuth bool + wantContext bool + wantErr bool + }{ + { + name: "valid key returns authenticated context", + setupMock: func(m *mocks.ServiceInterface) { + scopesJSON, _ := json.Marshal([]string{"read", "write"}) + m.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(newValidAPIKeyRow(scopesJSON, nil), nil).Once() + m.On("UpdateAPIKeyLastUsed", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + }, + wantAuth: true, + wantContext: true, + wantErr: false, + }, + { + name: "database error returns error", + setupMock: func(m *mocks.ServiceInterface) { + m.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(models.ApiKey{}, errors.New("connection refused")).Once() + }, + wantAuth: false, + wantContext: false, + wantErr: true, + }, + { + name: "expired key returns false with no error", + setupMock: func(m *mocks.ServiceInterface) { + row := newValidAPIKeyRow(nil, nil) + row.ExpiresAt = pgtype.Int4{Int32: int32(time.Now().Add(-1 * time.Hour).Unix()), Valid: true} + m.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(row, nil).Once() + }, + wantAuth: false, + wantContext: false, + wantErr: false, + }, + { + name: "IP allowed passes restriction check", + setupMock: func(m *mocks.ServiceInterface) { + ipJSON, _ := json.Marshal([]string{"10.0.0.0/8"}) + m.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(newValidAPIKeyRow(nil, ipJSON), nil).Once() + m.On("UpdateAPIKeyLastUsed", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + }, + clientIP: "10.0.0.5", + wantAuth: true, + wantContext: true, + wantErr: false, + }, + { + name: "IP denied fails restriction check", + setupMock: func(m *mocks.ServiceInterface) { + ipJSON, _ := json.Marshal([]string{"10.0.0.0/8"}) + m.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(newValidAPIKeyRow(nil, ipJSON), nil).Once() + }, + clientIP: "172.16.0.1", + wantAuth: false, + wantContext: false, + wantErr: false, + }, + { + name: "invalid scope JSON returns error", + setupMock: func(m *mocks.ServiceInterface) { + row := newValidAPIKeyRow([]byte(`not valid json`), nil) + m.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(row, nil).Once() + }, + wantAuth: false, + wantContext: false, + wantErr: true, + }, + { + name: "empty scopes returns context with nil scopes", + setupMock: func(m *mocks.ServiceInterface) { + m.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(newValidAPIKeyRow(nil, nil), nil).Once() + m.On("UpdateAPIKeyLastUsed", mock.Anything, mock.Anything, mock.Anything). + Return(nil).Maybe() + }, + wantAuth: true, + wantContext: true, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + tt.setupMock(mockService) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tt.clientIP != "" { + req.Header.Set("X-Real-IP", tt.clientIP) + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + authenticated, keyCtx, err := authenticateAPIKey(c, mockService, validAPIKey()) + + assert.Equal(t, tt.wantAuth, authenticated) + if tt.wantContext { + assert.NotNil(t, keyCtx) + assert.True(t, keyCtx.IsAPIKey) + } else { + assert.Nil(t, keyCtx) + } + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestCombinedAuth_DatabaseError(t *testing.T) { + t.Run("database error with required=true returns 401", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(models.ApiKey{}, errors.New("database unavailable")).Once() + + cfg := newTestCombinedAuthConfig(mockService, false, true, true) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", validAPIKey()) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + assert.Equal(t, echo.ErrUnauthorized, err) + }) + + t.Run("database error with required=false continues", func(t *testing.T) { + mockService := mocks.NewServiceInterface(t) + mockService.On("GetAPIKeyByHash", mock.Anything, mock.AnythingOfType("string")). + Return(models.ApiKey{}, errors.New("database unavailable")).Once() + + cfg := newTestCombinedAuthConfig(mockService, false, true, false) + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, "/", nil) + req.Header.Set("X-API-Key", validAPIKey()) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + handler := CombinedAuth(cfg)(func(c echo.Context) error { + return c.String(http.StatusOK, "success") + }) + + err := handler(c) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, rec.Code) + }) +}