From 2db9e3e1faaa3bb1e66e5e090a4e6fcc1153b239 Mon Sep 17 00:00:00 2001 From: Conrad Weidenkeller Date: Thu, 15 Jan 2026 19:26:23 -0600 Subject: [PATCH] feat(db): Add GetRelationshipFindingByName BED-6893 --- cmd/api/src/database/graphschema.go | 32 ++++++-- .../database/graphschema_integration_test.go | 75 +++++++++++++++++++ cmd/api/src/database/mocks/db.go | 15 ++++ 3 files changed, 115 insertions(+), 7 deletions(-) diff --git a/cmd/api/src/database/graphschema.go b/cmd/api/src/database/graphschema.go index 03d5e53014..e54926b3c0 100644 --- a/cmd/api/src/database/graphschema.go +++ b/cmd/api/src/database/graphschema.go @@ -59,6 +59,7 @@ type OpenGraphSchema interface { CreateSchemaRelationshipFinding(ctx context.Context, extensionId int32, relationshipKindId int32, environmentId int32, name string, displayName string) (model.SchemaRelationshipFinding, error) GetSchemaRelationshipFindingById(ctx context.Context, findingId int32) (model.SchemaRelationshipFinding, error) + GetSchemaRelationshipFindingByName(ctx context.Context, findingName string) (model.SchemaRelationshipFinding, error) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error CreateRemediation(ctx context.Context, findingId int32, shortDescription string, longDescription string, shortRemediation string, longRemediation string) (model.Remediation, error) @@ -239,10 +240,10 @@ func (s *BloodhoundDB) GetGraphSchemaNodeKinds(ctx context.Context, filters mode if filterAndPagination, err := parseFiltersAndPagination(filters, sort, skip, limit); err != nil { return schemaNodeKinds, 0, err } else { - sqlStr := fmt.Sprintf(`SELECT nk.id, k.name, nk.schema_extension_id, nk.display_name, nk.description, + sqlStr := fmt.Sprintf(`SELECT nk.id, k.name, nk.schema_extension_id, nk.display_name, nk.description, nk.is_display_kind, nk.icon, nk.icon_color, nk.created_at, nk.updated_at, nk.deleted_at FROM %s nk - JOIN %s k ON nk.kind_id = k.id + JOIN %s k ON nk.kind_id = k.id %s %s %s`, model.GraphSchemaNodeKind{}.TableName(), kindTable, filterAndPagination.WhereClause, filterAndPagination.OrderSql, filterAndPagination.SkipLimit) if result := s.db.WithContext(ctx).Raw(sqlStr, filterAndPagination.Filter.params...).Scan(&schemaNodeKinds); result.Error != nil { @@ -288,7 +289,7 @@ func (s *BloodhoundDB) UpdateGraphSchemaNodeKind(ctx context.Context, schemaNode RETURNING id, kind_id, schema_extension_id, display_name, description, is_display_kind, icon, icon_color, created_at, updated_at, deleted_at ) SELECT updated_row.id, %s.name, schema_extension_id, display_name, description, is_display_kind, icon, icon_color, created_at, updated_at, deleted_at - FROM updated_row + FROM updated_row JOIN %s ON %s.id = updated_row.kind_id`, schemaNodeKind.TableName(), kindTable, kindTable, kindTable), schemaNodeKind.SchemaExtensionId, schemaNodeKind.DisplayName, schemaNodeKind.Description, schemaNodeKind.IsDisplayKind, schemaNodeKind.Icon, @@ -436,7 +437,7 @@ func (s *BloodhoundDB) CreateGraphSchemaEdgeKind(ctx context.Context, name strin RETURNING id, kind_id, schema_extension_id, description, is_traversable, created_at, updated_at, deleted_at ) SELECT ie.id, ie.schema_extension_id, dk.name, ie.description, ie.is_traversable, ie.created_at, ie.updated_at, ie.deleted_at - FROM inserted_edges ie + FROM inserted_edges ie JOIN dawgs_kind dk ON ie.kind_id = dk.id;`, name, schemaExtensionId, description, isTraversable).Scan(&schemaEdgeKind); result.Error != nil { if strings.Contains(result.Error.Error(), DuplicateKeyValueErrorString) { return schemaEdgeKind, fmt.Errorf("%w: %v", ErrDuplicateSchemaEdgeKindName, result.Error) @@ -457,10 +458,10 @@ func (s *BloodhoundDB) GetGraphSchemaEdgeKinds(ctx context.Context, edgeKindFilt if filterAndPagination, err := parseFiltersAndPagination(edgeKindFilters, sort, skip, limit); err != nil { return schemaEdgeKinds, 0, err } else { - sqlStr := fmt.Sprintf(`SELECT ek.id, k.name, ek.schema_extension_id, ek.description, ek.is_traversable, + sqlStr := fmt.Sprintf(`SELECT ek.id, k.name, ek.schema_extension_id, ek.description, ek.is_traversable, ek.created_at, ek.updated_at, ek.deleted_at FROM %s ek - JOIN %s k ON ek.kind_id = k.id + JOIN %s k ON ek.kind_id = k.id %s %s %s`, model.GraphSchemaEdgeKind{}.TableName(), kindTable, filterAndPagination.WhereClause, filterAndPagination.OrderSql, filterAndPagination.SkipLimit) @@ -542,7 +543,7 @@ func (s *BloodhoundDB) UpdateGraphSchemaEdgeKind(ctx context.Context, schemaEdge SET schema_extension_id = ?, description = ?, is_traversable = ?, updated_at = NOW() WHERE id = ? RETURNING id, kind_id, schema_extension_id, description, is_traversable, created_at, updated_at, deleted_at - ) + ) SELECT updated_row.id, %s.name, schema_extension_id, description, is_traversable, created_at, updated_at, deleted_at FROM updated_row JOIN %s ON %s.id = updated_row.kind_id`, @@ -659,6 +660,23 @@ func (s *BloodhoundDB) GetSchemaRelationshipFindingById(ctx context.Context, fin return finding, nil } +// GetSchemaRelationshipFindingByName - retrieves a schema relationship finding by name. +func (s *BloodhoundDB) GetSchemaRelationshipFindingByName(ctx context.Context, findingName string) (model.SchemaRelationshipFinding, error) { + var finding model.SchemaRelationshipFinding + + if result := s.db.WithContext(ctx).Raw(fmt.Sprintf(` + SELECT id, schema_extension_id, relationship_kind_id, environment_id, name, display_name, created_at + FROM %s WHERE name = ?`, + finding.TableName()), + findingName).Scan(&finding); result.Error != nil { + return model.SchemaRelationshipFinding{}, CheckError(result) + } else if result.RowsAffected == 0 { + return model.SchemaRelationshipFinding{}, ErrNotFound + } + + return finding, nil +} + // DeleteSchemaRelationshipFinding - deletes a schema relationship finding by id. func (s *BloodhoundDB) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error { var finding model.SchemaRelationshipFinding diff --git a/cmd/api/src/database/graphschema_integration_test.go b/cmd/api/src/database/graphschema_integration_test.go index 28db314d6a..646f4e86d5 100644 --- a/cmd/api/src/database/graphschema_integration_test.go +++ b/cmd/api/src/database/graphschema_integration_test.go @@ -1776,6 +1776,81 @@ func TestGetSchemaRelationshipFindingById(t *testing.T) { } } +func TestGetSchemaRelationshipFindingByName(t *testing.T) { + type args struct { + name string + } + type want struct { + res model.SchemaRelationshipFinding + err error + } + tests := []struct { + name string + setup func() IntegrationTestSuite + args args + want want + }{ + { + name: "Success: get schema relationship finding by name", + setup: func() IntegrationTestSuite { + t.Helper() + testSuite := setupIntegrationTestSuite(t) + + _, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "GetFindingExt", "Get Finding Extension", "v1.0.0") + require.NoError(t, err) + + _, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1) + require.NoError(t, err) + + _, err = testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, 1, 1, 1, "GetByNameFinding", "Get By Name Finding") + require.NoError(t, err) + + return testSuite + }, + args: args{ + name: "GetByNameFinding", + }, + want: want{ + res: model.SchemaRelationshipFinding{ + ID: 1, + SchemaExtensionId: 1, + RelationshipKindId: 1, + EnvironmentId: 1, + Name: "GetByNameFinding", + DisplayName: "Get By Name Finding", + }, + }, + }, + { + name: "Fail: schema relationship finding not found", + setup: func() IntegrationTestSuite { + return setupIntegrationTestSuite(t) + }, + args: args{ + name: "NotFound", + }, + want: want{ + err: database.ErrNotFound, + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + testSuite := testCase.setup() + defer teardownIntegrationTestSuite(t, &testSuite) + + got, err := testSuite.BHDatabase.GetSchemaRelationshipFindingByName(testSuite.Context, testCase.args.name) + if testCase.want.err != nil { + assert.ErrorIs(t, err, testCase.want.err) + } else { + got.CreatedAt = time.Time{} + assert.Equal(t, testCase.want.res, got) + assert.NoError(t, err) + } + }) + } +} + func TestDeleteSchemaRelationshipFinding(t *testing.T) { type args struct { findingId int32 diff --git a/cmd/api/src/database/mocks/db.go b/cmd/api/src/database/mocks/db.go index cf16fece30..dc43e0d1ca 100644 --- a/cmd/api/src/database/mocks/db.go +++ b/cmd/api/src/database/mocks/db.go @@ -2232,6 +2232,21 @@ func (mr *MockDatabaseMockRecorder) GetSchemaRelationshipFindingById(ctx, findin return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaRelationshipFindingById", reflect.TypeOf((*MockDatabase)(nil).GetSchemaRelationshipFindingById), ctx, findingId) } +// GetSchemaRelationshipFindingByName mocks base method. +func (m *MockDatabase) GetSchemaRelationshipFindingByName(ctx context.Context, findingName string) (model.SchemaRelationshipFinding, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSchemaRelationshipFindingByName", ctx, findingName) + ret0, _ := ret[0].(model.SchemaRelationshipFinding) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSchemaRelationshipFindingByName indicates an expected call of GetSchemaRelationshipFindingByName. +func (mr *MockDatabaseMockRecorder) GetSchemaRelationshipFindingByName(ctx, findingName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSchemaRelationshipFindingByName", reflect.TypeOf((*MockDatabase)(nil).GetSchemaRelationshipFindingByName), ctx, findingName) +} + // GetScopeForSavedQuery mocks base method. func (m *MockDatabase) GetScopeForSavedQuery(ctx context.Context, queryID int64, userID uuid.UUID) (database.SavedQueryScopeMap, error) { m.ctrl.T.Helper()