Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 32 additions & 21 deletions pkg/rid/store/cockroach/identification_service_area.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ func (r *repo) fetchISAs(ctx context.Context, query string, args ...interface{})
for rows.Next() {
i := new(ridmodels.IdentificationServiceArea)

var updateTime time.Time
var (
updateTime time.Time
count int
)

err := rows.Scan(
&i.ID,
Expand All @@ -46,10 +49,14 @@ func (r *repo) fetchISAs(ctx context.Context, query string, args ...interface{})
&i.EndTime,
&writer,
&updateTime,
&count,
)
if err != nil {
return nil, stacktrace.Propagate(err, "Error scanning ISA row")
}
if count > dssmodels.MaxResultLimit {
return nil, stacktrace.NewError("Query returned %d ISAs which exceeds the maximum allowed %d", count, dssmodels.MaxResultLimit)
}
i.Writer = writer.String
i.SetCells(cids)
i.Version = dssmodels.VersionFromTime(updateTime)
Expand Down Expand Up @@ -80,7 +87,9 @@ func (r *repo) fetchISA(ctx context.Context, query string, args ...interface{})
// Returns nil, nil if not found
func (r *repo) GetISA(ctx context.Context, id dssmodels.ID, forUpdate bool) (*ridmodels.IdentificationServiceArea, error) {
var query = fmt.Sprintf(`
SELECT %s FROM
SELECT
%s,1 -- placeholder for count
FROM
identification_service_areas
WHERE
id = $1
Expand Down Expand Up @@ -108,7 +117,8 @@ func (r *repo) InsertISA(ctx context.Context, isa *ridmodels.IdentificationServi
VALUES
($1, $2, $3, $4, $5, $6, $7, transaction_timestamp())
RETURNING
%s`, isaFields, isaFields)
%s,1 -- placeholder for count
`, isaFields, isaFields)
)

cids := make([]int64, len(isa.Cells))
Expand Down Expand Up @@ -143,7 +153,8 @@ func (r *repo) UpdateISA(ctx context.Context, isa *ridmodels.IdentificationServi
SET (%s) = ($1, $2, $3, $4, $5, $7, transaction_timestamp())
WHERE id = $1 AND updated_at = $6
RETURNING
%s`, updateISAFields, isaFields)
%s,1 -- placeholder for count
`, updateISAFields, isaFields)
)

cids, err := dssql.CellUnionToCellIdsWithValidation(isa.Cells)
Expand All @@ -169,7 +180,9 @@ func (r *repo) DeleteISA(ctx context.Context, isa *ridmodels.IdentificationServi
id = $1
AND
updated_at = $2
RETURNING %s`, isaFields)
RETURNING
%s,1 -- placeholder for count
`, isaFields)
)
id, err := isa.ID.PgUUID()
if err != nil {
Expand All @@ -187,7 +200,7 @@ func (r *repo) SearchISAs(ctx context.Context, cells s2.CellUnion, earliest *tim
// Make them real values (not pointers), on the model layer.
isasInCellsQuery = fmt.Sprintf(`
SELECT
%s
%s, COUNT(*) OVER() -- placeholder for count
FROM
identification_service_areas
WHERE
Expand All @@ -196,7 +209,7 @@ func (r *repo) SearchISAs(ctx context.Context, cells s2.CellUnion, earliest *tim
COALESCE(starts_at <= $2, true)
AND
cells && $3
LIMIT $4`, isaFields)
`, isaFields)
)

if len(cells) == 0 {
Expand All @@ -207,7 +220,7 @@ func (r *repo) SearchISAs(ctx context.Context, cells s2.CellUnion, earliest *tim
return nil, stacktrace.NewError("Earliest start time is missing")
}

return r.fetchISAs(ctx, isasInCellsQuery, earliest, latest, dssql.CellUnionToCellIds(cells), dssmodels.MaxResultLimit)
return r.fetchISAs(ctx, isasInCellsQuery, earliest, latest, dssql.CellUnionToCellIds(cells))
}

// ListExpiredISAs lists all expired ISAs based on writer.
Expand All @@ -219,18 +232,16 @@ func (r *repo) ListExpiredISAs(ctx context.Context, writer string) ([]*ridmodels
writerQuery = "'' OR writer = NULL"
}

var (
isasInCellsQuery = fmt.Sprintf(`
SELECT
%s
FROM
identification_service_areas
WHERE
ends_at + INTERVAL '%d' MINUTE <= CURRENT_TIMESTAMP
AND
(writer = %s)
LIMIT $1`, isaFields, expiredDurationInMin, writerQuery)
)
var isasInCellsQuery = fmt.Sprintf(`
SELECT
%s, COUNT(*) OVER() -- placeholder for count
FROM
identification_service_areas
WHERE
ends_at + INTERVAL '%d' MINUTE <= CURRENT_TIMESTAMP
AND
(writer = %s)
`, isaFields, expiredDurationInMin, writerQuery)

