diff --git a/cmd/api/src/api/registration/registration.go b/cmd/api/src/api/registration/registration.go index 179298a059..4b0c02b57c 100644 --- a/cmd/api/src/api/registration/registration.go +++ b/cmd/api/src/api/registration/registration.go @@ -63,6 +63,7 @@ func RegisterFossRoutes( authenticator api.Authenticator, authorizer auth.Authorizer, ingestSchema upload.IngestSchema, + openGraphSchemaService v2.OpenGraphSchemaService, dogtagsService dogtags.Service, ) { router.With(func() mux.MiddlewareFunc { @@ -82,6 +83,7 @@ func RegisterFossRoutes( routerInst.PathPrefix("/ui", static.AssetHandler), ) - var resources = v2.NewResources(rdms, graphDB, cfg, apiCache, graphQuery, collectorManifests, authorizer, authenticator, ingestSchema, dogtagsService) + var resources = v2.NewResources(rdms, graphDB, cfg, apiCache, graphQuery, collectorManifests, authorizer, + authenticator, ingestSchema, openGraphSchemaService, dogtagsService) NewV2API(resources, routerInst) } diff --git a/cmd/api/src/api/registration/v2.go b/cmd/api/src/api/registration/v2.go index 16d524e135..056c744a54 100644 --- a/cmd/api/src/api/registration/v2.go +++ b/cmd/api/src/api/registration/v2.go @@ -367,5 +367,8 @@ func NewV2API(resources v2.Resources, routerInst *router.Router) { routerInst.POST("/api/v2/custom-nodes", resources.CreateCustomNodeKind).RequireAuth(), routerInst.PUT(fmt.Sprintf("/api/v2/custom-nodes/{%s}", v2.CustomNodeKindParameter), resources.UpdateCustomNodeKind).RequireAuth(), routerInst.DELETE(fmt.Sprintf("/api/v2/custom-nodes/{%s}", v2.CustomNodeKindParameter), resources.DeleteCustomNodeKind).RequireAuth(), + + // Open Graph Schema Ingest + routerInst.PUT("/api/v2/extensions", resources.OpenGraphSchemaIngest).RequireAuth(), ) } diff --git a/cmd/api/src/api/v2/mocks/graphschemaextensions.go b/cmd/api/src/api/v2/mocks/graphschemaextensions.go new file mode 100644 index 0000000000..f59b0b56e7 --- /dev/null +++ b/cmd/api/src/api/v2/mocks/graphschemaextensions.go @@ -0,0 +1,73 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/specterops/bloodhound/cmd/api/src/api/v2 (interfaces: OpenGraphSchemaService) +// +// Generated by this command: +// +// mockgen -copyright_file ../../../../../LICENSE.header -destination=./mocks/graphschemaextensions.go -package=mocks . OpenGraphSchemaService +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + model "github.com/specterops/bloodhound/cmd/api/src/model" + gomock "go.uber.org/mock/gomock" +) + +// MockOpenGraphSchemaService is a mock of OpenGraphSchemaService interface. +type MockOpenGraphSchemaService struct { + ctrl *gomock.Controller + recorder *MockOpenGraphSchemaServiceMockRecorder + isgomock struct{} +} + +// MockOpenGraphSchemaServiceMockRecorder is the mock recorder for MockOpenGraphSchemaService. +type MockOpenGraphSchemaServiceMockRecorder struct { + mock *MockOpenGraphSchemaService +} + +// NewMockOpenGraphSchemaService creates a new mock instance. +func NewMockOpenGraphSchemaService(ctrl *gomock.Controller) *MockOpenGraphSchemaService { + mock := &MockOpenGraphSchemaService{ctrl: ctrl} + mock.recorder = &MockOpenGraphSchemaServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOpenGraphSchemaService) EXPECT() *MockOpenGraphSchemaServiceMockRecorder { + return m.recorder +} + +// UpsertOpenGraphExtension mocks base method. +func (m *MockOpenGraphSchemaService) UpsertOpenGraphExtension(ctx context.Context, graphSchema model.GraphSchema) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertOpenGraphExtension", ctx, graphSchema) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpsertOpenGraphExtension indicates an expected call of UpsertOpenGraphExtension. +func (mr *MockOpenGraphSchemaServiceMockRecorder) UpsertOpenGraphExtension(ctx, graphSchema any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertOpenGraphExtension", reflect.TypeOf((*MockOpenGraphSchemaService)(nil).UpsertOpenGraphExtension), ctx, graphSchema) +} diff --git a/cmd/api/src/api/v2/model.go b/cmd/api/src/api/v2/model.go index f9a08e13c5..a1b455db3d 100644 --- a/cmd/api/src/api/v2/model.go +++ b/cmd/api/src/api/v2/model.go @@ -117,6 +117,7 @@ type Resources struct { IngestSchema upload.IngestSchema FileService fs.Service DogTags dogtags.Service + openGraphSchemaService OpenGraphSchemaService } func NewResources( @@ -129,6 +130,7 @@ func NewResources( authorizer auth.Authorizer, authenticator api.Authenticator, ingestSchema upload.IngestSchema, + openGraphSchemaService OpenGraphSchemaService, dogtagsService dogtags.Service, ) Resources { return Resources{ @@ -144,6 +146,7 @@ func NewResources( Authenticator: authenticator, IngestSchema: ingestSchema, FileService: &fs.Client{}, + openGraphSchemaService: openGraphSchemaService, DogTags: dogtagsService, } } diff --git a/cmd/api/src/api/v2/opengraphschema.go b/cmd/api/src/api/v2/opengraphschema.go new file mode 100644 index 0000000000..6eb2f15d15 --- /dev/null +++ b/cmd/api/src/api/v2/opengraphschema.go @@ -0,0 +1,117 @@ +// Copyright 2025 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package v2 + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/specterops/bloodhound/cmd/api/src/api" + "github.com/specterops/bloodhound/cmd/api/src/auth" + ctx2 "github.com/specterops/bloodhound/cmd/api/src/ctx" + "github.com/specterops/bloodhound/cmd/api/src/model" + "github.com/specterops/bloodhound/cmd/api/src/model/appcfg" + "github.com/specterops/bloodhound/cmd/api/src/model/ingest" + bhUtils "github.com/specterops/bloodhound/cmd/api/src/utils" + "github.com/specterops/bloodhound/packages/go/headers" + "github.com/specterops/bloodhound/packages/go/mediatypes" +) + +//go:generate go run go.uber.org/mock/mockgen -copyright_file ../../../../../LICENSE.header -destination=./mocks/graphschemaextensions.go -package=mocks . OpenGraphSchemaService + +type OpenGraphSchemaService interface { + UpsertOpenGraphExtension(ctx context.Context, graphSchema model.GraphSchema) (bool, error) +} + +func (s Resources) OpenGraphSchemaIngest(response http.ResponseWriter, request *http.Request) { + var ( + ctx = request.Context() + err error + flag appcfg.FeatureFlag + + updated bool + + extractExtensionData func(file io.Reader) (model.GraphSchema, error) + graphSchemaPayload model.GraphSchema + ) + + // TODO: what to return if feature flag is not enabled + if flag, err = s.DB.GetFlagByKey(ctx, appcfg.FeatureOpenGraphExtensionManagement); err != nil { + api.HandleDatabaseError(request, response, err) + } else if !flag.Enabled { + response.WriteHeader(http.StatusNotFound) + } else if user, isUser := auth.GetUserFromAuthCtx(ctx2.FromRequest(request).AuthCtx); !isUser { + api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "No associated "+ + "user found", request), response) + } else if !user.Roles.Has(model.Role{Name: auth.RoleAdministrator}) { + api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusForbidden, "user does not "+ + "have sufficient permissions to create or update an extension", request), response) + } else if request.Body == nil { + api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, "open graph "+ + "extension payload cannot be empty", request), response) + } else { + request.Body = http.MaxBytesReader(response, request.Body, api.DefaultAPIPayloadReadLimitBytes) + defer request.Body.Close() + switch { + case bhUtils.HeaderMatches(request.Header, headers.ContentType.String(), mediatypes.ApplicationJson.String()): + extractExtensionData = extractExtensionDataFromJSON + case bhUtils.HeaderMatches(request.Header, headers.ContentType.String(), ingest.AllowedZipFileUploadTypes...): + fallthrough + // extractExtensionData = extractExtensionDataFromZipFile - will be needed for a future + default: + api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusUnsupportedMediaType, + fmt.Sprintf("%s; Content type must be application/json", + fmt.Errorf("invalid content-type: %s", request.Header[headers.ContentType.String()])), request), response) + return + } + + if graphSchemaPayload, err = extractExtensionData(request.Body); err != nil { + api.WriteErrorResponse(request.Context(), api.BuildErrorResponse(http.StatusBadRequest, err.Error(), request), response) + return + } + + if updated, err = s.openGraphSchemaService.UpsertOpenGraphExtension(ctx, graphSchemaPayload); err != nil { + switch { + // TODO: more error types (ex: validation) + default: + api.WriteErrorResponse(ctx, api.BuildErrorResponse(http.StatusInternalServerError, fmt.Sprintf("unable to update graph schema: %v", err), request), response) + return + } + } else if updated { + response.WriteHeader(http.StatusOK) + } else { + response.WriteHeader(http.StatusCreated) + } + } +} + +// extractExtensionDataFromJSON - extracts a model.GraphSchema from the incoming payload. Will return an error if there +// are any extra fields or if the decoder fails to decode the payload. +func extractExtensionDataFromJSON(payload io.Reader) (model.GraphSchema, error) { + var ( + err error + decoder = json.NewDecoder(payload) + graphSchema model.GraphSchema + ) + decoder.DisallowUnknownFields() + if err = decoder.Decode(&graphSchema); err != nil { + return graphSchema, err + } + return graphSchema, nil +} diff --git a/cmd/api/src/database/db.go b/cmd/api/src/database/db.go index 7f776d595f..d1d0fc168e 100644 --- a/cmd/api/src/database/db.go +++ b/cmd/api/src/database/db.go @@ -187,8 +187,7 @@ type Database interface { // Environment Targeted Access Control EnvironmentTargetedAccessControlData - // OpenGraph Schema - OpenGraphSchema + GetGraphSchemaEdgeKindsWithSchemaName(ctx context.Context, edgeKindFilters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaEdgeKindsWithNamedSchema, int, error) } type BloodhoundDB struct { diff --git a/cmd/api/src/database/genericmapdiff.go b/cmd/api/src/database/genericmapdiff.go new file mode 100644 index 0000000000..f6d8e2ce2a --- /dev/null +++ b/cmd/api/src/database/genericmapdiff.go @@ -0,0 +1,148 @@ +// Copyright 2025 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package database + +import ( + "context" +) + +// TODO: Where should these live?? + +// MapDiffActions - Actions required to sync two maps +// +// 1. ItemsToUpdate (Updates): Represents the items present in both the SourceMap and DestinationMap, +// (SourceMap ∩ DestinationMap). This is the overlapping portion of both circles of a Venn diagram. +// +// 2. ItemsToDelete (Deletes): Represents the items present exclusively in the DestinationMap +// that are absent in the SourceMap, (DestinationMap - SourceMap). This represents items in the right +// circle of a Venn diagram that are *not* in the intersection. +// +// 3. ItemsToInsert (Inserts): Represents the items present exclusively in the SourceMap that are absent +// in the DestinationMap, (SourceMap - DestinationMap). This represents items in the left circle of a Venn +// diagram that are *not* in the intersection. +type MapDiffActions[V any] struct { + ItemsToDelete []V + ItemsToUpdate []V + ItemsToInsert []V +} + +// GenerateMapDiffActions compares two maps (`SourceMap` and `DestinationMap`) using their +// keys (`K`) to compute the required synchronization actions based on set theory. +// +// The function generates three lists of *values* (`V`) representing the operations needed to make +// `DestinationMap` an exact replica of `SourceMap`: +// +// 1. ItemsToUpdate (Updates): Represents the items present in both the SourceMap and DestinationMap, +// (SourceMap ∩ DestinationMap). This is the overlapping portion of both circles of a Venn diagram. +// +// 2. ItemsToDelete (Deletes): Represents the items present exclusively in the DestinationMap that are +// absent in the SourceMap, (DestinationMap - SourceMap). This represents items in the right circle +// of a Venn diagram that are *not* in the intersection. +// +// 3. ItemsToInsert (Inserts): Represents the items present exclusively in the SourceMap that are absent +// in the DestinationMap, (SourceMap - DestinationMap). This represents items in the left circle of a Venn +// diagram that are *not* in the intersection. +// +// An optional onMatch function can be provided to perform struct updates. +func GenerateMapDiffActions[K comparable, V any](src, dst map[K]V, onMatch func(*V, *V)) MapDiffActions[V] { + + var ( + actions = MapDiffActions[V]{ + ItemsToDelete: make([]V, 0), + ItemsToUpdate: make([]V, 0), + ItemsToInsert: make([]V, 0), + } + ) + + if src == nil { + src = make(map[K]V) + } else if dst == nil { + dst = make(map[K]V) + } + + // 1. Identify keys to delete from the dst + // These will be keys that exist in dst but not in src + for k, v := range dst { + if _, exists := src[k]; !exists { + actions.ItemsToDelete = append(actions.ItemsToDelete, v) + } + } + + // 2. Identify keys to upsert (all keys in src) + for k, v := range src { + + // Retrieve the existing value from dst map, if it exists + dstVal, existsInDst := dst[k] + + // Pass the key, the src value pointer, and the dst value pointer + if existsInDst { + if onMatch != nil { + onMatch(&v, &dstVal) + } + actions.ItemsToUpdate = append(actions.ItemsToUpdate, v) + } else { + // If it's a new key, pass nil for the dst value pointer + actions.ItemsToInsert = append(actions.ItemsToInsert, v) + } + } + + return actions +} + +// HandleMapDiffAction iterates through a set of MapDiffActions and executes the +// provided callback functions for items marked for deletion, update, or insertion. +// +// The function processes actions in the following order: +// 1. Deletions (using deleteFunc) +// 2. Updates (using updateFunc) +// 3. Insertions (using insertFunc) +// +// It returns the first error encountered during any of the operations, halting +// further processing. If all operations succeed, it returns an updated slice of +// V, consisting of items returned from the update and insert functions. +func HandleMapDiffAction[V any](ctx context.Context, actions MapDiffActions[V], deleteFunc func(context.Context, V) error, updateFunc, insertFunc func(context.Context, V) (V, error)) ([]V, error) { + var ( + err error + updatedItem V + updatedItems = make([]V, 0) + ) + if len(actions.ItemsToDelete) > 0 { + for _, itemToDelete := range actions.ItemsToDelete { + if err = deleteFunc(ctx, itemToDelete); err != nil { + return updatedItems, err + } + } + } + + if len(actions.ItemsToUpdate) > 0 { + for _, itemToUpdate := range actions.ItemsToUpdate { + if updatedItem, err = updateFunc(ctx, itemToUpdate); err != nil { + return updatedItems, err + } + updatedItems = append(updatedItems, updatedItem) + } + } + if len(actions.ItemsToInsert) > 0 { + for _, itemToInsert := range actions.ItemsToInsert { + if updatedItem, err = insertFunc(ctx, itemToInsert); err != nil { + return updatedItems, err + } + updatedItems = append(updatedItems, updatedItem) + } + } + return updatedItems, nil +} diff --git a/cmd/api/src/database/genericmapdiff_test.go b/cmd/api/src/database/genericmapdiff_test.go new file mode 100644 index 0000000000..e1231ec966 --- /dev/null +++ b/cmd/api/src/database/genericmapdiff_test.go @@ -0,0 +1,231 @@ +// Copyright 2025 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package database + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +type testMapDiffStruct struct { + foo string + bar int + z string +} + +func swapZTestFunc(src, dst *testMapDiffStruct) { + src.z = dst.z +} + +func TestDiffMapsToSyncActions(t *testing.T) { + var ( + testStruct1 = testMapDiffStruct{foo: "foo_1", bar: 1, z: "z_1"} + testStruct2 = testMapDiffStruct{foo: "foo_2", bar: 2, z: "z_2"} + updatedTestStruct = testMapDiffStruct{foo: "foo_1", bar: 4, z: "z_4"} + testStruct3 = testMapDiffStruct{foo: "foo_3", bar: 3, z: "z_3"} + ) + + type args[K comparable, V any] struct { + dst map[K]V + src map[K]V + onUpsert func(*V, *V) + } + type testCase[K comparable, V any] struct { + name string + args args[K, V] + want MapDiffActions[V] + } + tests := []testCase[string, testMapDiffStruct]{ + { + name: "empty src", + args: args[string, testMapDiffStruct]{ + dst: map[string]testMapDiffStruct{ + testStruct1.foo: testStruct1, + testStruct2.foo: testStruct2, + }, + src: map[string]testMapDiffStruct{}, + onUpsert: swapZTestFunc, + }, + want: MapDiffActions[testMapDiffStruct]{ + ItemsToDelete: []testMapDiffStruct{testStruct1, testStruct2}, + ItemsToUpdate: []testMapDiffStruct{}, + ItemsToInsert: []testMapDiffStruct{}, + }, + }, + { + name: "empty dst", + args: args[string, testMapDiffStruct]{ + dst: map[string]testMapDiffStruct{}, + src: map[string]testMapDiffStruct{ + updatedTestStruct.foo: updatedTestStruct, + testStruct3.foo: testStruct3, + }, + onUpsert: swapZTestFunc, + }, + want: MapDiffActions[testMapDiffStruct]{ + ItemsToDelete: []testMapDiffStruct{}, + ItemsToUpdate: []testMapDiffStruct{}, + ItemsToInsert: []testMapDiffStruct{updatedTestStruct, testStruct3}, + }, + }, + { + name: "success - convertGraphSchemaNodeKinds", + args: args[string, testMapDiffStruct]{ + dst: map[string]testMapDiffStruct{ + testStruct1.foo: testStruct1, + testStruct2.foo: testStruct2, + }, + src: map[string]testMapDiffStruct{ + updatedTestStruct.foo: updatedTestStruct, + testStruct3.foo: testStruct3, + }, + onUpsert: swapZTestFunc, + }, + want: MapDiffActions[testMapDiffStruct]{ + ItemsToDelete: []testMapDiffStruct{testStruct2}, + ItemsToUpdate: []testMapDiffStruct{{foo: "foo_1", bar: 4, z: "z_1"}}, + ItemsToInsert: []testMapDiffStruct{testStruct3}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GenerateMapDiffActions(tt.args.src, tt.args.dst, tt.args.onUpsert) + compareTestMapDiffStructs(t, got.ItemsToInsert, tt.want.ItemsToInsert) + compareTestMapDiffStructs(t, got.ItemsToUpdate, tt.want.ItemsToUpdate) + compareTestMapDiffStructs(t, got.ItemsToDelete, tt.want.ItemsToDelete) + }) + } +} + +func compareTestMapDiffStructs(t *testing.T, got, want []testMapDiffStruct) { + t.Helper() + require.Equalf(t, len(want), len(got), "length mismatch of GraphSchemaEdgeKinds") + require.ElementsMatchf(t, want, got, "GraphSchemaEdgeKinds mismatch") +} + +func TestHandleMapDiffAction(t *testing.T) { + var ( + testStruct1 = testMapDiffStruct{foo: "foo_1", bar: 1, z: "z_1"} + testStruct2 = testMapDiffStruct{foo: "foo_2", bar: 2, z: "z_2"} + testStruct3 = testMapDiffStruct{foo: "foo_3", bar: 3, z: "z_3"} + ) + + type args[V any] struct { + ctx context.Context + actions MapDiffActions[testMapDiffStruct] + deleteFunc func(context.Context, V) error + updateFunc func(context.Context, V) (V, error) + insertFunc func(context.Context, V) (V, error) + } + type testCase[V any] struct { + name string + args args[V] + wantErr error + want []testMapDiffStruct + } + tests := []testCase[testMapDiffStruct]{ + { + name: "fail - error during delete func", + args: args[testMapDiffStruct]{ + ctx: context.Background(), + actions: MapDiffActions[testMapDiffStruct]{ + ItemsToDelete: []testMapDiffStruct{testStruct1}, + ItemsToUpdate: []testMapDiffStruct{testStruct2}, + ItemsToInsert: []testMapDiffStruct{testStruct3}, + }, + deleteFunc: testDeleteMapDiffStructFuncError, + updateFunc: testMapDiffStructFunc, + insertFunc: testMapDiffStructFunc, + }, + wantErr: fmt.Errorf("test map diff func error")}, + { + name: "fail - error during update func", + args: args[testMapDiffStruct]{ + ctx: context.Background(), + actions: MapDiffActions[testMapDiffStruct]{ + ItemsToDelete: []testMapDiffStruct{testStruct1}, + ItemsToUpdate: []testMapDiffStruct{testStruct2}, + ItemsToInsert: []testMapDiffStruct{testStruct3}, + }, + deleteFunc: testDeleteMapDiffStructFunc, + updateFunc: testMapDiffStructFuncError, + insertFunc: testMapDiffStructFunc, + }, + wantErr: fmt.Errorf("test map diff func error")}, + { + name: "fail - error during insert func", + args: args[testMapDiffStruct]{ + ctx: context.Background(), + actions: MapDiffActions[testMapDiffStruct]{ + ItemsToDelete: []testMapDiffStruct{testStruct1}, + ItemsToUpdate: []testMapDiffStruct{testStruct2}, + ItemsToInsert: []testMapDiffStruct{testStruct3}, + }, + deleteFunc: testDeleteMapDiffStructFunc, + updateFunc: testMapDiffStructFunc, + insertFunc: testMapDiffStructFuncError, + }, + wantErr: fmt.Errorf("test map diff func error"), + }, + { + name: "success", + args: args[testMapDiffStruct]{ + ctx: context.Background(), + actions: MapDiffActions[testMapDiffStruct]{ + ItemsToDelete: []testMapDiffStruct{testStruct1}, + ItemsToUpdate: []testMapDiffStruct{testStruct2}, + ItemsToInsert: []testMapDiffStruct{testStruct3}, + }, + deleteFunc: testDeleteMapDiffStructFunc, + updateFunc: testMapDiffStructFunc, + insertFunc: testMapDiffStructFunc, + }, + wantErr: nil, + want: []testMapDiffStruct{testStruct2, testStruct3}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if updatedItems, err := HandleMapDiffAction(tt.args.ctx, tt.args.actions, tt.args.deleteFunc, + tt.args.updateFunc, tt.args.insertFunc); tt.wantErr != nil { + require.EqualErrorf(t, err, tt.wantErr.Error(), "HandleMapDiffAction(%v, %v)", tt.args.ctx, tt.args.actions) + } else { + require.NoError(t, err) + require.Equalf(t, updatedItems, tt.want, "HandleMapDiffAction(%v, %v)", tt.args.ctx, tt.args.actions) + } + }) + } +} + +func testDeleteMapDiffStructFunc(_ context.Context, t testMapDiffStruct) error { + return nil +} +func testDeleteMapDiffStructFuncError(_ context.Context, t testMapDiffStruct) error { + return fmt.Errorf("test map diff func error") +} + +func testMapDiffStructFunc(_ context.Context, t testMapDiffStruct) (testMapDiffStruct, error) { + return t, nil +} + +func testMapDiffStructFuncError(_ context.Context, t testMapDiffStruct) (testMapDiffStruct, error) { + return t, fmt.Errorf("test map diff func error") +} diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index 03d5e53014..68ac88d81c 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -25,54 +25,7 @@ import ( "github.com/specterops/bloodhound/cmd/api/src/model" ) -type OpenGraphSchema interface { - CreateGraphSchemaExtension(ctx context.Context, name string, displayName string, version string) (model.GraphSchemaExtension, error) - GetGraphSchemaExtensionById(ctx context.Context, extensionId int32) (model.GraphSchemaExtension, error) - GetGraphSchemaExtensions(ctx context.Context, extensionFilters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaExtensions, int, error) - UpdateGraphSchemaExtension(ctx context.Context, extension model.GraphSchemaExtension) (model.GraphSchemaExtension, error) - DeleteGraphSchemaExtension(ctx context.Context, extensionId int32) error - - CreateGraphSchemaNodeKind(ctx context.Context, name string, extensionId int32, displayName string, description string, isDisplayKind bool, icon, iconColor string) (model.GraphSchemaNodeKind, error) - GetGraphSchemaNodeKindById(ctx context.Context, schemaNodeKindID int32) (model.GraphSchemaNodeKind, error) - GetGraphSchemaNodeKinds(ctx context.Context, nodeKindFilters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaNodeKinds, int, error) - UpdateGraphSchemaNodeKind(ctx context.Context, schemaNodeKind model.GraphSchemaNodeKind) (model.GraphSchemaNodeKind, error) - DeleteGraphSchemaNodeKind(ctx context.Context, schemaNodeKindId int32) error - - CreateGraphSchemaProperty(ctx context.Context, extensionId int32, name string, displayName string, dataType string, description string) (model.GraphSchemaProperty, error) - GetGraphSchemaPropertyById(ctx context.Context, extensionPropertyId int32) (model.GraphSchemaProperty, error) - GetGraphSchemaProperties(ctx context.Context, filters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaProperties, int, error) - UpdateGraphSchemaProperty(ctx context.Context, property model.GraphSchemaProperty) (model.GraphSchemaProperty, error) - DeleteGraphSchemaProperty(ctx context.Context, propertyID int32) error - - CreateGraphSchemaEdgeKind(ctx context.Context, name string, schemaExtensionId int32, description string, isTraversable bool) (model.GraphSchemaEdgeKind, error) - GetGraphSchemaEdgeKinds(ctx context.Context, edgeKindFilters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaEdgeKinds, int, error) - GetGraphSchemaEdgeKindById(ctx context.Context, schemaEdgeKindId int32) (model.GraphSchemaEdgeKind, error) - UpdateGraphSchemaEdgeKind(ctx context.Context, schemaEdgeKind model.GraphSchemaEdgeKind) (model.GraphSchemaEdgeKind, error) - DeleteGraphSchemaEdgeKind(ctx context.Context, schemaEdgeKindId int32) error - - GetGraphSchemaEdgeKindsWithSchemaName(ctx context.Context, edgeKindFilters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaEdgeKindsWithNamedSchema, int, error) - - CreateSchemaEnvironment(ctx context.Context, extensionId int32, environmentKindId int32, sourceKindId int32) (model.SchemaEnvironment, error) - GetSchemaEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) - GetSchemaEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) - DeleteSchemaEnvironment(ctx context.Context, environmentId int32) error - - CreateSchemaRelationshipFinding(ctx context.Context, extensionId int32, relationshipKindId int32, environmentId int32, name string, displayName string) (model.SchemaRelationshipFinding, error) - GetSchemaRelationshipFindingById(ctx context.Context, findingId int32) (model.SchemaRelationshipFinding, error) - DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error - - CreateRemediation(ctx context.Context, findingId int32, shortDescription string, longDescription string, shortRemediation string, longRemediation string) (model.Remediation, error) - GetRemediationByFindingId(ctx context.Context, findingId int32) (model.Remediation, error) - UpdateRemediation(ctx context.Context, findingId int32, shortDescription string, longDescription string, shortRemediation string, longRemediation string) (model.Remediation, error) - DeleteRemediation(ctx context.Context, findingId int32) error - CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) - GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) - DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId int32, principalKind int32) error -} - -const ( - DuplicateKeyValueErrorString = "duplicate key value violates unique constraint" -) +const DuplicateKeyValueErrorString = "duplicate key value violates unique constraint" type FilterAndPagination struct { Filter sqlFilter diff --git a/cmd/api/src/database/migration/migrations/v8.6.0.sql b/cmd/api/src/database/migration/migrations/v8.6.0.sql new file mode 100644 index 0000000000..eddbe1e8dc --- /dev/null +++ b/cmd/api/src/database/migration/migrations/v8.6.0.sql @@ -0,0 +1,26 @@ +-- Copyright 2026 Specter Ops, Inc. +-- +-- Licensed under the Apache License, Version 2.0 +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. +-- +-- SPDX-License-Identifier: Apache-2.0 + +-- OpenGraph Schema Extension Management feature flag +INSERT INTO feature_flags (created_at, updated_at, key, name, description, enabled, user_updatable) +VALUES (current_timestamp, + current_timestamp, + 'opengraph_extension_management', + 'OpenGraph Schema Extension Management', + 'Enable OpenGraph Schema Extension Management', + false, + false) +ON CONFLICT DO NOTHING; diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index cf16fece30..d49f498b3f 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -341,66 +341,6 @@ func (mr *MockDatabaseMockRecorder) CreateCustomNodeKinds(ctx, customNodeKind an return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCustomNodeKinds", reflect.TypeOf((*MockDatabase)(nil).CreateCustomNodeKinds), ctx, customNodeKind) } -// CreateGraphSchemaEdgeKind mocks base method. -func (m *MockDatabase) CreateGraphSchemaEdgeKind(ctx context.Context, name string, schemaExtensionId int32, description string, isTraversable bool) (model.GraphSchemaEdgeKind, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateGraphSchemaEdgeKind", ctx, name, schemaExtensionId, description, isTraversable) - ret0, _ := ret[0].(model.GraphSchemaEdgeKind) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateGraphSchemaEdgeKind indicates an expected call of CreateGraphSchemaEdgeKind. -func (mr *MockDatabaseMockRecorder) CreateGraphSchemaEdgeKind(ctx, name, schemaExtensionId, description, isTraversable any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateGraphSchemaEdgeKind", reflect.TypeOf((*MockDatabase)(nil).CreateGraphSchemaEdgeKind), ctx, name, schemaExtensionId, description, isTraversable) -} - -// CreateGraphSchemaExtension mocks base method. -func (m *MockDatabase) CreateGraphSchemaExtension(ctx context.Context, name, displayName, version string) (model.GraphSchemaExtension, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateGraphSchemaExtension", ctx, name, displayName, version) - ret0, _ := ret[0].(model.GraphSchemaExtension) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateGraphSchemaExtension indicates an expected call of CreateGraphSchemaExtension. -func (mr *MockDatabaseMockRecorder) CreateGraphSchemaExtension(ctx, name, displayName, version any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateGraphSchemaExtension", reflect.TypeOf((*MockDatabase)(nil).CreateGraphSchemaExtension), ctx, name, displayName, version) -} - -// CreateGraphSchemaNodeKind mocks base method. -func (m *MockDatabase) CreateGraphSchemaNodeKind(ctx context.Context, name string, extensionId int32, displayName, description string, isDisplayKind bool, icon, iconColor string) (model.GraphSchemaNodeKind, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateGraphSchemaNodeKind", ctx, name, extensionId, displayName, description, isDisplayKind, icon, iconColor) - ret0, _ := ret[0].(model.GraphSchemaNodeKind) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateGraphSchemaNodeKind indicates an expected call of CreateGraphSchemaNodeKind. -func (mr *MockDatabaseMockRecorder) CreateGraphSchemaNodeKind(ctx, name, extensionId, displayName, description, isDisplayKind, icon, iconColor any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateGraphSchemaNodeKind", reflect.TypeOf((*MockDatabase)(nil).CreateGraphSchemaNodeKind), ctx, name, extensionId, displayName, description, isDisplayKind, icon, iconColor) -} - -// CreateGraphSchemaProperty mocks base method. -func (m *MockDatabase) CreateGraphSchemaProperty(ctx context.Context, extensionId int32, name, displayName, dataType, description string) (model.GraphSchemaProperty, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateGraphSchemaProperty", ctx, extensionId, name, displayName, dataType, description) - ret0, _ := ret[0].(model.GraphSchemaProperty) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateGraphSchemaProperty indicates an expected call of CreateGraphSchemaProperty. -func (mr *MockDatabaseMockRecorder) CreateGraphSchemaProperty(ctx, extensionId, name, displayName, dataType, description any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateGraphSchemaProperty", reflect.TypeOf((*MockDatabase)(nil).CreateGraphSchemaProperty), ctx, extensionId, name, displayName, dataType, description) -} - // CreateIngestJob mocks base method. func (m *MockDatabase) CreateIngestJob(ctx context.Context, job model.IngestJob) (model.IngestJob, error) { m.ctrl.T.Helper() @@ -461,21 +401,6 @@ func (mr *MockDatabaseMockRecorder) CreateOIDCProvider(ctx, name, issuer, client return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOIDCProvider", reflect.TypeOf((*MockDatabase)(nil).CreateOIDCProvider), ctx, name, issuer, clientID, config) } -// CreateRemediation mocks base method. -func (m *MockDatabase) CreateRemediation(ctx context.Context, findingId int32, shortDescription, longDescription, shortRemediation, longRemediation string) (model.Remediation, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateRemediation", ctx, findingId, shortDescription, longDescription, shortRemediation, longRemediation) - ret0, _ := ret[0].(model.Remediation) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateRemediation indicates an expected call of CreateRemediation. -func (mr *MockDatabaseMockRecorder) CreateRemediation(ctx, findingId, shortDescription, longDescription, shortRemediation, longRemediation any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateRemediation", reflect.TypeOf((*MockDatabase)(nil).CreateRemediation), ctx, findingId, shortDescription, longDescription, shortRemediation, longRemediation) -} - // CreateSAMLIdentityProvider mocks base method. func (m *MockDatabase) CreateSAMLIdentityProvider(ctx context.Context, samlProvider model.SAMLProvider, config model.SSOProviderConfig) (model.SAMLProvider, error) { m.ctrl.T.Helper() @@ -570,51 +495,6 @@ func (mr *MockDatabaseMockRecorder) CreateSavedQueryPermissionsToUsers(ctx, quer return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSavedQueryPermissionsToUsers", reflect.TypeOf((*MockDatabase)(nil).CreateSavedQueryPermissionsToUsers), varargs...) } -// CreateSchemaEnvironment mocks base method. -func (m *MockDatabase) CreateSchemaEnvironment(ctx context.Context, extensionId, environmentKindId, sourceKindId int32) (model.SchemaEnvironment, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateSchemaEnvironment", ctx, extensionId, environmentKindId, sourceKindId) - ret0, _ := ret[0].(model.SchemaEnvironment) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateSchemaEnvironment indicates an expected call of CreateSchemaEnvironment. -func (mr *MockDatabaseMockRecorder) CreateSchemaEnvironment(ctx, extensionId, environmentKindId, sourceKindId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironment", reflect.TypeOf((*MockDatabase)(nil).CreateSchemaEnvironment), ctx, extensionId, environmentKindId, sourceKindId) -} - -// CreateSchemaEnvironmentPrincipalKind mocks base method. -func (m *MockDatabase) CreateSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) (model.SchemaEnvironmentPrincipalKind, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) - ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKind) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateSchemaEnvironmentPrincipalKind indicates an expected call of CreateSchemaEnvironmentPrincipalKind. -func (mr *MockDatabaseMockRecorder) CreateSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockDatabase)(nil).CreateSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) -} - -// CreateSchemaRelationshipFinding mocks base method. -func (m *MockDatabase) CreateSchemaRelationshipFinding(ctx context.Context, extensionId, relationshipKindId, environmentId int32, name, displayName string) (model.SchemaRelationshipFinding, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateSchemaRelationshipFinding", ctx, extensionId, relationshipKindId, environmentId, name, displayName) - ret0, _ := ret[0].(model.SchemaRelationshipFinding) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateSchemaRelationshipFinding indicates an expected call of CreateSchemaRelationshipFinding. -func (mr *MockDatabaseMockRecorder) CreateSchemaRelationshipFinding(ctx, extensionId, relationshipKindId, environmentId, name, displayName any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSchemaRelationshipFinding", reflect.TypeOf((*MockDatabase)(nil).CreateSchemaRelationshipFinding), ctx, extensionId, relationshipKindId, environmentId, name, displayName) -} - // CreateUser mocks base method. func (m *MockDatabase) CreateUser(ctx context.Context, user model.User) (model.User, error) { m.ctrl.T.Helper() @@ -856,62 +736,6 @@ func (mr *MockDatabaseMockRecorder) DeleteEnvironmentTargetedAccessControlForUse return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteEnvironmentTargetedAccessControlForUser", reflect.TypeOf((*MockDatabase)(nil).DeleteEnvironmentTargetedAccessControlForUser), ctx, user) } -// DeleteGraphSchemaEdgeKind mocks base method. -func (m *MockDatabase) DeleteGraphSchemaEdgeKind(ctx context.Context, schemaEdgeKindId int32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteGraphSchemaEdgeKind", ctx, schemaEdgeKindId) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteGraphSchemaEdgeKind indicates an expected call of DeleteGraphSchemaEdgeKind. -func (mr *MockDatabaseMockRecorder) DeleteGraphSchemaEdgeKind(ctx, schemaEdgeKindId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteGraphSchemaEdgeKind", reflect.TypeOf((*MockDatabase)(nil).DeleteGraphSchemaEdgeKind), ctx, schemaEdgeKindId) -} - -// DeleteGraphSchemaExtension mocks base method. -func (m *MockDatabase) DeleteGraphSchemaExtension(ctx context.Context, extensionId int32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteGraphSchemaExtension", ctx, extensionId) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteGraphSchemaExtension indicates an expected call of DeleteGraphSchemaExtension. -func (mr *MockDatabaseMockRecorder) DeleteGraphSchemaExtension(ctx, extensionId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteGraphSchemaExtension", reflect.TypeOf((*MockDatabase)(nil).DeleteGraphSchemaExtension), ctx, extensionId) -} - -// DeleteGraphSchemaNodeKind mocks base method. -func (m *MockDatabase) DeleteGraphSchemaNodeKind(ctx context.Context, schemaNodeKindId int32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteGraphSchemaNodeKind", ctx, schemaNodeKindId) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteGraphSchemaNodeKind indicates an expected call of DeleteGraphSchemaNodeKind. -func (mr *MockDatabaseMockRecorder) DeleteGraphSchemaNodeKind(ctx, schemaNodeKindId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteGraphSchemaNodeKind", reflect.TypeOf((*MockDatabase)(nil).DeleteGraphSchemaNodeKind), ctx, schemaNodeKindId) -} - -// DeleteGraphSchemaProperty mocks base method. -func (m *MockDatabase) DeleteGraphSchemaProperty(ctx context.Context, propertyID int32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteGraphSchemaProperty", ctx, propertyID) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteGraphSchemaProperty indicates an expected call of DeleteGraphSchemaProperty. -func (mr *MockDatabaseMockRecorder) DeleteGraphSchemaProperty(ctx, propertyID any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteGraphSchemaProperty", reflect.TypeOf((*MockDatabase)(nil).DeleteGraphSchemaProperty), ctx, propertyID) -} - // DeleteIngestTask mocks base method. func (m *MockDatabase) DeleteIngestTask(ctx context.Context, ingestTask model.IngestTask) error { m.ctrl.T.Helper() @@ -926,20 +750,6 @@ func (mr *MockDatabaseMockRecorder) DeleteIngestTask(ctx, ingestTask any) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteIngestTask", reflect.TypeOf((*MockDatabase)(nil).DeleteIngestTask), ctx, ingestTask) } -// DeleteRemediation mocks base method. -func (m *MockDatabase) DeleteRemediation(ctx context.Context, findingId int32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteRemediation", ctx, findingId) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteRemediation indicates an expected call of DeleteRemediation. -func (mr *MockDatabaseMockRecorder) DeleteRemediation(ctx, findingId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteRemediation", reflect.TypeOf((*MockDatabase)(nil).DeleteRemediation), ctx, findingId) -} - // DeleteSSOProvider mocks base method. func (m *MockDatabase) DeleteSSOProvider(ctx context.Context, id int) error { m.ctrl.T.Helper() @@ -987,48 +797,6 @@ func (mr *MockDatabaseMockRecorder) DeleteSavedQueryPermissionsForUsers(ctx, que return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSavedQueryPermissionsForUsers", reflect.TypeOf((*MockDatabase)(nil).DeleteSavedQueryPermissionsForUsers), varargs...) } -// DeleteSchemaEnvironment mocks base method. -func (m *MockDatabase) DeleteSchemaEnvironment(ctx context.Context, environmentId int32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteSchemaEnvironment", ctx, environmentId) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteSchemaEnvironment indicates an expected call of DeleteSchemaEnvironment. -func (mr *MockDatabaseMockRecorder) DeleteSchemaEnvironment(ctx, environmentId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironment", reflect.TypeOf((*MockDatabase)(nil).DeleteSchemaEnvironment), ctx, environmentId) -} - -// DeleteSchemaEnvironmentPrincipalKind mocks base method. -func (m *MockDatabase) DeleteSchemaEnvironmentPrincipalKind(ctx context.Context, environmentId, principalKind int32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteSchemaEnvironmentPrincipalKind", ctx, environmentId, principalKind) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteSchemaEnvironmentPrincipalKind indicates an expected call of DeleteSchemaEnvironmentPrincipalKind. -func (mr *MockDatabaseMockRecorder) DeleteSchemaEnvironmentPrincipalKind(ctx, environmentId, principalKind any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaEnvironmentPrincipalKind", reflect.TypeOf((*MockDatabase)(nil).DeleteSchemaEnvironmentPrincipalKind), ctx, environmentId, principalKind) -} - -// DeleteSchemaRelationshipFinding mocks base method. -func (m *MockDatabase) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteSchemaRelationshipFinding", ctx, findingId) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteSchemaRelationshipFinding indicates an expected call of DeleteSchemaRelationshipFinding. -func (mr *MockDatabaseMockRecorder) DeleteSchemaRelationshipFinding(ctx, findingId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteSchemaRelationshipFinding", reflect.TypeOf((*MockDatabase)(nil).DeleteSchemaRelationshipFinding), ctx, findingId) -} - // DeleteSelectorNodesByNodeId mocks base method. func (m *MockDatabase) DeleteSelectorNodesByNodeId(ctx context.Context, selectorId int, nodeId graph.ID) error { m.ctrl.T.Helper() @@ -1747,37 +1515,6 @@ func (mr *MockDatabaseMockRecorder) GetFlagByKey(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFlagByKey", reflect.TypeOf((*MockDatabase)(nil).GetFlagByKey), arg0, arg1) } -// GetGraphSchemaEdgeKindById mocks base method. -func (m *MockDatabase) GetGraphSchemaEdgeKindById(ctx context.Context, schemaEdgeKindId int32) (model.GraphSchemaEdgeKind, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetGraphSchemaEdgeKindById", ctx, schemaEdgeKindId) - ret0, _ := ret[0].(model.GraphSchemaEdgeKind) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetGraphSchemaEdgeKindById indicates an expected call of GetGraphSchemaEdgeKindById. -func (mr *MockDatabaseMockRecorder) GetGraphSchemaEdgeKindById(ctx, schemaEdgeKindId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGraphSchemaEdgeKindById", reflect.TypeOf((*MockDatabase)(nil).GetGraphSchemaEdgeKindById), ctx, schemaEdgeKindId) -} - -// GetGraphSchemaEdgeKinds mocks base method. -func (m *MockDatabase) GetGraphSchemaEdgeKinds(ctx context.Context, edgeKindFilters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaEdgeKinds, int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetGraphSchemaEdgeKinds", ctx, edgeKindFilters, sort, skip, limit) - ret0, _ := ret[0].(model.GraphSchemaEdgeKinds) - ret1, _ := ret[1].(int) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// GetGraphSchemaEdgeKinds indicates an expected call of GetGraphSchemaEdgeKinds. -func (mr *MockDatabaseMockRecorder) GetGraphSchemaEdgeKinds(ctx, edgeKindFilters, sort, skip, limit any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGraphSchemaEdgeKinds", reflect.TypeOf((*MockDatabase)(nil).GetGraphSchemaEdgeKinds), ctx, edgeKindFilters, sort, skip, limit) -} - // GetGraphSchemaEdgeKindsWithSchemaName mocks base method. func (m *MockDatabase) GetGraphSchemaEdgeKindsWithSchemaName(ctx context.Context, edgeKindFilters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaEdgeKindsWithNamedSchema, int, error) { m.ctrl.T.Helper() @@ -1794,99 +1531,6 @@ func (mr *MockDatabaseMockRecorder) GetGraphSchemaEdgeKindsWithSchemaName(ctx, e return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGraphSchemaEdgeKindsWithSchemaName", reflect.TypeOf((*MockDatabase)(nil).GetGraphSchemaEdgeKindsWithSchemaName), ctx, edgeKindFilters, sort, skip, limit) } -// GetGraphSchemaExtensionById mocks base method. -func (m *MockDatabase) GetGraphSchemaExtensionById(ctx context.Context, extensionId int32) (model.GraphSchemaExtension, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetGraphSchemaExtensionById", ctx, extensionId) - ret0, _ := ret[0].(model.GraphSchemaExtension) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetGraphSchemaExtensionById indicates an expected call of GetGraphSchemaExtensionById. -func (mr *MockDatabaseMockRecorder) GetGraphSchemaExtensionById(ctx, extensionId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGraphSchemaExtensionById", reflect.TypeOf((*MockDatabase)(nil).GetGraphSchemaExtensionById), ctx, extensionId) -} - -// GetGraphSchemaExtensions mocks base method. -func (m *MockDatabase) GetGraphSchemaExtensions(ctx context.Context, extensionFilters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaExtensions, int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetGraphSchemaExtensions", ctx, extensionFilters, sort, skip, limit) - ret0, _ := ret[0].(model.GraphSchemaExtensions) - ret1, _ := ret[1].(int) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// GetGraphSchemaExtensions indicates an expected call of GetGraphSchemaExtensions. -func (mr *MockDatabaseMockRecorder) GetGraphSchemaExtensions(ctx, extensionFilters, sort, skip, limit any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGraphSchemaExtensions", reflect.TypeOf((*MockDatabase)(nil).GetGraphSchemaExtensions), ctx, extensionFilters, sort, skip, limit) -} - -// GetGraphSchemaNodeKindById mocks base method. -func (m *MockDatabase) GetGraphSchemaNodeKindById(ctx context.Context, schemaNodeKindID int32) (model.GraphSchemaNodeKind, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetGraphSchemaNodeKindById", ctx, schemaNodeKindID) - ret0, _ := ret[0].(model.GraphSchemaNodeKind) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetGraphSchemaNodeKindById indicates an expected call of GetGraphSchemaNodeKindById. -func (mr *MockDatabaseMockRecorder) GetGraphSchemaNodeKindById(ctx, schemaNodeKindID any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGraphSchemaNodeKindById", reflect.TypeOf((*MockDatabase)(nil).GetGraphSchemaNodeKindById), ctx, schemaNodeKindID) -} - -// GetGraphSchemaNodeKinds mocks base method. -func (m *MockDatabase) GetGraphSchemaNodeKinds(ctx context.Context, nodeKindFilters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaNodeKinds, int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetGraphSchemaNodeKinds", ctx, nodeKindFilters, sort, skip, limit) - ret0, _ := ret[0].(model.GraphSchemaNodeKinds) - ret1, _ := ret[1].(int) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// GetGraphSchemaNodeKinds indicates an expected call of GetGraphSchemaNodeKinds. -func (mr *MockDatabaseMockRecorder) GetGraphSchemaNodeKinds(ctx, nodeKindFilters, sort, skip, limit any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGraphSchemaNodeKinds", reflect.TypeOf((*MockDatabase)(nil).GetGraphSchemaNodeKinds), ctx, nodeKindFilters, sort, skip, limit) -} - -// GetGraphSchemaProperties mocks base method. -func (m *MockDatabase) GetGraphSchemaProperties(ctx context.Context, filters model.Filters, sort model.Sort, skip, limit int) (model.GraphSchemaProperties, int, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetGraphSchemaProperties", ctx, filters, sort, skip, limit) - ret0, _ := ret[0].(model.GraphSchemaProperties) - ret1, _ := ret[1].(int) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// GetGraphSchemaProperties indicates an expected call of GetGraphSchemaProperties. -func (mr *MockDatabaseMockRecorder) GetGraphSchemaProperties(ctx, filters, sort, skip, limit any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGraphSchemaProperties", reflect.TypeOf((*MockDatabase)(nil).GetGraphSchemaProperties), ctx, filters, sort, skip, limit) -} - -// GetGraphSchemaPropertyById mocks base method. -func (m *MockDatabase) GetGraphSchemaPropertyById(ctx context.Context, extensionPropertyId int32) (model.GraphSchemaProperty, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetGraphSchemaPropertyById", ctx, extensionPropertyId) - ret0, _ := ret[0].(model.GraphSchemaProperty) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetGraphSchemaPropertyById indicates an expected call of GetGraphSchemaPropertyById. -func (mr *MockDatabaseMockRecorder) GetGraphSchemaPropertyById(ctx, extensionPropertyId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGraphSchemaPropertyById", reflect.TypeOf((*MockDatabase)(nil).GetGraphSchemaPropertyById), ctx, extensionPropertyId) -} - // GetIngestJob mocks base method. func (m *MockDatabase) GetIngestJob(ctx context.Context, id int64) (model.IngestJob, error) { m.ctrl.T.Helper() @@ -2007,21 +1651,6 @@ func (mr *MockDatabaseMockRecorder) GetPublicSavedQueries(ctx any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPublicSavedQueries", reflect.TypeOf((*MockDatabase)(nil).GetPublicSavedQueries), ctx) } -// GetRemediationByFindingId mocks base method. -func (m *MockDatabase) GetRemediationByFindingId(ctx context.Context, findingId int32) (model.Remediation, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetRemediationByFindingId", ctx, findingId) - ret0, _ := ret[0].(model.Remediation) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetRemediationByFindingId indicates an expected call of GetRemediationByFindingId. -func (mr *MockDatabaseMockRecorder) GetRemediationByFindingId(ctx, findingId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRemediationByFindingId", reflect.TypeOf((*MockDatabase)(nil).GetRemediationByFindingId), ctx, findingId) -} - // GetRole mocks base method. func (m *MockDatabase) GetRole(ctx context.Context, id int32) (model.Role, error) { m.ctrl.T.Helper() @@ -2172,66 +1801,6 @@ func (mr *MockDatabaseMockRecorder) GetSavedQueryPermissions(ctx, queryID any) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSavedQueryPermissions", reflect.TypeOf((*MockDatabase)(nil).GetSavedQueryPermissions), ctx, queryID) } -// GetSchemaEnvironmentById mocks base method. -func (m *MockDatabase) GetSchemaEnvironmentById(ctx context.Context, environmentId int32) (model.SchemaEnvironment, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaEnvironmentById", ctx, environmentId) - ret0, _ := ret[0].(model.SchemaEnvironment) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaEnvironmentById indicates an expected call of GetSchemaEnvironmentById. -func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentById(ctx, environmentId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentById", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentById), ctx, environmentId) -} - -// GetSchemaEnvironmentPrincipalKindsByEnvironmentId mocks base method. -func (m *MockDatabase) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx context.Context, environmentId int32) (model.SchemaEnvironmentPrincipalKinds, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", ctx, environmentId) - ret0, _ := ret[0].(model.SchemaEnvironmentPrincipalKinds) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaEnvironmentPrincipalKindsByEnvironmentId indicates an expected call of GetSchemaEnvironmentPrincipalKindsByEnvironmentId. -func (mr *MockDatabaseMockRecorder) GetSchemaEnvironmentPrincipalKindsByEnvironmentId(ctx, environmentId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironmentPrincipalKindsByEnvironmentId", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironmentPrincipalKindsByEnvironmentId), ctx, environmentId) -} - -// GetSchemaEnvironments mocks base method. -func (m *MockDatabase) GetSchemaEnvironments(ctx context.Context) ([]model.SchemaEnvironment, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaEnvironments", ctx) - ret0, _ := ret[0].([]model.SchemaEnvironment) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaEnvironments indicates an expected call of GetSchemaEnvironments. -func (mr *MockDatabaseMockRecorder) GetSchemaEnvironments(ctx any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaEnvironments", reflect.TypeOf((*MockDatabase)(nil).GetSchemaEnvironments), ctx) -} - -// GetSchemaRelationshipFindingById mocks base method. -func (m *MockDatabase) GetSchemaRelationshipFindingById(ctx context.Context, findingId int32) (model.SchemaRelationshipFinding, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSchemaRelationshipFindingById", ctx, findingId) - ret0, _ := ret[0].(model.SchemaRelationshipFinding) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSchemaRelationshipFindingById indicates an expected call of GetSchemaRelationshipFindingById. -func (mr *MockDatabaseMockRecorder) GetSchemaRelationshipFindingById(ctx, findingId any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaRelationshipFindingById", reflect.TypeOf((*MockDatabase)(nil).GetSchemaRelationshipFindingById), ctx, findingId) -} - // GetScopeForSavedQuery mocks base method. func (m *MockDatabase) GetScopeForSavedQuery(ctx context.Context, queryID int64, userID uuid.UUID) (database.SavedQueryScopeMap, error) { m.ctrl.T.Helper() @@ -2880,66 +2449,6 @@ func (mr *MockDatabaseMockRecorder) UpdateCustomNodeKind(ctx, customNodeKind any return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateCustomNodeKind", reflect.TypeOf((*MockDatabase)(nil).UpdateCustomNodeKind), ctx, customNodeKind) } -// UpdateGraphSchemaEdgeKind mocks base method. -func (m *MockDatabase) UpdateGraphSchemaEdgeKind(ctx context.Context, schemaEdgeKind model.GraphSchemaEdgeKind) (model.GraphSchemaEdgeKind, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateGraphSchemaEdgeKind", ctx, schemaEdgeKind) - ret0, _ := ret[0].(model.GraphSchemaEdgeKind) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// UpdateGraphSchemaEdgeKind indicates an expected call of UpdateGraphSchemaEdgeKind. -func (mr *MockDatabaseMockRecorder) UpdateGraphSchemaEdgeKind(ctx, schemaEdgeKind any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGraphSchemaEdgeKind", reflect.TypeOf((*MockDatabase)(nil).UpdateGraphSchemaEdgeKind), ctx, schemaEdgeKind) -} - -// UpdateGraphSchemaExtension mocks base method. -func (m *MockDatabase) UpdateGraphSchemaExtension(ctx context.Context, extension model.GraphSchemaExtension) (model.GraphSchemaExtension, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateGraphSchemaExtension", ctx, extension) - ret0, _ := ret[0].(model.GraphSchemaExtension) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// UpdateGraphSchemaExtension indicates an expected call of UpdateGraphSchemaExtension. -func (mr *MockDatabaseMockRecorder) UpdateGraphSchemaExtension(ctx, extension any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGraphSchemaExtension", reflect.TypeOf((*MockDatabase)(nil).UpdateGraphSchemaExtension), ctx, extension) -} - -// UpdateGraphSchemaNodeKind mocks base method. -func (m *MockDatabase) UpdateGraphSchemaNodeKind(ctx context.Context, schemaNodeKind model.GraphSchemaNodeKind) (model.GraphSchemaNodeKind, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateGraphSchemaNodeKind", ctx, schemaNodeKind) - ret0, _ := ret[0].(model.GraphSchemaNodeKind) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// UpdateGraphSchemaNodeKind indicates an expected call of UpdateGraphSchemaNodeKind. -func (mr *MockDatabaseMockRecorder) UpdateGraphSchemaNodeKind(ctx, schemaNodeKind any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGraphSchemaNodeKind", reflect.TypeOf((*MockDatabase)(nil).UpdateGraphSchemaNodeKind), ctx, schemaNodeKind) -} - -// UpdateGraphSchemaProperty mocks base method. -func (m *MockDatabase) UpdateGraphSchemaProperty(ctx context.Context, property model.GraphSchemaProperty) (model.GraphSchemaProperty, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateGraphSchemaProperty", ctx, property) - ret0, _ := ret[0].(model.GraphSchemaProperty) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// UpdateGraphSchemaProperty indicates an expected call of UpdateGraphSchemaProperty. -func (mr *MockDatabaseMockRecorder) UpdateGraphSchemaProperty(ctx, property any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGraphSchemaProperty", reflect.TypeOf((*MockDatabase)(nil).UpdateGraphSchemaProperty), ctx, property) -} - // UpdateIngestJob mocks base method. func (m *MockDatabase) UpdateIngestJob(ctx context.Context, job model.IngestJob) error { m.ctrl.T.Helper() @@ -2983,21 +2492,6 @@ func (mr *MockDatabaseMockRecorder) UpdateOIDCProvider(ctx, ssoProvider any) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateOIDCProvider", reflect.TypeOf((*MockDatabase)(nil).UpdateOIDCProvider), ctx, ssoProvider) } -// UpdateRemediation mocks base method. -func (m *MockDatabase) UpdateRemediation(ctx context.Context, findingId int32, shortDescription, longDescription, shortRemediation, longRemediation string) (model.Remediation, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateRemediation", ctx, findingId, shortDescription, longDescription, shortRemediation, longRemediation) - ret0, _ := ret[0].(model.Remediation) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// UpdateRemediation indicates an expected call of UpdateRemediation. -func (mr *MockDatabaseMockRecorder) UpdateRemediation(ctx, findingId, shortDescription, longDescription, shortRemediation, longRemediation any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateRemediation", reflect.TypeOf((*MockDatabase)(nil).UpdateRemediation), ctx, findingId, shortDescription, longDescription, shortRemediation, longRemediation) -} - // UpdateSAMLIdentityProvider mocks base method. func (m *MockDatabase) UpdateSAMLIdentityProvider(ctx context.Context, ssoProvider model.SSOProvider) (model.SAMLProvider, error) { m.ctrl.T.Helper() diff --git a/cmd/api/src/database/upsertExtension.go b/cmd/api/src/database/upsertExtension.go new file mode 100644 index 0000000000..bed2a1751a --- /dev/null +++ b/cmd/api/src/database/upsertExtension.go @@ -0,0 +1,232 @@ +// Copyright 2025 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package database + +import ( + "context" + "errors" + "fmt" + "strconv" + + "github.com/specterops/bloodhound/cmd/api/src/model" +) + +// UpsertOpenGraphExtension - compares then upserts the incoming model.GraphSchema with the one stored in +// the BloodHoundDB. +// +// During development, it was decided to push the logic of how extensions are upserted down to the database +// layer due to difficulties decoupling the database and service layers while still providing transactional +// consistency. The following functions use models intended for the service layer and call the database public +// methods directly, rather than using an interface. +func (s *BloodhoundDB) UpsertOpenGraphExtension(ctx context.Context, graphSchema model.GraphSchema) (bool, error) { + var ( + tx = s.db.WithContext(ctx).Begin() + bloodhoundDBTransaction = BloodhoundDB{db: tx} + ) + // Check for an immediate error after beginning the transaction + if err := tx.Error; err != nil { + return false, err + } + + defer func() { + tx.Rollback() // rollback is a no-op if the tx has already been committed, todo: confirm + }() + + if _, schemaExists, err := bloodhoundDBTransaction.upsertGraphSchemaExtension(ctx, graphSchema); err != nil { + return false, err + } else if err = tx.Commit().Error; err != nil { + return false, err + } else { + return schemaExists, nil + + } +} + +// upsertGraphSchemaExtension - upserts the model.GraphSchema portion of a model.GraphExtension. TODO: replace with entire extension model. +func (s *BloodhoundDB) upsertGraphSchemaExtension(ctx context.Context, graphSchema model.GraphSchema) (model.GraphSchema, bool, error) { + var ( + err error + schemaExists bool + existingGraphSchema model.GraphSchema + + nodeKindActions MapDiffActions[model.GraphSchemaNodeKind] + propertyActions MapDiffActions[model.GraphSchemaProperty] + edgeKindActions MapDiffActions[model.GraphSchemaEdgeKind] + ) + + if existingGraphSchema, err = s.GetGraphSchemaByExtensionName(ctx, graphSchema.GraphSchemaExtension.Name); err != nil { + if !errors.Is(err, ErrNotFound) { + return graphSchema, schemaExists, err + } else { + // extension does not exist so create extension + if graphSchema.GraphSchemaExtension, err = s.CreateGraphSchemaExtension(ctx, graphSchema.GraphSchemaExtension.Name, + graphSchema.GraphSchemaExtension.DisplayName, graphSchema.GraphSchemaExtension.Version); err != nil { + return graphSchema, schemaExists, err + } + } + } else { + // extension exists, transfer model.Serial and update + schemaExists = true + if graphSchema.GraphSchemaExtension.IsBuiltin { + return graphSchema, schemaExists, fmt.Errorf("cannot modify a built-in graph schema extension") + } + graphSchema.GraphSchemaExtension.Serial = existingGraphSchema.GraphSchemaExtension.Serial + if graphSchema.GraphSchemaExtension, err = s.UpdateGraphSchemaExtension(ctx, graphSchema.GraphSchemaExtension); err != nil { + return graphSchema, schemaExists, err + } + } + + // explicitly assign nodes, properties and edges their extension id + for idx := range graphSchema.GraphSchemaNodeKinds { + graphSchema.GraphSchemaNodeKinds[idx].SchemaExtensionId = graphSchema.GraphSchemaExtension.ID + } + for idx := range graphSchema.GraphSchemaEdgeKinds { + graphSchema.GraphSchemaEdgeKinds[idx].SchemaExtensionId = graphSchema.GraphSchemaExtension.ID + } + for idx := range graphSchema.GraphSchemaProperties { + graphSchema.GraphSchemaProperties[idx].SchemaExtensionId = graphSchema.GraphSchemaExtension.ID + } + + // GenerateMapDiffActions compares the incoming graph schema extension (src) with the one stored + // in the schema database (dst). It generates actions (inserts, updates and deletes) that + // HandleMapDiffAction apply to the database to upsert the incoming graph schema. These actions + // are generated for nodes, edges and properties atm. + nodeKindActions = GenerateMapDiffActions(graphSchema.GraphSchemaNodeKinds.ToMapKeyedOnName(), + existingGraphSchema.GraphSchemaNodeKinds.ToMapKeyedOnName(), convertGraphSchemaNodeKinds) + edgeKindActions = GenerateMapDiffActions(graphSchema.GraphSchemaEdgeKinds.ToMapKeyedOnName(), + existingGraphSchema.GraphSchemaEdgeKinds.ToMapKeyedOnName(), convertGraphSchemaEdgeKinds) + propertyActions = GenerateMapDiffActions(graphSchema.GraphSchemaProperties.ToMapKeyedOnName(), + existingGraphSchema.GraphSchemaProperties.ToMapKeyedOnName(), convertGraphSchemaProperties) + + if graphSchema.GraphSchemaNodeKinds, err = HandleMapDiffAction(ctx, nodeKindActions, + s.deleteGraphSchemaNodeKind, s.updateGraphSchemaNodeKind, s.createGraphSchemaNodeKind); err != nil { + return graphSchema, schemaExists, err + } else if graphSchema.GraphSchemaEdgeKinds, err = HandleMapDiffAction(ctx, edgeKindActions, + s.deleteGraphSchemaEdgeKind, s.updateGraphSchemaEdgeKind, s.createGraphSchemaEdgeKind); err != nil { + return graphSchema, schemaExists, err + } else if graphSchema.GraphSchemaProperties, err = HandleMapDiffAction(ctx, propertyActions, + s.deleteGraphSchemaProperty, s.updateGraphSchemaProperty, s.createGraphSchemaProperty); err != nil { + return graphSchema, schemaExists, err + } + return graphSchema, schemaExists, nil +} + +// GetGraphSchemaByExtensionName - returns a graph schema extension with nodes, edges and properties. Will return +// ErrNotFound if the extension does not exist. +func (s *BloodhoundDB) GetGraphSchemaByExtensionName(ctx context.Context, extensionName string) (model.GraphSchema, error) { + var graphSchema = model.GraphSchema{ + GraphSchemaProperties: make(model.GraphSchemaProperties, 0), + GraphSchemaEdgeKinds: make(model.GraphSchemaEdgeKinds, 0), + GraphSchemaNodeKinds: make(model.GraphSchemaNodeKinds, 0), + } + + if extensions, totalRecords, err := s.GetGraphSchemaExtensions(ctx, + model.Filters{"name": []model.Filter{{ // check to see if extension exists + Operator: model.Equals, + Value: extensionName, + SetOperator: model.FilterAnd, + }}}, model.Sort{}, 0, 1); err != nil && !errors.Is(err, ErrNotFound) { + return model.GraphSchema{}, err + } else if totalRecords == 0 || errors.Is(err, ErrNotFound) { + return model.GraphSchema{}, ErrNotFound + } else { + graphSchema.GraphSchemaExtension = extensions[0] + if graphSchema.GraphSchemaNodeKinds, _, err = s.GetGraphSchemaNodeKinds(ctx, + model.Filters{"schema_extension_id": []model.Filter{{ + Operator: model.Equals, + Value: strconv.FormatInt(int64(graphSchema.GraphSchemaExtension.ID), 10), + SetOperator: model.FilterAnd, + }}}, model.Sort{}, 0, 0); err != nil && !errors.Is(err, ErrNotFound) { + return model.GraphSchema{}, err + } else if graphSchema.GraphSchemaEdgeKinds, _, err = s.GetGraphSchemaEdgeKinds(ctx, + model.Filters{"schema_extension_id": []model.Filter{{ + Operator: model.Equals, + Value: strconv.FormatInt(int64(graphSchema.GraphSchemaExtension.ID), 10), + SetOperator: model.FilterAnd, + }}}, model.Sort{}, 0, 0); err != nil && !errors.Is(err, ErrNotFound) { + return model.GraphSchema{}, err + } else if graphSchema.GraphSchemaProperties, _, err = s.GetGraphSchemaProperties(ctx, + model.Filters{"schema_extension_id": []model.Filter{{ + Operator: model.Equals, + Value: strconv.FormatInt(int64(graphSchema.GraphSchemaExtension.ID), 10), + SetOperator: model.FilterAnd, + }}}, model.Sort{}, 0, 0); err != nil && !errors.Is(err, ErrNotFound) { + return model.GraphSchema{}, err + } + return graphSchema, nil + } +} + +// convertGraphSchemaNodeKinds - reassigns model.Serial and SchemaExtensionId data from dst to src if neither is nil. +func convertGraphSchemaNodeKinds(src, dst *model.GraphSchemaNodeKind) { + if dst == nil || src == nil { + return + } + src.Serial = dst.Serial +} + +func convertGraphSchemaEdgeKinds(src, dst *model.GraphSchemaEdgeKind) { + if dst == nil || src == nil { + return + } + src.Serial = dst.Serial +} + +func convertGraphSchemaProperties(src, dst *model.GraphSchemaProperty) { + if dst == nil || src == nil { + return + } + src.Serial = dst.Serial +} + +func (s *BloodhoundDB) deleteGraphSchemaNodeKind(ctx context.Context, nodeKind model.GraphSchemaNodeKind) error { + return s.DeleteGraphSchemaNodeKind(ctx, nodeKind.ID) +} + +func (s *BloodhoundDB) updateGraphSchemaNodeKind(ctx context.Context, nodeKind model.GraphSchemaNodeKind) (model.GraphSchemaNodeKind, error) { + return s.UpdateGraphSchemaNodeKind(ctx, nodeKind) +} + +func (s *BloodhoundDB) createGraphSchemaNodeKind(ctx context.Context, nodeKind model.GraphSchemaNodeKind) (model.GraphSchemaNodeKind, error) { + return s.CreateGraphSchemaNodeKind(ctx, nodeKind.Name, nodeKind.SchemaExtensionId, nodeKind.DisplayName, + nodeKind.Description, nodeKind.IsDisplayKind, nodeKind.Icon, nodeKind.IconColor) +} + +func (s *BloodhoundDB) deleteGraphSchemaEdgeKind(ctx context.Context, edgeKind model.GraphSchemaEdgeKind) error { + return s.DeleteGraphSchemaEdgeKind(ctx, edgeKind.ID) +} + +func (s *BloodhoundDB) updateGraphSchemaEdgeKind(ctx context.Context, edgeKind model.GraphSchemaEdgeKind) (model.GraphSchemaEdgeKind, error) { + return s.UpdateGraphSchemaEdgeKind(ctx, edgeKind) +} + +func (s *BloodhoundDB) createGraphSchemaEdgeKind(ctx context.Context, edgeKind model.GraphSchemaEdgeKind) (model.GraphSchemaEdgeKind, error) { + return s.CreateGraphSchemaEdgeKind(ctx, edgeKind.Name, edgeKind.SchemaExtensionId, edgeKind.Description, + edgeKind.IsTraversable) +} + +func (s *BloodhoundDB) deleteGraphSchemaProperty(ctx context.Context, property model.GraphSchemaProperty) error { + return s.DeleteGraphSchemaProperty(ctx, property.ID) +} + +func (s *BloodhoundDB) updateGraphSchemaProperty(ctx context.Context, property model.GraphSchemaProperty) (model.GraphSchemaProperty, error) { + return s.UpdateGraphSchemaProperty(ctx, property) +} + +func (s *BloodhoundDB) createGraphSchemaProperty(ctx context.Context, property model.GraphSchemaProperty) (model.GraphSchemaProperty, error) { + return s.CreateGraphSchemaProperty(ctx, property.SchemaExtensionId, property.Name, + property.DisplayName, property.DataType, property.Description) +} diff --git a/cmd/api/src/database/upsertExtension_test.go b/cmd/api/src/database/upsertExtension_test.go new file mode 100644 index 0000000000..9a94a7841a --- /dev/null +++ b/cmd/api/src/database/upsertExtension_test.go @@ -0,0 +1,291 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +//go:build integration + +package database_test + +import ( + "context" + "testing" + + "github.com/specterops/bloodhound/cmd/api/src/database" + "github.com/specterops/bloodhound/cmd/api/src/model" + "github.com/stretchr/testify/require" +) + +func TestDatabase_GetGraphSchemaByExtensionName(t *testing.T) { + t.Parallel() + + testSuite := setupIntegrationTestSuite(t) + + defer teardownIntegrationTestSuite(t, &testSuite) + + var ( + err error + testExtensionName = "test_extension" + testExtension = model.GraphSchemaExtension{ + Name: testExtensionName, + Version: "1.0.0", + IsBuiltin: false, + DisplayName: "Test Extension", + } + nodeKind1 = model.GraphSchemaNodeKind{ + Name: "Test_Node_Kind_1", + SchemaExtensionId: testExtension.ID, + DisplayName: "Test Node Kind 1", + Description: "a test node kind", + IsDisplayKind: true, + Icon: "user", + IconColor: "blue", + } + edgeKind1 = model.GraphSchemaEdgeKind{ + SchemaExtensionId: testExtension.ID, + Name: "Test_Edge_Kind_1", + Description: "Test Edge Kind 1", + IsTraversable: true, + } + property1 = model.GraphSchemaProperty{ + SchemaExtensionId: testExtension.ID, + Name: "Test_Property_1", + DisplayName: "Test Property 1", + DataType: "string", + Description: "Test Property 1", + } + ) + + type fields struct { + setup func(t *testing.T) model.GraphSchema + teardown func(t *testing.T) + } + type args struct { + ctx context.Context + extensionName string + } + tests := []struct { + name string + fields fields + args args + wantErr error + }{ + + { + name: "success - existing but empty GraphSchemaExtension", // schema extension exists but there are no nodes, edges or properties linked to the extension + fields: fields{ + setup: func(t *testing.T) model.GraphSchema { + testExtension, err = testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, testExtension.Name, testExtension.DisplayName, + testExtension.Version) + require.NoError(t, err) + return model.GraphSchema{ + GraphSchemaExtension: testExtension, + } + }, + teardown: func(t *testing.T) { + err = testSuite.BHDatabase.DeleteGraphSchemaExtension(testSuite.Context, testExtension.ID) + require.NoError(t, err) + }, + }, + args: args{ + ctx: context.Background(), + extensionName: testExtension.Name, + }, + }, + + { + name: "success - GraphSchemaExtension", + fields: fields{ + setup: func(t *testing.T) model.GraphSchema { + testExtension, err = testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, testExtension.Name, testExtension.DisplayName, + testExtension.Version) + require.NoError(t, err) + nodeKind1, err = testSuite.BHDatabase.CreateGraphSchemaNodeKind(testSuite.Context, nodeKind1.Name, + testExtension.ID, nodeKind1.DisplayName, nodeKind1.Description, nodeKind1.IsDisplayKind, + nodeKind1.Icon, nodeKind1.IconColor) + require.NoError(t, err) + edgeKind1, err = testSuite.BHDatabase.CreateGraphSchemaEdgeKind(testSuite.Context, edgeKind1.Name, + testExtension.ID, edgeKind1.Description, edgeKind1.IsTraversable) + require.NoError(t, err) + property1, err = testSuite.BHDatabase.CreateGraphSchemaProperty(testSuite.Context, testExtension.ID, + property1.Name, property1.DisplayName, property1.DataType, property1.Description) + require.NoError(t, err) + + return model.GraphSchema{ + GraphSchemaExtension: testExtension, + GraphSchemaNodeKinds: model.GraphSchemaNodeKinds{nodeKind1}, + GraphSchemaEdgeKinds: model.GraphSchemaEdgeKinds{edgeKind1}, + GraphSchemaProperties: model.GraphSchemaProperties{property1}, + } + }, + teardown: func(t *testing.T) { + err = testSuite.BHDatabase.DeleteGraphSchemaExtension(testSuite.Context, testExtension.ID) + require.NoError(t, err) + }, + }, + args: args{ + ctx: context.Background(), + extensionName: testExtensionName, + }, + }, + { + name: "success - no GetGraphSchemaExtensions results", // Will result in new graph schema extension + fields: fields{ + setup: func(t *testing.T) model.GraphSchema { + return model.GraphSchema{} + }, + teardown: func(t *testing.T) {}, + }, + args: args{ + ctx: context.Background(), + extensionName: "non_existing_extension", + }, + wantErr: database.ErrNotFound, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + want := tt.fields.setup(t) + if got, err := testSuite.BHDatabase.GetGraphSchemaByExtensionName(tt.args.ctx, tt.args.extensionName); tt.wantErr != nil { + require.EqualError(t, err, tt.wantErr.Error()) + return + } else { + require.NoError(t, err) + compareGraphSchema(t, got, want) + tt.fields.teardown(t) + } + }) + } +} + +func TestBloodhoundDB_UpsertGraphSchemaExtension(t *testing.T) { + t.Parallel() + testSuite := setupIntegrationTestSuite(t) + defer teardownIntegrationTestSuite(t, &testSuite) + + var ( + err error + got bool + gotGraphSchema model.GraphSchema + + testExtensionName = "Upsert_New_Test_Extension" + testExtension = model.GraphSchemaExtension{ + Name: testExtensionName, + Version: "1.0.0", + DisplayName: "Test Extension", + } + nodeKind1 = model.GraphSchemaNodeKind{ + Name: "Upsert_New_Test_Node_Kind_1", + SchemaExtensionId: testExtension.ID, + DisplayName: "Test Node Kind 1", + Description: "a test node kind", + IsDisplayKind: true, + Icon: "user", + IconColor: "blue", + } + edgeKind1 = model.GraphSchemaEdgeKind{ + SchemaExtensionId: testExtension.ID, + Name: "Upsert_New_Test_Edge_Kind_1", + Description: "Test Edge Kind 1", + IsTraversable: true, + } + property1 = model.GraphSchemaProperty{ + SchemaExtensionId: testExtension.ID, + Name: "Upsert_New_Test_Property_1", + DisplayName: "Test Property 1", + DataType: "string", + Description: "Test Property 1", + } + testGraphSchema = model.GraphSchema{ + GraphSchemaExtension: testExtension, + GraphSchemaNodeKinds: model.GraphSchemaNodeKinds{nodeKind1}, + GraphSchemaEdgeKinds: model.GraphSchemaEdgeKinds{edgeKind1}, + GraphSchemaProperties: model.GraphSchemaProperties{property1}, + } + ) + + type fields struct { + setup func(t *testing.T) + teardown func(t *testing.T) + } + + type args struct { + ctx context.Context + graphSchema model.GraphSchema + } + tests := []struct { + name string + fields fields + args args + want bool + wantErr error + }{ + { + name: "success - create new OpenGraph extension", + fields: fields{ + setup: func(t *testing.T) {}, + teardown: func(t *testing.T) { + gotGraphSchema, err = testSuite.BHDatabase.GetGraphSchemaByExtensionName(testSuite.Context, testExtensionName) + require.NoError(t, err) + compareGraphSchema(t, gotGraphSchema, testGraphSchema) + + err = testSuite.BHDatabase.DeleteGraphSchemaExtension(testSuite.Context, gotGraphSchema.GraphSchemaExtension.ID) + require.NoError(t, err) + + _, err = testSuite.BHDatabase.GetGraphSchemaByExtensionName(testSuite.Context, testExtension.Name) + require.Equal(t, database.ErrNotFound, err) + }, + }, + args: args{ + ctx: testSuite.Context, + graphSchema: testGraphSchema, + }, + wantErr: nil, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.fields.setup(t) + + if got, err = testSuite.BHDatabase.UpsertOpenGraphExtension(tt.args.ctx, tt.args.graphSchema); tt.wantErr != nil { + require.EqualError(t, err, tt.wantErr.Error()) + } else { + require.NoError(t, err) + require.Equalf(t, tt.want, got, "UpsertOpenGraphExtension(%v, %v)", tt.args.ctx, tt.args.graphSchema) + } + tt.fields.teardown(t) + }) + } +} + +func compareGraphSchemaExtension(t *testing.T, got, want model.GraphSchemaExtension) { + t.Helper() + require.Greaterf(t, got.ID, int32(0), "GraphSchemaExtension - ID mismatch - got: %v", got.ID) + require.Equalf(t, want.Name, got.Name, "GraphSchemaExtension - name mismatch - got %v, want %v", got.Name, want.Name) + require.Equalf(t, want.DisplayName, got.DisplayName, "GraphSchemaExtension - display_name mismatch - got %v, want %v", got.DisplayName, want.DisplayName) + require.Equalf(t, want.Version, got.Version, "GraphSchemaExtension - version mismatch - got %v, want %v", got.Version, want.Version) + require.Equalf(t, want.IsBuiltin, got.IsBuiltin, "GraphSchemaExtension - is_built mismatch - got %t, want %t", got.IsBuiltin, want.IsBuiltin) + require.Equalf(t, false, got.CreatedAt.IsZero(), "GraphSchemaExtension - created_at mismatch - got: %s", got.CreatedAt.String()) + require.Equalf(t, false, got.UpdatedAt.IsZero(), "GraphSchemaExtension - updated_at mismatch - got: %s", got.UpdatedAt.String()) + require.Equalf(t, false, got.DeletedAt.Valid, "GraphSchemaExtension - deleted_at is not null") +} + +func compareGraphSchema(t *testing.T, got, want model.GraphSchema) { + t.Helper() + compareGraphSchemaExtension(t, got.GraphSchemaExtension, want.GraphSchemaExtension) + compareGraphSchemaNodeKinds(t, got.GraphSchemaNodeKinds, want.GraphSchemaNodeKinds) + compareGraphSchemaEdgeKinds(t, got.GraphSchemaEdgeKinds, want.GraphSchemaEdgeKinds) + compareGraphSchemaProperties(t, got.GraphSchemaProperties, want.GraphSchemaProperties) +} diff --git a/cmd/api/src/model/appcfg/flag.go b/cmd/api/src/model/appcfg/flag.go index 99f622ba29..fb6b7b8a40 100644 --- a/cmd/api/src/model/appcfg/flag.go +++ b/cmd/api/src/model/appcfg/flag.go @@ -25,24 +25,25 @@ import ( // AvailableFlags has been removed and the db feature_flags table is the source of truth. Feature flag defaults should be added via migration *.sql files. const ( - FeatureButterflyAnalysis = "butterfly_analysis" - FeatureEnableSAMLSSO = "enable_saml_sso" - FeatureScopeCollectionByOU = "scope_collection_by_ou" - FeatureAzureSupport = "azure_support" - FeatureEntityPanelCaching = "entity_panel_cache" - FeatureAdcs = "adcs" - FeatureClearGraphData = "clear_graph_data" - FeatureRiskExposureNewCalculation = "risk_exposure_new_calculation" - FeatureFedRAMPEULA = "fedramp_eula" - FeatureDarkMode = "dark_mode" - FeatureAutoTagT0ParentObjects = "auto_tag_t0_parent_objects" - FeatureOIDCSupport = "oidc_support" - FeatureNTLMPostProcessing = "ntlm_post_processing" - FeatureTierManagement = "tier_management_engine" - FeatureChangelog = "changelog" - FeatureETAC = "environment_targeted_access_control" - FeatureOpenGraphSearch = "opengraph_search" - FeatureClientBearerAuth = "client_bearer_auth" + FeatureButterflyAnalysis = "butterfly_analysis" + FeatureEnableSAMLSSO = "enable_saml_sso" + FeatureScopeCollectionByOU = "scope_collection_by_ou" + FeatureAzureSupport = "azure_support" + FeatureEntityPanelCaching = "entity_panel_cache" + FeatureAdcs = "adcs" + FeatureClearGraphData = "clear_graph_data" + FeatureRiskExposureNewCalculation = "risk_exposure_new_calculation" + FeatureFedRAMPEULA = "fedramp_eula" + FeatureDarkMode = "dark_mode" + FeatureAutoTagT0ParentObjects = "auto_tag_t0_parent_objects" + FeatureOIDCSupport = "oidc_support" + FeatureNTLMPostProcessing = "ntlm_post_processing" + FeatureTierManagement = "tier_management_engine" + FeatureChangelog = "changelog" + FeatureETAC = "environment_targeted_access_control" + FeatureOpenGraphSearch = "opengraph_search" + FeatureClientBearerAuth = "client_bearer_auth" + FeatureOpenGraphExtensionManagement = "opengraph_extension_management" ) // FeatureFlag defines the most basic details of what a feature flag must contain to be actionable. Feature flags should be diff --git a/cmd/api/src/model/graphschema.go b/cmd/api/src/model/graphschema.go index 397ba77afc..767f9f6680 100644 --- a/cmd/api/src/model/graphschema.go +++ b/cmd/api/src/model/graphschema.go @@ -18,6 +18,18 @@ package model import "time" +type GraphExtension struct { + GraphSchema GraphSchema +} + +// GraphSchema - +type GraphSchema struct { + GraphSchemaExtension GraphSchemaExtension `json:"extension"` + GraphSchemaProperties GraphSchemaProperties `json:"properties"` + GraphSchemaEdgeKinds GraphSchemaEdgeKinds `json:"edge_kinds"` + GraphSchemaNodeKinds GraphSchemaNodeKinds `json:"node_kinds"` +} + type GraphSchemaExtensions []GraphSchemaExtension type GraphSchemaExtension struct { @@ -46,17 +58,26 @@ func (s GraphSchemaExtension) AuditData() AuditData { // GraphSchemaNodeKinds - slice of node kinds type GraphSchemaNodeKinds []GraphSchemaNodeKind +// ToMapKeyedOnName - converts a list of graph schema node kinds to a map based on name +func (g GraphSchemaNodeKinds) ToMapKeyedOnName() map[string]GraphSchemaNodeKind { + result := make(map[string]GraphSchemaNodeKind, 0) + for _, kind := range g { + result[kind.Name] = kind + } + return result +} + // GraphSchemaNodeKind - represents a node kind for an extension type GraphSchemaNodeKind struct { Serial - Name string - SchemaExtensionId int32 // indicates which extension this node kind belongs to - DisplayName string // can be different from name but usually isn't other than Base/Entity - Description string // human-readable description of the node kind - IsDisplayKind bool // indicates if this kind should supersede others and be displayed - Icon string // font-awesome icon for the registered node kind - IconColor string // icon hex color + Name string `json:"name"` + SchemaExtensionId int32 `json:"schema_extension_id"` // indicates which extension this node kind belongs to + DisplayName string `json:"display_name"` // can be different from name but usually isn't other than Base/Entity + Description string `json:"description"` // human-readable description of the node kind + IsDisplayKind bool `json:"is_display_kind"` // indicates if this kind should supersede others and be displayed + Icon string `json:"icon"` // font-awesome icon for the registered node kind + IconColor string `json:"icon_color"` // icon hex color } // TableName - Retrieve table name @@ -67,6 +88,15 @@ func (GraphSchemaNodeKind) TableName() string { // GraphSchemaProperties - slice of graph schema properties. type GraphSchemaProperties []GraphSchemaProperty +// ToMapKeyedOnName - converts a list of graph schema properties to a map keyed on name +func (g GraphSchemaProperties) ToMapKeyedOnName() map[string]GraphSchemaProperty { + result := make(map[string]GraphSchemaProperty, 0) + for _, kind := range g { + result[kind.Name] = kind + } + return result +} + // GraphSchemaProperty - represents a property that an edge or node kind can have. Grouped by schema extension. type GraphSchemaProperty struct { Serial @@ -82,16 +112,25 @@ func (GraphSchemaProperty) TableName() string { return "schema_properties" } -// GraphSchemaEdgeKinds - slice of model.GraphSchemaEdgeKind +// GraphSchemaEdgeKinds - slice of GraphSchemaEdgeKind type GraphSchemaEdgeKinds []GraphSchemaEdgeKind +// ToMapKeyedOnName - converts a list of graph schema edge kinds to a map keyed on name +func (g GraphSchemaEdgeKinds) ToMapKeyedOnName() map[string]GraphSchemaEdgeKind { + result := make(map[string]GraphSchemaEdgeKind, 0) + for _, kind := range g { + result[kind.Name] = kind + } + return result +} + // GraphSchemaEdgeKind - represents an edge kind for an extension type GraphSchemaEdgeKind struct { Serial - SchemaExtensionId int32 // indicates which extension this edge kind belongs to - Name string - Description string - IsTraversable bool // indicates whether the edge-kind is a traversable path + SchemaExtensionId int32 `json:"schema_extension_id"` // indicates which extension this edge kind belongs to + Name string `json:"name"` + Description string `json:"description"` + IsTraversable bool `json:"is_traversable"` // indicates whether the edge-kind is a traversable path } func (GraphSchemaEdgeKind) TableName() string { diff --git a/cmd/api/src/services/entrypoint.go b/cmd/api/src/services/entrypoint.go index 4ab5a38745..a16dc35330 100644 --- a/cmd/api/src/services/entrypoint.go +++ b/cmd/api/src/services/entrypoint.go @@ -39,6 +39,7 @@ import ( "github.com/specterops/bloodhound/cmd/api/src/model/appcfg" "github.com/specterops/bloodhound/cmd/api/src/queries" "github.com/specterops/bloodhound/cmd/api/src/services/dogtags" + "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema" "github.com/specterops/bloodhound/cmd/api/src/services/upload" "github.com/specterops/bloodhound/packages/go/cache" schema "github.com/specterops/bloodhound/packages/go/graphschema" @@ -118,17 +119,19 @@ func Entrypoint(ctx context.Context, cfg config.Configuration, connections boots startDelay := 0 * time.Second var ( - cl = changelog.NewChangelog(connections.Graph, connections.RDMS, changelog.DefaultOptions()) - pipeline = datapipe.NewPipeline(ctx, cfg, connections.RDMS, connections.Graph, graphQueryCache, ingestSchema, cl) - graphQuery = queries.NewGraphQuery(connections.Graph, graphQueryCache, cfg) - authorizer = auth.NewAuthorizer(connections.RDMS) - datapipeDaemon = datapipe.NewDaemon(pipeline, startDelay, time.Duration(cfg.DatapipeInterval)*time.Second, connections.RDMS) - routerInst = router.NewRouter(cfg, authorizer, fmt.Sprintf(bootstrap.ContentSecurityPolicy, "", "")) - authenticator = api.NewAuthenticator(cfg, connections.RDMS, api.NewAuthExtensions(cfg, connections.RDMS)) + cl = changelog.NewChangelog(connections.Graph, connections.RDMS, changelog.DefaultOptions()) + pipeline = datapipe.NewPipeline(ctx, cfg, connections.RDMS, connections.Graph, graphQueryCache, ingestSchema, cl) + graphQuery = queries.NewGraphQuery(connections.Graph, graphQueryCache, cfg) + authorizer = auth.NewAuthorizer(connections.RDMS) + datapipeDaemon = datapipe.NewDaemon(pipeline, startDelay, time.Duration(cfg.DatapipeInterval)*time.Second, connections.RDMS) + routerInst = router.NewRouter(cfg, authorizer, fmt.Sprintf(bootstrap.ContentSecurityPolicy, "", "")) + authenticator = api.NewAuthenticator(cfg, connections.RDMS, api.NewAuthExtensions(cfg, connections.RDMS)) + openGraphSchemaService = opengraphschema.NewOpenGraphSchemaService(connections.RDMS, connections.Graph) ) registration.RegisterFossGlobalMiddleware(&routerInst, cfg, auth.NewIdentityResolver(), authenticator) - registration.RegisterFossRoutes(&routerInst, cfg, connections.RDMS, connections.Graph, graphQuery, apiCache, collectorManifests, authenticator, authorizer, ingestSchema, dogtagsService) + registration.RegisterFossRoutes(&routerInst, cfg, connections.RDMS, connections.Graph, graphQuery, apiCache, + collectorManifests, authenticator, authorizer, ingestSchema, openGraphSchemaService, dogtagsService) // Set neo4j batch and flush sizes neo4jParameters := appcfg.GetNeo4jParameters(ctx, connections.RDMS) diff --git a/cmd/api/src/services/opengraphschema/mocks/graphdbkindrepository.go b/cmd/api/src/services/opengraphschema/mocks/graphdbkindrepository.go new file mode 100644 index 0000000000..cdad3155e3 --- /dev/null +++ b/cmd/api/src/services/opengraphschema/mocks/graphdbkindrepository.go @@ -0,0 +1,71 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema (interfaces: GraphDBKindRepository) +// +// Generated by this command: +// +// mockgen -copyright_file ../../../../../LICENSE.header -destination=./mocks/graphdbkindrepository.go -package=mocks . GraphDBKindRepository +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockGraphDBKindRepository is a mock of GraphDBKindRepository interface. +type MockGraphDBKindRepository struct { + ctrl *gomock.Controller + recorder *MockGraphDBKindRepositoryMockRecorder + isgomock struct{} +} + +// MockGraphDBKindRepositoryMockRecorder is the mock recorder for MockGraphDBKindRepository. +type MockGraphDBKindRepositoryMockRecorder struct { + mock *MockGraphDBKindRepository +} + +// NewMockGraphDBKindRepository creates a new mock instance. +func NewMockGraphDBKindRepository(ctrl *gomock.Controller) *MockGraphDBKindRepository { + mock := &MockGraphDBKindRepository{ctrl: ctrl} + mock.recorder = &MockGraphDBKindRepositoryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockGraphDBKindRepository) EXPECT() *MockGraphDBKindRepositoryMockRecorder { + return m.recorder +} + +// RefreshKinds mocks base method. +func (m *MockGraphDBKindRepository) RefreshKinds(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RefreshKinds", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// RefreshKinds indicates an expected call of RefreshKinds. +func (mr *MockGraphDBKindRepositoryMockRecorder) RefreshKinds(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RefreshKinds", reflect.TypeOf((*MockGraphDBKindRepository)(nil).RefreshKinds), ctx) +} diff --git a/cmd/api/src/services/opengraphschema/mocks/opengraphschemarepository.go b/cmd/api/src/services/opengraphschema/mocks/opengraphschemarepository.go new file mode 100644 index 0000000000..6d14d9b183 --- /dev/null +++ b/cmd/api/src/services/opengraphschema/mocks/opengraphschemarepository.go @@ -0,0 +1,73 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema (interfaces: OpenGraphSchemaRepository) +// +// Generated by this command: +// +// mockgen -copyright_file ../../../../../LICENSE.header -destination=./mocks/opengraphschemarepository.go -package=mocks . OpenGraphSchemaRepository +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + model "github.com/specterops/bloodhound/cmd/api/src/model" + gomock "go.uber.org/mock/gomock" +) + +// MockOpenGraphSchemaRepository is a mock of OpenGraphSchemaRepository interface. +type MockOpenGraphSchemaRepository struct { + ctrl *gomock.Controller + recorder *MockOpenGraphSchemaRepositoryMockRecorder + isgomock struct{} +} + +// MockOpenGraphSchemaRepositoryMockRecorder is the mock recorder for MockOpenGraphSchemaRepository. +type MockOpenGraphSchemaRepositoryMockRecorder struct { + mock *MockOpenGraphSchemaRepository +} + +// NewMockOpenGraphSchemaRepository creates a new mock instance. +func NewMockOpenGraphSchemaRepository(ctrl *gomock.Controller) *MockOpenGraphSchemaRepository { + mock := &MockOpenGraphSchemaRepository{ctrl: ctrl} + mock.recorder = &MockOpenGraphSchemaRepositoryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOpenGraphSchemaRepository) EXPECT() *MockOpenGraphSchemaRepositoryMockRecorder { + return m.recorder +} + +// UpsertOpenGraphExtension mocks base method. +func (m *MockOpenGraphSchemaRepository) UpsertOpenGraphExtension(ctx context.Context, graphSchema model.GraphSchema) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertOpenGraphExtension", ctx, graphSchema) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpsertOpenGraphExtension indicates an expected call of UpsertOpenGraphExtension. +func (mr *MockOpenGraphSchemaRepositoryMockRecorder) UpsertOpenGraphExtension(ctx, graphSchema any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertOpenGraphExtension", reflect.TypeOf((*MockOpenGraphSchemaRepository)(nil).UpsertOpenGraphExtension), ctx, graphSchema) +} diff --git a/cmd/api/src/services/opengraphschema/opengraphschema.go b/cmd/api/src/services/opengraphschema/opengraphschema.go new file mode 100644 index 0000000000..3703f11565 --- /dev/null +++ b/cmd/api/src/services/opengraphschema/opengraphschema.go @@ -0,0 +1,97 @@ +// Copyright 2025 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package opengraphschema + +// Mocks + +//go:generate go run go.uber.org/mock/mockgen -copyright_file ../../../../../LICENSE.header -destination=./mocks/opengraphschemarepository.go -package=mocks . OpenGraphSchemaRepository +//go:generate go run go.uber.org/mock/mockgen -copyright_file ../../../../../LICENSE.header -destination=./mocks/graphdbkindrepository.go -package=mocks . GraphDBKindRepository + +import ( + "context" + "errors" + "fmt" + "log/slog" + + "github.com/specterops/bloodhound/cmd/api/src/model" + "github.com/specterops/bloodhound/packages/go/bhlog/attr" +) + +// OpenGraphSchemaRepository - +type OpenGraphSchemaRepository interface { + UpsertOpenGraphExtension(ctx context.Context, graphSchema model.GraphSchema) (bool, error) +} + +// GraphDBKindRepository - +type GraphDBKindRepository interface { + // RefreshKinds refreshes the in memory kinds maps + RefreshKinds(ctx context.Context) error +} + +// OpenGraphSchemaService - +type OpenGraphSchemaService struct { + openGraphSchemaRepository OpenGraphSchemaRepository + graphDBKindRepository GraphDBKindRepository +} + +func NewOpenGraphSchemaService(openGraphSchemaExtensionRepository OpenGraphSchemaRepository, graphDBKindRepository GraphDBKindRepository) *OpenGraphSchemaService { + return &OpenGraphSchemaService{ + openGraphSchemaRepository: openGraphSchemaExtensionRepository, + graphDBKindRepository: graphDBKindRepository, + } +} + +// UpsertOpenGraphExtension - validates the incoming graph schema, passes it to the DB layer for upserting and if successful +// updates the in memory kinds map. +func (o *OpenGraphSchemaService) UpsertOpenGraphExtension(ctx context.Context, graphSchema model.GraphSchema) (bool, error) { + var ( + err error + schemaExists bool + ) + + if err = validateGraphSchemaModel(graphSchema); err != nil { + return schemaExists, fmt.Errorf("graph schema validation error: %w", err) + } else if schemaExists, err = o.openGraphSchemaRepository.UpsertOpenGraphExtension(ctx, graphSchema); err != nil { + return schemaExists, err + } else if err = o.graphDBKindRepository.RefreshKinds(ctx); err != nil { + slog.WarnContext(ctx, "OpenGraphSchema: refreshing graph kind maps failed", attr.Error(err)) + } + return schemaExists, nil +} + +// validateGraphSchemaModel - Ensures the incoming model.GraphSchema has an extension name, node kinds exist, and +// there are no duplicate kinds. +func validateGraphSchemaModel(graphSchema model.GraphSchema) error { + var kinds = make(map[string]any, 0) + if graphSchema.GraphSchemaExtension.Name == "" { + return errors.New("graph schema extension name is required") + } else if len(graphSchema.GraphSchemaNodeKinds) == 0 { + return errors.New("graph schema node kinds is required") + } + for _, kind := range graphSchema.GraphSchemaNodeKinds { + if _, ok := kinds[kind.Name]; ok { + return fmt.Errorf("graph kind: %s is already registered", kind.Name) + } + kinds[kind.Name] = struct{}{} + } + for _, kind := range graphSchema.GraphSchemaEdgeKinds { + if _, ok := kinds[kind.Name]; ok { + return fmt.Errorf("graph kind: %s is already registered", kind.Name) + } + kinds[kind.Name] = struct{}{} + } + return nil +} diff --git a/cmd/api/src/services/opengraphschema/opengraphschema_test.go b/cmd/api/src/services/opengraphschema/opengraphschema_test.go new file mode 100644 index 0000000000..bcc15f191b --- /dev/null +++ b/cmd/api/src/services/opengraphschema/opengraphschema_test.go @@ -0,0 +1,330 @@ +// Copyright 2025 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 +package opengraphschema + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/specterops/bloodhound/cmd/api/src/model" + "github.com/specterops/bloodhound/cmd/api/src/services/opengraphschema/mocks" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +// TestOpenGraphSchemaService_UpsertGraphSchemaExtension - +// +// Mocks: +// +// GenerateMapSynchronizationDiffActions does not preserve ordering so the following CRUD mocks +// perform a Do function which removes the provided kind/property from the item-function map. +// The last Do function for each CRUD Operation will check to see if their respective wantAction's +// map length is 0 to ensure all actions are accounted for. +func TestOpenGraphSchemaService_UpsertGraphSchemaExtension(t *testing.T) { + t.Parallel() + + var ( + mockCtrl = gomock.NewController(t) + + mockOpenGraphSchemaRepository = mocks.NewMockOpenGraphSchemaRepository(mockCtrl) + mockGraphDBKindsRepository = mocks.NewMockGraphDBKindRepository(mockCtrl) + + existingExtension1 = model.GraphSchemaExtension{ + Serial: model.Serial{ + ID: 1, + Basic: model.Basic{ + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + }, + Name: "test_extension_1", + DisplayName: "Test Extension 1", + Version: "1.0.0", + IsBuiltin: false, + } + + _ = model.GraphSchemaExtension{ + Serial: model.Serial{ + ID: 1, + Basic: model.Basic{ + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + }, + Name: "test_extension_1", + DisplayName: "Test Extension 1", + Version: "1.0.0", + IsBuiltin: true, + } + + newExtension1 = model.GraphSchemaExtension{ + Name: "test_extension_2", + DisplayName: "Test Extension 2", + Version: "1.0.0", + IsBuiltin: false, + } + + _ = model.GraphSchemaExtension{ + Name: "test_extension_1", + DisplayName: "Test Extension 1", + Version: "2.0.0", + IsBuiltin: false, + } + + existingNodeKind1 = model.GraphSchemaNodeKind{ + Serial: model.Serial{ + ID: 1, + Basic: model.Basic{ + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + }, + Name: "node_kind_1", + SchemaExtensionId: existingExtension1.ID, + DisplayName: "Node Kind 1", + Description: "a test node kind", + IsDisplayKind: true, + Icon: "desktop", + IconColor: "blue", + } + existingNodeKind2 = model.GraphSchemaNodeKind{ + Serial: model.Serial{ + ID: 2, + Basic: model.Basic{ + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + }, + Name: "node_kind_2", + SchemaExtensionId: existingExtension1.ID, + DisplayName: "Node Kind 2", + Description: "a test node kind", + IsDisplayKind: true, + Icon: "user", + IconColor: "red", + } + newNodeKind1 = model.GraphSchemaNodeKind{ + Name: "new_node_kind_1", + DisplayName: "New Node Kind 1", + Description: "a test node kind", + IsDisplayKind: true, + Icon: "desktop", + IconColor: "blue", + } + newNodeKind2 = model.GraphSchemaNodeKind{ + Name: "new_node_kind_2", + DisplayName: "New Node Kind 2", + Description: "a test node kind", + IsDisplayKind: true, + Icon: "user", + IconColor: "green", + } + + existingEdgeKind1 = model.GraphSchemaEdgeKind{ + Serial: model.Serial{ + ID: 1, + Basic: model.Basic{ + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + }, + SchemaExtensionId: existingExtension1.ID, + Name: "edge_kind_1", + Description: "a test edge kind", + IsTraversable: true, + } + existingEdgeKind2 = model.GraphSchemaEdgeKind{ + Serial: model.Serial{ + ID: 2, + Basic: model.Basic{ + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + }, + SchemaExtensionId: existingExtension1.ID, + Name: "edge_kind_2", + Description: "a test edge kind", + IsTraversable: true, + } + newEdgeKind1 = model.GraphSchemaEdgeKind{ + Name: "new_edge_kind_1", + Description: "a test edge kind", + IsTraversable: true, + } + newEdgeKind2 = model.GraphSchemaEdgeKind{ + Name: "new_edge_kind_2", + Description: "a test edge kind", + IsTraversable: true, + } + existingProperty1 = model.GraphSchemaProperty{ + Serial: model.Serial{ + ID: 1, + Basic: model.Basic{ + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + }, + SchemaExtensionId: existingExtension1.ID, + Name: "property_1", + DisplayName: "Property 1", + DataType: "string", + Description: "a test property", + } + existingProperty2 = model.GraphSchemaProperty{ + Serial: model.Serial{ + ID: 2, + Basic: model.Basic{ + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + }, + SchemaExtensionId: existingExtension1.ID, + Name: "property_2", + DisplayName: "Property 2", + DataType: "integer", + Description: "a test property", + } + newProperty1 = model.GraphSchemaProperty{ + Name: "property_1", + DisplayName: "Property 1", + DataType: "integer", + Description: "a test property", + } + newProperty2 = model.GraphSchemaProperty{ + Name: "property_2", + DisplayName: "Property 2", + DataType: "string", + Description: "a test property", + } + + _ = model.GraphSchema{ + GraphSchemaExtension: newExtension1, + GraphSchemaNodeKinds: model.GraphSchemaNodeKinds{newNodeKind1, newNodeKind2}, + GraphSchemaEdgeKinds: model.GraphSchemaEdgeKinds{newEdgeKind1, newEdgeKind2}, + GraphSchemaProperties: model.GraphSchemaProperties{newProperty1, newProperty2}, + } + _ = model.GraphSchema{ + GraphSchemaExtension: existingExtension1, + GraphSchemaNodeKinds: model.GraphSchemaNodeKinds{existingNodeKind1, existingNodeKind2}, + GraphSchemaEdgeKinds: model.GraphSchemaEdgeKinds{existingEdgeKind1, existingEdgeKind2}, + GraphSchemaProperties: model.GraphSchemaProperties{existingProperty1, existingProperty2}, + } + ) + + defer mockCtrl.Finish() + + type fields struct { + setupOpenGraphSchemaRepositoryMock func(t *testing.T, mock *mocks.MockOpenGraphSchemaRepository) + setupGraphDBKindsRepositoryMock func(t *testing.T, mock *mocks.MockGraphDBKindRepository) + } + type args struct { + ctx context.Context + graphSchema model.GraphSchema + } + + tests := []struct { + name string + fields fields + args args + wantErr error + wantUpdated bool + }{ + { + name: "fail - invalid graph schema", + fields: fields{ + setupOpenGraphSchemaRepositoryMock: func(t *testing.T, mock *mocks.MockOpenGraphSchemaRepository) {}, + setupGraphDBKindsRepositoryMock: func(t *testing.T, mock *mocks.MockGraphDBKindRepository) {}, + }, + args: args{ + ctx: context.Background(), + graphSchema: model.GraphSchema{}, + }, + wantErr: fmt.Errorf("validation error"), + wantUpdated: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.fields.setupOpenGraphSchemaRepositoryMock(t, mockOpenGraphSchemaRepository) + tt.fields.setupGraphDBKindsRepositoryMock(t, mockGraphDBKindsRepository) + + o := &OpenGraphSchemaService{ + openGraphSchemaRepository: mockOpenGraphSchemaRepository, + graphDBKindRepository: mockGraphDBKindsRepository, + } + updated, err := o.UpsertOpenGraphExtension(tt.args.ctx, tt.args.graphSchema) + if tt.wantErr != nil { + require.ErrorContains(t, err, tt.wantErr.Error(), "UpsertOpenGraphExtension() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantUpdated != updated { + require.Fail(t, "expected graph schema to be updated") + } + }) + } +} + +func Test_validateGraphSchemaModel(t *testing.T) { + type args struct { + graphSchema model.GraphSchema + } + tests := []struct { + name string + args args + wantErr require.ErrorAssertionFunc + }{ + { + name: "fail - empty extension name", + args: args{ + graphSchema: model.GraphSchema{}, + }, + wantErr: require.Error, + }, + { + name: "fail - empty graph schema nodes", + args: args{ + graphSchema: model.GraphSchema{ + GraphSchemaExtension: model.GraphSchemaExtension{ + Name: "Test extension", + }, + }, + }, + wantErr: require.Error, + }, + { + name: "success - valid model.GraphSchemaExtension", + args: args{ + graphSchema: model.GraphSchema{ + GraphSchemaExtension: model.GraphSchemaExtension{ + Name: "Test extension", + }, + GraphSchemaNodeKinds: model.GraphSchemaNodeKinds{{ + Name: "node kind 1", + SchemaExtensionId: 1, + }}, + }, + }, + wantErr: require.NoError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.wantErr(t, validateGraphSchemaModel(tt.args.graphSchema), fmt.Sprintf("validateGraphSchemaModel(%v)", tt.args.graphSchema)) + }) + } +} diff --git a/packages/go/openapi/doc/openapi.json b/packages/go/openapi/doc/openapi.json index 62323d7c8f..58c86c355b 100644 --- a/packages/go/openapi/doc/openapi.json +++ b/packages/go/openapi/doc/openapi.json @@ -17433,6 +17433,42 @@ } } } + }, + "/api/v2/extensions": { + "put": { + "operationId": "UpsertOpenGraphExtension", + "summary": "Upserts the OpenGraph Extension", + "description": "Upserts the OpenGraph extension", + "tags": [ + "OpenGraph" + ], + "responses": { + "200": { + "description": "OK" + }, + "201": { + "description": "CREATED" + }, + "400": { + "$ref": "#/components/responses/bad-request" + }, + "401": { + "$ref": "#/components/responses/unauthorized" + }, + "403": { + "$ref": "#/components/responses/forbidden" + }, + "404": { + "$ref": "#/components/responses/not-found" + }, + "429": { + "$ref": "#/components/responses/too-many-requests" + }, + "500": { + "$ref": "#/components/responses/internal-server-error" + } + } + } } }, "components": { @@ -21526,7 +21562,8 @@ "Groups", "Data Quality", "Datapipe", - "Cypher" + "Cypher", + "OpenGraph" ] }, { diff --git a/packages/go/openapi/src/openapi.yaml b/packages/go/openapi/src/openapi.yaml index 190329da3b..21999aecf0 100644 --- a/packages/go/openapi/src/openapi.yaml +++ b/packages/go/openapi/src/openapi.yaml @@ -180,6 +180,7 @@ x-tagGroups: - Data Quality - Datapipe - Cypher + - OpenGraph - name: Enterprise Only tags: - EULA @@ -747,6 +748,10 @@ paths: /api/v2/meta/{object_id}: $ref: './paths/meta-entity.meta.id.yaml' + # open graph extensions + /api/v2/extensions: + $ref: './paths/opengraph.extension-upsert.yaml' + components: ## # SECURITY diff --git a/packages/go/openapi/src/paths/opengraph.extension-upsert.yaml b/packages/go/openapi/src/paths/opengraph.extension-upsert.yaml new file mode 100644 index 0000000000..0b6dfa06d2 --- /dev/null +++ b/packages/go/openapi/src/paths/opengraph.extension-upsert.yaml @@ -0,0 +1,39 @@ +# Copyright 2026 Specter Ops, Inc. +# +# Licensed under the Apache License, Version 2.0 +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +put: + operationId: UpsertOpenGraphExtension + summary: Upserts the OpenGraph Extension + description: Upserts the OpenGraph extension + tags: + - OpenGraph + responses: + 200: + description: OK + 201: + description: CREATED + 400: + $ref: './../responses/bad-request.yaml' + 401: + $ref: './../responses/unauthorized.yaml' + 403: + $ref: './../responses/forbidden.yaml' + 404: + $ref: './../responses/not-found.yaml' + 429: + $ref: './../responses/too-many-requests.yaml' + 500: + $ref: './../responses/internal-server-error.yaml'