diff --git a/cmd/api/src/api/v2/opengraphschema.go b/cmd/api/src/api/v2/opengraphschema.go index 24cf755b49..ac8eeeb532 100644 --- a/cmd/api/src/api/v2/opengraphschema.go +++ b/cmd/api/src/api/v2/opengraphschema.go @@ -31,6 +31,7 @@ type OpenGraphSchemaService interface { type GraphSchemaExtension struct { Environments []Environment `json:"environments"` + Findings []Finding `json:"findings"` } type Environment struct { @@ -39,6 +40,22 @@ type Environment struct { PrincipalKinds []string `json:"principalKinds"` } +type Finding struct { + Name string `json:"name"` + DisplayName string `json:"displayName"` + SourceKind string `json:"sourceKind"` + RelationshipKind string `json:"relationshipKind"` + EnvironmentKind string `json:"environmentKind"` + Remediation Remediation `json:"remediation"` +} + +type Remediation struct { + ShortDescription string `json:"shortDescription"` + LongDescription string `json:"longDescription"` + ShortRemediation string `json:"shortRemediation"` + LongRemediation string `json:"longRemediation"` +} + // TODO: Implement this - skeleton endpoint to simply test the handler. func (s Resources) OpenGraphSchemaIngest(response http.ResponseWriter, request *http.Request) { var ( diff --git a/cmd/api/src/config/config.go b/cmd/api/src/config/config.go index dd84db8406..2ee7652fd1 100644 --- a/cmd/api/src/config/config.go +++ b/cmd/api/src/config/config.go @@ -159,7 +159,7 @@ type Configuration struct { EnableAPILogging bool `json:"enable_api_logging"` EnableCypherMutations bool `json:"enable_cypher_mutations"` DisableAnalysis bool `json:"disable_analysis"` - DisableAPIKeys bool `json:"disable_api_keys"` + DisableAPIKeys bool `json:"disable_api_keys"` DisableCypherComplexityLimit bool `json:"disable_cypher_complexity_limit"` DisableIngest bool `json:"disable_ingest"` DisableMigrations bool `json:"disable_migrations"` diff --git a/cmd/api/src/config/default.go b/cmd/api/src/config/default.go index 6b35df66ca..cf1c3ac903 100644 --- a/cmd/api/src/config/default.go +++ b/cmd/api/src/config/default.go @@ -62,7 +62,7 @@ func NewDefaultConfiguration() (Configuration, error) { EnableStartupWaitPeriod: true, EnableAPILogging: true, DisableAnalysis: false, - DisableAPIKeys: false, + DisableAPIKeys: false, DisableCypherComplexityLimit: false, DisableIngest: false, DisableMigrations: false, diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index 3588f3f79b..0819e88257 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -60,6 +60,7 @@ type OpenGraphSchema interface { CreateSchemaRelationshipFinding(ctx context.Context, extensionId int32, relationshipKindId int32, environmentId int32, name string, displayName string) (model.SchemaRelationshipFinding, error) GetSchemaRelationshipFindingById(ctx context.Context, findingId int32) (model.SchemaRelationshipFinding, error) + GetSchemaRelationshipFindingByName(ctx context.Context, name string) (model.SchemaRelationshipFinding, error) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error CreateRemediation(ctx context.Context, findingId int32, shortDescription string, longDescription string, shortRemediation string, longRemediation string) (model.Remediation, error) @@ -677,6 +678,23 @@ func (s *BloodhoundDB) GetSchemaRelationshipFindingById(ctx context.Context, fin return finding, nil } +// GetSchemaRelationshipFindingByName - retrieves a schema relationship finding by finding name. +func (s *BloodhoundDB) GetSchemaRelationshipFindingByName(ctx context.Context, name string) (model.SchemaRelationshipFinding, error) { + var finding model.SchemaRelationshipFinding + + if result := s.db.WithContext(ctx).Raw(fmt.Sprintf(` + SELECT id, schema_extension_id, relationship_kind_id, environment_id, name, display_name, created_at + FROM %s WHERE name = ?`, + finding.TableName()), + name).Scan(&finding); result.Error != nil { + return model.SchemaRelationshipFinding{}, CheckError(result) + } else if result.RowsAffected == 0 { + return model.SchemaRelationshipFinding{}, ErrNotFound + } + + return finding, nil +} + // DeleteSchemaRelationshipFinding - deletes a schema relationship finding by id. func (s *BloodhoundDB) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error { var finding model.SchemaRelationshipFinding @@ -793,6 +811,9 @@ func (s *BloodhoundDB) CreatePrincipalKind(ctx context.Context, environmentId in VALUES (?, ?, NOW()) RETURNING environment_id, principal_kind, created_at`, environmentId, principalKind).Scan(&envPrincipalKind); result.Error != nil { + if strings.Contains(result.Error.Error(), DuplicateKeyValueErrorString) { + return model.SchemaEnvironmentPrincipalKind{}, result.Error + } return model.SchemaEnvironmentPrincipalKind{}, CheckError(result) } diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index c23d6e150f..27d9762599 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -2262,6 +2262,21 @@ func (mr *MockDatabaseMockRecorder) GetSchemaRelationshipFindingById(ctx, findin return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaRelationshipFindingById", reflect.TypeOf((*MockDatabase)(nil).GetSchemaRelationshipFindingById), ctx, findingId) } +// GetSchemaRelationshipFindingByName mocks base method. +func (m *MockDatabase) GetSchemaRelationshipFindingByName(ctx context.Context, name string) (model.SchemaRelationshipFinding, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSchemaRelationshipFindingByName", ctx, name) + ret0, _ := ret[0].(model.SchemaRelationshipFinding) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSchemaRelationshipFindingByName indicates an expected call of GetSchemaRelationshipFindingByName. +func (mr *MockDatabaseMockRecorder) GetSchemaRelationshipFindingByName(ctx, name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaRelationshipFindingByName", reflect.TypeOf((*MockDatabase)(nil).GetSchemaRelationshipFindingByName), ctx, name) +} + // GetScopeForSavedQuery mocks base method. func (m *MockDatabase) GetScopeForSavedQuery(ctx context.Context, queryID int64, userID uuid.UUID) (database.SavedQueryScopeMap, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/database/upsert_schema_environment.go b/cmd/api/src/database/upsert_schema_environment.go index 180ec7702a..f266623d9d 100644 --- a/cmd/api/src/database/upsert_schema_environment.go +++ b/cmd/api/src/database/upsert_schema_environment.go @@ -31,7 +31,7 @@ func (s *BloodhoundDB) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Con SchemaExtensionId: schemaExtensionId, } - envKind, err := s.validateAndTranslateEnvironmentKind(ctx, environmentKind) + envKindID, err := s.validateAndTranslateEnvironmentKind(ctx, environmentKind) if err != nil { return err } @@ -46,7 +46,7 @@ func (s *BloodhoundDB) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Con return err } - environment.EnvironmentKindId = int32(envKind.ID) + environment.EnvironmentKindId = envKindID environment.SourceKindId = sourceKindID envID, err := s.replaceSchemaEnvironment(ctx, environment) @@ -62,13 +62,13 @@ func (s *BloodhoundDB) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Con } // validateAndTranslateEnvironmentKind validates that the environment kind exists in the kinds table. -func (s *BloodhoundDB) validateAndTranslateEnvironmentKind(ctx context.Context, environmentKindName string) (model.Kind, error) { +func (s *BloodhoundDB) validateAndTranslateEnvironmentKind(ctx context.Context, environmentKindName string) (int32, error) { if envKind, err := s.GetKindByName(ctx, environmentKindName); err != nil && !errors.Is(err, ErrNotFound) { - return model.Kind{}, fmt.Errorf("error retrieving environment kind '%s': %w", environmentKindName, err) + return 0, fmt.Errorf("error retrieving environment kind '%s': %w", environmentKindName, err) } else if errors.Is(err, ErrNotFound) { - return model.Kind{}, fmt.Errorf("environment kind '%s' not found", environmentKindName) + return 0, fmt.Errorf("environment kind '%s' not found", environmentKindName) } else { - return envKind, nil + return envKind.ID, nil } } diff --git a/cmd/api/src/database/upsert_schema_extension.go b/cmd/api/src/database/upsert_schema_extension.go index d189a1d40f..11dcabe02f 100644 --- a/cmd/api/src/database/upsert_schema_extension.go +++ b/cmd/api/src/database/upsert_schema_extension.go @@ -26,7 +26,23 @@ type EnvironmentInput struct { PrincipalKinds []string } -func (s *BloodhoundDB) UpsertGraphSchemaExtension(ctx context.Context, extensionID int32, environments []EnvironmentInput) error { +type FindingInput struct { + Name string + DisplayName string + RelationshipKindName string + EnvironmentKindName string + SourceKindName string + RemediationInput RemediationInput +} + +type RemediationInput struct { + ShortDescription string + LongDescription string + ShortRemediation string + LongRemediation string +} + +func (s *BloodhoundDB) UpsertGraphSchemaExtension(ctx context.Context, extensionID int32, environments []EnvironmentInput, findings []FindingInput) error { return s.Transaction(ctx, func(tx *BloodhoundDB) error { for _, env := range environments { if err := tx.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, extensionID, env.EnvironmentKindName, env.SourceKindName, env.PrincipalKinds); err != nil { @@ -34,6 +50,16 @@ func (s *BloodhoundDB) UpsertGraphSchemaExtension(ctx context.Context, extension } } + for _, finding := range findings { + if schemaFinding, err := tx.UpsertFinding(ctx, extensionID, finding.SourceKindName, finding.RelationshipKindName, finding.EnvironmentKindName, finding.Name, finding.DisplayName); err != nil { + return fmt.Errorf("failed to upsert finding: %w", err) + } else { + if err := tx.UpsertRemediation(ctx, schemaFinding.ID, finding.RemediationInput.ShortDescription, finding.RemediationInput.LongDescription, finding.RemediationInput.ShortRemediation, finding.RemediationInput.LongRemediation); err != nil { + return fmt.Errorf("failed to upsert remediation: %w", err) + } + } + } + return nil }) } diff --git a/cmd/api/src/database/upsert_schema_extension_integration_test.go b/cmd/api/src/database/upsert_schema_extension_integration_test.go index e4fcae0287..6c67d9c682 100644 --- a/cmd/api/src/database/upsert_schema_extension_integration_test.go +++ b/cmd/api/src/database/upsert_schema_extension_integration_test.go @@ -4,7 +4,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -22,7 +22,6 @@ import ( "testing" "github.com/specterops/bloodhound/cmd/api/src/database" - "github.com/specterops/dawgs/graph" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -30,66 +29,23 @@ import ( func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { type args struct { environments []database.EnvironmentInput + findings []database.FindingInput } tests := []struct { name string - setupData func(t *testing.T, db *database.BloodhoundDB) int32 + setupData func(t *testing.T, db *database.BloodhoundDB) int32 // Returns extensionId args args - assert func(t *testing.T, db *database.BloodhoundDB) + assert func(t *testing.T, db *database.BloodhoundDB, extensionId int32) expectedError string }{ { - name: "Success: Create environment with principal kinds", + name: "Success: Create new environments and findings with remediations", setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { t.Helper() ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") require.NoError(t, err) - return ext.ID - }, - args: args{ - environments: []database.EnvironmentInput{ - { - EnvironmentKindName: "Tag_Tier_Zero", - SourceKindName: "Base", - PrincipalKinds: []string{"Tag_Tier_Zero", "Tag_Owned"}, - }, - }, - }, - assert: func(t *testing.T, db *database.BloodhoundDB) { - t.Helper() - - expectedPrincipalKindNames := []string{"Tag_Tier_Zero", "Tag_Owned"} - - environments, err := db.GetEnvironments(context.Background()) - assert.NoError(t, err) - assert.Equal(t, 1, len(environments)) - - principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) - assert.NoError(t, err) - assert.Equal(t, len(expectedPrincipalKindNames), len(principalKinds)) - - expectedKindIDs := make([]int32, len(expectedPrincipalKindNames)) - for i, name := range expectedPrincipalKindNames { - kind, err := db.GetKindByName(context.Background(), name) - assert.NoError(t, err) - expectedKindIDs[i] = int32(kind.ID) - } - - actualKindIDs := make([]int32, len(principalKinds)) - for i, pk := range principalKinds { - assert.Equal(t, environments[0].ID, pk.EnvironmentId) - actualKindIDs[i] = pk.PrincipalKind - } - - assert.ElementsMatch(t, expectedKindIDs, actualKindIDs) - }, - }, - { - name: "Success: Create multiple environments", - setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { - t.Helper() - ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + _, err = db.CreateEnvironment(context.Background(), ext.ID, int32(1), int32(1)) require.NoError(t, err) return ext.ID @@ -99,219 +55,83 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { { EnvironmentKindName: "Tag_Tier_Zero", SourceKindName: "Base", - PrincipalKinds: []string{"Tag_Tier_Zero"}, - }, - { - EnvironmentKindName: "Tag_Owned", - SourceKindName: "Base", - PrincipalKinds: []string{"Tag_Owned"}, + PrincipalKinds: []string{"Tag_Owned", "Tag_Tier_Zero"}, }, }, - }, - assert: func(t *testing.T, db *database.BloodhoundDB) { - t.Helper() - - environments, err := db.GetEnvironments(context.Background()) - assert.NoError(t, err) - assert.Equal(t, 2, len(environments), "Should have two environments") - - // Verify first environment - env1PrincipalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) - assert.NoError(t, err) - assert.Equal(t, 1, len(env1PrincipalKinds)) - - // Verify second environment - env2PrincipalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[1].ID) - assert.NoError(t, err) - assert.Equal(t, 1, len(env2PrincipalKinds)) - }, - }, - { - name: "Success: Upsert replaces existing environment", - setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { - t.Helper() - ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") - require.NoError(t, err) - - // Create initial environment - err = db.UpsertGraphSchemaExtension(context.Background(), ext.ID, []database.EnvironmentInput{ + findings: []database.FindingInput{ { - EnvironmentKindName: "Tag_Tier_Zero", - SourceKindName: "Base", - PrincipalKinds: []string{"Tag_Owned"}, + Name: "Name 1", + DisplayName: "Display Name 1", + RelationshipKindName: "Tag_Tier_Zero", + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + RemediationInput: database.RemediationInput{ + ShortDescription: "Short Description", + LongDescription: "Long Description", + ShortRemediation: "Short Remediation", + LongRemediation: "Long Remediation", + }, }, - }) - require.NoError(t, err) - - return ext.ID - }, - args: args{ - environments: []database.EnvironmentInput{ { - EnvironmentKindName: "Tag_Tier_Zero", - SourceKindName: "Base", - PrincipalKinds: []string{"Tag_Tier_Zero"}, + Name: "Name 2", + DisplayName: "Display Name 2", + RelationshipKindName: "Tag_Tier_Zero", + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + RemediationInput: database.RemediationInput{ + ShortDescription: "Short Description", + LongDescription: "Long Description", + ShortRemediation: "Short Remediation", + LongRemediation: "Long Remediation", + }, }, }, }, - assert: func(t *testing.T, db *database.BloodhoundDB) { + assert: func(t *testing.T, db *database.BloodhoundDB, extensionId int32) { t.Helper() - expectedPrincipalKindNames := []string{"Tag_Tier_Zero"} - - environments, err := db.GetEnvironments(context.Background()) - assert.NoError(t, err) - assert.Equal(t, 1, len(environments), "Should only have one environment (old one replaced)") - - principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) - assert.NoError(t, err) - assert.Equal(t, 1, len(principalKinds)) - - expectedKind, err := db.GetKindByName(context.Background(), expectedPrincipalKindNames[0]) - assert.NoError(t, err) - - assert.Equal(t, int32(expectedKind.ID), principalKinds[0].PrincipalKind) - assert.Equal(t, environments[0].ID, principalKinds[0].EnvironmentId) - }, - }, - { - name: "Success: Source kind auto-registers", - setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { - t.Helper() - ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + // Verify findings were created + finding1, err := db.GetSchemaRelationshipFindingByName(context.Background(), "Name 1") require.NoError(t, err) + assert.Equal(t, extensionId, finding1.SchemaExtensionId) + assert.Equal(t, "Display Name 1", finding1.DisplayName) - return ext.ID - }, - args: args{ - environments: []database.EnvironmentInput{ - { - EnvironmentKindName: "Tag_Tier_Zero", - SourceKindName: "NewSource", - PrincipalKinds: []string{"Tag_Tier_Zero"}, - }, - }, - }, - assert: func(t *testing.T, db *database.BloodhoundDB) { - t.Helper() - - sourceKind, err := db.GetSourceKindByName(context.Background(), "NewSource") - assert.NoError(t, err) - assert.Equal(t, graph.StringKind("NewSource"), sourceKind.Name) - - environments, err := db.GetEnvironments(context.Background()) - assert.NoError(t, err) - assert.Equal(t, 1, len(environments)) - assert.Equal(t, int32(sourceKind.ID), environments[0].SourceKindId) - - principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) - assert.NoError(t, err) - assert.Equal(t, 1, len(principalKinds)) - }, - }, - { - name: "Success: Multiple environments with different source kinds", - setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { - t.Helper() - ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + finding2, err := db.GetSchemaRelationshipFindingByName(context.Background(), "Name 2") require.NoError(t, err) + assert.Equal(t, extensionId, finding2.SchemaExtensionId) + assert.Equal(t, "Display Name 2", finding2.DisplayName) - return ext.ID - }, - args: args{ - environments: []database.EnvironmentInput{ - { - EnvironmentKindName: "Tag_Tier_Zero", - SourceKindName: "Base", - PrincipalKinds: []string{"Tag_Tier_Zero"}, - }, - { - EnvironmentKindName: "Tag_Owned", - SourceKindName: "NewSource", - PrincipalKinds: []string{"Tag_Owned"}, - }, - }, - }, - assert: func(t *testing.T, db *database.BloodhoundDB) { - t.Helper() - - // Verify NewSource was auto-registered - sourceKind, err := db.GetSourceKindByName(context.Background(), "NewSource") - assert.NoError(t, err) - assert.Equal(t, graph.StringKind("NewSource"), sourceKind.Name) - - environments, err := db.GetEnvironments(context.Background()) - assert.NoError(t, err) - assert.Equal(t, 2, len(environments), "Should have two environments") + // Verify remediations were created + remediation1, err := db.GetRemediationByFindingId(context.Background(), finding1.ID) + require.NoError(t, err) + assert.Equal(t, "Short Description", remediation1.ShortDescription) + assert.Equal(t, "Long Description", remediation1.LongDescription) + assert.Equal(t, "Short Remediation", remediation1.ShortRemediation) + assert.Equal(t, "Long Remediation", remediation1.LongRemediation) - for _, env := range environments { - principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), env.ID) - assert.NoError(t, err) - assert.Equal(t, 1, len(principalKinds), "Each environment should have one principal kind") - } + remediation2, err := db.GetRemediationByFindingId(context.Background(), finding2.ID) + require.NoError(t, err) + assert.Equal(t, "Short Description", remediation2.ShortDescription) + assert.Equal(t, "Long Description", remediation2.LongDescription) + assert.Equal(t, "Short Remediation", remediation2.ShortRemediation) + assert.Equal(t, "Long Remediation", remediation2.LongRemediation) }, }, { - name: "Error: First environment has invalid environment kind", + name: "Success: Update existing findings and remediations", setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { t.Helper() - ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt2", "Test2", "v1.0.0") require.NoError(t, err) - return ext.ID - }, - args: args{ - environments: []database.EnvironmentInput{ - { - EnvironmentKindName: "NonExistent", - SourceKindName: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - expectedError: "environment kind 'NonExistent' not found", - assert: func(t *testing.T, db *database.BloodhoundDB) { - t.Helper() - - // Verify transaction rolled back - no environment created - environments, err := db.GetEnvironments(context.Background()) - assert.NoError(t, err) - assert.Equal(t, 0, len(environments), "No environment should exist after rollback") - }, - }, - { - name: "Error: First environment has invalid principal kind", - setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { - t.Helper() - ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + env, err := db.CreateEnvironment(context.Background(), ext.ID, 1, 1) require.NoError(t, err) - return ext.ID - }, - args: args{ - environments: []database.EnvironmentInput{ - { - EnvironmentKindName: "Tag_Tier_Zero", - SourceKindName: "Base", - PrincipalKinds: []string{"NonExistent"}, - }, - }, - }, - expectedError: "principal kind 'NonExistent' not found", - assert: func(t *testing.T, db *database.BloodhoundDB) { - t.Helper() + // Create initial finding with remediation + finding, err := db.CreateSchemaRelationshipFinding(context.Background(), ext.ID, 1, env.ID, "ExistingFinding", "Old Display Name") + require.NoError(t, err) - // Verify transaction rolled back - no environment created - environments, err := db.GetEnvironments(context.Background()) - assert.NoError(t, err) - assert.Equal(t, 0, len(environments), "No environment should exist after rollback") - }, - }, - { - name: "Rollback: Second environment fails, first should rollback", - setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { - t.Helper() - ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + _, err = db.CreateRemediation(context.Background(), finding.ID, "old short", "old long", "old short rem", "old long rem") require.NoError(t, err) return ext.ID @@ -323,110 +143,83 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { SourceKindName: "Base", PrincipalKinds: []string{"Tag_Tier_Zero"}, }, + }, + findings: []database.FindingInput{ { - EnvironmentKindName: "NonExistent", - SourceKindName: "Base", - PrincipalKinds: []string{}, + Name: "ExistingFinding", + DisplayName: "Updated Display Name", + RelationshipKindName: "Tag_Tier_Zero", + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + RemediationInput: database.RemediationInput{ + ShortDescription: "Updated Short Description", + LongDescription: "Updated Long Description", + ShortRemediation: "Updated Short Remediation", + LongRemediation: "Updated Long Remediation", + }, }, }, }, - expectedError: "environment kind 'NonExistent' not found", - assert: func(t *testing.T, db *database.BloodhoundDB) { + assert: func(t *testing.T, db *database.BloodhoundDB, extensionId int32) { t.Helper() - // Verify complete transaction rollback - no environments created - environments, err := db.GetEnvironments(context.Background()) - assert.NoError(t, err) - assert.Equal(t, 0, len(environments), "No environments should exist after rollback") - }, - }, - { - name: "Rollback: Second environment has invalid principal kind", - setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { - t.Helper() - ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + // Verify finding was updated (deleted and recreated) + finding, err := db.GetSchemaRelationshipFindingByName(context.Background(), "ExistingFinding") require.NoError(t, err) + assert.Equal(t, extensionId, finding.SchemaExtensionId) + assert.Equal(t, "Updated Display Name", finding.DisplayName) - return ext.ID - }, - args: args{ - environments: []database.EnvironmentInput{ - { - EnvironmentKindName: "Tag_Tier_Zero", - SourceKindName: "Base", - PrincipalKinds: []string{"Tag_Tier_Zero"}, - }, - { - EnvironmentKindName: "Tag_Owned", - SourceKindName: "Base", - PrincipalKinds: []string{"NonExistent"}, - }, - }, - }, - expectedError: "principal kind 'NonExistent' not found", - assert: func(t *testing.T, db *database.BloodhoundDB) { - t.Helper() - - // Verify complete transaction rollback - no environments created - environments, err := db.GetEnvironments(context.Background()) - assert.NoError(t, err) - assert.Equal(t, 0, len(environments), "No environments should exist after rollback") + // Verify remediation was updated + remediation, err := db.GetRemediationByFindingId(context.Background(), finding.ID) + require.NoError(t, err) + assert.Equal(t, "Updated Short Description", remediation.ShortDescription) + assert.Equal(t, "Updated Long Description", remediation.LongDescription) + assert.Equal(t, "Updated Short Remediation", remediation.ShortRemediation) + assert.Equal(t, "Updated Long Remediation", remediation.LongRemediation) }, }, { - name: "Rollback: Partial failure in first environment's principal kinds", + name: "Success: Empty environments and findings", setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { t.Helper() - ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt3", "Test3", "v1.0.0") require.NoError(t, err) return ext.ID }, args: args{ - environments: []database.EnvironmentInput{ - { - EnvironmentKindName: "Tag_Tier_Zero", - SourceKindName: "Base", - PrincipalKinds: []string{"Tag_Owned", "NonExistent"}, - }, - }, + environments: []database.EnvironmentInput{}, + findings: []database.FindingInput{}, }, - expectedError: "principal kind 'NonExistent' not found", - assert: func(t *testing.T, db *database.BloodhoundDB) { + assert: func(t *testing.T, db *database.BloodhoundDB, extensionId int32) { t.Helper() - - // Verify transaction rolled back - no environment created - environments, err := db.GetEnvironments(context.Background()) - assert.NoError(t, err) - assert.Equal(t, 0, len(environments), "No environment should exist after rollback") + // Nothing to assert - just verify no error occurred }, }, } - for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { testSuite := setupIntegrationTestSuite(t) defer teardownIntegrationTestSuite(t, &testSuite) - extensionID := tt.setupData(t, testSuite.BHDatabase) + extensionId := tt.setupData(t, testSuite.BHDatabase) err := testSuite.BHDatabase.UpsertGraphSchemaExtension( context.Background(), - extensionID, + extensionId, tt.args.environments, + tt.args.findings, ) if tt.expectedError != "" { require.Error(t, err) assert.Contains(t, err.Error(), tt.expectedError) - if tt.assert != nil { - tt.assert(t, testSuite.BHDatabase) - } } else { require.NoError(t, err) - if tt.assert != nil { - tt.assert(t, testSuite.BHDatabase) - } + } + + if tt.assert != nil { + tt.assert(t, testSuite.BHDatabase, extensionId) } }) } diff --git a/cmd/api/src/database/upsert_schema_finding.go b/cmd/api/src/database/upsert_schema_finding.go new file mode 100644 index 0000000000..9bca51a127 --- /dev/null +++ b/cmd/api/src/database/upsert_schema_finding.go @@ -0,0 +1,89 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package database + +import ( + "context" + "errors" + "fmt" + + "github.com/specterops/bloodhound/cmd/api/src/model" +) + +// UpsertFinding validates and upserts a finding. +// If a finding with the same name exists, it will be deleted and re-created. +func (s *BloodhoundDB) UpsertFinding(ctx context.Context, extensionId int32, sourceKindName, relationshipKindName, environmentKind string, name, displayName string) (model.SchemaRelationshipFinding, error) { + relationshipKindId, err := s.validateAndTranslateRelationshipKind(ctx, relationshipKindName) + if err != nil { + return model.SchemaRelationshipFinding{}, err + } + + environmentKindId, err := s.validateAndTranslateEnvironmentKind(ctx, environmentKind) + if err != nil { + return model.SchemaRelationshipFinding{}, err + } + + sourceKindId, err := s.validateAndTranslateSourceKind(ctx, sourceKindName) + if err != nil { + return model.SchemaRelationshipFinding{}, err + } + + // The unique constraint on (environment_kind_id, source_kind_id) of the Schema Environment table ensures no + // duplicate pairs exist, enabling this logic. + environment, err := s.GetEnvironmentByKinds(ctx, environmentKindId, sourceKindId) + if err != nil { + return model.SchemaRelationshipFinding{}, err + } + + finding, err := s.replaceFinding(ctx, extensionId, relationshipKindId, environment.ID, name, displayName) + if err != nil { + return model.SchemaRelationshipFinding{}, err + } + + return finding, nil +} + +// validateAndTranslateRelationshipKind validates that the relationship kind exists in the kinds table. +func (s *BloodhoundDB) validateAndTranslateRelationshipKind(ctx context.Context, relationshipKindName string) (int32, error) { + if relationshipKind, err := s.GetKindByName(ctx, relationshipKindName); err != nil && !errors.Is(err, ErrNotFound) { + return 0, fmt.Errorf("error retrieving relationship kind '%s': %w", relationshipKindName, err) + } else if errors.Is(err, ErrNotFound) { + return 0, fmt.Errorf("relationship kind '%s' not found", relationshipKindName) + } else { + return relationshipKind.ID, nil + } +} + +// replaceFinding creates or updates a schema relationship finding. +// If a finding with the given name exists, it deletes it first before creating the new one. +func (s *BloodhoundDB) replaceFinding(ctx context.Context, extensionId, relationshipKindId, environmentId int32, name, displayName string) (model.SchemaRelationshipFinding, error) { + if existing, err := s.GetSchemaRelationshipFindingByName(ctx, name); err != nil && !errors.Is(err, ErrNotFound) { + return model.SchemaRelationshipFinding{}, fmt.Errorf("error retrieving schema relationship finding: %w", err) + } else if err == nil { + // Finding exists - delete it first + if err := s.DeleteSchemaRelationshipFinding(ctx, existing.ID); err != nil { + return model.SchemaRelationshipFinding{}, fmt.Errorf("error deleting schema relationship finding %d: %w", existing.ID, err) + } + } + + finding, err := s.CreateSchemaRelationshipFinding(ctx, extensionId, relationshipKindId, environmentId, name, displayName) + if err != nil { + return model.SchemaRelationshipFinding{}, fmt.Errorf("error creating schema relationship finding: %w", err) + } + + return finding, nil + +} diff --git a/cmd/api/src/database/upsert_schema_finding_integration_test.go b/cmd/api/src/database/upsert_schema_finding_integration_test.go new file mode 100644 index 0000000000..70297b5e90 --- /dev/null +++ b/cmd/api/src/database/upsert_schema_finding_integration_test.go @@ -0,0 +1,146 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +//go:build integration + +package database_test + +import ( + "context" + "testing" + + "github.com/specterops/bloodhound/cmd/api/src/database" + "github.com/specterops/bloodhound/cmd/api/src/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBloodhoundDB_UpsertFinding(t *testing.T) { + type args struct { + sourceKindName, relationshipKindName, environmentKind, name, displayName string + } + tests := []struct { + name string + setupData func(t *testing.T, db *database.BloodhoundDB) int32 // Returns extensionId + args args + assert func(t *testing.T, db *database.BloodhoundDB, extensionId int32) + expectedError string + }{ + { + name: "Success: Update existing finding - delete and re-create", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + env, err := db.CreateEnvironment(context.Background(), ext.ID, 1, 1) + require.NoError(t, err) + + // Create finding + _, err = db.CreateSchemaRelationshipFinding(context.Background(), ext.ID, 1, env.ID, "Finding Name", "Finding Display Name") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + sourceKindName: "Base", + relationshipKindName: "Tag_Tier_Zero", + environmentKind: "Tag_Tier_Zero", + // Name triggers upsert so this needs to match the finding's name that we want to update + name: "Finding Name", + displayName: "Updated Display Name", + }, + assert: func(t *testing.T, db *database.BloodhoundDB, extensionId int32) { + t.Helper() + + finding, err := db.GetSchemaRelationshipFindingByName(context.Background(), "Finding Name") + require.NoError(t, err) + + assert.Equal(t, extensionId, finding.SchemaExtensionId) + assert.Equal(t, "Finding Name", finding.Name) + assert.Equal(t, "Updated Display Name", finding.DisplayName) + }, + }, + { + name: "Success: Create finding when none exists", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt2", "Test2", "v1.0.0") + require.NoError(t, err) + + _, err = db.CreateEnvironment(context.Background(), ext.ID, 1, 1) + require.NoError(t, err) + + // No finding created since we're testing the creation workflow + return ext.ID + }, + args: args{ + sourceKindName: "Base", + relationshipKindName: "Tag_Tier_Zero", + environmentKind: "Tag_Tier_Zero", + name: "Finding", + displayName: "Finding Display Name", + }, + assert: func(t *testing.T, db *database.BloodhoundDB, extensionId int32) { + t.Helper() + + finding, err := db.GetSchemaRelationshipFindingByName(context.Background(), "Finding") + require.NoError(t, err) + + assert.Equal(t, extensionId, finding.SchemaExtensionId) + assert.Equal(t, "Finding", finding.Name) + assert.Equal(t, "Finding Display Name", finding.DisplayName) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testSuite := setupIntegrationTestSuite(t) + defer teardownIntegrationTestSuite(t, &testSuite) + + extensionId := tt.setupData(t, testSuite.BHDatabase) + + var findingResponse model.SchemaRelationshipFinding + // Wrap the call in a transaction + err := testSuite.BHDatabase.Transaction(context.Background(), func(tx *database.BloodhoundDB) error { + finding, err := tx.UpsertFinding( + context.Background(), + extensionId, + tt.args.sourceKindName, + tt.args.relationshipKindName, + tt.args.environmentKind, + tt.args.name, + tt.args.displayName, + ) + findingResponse = finding + return err + }) + + if tt.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + } else { + require.NoError(t, err) + assert.NotZero(t, findingResponse.ID, "Finding should have been created/updated") + } + + if tt.assert != nil { + tt.assert(t, testSuite.BHDatabase, extensionId) + } + }) + } +} diff --git a/cmd/api/src/database/upsert_schema_remediation.go b/cmd/api/src/database/upsert_schema_remediation.go new file mode 100644 index 0000000000..cff7f98e87 --- /dev/null +++ b/cmd/api/src/database/upsert_schema_remediation.go @@ -0,0 +1,41 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package database + +import ( + "context" + "errors" + "fmt" +) + +// UpsertRemediation validates and upserts a remediation. +// If the remediation exists for the finding ID, it is updated. If it doesn't already exist, it is created. +// Findings information must be inserted first before inserting remediation information. +func (s *BloodhoundDB) UpsertRemediation(ctx context.Context, findingId int32, shortDescription, longDescription, shortRemediation, longRemediation string) error { + if _, err := s.GetRemediationByFindingId(ctx, findingId); err != nil && !errors.Is(err, ErrNotFound) { + return fmt.Errorf("error retrieving remediation by finding id '%d': %w", findingId, err) + } else if err == nil { + // Remediation exists - update it + if _, err := s.UpdateRemediation(ctx, findingId, shortDescription, longDescription, shortRemediation, longRemediation); err != nil { + return fmt.Errorf("error updating remediation by finding id '%d': %w", findingId, err) + } + } else { + if _, err := s.CreateRemediation(ctx, findingId, shortDescription, longDescription, shortRemediation, longRemediation); err != nil { + return fmt.Errorf("error creating remediation by finding id '%d': %w", findingId, err) + } + } + return nil +} diff --git a/cmd/api/src/database/upsert_schema_remediation_integration_test.go b/cmd/api/src/database/upsert_schema_remediation_integration_test.go new file mode 100644 index 0000000000..78616cbfb5 --- /dev/null +++ b/cmd/api/src/database/upsert_schema_remediation_integration_test.go @@ -0,0 +1,146 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +//go:build integration + +package database_test + +import ( + "context" + "testing" + + "github.com/specterops/bloodhound/cmd/api/src/database" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBloodhoundDB_UpsertRemediation(t *testing.T) { + type args struct { + shortDescription, longDescription, shortRemediation, longRemediation string + } + tests := []struct { + name string + setupData func(t *testing.T, db *database.BloodhoundDB) int32 // Returns findingID + args args + assert func(t *testing.T, db *database.BloodhoundDB, findingId int32) + expectedError string + }{ + { + name: "Success: Update existing remediation", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + env, err := db.CreateEnvironment(context.Background(), ext.ID, 1, 1) + require.NoError(t, err) + + finding, err := db.CreateSchemaRelationshipFinding(context.Background(), ext.ID, 1, env.ID, "Finding", "Finding Display Name") + require.NoError(t, err) + + _, err = db.CreateRemediation(context.Background(), finding.ID, "short", "long", "short rem", "long rem") + require.NoError(t, err) + + return finding.ID + }, + args: args{ + shortDescription: "updated short description", + longDescription: "updated long description", + shortRemediation: "updated short remediation", + longRemediation: "updated long remediation", + }, + assert: func(t *testing.T, db *database.BloodhoundDB, findingId int32) { + t.Helper() + + remediation, err := db.GetRemediationByFindingId(context.Background(), findingId) + require.NoError(t, err) + + assert.Equal(t, findingId, remediation.FindingID) + assert.Equal(t, "updated short description", remediation.ShortDescription) + assert.Equal(t, "updated long description", remediation.LongDescription) + assert.Equal(t, "updated short remediation", remediation.ShortRemediation) + assert.Equal(t, "updated long remediation", remediation.LongRemediation) + }, + }, + { + name: "Success: Create remediation when none exists", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + env, err := db.CreateEnvironment(context.Background(), ext.ID, 1, 1) + require.NoError(t, err) + + // Create Finding but do not create Remediation + finding, err := db.CreateSchemaRelationshipFinding(context.Background(), ext.ID, 1, env.ID, "Finding2", "Finding 2 Display Name") + require.NoError(t, err) + + return finding.ID + }, + args: args{ + shortDescription: "new short description", + longDescription: "new long description", + shortRemediation: "new short remediation", + longRemediation: "new long remediation", + }, + assert: func(t *testing.T, db *database.BloodhoundDB, findingId int32) { + t.Helper() + + remediation, err := db.GetRemediationByFindingId(context.Background(), findingId) + require.NoError(t, err) + + assert.Equal(t, findingId, remediation.FindingID) + assert.Equal(t, "new short description", remediation.ShortDescription) + assert.Equal(t, "new long description", remediation.LongDescription) + assert.Equal(t, "new short remediation", remediation.ShortRemediation) + assert.Equal(t, "new long remediation", remediation.LongRemediation) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testSuite := setupIntegrationTestSuite(t) + defer teardownIntegrationTestSuite(t, &testSuite) + + findingId := tt.setupData(t, testSuite.BHDatabase) + + // Wrap the call in a transaction + err := testSuite.BHDatabase.Transaction(context.Background(), func(tx *database.BloodhoundDB) error { + return tx.UpsertRemediation( + context.Background(), + findingId, + tt.args.shortDescription, + tt.args.longDescription, + tt.args.shortRemediation, + tt.args.longRemediation, + ) + }) + + if tt.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + } else { + require.NoError(t, err) + } + + if tt.assert != nil { + tt.assert(t, testSuite.BHDatabase, findingId) + } + }) + } +} diff --git a/cmd/api/src/model/kind.go b/cmd/api/src/model/kind.go index 8f6402c682..95d3ad8f25 100644 --- a/cmd/api/src/model/kind.go +++ b/cmd/api/src/model/kind.go @@ -16,6 +16,6 @@ package model type Kind struct { - ID int `json:"id"` + ID int32 `json:"id"` Name string `json:"name"` } diff --git a/cmd/api/src/services/opengraphschema/extension.go b/cmd/api/src/services/opengraphschema/extension.go index a16df48021..58483bca8a 100644 --- a/cmd/api/src/services/opengraphschema/extension.go +++ b/cmd/api/src/services/opengraphschema/extension.go @@ -26,6 +26,7 @@ import ( func (s *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { var ( environments = make([]database.EnvironmentInput, len(req.Environments)) + findings = make([]database.FindingInput, len(req.Findings)) ) for i, environment := range req.Environments { @@ -36,8 +37,24 @@ func (s *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, } } + for i, finding := range req.Findings { + findings[i] = database.FindingInput{ + Name: finding.Name, + DisplayName: finding.DisplayName, + SourceKindName: finding.SourceKind, + RelationshipKindName: finding.RelationshipKind, + EnvironmentKindName: finding.EnvironmentKind, + RemediationInput: database.RemediationInput{ + ShortDescription: finding.Remediation.ShortDescription, + LongDescription: finding.Remediation.LongDescription, + ShortRemediation: finding.Remediation.ShortRemediation, + LongRemediation: finding.Remediation.LongRemediation, + }, + } + } + // TODO: Temporary hardcoded value but needs to be updated to pass in the extension ID - err := s.openGraphSchemaRepository.UpsertGraphSchemaExtension(ctx, 1, environments) + err := s.openGraphSchemaRepository.UpsertGraphSchemaExtension(ctx, 1, environments, findings) if err != nil { return fmt.Errorf("error upserting graph extension: %w", err) } diff --git a/cmd/api/src/services/opengraphschema/extension_test.go b/cmd/api/src/services/opengraphschema/extension_test.go index 0f81f8b881..98277daa57 100644 --- a/cmd/api/src/services/opengraphschema/extension_test.go +++ b/cmd/api/src/services/opengraphschema/extension_test.go @@ -34,6 +34,7 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { } type args struct { environments []v2.Environment + findings []v2.Finding } tests := []struct { name string @@ -51,6 +52,21 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { PrincipalKinds: []string{"User"}, }, }, + findings: []v2.Finding{ + { + Name: "Finding", + DisplayName: "DisplayName", + RelationshipKind: "Domain", + EnvironmentKind: "Domain", + SourceKind: "Base", + Remediation: v2.Remediation{ + ShortDescription: "Short Description", + LongDescription: "Long Description", + ShortRemediation: "Short Remediation", + LongRemediation: "Long Remediation", + }, + }, + }, }, setupMocks: func(t *testing.T, m *mocks) { t.Helper() @@ -61,16 +77,32 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { PrincipalKinds: []string{"User"}, }, } + expectedFindings := []database.FindingInput{ + { + Name: "Finding", + DisplayName: "DisplayName", + RelationshipKindName: "Domain", + EnvironmentKindName: "Domain", + SourceKindName: "Base", + RemediationInput: database.RemediationInput{ + ShortDescription: "Short Description", + LongDescription: "Long Description", + ShortRemediation: "Short Remediation", + LongRemediation: "Long Remediation", + }, + }, + } m.mockOpenGraphSchema.EXPECT().UpsertGraphSchemaExtension( gomock.Any(), int32(1), expectedEnvs, + expectedFindings, ).Return(errors.New("error")) }, expected: errors.New("error upserting graph extension: error"), }, { - name: "Success: single environment", + name: "Success: single environment with single finding", args: args{ environments: []v2.Environment{ { @@ -79,6 +111,21 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { PrincipalKinds: []string{"User", "Computer"}, }, }, + findings: []v2.Finding{ + { + Name: "Finding", + DisplayName: "DisplayName", + RelationshipKind: "Domain", + EnvironmentKind: "Domain", + SourceKind: "Base", + Remediation: v2.Remediation{ + ShortDescription: "Short Description", + LongDescription: "Long Description", + ShortRemediation: "Short Remediation", + LongRemediation: "Long Remediation", + }, + }, + }, }, setupMocks: func(t *testing.T, m *mocks) { t.Helper() @@ -89,16 +136,32 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { PrincipalKinds: []string{"User", "Computer"}, }, } + expectedFindings := []database.FindingInput{ + { + Name: "Finding", + DisplayName: "DisplayName", + RelationshipKindName: "Domain", + EnvironmentKindName: "Domain", + SourceKindName: "Base", + RemediationInput: database.RemediationInput{ + ShortDescription: "Short Description", + LongDescription: "Long Description", + ShortRemediation: "Short Remediation", + LongRemediation: "Long Remediation", + }, + }, + } m.mockOpenGraphSchema.EXPECT().UpsertGraphSchemaExtension( gomock.Any(), int32(1), expectedEnvs, + expectedFindings, ).Return(nil) }, expected: nil, }, { - name: "Success: multiple environments", + name: "Success: multiple environments with multiple findings", args: args{ environments: []v2.Environment{ { @@ -112,6 +175,34 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { PrincipalKinds: []string{"User", "Group"}, }, }, + findings: []v2.Finding{ + { + Name: "Finding1", + DisplayName: "DisplayName1", + RelationshipKind: "Domain", + EnvironmentKind: "Domain", + SourceKind: "Base", + Remediation: v2.Remediation{ + ShortDescription: "Short Description", + LongDescription: "Long Description", + ShortRemediation: "Short Remediation", + LongRemediation: "Long Remediation", + }, + }, + { + Name: "Finding2", + DisplayName: "DisplayName2", + RelationshipKind: "Domain", + EnvironmentKind: "Domain", + SourceKind: "Base", + Remediation: v2.Remediation{ + ShortDescription: "Short Description", + LongDescription: "Long Description", + ShortRemediation: "Short Remediation", + LongRemediation: "Long Remediation", + }, + }, + }, }, setupMocks: func(t *testing.T, m *mocks) { t.Helper() @@ -127,10 +218,39 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { PrincipalKinds: []string{"User", "Group"}, }, } + expectedFindings := []database.FindingInput{ + { + Name: "Finding1", + DisplayName: "DisplayName1", + RelationshipKindName: "Domain", + EnvironmentKindName: "Domain", + SourceKindName: "Base", + RemediationInput: database.RemediationInput{ + ShortDescription: "Short Description", + LongDescription: "Long Description", + ShortRemediation: "Short Remediation", + LongRemediation: "Long Remediation", + }, + }, + { + Name: "Finding2", + DisplayName: "DisplayName2", + RelationshipKindName: "Domain", + EnvironmentKindName: "Domain", + SourceKindName: "Base", + RemediationInput: database.RemediationInput{ + ShortDescription: "Short Description", + LongDescription: "Long Description", + ShortRemediation: "Short Remediation", + LongRemediation: "Long Remediation", + }, + }, + } m.mockOpenGraphSchema.EXPECT().UpsertGraphSchemaExtension( gomock.Any(), int32(1), expectedEnvs, + expectedFindings, ).Return(nil) }, expected: nil, @@ -152,6 +272,7 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { err := service.UpsertGraphSchemaExtension(context.Background(), v2.GraphSchemaExtension{ Environments: tt.args.environments, + Findings: tt.args.findings, }) if tt.expected != nil { diff --git a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go index f1e1539a2a..2653fb152b 100644 --- a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go @@ -58,15 +58,15 @@ func (m *MockOpenGraphSchemaRepository) EXPECT() *MockOpenGraphSchemaRepositoryM } // UpsertGraphSchemaExtension mocks base method. -func (m *MockOpenGraphSchemaRepository) UpsertGraphSchemaExtension(ctx context.Context, extensionID int32, environments []database.EnvironmentInput) error { +func (m *MockOpenGraphSchemaRepository) UpsertGraphSchemaExtension(ctx context.Context, extensionID int32, environments []database.EnvironmentInput, findings []database.FindingInput) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpsertGraphSchemaExtension", ctx, extensionID, environments) + ret := m.ctrl.Call(m, "UpsertGraphSchemaExtension", ctx, extensionID, environments, findings) ret0, _ := ret[0].(error) return ret0 } // UpsertGraphSchemaExtension indicates an expected call of UpsertGraphSchemaExtension. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) UpsertGraphSchemaExtension(ctx, extensionID, environments any) *gomock.Call { +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) UpsertGraphSchemaExtension(ctx, extensionID, environments, findings any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertGraphSchemaExtension", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).UpsertGraphSchemaExtension), ctx, extensionID, environments) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertGraphSchemaExtension", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).UpsertGraphSchemaExtension), ctx, extensionID, environments, findings) } diff --git a/cmd/api/src/services/opengraphschema/opengraphschema.go b/cmd/api/src/services/opengraphschema/opengraphschema.go index 74322b76bf..cf9fb9f318 100644 --- a/cmd/api/src/services/opengraphschema/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/opengraphschema.go @@ -25,7 +25,7 @@ import ( // OpenGraphSchemaRepository - type OpenGraphSchemaRepository interface { - UpsertGraphSchemaExtension(ctx context.Context, extensionID int32, environments []database.EnvironmentInput) error + UpsertGraphSchemaExtension(ctx context.Context, extensionID int32, environments []database.EnvironmentInput, findings []database.FindingInput) error } type OpenGraphSchemaService struct {