diff --git a/pkg/models/models.go b/pkg/models/models.go index 7ec4c91e5..ebaedcdc6 100644 --- a/pkg/models/models.go +++ b/pkg/models/models.go @@ -33,11 +33,12 @@ const ( // 32 used to be the highest value allowed by strconv. The new value is 36, // although changes to this will result in RMW errors. versionBase = 32 - - // Set a max limit for the SELECT query result - MaxResultLimit = 10000 ) +// Set a max limit for looping on results +// NOTE: var so it can be overridden in unit tests +var MaxResultLimit = 10000 + // PgUUID converts an ID to a pgtype.UUID. // If the ID this is called on is nil, nil will be returned func (id *ID) PgUUID() (*pgtype.UUID, error) { diff --git a/pkg/rid/store/cockroach/identification_service_area.go b/pkg/rid/store/cockroach/identification_service_area.go index 75f423d17..f44f75382 100644 --- a/pkg/rid/store/cockroach/identification_service_area.go +++ b/pkg/rid/store/cockroach/identification_service_area.go @@ -32,6 +32,7 @@ func (r *repo) fetchISAs(ctx context.Context, query string, args ...interface{}) var cids []int64 var writer pgtype.Text + var count int for rows.Next() { i := new(ridmodels.IdentificationServiceArea) @@ -50,6 +51,10 @@ func (r *repo) fetchISAs(ctx context.Context, query string, args ...interface{}) if err != nil { return nil, stacktrace.Propagate(err, "Error scanning ISA row") } + count++ + if count > dssmodels.MaxResultLimit { + return nil, stacktrace.NewError("Result set exceeded max limit of %d", dssmodels.MaxResultLimit) + } i.Writer = writer.String i.SetCells(cids) i.Version = dssmodels.VersionFromTime(updateTime) @@ -196,7 +201,7 @@ func (r *repo) SearchISAs(ctx context.Context, cells s2.CellUnion, earliest *tim COALESCE(starts_at <= $2, true) AND cells && $3 - LIMIT $4`, isaFields) + `, isaFields) ) if len(cells) == 0 { @@ -207,7 +212,7 @@ func (r *repo) SearchISAs(ctx context.Context, cells s2.CellUnion, earliest *tim return nil, stacktrace.NewError("Earliest start time is missing") } - return r.fetchISAs(ctx, isasInCellsQuery, earliest, latest, dssql.CellUnionToCellIds(cells), dssmodels.MaxResultLimit) + return r.fetchISAs(ctx, isasInCellsQuery, earliest, latest, dssql.CellUnionToCellIds(cells)) } // ListExpiredISAs lists all expired ISAs based on writer. @@ -229,8 +234,8 @@ func (r *repo) ListExpiredISAs(ctx context.Context, writer string) ([]*ridmodels ends_at + INTERVAL '%d' MINUTE <= CURRENT_TIMESTAMP AND (writer = %s) - LIMIT $1`, isaFields, expiredDurationInMin, writerQuery) + `, isaFields, expiredDurationInMin, writerQuery) ) - return r.fetchISAs(ctx, isasInCellsQuery, dssmodels.MaxResultLimit) + return r.fetchISAs(ctx, isasInCellsQuery) } diff --git a/pkg/rid/store/cockroach/subscriptions.go b/pkg/rid/store/cockroach/subscriptions.go index ab148966c..bafe62b85 100644 --- a/pkg/rid/store/cockroach/subscriptions.go +++ b/pkg/rid/store/cockroach/subscriptions.go @@ -32,6 +32,7 @@ func (r *repo) process(ctx context.Context, query string, args ...interface{}) ( var cids []int64 var writer pgtype.Text + var count int for rows.Next() { s := new(ridmodels.Subscription) @@ -51,8 +52,11 @@ func (r *repo) process(ctx context.Context, query string, args ...interface{}) ( if err != nil { return nil, stacktrace.Propagate(err, "Error scanning Subscription row") } + count++ + if count > dssmodels.MaxResultLimit { + return nil, stacktrace.NewError("Result set exceeded max limit of %d", dssmodels.MaxResultLimit) + } s.Writer = writer.String - s.SetCells(cids) s.Version = dssmodels.VersionFromTime(updateTime) payload = append(payload, s) @@ -237,14 +241,14 @@ func (r *repo) SearchSubscriptions(ctx context.Context, cells s2.CellUnion) ([]* cells && $1 AND ends_at >= $2 - LIMIT $3`, subscriptionFields) + `, subscriptionFields) ) if len(cells) == 0 { return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided") } - return r.process(ctx, query, dssql.CellUnionToCellIds(cells), r.clock.Now(), dssmodels.MaxResultLimit) + return r.process(ctx, query, dssql.CellUnionToCellIds(cells), r.clock.Now()) } // SearchSubscriptionsByOwner returns all subscriptions in "cells". @@ -261,14 +265,14 @@ func (r *repo) SearchSubscriptionsByOwner(ctx context.Context, cells s2.CellUnio subscriptions.owner = $2 AND ends_at >= $3 - LIMIT $4`, subscriptionFields) + `, subscriptionFields) ) if len(cells) == 0 { return nil, stacktrace.NewErrorWithCode(dsserr.BadRequest, "no location provided") } - return r.process(ctx, query, dssql.CellUnionToCellIds(cells), owner, r.clock.Now(), dssmodels.MaxResultLimit) + return r.process(ctx, query, dssql.CellUnionToCellIds(cells), owner, r.clock.Now()) } // ListExpiredSubscriptions lists all expired Subscriptions based on writer. diff --git a/pkg/scd/store/cockroach/constraints.go b/pkg/scd/store/cockroach/constraints.go index 5fcc68aa2..115af8c8e 100644 --- a/pkg/scd/store/cockroach/constraints.go +++ b/pkg/scd/store/cockroach/constraints.go @@ -61,6 +61,7 @@ func (c *repo) fetchConstraints(ctx context.Context, q dsssql.Queryable, query s var payload []*scdmodels.Constraint var cids []int64 + var count int for rows.Next() { var ( c = new(scdmodels.Constraint) @@ -81,6 +82,10 @@ func (c *repo) fetchConstraints(ctx context.Context, q dsssql.Queryable, query s if err != nil { return nil, stacktrace.Propagate(err, "Error scanning Constraint row") } + count++ + if count > dssmodels.MaxResultLimit { + return nil, stacktrace.NewError("Result set exceeded max limit of %d", dssmodels.MaxResultLimit) + } c.Cells = geo.CellUnionFromInt64(cids) c.OVN = scdmodels.NewOVNFromTime(updatedAt, c.ID.String()) payload = append(payload, c) @@ -216,7 +221,6 @@ func (c *repo) SearchConstraints(ctx context.Context, v4d *dssmodels.Volume4D) ( COALESCE(starts_at <= $3, true) AND COALESCE(ends_at >= $2, true) - LIMIT $4 `, constraintFieldsWithoutPrefix) ) @@ -232,7 +236,7 @@ func (c *repo) SearchConstraints(ctx context.Context, v4d *dssmodels.Volume4D) ( } constraints, err := c.fetchConstraints( - ctx, c.q, query, dsssql.CellUnionToCellIds(cells), v4d.StartTime, v4d.EndTime, dssmodels.MaxResultLimit) + ctx, c.q, query, dsssql.CellUnionToCellIds(cells), v4d.StartTime, v4d.EndTime) if err != nil { return nil, stacktrace.Propagate(err, "Error fetching Constraints") } diff --git a/pkg/scd/store/cockroach/operational_intents.go b/pkg/scd/store/cockroach/operational_intents.go index 9f73a3b4e..2ad89ad38 100644 --- a/pkg/scd/store/cockroach/operational_intents.go +++ b/pkg/scd/store/cockroach/operational_intents.go @@ -66,6 +66,7 @@ func (s *repo) fetchOperationalIntents(ctx context.Context, q dsssql.Queryable, cids []int64 ussRequestedOVN pgtype.Text pastOVNs []string + count int ) ussAvailabilities := map[dssmodels.Manager]scdmodels.UssAvailabilityState{} for rows.Next() { @@ -92,6 +93,10 @@ func (s *repo) fetchOperationalIntents(ctx context.Context, q dsssql.Queryable, if err != nil { return nil, stacktrace.Propagate(err, "Error scanning Operation row") } + count++ + if count > dssmodels.MaxResultLimit { + return nil, stacktrace.NewError("Result set exceeded max limit of %d", dssmodels.MaxResultLimit) + } // If the managing USS has requested a specific OVN on this operational intent, it will be persisted in DB. // If not, a default DSS-generated OVN based on the last update time is used. @@ -292,7 +297,7 @@ func (s *repo) searchOperationalIntents(ctx context.Context, q dsssql.Queryable, COALESCE(scd_operations.ends_at >= $4, true) AND COALESCE(scd_operations.starts_at <= $5, true) - LIMIT $6`, operationFieldsWithPrefix) + `, operationFieldsWithPrefix) ) if v4d.SpatialVolume == nil || v4d.SpatialVolume.Footprint == nil { @@ -313,7 +318,6 @@ func (s *repo) searchOperationalIntents(ctx context.Context, q dsssql.Queryable, v4d.SpatialVolume.AltitudeHi, v4d.StartTime, v4d.EndTime, - dssmodels.MaxResultLimit, ) if err != nil { return nil, stacktrace.Propagate(err, "Error fetching Operations") @@ -371,12 +375,11 @@ func (s *repo) ListExpiredOperationalIntents(ctx context.Context, threshold time scd_operations.ends_at IS NOT NULL AND scd_operations.ends_at <= $1 OR scd_operations.ends_at IS NULL AND scd_operations.updated_at <= $1 -- use last update time as reference if there is no end time - LIMIT $2`, operationFieldsWithPrefix) + `, operationFieldsWithPrefix) result, err := s.fetchOperationalIntents( ctx, s.q, expiredOpIntentsQuery, threshold, - dssmodels.MaxResultLimit, ) if err != nil { return nil, stacktrace.Propagate(err, "Error fetching Operations") diff --git a/pkg/scd/store/cockroach/subscriptions.go b/pkg/scd/store/cockroach/subscriptions.go index 53c3d710e..567d87074 100644 --- a/pkg/scd/store/cockroach/subscriptions.go +++ b/pkg/scd/store/cockroach/subscriptions.go @@ -97,6 +97,7 @@ func (c *repo) fetchSubscriptions(ctx context.Context, q dsssql.Queryable, query var payload []*scdmodels.Subscription var cids []int64 + var count int for rows.Next() { var ( s = new(scdmodels.Subscription) @@ -120,10 +121,11 @@ func (c *repo) fetchSubscriptions(ctx context.Context, q dsssql.Queryable, query if err != nil { return nil, stacktrace.Propagate(err, "Error scanning Subscription row") } - s.Version = scdmodels.NewOVNFromTime(updatedAt, s.ID.String()) - if err != nil { - return nil, stacktrace.Propagate(err, "Error generating Subscription version") + count++ + if count > dssmodels.MaxResultLimit { + return nil, stacktrace.NewError("Result set exceeded max limit of %d", dssmodels.MaxResultLimit) } + s.Version = scdmodels.NewOVNFromTime(updatedAt, s.ID.String()) s.SetCells(cids) payload = append(payload, s) } @@ -309,7 +311,7 @@ func (c *repo) SearchSubscriptions(ctx context.Context, v4d *dssmodels.Volume4D) COALESCE(starts_at <= $3, true) AND COALESCE(ends_at >= $2, true) - LIMIT $4`, subscriptionFieldsWithPrefix) + `, subscriptionFieldsWithPrefix) ) // TODO: Lazily calculate & cache spatial covering so that it is only ever @@ -324,7 +326,7 @@ func (c *repo) SearchSubscriptions(ctx context.Context, v4d *dssmodels.Volume4D) } subscriptions, err := c.fetchSubscriptions( - ctx, c.q, query, dsssql.CellUnionToCellIds(cells), v4d.StartTime, v4d.EndTime, dssmodels.MaxResultLimit) + ctx, c.q, query, dsssql.CellUnionToCellIds(cells), v4d.StartTime, v4d.EndTime) if err != nil { return nil, stacktrace.Propagate(err, "Unable to fetch Subscriptions") } @@ -410,12 +412,11 @@ func (c *repo) ListExpiredSubscriptions(ctx context.Context, threshold time.Time scd_subscriptions.ends_at IS NOT NULL AND scd_subscriptions.ends_at <= $1 OR scd_subscriptions.ends_at IS NULL AND scd_subscriptions.updated_at <= $1 -- use last update time as reference if there is no end time - LIMIT $2`, subscriptionFieldsWithPrefix) + `, subscriptionFieldsWithPrefix) subscriptions, err := c.fetchSubscriptions( ctx, c.q, expiredSubsQuery, threshold, - dssmodels.MaxResultLimit, ) if err != nil { return nil, stacktrace.Propagate(err, "Unable to fetch Subscriptions") diff --git a/pkg/scd/store/cockroach/subscriptions_test.go b/pkg/scd/store/cockroach/subscriptions_test.go index cd4a11d97..b90712760 100644 --- a/pkg/scd/store/cockroach/subscriptions_test.go +++ b/pkg/scd/store/cockroach/subscriptions_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/google/uuid" "github.com/interuss/dss/pkg/models" scdmodels "github.com/interuss/dss/pkg/scd/models" "github.com/stretchr/testify/require" @@ -115,3 +116,63 @@ func TestListExpiredSubscriptions(t *testing.T) { }) } } + +func TestListExpiredSubscriptionsMaxLimit(t *testing.T) { + // Set lower limit for testing + models.MaxResultLimit = 10 + var ( + ctx = context.Background() + store, tearDownStore = setUpStore(ctx, t) + ) + require.NotNil(t, store) + defer tearDownStore() + + r, err := store.Interact(ctx) + require.NoError(t, err) + + for range models.MaxResultLimit { + id := uuid.New() + subID := models.ID(id.String()) + sub := &scdmodels.Subscription{ + ID: subID, + NotificationIndex: 1, + Manager: "unittest", + StartTime: &start1, + EndTime: &end1, + USSBaseURL: "https://dummy.uss", + NotifyForOperationalIntents: true, + NotifyForConstraints: false, + ImplicitSubscription: true, + Cells: cells, + } + _, err = r.UpsertSubscription(ctx, sub) + require.NoError(t, err) + } + + timeRef := time.Date(2024, time.December, 15, 15, 0, 0, 0, time.UTC) + ttl := time.Hour * 24 * 30 + threshold := timeRef.Add(-ttl) + _, err = r.ListExpiredSubscriptions(ctx, threshold) + require.NoError(t, err) + + // Insert one more to exceed the limit + id := uuid.New() + subID := models.ID(id.String()) + sub := &scdmodels.Subscription{ + ID: subID, + NotificationIndex: 1, + Manager: "unittest", + StartTime: &start1, + EndTime: &end1, + USSBaseURL: "https://dummy.uss", + NotifyForOperationalIntents: true, + NotifyForConstraints: false, + ImplicitSubscription: true, + Cells: cells, + } + _, err = r.UpsertSubscription(ctx, sub) + require.NoError(t, err) + _, err = r.ListExpiredSubscriptions(ctx, threshold) + require.Error(t, err) + require.ErrorContainsf(t, err, "Result set exceeded max limit of", "%d", models.MaxResultLimit) +}