Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions cmd/api/src/database/graphschema.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ type OpenGraphSchema interface {

CreateSchemaRelationshipFinding(ctx context.Context, extensionId int32, relationshipKindId int32, environmentId int32, name string, displayName string) (model.SchemaRelationshipFinding, error)
GetSchemaRelationshipFindingById(ctx context.Context, findingId int32) (model.SchemaRelationshipFinding, error)
GetSchemaRelationshipFindingByName(ctx context.Context, findingName string) (model.SchemaRelationshipFinding, error)
DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error

CreateRemediation(ctx context.Context, findingId int32, shortDescription string, longDescription string, shortRemediation string, longRemediation string) (model.Remediation, error)
Expand Down Expand Up @@ -239,10 +240,10 @@ func (s *BloodhoundDB) GetGraphSchemaNodeKinds(ctx context.Context, filters mode
if filterAndPagination, err := parseFiltersAndPagination(filters, sort, skip, limit); err != nil {
return schemaNodeKinds, 0, err
} else {
sqlStr := fmt.Sprintf(`SELECT nk.id, k.name, nk.schema_extension_id, nk.display_name, nk.description,
sqlStr := fmt.Sprintf(`SELECT nk.id, k.name, nk.schema_extension_id, nk.display_name, nk.description,
nk.is_display_kind, nk.icon, nk.icon_color, nk.created_at, nk.updated_at, nk.deleted_at
FROM %s nk
JOIN %s k ON nk.kind_id = k.id
JOIN %s k ON nk.kind_id = k.id
%s %s %s`,
model.GraphSchemaNodeKind{}.TableName(), kindTable, filterAndPagination.WhereClause, filterAndPagination.OrderSql, filterAndPagination.SkipLimit)
if result := s.db.WithContext(ctx).Raw(sqlStr, filterAndPagination.Filter.params...).Scan(&schemaNodeKinds); result.Error != nil {
Expand Down Expand Up @@ -288,7 +289,7 @@ func (s *BloodhoundDB) UpdateGraphSchemaNodeKind(ctx context.Context, schemaNode
RETURNING id, kind_id, schema_extension_id, display_name, description, is_display_kind, icon, icon_color, created_at, updated_at, deleted_at
)
SELECT updated_row.id, %s.name, schema_extension_id, display_name, description, is_display_kind, icon, icon_color, created_at, updated_at, deleted_at
FROM updated_row
FROM updated_row
JOIN %s ON %s.id = updated_row.kind_id`,
schemaNodeKind.TableName(), kindTable, kindTable, kindTable), schemaNodeKind.SchemaExtensionId,
schemaNodeKind.DisplayName, schemaNodeKind.Description, schemaNodeKind.IsDisplayKind, schemaNodeKind.Icon,
Expand Down Expand Up @@ -436,7 +437,7 @@ func (s *BloodhoundDB) CreateGraphSchemaEdgeKind(ctx context.Context, name strin
RETURNING id, kind_id, schema_extension_id, description, is_traversable, created_at, updated_at, deleted_at
)
SELECT ie.id, ie.schema_extension_id, dk.name, ie.description, ie.is_traversable, ie.created_at, ie.updated_at, ie.deleted_at
FROM inserted_edges ie
FROM inserted_edges ie
JOIN dawgs_kind dk ON ie.kind_id = dk.id;`, name, schemaExtensionId, description, isTraversable).Scan(&schemaEdgeKind); result.Error != nil {
if strings.Contains(result.Error.Error(), DuplicateKeyValueErrorString) {
return schemaEdgeKind, fmt.Errorf("%w: %v", ErrDuplicateSchemaEdgeKindName, result.Error)
Expand All @@ -457,10 +458,10 @@ func (s *BloodhoundDB) GetGraphSchemaEdgeKinds(ctx context.Context, edgeKindFilt
if filterAndPagination, err := parseFiltersAndPagination(edgeKindFilters, sort, skip, limit); err != nil {
return schemaEdgeKinds, 0, err
} else {
sqlStr := fmt.Sprintf(`SELECT ek.id, k.name, ek.schema_extension_id, ek.description, ek.is_traversable,
sqlStr := fmt.Sprintf(`SELECT ek.id, k.name, ek.schema_extension_id, ek.description, ek.is_traversable,
ek.created_at, ek.updated_at, ek.deleted_at
FROM %s ek
JOIN %s k ON ek.kind_id = k.id
JOIN %s k ON ek.kind_id = k.id
%s %s %s`,
model.GraphSchemaEdgeKind{}.TableName(), kindTable, filterAndPagination.WhereClause,
filterAndPagination.OrderSql, filterAndPagination.SkipLimit)
Expand Down Expand Up @@ -542,7 +543,7 @@ func (s *BloodhoundDB) UpdateGraphSchemaEdgeKind(ctx context.Context, schemaEdge
SET schema_extension_id = ?, description = ?, is_traversable = ?, updated_at = NOW()
WHERE id = ?
RETURNING id, kind_id, schema_extension_id, description, is_traversable, created_at, updated_at, deleted_at
)
)
SELECT updated_row.id, %s.name, schema_extension_id, description, is_traversable, created_at, updated_at, deleted_at
FROM updated_row
JOIN %s ON %s.id = updated_row.kind_id`,
Expand Down Expand Up @@ -659,6 +660,23 @@ func (s *BloodhoundDB) GetSchemaRelationshipFindingById(ctx context.Context, fin
return finding, nil
}

// GetSchemaRelationshipFindingByName - retrieves a schema relationship finding by name.
func (s *BloodhoundDB) GetSchemaRelationshipFindingByName(ctx context.Context, findingName string) (model.SchemaRelationshipFinding, error) {
var finding model.SchemaRelationshipFinding

if result := s.db.WithContext(ctx).Raw(fmt.Sprintf(`
SELECT id, schema_extension_id, relationship_kind_id, environment_id, name, display_name, created_at
FROM %s WHERE name = ?`,
finding.TableName()),
findingName).Scan(&finding); result.Error != nil {
return model.SchemaRelationshipFinding{}, CheckError(result)
} else if result.RowsAffected == 0 {
return model.SchemaRelationshipFinding{}, ErrNotFound
}

return finding, nil
}

// DeleteSchemaRelationshipFinding - deletes a schema relationship finding by id.
func (s *BloodhoundDB) DeleteSchemaRelationshipFinding(ctx context.Context, findingId int32) error {
var finding model.SchemaRelationshipFinding
Expand Down
75 changes: 75 additions & 0 deletions cmd/api/src/database/graphschema_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1776,6 +1776,81 @@ func TestGetSchemaRelationshipFindingById(t *testing.T) {
}
}

func TestGetSchemaRelationshipFindingByName(t *testing.T) {
type args struct {
name string
}
type want struct {
res model.SchemaRelationshipFinding
err error
}
tests := []struct {
name string
setup func() IntegrationTestSuite
args args
want want
}{
{
name: "Success: get schema relationship finding by name",
setup: func() IntegrationTestSuite {
t.Helper()
testSuite := setupIntegrationTestSuite(t)

_, err := testSuite.BHDatabase.CreateGraphSchemaExtension(testSuite.Context, "GetFindingExt", "Get Finding Extension", "v1.0.0")
require.NoError(t, err)

_, err = testSuite.BHDatabase.CreateSchemaEnvironment(testSuite.Context, 1, 1, 1)
require.NoError(t, err)

_, err = testSuite.BHDatabase.CreateSchemaRelationshipFinding(testSuite.Context, 1, 1, 1, "GetByNameFinding", "Get By Name Finding")
require.NoError(t, err)

return testSuite
},
args: args{
name: "GetByNameFinding",
},
want: want{
res: model.SchemaRelationshipFinding{
ID: 1,
SchemaExtensionId: 1,
RelationshipKindId: 1,
EnvironmentId: 1,
Name: "GetByNameFinding",
DisplayName: "Get By Name Finding",
},
},
},
{
name: "Fail: schema relationship finding not found",
setup: func() IntegrationTestSuite {
return setupIntegrationTestSuite(t)
},
args: args{
name: "NotFound",
},
want: want{
err: database.ErrNotFound,
},
},
}
for _, testCase := range tests {
t.Run(testCase.name, func(t *testing.T) {
testSuite := testCase.setup()
defer teardownIntegrationTestSuite(t, &testSuite)

got, err := testSuite.BHDatabase.GetSchemaRelationshipFindingByName(testSuite.Context, testCase.args.name)
if testCase.want.err != nil {
assert.ErrorIs(t, err, testCase.want.err)
} else {
got.CreatedAt = time.Time{}
assert.Equal(t, testCase.want.res, got)
assert.NoError(t, err)
}
})
}
}

func TestDeleteSchemaRelationshipFinding(t *testing.T) {
type args struct {
findingId int32
Expand Down
15 changes: 15 additions & 0 deletions cmd/api/src/database/mocks/db.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.