From 29da7938906c8df5d8831bebce1acea2076d995a Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Tue, 6 Jan 2026 10:13:26 -0600 Subject: [PATCH 01/25] environment added to opengraphschema service + principal kinds - lots of questions about validations --- LICENSE.header | 2 +- cmd/api/src/database/kind.go | 19 + cmd/api/src/model/graphschema.go | 30 +- cmd/api/src/model/kind.go | 6 + .../services/opengraphschema/environment.go | 149 +++++ .../opengraphschema/environment_test.go | 535 ++++++++++++++++++ .../opengraphschema/mocks/opengraphschema.go | 192 +++++++ .../opengraphschema/opengraphschema.go | 56 ++ packages/csharp/graphschema/PropertyNames.cs | 2 +- packages/go/graphschema/ad/ad.go | 2 +- packages/go/graphschema/azure/azure.go | 2 +- packages/go/graphschema/common/common.go | 2 +- packages/go/graphschema/graph.go | 2 +- .../bh-shared-ui/src/graphSchema.ts | 2 +- .../SelectedDetailsTabs.test.tsx | 2 +- 15 files changed, 994 insertions(+), 9 deletions(-) create mode 100644 cmd/api/src/database/kind.go create mode 100644 cmd/api/src/model/kind.go create mode 100644 cmd/api/src/services/opengraphschema/environment.go create mode 100644 cmd/api/src/services/opengraphschema/environment_test.go create mode 100644 cmd/api/src/services/opengraphschema/mocks/opengraphschema.go create mode 100644 cmd/api/src/services/opengraphschema/opengraphschema.go diff --git a/LICENSE.header b/LICENSE.header index 5d5c596b1c..958be83318 100644 --- a/LICENSE.header +++ b/LICENSE.header @@ -1,4 +1,4 @@ -Copyright 2025 Specter Ops, Inc. +Copyright 2026 Specter Ops, Inc. Licensed under the Apache License, Version 2.0 you may not use this file except in compliance with the License. diff --git a/cmd/api/src/database/kind.go b/cmd/api/src/database/kind.go new file mode 100644 index 0000000000..c89049ceb8 --- /dev/null +++ b/cmd/api/src/database/kind.go @@ -0,0 +1,19 @@ +package database + +import ( + "context" + "fmt" + + "github.com/specterops/bloodhound/cmd/api/src/model" +) + +// GetKindById gets a row from the kind table by id. +func (s *BloodhoundDB) GetKindById(ctx context.Context, id int32) (model.Kind, error) { + var kind model.Kind + + // var kind Kind + return kind, CheckError(s.db.WithContext(ctx).Raw(fmt.Sprintf(` + SELECT id, name + FROM %s WHERE id = ?`, kindTable), id).First(&kind), + ) +} diff --git a/cmd/api/src/model/graphschema.go b/cmd/api/src/model/graphschema.go index 32fcfbbf48..cc343e5197 100644 --- a/cmd/api/src/model/graphschema.go +++ b/cmd/api/src/model/graphschema.go @@ -16,7 +16,10 @@ package model -import "time" +import ( + "fmt" + "time" +) type GraphSchemaExtensions []GraphSchemaExtension @@ -98,6 +101,20 @@ func (GraphSchemaEdgeKind) TableName() string { return "schema_edge_kinds" } +// GraphSchemaEnvironments - slice of environments +type GraphSchemaEnvironments []SchemaEnvironment + +// ToMapKeyedOnEnvironmentAndSource - converts a list of graph schema environments to a map based on environment kind id and source id +func (s GraphSchemaEnvironments) ToMapKeyedOnEnvironmentAndSource() map[string]SchemaEnvironment { + m := make(map[string]SchemaEnvironment) + for _, env := range s { + // Key is environment id + source id separated with an underscore e.g., kindid_sourceid + key := fmt.Sprintf("%d_%d", env.EnvironmentKindId, env.SourceKindId) + m[key] = env + } + return m +} + type SchemaEnvironment struct { Serial SchemaExtensionId int32 `json:"schema_extension_id"` @@ -144,3 +161,14 @@ type GraphSchemaEdgeKindWithNamedSchema struct { } type GraphSchemaEdgeKindsWithNamedSchema []GraphSchemaEdgeKindWithNamedSchema + +type SchemaEnvironmentPrincipalKinds []SchemaEnvironmentPrincipalKind + +type SchemaEnvironmentPrincipalKind struct { + EnvironmentId int32 `json:"environment_id"` + PrincipalKind int32 `json:"principal_kind"` +} + +func (SchemaEnvironmentPrincipalKind) TableName() string { + return "schema_environments_principal_kinds" +} diff --git a/cmd/api/src/model/kind.go b/cmd/api/src/model/kind.go new file mode 100644 index 0000000000..786fa0a5bf --- /dev/null +++ b/cmd/api/src/model/kind.go @@ -0,0 +1,6 @@ +package model + +type Kind struct { + ID int `json:"id"` + Name string `json:"name"` +} diff --git a/cmd/api/src/services/opengraphschema/environment.go b/cmd/api/src/services/opengraphschema/environment.go new file mode 100644 index 0000000000..492da32a49 --- /dev/null +++ b/cmd/api/src/services/opengraphschema/environment.go @@ -0,0 +1,149 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package opengraphschema + +import ( + "context" + "errors" + "fmt" + + "github.com/specterops/bloodhound/cmd/api/src/database" + "github.com/specterops/bloodhound/cmd/api/src/model" +) + +func (o *OpenGraphSchemaService) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, environment model.SchemaEnvironment, principalKinds []model.SchemaEnvironmentPrincipalKind) error { + // Validate environment + if err := o.validateSchemaEnvironment(ctx, environment); err != nil { + return fmt.Errorf("error validating schema environment: %w", err) + } + + // Validate principal kinds + for _, kind := range principalKinds { + if err := o.validateSchemaEnvironmentPrincipalKind(ctx, kind.PrincipalKind); err != nil { + return fmt.Errorf("error validating principal kind: %w", err) + } + } + + // Upsert the environment + id, err := o.upsertSchemaEnvironment(ctx, environment) + if err != nil { + return fmt.Errorf("error upserting schema environment: %w", err) + } + + // Upsert principal kinds for environment + if err := o.upsertPrincipalKinds(ctx, id, principalKinds); err != nil { + return fmt.Errorf("error upserting principal kinds: %w", err) + } + + return nil +} + +func (o *OpenGraphSchemaService) upsertSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) (int32, error) { + if existing, err := o.openGraphSchemaRepository.GetSchemaEnvironmentById(ctx, graphSchema.ID); err != nil && !errors.Is(err, database.ErrNotFound) { + return 0, fmt.Errorf("error retrieving schema environment id %d: %w", graphSchema.ID, err) + } else if !errors.Is(err, database.ErrNotFound) { + // Environment exists - delete it first + if err := o.openGraphSchemaRepository.DeleteSchemaEnvironment(ctx, existing.ID); err != nil { + return 0, fmt.Errorf("error deleting schema environment %d: %w", existing.ID, err) + } + } + + // Create Environment + if created, err := o.openGraphSchemaRepository.CreateSchemaEnvironment(ctx, graphSchema.SchemaExtensionId, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil { + return 0, fmt.Errorf("error creating schema environment: %w", err) + } else { + return created.ID, nil + } +} + +/* +Validations: https://github.com/SpecterOps/BloodHound/blob/73b569a340ef5cd459b383e3e42e707b201193ee/rfc/bh-rfc-4.md#10-validation-rules-for-environments + 1. Ensure the specified environmentKind exists in kinds table + ** QUESTION: Documentation states to use the kind table. Is there already a database method to query the kind table to do this validation or was I supposed to create it? + 2. Ensure the specified sourceKind exists in source_kinds table (create if it doesn't, reactivate if it does) +*/ +func (o *OpenGraphSchemaService) validateSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) error { + // Validate environment kind id exists in kinds table + if _, err := o.openGraphSchemaRepository.GetKindById(ctx, graphSchema.EnvironmentKindId); err != nil { + return fmt.Errorf("error retrieving environment kind: %w", err) + } + + // Get all source kinds + if sourceKinds, err := o.openGraphSchemaRepository.GetSourceKinds(ctx); err != nil { + return fmt.Errorf("error retrieving source kinds: %w", err) + } else { + // Check if source kind exists + found := false + for _, kind := range sourceKinds { + if graphSchema.SourceKindId == int32(kind.ID) { + found = true + break + } + } + + if !found { + /* + ** QUESTION: Example Environment Schema: https://github.com/SpecterOps/BloodHound/blob/73b569a340ef5cd459b383e3e42e707b201193ee/rfc/bh-rfc-4.md#9-environments-and-principal-kinds + The RFC example uses source kind names (strings) in the environment schema, but our model uses IDs (int32). To register a new source kind, we need the + kind name/data, not just the ID. Cannot register with only an ID - kind id/name is required for registration. + RegisterSourceKind(ctx context.Context) func(sourceKind graph.Kind) error + + For now, this validates that the source kind should exist. + */ + return fmt.Errorf("invalid source kind id %d", graphSchema.SourceKindId) + } + } + + return nil +} + +func (o *OpenGraphSchemaService) upsertPrincipalKinds(ctx context.Context, environmentID int32, principalKinds []model.SchemaEnvironmentPrincipalKind) error { + if existingKinds, err := o.openGraphSchemaRepository.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentID); err != nil && !errors.Is(err, database.ErrNotFound) { + return fmt.Errorf("error retrieving existing principal kinds for environment %d: %w", environmentID, err) + } else if !errors.Is(err, database.ErrNotFound) { + // Delete all existing principal kinds + for _, kind := range existingKinds { + if err := o.openGraphSchemaRepository.DeleteSchemaEnvironmentPrincipalKind(ctx, kind.EnvironmentId, kind.PrincipalKind); err != nil { + return fmt.Errorf("error deleting principal kind %d for environment %d: %w", kind.PrincipalKind, kind.EnvironmentId, err) + } + } + } + + // Create the new principal kinds + for _, kind := range principalKinds { + if _, err := o.openGraphSchemaRepository.CreateSchemaEnvironmentPrincipalKind(ctx, environmentID, kind.PrincipalKind); err != nil { + return fmt.Errorf("error creating principal kind %d for environment %d: %w", kind.PrincipalKind, environmentID, err) + } + } + + return nil +} + +/* +Validations: https://github.com/SpecterOps/BloodHound/blob/73b569a340ef5cd459b383e3e42e707b201193ee/rfc/bh-rfc-4.md#10-validation-rules-for-environments +1. Ensure all principalKinds exist in kinds table. + +** QUESTION: Documentation states to use the kind table. Is there already a database method to query the kind table to do this validation or was I supposed to create it? +*/ +func (o *OpenGraphSchemaService) validateSchemaEnvironmentPrincipalKind(ctx context.Context, kindID int32) error { + if _, err := o.openGraphSchemaRepository.GetKindById(ctx, kindID); err != nil && !errors.Is(err, database.ErrNotFound) { + return fmt.Errorf("error retrieving kind by id: %w", err) + } else if errors.Is(err, database.ErrNotFound) { + return fmt.Errorf("invalid principal kind id %d", kindID) + } + + return nil +} diff --git a/cmd/api/src/services/opengraphschema/environment_test.go b/cmd/api/src/services/opengraphschema/environment_test.go new file mode 100644 index 0000000000..991509a6ed --- /dev/null +++ b/cmd/api/src/services/opengraphschema/environment_test.go @@ -0,0 +1,535 @@ +package opengraphschema_test + +import ( + "context" + "errors" + "testing" + + "github.com/specterops/bloodhound/cmd/api/src/database" + "github.com/specterops/bloodhound/cmd/api/src/model" + "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" + schemamocks "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema/mocks" + "github.com/specterops/dawgs/graph" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { + type mocks struct { + mockOpenGraphSchema *schemamocks.MockOpenGraphSchemaRepository + } + type args struct { + environment model.SchemaEnvironment + principalKinds []model.SchemaEnvironmentPrincipalKind + } + tests := []struct { + name string + mocks mocks + setupMocks func(t *testing.T, mock *mocks) + args args + expected error + }{ + { + name: "Error: Validation - Failed to retrieve environment kind", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(1), + EnvironmentKindId: int32(1), + SourceKindId: int32(1), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{}, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(1)).Return(model.Kind{}, errors.New("error")) + }, + expected: errors.New("error validating schema environment: error retrieving environment kind: error"), + }, + { + name: "Error: Validation - Environment Kind doesn't exist in Kinds table", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(1), + EnvironmentKindId: int32(1), + SourceKindId: int32(1), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{}, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(1)).Return(model.Kind{}, database.ErrNotFound) + }, + expected: errors.New("error validating schema environment: error retrieving environment kind: entity not found"), + }, + { + name: "Error: Validation - Failed to retrieve source kinds", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(1), + EnvironmentKindId: int32(1), + SourceKindId: int32(1), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{}, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(1)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{}, errors.New("error")) + }, + expected: errors.New("error validating schema environment: error retrieving source kinds: error"), + }, + { + name: "Error: Validation - Source Kind doesn't exist", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(1), + EnvironmentKindId: int32(3), + SourceKindId: int32(1), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{}, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + }, + expected: errors.New("error validating schema environment: invalid source kind id 1"), + }, + { + name: "Error: Validation - Failed to retrieve principal kind", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(1), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{ + { + EnvironmentId: int32(1), + PrincipalKind: int32(99), + }, + }, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + // Environment validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + // Principal kind validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(99)).Return(model.Kind{}, errors.New("error")) + }, + expected: errors.New("error validating principal kind: error retrieving kind by id: error"), + }, + { + name: "Error: Validation - Principal Kind doesn't exist", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(1), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{ + { + EnvironmentId: int32(1), + PrincipalKind: int32(99), + }, + }, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + // Environment validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + // Principal kind validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(99)).Return(model.Kind{}, database.ErrNotFound) + }, + expected: errors.New("error validating principal kind: invalid principal kind id 99"), + }, + { + name: "Error: GetSchemaEnvironmentById fails", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(3), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{}, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, errors.New("error")) + }, + expected: errors.New("error upserting schema environment: error retrieving schema environment id 0: error"), + }, + { + name: "Error: DeleteSchemaEnvironment fails", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(3), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{}, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{ + Serial: model.Serial{ID: 5}, + }, nil) + mocks.mockOpenGraphSchema.EXPECT().DeleteSchemaEnvironment(gomock.Any(), int32(5)).Return(errors.New("error")) + }, + expected: errors.New("error upserting schema environment: error deleting schema environment 5: error"), + }, + { + name: "Error: CreateSchemaEnvironment fails after delete", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(3), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{}, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{ + Serial: model.Serial{ID: 5}, + }, nil) + mocks.mockOpenGraphSchema.EXPECT().DeleteSchemaEnvironment(gomock.Any(), int32(5)).Return(nil) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{}, errors.New("error")) + }, + expected: errors.New("error upserting schema environment: error creating schema environment: error"), + }, + { + name: "Error: CreateSchemaEnvironment fails on new environment", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(3), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{}, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{}, errors.New("error")) + }, + expected: errors.New("error upserting schema environment: error creating schema environment: error"), + }, + { + name: "Error: GetSchemaEnvironmentPrincipalKindsByEnvironmentId fails", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(3), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{ + { + PrincipalKind: int32(3), + }, + }, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + // Environment validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + // Principal kind validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + + // Environment upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ + Serial: model.Serial{ID: 10}, + }, nil) + + // Principal kinds upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, errors.New("error")) + }, + expected: errors.New("error upserting principal kinds: error retrieving existing principal kinds for environment 10: error"), + }, + { + name: "Error: DeleteSchemaEnvironmentPrincipalKind fails", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(3), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{ + { + PrincipalKind: int32(3), + }, + }, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + // Environment validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + // Principal kind validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + + // Environment upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ + Serial: model.Serial{ID: 10}, + }, nil) + + // Principal kinds upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return([]model.SchemaEnvironmentPrincipalKind{ + { + EnvironmentId: int32(10), + PrincipalKind: int32(5), + }, + }, nil) + mocks.mockOpenGraphSchema.EXPECT().DeleteSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(5)).Return(errors.New("error")) + }, + expected: errors.New("error upserting principal kinds: error deleting principal kind 5 for environment 10: error"), + }, + { + name: "Error: CreateSchemaEnvironmentPrincipalKind fails", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(3), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{ + { + PrincipalKind: int32(3), + }, + }, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + // Environment validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + // Principal kind validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + + // Environment upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ + Serial: model.Serial{ID: 10}, + }, nil) + + // Principal kinds upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(3)).Return(model.SchemaEnvironmentPrincipalKind{}, errors.New("error")) + }, + expected: errors.New("error upserting principal kinds: error creating principal kind 3 for environment 10: error"), + }, + { + name: "Success: Create new environment with principal kinds", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(3), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{ + { + PrincipalKind: int32(3), + }, + { + PrincipalKind: int32(4), + }, + }, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + // Environment validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind3"), + }, + { + ID: 4, + Name: graph.StringKind("kind4"), + }, + }, nil) + // Principal kind validations + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(4)).Return(model.Kind{}, nil) + + // Environment upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ + Serial: model.Serial{ID: 10}, + }, nil) + + // Principal kinds upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(3)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(4)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) + }, + expected: nil, + }, + { + name: "Success: Update existing environment and replace principal kinds", + args: args{ + environment: model.SchemaEnvironment{ + Serial: model.Serial{ID: 5}, + SchemaExtensionId: int32(3), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{ + { + PrincipalKind: int32(3), + }, + }, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + // Environment validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + // Principal kind validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + + // Environment upsert (delete and recreate) + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(5)).Return(model.SchemaEnvironment{ + Serial: model.Serial{ID: 5}, + }, nil) + mocks.mockOpenGraphSchema.EXPECT().DeleteSchemaEnvironment(gomock.Any(), int32(5)).Return(nil) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ + Serial: model.Serial{ID: 10}, + }, nil) + + // Principal kinds upsert (delete old, create new) + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return([]model.SchemaEnvironmentPrincipalKind{ + { + EnvironmentId: int32(10), + PrincipalKind: int32(99), + }, + }, nil) + mocks.mockOpenGraphSchema.EXPECT().DeleteSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(99)).Return(nil) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(3)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) + }, + expected: nil, + }, + { + name: "Success: Create environment with no principal kinds", + args: args{ + environment: model.SchemaEnvironment{ + SchemaExtensionId: int32(3), + EnvironmentKindId: int32(3), + SourceKindId: int32(3), + }, + principalKinds: []model.SchemaEnvironmentPrincipalKind{}, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + // Environment validation + mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + { + ID: 3, + Name: graph.StringKind("kind"), + }, + }, nil) + + // Environment upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ + Serial: model.Serial{ID: 10}, + }, nil) + + // Principal kinds upsert (no existing, no new) + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, database.ErrNotFound) + }, + expected: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + mocks := &mocks{ + mockOpenGraphSchema: schemamocks.NewMockOpenGraphSchemaRepository(ctrl), + } + + tt.setupMocks(t, mocks) + + graphService := opengraphschema.NewOpenGraphSchemaService(mocks.mockOpenGraphSchema) + + err := graphService.UpsertSchemaEnvironmentWithPrincipalKinds(context.Background(), tt.args.environment, tt.args.principalKinds) + if tt.expected != nil { + assert.EqualError(t, tt.expected, err.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go new file mode 100644 index 0000000000..7723197208 --- /dev/null +++ b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go @@ -0,0 +1,192 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema (interfaces: OpenGraphSchemaRepository) +// +// Generated by this command: +// +// mockgen -copyright_file ../../../../../LICENSE.header -destination=./mocks/opengraphschema.go -package=mocks . OpenGraphSchemaRepository +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + database "github.com/specterops/bloodhound/cmd/api/src/database" + model "github.com/specterops/bloodhound/cmd/api/src/model" + graph "github.com/specterops/dawgs/graph" + gomock "go.uber.org/mock/gomock" +) + +// MockOpenGraphSchemaRepository is a mock of OpenGraphSchemaRepository interface. +type MockOpenGraphSchemaRepository struct { + ctrl *gomock.Controller + recorder *MockOpenGraphSchemaRepositoryMockRecorder + isgomock struct{} +} + +// MockOpenGraphSchemaRepositoryMockRecorder is the mock recorder for MockOpenGraphSchemaRepository. +type MockOpenGraphSchemaRepositoryMockRecorder struct { + mock *MockOpenGraphSchemaRepository +} + +// NewMockOpenGraphSchemaRepository creates a new mock instance. +func NewMockOpenGraphSchemaRepository(ctrl *gomock.Controller) *MockOpenGraphSchemaRepository { + mock := &MockOpenGraphSchemaRepository{ctrl: ctrl} + mock.recorder = &MockOpenGraphSchemaRepositoryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOpenGraphSchemaRepository) EXPECT() *MockOpenGraphSchemaRepositoryMockRecorder { + return m.recorder +} + +// CreateSchemaEnvironment mocks base method. +func (m *MockOpenGraphSchemaRepository) CreateSchemaEnvironment(ctx context.Context, schemaExtensionId, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSchemaEnvironment", ctx, schemaExtensionId, environmentKindId, sourceKindId) + ret0, _ := ret[0].(model.SchemaEnvironment) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateSchemaEnvironment indicates an expected call of CreateSchemaEnvironment. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) CreateSchemaEnvironment(ctx, schemaExtensionId, environmentKindId, sourceKindId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironment", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).CreateSchemaEnvironment), ctx, schemaExtensionId, environmentKindId, sourceKindId) +} + +// CreateSchemaEnvironmentPrincipalKind mocks base method. +func (m *MockOpenGraphSchemaRepository) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) + ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKind) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateSchemaEnvironmentPrincipalKind indicates an expected call of CreateSchemaEnvironmentPrincipalKind. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) CreateSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).CreateSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) +} + +// DeleteSchemaEnvironment mocks base method. +func (m *MockOpenGraphSchemaRepository) DeleteSchemaEnvironment(ctx context.Context, schemaEnvironmentId int32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteSchemaEnvironment", ctx, schemaEnvironmentId) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteSchemaEnvironment indicates an expected call of DeleteSchemaEnvironment. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) DeleteSchemaEnvironment(ctx, schemaEnvironmentId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironment", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).DeleteSchemaEnvironment), ctx, schemaEnvironmentId) +} + +// DeleteSchemaEnvironmentPrincipalKind mocks base method. +func (m *MockOpenGraphSchemaRepository) DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteSchemaEnvironmentPrincipalKind indicates an expected call of DeleteSchemaEnvironmentPrincipalKind. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) DeleteSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).DeleteSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) +} + +// GetKindById mocks base method. +func (m *MockOpenGraphSchemaRepository) GetKindById(ctx context.Context, id int32) (model.Kind, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetKindById", ctx, id) + ret0, _ := ret[0].(model.Kind) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetKindById indicates an expected call of GetKindById. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetKindById(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKindById", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetKindById), ctx, id) +} + +// GetSchemaEnvironmentById mocks base method. +func (m *MockOpenGraphSchemaRepository) GetSchemaEnvironmentById(ctx context.Context, schemaEnvironmentId int32) (model.SchemaEnvironment, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSchemaEnvironmentById", ctx, schemaEnvironmentId) + ret0, _ := ret[0].(model.SchemaEnvironment) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSchemaEnvironmentById indicates an expected call of GetSchemaEnvironmentById. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetSchemaEnvironmentById(ctx, schemaEnvironmentId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentById", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetSchemaEnvironmentById), ctx, schemaEnvironmentId) +} + +// GetSchemaEnvironmentPrincipalKindsByEnvironmentId mocks base method. +func (m *MockOpenGraphSchemaRepository) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", ctx, environmentId) + ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKinds) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSchemaEnvironmentPrincipalKindsByEnvironmentId indicates an expected call of GetSchemaEnvironmentPrincipalKindsByEnvironmentId. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetSchemaEnvironmentPrincipalKindsByEnvironmentId), ctx, environmentId) +} + +// GetSourceKinds mocks base method. +func (m *MockOpenGraphSchemaRepository) GetSourceKinds(ctx context.Context) ([]database.SourceKind, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSourceKinds", ctx) + ret0, _ := ret[0].([]database.SourceKind) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSourceKinds indicates an expected call of GetSourceKinds. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetSourceKinds(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceKinds", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetSourceKinds), ctx) +} + +// RegisterSourceKind mocks base method. +func (m *MockOpenGraphSchemaRepository) RegisterSourceKind(ctx context.Context) func(graph.Kind) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterSourceKind", ctx) + ret0, _ := ret[0].(func(graph.Kind) error) + return ret0 +} + +// RegisterSourceKind indicates an expected call of RegisterSourceKind. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) RegisterSourceKind(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterSourceKind", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).RegisterSourceKind), ctx) +} diff --git a/cmd/api/src/services/opengraphschema/opengraphschema.go b/cmd/api/src/services/opengraphschema/opengraphschema.go new file mode 100644 index 0000000000..1ff658041d --- /dev/null +++ b/cmd/api/src/services/opengraphschema/opengraphschema.go @@ -0,0 +1,56 @@ +// Copyright 2025 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package opengraphschema + +//go:generate go run go.uber.org/mock/mockgen -copyright_file ../../../../../LICENSE.header -destination=./mocks/opengraphschema.go -package=mocks . OpenGraphSchemaRepository + +import ( + "context" + + "github.com/specterops/bloodhound/cmd/api/src/database" + "github.com/specterops/bloodhound/cmd/api/src/model" + "github.com/specterops/dawgs/graph" +) + +// OpenGraphSchemaRepository - +type OpenGraphSchemaRepository interface { + // Kinds + GetKindById(ctx context.Context, id int32) (model.Kind, error) + + // Environment + CreateSchemaEnvironment(ctx context.Context, schemaExtensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) + GetSchemaEnvironmentById(ctx context.Context, schemaEnvironmentId int32) (model.SchemaEnvironment, error) + DeleteSchemaEnvironment(ctx context.Context, schemaEnvironmentId int32) error + + // Source Kinds + RegisterSourceKind(ctx context.Context) func(sourceKind graph.Kind) error + GetSourceKinds(ctx context.Context) ([]database.SourceKind, error) + + // Principal Kinds + CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) + GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) + DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error +} + +type OpenGraphSchemaService struct { + openGraphSchemaRepository OpenGraphSchemaRepository +} + +func NewOpenGraphSchemaService(openGraphSchemaRepository OpenGraphSchemaRepository) *OpenGraphSchemaService { + return &OpenGraphSchemaService{ + openGraphSchemaRepository: openGraphSchemaRepository, + } +} diff --git a/packages/csharp/graphschema/PropertyNames.cs b/packages/csharp/graphschema/PropertyNames.cs index fc3d6b08a4..2031b9df97 100644 --- a/packages/csharp/graphschema/PropertyNames.cs +++ b/packages/csharp/graphschema/PropertyNames.cs @@ -1,5 +1,5 @@ /* - Copyright 2025 Specter Ops, Inc. + Copyright 2026 Specter Ops, Inc. Licensed under the Apache License, Version 2.0 you may not use this file except in compliance with the License. diff --git a/packages/go/graphschema/ad/ad.go b/packages/go/graphschema/ad/ad.go index c7a589ab21..c9411682e6 100644 --- a/packages/go/graphschema/ad/ad.go +++ b/packages/go/graphschema/ad/ad.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/packages/go/graphschema/azure/azure.go b/packages/go/graphschema/azure/azure.go index 3ea1483949..709eab6828 100644 --- a/packages/go/graphschema/azure/azure.go +++ b/packages/go/graphschema/azure/azure.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/packages/go/graphschema/common/common.go b/packages/go/graphschema/common/common.go index f18f99a486..057e552ab7 100644 --- a/packages/go/graphschema/common/common.go +++ b/packages/go/graphschema/common/common.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/packages/go/graphschema/graph.go b/packages/go/graphschema/graph.go index aedef13acd..acb1385859 100644 --- a/packages/go/graphschema/graph.go +++ b/packages/go/graphschema/graph.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/packages/javascript/bh-shared-ui/src/graphSchema.ts b/packages/javascript/bh-shared-ui/src/graphSchema.ts index 7ede0cf403..c92263f53e 100644 --- a/packages/javascript/bh-shared-ui/src/graphSchema.ts +++ b/packages/javascript/bh-shared-ui/src/graphSchema.ts @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/packages/javascript/bh-shared-ui/src/views/PrivilegeZones/Details/SelectedDetailsTabs/SelectedDetailsTabs.test.tsx b/packages/javascript/bh-shared-ui/src/views/PrivilegeZones/Details/SelectedDetailsTabs/SelectedDetailsTabs.test.tsx index dab44ff51a..e5d1bb755d 100644 --- a/packages/javascript/bh-shared-ui/src/views/PrivilegeZones/Details/SelectedDetailsTabs/SelectedDetailsTabs.test.tsx +++ b/packages/javascript/bh-shared-ui/src/views/PrivilegeZones/Details/SelectedDetailsTabs/SelectedDetailsTabs.test.tsx @@ -173,4 +173,4 @@ describe('Selected Details Tabs', () => { expect(objectTab).toBeEnabled(); }); }); -}); \ No newline at end of file +}); From a7f351eaa20e8fbfbdb6091bbc9ce9f3a6872e65 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Tue, 6 Jan 2026 18:05:38 -0600 Subject: [PATCH 02/25] did an overhaul of everything after talking with Cody and pulled a couple of Lawson's changes to test the environment upsert with a http endpoint --- cmd/api/src/api/mocks/authenticator.go | 2 +- cmd/api/src/api/registration/registration.go | 3 +- cmd/api/src/api/registration/v2.go | 2 + .../src/api/v2/mocks/graphschemaextensions.go | 72 +++ cmd/api/src/api/v2/model.go | 3 + cmd/api/src/api/v2/opengraphschema.go | 62 +++ cmd/api/src/daemons/datapipe/mocks/cleanup.go | 2 +- cmd/api/src/database/graphschema.go | 41 ++ cmd/api/src/database/kind.go | 34 +- cmd/api/src/database/kind_integration_test.go | 75 +++ .../database/migration/migrations/v8.5.0.sql | 4 + cmd/api/src/database/mocks/auth.go | 2 +- cmd/api/src/database/mocks/db.go | 17 +- cmd/api/src/database/sourcekinds.go | 28 + .../database/sourcekinds_integration_test.go | 50 ++ cmd/api/src/model/graphschema.go | 15 - cmd/api/src/model/kind.go | 15 + cmd/api/src/queries/mocks/graph.go | 2 +- cmd/api/src/services/agi/mocks/mock.go | 2 +- .../src/services/dataquality/mocks/mock.go | 2 +- cmd/api/src/services/entrypoint.go | 20 +- cmd/api/src/services/fs/mocks/fs.go | 2 +- cmd/api/src/services/graphify/mocks/ingest.go | 2 +- cmd/api/src/services/oidc/mocks/oidc.go | 2 +- .../services/opengraphschema/environment.go | 161 +++--- .../opengraphschema/environment_test.go | 509 +++++++++--------- .../opengraphschema/mocks/opengraphschema.go | 26 +- .../opengraphschema/opengraphschema.go | 4 +- cmd/api/src/services/saml/mocks/saml.go | 2 +- cmd/api/src/services/upload/mocks/mock.go | 2 +- .../src/utils/validation/mocks/validator.go | 2 +- cmd/api/src/vendormocks/dawgs/graph/mock.go | 2 +- cmd/api/src/vendormocks/io/fs/mock.go | 2 +- .../neo4j/neo4j-go-driver/v5/neo4j/mock.go | 2 +- packages/go/crypto/mocks/digest.go | 2 +- 35 files changed, 769 insertions(+), 404 deletions(-) create mode 100644 cmd/api/src/api/v2/mocks/graphschemaextensions.go create mode 100644 cmd/api/src/api/v2/opengraphschema.go create mode 100644 cmd/api/src/database/kind_integration_test.go diff --git a/cmd/api/src/api/mocks/authenticator.go b/cmd/api/src/api/mocks/authenticator.go index 32cbd5053e..8811703cb3 100644 --- a/cmd/api/src/api/mocks/authenticator.go +++ b/cmd/api/src/api/mocks/authenticator.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/api/registration/registration.go b/cmd/api/src/api/registration/registration.go index e06a494d23..bd3938c32a 100644 --- a/cmd/api/src/api/registration/registration.go +++ b/cmd/api/src/api/registration/registration.go @@ -62,6 +62,7 @@ func RegisterFossRoutes( authenticator api.Authenticator, authorizer auth.Authorizer, ingestSchema upload.IngestSchema, + openGraphSchemaService v2.OpenGraphSchemaService, ) { router.With(func() mux.MiddlewareFunc { return middleware.DefaultRateLimitMiddleware(rdms) @@ -80,6 +81,6 @@ func RegisterFossRoutes( routerInst.PathPrefix("/ui", static.AssetHandler), ) - var resources = v2.NewResources(rdms, graphDB, cfg, apiCache, graphQuery, collectorManifests, authorizer, authenticator, ingestSchema) + var resources = v2.NewResources(rdms, graphDB, cfg, apiCache, graphQuery, collectorManifests, authorizer, authenticator, ingestSchema, openGraphSchemaService) NewV2API(resources, routerInst) } diff --git a/cmd/api/src/api/registration/v2.go b/cmd/api/src/api/registration/v2.go index 5b04638b4b..0bf9a66f44 100644 --- a/cmd/api/src/api/registration/v2.go +++ b/cmd/api/src/api/registration/v2.go @@ -364,5 +364,7 @@ func NewV2API(resources v2.Resources, routerInst *router.Router) { routerInst.POST("/api/v2/custom-nodes", resources.CreateCustomNodeKind).RequireAuth(), routerInst.PUT(fmt.Sprintf("/api/v2/custom-nodes/{%s}", v2.CustomNodeKindParameter), resources.UpdateCustomNodeKind).RequireAuth(), routerInst.DELETE(fmt.Sprintf("/api/v2/custom-nodes/{%s}", v2.CustomNodeKindParameter), resources.DeleteCustomNodeKind).RequireAuth(), + + routerInst.PUT("/api/v2/extensions", resources.OpenGraphSchemaIngest), ) } diff --git a/cmd/api/src/api/v2/mocks/graphschemaextensions.go b/cmd/api/src/api/v2/mocks/graphschemaextensions.go new file mode 100644 index 0000000000..96cb240811 --- /dev/null +++ b/cmd/api/src/api/v2/mocks/graphschemaextensions.go @@ -0,0 +1,72 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/specterops/bloodhound/cmd/api/src/api/v2 (interfaces: OpenGraphSchemaService) +// +// Generated by this command: +// +// mockgen -copyright_file ../../../../../LICENSE.header -destination=./mocks/graphschemaextensions.go -package=mocks . OpenGraphSchemaService +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" + gomock "go.uber.org/mock/gomock" +) + +// MockOpenGraphSchemaService is a mock of OpenGraphSchemaService interface. +type MockOpenGraphSchemaService struct { + ctrl *gomock.Controller + recorder *MockOpenGraphSchemaServiceMockRecorder + isgomock struct{} +} + +// MockOpenGraphSchemaServiceMockRecorder is the mock recorder for MockOpenGraphSchemaService. +type MockOpenGraphSchemaServiceMockRecorder struct { + mock *MockOpenGraphSchemaService +} + +// NewMockOpenGraphSchemaService creates a new mock instance. +func NewMockOpenGraphSchemaService(ctrl *gomock.Controller) *MockOpenGraphSchemaService { + mock := &MockOpenGraphSchemaService{ctrl: ctrl} + mock.recorder = &MockOpenGraphSchemaServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOpenGraphSchemaService) EXPECT() *MockOpenGraphSchemaServiceMockRecorder { + return m.recorder +} + +// UpsertSchemaEnvironmentWithPrincipalKinds mocks base method. +func (m *MockOpenGraphSchemaService) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environments []v2.Environment) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertSchemaEnvironmentWithPrincipalKinds", ctx, schemaExtensionId, environments) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpsertSchemaEnvironmentWithPrincipalKinds indicates an expected call of UpsertSchemaEnvironmentWithPrincipalKinds. +func (mr *MockOpenGraphSchemaServiceMockRecorder) UpsertSchemaEnvironmentWithPrincipalKinds(ctx, schemaExtensionId, environments any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertSchemaEnvironmentWithPrincipalKinds", reflect.TypeOf((*MockOpenGraphSchemaService)(nil).UpsertSchemaEnvironmentWithPrincipalKinds), ctx, schemaExtensionId, environments) +} diff --git a/cmd/api/src/api/v2/model.go b/cmd/api/src/api/v2/model.go index e7f2a3246c..c66cd94019 100644 --- a/cmd/api/src/api/v2/model.go +++ b/cmd/api/src/api/v2/model.go @@ -115,6 +115,7 @@ type Resources struct { Authenticator api.Authenticator IngestSchema upload.IngestSchema FileService fs.Service + openGraphSchemaService OpenGraphSchemaService } func NewResources( @@ -127,6 +128,7 @@ func NewResources( authorizer auth.Authorizer, authenticator api.Authenticator, ingestSchema upload.IngestSchema, + openGraphSchemaService OpenGraphSchemaService, ) Resources { return Resources{ Decoder: schema.NewDecoder(), @@ -141,5 +143,6 @@ func NewResources( Authenticator: authenticator, IngestSchema: ingestSchema, FileService: &fs.Client{}, + openGraphSchemaService: openGraphSchemaService, } } diff --git a/cmd/api/src/api/v2/opengraphschema.go b/cmd/api/src/api/v2/opengraphschema.go new file mode 100644 index 0000000000..50446d6a04 --- /dev/null +++ b/cmd/api/src/api/v2/opengraphschema.go @@ -0,0 +1,62 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package v2 + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/specterops/bloodhound/cmd/api/src/api" +) + +//go:generate go run go.uber.org/mock/mockgen -copyright_file ../../../../../LICENSE.header -destination=./mocks/graphschemaextensions.go -package=mocks . OpenGraphSchemaService +type OpenGraphSchemaService interface { + UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environments []Environment) error +} + +type SchemaUploadRequest struct { + ID int32 `` + Environments []Environment `json:"environments"` +} + +type Environment struct { + EnvironmentKind string `json:"environmentKind"` + SourceKind string `json:"sourceKind"` + PrincipalKinds []string `json:"principalKinds"` +} + +func (s Resources) OpenGraphSchemaIngest(response http.ResponseWriter, request *http.Request) { + var ( + ctx = request.Context() + ) + + var req SchemaUploadRequest + if err := json.NewDecoder(request.Body).Decode(&req); err != nil { + api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponsePayloadUnmarshalError, request), response) + return + } + + if err := s.openGraphSchemaService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { + api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusInternalServerError, fmt.Sprintf("Error upserting schema environment with principal kinds: %w", err), request), response) + return + } + + api.WriteBasicResponse(request.Context(), map[string]string{ + "message": "Schema environments uploaded successfully", + }, http.StatusCreated, response) +} diff --git a/cmd/api/src/daemons/datapipe/mocks/cleanup.go b/cmd/api/src/daemons/datapipe/mocks/cleanup.go index ed40825455..d2f7e31808 100644 --- a/cmd/api/src/daemons/datapipe/mocks/cleanup.go +++ b/cmd/api/src/daemons/datapipe/mocks/cleanup.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index 7c52d14b6d..3801cc5248 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -607,6 +607,47 @@ func (s *BloodhoundDB) DeleteSchemaRelationshipFinding(ctx context.Context, find return nil } +func (s *BloodhoundDB) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { + var envPrincipalKind model.SchemaEnvironmentPrincipalKind + + if result := s.db.WithContext(ctx).Raw(` + INSERT INTO schema_environments_principal_kinds (environment_id, principal_kind) + VALUES (?, ?) + RETURNING environment_id, principal_kind`, + environmentId, principalKind).Scan(&envPrincipalKind); result.Error != nil { + return model.SchemaEnvironmentPrincipalKind{}, CheckError(result) + } + + return envPrincipalKind, nil +} + +func (s *BloodhoundDB) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { + var envPrincipalKinds model.SchemaEnvironmentPrincipalKinds + + if result := s.db.WithContext(ctx).Raw(` + SELECT environment_id, principal_kind + FROM schema_environments_principal_kinds + WHERE environment_id = ?`, + environmentId).Scan(&envPrincipalKinds); result.Error != nil { + return nil, CheckError(result) + } + + return envPrincipalKinds, nil +} + +func (s *BloodhoundDB) DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error { + if result := s.db.WithContext(ctx).Exec(` + DELETE FROM schema_environments_principal_kinds + WHERE environment_id = ? AND principal_kind = ?`, + environmentId, principalKind); result.Error != nil { + return CheckError(result) + } else if result.RowsAffected == 0 { + return ErrNotFound + } + + return nil +} + func parseFiltersAndPagination(filters model.Filters, sort model.Sort, skip, limit int) (FilterAndPagination, error) { var ( filtersAndPagination FilterAndPagination diff --git a/cmd/api/src/database/kind.go b/cmd/api/src/database/kind.go index c89049ceb8..87632f6ae3 100644 --- a/cmd/api/src/database/kind.go +++ b/cmd/api/src/database/kind.go @@ -1,19 +1,37 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 package database import ( "context" - "fmt" "github.com/specterops/bloodhound/cmd/api/src/model" ) -// GetKindById gets a row from the kind table by id. -func (s *BloodhoundDB) GetKindById(ctx context.Context, id int32) (model.Kind, error) { +func (s *BloodhoundDB) GetKindByName(ctx context.Context, name string) (model.Kind, error) { + const query = ` + SELECT id, name + FROM kind + WHERE name = $1; + ` + var kind model.Kind + if err := s.db.Raw(query, name).Scan(&kind).Error; err != nil { + return model.Kind{}, err + } - // var kind Kind - return kind, CheckError(s.db.WithContext(ctx).Raw(fmt.Sprintf(` - SELECT id, name - FROM %s WHERE id = ?`, kindTable), id).First(&kind), - ) + return kind, nil } diff --git a/cmd/api/src/database/kind_integration_test.go b/cmd/api/src/database/kind_integration_test.go new file mode 100644 index 0000000000..089db89dc9 --- /dev/null +++ b/cmd/api/src/database/kind_integration_test.go @@ -0,0 +1,75 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +//go:build integration + +package database_test + +import ( + "testing" + + "github.com/specterops/bloodhound/cmd/api/src/model" + "github.com/stretchr/testify/assert" +) + +// the v7.3.0 migration initializes the kind table with Tag_Tier_Zero, so we're +// simply testing the kind exists +func TestGetKindByName(t *testing.T) { + type args struct { + name string + } + type want struct { + err error + kind model.Kind + } + tests := []struct { + name string + args args + setup func() IntegrationTestSuite + want want + }{ + { + name: "Success: Retrieves Kind Tag_Tier_Zero by name", + args: args{ + name: "Tag_Tier_Zero", + }, + setup: func() IntegrationTestSuite { + return setupIntegrationTestSuite(t) + }, + want: want{ + err: nil, + kind: model.Kind{ + ID: 1, + Name: "Tag_Tier_Zero", + }, + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + testSuite := testCase.setup() + defer teardownIntegrationTestSuite(t, &testSuite) + + kind, err := testSuite.BHDatabase.GetKindByName(testSuite.Context, testCase.args.name) + if testCase.want.err != nil { + assert.EqualError(t, testCase.want.err, err.Error()) + } else { + assert.NoError(t, err) + assert.Equal(t, testCase.want.kind, kind) + } + }) + } +} diff --git a/cmd/api/src/database/migration/migrations/v8.5.0.sql b/cmd/api/src/database/migration/migrations/v8.5.0.sql index 204da86fec..78cf431053 100644 --- a/cmd/api/src/database/migration/migrations/v8.5.0.sql +++ b/cmd/api/src/database/migration/migrations/v8.5.0.sql @@ -184,3 +184,7 @@ $$ END IF; END $$; + +-- Insert a test schema extension +INSERT INTO schema_extensions (name, display_name, version, is_builtin) +VALUES ('test_schema', 'Test Schema', '1.0.0', false); diff --git a/cmd/api/src/database/mocks/auth.go b/cmd/api/src/database/mocks/auth.go index 6934b158d0..d696b4804c 100644 --- a/cmd/api/src/database/mocks/auth.go +++ b/cmd/api/src/database/mocks/auth.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index 49772c522e..ae5c1b5e1c 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. @@ -2230,6 +2230,21 @@ func (mr *MockDatabaseMockRecorder) GetSharedSavedQueries(ctx, userID any) *gomo return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSharedSavedQueries", reflect.TypeOf((*MockDatabase)(nil).GetSharedSavedQueries), ctx, userID) } +// GetSourceKindByName mocks base method. +func (m *MockDatabase) GetSourceKindByName(ctx context.Context, name string) (database.SourceKind, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSourceKindByName", ctx, name) + ret0, _ := ret[0].(database.SourceKind) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSourceKindByName indicates an expected call of GetSourceKindByName. +func (mr *MockDatabaseMockRecorder) GetSourceKindByName(ctx, name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceKindByName", reflect.TypeOf((*MockDatabase)(nil).GetSourceKindByName), ctx, name) +} + // GetSourceKinds mocks base method. func (m *MockDatabase) GetSourceKinds(ctx context.Context) ([]database.SourceKind, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/database/sourcekinds.go b/cmd/api/src/database/sourcekinds.go index bf95599dd8..d69ed7dcd4 100644 --- a/cmd/api/src/database/sourcekinds.go +++ b/cmd/api/src/database/sourcekinds.go @@ -28,6 +28,7 @@ type SourceKindsData interface { GetSourceKinds(ctx context.Context) ([]SourceKind, error) DeactivateSourceKindsByName(ctx context.Context, kinds graph.Kinds) error RegisterSourceKind(ctx context.Context) func(sourceKind graph.Kind) error + GetSourceKindByName(ctx context.Context, name string) (SourceKind, error) } // RegisterSourceKind returns a function that inserts a source kind by name, @@ -93,6 +94,33 @@ func (s *BloodhoundDB) GetSourceKinds(ctx context.Context) ([]SourceKind, error) return out, nil } +func (s *BloodhoundDB) GetSourceKindByName(ctx context.Context, name string) (SourceKind, error) { + const query = ` + SELECT id, name, active + FROM source_kinds + WHERE name = $1 AND active = true; + ` + + type rawSourceKind struct { + ID int + Name string + Active bool + } + + var raw rawSourceKind + if err := s.db.Raw(query, name).Scan(&raw).Error; err != nil { + return SourceKind{}, err + } + + kind := SourceKind{ + ID: raw.ID, + Name: graph.StringKind(raw.Name), + Active: raw.Active, + } + + return kind, nil +} + func (s *BloodhoundDB) DeactivateSourceKindsByName(ctx context.Context, kinds graph.Kinds) error { if len(kinds) == 0 { return nil diff --git a/cmd/api/src/database/sourcekinds_integration_test.go b/cmd/api/src/database/sourcekinds_integration_test.go index 8d6cbbab45..329831d2cb 100644 --- a/cmd/api/src/database/sourcekinds_integration_test.go +++ b/cmd/api/src/database/sourcekinds_integration_test.go @@ -206,6 +206,56 @@ func TestGetSourceKinds(t *testing.T) { } } +func TestGetSourceKindByName(t *testing.T) { + type args struct { + name string + } + type want struct { + err error + sourceKind database.SourceKind + } + tests := []struct { + name string + args args + setup func() IntegrationTestSuite + want want + }{ + { + name: "Success: Retrieves Source Kinds by Name", + args: args{ + name: "AZBase", + }, + setup: func() IntegrationTestSuite { + return setupIntegrationTestSuite(t) + }, + want: want{ + err: nil, + // the v8.0.0 migration initializes the source_kinds table with Base, AZBase, so we're + // simply testing the default returned source_kinds + sourceKind: database.SourceKind{ + ID: 2, + Name: graph.StringKind("AZBase"), + Active: true, + }, + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + testSuite := testCase.setup() + defer teardownIntegrationTestSuite(t, &testSuite) + + sourceKind, err := testSuite.BHDatabase.GetSourceKindByName(testSuite.Context, testCase.args.name) + if testCase.want.err != nil { + assert.EqualError(t, testCase.want.err, err.Error()) + } else { + assert.NoError(t, err) + assert.Equal(t, testCase.want.sourceKind, sourceKind) + } + }) + } +} + func TestDeactivateSourceKindsByName(t *testing.T) { type args struct { sourceKind graph.Kinds diff --git a/cmd/api/src/model/graphschema.go b/cmd/api/src/model/graphschema.go index cc343e5197..2575af475d 100644 --- a/cmd/api/src/model/graphschema.go +++ b/cmd/api/src/model/graphschema.go @@ -17,7 +17,6 @@ package model import ( - "fmt" "time" ) @@ -101,20 +100,6 @@ func (GraphSchemaEdgeKind) TableName() string { return "schema_edge_kinds" } -// GraphSchemaEnvironments - slice of environments -type GraphSchemaEnvironments []SchemaEnvironment - -// ToMapKeyedOnEnvironmentAndSource - converts a list of graph schema environments to a map based on environment kind id and source id -func (s GraphSchemaEnvironments) ToMapKeyedOnEnvironmentAndSource() map[string]SchemaEnvironment { - m := make(map[string]SchemaEnvironment) - for _, env := range s { - // Key is environment id + source id separated with an underscore e.g., kindid_sourceid - key := fmt.Sprintf("%d_%d", env.EnvironmentKindId, env.SourceKindId) - m[key] = env - } - return m -} - type SchemaEnvironment struct { Serial SchemaExtensionId int32 `json:"schema_extension_id"` diff --git a/cmd/api/src/model/kind.go b/cmd/api/src/model/kind.go index 786fa0a5bf..8f6402c682 100644 --- a/cmd/api/src/model/kind.go +++ b/cmd/api/src/model/kind.go @@ -1,3 +1,18 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 package model type Kind struct { diff --git a/cmd/api/src/queries/mocks/graph.go b/cmd/api/src/queries/mocks/graph.go index a536fce424..85b683c87c 100644 --- a/cmd/api/src/queries/mocks/graph.go +++ b/cmd/api/src/queries/mocks/graph.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/services/agi/mocks/mock.go b/cmd/api/src/services/agi/mocks/mock.go index d6f0342f8e..d8c4c5695b 100644 --- a/cmd/api/src/services/agi/mocks/mock.go +++ b/cmd/api/src/services/agi/mocks/mock.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/services/dataquality/mocks/mock.go b/cmd/api/src/services/dataquality/mocks/mock.go index 11f0a8a31b..0ffa179ce6 100644 --- a/cmd/api/src/services/dataquality/mocks/mock.go +++ b/cmd/api/src/services/dataquality/mocks/mock.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/services/entrypoint.go b/cmd/api/src/services/entrypoint.go index 10ff850660..e0045fc3d2 100644 --- a/cmd/api/src/services/entrypoint.go +++ b/cmd/api/src/services/entrypoint.go @@ -38,6 +38,7 @@ import ( "github.com/specterops/bloodhound/cmd/api/src/migrations" "github.com/specterops/bloodhound/cmd/api/src/model/appcfg" "github.com/specterops/bloodhound/cmd/api/src/queries" + "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" "github.com/specterops/bloodhound/cmd/api/src/services/upload" "github.com/specterops/bloodhound/packages/go/cache" schema "github.com/specterops/bloodhound/packages/go/graphschema" @@ -109,18 +110,19 @@ func Entrypoint(ctx context.Context, cfg config.Configuration, connections boots startDelay := 0 * time.Second var ( - cl = changelog.NewChangelog(connections.Graph, connections.RDMS, changelog.DefaultOptions()) - pipeline = datapipe.NewPipeline(ctx, cfg, connections.RDMS, connections.Graph, graphQueryCache, ingestSchema, cl) - graphQuery = queries.NewGraphQuery(connections.Graph, graphQueryCache, cfg) - authorizer = auth.NewAuthorizer(connections.RDMS) - datapipeDaemon = datapipe.NewDaemon(pipeline, startDelay, time.Duration(cfg.DatapipeInterval)*time.Second, connections.RDMS) - routerInst = router.NewRouter(cfg, authorizer, fmt.Sprintf(bootstrap.ContentSecurityPolicy, "", "")) - ctxInitializer = database.NewContextInitializer(connections.RDMS) - authenticator = api.NewAuthenticator(cfg, connections.RDMS, ctxInitializer) + cl = changelog.NewChangelog(connections.Graph, connections.RDMS, changelog.DefaultOptions()) + pipeline = datapipe.NewPipeline(ctx, cfg, connections.RDMS, connections.Graph, graphQueryCache, ingestSchema, cl) + graphQuery = queries.NewGraphQuery(connections.Graph, graphQueryCache, cfg) + authorizer = auth.NewAuthorizer(connections.RDMS) + datapipeDaemon = datapipe.NewDaemon(pipeline, startDelay, time.Duration(cfg.DatapipeInterval)*time.Second, connections.RDMS) + routerInst = router.NewRouter(cfg, authorizer, fmt.Sprintf(bootstrap.ContentSecurityPolicy, "", "")) + ctxInitializer = database.NewContextInitializer(connections.RDMS) + authenticator = api.NewAuthenticator(cfg, connections.RDMS, ctxInitializer) + openGraphSchemaService = opengraphschema.NewOpenGraphSchemaService(connections.RDMS) ) registration.RegisterFossGlobalMiddleware(&routerInst, cfg, auth.NewIdentityResolver(), authenticator) - registration.RegisterFossRoutes(&routerInst, cfg, connections.RDMS, connections.Graph, graphQuery, apiCache, collectorManifests, authenticator, authorizer, ingestSchema) + registration.RegisterFossRoutes(&routerInst, cfg, connections.RDMS, connections.Graph, graphQuery, apiCache, collectorManifests, authenticator, authorizer, ingestSchema, openGraphSchemaService) // Set neo4j batch and flush sizes neo4jParameters := appcfg.GetNeo4jParameters(ctx, connections.RDMS) diff --git a/cmd/api/src/services/fs/mocks/fs.go b/cmd/api/src/services/fs/mocks/fs.go index 7d05415d18..5ba753d534 100644 --- a/cmd/api/src/services/fs/mocks/fs.go +++ b/cmd/api/src/services/fs/mocks/fs.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/services/graphify/mocks/ingest.go b/cmd/api/src/services/graphify/mocks/ingest.go index 81b68bbc28..c897f58c6d 100644 --- a/cmd/api/src/services/graphify/mocks/ingest.go +++ b/cmd/api/src/services/graphify/mocks/ingest.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/services/oidc/mocks/oidc.go b/cmd/api/src/services/oidc/mocks/oidc.go index 704420ef38..6faa71e5d5 100644 --- a/cmd/api/src/services/oidc/mocks/oidc.go +++ b/cmd/api/src/services/oidc/mocks/oidc.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/services/opengraphschema/environment.go b/cmd/api/src/services/opengraphschema/environment.go index 492da32a49..5c45693808 100644 --- a/cmd/api/src/services/opengraphschema/environment.go +++ b/cmd/api/src/services/opengraphschema/environment.go @@ -20,37 +20,106 @@ import ( "errors" "fmt" + v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" "github.com/specterops/bloodhound/cmd/api/src/database" "github.com/specterops/bloodhound/cmd/api/src/model" + "github.com/specterops/dawgs/graph" ) -func (o *OpenGraphSchemaService) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, environment model.SchemaEnvironment, principalKinds []model.SchemaEnvironmentPrincipalKind) error { - // Validate environment - if err := o.validateSchemaEnvironment(ctx, environment); err != nil { - return fmt.Errorf("error validating schema environment: %w", err) - } +// UpsertSchemaEnvironmentWithPrincipalKinds takes a slice of environments, validates and translates each environment. +// The translation is used to upsert the environments into the database. +// If an existing environment is found to already exist in the database, the existing environment will be removed and the new one will be uploaded. +func (o *OpenGraphSchemaService) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environments []v2.Environment) error { + for _, env := range environments { + environment := model.SchemaEnvironment{ + SchemaExtensionId: schemaExtensionId, + } - // Validate principal kinds - for _, kind := range principalKinds { - if err := o.validateSchemaEnvironmentPrincipalKind(ctx, kind.PrincipalKind); err != nil { - return fmt.Errorf("error validating principal kind: %w", err) + if updatedEnv, principalKinds, err := o.validateAndTranslateEnvironment(ctx, environment, env); err != nil { + return fmt.Errorf("error validating and translating environment: %w", err) + } else if envID, err := o.upsertSchemaEnvironment(ctx, updatedEnv); err != nil { + return fmt.Errorf("error upserting schema environment: %w", err) + } else if err := o.upsertPrincipalKinds(ctx, envID, principalKinds); err != nil { + return fmt.Errorf("error upserting principal kinds: %w", err) } } - // Upsert the environment - id, err := o.upsertSchemaEnvironment(ctx, environment) - if err != nil { - return fmt.Errorf("error upserting schema environment: %w", err) + return nil +} + +// validateAndTranslateEnvironment validates that the environment kind, source kind, and the principal kinds exist in the database. +// It is then translated from the API model to the Database model to prepare it for insert. +func (o *OpenGraphSchemaService) validateAndTranslateEnvironment(ctx context.Context, environment model.SchemaEnvironment, env v2.Environment) (model.SchemaEnvironment, []model.SchemaEnvironmentPrincipalKind, error) { + if envKind, err := o.validateAndTranslateEnvironmentKind(ctx, env.EnvironmentKind); err != nil { + return model.SchemaEnvironment{}, nil, err + } else if sourceKindID, err := o.validateAndTranslateSourceKind(ctx, env.SourceKind); err != nil { + return model.SchemaEnvironment{}, nil, err + } else if principalKinds, err := o.validateAndTranslatePrincipalKinds(ctx, env.PrincipalKinds); err != nil { + return model.SchemaEnvironment{}, nil, err + } else { + // Update environment with translated IDs + environment.EnvironmentKindId = int32(envKind.ID) + environment.SourceKindId = sourceKindID + + return environment, principalKinds, nil } +} - // Upsert principal kinds for environment - if err := o.upsertPrincipalKinds(ctx, id, principalKinds); err != nil { - return fmt.Errorf("error upserting principal kinds: %w", err) +// validateAndTranslateEnvironmentKind validates that the environment kind exists in the kinds table. +func (o *OpenGraphSchemaService) validateAndTranslateEnvironmentKind(ctx context.Context, environmentKindName string) (model.Kind, error) { + if envKind, err := o.openGraphSchemaRepository.GetKindByName(ctx, environmentKindName); err != nil && !errors.Is(err, database.ErrNotFound) { + return model.Kind{}, fmt.Errorf("error retrieving environment kind '%s': %w", environmentKindName, err) + } else if errors.Is(err, database.ErrNotFound){ + return model.Kind{}, fmt.Errorf("environment kind '%s' not found", environmentKindName) + } else { + return envKind, nil } +} - return nil +// validateAndTranslateSourceKind validates that the source kind exists in the source_kinds table. +// If not found, it registers the source kind and returns its ID so it can be added to the Environment object. +func (o *OpenGraphSchemaService) validateAndTranslateSourceKind(ctx context.Context, sourceKindName string) (int32, error) { + if sourceKind, err := o.openGraphSchemaRepository.GetSourceKindByName(ctx, sourceKindName); err != nil && !errors.Is(err, database.ErrNotFound) { + return 0, fmt.Errorf("error retrieving source kind '%s': %w", sourceKindName, err) + } else if err == nil { + return int32(sourceKind.ID), nil + } + + // If source kind is not found, register it. If it exists and is inactive, it will automatically update as active. + kindType := graph.StringKind(sourceKindName) + if err := o.openGraphSchemaRepository.RegisterSourceKind(ctx)(kindType); err != nil { + return 0, fmt.Errorf("error registering source kind '%s': %w", sourceKindName, err) + } + + if sourceKind, err := o.openGraphSchemaRepository.GetSourceKindByName(ctx, sourceKindName); err != nil { + return 0, fmt.Errorf("error retrieving newly registered source kind '%s': %w", sourceKindName, err) + } else { + return int32(sourceKind.ID), nil + } +} + +// validateAndTranslatePrincipalKinds ensures all principalKinds exist in the kinds table. +// It also translates them to IDs so they can be upserted into the database. +func (o *OpenGraphSchemaService) validateAndTranslatePrincipalKinds(ctx context.Context, principalKindNames []string) ([]model.SchemaEnvironmentPrincipalKind, error) { + principalKinds := make([]model.SchemaEnvironmentPrincipalKind, len(principalKindNames)) + + for i, kindName := range principalKindNames { + if kind, err := o.openGraphSchemaRepository.GetKindByName(ctx, kindName); err != nil && !errors.Is(err, database.ErrNotFound) { + return nil, fmt.Errorf("error retrieving principal kind by name '%s': %w", kindName, err) + } else if errors.Is(err, database.ErrNotFound){ + return nil, fmt.Errorf("principal kind '%s' not found", kindName) + } else { + principalKinds[i] = model.SchemaEnvironmentPrincipalKind{ + PrincipalKind: int32(kind.ID), + } + } + } + + return principalKinds, nil } +// upsertSchemaEnvironment creates or updates a schema environment. +// If an environment with the given ID exists, it deletes it first before creating the new one. func (o *OpenGraphSchemaService) upsertSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) (int32, error) { if existing, err := o.openGraphSchemaRepository.GetSchemaEnvironmentById(ctx, graphSchema.ID); err != nil && !errors.Is(err, database.ErrNotFound) { return 0, fmt.Errorf("error retrieving schema environment id %d: %w", graphSchema.ID, err) @@ -69,47 +138,7 @@ func (o *OpenGraphSchemaService) upsertSchemaEnvironment(ctx context.Context, gr } } -/* -Validations: https://github.com/SpecterOps/BloodHound/blob/73b569a340ef5cd459b383e3e42e707b201193ee/rfc/bh-rfc-4.md#10-validation-rules-for-environments - 1. Ensure the specified environmentKind exists in kinds table - ** QUESTION: Documentation states to use the kind table. Is there already a database method to query the kind table to do this validation or was I supposed to create it? - 2. Ensure the specified sourceKind exists in source_kinds table (create if it doesn't, reactivate if it does) -*/ -func (o *OpenGraphSchemaService) validateSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) error { - // Validate environment kind id exists in kinds table - if _, err := o.openGraphSchemaRepository.GetKindById(ctx, graphSchema.EnvironmentKindId); err != nil { - return fmt.Errorf("error retrieving environment kind: %w", err) - } - - // Get all source kinds - if sourceKinds, err := o.openGraphSchemaRepository.GetSourceKinds(ctx); err != nil { - return fmt.Errorf("error retrieving source kinds: %w", err) - } else { - // Check if source kind exists - found := false - for _, kind := range sourceKinds { - if graphSchema.SourceKindId == int32(kind.ID) { - found = true - break - } - } - - if !found { - /* - ** QUESTION: Example Environment Schema: https://github.com/SpecterOps/BloodHound/blob/73b569a340ef5cd459b383e3e42e707b201193ee/rfc/bh-rfc-4.md#9-environments-and-principal-kinds - The RFC example uses source kind names (strings) in the environment schema, but our model uses IDs (int32). To register a new source kind, we need the - kind name/data, not just the ID. Cannot register with only an ID - kind id/name is required for registration. - RegisterSourceKind(ctx context.Context) func(sourceKind graph.Kind) error - - For now, this validates that the source kind should exist. - */ - return fmt.Errorf("invalid source kind id %d", graphSchema.SourceKindId) - } - } - - return nil -} - +// upsertPrincipalKinds deletes all existing principal kinds for an environment and creates new ones. func (o *OpenGraphSchemaService) upsertPrincipalKinds(ctx context.Context, environmentID int32, principalKinds []model.SchemaEnvironmentPrincipalKind) error { if existingKinds, err := o.openGraphSchemaRepository.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentID); err != nil && !errors.Is(err, database.ErrNotFound) { return fmt.Errorf("error retrieving existing principal kinds for environment %d: %w", environmentID, err) @@ -131,19 +160,3 @@ func (o *OpenGraphSchemaService) upsertPrincipalKinds(ctx context.Context, envir return nil } - -/* -Validations: https://github.com/SpecterOps/BloodHound/blob/73b569a340ef5cd459b383e3e42e707b201193ee/rfc/bh-rfc-4.md#10-validation-rules-for-environments -1. Ensure all principalKinds exist in kinds table. - -** QUESTION: Documentation states to use the kind table. Is there already a database method to query the kind table to do this validation or was I supposed to create it? -*/ -func (o *OpenGraphSchemaService) validateSchemaEnvironmentPrincipalKind(ctx context.Context, kindID int32) error { - if _, err := o.openGraphSchemaRepository.GetKindById(ctx, kindID); err != nil && !errors.Is(err, database.ErrNotFound) { - return fmt.Errorf("error retrieving kind by id: %w", err) - } else if errors.Is(err, database.ErrNotFound) { - return fmt.Errorf("invalid principal kind id %d", kindID) - } - - return nil -} diff --git a/cmd/api/src/services/opengraphschema/environment_test.go b/cmd/api/src/services/opengraphschema/environment_test.go index 991509a6ed..e50362a423 100644 --- a/cmd/api/src/services/opengraphschema/environment_test.go +++ b/cmd/api/src/services/opengraphschema/environment_test.go @@ -1,3 +1,18 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 package opengraphschema_test import ( @@ -5,6 +20,7 @@ import ( "errors" "testing" + v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" "github.com/specterops/bloodhound/cmd/api/src/database" "github.com/specterops/bloodhound/cmd/api/src/model" "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" @@ -19,8 +35,8 @@ func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *tes mockOpenGraphSchema *schemamocks.MockOpenGraphSchemaRepository } type args struct { - environment model.SchemaEnvironment - principalKinds []model.SchemaEnvironmentPrincipalKind + schemaExtensionId int32 + environments []v2.Environment } tests := []struct { name string @@ -29,179 +45,188 @@ func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *tes args args expected error }{ + // Validation: Environment Kind { - name: "Error: Validation - Failed to retrieve environment kind", + name: "Error: openGraphSchemaRepository.GetKindByName environment kind name not found in the database", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(1), - EnvironmentKindId: int32(1), - SourceKindId: int32(1), + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{}, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(1)).Return(model.Kind{}, errors.New("error")) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{}, database.ErrNotFound) }, - expected: errors.New("error validating schema environment: error retrieving environment kind: error"), + expected: errors.New("error validating and translating environment: environment kind 'Domain' not found"), }, { - name: "Error: Validation - Environment Kind doesn't exist in Kinds table", + name: "Error: openGraphSchemaRepository.GetKindByName failed to retrieve environment kind from database", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(1), - EnvironmentKindId: int32(1), - SourceKindId: int32(1), + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{}, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(1)).Return(model.Kind{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{}, errors.New("error")) }, - expected: errors.New("error validating schema environment: error retrieving environment kind: entity not found"), + expected: errors.New("error validating and translating environment: error retrieving environment kind 'Domain': error"), }, + // Validation: Source Kind { - name: "Error: Validation - Failed to retrieve source kinds", + name: "Error: validateAndTranslateSourceKind failed to retrieve source kind from database", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(1), - EnvironmentKindId: int32(1), - SourceKindId: int32(1), + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{}, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(1)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{}, errors.New("error")) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 1}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{}, errors.New("error")) }, - expected: errors.New("error validating schema environment: error retrieving source kinds: error"), + expected: errors.New("error validating and translating environment: error retrieving source kind 'Base': error"), }, { - name: "Error: Validation - Source Kind doesn't exist", + name: "Error: validateAndTranslateSourceKind source kind name doesn't exist in database, registration fails", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(1), - EnvironmentKindId: int32(3), - SourceKindId: int32(1), + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{}, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 1}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().RegisterSourceKind(gomock.Any()).Return(func(kind graph.Kind) error { + return errors.New("error") + }) }, - expected: errors.New("error validating schema environment: invalid source kind id 1"), + expected: errors.New("error validating and translating environment: error registering source kind 'Base': error"), }, { - name: "Error: Validation - Failed to retrieve principal kind", + name: "Error: validateAndTranslateSourceKind source kind name doesn't exist in database, registration succeeds but fetch fails", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(1), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), - }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ { - EnvironmentId: int32(1), - PrincipalKind: int32(99), + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, }, }, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - // Environment validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 1}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().RegisterSourceKind(gomock.Any()).Return(func(kind graph.Kind) error { + return nil + }) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{}, errors.New("error")) + }, + expected: errors.New("error validating and translating environment: error retrieving newly registered source kind 'Base': error"), + }, + // Validation: Principal Kind + { + name: "Error: validateAndTranslatePrincipalKinds principal kind not found in database", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ { - ID: 3, - Name: graph.StringKind("kind"), + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User", "InvalidKind"}, }, - }, nil) - // Principal kind validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(99)).Return(model.Kind{}, errors.New("error")) + }, + }, + setupMocks: func(t *testing.T, mocks *mocks) { + t.Helper() + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 1}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 2}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "InvalidKind").Return(model.Kind{}, database.ErrNotFound) }, - expected: errors.New("error validating principal kind: error retrieving kind by id: error"), + expected: errors.New("error validating and translating environment: principal kind 'InvalidKind' not found"), }, { - name: "Error: Validation - Principal Kind doesn't exist", + name: "Error: validateAndTranslatePrincipalKinds failed to retrieve principal kind from database", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(1), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), - }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ { - EnvironmentId: int32(1), - PrincipalKind: int32(99), + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User", "InvalidKind"}, }, }, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - // Environment validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) - // Principal kind validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(99)).Return(model.Kind{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 1}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 2}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "InvalidKind").Return(model.Kind{}, errors.New("error")) }, - expected: errors.New("error validating principal kind: invalid principal kind id 99"), + expected: errors.New("error validating and translating environment: error retrieving principal kind by name 'InvalidKind': error"), }, + // Upsert Schema Environment { - name: "Error: GetSchemaEnvironmentById fails", + name: "Error: upsertSchemaEnvironment error retrieving schema environment from database", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(3), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), + schemaExtensionId: int32(3), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{}, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, errors.New("error")) }, expected: errors.New("error upserting schema environment: error retrieving schema environment id 0: error"), }, { - name: "Error: DeleteSchemaEnvironment fails", + name: "Error: upsertSchemaEnvironment error deleting schema environment", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(3), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), + schemaExtensionId: int32(3), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{}, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{ Serial: model.Serial{ID: 5}, }, nil) @@ -210,24 +235,21 @@ func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *tes expected: errors.New("error upserting schema environment: error deleting schema environment 5: error"), }, { - name: "Error: CreateSchemaEnvironment fails after delete", + name: "Error: upsertSchemaEnvironment error creating schema environment after deletion", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(3), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), + schemaExtensionId: int32(3), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{}, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{ Serial: model.Serial{ID: 5}, }, nil) @@ -237,55 +259,45 @@ func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *tes expected: errors.New("error upserting schema environment: error creating schema environment: error"), }, { - name: "Error: CreateSchemaEnvironment fails on new environment", + name: "Error: upsertSchemaEnvironment error creating new schema environment", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(3), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), + schemaExtensionId: int32(3), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{}, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{}, errors.New("error")) }, expected: errors.New("error upserting schema environment: error creating schema environment: error"), }, + // Upsert Principal Kinds { - name: "Error: GetSchemaEnvironmentPrincipalKindsByEnvironmentId fails", + name: "Error: upsertPrincipalKinds error getting principal kinds by environment id", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(3), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), - }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{ + schemaExtensionId: int32(3), + environments: []v2.Environment{ { - PrincipalKind: int32(3), + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, }, }, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - // Environment validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) - // Principal kind validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + // Validation and translation + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 3}, nil) // Environment upsert mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) @@ -299,31 +311,23 @@ func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *tes expected: errors.New("error upserting principal kinds: error retrieving existing principal kinds for environment 10: error"), }, { - name: "Error: DeleteSchemaEnvironmentPrincipalKind fails", + name: "Error: upsertPrincipalKinds error deleting principal kinds", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(3), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), - }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{ + schemaExtensionId: int32(3), + environments: []v2.Environment{ { - PrincipalKind: int32(3), + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, }, }, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - // Environment validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) - // Principal kind validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + // Validation and translation + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 3}, nil) // Environment upsert mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) @@ -343,31 +347,23 @@ func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *tes expected: errors.New("error upserting principal kinds: error deleting principal kind 5 for environment 10: error"), }, { - name: "Error: CreateSchemaEnvironmentPrincipalKind fails", + name: "Error: upsertPrincipalKinds error creating principal kinds", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(3), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), - }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{ + schemaExtensionId: int32(3), + environments: []v2.Environment{ { - PrincipalKind: int32(3), + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, }, }, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - // Environment validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) - // Principal kind validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + // Validation and translation + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 3}, nil) // Environment upsert mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) @@ -384,37 +380,22 @@ func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *tes { name: "Success: Create new environment with principal kinds", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(3), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), - }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{ + schemaExtensionId: int32(3), + environments: []v2.Environment{ { - PrincipalKind: int32(3), - }, - { - PrincipalKind: int32(4), + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User", "Computer"}, }, }, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - // Environment validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind3"), - }, - { - ID: 4, - Name: graph.StringKind("kind4"), - }, - }, nil) - // Principal kind validations - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(4)).Return(model.Kind{}, nil) + // Validation and translation + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 4}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Computer").Return(model.Kind{ID: 5}, nil) // Environment upsert mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) @@ -424,89 +405,87 @@ func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *tes // Principal kinds upsert mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(3)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(4)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(5)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) }, expected: nil, }, { - name: "Success: Update existing environment and replace principal kinds", + name: "Success: Create environment with source kind registration", args: args{ - environment: model.SchemaEnvironment{ - Serial: model.Serial{ID: 5}, - SchemaExtensionId: int32(3), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), - }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{ + schemaExtensionId: int32(3), + environments: []v2.Environment{ { - PrincipalKind: int32(3), + EnvironmentKind: "Domain", + SourceKind: "NewSource", + PrincipalKinds: []string{}, }, }, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - // Environment validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) - // Principal kind validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) + // Validation and translation + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) + // Source kind not found, register it + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "NewSource").Return(database.SourceKind{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().RegisterSourceKind(gomock.Any()).Return(func(kind graph.Kind) error { + return nil + }) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "NewSource").Return(database.SourceKind{ID: 10}, nil) - // Environment upsert (delete and recreate) - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(5)).Return(model.SchemaEnvironment{ - Serial: model.Serial{ID: 5}, - }, nil) - mocks.mockOpenGraphSchema.EXPECT().DeleteSchemaEnvironment(gomock.Any(), int32(5)).Return(nil) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ + // Environment upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(10)).Return(model.SchemaEnvironment{ Serial: model.Serial{ID: 10}, }, nil) - // Principal kinds upsert (delete old, create new) - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return([]model.SchemaEnvironmentPrincipalKind{ - { - EnvironmentId: int32(10), - PrincipalKind: int32(99), - }, - }, nil) - mocks.mockOpenGraphSchema.EXPECT().DeleteSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(99)).Return(nil) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(3)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) + // Principal kinds upsert + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, database.ErrNotFound) }, expected: nil, }, { - name: "Success: Create environment with no principal kinds", + name: "Success: Process multiple environments", args: args{ - environment: model.SchemaEnvironment{ - SchemaExtensionId: int32(3), - EnvironmentKindId: int32(3), - SourceKindId: int32(3), + schemaExtensionId: int32(3), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, + }, + { + EnvironmentKind: "AzureAD", + SourceKind: "AzureHound", + PrincipalKinds: []string{"User", "Group"}, + }, }, - principalKinds: []model.SchemaEnvironmentPrincipalKind{}, }, setupMocks: func(t *testing.T, mocks *mocks) { t.Helper() - // Environment validation - mocks.mockOpenGraphSchema.EXPECT().GetKindById(gomock.Any(), int32(3)).Return(model.Kind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKinds(gomock.Any()).Return([]database.SourceKind{ - { - ID: 3, - Name: graph.StringKind("kind"), - }, - }, nil) - - // Environment upsert + // First environment + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 4}, nil) mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ Serial: model.Serial{ID: 10}, }, nil) - - // Principal kinds upsert (no existing, no new) mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(4)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) + + // Second environment + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "AzureAD").Return(model.Kind{ID: 5}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "AzureHound").Return(database.SourceKind{ID: 6}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 4}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Group").Return(model.Kind{ID: 7}, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(5), int32(6)).Return(model.SchemaEnvironment{ + Serial: model.Serial{ID: 11}, + }, nil) + mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(11)).Return(nil, database.ErrNotFound) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(11), int32(4)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) + mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(11), int32(7)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) }, expected: nil, }, @@ -524,9 +503,9 @@ func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *tes graphService := opengraphschema.NewOpenGraphSchemaService(mocks.mockOpenGraphSchema) - err := graphService.UpsertSchemaEnvironmentWithPrincipalKinds(context.Background(), tt.args.environment, tt.args.principalKinds) + err := graphService.UpsertSchemaEnvironmentWithPrincipalKinds(context.Background(), tt.args.schemaExtensionId, tt.args.environments) if tt.expected != nil { - assert.EqualError(t, tt.expected, err.Error()) + assert.EqualError(t, err, tt.expected.Error()) } else { assert.NoError(t, err) } diff --git a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go index 7723197208..a3793b1572 100644 --- a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go @@ -117,19 +117,19 @@ func (mr *MockOpenGraphSchemaRepositoryMockRecorder) DeleteSchemaEnvironmentPrin return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).DeleteSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) } -// GetKindById mocks base method. -func (m *MockOpenGraphSchemaRepository) GetKindById(ctx context.Context, id int32) (model.Kind, error) { +// GetKindByName mocks base method. +func (m *MockOpenGraphSchemaRepository) GetKindByName(ctx context.Context, name string) (model.Kind, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetKindById", ctx, id) + ret := m.ctrl.Call(m, "GetKindByName", ctx, name) ret0, _ := ret[0].(model.Kind) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetKindById indicates an expected call of GetKindById. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetKindById(ctx, id any) *gomock.Call { +// GetKindByName indicates an expected call of GetKindByName. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetKindByName(ctx, name any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKindById", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetKindById), ctx, id) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKindByName", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetKindByName), ctx, name) } // GetSchemaEnvironmentById mocks base method. @@ -162,19 +162,19 @@ func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetSchemaEnvironmentPrincip return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetSchemaEnvironmentPrincipalKindsByEnvironmentId), ctx, environmentId) } -// GetSourceKinds mocks base method. -func (m *MockOpenGraphSchemaRepository) GetSourceKinds(ctx context.Context) ([]database.SourceKind, error) { +// GetSourceKindByName mocks base method. +func (m *MockOpenGraphSchemaRepository) GetSourceKindByName(ctx context.Context, name string) (database.SourceKind, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSourceKinds", ctx) - ret0, _ := ret[0].([]database.SourceKind) + ret := m.ctrl.Call(m, "GetSourceKindByName", ctx, name) + ret0, _ := ret[0].(database.SourceKind) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetSourceKinds indicates an expected call of GetSourceKinds. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetSourceKinds(ctx any) *gomock.Call { +// GetSourceKindByName indicates an expected call of GetSourceKindByName. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetSourceKindByName(ctx, name any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceKinds", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetSourceKinds), ctx) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceKindByName", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetSourceKindByName), ctx, name) } // RegisterSourceKind mocks base method. diff --git a/cmd/api/src/services/opengraphschema/opengraphschema.go b/cmd/api/src/services/opengraphschema/opengraphschema.go index 1ff658041d..cea2845ee9 100644 --- a/cmd/api/src/services/opengraphschema/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/opengraphschema.go @@ -28,7 +28,7 @@ import ( // OpenGraphSchemaRepository - type OpenGraphSchemaRepository interface { // Kinds - GetKindById(ctx context.Context, id int32) (model.Kind, error) + GetKindByName(ctx context.Context, name string) (model.Kind, error) // Environment CreateSchemaEnvironment(ctx context.Context, schemaExtensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) @@ -37,7 +37,7 @@ type OpenGraphSchemaRepository interface { // Source Kinds RegisterSourceKind(ctx context.Context) func(sourceKind graph.Kind) error - GetSourceKinds(ctx context.Context) ([]database.SourceKind, error) + GetSourceKindByName(ctx context.Context, name string) (database.SourceKind, error) // Principal Kinds CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) diff --git a/cmd/api/src/services/saml/mocks/saml.go b/cmd/api/src/services/saml/mocks/saml.go index f71839355b..e6029f5654 100644 --- a/cmd/api/src/services/saml/mocks/saml.go +++ b/cmd/api/src/services/saml/mocks/saml.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/services/upload/mocks/mock.go b/cmd/api/src/services/upload/mocks/mock.go index bd1cd94039..a3b90e7fb3 100644 --- a/cmd/api/src/services/upload/mocks/mock.go +++ b/cmd/api/src/services/upload/mocks/mock.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/utils/validation/mocks/validator.go b/cmd/api/src/utils/validation/mocks/validator.go index 309fe51330..536c0385a5 100644 --- a/cmd/api/src/utils/validation/mocks/validator.go +++ b/cmd/api/src/utils/validation/mocks/validator.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/vendormocks/dawgs/graph/mock.go b/cmd/api/src/vendormocks/dawgs/graph/mock.go index 45ea562541..14f78cb94d 100644 --- a/cmd/api/src/vendormocks/dawgs/graph/mock.go +++ b/cmd/api/src/vendormocks/dawgs/graph/mock.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/vendormocks/io/fs/mock.go b/cmd/api/src/vendormocks/io/fs/mock.go index 5c094281bd..f90a77b286 100644 --- a/cmd/api/src/vendormocks/io/fs/mock.go +++ b/cmd/api/src/vendormocks/io/fs/mock.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/cmd/api/src/vendormocks/neo4j/neo4j-go-driver/v5/neo4j/mock.go b/cmd/api/src/vendormocks/neo4j/neo4j-go-driver/v5/neo4j/mock.go index a4adb00b99..1e3a6ab710 100644 --- a/cmd/api/src/vendormocks/neo4j/neo4j-go-driver/v5/neo4j/mock.go +++ b/cmd/api/src/vendormocks/neo4j/neo4j-go-driver/v5/neo4j/mock.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. diff --git a/packages/go/crypto/mocks/digest.go b/packages/go/crypto/mocks/digest.go index 3c1acaa2ce..ddd118ab95 100644 --- a/packages/go/crypto/mocks/digest.go +++ b/packages/go/crypto/mocks/digest.go @@ -1,4 +1,4 @@ -// Copyright 2025 Specter Ops, Inc. +// Copyright 2026 Specter Ops, Inc. // // Licensed under the Apache License, Version 2.0 // you may not use this file except in compliance with the License. From bc97cb007ff88dd577ea85d7f8bf98fac78ebf4d Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Tue, 6 Jan 2026 18:14:41 -0600 Subject: [PATCH 03/25] cleanup --- cmd/api/src/api/v2/opengraphschema.go | 8 ++++---- cmd/api/src/database/migration/migrations/v8.5.0.sql | 4 ---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/cmd/api/src/api/v2/opengraphschema.go b/cmd/api/src/api/v2/opengraphschema.go index 50446d6a04..2a6459876f 100644 --- a/cmd/api/src/api/v2/opengraphschema.go +++ b/cmd/api/src/api/v2/opengraphschema.go @@ -40,6 +40,7 @@ type Environment struct { PrincipalKinds []string `json:"principalKinds"` } +// TODO: Implement this - barebones in order to test handler. func (s Resources) OpenGraphSchemaIngest(response http.ResponseWriter, request *http.Request) { var ( ctx = request.Context() @@ -51,12 +52,11 @@ func (s Resources) OpenGraphSchemaIngest(response http.ResponseWriter, request * return } + // TODO: Pass Extension ID instead of harcoded value if err := s.openGraphSchemaService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { - api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusInternalServerError, fmt.Sprintf("Error upserting schema environment with principal kinds: %w", err), request), response) + api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusInternalServerError, fmt.Sprintf("error upserting environment with principal kinds: %v", err), request), response) return } - api.WriteBasicResponse(request.Context(), map[string]string{ - "message": "Schema environments uploaded successfully", - }, http.StatusCreated, response) + api.WriteBasicResponse(request.Context(), "Success", http.StatusCreated, response) } diff --git a/cmd/api/src/database/migration/migrations/v8.5.0.sql b/cmd/api/src/database/migration/migrations/v8.5.0.sql index 78cf431053..204da86fec 100644 --- a/cmd/api/src/database/migration/migrations/v8.5.0.sql +++ b/cmd/api/src/database/migration/migrations/v8.5.0.sql @@ -184,7 +184,3 @@ $$ END IF; END $$; - --- Insert a test schema extension -INSERT INTO schema_extensions (name, display_name, version, is_builtin) -VALUES ('test_schema', 'Test Schema', '1.0.0', false); From bce213baeee6d305988e2e03f8cd38dee4cf3f56 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Wed, 7 Jan 2026 12:38:52 -0600 Subject: [PATCH 04/25] transactions that hopefully satisfy both databases --- cmd/api/src/api/v2/opengraphschema.go | 12 ++++----- .../src/services/opengraphschema/extension.go | 25 +++++++++++++++++++ .../opengraphschema/mocks/opengraphschema.go | 20 +++++++++++++++ .../opengraphschema/opengraphschema.go | 5 ++++ 4 files changed, 55 insertions(+), 7 deletions(-) create mode 100644 cmd/api/src/services/opengraphschema/extension.go diff --git a/cmd/api/src/api/v2/opengraphschema.go b/cmd/api/src/api/v2/opengraphschema.go index 2a6459876f..6600045f60 100644 --- a/cmd/api/src/api/v2/opengraphschema.go +++ b/cmd/api/src/api/v2/opengraphschema.go @@ -26,11 +26,10 @@ import ( //go:generate go run go.uber.org/mock/mockgen -copyright_file ../../../../../LICENSE.header -destination=./mocks/graphschemaextensions.go -package=mocks . OpenGraphSchemaService type OpenGraphSchemaService interface { - UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environments []Environment) error + UpsertGraphSchemaExtension(ctx context.Context, req GraphSchemaExtension) error } -type SchemaUploadRequest struct { - ID int32 `` +type GraphSchemaExtension struct { Environments []Environment `json:"environments"` } @@ -46,15 +45,14 @@ func (s Resources) OpenGraphSchemaIngest(response http.ResponseWriter, request * ctx = request.Context() ) - var req SchemaUploadRequest + var req GraphSchemaExtension if err := json.NewDecoder(request.Body).Decode(&req); err != nil { api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, api.ErrorResponsePayloadUnmarshalError, request), response) return } - // TODO: Pass Extension ID instead of harcoded value - if err := s.openGraphSchemaService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { - api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusInternalServerError, fmt.Sprintf("error upserting environment with principal kinds: %v", err), request), response) + if err := s.openGraphSchemaService.UpsertGraphSchemaExtension(ctx, req); err != nil { + api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusInternalServerError, fmt.Sprintf("error upserting graph schema extension: %v", err), request), response) return } diff --git a/cmd/api/src/services/opengraphschema/extension.go b/cmd/api/src/services/opengraphschema/extension.go new file mode 100644 index 0000000000..22848aa05f --- /dev/null +++ b/cmd/api/src/services/opengraphschema/extension.go @@ -0,0 +1,25 @@ +package opengraphschema + +import ( + "context" + "fmt" + + v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" +) + +func (o *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { + return o.transactor.WithTransaction(ctx, func(repo OpenGraphSchemaRepository) error { + txService := &OpenGraphSchemaService{ + openGraphSchemaRepository: repo, + transactor: o.transactor, + } + + // Upsert environments with principal kinds + // TODO: Temporary hardcoded extension ID + if err := txService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { + return fmt.Errorf("failed to upload environments with principal kinds: %w", err) + } + + return nil + }) +} diff --git a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go index a3793b1572..8ff7dabfa8 100644 --- a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go @@ -27,6 +27,7 @@ package mocks import ( context "context" + sql "database/sql" reflect "reflect" database "github.com/specterops/bloodhound/cmd/api/src/database" @@ -190,3 +191,22 @@ func (mr *MockOpenGraphSchemaRepositoryMockRecorder) RegisterSourceKind(ctx any) mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterSourceKind", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).RegisterSourceKind), ctx) } + +// Transaction mocks base method. +func (m *MockOpenGraphSchemaRepository) Transaction(ctx context.Context, fn func(*database.BloodhoundDB) error, opts ...*sql.TxOptions) error { + m.ctrl.T.Helper() + varargs := []any{ctx, fn} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Transaction", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Transaction indicates an expected call of Transaction. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) Transaction(ctx, fn any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx, fn}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Transaction", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).Transaction), varargs...) +} diff --git a/cmd/api/src/services/opengraphschema/opengraphschema.go b/cmd/api/src/services/opengraphschema/opengraphschema.go index cea2845ee9..a37535f639 100644 --- a/cmd/api/src/services/opengraphschema/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/opengraphschema.go @@ -47,6 +47,11 @@ type OpenGraphSchemaRepository interface { type OpenGraphSchemaService struct { openGraphSchemaRepository OpenGraphSchemaRepository + transactor Transactor +} + +type Transactor interface { + WithTransaction(ctx context.Context, fn func(repo OpenGraphSchemaRepository) error) error } func NewOpenGraphSchemaService(openGraphSchemaRepository OpenGraphSchemaRepository) *OpenGraphSchemaService { From 89998ee5b21626835ed46d9cb7aee32fd5a769cf Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Wed, 7 Jan 2026 12:43:49 -0600 Subject: [PATCH 05/25] just prepare --- .../src/api/v2/mocks/graphschemaextensions.go | 12 ++++---- cmd/api/src/config/config.go | 2 +- cmd/api/src/config/default.go | 2 +- .../services/opengraphschema/environment.go | 4 +-- .../src/services/opengraphschema/extension.go | 29 ++++++++++++++----- .../opengraphschema/mocks/opengraphschema.go | 20 ------------- 6 files changed, 32 insertions(+), 37 deletions(-) diff --git a/cmd/api/src/api/v2/mocks/graphschemaextensions.go b/cmd/api/src/api/v2/mocks/graphschemaextensions.go index 96cb240811..21803d8b7b 100644 --- a/cmd/api/src/api/v2/mocks/graphschemaextensions.go +++ b/cmd/api/src/api/v2/mocks/graphschemaextensions.go @@ -57,16 +57,16 @@ func (m *MockOpenGraphSchemaService) EXPECT() *MockOpenGraphSchemaServiceMockRec return m.recorder } -// UpsertSchemaEnvironmentWithPrincipalKinds mocks base method. -func (m *MockOpenGraphSchemaService) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environments []v2.Environment) error { +// UpsertGraphSchemaExtension mocks base method. +func (m *MockOpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpsertSchemaEnvironmentWithPrincipalKinds", ctx, schemaExtensionId, environments) + ret := m.ctrl.Call(m, "UpsertGraphSchemaExtension", ctx, req) ret0, _ := ret[0].(error) return ret0 } -// UpsertSchemaEnvironmentWithPrincipalKinds indicates an expected call of UpsertSchemaEnvironmentWithPrincipalKinds. -func (mr *MockOpenGraphSchemaServiceMockRecorder) UpsertSchemaEnvironmentWithPrincipalKinds(ctx, schemaExtensionId, environments any) *gomock.Call { +// UpsertGraphSchemaExtension indicates an expected call of UpsertGraphSchemaExtension. +func (mr *MockOpenGraphSchemaServiceMockRecorder) UpsertGraphSchemaExtension(ctx, req any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertSchemaEnvironmentWithPrincipalKinds", reflect.TypeOf((*MockOpenGraphSchemaService)(nil).UpsertSchemaEnvironmentWithPrincipalKinds), ctx, schemaExtensionId, environments) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertGraphSchemaExtension", reflect.TypeOf((*MockOpenGraphSchemaService)(nil).UpsertGraphSchemaExtension), ctx, req) } diff --git a/cmd/api/src/config/config.go b/cmd/api/src/config/config.go index b928d38d3a..aa8a218aef 100644 --- a/cmd/api/src/config/config.go +++ b/cmd/api/src/config/config.go @@ -162,7 +162,7 @@ type Configuration struct { DisableCypherComplexityLimit bool `json:"disable_cypher_complexity_limit"` DisableIngest bool `json:"disable_ingest"` DisableMigrations bool `json:"disable_migrations"` - DisableTimeoutLimit bool `json:"disable_timeout_limit"` + DisableTimeoutLimit bool `json:"disable_timeout_limit"` GraphQueryMemoryLimit uint16 `json:"graph_query_memory_limit"` EnableTextLogger bool `json:"enable_text_logger"` RecreateDefaultAdmin bool `json:"recreate_default_admin"` diff --git a/cmd/api/src/config/default.go b/cmd/api/src/config/default.go index 656a308978..15efb8a6e0 100644 --- a/cmd/api/src/config/default.go +++ b/cmd/api/src/config/default.go @@ -65,7 +65,7 @@ func NewDefaultConfiguration() (Configuration, error) { DisableCypherComplexityLimit: false, DisableIngest: false, DisableMigrations: false, - DisableTimeoutLimit: false, + DisableTimeoutLimit: false, EnableCypherMutations: false, RecreateDefaultAdmin: false, ForceDownloadEmbeddedCollectors: false, diff --git a/cmd/api/src/services/opengraphschema/environment.go b/cmd/api/src/services/opengraphschema/environment.go index 5c45693808..662d429637 100644 --- a/cmd/api/src/services/opengraphschema/environment.go +++ b/cmd/api/src/services/opengraphschema/environment.go @@ -69,7 +69,7 @@ func (o *OpenGraphSchemaService) validateAndTranslateEnvironment(ctx context.Con func (o *OpenGraphSchemaService) validateAndTranslateEnvironmentKind(ctx context.Context, environmentKindName string) (model.Kind, error) { if envKind, err := o.openGraphSchemaRepository.GetKindByName(ctx, environmentKindName); err != nil && !errors.Is(err, database.ErrNotFound) { return model.Kind{}, fmt.Errorf("error retrieving environment kind '%s': %w", environmentKindName, err) - } else if errors.Is(err, database.ErrNotFound){ + } else if errors.Is(err, database.ErrNotFound) { return model.Kind{}, fmt.Errorf("environment kind '%s' not found", environmentKindName) } else { return envKind, nil @@ -106,7 +106,7 @@ func (o *OpenGraphSchemaService) validateAndTranslatePrincipalKinds(ctx context. for i, kindName := range principalKindNames { if kind, err := o.openGraphSchemaRepository.GetKindByName(ctx, kindName); err != nil && !errors.Is(err, database.ErrNotFound) { return nil, fmt.Errorf("error retrieving principal kind by name '%s': %w", kindName, err) - } else if errors.Is(err, database.ErrNotFound){ + } else if errors.Is(err, database.ErrNotFound) { return nil, fmt.Errorf("principal kind '%s' not found", kindName) } else { principalKinds[i] = model.SchemaEnvironmentPrincipalKind{ diff --git a/cmd/api/src/services/opengraphschema/extension.go b/cmd/api/src/services/opengraphschema/extension.go index 22848aa05f..5c8cbeb456 100644 --- a/cmd/api/src/services/opengraphschema/extension.go +++ b/cmd/api/src/services/opengraphschema/extension.go @@ -1,3 +1,18 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 package opengraphschema import ( @@ -8,18 +23,18 @@ import ( ) func (o *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { - return o.transactor.WithTransaction(ctx, func(repo OpenGraphSchemaRepository) error { - txService := &OpenGraphSchemaService{ - openGraphSchemaRepository: repo, - transactor: o.transactor, - } + return o.transactor.WithTransaction(ctx, func(repo OpenGraphSchemaRepository) error { + txService := &OpenGraphSchemaService{ + openGraphSchemaRepository: repo, + transactor: o.transactor, + } // Upsert environments with principal kinds // TODO: Temporary hardcoded extension ID - if err := txService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { + if err := txService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { return fmt.Errorf("failed to upload environments with principal kinds: %w", err) } return nil - }) + }) } diff --git a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go index 8ff7dabfa8..a3793b1572 100644 --- a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go @@ -27,7 +27,6 @@ package mocks import ( context "context" - sql "database/sql" reflect "reflect" database "github.com/specterops/bloodhound/cmd/api/src/database" @@ -191,22 +190,3 @@ func (mr *MockOpenGraphSchemaRepositoryMockRecorder) RegisterSourceKind(ctx any) mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterSourceKind", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).RegisterSourceKind), ctx) } - -// Transaction mocks base method. -func (m *MockOpenGraphSchemaRepository) Transaction(ctx context.Context, fn func(*database.BloodhoundDB) error, opts ...*sql.TxOptions) error { - m.ctrl.T.Helper() - varargs := []any{ctx, fn} - for _, a := range opts { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "Transaction", varargs...) - ret0, _ := ret[0].(error) - return ret0 -} - -// Transaction indicates an expected call of Transaction. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) Transaction(ctx, fn any, opts ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{ctx, fn}, opts...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Transaction", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).Transaction), varargs...) -} From 6ded4e314ca557d3507a0bc31dbee3ae6a0e2627 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Wed, 7 Jan 2026 15:03:34 -0600 Subject: [PATCH 06/25] abandoned the transactor --- cmd/api/src/database/db.go | 37 +++++++++++++++++++ .../src/services/opengraphschema/extension.go | 35 +++++++++++------- 2 files changed, 59 insertions(+), 13 deletions(-) diff --git a/cmd/api/src/database/db.go b/cmd/api/src/database/db.go index 7f776d595f..8dd772f842 100644 --- a/cmd/api/src/database/db.go +++ b/cmd/api/src/database/db.go @@ -238,6 +238,43 @@ func (s *BloodhoundDB) Transaction(ctx context.Context, fn func(tx *BloodhoundDB }, opts...) } +/* Manual Transaction Control +The following methods provide manual control over transactions as an alternative to the automatic Transaction method above. +Use these when you need explicit control over when to commit or rollback. +- BeginTransaction +- Commit +- Rollback +*/ + +// BeginTransaction starts a new database transaction and returns a transactional-aware BloodhoundDB. +func (s *BloodhoundDB) BeginTransaction(ctx context.Context, opts ...*sql.TxOptions) (*BloodhoundDB, error) { + tx := s.db.WithContext(ctx).Begin(opts...) + if tx.Error != nil { + return nil, fmt.Errorf("error beginning transaction: %w", tx.Error) + } + + return &BloodhoundDB{ + db: tx, + idResolver: s.idResolver, + }, nil +} + +// Commit commits the transaction and releases the database connection back to the pool. +func (s *BloodhoundDB) Commit() error { + if err := s.db.Commit().Error; err != nil { + return fmt.Errorf("error committing transaction: %w", err) + } + return nil +} + +// Rollback rolls back the transaction and releases the database connection back to the pool. +func (s *BloodhoundDB) Rollback() error { + if err := s.db.Rollback().Error; err != nil { + return fmt.Errorf("error rolling back transaction: %w", err) + } + return nil +} + func OpenDatabase(connection string) (*gorm.DB, error) { gormConfig := &gorm.Config{ Logger: &GormLogAdapter{ diff --git a/cmd/api/src/services/opengraphschema/extension.go b/cmd/api/src/services/opengraphschema/extension.go index 5c8cbeb456..e19ff8aaf3 100644 --- a/cmd/api/src/services/opengraphschema/extension.go +++ b/cmd/api/src/services/opengraphschema/extension.go @@ -4,7 +4,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -20,21 +20,30 @@ import ( "fmt" v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" + "github.com/specterops/bloodhound/cmd/api/src/database" ) func (o *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { - return o.transactor.WithTransaction(ctx, func(repo OpenGraphSchemaRepository) error { - txService := &OpenGraphSchemaService{ - openGraphSchemaRepository: repo, - transactor: o.transactor, - } + db, ok := o.openGraphSchemaRepository.(*database.BloodhoundDB) + if !ok { + return fmt.Errorf("database not found: unable to begin transaction") + } - // Upsert environments with principal kinds - // TODO: Temporary hardcoded extension ID - if err := txService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { - return fmt.Errorf("failed to upload environments with principal kinds: %w", err) - } + tx, err := db.BeginTransaction(ctx) + if err != nil { + return err + } - return nil - }) + txService := &OpenGraphSchemaService{ + openGraphSchemaRepository: tx, + } + + // TODO: Temporary hardcoded extension ID + // Upsert environments with principal kinds + if err := txService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { + tx.Rollback() + return fmt.Errorf("failed to upload environments with principal kinds: %w", err) + } + + return tx.Commit() } From 6c082854ed18a26f9836ad99bd3a86fe37de3844 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Wed, 7 Jan 2026 15:07:13 -0600 Subject: [PATCH 07/25] abandoned the transactor --- cmd/api/src/services/opengraphschema/opengraphschema.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/cmd/api/src/services/opengraphschema/opengraphschema.go b/cmd/api/src/services/opengraphschema/opengraphschema.go index a37535f639..cea2845ee9 100644 --- a/cmd/api/src/services/opengraphschema/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/opengraphschema.go @@ -47,11 +47,6 @@ type OpenGraphSchemaRepository interface { type OpenGraphSchemaService struct { openGraphSchemaRepository OpenGraphSchemaRepository - transactor Transactor -} - -type Transactor interface { - WithTransaction(ctx context.Context, fn func(repo OpenGraphSchemaRepository) error) error } func NewOpenGraphSchemaService(openGraphSchemaRepository OpenGraphSchemaRepository) *OpenGraphSchemaService { From 4ae5dfbc616b05505bed4d15455c4ded2a696b47 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Thu, 8 Jan 2026 11:16:29 -0600 Subject: [PATCH 08/25] entry pointed corrected so transactions can live on the interface again and all is well in the world --- .../src/services/opengraphschema/extension.go | 13 ++--- .../opengraphschema/mocks/opengraphschema.go | 49 +++++++++++++++++++ .../opengraphschema/opengraphschema.go | 6 +++ 3 files changed, 59 insertions(+), 9 deletions(-) diff --git a/cmd/api/src/services/opengraphschema/extension.go b/cmd/api/src/services/opengraphschema/extension.go index e19ff8aaf3..aefcff32e3 100644 --- a/cmd/api/src/services/opengraphschema/extension.go +++ b/cmd/api/src/services/opengraphschema/extension.go @@ -20,27 +20,22 @@ import ( "fmt" v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" - "github.com/specterops/bloodhound/cmd/api/src/database" ) func (o *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { - db, ok := o.openGraphSchemaRepository.(*database.BloodhoundDB) - if !ok { - return fmt.Errorf("database not found: unable to begin transaction") - } - - tx, err := db.BeginTransaction(ctx) + tx, err := o.openGraphSchemaRepository.BeginTransaction(ctx) if err != nil { return err } - txService := &OpenGraphSchemaService{ + // Create service with transaction + transactionalService := &OpenGraphSchemaService{ openGraphSchemaRepository: tx, } // TODO: Temporary hardcoded extension ID // Upsert environments with principal kinds - if err := txService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { + if err := transactionalService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { tx.Rollback() return fmt.Errorf("failed to upload environments with principal kinds: %w", err) } diff --git a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go index a3793b1572..e9f00e40dc 100644 --- a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go @@ -27,6 +27,7 @@ package mocks import ( context "context" + sql "database/sql" reflect "reflect" database "github.com/specterops/bloodhound/cmd/api/src/database" @@ -59,6 +60,40 @@ func (m *MockOpenGraphSchemaRepository) EXPECT() *MockOpenGraphSchemaRepositoryM return m.recorder } +// BeginTransaction mocks base method. +func (m *MockOpenGraphSchemaRepository) BeginTransaction(ctx context.Context, opts ...*sql.TxOptions) (*database.BloodhoundDB, error) { + m.ctrl.T.Helper() + varargs := []any{ctx} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "BeginTransaction", varargs...) + ret0, _ := ret[0].(*database.BloodhoundDB) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BeginTransaction indicates an expected call of BeginTransaction. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) BeginTransaction(ctx any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{ctx}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTransaction", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).BeginTransaction), varargs...) +} + +// Commit mocks base method. +func (m *MockOpenGraphSchemaRepository) Commit() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit") + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) Commit() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).Commit)) +} + // CreateSchemaEnvironment mocks base method. func (m *MockOpenGraphSchemaRepository) CreateSchemaEnvironment(ctx context.Context, schemaExtensionId, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { m.ctrl.T.Helper() @@ -190,3 +225,17 @@ func (mr *MockOpenGraphSchemaRepositoryMockRecorder) RegisterSourceKind(ctx any) mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterSourceKind", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).RegisterSourceKind), ctx) } + +// Rollback mocks base method. +func (m *MockOpenGraphSchemaRepository) Rollback() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Rollback") + ret0, _ := ret[0].(error) + return ret0 +} + +// Rollback indicates an expected call of Rollback. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) Rollback() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).Rollback)) +} diff --git a/cmd/api/src/services/opengraphschema/opengraphschema.go b/cmd/api/src/services/opengraphschema/opengraphschema.go index cea2845ee9..21e61bd088 100644 --- a/cmd/api/src/services/opengraphschema/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/opengraphschema.go @@ -19,6 +19,7 @@ package opengraphschema import ( "context" + "database/sql" "github.com/specterops/bloodhound/cmd/api/src/database" "github.com/specterops/bloodhound/cmd/api/src/model" @@ -27,6 +28,11 @@ import ( // OpenGraphSchemaRepository - type OpenGraphSchemaRepository interface { + // TX + BeginTransaction(ctx context.Context, opts ...*sql.TxOptions) (*database.BloodhoundDB, error) + Commit() error + Rollback() error + // Kinds GetKindByName(ctx context.Context, name string) (model.Kind, error) From a3d942284dff58a408ac14e4ef5c39dde6bb5fc0 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Thu, 8 Jan 2026 11:56:44 -0600 Subject: [PATCH 09/25] some cleanup --- cmd/api/src/api/v2/opengraphschema.go | 2 +- cmd/api/src/database/db.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/api/src/api/v2/opengraphschema.go b/cmd/api/src/api/v2/opengraphschema.go index 6600045f60..a49248fdf0 100644 --- a/cmd/api/src/api/v2/opengraphschema.go +++ b/cmd/api/src/api/v2/opengraphschema.go @@ -56,5 +56,5 @@ func (s Resources) OpenGraphSchemaIngest(response http.ResponseWriter, request * return } - api.WriteBasicResponse(request.Context(), "Success", http.StatusCreated, response) + response.WriteHeader(http.StatusCreated) } diff --git a/cmd/api/src/database/db.go b/cmd/api/src/database/db.go index 8dd772f842..c8ea2a713a 100644 --- a/cmd/api/src/database/db.go +++ b/cmd/api/src/database/db.go @@ -246,7 +246,7 @@ Use these when you need explicit control over when to commit or rollback. - Rollback */ -// BeginTransaction starts a new database transaction and returns a transactional-aware BloodhoundDB. +// BeginTransaction starts a new database transaction and returns a transactional-aware connection of BloodhoundDB. func (s *BloodhoundDB) BeginTransaction(ctx context.Context, opts ...*sql.TxOptions) (*BloodhoundDB, error) { tx := s.db.WithContext(ctx).Begin(opts...) if tx.Error != nil { From 5987407d881d2064e459fe0f4d9b9d2fe0489fa0 Mon Sep 17 00:00:00 2001 From: Conrad Weidenkeller Date: Mon, 29 Dec 2025 07:46:16 -0600 Subject: [PATCH 10/25] feat(graphschema): Add Environment principal kinds BED-7076 --- cmd/api/src/database/graphschema.go | 44 +++ .../database/graphschema_integration_test.go | 254 ++++++++++++++++++ .../database/migration/migrations/v8.5.0.sql | 1 + cmd/api/src/database/mocks/db.go | 44 +++ cmd/api/src/model/graphschema.go | 12 + 5 files changed, 355 insertions(+) diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index 85ed961b08..70613a9769 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -65,6 +65,9 @@ type OpenGraphSchema interface { GetRemediationByFindingId(ctx context.Context, findingId int32) (model.Remediation, error) UpdateRemediation(ctx context.Context, findingId int32, shortDescription string, longDescription string, shortRemediation string, longRemediation string) (model.Remediation, error) DeleteRemediation(ctx context.Context, findingId int32) error + CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) + GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) + DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error } const DuplicateKeyValueErrorString = "duplicate key value violates unique constraint" @@ -708,6 +711,47 @@ func (s *BloodhoundDB) DeleteRemediation(ctx context.Context, findingId int32) e return nil } +func (s *BloodhoundDB) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { + var envPrincipalKind model.SchemaEnvironmentPrincipalKind + + if result := s.db.WithContext(ctx).Raw(` + INSERT INTO schema_environments_principal_kinds (environment_id, principal_kind, created_at) + VALUES (?, ?, NOW()) + RETURNING environment_id, principal_kind, created_at`, + environmentId, principalKind).Scan(&envPrincipalKind); result.Error != nil { + return model.SchemaEnvironmentPrincipalKind{}, CheckError(result) + } + + return envPrincipalKind, nil +} + +func (s *BloodhoundDB) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { + var envPrincipalKinds model.SchemaEnvironmentPrincipalKinds + + if result := s.db.WithContext(ctx).Raw(` + SELECT environment_id, principal_kind, created_at + FROM schema_environments_principal_kinds + WHERE environment_id = ?`, + environmentId).Scan(&envPrincipalKinds); result.Error != nil { + return nil, CheckError(result) + } + + return envPrincipalKinds, nil +} + +func (s *BloodhoundDB) DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error { + if result := s.db.WithContext(ctx).Exec(` + DELETE FROM schema_environments_principal_kinds + WHERE environment_id = ? AND principal_kind = ?`, + environmentId, principalKind); result.Error != nil { + return CheckError(result) + } else if result.RowsAffected == 0 { + return ErrNotFound + } + + return nil +} + func parseFiltersAndPagination(filters model.Filters, sort model.Sort, skip, limit int) (FilterAndPagination, error) { var ( filtersAndPagination FilterAndPagination diff --git a/cmd/api/src/database/graphschema_integration_test.go b/cmd/api/src/database/graphschema_integration_test.go index 6cf19ce078..a55588f26f 100644 --- a/cmd/api/src/database/graphschema_integration_test.go +++ b/cmd/api/src/database/graphschema_integration_test.go @@ -2162,3 +2162,257 @@ func TestDeleteRemediation(t *testing.T) { }) } } + +func TestCreateSchemaEnvironmentPrincipalKind(t *testing.T) { + type args struct { + environmentId int32 + principalKind int32 + } + type want struct { + res model.SchemaEnvironmentPrincipalKind + err error + } + tests := []struct { + name string + setup func() IntegrationTestSuite + args args + want want + }{ + { + name: "Success: schema environment principal kind created", + setup: func() IntegrationTestSuite { + t.Helper() + testSuite := setupIntegrationTestSuite(t) + + _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "EnvPrincipalKindExt", "Env Principal Kind Extension", "v1.0.0") + require.NoError(t, err) + + _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + require.NoError(t, err) + + return testSuite + }, + args: args{ + environmentId: 1, + principalKind: 1, + }, + want: want{ + res: model.SchemaEnvironmentPrincipalKind{ + EnvironmentId: 1, + PrincipalKind: 1, + }, + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + testSuite := testCase.setup() + defer teardownIntegrationTestSuite(t, &testSuite) + + result, err := testSuite.BHDatabase.CreateSchemaEnvironmentPrincipalKind(testSuite.Context, testCase.args.environmentId, testCase.args.principalKind) + if testCase.want.err != nil { + assert.ErrorIs(t, err, testCase.want.err) + } else { + assert.NoError(t, err) + assert.Equal(t, testCase.want.res.EnvironmentId, result.EnvironmentId) + assert.Equal(t, testCase.want.res.PrincipalKind, result.PrincipalKind) + } + }) + } +} + +func TestGetSchemaEnvironmentPrincipalKindsByEnvironmentId(t *testing.T) { + type args struct { + environmentId int32 + } + type want struct { + count int + err error + } + tests := []struct { + name string + setup func() IntegrationTestSuite + args args + want want + }{ + { + name: "Success: get schema environment principal kinds by environment id", + setup: func() IntegrationTestSuite { + t.Helper() + testSuite := setupIntegrationTestSuite(t) + + _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "GetEnvPrincipalKindExt", "Get Env Principal Kind Extension", "v1.0.0") + require.NoError(t, err) + + _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + require.NoError(t, err) + + _, err = testSuite.BHDatabase.CreateSchemaEnvironmentPrincipalKind(testSuite.Context, 1, 1) + require.NoError(t, err) + + _, err = testSuite.BHDatabase.CreateSchemaEnvironmentPrincipalKind(testSuite.Context, 1, 2) + require.NoError(t, err) + + return testSuite + }, + args: args{ + environmentId: 1, + }, + want: want{ + count: 2, + }, + }, + { + name: "Success: returns empty slice when no principal kinds exist", + setup: func() IntegrationTestSuite { + return setupIntegrationTestSuite(t) + }, + args: args{ + environmentId: 9999, + }, + want: want{ + count: 0, + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + testSuite := testCase.setup() + defer teardownIntegrationTestSuite(t, &testSuite) + + result, err := testSuite.BHDatabase.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(testSuite.Context, testCase.args.environmentId) + if testCase.want.err != nil { + assert.ErrorIs(t, err, testCase.want.err) + } else { + assert.NoError(t, err) + assert.Len(t, result, testCase.want.count) + } + }) + } +} + +func TestDeleteSchemaEnvironmentPrincipalKind(t *testing.T) { + type args struct { + environmentId int32 + principalKind int32 + } + type want struct { + err error + } + tests := []struct { + name string + setup func() IntegrationTestSuite + args args + want want + }{ + { + name: "Success: delete schema environment principal kind", + setup: func() IntegrationTestSuite { + t.Helper() + testSuite := setupIntegrationTestSuite(t) + + _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "DeleteEnvPrincipalKindExt", "Delete Env Principal Kind Extension", "v1.0.0") + require.NoError(t, err) + + _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + require.NoError(t, err) + + _, err = testSuite.BHDatabase.CreateSchemaEnvironmentPrincipalKind(testSuite.Context, 1, 1) + require.NoError(t, err) + + return testSuite + }, + args: args{ + environmentId: 1, + principalKind: 1, + }, + want: want{ + err: nil, + }, + }, + { + name: "Fail: schema environment principal kind not found", + setup: func() IntegrationTestSuite { + return setupIntegrationTestSuite(t) + }, + args: args{ + environmentId: 9999, + principalKind: 9999, + }, + want: want{ + err: database.ErrNotFound, + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + testSuite := testCase.setup() + defer teardownIntegrationTestSuite(t, &testSuite) + + err := testSuite.BHDatabase.DeleteSchemaEnvironmentPrincipalKind(testSuite.Context, testCase.args.environmentId, testCase.args.principalKind) + if testCase.want.err != nil { + assert.ErrorIs(t, err, testCase.want.err) + } else { + assert.NoError(t, err) + result, err := testSuite.BHDatabase.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(testSuite.Context, testCase.args.environmentId) + assert.NoError(t, err) + assert.Len(t, result, 0) + } + }) + } +} + +func TestDeleteSchemaExtension_CascadeDeletesAllDependents(t *testing.T) { + t.Parallel() + testSuite := setupIntegrationTestSuite(t) + defer teardownIntegrationTestSuite(t, &testSuite) + + extension, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "CascadeTestExtension", "Cascade Test Extension", "v1.0.0") + require.NoError(t, err) + + nodeKind, err := testSuite.BHDatabase.CreateGraphSchemaNodeKind(testSuite.Context, "CascadeTestNodeKind", extension.ID, "Cascade Test Node Kind", "Test description", false, "fa-test", "#000000") + require.NoError(t, err) + + property, err := testSuite.BHDatabase.CreateGraphSchemaProperty(testSuite.Context, extension.ID, "cascade_test_property", "Cascade Test Property", "string", "Test description") + require.NoError(t, err) + + edgeKind, err := testSuite.BHDatabase.CreateGraphSchemaEdgeKind(testSuite.Context, "CascadeTestEdgeKind", extension.ID, "Test description", true) + require.NoError(t, err) + + environment, err := testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, extension.ID, nodeKind.ID, nodeKind.ID) + require.NoError(t, err) + + relationshipFinding, err := testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, extension.ID, edgeKind.ID, environment.ID, "CascadeTestFinding", "Cascade Test Finding") + require.NoError(t, err) + + _, err = testSuite.BHDatabase.CreateRemediation(testSuite.Context, relationshipFinding.ID, "Short desc", "Long desc", "Short remediation", "Long remediation") + require.NoError(t, err) + + _, err = testSuite.BHDatabase.CreateSchemaEnvironmentPrincipalKind(testSuite.Context, environment.ID, nodeKind.ID) + require.NoError(t, err) + + err = testSuite.BHDatabase.DeleteGraphSchemaExtension(testSuite.Context, extension.ID) + require.NoError(t, err) + + _, err = testSuite.BHDatabase.GetGraphSchemaNodeKindById(testSuite.Context, nodeKind.ID) + assert.ErrorIs(t, err, database.ErrNotFound) + + _, err = testSuite.BHDatabase.GetGraphSchemaPropertyById(testSuite.Context, property.ID) + assert.ErrorIs(t, err, database.ErrNotFound) + + _, err = testSuite.BHDatabase.GetGraphSchemaEdgeKindById(testSuite.Context, edgeKind.ID) + assert.ErrorIs(t, err, database.ErrNotFound) + + _, err = testSuite.BHDatabase.GetSchemaEnvironmentById(testSuite.Context, environment.ID) + assert.ErrorIs(t, err, database.ErrNotFound) + + _, err = testSuite.BHDatabase.GetSchemaRelationshipFindingById(testSuite.Context, relationshipFinding.ID) + assert.ErrorIs(t, err, database.ErrNotFound) + + _, err = testSuite.BHDatabase.GetRemediationByFindingId(testSuite.Context, relationshipFinding.ID) + assert.ErrorIs(t, err, database.ErrNotFound) + + principalKinds, err := testSuite.BHDatabase.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(testSuite.Context, environment.ID) + assert.NoError(t, err) + assert.Len(t, principalKinds, 0) +} diff --git a/cmd/api/src/database/migration/migrations/v8.5.0.sql b/cmd/api/src/database/migration/migrations/v8.5.0.sql index 204da86fec..91e6636af6 100644 --- a/cmd/api/src/database/migration/migrations/v8.5.0.sql +++ b/cmd/api/src/database/migration/migrations/v8.5.0.sql @@ -148,6 +148,7 @@ CREATE INDEX IF NOT EXISTS idx_schema_remediations_content_type ON schema_remedi CREATE TABLE IF NOT EXISTS schema_environments_principal_kinds ( environment_id INTEGER NOT NULL REFERENCES schema_environments(id) ON DELETE CASCADE, principal_kind INTEGER NOT NULL REFERENCES kind(id), + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT current_timestamp, PRIMARY KEY(environment_id, principal_kind) ); diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index 22f96083e9..cf16fece30 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -585,6 +585,21 @@ func (mr *MockDatabaseMockRecorder) CreateSchemaEnvironment(ctx, extensionId, en return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironment", reflect.TypeOf((*MockDatabase)(nil).CreateSchemaEnvironment), ctx, extensionId, environmentKindId, sourceKindId) } +// CreateSchemaEnvironmentPrincipalKind mocks base method. +func (m *MockDatabase) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) + ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKind) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateSchemaEnvironmentPrincipalKind indicates an expected call of CreateSchemaEnvironmentPrincipalKind. +func (mr *MockDatabaseMockRecorder) CreateSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockDatabase)(nil).CreateSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) +} + // CreateSchemaRelationshipFinding mocks base method. func (m *MockDatabase) CreateSchemaRelationshipFinding(ctx context.Context, extensionId, relationshipKindId, environmentId int32, name, displayName string) (model.SchemaRelationshipFinding, error) { m.ctrl.T.Helper() @@ -986,6 +1001,20 @@ func (mr *MockDatabaseMockRecorder) DeleteSchemaEnvironment(ctx, environmentId a return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironment", reflect.TypeOf((*MockDatabase)(nil).DeleteSchemaEnvironment), ctx, environmentId) } +// DeleteSchemaEnvironmentPrincipalKind mocks base method. +func (m *MockDatabase) DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteSchemaEnvironmentPrincipalKind indicates an expected call of DeleteSchemaEnvironmentPrincipalKind. +func (mr *MockDatabaseMockRecorder) DeleteSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockDatabase)(nil).DeleteSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) +} + // DeleteSchemaRelationshipFinding mocks base method. func (m *MockDatabase) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error { m.ctrl.T.Helper() @@ -2158,6 +2187,21 @@ func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentById(ctx, environmentId return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentById", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentById), ctx, environmentId) } +// GetSchemaEnvironmentPrincipalKindsByEnvironmentId mocks base method. +func (m *MockDatabase) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", ctx, environmentId) + ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKinds) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSchemaEnvironmentPrincipalKindsByEnvironmentId indicates an expected call of GetSchemaEnvironmentPrincipalKindsByEnvironmentId. +func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentPrincipalKindsByEnvironmentId), ctx, environmentId) +} + // GetSchemaEnvironments mocks base method. func (m *MockDatabase) GetSchemaEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/model/graphschema.go b/cmd/api/src/model/graphschema.go index 38c1ac2867..397ba77afc 100644 --- a/cmd/api/src/model/graphschema.go +++ b/cmd/api/src/model/graphschema.go @@ -136,6 +136,18 @@ func (Remediation) TableName() string { return "schema_remediations" } +type SchemaEnvironmentPrincipalKinds []SchemaEnvironmentPrincipalKind + +type SchemaEnvironmentPrincipalKind struct { + EnvironmentId int32 `json:"environment_id"` + PrincipalKind int32 `json:"principal_kind"` + CreatedAt time.Time `json:"created_at"` +} + +func (SchemaEnvironmentPrincipalKind) TableName() string { + return "schema_environments_principal_kinds" +} + func (GraphSchemaEdgeKind) ValidFilters() map[string][]FilterOperator { return ValidFilters{ "is_traversable": {Equals, NotEquals}, From e76b52dd1518c97764c6e08a2c5247f3abfd5397 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Fri, 9 Jan 2026 09:04:13 -0600 Subject: [PATCH 11/25] pretty decent sized refactor to move from service layer to database layer + integration tests --- cmd/api/src/database/db.go | 40 +- cmd/api/src/database/graphschema.go | 21 + cmd/api/src/database/kind.go | 14 +- .../database/migration/migrations/v8.5.0.sql | 2 +- cmd/api/src/database/mocks/db.go | 74 +++ cmd/api/src/database/sourcekinds.go | 10 +- .../src/database/upsert_schema_environment.go | 160 ++++++ ...ert_schema_environment_integration_test.go | 305 +++++++++++ .../services/opengraphschema/environment.go | 162 ------ .../opengraphschema/environment_test.go | 514 ------------------ .../src/services/opengraphschema/extension.go | 21 +- .../opengraphschema/mocks/opengraphschema.go | 182 +------ .../opengraphschema/opengraphschema.go | 27 +- .../opengraphschema/opengraphschema_test.go | 499 +++++++++++++++++ 14 files changed, 1095 insertions(+), 936 deletions(-) create mode 100644 cmd/api/src/database/upsert_schema_environment.go create mode 100644 cmd/api/src/database/upsert_schema_environment_integration_test.go delete mode 100644 cmd/api/src/services/opengraphschema/environment.go delete mode 100644 cmd/api/src/services/opengraphschema/environment_test.go create mode 100644 cmd/api/src/services/opengraphschema/opengraphschema_test.go diff --git a/cmd/api/src/database/db.go b/cmd/api/src/database/db.go index c8ea2a713a..acee7c7914 100644 --- a/cmd/api/src/database/db.go +++ b/cmd/api/src/database/db.go @@ -189,6 +189,9 @@ type Database interface { // OpenGraph Schema OpenGraphSchema + + // Kind + Kind } type BloodhoundDB struct { @@ -238,43 +241,6 @@ func (s *BloodhoundDB) Transaction(ctx context.Context, fn func(tx *BloodhoundDB }, opts...) } -/* Manual Transaction Control -The following methods provide manual control over transactions as an alternative to the automatic Transaction method above. -Use these when you need explicit control over when to commit or rollback. -- BeginTransaction -- Commit -- Rollback -*/ - -// BeginTransaction starts a new database transaction and returns a transactional-aware connection of BloodhoundDB. -func (s *BloodhoundDB) BeginTransaction(ctx context.Context, opts ...*sql.TxOptions) (*BloodhoundDB, error) { - tx := s.db.WithContext(ctx).Begin(opts...) - if tx.Error != nil { - return nil, fmt.Errorf("error beginning transaction: %w", tx.Error) - } - - return &BloodhoundDB{ - db: tx, - idResolver: s.idResolver, - }, nil -} - -// Commit commits the transaction and releases the database connection back to the pool. -func (s *BloodhoundDB) Commit() error { - if err := s.db.Commit().Error; err != nil { - return fmt.Errorf("error committing transaction: %w", err) - } - return nil -} - -// Rollback rolls back the transaction and releases the database connection back to the pool. -func (s *BloodhoundDB) Rollback() error { - if err := s.db.Rollback().Error; err != nil { - return fmt.Errorf("error rolling back transaction: %w", err) - } - return nil -} - func OpenDatabase(connection string) (*gorm.DB, error) { gormConfig := &gorm.Config{ Logger: &GormLogAdapter{ diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index 3801cc5248..b7e81099be 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -52,6 +52,7 @@ type OpenGraphSchema interface { GetGraphSchemaEdgeKindsWithSchemaName(ctx context.Context, edgeKindFilters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaEdgeKindsWithNamedSchema, int, error) CreateSchemaEnvironment(ctx context.Context, extensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) + GetSchemaEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) GetSchemaEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) GetSchemaEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) DeleteSchemaEnvironment(ctx context.Context, environmentId int32) error @@ -59,6 +60,10 @@ type OpenGraphSchema interface { CreateSchemaRelationshipFinding(ctx context.Context, extensionId int32, relationshipKindId int32, environmentId int32, name string, displayName string) (model.SchemaRelationshipFinding, error) GetSchemaRelationshipFindingById(ctx context.Context, findingId int32) (model.SchemaRelationshipFinding, error) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error + + CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) + GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) + DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error } const DuplicateKeyValueErrorString = "duplicate key value violates unique constraint" @@ -546,6 +551,22 @@ func (s *BloodhoundDB) GetSchemaEnvironmentById(ctx context.Context, environment return schemaEnvironment, nil } +// GetSchemaEnvironmentByKinds - retrieves an environment by its environment kind and source kind. +func (s *BloodhoundDB) GetSchemaEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { + var env model.SchemaEnvironment + + if result := s.db.WithContext(ctx).Raw( + "SELECT * FROM schema_environments WHERE environment_kind_id = ? AND source_kind_id = ? AND deleted_at IS NULL", + environmentKindId, sourceKindId, + ).Scan(&env); result.Error != nil { + return model.SchemaEnvironment{}, CheckError(result) + } else if result.RowsAffected == 0 { + return model.SchemaEnvironment{}, ErrNotFound + } + + return env, nil +} + // DeleteSchemaEnvironment - deletes a schema environment by id. func (s *BloodhoundDB) DeleteSchemaEnvironment(ctx context.Context, environmentId int32) error { var schemaEnvironment model.SchemaEnvironment diff --git a/cmd/api/src/database/kind.go b/cmd/api/src/database/kind.go index 87632f6ae3..0487f28375 100644 --- a/cmd/api/src/database/kind.go +++ b/cmd/api/src/database/kind.go @@ -21,6 +21,10 @@ import ( "github.com/specterops/bloodhound/cmd/api/src/model" ) +type Kind interface { + GetKindByName(ctx context.Context, name string) (model.Kind, error) +} + func (s *BloodhoundDB) GetKindByName(ctx context.Context, name string) (model.Kind, error) { const query = ` SELECT id, name @@ -29,8 +33,14 @@ func (s *BloodhoundDB) GetKindByName(ctx context.Context, name string) (model.Ki ` var kind model.Kind - if err := s.db.Raw(query, name).Scan(&kind).Error; err != nil { - return model.Kind{}, err + result := s.db.WithContext(ctx).Raw(query, name).Scan(&kind) + + if result.Error != nil { + return model.Kind{}, result.Error + } + + if result.RowsAffected == 0 || kind.ID == 0 { + return model.Kind{}, ErrNotFound } return kind, nil diff --git a/cmd/api/src/database/migration/migrations/v8.5.0.sql b/cmd/api/src/database/migration/migrations/v8.5.0.sql index 204da86fec..e73d008e6e 100644 --- a/cmd/api/src/database/migration/migrations/v8.5.0.sql +++ b/cmd/api/src/database/migration/migrations/v8.5.0.sql @@ -93,7 +93,7 @@ CREATE TABLE IF NOT EXISTS schema_environments ( id SERIAL, schema_extension_id INTEGER NOT NULL REFERENCES schema_extensions(id) ON DELETE CASCADE, environment_kind_id INTEGER NOT NULL REFERENCES kind(id), - source_kind_id INTEGER NOT NULL REFERENCES kind(id), + source_kind_id INTEGER NOT NULL REFERENCES source_kinds(id), PRIMARY KEY (id), created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT current_timestamp, updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT current_timestamp, diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index ae5c1b5e1c..dd57b73085 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -570,6 +570,21 @@ func (mr *MockDatabaseMockRecorder) CreateSchemaEnvironment(ctx, extensionId, en return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironment", reflect.TypeOf((*MockDatabase)(nil).CreateSchemaEnvironment), ctx, extensionId, environmentKindId, sourceKindId) } +// CreateSchemaEnvironmentPrincipalKind mocks base method. +func (m *MockDatabase) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) + ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKind) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateSchemaEnvironmentPrincipalKind indicates an expected call of CreateSchemaEnvironmentPrincipalKind. +func (mr *MockDatabaseMockRecorder) CreateSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockDatabase)(nil).CreateSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) +} + // CreateSchemaRelationshipFinding mocks base method. func (m *MockDatabase) CreateSchemaRelationshipFinding(ctx context.Context, extensionId, relationshipKindId, environmentId int32, name, displayName string) (model.SchemaRelationshipFinding, error) { m.ctrl.T.Helper() @@ -957,6 +972,20 @@ func (mr *MockDatabaseMockRecorder) DeleteSchemaEnvironment(ctx, environmentId a return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironment", reflect.TypeOf((*MockDatabase)(nil).DeleteSchemaEnvironment), ctx, environmentId) } +// DeleteSchemaEnvironmentPrincipalKind mocks base method. +func (m *MockDatabase) DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteSchemaEnvironmentPrincipalKind indicates an expected call of DeleteSchemaEnvironmentPrincipalKind. +func (mr *MockDatabaseMockRecorder) DeleteSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockDatabase)(nil).DeleteSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) +} + // DeleteSchemaRelationshipFinding mocks base method. func (m *MockDatabase) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error { m.ctrl.T.Helper() @@ -1889,6 +1918,21 @@ func (mr *MockDatabaseMockRecorder) GetInstallation(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInstallation", reflect.TypeOf((*MockDatabase)(nil).GetInstallation), ctx) } +// GetKindByName mocks base method. +func (m *MockDatabase) GetKindByName(ctx context.Context, name string) (model.Kind, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetKindByName", ctx, name) + ret0, _ := ret[0].(model.Kind) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetKindByName indicates an expected call of GetKindByName. +func (mr *MockDatabaseMockRecorder) GetKindByName(ctx, name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKindByName", reflect.TypeOf((*MockDatabase)(nil).GetKindByName), ctx, name) +} + // GetLatestAssetGroupCollection mocks base method. func (m *MockDatabase) GetLatestAssetGroupCollection(ctx context.Context, assetGroupID int32) (model.AssetGroupCollection, error) { m.ctrl.T.Helper() @@ -2114,6 +2158,36 @@ func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentById(ctx, environmentId return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentById", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentById), ctx, environmentId) } +// GetSchemaEnvironmentByKinds mocks base method. +func (m *MockDatabase) GetSchemaEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSchemaEnvironmentByKinds", ctx, environmentKindId, sourceKindId) + ret0, _ := ret[0].(model.SchemaEnvironment) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSchemaEnvironmentByKinds indicates an expected call of GetSchemaEnvironmentByKinds. +func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentByKinds(ctx, environmentKindId, sourceKindId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentByKinds", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentByKinds), ctx, environmentKindId, sourceKindId) +} + +// GetSchemaEnvironmentPrincipalKindsByEnvironmentId mocks base method. +func (m *MockDatabase) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", ctx, environmentId) + ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKinds) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSchemaEnvironmentPrincipalKindsByEnvironmentId indicates an expected call of GetSchemaEnvironmentPrincipalKindsByEnvironmentId. +func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentPrincipalKindsByEnvironmentId), ctx, environmentId) +} + // GetSchemaEnvironments mocks base method. func (m *MockDatabase) GetSchemaEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/database/sourcekinds.go b/cmd/api/src/database/sourcekinds.go index d69ed7dcd4..6aa76b8a24 100644 --- a/cmd/api/src/database/sourcekinds.go +++ b/cmd/api/src/database/sourcekinds.go @@ -108,8 +108,14 @@ func (s *BloodhoundDB) GetSourceKindByName(ctx context.Context, name string) (So } var raw rawSourceKind - if err := s.db.Raw(query, name).Scan(&raw).Error; err != nil { - return SourceKind{}, err + result := s.db.WithContext(ctx).Raw(query, name).Scan(&raw) + + if result.Error != nil { + return SourceKind{}, result.Error + } + + if result.RowsAffected == 0 || raw.ID == 0 { + return SourceKind{}, ErrNotFound } kind := SourceKind{ diff --git a/cmd/api/src/database/upsert_schema_environment.go b/cmd/api/src/database/upsert_schema_environment.go new file mode 100644 index 0000000000..f19e39cd28 --- /dev/null +++ b/cmd/api/src/database/upsert_schema_environment.go @@ -0,0 +1,160 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package database + +import ( + "context" + "errors" + "fmt" + + "github.com/specterops/bloodhound/cmd/api/src/model" + "github.com/specterops/dawgs/graph" +) + +// UpsertSchemaEnvironmentWithPrincipalKinds creates or updates an environment with its principal kinds. +// If an environment with the same environment kind and source kind exists, it will be replaced. +// +// NOTE: This method should be called within a transaction. The caller is responsible for transaction management. +func (s *BloodhoundDB) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environmentKind string, sourceKind string, principalKinds []string) error { + environment := model.SchemaEnvironment{ + SchemaExtensionId: schemaExtensionId, + } + + envKind, err := s.validateAndTranslateEnvironmentKind(ctx, environmentKind) + if err != nil { + return err + } + + sourceKindID, err := s.validateAndTranslateSourceKind(ctx, sourceKind) + if err != nil { + return err + } + + translatedPrincipalKinds, err := s.validateAndTranslatePrincipalKinds(ctx, principalKinds) + if err != nil { + return err + } + + environment.EnvironmentKindId = int32(envKind.ID) + environment.SourceKindId = sourceKindID + + envID, err := s.upsertSchemaEnvironment(ctx, environment) + if err != nil { + return fmt.Errorf("error upserting schema environment: %w", err) + } + + if err := s.upsertPrincipalKinds(ctx, envID, translatedPrincipalKinds); err != nil { + return fmt.Errorf("error upserting principal kinds: %w", err) + } + + return nil +} + +// validateAndTranslateEnvironmentKind validates that the environment kind exists in the kinds table. +func (s *BloodhoundDB) validateAndTranslateEnvironmentKind(ctx context.Context, environmentKindName string) (model.Kind, error) { + if envKind, err := s.GetKindByName(ctx, environmentKindName); err != nil && !errors.Is(err, ErrNotFound) { + return model.Kind{}, fmt.Errorf("error retrieving environment kind '%s': %w", environmentKindName, err) + } else if errors.Is(err, ErrNotFound) { + return model.Kind{}, fmt.Errorf("environment kind '%s' not found", environmentKindName) + } else { + return envKind, nil + } +} + +// validateAndTranslateSourceKind validates that the source kind exists in the source_kinds table. +// If not found, it registers the source kind and returns its ID so it can be added to the Environment object. +func (s *BloodhoundDB) validateAndTranslateSourceKind(ctx context.Context, sourceKindName string) (int32, error) { + if sourceKind, err := s.GetSourceKindByName(ctx, sourceKindName); err != nil && !errors.Is(err, ErrNotFound) { + return 0, fmt.Errorf("error retrieving source kind '%s': %w", sourceKindName, err) + } else if err == nil { + return int32(sourceKind.ID), nil + } + + // If source kind is not found, register it. If it exists and is inactive, it will automatically update as active. + kindType := graph.StringKind(sourceKindName) + if err := s.RegisterSourceKind(ctx)(kindType); err != nil { + return 0, fmt.Errorf("error registering source kind '%s': %w", sourceKindName, err) + } + + if sourceKind, err := s.GetSourceKindByName(ctx, sourceKindName); err != nil { + return 0, fmt.Errorf("error retrieving newly registered source kind '%s': %w", sourceKindName, err) + } else { + return int32(sourceKind.ID), nil + } +} + +// validateAndTranslatePrincipalKinds ensures all principalKinds exist in the kinds table. +// It also translates them to IDs so they can be upserted into the database. +func (s *BloodhoundDB) validateAndTranslatePrincipalKinds(ctx context.Context, principalKindNames []string) ([]model.SchemaEnvironmentPrincipalKind, error) { + principalKinds := make([]model.SchemaEnvironmentPrincipalKind, len(principalKindNames)) + + for i, kindName := range principalKindNames { + if kind, err := s.GetKindByName(ctx, kindName); err != nil && !errors.Is(err, ErrNotFound) { + return nil, fmt.Errorf("error retrieving principal kind by name '%s': %w", kindName, err) + } else if errors.Is(err, ErrNotFound) { + return nil, fmt.Errorf("principal kind '%s' not found", kindName) + } else { + principalKinds[i] = model.SchemaEnvironmentPrincipalKind{ + PrincipalKind: int32(kind.ID), + } + } + } + + return principalKinds, nil +} + +// upsertSchemaEnvironment creates or updates a schema environment. +// If an environment with the given ID exists, it deletes it first before creating the new one. +func (s *BloodhoundDB) upsertSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) (int32, error) { + if existing, err := s.GetSchemaEnvironmentByKinds(ctx, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil && !errors.Is(err, ErrNotFound) { + return 0, fmt.Errorf("error retrieving schema environment: %w", err) + } else if !errors.Is(err, ErrNotFound) { + // Environment exists - delete it first + if err := s.DeleteSchemaEnvironment(ctx, existing.ID); err != nil { + return 0, fmt.Errorf("error deleting schema environment %d: %w", existing.ID, err) + } + } + + // Create Environment + if created, err := s.CreateSchemaEnvironment(ctx, graphSchema.SchemaExtensionId, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil { + return 0, fmt.Errorf("error creating schema environment: %w", err) + } else { + return created.ID, nil + } +} + +// upsertPrincipalKinds deletes all existing principal kinds for an environment and creates new ones. +func (s *BloodhoundDB) upsertPrincipalKinds(ctx context.Context, environmentID int32, principalKinds []model.SchemaEnvironmentPrincipalKind) error { + if existingKinds, err := s.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentID); err != nil && !errors.Is(err, ErrNotFound) { + return fmt.Errorf("error retrieving existing principal kinds for environment %d: %w", environmentID, err) + } else if !errors.Is(err, ErrNotFound) { + // Delete all existing principal kinds + for _, kind := range existingKinds { + if err := s.DeleteSchemaEnvironmentPrincipalKind(ctx, kind.EnvironmentId, kind.PrincipalKind); err != nil { + return fmt.Errorf("error deleting principal kind %d for environment %d: %w", kind.PrincipalKind, kind.EnvironmentId, err) + } + } + } + + // Create the new principal kinds + for _, kind := range principalKinds { + if _, err := s.CreateSchemaEnvironmentPrincipalKind(ctx, environmentID, kind.PrincipalKind); err != nil { + return fmt.Errorf("error creating principal kind %d for environment %d: %w", kind.PrincipalKind, environmentID, err) + } + } + + return nil +} diff --git a/cmd/api/src/database/upsert_schema_environment_integration_test.go b/cmd/api/src/database/upsert_schema_environment_integration_test.go new file mode 100644 index 0000000000..2ba131cd3c --- /dev/null +++ b/cmd/api/src/database/upsert_schema_environment_integration_test.go @@ -0,0 +1,305 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package database_test + +import ( + "context" + "testing" + + "github.com/specterops/bloodhound/cmd/api/src/database" + "github.com/specterops/dawgs/graph" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { + type args struct { + environmentKind string + sourceKind string + principalKinds []string + } + tests := []struct { + name string + setupData func(t *testing.T, db *database.BloodhoundDB) int32 + args args + assert func(t *testing.T, db *database.BloodhoundDB) + expectedError string + }{ + { + name: "Success: Create new environment with principal kinds", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environmentKind: "Tag_Tier_Zero", + sourceKind: "Base", + principalKinds: []string{"Tag_Tier_Zero", "Tag_Owned"}, + }, + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + expectedPrincipalKindNames := []string{"Tag_Tier_Zero", "Tag_Owned"} + + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 1, len(environments)) + + principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + assert.NoError(t, err) + assert.Equal(t, len(expectedPrincipalKindNames), len(principalKinds)) + + expectedKindIDs := make([]int32, len(expectedPrincipalKindNames)) + for i, name := range expectedPrincipalKindNames { + kind, err := db.GetKindByName(context.Background(), name) + assert.NoError(t, err) + expectedKindIDs[i] = int32(kind.ID) + } + + actualKindIDs := make([]int32, len(principalKinds)) + for i, pk := range principalKinds { + assert.Equal(t, environments[0].ID, pk.EnvironmentId) + actualKindIDs[i] = pk.PrincipalKind + } + + assert.ElementsMatch(t, expectedKindIDs, actualKindIDs) + }, + }, + { + name: "Success: Upsert replaces existing environment", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + err = db.Transaction(context.Background(), func(tx *database.BloodhoundDB) error { + return tx.UpsertSchemaEnvironmentWithPrincipalKinds( + context.Background(), + ext.ID, + "Tag_Tier_Zero", + "Base", + []string{"Tag_Owned"}, + ) + }) + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environmentKind: "Tag_Tier_Zero", + sourceKind: "Base", + principalKinds: []string{"Tag_Tier_Zero"}, + }, + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + expectedPrincipalKindNames := []string{"Tag_Tier_Zero"} + + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 1, len(environments), "Should only have one environment (old one deleted)") + + principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + assert.NoError(t, err) + assert.Equal(t, 1, len(principalKinds)) + + expectedKind, err := db.GetKindByName(context.Background(), expectedPrincipalKindNames[0]) + assert.NoError(t, err) + + assert.Equal(t, int32(expectedKind.ID), principalKinds[0].PrincipalKind) + assert.Equal(t, environments[0].ID, principalKinds[0].EnvironmentId) + }, + }, + { + name: "Success: Source kind auto-registers", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environmentKind: "Tag_Tier_Zero", + sourceKind: "NewSource", + principalKinds: []string{"Tag_Tier_Zero"}, + }, + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + sourceKind, err := db.GetSourceKindByName(context.Background(), "NewSource") + assert.NoError(t, err) + assert.Equal(t, graph.StringKind("NewSource"), sourceKind.Name) + + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 1, len(environments)) + assert.Equal(t, int32(sourceKind.ID), environments[0].SourceKindId) + + principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + assert.NoError(t, err) + assert.Equal(t, 1, len(principalKinds)) + }, + }, + { + name: "Error: Environment kind not found", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environmentKind: "NonExistent", + sourceKind: "Base", + principalKinds: []string{}, + }, + expectedError: "environment kind 'NonExistent' not found", + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + // Verify transaction rolled back - no environment created + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 0, len(environments), "No environment should exist after rollback") + }, + }, + { + name: "Error: Principal kind not found", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environmentKind: "Tag_Tier_Zero", + sourceKind: "Base", + principalKinds: []string{"NonExistent"}, + }, + expectedError: "principal kind 'NonExistent' not found", + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + // Verify transaction rolled back - no environment created + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 0, len(environments), "No environment should exist after rollback") + }, + }, + { + name: "Rollback: Partial failure on second principal kind", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environmentKind: "Tag_Tier_Zero", + sourceKind: "Base", + principalKinds: []string{"Tag_Owned", "NonExistent"}, + }, + expectedError: "principal kind 'NonExistent' not found", + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + // Verify transaction rolled back - no environment created + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 0, len(environments), "No environment should exist after rollback") + }, + }, + { + name: "Success: Multiple environments with different combinations", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + err = db.Transaction(context.Background(), func(tx *database.BloodhoundDB) error { + return tx.UpsertSchemaEnvironmentWithPrincipalKinds( + context.Background(), + ext.ID, + "Tag_Tier_Zero", + "Base", + []string{"Tag_Tier_Zero"}, + ) + }) + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environmentKind: "Tag_Owned", + sourceKind: "Base", + principalKinds: []string{"Tag_Owned"}, + }, + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 2, len(environments), "Should have two different environments") + + for _, env := range environments { + principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), env.ID) + assert.NoError(t, err) + assert.Equal(t, 1, len(principalKinds), "Each environment should have one principal kind") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testSuite := setupIntegrationTestSuite(t) + defer teardownIntegrationTestSuite(t, &testSuite) + + extensionID := tt.setupData(t, testSuite.BHDatabase) + + // Wrap the call in a transaction + err := testSuite.BHDatabase.Transaction(context.Background(), func(tx *database.BloodhoundDB) error { + return tx.UpsertSchemaEnvironmentWithPrincipalKinds( + context.Background(), + extensionID, + tt.args.environmentKind, + tt.args.sourceKind, + tt.args.principalKinds, + ) + }) + + if tt.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + if tt.assert != nil { + tt.assert(t, testSuite.BHDatabase) + } + } else { + require.NoError(t, err) + if tt.assert != nil { + tt.assert(t, testSuite.BHDatabase) + } + } + }) + } +} diff --git a/cmd/api/src/services/opengraphschema/environment.go b/cmd/api/src/services/opengraphschema/environment.go deleted file mode 100644 index 662d429637..0000000000 --- a/cmd/api/src/services/opengraphschema/environment.go +++ /dev/null @@ -1,162 +0,0 @@ -// Copyright 2026 Specter Ops, Inc. -// -// Licensed under the Apache License, Version 2.0 -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 -package opengraphschema - -import ( - "context" - "errors" - "fmt" - - v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" - "github.com/specterops/bloodhound/cmd/api/src/database" - "github.com/specterops/bloodhound/cmd/api/src/model" - "github.com/specterops/dawgs/graph" -) - -// UpsertSchemaEnvironmentWithPrincipalKinds takes a slice of environments, validates and translates each environment. -// The translation is used to upsert the environments into the database. -// If an existing environment is found to already exist in the database, the existing environment will be removed and the new one will be uploaded. -func (o *OpenGraphSchemaService) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environments []v2.Environment) error { - for _, env := range environments { - environment := model.SchemaEnvironment{ - SchemaExtensionId: schemaExtensionId, - } - - if updatedEnv, principalKinds, err := o.validateAndTranslateEnvironment(ctx, environment, env); err != nil { - return fmt.Errorf("error validating and translating environment: %w", err) - } else if envID, err := o.upsertSchemaEnvironment(ctx, updatedEnv); err != nil { - return fmt.Errorf("error upserting schema environment: %w", err) - } else if err := o.upsertPrincipalKinds(ctx, envID, principalKinds); err != nil { - return fmt.Errorf("error upserting principal kinds: %w", err) - } - } - - return nil -} - -// validateAndTranslateEnvironment validates that the environment kind, source kind, and the principal kinds exist in the database. -// It is then translated from the API model to the Database model to prepare it for insert. -func (o *OpenGraphSchemaService) validateAndTranslateEnvironment(ctx context.Context, environment model.SchemaEnvironment, env v2.Environment) (model.SchemaEnvironment, []model.SchemaEnvironmentPrincipalKind, error) { - if envKind, err := o.validateAndTranslateEnvironmentKind(ctx, env.EnvironmentKind); err != nil { - return model.SchemaEnvironment{}, nil, err - } else if sourceKindID, err := o.validateAndTranslateSourceKind(ctx, env.SourceKind); err != nil { - return model.SchemaEnvironment{}, nil, err - } else if principalKinds, err := o.validateAndTranslatePrincipalKinds(ctx, env.PrincipalKinds); err != nil { - return model.SchemaEnvironment{}, nil, err - } else { - // Update environment with translated IDs - environment.EnvironmentKindId = int32(envKind.ID) - environment.SourceKindId = sourceKindID - - return environment, principalKinds, nil - } -} - -// validateAndTranslateEnvironmentKind validates that the environment kind exists in the kinds table. -func (o *OpenGraphSchemaService) validateAndTranslateEnvironmentKind(ctx context.Context, environmentKindName string) (model.Kind, error) { - if envKind, err := o.openGraphSchemaRepository.GetKindByName(ctx, environmentKindName); err != nil && !errors.Is(err, database.ErrNotFound) { - return model.Kind{}, fmt.Errorf("error retrieving environment kind '%s': %w", environmentKindName, err) - } else if errors.Is(err, database.ErrNotFound) { - return model.Kind{}, fmt.Errorf("environment kind '%s' not found", environmentKindName) - } else { - return envKind, nil - } -} - -// validateAndTranslateSourceKind validates that the source kind exists in the source_kinds table. -// If not found, it registers the source kind and returns its ID so it can be added to the Environment object. -func (o *OpenGraphSchemaService) validateAndTranslateSourceKind(ctx context.Context, sourceKindName string) (int32, error) { - if sourceKind, err := o.openGraphSchemaRepository.GetSourceKindByName(ctx, sourceKindName); err != nil && !errors.Is(err, database.ErrNotFound) { - return 0, fmt.Errorf("error retrieving source kind '%s': %w", sourceKindName, err) - } else if err == nil { - return int32(sourceKind.ID), nil - } - - // If source kind is not found, register it. If it exists and is inactive, it will automatically update as active. - kindType := graph.StringKind(sourceKindName) - if err := o.openGraphSchemaRepository.RegisterSourceKind(ctx)(kindType); err != nil { - return 0, fmt.Errorf("error registering source kind '%s': %w", sourceKindName, err) - } - - if sourceKind, err := o.openGraphSchemaRepository.GetSourceKindByName(ctx, sourceKindName); err != nil { - return 0, fmt.Errorf("error retrieving newly registered source kind '%s': %w", sourceKindName, err) - } else { - return int32(sourceKind.ID), nil - } -} - -// validateAndTranslatePrincipalKinds ensures all principalKinds exist in the kinds table. -// It also translates them to IDs so they can be upserted into the database. -func (o *OpenGraphSchemaService) validateAndTranslatePrincipalKinds(ctx context.Context, principalKindNames []string) ([]model.SchemaEnvironmentPrincipalKind, error) { - principalKinds := make([]model.SchemaEnvironmentPrincipalKind, len(principalKindNames)) - - for i, kindName := range principalKindNames { - if kind, err := o.openGraphSchemaRepository.GetKindByName(ctx, kindName); err != nil && !errors.Is(err, database.ErrNotFound) { - return nil, fmt.Errorf("error retrieving principal kind by name '%s': %w", kindName, err) - } else if errors.Is(err, database.ErrNotFound) { - return nil, fmt.Errorf("principal kind '%s' not found", kindName) - } else { - principalKinds[i] = model.SchemaEnvironmentPrincipalKind{ - PrincipalKind: int32(kind.ID), - } - } - } - - return principalKinds, nil -} - -// upsertSchemaEnvironment creates or updates a schema environment. -// If an environment with the given ID exists, it deletes it first before creating the new one. -func (o *OpenGraphSchemaService) upsertSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) (int32, error) { - if existing, err := o.openGraphSchemaRepository.GetSchemaEnvironmentById(ctx, graphSchema.ID); err != nil && !errors.Is(err, database.ErrNotFound) { - return 0, fmt.Errorf("error retrieving schema environment id %d: %w", graphSchema.ID, err) - } else if !errors.Is(err, database.ErrNotFound) { - // Environment exists - delete it first - if err := o.openGraphSchemaRepository.DeleteSchemaEnvironment(ctx, existing.ID); err != nil { - return 0, fmt.Errorf("error deleting schema environment %d: %w", existing.ID, err) - } - } - - // Create Environment - if created, err := o.openGraphSchemaRepository.CreateSchemaEnvironment(ctx, graphSchema.SchemaExtensionId, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil { - return 0, fmt.Errorf("error creating schema environment: %w", err) - } else { - return created.ID, nil - } -} - -// upsertPrincipalKinds deletes all existing principal kinds for an environment and creates new ones. -func (o *OpenGraphSchemaService) upsertPrincipalKinds(ctx context.Context, environmentID int32, principalKinds []model.SchemaEnvironmentPrincipalKind) error { - if existingKinds, err := o.openGraphSchemaRepository.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentID); err != nil && !errors.Is(err, database.ErrNotFound) { - return fmt.Errorf("error retrieving existing principal kinds for environment %d: %w", environmentID, err) - } else if !errors.Is(err, database.ErrNotFound) { - // Delete all existing principal kinds - for _, kind := range existingKinds { - if err := o.openGraphSchemaRepository.DeleteSchemaEnvironmentPrincipalKind(ctx, kind.EnvironmentId, kind.PrincipalKind); err != nil { - return fmt.Errorf("error deleting principal kind %d for environment %d: %w", kind.PrincipalKind, kind.EnvironmentId, err) - } - } - } - - // Create the new principal kinds - for _, kind := range principalKinds { - if _, err := o.openGraphSchemaRepository.CreateSchemaEnvironmentPrincipalKind(ctx, environmentID, kind.PrincipalKind); err != nil { - return fmt.Errorf("error creating principal kind %d for environment %d: %w", kind.PrincipalKind, environmentID, err) - } - } - - return nil -} diff --git a/cmd/api/src/services/opengraphschema/environment_test.go b/cmd/api/src/services/opengraphschema/environment_test.go deleted file mode 100644 index e50362a423..0000000000 --- a/cmd/api/src/services/opengraphschema/environment_test.go +++ /dev/null @@ -1,514 +0,0 @@ -// Copyright 2026 Specter Ops, Inc. -// -// Licensed under the Apache License, Version 2.0 -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 -package opengraphschema_test - -import ( - "context" - "errors" - "testing" - - v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" - "github.com/specterops/bloodhound/cmd/api/src/database" - "github.com/specterops/bloodhound/cmd/api/src/model" - "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" - schemamocks "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema/mocks" - "github.com/specterops/dawgs/graph" - "github.com/stretchr/testify/assert" - "go.uber.org/mock/gomock" -) - -func TestOpenGraphSchemaService_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { - type mocks struct { - mockOpenGraphSchema *schemamocks.MockOpenGraphSchemaRepository - } - type args struct { - schemaExtensionId int32 - environments []v2.Environment - } - tests := []struct { - name string - mocks mocks - setupMocks func(t *testing.T, mock *mocks) - args args - expected error - }{ - // Validation: Environment Kind - { - name: "Error: openGraphSchemaRepository.GetKindByName environment kind name not found in the database", - args: args{ - schemaExtensionId: int32(1), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{}, database.ErrNotFound) - }, - expected: errors.New("error validating and translating environment: environment kind 'Domain' not found"), - }, - { - name: "Error: openGraphSchemaRepository.GetKindByName failed to retrieve environment kind from database", - args: args{ - schemaExtensionId: int32(1), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{}, errors.New("error")) - }, - expected: errors.New("error validating and translating environment: error retrieving environment kind 'Domain': error"), - }, - // Validation: Source Kind - { - name: "Error: validateAndTranslateSourceKind failed to retrieve source kind from database", - args: args{ - schemaExtensionId: int32(1), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 1}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{}, errors.New("error")) - }, - expected: errors.New("error validating and translating environment: error retrieving source kind 'Base': error"), - }, - { - name: "Error: validateAndTranslateSourceKind source kind name doesn't exist in database, registration fails", - args: args{ - schemaExtensionId: int32(1), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 1}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().RegisterSourceKind(gomock.Any()).Return(func(kind graph.Kind) error { - return errors.New("error") - }) - }, - expected: errors.New("error validating and translating environment: error registering source kind 'Base': error"), - }, - { - name: "Error: validateAndTranslateSourceKind source kind name doesn't exist in database, registration succeeds but fetch fails", - args: args{ - schemaExtensionId: int32(1), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 1}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().RegisterSourceKind(gomock.Any()).Return(func(kind graph.Kind) error { - return nil - }) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{}, errors.New("error")) - }, - expected: errors.New("error validating and translating environment: error retrieving newly registered source kind 'Base': error"), - }, - // Validation: Principal Kind - { - name: "Error: validateAndTranslatePrincipalKinds principal kind not found in database", - args: args{ - schemaExtensionId: int32(1), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User", "InvalidKind"}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 1}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 2}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "InvalidKind").Return(model.Kind{}, database.ErrNotFound) - }, - expected: errors.New("error validating and translating environment: principal kind 'InvalidKind' not found"), - }, - { - name: "Error: validateAndTranslatePrincipalKinds failed to retrieve principal kind from database", - args: args{ - schemaExtensionId: int32(1), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User", "InvalidKind"}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 1}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 2}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "InvalidKind").Return(model.Kind{}, errors.New("error")) - }, - expected: errors.New("error validating and translating environment: error retrieving principal kind by name 'InvalidKind': error"), - }, - // Upsert Schema Environment - { - name: "Error: upsertSchemaEnvironment error retrieving schema environment from database", - args: args{ - schemaExtensionId: int32(3), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, errors.New("error")) - }, - expected: errors.New("error upserting schema environment: error retrieving schema environment id 0: error"), - }, - { - name: "Error: upsertSchemaEnvironment error deleting schema environment", - args: args{ - schemaExtensionId: int32(3), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{ - Serial: model.Serial{ID: 5}, - }, nil) - mocks.mockOpenGraphSchema.EXPECT().DeleteSchemaEnvironment(gomock.Any(), int32(5)).Return(errors.New("error")) - }, - expected: errors.New("error upserting schema environment: error deleting schema environment 5: error"), - }, - { - name: "Error: upsertSchemaEnvironment error creating schema environment after deletion", - args: args{ - schemaExtensionId: int32(3), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{ - Serial: model.Serial{ID: 5}, - }, nil) - mocks.mockOpenGraphSchema.EXPECT().DeleteSchemaEnvironment(gomock.Any(), int32(5)).Return(nil) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{}, errors.New("error")) - }, - expected: errors.New("error upserting schema environment: error creating schema environment: error"), - }, - { - name: "Error: upsertSchemaEnvironment error creating new schema environment", - args: args{ - schemaExtensionId: int32(3), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{}, errors.New("error")) - }, - expected: errors.New("error upserting schema environment: error creating schema environment: error"), - }, - // Upsert Principal Kinds - { - name: "Error: upsertPrincipalKinds error getting principal kinds by environment id", - args: args{ - schemaExtensionId: int32(3), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User"}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - // Validation and translation - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 3}, nil) - - // Environment upsert - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ - Serial: model.Serial{ID: 10}, - }, nil) - - // Principal kinds upsert - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, errors.New("error")) - }, - expected: errors.New("error upserting principal kinds: error retrieving existing principal kinds for environment 10: error"), - }, - { - name: "Error: upsertPrincipalKinds error deleting principal kinds", - args: args{ - schemaExtensionId: int32(3), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User"}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - // Validation and translation - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 3}, nil) - - // Environment upsert - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ - Serial: model.Serial{ID: 10}, - }, nil) - - // Principal kinds upsert - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return([]model.SchemaEnvironmentPrincipalKind{ - { - EnvironmentId: int32(10), - PrincipalKind: int32(5), - }, - }, nil) - mocks.mockOpenGraphSchema.EXPECT().DeleteSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(5)).Return(errors.New("error")) - }, - expected: errors.New("error upserting principal kinds: error deleting principal kind 5 for environment 10: error"), - }, - { - name: "Error: upsertPrincipalKinds error creating principal kinds", - args: args{ - schemaExtensionId: int32(3), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User"}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - // Validation and translation - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 3}, nil) - - // Environment upsert - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ - Serial: model.Serial{ID: 10}, - }, nil) - - // Principal kinds upsert - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(3)).Return(model.SchemaEnvironmentPrincipalKind{}, errors.New("error")) - }, - expected: errors.New("error upserting principal kinds: error creating principal kind 3 for environment 10: error"), - }, - { - name: "Success: Create new environment with principal kinds", - args: args{ - schemaExtensionId: int32(3), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User", "Computer"}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - // Validation and translation - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 4}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Computer").Return(model.Kind{ID: 5}, nil) - - // Environment upsert - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ - Serial: model.Serial{ID: 10}, - }, nil) - - // Principal kinds upsert - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(4)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(5)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) - }, - expected: nil, - }, - { - name: "Success: Create environment with source kind registration", - args: args{ - schemaExtensionId: int32(3), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "NewSource", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - // Validation and translation - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) - // Source kind not found, register it - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "NewSource").Return(database.SourceKind{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().RegisterSourceKind(gomock.Any()).Return(func(kind graph.Kind) error { - return nil - }) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "NewSource").Return(database.SourceKind{ID: 10}, nil) - - // Environment upsert - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(10)).Return(model.SchemaEnvironment{ - Serial: model.Serial{ID: 10}, - }, nil) - - // Principal kinds upsert - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, database.ErrNotFound) - }, - expected: nil, - }, - { - name: "Success: Process multiple environments", - args: args{ - schemaExtensionId: int32(3), - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User"}, - }, - { - EnvironmentKind: "AzureAD", - SourceKind: "AzureHound", - PrincipalKinds: []string{"User", "Group"}, - }, - }, - }, - setupMocks: func(t *testing.T, mocks *mocks) { - t.Helper() - // First environment - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Domain").Return(model.Kind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "Base").Return(database.SourceKind{ID: 3}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 4}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(3), int32(3)).Return(model.SchemaEnvironment{ - Serial: model.Serial{ID: 10}, - }, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(10)).Return(nil, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(10), int32(4)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) - - // Second environment - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "AzureAD").Return(model.Kind{ID: 5}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSourceKindByName(gomock.Any(), "AzureHound").Return(database.SourceKind{ID: 6}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "User").Return(model.Kind{ID: 4}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetKindByName(gomock.Any(), "Group").Return(model.Kind{ID: 7}, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentById(gomock.Any(), int32(0)).Return(model.SchemaEnvironment{}, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironment(gomock.Any(), int32(3), int32(5), int32(6)).Return(model.SchemaEnvironment{ - Serial: model.Serial{ID: 11}, - }, nil) - mocks.mockOpenGraphSchema.EXPECT().GetSchemaEnvironmentPrincipalKindsByEnvironmentId(gomock.Any(), int32(11)).Return(nil, database.ErrNotFound) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(11), int32(4)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) - mocks.mockOpenGraphSchema.EXPECT().CreateSchemaEnvironmentPrincipalKind(gomock.Any(), int32(11), int32(7)).Return(model.SchemaEnvironmentPrincipalKind{}, nil) - }, - expected: nil, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - mocks := &mocks{ - mockOpenGraphSchema: schemamocks.NewMockOpenGraphSchemaRepository(ctrl), - } - - tt.setupMocks(t, mocks) - - graphService := opengraphschema.NewOpenGraphSchemaService(mocks.mockOpenGraphSchema) - - err := graphService.UpsertSchemaEnvironmentWithPrincipalKinds(context.Background(), tt.args.schemaExtensionId, tt.args.environments) - if tt.expected != nil { - assert.EqualError(t, err, tt.expected.Error()) - } else { - assert.NoError(t, err) - } - }) - } -} diff --git a/cmd/api/src/services/opengraphschema/extension.go b/cmd/api/src/services/opengraphschema/extension.go index aefcff32e3..76e960f3b4 100644 --- a/cmd/api/src/services/opengraphschema/extension.go +++ b/cmd/api/src/services/opengraphschema/extension.go @@ -23,22 +23,11 @@ import ( ) func (o *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { - tx, err := o.openGraphSchemaRepository.BeginTransaction(ctx) - if err != nil { - return err + for _, env := range req.Environments { + if err := o.openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, env.EnvironmentKind, env.SourceKind, env.PrincipalKinds); err != nil { + return fmt.Errorf("failed to upload environments with principal kinds: %w", err) + } } - // Create service with transaction - transactionalService := &OpenGraphSchemaService{ - openGraphSchemaRepository: tx, - } - - // TODO: Temporary hardcoded extension ID - // Upsert environments with principal kinds - if err := transactionalService.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, req.Environments); err != nil { - tx.Rollback() - return fmt.Errorf("failed to upload environments with principal kinds: %w", err) - } - - return tx.Commit() + return nil } diff --git a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go index e9f00e40dc..9c3e712987 100644 --- a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go @@ -27,12 +27,8 @@ package mocks import ( context "context" - sql "database/sql" reflect "reflect" - database "github.com/specterops/bloodhound/cmd/api/src/database" - model "github.com/specterops/bloodhound/cmd/api/src/model" - graph "github.com/specterops/dawgs/graph" gomock "go.uber.org/mock/gomock" ) @@ -60,182 +56,16 @@ func (m *MockOpenGraphSchemaRepository) EXPECT() *MockOpenGraphSchemaRepositoryM return m.recorder } -// BeginTransaction mocks base method. -func (m *MockOpenGraphSchemaRepository) BeginTransaction(ctx context.Context, opts ...*sql.TxOptions) (*database.BloodhoundDB, error) { +// UpsertSchemaEnvironmentWithPrincipalKinds mocks base method. +func (m *MockOpenGraphSchemaRepository) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environmentKind, sourceKind string, principalKinds []string) error { m.ctrl.T.Helper() - varargs := []any{ctx} - for _, a := range opts { - varargs = append(varargs, a) - } - ret := m.ctrl.Call(m, "BeginTransaction", varargs...) - ret0, _ := ret[0].(*database.BloodhoundDB) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// BeginTransaction indicates an expected call of BeginTransaction. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) BeginTransaction(ctx any, opts ...any) *gomock.Call { - mr.mock.ctrl.T.Helper() - varargs := append([]any{ctx}, opts...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BeginTransaction", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).BeginTransaction), varargs...) -} - -// Commit mocks base method. -func (m *MockOpenGraphSchemaRepository) Commit() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Commit") - ret0, _ := ret[0].(error) - return ret0 -} - -// Commit indicates an expected call of Commit. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) Commit() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).Commit)) -} - -// CreateSchemaEnvironment mocks base method. -func (m *MockOpenGraphSchemaRepository) CreateSchemaEnvironment(ctx context.Context, schemaExtensionId, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateSchemaEnvironment", ctx, schemaExtensionId, environmentKindId, sourceKindId) - ret0, _ := ret[0].(model.SchemaEnvironment) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateSchemaEnvironment indicates an expected call of CreateSchemaEnvironment. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) CreateSchemaEnvironment(ctx, schemaExtensionId, environmentKindId, sourceKindId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironment", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).CreateSchemaEnvironment), ctx, schemaExtensionId, environmentKindId, sourceKindId) -} - -// CreateSchemaEnvironmentPrincipalKind mocks base method. -func (m *MockOpenGraphSchemaRepository) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) - ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKind) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateSchemaEnvironmentPrincipalKind indicates an expected call of CreateSchemaEnvironmentPrincipalKind. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) CreateSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).CreateSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) -} - -// DeleteSchemaEnvironment mocks base method. -func (m *MockOpenGraphSchemaRepository) DeleteSchemaEnvironment(ctx context.Context, schemaEnvironmentId int32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteSchemaEnvironment", ctx, schemaEnvironmentId) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteSchemaEnvironment indicates an expected call of DeleteSchemaEnvironment. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) DeleteSchemaEnvironment(ctx, schemaEnvironmentId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironment", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).DeleteSchemaEnvironment), ctx, schemaEnvironmentId) -} - -// DeleteSchemaEnvironmentPrincipalKind mocks base method. -func (m *MockOpenGraphSchemaRepository) DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteSchemaEnvironmentPrincipalKind indicates an expected call of DeleteSchemaEnvironmentPrincipalKind. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) DeleteSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).DeleteSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) -} - -// GetKindByName mocks base method. -func (m *MockOpenGraphSchemaRepository) GetKindByName(ctx context.Context, name string) (model.Kind, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetKindByName", ctx, name) - ret0, _ := ret[0].(model.Kind) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetKindByName indicates an expected call of GetKindByName. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetKindByName(ctx, name any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKindByName", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetKindByName), ctx, name) -} - -// GetSchemaEnvironmentById mocks base method. -func (m *MockOpenGraphSchemaRepository) GetSchemaEnvironmentById(ctx context.Context, schemaEnvironmentId int32) (model.SchemaEnvironment, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaEnvironmentById", ctx, schemaEnvironmentId) - ret0, _ := ret[0].(model.SchemaEnvironment) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaEnvironmentById indicates an expected call of GetSchemaEnvironmentById. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetSchemaEnvironmentById(ctx, schemaEnvironmentId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentById", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetSchemaEnvironmentById), ctx, schemaEnvironmentId) -} - -// GetSchemaEnvironmentPrincipalKindsByEnvironmentId mocks base method. -func (m *MockOpenGraphSchemaRepository) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", ctx, environmentId) - ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKinds) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaEnvironmentPrincipalKindsByEnvironmentId indicates an expected call of GetSchemaEnvironmentPrincipalKindsByEnvironmentId. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetSchemaEnvironmentPrincipalKindsByEnvironmentId), ctx, environmentId) -} - -// GetSourceKindByName mocks base method. -func (m *MockOpenGraphSchemaRepository) GetSourceKindByName(ctx context.Context, name string) (database.SourceKind, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSourceKindByName", ctx, name) - ret0, _ := ret[0].(database.SourceKind) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSourceKindByName indicates an expected call of GetSourceKindByName. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) GetSourceKindByName(ctx, name any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSourceKindByName", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).GetSourceKindByName), ctx, name) -} - -// RegisterSourceKind mocks base method. -func (m *MockOpenGraphSchemaRepository) RegisterSourceKind(ctx context.Context) func(graph.Kind) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RegisterSourceKind", ctx) - ret0, _ := ret[0].(func(graph.Kind) error) - return ret0 -} - -// RegisterSourceKind indicates an expected call of RegisterSourceKind. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) RegisterSourceKind(ctx any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterSourceKind", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).RegisterSourceKind), ctx) -} - -// Rollback mocks base method. -func (m *MockOpenGraphSchemaRepository) Rollback() error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Rollback") + ret := m.ctrl.Call(m, "UpsertSchemaEnvironmentWithPrincipalKinds", ctx, schemaExtensionId, environmentKind, sourceKind, principalKinds) ret0, _ := ret[0].(error) return ret0 } -// Rollback indicates an expected call of Rollback. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) Rollback() *gomock.Call { +// UpsertSchemaEnvironmentWithPrincipalKinds indicates an expected call of UpsertSchemaEnvironmentWithPrincipalKinds. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) UpsertSchemaEnvironmentWithPrincipalKinds(ctx, schemaExtensionId, environmentKind, sourceKind, principalKinds any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).Rollback)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertSchemaEnvironmentWithPrincipalKinds", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).UpsertSchemaEnvironmentWithPrincipalKinds), ctx, schemaExtensionId, environmentKind, sourceKind, principalKinds) } diff --git a/cmd/api/src/services/opengraphschema/opengraphschema.go b/cmd/api/src/services/opengraphschema/opengraphschema.go index 21e61bd088..6de81562dc 100644 --- a/cmd/api/src/services/opengraphschema/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/opengraphschema.go @@ -19,36 +19,11 @@ package opengraphschema import ( "context" - "database/sql" - - "github.com/specterops/bloodhound/cmd/api/src/database" - "github.com/specterops/bloodhound/cmd/api/src/model" - "github.com/specterops/dawgs/graph" ) // OpenGraphSchemaRepository - type OpenGraphSchemaRepository interface { - // TX - BeginTransaction(ctx context.Context, opts ...*sql.TxOptions) (*database.BloodhoundDB, error) - Commit() error - Rollback() error - - // Kinds - GetKindByName(ctx context.Context, name string) (model.Kind, error) - - // Environment - CreateSchemaEnvironment(ctx context.Context, schemaExtensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) - GetSchemaEnvironmentById(ctx context.Context, schemaEnvironmentId int32) (model.SchemaEnvironment, error) - DeleteSchemaEnvironment(ctx context.Context, schemaEnvironmentId int32) error - - // Source Kinds - RegisterSourceKind(ctx context.Context) func(sourceKind graph.Kind) error - GetSourceKindByName(ctx context.Context, name string) (database.SourceKind, error) - - // Principal Kinds - CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) - GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) - DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error + UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environmentKind string, sourceKind string, principalKinds []string) error } type OpenGraphSchemaService struct { diff --git a/cmd/api/src/services/opengraphschema/opengraphschema_test.go b/cmd/api/src/services/opengraphschema/opengraphschema_test.go new file mode 100644 index 0000000000..5b2ac311ad --- /dev/null +++ b/cmd/api/src/services/opengraphschema/opengraphschema_test.go @@ -0,0 +1,499 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package opengraphschema_test + +import ( + "context" + "errors" + "testing" + + v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" + "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" + schemamocks "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema/mocks" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { + type mocks struct { + mockOpenGraphSchema *schemamocks.MockOpenGraphSchemaRepository + } + type args struct { + schemaExtensionId int32 + environments []v2.Environment + } + tests := []struct { + name string + setupMocks func(t *testing.T, m *mocks) + args args + expected error + }{ + // UpsertSchemaEnvironmentWithPrincipalKinds + // Validation: Environment Kind + { + name: "Error: environment kind name not found in the database", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{}, + ).Return(errors.New("environment kind 'Domain' not found")) + }, + expected: errors.New("failed to upload environments with principal kinds: environment kind 'Domain' not found"), + }, + { + name: "Error: failed to retrieve environment kind from database", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{}, + ).Return(errors.New("error retrieving environment kind 'Domain': error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error retrieving environment kind 'Domain': error"), + }, + // Validation: Source Kind + { + name: "Error: failed to retrieve source kind from database", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{}, + ).Return(errors.New("error retrieving source kind 'Base': error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error retrieving source kind 'Base': error"), + }, + { + name: "Error: source kind name doesn't exist in database, registration fails", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{}, + ).Return(errors.New("error registering source kind 'Base': error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error registering source kind 'Base': error"), + }, + { + name: "Error: source kind name doesn't exist in database, registration succeeds but fetch fails", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{}, + ).Return(errors.New("error retrieving newly registered source kind 'Base': error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error retrieving newly registered source kind 'Base': error"), + }, + // Validation: Principal Kind + { + name: "Error: principal kind not found in database", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User", "InvalidKind"}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{"User", "InvalidKind"}, + ).Return(errors.New("principal kind 'InvalidKind' not found")) + }, + expected: errors.New("failed to upload environments with principal kinds: principal kind 'InvalidKind' not found"), + }, + { + name: "Error: failed to retrieve principal kind from database", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User", "InvalidKind"}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{"User", "InvalidKind"}, + ).Return(errors.New("error retrieving principal kind by name 'InvalidKind': error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error retrieving principal kind by name 'InvalidKind': error"), + }, + // Upsert Schema Environment + { + name: "Error: error retrieving schema environment from database", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{}, + ).Return(errors.New("error upserting schema environment: error retrieving schema environment id 0: error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error upserting schema environment: error retrieving schema environment id 0: error"), + }, + { + name: "Error: error deleting schema environment", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{}, + ).Return(errors.New("error upserting schema environment: error deleting schema environment 5: error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error upserting schema environment: error deleting schema environment 5: error"), + }, + { + name: "Error: error creating schema environment after deletion", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{}, + ).Return(errors.New("error upserting schema environment: error creating schema environment: error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error upserting schema environment: error creating schema environment: error"), + }, + { + name: "Error: error creating new schema environment", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{}, + ).Return(errors.New("error upserting schema environment: error creating schema environment: error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error upserting schema environment: error creating schema environment: error"), + }, + // Upsert Principal Kinds + { + name: "Error: error getting principal kinds by environment id", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{"User"}, + ).Return(errors.New("error upserting principal kinds: error retrieving existing principal kinds for environment 10: error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error upserting principal kinds: error retrieving existing principal kinds for environment 10: error"), + }, + { + name: "Error: openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds error deleting principal kinds", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{"User"}, + ).Return(errors.New("error upserting principal kinds: error deleting principal kind 5 for environment 10: error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error upserting principal kinds: error deleting principal kind 5 for environment 10: error"), + }, + { + name: "Error: openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds error creating principal kinds", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{"User"}, + ).Return(errors.New("error upserting principal kinds: error creating principal kind 3 for environment 10: error")) + }, + expected: errors.New("failed to upload environments with principal kinds: error upserting principal kinds: error creating principal kind 3 for environment 10: error"), + }, + { + name: "Success: Create new environment with principal kinds", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User", "Computer"}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{"User", "Computer"}, + ).Return(nil) + }, + expected: nil, + }, + { + name: "Success: Create environment with source kind registration", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "NewSource", + PrincipalKinds: []string{}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "NewSource", + []string{}, + ).Return(nil) + }, + expected: nil, + }, + { + name: "Success: Process multiple environments", + args: args{ + schemaExtensionId: int32(1), + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, + }, + { + EnvironmentKind: "AzureAD", + SourceKind: "AzureHound", + PrincipalKinds: []string{"User", "Group"}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + // First environment + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "Domain", + "Base", + []string{"User"}, + ).Return(nil) + + // Second environment + m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( + gomock.Any(), + int32(1), + "AzureAD", + "AzureHound", + []string{"User", "Group"}, + ).Return(nil) + }, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + m := &mocks{ + mockOpenGraphSchema: schemamocks.NewMockOpenGraphSchemaRepository(ctrl), + } + + tt.setupMocks(t, m) + + service := opengraphschema.NewOpenGraphSchemaService(m.mockOpenGraphSchema) + + err := service.UpsertGraphSchemaExtension(context.Background(), v2.GraphSchemaExtension{ + Environments: tt.args.environments, + }) + + if tt.expected != nil { + assert.EqualError(t, err, tt.expected.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} From 38186bed4c70611d6a277ac85b26b8330329df22 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Fri, 9 Jan 2026 09:31:19 -0600 Subject: [PATCH 12/25] pulled in Conrad's changes --- cmd/api/src/database/graphschema.go | 17 +++++++++++++++++ cmd/api/src/database/mocks/db.go | 3 --- cmd/api/src/model/graphschema.go | 15 +-------------- cmd/ui/src/rendering/utils/dagre/dagre.test.ts | 2 +- cmd/ui/src/rendering/utils/dagre/dagre.ts | 3 ++- .../PrivilegeZones/Details/ObjectCountPanel.tsx | 2 +- 6 files changed, 22 insertions(+), 20 deletions(-) diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index 70613a9769..26173a3224 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -53,6 +53,7 @@ type OpenGraphSchema interface { GetGraphSchemaEdgeKindsWithSchemaName(ctx context.Context, edgeKindFilters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaEdgeKindsWithNamedSchema, int, error) CreateSchemaEnvironment(ctx context.Context, extensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) + GetSchemaEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) GetSchemaEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) GetSchemaEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) DeleteSchemaEnvironment(ctx context.Context, environmentId int32) error @@ -538,6 +539,22 @@ func (s *BloodhoundDB) GetSchemaEnvironments(ctx context.Context) ([]model.Schem return result, CheckError(s.db.WithContext(ctx).Find(&result)) } +// GetSchemaEnvironmentByKinds - retrieves an environment by its environment kind and source kind. +func (s *BloodhoundDB) GetSchemaEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { + var env model.SchemaEnvironment + + if result := s.db.WithContext(ctx).Raw( + "SELECT * FROM schema_environments WHERE environment_kind_id = ? AND source_kind_id = ? AND deleted_at IS NULL", + environmentKindId, sourceKindId, + ).Scan(&env); result.Error != nil { + return model.SchemaEnvironment{}, CheckError(result) + } else if result.RowsAffected == 0 { + return model.SchemaEnvironment{}, ErrNotFound + } + + return env, nil +} + // GetSchemaEnvironmentById - retrieves a schema environment by id. func (s *BloodhoundDB) GetSchemaEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) { var schemaEnvironment model.SchemaEnvironment diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index 4c6f03c883..b911c1ad42 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -2202,7 +2202,6 @@ func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentById(ctx, environmentId return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentById", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentById), ctx, environmentId) } -<<<<<<< HEAD // GetSchemaEnvironmentByKinds mocks base method. func (m *MockDatabase) GetSchemaEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { m.ctrl.T.Helper() @@ -2218,8 +2217,6 @@ func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentByKinds(ctx, environment return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentByKinds", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentByKinds), ctx, environmentKindId, sourceKindId) } -======= ->>>>>>> origin/BED-7076 // GetSchemaEnvironmentPrincipalKindsByEnvironmentId mocks base method. func (m *MockDatabase) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/model/graphschema.go b/cmd/api/src/model/graphschema.go index 44985d3dcc..397ba77afc 100644 --- a/cmd/api/src/model/graphschema.go +++ b/cmd/api/src/model/graphschema.go @@ -16,9 +16,7 @@ package model -import ( - "time" -) +import "time" type GraphSchemaExtensions []GraphSchemaExtension @@ -170,14 +168,3 @@ type GraphSchemaEdgeKindWithNamedSchema struct { } type GraphSchemaEdgeKindsWithNamedSchema []GraphSchemaEdgeKindWithNamedSchema - -type SchemaEnvironmentPrincipalKinds []SchemaEnvironmentPrincipalKind - -type SchemaEnvironmentPrincipalKind struct { - EnvironmentId int32 `json:"environment_id"` - PrincipalKind int32 `json:"principal_kind"` -} - -func (SchemaEnvironmentPrincipalKind) TableName() string { - return "schema_environments_principal_kinds" -} diff --git a/cmd/ui/src/rendering/utils/dagre/dagre.test.ts b/cmd/ui/src/rendering/utils/dagre/dagre.test.ts index 0a6ab82b09..c80538ddf3 100644 --- a/cmd/ui/src/rendering/utils/dagre/dagre.test.ts +++ b/cmd/ui/src/rendering/utils/dagre/dagre.test.ts @@ -16,9 +16,9 @@ import dagre from '@dagrejs/dagre'; import Graph from 'graphology'; +import { getEdgeDataFromKey } from 'src/ducks/graph/utils'; import { NODE_DEFAULT_SIZE, applyNodePositionsFromGraphlibGraph, copySigmaNodesToGraphlibGraph } from './'; import { copySigmaEdgesToGraphlibGraph } from './dagre'; -import { getEdgeDataFromKey } from 'src/ducks/graph/utils'; const sigmaGraph = new Graph(); const graphlibGraph = new dagre.graphlib.Graph(); diff --git a/cmd/ui/src/rendering/utils/dagre/dagre.ts b/cmd/ui/src/rendering/utils/dagre/dagre.ts index c62ac36ca0..4a519fa99a 100644 --- a/cmd/ui/src/rendering/utils/dagre/dagre.ts +++ b/cmd/ui/src/rendering/utils/dagre/dagre.ts @@ -98,7 +98,8 @@ export const copySigmaEdgesToGraphlibGraph = ( ): void => { sigmaGraph.forEachEdge((edge: string) => { const edgeData = getEdgeDataFromKey(edge); - if (edgeData !== null) graphlibGraph.setEdge(edgeData.source, edgeData.target, { label: edgeData.label, points: [] }); + if (edgeData !== null) + graphlibGraph.setEdge(edgeData.source, edgeData.target, { label: edgeData.label, points: [] }); }); }; diff --git a/packages/javascript/bh-shared-ui/src/views/PrivilegeZones/Details/ObjectCountPanel.tsx b/packages/javascript/bh-shared-ui/src/views/PrivilegeZones/Details/ObjectCountPanel.tsx index 8439b6ae57..fafdc11d83 100644 --- a/packages/javascript/bh-shared-ui/src/views/PrivilegeZones/Details/ObjectCountPanel.tsx +++ b/packages/javascript/bh-shared-ui/src/views/PrivilegeZones/Details/ObjectCountPanel.tsx @@ -19,8 +19,8 @@ import { FC } from 'react'; import { useQuery } from 'react-query'; import { NodeIcon } from '../../../components'; import { useEnvironmentIdList } from '../../../hooks'; -import { apiClient } from '../../../utils'; import { ENVIRONMENT_AGGREGATION_SUPPORTED_ROUTES } from '../../../routes'; +import { apiClient } from '../../../utils'; const ObjectCountPanel: FC<{ tagId: string }> = ({ tagId }) => { const environments = useEnvironmentIdList(ENVIRONMENT_AGGREGATION_SUPPORTED_ROUTES, false); From ab74bf194d8f4635121f02669980bd7d5c893335 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Fri, 9 Jan 2026 09:49:28 -0600 Subject: [PATCH 13/25] code rabbit comments --- cmd/api/src/api/v2/opengraphschema.go | 2 +- cmd/api/src/database/kind_integration_test.go | 2 +- .../database/sourcekinds_integration_test.go | 8 ++++---- .../src/database/upsert_schema_environment.go | 2 +- .../src/services/opengraphschema/extension.go | 1 + .../opengraphschema/opengraphschema_test.go | 19 +------------------ 6 files changed, 9 insertions(+), 25 deletions(-) diff --git a/cmd/api/src/api/v2/opengraphschema.go b/cmd/api/src/api/v2/opengraphschema.go index a49248fdf0..24cf755b49 100644 --- a/cmd/api/src/api/v2/opengraphschema.go +++ b/cmd/api/src/api/v2/opengraphschema.go @@ -39,7 +39,7 @@ type Environment struct { PrincipalKinds []string `json:"principalKinds"` } -// TODO: Implement this - barebones in order to test handler. +// TODO: Implement this - skeleton endpoint to simply test the handler. func (s Resources) OpenGraphSchemaIngest(response http.ResponseWriter, request *http.Request) { var ( ctx = request.Context() diff --git a/cmd/api/src/database/kind_integration_test.go b/cmd/api/src/database/kind_integration_test.go index 089db89dc9..8c541dfc55 100644 --- a/cmd/api/src/database/kind_integration_test.go +++ b/cmd/api/src/database/kind_integration_test.go @@ -65,7 +65,7 @@ func TestGetKindByName(t *testing.T) { kind, err := testSuite.BHDatabase.GetKindByName(testSuite.Context, testCase.args.name) if testCase.want.err != nil { - assert.EqualError(t, testCase.want.err, err.Error()) + assert.EqualError(t, err, testCase.want.err.Error()) } else { assert.NoError(t, err) assert.Equal(t, testCase.want.kind, kind) diff --git a/cmd/api/src/database/sourcekinds_integration_test.go b/cmd/api/src/database/sourcekinds_integration_test.go index 329831d2cb..0d8931d504 100644 --- a/cmd/api/src/database/sourcekinds_integration_test.go +++ b/cmd/api/src/database/sourcekinds_integration_test.go @@ -142,7 +142,7 @@ func TestRegisterSourceKind(t *testing.T) { err := testSuite.BHDatabase.RegisterSourceKind(testSuite.Context)(testCase.args.sourceKind) if testCase.want.err != nil { - assert.EqualError(t, testCase.want.err, err.Error()) + assert.EqualError(t, err, testCase.want.err.Error()) } else { assert.NoError(t, err) } @@ -197,7 +197,7 @@ func TestGetSourceKinds(t *testing.T) { sourceKinds, err := testSuite.BHDatabase.GetSourceKinds(testSuite.Context) if testCase.want.err != nil { - assert.EqualError(t, testCase.want.err, err.Error()) + assert.EqualError(t, err, testCase.want.err.Error()) } else { assert.NoError(t, err) assert.Equal(t, testCase.want.sourceKinds, sourceKinds) @@ -247,7 +247,7 @@ func TestGetSourceKindByName(t *testing.T) { sourceKind, err := testSuite.BHDatabase.GetSourceKindByName(testSuite.Context, testCase.args.name) if testCase.want.err != nil { - assert.EqualError(t, testCase.want.err, err.Error()) + assert.EqualError(t, err, testCase.want.err.Error()) } else { assert.NoError(t, err) assert.Equal(t, testCase.want.sourceKind, sourceKind) @@ -423,7 +423,7 @@ func TestDeactivateSourceKindsByName(t *testing.T) { err := testSuite.BHDatabase.DeactivateSourceKindsByName(testSuite.Context, testCase.args.sourceKind) if testCase.want.err != nil { - assert.EqualError(t, testCase.want.err, err.Error()) + assert.EqualError(t, err, testCase.want.err.Error()) } else { assert.NoError(t, err) } diff --git a/cmd/api/src/database/upsert_schema_environment.go b/cmd/api/src/database/upsert_schema_environment.go index f19e39cd28..71b3107c44 100644 --- a/cmd/api/src/database/upsert_schema_environment.go +++ b/cmd/api/src/database/upsert_schema_environment.go @@ -117,7 +117,7 @@ func (s *BloodhoundDB) validateAndTranslatePrincipalKinds(ctx context.Context, p } // upsertSchemaEnvironment creates or updates a schema environment. -// If an environment with the given ID exists, it deletes it first before creating the new one. +// If an environment with the given kinds exists, it deletes it first before creating the new one. func (s *BloodhoundDB) upsertSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) (int32, error) { if existing, err := s.GetSchemaEnvironmentByKinds(ctx, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil && !errors.Is(err, ErrNotFound) { return 0, fmt.Errorf("error retrieving schema environment: %w", err) diff --git a/cmd/api/src/services/opengraphschema/extension.go b/cmd/api/src/services/opengraphschema/extension.go index 76e960f3b4..85d44ba1da 100644 --- a/cmd/api/src/services/opengraphschema/extension.go +++ b/cmd/api/src/services/opengraphschema/extension.go @@ -24,6 +24,7 @@ import ( func (o *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { for _, env := range req.Environments { + // TODO: Update temporary hardcoded extensionID once extension work is complete if err := o.openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, env.EnvironmentKind, env.SourceKind, env.PrincipalKinds); err != nil { return fmt.Errorf("failed to upload environments with principal kinds: %w", err) } diff --git a/cmd/api/src/services/opengraphschema/opengraphschema_test.go b/cmd/api/src/services/opengraphschema/opengraphschema_test.go index 5b2ac311ad..a07253548b 100644 --- a/cmd/api/src/services/opengraphschema/opengraphschema_test.go +++ b/cmd/api/src/services/opengraphschema/opengraphschema_test.go @@ -32,8 +32,7 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { mockOpenGraphSchema *schemamocks.MockOpenGraphSchemaRepository } type args struct { - schemaExtensionId int32 - environments []v2.Environment + environments []v2.Environment } tests := []struct { name string @@ -46,7 +45,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: environment kind name not found in the database", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -70,7 +68,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: failed to retrieve environment kind from database", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -95,7 +92,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: failed to retrieve source kind from database", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -119,7 +115,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: source kind name doesn't exist in database, registration fails", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -143,7 +138,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: source kind name doesn't exist in database, registration succeeds but fetch fails", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -168,7 +162,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: principal kind not found in database", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -192,7 +185,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: failed to retrieve principal kind from database", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -217,7 +209,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: error retrieving schema environment from database", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -241,7 +232,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: error deleting schema environment", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -265,7 +255,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: error creating schema environment after deletion", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -289,7 +278,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: error creating new schema environment", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -314,7 +302,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: error getting principal kinds by environment id", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -338,7 +325,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds error deleting principal kinds", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -386,7 +372,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Success: Create new environment with principal kinds", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -410,7 +395,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Success: Create environment with source kind registration", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", @@ -434,7 +418,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Success: Process multiple environments", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", From d63427fcc9682293ad851f29176feb29d8f929a5 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Fri, 9 Jan 2026 09:52:16 -0600 Subject: [PATCH 14/25] updated to add integration flag --- .../src/database/upsert_schema_environment_integration_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cmd/api/src/database/upsert_schema_environment_integration_test.go b/cmd/api/src/database/upsert_schema_environment_integration_test.go index 2ba131cd3c..7207ea0d2b 100644 --- a/cmd/api/src/database/upsert_schema_environment_integration_test.go +++ b/cmd/api/src/database/upsert_schema_environment_integration_test.go @@ -13,6 +13,8 @@ // limitations under the License. // // SPDX-License-Identifier: Apache-2.0 + +//go:build integration package database_test import ( From bad51d25101ed88598b1e320d28602bff9aafeac Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Fri, 9 Jan 2026 09:54:51 -0600 Subject: [PATCH 15/25] missed an arg --- .../src/database/upsert_schema_environment_integration_test.go | 1 + cmd/api/src/services/opengraphschema/opengraphschema_test.go | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/api/src/database/upsert_schema_environment_integration_test.go b/cmd/api/src/database/upsert_schema_environment_integration_test.go index 7207ea0d2b..bd32cce3d2 100644 --- a/cmd/api/src/database/upsert_schema_environment_integration_test.go +++ b/cmd/api/src/database/upsert_schema_environment_integration_test.go @@ -15,6 +15,7 @@ // SPDX-License-Identifier: Apache-2.0 //go:build integration + package database_test import ( diff --git a/cmd/api/src/services/opengraphschema/opengraphschema_test.go b/cmd/api/src/services/opengraphschema/opengraphschema_test.go index a07253548b..b15effc82f 100644 --- a/cmd/api/src/services/opengraphschema/opengraphschema_test.go +++ b/cmd/api/src/services/opengraphschema/opengraphschema_test.go @@ -348,7 +348,6 @@ func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { { name: "Error: openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds error creating principal kinds", args: args{ - schemaExtensionId: int32(1), environments: []v2.Environment{ { EnvironmentKind: "Domain", From 6e4ba47c3b7cfdc7a08efb970a6bad58444abb32 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Fri, 9 Jan 2026 15:21:46 -0600 Subject: [PATCH 16/25] peer review changes --- cmd/api/src/services/opengraphschema/extension.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/api/src/services/opengraphschema/extension.go b/cmd/api/src/services/opengraphschema/extension.go index 85d44ba1da..8b16e1e35b 100644 --- a/cmd/api/src/services/opengraphschema/extension.go +++ b/cmd/api/src/services/opengraphschema/extension.go @@ -22,10 +22,10 @@ import ( v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" ) -func (o *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { +func (s *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { for _, env := range req.Environments { // TODO: Update temporary hardcoded extensionID once extension work is complete - if err := o.openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, env.EnvironmentKind, env.SourceKind, env.PrincipalKinds); err != nil { + if err := s.openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, env.EnvironmentKind, env.SourceKind, env.PrincipalKinds); err != nil { return fmt.Errorf("failed to upload environments with principal kinds: %w", err) } } From 85b9407a4b0b91771ceeae218c6ea5c264d4f025 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Tue, 13 Jan 2026 12:54:47 -0600 Subject: [PATCH 17/25] merge conflicts --- cmd/api/src/services/entrypoint.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/api/src/services/entrypoint.go b/cmd/api/src/services/entrypoint.go index ac8361e87f..76a4f33efd 100644 --- a/cmd/api/src/services/entrypoint.go +++ b/cmd/api/src/services/entrypoint.go @@ -38,8 +38,8 @@ import ( "github.com/specterops/bloodhound/cmd/api/src/migrations" "github.com/specterops/bloodhound/cmd/api/src/model/appcfg" "github.com/specterops/bloodhound/cmd/api/src/queries" - "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" "github.com/specterops/bloodhound/cmd/api/src/services/dogtags" + "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" "github.com/specterops/bloodhound/cmd/api/src/services/upload" "github.com/specterops/bloodhound/packages/go/cache" schema "github.com/specterops/bloodhound/packages/go/graphschema" From be020633f177f6f052cf88db865fab5242e0bf3b Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Tue, 13 Jan 2026 13:10:36 -0600 Subject: [PATCH 18/25] updated to be in line with other PR --- cmd/api/src/api/registration/registration.go | 2 +- cmd/api/src/api/v2/model.go | 4 +- .../src/database/upsert_schema_environment.go | 2 - .../src/database/upsert_schema_extension.go | 24 + ...psert_schema_extension_integration_test.go | 418 +++++++++++++++ .../src/services/opengraphschema/extension.go | 20 +- .../opengraphschema/extension_test.go | 164 ++++++ .../opengraphschema/mocks/opengraphschema.go | 13 +- .../opengraphschema/opengraphschema.go | 4 +- .../opengraphschema/opengraphschema_test.go | 481 ------------------ 10 files changed, 635 insertions(+), 497 deletions(-) create mode 100644 cmd/api/src/database/upsert_schema_extension.go create mode 100644 cmd/api/src/database/upsert_schema_extension_integration_test.go create mode 100644 cmd/api/src/services/opengraphschema/extension_test.go delete mode 100644 cmd/api/src/services/opengraphschema/opengraphschema_test.go diff --git a/cmd/api/src/api/registration/registration.go b/cmd/api/src/api/registration/registration.go index 25eb6e2370..20cf7ff25a 100644 --- a/cmd/api/src/api/registration/registration.go +++ b/cmd/api/src/api/registration/registration.go @@ -63,8 +63,8 @@ func RegisterFossRoutes( authenticator api.Authenticator, authorizer auth.Authorizer, ingestSchema upload.IngestSchema, - openGraphSchemaService v2.OpenGraphSchemaService, dogtagsService dogtags.Service, + openGraphSchemaService v2.OpenGraphSchemaService, ) { router.With(func() mux.MiddlewareFunc { return middleware.DefaultRateLimitMiddleware(rdms) diff --git a/cmd/api/src/api/v2/model.go b/cmd/api/src/api/v2/model.go index 1a4c501490..a91b512275 100644 --- a/cmd/api/src/api/v2/model.go +++ b/cmd/api/src/api/v2/model.go @@ -130,8 +130,8 @@ func NewResources( authorizer auth.Authorizer, authenticator api.Authenticator, ingestSchema upload.IngestSchema, - openGraphSchemaService OpenGraphSchemaService, dogtagsService dogtags.Service, + openGraphSchemaService OpenGraphSchemaService, ) Resources { return Resources{ Decoder: schema.NewDecoder(), @@ -146,7 +146,7 @@ func NewResources( Authenticator: authenticator, IngestSchema: ingestSchema, FileService: &fs.Client{}, - openGraphSchemaService: openGraphSchemaService, DogTags: dogtagsService, + openGraphSchemaService: openGraphSchemaService, } } diff --git a/cmd/api/src/database/upsert_schema_environment.go b/cmd/api/src/database/upsert_schema_environment.go index 71b3107c44..f65022bfb3 100644 --- a/cmd/api/src/database/upsert_schema_environment.go +++ b/cmd/api/src/database/upsert_schema_environment.go @@ -26,8 +26,6 @@ import ( // UpsertSchemaEnvironmentWithPrincipalKinds creates or updates an environment with its principal kinds. // If an environment with the same environment kind and source kind exists, it will be replaced. -// -// NOTE: This method should be called within a transaction. The caller is responsible for transaction management. func (s *BloodhoundDB) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environmentKind string, sourceKind string, principalKinds []string) error { environment := model.SchemaEnvironment{ SchemaExtensionId: schemaExtensionId, diff --git a/cmd/api/src/database/upsert_schema_extension.go b/cmd/api/src/database/upsert_schema_extension.go new file mode 100644 index 0000000000..d1a9ea8b25 --- /dev/null +++ b/cmd/api/src/database/upsert_schema_extension.go @@ -0,0 +1,24 @@ +package database + +import ( + "context" + "fmt" +) + +type EnvironmentInput struct { + EnvironmentKindName string + SourceKindName string + PrincipalKinds []string +} + +func (s *BloodhoundDB) UpsertGraphSchemaExtension(ctx context.Context, extensionID int32, environments []EnvironmentInput) error { + return s.Transaction(ctx, func(tx *BloodhoundDB) error { + for _, env := range environments { + if err := tx.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, extensionID, env.EnvironmentKindName, env.SourceKindName, env.PrincipalKinds); err != nil { + return fmt.Errorf("failed to upload environment with principal kinds: %w", err) + } + } + + return nil + }) +} diff --git a/cmd/api/src/database/upsert_schema_extension_integration_test.go b/cmd/api/src/database/upsert_schema_extension_integration_test.go new file mode 100644 index 0000000000..b406021252 --- /dev/null +++ b/cmd/api/src/database/upsert_schema_extension_integration_test.go @@ -0,0 +1,418 @@ +//go:build integration + +package database_test + +import ( + "context" + "testing" + + "github.com/specterops/bloodhound/cmd/api/src/database" + "github.com/specterops/dawgs/graph" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { + type args struct { + environments []database.EnvironmentInput + } + tests := []struct { + name string + setupData func(t *testing.T, db *database.BloodhoundDB) int32 + args args + assert func(t *testing.T, db *database.BloodhoundDB) + expectedError string + }{ + { + name: "Success: Create environment with principal kinds", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + PrincipalKinds: []string{"Tag_Tier_Zero", "Tag_Owned"}, + }, + }, + }, + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + expectedPrincipalKindNames := []string{"Tag_Tier_Zero", "Tag_Owned"} + + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 1, len(environments)) + + principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + assert.NoError(t, err) + assert.Equal(t, len(expectedPrincipalKindNames), len(principalKinds)) + + expectedKindIDs := make([]int32, len(expectedPrincipalKindNames)) + for i, name := range expectedPrincipalKindNames { + kind, err := db.GetKindByName(context.Background(), name) + assert.NoError(t, err) + expectedKindIDs[i] = int32(kind.ID) + } + + actualKindIDs := make([]int32, len(principalKinds)) + for i, pk := range principalKinds { + assert.Equal(t, environments[0].ID, pk.EnvironmentId) + actualKindIDs[i] = pk.PrincipalKind + } + + assert.ElementsMatch(t, expectedKindIDs, actualKindIDs) + }, + }, + { + name: "Success: Create multiple environments", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + PrincipalKinds: []string{"Tag_Tier_Zero"}, + }, + { + EnvironmentKindName: "Tag_Owned", + SourceKindName: "Base", + PrincipalKinds: []string{"Tag_Owned"}, + }, + }, + }, + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 2, len(environments), "Should have two environments") + + // Verify first environment + env1PrincipalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + assert.NoError(t, err) + assert.Equal(t, 1, len(env1PrincipalKinds)) + + // Verify second environment + env2PrincipalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[1].ID) + assert.NoError(t, err) + assert.Equal(t, 1, len(env2PrincipalKinds)) + }, + }, + { + name: "Success: Upsert replaces existing environment", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + // Create initial environment + err = db.UpsertGraphSchemaExtension(context.Background(), ext.ID, []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + PrincipalKinds: []string{"Tag_Owned"}, + }, + }) + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + PrincipalKinds: []string{"Tag_Tier_Zero"}, + }, + }, + }, + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + expectedPrincipalKindNames := []string{"Tag_Tier_Zero"} + + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 1, len(environments), "Should only have one environment (old one replaced)") + + principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + assert.NoError(t, err) + assert.Equal(t, 1, len(principalKinds)) + + expectedKind, err := db.GetKindByName(context.Background(), expectedPrincipalKindNames[0]) + assert.NoError(t, err) + + assert.Equal(t, int32(expectedKind.ID), principalKinds[0].PrincipalKind) + assert.Equal(t, environments[0].ID, principalKinds[0].EnvironmentId) + }, + }, + { + name: "Success: Source kind auto-registers", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "NewSource", + PrincipalKinds: []string{"Tag_Tier_Zero"}, + }, + }, + }, + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + sourceKind, err := db.GetSourceKindByName(context.Background(), "NewSource") + assert.NoError(t, err) + assert.Equal(t, graph.StringKind("NewSource"), sourceKind.Name) + + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 1, len(environments)) + assert.Equal(t, int32(sourceKind.ID), environments[0].SourceKindId) + + principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + assert.NoError(t, err) + assert.Equal(t, 1, len(principalKinds)) + }, + }, + { + name: "Success: Multiple environments with different source kinds", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + PrincipalKinds: []string{"Tag_Tier_Zero"}, + }, + { + EnvironmentKindName: "Tag_Owned", + SourceKindName: "NewSource", + PrincipalKinds: []string{"Tag_Owned"}, + }, + }, + }, + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + // Verify NewSource was auto-registered + sourceKind, err := db.GetSourceKindByName(context.Background(), "NewSource") + assert.NoError(t, err) + assert.Equal(t, graph.StringKind("NewSource"), sourceKind.Name) + + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 2, len(environments), "Should have two environments") + + for _, env := range environments { + principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), env.ID) + assert.NoError(t, err) + assert.Equal(t, 1, len(principalKinds), "Each environment should have one principal kind") + } + }, + }, + { + name: "Error: First environment has invalid environment kind", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "NonExistent", + SourceKindName: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + expectedError: "environment kind 'NonExistent' not found", + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + // Verify transaction rolled back - no environment created + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 0, len(environments), "No environment should exist after rollback") + }, + }, + { + name: "Error: First environment has invalid principal kind", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + PrincipalKinds: []string{"NonExistent"}, + }, + }, + }, + expectedError: "principal kind 'NonExistent' not found", + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + // Verify transaction rolled back - no environment created + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 0, len(environments), "No environment should exist after rollback") + }, + }, + { + name: "Rollback: Second environment fails, first should rollback", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + PrincipalKinds: []string{"Tag_Tier_Zero"}, + }, + { + EnvironmentKindName: "NonExistent", + SourceKindName: "Base", + PrincipalKinds: []string{}, + }, + }, + }, + expectedError: "environment kind 'NonExistent' not found", + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + // Verify complete transaction rollback - no environments created + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 0, len(environments), "No environments should exist after rollback") + }, + }, + { + name: "Rollback: Second environment has invalid principal kind", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + PrincipalKinds: []string{"Tag_Tier_Zero"}, + }, + { + EnvironmentKindName: "Tag_Owned", + SourceKindName: "Base", + PrincipalKinds: []string{"NonExistent"}, + }, + }, + }, + expectedError: "principal kind 'NonExistent' not found", + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + // Verify complete transaction rollback - no environments created + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 0, len(environments), "No environments should exist after rollback") + }, + }, + { + name: "Rollback: Partial failure in first environment's principal kinds", + setupData: func(t *testing.T, db *database.BloodhoundDB) int32 { + t.Helper() + ext, err := db.CreateGraphSchemaExtension(context.Background(), "TestExt", "Test", "v1.0.0") + require.NoError(t, err) + + return ext.ID + }, + args: args{ + environments: []database.EnvironmentInput{ + { + EnvironmentKindName: "Tag_Tier_Zero", + SourceKindName: "Base", + PrincipalKinds: []string{"Tag_Owned", "NonExistent"}, + }, + }, + }, + expectedError: "principal kind 'NonExistent' not found", + assert: func(t *testing.T, db *database.BloodhoundDB) { + t.Helper() + + // Verify transaction rolled back - no environment created + environments, err := db.GetSchemaEnvironments(context.Background()) + assert.NoError(t, err) + assert.Equal(t, 0, len(environments), "No environment should exist after rollback") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testSuite := setupIntegrationTestSuite(t) + defer teardownIntegrationTestSuite(t, &testSuite) + + extensionID := tt.setupData(t, testSuite.BHDatabase) + + err := testSuite.BHDatabase.UpsertGraphSchemaExtension( + context.Background(), + extensionID, + tt.args.environments, + ) + + if tt.expectedError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectedError) + if tt.assert != nil { + tt.assert(t, testSuite.BHDatabase) + } + } else { + require.NoError(t, err) + if tt.assert != nil { + tt.assert(t, testSuite.BHDatabase) + } + } + }) + } +} diff --git a/cmd/api/src/services/opengraphschema/extension.go b/cmd/api/src/services/opengraphschema/extension.go index 8b16e1e35b..a16df48021 100644 --- a/cmd/api/src/services/opengraphschema/extension.go +++ b/cmd/api/src/services/opengraphschema/extension.go @@ -20,15 +20,27 @@ import ( "fmt" v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" + "github.com/specterops/bloodhound/cmd/api/src/database" ) func (s *OpenGraphSchemaService) UpsertGraphSchemaExtension(ctx context.Context, req v2.GraphSchemaExtension) error { - for _, env := range req.Environments { - // TODO: Update temporary hardcoded extensionID once extension work is complete - if err := s.openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, 1, env.EnvironmentKind, env.SourceKind, env.PrincipalKinds); err != nil { - return fmt.Errorf("failed to upload environments with principal kinds: %w", err) + var ( + environments = make([]database.EnvironmentInput, len(req.Environments)) + ) + + for i, environment := range req.Environments { + environments[i] = database.EnvironmentInput{ + EnvironmentKindName: environment.EnvironmentKind, + SourceKindName: environment.SourceKind, + PrincipalKinds: environment.PrincipalKinds, } } + // TODO: Temporary hardcoded value but needs to be updated to pass in the extension ID + err := s.openGraphSchemaRepository.UpsertGraphSchemaExtension(ctx, 1, environments) + if err != nil { + return fmt.Errorf("error upserting graph extension: %w", err) + } + return nil } diff --git a/cmd/api/src/services/opengraphschema/extension_test.go b/cmd/api/src/services/opengraphschema/extension_test.go new file mode 100644 index 0000000000..0f81f8b881 --- /dev/null +++ b/cmd/api/src/services/opengraphschema/extension_test.go @@ -0,0 +1,164 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package opengraphschema_test + +import ( + "context" + "errors" + "testing" + + v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" + "github.com/specterops/bloodhound/cmd/api/src/database" + "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" + schemamocks "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema/mocks" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { + type mocks struct { + mockOpenGraphSchema *schemamocks.MockOpenGraphSchemaRepository + } + type args struct { + environments []v2.Environment + } + tests := []struct { + name string + setupMocks func(t *testing.T, m *mocks) + args args + expected error + }{ + { + name: "Error: openGraphSchemaRepository.UpsertGraphSchemaExtension error", + args: args{ + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + expectedEnvs := []database.EnvironmentInput{ + { + EnvironmentKindName: "Domain", + SourceKindName: "Base", + PrincipalKinds: []string{"User"}, + }, + } + m.mockOpenGraphSchema.EXPECT().UpsertGraphSchemaExtension( + gomock.Any(), + int32(1), + expectedEnvs, + ).Return(errors.New("error")) + }, + expected: errors.New("error upserting graph extension: error"), + }, + { + name: "Success: single environment", + args: args{ + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User", "Computer"}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + expectedEnvs := []database.EnvironmentInput{ + { + EnvironmentKindName: "Domain", + SourceKindName: "Base", + PrincipalKinds: []string{"User", "Computer"}, + }, + } + m.mockOpenGraphSchema.EXPECT().UpsertGraphSchemaExtension( + gomock.Any(), + int32(1), + expectedEnvs, + ).Return(nil) + }, + expected: nil, + }, + { + name: "Success: multiple environments", + args: args{ + environments: []v2.Environment{ + { + EnvironmentKind: "Domain", + SourceKind: "Base", + PrincipalKinds: []string{"User"}, + }, + { + EnvironmentKind: "AzureAD", + SourceKind: "AzureHound", + PrincipalKinds: []string{"User", "Group"}, + }, + }, + }, + setupMocks: func(t *testing.T, m *mocks) { + t.Helper() + expectedEnvs := []database.EnvironmentInput{ + { + EnvironmentKindName: "Domain", + SourceKindName: "Base", + PrincipalKinds: []string{"User"}, + }, + { + EnvironmentKindName: "AzureAD", + SourceKindName: "AzureHound", + PrincipalKinds: []string{"User", "Group"}, + }, + } + m.mockOpenGraphSchema.EXPECT().UpsertGraphSchemaExtension( + gomock.Any(), + int32(1), + expectedEnvs, + ).Return(nil) + }, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + m := &mocks{ + mockOpenGraphSchema: schemamocks.NewMockOpenGraphSchemaRepository(ctrl), + } + + tt.setupMocks(t, m) + + service := opengraphschema.NewOpenGraphSchemaService(m.mockOpenGraphSchema) + + err := service.UpsertGraphSchemaExtension(context.Background(), v2.GraphSchemaExtension{ + Environments: tt.args.environments, + }) + + if tt.expected != nil { + assert.EqualError(t, err, tt.expected.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go index 9c3e712987..f1e1539a2a 100644 --- a/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/mocks/opengraphschema.go @@ -29,6 +29,7 @@ import ( context "context" reflect "reflect" + database "github.com/specterops/bloodhound/cmd/api/src/database" gomock "go.uber.org/mock/gomock" ) @@ -56,16 +57,16 @@ func (m *MockOpenGraphSchemaRepository) EXPECT() *MockOpenGraphSchemaRepositoryM return m.recorder } -// UpsertSchemaEnvironmentWithPrincipalKinds mocks base method. -func (m *MockOpenGraphSchemaRepository) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environmentKind, sourceKind string, principalKinds []string) error { +// UpsertGraphSchemaExtension mocks base method. +func (m *MockOpenGraphSchemaRepository) UpsertGraphSchemaExtension(ctx context.Context, extensionID int32, environments []database.EnvironmentInput) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpsertSchemaEnvironmentWithPrincipalKinds", ctx, schemaExtensionId, environmentKind, sourceKind, principalKinds) + ret := m.ctrl.Call(m, "UpsertGraphSchemaExtension", ctx, extensionID, environments) ret0, _ := ret[0].(error) return ret0 } -// UpsertSchemaEnvironmentWithPrincipalKinds indicates an expected call of UpsertSchemaEnvironmentWithPrincipalKinds. -func (mr *MockOpenGraphSchemaRepositoryMockRecorder) UpsertSchemaEnvironmentWithPrincipalKinds(ctx, schemaExtensionId, environmentKind, sourceKind, principalKinds any) *gomock.Call { +// UpsertGraphSchemaExtension indicates an expected call of UpsertGraphSchemaExtension. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) UpsertGraphSchemaExtension(ctx, extensionID, environments any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertSchemaEnvironmentWithPrincipalKinds", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).UpsertSchemaEnvironmentWithPrincipalKinds), ctx, schemaExtensionId, environmentKind, sourceKind, principalKinds) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertGraphSchemaExtension", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).UpsertGraphSchemaExtension), ctx, extensionID, environments) } diff --git a/cmd/api/src/services/opengraphschema/opengraphschema.go b/cmd/api/src/services/opengraphschema/opengraphschema.go index 6de81562dc..74322b76bf 100644 --- a/cmd/api/src/services/opengraphschema/opengraphschema.go +++ b/cmd/api/src/services/opengraphschema/opengraphschema.go @@ -19,11 +19,13 @@ package opengraphschema import ( "context" + + "github.com/specterops/bloodhound/cmd/api/src/database" ) // OpenGraphSchemaRepository - type OpenGraphSchemaRepository interface { - UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Context, schemaExtensionId int32, environmentKind string, sourceKind string, principalKinds []string) error + UpsertGraphSchemaExtension(ctx context.Context, extensionID int32, environments []database.EnvironmentInput) error } type OpenGraphSchemaService struct { diff --git a/cmd/api/src/services/opengraphschema/opengraphschema_test.go b/cmd/api/src/services/opengraphschema/opengraphschema_test.go deleted file mode 100644 index b15effc82f..0000000000 --- a/cmd/api/src/services/opengraphschema/opengraphschema_test.go +++ /dev/null @@ -1,481 +0,0 @@ -// Copyright 2026 Specter Ops, Inc. -// -// Licensed under the Apache License, Version 2.0 -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// SPDX-License-Identifier: Apache-2.0 -package opengraphschema_test - -import ( - "context" - "errors" - "testing" - - v2 "github.com/specterops/bloodhound/cmd/api/src/api/v2" - "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" - schemamocks "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema/mocks" - "github.com/stretchr/testify/assert" - "go.uber.org/mock/gomock" -) - -func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { - type mocks struct { - mockOpenGraphSchema *schemamocks.MockOpenGraphSchemaRepository - } - type args struct { - environments []v2.Environment - } - tests := []struct { - name string - setupMocks func(t *testing.T, m *mocks) - args args - expected error - }{ - // UpsertSchemaEnvironmentWithPrincipalKinds - // Validation: Environment Kind - { - name: "Error: environment kind name not found in the database", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("environment kind 'Domain' not found")) - }, - expected: errors.New("failed to upload environments with principal kinds: environment kind 'Domain' not found"), - }, - { - name: "Error: failed to retrieve environment kind from database", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error retrieving environment kind 'Domain': error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error retrieving environment kind 'Domain': error"), - }, - // Validation: Source Kind - { - name: "Error: failed to retrieve source kind from database", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error retrieving source kind 'Base': error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error retrieving source kind 'Base': error"), - }, - { - name: "Error: source kind name doesn't exist in database, registration fails", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error registering source kind 'Base': error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error registering source kind 'Base': error"), - }, - { - name: "Error: source kind name doesn't exist in database, registration succeeds but fetch fails", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error retrieving newly registered source kind 'Base': error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error retrieving newly registered source kind 'Base': error"), - }, - // Validation: Principal Kind - { - name: "Error: principal kind not found in database", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User", "InvalidKind"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User", "InvalidKind"}, - ).Return(errors.New("principal kind 'InvalidKind' not found")) - }, - expected: errors.New("failed to upload environments with principal kinds: principal kind 'InvalidKind' not found"), - }, - { - name: "Error: failed to retrieve principal kind from database", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User", "InvalidKind"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User", "InvalidKind"}, - ).Return(errors.New("error retrieving principal kind by name 'InvalidKind': error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error retrieving principal kind by name 'InvalidKind': error"), - }, - // Upsert Schema Environment - { - name: "Error: error retrieving schema environment from database", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error upserting schema environment: error retrieving schema environment id 0: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting schema environment: error retrieving schema environment id 0: error"), - }, - { - name: "Error: error deleting schema environment", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error upserting schema environment: error deleting schema environment 5: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting schema environment: error deleting schema environment 5: error"), - }, - { - name: "Error: error creating schema environment after deletion", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error upserting schema environment: error creating schema environment: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting schema environment: error creating schema environment: error"), - }, - { - name: "Error: error creating new schema environment", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{}, - ).Return(errors.New("error upserting schema environment: error creating schema environment: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting schema environment: error creating schema environment: error"), - }, - // Upsert Principal Kinds - { - name: "Error: error getting principal kinds by environment id", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User"}, - ).Return(errors.New("error upserting principal kinds: error retrieving existing principal kinds for environment 10: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting principal kinds: error retrieving existing principal kinds for environment 10: error"), - }, - { - name: "Error: openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds error deleting principal kinds", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User"}, - ).Return(errors.New("error upserting principal kinds: error deleting principal kind 5 for environment 10: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting principal kinds: error deleting principal kind 5 for environment 10: error"), - }, - { - name: "Error: openGraphSchemaRepository.UpsertSchemaEnvironmentWithPrincipalKinds error creating principal kinds", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User"}, - ).Return(errors.New("error upserting principal kinds: error creating principal kind 3 for environment 10: error")) - }, - expected: errors.New("failed to upload environments with principal kinds: error upserting principal kinds: error creating principal kind 3 for environment 10: error"), - }, - { - name: "Success: Create new environment with principal kinds", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User", "Computer"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User", "Computer"}, - ).Return(nil) - }, - expected: nil, - }, - { - name: "Success: Create environment with source kind registration", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "NewSource", - PrincipalKinds: []string{}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "NewSource", - []string{}, - ).Return(nil) - }, - expected: nil, - }, - { - name: "Success: Process multiple environments", - args: args{ - environments: []v2.Environment{ - { - EnvironmentKind: "Domain", - SourceKind: "Base", - PrincipalKinds: []string{"User"}, - }, - { - EnvironmentKind: "AzureAD", - SourceKind: "AzureHound", - PrincipalKinds: []string{"User", "Group"}, - }, - }, - }, - setupMocks: func(t *testing.T, m *mocks) { - t.Helper() - // First environment - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "Domain", - "Base", - []string{"User"}, - ).Return(nil) - - // Second environment - m.mockOpenGraphSchema.EXPECT().UpsertSchemaEnvironmentWithPrincipalKinds( - gomock.Any(), - int32(1), - "AzureAD", - "AzureHound", - []string{"User", "Group"}, - ).Return(nil) - }, - expected: nil, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - - m := &mocks{ - mockOpenGraphSchema: schemamocks.NewMockOpenGraphSchemaRepository(ctrl), - } - - tt.setupMocks(t, m) - - service := opengraphschema.NewOpenGraphSchemaService(m.mockOpenGraphSchema) - - err := service.UpsertGraphSchemaExtension(context.Background(), v2.GraphSchemaExtension{ - Environments: tt.args.environments, - }) - - if tt.expected != nil { - assert.EqualError(t, err, tt.expected.Error()) - } else { - assert.NoError(t, err) - } - }) - } -} From 094d52a6fb27a63418a12257c15e95f6f590331d Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Tue, 13 Jan 2026 13:13:20 -0600 Subject: [PATCH 19/25] just prepare --- cmd/api/src/database/upsert_schema_extension.go | 15 +++++++++++++++ .../upsert_schema_extension_integration_test.go | 15 +++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/cmd/api/src/database/upsert_schema_extension.go b/cmd/api/src/database/upsert_schema_extension.go index d1a9ea8b25..8c564d6aa4 100644 --- a/cmd/api/src/database/upsert_schema_extension.go +++ b/cmd/api/src/database/upsert_schema_extension.go @@ -1,3 +1,18 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 package database import ( diff --git a/cmd/api/src/database/upsert_schema_extension_integration_test.go b/cmd/api/src/database/upsert_schema_extension_integration_test.go index b406021252..9a71677d59 100644 --- a/cmd/api/src/database/upsert_schema_extension_integration_test.go +++ b/cmd/api/src/database/upsert_schema_extension_integration_test.go @@ -1,3 +1,18 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 //go:build integration package database_test From be0c964b944ad8e3c832a6416f478589d58cbc54 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Tue, 13 Jan 2026 13:19:21 -0600 Subject: [PATCH 20/25] just prepare --- cmd/api/src/database/upsert_schema_extension.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/api/src/database/upsert_schema_extension.go b/cmd/api/src/database/upsert_schema_extension.go index 8c564d6aa4..e27f20af87 100644 --- a/cmd/api/src/database/upsert_schema_extension.go +++ b/cmd/api/src/database/upsert_schema_extension.go @@ -4,7 +4,7 @@ // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, From 86560a5a56e7bb04fd8c1a0091fb36c8499f6714 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Wed, 14 Jan 2026 11:40:59 -0600 Subject: [PATCH 21/25] updated error message --- cmd/api/src/database/upsert_schema_extension.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/api/src/database/upsert_schema_extension.go b/cmd/api/src/database/upsert_schema_extension.go index e27f20af87..d189a1d40f 100644 --- a/cmd/api/src/database/upsert_schema_extension.go +++ b/cmd/api/src/database/upsert_schema_extension.go @@ -30,7 +30,7 @@ func (s *BloodhoundDB) UpsertGraphSchemaExtension(ctx context.Context, extension return s.Transaction(ctx, func(tx *BloodhoundDB) error { for _, env := range environments { if err := tx.UpsertSchemaEnvironmentWithPrincipalKinds(ctx, extensionID, env.EnvironmentKindName, env.SourceKindName, env.PrincipalKinds); err != nil { - return fmt.Errorf("failed to upload environment with principal kinds: %w", err) + return fmt.Errorf("failed to upsert environment with principal kinds: %w", err) } } From 1cab880bcc2240eceef5c76ca591fc42fa03af7d Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Thu, 15 Jan 2026 15:50:05 -0600 Subject: [PATCH 22/25] peer review changes --- cmd/api/src/database/graphschema.go | 28 +++--- .../database/graphschema_integration_test.go | 18 ++-- .../database/migration/migrations/v8.6.0.sql | 22 +++++ cmd/api/src/database/mocks/db.go | 88 +++++++++---------- .../src/database/upsert_schema_environment.go | 24 ++--- ...ert_schema_environment_integration_test.go | 8 +- ...psert_schema_extension_integration_test.go | 12 +-- 7 files changed, 113 insertions(+), 87 deletions(-) create mode 100644 cmd/api/src/database/migration/migrations/v8.6.0.sql diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index df2ed45b18..b6e2348d2a 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -66,9 +66,10 @@ type OpenGraphSchema interface { GetRemediationByFindingId(ctx context.Context, findingId int32) (model.Remediation, error) UpdateRemediation(ctx context.Context, findingId int32, shortDescription string, longDescription string, shortRemediation string, longRemediation string) (model.Remediation, error) DeleteRemediation(ctx context.Context, findingId int32) error - CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) - GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) - DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error + + CreatePrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) + GetPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) + DeletePrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error } const ( @@ -240,10 +241,10 @@ func (s *BloodhoundDB) GetGraphSchemaNodeKinds(ctx context.Context, filters mode if filterAndPagination, err := parseFiltersAndPagination(filters, sort, skip, limit); err != nil { return schemaNodeKinds, 0, err } else { - sqlStr := fmt.Sprintf(`SELECT nk.id, k.name, nk.schema_extension_id, nk.display_name, nk.description, + sqlStr := fmt.Sprintf(`SELECT nk.id, k.name, nk.schema_extension_id, nk.display_name, nk.description, nk.is_display_kind, nk.icon, nk.icon_color, nk.created_at, nk.updated_at, nk.deleted_at FROM %s nk - JOIN %s k ON nk.kind_id = k.id + JOIN %s k ON nk.kind_id = k.id %s %s %s`, model.GraphSchemaNodeKind{}.TableName(), kindTable, filterAndPagination.WhereClause, filterAndPagination.OrderSql, filterAndPagination.SkipLimit) if result := s.db.WithContext(ctx).Raw(sqlStr, filterAndPagination.Filter.params...).Scan(&schemaNodeKinds); result.Error != nil { @@ -289,7 +290,7 @@ func (s *BloodhoundDB) UpdateGraphSchemaNodeKind(ctx context.Context, schemaNode RETURNING id, kind_id, schema_extension_id, display_name, description, is_display_kind, icon, icon_color, created_at, updated_at, deleted_at ) SELECT updated_row.id, %s.name, schema_extension_id, display_name, description, is_display_kind, icon, icon_color, created_at, updated_at, deleted_at - FROM updated_row + FROM updated_row JOIN %s ON %s.id = updated_row.kind_id`, schemaNodeKind.TableName(), kindTable, kindTable, kindTable), schemaNodeKind.SchemaExtensionId, schemaNodeKind.DisplayName, schemaNodeKind.Description, schemaNodeKind.IsDisplayKind, schemaNodeKind.Icon, @@ -437,7 +438,7 @@ func (s *BloodhoundDB) CreateGraphSchemaEdgeKind(ctx context.Context, name strin RETURNING id, kind_id, schema_extension_id, description, is_traversable, created_at, updated_at, deleted_at ) SELECT ie.id, ie.schema_extension_id, dk.name, ie.description, ie.is_traversable, ie.created_at, ie.updated_at, ie.deleted_at - FROM inserted_edges ie + FROM inserted_edges ie JOIN dawgs_kind dk ON ie.kind_id = dk.id;`, name, schemaExtensionId, description, isTraversable).Scan(&schemaEdgeKind); result.Error != nil { if strings.Contains(result.Error.Error(), DuplicateKeyValueErrorString) { return schemaEdgeKind, fmt.Errorf("%w: %v", ErrDuplicateSchemaEdgeKindName, result.Error) @@ -458,10 +459,10 @@ func (s *BloodhoundDB) GetGraphSchemaEdgeKinds(ctx context.Context, edgeKindFilt if filterAndPagination, err := parseFiltersAndPagination(edgeKindFilters, sort, skip, limit); err != nil { return schemaEdgeKinds, 0, err } else { - sqlStr := fmt.Sprintf(`SELECT ek.id, k.name, ek.schema_extension_id, ek.description, ek.is_traversable, + sqlStr := fmt.Sprintf(`SELECT ek.id, k.name, ek.schema_extension_id, ek.description, ek.is_traversable, ek.created_at, ek.updated_at, ek.deleted_at FROM %s ek - JOIN %s k ON ek.kind_id = k.id + JOIN %s k ON ek.kind_id = k.id %s %s %s`, model.GraphSchemaEdgeKind{}.TableName(), kindTable, filterAndPagination.WhereClause, filterAndPagination.OrderSql, filterAndPagination.SkipLimit) @@ -543,7 +544,7 @@ func (s *BloodhoundDB) UpdateGraphSchemaEdgeKind(ctx context.Context, schemaEdge SET schema_extension_id = ?, description = ?, is_traversable = ?, updated_at = NOW() WHERE id = ? RETURNING id, kind_id, schema_extension_id, description, is_traversable, created_at, updated_at, deleted_at - ) + ) SELECT updated_row.id, %s.name, schema_extension_id, description, is_traversable, created_at, updated_at, deleted_at FROM updated_row JOIN %s ON %s.id = updated_row.kind_id`, @@ -784,7 +785,7 @@ func (s *BloodhoundDB) DeleteRemediation(ctx context.Context, findingId int32) e return nil } -func (s *BloodhoundDB) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { +func (s *BloodhoundDB) CreatePrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { var envPrincipalKind model.SchemaEnvironmentPrincipalKind if result := s.db.WithContext(ctx).Raw(` @@ -798,7 +799,8 @@ func (s *BloodhoundDB) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, return envPrincipalKind, nil } -func (s *BloodhoundDB) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { +// GetPrincipalKindsByEnvironmentID - retrieves a schema environments principal kind by environment id. +func (s *BloodhoundDB) GetPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { var envPrincipalKinds model.SchemaEnvironmentPrincipalKinds if result := s.db.WithContext(ctx).Raw(` @@ -812,7 +814,7 @@ func (s *BloodhoundDB) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx con return envPrincipalKinds, nil } -func (s *BloodhoundDB) DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error { +func (s *BloodhoundDB) DeletePrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error { if result := s.db.WithContext(ctx).Exec(` DELETE FROM schema_environments_principal_kinds WHERE environment_id = ? AND principal_kind = ?`, diff --git a/cmd/api/src/database/graphschema_integration_test.go b/cmd/api/src/database/graphschema_integration_test.go index 28db314d6a..412a5d2d1d 100644 --- a/cmd/api/src/database/graphschema_integration_test.go +++ b/cmd/api/src/database/graphschema_integration_test.go @@ -2224,7 +2224,7 @@ func TestCreateSchemaEnvironmentPrincipalKind(t *testing.T) { testSuite := testCase.setup() defer teardownIntegrationTestSuite(t, &testSuite) - result, err := testSuite.BHDatabase.CreateSchemaEnvironmentPrincipalKind(testSuite.Context, testCase.args.environmentId, testCase.args.principalKind) + result, err := testSuite.BHDatabase.CreatePrincipalKind(testSuite.Context, testCase.args.environmentId, testCase.args.principalKind) if testCase.want.err != nil { assert.ErrorIs(t, err, testCase.want.err) } else { @@ -2262,10 +2262,10 @@ func TestGetSchemaEnvironmentPrincipalKindsByEnvironmentId(t *testing.T) { _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironmentPrincipalKind(testSuite.Context, 1, 1) + _, err = testSuite.BHDatabase.CreatePrincipalKind(testSuite.Context, 1, 1) require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironmentPrincipalKind(testSuite.Context, 1, 2) + _, err = testSuite.BHDatabase.CreatePrincipalKind(testSuite.Context, 1, 2) require.NoError(t, err) return testSuite @@ -2295,7 +2295,7 @@ func TestGetSchemaEnvironmentPrincipalKindsByEnvironmentId(t *testing.T) { testSuite := testCase.setup() defer teardownIntegrationTestSuite(t, &testSuite) - result, err := testSuite.BHDatabase.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(testSuite.Context, testCase.args.environmentId) + result, err := testSuite.BHDatabase.GetPrincipalKindsByEnvironmentId(testSuite.Context, testCase.args.environmentId) if testCase.want.err != nil { assert.ErrorIs(t, err, testCase.want.err) } else { @@ -2332,7 +2332,7 @@ func TestDeleteSchemaEnvironmentPrincipalKind(t *testing.T) { _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironmentPrincipalKind(testSuite.Context, 1, 1) + _, err = testSuite.BHDatabase.CreatePrincipalKind(testSuite.Context, 1, 1) require.NoError(t, err) return testSuite @@ -2364,12 +2364,12 @@ func TestDeleteSchemaEnvironmentPrincipalKind(t *testing.T) { testSuite := testCase.setup() defer teardownIntegrationTestSuite(t, &testSuite) - err := testSuite.BHDatabase.DeleteSchemaEnvironmentPrincipalKind(testSuite.Context, testCase.args.environmentId, testCase.args.principalKind) + err := testSuite.BHDatabase.DeletePrincipalKind(testSuite.Context, testCase.args.environmentId, testCase.args.principalKind) if testCase.want.err != nil { assert.ErrorIs(t, err, testCase.want.err) } else { assert.NoError(t, err) - result, err := testSuite.BHDatabase.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(testSuite.Context, testCase.args.environmentId) + result, err := testSuite.BHDatabase.GetPrincipalKindsByEnvironmentId(testSuite.Context, testCase.args.environmentId) assert.NoError(t, err) assert.Len(t, result, 0) } @@ -2403,7 +2403,7 @@ func TestDeleteSchemaExtension_CascadeDeletesAllDependents(t *testing.T) { _, err = testSuite.BHDatabase.CreateRemediation(testSuite.Context, relationshipFinding.ID, "Short desc", "Long desc", "Short remediation", "Long remediation") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironmentPrincipalKind(testSuite.Context, environment.ID, nodeKind.ID) + _, err = testSuite.BHDatabase.CreatePrincipalKind(testSuite.Context, environment.ID, nodeKind.ID) require.NoError(t, err) err = testSuite.BHDatabase.DeleteGraphSchemaExtension(testSuite.Context, extension.ID) @@ -2427,7 +2427,7 @@ func TestDeleteSchemaExtension_CascadeDeletesAllDependents(t *testing.T) { _, err = testSuite.BHDatabase.GetRemediationByFindingId(testSuite.Context, relationshipFinding.ID) assert.ErrorIs(t, err, database.ErrNotFound) - principalKinds, err := testSuite.BHDatabase.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(testSuite.Context, environment.ID) + principalKinds, err := testSuite.BHDatabase.GetPrincipalKindsByEnvironmentId(testSuite.Context, environment.ID) assert.NoError(t, err) assert.Len(t, principalKinds, 0) } diff --git a/cmd/api/src/database/migration/migrations/v8.6.0.sql b/cmd/api/src/database/migration/migrations/v8.6.0.sql new file mode 100644 index 0000000000..fbb20a81b0 --- /dev/null +++ b/cmd/api/src/database/migration/migrations/v8.6.0.sql @@ -0,0 +1,22 @@ +-- Copyright 2026 Specter Ops, Inc. +-- +-- Licensed under the Apache License, Version 2.0 +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 +-- Drop the column with the incorrect foreign key +ALTER TABLE IF EXISTS schema_environments + DROP COLUMN IF EXISTS source_kind_id; + +-- Add the column back with the correct foreign key reference +ALTER TABLE IF EXISTS schema_environments + ADD COLUMN source_kind_id INTEGER NOT NULL REFERENCES source_kinds(id); diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index b911c1ad42..df0c1b70fb 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -461,6 +461,21 @@ func (mr *MockDatabaseMockRecorder) CreateOIDCProvider(ctx, name, issuer, client return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOIDCProvider", reflect.TypeOf((*MockDatabase)(nil).CreateOIDCProvider), ctx, name, issuer, clientID, config) } +// CreatePrincipalKind mocks base method. +func (m *MockDatabase) CreatePrincipalKind(ctx context.Context, environmentId, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreatePrincipalKind", ctx, environmentId, principalKind) + ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKind) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreatePrincipalKind indicates an expected call of CreatePrincipalKind. +func (mr *MockDatabaseMockRecorder) CreatePrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreatePrincipalKind", reflect.TypeOf((*MockDatabase)(nil).CreatePrincipalKind), ctx, environmentId, principalKind) +} + // CreateRemediation mocks base method. func (m *MockDatabase) CreateRemediation(ctx context.Context, findingId int32, shortDescription, longDescription, shortRemediation, longRemediation string) (model.Remediation, error) { m.ctrl.T.Helper() @@ -585,21 +600,6 @@ func (mr *MockDatabaseMockRecorder) CreateSchemaEnvironment(ctx, extensionId, en return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironment", reflect.TypeOf((*MockDatabase)(nil).CreateSchemaEnvironment), ctx, extensionId, environmentKindId, sourceKindId) } -// CreateSchemaEnvironmentPrincipalKind mocks base method. -func (m *MockDatabase) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) - ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKind) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateSchemaEnvironmentPrincipalKind indicates an expected call of CreateSchemaEnvironmentPrincipalKind. -func (mr *MockDatabaseMockRecorder) CreateSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockDatabase)(nil).CreateSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) -} - // CreateSchemaRelationshipFinding mocks base method. func (m *MockDatabase) CreateSchemaRelationshipFinding(ctx context.Context, extensionId, relationshipKindId, environmentId int32, name, displayName string) (model.SchemaRelationshipFinding, error) { m.ctrl.T.Helper() @@ -926,6 +926,20 @@ func (mr *MockDatabaseMockRecorder) DeleteIngestTask(ctx, ingestTask any) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteIngestTask", reflect.TypeOf((*MockDatabase)(nil).DeleteIngestTask), ctx, ingestTask) } +// DeletePrincipalKind mocks base method. +func (m *MockDatabase) DeletePrincipalKind(ctx context.Context, environmentId, principalKind int32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeletePrincipalKind", ctx, environmentId, principalKind) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeletePrincipalKind indicates an expected call of DeletePrincipalKind. +func (mr *MockDatabaseMockRecorder) DeletePrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePrincipalKind", reflect.TypeOf((*MockDatabase)(nil).DeletePrincipalKind), ctx, environmentId, principalKind) +} + // DeleteRemediation mocks base method. func (m *MockDatabase) DeleteRemediation(ctx context.Context, findingId int32) error { m.ctrl.T.Helper() @@ -1001,20 +1015,6 @@ func (mr *MockDatabaseMockRecorder) DeleteSchemaEnvironment(ctx, environmentId a return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironment", reflect.TypeOf((*MockDatabase)(nil).DeleteSchemaEnvironment), ctx, environmentId) } -// DeleteSchemaEnvironmentPrincipalKind mocks base method. -func (m *MockDatabase) DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteSchemaEnvironmentPrincipalKind indicates an expected call of DeleteSchemaEnvironmentPrincipalKind. -func (mr *MockDatabaseMockRecorder) DeleteSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockDatabase)(nil).DeleteSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) -} - // DeleteSchemaRelationshipFinding mocks base method. func (m *MockDatabase) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error { m.ctrl.T.Helper() @@ -2007,6 +2007,21 @@ func (mr *MockDatabaseMockRecorder) GetPermission(ctx, id any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermission", reflect.TypeOf((*MockDatabase)(nil).GetPermission), ctx, id) } +// GetPrincipalKindsByEnvironmentId mocks base method. +func (m *MockDatabase) GetPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetPrincipalKindsByEnvironmentId", ctx, environmentId) + ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKinds) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetPrincipalKindsByEnvironmentId indicates an expected call of GetPrincipalKindsByEnvironmentId. +func (mr *MockDatabaseMockRecorder) GetPrincipalKindsByEnvironmentId(ctx, environmentId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPrincipalKindsByEnvironmentId", reflect.TypeOf((*MockDatabase)(nil).GetPrincipalKindsByEnvironmentId), ctx, environmentId) +} + // GetPublicSavedQueries mocks base method. func (m *MockDatabase) GetPublicSavedQueries(ctx context.Context) (model.SavedQueries, error) { m.ctrl.T.Helper() @@ -2217,21 +2232,6 @@ func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentByKinds(ctx, environment return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentByKinds", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentByKinds), ctx, environmentKindId, sourceKindId) } -// GetSchemaEnvironmentPrincipalKindsByEnvironmentId mocks base method. -func (m *MockDatabase) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", ctx, environmentId) - ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKinds) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaEnvironmentPrincipalKindsByEnvironmentId indicates an expected call of GetSchemaEnvironmentPrincipalKindsByEnvironmentId. -func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentPrincipalKindsByEnvironmentId), ctx, environmentId) -} - // GetSchemaEnvironments mocks base method. func (m *MockDatabase) GetSchemaEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/database/upsert_schema_environment.go b/cmd/api/src/database/upsert_schema_environment.go index f65022bfb3..df1401db18 100644 --- a/cmd/api/src/database/upsert_schema_environment.go +++ b/cmd/api/src/database/upsert_schema_environment.go @@ -49,13 +49,13 @@ func (s *BloodhoundDB) UpsertSchemaEnvironmentWithPrincipalKinds(ctx context.Con environment.EnvironmentKindId = int32(envKind.ID) environment.SourceKindId = sourceKindID - envID, err := s.upsertSchemaEnvironment(ctx, environment) + envID, err := s.replaceSchemaEnvironment(ctx, environment) if err != nil { - return fmt.Errorf("error upserting schema environment: %w", err) + return fmt.Errorf("error replacing or creating schema environment: %w", err) } - if err := s.upsertPrincipalKinds(ctx, envID, translatedPrincipalKinds); err != nil { - return fmt.Errorf("error upserting principal kinds: %w", err) + if err := s.replacePrincipalKinds(ctx, envID, translatedPrincipalKinds); err != nil { + return fmt.Errorf("error replacing principal kinds: %w", err) } return nil @@ -114,9 +114,11 @@ func (s *BloodhoundDB) validateAndTranslatePrincipalKinds(ctx context.Context, p return principalKinds, nil } -// upsertSchemaEnvironment creates or updates a schema environment. +// replaceSchemaEnvironment creates or updates a schema environment. // If an environment with the given kinds exists, it deletes it first before creating the new one. -func (s *BloodhoundDB) upsertSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) (int32, error) { +// The unique constraint on (environment_kind_id, source_kind_id) of the Schema Environment table ensures no +// duplicate pairs exist, enabling this upsert logic. +func (s *BloodhoundDB) replaceSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) (int32, error) { if existing, err := s.GetSchemaEnvironmentByKinds(ctx, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil && !errors.Is(err, ErrNotFound) { return 0, fmt.Errorf("error retrieving schema environment: %w", err) } else if !errors.Is(err, ErrNotFound) { @@ -134,14 +136,14 @@ func (s *BloodhoundDB) upsertSchemaEnvironment(ctx context.Context, graphSchema } } -// upsertPrincipalKinds deletes all existing principal kinds for an environment and creates new ones. -func (s *BloodhoundDB) upsertPrincipalKinds(ctx context.Context, environmentID int32, principalKinds []model.SchemaEnvironmentPrincipalKind) error { - if existingKinds, err := s.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentID); err != nil && !errors.Is(err, ErrNotFound) { +// replacePrincipalKinds deletes all existing principal kinds for an environment and creates new ones. +func (s *BloodhoundDB) replacePrincipalKinds(ctx context.Context, environmentID int32, principalKinds []model.SchemaEnvironmentPrincipalKind) error { + if existingKinds, err := s.GetPrincipalKindsByEnvironmentId(ctx, environmentID); err != nil && !errors.Is(err, ErrNotFound) { return fmt.Errorf("error retrieving existing principal kinds for environment %d: %w", environmentID, err) } else if !errors.Is(err, ErrNotFound) { // Delete all existing principal kinds for _, kind := range existingKinds { - if err := s.DeleteSchemaEnvironmentPrincipalKind(ctx, kind.EnvironmentId, kind.PrincipalKind); err != nil { + if err := s.DeletePrincipalKind(ctx, kind.EnvironmentId, kind.PrincipalKind); err != nil { return fmt.Errorf("error deleting principal kind %d for environment %d: %w", kind.PrincipalKind, kind.EnvironmentId, err) } } @@ -149,7 +151,7 @@ func (s *BloodhoundDB) upsertPrincipalKinds(ctx context.Context, environmentID i // Create the new principal kinds for _, kind := range principalKinds { - if _, err := s.CreateSchemaEnvironmentPrincipalKind(ctx, environmentID, kind.PrincipalKind); err != nil { + if _, err := s.CreatePrincipalKind(ctx, environmentID, kind.PrincipalKind); err != nil { return fmt.Errorf("error creating principal kind %d for environment %d: %w", kind.PrincipalKind, environmentID, err) } } diff --git a/cmd/api/src/database/upsert_schema_environment_integration_test.go b/cmd/api/src/database/upsert_schema_environment_integration_test.go index bd32cce3d2..b811b10adb 100644 --- a/cmd/api/src/database/upsert_schema_environment_integration_test.go +++ b/cmd/api/src/database/upsert_schema_environment_integration_test.go @@ -64,7 +64,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, len(environments)) - principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) assert.NoError(t, err) assert.Equal(t, len(expectedPrincipalKindNames), len(principalKinds)) @@ -118,7 +118,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, len(environments), "Should only have one environment (old one deleted)") - principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) assert.NoError(t, err) assert.Equal(t, 1, len(principalKinds)) @@ -155,7 +155,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { assert.Equal(t, 1, len(environments)) assert.Equal(t, int32(sourceKind.ID), environments[0].SourceKindId) - principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) assert.NoError(t, err) assert.Equal(t, 1, len(principalKinds)) }, @@ -265,7 +265,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { assert.Equal(t, 2, len(environments), "Should have two different environments") for _, env := range environments { - principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), env.ID) + principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), env.ID) assert.NoError(t, err) assert.Equal(t, 1, len(principalKinds), "Each environment should have one principal kind") } diff --git a/cmd/api/src/database/upsert_schema_extension_integration_test.go b/cmd/api/src/database/upsert_schema_extension_integration_test.go index 9a71677d59..2f998cd054 100644 --- a/cmd/api/src/database/upsert_schema_extension_integration_test.go +++ b/cmd/api/src/database/upsert_schema_extension_integration_test.go @@ -65,7 +65,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, len(environments)) - principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) assert.NoError(t, err) assert.Equal(t, len(expectedPrincipalKindNames), len(principalKinds)) @@ -116,12 +116,12 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { assert.Equal(t, 2, len(environments), "Should have two environments") // Verify first environment - env1PrincipalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + env1PrincipalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) assert.NoError(t, err) assert.Equal(t, 1, len(env1PrincipalKinds)) // Verify second environment - env2PrincipalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[1].ID) + env2PrincipalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[1].ID) assert.NoError(t, err) assert.Equal(t, 1, len(env2PrincipalKinds)) }, @@ -163,7 +163,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 1, len(environments), "Should only have one environment (old one replaced)") - principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) assert.NoError(t, err) assert.Equal(t, 1, len(principalKinds)) @@ -204,7 +204,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { assert.Equal(t, 1, len(environments)) assert.Equal(t, int32(sourceKind.ID), environments[0].SourceKindId) - principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) + principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), environments[0].ID) assert.NoError(t, err) assert.Equal(t, 1, len(principalKinds)) }, @@ -245,7 +245,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { assert.Equal(t, 2, len(environments), "Should have two environments") for _, env := range environments { - principalKinds, err := db.GetSchemaEnvironmentPrincipalKindsByEnvironmentId(context.Background(), env.ID) + principalKinds, err := db.GetPrincipalKindsByEnvironmentId(context.Background(), env.ID) assert.NoError(t, err) assert.Equal(t, 1, len(principalKinds), "Each environment should have one principal kind") } From 471fc4e7bc0422f3990759019cb3550b732ffe3e Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Thu, 15 Jan 2026 15:56:38 -0600 Subject: [PATCH 23/25] peer review changes - more naming --- cmd/api/src/database/graphschema.go | 30 ++-- .../database/graphschema_integration_test.go | 72 ++++----- cmd/api/src/database/mocks/db.go | 148 +++++++++--------- .../src/database/upsert_schema_environment.go | 6 +- ...ert_schema_environment_integration_test.go | 14 +- ...psert_schema_extension_integration_test.go | 20 +-- cmd/api/src/services/entrypoint.go | 14 +- 7 files changed, 152 insertions(+), 152 deletions(-) diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index b6e2348d2a..3588f3f79b 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -52,11 +52,11 @@ type OpenGraphSchema interface { GetGraphSchemaEdgeKindsWithSchemaName(ctx context.Context, edgeKindFilters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaEdgeKindsWithNamedSchema, int, error) - CreateSchemaEnvironment(ctx context.Context, extensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) - GetSchemaEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) - GetSchemaEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) - GetSchemaEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) - DeleteSchemaEnvironment(ctx context.Context, environmentId int32) error + CreateEnvironment(ctx context.Context, extensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) + GetEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) + GetEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) + GetEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) + DeleteEnvironment(ctx context.Context, environmentId int32) error CreateSchemaRelationshipFinding(ctx context.Context, extensionId int32, relationshipKindId int32, environmentId int32, name string, displayName string) (model.SchemaRelationshipFinding, error) GetSchemaRelationshipFindingById(ctx context.Context, findingId int32) (model.SchemaRelationshipFinding, error) @@ -572,8 +572,8 @@ func (s *BloodhoundDB) DeleteGraphSchemaEdgeKind(ctx context.Context, schemaEdge return nil } -// CreateSchemaEnvironment - creates a new schema_environment. -func (s *BloodhoundDB) CreateSchemaEnvironment(ctx context.Context, extensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) { +// CreateEnvironment - creates a new schema_environment. +func (s *BloodhoundDB) CreateEnvironment(ctx context.Context, extensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) { var schemaEnvironment model.SchemaEnvironment if result := s.db.WithContext(ctx).Raw(fmt.Sprintf(` @@ -590,14 +590,14 @@ func (s *BloodhoundDB) CreateSchemaEnvironment(ctx context.Context, extensionId return schemaEnvironment, nil } -// GetSchemaEnvironments - retrieves list of schema environments. -func (s *BloodhoundDB) GetSchemaEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) { +// GetEnvironments - retrieves list of schema environments. +func (s *BloodhoundDB) GetEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) { var result []model.SchemaEnvironment return result, CheckError(s.db.WithContext(ctx).Find(&result)) } -// GetSchemaEnvironmentByKinds - retrieves an environment by its environment kind and source kind. -func (s *BloodhoundDB) GetSchemaEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { +// GetEnvironmentByKinds - retrieves an environment by its environment kind and source kind. +func (s *BloodhoundDB) GetEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { var env model.SchemaEnvironment if result := s.db.WithContext(ctx).Raw( @@ -612,8 +612,8 @@ func (s *BloodhoundDB) GetSchemaEnvironmentByKinds(ctx context.Context, environm return env, nil } -// GetSchemaEnvironmentById - retrieves a schema environment by id. -func (s *BloodhoundDB) GetSchemaEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) { +// GetEnvironmentById - retrieves a schema environment by id. +func (s *BloodhoundDB) GetEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) { var schemaEnvironment model.SchemaEnvironment if result := s.db.WithContext(ctx).Raw(fmt.Sprintf(` @@ -629,8 +629,8 @@ func (s *BloodhoundDB) GetSchemaEnvironmentById(ctx context.Context, environment return schemaEnvironment, nil } -// DeleteSchemaEnvironment - deletes a schema environment by id. -func (s *BloodhoundDB) DeleteSchemaEnvironment(ctx context.Context, environmentId int32) error { +// DeleteEnvironment - deletes a schema environment by id. +func (s *BloodhoundDB) DeleteEnvironment(ctx context.Context, environmentId int32) error { var schemaEnvironment model.SchemaEnvironment if result := s.db.WithContext(ctx).Exec(fmt.Sprintf(`DELETE FROM %s WHERE id = ?`, schemaEnvironment.TableName()), environmentId); result.Error != nil { diff --git a/cmd/api/src/database/graphschema_integration_test.go b/cmd/api/src/database/graphschema_integration_test.go index 412a5d2d1d..b4431b7d8f 100644 --- a/cmd/api/src/database/graphschema_integration_test.go +++ b/cmd/api/src/database/graphschema_integration_test.go @@ -1213,7 +1213,7 @@ func TestCreateSchemaEnvironment(t *testing.T) { testSuite := testCase.setup() defer teardownIntegrationTestSuite(t, &testSuite) - got, err := testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, testCase.args.extensionId, testCase.args.environmentKindId, testCase.args.sourceKindId) + got, err := testSuite.BHDatabase.CreateEnvironment(testSuite.Context, testCase.args.extensionId, testCase.args.environmentKindId, testCase.args.sourceKindId) if testCase.want.err != nil { assert.EqualError(t, err, testCase.want.err.Error()) } else { @@ -1261,7 +1261,7 @@ func TestGetSchemaEnvironments(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "Extension1", "DisplayName", "v1.0.0") require.NoError(t, err) // Create Environments - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, defaultSchemaExtensionID, int32(1), int32(1)) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, defaultSchemaExtensionID, int32(1), int32(1)) require.NoError(t, err) return testSuite @@ -1289,9 +1289,9 @@ func TestGetSchemaEnvironments(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "Extension1", "DisplayName", "v1.0.0") require.NoError(t, err) // Create Environments - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, defaultSchemaExtensionID, int32(1), int32(1)) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, defaultSchemaExtensionID, int32(1), int32(1)) require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, defaultSchemaExtensionID, int32(2), int32(2)) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, defaultSchemaExtensionID, int32(2), int32(2)) require.NoError(t, err) return testSuite @@ -1328,7 +1328,7 @@ func TestGetSchemaEnvironments(t *testing.T) { testSuite := testCase.setup() defer teardownIntegrationTestSuite(t, &testSuite) - got, err := testSuite.BHDatabase.GetSchemaEnvironments(testSuite.Context) + got, err := testSuite.BHDatabase.GetEnvironments(testSuite.Context) if testCase.want.err != nil { assert.EqualError(t, err, testCase.want.err.Error()) } else { @@ -1372,7 +1372,7 @@ func TestGetSchemaEnvironmentById(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "Extension1", "DisplayName", "v1.0.0") require.NoError(t, err) // Create Environment - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, defaultSchemaExtensionID, int32(1), int32(1)) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, defaultSchemaExtensionID, int32(1), int32(1)) require.NoError(t, err) return testSuite @@ -1409,7 +1409,7 @@ func TestGetSchemaEnvironmentById(t *testing.T) { testSuite := testCase.setup() defer teardownIntegrationTestSuite(t, &testSuite) - got, err := testSuite.BHDatabase.GetSchemaEnvironmentById(testSuite.Context, testCase.args.environmentId) + got, err := testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, testCase.args.environmentId) if testCase.want.err != nil { assert.ErrorIs(t, err, testCase.want.err) } else { @@ -1451,7 +1451,7 @@ func TestDeleteSchemaEnvironment(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "Extension1", "DisplayName", "v1.0.0") require.NoError(t, err) // Create Environment - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, defaultSchemaExtensionID, int32(1), int32(1)) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, defaultSchemaExtensionID, int32(1), int32(1)) require.NoError(t, err) return testSuite @@ -1481,14 +1481,14 @@ func TestDeleteSchemaEnvironment(t *testing.T) { testSuite := testCase.setup() defer teardownIntegrationTestSuite(t, &testSuite) - err := testSuite.BHDatabase.DeleteSchemaEnvironment(testSuite.Context, testCase.args.environmentId) + err := testSuite.BHDatabase.DeleteEnvironment(testSuite.Context, testCase.args.environmentId) if testCase.want.err != nil { assert.ErrorIs(t, err, testCase.want.err) } else { assert.NoError(t, err) // Verify deletion by trying to get the environment - _, err = testSuite.BHDatabase.GetSchemaEnvironmentById(testSuite.Context, testCase.args.environmentId) + _, err = testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, testCase.args.environmentId) assert.ErrorIs(t, err, database.ErrNotFound) } }) @@ -1506,21 +1506,21 @@ func TestTransaction_SchemaEnvironment(t *testing.T) { // Create two environments in a single transaction err = testSuite.BHDatabase.Transaction(testSuite.Context, func(tx *database.BloodhoundDB) error { - _, err := tx.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err := tx.CreateEnvironment(testSuite.Context, 1, 1, 1) if err != nil { return err } - _, err = tx.CreateSchemaEnvironment(testSuite.Context, 1, 2, 2) + _, err = tx.CreateEnvironment(testSuite.Context, 1, 2, 2) return err }) require.NoError(t, err) // Verify both environments were created - env1, err := testSuite.BHDatabase.GetSchemaEnvironmentById(testSuite.Context, 1) + env1, err := testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, 1) require.NoError(t, err) assert.Equal(t, int32(1), env1.EnvironmentKindId) - env2, err := testSuite.BHDatabase.GetSchemaEnvironmentById(testSuite.Context, 2) + env2, err := testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, 2) require.NoError(t, err) assert.Equal(t, int32(2), env2.EnvironmentKindId) }) @@ -1536,7 +1536,7 @@ func TestTransaction_SchemaEnvironment(t *testing.T) { // Create one environment, then fail - should rollback expectedErr := fmt.Errorf("intentional error to trigger rollback") err = testSuite.BHDatabase.Transaction(testSuite.Context, func(tx *database.BloodhoundDB) error { - _, err := tx.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err := tx.CreateEnvironment(testSuite.Context, 1, 1, 1) if err != nil { return err } @@ -1545,7 +1545,7 @@ func TestTransaction_SchemaEnvironment(t *testing.T) { require.ErrorIs(t, err, expectedErr) // Verify the environment was NOT created (rolled back) - _, err = testSuite.BHDatabase.GetSchemaEnvironmentById(testSuite.Context, 1) + _, err = testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, 1) assert.ErrorIs(t, err, database.ErrNotFound) }) @@ -1559,18 +1559,18 @@ func TestTransaction_SchemaEnvironment(t *testing.T) { // Create one environment, then try to create a duplicate - should rollback both err = testSuite.BHDatabase.Transaction(testSuite.Context, func(tx *database.BloodhoundDB) error { - _, err := tx.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err := tx.CreateEnvironment(testSuite.Context, 1, 1, 1) if err != nil { return err } // Try to create duplicate (same environment_kind_id + source_kind_id) - will fail - _, err = tx.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = tx.CreateEnvironment(testSuite.Context, 1, 1, 1) return err }) require.Error(t, err) // Verify the first environment was NOT created (rolled back due to second failure) - _, err = testSuite.BHDatabase.GetSchemaEnvironmentById(testSuite.Context, 1) + _, err = testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, 1) assert.ErrorIs(t, err, database.ErrNotFound) }) @@ -1584,16 +1584,16 @@ func TestTransaction_SchemaEnvironment(t *testing.T) { // Create and delete in same transaction err = testSuite.BHDatabase.Transaction(testSuite.Context, func(tx *database.BloodhoundDB) error { - env, err := tx.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + env, err := tx.CreateEnvironment(testSuite.Context, 1, 1, 1) if err != nil { return err } - return tx.DeleteSchemaEnvironment(testSuite.Context, env.ID) + return tx.DeleteEnvironment(testSuite.Context, env.ID) }) require.NoError(t, err) // Verify the environment does not exist - _, err = testSuite.BHDatabase.GetSchemaEnvironmentById(testSuite.Context, 1) + _, err = testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, 1) assert.ErrorIs(t, err, database.ErrNotFound) }) } @@ -1625,7 +1625,7 @@ func TestCreateSchemaRelationshipFinding(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "FindingExtension", "Finding Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) return testSuite @@ -1657,7 +1657,7 @@ func TestCreateSchemaRelationshipFinding(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "FindingExtension2", "Finding Extension 2", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) _, err = testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, 1, 1, 1, "DuplicateName", "Display Name") @@ -1724,7 +1724,7 @@ func TestGetSchemaRelationshipFindingById(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "GetFindingExt", "Get Finding Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) _, err = testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, 1, 1, 1, "GetByIdFinding", "Get By ID Finding") @@ -1798,7 +1798,7 @@ func TestDeleteSchemaRelationshipFinding(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "DeleteFindingExt", "Delete Finding Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) _, err = testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, 1, 1, 1, "DeleteFinding", "Delete Finding") @@ -1870,7 +1870,7 @@ func TestCreateRemediation(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "RemediationExt", "Remediation Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) _, err = testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, 1, 1, 1, "RemediationFinding", "Remediation Finding") @@ -1942,7 +1942,7 @@ func TestGetRemediationByFindingId(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "GetRemediationExt", "Get Remediation Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) _, err = testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, 1, 1, 1, "GetRemediationFinding", "Get Remediation Finding") @@ -2022,7 +2022,7 @@ func TestUpdateRemediation(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "UpdateRemediationExt", "Update Remediation Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) _, err = testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, 1, 1, 1, "UpdateRemediationFinding", "Update Remediation Finding") @@ -2059,7 +2059,7 @@ func TestUpdateRemediation(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "UpsertRemediationExt", "Upsert Remediation Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) _, err = testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, 1, 1, 1, "UpsertRemediationFinding", "Upsert Remediation Finding") @@ -2130,7 +2130,7 @@ func TestDeleteRemediation(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "DeleteRemediationExt", "Delete Remediation Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) _, err = testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, 1, 1, 1, "DeleteRemediationFinding", "Delete Remediation Finding") @@ -2202,7 +2202,7 @@ func TestCreateSchemaEnvironmentPrincipalKind(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "EnvPrincipalKindExt", "Env Principal Kind Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) return testSuite @@ -2259,7 +2259,7 @@ func TestGetSchemaEnvironmentPrincipalKindsByEnvironmentId(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "GetEnvPrincipalKindExt", "Get Env Principal Kind Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) _, err = testSuite.BHDatabase.CreatePrincipalKind(testSuite.Context, 1, 1) @@ -2329,7 +2329,7 @@ func TestDeleteSchemaEnvironmentPrincipalKind(t *testing.T) { _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "DeleteEnvPrincipalKindExt", "Delete Env Principal Kind Extension", "v1.0.0") require.NoError(t, err) - _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + _, err = testSuite.BHDatabase.CreateEnvironment(testSuite.Context, 1, 1, 1) require.NoError(t, err) _, err = testSuite.BHDatabase.CreatePrincipalKind(testSuite.Context, 1, 1) @@ -2394,7 +2394,7 @@ func TestDeleteSchemaExtension_CascadeDeletesAllDependents(t *testing.T) { edgeKind, err := testSuite.BHDatabase.CreateGraphSchemaEdgeKind(testSuite.Context, "CascadeTestEdgeKind", extension.ID, "Test description", true) require.NoError(t, err) - environment, err := testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, extension.ID, nodeKind.ID, nodeKind.ID) + environment, err := testSuite.BHDatabase.CreateEnvironment(testSuite.Context, extension.ID, nodeKind.ID, nodeKind.ID) require.NoError(t, err) relationshipFinding, err := testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, extension.ID, edgeKind.ID, environment.ID, "CascadeTestFinding", "Cascade Test Finding") @@ -2418,7 +2418,7 @@ func TestDeleteSchemaExtension_CascadeDeletesAllDependents(t *testing.T) { _, err = testSuite.BHDatabase.GetGraphSchemaEdgeKindById(testSuite.Context, edgeKind.ID) assert.ErrorIs(t, err, database.ErrNotFound) - _, err = testSuite.BHDatabase.GetSchemaEnvironmentById(testSuite.Context, environment.ID) + _, err = testSuite.BHDatabase.GetEnvironmentById(testSuite.Context, environment.ID) assert.ErrorIs(t, err, database.ErrNotFound) _, err = testSuite.BHDatabase.GetSchemaRelationshipFindingById(testSuite.Context, relationshipFinding.ID) diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index df0c1b70fb..c23d6e150f 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -341,6 +341,21 @@ func (mr *MockDatabaseMockRecorder) CreateCustomNodeKinds(ctx, customNodeKind an return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCustomNodeKinds", reflect.TypeOf((*MockDatabase)(nil).CreateCustomNodeKinds), ctx, customNodeKind) } +// CreateEnvironment mocks base method. +func (m *MockDatabase) CreateEnvironment(ctx context.Context, extensionId, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateEnvironment", ctx, extensionId, environmentKindId, sourceKindId) + ret0, _ := ret[0].(model.SchemaEnvironment) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateEnvironment indicates an expected call of CreateEnvironment. +func (mr *MockDatabaseMockRecorder) CreateEnvironment(ctx, extensionId, environmentKindId, sourceKindId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateEnvironment", reflect.TypeOf((*MockDatabase)(nil).CreateEnvironment), ctx, extensionId, environmentKindId, sourceKindId) +} + // CreateGraphSchemaEdgeKind mocks base method. func (m *MockDatabase) CreateGraphSchemaEdgeKind(ctx context.Context, name string, schemaExtensionId int32, description string, isTraversable bool) (model.GraphSchemaEdgeKind, error) { m.ctrl.T.Helper() @@ -585,21 +600,6 @@ func (mr *MockDatabaseMockRecorder) CreateSavedQueryPermissionsToUsers(ctx, quer return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSavedQueryPermissionsToUsers", reflect.TypeOf((*MockDatabase)(nil).CreateSavedQueryPermissionsToUsers), varargs...) } -// CreateSchemaEnvironment mocks base method. -func (m *MockDatabase) CreateSchemaEnvironment(ctx context.Context, extensionId, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateSchemaEnvironment", ctx, extensionId, environmentKindId, sourceKindId) - ret0, _ := ret[0].(model.SchemaEnvironment) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateSchemaEnvironment indicates an expected call of CreateSchemaEnvironment. -func (mr *MockDatabaseMockRecorder) CreateSchemaEnvironment(ctx, extensionId, environmentKindId, sourceKindId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironment", reflect.TypeOf((*MockDatabase)(nil).CreateSchemaEnvironment), ctx, extensionId, environmentKindId, sourceKindId) -} - // CreateSchemaRelationshipFinding mocks base method. func (m *MockDatabase) CreateSchemaRelationshipFinding(ctx context.Context, extensionId, relationshipKindId, environmentId int32, name, displayName string) (model.SchemaRelationshipFinding, error) { m.ctrl.T.Helper() @@ -842,6 +842,20 @@ func (mr *MockDatabaseMockRecorder) DeleteCustomNodeKind(ctx, kindName any) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteCustomNodeKind", reflect.TypeOf((*MockDatabase)(nil).DeleteCustomNodeKind), ctx, kindName) } +// DeleteEnvironment mocks base method. +func (m *MockDatabase) DeleteEnvironment(ctx context.Context, environmentId int32) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteEnvironment", ctx, environmentId) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteEnvironment indicates an expected call of DeleteEnvironment. +func (mr *MockDatabaseMockRecorder) DeleteEnvironment(ctx, environmentId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteEnvironment", reflect.TypeOf((*MockDatabase)(nil).DeleteEnvironment), ctx, environmentId) +} + // DeleteEnvironmentTargetedAccessControlForUser mocks base method. func (m *MockDatabase) DeleteEnvironmentTargetedAccessControlForUser(ctx context.Context, user model.User) error { m.ctrl.T.Helper() @@ -1001,20 +1015,6 @@ func (mr *MockDatabaseMockRecorder) DeleteSavedQueryPermissionsForUsers(ctx, que return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSavedQueryPermissionsForUsers", reflect.TypeOf((*MockDatabase)(nil).DeleteSavedQueryPermissionsForUsers), varargs...) } -// DeleteSchemaEnvironment mocks base method. -func (m *MockDatabase) DeleteSchemaEnvironment(ctx context.Context, environmentId int32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteSchemaEnvironment", ctx, environmentId) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteSchemaEnvironment indicates an expected call of DeleteSchemaEnvironment. -func (mr *MockDatabaseMockRecorder) DeleteSchemaEnvironment(ctx, environmentId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironment", reflect.TypeOf((*MockDatabase)(nil).DeleteSchemaEnvironment), ctx, environmentId) -} - // DeleteSchemaRelationshipFinding mocks base method. func (m *MockDatabase) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error { m.ctrl.T.Helper() @@ -1702,6 +1702,36 @@ func (mr *MockDatabaseMockRecorder) GetDatapipeStatus(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDatapipeStatus", reflect.TypeOf((*MockDatabase)(nil).GetDatapipeStatus), ctx) } +// GetEnvironmentById mocks base method. +func (m *MockDatabase) GetEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetEnvironmentById", ctx, environmentId) + ret0, _ := ret[0].(model.SchemaEnvironment) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetEnvironmentById indicates an expected call of GetEnvironmentById. +func (mr *MockDatabaseMockRecorder) GetEnvironmentById(ctx, environmentId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnvironmentById", reflect.TypeOf((*MockDatabase)(nil).GetEnvironmentById), ctx, environmentId) +} + +// GetEnvironmentByKinds mocks base method. +func (m *MockDatabase) GetEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetEnvironmentByKinds", ctx, environmentKindId, sourceKindId) + ret0, _ := ret[0].(model.SchemaEnvironment) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetEnvironmentByKinds indicates an expected call of GetEnvironmentByKinds. +func (mr *MockDatabaseMockRecorder) GetEnvironmentByKinds(ctx, environmentKindId, sourceKindId any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnvironmentByKinds", reflect.TypeOf((*MockDatabase)(nil).GetEnvironmentByKinds), ctx, environmentKindId, sourceKindId) +} + // GetEnvironmentTargetedAccessControlForUser mocks base method. func (m *MockDatabase) GetEnvironmentTargetedAccessControlForUser(ctx context.Context, user model.User) ([]model.EnvironmentTargetedAccessControl, error) { m.ctrl.T.Helper() @@ -1717,6 +1747,21 @@ func (mr *MockDatabaseMockRecorder) GetEnvironmentTargetedAccessControlForUser(c return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnvironmentTargetedAccessControlForUser", reflect.TypeOf((*MockDatabase)(nil).GetEnvironmentTargetedAccessControlForUser), ctx, user) } +// GetEnvironments mocks base method. +func (m *MockDatabase) GetEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetEnvironments", ctx) + ret0, _ := ret[0].([]model.SchemaEnvironment) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetEnvironments indicates an expected call of GetEnvironments. +func (mr *MockDatabaseMockRecorder) GetEnvironments(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnvironments", reflect.TypeOf((*MockDatabase)(nil).GetEnvironments), ctx) +} + // GetFlag mocks base method. func (m *MockDatabase) GetFlag(ctx context.Context, id int32) (appcfg.FeatureFlag, error) { m.ctrl.T.Helper() @@ -2202,51 +2247,6 @@ func (mr *MockDatabaseMockRecorder) GetSavedQueryPermissions(ctx, queryID any) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSavedQueryPermissions", reflect.TypeOf((*MockDatabase)(nil).GetSavedQueryPermissions), ctx, queryID) } -// GetSchemaEnvironmentById mocks base method. -func (m *MockDatabase) GetSchemaEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaEnvironmentById", ctx, environmentId) - ret0, _ := ret[0].(model.SchemaEnvironment) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaEnvironmentById indicates an expected call of GetSchemaEnvironmentById. -func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentById(ctx, environmentId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentById", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentById), ctx, environmentId) -} - -// GetSchemaEnvironmentByKinds mocks base method. -func (m *MockDatabase) GetSchemaEnvironmentByKinds(ctx context.Context, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaEnvironmentByKinds", ctx, environmentKindId, sourceKindId) - ret0, _ := ret[0].(model.SchemaEnvironment) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaEnvironmentByKinds indicates an expected call of GetSchemaEnvironmentByKinds. -func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentByKinds(ctx, environmentKindId, sourceKindId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentByKinds", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentByKinds), ctx, environmentKindId, sourceKindId) -} - -// GetSchemaEnvironments mocks base method. -func (m *MockDatabase) GetSchemaEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaEnvironments", ctx) - ret0, _ := ret[0].([]model.SchemaEnvironment) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaEnvironments indicates an expected call of GetSchemaEnvironments. -func (mr *MockDatabaseMockRecorder) GetSchemaEnvironments(ctx any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironments", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironments), ctx) -} - // GetSchemaRelationshipFindingById mocks base method. func (m *MockDatabase) GetSchemaRelationshipFindingById(ctx context.Context, findingId int32) (model.SchemaRelationshipFinding, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/database/upsert_schema_environment.go b/cmd/api/src/database/upsert_schema_environment.go index df1401db18..180ec7702a 100644 --- a/cmd/api/src/database/upsert_schema_environment.go +++ b/cmd/api/src/database/upsert_schema_environment.go @@ -119,17 +119,17 @@ func (s *BloodhoundDB) validateAndTranslatePrincipalKinds(ctx context.Context, p // The unique constraint on (environment_kind_id, source_kind_id) of the Schema Environment table ensures no // duplicate pairs exist, enabling this upsert logic. func (s *BloodhoundDB) replaceSchemaEnvironment(ctx context.Context, graphSchema model.SchemaEnvironment) (int32, error) { - if existing, err := s.GetSchemaEnvironmentByKinds(ctx, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil && !errors.Is(err, ErrNotFound) { + if existing, err := s.GetEnvironmentByKinds(ctx, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil && !errors.Is(err, ErrNotFound) { return 0, fmt.Errorf("error retrieving schema environment: %w", err) } else if !errors.Is(err, ErrNotFound) { // Environment exists - delete it first - if err := s.DeleteSchemaEnvironment(ctx, existing.ID); err != nil { + if err := s.DeleteEnvironment(ctx, existing.ID); err != nil { return 0, fmt.Errorf("error deleting schema environment %d: %w", existing.ID, err) } } // Create Environment - if created, err := s.CreateSchemaEnvironment(ctx, graphSchema.SchemaExtensionId, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil { + if created, err := s.CreateEnvironment(ctx, graphSchema.SchemaExtensionId, graphSchema.EnvironmentKindId, graphSchema.SourceKindId); err != nil { return 0, fmt.Errorf("error creating schema environment: %w", err) } else { return created.ID, nil diff --git a/cmd/api/src/database/upsert_schema_environment_integration_test.go b/cmd/api/src/database/upsert_schema_environment_integration_test.go index b811b10adb..41409b4b99 100644 --- a/cmd/api/src/database/upsert_schema_environment_integration_test.go +++ b/cmd/api/src/database/upsert_schema_environment_integration_test.go @@ -60,7 +60,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { expectedPrincipalKindNames := []string{"Tag_Tier_Zero", "Tag_Owned"} - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 1, len(environments)) @@ -114,7 +114,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { expectedPrincipalKindNames := []string{"Tag_Tier_Zero"} - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 1, len(environments), "Should only have one environment (old one deleted)") @@ -150,7 +150,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { assert.NoError(t, err) assert.Equal(t, graph.StringKind("NewSource"), sourceKind.Name) - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 1, len(environments)) assert.Equal(t, int32(sourceKind.ID), environments[0].SourceKindId) @@ -179,7 +179,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { t.Helper() // Verify transaction rolled back - no environment created - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 0, len(environments), "No environment should exist after rollback") }, @@ -203,7 +203,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { t.Helper() // Verify transaction rolled back - no environment created - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 0, len(environments), "No environment should exist after rollback") }, @@ -227,7 +227,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { t.Helper() // Verify transaction rolled back - no environment created - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 0, len(environments), "No environment should exist after rollback") }, @@ -260,7 +260,7 @@ func TestBloodhoundDB_UpsertSchemaEnvironmentWithPrincipalKinds(t *testing.T) { assert: func(t *testing.T, db *database.BloodhoundDB) { t.Helper() - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 2, len(environments), "Should have two different environments") diff --git a/cmd/api/src/database/upsert_schema_extension_integration_test.go b/cmd/api/src/database/upsert_schema_extension_integration_test.go index 2f998cd054..e4fcae0287 100644 --- a/cmd/api/src/database/upsert_schema_extension_integration_test.go +++ b/cmd/api/src/database/upsert_schema_extension_integration_test.go @@ -61,7 +61,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { expectedPrincipalKindNames := []string{"Tag_Tier_Zero", "Tag_Owned"} - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 1, len(environments)) @@ -111,7 +111,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { assert: func(t *testing.T, db *database.BloodhoundDB) { t.Helper() - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 2, len(environments), "Should have two environments") @@ -159,7 +159,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { expectedPrincipalKindNames := []string{"Tag_Tier_Zero"} - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 1, len(environments), "Should only have one environment (old one replaced)") @@ -199,7 +199,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { assert.NoError(t, err) assert.Equal(t, graph.StringKind("NewSource"), sourceKind.Name) - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 1, len(environments)) assert.Equal(t, int32(sourceKind.ID), environments[0].SourceKindId) @@ -240,7 +240,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { assert.NoError(t, err) assert.Equal(t, graph.StringKind("NewSource"), sourceKind.Name) - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 2, len(environments), "Should have two environments") @@ -274,7 +274,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { t.Helper() // Verify transaction rolled back - no environment created - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 0, len(environments), "No environment should exist after rollback") }, @@ -302,7 +302,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { t.Helper() // Verify transaction rolled back - no environment created - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 0, len(environments), "No environment should exist after rollback") }, @@ -335,7 +335,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { t.Helper() // Verify complete transaction rollback - no environments created - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 0, len(environments), "No environments should exist after rollback") }, @@ -368,7 +368,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { t.Helper() // Verify complete transaction rollback - no environments created - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 0, len(environments), "No environments should exist after rollback") }, @@ -396,7 +396,7 @@ func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { t.Helper() // Verify transaction rolled back - no environment created - environments, err := db.GetSchemaEnvironments(context.Background()) + environments, err := db.GetEnvironments(context.Background()) assert.NoError(t, err) assert.Equal(t, 0, len(environments), "No environment should exist after rollback") }, diff --git a/cmd/api/src/services/entrypoint.go b/cmd/api/src/services/entrypoint.go index fd2eebb60e..cb87ebfd2b 100644 --- a/cmd/api/src/services/entrypoint.go +++ b/cmd/api/src/services/entrypoint.go @@ -119,13 +119,13 @@ func Entrypoint(ctx context.Context, cfg config.Configuration, connections boots startDelay := 0 * time.Second var ( - cl = changelog.NewChangelog(connections.Graph, connections.RDMS, changelog.DefaultOptions()) - pipeline = datapipe.NewPipeline(ctx, cfg, connections.RDMS, connections.Graph, graphQueryCache, ingestSchema, cl) - graphQuery = queries.NewGraphQuery(connections.Graph, graphQueryCache, cfg) - authorizer = auth.NewAuthorizer(connections.RDMS) - datapipeDaemon = datapipe.NewDaemon(pipeline, startDelay, time.Duration(cfg.DatapipeInterval)*time.Second, connections.RDMS) - routerInst = router.NewRouter(cfg, authorizer, fmt.Sprintf(bootstrap.ContentSecurityPolicy, "", "")) - authenticator = api.NewAuthenticator(cfg, connections.RDMS, api.NewAuthExtensions(cfg, connections.RDMS)) + cl = changelog.NewChangelog(connections.Graph, connections.RDMS, changelog.DefaultOptions()) + pipeline = datapipe.NewPipeline(ctx, cfg, connections.RDMS, connections.Graph, graphQueryCache, ingestSchema, cl) + graphQuery = queries.NewGraphQuery(connections.Graph, graphQueryCache, cfg) + authorizer = auth.NewAuthorizer(connections.RDMS) + datapipeDaemon = datapipe.NewDaemon(pipeline, startDelay, time.Duration(cfg.DatapipeInterval)*time.Second, connections.RDMS) + routerInst = router.NewRouter(cfg, authorizer, fmt.Sprintf(bootstrap.ContentSecurityPolicy, "", "")) + authenticator = api.NewAuthenticator(cfg, connections.RDMS, api.NewAuthExtensions(cfg, connections.RDMS)) openGraphSchemaService = opengraphschema.NewOpenGraphSchemaService(connections.RDMS) ) From 1de71fc7ff6e2554ca9e03f0a0f8f2f00c873b9b Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Thu, 15 Jan 2026 16:08:07 -0600 Subject: [PATCH 24/25] migration update --- cmd/api/src/database/migration/migrations/v8.6.0.sql | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cmd/api/src/database/migration/migrations/v8.6.0.sql b/cmd/api/src/database/migration/migrations/v8.6.0.sql index fbb20a81b0..219f6bb54d 100644 --- a/cmd/api/src/database/migration/migrations/v8.6.0.sql +++ b/cmd/api/src/database/migration/migrations/v8.6.0.sql @@ -13,10 +13,10 @@ -- limitations under the License. -- -- SPDX-License-Identifier: Apache-2.0 --- Drop the column with the incorrect foreign key + ALTER TABLE IF EXISTS schema_environments - DROP COLUMN IF EXISTS source_kind_id; + DROP CONSTRAINT IF EXISTS schema_environments_source_kind_id_fkey; --- Add the column back with the correct foreign key reference ALTER TABLE IF EXISTS schema_environments - ADD COLUMN source_kind_id INTEGER NOT NULL REFERENCES source_kinds(id); + ADD CONSTRAINT schema_environments_source_kind_id_fkey + FOREIGN KEY (source_kind_id) REFERENCES source_kinds(id); From a2427b6ce803ded0144d3f00b0f686ffd4a95640 Mon Sep 17 00:00:00 2001 From: Katherine Powderly Date: Thu, 15 Jan 2026 16:08:13 -0600 Subject: [PATCH 25/25] migration update --- cmd/api/src/database/migration/migrations/v8.5.0.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/api/src/database/migration/migrations/v8.5.0.sql b/cmd/api/src/database/migration/migrations/v8.5.0.sql index 348e8abf84..b2b94b3e80 100644 --- a/cmd/api/src/database/migration/migrations/v8.5.0.sql +++ b/cmd/api/src/database/migration/migrations/v8.5.0.sql @@ -92,7 +92,7 @@ CREATE TABLE IF NOT EXISTS schema_environments ( id SERIAL, schema_extension_id INTEGER NOT NULL REFERENCES schema_extensions(id) ON DELETE CASCADE, environment_kind_id INTEGER NOT NULL REFERENCES kind(id), - source_kind_id INTEGER NOT NULL REFERENCES source_kinds(id), + source_kind_id INTEGER NOT NULL REFERENCES kind(id), PRIMARY KEY (id), created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT current_timestamp, updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT current_timestamp,