From 29da7938906c8df5d8831bebce1acea2076d995a Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Tue, 6 Jan 2026 10:13:26 -0600 Subject: [PATCH 01/36] environment added to opengraphschema service + principal kinds - lots of questions about validations --- LICENSE.header | 2 +- cmd/api/src/database/kind.go | 19 + cmd/api/src/model/graphschema.go | 30 +- cmd/api/src/model/kind.go | 6 + .../services/opengraphschema/environment.go | 149 +++++ .../opengraphschema/environment_test.go | 535 ++++++++++++++++++ .../opengraphschema/mocks/opengraphschema.go | 192 +++++++ .../opengraphschema/opengraphschema.go | 56 ++ packages/csharp/graphschema/PropertyNames.cs | 2 +- packages/go/graphschema/ad/ad.go | 2 +- packages/go/graphschema/azure/azure.go | 2 +- packages/go/graphschema/common/common.go | 2 +- packages/go/graphschema/graph.go | 2 +- .../bh-shared-ui/src/graphSchema.ts | 2 +- .../SelectedDetailsTabs.test.tsx | 2 +- 15 files changed, 994 insertions(+), 9 deletions(-) create mode 100644 cmd/api/src/database/kind.go create mode 100644 cmd/api/src/model/kind.go create mode 100644 cmd/api/src/services/opengraphschema/environment.go create mode 100644 cmd/api/src/services/opengraphschema/environment_test.go create mode 100644 cmd/api/src/services/opengraphschema/mocks/opengraphschema.go create mode 100644 cmd/api/src/services/opengraphschema/opengraphschema.go diff --git a/LICENSE.header b/LICENSE.header index 5d5c596b1c..958be83318 100644 --- a/LICENSE.header +++ b/LICENSE.header @@ -1,4 +1,4 @@ -Copyright 2025 Specter Ops, Inc. +Copyright 2026 Specter Ops, Inc. Licensed under the Apache License, Version 2.0 you may not use this file except in compliance with the License. diff --git a/cmd/api/src/database/kind.go b/cmd/api/src/database/kind.go new file mode 100644 index 0000000000..c89049ceb8 --- /dev/null +++ b/cmd/api/src/database/kind.go @@ -0,0 +1,19 @@ +package database + +import ( + "context" + "fmt" + + "github.com/specterops/bloodhound/cmd/api/src/model" +) + +// GetKindById gets a row from the kind table by id. +func (s *BloodhoundDB) GetKindById(ctx context.Context, id int32) (model.Kind, error) { + var kind model.Kind + + // var kind Kind + return kind, CheckError(s.db.WithContext(ctx).Raw(fmt.Sprintf(` + SELECT id, name + FROM %s WHERE id = ?`, kindTable), id).First(&kind), + ) +} diff --git a/cmd/api/src/model/graphschema.go b/cmd/api/src/model/graphschema.go index 32fcfbbf48..cc343e5197 100644 --- a/cmd/api/src/model/graphschema.go +++ b/cmd/api/src/model/graphschema.go @@ -16,7 +16,10 @@ package model -import "time" +import ( + "fmt" + "time" +) type GraphSchemaExtensions []GraphSchemaExtension @@ -98,6 +101,20 @@ func (GraphSchemaEdgeKind) TableName() string { return "schema_edge_kinds" } +// GraphSchemaEnvironments - slice of environments +type GraphSchemaEnvironments []SchemaEnvironment + +// ToMapKeyedOnEnvironmentAndSource - converts a list of graph schema environments to a map based on environment kind id and source id +func (s GraphSchemaEnvironments) ToMapKeyedOnEnvironmentAndSource() map[string]SchemaEnvironment { + m := make(map[string]SchemaEnvironment) + for _, env := range s { + // Key is environment id + source id separated with an underscore e.g., kindid_sourceid + key := fmt.Sprintf("%d_%d", env.EnvironmentKindId, env.SourceKindId) + m[key] = env + } + return m +} + type SchemaEnvironment struct { Serial SchemaExtensionId int32 `json:"schema_extension_id"` @@ -144,3 +161,14 @@ type GraphSchemaEdgeKindWithNamedSchema struct { } type GraphSchemaEdgeKindsWithNamedSchema []GraphSchemaEdgeKindWithNamedSchema + +type SchemaEnvironmentPrincipalKinds []SchemaEnvironmentPrincipalKind + +type SchemaEnvironmentPrincipalKind struct { + EnvironmentId int32 `json:"environment_id"` + PrincipalKind int32 `json:"principal_kind"` +} + +func (SchemaEnvironmentPrincipalKind) TableName() string { + return "schema_environments_principal_kinds" +} diff --git a/cmd/api/src/model/kind.go b/cmd/api/src/model/kind.go new file mode 100644 index 0000000000..786fa0a5bf --- /dev/null +++ b/cmd/api/src/model/kind.go @@ -0,0 +1,6 @@ +package model + +type Kind struct { + ID int `json:"id"` + Name string `json:"name"` +} diff --git a/cmd/api/src/services/opengraphschema/environment.go b/cmd/api/src/services/opengraphschema/environment.go new file mode 100644 index 0000000000..492da32a49 --- /dev/null +++ b/cmd/api/src/services/opengraphschema/environment.go @@ -0,0 +1,149 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package opengraphschema + +import ( + "context" + "errors" + "fmt" + + "github.com/specterops/bloodhound/cmd/api/src/database" + "github.com/specterops/bloodhound/cmd/api/src/model" +) + +func (o *OpenGraphSchemaService) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, environment model.SchemaEnvironment, principalKinds []model.SchemaEnvironmentPrincipalKind) error { + // Validate environment + if err := o.validateSchemaEnvironment(ctx, environment); err != nil { + return fmt.Errorf("error validating schema environment: %w", err) + } + + // Validate principal kinds + for _, kind := range principalKinds { + if err := o.validateSchemaEnvironmentPrincipalKind(ctx, kind.PrincipalKind); err != nil { + return fmt.Errorf("error validating principal kind: %w", err) + } + } + + // Upsert the environment + id, err := o.upsertSchemaEnvironment(ctx, environment) + if err != nil { + return fmt.Errorf("error upserting schema environment: %w", err) + } + + // Upsert principal kinds for environment + if err := o.upsertPrincipalKinds(ctx, id, principalKinds); err != nil { + return fmt.Errorf("error upserting principal kinds: %w", err) + } + + return nil +} + +func (o *OpenGraphSchemaService) upsertSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) (int32, error) { + if existing, err := o.openGraphSchemaRepository.GetSchemaEnvironmentById(ctx, graphSchema.ID); err != nil && !errors.Is(err, database.ErrNotFound) { + return 0, fmt.Errorf("error retrieving schema environment id %d: %w", graphSchema.ID, err) + } else if !errors.Is(err, database.ErrNotFound) { + // Environment exists - delete it first + if err := o.openGraphSchemaRepository.DeleteSchemaEnvironment(ctx, existing.ID); err != nil { + return 0, fmt.Errorf("error deleting schema environment %d: %w", existing.ID, err) + } + } + + // Create Environment + if created, err := o.openGraphSchemaRepository.CreateSchemaEnvironment(ctx, graphSchema.SchemaExtensionId, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil { + return 0, fmt.Errorf("error creating schema environment: %w", err) + } else { + return created.ID, nil + } +} + +/* +Validations: https://github.com/SpecterOps/BloodHound/blob/73b569a340ef5cd459b383e3e42e707b201193ee/rfc/bh-rfc-4.md#10-validation-rules-for-environments + 1. Ensure the specified environmentKind exists in kinds table + ** QUESTION: Documentation states to use the kind table. Is there already a database method to query the kind table to do this validation or was I supposed to create it? + 2. Ensure the specified sourceKind exists in source_kinds table (create if it doesn't, reactivate if it does) +*/ +func (o *OpenGraphSchemaService) validateSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) error { + // Validate environment kind id exists in kinds table + if _, err := o.openGraphSchemaRepository.GetKindById(ctx, graphSchema.EnvironmentKindId); err != nil { + return fmt.Errorf("error retrieving environment kind: %w", err) + } + + // Get all source kinds + if sourceKinds, err := o.openGraphSchemaRepository.GetSourceKinds(ctx); err != nil { + return fmt.Errorf("error retrieving source kinds: %w", err) + } else { + // Check if source kind exists + found := false + for _, kind := range sourceKinds { + if graphSchema.SourceKindId == int32(kind.ID) { + found = true + break + } + } + + if !found { + /* + ** QUESTION: Example Environment Schema: https://github.com/SpecterOps/BloodHound/blob/73b569a340ef5cd459b383e3e42e707b201193ee/rfc/bh-rfc-4.md#9-environments-and-principal-kinds + The RFC example uses source kind names (strings) in the environment schema, but our model uses IDs (int32). To register a new source kind, we need the + kind name/data, not just the ID. Cannot register with only an ID - kind id/name is required for registration. + RegisterSourceKind(ctx context.Context) func(sourceKind graph.Kind) error + + For now, this validates that the source kind should exist. + */ + return fmt.Errorf("invalid source kind id %d", graphSchema.SourceKindId) + } + } + + return nil +} + +func (o *OpenGraphSchemaService) upsertPrincipalKinds(ctx context.Context, environmentID int32, principalKinds []model.SchemaEnvironmentPrincipalKind) error { + if existingKinds, err := o.openGraphSchemaRepository.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentID); err != nil && !errors.Is(err, database.ErrNotFound) { + return fmt.Errorf("error retrieving existing principal kinds for environment %d: %w", environmentID, err) + } else if !errors.Is(err, database.ErrNotFound) { + // Delete all existing principal kinds + for _, kind := range existingKinds { + if err := o.openGraphSchemaRepository.DeleteSchemaEnvironmentPrincipalKind(ctx, kind.EnvironmentId, kind.PrincipalKind); err != nil { + return fmt.Errorf("error deleting principal kind %d for environment %d: %w", kind.PrincipalKind, kind.EnvironmentId, err) + } + } + } + + // Create the new principal kinds + for _, kind := range principalKinds { + if _, err := o.openGraphSchemaRepository.CreateSchemaEnvironmentPrincipalKind(ctx, environmentID, kind.PrincipalKind); err != nil { + return fmt.Errorf("error creating principal kind %d for environment %d: %w", kind.PrincipalKind, environmentID, err) + } + } + + return nil +} + +/* +Validations: https://github.com/SpecterOps/BloodHound/blob/73b569a340ef5cd459b383e3e42e707b201193ee/rfc/bh-rfc-4.md#10-validation-rules-for-environments +1. Ensure all principalKinds exist in kinds table. + +** QUESTION: Documentation states to use the kind table. Is there already a database method to query the kind table to do this validation or was I supposed to create it? +*/ +func (o *OpenGraphSchemaService) validateSchemaEnvironmentPrincipalKind(ctx context.Context, kindID int32) error { + if _, err := o.openGraphSchemaRepository.GetKindById(ctx, kindID); err != nil && !errors.Is(err, database.ErrNotFound) { + return fmt.Errorf("error retrieving kind by id: %w", err) + } else if errors.Is(err, database.ErrNotFound) { + return fmt.Errorf("invalid principal kind id %d", kindID) + } + + return nil +} diff --git a/cmd/api/src/services/opengraphschema/environment_test.go b/cmd/api/src/services/opengraphschema/environment_test.go new file mode 100644 index 0000000000..991509a6ed --- /dev/null +++ b/cmd/api/src/services/opengraphschema/environment_test.go @@ -0,0 +1,535 @@ +package opengraphschema_test + +import ( + "context" + "errors" + "testing" + + "github.com/specterops/bloodhound/cmd/api/src/database" + "github.com/specterops/bloodhound/cmd/api/src/model" + "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" + schemamocks "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema/mocks" + "github.com/specterops/dawgs/graph" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { + type mocks struct { + mockOpenGraphSchema *schemamocks.MockOpenGraphSchemaRepository + } + type args struct { + environment model.SchemaEnvironment + principalKinds []model.SchemaEnvironmentPrincipalKind + } + tests := []struct { + name string + mocks mocks + setupMocks func(t *testing.T, mock *mocks) + args args + expected error + }{ + { + name: "Error: Validation - Failed to retrieve environment kind", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(1), + EnvironmentKindId: int32(1), + SourceKindId: int32(1), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{}, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(1)).Return(model.Kind{}, errors.New("error")) + }, + expected: errors.New("error validating schema environment: error retrieving environment kind: error"), + }, + { + name: "Error: Validation - Environment Kind doesn't exist in Kinds table", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(1), + EnvironmentKindId: int32(1), + SourceKindId: int32(1), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{}, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(1)).Return(model.Kind{}, database.ErrNotFound) + }, + expected: errors.New("error validating schema environment: error retrieving environment kind: entity not found"), + }, + { + name: "Error: Validation - Failed to retrieve source kinds", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(1), + EnvironmentKindId: int32(1), + SourceKindId: int32(1), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{}, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(1)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{}, errors.New("error")) + }, + expected: errors.New("error validating schema environment: error retrieving source kinds: error"), + }, + { + name: "Error: Validation - Source Kind doesn't exist", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(1), + EnvironmentKindId: int32(3), + SourceKindId: int32(1), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{}, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + }, + expected: errors.New("error validating schema environment: invalid source kind id 1"), + }, + { + name: "Error: Validation - Failed to retrieve principal kind", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(1), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{ + { + EnvironmentId: int32(1), + PrincipalKind: int32(99), + }, + }, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + // Environment validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + // Principal kind validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(99)).Return(model.Kind{}, errors.New("error")) + }, + expected: errors.New("error validating principal kind: error retrieving kind by id: error"), + }, + { + name: "Error: Validation - Principal Kind doesn't exist", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(1), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{ + { + EnvironmentId: int32(1), + PrincipalKind: int32(99), + }, + }, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + // Environment validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + // Principal kind validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(99)).Return(model.Kind{}, database.ErrNotFound) + }, + expected: errors.New("error validating principal kind: invalid principal kind id 99"), + }, + { + name: "Error: GetSchemaEnvironmentById fails", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(3), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{}, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, errors.New("error")) + }, + expected: errors.New("error upserting schema environment: error retrieving schema environment id 0: error"), + }, + { + name: "Error: DeleteSchemaEnvironment fails", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(3), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{}, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{ + Serial: model.Serial{ID: 5}, + }, nil) + mocks.mockOpenGraphSchema.EXPECT().DeleteSchemaEnvironment(gomock.Any(), int32(5)).Return(errors.New("error")) + }, + expected: errors.New("error upserting schema environment: error deleting schema environment 5: error"), + }, + { + name: "Error: CreateSchemaEnvironment fails after delete", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(3), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{}, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{ + Serial: model.Serial{ID: 5}, + }, nil) + mocks.mockOpenGraphSchema.EXPECT().DeleteSchemaEnvironment(gomock.Any(), int32(5)).Return(nil) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{}, errors.New("error")) + }, + expected: errors.New("error upserting schema environment: error creating schema environment: error"), + }, + { + name: "Error: CreateSchemaEnvironment fails on new environment", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(3), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{}, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{}, errors.New("error")) + }, + expected: errors.New("error upserting schema environment: error creating schema environment: error"), + }, + { + name: "Error: GetSchemaEnvironmentPrincipalKindsByEnvironmentId fails", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(3), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{ + { + PrincipalKind: int32(3), + }, + }, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + // Environment validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + // Principal kind validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + + // Environment upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ + Serial: model.Serial{ID: 10}, + }, nil) + + // Principal kinds upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, errors.New("error")) + }, + expected: errors.New("error upserting principal kinds: error retrieving existing principal kinds for environment 10: error"), + }, + { + name: "Error: DeleteSchemaEnvironmentPrincipalKind fails", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(3), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{ + { + PrincipalKind: int32(3), + }, + }, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + // Environment validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + // Principal kind validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + + // Environment upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ + Serial: model.Serial{ID: 10}, + }, nil) + + // Principal kinds upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return([]model.SchemaEnvironmentPrincipalKind{ + { + EnvironmentId: int32(10), + PrincipalKind: int32(5), + }, + }, nil) + mocks.mockOpenGraphSchema.EXPECT().DeleteSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(5)).Return(errors.New("error")) + }, + expected: errors.New("error upserting principal kinds: error deleting principal kind 5 for environment 10: error"), + }, + { + name: "Error: CreateSchemaEnvironmentPrincipalKind fails", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(3), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{ + { + PrincipalKind: int32(3), + }, + }, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + // Environment validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + // Principal kind validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + + // Environment upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ + Serial: model.Serial{ID: 10}, + }, nil) + + // Principal kinds upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(3)).Return(model.SchemaEnvironmentPrincipalKind{}, errors.New("error")) + }, + expected: errors.New("error upserting principal kinds: error creating principal kind 3 for environment 10: error"), + }, + { + name: "Success: Create new environment with principal kinds", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(3), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{ + { + PrincipalKind: int32(3), + }, + { + PrincipalKind: int32(4), + }, + }, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + // Environment validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind3"), + }, + { + ID: 4, + Name: graph.StringKind("kind4"), + }, + }, nil) + // Principal kind validations + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(4)).Return(model.Kind{}, nil) + + // Environment upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ + Serial: model.Serial{ID: 10}, + }, nil) + + // Principal kinds upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(3)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(4)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) + }, + expected: nil, + }, + { + name: "Success: Update existing environment and replace principal kinds", + args: args{ + environment: model.SchemaEnvironment{ + Serial: model.Serial{ID: 5}, + SchemaExtensionId: int32(3), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{ + { + PrincipalKind: int32(3), + }, + }, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + // Environment validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + // Principal kind validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + + // Environment upsert (delete and recreate) + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(5)).Return(model.SchemaEnvironment{ + Serial: model.Serial{ID: 5}, + }, nil) + mocks.mockOpenGraphSchema.EXPECT().DeleteSchemaEnvironment(gomock.Any(), int32(5)).Return(nil) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ + Serial: model.Serial{ID: 10}, + }, nil) + + // Principal kinds upsert (delete old, create new) + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return([]model.SchemaEnvironmentPrincipalKind{ + { + EnvironmentId: int32(10), + PrincipalKind: int32(99), + }, + }, nil) + mocks.mockOpenGraphSchema.EXPECT().DeleteSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(99)).Return(nil) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(3)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) + }, + expected: nil, + }, + { + name: "Success: Create environment with no principal kinds", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(3), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{}, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + // Environment validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + + // Environment upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ + Serial: model.Serial{ID: 10}, + }, nil) + + // Principal kinds upsert (no existing, no new) + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, database.ErrNotFound) + }, + expected: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + mocks := &mocks{ + mockOpenGraphSchema: schemamocks.NewMockOpenGraphSchemaRepository(ctrl), + } + + tt.setupMocks(t, mocks) + + graphService := opengraphschema.NewOpenGraphSchemaService(mocks.mockOpenGraphSchema) + + err := graphService.UpsertSchemaEnvironmentWithPrincipalKinds(context.Background(), tt.args.environment, tt.args.principalKinds) + if tt.expected != nil { + assert.EqualError(t, tt.expected, err.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go new file mode 100644 index 0000000000..7723197208 --- /dev/null +++ b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go @@ -0,0 +1,192 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema (interfaces: OpenGraphSchemaRepository) +// +// Generated by this command: +// +// mockgen -copyright_file ../../../../../LICENSE.header -destination=./mocks/opengraphschema.go -package=mocks . OpenGraphSchemaRepository +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + database "github.com/specterops/bloodhound/cmd/api/src/database" + model "github.com/specterops/bloodhound/cmd/api/src/model" + graph "github.com/specterops/dawgs/graph" + gomock "go.uber.org/mock/gomock" +) + +// MockOpenGraphSchemaRepository is a mock of OpenGraphSchemaRepository interface. +type MockOpenGraphSchemaRepository struct { + ctrl *gomock.Controller + recorder *MockOpenGraphSchemaRepositoryMockRecorder + isgomock struct{} +} + +// MockOpenGraphSchemaRepositoryMockRecorder is the mock recorder for MockOpenGraphSchemaRepository. +type MockOpenGraphSchemaRepositoryMockRecorder struct { + mock *MockOpenGraphSchemaRepository +} + +// NewMockOpenGraphSchemaRepository creates a new mock instance. +func NewMockOpenGraphSchemaRepository(ctrl *gomock.Controller) *MockOpenGraphSchemaRepository { + mock := &MockOpenGraphSchemaRepository{ctrl: ctrl} + mock.recorder = &MockOpenGraphSchemaRepositoryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOpenGraphSchemaRepository) EXPECT() *MockOpenGraphSchemaRepositoryMockRecorder { + return m.recorder +} + +// CreateSchemaEnvironment mocks base method. +func (m *MockOpenGraphSchemaRepository) CreateSchemaEnvironment(ctx context.Context, schemaExtensionId, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSchemaEnvironment", ctx, schemaExtensionId, environmentKindId, sourceKindId) + ret0, _ := ret[0].(model.SchemaEnvironment) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateSchemaEnvironment indicates an expected call of CreateSchemaEnvironment. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) CreateSchemaEnvironment(ctx, schemaExtensionId, environmentKindId, sourceKindId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironment", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).CreateSchemaEnvironment), ctx, schemaExtensionId, environmentKindId, sourceKindId) +} + +// CreateSchemaEnvironmentPrincipalKind mocks base method. +func (m *MockOpenGraphSchemaRepository) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) + ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKind) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateSchemaEnvironmentPrincipalKind indicates an expected call of CreateSchemaEnvironmentPrincipalKind. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) CreateSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).CreateSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) +} + +// DeleteSchemaEnvironment mocks base method. +func (m *MockOpenGraphSchemaRepository) DeleteSchemaEnvironment(ctx context.Context, schemaEnvironmentId int32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteSchemaEnvironment", ctx, schemaEnvironmentId) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteSchemaEnvironment indicates an expected call of DeleteSchemaEnvironment. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) DeleteSchemaEnvironment(ctx, schemaEnvironmentId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironment", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).DeleteSchemaEnvironment), ctx, schemaEnvironmentId) +} + +// DeleteSchemaEnvironmentPrincipalKind mocks base method. +func (m *MockOpenGraphSchemaRepository) DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteSchemaEnvironmentPrincipalKind indicates an expected call of DeleteSchemaEnvironmentPrincipalKind. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) DeleteSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).DeleteSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) +} + +// GetKindById mocks base method. +func (m *MockOpenGraphSchemaRepository) GetKindById(ctx context.Context, id int32) (model.Kind, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetKindById", ctx, id) + ret0, _ := ret[0].(model.Kind) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetKindById indicates an expected call of GetKindById. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetKindById(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKindById", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetKindById), ctx, id) +} + +// GetSchemaEnvironmentById mocks base method. +func (m *MockOpenGraphSchemaRepository) GetSchemaEnvironmentById(ctx context.Context, schemaEnvironmentId int32) (model.SchemaEnvironment, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSchemaEnvironmentById", ctx, schemaEnvironmentId) + ret0, _ := ret[0].(model.SchemaEnvironment) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSchemaEnvironmentById indicates an expected call of GetSchemaEnvironmentById. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetSchemaEnvironmentById(ctx, schemaEnvironmentId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentById", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetSchemaEnvironmentById), ctx, schemaEnvironmentId) +} + +// GetSchemaEnvironmentPrincipalKindsByEnvironmentId mocks base method. +func (m *MockOpenGraphSchemaRepository) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", ctx, environmentId) + ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKinds) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSchemaEnvironmentPrincipalKindsByEnvironmentId indicates an expected call of GetSchemaEnvironmentPrincipalKindsByEnvironmentId. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetSchemaEnvironmentPrincipalKindsByEnvironmentId), ctx, environmentId) +} + +// GetSourceKinds mocks base method. +func (m *MockOpenGraphSchemaRepository) GetSourceKinds(ctx context.Context) ([]database.SourceKind, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSourceKinds", ctx) + ret0, _ := ret[0].([]database.SourceKind) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSourceKinds indicates an expected call of GetSourceKinds. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetSourceKinds(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceKinds", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetSourceKinds), ctx) +} + +// RegisterSourceKind mocks base method. +func (m *MockOpenGraphSchemaRepository) RegisterSourceKind(ctx context.Context) func(graph.Kind) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterSourceKind", ctx) + ret0, _ := ret[0].(func(graph.Kind) error) + return ret0 +} + +// RegisterSourceKind indicates an expected call of RegisterSourceKind. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) RegisterSourceKind(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterSourceKind", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).RegisterSourceKind), ctx) +} diff --git a/cmd/api/src/services/opengraphschema/opengraphschema.go b/cmd/api/src/services/opengraphschema/opengraphschema.go new file mode 100644 index 0000000000..1ff658041d --- /dev/null +++ b/cmd/api/src/services/opengraphschema/opengraphschema.go @@ -0,0 +1,56 @@ +// Copyright 2025 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package opengraphschema + +//go:generate go run go.uber.org/mock/mockgen -copyright_file ../../../../../LICENSE.header -destination=./mocks/opengraphschema.go -package=mocks . OpenGraphSchemaRepository + +import ( + "context" + + "github.com/specterops/bloodhound/cmd/api/src/database" + "github.com/specterops/bloodhound/cmd/api/src/model" + "github.com/specterops/dawgs/graph" +) + +// OpenGraphSchemaRepository - +type OpenGraphSchemaRepository interface { + // Kinds + GetKindById(ctx context.Context, id int32) (model.Kind, error) + + // Environment + CreateSchemaEnvironment(ctx context.Context, schemaExtensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) + GetSchemaEnvironmentById(ctx context.Context, schemaEnvironmentId int32) (model.SchemaEnvironment, error) + DeleteSchemaEnvironment(ctx context.Context, schemaEnvironmentId int32) error + + // Source Kinds + RegisterSourceKind(ctx context.Context) func(sourceKind graph.Kind) error + GetSourceKinds(ctx context.Context) ([]database.SourceKind, error) + + // Principal Kinds + CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) + GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) + DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error +} + +type OpenGraphSchemaService struct { + openGraphSchemaRepository OpenGraphSchemaRepository +} + +func NewOpenGraphSchemaService(openGraphSchemaRepository OpenGraphSchemaRepository) *OpenGraphSchemaService { + return &OpenGraphSchemaService{ + openGraphSchemaRepository: openGraphSchemaRepository, + } +} diff --git a/packages/csharp/graphschema/PropertyNames.cs b/packages/csharp/graphschema/PropertyNames.cs index fc3d6b08a4..2031b9df97 100644 --- a/packages/csharp/graphschema/PropertyNames.cs +++ b/packages/csharp/graphschema/PropertyNames.cs @@ -1,5 +1,5 @@ /* - Copyright 2025 Specter Ops, Inc. + Copyright 2026 Specter Ops, Inc. Licensed under the Apache License, Version 2.0 you may not use this file except in compliance with the License. diff --git a/packages/go/graphschema/ad/ad.go b/packages/go/graphschema/ad/ad.go index c7a589ab21..c9411682e6 100644 --- a/packages/go/graphschema/ad/ad.go +++ b/packages/go/graphschema/ad/ad.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/packages/go/graphschema/azure/azure.go b/packages/go/graphschema/azure/azure.go index 3ea1483949..709eab6828 100644 --- a/packages/go/graphschema/azure/azure.go +++ b/packages/go/graphschema/azure/azure.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/packages/go/graphschema/common/common.go b/packages/go/graphschema/common/common.go index f18f99a486..057e552ab7 100644 --- a/packages/go/graphschema/common/common.go +++ b/packages/go/graphschema/common/common.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/packages/go/graphschema/graph.go b/packages/go/graphschema/graph.go index aedef13acd..acb1385859 100644 --- a/packages/go/graphschema/graph.go +++ b/packages/go/graphschema/graph.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/packages/javascript/bh-shared-ui/src/graphSchema.ts b/packages/javascript/bh-shared-ui/src/graphSchema.ts index 7ede0cf403..c92263f53e 100644 --- a/packages/javascript/bh-shared-ui/src/graphSchema.ts +++ b/packages/javascript/bh-shared-ui/src/graphSchema.ts @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/packages/javascript/bh-shared-ui/src/views/PrivilegeZones/Details/SelectedDetailsTabs/SelectedDetailsTabs.test.tsx b/packages/javascript/bh-shared-ui/src/views/PrivilegeZones/Details/SelectedDetailsTabs/SelectedDetailsTabs.test.tsx index dab44ff51a..e5d1bb755d 100644 --- a/packages/javascript/bh-shared-ui/src/views/PrivilegeZones/Details/SelectedDetailsTabs/SelectedDetailsTabs.test.tsx +++ b/packages/javascript/bh-shared-ui/src/views/PrivilegeZones/Details/SelectedDetailsTabs/SelectedDetailsTabs.test.tsx @@ -173,4 +173,4 @@ describe('Selected Details Tabs', () => { expect(objectTab).toBeEnabled(); }); }); -}); \ No newline at end of file +}); From a7f351eaa20e8fbfbdb6091bbc9ce9f3a6872e65 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Tue, 6 Jan 2026 18:05:38 -0600 Subject: [PATCH 02/36] did an overhaul of everything after talking with Cody and pulled a couple of Lawson's changes to test the environment upsert with a http endpoint --- cmd/api/src/api/mocks/authenticator.go | 2 +- cmd/api/src/api/registration/registration.go | 3 +- cmd/api/src/api/registration/v2.go | 2 + .../src/api/v2/mocks/graphschemaextensions.go | 72 +++ cmd/api/src/api/v2/model.go | 3 + cmd/api/src/api/v2/opengraphschema.go | 62 +++ cmd/api/src/daemons/datapipe/mocks/cleanup.go | 2 +- cmd/api/src/database/graphschema.go | 41 ++ cmd/api/src/database/kind.go | 34 +- cmd/api/src/database/kind_integration_test.go | 75 +++ .../database/migration/migrations/v8.5.0.sql | 4 + cmd/api/src/database/mocks/auth.go | 2 +- cmd/api/src/database/mocks/db.go | 17 +- cmd/api/src/database/sourcekinds.go | 28 + .../database/sourcekinds_integration_test.go | 50 ++ cmd/api/src/model/graphschema.go | 15 - cmd/api/src/model/kind.go | 15 + cmd/api/src/queries/mocks/graph.go | 2 +- cmd/api/src/services/agi/mocks/mock.go | 2 +- .../src/services/dataquality/mocks/mock.go | 2 +- cmd/api/src/services/entrypoint.go | 20 +- cmd/api/src/services/fs/mocks/fs.go | 2 +- cmd/api/src/services/graphify/mocks/ingest.go | 2 +- cmd/api/src/services/oidc/mocks/oidc.go | 2 +- .../services/opengraphschema/environment.go | 161 +++--- .../opengraphschema/environment_test.go | 509 +++++++++--------- .../opengraphschema/mocks/opengraphschema.go | 26 +- .../opengraphschema/opengraphschema.go | 4 +- cmd/api/src/services/saml/mocks/saml.go | 2 +- cmd/api/src/services/upload/mocks/mock.go | 2 +- .../src/utils/validation/mocks/validator.go | 2 +- cmd/api/src/vendormocks/dawgs/graph/mock.go | 2 +- cmd/api/src/vendormocks/io/fs/mock.go | 2 +- .../neo4j/neo4j-go-driver/v5/neo4j/mock.go | 2 +- packages/go/crypto/mocks/digest.go | 2 +- 35 files changed, 769 insertions(+), 404 deletions(-) create mode 100644 cmd/api/src/api/v2/mocks/graphschemaextensions.go create mode 100644 cmd/api/src/api/v2/opengraphschema.go create mode 100644 cmd/api/src/database/kind_integration_test.go diff --git a/cmd/api/src/api/mocks/authenticator.go b/cmd/api/src/api/mocks/authenticator.go index 32cbd5053e..8811703cb3 100644 --- a/cmd/api/src/api/mocks/authenticator.go +++ b/cmd/api/src/api/mocks/authenticator.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/api/registration/registration.go b/cmd/api/src/api/registration/registration.go index e06a494d23..bd3938c32a 100644 --- a/cmd/api/src/api/registration/registration.go +++ b/cmd/api/src/api/registration/registration.go @@ -62,6 +62,7 @@ func RegisterFossRoutes( authenticator api.Authenticator, authorizer auth.Authorizer, ingestSchema upload.IngestSchema, + openGraphSchemaService v2.OpenGraphSchemaService, ) { router.With(func() mux.MiddlewareFunc { return middleware.DefaultRateLimitMiddleware(rdms) @@ -80,6 +81,6 @@ func RegisterFossRoutes( routerInst.PathPrefix("/ui", static.AssetHandler), ) - var resources = v2.NewResources(rdms, graphDB, cfg, apiCache, graphQuery, collectorManifests, authorizer, authenticator, ingestSchema) + var resources = v2.NewResources(rdms, graphDB, cfg, apiCache, graphQuery, collectorManifests, authorizer, authenticator, ingestSchema, openGraphSchemaService) NewV2API(resources, routerInst) } diff --git a/cmd/api/src/api/registration/v2.go b/cmd/api/src/api/registration/v2.go index 5b04638b4b..0bf9a66f44 100644 --- a/cmd/api/src/api/registration/v2.go +++ b/cmd/api/src/api/registration/v2.go @@ -364,5 +364,7 @@ func NewV2API(resources v2.Resources, routerInst *router.Router) { routerInst.POST("/api/v2/custom-nodes", resources.CreateCustomNodeKind).RequireAuth(), routerInst.PUT(fmt.Sprintf("/api/v2/custom-nodes/{%s}", v2.CustomNodeKindParameter), resources.UpdateCustomNodeKind).RequireAuth(), routerInst.DELETE(fmt.Sprintf("/api/v2/custom-nodes/{%s}", v2.CustomNodeKindParameter), resources.DeleteCustomNodeKind).RequireAuth(), + + routerInst.PUT("/api/v2/extensions", resources.OpenGraphSchemaIngest), ) } diff --git a/cmd/api/src/api/v2/mocks/graphschemaextensions.go b/cmd/api/src/api/v2/mocks/graphschemaextensions.go new file mode 100644 index 0000000000..96cb240811 --- /dev/null +++ b/cmd/api/src/api/v2/mocks/graphschemaextensions.go @@ -0,0 +1,72 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/specterops/bloodhound/cmd/api/src/api/v2 (interfaces: OpenGraphSchemaService) +// +// Generated by this command: +// +// mockgen -copyright_file ../../../../../LICENSE.header -destination=./mocks/graphschemaextensions.go -package=mocks . OpenGraphSchemaService +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" + gomock "go.uber.org/mock/gomock" +) + +// MockOpenGraphSchemaService is a mock of OpenGraphSchemaService interface. +type MockOpenGraphSchemaService struct { + ctrl *gomock.Controller + recorder *MockOpenGraphSchemaServiceMockRecorder + isgomock struct{} +} + +// MockOpenGraphSchemaServiceMockRecorder is the mock recorder for MockOpenGraphSchemaService. +type MockOpenGraphSchemaServiceMockRecorder struct { + mock *MockOpenGraphSchemaService +} + +// NewMockOpenGraphSchemaService creates a new mock instance. +func NewMockOpenGraphSchemaService(ctrl *gomock.Controller) *MockOpenGraphSchemaService { + mock := &MockOpenGraphSchemaService{ctrl: ctrl} + mock.recorder = &MockOpenGraphSchemaServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOpenGraphSchemaService) EXPECT() *MockOpenGraphSchemaServiceMockRecorder { + return m.recorder +} + +// UpsertSchemaEnvironmentWithPrincipalKinds mocks base method. +func (m *MockOpenGraphSchemaService) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environments []v2.Environment) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertSchemaEnvironmentWithPrincipalKinds", ctx, schemaExtensionId, environments) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpsertSchemaEnvironmentWithPrincipalKinds indicates an expected call of UpsertSchemaEnvironmentWithPrincipalKinds. +func (mr *MockOpenGraphSchemaServiceMockRecorder) UpsertSchemaEnvironmentWithPrincipalKinds(ctx, schemaExtensionId, environments any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertSchemaEnvironmentWithPrincipalKinds", reflect.TypeOf((*MockOpenGraphSchemaService)(nil).UpsertSchemaEnvironmentWithPrincipalKinds), ctx, schemaExtensionId, environments) +} diff --git a/cmd/api/src/api/v2/model.go b/cmd/api/src/api/v2/model.go index e7f2a3246c..c66cd94019 100644 --- a/cmd/api/src/api/v2/model.go +++ b/cmd/api/src/api/v2/model.go @@ -115,6 +115,7 @@ type Resources struct { Authenticator api.Authenticator IngestSchema upload.IngestSchema FileService fs.Service + openGraphSchemaService OpenGraphSchemaService } func NewResources( @@ -127,6 +128,7 @@ func NewResources( authorizer auth.Authorizer, authenticator api.Authenticator, ingestSchema upload.IngestSchema, + openGraphSchemaService OpenGraphSchemaService, ) Resources { return Resources{ Decoder: schema.NewDecoder(), @@ -141,5 +143,6 @@ func NewResources( Authenticator: authenticator, IngestSchema: ingestSchema, FileService: &fs.Client{}, + openGraphSchemaService: openGraphSchemaService, } } diff --git a/cmd/api/src/api/v2/opengraphschema.go b/cmd/api/src/api/v2/opengraphschema.go new file mode 100644 index 0000000000..50446d6a04 --- /dev/null +++ b/cmd/api/src/api/v2/opengraphschema.go @@ -0,0 +1,62 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package v2 + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/specterops/bloodhound/cmd/api/src/api" +) + +//go:generate go run go.uber.org/mock/mockgen -copyright_file ../../../../../LICENSE.header -destination=./mocks/graphschemaextensions.go -package=mocks . OpenGraphSchemaService +type OpenGraphSchemaService interface { + UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environments []Environment) error +} + +type SchemaUploadRequest struct { + ID int32 `` + Environments []Environment `json:"environments"` +} + +type Environment struct { + EnvironmentKind string `json:"environmentKind"` + SourceKind string `json:"sourceKind"` + PrincipalKinds []string `json:"principalKinds"` +} + +func (s Resources) OpenGraphSchemaIngest(response http.ResponseWriter, request *http.Request) { + var ( + ctx = request.Context() + ) + + var req SchemaUploadRequest + if err := json.NewDecoder(request.Body).Decode(&req); err != nil { + api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponsePayloadUnmarshalError, request), response) + return + } + + if err := s.openGraphSchemaService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { + api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusInternalServerError, fmt.Sprintf("Error upserting schema environment with principal kinds: %w", err), request), response) + return + } + + api.WriteBasicResponse(request.Context(), map[string]string{ + "message": "Schema environments uploaded successfully", + }, http.StatusCreated, response) +} diff --git a/cmd/api/src/daemons/datapipe/mocks/cleanup.go b/cmd/api/src/daemons/datapipe/mocks/cleanup.go index ed40825455..d2f7e31808 100644 --- a/cmd/api/src/daemons/datapipe/mocks/cleanup.go +++ b/cmd/api/src/daemons/datapipe/mocks/cleanup.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index 7c52d14b6d..3801cc5248 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -607,6 +607,47 @@ func (s *BloodhoundDB) DeleteSchemaRelationshipFinding(ctx context.Context, find return nil } +func (s *BloodhoundDB) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { + var envPrincipalKind model.SchemaEnvironmentPrincipalKind + + if result := s.db.WithContext(ctx).Raw(` + INSERT INTO schema_environments_principal_kinds (environment_id, principal_kind) + VALUES (?, ?) + RETURNING environment_id, principal_kind`, + environmentId, principalKind).Scan(&envPrincipalKind); result.Error != nil { + return model.SchemaEnvironmentPrincipalKind{}, CheckError(result) + } + + return envPrincipalKind, nil +} + +func (s *BloodhoundDB) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { + var envPrincipalKinds model.SchemaEnvironmentPrincipalKinds + + if result := s.db.WithContext(ctx).Raw(` + SELECT environment_id, principal_kind + FROM schema_environments_principal_kinds + WHERE environment_id = ?`, + environmentId).Scan(&envPrincipalKinds); result.Error != nil { + return nil, CheckError(result) + } + + return envPrincipalKinds, nil +} + +func (s *BloodhoundDB) DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error { + if result := s.db.WithContext(ctx).Exec(` + DELETE FROM schema_environments_principal_kinds + WHERE environment_id = ? AND principal_kind = ?`, + environmentId, principalKind); result.Error != nil { + return CheckError(result) + } else if result.RowsAffected == 0 { + return ErrNotFound + } + + return nil +} + func parseFiltersAndPagination(filters model.Filters, sort model.Sort, skip, limit int) (FilterAndPagination, error) { var ( filtersAndPagination FilterAndPagination diff --git a/cmd/api/src/database/kind.go b/cmd/api/src/database/kind.go index c89049ceb8..87632f6ae3 100644 --- a/cmd/api/src/database/kind.go +++ b/cmd/api/src/database/kind.go @@ -1,19 +1,37 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 package database import ( "context" - "fmt" "github.com/specterops/bloodhound/cmd/api/src/model" ) -// GetKindById gets a row from the kind table by id. -func (s *BloodhoundDB) GetKindById(ctx context.Context, id int32) (model.Kind, error) { +func (s *BloodhoundDB) GetKindByName(ctx context.Context, name string) (model.Kind, error) { + const query = ` + SELECT id, name + FROM kind + WHERE name = $1; + ` + var kind model.Kind + if err := s.db.Raw(query, name).Scan(&kind).Error; err != nil { + return model.Kind{}, err + } - // var kind Kind - return kind, CheckError(s.db.WithContext(ctx).Raw(fmt.Sprintf(` - SELECT id, name - FROM %s WHERE id = ?`, kindTable), id).First(&kind), - ) + return kind, nil } diff --git a/cmd/api/src/database/kind_integration_test.go b/cmd/api/src/database/kind_integration_test.go new file mode 100644 index 0000000000..089db89dc9 --- /dev/null +++ b/cmd/api/src/database/kind_integration_test.go @@ -0,0 +1,75 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +//go:build integration + +package database_test + +import ( + "testing" + + "github.com/specterops/bloodhound/cmd/api/src/model" + "github.com/stretchr/testify/assert" +) + +// the v7.3.0 migration initializes the kind table with Tag_Tier_Zero, so we're +// simply testing the kind exists +func TestGetKindByName(t *testing.T) { + type args struct { + name string + } + type want struct { + err error + kind model.Kind + } + tests := []struct { + name string + args args + setup func() IntegrationTestSuite + want want + }{ + { + name: "Success: Retrieves Kind Tag_Tier_Zero by name", + args: args{ + name: "Tag_Tier_Zero", + }, + setup: func() IntegrationTestSuite { + return setupIntegrationTestSuite(t) + }, + want: want{ + err: nil, + kind: model.Kind{ + ID: 1, + Name: "Tag_Tier_Zero", + }, + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + testSuite := testCase.setup() + defer teardownIntegrationTestSuite(t, &testSuite) + + kind, err := testSuite.BHDatabase.GetKindByName(testSuite.Context, testCase.args.name) + if testCase.want.err != nil { + assert.EqualError(t, testCase.want.err, err.Error()) + } else { + assert.NoError(t, err) + assert.Equal(t, testCase.want.kind, kind) + } + }) + } +} diff --git a/cmd/api/src/database/migration/migrations/v8.5.0.sql b/cmd/api/src/database/migration/migrations/v8.5.0.sql index 204da86fec..78cf431053 100644 --- a/cmd/api/src/database/migration/migrations/v8.5.0.sql +++ b/cmd/api/src/database/migration/migrations/v8.5.0.sql @@ -184,3 +184,7 @@ $$ END IF; END $$; + +-- Insert a test schema extension +INSERT INTO schema_extensions (name, display_name, version, is_builtin) +VALUES ('test_schema', 'Test Schema', '1.0.0', false); diff --git a/cmd/api/src/database/mocks/auth.go b/cmd/api/src/database/mocks/auth.go index 6934b158d0..d696b4804c 100644 --- a/cmd/api/src/database/mocks/auth.go +++ b/cmd/api/src/database/mocks/auth.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index 49772c522e..ae5c1b5e1c 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. @@ -2230,6 +2230,21 @@ func (mr *MockDatabaseMockRecorder) GetSharedSavedQueries(ctx, userID any) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSharedSavedQueries", reflect.TypeOf((*MockDatabase)(nil).GetSharedSavedQueries), ctx, userID) } +// GetSourceKindByName mocks base method. +func (m *MockDatabase) GetSourceKindByName(ctx context.Context, name string) (database.SourceKind, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSourceKindByName", ctx, name) + ret0, _ := ret[0].(database.SourceKind) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSourceKindByName indicates an expected call of GetSourceKindByName. +func (mr *MockDatabaseMockRecorder) GetSourceKindByName(ctx, name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceKindByName", reflect.TypeOf((*MockDatabase)(nil).GetSourceKindByName), ctx, name) +} + // GetSourceKinds mocks base method. func (m *MockDatabase) GetSourceKinds(ctx context.Context) ([]database.SourceKind, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/database/sourcekinds.go b/cmd/api/src/database/sourcekinds.go index bf95599dd8..d69ed7dcd4 100644 --- a/cmd/api/src/database/sourcekinds.go +++ b/cmd/api/src/database/sourcekinds.go @@ -28,6 +28,7 @@ type SourceKindsData interface { GetSourceKinds(ctx context.Context) ([]SourceKind, error) DeactivateSourceKindsByName(ctx context.Context, kinds graph.Kinds) error RegisterSourceKind(ctx context.Context) func(sourceKind graph.Kind) error + GetSourceKindByName(ctx context.Context, name string) (SourceKind, error) } // RegisterSourceKind returns a function that inserts a source kind by name, @@ -93,6 +94,33 @@ func (s *BloodhoundDB) GetSourceKinds(ctx context.Context) ([]SourceKind, error) return out, nil } +func (s *BloodhoundDB) GetSourceKindByName(ctx context.Context, name string) (SourceKind, error) { + const query = ` + SELECT id, name, active + FROM source_kinds + WHERE name = $1 AND active = true; + ` + + type rawSourceKind struct { + ID int + Name string + Active bool + } + + var raw rawSourceKind + if err := s.db.Raw(query, name).Scan(&raw).Error; err != nil { + return SourceKind{}, err + } + + kind := SourceKind{ + ID: raw.ID, + Name: graph.StringKind(raw.Name), + Active: raw.Active, + } + + return kind, nil +} + func (s *BloodhoundDB) DeactivateSourceKindsByName(ctx context.Context, kinds graph.Kinds) error { if len(kinds) == 0 { return nil diff --git a/cmd/api/src/database/sourcekinds_integration_test.go b/cmd/api/src/database/sourcekinds_integration_test.go index 8d6cbbab45..329831d2cb 100644 --- a/cmd/api/src/database/sourcekinds_integration_test.go +++ b/cmd/api/src/database/sourcekinds_integration_test.go @@ -206,6 +206,56 @@ func TestGetSourceKinds(t *testing.T) { } } +func TestGetSourceKindByName(t *testing.T) { + type args struct { + name string + } + type want struct { + err error + sourceKind database.SourceKind + } + tests := []struct { + name string + args args + setup func() IntegrationTestSuite + want want + }{ + { + name: "Success: Retrieves Source Kinds by Name", + args: args{ + name: "AZBase", + }, + setup: func() IntegrationTestSuite { + return setupIntegrationTestSuite(t) + }, + want: want{ + err: nil, + // the v8.0.0 migration initializes the source_kinds table with Base, AZBase, so we're + // simply testing the default returned source_kinds + sourceKind: database.SourceKind{ + ID: 2, + Name: graph.StringKind("AZBase"), + Active: true, + }, + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + testSuite := testCase.setup() + defer teardownIntegrationTestSuite(t, &testSuite) + + sourceKind, err := testSuite.BHDatabase.GetSourceKindByName(testSuite.Context, testCase.args.name) + if testCase.want.err != nil { + assert.EqualError(t, testCase.want.err, err.Error()) + } else { + assert.NoError(t, err) + assert.Equal(t, testCase.want.sourceKind, sourceKind) + } + }) + } +} + func TestDeactivateSourceKindsByName(t *testing.T) { type args struct { sourceKind graph.Kinds diff --git a/cmd/api/src/model/graphschema.go b/cmd/api/src/model/graphschema.go index cc343e5197..2575af475d 100644 --- a/cmd/api/src/model/graphschema.go +++ b/cmd/api/src/model/graphschema.go @@ -17,7 +17,6 @@ package model import ( - "fmt" "time" ) @@ -101,20 +100,6 @@ func (GraphSchemaEdgeKind) TableName() string { return "schema_edge_kinds" } -// GraphSchemaEnvironments - slice of environments -type GraphSchemaEnvironments []SchemaEnvironment - -// ToMapKeyedOnEnvironmentAndSource - converts a list of graph schema environments to a map based on environment kind id and source id -func (s GraphSchemaEnvironments) ToMapKeyedOnEnvironmentAndSource() map[string]SchemaEnvironment { - m := make(map[string]SchemaEnvironment) - for _, env := range s { - // Key is environment id + source id separated with an underscore e.g., kindid_sourceid - key := fmt.Sprintf("%d_%d", env.EnvironmentKindId, env.SourceKindId) - m[key] = env - } - return m -} - type SchemaEnvironment struct { Serial SchemaExtensionId int32 `json:"schema_extension_id"` diff --git a/cmd/api/src/model/kind.go b/cmd/api/src/model/kind.go index 786fa0a5bf..8f6402c682 100644 --- a/cmd/api/src/model/kind.go +++ b/cmd/api/src/model/kind.go @@ -1,3 +1,18 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 package model type Kind struct { diff --git a/cmd/api/src/queries/mocks/graph.go b/cmd/api/src/queries/mocks/graph.go index a536fce424..85b683c87c 100644 --- a/cmd/api/src/queries/mocks/graph.go +++ b/cmd/api/src/queries/mocks/graph.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/services/agi/mocks/mock.go b/cmd/api/src/services/agi/mocks/mock.go index d6f0342f8e..d8c4c5695b 100644 --- a/cmd/api/src/services/agi/mocks/mock.go +++ b/cmd/api/src/services/agi/mocks/mock.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/services/dataquality/mocks/mock.go b/cmd/api/src/services/dataquality/mocks/mock.go index 11f0a8a31b..0ffa179ce6 100644 --- a/cmd/api/src/services/dataquality/mocks/mock.go +++ b/cmd/api/src/services/dataquality/mocks/mock.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/services/entrypoint.go b/cmd/api/src/services/entrypoint.go index 10ff850660..e0045fc3d2 100644 --- a/cmd/api/src/services/entrypoint.go +++ b/cmd/api/src/services/entrypoint.go @@ -38,6 +38,7 @@ import ( "github.com/specterops/bloodhound/cmd/api/src/migrations" "github.com/specterops/bloodhound/cmd/api/src/model/appcfg" "github.com/specterops/bloodhound/cmd/api/src/queries" + "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" "github.com/specterops/bloodhound/cmd/api/src/services/upload" "github.com/specterops/bloodhound/packages/go/cache" schema "github.com/specterops/bloodhound/packages/go/graphschema" @@ -109,18 +110,19 @@ func Entrypoint(ctx context.Context, cfg config.Configuration, connections boots startDelay := 0 * time.Second var ( - cl = changelog.NewChangelog(connections.Graph, connections.RDMS, changelog.DefaultOptions()) - pipeline = datapipe.NewPipeline(ctx, cfg, connections.RDMS, connections.Graph, graphQueryCache, ingestSchema, cl) - graphQuery = queries.NewGraphQuery(connections.Graph, graphQueryCache, cfg) - authorizer = auth.NewAuthorizer(connections.RDMS) - datapipeDaemon = datapipe.NewDaemon(pipeline, startDelay, time.Duration(cfg.DatapipeInterval)*time.Second, connections.RDMS) - routerInst = router.NewRouter(cfg, authorizer, fmt.Sprintf(bootstrap.ContentSecurityPolicy, "", "")) - ctxInitializer = database.NewContextInitializer(connections.RDMS) - authenticator = api.NewAuthenticator(cfg, connections.RDMS, ctxInitializer) + cl = changelog.NewChangelog(connections.Graph, connections.RDMS, changelog.DefaultOptions()) + pipeline = datapipe.NewPipeline(ctx, cfg, connections.RDMS, connections.Graph, graphQueryCache, ingestSchema, cl) + graphQuery = queries.NewGraphQuery(connections.Graph, graphQueryCache, cfg) + authorizer = auth.NewAuthorizer(connections.RDMS) + datapipeDaemon = datapipe.NewDaemon(pipeline, startDelay, time.Duration(cfg.DatapipeInterval)*time.Second, connections.RDMS) + routerInst = router.NewRouter(cfg, authorizer, fmt.Sprintf(bootstrap.ContentSecurityPolicy, "", "")) + ctxInitializer = database.NewContextInitializer(connections.RDMS) + authenticator = api.NewAuthenticator(cfg, connections.RDMS, ctxInitializer) + openGraphSchemaService = opengraphschema.NewOpenGraphSchemaService(connections.RDMS) ) registration.RegisterFossGlobalMiddleware(&routerInst, cfg, auth.NewIdentityResolver(), authenticator) - registration.RegisterFossRoutes(&routerInst, cfg, connections.RDMS, connections.Graph, graphQuery, apiCache, collectorManifests, authenticator, authorizer, ingestSchema) + registration.RegisterFossRoutes(&routerInst, cfg, connections.RDMS, connections.Graph, graphQuery, apiCache, collectorManifests, authenticator, authorizer, ingestSchema, openGraphSchemaService) // Set neo4j batch and flush sizes neo4jParameters := appcfg.GetNeo4jParameters(ctx, connections.RDMS) diff --git a/cmd/api/src/services/fs/mocks/fs.go b/cmd/api/src/services/fs/mocks/fs.go index 7d05415d18..5ba753d534 100644 --- a/cmd/api/src/services/fs/mocks/fs.go +++ b/cmd/api/src/services/fs/mocks/fs.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/services/graphify/mocks/ingest.go b/cmd/api/src/services/graphify/mocks/ingest.go index 81b68bbc28..c897f58c6d 100644 --- a/cmd/api/src/services/graphify/mocks/ingest.go +++ b/cmd/api/src/services/graphify/mocks/ingest.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/services/oidc/mocks/oidc.go b/cmd/api/src/services/oidc/mocks/oidc.go index 704420ef38..6faa71e5d5 100644 --- a/cmd/api/src/services/oidc/mocks/oidc.go +++ b/cmd/api/src/services/oidc/mocks/oidc.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/services/opengraphschema/environment.go b/cmd/api/src/services/opengraphschema/environment.go index 492da32a49..5c45693808 100644 --- a/cmd/api/src/services/opengraphschema/environment.go +++ b/cmd/api/src/services/opengraphschema/environment.go @@ -20,37 +20,106 @@ import ( "errors" "fmt" + v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" "github.com/specterops/bloodhound/cmd/api/src/database" "github.com/specterops/bloodhound/cmd/api/src/model" + "github.com/specterops/dawgs/graph" ) -func (o *OpenGraphSchemaService) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, environment model.SchemaEnvironment, principalKinds []model.SchemaEnvironmentPrincipalKind) error { - // Validate environment - if err := o.validateSchemaEnvironment(ctx, environment); err != nil { - return fmt.Errorf("error validating schema environment: %w", err) - } +// UpsertSchemaEnvironmentWithPrincipalKinds takes a slice of environments, validates and translates each environment. +// The translation is used to upsert the environments into the database. +// If an existing environment is found to already exist in the database, the existing environment will be removed and the new one will be uploaded. +func (o *OpenGraphSchemaService) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environments []v2.Environment) error { + for _, env := range environments { + environment := model.SchemaEnvironment{ + SchemaExtensionId: schemaExtensionId, + } - // Validate principal kinds - for _, kind := range principalKinds { - if err := o.validateSchemaEnvironmentPrincipalKind(ctx, kind.PrincipalKind); err != nil { - return fmt.Errorf("error validating principal kind: %w", err) + if updatedEnv, principalKinds, err := o.validateAndTranslateEnvironment(ctx, environment, env); err != nil { + return fmt.Errorf("error validating and translating environment: %w", err) + } else if envID, err := o.upsertSchemaEnvironment(ctx, updatedEnv); err != nil { + return fmt.Errorf("error upserting schema environment: %w", err) + } else if err := o.upsertPrincipalKinds(ctx, envID, principalKinds); err != nil { + return fmt.Errorf("error upserting principal kinds: %w", err) } } - // Upsert the environment - id, err := o.upsertSchemaEnvironment(ctx, environment) - if err != nil { - return fmt.Errorf("error upserting schema environment: %w", err) + return nil +} + +// validateAndTranslateEnvironment validates that the environment kind, source kind, and the principal kinds exist in the database. +// It is then translated from the API model to the Database model to prepare it for insert. +func (o *OpenGraphSchemaService) validateAndTranslateEnvironment(ctx context.Context, environment model.SchemaEnvironment, env v2.Environment) (model.SchemaEnvironment, []model.SchemaEnvironmentPrincipalKind, error) { + if envKind, err := o.validateAndTranslateEnvironmentKind(ctx, env.EnvironmentKind); err != nil { + return model.SchemaEnvironment{}, nil, err + } else if sourceKindID, err := o.validateAndTranslateSourceKind(ctx, env.SourceKind); err != nil { + return model.SchemaEnvironment{}, nil, err + } else if principalKinds, err := o.validateAndTranslatePrincipalKinds(ctx, env.PrincipalKinds); err != nil { + return model.SchemaEnvironment{}, nil, err + } else { + // Update environment with translated IDs + environment.EnvironmentKindId = int32(envKind.ID) + environment.SourceKindId = sourceKindID + + return environment, principalKinds, nil } +} - // Upsert principal kinds for environment - if err := o.upsertPrincipalKinds(ctx, id, principalKinds); err != nil { - return fmt.Errorf("error upserting principal kinds: %w", err) +// validateAndTranslateEnvironmentKind validates that the environment kind exists in the kinds table. +func (o *OpenGraphSchemaService) validateAndTranslateEnvironmentKind(ctx context.Context, environmentKindName string) (model.Kind, error) { + if envKind, err := o.openGraphSchemaRepository.GetKindByName(ctx, environmentKindName); err != nil && !errors.Is(err, database.ErrNotFound) { + return model.Kind{}, fmt.Errorf("error retrieving environment kind '%s': %w", environmentKindName, err) + } else if errors.Is(err, database.ErrNotFound){ + return model.Kind{}, fmt.Errorf("environment kind '%s' not found", environmentKindName) + } else { + return envKind, nil } +} - return nil +// validateAndTranslateSourceKind validates that the source kind exists in the source_kinds table. +// If not found, it registers the source kind and returns its ID so it can be added to the Environment object. +func (o *OpenGraphSchemaService) validateAndTranslateSourceKind(ctx context.Context, sourceKindName string) (int32, error) { + if sourceKind, err := o.openGraphSchemaRepository.GetSourceKindByName(ctx, sourceKindName); err != nil && !errors.Is(err, database.ErrNotFound) { + return 0, fmt.Errorf("error retrieving source kind '%s': %w", sourceKindName, err) + } else if err == nil { + return int32(sourceKind.ID), nil + } + + // If source kind is not found, register it. If it exists and is inactive, it will automatically update as active. + kindType := graph.StringKind(sourceKindName) + if err := o.openGraphSchemaRepository.RegisterSourceKind(ctx)(kindType); err != nil { + return 0, fmt.Errorf("error registering source kind '%s': %w", sourceKindName, err) + } + + if sourceKind, err := o.openGraphSchemaRepository.GetSourceKindByName(ctx, sourceKindName); err != nil { + return 0, fmt.Errorf("error retrieving newly registered source kind '%s': %w", sourceKindName, err) + } else { + return int32(sourceKind.ID), nil + } +} + +// validateAndTranslatePrincipalKinds ensures all principalKinds exist in the kinds table. +// It also translates them to IDs so they can be upserted into the database. +func (o *OpenGraphSchemaService) validateAndTranslatePrincipalKinds(ctx context.Context, principalKindNames []string) ([]model.SchemaEnvironmentPrincipalKind, error) { + principalKinds := make([]model.SchemaEnvironmentPrincipalKind, len(principalKindNames)) + + for i, kindName := range principalKindNames { + if kind, err := o.openGraphSchemaRepository.GetKindByName(ctx, kindName); err != nil && !errors.Is(err, database.ErrNotFound) { + return nil, fmt.Errorf("error retrieving principal kind by name '%s': %w", kindName, err) + } else if errors.Is(err, database.ErrNotFound){ + return nil, fmt.Errorf("principal kind '%s' not found", kindName) + } else { + principalKinds[i] = model.SchemaEnvironmentPrincipalKind{ + PrincipalKind: int32(kind.ID), + } + } + } + + return principalKinds, nil } +// upsertSchemaEnvironment creates or updates a schema environment. +// If an environment with the given ID exists, it deletes it first before creating the new one. func (o *OpenGraphSchemaService) upsertSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) (int32, error) { if existing, err := o.openGraphSchemaRepository.GetSchemaEnvironmentById(ctx, graphSchema.ID); err != nil && !errors.Is(err, database.ErrNotFound) { return 0, fmt.Errorf("error retrieving schema environment id %d: %w", graphSchema.ID, err) @@ -69,47 +138,7 @@ func (o *OpenGraphSchemaService) upsertSchemaEnvironment(ctx context.Context, gr } } -/* -Validations: https://github.com/SpecterOps/BloodHound/blob/73b569a340ef5cd459b383e3e42e707b201193ee/rfc/bh-rfc-4.md#10-validation-rules-for-environments - 1. Ensure the specified environmentKind exists in kinds table - ** QUESTION: Documentation states to use the kind table. Is there already a database method to query the kind table to do this validation or was I supposed to create it? - 2. Ensure the specified sourceKind exists in source_kinds table (create if it doesn't, reactivate if it does) -*/ -func (o *OpenGraphSchemaService) validateSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) error { - // Validate environment kind id exists in kinds table - if _, err := o.openGraphSchemaRepository.GetKindById(ctx, graphSchema.EnvironmentKindId); err != nil { - return fmt.Errorf("error retrieving environment kind: %w", err) - } - - // Get all source kinds - if sourceKinds, err := o.openGraphSchemaRepository.GetSourceKinds(ctx); err != nil { - return fmt.Errorf("error retrieving source kinds: %w", err) - } else { - // Check if source kind exists - found := false - for _, kind := range sourceKinds { - if graphSchema.SourceKindId == int32(kind.ID) { - found = true - break - } - } - - if !found { - /* - ** QUESTION: Example Environment Schema: https://github.com/SpecterOps/BloodHound/blob/73b569a340ef5cd459b383e3e42e707b201193ee/rfc/bh-rfc-4.md#9-environments-and-principal-kinds - The RFC example uses source kind names (strings) in the environment schema, but our model uses IDs (int32). To register a new source kind, we need the - kind name/data, not just the ID. Cannot register with only an ID - kind id/name is required for registration. - RegisterSourceKind(ctx context.Context) func(sourceKind graph.Kind) error - - For now, this validates that the source kind should exist. - */ - return fmt.Errorf("invalid source kind id %d", graphSchema.SourceKindId) - } - } - - return nil -} - +// upsertPrincipalKinds deletes all existing principal kinds for an environment and creates new ones. func (o *OpenGraphSchemaService) upsertPrincipalKinds(ctx context.Context, environmentID int32, principalKinds []model.SchemaEnvironmentPrincipalKind) error { if existingKinds, err := o.openGraphSchemaRepository.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentID); err != nil && !errors.Is(err, database.ErrNotFound) { return fmt.Errorf("error retrieving existing principal kinds for environment %d: %w", environmentID, err) @@ -131,19 +160,3 @@ func (o *OpenGraphSchemaService) upsertPrincipalKinds(ctx context.Context, envir return nil } - -/* -Validations: https://github.com/SpecterOps/BloodHound/blob/73b569a340ef5cd459b383e3e42e707b201193ee/rfc/bh-rfc-4.md#10-validation-rules-for-environments -1. Ensure all principalKinds exist in kinds table. - -** QUESTION: Documentation states to use the kind table. Is there already a database method to query the kind table to do this validation or was I supposed to create it? -*/ -func (o *OpenGraphSchemaService) validateSchemaEnvironmentPrincipalKind(ctx context.Context, kindID int32) error { - if _, err := o.openGraphSchemaRepository.GetKindById(ctx, kindID); err != nil && !errors.Is(err, database.ErrNotFound) { - return fmt.Errorf("error retrieving kind by id: %w", err) - } else if errors.Is(err, database.ErrNotFound) { - return fmt.Errorf("invalid principal kind id %d", kindID) - } - - return nil -} diff --git a/cmd/api/src/services/opengraphschema/environment_test.go b/cmd/api/src/services/opengraphschema/environment_test.go index 991509a6ed..e50362a423 100644 --- a/cmd/api/src/services/opengraphschema/environment_test.go +++ b/cmd/api/src/services/opengraphschema/environment_test.go @@ -1,3 +1,18 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 package opengraphschema_test import ( @@ -5,6 +20,7 @@ import ( "errors" "testing" + v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" "github.com/specterops/bloodhound/cmd/api/src/database" "github.com/specterops/bloodhound/cmd/api/src/model" "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" @@ -19,8 +35,8 @@ func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *tes mockOpenGraphSchema *schemamocks.MockOpenGraphSchemaRepository } type args struct { - environment model.SchemaEnvironment - principalKinds []model.SchemaEnvironmentPrincipalKind + schemaExtensionId int32 + environments []v2.Environment } tests := []struct { name string @@ -29,179 +45,188 @@ func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *tes args args expected error }{ + // Validation: Environment Kind { - name: "Error: Validation - Failed to retrieve environment kind", + name: "Error: openGraphSchemaRepository.GetKindByName environment kind name not found in the database", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(1), - EnvironmentKindId: int32(1), - SourceKindId: int32(1), + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{}, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(1)).Return(model.Kind{}, errors.New("error")) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{}, database.ErrNotFound) }, - expected: errors.New("error validating schema environment: error retrieving environment kind: error"), + expected: errors.New("error validating and translating environment: environment kind 'Domain' not found"), }, { - name: "Error: Validation - Environment Kind doesn't exist in Kinds table", + name: "Error: openGraphSchemaRepository.GetKindByName failed to retrieve environment kind from database", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(1), - EnvironmentKindId: int32(1), - SourceKindId: int32(1), + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{}, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(1)).Return(model.Kind{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{}, errors.New("error")) }, - expected: errors.New("error validating schema environment: error retrieving environment kind: entity not found"), + expected: errors.New("error validating and translating environment: error retrieving environment kind 'Domain': error"), }, + // Validation: Source Kind { - name: "Error: Validation - Failed to retrieve source kinds", + name: "Error: validateAndTranslateSourceKind failed to retrieve source kind from database", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(1), - EnvironmentKindId: int32(1), - SourceKindId: int32(1), + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{}, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(1)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{}, errors.New("error")) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 1}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{}, errors.New("error")) }, - expected: errors.New("error validating schema environment: error retrieving source kinds: error"), + expected: errors.New("error validating and translating environment: error retrieving source kind 'Base': error"), }, { - name: "Error: Validation - Source Kind doesn't exist", + name: "Error: validateAndTranslateSourceKind source kind name doesn't exist in database, registration fails", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(1), - EnvironmentKindId: int32(3), - SourceKindId: int32(1), + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{}, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 1}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().RegisterSourceKind(gomock.Any()).Return(func(kind graph.Kind) error { + return errors.New("error") + }) }, - expected: errors.New("error validating schema environment: invalid source kind id 1"), + expected: errors.New("error validating and translating environment: error registering source kind 'Base': error"), }, { - name: "Error: Validation - Failed to retrieve principal kind", + name: "Error: validateAndTranslateSourceKind source kind name doesn't exist in database, registration succeeds but fetch fails", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(1), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), - }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ { - EnvironmentId: int32(1), - PrincipalKind: int32(99), + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, }, }, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - // Environment validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 1}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().RegisterSourceKind(gomock.Any()).Return(func(kind graph.Kind) error { + return nil + }) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{}, errors.New("error")) + }, + expected: errors.New("error validating and translating environment: error retrieving newly registered source kind 'Base': error"), + }, + // Validation: Principal Kind + { + name: "Error: validateAndTranslatePrincipalKinds principal kind not found in database", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ { - ID: 3, - Name: graph.StringKind("kind"), + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User", "InvalidKind"}, }, - }, nil) - // Principal kind validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(99)).Return(model.Kind{}, errors.New("error")) + }, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 1}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 2}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "InvalidKind").Return(model.Kind{}, database.ErrNotFound) }, - expected: errors.New("error validating principal kind: error retrieving kind by id: error"), + expected: errors.New("error validating and translating environment: principal kind 'InvalidKind' not found"), }, { - name: "Error: Validation - Principal Kind doesn't exist", + name: "Error: validateAndTranslatePrincipalKinds failed to retrieve principal kind from database", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(1), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), - }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ { - EnvironmentId: int32(1), - PrincipalKind: int32(99), + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User", "InvalidKind"}, }, }, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - // Environment validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) - // Principal kind validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(99)).Return(model.Kind{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 1}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 2}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "InvalidKind").Return(model.Kind{}, errors.New("error")) }, - expected: errors.New("error validating principal kind: invalid principal kind id 99"), + expected: errors.New("error validating and translating environment: error retrieving principal kind by name 'InvalidKind': error"), }, + // Upsert Schema Environment { - name: "Error: GetSchemaEnvironmentById fails", + name: "Error: upsertSchemaEnvironment error retrieving schema environment from database", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(3), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), + schemaExtensionId: int32(3), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{}, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, errors.New("error")) }, expected: errors.New("error upserting schema environment: error retrieving schema environment id 0: error"), }, { - name: "Error: DeleteSchemaEnvironment fails", + name: "Error: upsertSchemaEnvironment error deleting schema environment", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(3), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), + schemaExtensionId: int32(3), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{}, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{ Serial: model.Serial{ID: 5}, }, nil) @@ -210,24 +235,21 @@ func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *tes expected: errors.New("error upserting schema environment: error deleting schema environment 5: error"), }, { - name: "Error: CreateSchemaEnvironment fails after delete", + name: "Error: upsertSchemaEnvironment error creating schema environment after deletion", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(3), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), + schemaExtensionId: int32(3), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{}, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{ Serial: model.Serial{ID: 5}, }, nil) @@ -237,55 +259,45 @@ func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *tes expected: errors.New("error upserting schema environment: error creating schema environment: error"), }, { - name: "Error: CreateSchemaEnvironment fails on new environment", + name: "Error: upsertSchemaEnvironment error creating new schema environment", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(3), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), + schemaExtensionId: int32(3), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{}, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{}, errors.New("error")) }, expected: errors.New("error upserting schema environment: error creating schema environment: error"), }, + // Upsert Principal Kinds { - name: "Error: GetSchemaEnvironmentPrincipalKindsByEnvironmentId fails", + name: "Error: upsertPrincipalKinds error getting principal kinds by environment id", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(3), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), - }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{ + schemaExtensionId: int32(3), + environments: []v2.Environment{ { - PrincipalKind: int32(3), + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, }, }, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - // Environment validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) - // Principal kind validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + // Validation and translation + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 3}, nil) // Environment upsert mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) @@ -299,31 +311,23 @@ func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *tes expected: errors.New("error upserting principal kinds: error retrieving existing principal kinds for environment 10: error"), }, { - name: "Error: DeleteSchemaEnvironmentPrincipalKind fails", + name: "Error: upsertPrincipalKinds error deleting principal kinds", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(3), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), - }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{ + schemaExtensionId: int32(3), + environments: []v2.Environment{ { - PrincipalKind: int32(3), + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, }, }, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - // Environment validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) - // Principal kind validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + // Validation and translation + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 3}, nil) // Environment upsert mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) @@ -343,31 +347,23 @@ func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *tes expected: errors.New("error upserting principal kinds: error deleting principal kind 5 for environment 10: error"), }, { - name: "Error: CreateSchemaEnvironmentPrincipalKind fails", + name: "Error: upsertPrincipalKinds error creating principal kinds", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(3), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), - }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{ + schemaExtensionId: int32(3), + environments: []v2.Environment{ { - PrincipalKind: int32(3), + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, }, }, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - // Environment validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) - // Principal kind validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + // Validation and translation + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 3}, nil) // Environment upsert mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) @@ -384,37 +380,22 @@ func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *tes { name: "Success: Create new environment with principal kinds", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(3), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), - }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{ + schemaExtensionId: int32(3), + environments: []v2.Environment{ { - PrincipalKind: int32(3), - }, - { - PrincipalKind: int32(4), + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User", "Computer"}, }, }, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - // Environment validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind3"), - }, - { - ID: 4, - Name: graph.StringKind("kind4"), - }, - }, nil) - // Principal kind validations - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(4)).Return(model.Kind{}, nil) + // Validation and translation + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 4}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Computer").Return(model.Kind{ID: 5}, nil) // Environment upsert mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) @@ -424,89 +405,87 @@ func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *tes // Principal kinds upsert mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(3)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(4)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(5)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) }, expected: nil, }, { - name: "Success: Update existing environment and replace principal kinds", + name: "Success: Create environment with source kind registration", args: args{ - environment: model.SchemaEnvironment{ - Serial: model.Serial{ID: 5}, - SchemaExtensionId: int32(3), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), - }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{ + schemaExtensionId: int32(3), + environments: []v2.Environment{ { - PrincipalKind: int32(3), + EnvironmentKind: "Domain", + SourceKind: "NewSource", + PrincipalKinds: []string{}, }, }, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - // Environment validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) - // Principal kind validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + // Validation and translation + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) + // Source kind not found, register it + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "NewSource").Return(database.SourceKind{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().RegisterSourceKind(gomock.Any()).Return(func(kind graph.Kind) error { + return nil + }) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "NewSource").Return(database.SourceKind{ID: 10}, nil) - // Environment upsert (delete and recreate) - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(5)).Return(model.SchemaEnvironment{ - Serial: model.Serial{ID: 5}, - }, nil) - mocks.mockOpenGraphSchema.EXPECT().DeleteSchemaEnvironment(gomock.Any(), int32(5)).Return(nil) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ + // Environment upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(10)).Return(model.SchemaEnvironment{ Serial: model.Serial{ID: 10}, }, nil) - // Principal kinds upsert (delete old, create new) - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return([]model.SchemaEnvironmentPrincipalKind{ - { - EnvironmentId: int32(10), - PrincipalKind: int32(99), - }, - }, nil) - mocks.mockOpenGraphSchema.EXPECT().DeleteSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(99)).Return(nil) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(3)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) + // Principal kinds upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, database.ErrNotFound) }, expected: nil, }, { - name: "Success: Create environment with no principal kinds", + name: "Success: Process multiple environments", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(3), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), + schemaExtensionId: int32(3), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, + }, + { + EnvironmentKind: "AzureAD", + SourceKind: "AzureHound", + PrincipalKinds: []string{"User", "Group"}, + }, }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{}, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - // Environment validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) - - // Environment upsert + // First environment + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 4}, nil) mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ Serial: model.Serial{ID: 10}, }, nil) - - // Principal kinds upsert (no existing, no new) mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(4)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) + + // Second environment + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "AzureAD").Return(model.Kind{ID: 5}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "AzureHound").Return(database.SourceKind{ID: 6}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 4}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Group").Return(model.Kind{ID: 7}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(5), int32(6)).Return(model.SchemaEnvironment{ + Serial: model.Serial{ID: 11}, + }, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(11)).Return(nil, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(11), int32(4)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(11), int32(7)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) }, expected: nil, }, @@ -524,9 +503,9 @@ func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *tes graphService := opengraphschema.NewOpenGraphSchemaService(mocks.mockOpenGraphSchema) - err := graphService.UpsertSchemaEnvironmentWithPrincipalKinds(context.Background(), tt.args.environment, tt.args.principalKinds) + err := graphService.UpsertSchemaEnvironmentWithPrincipalKinds(context.Background(), tt.args.schemaExtensionId, tt.args.environments) if tt.expected != nil { - assert.EqualError(t, tt.expected, err.Error()) + assert.EqualError(t, err, tt.expected.Error()) } else { assert.NoError(t, err) } diff --git a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go index 7723197208..a3793b1572 100644 --- a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go @@ -117,19 +117,19 @@ func (mr *MockOpenGraphSchemaRepositoryMockRecorder) DeleteSchemaEnvironmentPrin return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).DeleteSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) } -// GetKindById mocks base method. -func (m *MockOpenGraphSchemaRepository) GetKindById(ctx context.Context, id int32) (model.Kind, error) { +// GetKindByName mocks base method. +func (m *MockOpenGraphSchemaRepository) GetKindByName(ctx context.Context, name string) (model.Kind, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetKindById", ctx, id) + ret := m.ctrl.Call(m, "GetKindByName", ctx, name) ret0, _ := ret[0].(model.Kind) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetKindById indicates an expected call of GetKindById. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetKindById(ctx, id any) *gomock.Call { +// GetKindByName indicates an expected call of GetKindByName. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetKindByName(ctx, name any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKindById", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetKindById), ctx, id) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKindByName", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetKindByName), ctx, name) } // GetSchemaEnvironmentById mocks base method. @@ -162,19 +162,19 @@ func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetSchemaEnvironmentPrincip return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetSchemaEnvironmentPrincipalKindsByEnvironmentId), ctx, environmentId) } -// GetSourceKinds mocks base method. -func (m *MockOpenGraphSchemaRepository) GetSourceKinds(ctx context.Context) ([]database.SourceKind, error) { +// GetSourceKindByName mocks base method. +func (m *MockOpenGraphSchemaRepository) GetSourceKindByName(ctx context.Context, name string) (database.SourceKind, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSourceKinds", ctx) - ret0, _ := ret[0].([]database.SourceKind) + ret := m.ctrl.Call(m, "GetSourceKindByName", ctx, name) + ret0, _ := ret[0].(database.SourceKind) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetSourceKinds indicates an expected call of GetSourceKinds. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetSourceKinds(ctx any) *gomock.Call { +// GetSourceKindByName indicates an expected call of GetSourceKindByName. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetSourceKindByName(ctx, name any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceKinds", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetSourceKinds), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceKindByName", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetSourceKindByName), ctx, name) } // RegisterSourceKind mocks base method. diff --git a/cmd/api/src/services/opengraphschema/opengraphschema.go b/cmd/api/src/services/opengraphschema/opengraphschema.go index 1ff658041d..cea2845ee9 100644 --- a/cmd/api/src/services/opengraphschema/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/opengraphschema.go @@ -28,7 +28,7 @@ import ( // OpenGraphSchemaRepository - type OpenGraphSchemaRepository interface { // Kinds - GetKindById(ctx context.Context, id int32) (model.Kind, error) + GetKindByName(ctx context.Context, name string) (model.Kind, error) // Environment CreateSchemaEnvironment(ctx context.Context, schemaExtensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) @@ -37,7 +37,7 @@ type OpenGraphSchemaRepository interface { // Source Kinds RegisterSourceKind(ctx context.Context) func(sourceKind graph.Kind) error - GetSourceKinds(ctx context.Context) ([]database.SourceKind, error) + GetSourceKindByName(ctx context.Context, name string) (database.SourceKind, error) // Principal Kinds CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) diff --git a/cmd/api/src/services/saml/mocks/saml.go b/cmd/api/src/services/saml/mocks/saml.go index f71839355b..e6029f5654 100644 --- a/cmd/api/src/services/saml/mocks/saml.go +++ b/cmd/api/src/services/saml/mocks/saml.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/services/upload/mocks/mock.go b/cmd/api/src/services/upload/mocks/mock.go index bd1cd94039..a3b90e7fb3 100644 --- a/cmd/api/src/services/upload/mocks/mock.go +++ b/cmd/api/src/services/upload/mocks/mock.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/utils/validation/mocks/validator.go b/cmd/api/src/utils/validation/mocks/validator.go index 309fe51330..536c0385a5 100644 --- a/cmd/api/src/utils/validation/mocks/validator.go +++ b/cmd/api/src/utils/validation/mocks/validator.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/vendormocks/dawgs/graph/mock.go b/cmd/api/src/vendormocks/dawgs/graph/mock.go index 45ea562541..14f78cb94d 100644 --- a/cmd/api/src/vendormocks/dawgs/graph/mock.go +++ b/cmd/api/src/vendormocks/dawgs/graph/mock.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/vendormocks/io/fs/mock.go b/cmd/api/src/vendormocks/io/fs/mock.go index 5c094281bd..f90a77b286 100644 --- a/cmd/api/src/vendormocks/io/fs/mock.go +++ b/cmd/api/src/vendormocks/io/fs/mock.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/vendormocks/neo4j/neo4j-go-driver/v5/neo4j/mock.go b/cmd/api/src/vendormocks/neo4j/neo4j-go-driver/v5/neo4j/mock.go index a4adb00b99..1e3a6ab710 100644 --- a/cmd/api/src/vendormocks/neo4j/neo4j-go-driver/v5/neo4j/mock.go +++ b/cmd/api/src/vendormocks/neo4j/neo4j-go-driver/v5/neo4j/mock.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/packages/go/crypto/mocks/digest.go b/packages/go/crypto/mocks/digest.go index 3c1acaa2ce..ddd118ab95 100644 --- a/packages/go/crypto/mocks/digest.go +++ b/packages/go/crypto/mocks/digest.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. From bc97cb007ff88dd577ea85d7f8bf98fac78ebf4d Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Tue, 6 Jan 2026 18:14:41 -0600 Subject: [PATCH 03/36] cleanup --- cmd/api/src/api/v2/opengraphschema.go | 8 ++++---- cmd/api/src/database/migration/migrations/v8.5.0.sql | 4 ---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/cmd/api/src/api/v2/opengraphschema.go b/cmd/api/src/api/v2/opengraphschema.go index 50446d6a04..2a6459876f 100644 --- a/cmd/api/src/api/v2/opengraphschema.go +++ b/cmd/api/src/api/v2/opengraphschema.go @@ -40,6 +40,7 @@ type Environment struct { PrincipalKinds []string `json:"principalKinds"` } +// TODO: Implement this - barebones in order to test handler. func (s Resources) OpenGraphSchemaIngest(response http.ResponseWriter, request *http.Request) { var ( ctx = request.Context() @@ -51,12 +52,11 @@ func (s Resources) OpenGraphSchemaIngest(response http.ResponseWriter, request * return } + // TODO: Pass Extension ID instead of harcoded value if err := s.openGraphSchemaService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { - api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusInternalServerError, fmt.Sprintf("Error upserting schema environment with principal kinds: %w", err), request), response) + api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusInternalServerError, fmt.Sprintf("error upserting environment with principal kinds: %v", err), request), response) return } - api.WriteBasicResponse(request.Context(), map[string]string{ - "message": "Schema environments uploaded successfully", - }, http.StatusCreated, response) + api.WriteBasicResponse(request.Context(), "Success", http.StatusCreated, response) } diff --git a/cmd/api/src/database/migration/migrations/v8.5.0.sql b/cmd/api/src/database/migration/migrations/v8.5.0.sql index 78cf431053..204da86fec 100644 --- a/cmd/api/src/database/migration/migrations/v8.5.0.sql +++ b/cmd/api/src/database/migration/migrations/v8.5.0.sql @@ -184,7 +184,3 @@ $$ END IF; END $$; - --- Insert a test schema extension -INSERT INTO schema_extensions (name, display_name, version, is_builtin) -VALUES ('test_schema', 'Test Schema', '1.0.0', false); From bce213baeee6d305988e2e03f8cd38dee4cf3f56 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Wed, 7 Jan 2026 12:38:52 -0600 Subject: [PATCH 04/36] transactions that hopefully satisfy both databases --- cmd/api/src/api/v2/opengraphschema.go | 12 ++++----- .../src/services/opengraphschema/extension.go | 25 +++++++++++++++++++ .../opengraphschema/mocks/opengraphschema.go | 20 +++++++++++++++ .../opengraphschema/opengraphschema.go | 5 ++++ 4 files changed, 55 insertions(+), 7 deletions(-) create mode 100644 cmd/api/src/services/opengraphschema/extension.go diff --git a/cmd/api/src/api/v2/opengraphschema.go b/cmd/api/src/api/v2/opengraphschema.go index 2a6459876f..6600045f60 100644 --- a/cmd/api/src/api/v2/opengraphschema.go +++ b/cmd/api/src/api/v2/opengraphschema.go @@ -26,11 +26,10 @@ import ( //go:generate go run go.uber.org/mock/mockgen -copyright_file ../../../../../LICENSE.header -destination=./mocks/graphschemaextensions.go -package=mocks . OpenGraphSchemaService type OpenGraphSchemaService interface { - UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environments []Environment) error + UpsertGraphSchemaExtension(ctx context.Context, req GraphSchemaExtension) error } -type SchemaUploadRequest struct { - ID int32 `` +type GraphSchemaExtension struct { Environments []Environment `json:"environments"` } @@ -46,15 +45,14 @@ func (s Resources) OpenGraphSchemaIngest(response http.ResponseWriter, request * ctx = request.Context() ) - var req SchemaUploadRequest + var req GraphSchemaExtension if err := json.NewDecoder(request.Body).Decode(&req); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponsePayloadUnmarshalError, request), response) return } - // TODO: Pass Extension ID instead of harcoded value - if err := s.openGraphSchemaService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { - api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusInternalServerError, fmt.Sprintf("error upserting environment with principal kinds: %v", err), request), response) + if err := s.openGraphSchemaService.UpsertGraphSchemaExtension(ctx, req); err != nil { + api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusInternalServerError, fmt.Sprintf("error upserting graph schema extension: %v", err), request), response) return } diff --git a/cmd/api/src/services/opengraphschema/extension.go b/cmd/api/src/services/opengraphschema/extension.go new file mode 100644 index 0000000000..22848aa05f --- /dev/null +++ b/cmd/api/src/services/opengraphschema/extension.go @@ -0,0 +1,25 @@ +package opengraphschema + +import ( + "context" + "fmt" + + v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" +) + +func (o *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { + return o.transactor.WithTransaction(ctx, func(repo OpenGraphSchemaRepository) error { + txService := &OpenGraphSchemaService{ + openGraphSchemaRepository: repo, + transactor: o.transactor, + } + + // Upsert environments with principal kinds + // TODO: Temporary hardcoded extension ID + if err := txService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { + return fmt.Errorf("failed to upload environments with principal kinds: %w", err) + } + + return nil + }) +} diff --git a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go index a3793b1572..8ff7dabfa8 100644 --- a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go @@ -27,6 +27,7 @@ package mocks import ( context "context" + sql "database/sql" reflect "reflect" database "github.com/specterops/bloodhound/cmd/api/src/database" @@ -190,3 +191,22 @@ func (mr *MockOpenGraphSchemaRepositoryMockRecorder) RegisterSourceKind(ctx any) mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterSourceKind", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).RegisterSourceKind), ctx) } + +// Transaction mocks base method. +func (m *MockOpenGraphSchemaRepository) Transaction(ctx context.Context, fn func(*database.BloodhoundDB) error, opts ...*sql.TxOptions) error { + m.ctrl.T.Helper() + varargs := []any{ctx, fn} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Transaction", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Transaction indicates an expected call of Transaction. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) Transaction(ctx, fn any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, fn}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Transaction", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).Transaction), varargs...) +} diff --git a/cmd/api/src/services/opengraphschema/opengraphschema.go b/cmd/api/src/services/opengraphschema/opengraphschema.go index cea2845ee9..a37535f639 100644 --- a/cmd/api/src/services/opengraphschema/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/opengraphschema.go @@ -47,6 +47,11 @@ type OpenGraphSchemaRepository interface { type OpenGraphSchemaService struct { openGraphSchemaRepository OpenGraphSchemaRepository + transactor Transactor +} + +type Transactor interface { + WithTransaction(ctx context.Context, fn func(repo OpenGraphSchemaRepository) error) error } func NewOpenGraphSchemaService(openGraphSchemaRepository OpenGraphSchemaRepository) *OpenGraphSchemaService { From 89998ee5b21626835ed46d9cb7aee32fd5a769cf Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Wed, 7 Jan 2026 12:43:49 -0600 Subject: [PATCH 05/36] just prepare --- .../src/api/v2/mocks/graphschemaextensions.go | 12 ++++---- cmd/api/src/config/config.go | 2 +- cmd/api/src/config/default.go | 2 +- .../services/opengraphschema/environment.go | 4 +-- .../src/services/opengraphschema/extension.go | 29 ++++++++++++++----- .../opengraphschema/mocks/opengraphschema.go | 20 ------------- 6 files changed, 32 insertions(+), 37 deletions(-) diff --git a/cmd/api/src/api/v2/mocks/graphschemaextensions.go b/cmd/api/src/api/v2/mocks/graphschemaextensions.go index 96cb240811..21803d8b7b 100644 --- a/cmd/api/src/api/v2/mocks/graphschemaextensions.go +++ b/cmd/api/src/api/v2/mocks/graphschemaextensions.go @@ -57,16 +57,16 @@ func (m *MockOpenGraphSchemaService) EXPECT() *MockOpenGraphSchemaServiceMockRec return m.recorder } -// UpsertSchemaEnvironmentWithPrincipalKinds mocks base method. -func (m *MockOpenGraphSchemaService) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environments []v2.Environment) error { +// UpsertGraphSchemaExtension mocks base method. +func (m *MockOpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpsertSchemaEnvironmentWithPrincipalKinds", ctx, schemaExtensionId, environments) + ret := m.ctrl.Call(m, "UpsertGraphSchemaExtension", ctx, req) ret0, _ := ret[0].(error) return ret0 } -// UpsertSchemaEnvironmentWithPrincipalKinds indicates an expected call of UpsertSchemaEnvironmentWithPrincipalKinds. -func (mr *MockOpenGraphSchemaServiceMockRecorder) UpsertSchemaEnvironmentWithPrincipalKinds(ctx, schemaExtensionId, environments any) *gomock.Call { +// UpsertGraphSchemaExtension indicates an expected call of UpsertGraphSchemaExtension. +func (mr *MockOpenGraphSchemaServiceMockRecorder) UpsertGraphSchemaExtension(ctx, req any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertSchemaEnvironmentWithPrincipalKinds", reflect.TypeOf((*MockOpenGraphSchemaService)(nil).UpsertSchemaEnvironmentWithPrincipalKinds), ctx, schemaExtensionId, environments) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertGraphSchemaExtension", reflect.TypeOf((*MockOpenGraphSchemaService)(nil).UpsertGraphSchemaExtension), ctx, req) } diff --git a/cmd/api/src/config/config.go b/cmd/api/src/config/config.go index b928d38d3a..aa8a218aef 100644 --- a/cmd/api/src/config/config.go +++ b/cmd/api/src/config/config.go @@ -162,7 +162,7 @@ type Configuration struct { DisableCypherComplexityLimit bool `json:"disable_cypher_complexity_limit"` DisableIngest bool `json:"disable_ingest"` DisableMigrations bool `json:"disable_migrations"` - DisableTimeoutLimit bool `json:"disable_timeout_limit"` + DisableTimeoutLimit bool `json:"disable_timeout_limit"` GraphQueryMemoryLimit uint16 `json:"graph_query_memory_limit"` EnableTextLogger bool `json:"enable_text_logger"` RecreateDefaultAdmin bool `json:"recreate_default_admin"` diff --git a/cmd/api/src/config/default.go b/cmd/api/src/config/default.go index 656a308978..15efb8a6e0 100644 --- a/cmd/api/src/config/default.go +++ b/cmd/api/src/config/default.go @@ -65,7 +65,7 @@ func NewDefaultConfiguration() (Configuration, error) { DisableCypherComplexityLimit: false, DisableIngest: false, DisableMigrations: false, - DisableTimeoutLimit: false, + DisableTimeoutLimit: false, EnableCypherMutations: false, RecreateDefaultAdmin: false, ForceDownloadEmbeddedCollectors: false, diff --git a/cmd/api/src/services/opengraphschema/environment.go b/cmd/api/src/services/opengraphschema/environment.go index 5c45693808..662d429637 100644 --- a/cmd/api/src/services/opengraphschema/environment.go +++ b/cmd/api/src/services/opengraphschema/environment.go @@ -69,7 +69,7 @@ func (o *OpenGraphSchemaService) validateAndTranslateEnvironment(ctx context.Con func (o *OpenGraphSchemaService) validateAndTranslateEnvironmentKind(ctx context.Context, environmentKindName string) (model.Kind, error) { if envKind, err := o.openGraphSchemaRepository.GetKindByName(ctx, environmentKindName); err != nil && !errors.Is(err, database.ErrNotFound) { return model.Kind{}, fmt.Errorf("error retrieving environment kind '%s': %w", environmentKindName, err) - } else if errors.Is(err, database.ErrNotFound){ + } else if errors.Is(err, database.ErrNotFound) { return model.Kind{}, fmt.Errorf("environment kind '%s' not found", environmentKindName) } else { return envKind, nil @@ -106,7 +106,7 @@ func (o *OpenGraphSchemaService) validateAndTranslatePrincipalKinds(ctx context. for i, kindName := range principalKindNames { if kind, err := o.openGraphSchemaRepository.GetKindByName(ctx, kindName); err != nil && !errors.Is(err, database.ErrNotFound) { return nil, fmt.Errorf("error retrieving principal kind by name '%s': %w", kindName, err) - } else if errors.Is(err, database.ErrNotFound){ + } else if errors.Is(err, database.ErrNotFound) { return nil, fmt.Errorf("principal kind '%s' not found", kindName) } else { principalKinds[i] = model.SchemaEnvironmentPrincipalKind{ diff --git a/cmd/api/src/services/opengraphschema/extension.go b/cmd/api/src/services/opengraphschema/extension.go index 22848aa05f..5c8cbeb456 100644 --- a/cmd/api/src/services/opengraphschema/extension.go +++ b/cmd/api/src/services/opengraphschema/extension.go @@ -1,3 +1,18 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 package opengraphschema import ( @@ -8,18 +23,18 @@ import ( ) func (o *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { - return o.transactor.WithTransaction(ctx, func(repo OpenGraphSchemaRepository) error { - txService := &OpenGraphSchemaService{ - openGraphSchemaRepository: repo, - transactor: o.transactor, - } + return o.transactor.WithTransaction(ctx, func(repo OpenGraphSchemaRepository) error { + txService := &OpenGraphSchemaService{ + openGraphSchemaRepository: repo, + transactor: o.transactor, + } // Upsert environments with principal kinds // TODO: Temporary hardcoded extension ID - if err := txService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { + if err := txService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { return fmt.Errorf("failed to upload environments with principal kinds: %w", err) } return nil - }) + }) } diff --git a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go index 8ff7dabfa8..a3793b1572 100644 --- a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go @@ -27,7 +27,6 @@ package mocks import ( context "context" - sql "database/sql" reflect "reflect" database "github.com/specterops/bloodhound/cmd/api/src/database" @@ -191,22 +190,3 @@ func (mr *MockOpenGraphSchemaRepositoryMockRecorder) RegisterSourceKind(ctx any) mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterSourceKind", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).RegisterSourceKind), ctx) } - -// Transaction mocks base method. -func (m *MockOpenGraphSchemaRepository) Transaction(ctx context.Context, fn func(*database.BloodhoundDB) error, opts ...*sql.TxOptions) error { - m.ctrl.T.Helper() - varargs := []any{ctx, fn} - for _, a := range opts { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "Transaction", varargs...) - ret0, _ := ret[0].(error) - return ret0 -} - -// Transaction indicates an expected call of Transaction. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) Transaction(ctx, fn any, opts ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{ctx, fn}, opts...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Transaction", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).Transaction), varargs...) -} From 6ded4e314ca557d3507a0bc31dbee3ae6a0e2627 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Wed, 7 Jan 2026 15:03:34 -0600 Subject: [PATCH 06/36] abandoned the transactor --- cmd/api/src/database/db.go | 37 +++++++++++++++++++ .../src/services/opengraphschema/extension.go | 35 +++++++++++------- 2 files changed, 59 insertions(+), 13 deletions(-) diff --git a/cmd/api/src/database/db.go b/cmd/api/src/database/db.go index 7f776d595f..8dd772f842 100644 --- a/cmd/api/src/database/db.go +++ b/cmd/api/src/database/db.go @@ -238,6 +238,43 @@ func (s *BloodhoundDB) Transaction(ctx context.Context, fn func(tx *BloodhoundDB }, opts...) } +/* Manual Transaction Control +The following methods provide manual control over transactions as an alternative to the automatic Transaction method above. +Use these when you need explicit control over when to commit or rollback. +- BeginTransaction +- Commit +- Rollback +*/ + +// BeginTransaction starts a new database transaction and returns a transactional-aware BloodhoundDB. +func (s *BloodhoundDB) BeginTransaction(ctx context.Context, opts ...*sql.TxOptions) (*BloodhoundDB, error) { + tx := s.db.WithContext(ctx).Begin(opts...) + if tx.Error != nil { + return nil, fmt.Errorf("error beginning transaction: %w", tx.Error) + } + + return &BloodhoundDB{ + db: tx, + idResolver: s.idResolver, + }, nil +} + +// Commit commits the transaction and releases the database connection back to the pool. +func (s *BloodhoundDB) Commit() error { + if err := s.db.Commit().Error; err != nil { + return fmt.Errorf("error committing transaction: %w", err) + } + return nil +} + +// Rollback rolls back the transaction and releases the database connection back to the pool. +func (s *BloodhoundDB) Rollback() error { + if err := s.db.Rollback().Error; err != nil { + return fmt.Errorf("error rolling back transaction: %w", err) + } + return nil +} + func OpenDatabase(connection string) (*gorm.DB, error) { gormConfig := &gorm.Config{ Logger: &GormLogAdapter{ diff --git a/cmd/api/src/services/opengraphschema/extension.go b/cmd/api/src/services/opengraphschema/extension.go index 5c8cbeb456..e19ff8aaf3 100644 --- a/cmd/api/src/services/opengraphschema/extension.go +++ b/cmd/api/src/services/opengraphschema/extension.go @@ -4,7 +4,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -20,21 +20,30 @@ import ( "fmt" v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" + "github.com/specterops/bloodhound/cmd/api/src/database" ) func (o *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { - return o.transactor.WithTransaction(ctx, func(repo OpenGraphSchemaRepository) error { - txService := &OpenGraphSchemaService{ - openGraphSchemaRepository: repo, - transactor: o.transactor, - } + db, ok := o.openGraphSchemaRepository.(*database.BloodhoundDB) + if !ok { + return fmt.Errorf("database not found: unable to begin transaction") + } - // Upsert environments with principal kinds - // TODO: Temporary hardcoded extension ID - if err := txService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { - return fmt.Errorf("failed to upload environments with principal kinds: %w", err) - } + tx, err := db.BeginTransaction(ctx) + if err != nil { + return err + } - return nil - }) + txService := &OpenGraphSchemaService{ + openGraphSchemaRepository: tx, + } + + // TODO: Temporary hardcoded extension ID + // Upsert environments with principal kinds + if err := txService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { + tx.Rollback() + return fmt.Errorf("failed to upload environments with principal kinds: %w", err) + } + + return tx.Commit() } From 6c082854ed18a26f9836ad99bd3a86fe37de3844 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Wed, 7 Jan 2026 15:07:13 -0600 Subject: [PATCH 07/36] abandoned the transactor --- cmd/api/src/services/opengraphschema/opengraphschema.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/cmd/api/src/services/opengraphschema/opengraphschema.go b/cmd/api/src/services/opengraphschema/opengraphschema.go index a37535f639..cea2845ee9 100644 --- a/cmd/api/src/services/opengraphschema/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/opengraphschema.go @@ -47,11 +47,6 @@ type OpenGraphSchemaRepository interface { type OpenGraphSchemaService struct { openGraphSchemaRepository OpenGraphSchemaRepository - transactor Transactor -} - -type Transactor interface { - WithTransaction(ctx context.Context, fn func(repo OpenGraphSchemaRepository) error) error } func NewOpenGraphSchemaService(openGraphSchemaRepository OpenGraphSchemaRepository) *OpenGraphSchemaService { From 4ae5dfbc616b05505bed4d15455c4ded2a696b47 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Thu, 8 Jan 2026 11:16:29 -0600 Subject: [PATCH 08/36] entry pointed corrected so transactions can live on the interface again and all is well in the world --- .../src/services/opengraphschema/extension.go | 13 ++--- .../opengraphschema/mocks/opengraphschema.go | 49 +++++++++++++++++++ .../opengraphschema/opengraphschema.go | 6 +++ 3 files changed, 59 insertions(+), 9 deletions(-) diff --git a/cmd/api/src/services/opengraphschema/extension.go b/cmd/api/src/services/opengraphschema/extension.go index e19ff8aaf3..aefcff32e3 100644 --- a/cmd/api/src/services/opengraphschema/extension.go +++ b/cmd/api/src/services/opengraphschema/extension.go @@ -20,27 +20,22 @@ import ( "fmt" v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" - "github.com/specterops/bloodhound/cmd/api/src/database" ) func (o *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { - db, ok := o.openGraphSchemaRepository.(*database.BloodhoundDB) - if !ok { - return fmt.Errorf("database not found: unable to begin transaction") - } - - tx, err := db.BeginTransaction(ctx) + tx, err := o.openGraphSchemaRepository.BeginTransaction(ctx) if err != nil { return err } - txService := &OpenGraphSchemaService{ + // Create service with transaction + transactionalService := &OpenGraphSchemaService{ openGraphSchemaRepository: tx, } // TODO: Temporary hardcoded extension ID // Upsert environments with principal kinds - if err := txService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { + if err := transactionalService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { tx.Rollback() return fmt.Errorf("failed to upload environments with principal kinds: %w", err) } diff --git a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go index a3793b1572..e9f00e40dc 100644 --- a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go @@ -27,6 +27,7 @@ package mocks import ( context "context" + sql "database/sql" reflect "reflect" database "github.com/specterops/bloodhound/cmd/api/src/database" @@ -59,6 +60,40 @@ func (m *MockOpenGraphSchemaRepository) EXPECT() *MockOpenGraphSchemaRepositoryM return m.recorder } +// BeginTransaction mocks base method. +func (m *MockOpenGraphSchemaRepository) BeginTransaction(ctx context.Context, opts ...*sql.TxOptions) (*database.BloodhoundDB, error) { + m.ctrl.T.Helper() + varargs := []any{ctx} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "BeginTransaction", varargs...) + ret0, _ := ret[0].(*database.BloodhoundDB) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTransaction indicates an expected call of BeginTransaction. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) BeginTransaction(ctx any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTransaction", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).BeginTransaction), varargs...) +} + +// Commit mocks base method. +func (m *MockOpenGraphSchemaRepository) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).Commit)) +} + // CreateSchemaEnvironment mocks base method. func (m *MockOpenGraphSchemaRepository) CreateSchemaEnvironment(ctx context.Context, schemaExtensionId, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { m.ctrl.T.Helper() @@ -190,3 +225,17 @@ func (mr *MockOpenGraphSchemaRepositoryMockRecorder) RegisterSourceKind(ctx any) mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterSourceKind", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).RegisterSourceKind), ctx) } + +// Rollback mocks base method. +func (m *MockOpenGraphSchemaRepository) Rollback() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Rollback") + ret0, _ := ret[0].(error) + return ret0 +} + +// Rollback indicates an expected call of Rollback. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) Rollback() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).Rollback)) +} diff --git a/cmd/api/src/services/opengraphschema/opengraphschema.go b/cmd/api/src/services/opengraphschema/opengraphschema.go index cea2845ee9..21e61bd088 100644 --- a/cmd/api/src/services/opengraphschema/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/opengraphschema.go @@ -19,6 +19,7 @@ package opengraphschema import ( "context" + "database/sql" "github.com/specterops/bloodhound/cmd/api/src/database" "github.com/specterops/bloodhound/cmd/api/src/model" @@ -27,6 +28,11 @@ import ( // OpenGraphSchemaRepository - type OpenGraphSchemaRepository interface { + // TX + BeginTransaction(ctx context.Context, opts ...*sql.TxOptions) (*database.BloodhoundDB, error) + Commit() error + Rollback() error + // Kinds GetKindByName(ctx context.Context, name string) (model.Kind, error) From a3d942284dff58a408ac14e4ef5c39dde6bb5fc0 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Thu, 8 Jan 2026 11:56:44 -0600 Subject: [PATCH 09/36] some cleanup --- cmd/api/src/api/v2/opengraphschema.go | 2 +- cmd/api/src/database/db.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/api/src/api/v2/opengraphschema.go b/cmd/api/src/api/v2/opengraphschema.go index 6600045f60..a49248fdf0 100644 --- a/cmd/api/src/api/v2/opengraphschema.go +++ b/cmd/api/src/api/v2/opengraphschema.go @@ -56,5 +56,5 @@ func (s Resources) OpenGraphSchemaIngest(response http.ResponseWriter, request * return } - api.WriteBasicResponse(request.Context(), "Success", http.StatusCreated, response) + response.WriteHeader(http.StatusCreated) } diff --git a/cmd/api/src/database/db.go b/cmd/api/src/database/db.go index 8dd772f842..c8ea2a713a 100644 --- a/cmd/api/src/database/db.go +++ b/cmd/api/src/database/db.go @@ -246,7 +246,7 @@ Use these when you need explicit control over when to commit or rollback. - Rollback */ -// BeginTransaction starts a new database transaction and returns a transactional-aware BloodhoundDB. +// BeginTransaction starts a new database transaction and returns a transactional-aware connection of BloodhoundDB. func (s *BloodhoundDB) BeginTransaction(ctx context.Context, opts ...*sql.TxOptions) (*BloodhoundDB, error) { tx := s.db.WithContext(ctx).Begin(opts...) if tx.Error != nil { From 5987407d881d2064e459fe0f4d9b9d2fe0489fa0 Mon Sep 17 00:00:00 2001 From: Conrad Weidenkeller Date: Mon, 29 Dec 2025 07:46:16 -0600 Subject: [PATCH 10/36] feat(graphschema): Add Environment principal kinds BED-7076 --- cmd/api/src/database/graphschema.go | 44 +++ .../database/graphschema_integration_test.go | 254 ++++++++++++++++++ .../database/migration/migrations/v8.5.0.sql | 1 + cmd/api/src/database/mocks/db.go | 44 +++ cmd/api/src/model/graphschema.go | 12 + 5 files changed, 355 insertions(+) diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index 85ed961b08..70613a9769 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -65,6 +65,9 @@ type OpenGraphSchema interface { GetRemediationByFindingId(ctx context.Context, findingId int32) (model.Remediation, error) UpdateRemediation(ctx context.Context, findingId int32, shortDescription string, longDescription string, shortRemediation string, longRemediation string) (model.Remediation, error) DeleteRemediation(ctx context.Context, findingId int32) error + CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) + GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) + DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error } const DuplicateKeyValueErrorString = "duplicate key value violates unique constraint" @@ -708,6 +711,47 @@ func (s *BloodhoundDB) DeleteRemediation(ctx context.Context, findingId int32) e return nil } +func (s *BloodhoundDB) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { + var envPrincipalKind model.SchemaEnvironmentPrincipalKind + + if result := s.db.WithContext(ctx).Raw(` + INSERT INTO schema_environments_principal_kinds (environment_id, principal_kind, created_at) + VALUES (?, ?, NOW()) + RETURNING environment_id, principal_kind, created_at`, + environmentId, principalKind).Scan(&envPrincipalKind); result.Error != nil { + return model.SchemaEnvironmentPrincipalKind{}, CheckError(result) + } + + return envPrincipalKind, nil +} + +func (s *BloodhoundDB) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { + var envPrincipalKinds model.SchemaEnvironmentPrincipalKinds + + if result := s.db.WithContext(ctx).Raw(` + SELECT environment_id, principal_kind, created_at + FROM schema_environments_principal_kinds + WHERE environment_id = ?`, + environmentId).Scan(&envPrincipalKinds); result.Error != nil { + return nil, CheckError(result) + } + + return envPrincipalKinds, nil +} + +func (s *BloodhoundDB) DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error { + if result := s.db.WithContext(ctx).Exec(` + DELETE FROM schema_environments_principal_kinds + WHERE environment_id = ? AND principal_kind = ?`, + environmentId, principalKind); result.Error != nil { + return CheckError(result) + } else if result.RowsAffected == 0 { + return ErrNotFound + } + + return nil +} + func parseFiltersAndPagination(filters model.Filters, sort model.Sort, skip, limit int) (FilterAndPagination, error) { var ( filtersAndPagination FilterAndPagination diff --git a/cmd/api/src/database/graphschema_integration_test.go b/cmd/api/src/database/graphschema_integration_test.go index 6cf19ce078..a55588f26f 100644 --- a/cmd/api/src/database/graphschema_integration_test.go +++ b/cmd/api/src/database/graphschema_integration_test.go @@ -2162,3 +2162,257 @@ func TestDeleteRemediation(t *testing.T) { }) } } + +func TestCreateSchemaEnvironmentPrincipalKind(t *testing.T) { + type args struct { + environmentId int32 + principalKind int32 + } + type want struct { + res model.SchemaEnvironmentPrincipalKind + err error + } + tests := []struct { + name string + setup func() IntegrationTestSuite + args args + want want + }{ + { + name: "Success: schema environment principal kind created", + setup: func() IntegrationTestSuite { + t.Helper() + testSuite := setupIntegrationTestSuite(t) + + _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "EnvPrincipalKindExt", "Env Principal Kind Extension", "v1.0.0") + require.NoError(t, err) + + _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + require.NoError(t, err) + + return testSuite + }, + args: args{ + environmentId: 1, + principalKind: 1, + }, + want: want{ + res: model.SchemaEnvironmentPrincipalKind{ + EnvironmentId: 1, + PrincipalKind: 1, + }, + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + testSuite := testCase.setup() + defer teardownIntegrationTestSuite(t, &testSuite) + + result, err := testSuite.BHDatabase.CreateSchemaEnvironmentPrincipalKind(testSuite.Context, testCase.args.environmentId, testCase.args.principalKind) + if testCase.want.err != nil { + assert.ErrorIs(t, err, testCase.want.err) + } else { + assert.NoError(t, err) + assert.Equal(t, testCase.want.res.EnvironmentId, result.EnvironmentId) + assert.Equal(t, testCase.want.res.PrincipalKind, result.PrincipalKind) + } + }) + } +} + +func TestGetSchemaEnvironmentPrincipalKindsByEnvironmentId(t *testing.T) { + type args struct { + environmentId int32 + } + type want struct { + count int + err error + } + tests := []struct { + name string + setup func() IntegrationTestSuite + args args + want want + }{ + { + name: "Success: get schema environment principal kinds by environment id", + setup: func() IntegrationTestSuite { + t.Helper() + testSuite := setupIntegrationTestSuite(t) + + _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "GetEnvPrincipalKindExt", "Get Env Principal Kind Extension", "v1.0.0") + require.NoError(t, err) + + _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + require.NoError(t, err) + + _, err = testSuite.BHDatabase.CreateSchemaEnvironmentPrincipalKind(testSuite.Context, 1, 1) + require.NoError(t, err) + + _, err = testSuite.BHDatabase.CreateSchemaEnvironmentPrincipalKind(testSuite.Context, 1, 2) + require.NoError(t, err) + + return testSuite + }, + args: args{ + environmentId: 1, + }, + want: want{ + count: 2, + }, + }, + { + name: "Success: returns empty slice when no principal kinds exist", + setup: func() IntegrationTestSuite { + return setupIntegrationTestSuite(t) + }, + args: args{ + environmentId: 9999, + }, + want: want{ + count: 0, + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + testSuite := testCase.setup() + defer teardownIntegrationTestSuite(t, &testSuite) + + result, err := testSuite.BHDatabase.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(testSuite.Context, testCase.args.environmentId) + if testCase.want.err != nil { + assert.ErrorIs(t, err, testCase.want.err) + } else { + assert.NoError(t, err) + assert.Len(t, result, testCase.want.count) + } + }) + } +} + +func TestDeleteSchemaEnvironmentPrincipalKind(t *testing.T) { + type args struct { + environmentId int32 + principalKind int32 + } + type want struct { + err error + } + tests := []struct { + name string + setup func() IntegrationTestSuite + args args + want want + }{ + { + name: "Success: delete schema environment principal kind", + setup: func() IntegrationTestSuite { + t.Helper() + testSuite := setupIntegrationTestSuite(t) + + _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "DeleteEnvPrincipalKindExt", "Delete Env Principal Kind Extension", "v1.0.0") + require.NoError(t, err) + + _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + require.NoError(t, err) + + _, err = testSuite.BHDatabase.CreateSchemaEnvironmentPrincipalKind(testSuite.Context, 1, 1) + require.NoError(t, err) + + return testSuite + }, + args: args{ + environmentId: 1, + principalKind: 1, + }, + want: want{ + err: nil, + }, + }, + { + name: "Fail: schema environment principal kind not found", + setup: func() IntegrationTestSuite { + return setupIntegrationTestSuite(t) + }, + args: args{ + environmentId: 9999, + principalKind: 9999, + }, + want: want{ + err: database.ErrNotFound, + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + testSuite := testCase.setup() + defer teardownIntegrationTestSuite(t, &testSuite) + + err := testSuite.BHDatabase.DeleteSchemaEnvironmentPrincipalKind(testSuite.Context, testCase.args.environmentId, testCase.args.principalKind) + if testCase.want.err != nil { + assert.ErrorIs(t, err, testCase.want.err) + } else { + assert.NoError(t, err) + result, err := testSuite.BHDatabase.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(testSuite.Context, testCase.args.environmentId) + assert.NoError(t, err) + assert.Len(t, result, 0) + } + }) + } +} + +func TestDeleteSchemaExtension_CascadeDeletesAllDependents(t *testing.T) { + t.Parallel() + testSuite := setupIntegrationTestSuite(t) + defer teardownIntegrationTestSuite(t, &testSuite) + + extension, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "CascadeTestExtension", "Cascade Test Extension", "v1.0.0") + require.NoError(t, err) + + nodeKind, err := testSuite.BHDatabase.CreateGraphSchemaNodeKind(testSuite.Context, "CascadeTestNodeKind", extension.ID, "Cascade Test Node Kind", "Test description", false, "fa-test", "#000000") + require.NoError(t, err) + + property, err := testSuite.BHDatabase.CreateGraphSchemaProperty(testSuite.Context, extension.ID, "cascade_test_property", "Cascade Test Property", "string", "Test description") + require.NoError(t, err) + + edgeKind, err := testSuite.BHDatabase.CreateGraphSchemaEdgeKind(testSuite.Context, "CascadeTestEdgeKind", extension.ID, "Test description", true) + require.NoError(t, err) + + environment, err := testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, extension.ID, nodeKind.ID, nodeKind.ID) + require.NoError(t, err) + + relationshipFinding, err := testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, extension.ID, edgeKind.ID, environment.ID, "CascadeTestFinding", "Cascade Test Finding") + require.NoError(t, err) + + _, err = testSuite.BHDatabase.CreateRemediation(testSuite.Context, relationshipFinding.ID, "Short desc", "Long desc", "Short remediation", "Long remediation") + require.NoError(t, err) + + _, err = testSuite.BHDatabase.CreateSchemaEnvironmentPrincipalKind(testSuite.Context, environment.ID, nodeKind.ID) + require.NoError(t, err) + + err = testSuite.BHDatabase.DeleteGraphSchemaExtension(testSuite.Context, extension.ID) + require.NoError(t, err) + + _, err = testSuite.BHDatabase.GetGraphSchemaNodeKindById(testSuite.Context, nodeKind.ID) + assert.ErrorIs(t, err, database.ErrNotFound) + + _, err = testSuite.BHDatabase.GetGraphSchemaPropertyById(testSuite.Context, property.ID) + assert.ErrorIs(t, err, database.ErrNotFound) + + _, err = testSuite.BHDatabase.GetGraphSchemaEdgeKindById(testSuite.Context, edgeKind.ID) + assert.ErrorIs(t, err, database.ErrNotFound) + + _, err = testSuite.BHDatabase.GetSchemaEnvironmentById(testSuite.Context, environment.ID) + assert.ErrorIs(t, err, database.ErrNotFound) + + _, err = testSuite.BHDatabase.GetSchemaRelationshipFindingById(testSuite.Context, relationshipFinding.ID) + assert.ErrorIs(t, err, database.ErrNotFound) + + _, err = testSuite.BHDatabase.GetRemediationByFindingId(testSuite.Context, relationshipFinding.ID) + assert.ErrorIs(t, err, database.ErrNotFound) + + principalKinds, err := testSuite.BHDatabase.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(testSuite.Context, environment.ID) + assert.NoError(t, err) + assert.Len(t, principalKinds, 0) +} diff --git a/cmd/api/src/database/migration/migrations/v8.5.0.sql b/cmd/api/src/database/migration/migrations/v8.5.0.sql index 204da86fec..91e6636af6 100644 --- a/cmd/api/src/database/migration/migrations/v8.5.0.sql +++ b/cmd/api/src/database/migration/migrations/v8.5.0.sql @@ -148,6 +148,7 @@ CREATE INDEX IF NOT EXISTS idx_schema_remediations_content_type ON schema_remedi CREATE TABLE IF NOT EXISTS schema_environments_principal_kinds ( environment_id INTEGER NOT NULL REFERENCES schema_environments(id) ON DELETE CASCADE, principal_kind INTEGER NOT NULL REFERENCES kind(id), + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT current_timestamp, PRIMARY KEY(environment_id, principal_kind) ); diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index 22f96083e9..cf16fece30 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -585,6 +585,21 @@ func (mr *MockDatabaseMockRecorder) CreateSchemaEnvironment(ctx, extensionId, en return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironment", reflect.TypeOf((*MockDatabase)(nil).CreateSchemaEnvironment), ctx, extensionId, environmentKindId, sourceKindId) } +// CreateSchemaEnvironmentPrincipalKind mocks base method. +func (m *MockDatabase) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) + ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKind) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateSchemaEnvironmentPrincipalKind indicates an expected call of CreateSchemaEnvironmentPrincipalKind. +func (mr *MockDatabaseMockRecorder) CreateSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockDatabase)(nil).CreateSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) +} + // CreateSchemaRelationshipFinding mocks base method. func (m *MockDatabase) CreateSchemaRelationshipFinding(ctx context.Context, extensionId, relationshipKindId, environmentId int32, name, displayName string) (model.SchemaRelationshipFinding, error) { m.ctrl.T.Helper() @@ -986,6 +1001,20 @@ func (mr *MockDatabaseMockRecorder) DeleteSchemaEnvironment(ctx, environmentId a return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironment", reflect.TypeOf((*MockDatabase)(nil).DeleteSchemaEnvironment), ctx, environmentId) } +// DeleteSchemaEnvironmentPrincipalKind mocks base method. +func (m *MockDatabase) DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteSchemaEnvironmentPrincipalKind indicates an expected call of DeleteSchemaEnvironmentPrincipalKind. +func (mr *MockDatabaseMockRecorder) DeleteSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockDatabase)(nil).DeleteSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) +} + // DeleteSchemaRelationshipFinding mocks base method. func (m *MockDatabase) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error { m.ctrl.T.Helper() @@ -2158,6 +2187,21 @@ func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentById(ctx, environmentId return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentById", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentById), ctx, environmentId) } +// GetSchemaEnvironmentPrincipalKindsByEnvironmentId mocks base method. +func (m *MockDatabase) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", ctx, environmentId) + ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKinds) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSchemaEnvironmentPrincipalKindsByEnvironmentId indicates an expected call of GetSchemaEnvironmentPrincipalKindsByEnvironmentId. +func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentPrincipalKindsByEnvironmentId), ctx, environmentId) +} + // GetSchemaEnvironments mocks base method. func (m *MockDatabase) GetSchemaEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/model/graphschema.go b/cmd/api/src/model/graphschema.go index 38c1ac2867..397ba77afc 100644 --- a/cmd/api/src/model/graphschema.go +++ b/cmd/api/src/model/graphschema.go @@ -136,6 +136,18 @@ func (Remediation) TableName() string { return "schema_remediations" } +type SchemaEnvironmentPrincipalKinds []SchemaEnvironmentPrincipalKind + +type SchemaEnvironmentPrincipalKind struct { + EnvironmentId int32 `json:"environment_id"` + PrincipalKind int32 `json:"principal_kind"` + CreatedAt time.Time `json:"created_at"` +} + +func (SchemaEnvironmentPrincipalKind) TableName() string { + return "schema_environments_principal_kinds" +} + func (GraphSchemaEdgeKind) ValidFilters() map[string][]FilterOperator { return ValidFilters{ "is_traversable": {Equals, NotEquals}, From e76b52dd1518c97764c6e08a2c5247f3abfd5397 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Fri, 9 Jan 2026 09:04:13 -0600 Subject: [PATCH 11/36] pretty decent sized refactor to move from service layer to database layer + integration tests --- cmd/api/src/database/db.go | 40 +- cmd/api/src/database/graphschema.go | 21 + cmd/api/src/database/kind.go | 14 +- .../database/migration/migrations/v8.5.0.sql | 2 +- cmd/api/src/database/mocks/db.go | 74 +++ cmd/api/src/database/sourcekinds.go | 10 +- .../src/database/upsert_schema_environment.go | 160 ++++++ ...ert_schema_environment_integration_test.go | 305 +++++++++++ .../services/opengraphschema/environment.go | 162 ------ .../opengraphschema/environment_test.go | 514 ------------------ .../src/services/opengraphschema/extension.go | 21 +- .../opengraphschema/mocks/opengraphschema.go | 182 +------ .../opengraphschema/opengraphschema.go | 27 +- .../opengraphschema/opengraphschema_test.go | 499 +++++++++++++++++ 14 files changed, 1095 insertions(+), 936 deletions(-) create mode 100644 cmd/api/src/database/upsert_schema_environment.go create mode 100644 cmd/api/src/database/upsert_schema_environment_integration_test.go delete mode 100644 cmd/api/src/services/opengraphschema/environment.go delete mode 100644 cmd/api/src/services/opengraphschema/environment_test.go create mode 100644 cmd/api/src/services/opengraphschema/opengraphschema_test.go diff --git a/cmd/api/src/database/db.go b/cmd/api/src/database/db.go index c8ea2a713a..acee7c7914 100644 --- a/cmd/api/src/database/db.go +++ b/cmd/api/src/database/db.go @@ -189,6 +189,9 @@ type Database interface { // OpenGraph Schema OpenGraphSchema + + // Kind + Kind } type BloodhoundDB struct { @@ -238,43 +241,6 @@ func (s *BloodhoundDB) Transaction(ctx context.Context, fn func(tx *BloodhoundDB }, opts...) } -/* Manual Transaction Control -The following methods provide manual control over transactions as an alternative to the automatic Transaction method above. -Use these when you need explicit control over when to commit or rollback. -- BeginTransaction -- Commit -- Rollback -*/ - -// BeginTransaction starts a new database transaction and returns a transactional-aware connection of BloodhoundDB. -func (s *BloodhoundDB) BeginTransaction(ctx context.Context, opts ...*sql.TxOptions) (*BloodhoundDB, error) { - tx := s.db.WithContext(ctx).Begin(opts...) - if tx.Error != nil { - return nil, fmt.Errorf("error beginning transaction: %w", tx.Error) - } - - return &BloodhoundDB{ - db: tx, - idResolver: s.idResolver, - }, nil -} - -// Commit commits the transaction and releases the database connection back to the pool. -func (s *BloodhoundDB) Commit() error { - if err := s.db.Commit().Error; err != nil { - return fmt.Errorf("error committing transaction: %w", err) - } - return nil -} - -// Rollback rolls back the transaction and releases the database connection back to the pool. -func (s *BloodhoundDB) Rollback() error { - if err := s.db.Rollback().Error; err != nil { - return fmt.Errorf("error rolling back transaction: %w", err) - } - return nil -} - func OpenDatabase(connection string) (*gorm.DB, error) { gormConfig := &gorm.Config{ Logger: &GormLogAdapter{ diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index 3801cc5248..b7e81099be 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -52,6 +52,7 @@ type OpenGraphSchema interface { GetGraphSchemaEdgeKindsWithSchemaName(ctx context.Context, edgeKindFilters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaEdgeKindsWithNamedSchema, int, error) CreateSchemaEnvironment(ctx context.Context, extensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) + GetSchemaEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) GetSchemaEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) GetSchemaEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) DeleteSchemaEnvironment(ctx context.Context, environmentId int32) error @@ -59,6 +60,10 @@ type OpenGraphSchema interface { CreateSchemaRelationshipFinding(ctx context.Context, extensionId int32, relationshipKindId int32, environmentId int32, name string, displayName string) (model.SchemaRelationshipFinding, error) GetSchemaRelationshipFindingById(ctx context.Context, findingId int32) (model.SchemaRelationshipFinding, error) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error + + CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) + GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) + DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error } const DuplicateKeyValueErrorString = "duplicate key value violates unique constraint" @@ -546,6 +551,22 @@ func (s *BloodhoundDB) GetSchemaEnvironmentById(ctx context.Context, environment return schemaEnvironment, nil } +// GetSchemaEnvironmentByKinds - retrieves an environment by its environment kind and source kind. +func (s *BloodhoundDB) GetSchemaEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { + var env model.SchemaEnvironment + + if result := s.db.WithContext(ctx).Raw( + "SELECT * FROM schema_environments WHERE environment_kind_id = ? AND source_kind_id = ? AND deleted_at IS NULL", + environmentKindId, sourceKindId, + ).Scan(&env); result.Error != nil { + return model.SchemaEnvironment{}, CheckError(result) + } else if result.RowsAffected == 0 { + return model.SchemaEnvironment{}, ErrNotFound + } + + return env, nil +} + // DeleteSchemaEnvironment - deletes a schema environment by id. func (s *BloodhoundDB) DeleteSchemaEnvironment(ctx context.Context, environmentId int32) error { var schemaEnvironment model.SchemaEnvironment diff --git a/cmd/api/src/database/kind.go b/cmd/api/src/database/kind.go index 87632f6ae3..0487f28375 100644 --- a/cmd/api/src/database/kind.go +++ b/cmd/api/src/database/kind.go @@ -21,6 +21,10 @@ import ( "github.com/specterops/bloodhound/cmd/api/src/model" ) +type Kind interface { + GetKindByName(ctx context.Context, name string) (model.Kind, error) +} + func (s *BloodhoundDB) GetKindByName(ctx context.Context, name string) (model.Kind, error) { const query = ` SELECT id, name @@ -29,8 +33,14 @@ func (s *BloodhoundDB) GetKindByName(ctx context.Context, name string) (model.Ki ` var kind model.Kind - if err := s.db.Raw(query, name).Scan(&kind).Error; err != nil { - return model.Kind{}, err + result := s.db.WithContext(ctx).Raw(query, name).Scan(&kind) + + if result.Error != nil { + return model.Kind{}, result.Error + } + + if result.RowsAffected == 0 || kind.ID == 0 { + return model.Kind{}, ErrNotFound } return kind, nil diff --git a/cmd/api/src/database/migration/migrations/v8.5.0.sql b/cmd/api/src/database/migration/migrations/v8.5.0.sql index 204da86fec..e73d008e6e 100644 --- a/cmd/api/src/database/migration/migrations/v8.5.0.sql +++ b/cmd/api/src/database/migration/migrations/v8.5.0.sql @@ -93,7 +93,7 @@ CREATE TABLE IF NOT EXISTS schema_environments ( id SERIAL, schema_extension_id INTEGER NOT NULL REFERENCES schema_extensions(id) ON DELETE CASCADE, environment_kind_id INTEGER NOT NULL REFERENCES kind(id), - source_kind_id INTEGER NOT NULL REFERENCES kind(id), + source_kind_id INTEGER NOT NULL REFERENCES source_kinds(id), PRIMARY KEY (id), created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT current_timestamp, updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT current_timestamp, diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index ae5c1b5e1c..dd57b73085 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -570,6 +570,21 @@ func (mr *MockDatabaseMockRecorder) CreateSchemaEnvironment(ctx, extensionId, en return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironment", reflect.TypeOf((*MockDatabase)(nil).CreateSchemaEnvironment), ctx, extensionId, environmentKindId, sourceKindId) } +// CreateSchemaEnvironmentPrincipalKind mocks base method. +func (m *MockDatabase) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) + ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKind) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateSchemaEnvironmentPrincipalKind indicates an expected call of CreateSchemaEnvironmentPrincipalKind. +func (mr *MockDatabaseMockRecorder) CreateSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockDatabase)(nil).CreateSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) +} + // CreateSchemaRelationshipFinding mocks base method. func (m *MockDatabase) CreateSchemaRelationshipFinding(ctx context.Context, extensionId, relationshipKindId, environmentId int32, name, displayName string) (model.SchemaRelationshipFinding, error) { m.ctrl.T.Helper() @@ -957,6 +972,20 @@ func (mr *MockDatabaseMockRecorder) DeleteSchemaEnvironment(ctx, environmentId a return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironment", reflect.TypeOf((*MockDatabase)(nil).DeleteSchemaEnvironment), ctx, environmentId) } +// DeleteSchemaEnvironmentPrincipalKind mocks base method. +func (m *MockDatabase) DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteSchemaEnvironmentPrincipalKind indicates an expected call of DeleteSchemaEnvironmentPrincipalKind. +func (mr *MockDatabaseMockRecorder) DeleteSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockDatabase)(nil).DeleteSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) +} + // DeleteSchemaRelationshipFinding mocks base method. func (m *MockDatabase) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error { m.ctrl.T.Helper() @@ -1889,6 +1918,21 @@ func (mr *MockDatabaseMockRecorder) GetInstallation(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstallation", reflect.TypeOf((*MockDatabase)(nil).GetInstallation), ctx) } +// GetKindByName mocks base method. +func (m *MockDatabase) GetKindByName(ctx context.Context, name string) (model.Kind, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetKindByName", ctx, name) + ret0, _ := ret[0].(model.Kind) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetKindByName indicates an expected call of GetKindByName. +func (mr *MockDatabaseMockRecorder) GetKindByName(ctx, name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKindByName", reflect.TypeOf((*MockDatabase)(nil).GetKindByName), ctx, name) +} + // GetLatestAssetGroupCollection mocks base method. func (m *MockDatabase) GetLatestAssetGroupCollection(ctx context.Context, assetGroupID int32) (model.AssetGroupCollection, error) { m.ctrl.T.Helper() @@ -2114,6 +2158,36 @@ func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentById(ctx, environmentId return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentById", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentById), ctx, environmentId) } +// GetSchemaEnvironmentByKinds mocks base method. +func (m *MockDatabase) GetSchemaEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSchemaEnvironmentByKinds", ctx, environmentKindId, sourceKindId) + ret0, _ := ret[0].(model.SchemaEnvironment) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSchemaEnvironmentByKinds indicates an expected call of GetSchemaEnvironmentByKinds. +func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentByKinds(ctx, environmentKindId, sourceKindId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentByKinds", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentByKinds), ctx, environmentKindId, sourceKindId) +} + +// GetSchemaEnvironmentPrincipalKindsByEnvironmentId mocks base method. +func (m *MockDatabase) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", ctx, environmentId) + ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKinds) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSchemaEnvironmentPrincipalKindsByEnvironmentId indicates an expected call of GetSchemaEnvironmentPrincipalKindsByEnvironmentId. +func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentPrincipalKindsByEnvironmentId), ctx, environmentId) +} + // GetSchemaEnvironments mocks base method. func (m *MockDatabase) GetSchemaEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/database/sourcekinds.go b/cmd/api/src/database/sourcekinds.go index d69ed7dcd4..6aa76b8a24 100644 --- a/cmd/api/src/database/sourcekinds.go +++ b/cmd/api/src/database/sourcekinds.go @@ -108,8 +108,14 @@ func (s *BloodhoundDB) GetSourceKindByName(ctx context.Context, name string) (So } var raw rawSourceKind - if err := s.db.Raw(query, name).Scan(&raw).Error; err != nil { - return SourceKind{}, err + result := s.db.WithContext(ctx).Raw(query, name).Scan(&raw) + + if result.Error != nil { + return SourceKind{}, result.Error + } + + if result.RowsAffected == 0 || raw.ID == 0 { + return SourceKind{}, ErrNotFound } kind := SourceKind{ diff --git a/cmd/api/src/database/upsert_schema_environment.go b/cmd/api/src/database/upsert_schema_environment.go new file mode 100644 index 0000000000..f19e39cd28 --- /dev/null +++ b/cmd/api/src/database/upsert_schema_environment.go @@ -0,0 +1,160 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package database + +import ( + "context" + "errors" + "fmt" + + "github.com/specterops/bloodhound/cmd/api/src/model" + "github.com/specterops/dawgs/graph" +) + +// UpsertSchemaEnvironmentWithPrincipalKinds creates or updates an environment with its principal kinds. +// If an environment with the same environment kind and source kind exists, it will be replaced. +// +// NOTE: This method should be called within a transaction. The caller is responsible for transaction management. +func (s *BloodhoundDB) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environmentKind string, sourceKind string, principalKinds []string) error { + environment := model.SchemaEnvironment{ + SchemaExtensionId: schemaExtensionId, + } + + envKind, err := s.validateAndTranslateEnvironmentKind(ctx, environmentKind) + if err != nil { + return err + } + + sourceKindID, err := s.validateAndTranslateSourceKind(ctx, sourceKind) + if err != nil { + return err + } + + translatedPrincipalKinds, err := s.validateAndTranslatePrincipalKinds(ctx, principalKinds) + if err != nil { + return err + } + + environment.EnvironmentKindId = int32(envKind.ID) + environment.SourceKindId = sourceKindID + + envID, err := s.upsertSchemaEnvironment(ctx, environment) + if err != nil { + return fmt.Errorf("error upserting schema environment: %w", err) + } + + if err := s.upsertPrincipalKinds(ctx, envID, translatedPrincipalKinds); err != nil { + return fmt.Errorf("error upserting principal kinds: %w", err) + } + + return nil +} + +// validateAndTranslateEnvironmentKind validates that the environment kind exists in the kinds table. +func (s *BloodhoundDB) validateAndTranslateEnvironmentKind(ctx context.Context, environmentKindName string) (model.Kind, error) { + if envKind, err := s.GetKindByName(ctx, environmentKindName); err != nil && !errors.Is(err, ErrNotFound) { + return model.Kind{}, fmt.Errorf("error retrieving environment kind '%s': %w", environmentKindName, err) + } else if errors.Is(err, ErrNotFound) { + return model.Kind{}, fmt.Errorf("environment kind '%s' not found", environmentKindName) + } else { + return envKind, nil + } +} + +// validateAndTranslateSourceKind validates that the source kind exists in the source_kinds table. +// If not found, it registers the source kind and returns its ID so it can be added to the Environment object. +func (s *BloodhoundDB) validateAndTranslateSourceKind(ctx context.Context, sourceKindName string) (int32, error) { + if sourceKind, err := s.GetSourceKindByName(ctx, sourceKindName); err != nil && !errors.Is(err, ErrNotFound) { + return 0, fmt.Errorf("error retrieving source kind '%s': %w", sourceKindName, err) + } else if err == nil { + return int32(sourceKind.ID), nil + } + + // If source kind is not found, register it. If it exists and is inactive, it will automatically update as active. + kindType := graph.StringKind(sourceKindName) + if err := s.RegisterSourceKind(ctx)(kindType); err != nil { + return 0, fmt.Errorf("error registering source kind '%s': %w", sourceKindName, err) + } + + if sourceKind, err := s.GetSourceKindByName(ctx, sourceKindName); err != nil { + return 0, fmt.Errorf("error retrieving newly registered source kind '%s': %w", sourceKindName, err) + } else { + return int32(sourceKind.ID), nil + } +} + +// validateAndTranslatePrincipalKinds ensures all principalKinds exist in the kinds table. +// It also translates them to IDs so they can be upserted into the database. +func (s *BloodhoundDB) validateAndTranslatePrincipalKinds(ctx context.Context, principalKindNames []string) ([]model.SchemaEnvironmentPrincipalKind, error) { + principalKinds := make([]model.SchemaEnvironmentPrincipalKind, len(principalKindNames)) + + for i, kindName := range principalKindNames { + if kind, err := s.GetKindByName(ctx, kindName); err != nil && !errors.Is(err, ErrNotFound) { + return nil, fmt.Errorf("error retrieving principal kind by name '%s': %w", kindName, err) + } else if errors.Is(err, ErrNotFound) { + return nil, fmt.Errorf("principal kind '%s' not found", kindName) + } else { + principalKinds[i] = model.SchemaEnvironmentPrincipalKind{ + PrincipalKind: int32(kind.ID), + } + } + } + + return principalKinds, nil +} + +// upsertSchemaEnvironment creates or updates a schema environment. +// If an environment with the given ID exists, it deletes it first before creating the new one. +func (s *BloodhoundDB) upsertSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) (int32, error) { + if existing, err := s.GetSchemaEnvironmentByKinds(ctx, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil && !errors.Is(err, ErrNotFound) { + return 0, fmt.Errorf("error retrieving schema environment: %w", err) + } else if !errors.Is(err, ErrNotFound) { + // Environment exists - delete it first + if err := s.DeleteSchemaEnvironment(ctx, existing.ID); err != nil { + return 0, fmt.Errorf("error deleting schema environment %d: %w", existing.ID, err) + } + } + + // Create Environment + if created, err := s.CreateSchemaEnvironment(ctx, graphSchema.SchemaExtensionId, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil { + return 0, fmt.Errorf("error creating schema environment: %w", err) + } else { + return created.ID, nil + } +} + +// upsertPrincipalKinds deletes all existing principal kinds for an environment and creates new ones. +func (s *BloodhoundDB) upsertPrincipalKinds(ctx context.Context, environmentID int32, principalKinds []model.SchemaEnvironmentPrincipalKind) error { + if existingKinds, err := s.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentID); err != nil && !errors.Is(err, ErrNotFound) { + return fmt.Errorf("error retrieving existing principal kinds for environment %d: %w", environmentID, err) + } else if !errors.Is(err, ErrNotFound) { + // Delete all existing principal kinds + for _, kind := range existingKinds { + if err := s.DeleteSchemaEnvironmentPrincipalKind(ctx, kind.EnvironmentId, kind.PrincipalKind); err != nil { + return fmt.Errorf("error deleting principal kind %d for environment %d: %w", kind.PrincipalKind, kind.EnvironmentId, err) + } + } + } + + // Create the new principal kinds + for _, kind := range principalKinds { + if _, err := s.CreateSchemaEnvironmentPrincipalKind(ctx, environmentID, kind.PrincipalKind); err != nil { + return fmt.Errorf("error creating principal kind %d for environment %d: %w", kind.PrincipalKind, environmentID, err) + } + } + + return nil +} diff --git a/cmd/api/src/database/upsert_schema_environment_integration_test.go b/cmd/api/src/database/upsert_schema_environment_integration_test.go new file mode 100644 index 0000000000..2ba131cd3c --- /dev/null +++ b/cmd/api/src/database/upsert_schema_environment_integration_test.go @@ -0,0 +1,305 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package database_test + +import ( + "context" + "testing" + + "github.com/specterops/bloodhound/cmd/api/src/database" + "github.com/specterops/dawgs/graph" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { + type args struct { + environmentKind string + sourceKind string + principalKinds []string + } + tests := []struct { + name string + setupData func(t *testing.T, db *database.BloodhoundDB) int32 + args args + assert func(t *testing.T, db *database.BloodhoundDB) + expectedError string + }{ + { + name: "Success: Create new environment with principal kinds", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environmentKind: "Tag_Tier_Zero", + sourceKind: "Base", + principalKinds: []string{"Tag_Tier_Zero", "Tag_Owned"}, + }, + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + expectedPrincipalKindNames := []string{"Tag_Tier_Zero", "Tag_Owned"} + + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 1, len(environments)) + + principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + assert.NoError(t, err) + assert.Equal(t, len(expectedPrincipalKindNames), len(principalKinds)) + + expectedKindIDs := make([]int32, len(expectedPrincipalKindNames)) + for i, name := range expectedPrincipalKindNames { + kind, err := db.GetKindByName(context.Background(), name) + assert.NoError(t, err) + expectedKindIDs[i] = int32(kind.ID) + } + + actualKindIDs := make([]int32, len(principalKinds)) + for i, pk := range principalKinds { + assert.Equal(t, environments[0].ID, pk.EnvironmentId) + actualKindIDs[i] = pk.PrincipalKind + } + + assert.ElementsMatch(t, expectedKindIDs, actualKindIDs) + }, + }, + { + name: "Success: Upsert replaces existing environment", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + err = db.Transaction(context.Background(), func(tx *database.BloodhoundDB) error { + return tx.UpsertSchemaEnvironmentWithPrincipalKinds( + context.Background(), + ext.ID, + "Tag_Tier_Zero", + "Base", + []string{"Tag_Owned"}, + ) + }) + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environmentKind: "Tag_Tier_Zero", + sourceKind: "Base", + principalKinds: []string{"Tag_Tier_Zero"}, + }, + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + expectedPrincipalKindNames := []string{"Tag_Tier_Zero"} + + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 1, len(environments), "Should only have one environment (old one deleted)") + + principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + assert.NoError(t, err) + assert.Equal(t, 1, len(principalKinds)) + + expectedKind, err := db.GetKindByName(context.Background(), expectedPrincipalKindNames[0]) + assert.NoError(t, err) + + assert.Equal(t, int32(expectedKind.ID), principalKinds[0].PrincipalKind) + assert.Equal(t, environments[0].ID, principalKinds[0].EnvironmentId) + }, + }, + { + name: "Success: Source kind auto-registers", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environmentKind: "Tag_Tier_Zero", + sourceKind: "NewSource", + principalKinds: []string{"Tag_Tier_Zero"}, + }, + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + sourceKind, err := db.GetSourceKindByName(context.Background(), "NewSource") + assert.NoError(t, err) + assert.Equal(t, graph.StringKind("NewSource"), sourceKind.Name) + + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 1, len(environments)) + assert.Equal(t, int32(sourceKind.ID), environments[0].SourceKindId) + + principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + assert.NoError(t, err) + assert.Equal(t, 1, len(principalKinds)) + }, + }, + { + name: "Error: Environment kind not found", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environmentKind: "NonExistent", + sourceKind: "Base", + principalKinds: []string{}, + }, + expectedError: "environment kind 'NonExistent' not found", + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + // Verify transaction rolled back - no environment created + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 0, len(environments), "No environment should exist after rollback") + }, + }, + { + name: "Error: Principal kind not found", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environmentKind: "Tag_Tier_Zero", + sourceKind: "Base", + principalKinds: []string{"NonExistent"}, + }, + expectedError: "principal kind 'NonExistent' not found", + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + // Verify transaction rolled back - no environment created + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 0, len(environments), "No environment should exist after rollback") + }, + }, + { + name: "Rollback: Partial failure on second principal kind", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environmentKind: "Tag_Tier_Zero", + sourceKind: "Base", + principalKinds: []string{"Tag_Owned", "NonExistent"}, + }, + expectedError: "principal kind 'NonExistent' not found", + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + // Verify transaction rolled back - no environment created + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 0, len(environments), "No environment should exist after rollback") + }, + }, + { + name: "Success: Multiple environments with different combinations", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + err = db.Transaction(context.Background(), func(tx *database.BloodhoundDB) error { + return tx.UpsertSchemaEnvironmentWithPrincipalKinds( + context.Background(), + ext.ID, + "Tag_Tier_Zero", + "Base", + []string{"Tag_Tier_Zero"}, + ) + }) + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environmentKind: "Tag_Owned", + sourceKind: "Base", + principalKinds: []string{"Tag_Owned"}, + }, + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 2, len(environments), "Should have two different environments") + + for _, env := range environments { + principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), env.ID) + assert.NoError(t, err) + assert.Equal(t, 1, len(principalKinds), "Each environment should have one principal kind") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testSuite := setupIntegrationTestSuite(t) + defer teardownIntegrationTestSuite(t, &testSuite) + + extensionID := tt.setupData(t, testSuite.BHDatabase) + + // Wrap the call in a transaction + err := testSuite.BHDatabase.Transaction(context.Background(), func(tx *database.BloodhoundDB) error { + return tx.UpsertSchemaEnvironmentWithPrincipalKinds( + context.Background(), + extensionID, + tt.args.environmentKind, + tt.args.sourceKind, + tt.args.principalKinds, + ) + }) + + if tt.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + if tt.assert != nil { + tt.assert(t, testSuite.BHDatabase) + } + } else { + require.NoError(t, err) + if tt.assert != nil { + tt.assert(t, testSuite.BHDatabase) + } + } + }) + } +} diff --git a/cmd/api/src/services/opengraphschema/environment.go b/cmd/api/src/services/opengraphschema/environment.go deleted file mode 100644 index 662d429637..0000000000 --- a/cmd/api/src/services/opengraphschema/environment.go +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright 2026 Specter Ops, Inc. -// -// Licensed under the Apache License, Version 2.0 -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 -package opengraphschema - -import ( - "context" - "errors" - "fmt" - - v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" - "github.com/specterops/bloodhound/cmd/api/src/database" - "github.com/specterops/bloodhound/cmd/api/src/model" - "github.com/specterops/dawgs/graph" -) - -// UpsertSchemaEnvironmentWithPrincipalKinds takes a slice of environments, validates and translates each environment. -// The translation is used to upsert the environments into the database. -// If an existing environment is found to already exist in the database, the existing environment will be removed and the new one will be uploaded. -func (o *OpenGraphSchemaService) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environments []v2.Environment) error { - for _, env := range environments { - environment := model.SchemaEnvironment{ - SchemaExtensionId: schemaExtensionId, - } - - if updatedEnv, principalKinds, err := o.validateAndTranslateEnvironment(ctx, environment, env); err != nil { - return fmt.Errorf("error validating and translating environment: %w", err) - } else if envID, err := o.upsertSchemaEnvironment(ctx, updatedEnv); err != nil { - return fmt.Errorf("error upserting schema environment: %w", err) - } else if err := o.upsertPrincipalKinds(ctx, envID, principalKinds); err != nil { - return fmt.Errorf("error upserting principal kinds: %w", err) - } - } - - return nil -} - -// validateAndTranslateEnvironment validates that the environment kind, source kind, and the principal kinds exist in the database. -// It is then translated from the API model to the Database model to prepare it for insert. -func (o *OpenGraphSchemaService) validateAndTranslateEnvironment(ctx context.Context, environment model.SchemaEnvironment, env v2.Environment) (model.SchemaEnvironment, []model.SchemaEnvironmentPrincipalKind, error) { - if envKind, err := o.validateAndTranslateEnvironmentKind(ctx, env.EnvironmentKind); err != nil { - return model.SchemaEnvironment{}, nil, err - } else if sourceKindID, err := o.validateAndTranslateSourceKind(ctx, env.SourceKind); err != nil { - return model.SchemaEnvironment{}, nil, err - } else if principalKinds, err := o.validateAndTranslatePrincipalKinds(ctx, env.PrincipalKinds); err != nil { - return model.SchemaEnvironment{}, nil, err - } else { - // Update environment with translated IDs - environment.EnvironmentKindId = int32(envKind.ID) - environment.SourceKindId = sourceKindID - - return environment, principalKinds, nil - } -} - -// validateAndTranslateEnvironmentKind validates that the environment kind exists in the kinds table. -func (o *OpenGraphSchemaService) validateAndTranslateEnvironmentKind(ctx context.Context, environmentKindName string) (model.Kind, error) { - if envKind, err := o.openGraphSchemaRepository.GetKindByName(ctx, environmentKindName); err != nil && !errors.Is(err, database.ErrNotFound) { - return model.Kind{}, fmt.Errorf("error retrieving environment kind '%s': %w", environmentKindName, err) - } else if errors.Is(err, database.ErrNotFound) { - return model.Kind{}, fmt.Errorf("environment kind '%s' not found", environmentKindName) - } else { - return envKind, nil - } -} - -// validateAndTranslateSourceKind validates that the source kind exists in the source_kinds table. -// If not found, it registers the source kind and returns its ID so it can be added to the Environment object. -func (o *OpenGraphSchemaService) validateAndTranslateSourceKind(ctx context.Context, sourceKindName string) (int32, error) { - if sourceKind, err := o.openGraphSchemaRepository.GetSourceKindByName(ctx, sourceKindName); err != nil && !errors.Is(err, database.ErrNotFound) { - return 0, fmt.Errorf("error retrieving source kind '%s': %w", sourceKindName, err) - } else if err == nil { - return int32(sourceKind.ID), nil - } - - // If source kind is not found, register it. If it exists and is inactive, it will automatically update as active. - kindType := graph.StringKind(sourceKindName) - if err := o.openGraphSchemaRepository.RegisterSourceKind(ctx)(kindType); err != nil { - return 0, fmt.Errorf("error registering source kind '%s': %w", sourceKindName, err) - } - - if sourceKind, err := o.openGraphSchemaRepository.GetSourceKindByName(ctx, sourceKindName); err != nil { - return 0, fmt.Errorf("error retrieving newly registered source kind '%s': %w", sourceKindName, err) - } else { - return int32(sourceKind.ID), nil - } -} - -// validateAndTranslatePrincipalKinds ensures all principalKinds exist in the kinds table. -// It also translates them to IDs so they can be upserted into the database. -func (o *OpenGraphSchemaService) validateAndTranslatePrincipalKinds(ctx context.Context, principalKindNames []string) ([]model.SchemaEnvironmentPrincipalKind, error) { - principalKinds := make([]model.SchemaEnvironmentPrincipalKind, len(principalKindNames)) - - for i, kindName := range principalKindNames { - if kind, err := o.openGraphSchemaRepository.GetKindByName(ctx, kindName); err != nil && !errors.Is(err, database.ErrNotFound) { - return nil, fmt.Errorf("error retrieving principal kind by name '%s': %w", kindName, err) - } else if errors.Is(err, database.ErrNotFound) { - return nil, fmt.Errorf("principal kind '%s' not found", kindName) - } else { - principalKinds[i] = model.SchemaEnvironmentPrincipalKind{ - PrincipalKind: int32(kind.ID), - } - } - } - - return principalKinds, nil -} - -// upsertSchemaEnvironment creates or updates a schema environment. -// If an environment with the given ID exists, it deletes it first before creating the new one. -func (o *OpenGraphSchemaService) upsertSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) (int32, error) { - if existing, err := o.openGraphSchemaRepository.GetSchemaEnvironmentById(ctx, graphSchema.ID); err != nil && !errors.Is(err, database.ErrNotFound) { - return 0, fmt.Errorf("error retrieving schema environment id %d: %w", graphSchema.ID, err) - } else if !errors.Is(err, database.ErrNotFound) { - // Environment exists - delete it first - if err := o.openGraphSchemaRepository.DeleteSchemaEnvironment(ctx, existing.ID); err != nil { - return 0, fmt.Errorf("error deleting schema environment %d: %w", existing.ID, err) - } - } - - // Create Environment - if created, err := o.openGraphSchemaRepository.CreateSchemaEnvironment(ctx, graphSchema.SchemaExtensionId, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil { - return 0, fmt.Errorf("error creating schema environment: %w", err) - } else { - return created.ID, nil - } -} - -// upsertPrincipalKinds deletes all existing principal kinds for an environment and creates new ones. -func (o *OpenGraphSchemaService) upsertPrincipalKinds(ctx context.Context, environmentID int32, principalKinds []model.SchemaEnvironmentPrincipalKind) error { - if existingKinds, err := o.openGraphSchemaRepository.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentID); err != nil && !errors.Is(err, database.ErrNotFound) { - return fmt.Errorf("error retrieving existing principal kinds for environment %d: %w", environmentID, err) - } else if !errors.Is(err, database.ErrNotFound) { - // Delete all existing principal kinds - for _, kind := range existingKinds { - if err := o.openGraphSchemaRepository.DeleteSchemaEnvironmentPrincipalKind(ctx, kind.EnvironmentId, kind.PrincipalKind); err != nil { - return fmt.Errorf("error deleting principal kind %d for environment %d: %w", kind.PrincipalKind, kind.EnvironmentId, err) - } - } - } - - // Create the new principal kinds - for _, kind := range principalKinds { - if _, err := o.openGraphSchemaRepository.CreateSchemaEnvironmentPrincipalKind(ctx, environmentID, kind.PrincipalKind); err != nil { - return fmt.Errorf("error creating principal kind %d for environment %d: %w", kind.PrincipalKind, environmentID, err) - } - } - - return nil -} diff --git a/cmd/api/src/services/opengraphschema/environment_test.go b/cmd/api/src/services/opengraphschema/environment_test.go deleted file mode 100644 index e50362a423..0000000000 --- a/cmd/api/src/services/opengraphschema/environment_test.go +++ /dev/null @@ -1,514 +0,0 @@ -// Copyright 2026 Specter Ops, Inc. -// -// Licensed under the Apache License, Version 2.0 -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 -package opengraphschema_test - -import ( - "context" - "errors" - "testing" - - v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" - "github.com/specterops/bloodhound/cmd/api/src/database" - "github.com/specterops/bloodhound/cmd/api/src/model" - "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" - schemamocks "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema/mocks" - "github.com/specterops/dawgs/graph" - "github.com/stretchr/testify/assert" - "go.uber.org/mock/gomock" -) - -func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { - type mocks struct { - mockOpenGraphSchema *schemamocks.MockOpenGraphSchemaRepository - } - type args struct { - schemaExtensionId int32 - environments []v2.Environment - } - tests := []struct { - name string - mocks mocks - setupMocks func(t *testing.T, mock *mocks) - args args - expected error - }{ - // Validation: Environment Kind - { - name: "Error: openGraphSchemaRepository.GetKindByName environment kind name not found in the database", - args: args{ - schemaExtensionId: int32(1), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{}, database.ErrNotFound) - }, - expected: errors.New("error validating and translating environment: environment kind 'Domain' not found"), - }, - { - name: "Error: openGraphSchemaRepository.GetKindByName failed to retrieve environment kind from database", - args: args{ - schemaExtensionId: int32(1), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{}, errors.New("error")) - }, - expected: errors.New("error validating and translating environment: error retrieving environment kind 'Domain': error"), - }, - // Validation: Source Kind - { - name: "Error: validateAndTranslateSourceKind failed to retrieve source kind from database", - args: args{ - schemaExtensionId: int32(1), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 1}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{}, errors.New("error")) - }, - expected: errors.New("error validating and translating environment: error retrieving source kind 'Base': error"), - }, - { - name: "Error: validateAndTranslateSourceKind source kind name doesn't exist in database, registration fails", - args: args{ - schemaExtensionId: int32(1), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 1}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().RegisterSourceKind(gomock.Any()).Return(func(kind graph.Kind) error { - return errors.New("error") - }) - }, - expected: errors.New("error validating and translating environment: error registering source kind 'Base': error"), - }, - { - name: "Error: validateAndTranslateSourceKind source kind name doesn't exist in database, registration succeeds but fetch fails", - args: args{ - schemaExtensionId: int32(1), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 1}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().RegisterSourceKind(gomock.Any()).Return(func(kind graph.Kind) error { - return nil - }) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{}, errors.New("error")) - }, - expected: errors.New("error validating and translating environment: error retrieving newly registered source kind 'Base': error"), - }, - // Validation: Principal Kind - { - name: "Error: validateAndTranslatePrincipalKinds principal kind not found in database", - args: args{ - schemaExtensionId: int32(1), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User", "InvalidKind"}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 1}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 2}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "InvalidKind").Return(model.Kind{}, database.ErrNotFound) - }, - expected: errors.New("error validating and translating environment: principal kind 'InvalidKind' not found"), - }, - { - name: "Error: validateAndTranslatePrincipalKinds failed to retrieve principal kind from database", - args: args{ - schemaExtensionId: int32(1), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User", "InvalidKind"}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 1}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 2}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "InvalidKind").Return(model.Kind{}, errors.New("error")) - }, - expected: errors.New("error validating and translating environment: error retrieving principal kind by name 'InvalidKind': error"), - }, - // Upsert Schema Environment - { - name: "Error: upsertSchemaEnvironment error retrieving schema environment from database", - args: args{ - schemaExtensionId: int32(3), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, errors.New("error")) - }, - expected: errors.New("error upserting schema environment: error retrieving schema environment id 0: error"), - }, - { - name: "Error: upsertSchemaEnvironment error deleting schema environment", - args: args{ - schemaExtensionId: int32(3), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{ - Serial: model.Serial{ID: 5}, - }, nil) - mocks.mockOpenGraphSchema.EXPECT().DeleteSchemaEnvironment(gomock.Any(), int32(5)).Return(errors.New("error")) - }, - expected: errors.New("error upserting schema environment: error deleting schema environment 5: error"), - }, - { - name: "Error: upsertSchemaEnvironment error creating schema environment after deletion", - args: args{ - schemaExtensionId: int32(3), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{ - Serial: model.Serial{ID: 5}, - }, nil) - mocks.mockOpenGraphSchema.EXPECT().DeleteSchemaEnvironment(gomock.Any(), int32(5)).Return(nil) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{}, errors.New("error")) - }, - expected: errors.New("error upserting schema environment: error creating schema environment: error"), - }, - { - name: "Error: upsertSchemaEnvironment error creating new schema environment", - args: args{ - schemaExtensionId: int32(3), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{}, errors.New("error")) - }, - expected: errors.New("error upserting schema environment: error creating schema environment: error"), - }, - // Upsert Principal Kinds - { - name: "Error: upsertPrincipalKinds error getting principal kinds by environment id", - args: args{ - schemaExtensionId: int32(3), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User"}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - // Validation and translation - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 3}, nil) - - // Environment upsert - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ - Serial: model.Serial{ID: 10}, - }, nil) - - // Principal kinds upsert - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, errors.New("error")) - }, - expected: errors.New("error upserting principal kinds: error retrieving existing principal kinds for environment 10: error"), - }, - { - name: "Error: upsertPrincipalKinds error deleting principal kinds", - args: args{ - schemaExtensionId: int32(3), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User"}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - // Validation and translation - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 3}, nil) - - // Environment upsert - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ - Serial: model.Serial{ID: 10}, - }, nil) - - // Principal kinds upsert - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return([]model.SchemaEnvironmentPrincipalKind{ - { - EnvironmentId: int32(10), - PrincipalKind: int32(5), - }, - }, nil) - mocks.mockOpenGraphSchema.EXPECT().DeleteSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(5)).Return(errors.New("error")) - }, - expected: errors.New("error upserting principal kinds: error deleting principal kind 5 for environment 10: error"), - }, - { - name: "Error: upsertPrincipalKinds error creating principal kinds", - args: args{ - schemaExtensionId: int32(3), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User"}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - // Validation and translation - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 3}, nil) - - // Environment upsert - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ - Serial: model.Serial{ID: 10}, - }, nil) - - // Principal kinds upsert - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(3)).Return(model.SchemaEnvironmentPrincipalKind{}, errors.New("error")) - }, - expected: errors.New("error upserting principal kinds: error creating principal kind 3 for environment 10: error"), - }, - { - name: "Success: Create new environment with principal kinds", - args: args{ - schemaExtensionId: int32(3), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User", "Computer"}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - // Validation and translation - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 4}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Computer").Return(model.Kind{ID: 5}, nil) - - // Environment upsert - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ - Serial: model.Serial{ID: 10}, - }, nil) - - // Principal kinds upsert - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(4)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(5)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) - }, - expected: nil, - }, - { - name: "Success: Create environment with source kind registration", - args: args{ - schemaExtensionId: int32(3), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "NewSource", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - // Validation and translation - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) - // Source kind not found, register it - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "NewSource").Return(database.SourceKind{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().RegisterSourceKind(gomock.Any()).Return(func(kind graph.Kind) error { - return nil - }) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "NewSource").Return(database.SourceKind{ID: 10}, nil) - - // Environment upsert - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(10)).Return(model.SchemaEnvironment{ - Serial: model.Serial{ID: 10}, - }, nil) - - // Principal kinds upsert - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, database.ErrNotFound) - }, - expected: nil, - }, - { - name: "Success: Process multiple environments", - args: args{ - schemaExtensionId: int32(3), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User"}, - }, - { - EnvironmentKind: "AzureAD", - SourceKind: "AzureHound", - PrincipalKinds: []string{"User", "Group"}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - // First environment - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 4}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ - Serial: model.Serial{ID: 10}, - }, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(4)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) - - // Second environment - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "AzureAD").Return(model.Kind{ID: 5}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "AzureHound").Return(database.SourceKind{ID: 6}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 4}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Group").Return(model.Kind{ID: 7}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(5), int32(6)).Return(model.SchemaEnvironment{ - Serial: model.Serial{ID: 11}, - }, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(11)).Return(nil, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(11), int32(4)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(11), int32(7)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) - }, - expected: nil, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - mocks := &mocks{ - mockOpenGraphSchema: schemamocks.NewMockOpenGraphSchemaRepository(ctrl), - } - - tt.setupMocks(t, mocks) - - graphService := opengraphschema.NewOpenGraphSchemaService(mocks.mockOpenGraphSchema) - - err := graphService.UpsertSchemaEnvironmentWithPrincipalKinds(context.Background(), tt.args.schemaExtensionId, tt.args.environments) - if tt.expected != nil { - assert.EqualError(t, err, tt.expected.Error()) - } else { - assert.NoError(t, err) - } - }) - } -} diff --git a/cmd/api/src/services/opengraphschema/extension.go b/cmd/api/src/services/opengraphschema/extension.go index aefcff32e3..76e960f3b4 100644 --- a/cmd/api/src/services/opengraphschema/extension.go +++ b/cmd/api/src/services/opengraphschema/extension.go @@ -23,22 +23,11 @@ import ( ) func (o *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { - tx, err := o.openGraphSchemaRepository.BeginTransaction(ctx) - if err != nil { - return err + for _, env := range req.Environments { + if err := o.openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, env.EnvironmentKind, env.SourceKind, env.PrincipalKinds); err != nil { + return fmt.Errorf("failed to upload environments with principal kinds: %w", err) + } } - // Create service with transaction - transactionalService := &OpenGraphSchemaService{ - openGraphSchemaRepository: tx, - } - - // TODO: Temporary hardcoded extension ID - // Upsert environments with principal kinds - if err := transactionalService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { - tx.Rollback() - return fmt.Errorf("failed to upload environments with principal kinds: %w", err) - } - - return tx.Commit() + return nil } diff --git a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go index e9f00e40dc..9c3e712987 100644 --- a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go @@ -27,12 +27,8 @@ package mocks import ( context "context" - sql "database/sql" reflect "reflect" - database "github.com/specterops/bloodhound/cmd/api/src/database" - model "github.com/specterops/bloodhound/cmd/api/src/model" - graph "github.com/specterops/dawgs/graph" gomock "go.uber.org/mock/gomock" ) @@ -60,182 +56,16 @@ func (m *MockOpenGraphSchemaRepository) EXPECT() *MockOpenGraphSchemaRepositoryM return m.recorder } -// BeginTransaction mocks base method. -func (m *MockOpenGraphSchemaRepository) BeginTransaction(ctx context.Context, opts ...*sql.TxOptions) (*database.BloodhoundDB, error) { +// UpsertSchemaEnvironmentWithPrincipalKinds mocks base method. +func (m *MockOpenGraphSchemaRepository) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environmentKind, sourceKind string, principalKinds []string) error { m.ctrl.T.Helper() - varargs := []any{ctx} - for _, a := range opts { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "BeginTransaction", varargs...) - ret0, _ := ret[0].(*database.BloodhoundDB) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// BeginTransaction indicates an expected call of BeginTransaction. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) BeginTransaction(ctx any, opts ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{ctx}, opts...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTransaction", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).BeginTransaction), varargs...) -} - -// Commit mocks base method. -func (m *MockOpenGraphSchemaRepository) Commit() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Commit") - ret0, _ := ret[0].(error) - return ret0 -} - -// Commit indicates an expected call of Commit. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) Commit() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).Commit)) -} - -// CreateSchemaEnvironment mocks base method. -func (m *MockOpenGraphSchemaRepository) CreateSchemaEnvironment(ctx context.Context, schemaExtensionId, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateSchemaEnvironment", ctx, schemaExtensionId, environmentKindId, sourceKindId) - ret0, _ := ret[0].(model.SchemaEnvironment) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateSchemaEnvironment indicates an expected call of CreateSchemaEnvironment. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) CreateSchemaEnvironment(ctx, schemaExtensionId, environmentKindId, sourceKindId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironment", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).CreateSchemaEnvironment), ctx, schemaExtensionId, environmentKindId, sourceKindId) -} - -// CreateSchemaEnvironmentPrincipalKind mocks base method. -func (m *MockOpenGraphSchemaRepository) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) - ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKind) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateSchemaEnvironmentPrincipalKind indicates an expected call of CreateSchemaEnvironmentPrincipalKind. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) CreateSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).CreateSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) -} - -// DeleteSchemaEnvironment mocks base method. -func (m *MockOpenGraphSchemaRepository) DeleteSchemaEnvironment(ctx context.Context, schemaEnvironmentId int32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteSchemaEnvironment", ctx, schemaEnvironmentId) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteSchemaEnvironment indicates an expected call of DeleteSchemaEnvironment. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) DeleteSchemaEnvironment(ctx, schemaEnvironmentId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironment", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).DeleteSchemaEnvironment), ctx, schemaEnvironmentId) -} - -// DeleteSchemaEnvironmentPrincipalKind mocks base method. -func (m *MockOpenGraphSchemaRepository) DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteSchemaEnvironmentPrincipalKind indicates an expected call of DeleteSchemaEnvironmentPrincipalKind. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) DeleteSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).DeleteSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) -} - -// GetKindByName mocks base method. -func (m *MockOpenGraphSchemaRepository) GetKindByName(ctx context.Context, name string) (model.Kind, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetKindByName", ctx, name) - ret0, _ := ret[0].(model.Kind) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetKindByName indicates an expected call of GetKindByName. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetKindByName(ctx, name any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKindByName", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetKindByName), ctx, name) -} - -// GetSchemaEnvironmentById mocks base method. -func (m *MockOpenGraphSchemaRepository) GetSchemaEnvironmentById(ctx context.Context, schemaEnvironmentId int32) (model.SchemaEnvironment, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaEnvironmentById", ctx, schemaEnvironmentId) - ret0, _ := ret[0].(model.SchemaEnvironment) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaEnvironmentById indicates an expected call of GetSchemaEnvironmentById. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetSchemaEnvironmentById(ctx, schemaEnvironmentId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentById", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetSchemaEnvironmentById), ctx, schemaEnvironmentId) -} - -// GetSchemaEnvironmentPrincipalKindsByEnvironmentId mocks base method. -func (m *MockOpenGraphSchemaRepository) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", ctx, environmentId) - ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKinds) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaEnvironmentPrincipalKindsByEnvironmentId indicates an expected call of GetSchemaEnvironmentPrincipalKindsByEnvironmentId. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetSchemaEnvironmentPrincipalKindsByEnvironmentId), ctx, environmentId) -} - -// GetSourceKindByName mocks base method. -func (m *MockOpenGraphSchemaRepository) GetSourceKindByName(ctx context.Context, name string) (database.SourceKind, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSourceKindByName", ctx, name) - ret0, _ := ret[0].(database.SourceKind) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSourceKindByName indicates an expected call of GetSourceKindByName. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetSourceKindByName(ctx, name any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceKindByName", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetSourceKindByName), ctx, name) -} - -// RegisterSourceKind mocks base method. -func (m *MockOpenGraphSchemaRepository) RegisterSourceKind(ctx context.Context) func(graph.Kind) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RegisterSourceKind", ctx) - ret0, _ := ret[0].(func(graph.Kind) error) - return ret0 -} - -// RegisterSourceKind indicates an expected call of RegisterSourceKind. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) RegisterSourceKind(ctx any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterSourceKind", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).RegisterSourceKind), ctx) -} - -// Rollback mocks base method. -func (m *MockOpenGraphSchemaRepository) Rollback() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Rollback") + ret := m.ctrl.Call(m, "UpsertSchemaEnvironmentWithPrincipalKinds", ctx, schemaExtensionId, environmentKind, sourceKind, principalKinds) ret0, _ := ret[0].(error) return ret0 } -// Rollback indicates an expected call of Rollback. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) Rollback() *gomock.Call { +// UpsertSchemaEnvironmentWithPrincipalKinds indicates an expected call of UpsertSchemaEnvironmentWithPrincipalKinds. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) UpsertSchemaEnvironmentWithPrincipalKinds(ctx, schemaExtensionId, environmentKind, sourceKind, principalKinds any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).Rollback)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertSchemaEnvironmentWithPrincipalKinds", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).UpsertSchemaEnvironmentWithPrincipalKinds), ctx, schemaExtensionId, environmentKind, sourceKind, principalKinds) } diff --git a/cmd/api/src/services/opengraphschema/opengraphschema.go b/cmd/api/src/services/opengraphschema/opengraphschema.go index 21e61bd088..6de81562dc 100644 --- a/cmd/api/src/services/opengraphschema/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/opengraphschema.go @@ -19,36 +19,11 @@ package opengraphschema import ( "context" - "database/sql" - - "github.com/specterops/bloodhound/cmd/api/src/database" - "github.com/specterops/bloodhound/cmd/api/src/model" - "github.com/specterops/dawgs/graph" ) // OpenGraphSchemaRepository - type OpenGraphSchemaRepository interface { - // TX - BeginTransaction(ctx context.Context, opts ...*sql.TxOptions) (*database.BloodhoundDB, error) - Commit() error - Rollback() error - - // Kinds - GetKindByName(ctx context.Context, name string) (model.Kind, error) - - // Environment - CreateSchemaEnvironment(ctx context.Context, schemaExtensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) - GetSchemaEnvironmentById(ctx context.Context, schemaEnvironmentId int32) (model.SchemaEnvironment, error) - DeleteSchemaEnvironment(ctx context.Context, schemaEnvironmentId int32) error - - // Source Kinds - RegisterSourceKind(ctx context.Context) func(sourceKind graph.Kind) error - GetSourceKindByName(ctx context.Context, name string) (database.SourceKind, error) - - // Principal Kinds - CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) - GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) - DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error + UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environmentKind string, sourceKind string, principalKinds []string) error } type OpenGraphSchemaService struct { diff --git a/cmd/api/src/services/opengraphschema/opengraphschema_test.go b/cmd/api/src/services/opengraphschema/opengraphschema_test.go new file mode 100644 index 0000000000..5b2ac311ad --- /dev/null +++ b/cmd/api/src/services/opengraphschema/opengraphschema_test.go @@ -0,0 +1,499 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package opengraphschema_test + +import ( + "context" + "errors" + "testing" + + v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" + "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" + schemamocks "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema/mocks" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { + type mocks struct { + mockOpenGraphSchema *schemamocks.MockOpenGraphSchemaRepository + } + type args struct { + schemaExtensionId int32 + environments []v2.Environment + } + tests := []struct { + name string + setupMocks func(t *testing.T, m *mocks) + args args + expected error + }{ + // UpsertSchemaEnvironmentWithPrincipalKinds + // Validation: Environment Kind + { + name: "Error: environment kind name not found in the database", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{}, + ).Return(errors.New("environment kind 'Domain' not found")) + }, + expected: errors.New("failed to upload environments with principal kinds: environment kind 'Domain' not found"), + }, + { + name: "Error: failed to retrieve environment kind from database", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{}, + ).Return(errors.New("error retrieving environment kind 'Domain': error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error retrieving environment kind 'Domain': error"), + }, + // Validation: Source Kind + { + name: "Error: failed to retrieve source kind from database", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{}, + ).Return(errors.New("error retrieving source kind 'Base': error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error retrieving source kind 'Base': error"), + }, + { + name: "Error: source kind name doesn't exist in database, registration fails", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{}, + ).Return(errors.New("error registering source kind 'Base': error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error registering source kind 'Base': error"), + }, + { + name: "Error: source kind name doesn't exist in database, registration succeeds but fetch fails", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{}, + ).Return(errors.New("error retrieving newly registered source kind 'Base': error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error retrieving newly registered source kind 'Base': error"), + }, + // Validation: Principal Kind + { + name: "Error: principal kind not found in database", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User", "InvalidKind"}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{"User", "InvalidKind"}, + ).Return(errors.New("principal kind 'InvalidKind' not found")) + }, + expected: errors.New("failed to upload environments with principal kinds: principal kind 'InvalidKind' not found"), + }, + { + name: "Error: failed to retrieve principal kind from database", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User", "InvalidKind"}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{"User", "InvalidKind"}, + ).Return(errors.New("error retrieving principal kind by name 'InvalidKind': error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error retrieving principal kind by name 'InvalidKind': error"), + }, + // Upsert Schema Environment + { + name: "Error: error retrieving schema environment from database", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{}, + ).Return(errors.New("error upserting schema environment: error retrieving schema environment id 0: error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error upserting schema environment: error retrieving schema environment id 0: error"), + }, + { + name: "Error: error deleting schema environment", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{}, + ).Return(errors.New("error upserting schema environment: error deleting schema environment 5: error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error upserting schema environment: error deleting schema environment 5: error"), + }, + { + name: "Error: error creating schema environment after deletion", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{}, + ).Return(errors.New("error upserting schema environment: error creating schema environment: error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error upserting schema environment: error creating schema environment: error"), + }, + { + name: "Error: error creating new schema environment", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{}, + ).Return(errors.New("error upserting schema environment: error creating schema environment: error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error upserting schema environment: error creating schema environment: error"), + }, + // Upsert Principal Kinds + { + name: "Error: error getting principal kinds by environment id", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{"User"}, + ).Return(errors.New("error upserting principal kinds: error retrieving existing principal kinds for environment 10: error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error upserting principal kinds: error retrieving existing principal kinds for environment 10: error"), + }, + { + name: "Error: openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds error deleting principal kinds", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{"User"}, + ).Return(errors.New("error upserting principal kinds: error deleting principal kind 5 for environment 10: error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error upserting principal kinds: error deleting principal kind 5 for environment 10: error"), + }, + { + name: "Error: openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds error creating principal kinds", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{"User"}, + ).Return(errors.New("error upserting principal kinds: error creating principal kind 3 for environment 10: error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error upserting principal kinds: error creating principal kind 3 for environment 10: error"), + }, + { + name: "Success: Create new environment with principal kinds", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User", "Computer"}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{"User", "Computer"}, + ).Return(nil) + }, + expected: nil, + }, + { + name: "Success: Create environment with source kind registration", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "NewSource", + PrincipalKinds: []string{}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "NewSource", + []string{}, + ).Return(nil) + }, + expected: nil, + }, + { + name: "Success: Process multiple environments", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, + }, + { + EnvironmentKind: "AzureAD", + SourceKind: "AzureHound", + PrincipalKinds: []string{"User", "Group"}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + // First environment + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{"User"}, + ).Return(nil) + + // Second environment + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "AzureAD", + "AzureHound", + []string{"User", "Group"}, + ).Return(nil) + }, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + m := &mocks{ + mockOpenGraphSchema: schemamocks.NewMockOpenGraphSchemaRepository(ctrl), + } + + tt.setupMocks(t, m) + + service := opengraphschema.NewOpenGraphSchemaService(m.mockOpenGraphSchema) + + err := service.UpsertGraphSchemaExtension(context.Background(), v2.GraphSchemaExtension{ + Environments: tt.args.environments, + }) + + if tt.expected != nil { + assert.EqualError(t, err, tt.expected.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} From 38186bed4c70611d6a277ac85b26b8330329df22 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Fri, 9 Jan 2026 09:31:19 -0600 Subject: [PATCH 12/36] pulled in Conrad's changes --- cmd/api/src/database/graphschema.go | 17 +++++++++++++++++ cmd/api/src/database/mocks/db.go | 3 --- cmd/api/src/model/graphschema.go | 15 +-------------- cmd/ui/src/rendering/utils/dagre/dagre.test.ts | 2 +- cmd/ui/src/rendering/utils/dagre/dagre.ts | 3 ++- .../PrivilegeZones/Details/ObjectCountPanel.tsx | 2 +- 6 files changed, 22 insertions(+), 20 deletions(-) diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index 70613a9769..26173a3224 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -53,6 +53,7 @@ type OpenGraphSchema interface { GetGraphSchemaEdgeKindsWithSchemaName(ctx context.Context, edgeKindFilters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaEdgeKindsWithNamedSchema, int, error) CreateSchemaEnvironment(ctx context.Context, extensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) + GetSchemaEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) GetSchemaEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) GetSchemaEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) DeleteSchemaEnvironment(ctx context.Context, environmentId int32) error @@ -538,6 +539,22 @@ func (s *BloodhoundDB) GetSchemaEnvironments(ctx context.Context) ([]model.Schem return result, CheckError(s.db.WithContext(ctx).Find(&result)) } +// GetSchemaEnvironmentByKinds - retrieves an environment by its environment kind and source kind. +func (s *BloodhoundDB) GetSchemaEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { + var env model.SchemaEnvironment + + if result := s.db.WithContext(ctx).Raw( + "SELECT * FROM schema_environments WHERE environment_kind_id = ? AND source_kind_id = ? AND deleted_at IS NULL", + environmentKindId, sourceKindId, + ).Scan(&env); result.Error != nil { + return model.SchemaEnvironment{}, CheckError(result) + } else if result.RowsAffected == 0 { + return model.SchemaEnvironment{}, ErrNotFound + } + + return env, nil +} + // GetSchemaEnvironmentById - retrieves a schema environment by id. func (s *BloodhoundDB) GetSchemaEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) { var schemaEnvironment model.SchemaEnvironment diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index 4c6f03c883..b911c1ad42 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -2202,7 +2202,6 @@ func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentById(ctx, environmentId return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentById", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentById), ctx, environmentId) } -<<<<<<< HEAD // GetSchemaEnvironmentByKinds mocks base method. func (m *MockDatabase) GetSchemaEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { m.ctrl.T.Helper() @@ -2218,8 +2217,6 @@ func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentByKinds(ctx, environment return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentByKinds", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentByKinds), ctx, environmentKindId, sourceKindId) } -======= ->>>>>>> origin/BED-7076 // GetSchemaEnvironmentPrincipalKindsByEnvironmentId mocks base method. func (m *MockDatabase) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/model/graphschema.go b/cmd/api/src/model/graphschema.go index 44985d3dcc..397ba77afc 100644 --- a/cmd/api/src/model/graphschema.go +++ b/cmd/api/src/model/graphschema.go @@ -16,9 +16,7 @@ package model -import ( - "time" -) +import "time" type GraphSchemaExtensions []GraphSchemaExtension @@ -170,14 +168,3 @@ type GraphSchemaEdgeKindWithNamedSchema struct { } type GraphSchemaEdgeKindsWithNamedSchema []GraphSchemaEdgeKindWithNamedSchema - -type SchemaEnvironmentPrincipalKinds []SchemaEnvironmentPrincipalKind - -type SchemaEnvironmentPrincipalKind struct { - EnvironmentId int32 `json:"environment_id"` - PrincipalKind int32 `json:"principal_kind"` -} - -func (SchemaEnvironmentPrincipalKind) TableName() string { - return "schema_environments_principal_kinds" -} diff --git a/cmd/ui/src/rendering/utils/dagre/dagre.test.ts b/cmd/ui/src/rendering/utils/dagre/dagre.test.ts index 0a6ab82b09..c80538ddf3 100644 --- a/cmd/ui/src/rendering/utils/dagre/dagre.test.ts +++ b/cmd/ui/src/rendering/utils/dagre/dagre.test.ts @@ -16,9 +16,9 @@ import dagre from '@dagrejs/dagre'; import Graph from 'graphology'; +import { getEdgeDataFromKey } from 'src/ducks/graph/utils'; import { NODE_DEFAULT_SIZE, applyNodePositionsFromGraphlibGraph, copySigmaNodesToGraphlibGraph } from './'; import { copySigmaEdgesToGraphlibGraph } from './dagre'; -import { getEdgeDataFromKey } from 'src/ducks/graph/utils'; const sigmaGraph = new Graph(); const graphlibGraph = new dagre.graphlib.Graph(); diff --git a/cmd/ui/src/rendering/utils/dagre/dagre.ts b/cmd/ui/src/rendering/utils/dagre/dagre.ts index c62ac36ca0..4a519fa99a 100644 --- a/cmd/ui/src/rendering/utils/dagre/dagre.ts +++ b/cmd/ui/src/rendering/utils/dagre/dagre.ts @@ -98,7 +98,8 @@ export const copySigmaEdgesToGraphlibGraph = ( ): void => { sigmaGraph.forEachEdge((edge: string) => { const edgeData = getEdgeDataFromKey(edge); - if (edgeData !== null) graphlibGraph.setEdge(edgeData.source, edgeData.target, { label: edgeData.label, points: [] }); + if (edgeData !== null) + graphlibGraph.setEdge(edgeData.source, edgeData.target, { label: edgeData.label, points: [] }); }); }; diff --git a/packages/javascript/bh-shared-ui/src/views/PrivilegeZones/Details/ObjectCountPanel.tsx b/packages/javascript/bh-shared-ui/src/views/PrivilegeZones/Details/ObjectCountPanel.tsx index 8439b6ae57..fafdc11d83 100644 --- a/packages/javascript/bh-shared-ui/src/views/PrivilegeZones/Details/ObjectCountPanel.tsx +++ b/packages/javascript/bh-shared-ui/src/views/PrivilegeZones/Details/ObjectCountPanel.tsx @@ -19,8 +19,8 @@ import { FC } from 'react'; import { useQuery } from 'react-query'; import { NodeIcon } from '../../../components'; import { useEnvironmentIdList } from '../../../hooks'; -import { apiClient } from '../../../utils'; import { ENVIRONMENT_AGGREGATION_SUPPORTED_ROUTES } from '../../../routes'; +import { apiClient } from '../../../utils'; const ObjectCountPanel: FC<{ tagId: string }> = ({ tagId }) => { const environments = useEnvironmentIdList(ENVIRONMENT_AGGREGATION_SUPPORTED_ROUTES, false); From ab74bf194d8f4635121f02669980bd7d5c893335 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Fri, 9 Jan 2026 09:49:28 -0600 Subject: [PATCH 13/36] code rabbit comments --- cmd/api/src/api/v2/opengraphschema.go | 2 +- cmd/api/src/database/kind_integration_test.go | 2 +- .../database/sourcekinds_integration_test.go | 8 ++++---- .../src/database/upsert_schema_environment.go | 2 +- .../src/services/opengraphschema/extension.go | 1 + .../opengraphschema/opengraphschema_test.go | 19 +------------------ 6 files changed, 9 insertions(+), 25 deletions(-) diff --git a/cmd/api/src/api/v2/opengraphschema.go b/cmd/api/src/api/v2/opengraphschema.go index a49248fdf0..24cf755b49 100644 --- a/cmd/api/src/api/v2/opengraphschema.go +++ b/cmd/api/src/api/v2/opengraphschema.go @@ -39,7 +39,7 @@ type Environment struct { PrincipalKinds []string `json:"principalKinds"` } -// TODO: Implement this - barebones in order to test handler. +// TODO: Implement this - skeleton endpoint to simply test the handler. func (s Resources) OpenGraphSchemaIngest(response http.ResponseWriter, request *http.Request) { var ( ctx = request.Context() diff --git a/cmd/api/src/database/kind_integration_test.go b/cmd/api/src/database/kind_integration_test.go index 089db89dc9..8c541dfc55 100644 --- a/cmd/api/src/database/kind_integration_test.go +++ b/cmd/api/src/database/kind_integration_test.go @@ -65,7 +65,7 @@ func TestGetKindByName(t *testing.T) { kind, err := testSuite.BHDatabase.GetKindByName(testSuite.Context, testCase.args.name) if testCase.want.err != nil { - assert.EqualError(t, testCase.want.err, err.Error()) + assert.EqualError(t, err, testCase.want.err.Error()) } else { assert.NoError(t, err) assert.Equal(t, testCase.want.kind, kind) diff --git a/cmd/api/src/database/sourcekinds_integration_test.go b/cmd/api/src/database/sourcekinds_integration_test.go index 329831d2cb..0d8931d504 100644 --- a/cmd/api/src/database/sourcekinds_integration_test.go +++ b/cmd/api/src/database/sourcekinds_integration_test.go @@ -142,7 +142,7 @@ func TestRegisterSourceKind(t *testing.T) { err := testSuite.BHDatabase.RegisterSourceKind(testSuite.Context)(testCase.args.sourceKind) if testCase.want.err != nil { - assert.EqualError(t, testCase.want.err, err.Error()) + assert.EqualError(t, err, testCase.want.err.Error()) } else { assert.NoError(t, err) } @@ -197,7 +197,7 @@ func TestGetSourceKinds(t *testing.T) { sourceKinds, err := testSuite.BHDatabase.GetSourceKinds(testSuite.Context) if testCase.want.err != nil { - assert.EqualError(t, testCase.want.err, err.Error()) + assert.EqualError(t, err, testCase.want.err.Error()) } else { assert.NoError(t, err) assert.Equal(t, testCase.want.sourceKinds, sourceKinds) @@ -247,7 +247,7 @@ func TestGetSourceKindByName(t *testing.T) { sourceKind, err := testSuite.BHDatabase.GetSourceKindByName(testSuite.Context, testCase.args.name) if testCase.want.err != nil { - assert.EqualError(t, testCase.want.err, err.Error()) + assert.EqualError(t, err, testCase.want.err.Error()) } else { assert.NoError(t, err) assert.Equal(t, testCase.want.sourceKind, sourceKind) @@ -423,7 +423,7 @@ func TestDeactivateSourceKindsByName(t *testing.T) { err := testSuite.BHDatabase.DeactivateSourceKindsByName(testSuite.Context, testCase.args.sourceKind) if testCase.want.err != nil { - assert.EqualError(t, testCase.want.err, err.Error()) + assert.EqualError(t, err, testCase.want.err.Error()) } else { assert.NoError(t, err) } diff --git a/cmd/api/src/database/upsert_schema_environment.go b/cmd/api/src/database/upsert_schema_environment.go index f19e39cd28..71b3107c44 100644 --- a/cmd/api/src/database/upsert_schema_environment.go +++ b/cmd/api/src/database/upsert_schema_environment.go @@ -117,7 +117,7 @@ func (s *BloodhoundDB) validateAndTranslatePrincipalKinds(ctx context.Context, p } // upsertSchemaEnvironment creates or updates a schema environment. -// If an environment with the given ID exists, it deletes it first before creating the new one. +// If an environment with the given kinds exists, it deletes it first before creating the new one. func (s *BloodhoundDB) upsertSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) (int32, error) { if existing, err := s.GetSchemaEnvironmentByKinds(ctx, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil && !errors.Is(err, ErrNotFound) { return 0, fmt.Errorf("error retrieving schema environment: %w", err) diff --git a/cmd/api/src/services/opengraphschema/extension.go b/cmd/api/src/services/opengraphschema/extension.go index 76e960f3b4..85d44ba1da 100644 --- a/cmd/api/src/services/opengraphschema/extension.go +++ b/cmd/api/src/services/opengraphschema/extension.go @@ -24,6 +24,7 @@ import ( func (o *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { for _, env := range req.Environments { + // TODO: Update temporary hardcoded extensionID once extension work is complete if err := o.openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, env.EnvironmentKind, env.SourceKind, env.PrincipalKinds); err != nil { return fmt.Errorf("failed to upload environments with principal kinds: %w", err) } diff --git a/cmd/api/src/services/opengraphschema/opengraphschema_test.go b/cmd/api/src/services/opengraphschema/opengraphschema_test.go index 5b2ac311ad..a07253548b 100644 --- a/cmd/api/src/services/opengraphschema/opengraphschema_test.go +++ b/cmd/api/src/services/opengraphschema/opengraphschema_test.go @@ -32,8 +32,7 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { mockOpenGraphSchema *schemamocks.MockOpenGraphSchemaRepository } type args struct { - schemaExtensionId int32 - environments []v2.Environment + environments []v2.Environment } tests := []struct { name string @@ -46,7 +45,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: environment kind name not found in the database", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -70,7 +68,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: failed to retrieve environment kind from database", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -95,7 +92,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: failed to retrieve source kind from database", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -119,7 +115,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: source kind name doesn't exist in database, registration fails", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -143,7 +138,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: source kind name doesn't exist in database, registration succeeds but fetch fails", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -168,7 +162,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: principal kind not found in database", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -192,7 +185,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: failed to retrieve principal kind from database", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -217,7 +209,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: error retrieving schema environment from database", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -241,7 +232,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: error deleting schema environment", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -265,7 +255,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: error creating schema environment after deletion", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -289,7 +278,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: error creating new schema environment", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -314,7 +302,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: error getting principal kinds by environment id", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -338,7 +325,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds error deleting principal kinds", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -386,7 +372,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Success: Create new environment with principal kinds", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -410,7 +395,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Success: Create environment with source kind registration", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -434,7 +418,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Success: Process multiple environments", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", From d63427fcc9682293ad851f29176feb29d8f929a5 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Fri, 9 Jan 2026 09:52:16 -0600 Subject: [PATCH 14/36] updated to add integration flag --- .../src/database/upsert_schema_environment_integration_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cmd/api/src/database/upsert_schema_environment_integration_test.go b/cmd/api/src/database/upsert_schema_environment_integration_test.go index 2ba131cd3c..7207ea0d2b 100644 --- a/cmd/api/src/database/upsert_schema_environment_integration_test.go +++ b/cmd/api/src/database/upsert_schema_environment_integration_test.go @@ -13,6 +13,8 @@ // limitations under the License. // // SPDX-License-Identifier: Apache-2.0 + +//go:build integration package database_test import ( From bad51d25101ed88598b1e320d28602bff9aafeac Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Fri, 9 Jan 2026 09:54:51 -0600 Subject: [PATCH 15/36] missed an arg --- .../src/database/upsert_schema_environment_integration_test.go | 1 + cmd/api/src/services/opengraphschema/opengraphschema_test.go | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/api/src/database/upsert_schema_environment_integration_test.go b/cmd/api/src/database/upsert_schema_environment_integration_test.go index 7207ea0d2b..bd32cce3d2 100644 --- a/cmd/api/src/database/upsert_schema_environment_integration_test.go +++ b/cmd/api/src/database/upsert_schema_environment_integration_test.go @@ -15,6 +15,7 @@ // SPDX-License-Identifier: Apache-2.0 //go:build integration + package database_test import ( diff --git a/cmd/api/src/services/opengraphschema/opengraphschema_test.go b/cmd/api/src/services/opengraphschema/opengraphschema_test.go index a07253548b..b15effc82f 100644 --- a/cmd/api/src/services/opengraphschema/opengraphschema_test.go +++ b/cmd/api/src/services/opengraphschema/opengraphschema_test.go @@ -348,7 +348,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds error creating principal kinds", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", From 6e4ba47c3b7cfdc7a08efb970a6bad58444abb32 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Fri, 9 Jan 2026 15:21:46 -0600 Subject: [PATCH 16/36] peer review changes --- cmd/api/src/services/opengraphschema/extension.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/api/src/services/opengraphschema/extension.go b/cmd/api/src/services/opengraphschema/extension.go index 85d44ba1da..8b16e1e35b 100644 --- a/cmd/api/src/services/opengraphschema/extension.go +++ b/cmd/api/src/services/opengraphschema/extension.go @@ -22,10 +22,10 @@ import ( v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" ) -func (o *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { +func (s *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { for _, env := range req.Environments { // TODO: Update temporary hardcoded extensionID once extension work is complete - if err := o.openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, env.EnvironmentKind, env.SourceKind, env.PrincipalKinds); err != nil { + if err := s.openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, env.EnvironmentKind, env.SourceKind, env.PrincipalKinds); err != nil { return fmt.Errorf("failed to upload environments with principal kinds: %w", err) } } From 535fcb70f51199d1a97b5fb00cad9371ba999f41 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Tue, 13 Jan 2026 12:38:53 -0600 Subject: [PATCH 17/36] majority of logic for findings and remediation but still working on tests --- cmd/api/src/api/v2/opengraphschema.go | 17 ++ cmd/api/src/database/graphschema.go | 19 ++ cmd/api/src/database/mocks/db.go | 15 ++ .../src/database/upsert_schema_environment.go | 14 +- .../src/database/upsert_schema_extension.go | 65 ++++++ ...psert_schema_extension_integration_test.go | 202 ++++++++++++++++++ cmd/api/src/database/upsert_schema_finding.go | 87 ++++++++ .../upsert_schema_finding_integration_test.go | 145 +++++++++++++ .../src/database/upsert_schema_remediation.go | 41 ++++ ...ert_schema_remediation_integration_test.go | 145 +++++++++++++ cmd/api/src/model/kind.go | 2 +- .../src/services/opengraphschema/extension.go | 36 +++- .../opengraphschema/mocks/opengraphschema.go | 13 +- .../opengraphschema/opengraphschema.go | 4 +- 14 files changed, 785 insertions(+), 20 deletions(-) create mode 100644 cmd/api/src/database/upsert_schema_extension.go create mode 100644 cmd/api/src/database/upsert_schema_extension_integration_test.go create mode 100644 cmd/api/src/database/upsert_schema_finding.go create mode 100644 cmd/api/src/database/upsert_schema_finding_integration_test.go create mode 100644 cmd/api/src/database/upsert_schema_remediation.go create mode 100644 cmd/api/src/database/upsert_schema_remediation_integration_test.go diff --git a/cmd/api/src/api/v2/opengraphschema.go b/cmd/api/src/api/v2/opengraphschema.go index 24cf755b49..ac8eeeb532 100644 --- a/cmd/api/src/api/v2/opengraphschema.go +++ b/cmd/api/src/api/v2/opengraphschema.go @@ -31,6 +31,7 @@ type OpenGraphSchemaService interface { type GraphSchemaExtension struct { Environments []Environment `json:"environments"` + Findings []Finding `json:"findings"` } type Environment struct { @@ -39,6 +40,22 @@ type Environment struct { PrincipalKinds []string `json:"principalKinds"` } +type Finding struct { + Name string `json:"name"` + DisplayName string `json:"displayName"` + SourceKind string `json:"sourceKind"` + RelationshipKind string `json:"relationshipKind"` + EnvironmentKind string `json:"environmentKind"` + Remediation Remediation `json:"remediation"` +} + +type Remediation struct { + ShortDescription string `json:"shortDescription"` + LongDescription string `json:"longDescription"` + ShortRemediation string `json:"shortRemediation"` + LongRemediation string `json:"longRemediation"` +} + // TODO: Implement this - skeleton endpoint to simply test the handler. func (s Resources) OpenGraphSchemaIngest(response http.ResponseWriter, request *http.Request) { var ( diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index 26173a3224..a0caf92b83 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -60,12 +60,14 @@ type OpenGraphSchema interface { CreateSchemaRelationshipFinding(ctx context.Context, extensionId int32, relationshipKindId int32, environmentId int32, name string, displayName string) (model.SchemaRelationshipFinding, error) GetSchemaRelationshipFindingById(ctx context.Context, findingId int32) (model.SchemaRelationshipFinding, error) + GetSchemaRelationshipFindingByName(ctx context.Context, name string) (model.SchemaRelationshipFinding, error) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error CreateRemediation(ctx context.Context, findingId int32, shortDescription string, longDescription string, shortRemediation string, longRemediation string) (model.Remediation, error) GetRemediationByFindingId(ctx context.Context, findingId int32) (model.Remediation, error) UpdateRemediation(ctx context.Context, findingId int32, shortDescription string, longDescription string, shortRemediation string, longRemediation string) (model.Remediation, error) DeleteRemediation(ctx context.Context, findingId int32) error + CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error @@ -620,6 +622,23 @@ func (s *BloodhoundDB) GetSchemaRelationshipFindingById(ctx context.Context, fin return finding, nil } +// GetSchemaRelationshipFindingByName - retrieves a schema relationship finding by finding name. +func (s *BloodhoundDB) GetSchemaRelationshipFindingByName(ctx context.Context, name string) (model.SchemaRelationshipFinding, error) { + var finding model.SchemaRelationshipFinding + + if result := s.db.WithContext(ctx).Raw(fmt.Sprintf(` + SELECT id, schema_extension_id, relationship_kind_id, environment_id, name, display_name, created_at + FROM %s WHERE name = ?`, + finding.TableName()), + name).Scan(&finding); result.Error != nil { + return model.SchemaRelationshipFinding{}, CheckError(result) + } else if result.RowsAffected == 0 { + return model.SchemaRelationshipFinding{}, ErrNotFound + } + + return finding, nil +} + // DeleteSchemaRelationshipFinding - deletes a schema relationship finding by id. func (s *BloodhoundDB) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error { var finding model.SchemaRelationshipFinding diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index b911c1ad42..afcf62b39e 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -2262,6 +2262,21 @@ func (mr *MockDatabaseMockRecorder) GetSchemaRelationshipFindingById(ctx, findin return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaRelationshipFindingById", reflect.TypeOf((*MockDatabase)(nil).GetSchemaRelationshipFindingById), ctx, findingId) } +// GetSchemaRelationshipFindingByName mocks base method. +func (m *MockDatabase) GetSchemaRelationshipFindingByName(ctx context.Context, name string) (model.SchemaRelationshipFinding, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSchemaRelationshipFindingByName", ctx, name) + ret0, _ := ret[0].(model.SchemaRelationshipFinding) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSchemaRelationshipFindingByName indicates an expected call of GetSchemaRelationshipFindingByName. +func (mr *MockDatabaseMockRecorder) GetSchemaRelationshipFindingByName(ctx, name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaRelationshipFindingByName", reflect.TypeOf((*MockDatabase)(nil).GetSchemaRelationshipFindingByName), ctx, name) +} + // GetScopeForSavedQuery mocks base method. func (m *MockDatabase) GetScopeForSavedQuery(ctx context.Context, queryID int64, userID uuid.UUID) (database.SavedQueryScopeMap, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/database/upsert_schema_environment.go b/cmd/api/src/database/upsert_schema_environment.go index 71b3107c44..0cef01ef7e 100644 --- a/cmd/api/src/database/upsert_schema_environment.go +++ b/cmd/api/src/database/upsert_schema_environment.go @@ -26,14 +26,12 @@ import ( // UpsertSchemaEnvironmentWithPrincipalKinds creates or updates an environment with its principal kinds. // If an environment with the same environment kind and source kind exists, it will be replaced. -// -// NOTE: This method should be called within a transaction. The caller is responsible for transaction management. func (s *BloodhoundDB) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environmentKind string, sourceKind string, principalKinds []string) error { environment := model.SchemaEnvironment{ SchemaExtensionId: schemaExtensionId, } - envKind, err := s.validateAndTranslateEnvironmentKind(ctx, environmentKind) + envKindID, err := s.validateAndTranslateEnvironmentKind(ctx, environmentKind) if err != nil { return err } @@ -48,7 +46,7 @@ func (s *BloodhoundDB) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Con return err } - environment.EnvironmentKindId = int32(envKind.ID) + environment.EnvironmentKindId = envKindID environment.SourceKindId = sourceKindID envID, err := s.upsertSchemaEnvironment(ctx, environment) @@ -64,13 +62,13 @@ func (s *BloodhoundDB) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Con } // validateAndTranslateEnvironmentKind validates that the environment kind exists in the kinds table. -func (s *BloodhoundDB) validateAndTranslateEnvironmentKind(ctx context.Context, environmentKindName string) (model.Kind, error) { +func (s *BloodhoundDB) validateAndTranslateEnvironmentKind(ctx context.Context, environmentKindName string) (int32, error) { if envKind, err := s.GetKindByName(ctx, environmentKindName); err != nil && !errors.Is(err, ErrNotFound) { - return model.Kind{}, fmt.Errorf("error retrieving environment kind '%s': %w", environmentKindName, err) + return 0, fmt.Errorf("error retrieving environment kind '%s': %w", environmentKindName, err) } else if errors.Is(err, ErrNotFound) { - return model.Kind{}, fmt.Errorf("environment kind '%s' not found", environmentKindName) + return 0, fmt.Errorf("environment kind '%s' not found", environmentKindName) } else { - return envKind, nil + return envKind.ID, nil } } diff --git a/cmd/api/src/database/upsert_schema_extension.go b/cmd/api/src/database/upsert_schema_extension.go new file mode 100644 index 0000000000..2321128f47 --- /dev/null +++ b/cmd/api/src/database/upsert_schema_extension.go @@ -0,0 +1,65 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package database + +import ( + "context" + "fmt" +) + +type EnvironmentInput struct { + EnvironmentKindName string + SourceKindName string + PrincipalKinds []string +} + +type FindingInput struct { + Name string + DisplayName string + RelationshipKindName string + EnvironmentKindName string + SourceKindName string + RemediationInput RemediationInput +} + +type RemediationInput struct { + ShortDescription string + LongDescription string + ShortRemediation string + LongRemediation string +} + +func (s *BloodhoundDB) UpsertGraphSchemaExtension(ctx context.Context, extensionID int32, environments []EnvironmentInput, findings []FindingInput) error { + return s.Transaction(ctx, func(tx *BloodhoundDB) error { + for _, env := range environments { + if err := tx.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, extensionID, env.EnvironmentKindName, env.SourceKindName, env.PrincipalKinds); err != nil { + return fmt.Errorf("failed to upload environment with principal kinds: %w", err) + } + } + + for _, finding := range findings { + if schemaFinding, err := tx.UpsertFinding(ctx, extensionID, finding.SourceKindName, finding.RelationshipKindName, finding.EnvironmentKindName, finding.Name, finding.DisplayName); err != nil { + return fmt.Errorf("failed to upsert finding: %w", err) + } else { + if err := tx.UpsertRemediation(ctx, schemaFinding.ID, finding.RemediationInput.ShortDescription, finding.RemediationInput.LongDescription, finding.RemediationInput.ShortRemediation, finding.RemediationInput.LongRemediation); err != nil { + return fmt.Errorf("failed to upsert remediation: %w", err) + } + } + } + + return nil + }) +} diff --git a/cmd/api/src/database/upsert_schema_extension_integration_test.go b/cmd/api/src/database/upsert_schema_extension_integration_test.go new file mode 100644 index 0000000000..df0f1f0337 --- /dev/null +++ b/cmd/api/src/database/upsert_schema_extension_integration_test.go @@ -0,0 +1,202 @@ +package database_test + +// import ( +// "context" +// "testing" + +// "github.com/specterops/bloodhound/cmd/api/src/database" +// "github.com/stretchr/testify/assert" +// "github.com/stretchr/testify/require" +// ) + +// func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { +// type args struct { +// environments []database.EnvironmentInput +// findings []database.FindingInput +// } +// tests := []struct { +// name string +// setupData func(t *testing.T, db *database.BloodhoundDB) int32 // Returns extensionId +// args args +// assert func(t *testing.T, db *database.BloodhoundDB, extensionId int32) +// expectedError string +// }{ +// { +// name: "Success: Create new environments and findings with remediations", +// setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { +// t.Helper() +// ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") +// require.NoError(t, err) + +// _, err = db.CreateSchemaEnvironment(context.Background(), ext.ID, int32(1), int32(1)) +// require.NoError(t, err) + +// return ext.ID +// }, +// args: args{ +// environments: []database.EnvironmentInput{ +// { +// EnvironmentKindName: "Tag_Tier_Zero", +// SourceKindName: "Base", +// PrincipalKinds: []string{"Tag_Owned", "Tag_Tier_Zero"}, +// }, +// }, +// findings: []database.FindingInput{ +// { +// Name: "T0WriteOwner", +// DisplayName: "Write Owner", +// RelationshipKindName: "Tag_Tier_Zero", +// EnvironmentKindName: "Tag_Tier_Zero", +// RemediationInput: database.RemediationInput{ +// ShortDescription: "User has write owner permission", +// LongDescription: "This permission allows modifying object owner", +// ShortRemediation: "Remove write owner permissions", +// LongRemediation: "Review and remove unnecessary permissions", +// }, +// }, +// { +// Name: "T0DCSync", +// DisplayName: "DCSync Attack", +// RelationshipKindName: "Tag_Tier_Zero", +// EnvironmentKindName: "Tag_Tier_Zero", +// RemediationInput: database.RemediationInput{ +// ShortDescription: "Principal can perform DCSync", +// LongDescription: "Can extract password hashes", +// ShortRemediation: "Revoke replication permissions", +// LongRemediation: "Remove DS-Replication-Get-Changes permissions", +// }, +// }, +// }, +// }, +// assert: func(t *testing.T, db *database.BloodhoundDB, extensionId int32) { +// t.Helper() + +// // Verify findings were created +// finding1, err := db.GetSchemaRelationshipFindingByName(context.Background(), "T0WriteOwner") +// require.NoError(t, err) +// assert.Equal(t, extensionId, finding1.SchemaExtensionId) +// assert.Equal(t, "Write Owner", finding1.DisplayName) + +// finding2, err := db.GetSchemaRelationshipFindingByName(context.Background(), "T0DCSync") +// require.NoError(t, err) +// assert.Equal(t, extensionId, finding2.SchemaExtensionId) +// assert.Equal(t, "DCSync Attack", finding2.DisplayName) + +// // Verify remediations were created +// remediation1, err := db.GetRemediationByFindingId(context.Background(), finding1.ID) +// require.NoError(t, err) +// assert.Equal(t, "User has write owner permission", remediation1.ShortDescription) +// assert.Equal(t, "Remove write owner permissions", remediation1.ShortRemediation) + +// remediation2, err := db.GetRemediationByFindingId(context.Background(), finding2.ID) +// require.NoError(t, err) +// assert.Equal(t, "Principal can perform DCSync", remediation2.ShortDescription) +// assert.Equal(t, "Revoke replication permissions", remediation2.ShortRemediation) +// }, +// }, +// // { +// // name: "Success: Update existing findings and remediations", +// // setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { +// // t.Helper() +// // ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt2", "Test2", "v1.0.0") +// // require.NoError(t, err) + +// // env, err := db.CreateSchemaEnvironment(context.Background(), ext.ID, 1, 1) +// // require.NoError(t, err) + +// // // Create initial finding with remediation +// // finding, err := db.CreateSchemaRelationshipFinding(context.Background(), ext.ID, 1, env.ID, "ExistingFinding", "Old Display Name") +// // require.NoError(t, err) + +// // _, err = db.CreateRemediation(context.Background(), finding.ID, "old short", "old long", "old short rem", "old long rem") +// // require.NoError(t, err) + +// // return ext.ID +// // }, +// // args: args{ +// // environments: []database.EnvironmentInput{ +// // { +// // EnvironmentKind: "Domain", +// // SourceKind: "Base", +// // PrincipalKinds: []string{"User"}, +// // }, +// // }, +// // findings: []database.FindingInput{ +// // { +// // Name: "ExistingFinding", +// // DisplayName: "Updated Display Name", +// // RelationshipKindName: "WriteOwner", +// // EnvironmentKindName: "Domain", +// // RemediationInput: database.RemediationInput{ +// // ShortDescription: "updated short", +// // LongDescription: "updated long", +// // ShortRemediation: "updated short rem", +// // LongRemediation: "updated long rem", +// // }, +// // }, +// // }, +// // }, +// // assert: func(t *testing.T, db *database.BloodhoundDB, extensionId int32) { +// // t.Helper() + +// // // Verify finding was updated (deleted and recreated) +// // finding, err := db.GetSchemaRelationshipFindingByName(context.Background(), "ExistingFinding") +// // require.NoError(t, err) +// // assert.Equal(t, extensionId, finding.SchemaExtensionId) +// // assert.Equal(t, "Updated Display Name", finding.DisplayName) + +// // // Verify remediation was updated +// // remediation, err := db.GetRemediationByFindingId(context.Background(), finding.ID) +// // require.NoError(t, err) +// // assert.Equal(t, "updated short", remediation.ShortDescription) +// // assert.Equal(t, "updated long", remediation.LongDescription) +// // assert.Equal(t, "updated short rem", remediation.ShortRemediation) +// // assert.Equal(t, "updated long rem", remediation.LongRemediation) +// // }, +// // }, +// // { +// // name: "Success: Empty environments and findings", +// // setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { +// // t.Helper() +// // ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt3", "Test3", "v1.0.0") +// // require.NoError(t, err) + +// // return ext.ID +// // }, +// // args: args{ +// // environments: []database.EnvironmentInput{}, +// // findings: []database.FindingInput{}, +// // }, +// // assert: func(t *testing.T, db *database.BloodhoundDB, extensionId int32) { +// // t.Helper() +// // // Nothing to assert - just verify no error occurred +// // }, +// // }, +// } +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// testSuite := setupIntegrationTestSuite(t) +// defer teardownIntegrationTestSuite(t, &testSuite) + +// extensionId := tt.setupData(t, testSuite.BHDatabase) + +// err := testSuite.BHDatabase.UpsertGraphSchemaExtension( +// context.Background(), +// extensionId, +// tt.args.environments, +// tt.args.findings, +// ) + +// if tt.expectedError != "" { +// require.Error(t, err) +// assert.Contains(t, err.Error(), tt.expectedError) +// } else { +// require.NoError(t, err) +// } + +// if tt.assert != nil { +// tt.assert(t, testSuite.BHDatabase, extensionId) +// } +// }) +// } +// } diff --git a/cmd/api/src/database/upsert_schema_finding.go b/cmd/api/src/database/upsert_schema_finding.go new file mode 100644 index 0000000000..97abfc9dea --- /dev/null +++ b/cmd/api/src/database/upsert_schema_finding.go @@ -0,0 +1,87 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package database + +import ( + "context" + "errors" + "fmt" + + "github.com/specterops/bloodhound/cmd/api/src/model" +) + +// UpsertFinding validates and upserts a finding. +// If a finding with the same name exists, it will be deleted and re-created. +func (s *BloodhoundDB) UpsertFinding(ctx context.Context, extensionId int32, sourceKindName, relationshipKindName, environmentKind string, name, displayName string) (model.SchemaRelationshipFinding, error) { + relationshipKindId, err := s.validateAndTranslateRelationshipKind(ctx, relationshipKindName) + if err != nil { + return model.SchemaRelationshipFinding{}, err + } + + environmentKindId, err := s.validateAndTranslateEnvironmentKind(ctx, environmentKind) + if err != nil { + return model.SchemaRelationshipFinding{}, err + } + + sourceKindId, err := s.validateAndTranslateSourceKind(ctx, sourceKindName) + if err != nil { + return model.SchemaRelationshipFinding{}, err + } + + environment, err := s.GetSchemaEnvironmentByKinds(ctx, environmentKindId, sourceKindId) + if err != nil { + return model.SchemaRelationshipFinding{}, err + } + + finding, err := s.upsertFinding(ctx, extensionId, relationshipKindId, environment.ID, name, displayName) + if err != nil { + return model.SchemaRelationshipFinding{}, err + } + + return finding, nil +} + +// validateAndTranslateRelationshipKind validates that the relationship kind exists in the kinds table. +func (s *BloodhoundDB) validateAndTranslateRelationshipKind(ctx context.Context, relationshipKindName string) (int32, error) { + if relationshipKind, err := s.GetKindByName(ctx, relationshipKindName); err != nil && !errors.Is(err, ErrNotFound) { + return 0, fmt.Errorf("error retrieving relationship kind '%s': %w", relationshipKindName, err) + } else if errors.Is(err, ErrNotFound) { + return 0, fmt.Errorf("relationship kind '%s' not found", relationshipKindName) + } else { + return relationshipKind.ID, nil + } +} + +// upsertFinding creates or updates a schema relationship finding. +// If an environment with the given kinds exists, it deletes it first before creating the new one. +func (s *BloodhoundDB) upsertFinding(ctx context.Context, extensionId, relationshipKindId, environmentId int32, name, displayName string) (model.SchemaRelationshipFinding, error) { + if existing, err := s.GetSchemaRelationshipFindingByName(ctx, name); err != nil && !errors.Is(err, ErrNotFound) { + return model.SchemaRelationshipFinding{}, fmt.Errorf("error retrieving schema relationship finding: %w", err) + } else if err == nil { + // Finding exists - delete it first + if err := s.DeleteSchemaRelationshipFinding(ctx, existing.ID); err != nil { + return model.SchemaRelationshipFinding{}, fmt.Errorf("error deleting schema relationship finding %d: %w", existing.ID, err) + } + } + + finding, err := s.CreateSchemaRelationshipFinding(ctx, extensionId, relationshipKindId, environmentId, name, displayName) + if err != nil { + return model.SchemaRelationshipFinding{}, fmt.Errorf("error creating schema relationship finding: %w", err) + } + + return finding, nil + +} diff --git a/cmd/api/src/database/upsert_schema_finding_integration_test.go b/cmd/api/src/database/upsert_schema_finding_integration_test.go new file mode 100644 index 0000000000..c9d4fe5991 --- /dev/null +++ b/cmd/api/src/database/upsert_schema_finding_integration_test.go @@ -0,0 +1,145 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +//go:build integration + +package database_test + +import ( + "context" + "testing" + + "github.com/specterops/bloodhound/cmd/api/src/database" + "github.com/specterops/bloodhound/cmd/api/src/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBloodhoundDB_UpsertFinding(t *testing.T) { + type args struct { + sourceKindName, relationshipKindName, environmentKind, name, displayName string + } + tests := []struct { + name string + setupData func(t *testing.T, db *database.BloodhoundDB) int32 // Returns extensionId + args args + assert func(t *testing.T, db *database.BloodhoundDB, extensionId int32) + expectedError string + }{ + { + name: "Success: Update existing finding - delete and re-create", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + env, err := db.CreateSchemaEnvironment(context.Background(), ext.ID, 1, 1) + require.NoError(t, err) + + // Create finding + _, err = db.CreateSchemaRelationshipFinding(context.Background(), ext.ID, 1, env.ID, "Finding Name", "Finding Display Name") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + sourceKindName: "Base", + relationshipKindName: "Tag_Tier_Zero", + environmentKind: "Tag_Tier_Zero", + // Name triggers upsert so this needs to match the finding's name that we want to update + name: "Finding Name", + displayName: "Updated Display Name", + }, + assert: func(t *testing.T, db *database.BloodhoundDB, extensionId int32) { + t.Helper() + + finding, err := db.GetSchemaRelationshipFindingByName(context.Background(), "Finding Name") + require.NoError(t, err) + + assert.Equal(t, extensionId, finding.SchemaExtensionId) + assert.Equal(t, "Finding Name", finding.Name) + assert.Equal(t, "Updated Display Name", finding.DisplayName) + }, + }, + { + name: "Success: Create finding when none exists", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt2", "Test2", "v1.0.0") + require.NoError(t, err) + + _, err = db.CreateSchemaEnvironment(context.Background(), ext.ID, 1, 1) + require.NoError(t, err) + + // No finding created since we're testing the creation workflow + return ext.ID + }, + args: args{ + sourceKindName: "Base", + relationshipKindName: "Tag_Tier_Zero", + environmentKind: "Tag_Tier_Zero", + name: "Finding", + displayName: "Finding Display Name", + }, + assert: func(t *testing.T, db *database.BloodhoundDB, extensionId int32) { + t.Helper() + + finding, err := db.GetSchemaRelationshipFindingByName(context.Background(), "Finding") + require.NoError(t, err) + + assert.Equal(t, extensionId, finding.SchemaExtensionId) + assert.Equal(t, "Finding", finding.Name) + assert.Equal(t, "Finding Display Name", finding.DisplayName) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testSuite := setupIntegrationTestSuite(t) + defer teardownIntegrationTestSuite(t, &testSuite) + + extensionId := tt.setupData(t, testSuite.BHDatabase) + + var findingResponse model.SchemaRelationshipFinding + // Wrap the call in a transaction + err := testSuite.BHDatabase.Transaction(context.Background(), func(tx *database.BloodhoundDB) error { + finding, err := tx.UpsertFinding( + context.Background(), + extensionId, + tt.args.sourceKindName, + tt.args.relationshipKindName, + tt.args.environmentKind, + tt.args.name, + tt.args.displayName, + ) + findingResponse = finding + return err + }) + + if tt.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + } else { + require.NoError(t, err) + assert.NotZero(t, findingResponse.ID, "Finding should have been created/updated") + } + + if tt.assert != nil { + tt.assert(t, testSuite.BHDatabase, extensionId) + } + }) + } +} diff --git a/cmd/api/src/database/upsert_schema_remediation.go b/cmd/api/src/database/upsert_schema_remediation.go new file mode 100644 index 0000000000..b1843545c8 --- /dev/null +++ b/cmd/api/src/database/upsert_schema_remediation.go @@ -0,0 +1,41 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package database + +import ( + "context" + "errors" + "fmt" +) + +// UpsertRemediation validates and upserts a remediation. +// If the remediation exists for the finding ID, it is updated. If it doesn't already exist, it is created. +// Findings information must be inserted first before inserting remediation information. +func (s *BloodhoundDB) UpsertRemediation(ctx context.Context, findingId int32, shortDescription, longDescription, shortRemediation, longRemediation string) error { + if _, err := s.GetRemediationByFindingId(ctx, findingId); err != nil && !errors.Is(err, ErrNotFound) { + return fmt.Errorf("error retrieving remediation by finding id '%d': %w", findingId, err) + } else if err == nil { + // Remediation exists - update it + if _, err := s.UpdateRemediation(ctx, findingId, shortDescription, longDescription, shortRemediation, longRemediation); err != nil { + return fmt.Errorf("error updating remediation by finding id '%d': %w", findingId, err) + } + } else { + if _, err := s.CreateRemediation(ctx, findingId, shortDescription, longDescription, shortRemediation, longRemediation); err != nil { + return fmt.Errorf("error creating remediation by finding id '%d': %w", findingId, err) + } + } + return nil +} diff --git a/cmd/api/src/database/upsert_schema_remediation_integration_test.go b/cmd/api/src/database/upsert_schema_remediation_integration_test.go new file mode 100644 index 0000000000..4115faa0fa --- /dev/null +++ b/cmd/api/src/database/upsert_schema_remediation_integration_test.go @@ -0,0 +1,145 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +//go:build integration + +package database_test + +import ( + "context" + "testing" + + "github.com/specterops/bloodhound/cmd/api/src/database" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBloodhoundDB_UpsertRemediation(t *testing.T) { + type args struct { + shortDescription, longDescription, shortRemediation, longRemediation string + } + tests := []struct { + name string + setupData func(t *testing.T, db *database.BloodhoundDB) int32 // Returns findingID + args args + assert func(t *testing.T, db *database.BloodhoundDB, findingId int32) + expectedError string + }{ + { + name: "Success: Update existing remediation", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + env, err := db.CreateSchemaEnvironment(context.Background(), ext.ID, 1, 1) + require.NoError(t, err) + + finding, err := db.CreateSchemaRelationshipFinding(context.Background(), ext.ID, 1, env.ID, "Finding", "Finding Display Name") + require.NoError(t, err) + + _, err = db.CreateRemediation(context.Background(), finding.ID, "short", "long", "short rem", "long rem") + require.NoError(t, err) + + return finding.ID + }, + args: args{ + shortDescription: "updated short description", + longDescription: "updated long description", + shortRemediation: "updated short remediation", + longRemediation: "updated long remediation", + }, + assert: func(t *testing.T, db *database.BloodhoundDB, findingId int32) { + t.Helper() + + remediation, err := db.GetRemediationByFindingId(context.Background(), findingId) + require.NoError(t, err) + + assert.Equal(t, findingId, remediation.FindingID) + assert.Equal(t, "updated short description", remediation.ShortDescription) + assert.Equal(t, "updated long description", remediation.LongDescription) + assert.Equal(t, "updated short remediation", remediation.ShortRemediation) + assert.Equal(t, "updated long remediation", remediation.LongRemediation) + }, + }, + { + name: "Success: Create remediation when none exists", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + env, err := db.CreateSchemaEnvironment(context.Background(), ext.ID, 1, 1) + require.NoError(t, err) + + // Create Finding but do not create Remediation + finding, err := db.CreateSchemaRelationshipFinding(context.Background(), ext.ID, 1, env.ID, "Finding2", "Finding 2 Display Name") + require.NoError(t, err) + + return finding.ID + }, + args: args{ + shortDescription: "new short description", + longDescription: "new long description", + shortRemediation: "new short remediation", + longRemediation: "new long remediation", + }, + assert: func(t *testing.T, db *database.BloodhoundDB, findingId int32) { + t.Helper() + + remediation, err := db.GetRemediationByFindingId(context.Background(), findingId) + require.NoError(t, err) + + assert.Equal(t, findingId, remediation.FindingID) + assert.Equal(t, "new short description", remediation.ShortDescription) + assert.Equal(t, "new long description", remediation.LongDescription) + assert.Equal(t, "new short remediation", remediation.ShortRemediation) + assert.Equal(t, "new long remediation", remediation.LongRemediation) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testSuite := setupIntegrationTestSuite(t) + defer teardownIntegrationTestSuite(t, &testSuite) + + findingId := tt.setupData(t, testSuite.BHDatabase) + + // Wrap the call in a transaction + err := testSuite.BHDatabase.Transaction(context.Background(), func(tx *database.BloodhoundDB) error { + return tx.UpsertRemediation( + context.Background(), + findingId, + tt.args.shortDescription, + tt.args.longDescription, + tt.args.shortRemediation, + tt.args.longRemediation, + ) + }) + + if tt.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + } else { + require.NoError(t, err) + } + + if tt.assert != nil { + tt.assert(t, testSuite.BHDatabase, findingId) + } + }) + } +} diff --git a/cmd/api/src/model/kind.go b/cmd/api/src/model/kind.go index 8f6402c682..95d3ad8f25 100644 --- a/cmd/api/src/model/kind.go +++ b/cmd/api/src/model/kind.go @@ -16,6 +16,6 @@ package model type Kind struct { - ID int `json:"id"` + ID int32 `json:"id"` Name string `json:"name"` } diff --git a/cmd/api/src/services/opengraphschema/extension.go b/cmd/api/src/services/opengraphschema/extension.go index 8b16e1e35b..e02e3b4481 100644 --- a/cmd/api/src/services/opengraphschema/extension.go +++ b/cmd/api/src/services/opengraphschema/extension.go @@ -20,15 +20,43 @@ import ( "fmt" v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" + "github.com/specterops/bloodhound/cmd/api/src/database" ) func (s *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { - for _, env := range req.Environments { - // TODO: Update temporary hardcoded extensionID once extension work is complete - if err := s.openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, env.EnvironmentKind, env.SourceKind, env.PrincipalKinds); err != nil { - return fmt.Errorf("failed to upload environments with principal kinds: %w", err) + var ( + environments = make([]database.EnvironmentInput, len(req.Environments)) + findings = make([]database.FindingInput, len(req.Findings)) + ) + + for i, environment := range req.Environments { + environments[i] = database.EnvironmentInput{ + EnvironmentKindName: environment.EnvironmentKind, + SourceKindName: environment.SourceKind, + PrincipalKinds: environment.PrincipalKinds, + } + } + + for i, finding := range req.Findings { + findings[i] = database.FindingInput{ + Name: finding.Name, + DisplayName: finding.DisplayName, + SourceKindName: finding.SourceKind, + RelationshipKindName: finding.RelationshipKind, + EnvironmentKindName: finding.EnvironmentKind, + RemediationInput: database.RemediationInput{ + ShortDescription: finding.Remediation.ShortDescription, + LongDescription: finding.Remediation.LongDescription, + ShortRemediation: finding.Remediation.ShortRemediation, + LongRemediation: finding.Remediation.LongRemediation, + }, } } + err := s.openGraphSchemaRepository.UpsertGraphSchemaExtension(ctx, 1, environments, findings) + if err != nil { + return fmt.Errorf("error upserting graph extension: %w", err) + } + return nil } diff --git a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go index 9c3e712987..2653fb152b 100644 --- a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go @@ -29,6 +29,7 @@ import ( context "context" reflect "reflect" + database "github.com/specterops/bloodhound/cmd/api/src/database" gomock "go.uber.org/mock/gomock" ) @@ -56,16 +57,16 @@ func (m *MockOpenGraphSchemaRepository) EXPECT() *MockOpenGraphSchemaRepositoryM return m.recorder } -// UpsertSchemaEnvironmentWithPrincipalKinds mocks base method. -func (m *MockOpenGraphSchemaRepository) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environmentKind, sourceKind string, principalKinds []string) error { +// UpsertGraphSchemaExtension mocks base method. +func (m *MockOpenGraphSchemaRepository) UpsertGraphSchemaExtension(ctx context.Context, extensionID int32, environments []database.EnvironmentInput, findings []database.FindingInput) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpsertSchemaEnvironmentWithPrincipalKinds", ctx, schemaExtensionId, environmentKind, sourceKind, principalKinds) + ret := m.ctrl.Call(m, "UpsertGraphSchemaExtension", ctx, extensionID, environments, findings) ret0, _ := ret[0].(error) return ret0 } -// UpsertSchemaEnvironmentWithPrincipalKinds indicates an expected call of UpsertSchemaEnvironmentWithPrincipalKinds. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) UpsertSchemaEnvironmentWithPrincipalKinds(ctx, schemaExtensionId, environmentKind, sourceKind, principalKinds any) *gomock.Call { +// UpsertGraphSchemaExtension indicates an expected call of UpsertGraphSchemaExtension. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) UpsertGraphSchemaExtension(ctx, extensionID, environments, findings any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertSchemaEnvironmentWithPrincipalKinds", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).UpsertSchemaEnvironmentWithPrincipalKinds), ctx, schemaExtensionId, environmentKind, sourceKind, principalKinds) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertGraphSchemaExtension", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).UpsertGraphSchemaExtension), ctx, extensionID, environments, findings) } diff --git a/cmd/api/src/services/opengraphschema/opengraphschema.go b/cmd/api/src/services/opengraphschema/opengraphschema.go index 6de81562dc..cf9fb9f318 100644 --- a/cmd/api/src/services/opengraphschema/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/opengraphschema.go @@ -19,11 +19,13 @@ package opengraphschema import ( "context" + + "github.com/specterops/bloodhound/cmd/api/src/database" ) // OpenGraphSchemaRepository - type OpenGraphSchemaRepository interface { - UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environmentKind string, sourceKind string, principalKinds []string) error + UpsertGraphSchemaExtension(ctx context.Context, extensionID int32, environments []database.EnvironmentInput, findings []database.FindingInput) error } type OpenGraphSchemaService struct { From 85b9407a4b0b91771ceeae218c6ea5c264d4f025 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Tue, 13 Jan 2026 12:54:47 -0600 Subject: [PATCH 18/36] merge conflicts --- cmd/api/src/services/entrypoint.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/api/src/services/entrypoint.go b/cmd/api/src/services/entrypoint.go index ac8361e87f..76a4f33efd 100644 --- a/cmd/api/src/services/entrypoint.go +++ b/cmd/api/src/services/entrypoint.go @@ -38,8 +38,8 @@ import ( "github.com/specterops/bloodhound/cmd/api/src/migrations" "github.com/specterops/bloodhound/cmd/api/src/model/appcfg" "github.com/specterops/bloodhound/cmd/api/src/queries" - "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" "github.com/specterops/bloodhound/cmd/api/src/services/dogtags" + "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" "github.com/specterops/bloodhound/cmd/api/src/services/upload" "github.com/specterops/bloodhound/packages/go/cache" schema "github.com/specterops/bloodhound/packages/go/graphschema" From be020633f177f6f052cf88db865fab5242e0bf3b Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Tue, 13 Jan 2026 13:10:36 -0600 Subject: [PATCH 19/36] updated to be in line with other PR --- cmd/api/src/api/registration/registration.go | 2 +- cmd/api/src/api/v2/model.go | 4 +- .../src/database/upsert_schema_environment.go | 2 - .../src/database/upsert_schema_extension.go | 24 + ...psert_schema_extension_integration_test.go | 418 +++++++++++++++ .../src/services/opengraphschema/extension.go | 20 +- .../opengraphschema/extension_test.go | 164 ++++++ .../opengraphschema/mocks/opengraphschema.go | 13 +- .../opengraphschema/opengraphschema.go | 4 +- .../opengraphschema/opengraphschema_test.go | 481 ------------------ 10 files changed, 635 insertions(+), 497 deletions(-) create mode 100644 cmd/api/src/database/upsert_schema_extension.go create mode 100644 cmd/api/src/database/upsert_schema_extension_integration_test.go create mode 100644 cmd/api/src/services/opengraphschema/extension_test.go delete mode 100644 cmd/api/src/services/opengraphschema/opengraphschema_test.go diff --git a/cmd/api/src/api/registration/registration.go b/cmd/api/src/api/registration/registration.go index 25eb6e2370..20cf7ff25a 100644 --- a/cmd/api/src/api/registration/registration.go +++ b/cmd/api/src/api/registration/registration.go @@ -63,8 +63,8 @@ func RegisterFossRoutes( authenticator api.Authenticator, authorizer auth.Authorizer, ingestSchema upload.IngestSchema, - openGraphSchemaService v2.OpenGraphSchemaService, dogtagsService dogtags.Service, + openGraphSchemaService v2.OpenGraphSchemaService, ) { router.With(func() mux.MiddlewareFunc { return middleware.DefaultRateLimitMiddleware(rdms) diff --git a/cmd/api/src/api/v2/model.go b/cmd/api/src/api/v2/model.go index 1a4c501490..a91b512275 100644 --- a/cmd/api/src/api/v2/model.go +++ b/cmd/api/src/api/v2/model.go @@ -130,8 +130,8 @@ func NewResources( authorizer auth.Authorizer, authenticator api.Authenticator, ingestSchema upload.IngestSchema, - openGraphSchemaService OpenGraphSchemaService, dogtagsService dogtags.Service, + openGraphSchemaService OpenGraphSchemaService, ) Resources { return Resources{ Decoder: schema.NewDecoder(), @@ -146,7 +146,7 @@ func NewResources( Authenticator: authenticator, IngestSchema: ingestSchema, FileService: &fs.Client{}, - openGraphSchemaService: openGraphSchemaService, DogTags: dogtagsService, + openGraphSchemaService: openGraphSchemaService, } } diff --git a/cmd/api/src/database/upsert_schema_environment.go b/cmd/api/src/database/upsert_schema_environment.go index 71b3107c44..f65022bfb3 100644 --- a/cmd/api/src/database/upsert_schema_environment.go +++ b/cmd/api/src/database/upsert_schema_environment.go @@ -26,8 +26,6 @@ import ( // UpsertSchemaEnvironmentWithPrincipalKinds creates or updates an environment with its principal kinds. // If an environment with the same environment kind and source kind exists, it will be replaced. -// -// NOTE: This method should be called within a transaction. The caller is responsible for transaction management. func (s *BloodhoundDB) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environmentKind string, sourceKind string, principalKinds []string) error { environment := model.SchemaEnvironment{ SchemaExtensionId: schemaExtensionId, diff --git a/cmd/api/src/database/upsert_schema_extension.go b/cmd/api/src/database/upsert_schema_extension.go new file mode 100644 index 0000000000..d1a9ea8b25 --- /dev/null +++ b/cmd/api/src/database/upsert_schema_extension.go @@ -0,0 +1,24 @@ +package database + +import ( + "context" + "fmt" +) + +type EnvironmentInput struct { + EnvironmentKindName string + SourceKindName string + PrincipalKinds []string +} + +func (s *BloodhoundDB) UpsertGraphSchemaExtension(ctx context.Context, extensionID int32, environments []EnvironmentInput) error { + return s.Transaction(ctx, func(tx *BloodhoundDB) error { + for _, env := range environments { + if err := tx.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, extensionID, env.EnvironmentKindName, env.SourceKindName, env.PrincipalKinds); err != nil { + return fmt.Errorf("failed to upload environment with principal kinds: %w", err) + } + } + + return nil + }) +} diff --git a/cmd/api/src/database/upsert_schema_extension_integration_test.go b/cmd/api/src/database/upsert_schema_extension_integration_test.go new file mode 100644 index 0000000000..b406021252 --- /dev/null +++ b/cmd/api/src/database/upsert_schema_extension_integration_test.go @@ -0,0 +1,418 @@ +//go:build integration + +package database_test + +import ( + "context" + "testing" + + "github.com/specterops/bloodhound/cmd/api/src/database" + "github.com/specterops/dawgs/graph" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { + type args struct { + environments []database.EnvironmentInput + } + tests := []struct { + name string + setupData func(t *testing.T, db *database.BloodhoundDB) int32 + args args + assert func(t *testing.T, db *database.BloodhoundDB) + expectedError string + }{ + { + name: "Success: Create environment with principal kinds", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + PrincipalKinds: []string{"Tag_Tier_Zero", "Tag_Owned"}, + }, + }, + }, + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + expectedPrincipalKindNames := []string{"Tag_Tier_Zero", "Tag_Owned"} + + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 1, len(environments)) + + principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + assert.NoError(t, err) + assert.Equal(t, len(expectedPrincipalKindNames), len(principalKinds)) + + expectedKindIDs := make([]int32, len(expectedPrincipalKindNames)) + for i, name := range expectedPrincipalKindNames { + kind, err := db.GetKindByName(context.Background(), name) + assert.NoError(t, err) + expectedKindIDs[i] = int32(kind.ID) + } + + actualKindIDs := make([]int32, len(principalKinds)) + for i, pk := range principalKinds { + assert.Equal(t, environments[0].ID, pk.EnvironmentId) + actualKindIDs[i] = pk.PrincipalKind + } + + assert.ElementsMatch(t, expectedKindIDs, actualKindIDs) + }, + }, + { + name: "Success: Create multiple environments", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + PrincipalKinds: []string{"Tag_Tier_Zero"}, + }, + { + EnvironmentKindName: "Tag_Owned", + SourceKindName: "Base", + PrincipalKinds: []string{"Tag_Owned"}, + }, + }, + }, + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 2, len(environments), "Should have two environments") + + // Verify first environment + env1PrincipalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + assert.NoError(t, err) + assert.Equal(t, 1, len(env1PrincipalKinds)) + + // Verify second environment + env2PrincipalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[1].ID) + assert.NoError(t, err) + assert.Equal(t, 1, len(env2PrincipalKinds)) + }, + }, + { + name: "Success: Upsert replaces existing environment", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + // Create initial environment + err = db.UpsertGraphSchemaExtension(context.Background(), ext.ID, []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + PrincipalKinds: []string{"Tag_Owned"}, + }, + }) + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + PrincipalKinds: []string{"Tag_Tier_Zero"}, + }, + }, + }, + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + expectedPrincipalKindNames := []string{"Tag_Tier_Zero"} + + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 1, len(environments), "Should only have one environment (old one replaced)") + + principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + assert.NoError(t, err) + assert.Equal(t, 1, len(principalKinds)) + + expectedKind, err := db.GetKindByName(context.Background(), expectedPrincipalKindNames[0]) + assert.NoError(t, err) + + assert.Equal(t, int32(expectedKind.ID), principalKinds[0].PrincipalKind) + assert.Equal(t, environments[0].ID, principalKinds[0].EnvironmentId) + }, + }, + { + name: "Success: Source kind auto-registers", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "NewSource", + PrincipalKinds: []string{"Tag_Tier_Zero"}, + }, + }, + }, + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + sourceKind, err := db.GetSourceKindByName(context.Background(), "NewSource") + assert.NoError(t, err) + assert.Equal(t, graph.StringKind("NewSource"), sourceKind.Name) + + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 1, len(environments)) + assert.Equal(t, int32(sourceKind.ID), environments[0].SourceKindId) + + principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + assert.NoError(t, err) + assert.Equal(t, 1, len(principalKinds)) + }, + }, + { + name: "Success: Multiple environments with different source kinds", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + PrincipalKinds: []string{"Tag_Tier_Zero"}, + }, + { + EnvironmentKindName: "Tag_Owned", + SourceKindName: "NewSource", + PrincipalKinds: []string{"Tag_Owned"}, + }, + }, + }, + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + // Verify NewSource was auto-registered + sourceKind, err := db.GetSourceKindByName(context.Background(), "NewSource") + assert.NoError(t, err) + assert.Equal(t, graph.StringKind("NewSource"), sourceKind.Name) + + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 2, len(environments), "Should have two environments") + + for _, env := range environments { + principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), env.ID) + assert.NoError(t, err) + assert.Equal(t, 1, len(principalKinds), "Each environment should have one principal kind") + } + }, + }, + { + name: "Error: First environment has invalid environment kind", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "NonExistent", + SourceKindName: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + expectedError: "environment kind 'NonExistent' not found", + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + // Verify transaction rolled back - no environment created + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 0, len(environments), "No environment should exist after rollback") + }, + }, + { + name: "Error: First environment has invalid principal kind", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + PrincipalKinds: []string{"NonExistent"}, + }, + }, + }, + expectedError: "principal kind 'NonExistent' not found", + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + // Verify transaction rolled back - no environment created + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 0, len(environments), "No environment should exist after rollback") + }, + }, + { + name: "Rollback: Second environment fails, first should rollback", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + PrincipalKinds: []string{"Tag_Tier_Zero"}, + }, + { + EnvironmentKindName: "NonExistent", + SourceKindName: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + expectedError: "environment kind 'NonExistent' not found", + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + // Verify complete transaction rollback - no environments created + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 0, len(environments), "No environments should exist after rollback") + }, + }, + { + name: "Rollback: Second environment has invalid principal kind", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + PrincipalKinds: []string{"Tag_Tier_Zero"}, + }, + { + EnvironmentKindName: "Tag_Owned", + SourceKindName: "Base", + PrincipalKinds: []string{"NonExistent"}, + }, + }, + }, + expectedError: "principal kind 'NonExistent' not found", + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + // Verify complete transaction rollback - no environments created + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 0, len(environments), "No environments should exist after rollback") + }, + }, + { + name: "Rollback: Partial failure in first environment's principal kinds", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + PrincipalKinds: []string{"Tag_Owned", "NonExistent"}, + }, + }, + }, + expectedError: "principal kind 'NonExistent' not found", + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + // Verify transaction rolled back - no environment created + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 0, len(environments), "No environment should exist after rollback") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testSuite := setupIntegrationTestSuite(t) + defer teardownIntegrationTestSuite(t, &testSuite) + + extensionID := tt.setupData(t, testSuite.BHDatabase) + + err := testSuite.BHDatabase.UpsertGraphSchemaExtension( + context.Background(), + extensionID, + tt.args.environments, + ) + + if tt.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + if tt.assert != nil { + tt.assert(t, testSuite.BHDatabase) + } + } else { + require.NoError(t, err) + if tt.assert != nil { + tt.assert(t, testSuite.BHDatabase) + } + } + }) + } +} diff --git a/cmd/api/src/services/opengraphschema/extension.go b/cmd/api/src/services/opengraphschema/extension.go index 8b16e1e35b..a16df48021 100644 --- a/cmd/api/src/services/opengraphschema/extension.go +++ b/cmd/api/src/services/opengraphschema/extension.go @@ -20,15 +20,27 @@ import ( "fmt" v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" + "github.com/specterops/bloodhound/cmd/api/src/database" ) func (s *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { - for _, env := range req.Environments { - // TODO: Update temporary hardcoded extensionID once extension work is complete - if err := s.openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, env.EnvironmentKind, env.SourceKind, env.PrincipalKinds); err != nil { - return fmt.Errorf("failed to upload environments with principal kinds: %w", err) + var ( + environments = make([]database.EnvironmentInput, len(req.Environments)) + ) + + for i, environment := range req.Environments { + environments[i] = database.EnvironmentInput{ + EnvironmentKindName: environment.EnvironmentKind, + SourceKindName: environment.SourceKind, + PrincipalKinds: environment.PrincipalKinds, } } + // TODO: Temporary hardcoded value but needs to be updated to pass in the extension ID + err := s.openGraphSchemaRepository.UpsertGraphSchemaExtension(ctx, 1, environments) + if err != nil { + return fmt.Errorf("error upserting graph extension: %w", err) + } + return nil } diff --git a/cmd/api/src/services/opengraphschema/extension_test.go b/cmd/api/src/services/opengraphschema/extension_test.go new file mode 100644 index 0000000000..0f81f8b881 --- /dev/null +++ b/cmd/api/src/services/opengraphschema/extension_test.go @@ -0,0 +1,164 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package opengraphschema_test + +import ( + "context" + "errors" + "testing" + + v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" + "github.com/specterops/bloodhound/cmd/api/src/database" + "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" + schemamocks "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema/mocks" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { + type mocks struct { + mockOpenGraphSchema *schemamocks.MockOpenGraphSchemaRepository + } + type args struct { + environments []v2.Environment + } + tests := []struct { + name string + setupMocks func(t *testing.T, m *mocks) + args args + expected error + }{ + { + name: "Error: openGraphSchemaRepository.UpsertGraphSchemaExtension error", + args: args{ + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + expectedEnvs := []database.EnvironmentInput{ + { + EnvironmentKindName: "Domain", + SourceKindName: "Base", + PrincipalKinds: []string{"User"}, + }, + } + m.mockOpenGraphSchema.EXPECT().UpsertGraphSchemaExtension( + gomock.Any(), + int32(1), + expectedEnvs, + ).Return(errors.New("error")) + }, + expected: errors.New("error upserting graph extension: error"), + }, + { + name: "Success: single environment", + args: args{ + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User", "Computer"}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + expectedEnvs := []database.EnvironmentInput{ + { + EnvironmentKindName: "Domain", + SourceKindName: "Base", + PrincipalKinds: []string{"User", "Computer"}, + }, + } + m.mockOpenGraphSchema.EXPECT().UpsertGraphSchemaExtension( + gomock.Any(), + int32(1), + expectedEnvs, + ).Return(nil) + }, + expected: nil, + }, + { + name: "Success: multiple environments", + args: args{ + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, + }, + { + EnvironmentKind: "AzureAD", + SourceKind: "AzureHound", + PrincipalKinds: []string{"User", "Group"}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + expectedEnvs := []database.EnvironmentInput{ + { + EnvironmentKindName: "Domain", + SourceKindName: "Base", + PrincipalKinds: []string{"User"}, + }, + { + EnvironmentKindName: "AzureAD", + SourceKindName: "AzureHound", + PrincipalKinds: []string{"User", "Group"}, + }, + } + m.mockOpenGraphSchema.EXPECT().UpsertGraphSchemaExtension( + gomock.Any(), + int32(1), + expectedEnvs, + ).Return(nil) + }, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + m := &mocks{ + mockOpenGraphSchema: schemamocks.NewMockOpenGraphSchemaRepository(ctrl), + } + + tt.setupMocks(t, m) + + service := opengraphschema.NewOpenGraphSchemaService(m.mockOpenGraphSchema) + + err := service.UpsertGraphSchemaExtension(context.Background(), v2.GraphSchemaExtension{ + Environments: tt.args.environments, + }) + + if tt.expected != nil { + assert.EqualError(t, err, tt.expected.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go index 9c3e712987..f1e1539a2a 100644 --- a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go @@ -29,6 +29,7 @@ import ( context "context" reflect "reflect" + database "github.com/specterops/bloodhound/cmd/api/src/database" gomock "go.uber.org/mock/gomock" ) @@ -56,16 +57,16 @@ func (m *MockOpenGraphSchemaRepository) EXPECT() *MockOpenGraphSchemaRepositoryM return m.recorder } -// UpsertSchemaEnvironmentWithPrincipalKinds mocks base method. -func (m *MockOpenGraphSchemaRepository) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environmentKind, sourceKind string, principalKinds []string) error { +// UpsertGraphSchemaExtension mocks base method. +func (m *MockOpenGraphSchemaRepository) UpsertGraphSchemaExtension(ctx context.Context, extensionID int32, environments []database.EnvironmentInput) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpsertSchemaEnvironmentWithPrincipalKinds", ctx, schemaExtensionId, environmentKind, sourceKind, principalKinds) + ret := m.ctrl.Call(m, "UpsertGraphSchemaExtension", ctx, extensionID, environments) ret0, _ := ret[0].(error) return ret0 } -// UpsertSchemaEnvironmentWithPrincipalKinds indicates an expected call of UpsertSchemaEnvironmentWithPrincipalKinds. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) UpsertSchemaEnvironmentWithPrincipalKinds(ctx, schemaExtensionId, environmentKind, sourceKind, principalKinds any) *gomock.Call { +// UpsertGraphSchemaExtension indicates an expected call of UpsertGraphSchemaExtension. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) UpsertGraphSchemaExtension(ctx, extensionID, environments any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertSchemaEnvironmentWithPrincipalKinds", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).UpsertSchemaEnvironmentWithPrincipalKinds), ctx, schemaExtensionId, environmentKind, sourceKind, principalKinds) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertGraphSchemaExtension", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).UpsertGraphSchemaExtension), ctx, extensionID, environments) } diff --git a/cmd/api/src/services/opengraphschema/opengraphschema.go b/cmd/api/src/services/opengraphschema/opengraphschema.go index 6de81562dc..74322b76bf 100644 --- a/cmd/api/src/services/opengraphschema/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/opengraphschema.go @@ -19,11 +19,13 @@ package opengraphschema import ( "context" + + "github.com/specterops/bloodhound/cmd/api/src/database" ) // OpenGraphSchemaRepository - type OpenGraphSchemaRepository interface { - UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environmentKind string, sourceKind string, principalKinds []string) error + UpsertGraphSchemaExtension(ctx context.Context, extensionID int32, environments []database.EnvironmentInput) error } type OpenGraphSchemaService struct { diff --git a/cmd/api/src/services/opengraphschema/opengraphschema_test.go b/cmd/api/src/services/opengraphschema/opengraphschema_test.go deleted file mode 100644 index b15effc82f..0000000000 --- a/cmd/api/src/services/opengraphschema/opengraphschema_test.go +++ /dev/null @@ -1,481 +0,0 @@ -// Copyright 2026 Specter Ops, Inc. -// -// Licensed under the Apache License, Version 2.0 -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 -package opengraphschema_test - -import ( - "context" - "errors" - "testing" - - v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" - "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" - schemamocks "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema/mocks" - "github.com/stretchr/testify/assert" - "go.uber.org/mock/gomock" -) - -func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { - type mocks struct { - mockOpenGraphSchema *schemamocks.MockOpenGraphSchemaRepository - } - type args struct { - environments []v2.Environment - } - tests := []struct { - name string - setupMocks func(t *testing.T, m *mocks) - args args - expected error - }{ - // UpsertSchemaEnvironmentWithPrincipalKinds - // Validation: Environment Kind - { - name: "Error: environment kind name not found in the database", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("environment kind 'Domain' not found")) - }, - expected: errors.New("failed to upload environments with principal kinds: environment kind 'Domain' not found"), - }, - { - name: "Error: failed to retrieve environment kind from database", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error retrieving environment kind 'Domain': error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error retrieving environment kind 'Domain': error"), - }, - // Validation: Source Kind - { - name: "Error: failed to retrieve source kind from database", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error retrieving source kind 'Base': error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error retrieving source kind 'Base': error"), - }, - { - name: "Error: source kind name doesn't exist in database, registration fails", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error registering source kind 'Base': error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error registering source kind 'Base': error"), - }, - { - name: "Error: source kind name doesn't exist in database, registration succeeds but fetch fails", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error retrieving newly registered source kind 'Base': error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error retrieving newly registered source kind 'Base': error"), - }, - // Validation: Principal Kind - { - name: "Error: principal kind not found in database", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User", "InvalidKind"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User", "InvalidKind"}, - ).Return(errors.New("principal kind 'InvalidKind' not found")) - }, - expected: errors.New("failed to upload environments with principal kinds: principal kind 'InvalidKind' not found"), - }, - { - name: "Error: failed to retrieve principal kind from database", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User", "InvalidKind"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User", "InvalidKind"}, - ).Return(errors.New("error retrieving principal kind by name 'InvalidKind': error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error retrieving principal kind by name 'InvalidKind': error"), - }, - // Upsert Schema Environment - { - name: "Error: error retrieving schema environment from database", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error upserting schema environment: error retrieving schema environment id 0: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting schema environment: error retrieving schema environment id 0: error"), - }, - { - name: "Error: error deleting schema environment", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error upserting schema environment: error deleting schema environment 5: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting schema environment: error deleting schema environment 5: error"), - }, - { - name: "Error: error creating schema environment after deletion", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error upserting schema environment: error creating schema environment: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting schema environment: error creating schema environment: error"), - }, - { - name: "Error: error creating new schema environment", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error upserting schema environment: error creating schema environment: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting schema environment: error creating schema environment: error"), - }, - // Upsert Principal Kinds - { - name: "Error: error getting principal kinds by environment id", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User"}, - ).Return(errors.New("error upserting principal kinds: error retrieving existing principal kinds for environment 10: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting principal kinds: error retrieving existing principal kinds for environment 10: error"), - }, - { - name: "Error: openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds error deleting principal kinds", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User"}, - ).Return(errors.New("error upserting principal kinds: error deleting principal kind 5 for environment 10: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting principal kinds: error deleting principal kind 5 for environment 10: error"), - }, - { - name: "Error: openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds error creating principal kinds", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User"}, - ).Return(errors.New("error upserting principal kinds: error creating principal kind 3 for environment 10: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting principal kinds: error creating principal kind 3 for environment 10: error"), - }, - { - name: "Success: Create new environment with principal kinds", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User", "Computer"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User", "Computer"}, - ).Return(nil) - }, - expected: nil, - }, - { - name: "Success: Create environment with source kind registration", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "NewSource", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "NewSource", - []string{}, - ).Return(nil) - }, - expected: nil, - }, - { - name: "Success: Process multiple environments", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User"}, - }, - { - EnvironmentKind: "AzureAD", - SourceKind: "AzureHound", - PrincipalKinds: []string{"User", "Group"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - // First environment - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User"}, - ).Return(nil) - - // Second environment - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "AzureAD", - "AzureHound", - []string{"User", "Group"}, - ).Return(nil) - }, - expected: nil, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - m := &mocks{ - mockOpenGraphSchema: schemamocks.NewMockOpenGraphSchemaRepository(ctrl), - } - - tt.setupMocks(t, m) - - service := opengraphschema.NewOpenGraphSchemaService(m.mockOpenGraphSchema) - - err := service.UpsertGraphSchemaExtension(context.Background(), v2.GraphSchemaExtension{ - Environments: tt.args.environments, - }) - - if tt.expected != nil { - assert.EqualError(t, err, tt.expected.Error()) - } else { - assert.NoError(t, err) - } - }) - } -} From 094d52a6fb27a63418a12257c15e95f6f590331d Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Tue, 13 Jan 2026 13:13:20 -0600 Subject: [PATCH 20/36] just prepare --- cmd/api/src/database/upsert_schema_extension.go | 15 +++++++++++++++ .../upsert_schema_extension_integration_test.go | 15 +++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/cmd/api/src/database/upsert_schema_extension.go b/cmd/api/src/database/upsert_schema_extension.go index d1a9ea8b25..8c564d6aa4 100644 --- a/cmd/api/src/database/upsert_schema_extension.go +++ b/cmd/api/src/database/upsert_schema_extension.go @@ -1,3 +1,18 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 package database import ( diff --git a/cmd/api/src/database/upsert_schema_extension_integration_test.go b/cmd/api/src/database/upsert_schema_extension_integration_test.go index b406021252..9a71677d59 100644 --- a/cmd/api/src/database/upsert_schema_extension_integration_test.go +++ b/cmd/api/src/database/upsert_schema_extension_integration_test.go @@ -1,3 +1,18 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 //go:build integration package database_test From be0c964b944ad8e3c832a6416f478589d58cbc54 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Tue, 13 Jan 2026 13:19:21 -0600 Subject: [PATCH 21/36] just prepare --- cmd/api/src/database/upsert_schema_extension.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/api/src/database/upsert_schema_extension.go b/cmd/api/src/database/upsert_schema_extension.go index 8c564d6aa4..e27f20af87 100644 --- a/cmd/api/src/database/upsert_schema_extension.go +++ b/cmd/api/src/database/upsert_schema_extension.go @@ -4,7 +4,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, From 712d7763b32608dcc5ef790a81fead4c1ae85545 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Tue, 13 Jan 2026 15:52:13 -0600 Subject: [PATCH 22/36] tests --- ...psert_schema_extension_integration_test.go | 407 +++++++-------- .../opengraphschema/extension_test.go | 285 +++++++++++ .../opengraphschema/opengraphschema_test.go | 481 ------------------ 3 files changed, 492 insertions(+), 681 deletions(-) create mode 100644 cmd/api/src/services/opengraphschema/extension_test.go delete mode 100644 cmd/api/src/services/opengraphschema/opengraphschema_test.go diff --git a/cmd/api/src/database/upsert_schema_extension_integration_test.go b/cmd/api/src/database/upsert_schema_extension_integration_test.go index df0f1f0337..76a3d122a7 100644 --- a/cmd/api/src/database/upsert_schema_extension_integration_test.go +++ b/cmd/api/src/database/upsert_schema_extension_integration_test.go @@ -1,202 +1,209 @@ package database_test -// import ( -// "context" -// "testing" - -// "github.com/specterops/bloodhound/cmd/api/src/database" -// "github.com/stretchr/testify/assert" -// "github.com/stretchr/testify/require" -// ) - -// func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { -// type args struct { -// environments []database.EnvironmentInput -// findings []database.FindingInput -// } -// tests := []struct { -// name string -// setupData func(t *testing.T, db *database.BloodhoundDB) int32 // Returns extensionId -// args args -// assert func(t *testing.T, db *database.BloodhoundDB, extensionId int32) -// expectedError string -// }{ -// { -// name: "Success: Create new environments and findings with remediations", -// setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { -// t.Helper() -// ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") -// require.NoError(t, err) - -// _, err = db.CreateSchemaEnvironment(context.Background(), ext.ID, int32(1), int32(1)) -// require.NoError(t, err) - -// return ext.ID -// }, -// args: args{ -// environments: []database.EnvironmentInput{ -// { -// EnvironmentKindName: "Tag_Tier_Zero", -// SourceKindName: "Base", -// PrincipalKinds: []string{"Tag_Owned", "Tag_Tier_Zero"}, -// }, -// }, -// findings: []database.FindingInput{ -// { -// Name: "T0WriteOwner", -// DisplayName: "Write Owner", -// RelationshipKindName: "Tag_Tier_Zero", -// EnvironmentKindName: "Tag_Tier_Zero", -// RemediationInput: database.RemediationInput{ -// ShortDescription: "User has write owner permission", -// LongDescription: "This permission allows modifying object owner", -// ShortRemediation: "Remove write owner permissions", -// LongRemediation: "Review and remove unnecessary permissions", -// }, -// }, -// { -// Name: "T0DCSync", -// DisplayName: "DCSync Attack", -// RelationshipKindName: "Tag_Tier_Zero", -// EnvironmentKindName: "Tag_Tier_Zero", -// RemediationInput: database.RemediationInput{ -// ShortDescription: "Principal can perform DCSync", -// LongDescription: "Can extract password hashes", -// ShortRemediation: "Revoke replication permissions", -// LongRemediation: "Remove DS-Replication-Get-Changes permissions", -// }, -// }, -// }, -// }, -// assert: func(t *testing.T, db *database.BloodhoundDB, extensionId int32) { -// t.Helper() - -// // Verify findings were created -// finding1, err := db.GetSchemaRelationshipFindingByName(context.Background(), "T0WriteOwner") -// require.NoError(t, err) -// assert.Equal(t, extensionId, finding1.SchemaExtensionId) -// assert.Equal(t, "Write Owner", finding1.DisplayName) - -// finding2, err := db.GetSchemaRelationshipFindingByName(context.Background(), "T0DCSync") -// require.NoError(t, err) -// assert.Equal(t, extensionId, finding2.SchemaExtensionId) -// assert.Equal(t, "DCSync Attack", finding2.DisplayName) - -// // Verify remediations were created -// remediation1, err := db.GetRemediationByFindingId(context.Background(), finding1.ID) -// require.NoError(t, err) -// assert.Equal(t, "User has write owner permission", remediation1.ShortDescription) -// assert.Equal(t, "Remove write owner permissions", remediation1.ShortRemediation) - -// remediation2, err := db.GetRemediationByFindingId(context.Background(), finding2.ID) -// require.NoError(t, err) -// assert.Equal(t, "Principal can perform DCSync", remediation2.ShortDescription) -// assert.Equal(t, "Revoke replication permissions", remediation2.ShortRemediation) -// }, -// }, -// // { -// // name: "Success: Update existing findings and remediations", -// // setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { -// // t.Helper() -// // ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt2", "Test2", "v1.0.0") -// // require.NoError(t, err) - -// // env, err := db.CreateSchemaEnvironment(context.Background(), ext.ID, 1, 1) -// // require.NoError(t, err) - -// // // Create initial finding with remediation -// // finding, err := db.CreateSchemaRelationshipFinding(context.Background(), ext.ID, 1, env.ID, "ExistingFinding", "Old Display Name") -// // require.NoError(t, err) - -// // _, err = db.CreateRemediation(context.Background(), finding.ID, "old short", "old long", "old short rem", "old long rem") -// // require.NoError(t, err) - -// // return ext.ID -// // }, -// // args: args{ -// // environments: []database.EnvironmentInput{ -// // { -// // EnvironmentKind: "Domain", -// // SourceKind: "Base", -// // PrincipalKinds: []string{"User"}, -// // }, -// // }, -// // findings: []database.FindingInput{ -// // { -// // Name: "ExistingFinding", -// // DisplayName: "Updated Display Name", -// // RelationshipKindName: "WriteOwner", -// // EnvironmentKindName: "Domain", -// // RemediationInput: database.RemediationInput{ -// // ShortDescription: "updated short", -// // LongDescription: "updated long", -// // ShortRemediation: "updated short rem", -// // LongRemediation: "updated long rem", -// // }, -// // }, -// // }, -// // }, -// // assert: func(t *testing.T, db *database.BloodhoundDB, extensionId int32) { -// // t.Helper() - -// // // Verify finding was updated (deleted and recreated) -// // finding, err := db.GetSchemaRelationshipFindingByName(context.Background(), "ExistingFinding") -// // require.NoError(t, err) -// // assert.Equal(t, extensionId, finding.SchemaExtensionId) -// // assert.Equal(t, "Updated Display Name", finding.DisplayName) - -// // // Verify remediation was updated -// // remediation, err := db.GetRemediationByFindingId(context.Background(), finding.ID) -// // require.NoError(t, err) -// // assert.Equal(t, "updated short", remediation.ShortDescription) -// // assert.Equal(t, "updated long", remediation.LongDescription) -// // assert.Equal(t, "updated short rem", remediation.ShortRemediation) -// // assert.Equal(t, "updated long rem", remediation.LongRemediation) -// // }, -// // }, -// // { -// // name: "Success: Empty environments and findings", -// // setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { -// // t.Helper() -// // ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt3", "Test3", "v1.0.0") -// // require.NoError(t, err) - -// // return ext.ID -// // }, -// // args: args{ -// // environments: []database.EnvironmentInput{}, -// // findings: []database.FindingInput{}, -// // }, -// // assert: func(t *testing.T, db *database.BloodhoundDB, extensionId int32) { -// // t.Helper() -// // // Nothing to assert - just verify no error occurred -// // }, -// // }, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// testSuite := setupIntegrationTestSuite(t) -// defer teardownIntegrationTestSuite(t, &testSuite) - -// extensionId := tt.setupData(t, testSuite.BHDatabase) - -// err := testSuite.BHDatabase.UpsertGraphSchemaExtension( -// context.Background(), -// extensionId, -// tt.args.environments, -// tt.args.findings, -// ) - -// if tt.expectedError != "" { -// require.Error(t, err) -// assert.Contains(t, err.Error(), tt.expectedError) -// } else { -// require.NoError(t, err) -// } - -// if tt.assert != nil { -// tt.assert(t, testSuite.BHDatabase, extensionId) -// } -// }) -// } -// } +import ( + "context" + "testing" + + "github.com/specterops/bloodhound/cmd/api/src/database" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { + type args struct { + environments []database.EnvironmentInput + findings []database.FindingInput + } + tests := []struct { + name string + setupData func(t *testing.T, db *database.BloodhoundDB) int32 // Returns extensionId + args args + assert func(t *testing.T, db *database.BloodhoundDB, extensionId int32) + expectedError string + }{ + { + name: "Success: Create new environments and findings with remediations", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + _, err = db.CreateSchemaEnvironment(context.Background(), ext.ID, int32(1), int32(1)) + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + PrincipalKinds: []string{"Tag_Owned", "Tag_Tier_Zero"}, + }, + }, + findings: []database.FindingInput{ + { + Name: "Name 1", + DisplayName: "Display Name 1", + RelationshipKindName: "Tag_Tier_Zero", + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + RemediationInput: database.RemediationInput{ + ShortDescription: "Short Description", + LongDescription: "Long Description", + ShortRemediation: "Short Remediation", + LongRemediation: "Long Remediation", + }, + }, + { + Name: "Name 2", + DisplayName: "Display Name 2", + RelationshipKindName: "Tag_Tier_Zero", + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + RemediationInput: database.RemediationInput{ + ShortDescription: "Short Description", + LongDescription: "Long Description", + ShortRemediation: "Short Remediation", + LongRemediation: "Long Remediation", + }, + }, + }, + }, + assert: func(t *testing.T, db *database.BloodhoundDB, extensionId int32) { + t.Helper() + + // Verify findings were created + finding1, err := db.GetSchemaRelationshipFindingByName(context.Background(), "Name 1") + require.NoError(t, err) + assert.Equal(t, extensionId, finding1.SchemaExtensionId) + assert.Equal(t, "Display Name 1", finding1.DisplayName) + + finding2, err := db.GetSchemaRelationshipFindingByName(context.Background(), "Name 2") + require.NoError(t, err) + assert.Equal(t, extensionId, finding2.SchemaExtensionId) + assert.Equal(t, "Display Name 2", finding2.DisplayName) + + // Verify remediations were created + remediation1, err := db.GetRemediationByFindingId(context.Background(), finding1.ID) + require.NoError(t, err) + assert.Equal(t, "Short Description", remediation1.ShortDescription) + assert.Equal(t, "Long Description", remediation1.LongDescription) + assert.Equal(t, "Short Remediation", remediation1.ShortRemediation) + assert.Equal(t, "Long Remediation", remediation1.LongRemediation) + + remediation2, err := db.GetRemediationByFindingId(context.Background(), finding2.ID) + require.NoError(t, err) + assert.Equal(t, "Short Description", remediation2.ShortDescription) + assert.Equal(t, "Long Description", remediation2.LongDescription) + assert.Equal(t, "Short Remediation", remediation2.ShortRemediation) + assert.Equal(t, "Long Remediation", remediation2.LongRemediation) + }, + }, + { + name: "Success: Update existing findings and remediations", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt2", "Test2", "v1.0.0") + require.NoError(t, err) + + env, err := db.CreateSchemaEnvironment(context.Background(), ext.ID, 1, 1) + require.NoError(t, err) + + // Create initial finding with remediation + finding, err := db.CreateSchemaRelationshipFinding(context.Background(), ext.ID, 1, env.ID, "ExistingFinding", "Old Display Name") + require.NoError(t, err) + + _, err = db.CreateRemediation(context.Background(), finding.ID, "old short", "old long", "old short rem", "old long rem") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + PrincipalKinds: []string{"Tag_Tier_Zero"}, + }, + }, + findings: []database.FindingInput{ + { + Name: "ExistingFinding", + DisplayName: "Updated Display Name", + RelationshipKindName: "Tag_Tier_Zero", + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + RemediationInput: database.RemediationInput{ + ShortDescription: "Updated Short Description", + LongDescription: "Updated Long Description", + ShortRemediation: "Updated Short Remediation", + LongRemediation: "Updated Long Remediation", + }, + }, + }, + }, + assert: func(t *testing.T, db *database.BloodhoundDB, extensionId int32) { + t.Helper() + + // Verify finding was updated (deleted and recreated) + finding, err := db.GetSchemaRelationshipFindingByName(context.Background(), "ExistingFinding") + require.NoError(t, err) + assert.Equal(t, extensionId, finding.SchemaExtensionId) + assert.Equal(t, "Updated Display Name", finding.DisplayName) + + // Verify remediation was updated + remediation, err := db.GetRemediationByFindingId(context.Background(), finding.ID) + require.NoError(t, err) + assert.Equal(t, "Updated Short Description", remediation.ShortDescription) + assert.Equal(t, "Updated Long Description", remediation.LongDescription) + assert.Equal(t, "Updated Short Remediation", remediation.ShortRemediation) + assert.Equal(t, "Updated Long Remediation", remediation.LongRemediation) + }, + }, + { + name: "Success: Empty environments and findings", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt3", "Test3", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{}, + findings: []database.FindingInput{}, + }, + assert: func(t *testing.T, db *database.BloodhoundDB, extensionId int32) { + t.Helper() + // Nothing to assert - just verify no error occurred + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testSuite := setupIntegrationTestSuite(t) + defer teardownIntegrationTestSuite(t, &testSuite) + + extensionId := tt.setupData(t, testSuite.BHDatabase) + + err := testSuite.BHDatabase.UpsertGraphSchemaExtension( + context.Background(), + extensionId, + tt.args.environments, + tt.args.findings, + ) + + if tt.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + } else { + require.NoError(t, err) + } + + if tt.assert != nil { + tt.assert(t, testSuite.BHDatabase, extensionId) + } + }) + } +} diff --git a/cmd/api/src/services/opengraphschema/extension_test.go b/cmd/api/src/services/opengraphschema/extension_test.go new file mode 100644 index 0000000000..42cf4714b1 --- /dev/null +++ b/cmd/api/src/services/opengraphschema/extension_test.go @@ -0,0 +1,285 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package opengraphschema_test + +import ( + "context" + "errors" + "testing" + + v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" + "github.com/specterops/bloodhound/cmd/api/src/database" + "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" + schemamocks "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema/mocks" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { + type mocks struct { + mockOpenGraphSchema *schemamocks.MockOpenGraphSchemaRepository + } + type args struct { + environments []v2.Environment + findings []v2.Finding + } + tests := []struct { + name string + setupMocks func(t *testing.T, m *mocks) + args args + expected error + }{ + { + name: "Error: openGraphSchemaRepository.UpsertGraphSchemaExtension error", + args: args{ + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, + }, + }, + findings: []v2.Finding{ + { + Name: "Finding", + DisplayName: "DisplayName", + RelationshipKind: "Domain", + EnvironmentKind: "Domain", + SourceKind: "Base", + Remediation: v2.Remediation{ + ShortDescription: "Short Description", + LongDescription: "Long Description", + ShortRemediation: "Short Remediation", + LongRemediation: "Long Remediation", + }, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + expectedEnvs := []database.EnvironmentInput{ + { + EnvironmentKindName: "Domain", + SourceKindName: "Base", + PrincipalKinds: []string{"User"}, + }, + } + expectedFindings := []database.FindingInput{ + { + Name: "Finding", + DisplayName: "DisplayName", + RelationshipKindName: "Domain", + EnvironmentKindName: "Domain", + SourceKindName: "Base", + RemediationInput: database.RemediationInput{ + ShortDescription: "Short Description", + LongDescription: "Long Description", + ShortRemediation: "Short Remediation", + LongRemediation: "Long Remediation", + }, + }, + } + m.mockOpenGraphSchema.EXPECT().UpsertGraphSchemaExtension( + gomock.Any(), + int32(1), + expectedEnvs, + expectedFindings, + ).Return(errors.New("error")) + }, + expected: errors.New("error upserting graph extension: error"), + }, + { + name: "Success: single environment with single finding", + args: args{ + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User", "Computer"}, + }, + }, + findings: []v2.Finding{ + { + Name: "Finding", + DisplayName: "DisplayName", + RelationshipKind: "Domain", + EnvironmentKind: "Domain", + SourceKind: "Base", + Remediation: v2.Remediation{ + ShortDescription: "Short Description", + LongDescription: "Long Description", + ShortRemediation: "Short Remediation", + LongRemediation: "Long Remediation", + }, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + expectedEnvs := []database.EnvironmentInput{ + { + EnvironmentKindName: "Domain", + SourceKindName: "Base", + PrincipalKinds: []string{"User", "Computer"}, + }, + } + expectedFindings := []database.FindingInput{ + { + Name: "Finding", + DisplayName: "DisplayName", + RelationshipKindName: "Domain", + EnvironmentKindName: "Domain", + SourceKindName: "Base", + RemediationInput: database.RemediationInput{ + ShortDescription: "Short Description", + LongDescription: "Long Description", + ShortRemediation: "Short Remediation", + LongRemediation: "Long Remediation", + }, + }, + } + m.mockOpenGraphSchema.EXPECT().UpsertGraphSchemaExtension( + gomock.Any(), + int32(1), + expectedEnvs, + expectedFindings, + ).Return(nil) + }, + expected: nil, + }, + { + name: "Success: multiple environments with multiple findings", + args: args{ + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, + }, + { + EnvironmentKind: "AzureAD", + SourceKind: "AzureHound", + PrincipalKinds: []string{"User", "Group"}, + }, + }, + findings: []v2.Finding{ + { + Name: "Finding1", + DisplayName: "DisplayName1", + RelationshipKind: "Domain", + EnvironmentKind: "Domain", + SourceKind: "Base", + Remediation: v2.Remediation{ + ShortDescription: "Short Description", + LongDescription: "Long Description", + ShortRemediation: "Short Remediation", + LongRemediation: "Long Remediation", + }, + }, + { + Name: "Finding2", + DisplayName: "DisplayName2", + RelationshipKind: "Domain", + EnvironmentKind: "Domain", + SourceKind: "Base", + Remediation: v2.Remediation{ + ShortDescription: "Short Description", + LongDescription: "Long Description", + ShortRemediation: "Short Remediation", + LongRemediation: "Long Remediation", + }, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + expectedEnvs := []database.EnvironmentInput{ + { + EnvironmentKindName: "Domain", + SourceKindName: "Base", + PrincipalKinds: []string{"User"}, + }, + { + EnvironmentKindName: "AzureAD", + SourceKindName: "AzureHound", + PrincipalKinds: []string{"User", "Group"}, + }, + } + expectedFindings := []database.FindingInput{ + { + Name: "Finding1", + DisplayName: "DisplayName1", + RelationshipKindName: "Domain", + EnvironmentKindName: "Domain", + SourceKindName: "Base", + RemediationInput: database.RemediationInput{ + ShortDescription: "Short Description", + LongDescription: "Long Description", + ShortRemediation: "Short Remediation", + LongRemediation: "Long Remediation", + }, + }, + { + Name: "Finding2", + DisplayName: "DisplayName2", + RelationshipKindName: "Domain", + EnvironmentKindName: "Domain", + SourceKindName: "Base", + RemediationInput: database.RemediationInput{ + ShortDescription: "Short Description", + LongDescription: "Long Description", + ShortRemediation: "Short Remediation", + LongRemediation: "Long Remediation", + }, + }, + } + m.mockOpenGraphSchema.EXPECT().UpsertGraphSchemaExtension( + gomock.Any(), + int32(1), + expectedEnvs, + expectedFindings, + ).Return(nil) + }, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + m := &mocks{ + mockOpenGraphSchema: schemamocks.NewMockOpenGraphSchemaRepository(ctrl), + } + + tt.setupMocks(t, m) + + service := opengraphschema.NewOpenGraphSchemaService(m.mockOpenGraphSchema) + + err := service.UpsertGraphSchemaExtension(context.Background(), v2.GraphSchemaExtension{ + Environments: tt.args.environments, + Findings: tt.args.findings, + }) + + if tt.expected != nil { + assert.EqualError(t, err, tt.expected.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/cmd/api/src/services/opengraphschema/opengraphschema_test.go b/cmd/api/src/services/opengraphschema/opengraphschema_test.go deleted file mode 100644 index b15effc82f..0000000000 --- a/cmd/api/src/services/opengraphschema/opengraphschema_test.go +++ /dev/null @@ -1,481 +0,0 @@ -// Copyright 2026 Specter Ops, Inc. -// -// Licensed under the Apache License, Version 2.0 -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 -package opengraphschema_test - -import ( - "context" - "errors" - "testing" - - v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" - "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" - schemamocks "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema/mocks" - "github.com/stretchr/testify/assert" - "go.uber.org/mock/gomock" -) - -func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { - type mocks struct { - mockOpenGraphSchema *schemamocks.MockOpenGraphSchemaRepository - } - type args struct { - environments []v2.Environment - } - tests := []struct { - name string - setupMocks func(t *testing.T, m *mocks) - args args - expected error - }{ - // UpsertSchemaEnvironmentWithPrincipalKinds - // Validation: Environment Kind - { - name: "Error: environment kind name not found in the database", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("environment kind 'Domain' not found")) - }, - expected: errors.New("failed to upload environments with principal kinds: environment kind 'Domain' not found"), - }, - { - name: "Error: failed to retrieve environment kind from database", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error retrieving environment kind 'Domain': error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error retrieving environment kind 'Domain': error"), - }, - // Validation: Source Kind - { - name: "Error: failed to retrieve source kind from database", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error retrieving source kind 'Base': error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error retrieving source kind 'Base': error"), - }, - { - name: "Error: source kind name doesn't exist in database, registration fails", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error registering source kind 'Base': error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error registering source kind 'Base': error"), - }, - { - name: "Error: source kind name doesn't exist in database, registration succeeds but fetch fails", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error retrieving newly registered source kind 'Base': error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error retrieving newly registered source kind 'Base': error"), - }, - // Validation: Principal Kind - { - name: "Error: principal kind not found in database", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User", "InvalidKind"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User", "InvalidKind"}, - ).Return(errors.New("principal kind 'InvalidKind' not found")) - }, - expected: errors.New("failed to upload environments with principal kinds: principal kind 'InvalidKind' not found"), - }, - { - name: "Error: failed to retrieve principal kind from database", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User", "InvalidKind"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User", "InvalidKind"}, - ).Return(errors.New("error retrieving principal kind by name 'InvalidKind': error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error retrieving principal kind by name 'InvalidKind': error"), - }, - // Upsert Schema Environment - { - name: "Error: error retrieving schema environment from database", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error upserting schema environment: error retrieving schema environment id 0: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting schema environment: error retrieving schema environment id 0: error"), - }, - { - name: "Error: error deleting schema environment", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error upserting schema environment: error deleting schema environment 5: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting schema environment: error deleting schema environment 5: error"), - }, - { - name: "Error: error creating schema environment after deletion", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error upserting schema environment: error creating schema environment: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting schema environment: error creating schema environment: error"), - }, - { - name: "Error: error creating new schema environment", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error upserting schema environment: error creating schema environment: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting schema environment: error creating schema environment: error"), - }, - // Upsert Principal Kinds - { - name: "Error: error getting principal kinds by environment id", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User"}, - ).Return(errors.New("error upserting principal kinds: error retrieving existing principal kinds for environment 10: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting principal kinds: error retrieving existing principal kinds for environment 10: error"), - }, - { - name: "Error: openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds error deleting principal kinds", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User"}, - ).Return(errors.New("error upserting principal kinds: error deleting principal kind 5 for environment 10: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting principal kinds: error deleting principal kind 5 for environment 10: error"), - }, - { - name: "Error: openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds error creating principal kinds", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User"}, - ).Return(errors.New("error upserting principal kinds: error creating principal kind 3 for environment 10: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting principal kinds: error creating principal kind 3 for environment 10: error"), - }, - { - name: "Success: Create new environment with principal kinds", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User", "Computer"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User", "Computer"}, - ).Return(nil) - }, - expected: nil, - }, - { - name: "Success: Create environment with source kind registration", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "NewSource", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "NewSource", - []string{}, - ).Return(nil) - }, - expected: nil, - }, - { - name: "Success: Process multiple environments", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User"}, - }, - { - EnvironmentKind: "AzureAD", - SourceKind: "AzureHound", - PrincipalKinds: []string{"User", "Group"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - // First environment - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User"}, - ).Return(nil) - - // Second environment - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "AzureAD", - "AzureHound", - []string{"User", "Group"}, - ).Return(nil) - }, - expected: nil, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - m := &mocks{ - mockOpenGraphSchema: schemamocks.NewMockOpenGraphSchemaRepository(ctrl), - } - - tt.setupMocks(t, m) - - service := opengraphschema.NewOpenGraphSchemaService(m.mockOpenGraphSchema) - - err := service.UpsertGraphSchemaExtension(context.Background(), v2.GraphSchemaExtension{ - Environments: tt.args.environments, - }) - - if tt.expected != nil { - assert.EqualError(t, err, tt.expected.Error()) - } else { - assert.NoError(t, err) - } - }) - } -} From c2758beac2e2b9dd0c599a39a59c461619c650cd Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Wed, 14 Jan 2026 11:37:14 -0600 Subject: [PATCH 23/36] just prepare --- ...psert_schema_extension_integration_test.go | 25 +++++++++++++++---- cmd/api/src/database/upsert_schema_finding.go | 2 +- .../upsert_schema_finding_integration_test.go | 4 +-- .../src/database/upsert_schema_remediation.go | 2 +- .../opengraphschema/extension_test.go | 2 +- 5 files changed, 25 insertions(+), 10 deletions(-) diff --git a/cmd/api/src/database/upsert_schema_extension_integration_test.go b/cmd/api/src/database/upsert_schema_extension_integration_test.go index 76a3d122a7..2c75a61e7d 100644 --- a/cmd/api/src/database/upsert_schema_extension_integration_test.go +++ b/cmd/api/src/database/upsert_schema_extension_integration_test.go @@ -1,3 +1,18 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 package database_test import ( @@ -38,7 +53,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { { EnvironmentKindName: "Tag_Tier_Zero", SourceKindName: "Base", - PrincipalKinds: []string{"Tag_Owned", "Tag_Tier_Zero"}, + PrincipalKinds: []string{"Tag_Owned", "Tag_Tier_Zero"}, }, }, findings: []database.FindingInput{ @@ -47,7 +62,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { DisplayName: "Display Name 1", RelationshipKindName: "Tag_Tier_Zero", EnvironmentKindName: "Tag_Tier_Zero", - SourceKindName: "Base", + SourceKindName: "Base", RemediationInput: database.RemediationInput{ ShortDescription: "Short Description", LongDescription: "Long Description", @@ -60,7 +75,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { DisplayName: "Display Name 2", RelationshipKindName: "Tag_Tier_Zero", EnvironmentKindName: "Tag_Tier_Zero", - SourceKindName: "Base", + SourceKindName: "Base", RemediationInput: database.RemediationInput{ ShortDescription: "Short Description", LongDescription: "Long Description", @@ -124,7 +139,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { { EnvironmentKindName: "Tag_Tier_Zero", SourceKindName: "Base", - PrincipalKinds: []string{"Tag_Tier_Zero"}, + PrincipalKinds: []string{"Tag_Tier_Zero"}, }, }, findings: []database.FindingInput{ @@ -133,7 +148,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { DisplayName: "Updated Display Name", RelationshipKindName: "Tag_Tier_Zero", EnvironmentKindName: "Tag_Tier_Zero", - SourceKindName: "Base", + SourceKindName: "Base", RemediationInput: database.RemediationInput{ ShortDescription: "Updated Short Description", LongDescription: "Updated Long Description", diff --git a/cmd/api/src/database/upsert_schema_finding.go b/cmd/api/src/database/upsert_schema_finding.go index 97abfc9dea..b41044556c 100644 --- a/cmd/api/src/database/upsert_schema_finding.go +++ b/cmd/api/src/database/upsert_schema_finding.go @@ -4,7 +4,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, diff --git a/cmd/api/src/database/upsert_schema_finding_integration_test.go b/cmd/api/src/database/upsert_schema_finding_integration_test.go index c9d4fe5991..ddc8c56a98 100644 --- a/cmd/api/src/database/upsert_schema_finding_integration_test.go +++ b/cmd/api/src/database/upsert_schema_finding_integration_test.go @@ -55,7 +55,7 @@ func TestBloodhoundDB_UpsertFinding(t *testing.T) { return ext.ID }, args: args{ - sourceKindName: "Base", + sourceKindName: "Base", relationshipKindName: "Tag_Tier_Zero", environmentKind: "Tag_Tier_Zero", // Name triggers upsert so this needs to match the finding's name that we want to update @@ -87,7 +87,7 @@ func TestBloodhoundDB_UpsertFinding(t *testing.T) { return ext.ID }, args: args{ - sourceKindName: "Base", + sourceKindName: "Base", relationshipKindName: "Tag_Tier_Zero", environmentKind: "Tag_Tier_Zero", name: "Finding", diff --git a/cmd/api/src/database/upsert_schema_remediation.go b/cmd/api/src/database/upsert_schema_remediation.go index b1843545c8..cff7f98e87 100644 --- a/cmd/api/src/database/upsert_schema_remediation.go +++ b/cmd/api/src/database/upsert_schema_remediation.go @@ -4,7 +4,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, diff --git a/cmd/api/src/services/opengraphschema/extension_test.go b/cmd/api/src/services/opengraphschema/extension_test.go index 42cf4714b1..98277daa57 100644 --- a/cmd/api/src/services/opengraphschema/extension_test.go +++ b/cmd/api/src/services/opengraphschema/extension_test.go @@ -272,7 +272,7 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { err := service.UpsertGraphSchemaExtension(context.Background(), v2.GraphSchemaExtension{ Environments: tt.args.environments, - Findings: tt.args.findings, + Findings: tt.args.findings, }) if tt.expected != nil { From 86560a5a56e7bb04fd8c1a0091fb36c8499f6714 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Wed, 14 Jan 2026 11:40:59 -0600 Subject: [PATCH 24/36] updated error message --- cmd/api/src/database/upsert_schema_extension.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/api/src/database/upsert_schema_extension.go b/cmd/api/src/database/upsert_schema_extension.go index e27f20af87..d189a1d40f 100644 --- a/cmd/api/src/database/upsert_schema_extension.go +++ b/cmd/api/src/database/upsert_schema_extension.go @@ -30,7 +30,7 @@ func (s *BloodhoundDB) UpsertGraphSchemaExtension(ctx context.Context, extension return s.Transaction(ctx, func(tx *BloodhoundDB) error { for _, env := range environments { if err := tx.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, extensionID, env.EnvironmentKindName, env.SourceKindName, env.PrincipalKinds); err != nil { - return fmt.Errorf("failed to upload environment with principal kinds: %w", err) + return fmt.Errorf("failed to upsert environment with principal kinds: %w", err) } } From c75a03b909da0ad211166ee2def46e316a5fc69f Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Wed, 14 Jan 2026 11:57:56 -0600 Subject: [PATCH 25/36] just prepare --- .../src/database/upsert_schema_extension_integration_test.go | 2 +- cmd/api/src/services/entrypoint.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/api/src/database/upsert_schema_extension_integration_test.go b/cmd/api/src/database/upsert_schema_extension_integration_test.go index 2c75a61e7d..517c9bb686 100644 --- a/cmd/api/src/database/upsert_schema_extension_integration_test.go +++ b/cmd/api/src/database/upsert_schema_extension_integration_test.go @@ -4,7 +4,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, diff --git a/cmd/api/src/services/entrypoint.go b/cmd/api/src/services/entrypoint.go index ac8361e87f..76a4f33efd 100644 --- a/cmd/api/src/services/entrypoint.go +++ b/cmd/api/src/services/entrypoint.go @@ -38,8 +38,8 @@ import ( "github.com/specterops/bloodhound/cmd/api/src/migrations" "github.com/specterops/bloodhound/cmd/api/src/model/appcfg" "github.com/specterops/bloodhound/cmd/api/src/queries" - "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" "github.com/specterops/bloodhound/cmd/api/src/services/dogtags" + "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" "github.com/specterops/bloodhound/cmd/api/src/services/upload" "github.com/specterops/bloodhound/packages/go/cache" schema "github.com/specterops/bloodhound/packages/go/graphschema" From a6b04d830bb1b212ae610eae1bc4695eb0ac69ea Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Wed, 14 Jan 2026 12:11:08 -0600 Subject: [PATCH 26/36] integration flag added --- .../src/database/upsert_schema_extension_integration_test.go | 2 ++ cmd/api/src/database/upsert_schema_finding_integration_test.go | 1 + .../src/database/upsert_schema_remediation_integration_test.go | 1 + 3 files changed, 4 insertions(+) diff --git a/cmd/api/src/database/upsert_schema_extension_integration_test.go b/cmd/api/src/database/upsert_schema_extension_integration_test.go index 517c9bb686..41989b0673 100644 --- a/cmd/api/src/database/upsert_schema_extension_integration_test.go +++ b/cmd/api/src/database/upsert_schema_extension_integration_test.go @@ -13,6 +13,8 @@ // limitations under the License. // // SPDX-License-Identifier: Apache-2.0 + +//go:build integration package database_test import ( diff --git a/cmd/api/src/database/upsert_schema_finding_integration_test.go b/cmd/api/src/database/upsert_schema_finding_integration_test.go index ddc8c56a98..fcb1d0940f 100644 --- a/cmd/api/src/database/upsert_schema_finding_integration_test.go +++ b/cmd/api/src/database/upsert_schema_finding_integration_test.go @@ -13,6 +13,7 @@ // limitations under the License. // // SPDX-License-Identifier: Apache-2.0 + //go:build integration package database_test diff --git a/cmd/api/src/database/upsert_schema_remediation_integration_test.go b/cmd/api/src/database/upsert_schema_remediation_integration_test.go index 4115faa0fa..1798be8774 100644 --- a/cmd/api/src/database/upsert_schema_remediation_integration_test.go +++ b/cmd/api/src/database/upsert_schema_remediation_integration_test.go @@ -13,6 +13,7 @@ // limitations under the License. // // SPDX-License-Identifier: Apache-2.0 + //go:build integration package database_test From 84cbd1753be8f43add5b5d1d6d138713268fb781 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Wed, 14 Jan 2026 12:12:43 -0600 Subject: [PATCH 27/36] just prepare --- cmd/api/src/database/upsert_schema_extension_integration_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/cmd/api/src/database/upsert_schema_extension_integration_test.go b/cmd/api/src/database/upsert_schema_extension_integration_test.go index 41989b0673..7d530481c6 100644 --- a/cmd/api/src/database/upsert_schema_extension_integration_test.go +++ b/cmd/api/src/database/upsert_schema_extension_integration_test.go @@ -15,6 +15,7 @@ // SPDX-License-Identifier: Apache-2.0 //go:build integration + package database_test import ( From b47622662e6134b7cc0cdfc9461d0d5289110c24 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Wed, 14 Jan 2026 13:28:18 -0600 Subject: [PATCH 28/36] addressing code rabbit comments --- cmd/api/src/database/graphschema.go | 3 +++ cmd/api/src/database/upsert_schema_finding.go | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index 69540d5917..3cdaa7e809 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -811,6 +811,9 @@ func (s *BloodhoundDB) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, VALUES (?, ?, NOW()) RETURNING environment_id, principal_kind, created_at`, environmentId, principalKind).Scan(&envPrincipalKind); result.Error != nil { + if strings.Contains(result.Error.Error(), DuplicateKeyValueErrorString) { + return model.SchemaEnvironmentPrincipalKind{}, result.Error + } return model.SchemaEnvironmentPrincipalKind{}, CheckError(result) } diff --git a/cmd/api/src/database/upsert_schema_finding.go b/cmd/api/src/database/upsert_schema_finding.go index b41044556c..7c1391c910 100644 --- a/cmd/api/src/database/upsert_schema_finding.go +++ b/cmd/api/src/database/upsert_schema_finding.go @@ -66,7 +66,7 @@ func (s *BloodhoundDB) validateAndTranslateRelationshipKind(ctx context.Context, } // upsertFinding creates or updates a schema relationship finding. -// If an environment with the given kinds exists, it deletes it first before creating the new one. +// If a finding with the given name exists, it deletes it first before creating the new one. func (s *BloodhoundDB) upsertFinding(ctx context.Context, extensionId, relationshipKindId, environmentId int32, name, displayName string) (model.SchemaRelationshipFinding, error) { if existing, err := s.GetSchemaRelationshipFindingByName(ctx, name); err != nil && !errors.Is(err, ErrNotFound) { return model.SchemaRelationshipFinding{}, fmt.Errorf("error retrieving schema relationship finding: %w", err) From 51e44b5516fb5667f2e48ffd302660947ce8d021 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Wed, 14 Jan 2026 13:29:49 -0600 Subject: [PATCH 29/36] main changes --- cmd/api/src/database/graphschema.go | 4 ++-- cmd/api/src/services/entrypoint.go | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index 3cdaa7e809..368a0feee3 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -812,8 +812,8 @@ func (s *BloodhoundDB) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, RETURNING environment_id, principal_kind, created_at`, environmentId, principalKind).Scan(&envPrincipalKind); result.Error != nil { if strings.Contains(result.Error.Error(), DuplicateKeyValueErrorString) { - return model.SchemaEnvironmentPrincipalKind{}, result.Error - } + return model.SchemaEnvironmentPrincipalKind{}, result.Error + } return model.SchemaEnvironmentPrincipalKind{}, CheckError(result) } diff --git a/cmd/api/src/services/entrypoint.go b/cmd/api/src/services/entrypoint.go index fd2eebb60e..cb87ebfd2b 100644 --- a/cmd/api/src/services/entrypoint.go +++ b/cmd/api/src/services/entrypoint.go @@ -119,13 +119,13 @@ func Entrypoint(ctx context.Context, cfg config.Configuration, connections boots startDelay := 0 * time.Second var ( - cl = changelog.NewChangelog(connections.Graph, connections.RDMS, changelog.DefaultOptions()) - pipeline = datapipe.NewPipeline(ctx, cfg, connections.RDMS, connections.Graph, graphQueryCache, ingestSchema, cl) - graphQuery = queries.NewGraphQuery(connections.Graph, graphQueryCache, cfg) - authorizer = auth.NewAuthorizer(connections.RDMS) - datapipeDaemon = datapipe.NewDaemon(pipeline, startDelay, time.Duration(cfg.DatapipeInterval)*time.Second, connections.RDMS) - routerInst = router.NewRouter(cfg, authorizer, fmt.Sprintf(bootstrap.ContentSecurityPolicy, "", "")) - authenticator = api.NewAuthenticator(cfg, connections.RDMS, api.NewAuthExtensions(cfg, connections.RDMS)) + cl = changelog.NewChangelog(connections.Graph, connections.RDMS, changelog.DefaultOptions()) + pipeline = datapipe.NewPipeline(ctx, cfg, connections.RDMS, connections.Graph, graphQueryCache, ingestSchema, cl) + graphQuery = queries.NewGraphQuery(connections.Graph, graphQueryCache, cfg) + authorizer = auth.NewAuthorizer(connections.RDMS) + datapipeDaemon = datapipe.NewDaemon(pipeline, startDelay, time.Duration(cfg.DatapipeInterval)*time.Second, connections.RDMS) + routerInst = router.NewRouter(cfg, authorizer, fmt.Sprintf(bootstrap.ContentSecurityPolicy, "", "")) + authenticator = api.NewAuthenticator(cfg, connections.RDMS, api.NewAuthExtensions(cfg, connections.RDMS)) openGraphSchemaService = opengraphschema.NewOpenGraphSchemaService(connections.RDMS) ) From 1cab880bcc2240eceef5c76ca591fc42fa03af7d Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Thu, 15 Jan 2026 15:50:05 -0600 Subject: [PATCH 30/36] peer review changes --- cmd/api/src/database/graphschema.go | 28 +++--- .../database/graphschema_integration_test.go | 18 ++-- .../database/migration/migrations/v8.6.0.sql | 22 +++++ cmd/api/src/database/mocks/db.go | 88 +++++++++---------- .../src/database/upsert_schema_environment.go | 24 ++--- ...ert_schema_environment_integration_test.go | 8 +- ...psert_schema_extension_integration_test.go | 12 +-- 7 files changed, 113 insertions(+), 87 deletions(-) create mode 100644 cmd/api/src/database/migration/migrations/v8.6.0.sql diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index df2ed45b18..b6e2348d2a 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -66,9 +66,10 @@ type OpenGraphSchema interface { GetRemediationByFindingId(ctx context.Context, findingId int32) (model.Remediation, error) UpdateRemediation(ctx context.Context, findingId int32, shortDescription string, longDescription string, shortRemediation string, longRemediation string) (model.Remediation, error) DeleteRemediation(ctx context.Context, findingId int32) error - CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) - GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) - DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error + + CreatePrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) + GetPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) + DeletePrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error } const ( @@ -240,10 +241,10 @@ func (s *BloodhoundDB) GetGraphSchemaNodeKinds(ctx context.Context, filters mode if filterAndPagination, err := parseFiltersAndPagination(filters, sort, skip, limit); err != nil { return schemaNodeKinds, 0, err } else { - sqlStr := fmt.Sprintf(`SELECT nk.id, k.name, nk.schema_extension_id, nk.display_name, nk.description, + sqlStr := fmt.Sprintf(`SELECT nk.id, k.name, nk.schema_extension_id, nk.display_name, nk.description, nk.is_display_kind, nk.icon, nk.icon_color, nk.created_at, nk.updated_at, nk.deleted_at FROM %s nk - JOIN %s k ON nk.kind_id = k.id + JOIN %s k ON nk.kind_id = k.id %s %s %s`, model.GraphSchemaNodeKind{}.TableName(), kindTable, filterAndPagination.WhereClause, filterAndPagination.OrderSql, filterAndPagination.SkipLimit) if result := s.db.WithContext(ctx).Raw(sqlStr, filterAndPagination.Filter.params...).Scan(&schemaNodeKinds); result.Error != nil { @@ -289,7 +290,7 @@ func (s *BloodhoundDB) UpdateGraphSchemaNodeKind(ctx context.Context, schemaNode RETURNING id, kind_id, schema_extension_id, display_name, description, is_display_kind, icon, icon_color, created_at, updated_at, deleted_at ) SELECT updated_row.id, %s.name, schema_extension_id, display_name, description, is_display_kind, icon, icon_color, created_at, updated_at, deleted_at - FROM updated_row + FROM updated_row JOIN %s ON %s.id = updated_row.kind_id`, schemaNodeKind.TableName(), kindTable, kindTable, kindTable), schemaNodeKind.SchemaExtensionId, schemaNodeKind.DisplayName, schemaNodeKind.Description, schemaNodeKind.IsDisplayKind, schemaNodeKind.Icon, @@ -437,7 +438,7 @@ func (s *BloodhoundDB) CreateGraphSchemaEdgeKind(ctx context.Context, name strin RETURNING id, kind_id, schema_extension_id, description, is_traversable, created_at, updated_at, deleted_at ) SELECT ie.id, ie.schema_extension_id, dk.name, ie.description, ie.is_traversable, ie.created_at, ie.updated_at, ie.deleted_at - FROM inserted_edges ie + FROM inserted_edges ie JOIN dawgs_kind dk ON ie.kind_id = dk.id;`, name, schemaExtensionId, description, isTraversable).Scan(&schemaEdgeKind); result.Error != nil { if strings.Contains(result.Error.Error(), DuplicateKeyValueErrorString) { return schemaEdgeKind, fmt.Errorf("%w: %v", ErrDuplicateSchemaEdgeKindName, result.Error) @@ -458,10 +459,10 @@ func (s *BloodhoundDB) GetGraphSchemaEdgeKinds(ctx context.Context, edgeKindFilt if filterAndPagination, err := parseFiltersAndPagination(edgeKindFilters, sort, skip, limit); err != nil { return schemaEdgeKinds, 0, err } else { - sqlStr := fmt.Sprintf(`SELECT ek.id, k.name, ek.schema_extension_id, ek.description, ek.is_traversable, + sqlStr := fmt.Sprintf(`SELECT ek.id, k.name, ek.schema_extension_id, ek.description, ek.is_traversable, ek.created_at, ek.updated_at, ek.deleted_at FROM %s ek - JOIN %s k ON ek.kind_id = k.id + JOIN %s k ON ek.kind_id = k.id %s %s %s`, model.GraphSchemaEdgeKind{}.TableName(), kindTable, filterAndPagination.WhereClause, filterAndPagination.OrderSql, filterAndPagination.SkipLimit) @@ -543,7 +544,7 @@ func (s *BloodhoundDB) UpdateGraphSchemaEdgeKind(ctx context.Context, schemaEdge SET schema_extension_id = ?, description = ?, is_traversable = ?, updated_at = NOW() WHERE id = ? RETURNING id, kind_id, schema_extension_id, description, is_traversable, created_at, updated_at, deleted_at - ) + ) SELECT updated_row.id, %s.name, schema_extension_id, description, is_traversable, created_at, updated_at, deleted_at FROM updated_row JOIN %s ON %s.id = updated_row.kind_id`, @@ -784,7 +785,7 @@ func (s *BloodhoundDB) DeleteRemediation(ctx context.Context, findingId int32) e return nil } -func (s *BloodhoundDB) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { +func (s *BloodhoundDB) CreatePrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { var envPrincipalKind model.SchemaEnvironmentPrincipalKind if result := s.db.WithContext(ctx).Raw(` @@ -798,7 +799,8 @@ func (s *BloodhoundDB) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, return envPrincipalKind, nil } -func (s *BloodhoundDB) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { +// GetPrincipalKindsByEnvironmentID - retrieves a schema environments principal kind by environment id. +func (s *BloodhoundDB) GetPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { var envPrincipalKinds model.SchemaEnvironmentPrincipalKinds if result := s.db.WithContext(ctx).Raw(` @@ -812,7 +814,7 @@ func (s *BloodhoundDB) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx con return envPrincipalKinds, nil } -func (s *BloodhoundDB) DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error { +func (s *BloodhoundDB) DeletePrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error { if result := s.db.WithContext(ctx).Exec(` DELETE FROM schema_environments_principal_kinds WHERE environment_id = ? AND principal_kind = ?`, diff --git a/cmd/api/src/database/graphschema_integration_test.go b/cmd/api/src/database/graphschema_integration_test.go index 28db314d6a..412a5d2d1d 100644 --- a/cmd/api/src/database/graphschema_integration_test.go +++ b/cmd/api/src/database/graphschema_integration_test.go @@ -2224,7 +2224,7 @@ func TestCreateSchemaEnvironmentPrincipalKind(t *testing.T) { testSuite := testCase.setup() defer teardownIntegrationTestSuite(t, &testSuite) - result, err := testSuite.BHDatabase.CreateSchemaEnvironmentPrincipalKind(testSuite.Context, testCase.args.environmentId, testCase.args.principalKind) + result, err := testSuite.BHDatabase.CreatePrincipalKind(testSuite.Context, testCase.args.environmentId, testCase.args.principalKind) if testCase.want.err != nil { assert.ErrorIs(t, err, testCase.want.err) } else { @@ -2262,10 +2262,10 @@ func TestGetSchemaEnvironmentPrincipalKindsByEnvironmentId(t *testing.T) { _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironmentPrincipalKind(testSuite.Context, 1, 1) + _, err = testSuite.BHDatabase.CreatePrincipalKind(testSuite.Context, 1, 1) require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironmentPrincipalKind(testSuite.Context, 1, 2) + _, err = testSuite.BHDatabase.CreatePrincipalKind(testSuite.Context, 1, 2) require.NoError(t, err) return testSuite @@ -2295,7 +2295,7 @@ func TestGetSchemaEnvironmentPrincipalKindsByEnvironmentId(t *testing.T) { testSuite := testCase.setup() defer teardownIntegrationTestSuite(t, &testSuite) - result, err := testSuite.BHDatabase.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(testSuite.Context, testCase.args.environmentId) + result, err := testSuite.BHDatabase.GetPrincipalKindsByEnvironmentId(testSuite.Context, testCase.args.environmentId) if testCase.want.err != nil { assert.ErrorIs(t, err, testCase.want.err) } else { @@ -2332,7 +2332,7 @@ func TestDeleteSchemaEnvironmentPrincipalKind(t *testing.T) { _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironmentPrincipalKind(testSuite.Context, 1, 1) + _, err = testSuite.BHDatabase.CreatePrincipalKind(testSuite.Context, 1, 1) require.NoError(t, err) return testSuite @@ -2364,12 +2364,12 @@ func TestDeleteSchemaEnvironmentPrincipalKind(t *testing.T) { testSuite := testCase.setup() defer teardownIntegrationTestSuite(t, &testSuite) - err := testSuite.BHDatabase.DeleteSchemaEnvironmentPrincipalKind(testSuite.Context, testCase.args.environmentId, testCase.args.principalKind) + err := testSuite.BHDatabase.DeletePrincipalKind(testSuite.Context, testCase.args.environmentId, testCase.args.principalKind) if testCase.want.err != nil { assert.ErrorIs(t, err, testCase.want.err) } else { assert.NoError(t, err) - result, err := testSuite.BHDatabase.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(testSuite.Context, testCase.args.environmentId) + result, err := testSuite.BHDatabase.GetPrincipalKindsByEnvironmentId(testSuite.Context, testCase.args.environmentId) assert.NoError(t, err) assert.Len(t, result, 0) } @@ -2403,7 +2403,7 @@ func TestDeleteSchemaExtension_CascadeDeletesAllDependents(t *testing.T) { _, err = testSuite.BHDatabase.CreateRemediation(testSuite.Context, relationshipFinding.ID, "Short desc", "Long desc", "Short remediation", "Long remediation") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironmentPrincipalKind(testSuite.Context, environment.ID, nodeKind.ID) + _, err = testSuite.BHDatabase.CreatePrincipalKind(testSuite.Context, environment.ID, nodeKind.ID) require.NoError(t, err) err = testSuite.BHDatabase.DeleteGraphSchemaExtension(testSuite.Context, extension.ID) @@ -2427,7 +2427,7 @@ func TestDeleteSchemaExtension_CascadeDeletesAllDependents(t *testing.T) { _, err = testSuite.BHDatabase.GetRemediationByFindingId(testSuite.Context, relationshipFinding.ID) assert.ErrorIs(t, err, database.ErrNotFound) - principalKinds, err := testSuite.BHDatabase.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(testSuite.Context, environment.ID) + principalKinds, err := testSuite.BHDatabase.GetPrincipalKindsByEnvironmentId(testSuite.Context, environment.ID) assert.NoError(t, err) assert.Len(t, principalKinds, 0) } diff --git a/cmd/api/src/database/migration/migrations/v8.6.0.sql b/cmd/api/src/database/migration/migrations/v8.6.0.sql new file mode 100644 index 0000000000..fbb20a81b0 --- /dev/null +++ b/cmd/api/src/database/migration/migrations/v8.6.0.sql @@ -0,0 +1,22 @@ +-- Copyright 2026 Specter Ops, Inc. +-- +-- Licensed under the Apache License, Version 2.0 +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 +-- Drop the column with the incorrect foreign key +ALTER TABLE IF EXISTS schema_environments + DROP COLUMN IF EXISTS source_kind_id; + +-- Add the column back with the correct foreign key reference +ALTER TABLE IF EXISTS schema_environments + ADD COLUMN source_kind_id INTEGER NOT NULL REFERENCES source_kinds(id); diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index b911c1ad42..df0c1b70fb 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -461,6 +461,21 @@ func (mr *MockDatabaseMockRecorder) CreateOIDCProvider(ctx, name, issuer, client return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOIDCProvider", reflect.TypeOf((*MockDatabase)(nil).CreateOIDCProvider), ctx, name, issuer, clientID, config) } +// CreatePrincipalKind mocks base method. +func (m *MockDatabase) CreatePrincipalKind(ctx context.Context, environmentId, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreatePrincipalKind", ctx, environmentId, principalKind) + ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKind) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreatePrincipalKind indicates an expected call of CreatePrincipalKind. +func (mr *MockDatabaseMockRecorder) CreatePrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreatePrincipalKind", reflect.TypeOf((*MockDatabase)(nil).CreatePrincipalKind), ctx, environmentId, principalKind) +} + // CreateRemediation mocks base method. func (m *MockDatabase) CreateRemediation(ctx context.Context, findingId int32, shortDescription, longDescription, shortRemediation, longRemediation string) (model.Remediation, error) { m.ctrl.T.Helper() @@ -585,21 +600,6 @@ func (mr *MockDatabaseMockRecorder) CreateSchemaEnvironment(ctx, extensionId, en return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironment", reflect.TypeOf((*MockDatabase)(nil).CreateSchemaEnvironment), ctx, extensionId, environmentKindId, sourceKindId) } -// CreateSchemaEnvironmentPrincipalKind mocks base method. -func (m *MockDatabase) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) - ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKind) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateSchemaEnvironmentPrincipalKind indicates an expected call of CreateSchemaEnvironmentPrincipalKind. -func (mr *MockDatabaseMockRecorder) CreateSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockDatabase)(nil).CreateSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) -} - // CreateSchemaRelationshipFinding mocks base method. func (m *MockDatabase) CreateSchemaRelationshipFinding(ctx context.Context, extensionId, relationshipKindId, environmentId int32, name, displayName string) (model.SchemaRelationshipFinding, error) { m.ctrl.T.Helper() @@ -926,6 +926,20 @@ func (mr *MockDatabaseMockRecorder) DeleteIngestTask(ctx, ingestTask any) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteIngestTask", reflect.TypeOf((*MockDatabase)(nil).DeleteIngestTask), ctx, ingestTask) } +// DeletePrincipalKind mocks base method. +func (m *MockDatabase) DeletePrincipalKind(ctx context.Context, environmentId, principalKind int32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeletePrincipalKind", ctx, environmentId, principalKind) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeletePrincipalKind indicates an expected call of DeletePrincipalKind. +func (mr *MockDatabaseMockRecorder) DeletePrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePrincipalKind", reflect.TypeOf((*MockDatabase)(nil).DeletePrincipalKind), ctx, environmentId, principalKind) +} + // DeleteRemediation mocks base method. func (m *MockDatabase) DeleteRemediation(ctx context.Context, findingId int32) error { m.ctrl.T.Helper() @@ -1001,20 +1015,6 @@ func (mr *MockDatabaseMockRecorder) DeleteSchemaEnvironment(ctx, environmentId a return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironment", reflect.TypeOf((*MockDatabase)(nil).DeleteSchemaEnvironment), ctx, environmentId) } -// DeleteSchemaEnvironmentPrincipalKind mocks base method. -func (m *MockDatabase) DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteSchemaEnvironmentPrincipalKind indicates an expected call of DeleteSchemaEnvironmentPrincipalKind. -func (mr *MockDatabaseMockRecorder) DeleteSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockDatabase)(nil).DeleteSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) -} - // DeleteSchemaRelationshipFinding mocks base method. func (m *MockDatabase) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error { m.ctrl.T.Helper() @@ -2007,6 +2007,21 @@ func (mr *MockDatabaseMockRecorder) GetPermission(ctx, id any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermission", reflect.TypeOf((*MockDatabase)(nil).GetPermission), ctx, id) } +// GetPrincipalKindsByEnvironmentId mocks base method. +func (m *MockDatabase) GetPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPrincipalKindsByEnvironmentId", ctx, environmentId) + ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKinds) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPrincipalKindsByEnvironmentId indicates an expected call of GetPrincipalKindsByEnvironmentId. +func (mr *MockDatabaseMockRecorder) GetPrincipalKindsByEnvironmentId(ctx, environmentId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrincipalKindsByEnvironmentId", reflect.TypeOf((*MockDatabase)(nil).GetPrincipalKindsByEnvironmentId), ctx, environmentId) +} + // GetPublicSavedQueries mocks base method. func (m *MockDatabase) GetPublicSavedQueries(ctx context.Context) (model.SavedQueries, error) { m.ctrl.T.Helper() @@ -2217,21 +2232,6 @@ func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentByKinds(ctx, environment return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentByKinds", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentByKinds), ctx, environmentKindId, sourceKindId) } -// GetSchemaEnvironmentPrincipalKindsByEnvironmentId mocks base method. -func (m *MockDatabase) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", ctx, environmentId) - ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKinds) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaEnvironmentPrincipalKindsByEnvironmentId indicates an expected call of GetSchemaEnvironmentPrincipalKindsByEnvironmentId. -func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentPrincipalKindsByEnvironmentId), ctx, environmentId) -} - // GetSchemaEnvironments mocks base method. func (m *MockDatabase) GetSchemaEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/database/upsert_schema_environment.go b/cmd/api/src/database/upsert_schema_environment.go index f65022bfb3..df1401db18 100644 --- a/cmd/api/src/database/upsert_schema_environment.go +++ b/cmd/api/src/database/upsert_schema_environment.go @@ -49,13 +49,13 @@ func (s *BloodhoundDB) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Con environment.EnvironmentKindId = int32(envKind.ID) environment.SourceKindId = sourceKindID - envID, err := s.upsertSchemaEnvironment(ctx, environment) + envID, err := s.replaceSchemaEnvironment(ctx, environment) if err != nil { - return fmt.Errorf("error upserting schema environment: %w", err) + return fmt.Errorf("error replacing or creating schema environment: %w", err) } - if err := s.upsertPrincipalKinds(ctx, envID, translatedPrincipalKinds); err != nil { - return fmt.Errorf("error upserting principal kinds: %w", err) + if err := s.replacePrincipalKinds(ctx, envID, translatedPrincipalKinds); err != nil { + return fmt.Errorf("error replacing principal kinds: %w", err) } return nil @@ -114,9 +114,11 @@ func (s *BloodhoundDB) validateAndTranslatePrincipalKinds(ctx context.Context, p return principalKinds, nil } -// upsertSchemaEnvironment creates or updates a schema environment. +// replaceSchemaEnvironment creates or updates a schema environment. // If an environment with the given kinds exists, it deletes it first before creating the new one. -func (s *BloodhoundDB) upsertSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) (int32, error) { +// The unique constraint on (environment_kind_id, source_kind_id) of the Schema Environment table ensures no +// duplicate pairs exist, enabling this upsert logic. +func (s *BloodhoundDB) replaceSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) (int32, error) { if existing, err := s.GetSchemaEnvironmentByKinds(ctx, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil && !errors.Is(err, ErrNotFound) { return 0, fmt.Errorf("error retrieving schema environment: %w", err) } else if !errors.Is(err, ErrNotFound) { @@ -134,14 +136,14 @@ func (s *BloodhoundDB) upsertSchemaEnvironment(ctx context.Context, graphSchema } } -// upsertPrincipalKinds deletes all existing principal kinds for an environment and creates new ones. -func (s *BloodhoundDB) upsertPrincipalKinds(ctx context.Context, environmentID int32, principalKinds []model.SchemaEnvironmentPrincipalKind) error { - if existingKinds, err := s.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentID); err != nil && !errors.Is(err, ErrNotFound) { +// replacePrincipalKinds deletes all existing principal kinds for an environment and creates new ones. +func (s *BloodhoundDB) replacePrincipalKinds(ctx context.Context, environmentID int32, principalKinds []model.SchemaEnvironmentPrincipalKind) error { + if existingKinds, err := s.GetPrincipalKindsByEnvironmentId(ctx, environmentID); err != nil && !errors.Is(err, ErrNotFound) { return fmt.Errorf("error retrieving existing principal kinds for environment %d: %w", environmentID, err) } else if !errors.Is(err, ErrNotFound) { // Delete all existing principal kinds for _, kind := range existingKinds { - if err := s.DeleteSchemaEnvironmentPrincipalKind(ctx, kind.EnvironmentId, kind.PrincipalKind); err != nil { + if err := s.DeletePrincipalKind(ctx, kind.EnvironmentId, kind.PrincipalKind); err != nil { return fmt.Errorf("error deleting principal kind %d for environment %d: %w", kind.PrincipalKind, kind.EnvironmentId, err) } } @@ -149,7 +151,7 @@ func (s *BloodhoundDB) upsertPrincipalKinds(ctx context.Context, environmentID i // Create the new principal kinds for _, kind := range principalKinds { - if _, err := s.CreateSchemaEnvironmentPrincipalKind(ctx, environmentID, kind.PrincipalKind); err != nil { + if _, err := s.CreatePrincipalKind(ctx, environmentID, kind.PrincipalKind); err != nil { return fmt.Errorf("error creating principal kind %d for environment %d: %w", kind.PrincipalKind, environmentID, err) } } diff --git a/cmd/api/src/database/upsert_schema_environment_integration_test.go b/cmd/api/src/database/upsert_schema_environment_integration_test.go index bd32cce3d2..b811b10adb 100644 --- a/cmd/api/src/database/upsert_schema_environment_integration_test.go +++ b/cmd/api/src/database/upsert_schema_environment_integration_test.go @@ -64,7 +64,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, len(environments)) - principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) assert.NoError(t, err) assert.Equal(t, len(expectedPrincipalKindNames), len(principalKinds)) @@ -118,7 +118,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, len(environments), "Should only have one environment (old one deleted)") - principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) assert.NoError(t, err) assert.Equal(t, 1, len(principalKinds)) @@ -155,7 +155,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { assert.Equal(t, 1, len(environments)) assert.Equal(t, int32(sourceKind.ID), environments[0].SourceKindId) - principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) assert.NoError(t, err) assert.Equal(t, 1, len(principalKinds)) }, @@ -265,7 +265,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { assert.Equal(t, 2, len(environments), "Should have two different environments") for _, env := range environments { - principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), env.ID) + principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), env.ID) assert.NoError(t, err) assert.Equal(t, 1, len(principalKinds), "Each environment should have one principal kind") } diff --git a/cmd/api/src/database/upsert_schema_extension_integration_test.go b/cmd/api/src/database/upsert_schema_extension_integration_test.go index 9a71677d59..2f998cd054 100644 --- a/cmd/api/src/database/upsert_schema_extension_integration_test.go +++ b/cmd/api/src/database/upsert_schema_extension_integration_test.go @@ -65,7 +65,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, len(environments)) - principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) assert.NoError(t, err) assert.Equal(t, len(expectedPrincipalKindNames), len(principalKinds)) @@ -116,12 +116,12 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { assert.Equal(t, 2, len(environments), "Should have two environments") // Verify first environment - env1PrincipalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + env1PrincipalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) assert.NoError(t, err) assert.Equal(t, 1, len(env1PrincipalKinds)) // Verify second environment - env2PrincipalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[1].ID) + env2PrincipalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[1].ID) assert.NoError(t, err) assert.Equal(t, 1, len(env2PrincipalKinds)) }, @@ -163,7 +163,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, len(environments), "Should only have one environment (old one replaced)") - principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) assert.NoError(t, err) assert.Equal(t, 1, len(principalKinds)) @@ -204,7 +204,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { assert.Equal(t, 1, len(environments)) assert.Equal(t, int32(sourceKind.ID), environments[0].SourceKindId) - principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) assert.NoError(t, err) assert.Equal(t, 1, len(principalKinds)) }, @@ -245,7 +245,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { assert.Equal(t, 2, len(environments), "Should have two environments") for _, env := range environments { - principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), env.ID) + principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), env.ID) assert.NoError(t, err) assert.Equal(t, 1, len(principalKinds), "Each environment should have one principal kind") } From 471fc4e7bc0422f3990759019cb3550b732ffe3e Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Thu, 15 Jan 2026 15:56:38 -0600 Subject: [PATCH 31/36] peer review changes - more naming --- cmd/api/src/database/graphschema.go | 30 ++-- .../database/graphschema_integration_test.go | 72 ++++----- cmd/api/src/database/mocks/db.go | 148 +++++++++--------- .../src/database/upsert_schema_environment.go | 6 +- ...ert_schema_environment_integration_test.go | 14 +- ...psert_schema_extension_integration_test.go | 20 +-- cmd/api/src/services/entrypoint.go | 14 +- 7 files changed, 152 insertions(+), 152 deletions(-) diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index b6e2348d2a..3588f3f79b 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -52,11 +52,11 @@ type OpenGraphSchema interface { GetGraphSchemaEdgeKindsWithSchemaName(ctx context.Context, edgeKindFilters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaEdgeKindsWithNamedSchema, int, error) - CreateSchemaEnvironment(ctx context.Context, extensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) - GetSchemaEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) - GetSchemaEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) - GetSchemaEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) - DeleteSchemaEnvironment(ctx context.Context, environmentId int32) error + CreateEnvironment(ctx context.Context, extensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) + GetEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) + GetEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) + GetEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) + DeleteEnvironment(ctx context.Context, environmentId int32) error CreateSchemaRelationshipFinding(ctx context.Context, extensionId int32, relationshipKindId int32, environmentId int32, name string, displayName string) (model.SchemaRelationshipFinding, error) GetSchemaRelationshipFindingById(ctx context.Context, findingId int32) (model.SchemaRelationshipFinding, error) @@ -572,8 +572,8 @@ func (s *BloodhoundDB) DeleteGraphSchemaEdgeKind(ctx context.Context, schemaEdge return nil } -// CreateSchemaEnvironment - creates a new schema_environment. -func (s *BloodhoundDB) CreateSchemaEnvironment(ctx context.Context, extensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) { +// CreateEnvironment - creates a new schema_environment. +func (s *BloodhoundDB) CreateEnvironment(ctx context.Context, extensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) { var schemaEnvironment model.SchemaEnvironment if result := s.db.WithContext(ctx).Raw(fmt.Sprintf(` @@ -590,14 +590,14 @@ func (s *BloodhoundDB) CreateSchemaEnvironment(ctx context.Context, extensionId return schemaEnvironment, nil } -// GetSchemaEnvironments - retrieves list of schema environments. -func (s *BloodhoundDB) GetSchemaEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) { +// GetEnvironments - retrieves list of schema environments. +func (s *BloodhoundDB) GetEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) { var result []model.SchemaEnvironment return result, CheckError(s.db.WithContext(ctx).Find(&result)) } -// GetSchemaEnvironmentByKinds - retrieves an environment by its environment kind and source kind. -func (s *BloodhoundDB) GetSchemaEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { +// GetEnvironmentByKinds - retrieves an environment by its environment kind and source kind. +func (s *BloodhoundDB) GetEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { var env model.SchemaEnvironment if result := s.db.WithContext(ctx).Raw( @@ -612,8 +612,8 @@ func (s *BloodhoundDB) GetSchemaEnvironmentByKinds(ctx context.Context, environm return env, nil } -// GetSchemaEnvironmentById - retrieves a schema environment by id. -func (s *BloodhoundDB) GetSchemaEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) { +// GetEnvironmentById - retrieves a schema environment by id. +func (s *BloodhoundDB) GetEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) { var schemaEnvironment model.SchemaEnvironment if result := s.db.WithContext(ctx).Raw(fmt.Sprintf(` @@ -629,8 +629,8 @@ func (s *BloodhoundDB) GetSchemaEnvironmentById(ctx context.Context, environment return schemaEnvironment, nil } -// DeleteSchemaEnvironment - deletes a schema environment by id. -func (s *BloodhoundDB) DeleteSchemaEnvironment(ctx context.Context, environmentId int32) error { +// DeleteEnvironment - deletes a schema environment by id. +func (s *BloodhoundDB) DeleteEnvironment(ctx context.Context, environmentId int32) error { var schemaEnvironment model.SchemaEnvironment if result := s.db.WithContext(ctx).Exec(fmt.Sprintf(`DELETE FROM %s WHERE id = ?`, schemaEnvironment.TableName()), environmentId); result.Error != nil { diff --git a/cmd/api/src/database/graphschema_integration_test.go b/cmd/api/src/database/graphschema_integration_test.go index 412a5d2d1d..b4431b7d8f 100644 --- a/cmd/api/src/database/graphschema_integration_test.go +++ b/cmd/api/src/database/graphschema_integration_test.go @@ -1213,7 +1213,7 @@ func TestCreateSchemaEnvironment(t *testing.T) { testSuite := testCase.setup() defer teardownIntegrationTestSuite(t, &testSuite) - got, err := testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, testCase.args.extensionId, testCase.args.environmentKindId, testCase.args.sourceKindId) + got, err := testSuite.BHDatabase.CreateEnvironment(testSuite.Context, testCase.args.extensionId, testCase.args.environmentKindId, testCase.args.sourceKindId) if testCase.want.err != nil { assert.EqualError(t, err, testCase.want.err.Error()) } else { @@ -1261,7 +1261,7 @@ func TestGetSchemaEnvironments(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "Extension1", "DisplayName", "v1.0.0") require.NoError(t, err) // Create Environments - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, defaultSchemaExtensionID, int32(1), int32(1)) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, defaultSchemaExtensionID, int32(1), int32(1)) require.NoError(t, err) return testSuite @@ -1289,9 +1289,9 @@ func TestGetSchemaEnvironments(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "Extension1", "DisplayName", "v1.0.0") require.NoError(t, err) // Create Environments - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, defaultSchemaExtensionID, int32(1), int32(1)) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, defaultSchemaExtensionID, int32(1), int32(1)) require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, defaultSchemaExtensionID, int32(2), int32(2)) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, defaultSchemaExtensionID, int32(2), int32(2)) require.NoError(t, err) return testSuite @@ -1328,7 +1328,7 @@ func TestGetSchemaEnvironments(t *testing.T) { testSuite := testCase.setup() defer teardownIntegrationTestSuite(t, &testSuite) - got, err := testSuite.BHDatabase.GetSchemaEnvironments(testSuite.Context) + got, err := testSuite.BHDatabase.GetEnvironments(testSuite.Context) if testCase.want.err != nil { assert.EqualError(t, err, testCase.want.err.Error()) } else { @@ -1372,7 +1372,7 @@ func TestGetSchemaEnvironmentById(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "Extension1", "DisplayName", "v1.0.0") require.NoError(t, err) // Create Environment - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, defaultSchemaExtensionID, int32(1), int32(1)) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, defaultSchemaExtensionID, int32(1), int32(1)) require.NoError(t, err) return testSuite @@ -1409,7 +1409,7 @@ func TestGetSchemaEnvironmentById(t *testing.T) { testSuite := testCase.setup() defer teardownIntegrationTestSuite(t, &testSuite) - got, err := testSuite.BHDatabase.GetSchemaEnvironmentById(testSuite.Context, testCase.args.environmentId) + got, err := testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, testCase.args.environmentId) if testCase.want.err != nil { assert.ErrorIs(t, err, testCase.want.err) } else { @@ -1451,7 +1451,7 @@ func TestDeleteSchemaEnvironment(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "Extension1", "DisplayName", "v1.0.0") require.NoError(t, err) // Create Environment - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, defaultSchemaExtensionID, int32(1), int32(1)) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, defaultSchemaExtensionID, int32(1), int32(1)) require.NoError(t, err) return testSuite @@ -1481,14 +1481,14 @@ func TestDeleteSchemaEnvironment(t *testing.T) { testSuite := testCase.setup() defer teardownIntegrationTestSuite(t, &testSuite) - err := testSuite.BHDatabase.DeleteSchemaEnvironment(testSuite.Context, testCase.args.environmentId) + err := testSuite.BHDatabase.DeleteEnvironment(testSuite.Context, testCase.args.environmentId) if testCase.want.err != nil { assert.ErrorIs(t, err, testCase.want.err) } else { assert.NoError(t, err) // Verify deletion by trying to get the environment - _, err = testSuite.BHDatabase.GetSchemaEnvironmentById(testSuite.Context, testCase.args.environmentId) + _, err = testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, testCase.args.environmentId) assert.ErrorIs(t, err, database.ErrNotFound) } }) @@ -1506,21 +1506,21 @@ func TestTransaction_SchemaEnvironment(t *testing.T) { // Create two environments in a single transaction err = testSuite.BHDatabase.Transaction(testSuite.Context, func(tx *database.BloodhoundDB) error { - _, err := tx.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err := tx.CreateEnvironment(testSuite.Context, 1, 1, 1) if err != nil { return err } - _, err = tx.CreateSchemaEnvironment(testSuite.Context, 1, 2, 2) + _, err = tx.CreateEnvironment(testSuite.Context, 1, 2, 2) return err }) require.NoError(t, err) // Verify both environments were created - env1, err := testSuite.BHDatabase.GetSchemaEnvironmentById(testSuite.Context, 1) + env1, err := testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, 1) require.NoError(t, err) assert.Equal(t, int32(1), env1.EnvironmentKindId) - env2, err := testSuite.BHDatabase.GetSchemaEnvironmentById(testSuite.Context, 2) + env2, err := testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, 2) require.NoError(t, err) assert.Equal(t, int32(2), env2.EnvironmentKindId) }) @@ -1536,7 +1536,7 @@ func TestTransaction_SchemaEnvironment(t *testing.T) { // Create one environment, then fail - should rollback expectedErr := fmt.Errorf("intentional error to trigger rollback") err = testSuite.BHDatabase.Transaction(testSuite.Context, func(tx *database.BloodhoundDB) error { - _, err := tx.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err := tx.CreateEnvironment(testSuite.Context, 1, 1, 1) if err != nil { return err } @@ -1545,7 +1545,7 @@ func TestTransaction_SchemaEnvironment(t *testing.T) { require.ErrorIs(t, err, expectedErr) // Verify the environment was NOT created (rolled back) - _, err = testSuite.BHDatabase.GetSchemaEnvironmentById(testSuite.Context, 1) + _, err = testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, 1) assert.ErrorIs(t, err, database.ErrNotFound) }) @@ -1559,18 +1559,18 @@ func TestTransaction_SchemaEnvironment(t *testing.T) { // Create one environment, then try to create a duplicate - should rollback both err = testSuite.BHDatabase.Transaction(testSuite.Context, func(tx *database.BloodhoundDB) error { - _, err := tx.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err := tx.CreateEnvironment(testSuite.Context, 1, 1, 1) if err != nil { return err } // Try to create duplicate (same environment_kind_id + source_kind_id) - will fail - _, err = tx.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = tx.CreateEnvironment(testSuite.Context, 1, 1, 1) return err }) require.Error(t, err) // Verify the first environment was NOT created (rolled back due to second failure) - _, err = testSuite.BHDatabase.GetSchemaEnvironmentById(testSuite.Context, 1) + _, err = testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, 1) assert.ErrorIs(t, err, database.ErrNotFound) }) @@ -1584,16 +1584,16 @@ func TestTransaction_SchemaEnvironment(t *testing.T) { // Create and delete in same transaction err = testSuite.BHDatabase.Transaction(testSuite.Context, func(tx *database.BloodhoundDB) error { - env, err := tx.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + env, err := tx.CreateEnvironment(testSuite.Context, 1, 1, 1) if err != nil { return err } - return tx.DeleteSchemaEnvironment(testSuite.Context, env.ID) + return tx.DeleteEnvironment(testSuite.Context, env.ID) }) require.NoError(t, err) // Verify the environment does not exist - _, err = testSuite.BHDatabase.GetSchemaEnvironmentById(testSuite.Context, 1) + _, err = testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, 1) assert.ErrorIs(t, err, database.ErrNotFound) }) } @@ -1625,7 +1625,7 @@ func TestCreateSchemaRelationshipFinding(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "FindingExtension", "Finding Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) return testSuite @@ -1657,7 +1657,7 @@ func TestCreateSchemaRelationshipFinding(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "FindingExtension2", "Finding Extension 2", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) _, err = testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, 1, 1, 1, "DuplicateName", "Display Name") @@ -1724,7 +1724,7 @@ func TestGetSchemaRelationshipFindingById(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "GetFindingExt", "Get Finding Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) _, err = testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, 1, 1, 1, "GetByIdFinding", "Get By ID Finding") @@ -1798,7 +1798,7 @@ func TestDeleteSchemaRelationshipFinding(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "DeleteFindingExt", "Delete Finding Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) _, err = testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, 1, 1, 1, "DeleteFinding", "Delete Finding") @@ -1870,7 +1870,7 @@ func TestCreateRemediation(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "RemediationExt", "Remediation Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) _, err = testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, 1, 1, 1, "RemediationFinding", "Remediation Finding") @@ -1942,7 +1942,7 @@ func TestGetRemediationByFindingId(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "GetRemediationExt", "Get Remediation Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) _, err = testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, 1, 1, 1, "GetRemediationFinding", "Get Remediation Finding") @@ -2022,7 +2022,7 @@ func TestUpdateRemediation(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "UpdateRemediationExt", "Update Remediation Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) _, err = testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, 1, 1, 1, "UpdateRemediationFinding", "Update Remediation Finding") @@ -2059,7 +2059,7 @@ func TestUpdateRemediation(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "UpsertRemediationExt", "Upsert Remediation Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) _, err = testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, 1, 1, 1, "UpsertRemediationFinding", "Upsert Remediation Finding") @@ -2130,7 +2130,7 @@ func TestDeleteRemediation(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "DeleteRemediationExt", "Delete Remediation Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) _, err = testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, 1, 1, 1, "DeleteRemediationFinding", "Delete Remediation Finding") @@ -2202,7 +2202,7 @@ func TestCreateSchemaEnvironmentPrincipalKind(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "EnvPrincipalKindExt", "Env Principal Kind Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) return testSuite @@ -2259,7 +2259,7 @@ func TestGetSchemaEnvironmentPrincipalKindsByEnvironmentId(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "GetEnvPrincipalKindExt", "Get Env Principal Kind Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) _, err = testSuite.BHDatabase.CreatePrincipalKind(testSuite.Context, 1, 1) @@ -2329,7 +2329,7 @@ func TestDeleteSchemaEnvironmentPrincipalKind(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "DeleteEnvPrincipalKindExt", "Delete Env Principal Kind Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) _, err = testSuite.BHDatabase.CreatePrincipalKind(testSuite.Context, 1, 1) @@ -2394,7 +2394,7 @@ func TestDeleteSchemaExtension_CascadeDeletesAllDependents(t *testing.T) { edgeKind, err := testSuite.BHDatabase.CreateGraphSchemaEdgeKind(testSuite.Context, "CascadeTestEdgeKind", extension.ID, "Test description", true) require.NoError(t, err) - environment, err := testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, extension.ID, nodeKind.ID, nodeKind.ID) + environment, err := testSuite.BHDatabase.CreateEnvironment(testSuite.Context, extension.ID, nodeKind.ID, nodeKind.ID) require.NoError(t, err) relationshipFinding, err := testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, extension.ID, edgeKind.ID, environment.ID, "CascadeTestFinding", "Cascade Test Finding") @@ -2418,7 +2418,7 @@ func TestDeleteSchemaExtension_CascadeDeletesAllDependents(t *testing.T) { _, err = testSuite.BHDatabase.GetGraphSchemaEdgeKindById(testSuite.Context, edgeKind.ID) assert.ErrorIs(t, err, database.ErrNotFound) - _, err = testSuite.BHDatabase.GetSchemaEnvironmentById(testSuite.Context, environment.ID) + _, err = testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, environment.ID) assert.ErrorIs(t, err, database.ErrNotFound) _, err = testSuite.BHDatabase.GetSchemaRelationshipFindingById(testSuite.Context, relationshipFinding.ID) diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index df0c1b70fb..c23d6e150f 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -341,6 +341,21 @@ func (mr *MockDatabaseMockRecorder) CreateCustomNodeKinds(ctx, customNodeKind an return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCustomNodeKinds", reflect.TypeOf((*MockDatabase)(nil).CreateCustomNodeKinds), ctx, customNodeKind) } +// CreateEnvironment mocks base method. +func (m *MockDatabase) CreateEnvironment(ctx context.Context, extensionId, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateEnvironment", ctx, extensionId, environmentKindId, sourceKindId) + ret0, _ := ret[0].(model.SchemaEnvironment) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateEnvironment indicates an expected call of CreateEnvironment. +func (mr *MockDatabaseMockRecorder) CreateEnvironment(ctx, extensionId, environmentKindId, sourceKindId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateEnvironment", reflect.TypeOf((*MockDatabase)(nil).CreateEnvironment), ctx, extensionId, environmentKindId, sourceKindId) +} + // CreateGraphSchemaEdgeKind mocks base method. func (m *MockDatabase) CreateGraphSchemaEdgeKind(ctx context.Context, name string, schemaExtensionId int32, description string, isTraversable bool) (model.GraphSchemaEdgeKind, error) { m.ctrl.T.Helper() @@ -585,21 +600,6 @@ func (mr *MockDatabaseMockRecorder) CreateSavedQueryPermissionsToUsers(ctx, quer return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSavedQueryPermissionsToUsers", reflect.TypeOf((*MockDatabase)(nil).CreateSavedQueryPermissionsToUsers), varargs...) } -// CreateSchemaEnvironment mocks base method. -func (m *MockDatabase) CreateSchemaEnvironment(ctx context.Context, extensionId, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateSchemaEnvironment", ctx, extensionId, environmentKindId, sourceKindId) - ret0, _ := ret[0].(model.SchemaEnvironment) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateSchemaEnvironment indicates an expected call of CreateSchemaEnvironment. -func (mr *MockDatabaseMockRecorder) CreateSchemaEnvironment(ctx, extensionId, environmentKindId, sourceKindId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironment", reflect.TypeOf((*MockDatabase)(nil).CreateSchemaEnvironment), ctx, extensionId, environmentKindId, sourceKindId) -} - // CreateSchemaRelationshipFinding mocks base method. func (m *MockDatabase) CreateSchemaRelationshipFinding(ctx context.Context, extensionId, relationshipKindId, environmentId int32, name, displayName string) (model.SchemaRelationshipFinding, error) { m.ctrl.T.Helper() @@ -842,6 +842,20 @@ func (mr *MockDatabaseMockRecorder) DeleteCustomNodeKind(ctx, kindName any) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCustomNodeKind", reflect.TypeOf((*MockDatabase)(nil).DeleteCustomNodeKind), ctx, kindName) } +// DeleteEnvironment mocks base method. +func (m *MockDatabase) DeleteEnvironment(ctx context.Context, environmentId int32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteEnvironment", ctx, environmentId) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteEnvironment indicates an expected call of DeleteEnvironment. +func (mr *MockDatabaseMockRecorder) DeleteEnvironment(ctx, environmentId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteEnvironment", reflect.TypeOf((*MockDatabase)(nil).DeleteEnvironment), ctx, environmentId) +} + // DeleteEnvironmentTargetedAccessControlForUser mocks base method. func (m *MockDatabase) DeleteEnvironmentTargetedAccessControlForUser(ctx context.Context, user model.User) error { m.ctrl.T.Helper() @@ -1001,20 +1015,6 @@ func (mr *MockDatabaseMockRecorder) DeleteSavedQueryPermissionsForUsers(ctx, que return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSavedQueryPermissionsForUsers", reflect.TypeOf((*MockDatabase)(nil).DeleteSavedQueryPermissionsForUsers), varargs...) } -// DeleteSchemaEnvironment mocks base method. -func (m *MockDatabase) DeleteSchemaEnvironment(ctx context.Context, environmentId int32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteSchemaEnvironment", ctx, environmentId) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteSchemaEnvironment indicates an expected call of DeleteSchemaEnvironment. -func (mr *MockDatabaseMockRecorder) DeleteSchemaEnvironment(ctx, environmentId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironment", reflect.TypeOf((*MockDatabase)(nil).DeleteSchemaEnvironment), ctx, environmentId) -} - // DeleteSchemaRelationshipFinding mocks base method. func (m *MockDatabase) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error { m.ctrl.T.Helper() @@ -1702,6 +1702,36 @@ func (mr *MockDatabaseMockRecorder) GetDatapipeStatus(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDatapipeStatus", reflect.TypeOf((*MockDatabase)(nil).GetDatapipeStatus), ctx) } +// GetEnvironmentById mocks base method. +func (m *MockDatabase) GetEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetEnvironmentById", ctx, environmentId) + ret0, _ := ret[0].(model.SchemaEnvironment) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetEnvironmentById indicates an expected call of GetEnvironmentById. +func (mr *MockDatabaseMockRecorder) GetEnvironmentById(ctx, environmentId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnvironmentById", reflect.TypeOf((*MockDatabase)(nil).GetEnvironmentById), ctx, environmentId) +} + +// GetEnvironmentByKinds mocks base method. +func (m *MockDatabase) GetEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetEnvironmentByKinds", ctx, environmentKindId, sourceKindId) + ret0, _ := ret[0].(model.SchemaEnvironment) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetEnvironmentByKinds indicates an expected call of GetEnvironmentByKinds. +func (mr *MockDatabaseMockRecorder) GetEnvironmentByKinds(ctx, environmentKindId, sourceKindId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnvironmentByKinds", reflect.TypeOf((*MockDatabase)(nil).GetEnvironmentByKinds), ctx, environmentKindId, sourceKindId) +} + // GetEnvironmentTargetedAccessControlForUser mocks base method. func (m *MockDatabase) GetEnvironmentTargetedAccessControlForUser(ctx context.Context, user model.User) ([]model.EnvironmentTargetedAccessControl, error) { m.ctrl.T.Helper() @@ -1717,6 +1747,21 @@ func (mr *MockDatabaseMockRecorder) GetEnvironmentTargetedAccessControlForUser(c return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnvironmentTargetedAccessControlForUser", reflect.TypeOf((*MockDatabase)(nil).GetEnvironmentTargetedAccessControlForUser), ctx, user) } +// GetEnvironments mocks base method. +func (m *MockDatabase) GetEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetEnvironments", ctx) + ret0, _ := ret[0].([]model.SchemaEnvironment) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetEnvironments indicates an expected call of GetEnvironments. +func (mr *MockDatabaseMockRecorder) GetEnvironments(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnvironments", reflect.TypeOf((*MockDatabase)(nil).GetEnvironments), ctx) +} + // GetFlag mocks base method. func (m *MockDatabase) GetFlag(ctx context.Context, id int32) (appcfg.FeatureFlag, error) { m.ctrl.T.Helper() @@ -2202,51 +2247,6 @@ func (mr *MockDatabaseMockRecorder) GetSavedQueryPermissions(ctx, queryID any) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSavedQueryPermissions", reflect.TypeOf((*MockDatabase)(nil).GetSavedQueryPermissions), ctx, queryID) } -// GetSchemaEnvironmentById mocks base method. -func (m *MockDatabase) GetSchemaEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaEnvironmentById", ctx, environmentId) - ret0, _ := ret[0].(model.SchemaEnvironment) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaEnvironmentById indicates an expected call of GetSchemaEnvironmentById. -func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentById(ctx, environmentId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentById", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentById), ctx, environmentId) -} - -// GetSchemaEnvironmentByKinds mocks base method. -func (m *MockDatabase) GetSchemaEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaEnvironmentByKinds", ctx, environmentKindId, sourceKindId) - ret0, _ := ret[0].(model.SchemaEnvironment) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaEnvironmentByKinds indicates an expected call of GetSchemaEnvironmentByKinds. -func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentByKinds(ctx, environmentKindId, sourceKindId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentByKinds", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentByKinds), ctx, environmentKindId, sourceKindId) -} - -// GetSchemaEnvironments mocks base method. -func (m *MockDatabase) GetSchemaEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaEnvironments", ctx) - ret0, _ := ret[0].([]model.SchemaEnvironment) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaEnvironments indicates an expected call of GetSchemaEnvironments. -func (mr *MockDatabaseMockRecorder) GetSchemaEnvironments(ctx any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironments", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironments), ctx) -} - // GetSchemaRelationshipFindingById mocks base method. func (m *MockDatabase) GetSchemaRelationshipFindingById(ctx context.Context, findingId int32) (model.SchemaRelationshipFinding, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/database/upsert_schema_environment.go b/cmd/api/src/database/upsert_schema_environment.go index df1401db18..180ec7702a 100644 --- a/cmd/api/src/database/upsert_schema_environment.go +++ b/cmd/api/src/database/upsert_schema_environment.go @@ -119,17 +119,17 @@ func (s *BloodhoundDB) validateAndTranslatePrincipalKinds(ctx context.Context, p // The unique constraint on (environment_kind_id, source_kind_id) of the Schema Environment table ensures no // duplicate pairs exist, enabling this upsert logic. func (s *BloodhoundDB) replaceSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) (int32, error) { - if existing, err := s.GetSchemaEnvironmentByKinds(ctx, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil && !errors.Is(err, ErrNotFound) { + if existing, err := s.GetEnvironmentByKinds(ctx, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil && !errors.Is(err, ErrNotFound) { return 0, fmt.Errorf("error retrieving schema environment: %w", err) } else if !errors.Is(err, ErrNotFound) { // Environment exists - delete it first - if err := s.DeleteSchemaEnvironment(ctx, existing.ID); err != nil { + if err := s.DeleteEnvironment(ctx, existing.ID); err != nil { return 0, fmt.Errorf("error deleting schema environment %d: %w", existing.ID, err) } } // Create Environment - if created, err := s.CreateSchemaEnvironment(ctx, graphSchema.SchemaExtensionId, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil { + if created, err := s.CreateEnvironment(ctx, graphSchema.SchemaExtensionId, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil { return 0, fmt.Errorf("error creating schema environment: %w", err) } else { return created.ID, nil diff --git a/cmd/api/src/database/upsert_schema_environment_integration_test.go b/cmd/api/src/database/upsert_schema_environment_integration_test.go index b811b10adb..41409b4b99 100644 --- a/cmd/api/src/database/upsert_schema_environment_integration_test.go +++ b/cmd/api/src/database/upsert_schema_environment_integration_test.go @@ -60,7 +60,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { expectedPrincipalKindNames := []string{"Tag_Tier_Zero", "Tag_Owned"} - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 1, len(environments)) @@ -114,7 +114,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { expectedPrincipalKindNames := []string{"Tag_Tier_Zero"} - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 1, len(environments), "Should only have one environment (old one deleted)") @@ -150,7 +150,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { assert.NoError(t, err) assert.Equal(t, graph.StringKind("NewSource"), sourceKind.Name) - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 1, len(environments)) assert.Equal(t, int32(sourceKind.ID), environments[0].SourceKindId) @@ -179,7 +179,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { t.Helper() // Verify transaction rolled back - no environment created - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 0, len(environments), "No environment should exist after rollback") }, @@ -203,7 +203,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { t.Helper() // Verify transaction rolled back - no environment created - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 0, len(environments), "No environment should exist after rollback") }, @@ -227,7 +227,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { t.Helper() // Verify transaction rolled back - no environment created - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 0, len(environments), "No environment should exist after rollback") }, @@ -260,7 +260,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { assert: func(t *testing.T, db *database.BloodhoundDB) { t.Helper() - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 2, len(environments), "Should have two different environments") diff --git a/cmd/api/src/database/upsert_schema_extension_integration_test.go b/cmd/api/src/database/upsert_schema_extension_integration_test.go index 2f998cd054..e4fcae0287 100644 --- a/cmd/api/src/database/upsert_schema_extension_integration_test.go +++ b/cmd/api/src/database/upsert_schema_extension_integration_test.go @@ -61,7 +61,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { expectedPrincipalKindNames := []string{"Tag_Tier_Zero", "Tag_Owned"} - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 1, len(environments)) @@ -111,7 +111,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { assert: func(t *testing.T, db *database.BloodhoundDB) { t.Helper() - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 2, len(environments), "Should have two environments") @@ -159,7 +159,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { expectedPrincipalKindNames := []string{"Tag_Tier_Zero"} - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 1, len(environments), "Should only have one environment (old one replaced)") @@ -199,7 +199,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { assert.NoError(t, err) assert.Equal(t, graph.StringKind("NewSource"), sourceKind.Name) - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 1, len(environments)) assert.Equal(t, int32(sourceKind.ID), environments[0].SourceKindId) @@ -240,7 +240,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { assert.NoError(t, err) assert.Equal(t, graph.StringKind("NewSource"), sourceKind.Name) - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 2, len(environments), "Should have two environments") @@ -274,7 +274,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { t.Helper() // Verify transaction rolled back - no environment created - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 0, len(environments), "No environment should exist after rollback") }, @@ -302,7 +302,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { t.Helper() // Verify transaction rolled back - no environment created - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 0, len(environments), "No environment should exist after rollback") }, @@ -335,7 +335,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { t.Helper() // Verify complete transaction rollback - no environments created - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 0, len(environments), "No environments should exist after rollback") }, @@ -368,7 +368,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { t.Helper() // Verify complete transaction rollback - no environments created - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 0, len(environments), "No environments should exist after rollback") }, @@ -396,7 +396,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { t.Helper() // Verify transaction rolled back - no environment created - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 0, len(environments), "No environment should exist after rollback") }, diff --git a/cmd/api/src/services/entrypoint.go b/cmd/api/src/services/entrypoint.go index fd2eebb60e..cb87ebfd2b 100644 --- a/cmd/api/src/services/entrypoint.go +++ b/cmd/api/src/services/entrypoint.go @@ -119,13 +119,13 @@ func Entrypoint(ctx context.Context, cfg config.Configuration, connections boots startDelay := 0 * time.Second var ( - cl = changelog.NewChangelog(connections.Graph, connections.RDMS, changelog.DefaultOptions()) - pipeline = datapipe.NewPipeline(ctx, cfg, connections.RDMS, connections.Graph, graphQueryCache, ingestSchema, cl) - graphQuery = queries.NewGraphQuery(connections.Graph, graphQueryCache, cfg) - authorizer = auth.NewAuthorizer(connections.RDMS) - datapipeDaemon = datapipe.NewDaemon(pipeline, startDelay, time.Duration(cfg.DatapipeInterval)*time.Second, connections.RDMS) - routerInst = router.NewRouter(cfg, authorizer, fmt.Sprintf(bootstrap.ContentSecurityPolicy, "", "")) - authenticator = api.NewAuthenticator(cfg, connections.RDMS, api.NewAuthExtensions(cfg, connections.RDMS)) + cl = changelog.NewChangelog(connections.Graph, connections.RDMS, changelog.DefaultOptions()) + pipeline = datapipe.NewPipeline(ctx, cfg, connections.RDMS, connections.Graph, graphQueryCache, ingestSchema, cl) + graphQuery = queries.NewGraphQuery(connections.Graph, graphQueryCache, cfg) + authorizer = auth.NewAuthorizer(connections.RDMS) + datapipeDaemon = datapipe.NewDaemon(pipeline, startDelay, time.Duration(cfg.DatapipeInterval)*time.Second, connections.RDMS) + routerInst = router.NewRouter(cfg, authorizer, fmt.Sprintf(bootstrap.ContentSecurityPolicy, "", "")) + authenticator = api.NewAuthenticator(cfg, connections.RDMS, api.NewAuthExtensions(cfg, connections.RDMS)) openGraphSchemaService = opengraphschema.NewOpenGraphSchemaService(connections.RDMS) ) From 1de71fc7ff6e2554ca9e03f0a0f8f2f00c873b9b Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Thu, 15 Jan 2026 16:08:07 -0600 Subject: [PATCH 32/36] migration update --- cmd/api/src/database/migration/migrations/v8.6.0.sql | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cmd/api/src/database/migration/migrations/v8.6.0.sql b/cmd/api/src/database/migration/migrations/v8.6.0.sql index fbb20a81b0..219f6bb54d 100644 --- a/cmd/api/src/database/migration/migrations/v8.6.0.sql +++ b/cmd/api/src/database/migration/migrations/v8.6.0.sql @@ -13,10 +13,10 @@ -- limitations under the License. -- -- SPDX-License-Identifier: Apache-2.0 --- Drop the column with the incorrect foreign key + ALTER TABLE IF EXISTS schema_environments - DROP COLUMN IF EXISTS source_kind_id; + DROP CONSTRAINT IF EXISTS schema_environments_source_kind_id_fkey; --- Add the column back with the correct foreign key reference ALTER TABLE IF EXISTS schema_environments - ADD COLUMN source_kind_id INTEGER NOT NULL REFERENCES source_kinds(id); + ADD CONSTRAINT schema_environments_source_kind_id_fkey + FOREIGN KEY (source_kind_id) REFERENCES source_kinds(id); From a2427b6ce803ded0144d3f00b0f686ffd4a95640 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Thu, 15 Jan 2026 16:08:13 -0600 Subject: [PATCH 33/36] migration update --- cmd/api/src/database/migration/migrations/v8.5.0.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/api/src/database/migration/migrations/v8.5.0.sql b/cmd/api/src/database/migration/migrations/v8.5.0.sql index 348e8abf84..b2b94b3e80 100644 --- a/cmd/api/src/database/migration/migrations/v8.5.0.sql +++ b/cmd/api/src/database/migration/migrations/v8.5.0.sql @@ -92,7 +92,7 @@ CREATE TABLE IF NOT EXISTS schema_environments ( id SERIAL, schema_extension_id INTEGER NOT NULL REFERENCES schema_extensions(id) ON DELETE CASCADE, environment_kind_id INTEGER NOT NULL REFERENCES kind(id), - source_kind_id INTEGER NOT NULL REFERENCES source_kinds(id), + source_kind_id INTEGER NOT NULL REFERENCES kind(id), PRIMARY KEY (id), created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT current_timestamp, updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT current_timestamp, From 38456ce1443397e4dab0bad7b19106d81f46994a Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Fri, 16 Jan 2026 11:21:44 -0600 Subject: [PATCH 34/36] getting 6853 in line with peer review comments from 6852 --- cmd/api/src/database/mocks/db.go | 60 ------------------- ...psert_schema_extension_integration_test.go | 4 +- cmd/api/src/database/upsert_schema_finding.go | 8 +-- .../upsert_schema_finding_integration_test.go | 4 +- ...ert_schema_remediation_integration_test.go | 4 +- 5 files changed, 10 insertions(+), 70 deletions(-) diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index 7c65f3ac2a..27d9762599 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -2247,66 +2247,6 @@ func (mr *MockDatabaseMockRecorder) GetSavedQueryPermissions(ctx, queryID any) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSavedQueryPermissions", reflect.TypeOf((*MockDatabase)(nil).GetSavedQueryPermissions), ctx, queryID) } -// GetSchemaEnvironmentById mocks base method. -func (m *MockDatabase) GetSchemaEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaEnvironmentById", ctx, environmentId) - ret0, _ := ret[0].(model.SchemaEnvironment) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaEnvironmentById indicates an expected call of GetSchemaEnvironmentById. -func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentById(ctx, environmentId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentById", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentById), ctx, environmentId) -} - -// GetSchemaEnvironmentByKinds mocks base method. -func (m *MockDatabase) GetSchemaEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaEnvironmentByKinds", ctx, environmentKindId, sourceKindId) - ret0, _ := ret[0].(model.SchemaEnvironment) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaEnvironmentByKinds indicates an expected call of GetSchemaEnvironmentByKinds. -func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentByKinds(ctx, environmentKindId, sourceKindId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentByKinds", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentByKinds), ctx, environmentKindId, sourceKindId) -} - -// GetSchemaEnvironmentPrincipalKindsByEnvironmentId mocks base method. -func (m *MockDatabase) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", ctx, environmentId) - ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKinds) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaEnvironmentPrincipalKindsByEnvironmentId indicates an expected call of GetSchemaEnvironmentPrincipalKindsByEnvironmentId. -func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentPrincipalKindsByEnvironmentId), ctx, environmentId) -} - -// GetSchemaEnvironments mocks base method. -func (m *MockDatabase) GetSchemaEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaEnvironments", ctx) - ret0, _ := ret[0].([]model.SchemaEnvironment) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaEnvironments indicates an expected call of GetSchemaEnvironments. -func (mr *MockDatabaseMockRecorder) GetSchemaEnvironments(ctx any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironments", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironments), ctx) -} - // GetSchemaRelationshipFindingById mocks base method. func (m *MockDatabase) GetSchemaRelationshipFindingById(ctx context.Context, findingId int32) (model.SchemaRelationshipFinding, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/database/upsert_schema_extension_integration_test.go b/cmd/api/src/database/upsert_schema_extension_integration_test.go index 4c2387c972..6c67d9c682 100644 --- a/cmd/api/src/database/upsert_schema_extension_integration_test.go +++ b/cmd/api/src/database/upsert_schema_extension_integration_test.go @@ -45,7 +45,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") require.NoError(t, err) - _, err = db.CreateSchemaEnvironment(context.Background(), ext.ID, int32(1), int32(1)) + _, err = db.CreateEnvironment(context.Background(), ext.ID, int32(1), int32(1)) require.NoError(t, err) return ext.ID @@ -124,7 +124,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt2", "Test2", "v1.0.0") require.NoError(t, err) - env, err := db.CreateSchemaEnvironment(context.Background(), ext.ID, 1, 1) + env, err := db.CreateEnvironment(context.Background(), ext.ID, 1, 1) require.NoError(t, err) // Create initial finding with remediation diff --git a/cmd/api/src/database/upsert_schema_finding.go b/cmd/api/src/database/upsert_schema_finding.go index 7c1391c910..d2d7b995cd 100644 --- a/cmd/api/src/database/upsert_schema_finding.go +++ b/cmd/api/src/database/upsert_schema_finding.go @@ -41,12 +41,12 @@ func (s *BloodhoundDB) UpsertFinding(ctx context.Context, extensionId int32, sou return model.SchemaRelationshipFinding{}, err } - environment, err := s.GetSchemaEnvironmentByKinds(ctx, environmentKindId, sourceKindId) + environment, err := s.GetEnvironmentByKinds(ctx, environmentKindId, sourceKindId) if err != nil { return model.SchemaRelationshipFinding{}, err } - finding, err := s.upsertFinding(ctx, extensionId, relationshipKindId, environment.ID, name, displayName) + finding, err := s.replaceFinding(ctx, extensionId, relationshipKindId, environment.ID, name, displayName) if err != nil { return model.SchemaRelationshipFinding{}, err } @@ -65,9 +65,9 @@ func (s *BloodhoundDB) validateAndTranslateRelationshipKind(ctx context.Context, } } -// upsertFinding creates or updates a schema relationship finding. +// replaceFinding creates or updates a schema relationship finding. // If a finding with the given name exists, it deletes it first before creating the new one. -func (s *BloodhoundDB) upsertFinding(ctx context.Context, extensionId, relationshipKindId, environmentId int32, name, displayName string) (model.SchemaRelationshipFinding, error) { +func (s *BloodhoundDB) replaceFinding(ctx context.Context, extensionId, relationshipKindId, environmentId int32, name, displayName string) (model.SchemaRelationshipFinding, error) { if existing, err := s.GetSchemaRelationshipFindingByName(ctx, name); err != nil && !errors.Is(err, ErrNotFound) { return model.SchemaRelationshipFinding{}, fmt.Errorf("error retrieving schema relationship finding: %w", err) } else if err == nil { diff --git a/cmd/api/src/database/upsert_schema_finding_integration_test.go b/cmd/api/src/database/upsert_schema_finding_integration_test.go index fcb1d0940f..70297b5e90 100644 --- a/cmd/api/src/database/upsert_schema_finding_integration_test.go +++ b/cmd/api/src/database/upsert_schema_finding_integration_test.go @@ -46,7 +46,7 @@ func TestBloodhoundDB_UpsertFinding(t *testing.T) { ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") require.NoError(t, err) - env, err := db.CreateSchemaEnvironment(context.Background(), ext.ID, 1, 1) + env, err := db.CreateEnvironment(context.Background(), ext.ID, 1, 1) require.NoError(t, err) // Create finding @@ -81,7 +81,7 @@ func TestBloodhoundDB_UpsertFinding(t *testing.T) { ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt2", "Test2", "v1.0.0") require.NoError(t, err) - _, err = db.CreateSchemaEnvironment(context.Background(), ext.ID, 1, 1) + _, err = db.CreateEnvironment(context.Background(), ext.ID, 1, 1) require.NoError(t, err) // No finding created since we're testing the creation workflow diff --git a/cmd/api/src/database/upsert_schema_remediation_integration_test.go b/cmd/api/src/database/upsert_schema_remediation_integration_test.go index 1798be8774..78616cbfb5 100644 --- a/cmd/api/src/database/upsert_schema_remediation_integration_test.go +++ b/cmd/api/src/database/upsert_schema_remediation_integration_test.go @@ -45,7 +45,7 @@ func TestBloodhoundDB_UpsertRemediation(t *testing.T) { ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") require.NoError(t, err) - env, err := db.CreateSchemaEnvironment(context.Background(), ext.ID, 1, 1) + env, err := db.CreateEnvironment(context.Background(), ext.ID, 1, 1) require.NoError(t, err) finding, err := db.CreateSchemaRelationshipFinding(context.Background(), ext.ID, 1, env.ID, "Finding", "Finding Display Name") @@ -82,7 +82,7 @@ func TestBloodhoundDB_UpsertRemediation(t *testing.T) { ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") require.NoError(t, err) - env, err := db.CreateSchemaEnvironment(context.Background(), ext.ID, 1, 1) + env, err := db.CreateEnvironment(context.Background(), ext.ID, 1, 1) require.NoError(t, err) // Create Finding but do not create Remediation From c5765930979b0615e516c8529000ae5d431e4745 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Fri, 16 Jan 2026 12:55:52 -0600 Subject: [PATCH 35/36] updated per 6852 peer review --- cmd/api/src/config/config.go | 2 +- cmd/api/src/config/default.go | 2 +- cmd/api/src/database/mocks/db.go | 118 +++++++++--------- cmd/api/src/database/upsert_schema_finding.go | 2 + 4 files changed, 63 insertions(+), 61 deletions(-) diff --git a/cmd/api/src/config/config.go b/cmd/api/src/config/config.go index dd84db8406..2ee7652fd1 100644 --- a/cmd/api/src/config/config.go +++ b/cmd/api/src/config/config.go @@ -159,7 +159,7 @@ type Configuration struct { EnableAPILogging bool `json:"enable_api_logging"` EnableCypherMutations bool `json:"enable_cypher_mutations"` DisableAnalysis bool `json:"disable_analysis"` - DisableAPIKeys bool `json:"disable_api_keys"` + DisableAPIKeys bool `json:"disable_api_keys"` DisableCypherComplexityLimit bool `json:"disable_cypher_complexity_limit"` DisableIngest bool `json:"disable_ingest"` DisableMigrations bool `json:"disable_migrations"` diff --git a/cmd/api/src/config/default.go b/cmd/api/src/config/default.go index 6b35df66ca..cf1c3ac903 100644 --- a/cmd/api/src/config/default.go +++ b/cmd/api/src/config/default.go @@ -62,7 +62,7 @@ func NewDefaultConfiguration() (Configuration, error) { EnableStartupWaitPeriod: true, EnableAPILogging: true, DisableAnalysis: false, - DisableAPIKeys: false, + DisableAPIKeys: false, DisableCypherComplexityLimit: false, DisableIngest: false, DisableMigrations: false, diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index 27d9762599..189653f653 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -491,6 +491,21 @@ func (mr *MockDatabaseMockRecorder) CreatePrincipalKind(ctx, environmentId, prin return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreatePrincipalKind", reflect.TypeOf((*MockDatabase)(nil).CreatePrincipalKind), ctx, environmentId, principalKind) } +// CreateRelationshipFinding mocks base method. +func (m *MockDatabase) CreateRelationshipFinding(ctx context.Context, extensionId, relationshipKindId, environmentId int32, name, displayName string) (model.SchemaRelationshipFinding, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateRelationshipFinding", ctx, extensionId, relationshipKindId, environmentId, name, displayName) + ret0, _ := ret[0].(model.SchemaRelationshipFinding) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateRelationshipFinding indicates an expected call of CreateRelationshipFinding. +func (mr *MockDatabaseMockRecorder) CreateRelationshipFinding(ctx, extensionId, relationshipKindId, environmentId, name, displayName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateRelationshipFinding", reflect.TypeOf((*MockDatabase)(nil).CreateRelationshipFinding), ctx, extensionId, relationshipKindId, environmentId, name, displayName) +} + // CreateRemediation mocks base method. func (m *MockDatabase) CreateRemediation(ctx context.Context, findingId int32, shortDescription, longDescription, shortRemediation, longRemediation string) (model.Remediation, error) { m.ctrl.T.Helper() @@ -600,21 +615,6 @@ func (mr *MockDatabaseMockRecorder) CreateSavedQueryPermissionsToUsers(ctx, quer return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSavedQueryPermissionsToUsers", reflect.TypeOf((*MockDatabase)(nil).CreateSavedQueryPermissionsToUsers), varargs...) } -// CreateSchemaRelationshipFinding mocks base method. -func (m *MockDatabase) CreateSchemaRelationshipFinding(ctx context.Context, extensionId, relationshipKindId, environmentId int32, name, displayName string) (model.SchemaRelationshipFinding, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateSchemaRelationshipFinding", ctx, extensionId, relationshipKindId, environmentId, name, displayName) - ret0, _ := ret[0].(model.SchemaRelationshipFinding) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateSchemaRelationshipFinding indicates an expected call of CreateSchemaRelationshipFinding. -func (mr *MockDatabaseMockRecorder) CreateSchemaRelationshipFinding(ctx, extensionId, relationshipKindId, environmentId, name, displayName any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaRelationshipFinding", reflect.TypeOf((*MockDatabase)(nil).CreateSchemaRelationshipFinding), ctx, extensionId, relationshipKindId, environmentId, name, displayName) -} - // CreateUser mocks base method. func (m *MockDatabase) CreateUser(ctx context.Context, user model.User) (model.User, error) { m.ctrl.T.Helper() @@ -954,6 +954,20 @@ func (mr *MockDatabaseMockRecorder) DeletePrincipalKind(ctx, environmentId, prin return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePrincipalKind", reflect.TypeOf((*MockDatabase)(nil).DeletePrincipalKind), ctx, environmentId, principalKind) } +// DeleteRelationshipFinding mocks base method. +func (m *MockDatabase) DeleteRelationshipFinding(ctx context.Context, findingId int32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteRelationshipFinding", ctx, findingId) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteRelationshipFinding indicates an expected call of DeleteRelationshipFinding. +func (mr *MockDatabaseMockRecorder) DeleteRelationshipFinding(ctx, findingId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRelationshipFinding", reflect.TypeOf((*MockDatabase)(nil).DeleteRelationshipFinding), ctx, findingId) +} + // DeleteRemediation mocks base method. func (m *MockDatabase) DeleteRemediation(ctx context.Context, findingId int32) error { m.ctrl.T.Helper() @@ -1015,20 +1029,6 @@ func (mr *MockDatabaseMockRecorder) DeleteSavedQueryPermissionsForUsers(ctx, que return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSavedQueryPermissionsForUsers", reflect.TypeOf((*MockDatabase)(nil).DeleteSavedQueryPermissionsForUsers), varargs...) } -// DeleteSchemaRelationshipFinding mocks base method. -func (m *MockDatabase) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteSchemaRelationshipFinding", ctx, findingId) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteSchemaRelationshipFinding indicates an expected call of DeleteSchemaRelationshipFinding. -func (mr *MockDatabaseMockRecorder) DeleteSchemaRelationshipFinding(ctx, findingId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaRelationshipFinding", reflect.TypeOf((*MockDatabase)(nil).DeleteSchemaRelationshipFinding), ctx, findingId) -} - // DeleteSelectorNodesByNodeId mocks base method. func (m *MockDatabase) DeleteSelectorNodesByNodeId(ctx context.Context, selectorId int, nodeId graph.ID) error { m.ctrl.T.Helper() @@ -2082,6 +2082,36 @@ func (mr *MockDatabaseMockRecorder) GetPublicSavedQueries(ctx any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublicSavedQueries", reflect.TypeOf((*MockDatabase)(nil).GetPublicSavedQueries), ctx) } +// GetRelationshipFindingById mocks base method. +func (m *MockDatabase) GetRelationshipFindingById(ctx context.Context, findingId int32) (model.SchemaRelationshipFinding, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRelationshipFindingById", ctx, findingId) + ret0, _ := ret[0].(model.SchemaRelationshipFinding) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRelationshipFindingById indicates an expected call of GetRelationshipFindingById. +func (mr *MockDatabaseMockRecorder) GetRelationshipFindingById(ctx, findingId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRelationshipFindingById", reflect.TypeOf((*MockDatabase)(nil).GetRelationshipFindingById), ctx, findingId) +} + +// GetRelationshipFindingByName mocks base method. +func (m *MockDatabase) GetRelationshipFindingByName(ctx context.Context, name string) (model.SchemaRelationshipFinding, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetRelationshipFindingByName", ctx, name) + ret0, _ := ret[0].(model.SchemaRelationshipFinding) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetRelationshipFindingByName indicates an expected call of GetRelationshipFindingByName. +func (mr *MockDatabaseMockRecorder) GetRelationshipFindingByName(ctx, name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRelationshipFindingByName", reflect.TypeOf((*MockDatabase)(nil).GetRelationshipFindingByName), ctx, name) +} + // GetRemediationByFindingId mocks base method. func (m *MockDatabase) GetRemediationByFindingId(ctx context.Context, findingId int32) (model.Remediation, error) { m.ctrl.T.Helper() @@ -2247,36 +2277,6 @@ func (mr *MockDatabaseMockRecorder) GetSavedQueryPermissions(ctx, queryID any) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSavedQueryPermissions", reflect.TypeOf((*MockDatabase)(nil).GetSavedQueryPermissions), ctx, queryID) } -// GetSchemaRelationshipFindingById mocks base method. -func (m *MockDatabase) GetSchemaRelationshipFindingById(ctx context.Context, findingId int32) (model.SchemaRelationshipFinding, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaRelationshipFindingById", ctx, findingId) - ret0, _ := ret[0].(model.SchemaRelationshipFinding) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaRelationshipFindingById indicates an expected call of GetSchemaRelationshipFindingById. -func (mr *MockDatabaseMockRecorder) GetSchemaRelationshipFindingById(ctx, findingId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaRelationshipFindingById", reflect.TypeOf((*MockDatabase)(nil).GetSchemaRelationshipFindingById), ctx, findingId) -} - -// GetSchemaRelationshipFindingByName mocks base method. -func (m *MockDatabase) GetSchemaRelationshipFindingByName(ctx context.Context, name string) (model.SchemaRelationshipFinding, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaRelationshipFindingByName", ctx, name) - ret0, _ := ret[0].(model.SchemaRelationshipFinding) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaRelationshipFindingByName indicates an expected call of GetSchemaRelationshipFindingByName. -func (mr *MockDatabaseMockRecorder) GetSchemaRelationshipFindingByName(ctx, name any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaRelationshipFindingByName", reflect.TypeOf((*MockDatabase)(nil).GetSchemaRelationshipFindingByName), ctx, name) -} - // GetScopeForSavedQuery mocks base method. func (m *MockDatabase) GetScopeForSavedQuery(ctx context.Context, queryID int64, userID uuid.UUID) (database.SavedQueryScopeMap, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/database/upsert_schema_finding.go b/cmd/api/src/database/upsert_schema_finding.go index d2d7b995cd..9bca51a127 100644 --- a/cmd/api/src/database/upsert_schema_finding.go +++ b/cmd/api/src/database/upsert_schema_finding.go @@ -41,6 +41,8 @@ func (s *BloodhoundDB) UpsertFinding(ctx context.Context, extensionId int32, sou return model.SchemaRelationshipFinding{}, err } + // The unique constraint on (environment_kind_id, source_kind_id) of the Schema Environment table ensures no + // duplicate pairs exist, enabling this logic. environment, err := s.GetEnvironmentByKinds(ctx, environmentKindId, sourceKindId) if err != nil { return model.SchemaRelationshipFinding{}, err From cd60110effff5755008a735fbadfffb20107877c Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Fri, 16 Jan 2026 13:44:32 -0600 Subject: [PATCH 36/36] mocks --- cmd/api/src/database/mocks/db.go | 118 +++++++++++++++---------------- 1 file changed, 59 insertions(+), 59 deletions(-) diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index 189653f653..27d9762599 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -491,21 +491,6 @@ func (mr *MockDatabaseMockRecorder) CreatePrincipalKind(ctx, environmentId, prin return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreatePrincipalKind", reflect.TypeOf((*MockDatabase)(nil).CreatePrincipalKind), ctx, environmentId, principalKind) } -// CreateRelationshipFinding mocks base method. -func (m *MockDatabase) CreateRelationshipFinding(ctx context.Context, extensionId, relationshipKindId, environmentId int32, name, displayName string) (model.SchemaRelationshipFinding, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateRelationshipFinding", ctx, extensionId, relationshipKindId, environmentId, name, displayName) - ret0, _ := ret[0].(model.SchemaRelationshipFinding) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateRelationshipFinding indicates an expected call of CreateRelationshipFinding. -func (mr *MockDatabaseMockRecorder) CreateRelationshipFinding(ctx, extensionId, relationshipKindId, environmentId, name, displayName any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateRelationshipFinding", reflect.TypeOf((*MockDatabase)(nil).CreateRelationshipFinding), ctx, extensionId, relationshipKindId, environmentId, name, displayName) -} - // CreateRemediation mocks base method. func (m *MockDatabase) CreateRemediation(ctx context.Context, findingId int32, shortDescription, longDescription, shortRemediation, longRemediation string) (model.Remediation, error) { m.ctrl.T.Helper() @@ -615,6 +600,21 @@ func (mr *MockDatabaseMockRecorder) CreateSavedQueryPermissionsToUsers(ctx, quer return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSavedQueryPermissionsToUsers", reflect.TypeOf((*MockDatabase)(nil).CreateSavedQueryPermissionsToUsers), varargs...) } +// CreateSchemaRelationshipFinding mocks base method. +func (m *MockDatabase) CreateSchemaRelationshipFinding(ctx context.Context, extensionId, relationshipKindId, environmentId int32, name, displayName string) (model.SchemaRelationshipFinding, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSchemaRelationshipFinding", ctx, extensionId, relationshipKindId, environmentId, name, displayName) + ret0, _ := ret[0].(model.SchemaRelationshipFinding) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateSchemaRelationshipFinding indicates an expected call of CreateSchemaRelationshipFinding. +func (mr *MockDatabaseMockRecorder) CreateSchemaRelationshipFinding(ctx, extensionId, relationshipKindId, environmentId, name, displayName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaRelationshipFinding", reflect.TypeOf((*MockDatabase)(nil).CreateSchemaRelationshipFinding), ctx, extensionId, relationshipKindId, environmentId, name, displayName) +} + // CreateUser mocks base method. func (m *MockDatabase) CreateUser(ctx context.Context, user model.User) (model.User, error) { m.ctrl.T.Helper() @@ -954,20 +954,6 @@ func (mr *MockDatabaseMockRecorder) DeletePrincipalKind(ctx, environmentId, prin return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePrincipalKind", reflect.TypeOf((*MockDatabase)(nil).DeletePrincipalKind), ctx, environmentId, principalKind) } -// DeleteRelationshipFinding mocks base method. -func (m *MockDatabase) DeleteRelationshipFinding(ctx context.Context, findingId int32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteRelationshipFinding", ctx, findingId) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteRelationshipFinding indicates an expected call of DeleteRelationshipFinding. -func (mr *MockDatabaseMockRecorder) DeleteRelationshipFinding(ctx, findingId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRelationshipFinding", reflect.TypeOf((*MockDatabase)(nil).DeleteRelationshipFinding), ctx, findingId) -} - // DeleteRemediation mocks base method. func (m *MockDatabase) DeleteRemediation(ctx context.Context, findingId int32) error { m.ctrl.T.Helper() @@ -1029,6 +1015,20 @@ func (mr *MockDatabaseMockRecorder) DeleteSavedQueryPermissionsForUsers(ctx, que return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSavedQueryPermissionsForUsers", reflect.TypeOf((*MockDatabase)(nil).DeleteSavedQueryPermissionsForUsers), varargs...) } +// DeleteSchemaRelationshipFinding mocks base method. +func (m *MockDatabase) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteSchemaRelationshipFinding", ctx, findingId) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteSchemaRelationshipFinding indicates an expected call of DeleteSchemaRelationshipFinding. +func (mr *MockDatabaseMockRecorder) DeleteSchemaRelationshipFinding(ctx, findingId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaRelationshipFinding", reflect.TypeOf((*MockDatabase)(nil).DeleteSchemaRelationshipFinding), ctx, findingId) +} + // DeleteSelectorNodesByNodeId mocks base method. func (m *MockDatabase) DeleteSelectorNodesByNodeId(ctx context.Context, selectorId int, nodeId graph.ID) error { m.ctrl.T.Helper() @@ -2082,36 +2082,6 @@ func (mr *MockDatabaseMockRecorder) GetPublicSavedQueries(ctx any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublicSavedQueries", reflect.TypeOf((*MockDatabase)(nil).GetPublicSavedQueries), ctx) } -// GetRelationshipFindingById mocks base method. -func (m *MockDatabase) GetRelationshipFindingById(ctx context.Context, findingId int32) (model.SchemaRelationshipFinding, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetRelationshipFindingById", ctx, findingId) - ret0, _ := ret[0].(model.SchemaRelationshipFinding) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetRelationshipFindingById indicates an expected call of GetRelationshipFindingById. -func (mr *MockDatabaseMockRecorder) GetRelationshipFindingById(ctx, findingId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRelationshipFindingById", reflect.TypeOf((*MockDatabase)(nil).GetRelationshipFindingById), ctx, findingId) -} - -// GetRelationshipFindingByName mocks base method. -func (m *MockDatabase) GetRelationshipFindingByName(ctx context.Context, name string) (model.SchemaRelationshipFinding, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetRelationshipFindingByName", ctx, name) - ret0, _ := ret[0].(model.SchemaRelationshipFinding) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetRelationshipFindingByName indicates an expected call of GetRelationshipFindingByName. -func (mr *MockDatabaseMockRecorder) GetRelationshipFindingByName(ctx, name any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRelationshipFindingByName", reflect.TypeOf((*MockDatabase)(nil).GetRelationshipFindingByName), ctx, name) -} - // GetRemediationByFindingId mocks base method. func (m *MockDatabase) GetRemediationByFindingId(ctx context.Context, findingId int32) (model.Remediation, error) { m.ctrl.T.Helper() @@ -2277,6 +2247,36 @@ func (mr *MockDatabaseMockRecorder) GetSavedQueryPermissions(ctx, queryID any) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSavedQueryPermissions", reflect.TypeOf((*MockDatabase)(nil).GetSavedQueryPermissions), ctx, queryID) } +// GetSchemaRelationshipFindingById mocks base method. +func (m *MockDatabase) GetSchemaRelationshipFindingById(ctx context.Context, findingId int32) (model.SchemaRelationshipFinding, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSchemaRelationshipFindingById", ctx, findingId) + ret0, _ := ret[0].(model.SchemaRelationshipFinding) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSchemaRelationshipFindingById indicates an expected call of GetSchemaRelationshipFindingById. +func (mr *MockDatabaseMockRecorder) GetSchemaRelationshipFindingById(ctx, findingId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaRelationshipFindingById", reflect.TypeOf((*MockDatabase)(nil).GetSchemaRelationshipFindingById), ctx, findingId) +} + +// GetSchemaRelationshipFindingByName mocks base method. +func (m *MockDatabase) GetSchemaRelationshipFindingByName(ctx context.Context, name string) (model.SchemaRelationshipFinding, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSchemaRelationshipFindingByName", ctx, name) + ret0, _ := ret[0].(model.SchemaRelationshipFinding) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSchemaRelationshipFindingByName indicates an expected call of GetSchemaRelationshipFindingByName. +func (mr *MockDatabaseMockRecorder) GetSchemaRelationshipFindingByName(ctx, name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaRelationshipFindingByName", reflect.TypeOf((*MockDatabase)(nil).GetSchemaRelationshipFindingByName), ctx, name) +} + // GetScopeForSavedQuery mocks base method. func (m *MockDatabase) GetScopeForSavedQuery(ctx context.Context, queryID int64, userID uuid.UUID) (database.SavedQueryScopeMap, error) { m.ctrl.T.Helper()