|
| 1 | +package indexworker |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "errors" |
| 6 | + "fmt" |
| 7 | + "log" |
| 8 | + "net/url" |
| 9 | + "strings" |
| 10 | + "time" |
| 11 | + |
| 12 | + "github.com/gobuffalo/pop/v6" |
| 13 | + "github.com/sirupsen/logrus" |
| 14 | + "github.com/supabase/auth/internal/conf" |
| 15 | +) |
| 16 | + |
| 17 | +// ErrAdvisoryLockAlreadyAcquired is returned when another process already holds the advisory lock |
| 18 | +var ErrAdvisoryLockAlreadyAcquired = errors.New("advisory lock already acquired by another process") |
| 19 | + |
| 20 | +// CreateIndexes ensures that the necessary indexes on the users table exist. |
| 21 | +// If the indexes already exist and are valid, it skips creation. |
| 22 | +// It uses a Postgres advisory lock to prevent concurrent index creation |
| 23 | +// by multiple processes. |
| 24 | +// Returns an error either from index creation failure (partial or complete) or if the advisory lock |
| 25 | +// could not be acquired. |
| 26 | +func CreateIndexes(ctx context.Context, config *conf.GlobalConfiguration, le *logrus.Entry) error { |
| 27 | + if config.DB.Driver == "" && config.DB.URL != "" { |
| 28 | + u, err := url.Parse(config.DB.URL) |
| 29 | + if err != nil { |
| 30 | + le.Fatalf("Error parsing db connection url: %+v", err) |
| 31 | + } |
| 32 | + config.DB.Driver = u.Scheme |
| 33 | + } |
| 34 | + |
| 35 | + u, _ := url.Parse(config.DB.URL) |
| 36 | + processedUrl := config.DB.URL |
| 37 | + if len(u.Query()) != 0 { |
| 38 | + processedUrl = fmt.Sprintf("%s&application_name=auth_index_worker", processedUrl) |
| 39 | + } else { |
| 40 | + processedUrl = fmt.Sprintf("%s?application_name=auth_index_worker", processedUrl) |
| 41 | + } |
| 42 | + deets := &pop.ConnectionDetails{ |
| 43 | + Dialect: config.DB.Driver, |
| 44 | + URL: processedUrl, |
| 45 | + } |
| 46 | + deets.Options = map[string]string{ |
| 47 | + "Namespace": config.DB.Namespace, |
| 48 | + } |
| 49 | + |
| 50 | + db, err := pop.NewConnection(deets) |
| 51 | + if err != nil { |
| 52 | + log.Fatalf("Error opening db connection: %+v", err) |
| 53 | + } |
| 54 | + defer db.Close() |
| 55 | + |
| 56 | + if err := db.Open(); err != nil { |
| 57 | + log.Fatalf("Error checking database connection: %+v", err) |
| 58 | + } |
| 59 | + db = db.WithContext(ctx) |
| 60 | + |
| 61 | + // Try to obtain advisory lock to ensure only one index worker is creating indexes at a time |
| 62 | + lockName := "auth_index_worker" |
| 63 | + var lockAcquired bool |
| 64 | + lockQuery := fmt.Sprintf("SELECT pg_try_advisory_lock(hashtext('%s')::bigint)", lockName) |
| 65 | + |
| 66 | + if err := db.RawQuery(lockQuery).First(&lockAcquired); err != nil { |
| 67 | + le.Errorf("Failed to attempt advisory lock acquisition: %+v", err) |
| 68 | + return err |
| 69 | + } |
| 70 | + |
| 71 | + if !lockAcquired { |
| 72 | + le.Infof("Another process is currently creating indexes. Skipping index creation.") |
| 73 | + return ErrAdvisoryLockAlreadyAcquired |
| 74 | + } |
| 75 | + |
| 76 | + le.Infof("Successfully acquired advisory lock for index creation.") |
| 77 | + |
| 78 | + // Ensure lock is released on function exit |
| 79 | + defer func() { |
| 80 | + unlockQuery := fmt.Sprintf("SELECT pg_advisory_unlock(hashtext('%s')::bigint)", lockName) |
| 81 | + var unlocked bool |
| 82 | + if err := db.RawQuery(unlockQuery).First(&unlocked); err != nil { |
| 83 | + if ctx.Err() != nil { |
| 84 | + le.Infof("Context cancelled. Advisory lock will be released upon session termination.") |
| 85 | + } else { |
| 86 | + le.Errorf("Failed to release advisory lock: %+v", err) |
| 87 | + } |
| 88 | + } else if unlocked { |
| 89 | + le.Infof("Successfully released advisory lock.") |
| 90 | + } else { |
| 91 | + le.Warnf("Advisory lock was not held when attempting to release.") |
| 92 | + } |
| 93 | + }() |
| 94 | + |
| 95 | + indexes := getUsersIndexes(config.DB.Namespace) |
| 96 | + indexNames := make([]string, len(indexes)) |
| 97 | + for i, idx := range indexes { |
| 98 | + indexNames[i] = idx.name |
| 99 | + } |
| 100 | + |
| 101 | + // Check existing indexes and their statuses. If all exist and are valid, skip creation. |
| 102 | + existingIndexes, err := getIndexStatuses(db, config.DB.Namespace, indexNames) |
| 103 | + if err != nil { |
| 104 | + le.Warnf("Failed to check existing indexes: %+v. Proceeding with index creation.", err) |
| 105 | + } else { |
| 106 | + if len(existingIndexes) == len(indexes) { |
| 107 | + allHealthy := true |
| 108 | + for _, idx := range existingIndexes { |
| 109 | + if !idx.IsValid || !idx.IsReady { |
| 110 | + le.Infof("Index %s exists but is not healthy (valid: %v, ready: %v)", idx.IndexName, idx.IsValid, idx.IsReady) |
| 111 | + allHealthy = false |
| 112 | + break |
| 113 | + } |
| 114 | + } |
| 115 | + |
| 116 | + if allHealthy { |
| 117 | + le.Infof("All %d indexes on auth.users already exist and are ready. Skipping index creation.", len(indexes)) |
| 118 | + return nil |
| 119 | + } |
| 120 | + } else { |
| 121 | + le.Infof("Found %d of %d expected indexes. Proceeding with index creation.", len(existingIndexes), len(indexes)) |
| 122 | + } |
| 123 | + } |
| 124 | + |
| 125 | + userCount, err := getApproximateUserCount(db, config.DB.Namespace) |
| 126 | + if err != nil { |
| 127 | + le.Warnf("Failed to get approximate user count: %+v. Proceeding with index creation.", err) |
| 128 | + } |
| 129 | + le.Infof("User count: %d. Starting index creation...", userCount) |
| 130 | + |
| 131 | + // First, clean up any invalid indexes from previous interrupted attempts |
| 132 | + dropInvalidIndexes(db, le, config.DB.Namespace, indexNames) |
| 133 | + |
| 134 | + // Create indexes one by one |
| 135 | + var failedIndexes []string |
| 136 | + totalStartTime := time.Now() |
| 137 | + |
| 138 | + for _, idx := range indexes { |
| 139 | + startTime := time.Now() |
| 140 | + le.Infof("Creating index: %s", idx.name) |
| 141 | + |
| 142 | + if err := db.RawQuery(idx.query).Exec(); err != nil { |
| 143 | + duration := time.Since(startTime).Milliseconds() |
| 144 | + |
| 145 | + le.Errorf("Failed to create index %s after %d ms: %v", idx.name, duration, err) |
| 146 | + failedIndexes = append(failedIndexes, idx.name) |
| 147 | + } else { |
| 148 | + duration := time.Since(startTime).Milliseconds() |
| 149 | + le.Infof("Successfully created index %s in %d ms", idx.name, duration) |
| 150 | + } |
| 151 | + } |
| 152 | + |
| 153 | + totalDuration := time.Since(totalStartTime).Milliseconds() |
| 154 | + |
| 155 | + if len(failedIndexes) > 0 { |
| 156 | + le.Warnf("Index creation completed in %d ms with some failures: %v", totalDuration, failedIndexes) |
| 157 | + return fmt.Errorf("failed to create indexes: %v", failedIndexes) |
| 158 | + } else { |
| 159 | + le.Infof("All indexes created successfully in %d ms", totalDuration) |
| 160 | + } |
| 161 | + |
| 162 | + return nil |
| 163 | +} |
| 164 | + |
| 165 | +// getUsersIndexes returns the list of indexes to create on the users table |
| 166 | +func getUsersIndexes(namespace string) []struct { |
| 167 | + name string |
| 168 | + query string |
| 169 | +} { |
| 170 | + // Define indexes to create |
| 171 | + // Note: CONCURRENTLY cannot be used inside a transaction block |
| 172 | + return []struct { |
| 173 | + name string |
| 174 | + query string |
| 175 | + }{ |
| 176 | + // for exact-match queries, sorting, and LIKE '%term%' (trigram) searches on email |
| 177 | + { |
| 178 | + name: "idx_users_email", |
| 179 | + query: fmt.Sprintf(`CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_users_email |
| 180 | + ON %q.users USING btree (email);`, namespace), |
| 181 | + }, |
| 182 | + { |
| 183 | + name: "idx_users_email_trgm", |
| 184 | + query: fmt.Sprintf(`CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_users_email_trgm |
| 185 | + ON %q.users USING gin (email gin_trgm_ops);`, namespace), |
| 186 | + }, |
| 187 | + // enables exact-match and prefix searches and sorting by phone number |
| 188 | + { |
| 189 | + name: "idx_users_phone_pattern", |
| 190 | + query: fmt.Sprintf(`CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_users_phone_pattern |
| 191 | + ON %q.users USING btree (phone text_pattern_ops);`, namespace), |
| 192 | + }, |
| 193 | + // for range queries and sorting on created_at and last_sign_in_at |
| 194 | + { |
| 195 | + name: "idx_users_created_at_desc", |
| 196 | + query: fmt.Sprintf(`CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_users_created_at_desc |
| 197 | + ON %q.users (created_at DESC);`, namespace), |
| 198 | + }, |
| 199 | + { |
| 200 | + name: "idx_users_last_sign_in_at_desc", |
| 201 | + query: fmt.Sprintf(`CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_users_last_sign_in_at_desc |
| 202 | + ON %q.users (last_sign_in_at DESC);`, namespace), |
| 203 | + }, |
| 204 | + // trigram indexes on name field in raw_user_meta_data JSONB - enables fast LIKE '%term%' searches |
| 205 | + { |
| 206 | + name: "idx_users_name_trgm", |
| 207 | + query: fmt.Sprintf(`CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_users_name_trgm |
| 208 | + ON %q.users USING gin ((raw_user_meta_data->>'name') gin_trgm_ops) |
| 209 | + WHERE raw_user_meta_data->>'name' IS NOT NULL;`, namespace), |
| 210 | + }, |
| 211 | + } |
| 212 | +} |
| 213 | + |
| 214 | +type indexStatus struct { |
| 215 | + IndexName string `db:"index_name"` |
| 216 | + IsValid bool `db:"is_valid"` |
| 217 | + IsReady bool `db:"is_ready"` |
| 218 | +} |
| 219 | + |
| 220 | +// getIndexStatuses checks the status of the given indexes in the specified namespace |
| 221 | +func getIndexStatuses(db *pop.Connection, namespace string, indexNames []string) ([]indexStatus, error) { |
| 222 | + indexNamesList := make([]string, len(indexNames)) |
| 223 | + for i, idx := range indexNames { |
| 224 | + indexNamesList[i] = fmt.Sprintf("'%s'", idx) |
| 225 | + } |
| 226 | + indexNamesStr := strings.Join(indexNamesList, ",") |
| 227 | + |
| 228 | + query := fmt.Sprintf(` |
| 229 | + SELECT c.relname as index_name, i.indisvalid as is_valid, i.indisready as is_ready |
| 230 | + FROM pg_index i |
| 231 | + JOIN pg_class c ON c.oid = i.indexrelid |
| 232 | + JOIN pg_namespace n ON n.oid = c.relnamespace |
| 233 | + WHERE n.nspname = '%s' |
| 234 | + AND c.relname IN (%s) |
| 235 | + `, namespace, indexNamesStr) |
| 236 | + |
| 237 | + var existingIndexes []indexStatus |
| 238 | + if err := db.RawQuery(query).All(&existingIndexes); err != nil { |
| 239 | + return nil, err |
| 240 | + } |
| 241 | + |
| 242 | + return existingIndexes, nil |
| 243 | +} |
| 244 | + |
| 245 | +// getApproximateUserCount returns an approximate count of users in the users table to avoid a full table scan |
| 246 | +func getApproximateUserCount(db *pop.Connection, namespace string) (int64, error) { |
| 247 | + var userCount int64 |
| 248 | + countQuery := fmt.Sprintf("SELECT reltuples::BIGINT FROM pg_class WHERE oid = '%q.users'::regclass;", namespace) |
| 249 | + |
| 250 | + if err := db.RawQuery(countQuery).First(&userCount); err != nil { |
| 251 | + return 0, err |
| 252 | + } |
| 253 | + |
| 254 | + return userCount, nil |
| 255 | +} |
| 256 | + |
| 257 | +// dropInvalidIndexes drops any invalid indexes from previous interrupted attempts |
| 258 | +func dropInvalidIndexes(db *pop.Connection, le *logrus.Entry, namespace string, indexNames []string) { |
| 259 | + indexNamesList := make([]string, len(indexNames)) |
| 260 | + for i, idx := range indexNames { |
| 261 | + indexNamesList[i] = fmt.Sprintf("'%s'", idx) |
| 262 | + } |
| 263 | + indexNamesStr := strings.Join(indexNamesList, ",") |
| 264 | + |
| 265 | + // Query the system catalog to find invalid indexes (from interrupted CONCURRENTLY operations) |
| 266 | + cleanupQuery := fmt.Sprintf(` |
| 267 | + SELECT c.relname as index_name |
| 268 | + FROM pg_index i |
| 269 | + JOIN pg_class c ON c.oid = i.indexrelid |
| 270 | + JOIN pg_namespace n ON n.oid = c.relnamespace |
| 271 | + WHERE n.nspname = '%s' |
| 272 | + AND NOT i.indisvalid |
| 273 | + AND c.relname IN (%s) |
| 274 | + `, namespace, indexNamesStr) |
| 275 | + |
| 276 | + type invalidIndex struct { |
| 277 | + IndexName string `db:"index_name"` |
| 278 | + } |
| 279 | + var invalidIndexes []invalidIndex |
| 280 | + if err := db.RawQuery(cleanupQuery).All(&invalidIndexes); err == nil && len(invalidIndexes) > 0 { |
| 281 | + for _, idx := range invalidIndexes { |
| 282 | + le.Warnf("Dropping invalid index from previous interrupted run: %s", idx.IndexName) |
| 283 | + dropQuery := fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %q.%s", namespace, idx.IndexName) |
| 284 | + if err := db.RawQuery(dropQuery).Exec(); err != nil { |
| 285 | + le.Errorf("Failed to drop invalid index %s: %v", idx.IndexName, err) |
| 286 | + } |
| 287 | + } |
| 288 | + } |
| 289 | +} |
0 commit comments