return r.fetchISAs(ctx, isasInCellsQuery, dssmodels.MaxResultLimit)
return r.fetchISAs(ctx, isasInCellsQuery)
}
39 changes: 25 additions & 14 deletions pkg/rid/store/cockroach/subscriptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ func (r *repo) process(ctx context.Context, query string, args ...interface{}) (
for rows.Next() {
s := new(ridmodels.Subscription)

var updateTime time.Time
var (
updateTime time.Time
count int
)

err := rows.Scan(
&s.ID,
Expand All @@ -47,12 +50,15 @@ func (r *repo) process(ctx context.Context, query string, args ...interface{}) (
&s.EndTime,
&writer,
&updateTime,
&count,
)
if err != nil {
return nil, stacktrace.Propagate(err, "Error scanning Subscription row")
}
if count > dssmodels.MaxResultLimit {
return nil, stacktrace.NewError("Query returned %d subscriptions which exceeds the maximum allowed %d", count, dssmodels.MaxResultLimit)
}
s.Writer = writer.String

s.SetCells(cids)
s.Version = dssmodels.VersionFromTime(updateTime)
payload = append(payload, s)
Expand Down Expand Up @@ -113,7 +119,8 @@ func (r *repo) MaxSubscriptionCountInCellsByOwner(ctx context.Context, cells s2.
func (r *repo) GetSubscription(ctx context.Context, id dssmodels.ID) (*ridmodels.Subscription, error) {
// TODO(steeling) we should enforce startTime and endTime to not be null at the DB level.
var query = fmt.Sprintf(`
SELECT %s FROM subscriptions
SELECT %s, COUNT(*) OVER() -- placeholder for count
FROM subscriptions
WHERE id = $1`, subscriptionFields)
uid, err := id.PgUUID()
if err != nil {
Expand All @@ -132,7 +139,8 @@ func (r *repo) UpdateSubscription(ctx context.Context, s *ridmodels.Subscription
SET (%s) = ($1, $2, $3, $4, $5, $6, $7, transaction_timestamp())
WHERE id = $1 AND updated_at = $8
RETURNING
%s`, updateSubscriptionFields, subscriptionFields)
%s,1 -- placeholder for count
`, updateSubscriptionFields, subscriptionFields)
)

cids, err := dssql.CellUnionToCellIdsWithValidation(s.Cells)
Expand Down Expand Up @@ -167,7 +175,8 @@ func (r *repo) InsertSubscription(ctx context.Context, s *ridmodels.Subscription
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, transaction_timestamp())
RETURNING
%s`, subscriptionFields, subscriptionFields)
%s,1 -- placeholder for count
`, subscriptionFields, subscriptionFields)
)

cids, err := dssql.CellUnionToCellIdsWithValidation(s.Cells)
Expand Down Expand Up @@ -202,7 +211,8 @@ func (r *repo) DeleteSubscription(ctx context.Context, s *ridmodels.Subscription
WHERE
id = $1
AND updated_at = $2
RETURNING %s`, subscriptionFields)
RETURNING %s,1 -- placeholder for count
`, subscriptionFields)
)
id, err := s.ID.PgUUID()
if err != nil {
Expand All @@ -219,7 +229,8 @@ func (r *repo) UpdateNotificationIdxsInCells(ctx context.Context, cells s2.CellU
WHERE
cells && $1
AND ends_at >= $2
RETURNING %s`, subscriptionFields)
RETURNING %s,1 -- placeholder for count
`, subscriptionFields)

return r.process(
ctx, updateQuery, dssql.CellUnionToCellIds(cells), r.clock.Now())
Expand All @@ -230,29 +241,29 @@ func (r *repo) SearchSubscriptions(ctx context.Context, cells s2.CellUnion) ([]*
var (
query = fmt.Sprintf(`
SELECT
%s
%s, COUNT(*) OVER() -- placeholder for count
FROM
subscriptions
WHERE
cells && $1
AND
ends_at >= $2
LIMIT $3`, subscriptionFields)
`, subscriptionFields)
)

if len(cells) == 0 {
return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided")
}

return r.process(ctx, query, dssql.CellUnionToCellIds(cells), r.clock.Now(), dssmodels.MaxResultLimit)
return r.process(ctx, query, dssql.CellUnionToCellIds(cells), r.clock.Now())
}

// SearchSubscriptionsByOwner returns all subscriptions in "cells".
func (r *repo) SearchSubscriptionsByOwner(ctx context.Context, cells s2.CellUnion, owner dssmodels.Owner) ([]*ridmodels.Subscription, error) {
var (
query = fmt.Sprintf(`
SELECT
%s
%s, COUNT(*) OVER() -- placeholder for count
FROM
subscriptions
WHERE
Expand All @@ -261,14 +272,14 @@ func (r *repo) SearchSubscriptionsByOwner(ctx context.Context, cells s2.CellUnio
subscriptions.owner = $2
AND
ends_at >= $3
LIMIT $4`, subscriptionFields)
`, subscriptionFields)
)

if len(cells) == 0 {
return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided")
}

return r.process(ctx, query, dssql.CellUnionToCellIds(cells), owner, r.clock.Now(), dssmodels.MaxResultLimit)
return r.process(ctx, query, dssql.CellUnionToCellIds(cells), owner, r.clock.Now())
}

// ListExpiredSubscriptions lists all expired Subscriptions based on writer.
Expand All @@ -283,7 +294,7 @@ func (r *repo) ListExpiredSubscriptions(ctx context.Context, writer string) ([]*
var (
query = fmt.Sprintf(`
SELECT
%s
%s, COUNT(*) OVER() -- placeholder for count
FROM
subscriptions
WHERE
Expand Down
20 changes: 11 additions & 9 deletions pkg/scd/store/cockroach/constraints.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func (c *repo) fetchConstraints(ctx context.Context, q dsssql.Queryable, query s
var (
c = new(scdmodels.Constraint)
updatedAt time.Time
count int
)
err := rows.Scan(
&c.ID,
Expand All @@ -77,10 +78,14 @@ func (c *repo) fetchConstraints(ctx context.Context, q dsssql.Queryable, query s
&c.EndTime,
&cids,
&updatedAt,
&count,
)
if err != nil {
return nil, stacktrace.Propagate(err, "Error scanning Constraint row")
}
if count > dssmodels.MaxResultLimit {
return nil, stacktrace.NewError("Query returned %d Constraints which exceeds the maximum allowed %d", count, dssmodels.MaxResultLimit)
}
c.Cells = geo.CellUnionFromInt64(cids)
c.OVN = scdmodels.NewOVNFromTime(updatedAt, c.ID.String())
payload = append(payload, c)
Expand Down Expand Up @@ -110,7 +115,7 @@ func (c *repo) GetConstraint(ctx context.Context, id dssmodels.ID) (*scdmodels.C
var (
query = fmt.Sprintf(`
SELECT
%s
%s, COUNT(*) OVER() -- placeholder for count
FROM
scd_constraints
WHERE
Expand Down Expand Up @@ -142,11 +147,9 @@ func (c *repo) UpsertConstraint(ctx context.Context, s *scdmodels.Constraint) (*
ends_at = $8,
cells = $9,
updated_at = transaction_timestamp()
RETURNING %s
`,
constraintFieldsWithoutPrefix,
constraintFieldsWithPrefix,
)
RETURNING
%s,1 -- placeholder for count
`, constraintFieldsWithoutPrefix, constraintFieldsWithPrefix)
)

cids, err := dsssql.CellUnionToCellIdsWithValidation(s.Cells)
Expand Down Expand Up @@ -207,7 +210,7 @@ func (c *repo) SearchConstraints(ctx context.Context, v4d *dssmodels.Volume4D) (
var (
query = fmt.Sprintf(`
SELECT
%s
%s, COUNT(*) OVER() -- placeholder for count
FROM
scd_constraints
WHERE
Expand All @@ -216,7 +219,6 @@ func (c *repo) SearchConstraints(ctx context.Context, v4d *dssmodels.Volume4D) (
COALESCE(starts_at <= $3, true)
AND
COALESCE(ends_at >= $2, true)
LIMIT $4
`, constraintFieldsWithoutPrefix)
)

Expand All @@ -232,7 +234,7 @@ func (c *repo) SearchConstraints(ctx context.Context, v4d *dssmodels.Volume4D) (
}

constraints, err := c.fetchConstraints(
ctx, c.q, query, dsssql.CellUnionToCellIds(cells), v4d.StartTime, v4d.EndTime, dssmodels.MaxResultLimit)
ctx, c.q, query, dsssql.CellUnionToCellIds(cells), v4d.StartTime, v4d.EndTime)
if err != nil {
return nil, stacktrace.Propagate(err, "Error fetching Constraints")
}
Expand Down
Loading