diff --git a/queue.go b/queue.go index 4eb7d39..9171332 100644 --- a/queue.go +++ b/queue.go @@ -3,6 +3,7 @@ package sqliteq import ( "database/sql" "fmt" + "strings" "sync/atomic" "time" @@ -44,7 +45,7 @@ func newQueue(db *sql.DB, tableName string, opts ...Option) (*Queue, error) { // initTable initializes the queue table if it doesn't exist func (q *Queue) initTable() error { createTableSQL := fmt.Sprintf(` - CREATE TABLE IF NOT EXISTS %s ( + CREATE TABLE IF NOT EXISTS %[1]s ( id INTEGER PRIMARY KEY AUTOINCREMENT, data BLOB NOT NULL, status TEXT NOT NULL, @@ -53,10 +54,14 @@ func (q *Queue) initTable() error { created_at TIMESTAMP, updated_at TIMESTAMP ); - CREATE INDEX IF NOT EXISTS %s_status_idx ON %s (status, created_at); - CREATE INDEX IF NOT EXISTS %s_status_ack_idx ON %s (status, ack); - CREATE INDEX IF NOT EXISTS %s_ack_id_idx ON %s (ack_id); - `, q.tableName, q.tableName, q.tableName, q.tableName, q.tableName, q.tableName, q.tableName) + CREATE INDEX IF NOT EXISTS %[2]s ON %[1]s (status, created_at); + CREATE INDEX IF NOT EXISTS %[3]s ON %[1]s (status, ack); + CREATE INDEX IF NOT EXISTS %[4]s ON %[1]s (ack_id); + `, + quoteIdent(q.tableName), + quoteIdent(q.tableName+"_status_idx"), + quoteIdent(q.tableName+"_status_ack_idx"), + quoteIdent(q.tableName+"_ack_id_idx")) _, err := q.client.Exec(createTableSQL) return err @@ -72,7 +77,8 @@ func (q *Queue) RequeueNoAckRows() { }() _, err = tx.Exec( - fmt.Sprintf("UPDATE %s SET status = 'pending', updated_at = ? WHERE status = 'processing' AND ack = 0", q.tableName), + fmt.Sprintf("UPDATE %s SET status = 'pending', updated_at = ? WHERE status = 'processing' AND ack = 0", + quoteIdent(q.tableName)), time.Now().UTC(), ) @@ -99,9 +105,8 @@ func (q *Queue) Enqueue(item any) bool { }() _, err = tx.Exec( - fmt.Sprintf("INSERT INTO %s (data, status, ack, created_at, updated_at) VALUES (?, ?, ?, ?, ?)", q.tableName), - item, "pending", 0, now, now, - ) + fmt.Sprintf("INSERT INTO %s (data, status, ack, created_at, updated_at) VALUES (?, ?, ?, ?, ?)", + quoteIdent(q.tableName)), item, "pending", 0, now, now) if err != nil { return false } @@ -135,7 +140,7 @@ func (q *Queue) dequeueInternal(withAckId bool) (item any, success bool, ackID s // Only dequeue pending items in FIFO order row := tx.QueryRow(fmt.Sprintf( "SELECT id, data, ack_id FROM %s WHERE status = 'pending' ORDER BY created_at ASC LIMIT 1", - q.tableName, + quoteIdent(q.tableName), )) // Use NullString to handle NULL values from database @@ -168,13 +173,14 @@ func (q *Queue) dequeueInternal(withAckId bool) (item any, success bool, ackID s // Update the item to processing status _, err = tx.Exec( - fmt.Sprintf("UPDATE %s SET status = 'processing', ack_id = ?, updated_at = ? WHERE id = ?", q.tableName), + fmt.Sprintf("UPDATE %s SET status = 'processing', ack_id = ?, updated_at = ? WHERE id = ?", + quoteIdent(q.tableName)), ackID, now, id, ) } else { // For regular Dequeue, just delete the item immediately _, err = tx.Exec( - fmt.Sprintf("DELETE FROM %s WHERE id = ?", q.tableName), + fmt.Sprintf("DELETE FROM %s WHERE id = ?", quoteIdent(q.tableName)), id, ) } @@ -224,13 +230,13 @@ func (q *Queue) Acknowledge(ackID string) bool { if q.removeOnComplete { // If removeOnComplete is true, delete the acknowledged item result, err = tx.Exec( - fmt.Sprintf("DELETE FROM %s WHERE ack_id = ? ", q.tableName), + fmt.Sprintf("DELETE FROM %s WHERE ack_id = ? ", quoteIdent(q.tableName)), ackID, ) } else { // Otherwise, mark it as completed and set ack to 1 (true in SQLite) result, err = tx.Exec( - fmt.Sprintf("UPDATE %s SET status = 'completed', ack = 1, updated_at = ? WHERE ack_id = ?", q.tableName), + fmt.Sprintf("UPDATE %s SET status = 'completed', ack = 1, updated_at = ? WHERE ack_id = ?", quoteIdent(q.tableName)), time.Now().UTC(), ackID, ) } @@ -253,7 +259,7 @@ func (q *Queue) Acknowledge(ackID string) bool { // Len returns the number of pending items in the queue func (q *Queue) Len() int { var count int - row := q.client.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE status = 'pending'", q.tableName)) + row := q.client.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE status = 'pending'", quoteIdent(q.tableName))) err := row.Scan(&count) if err != nil { return 0 @@ -263,7 +269,7 @@ func (q *Queue) Len() int { // Values returns all pending items in the queue func (q *Queue) Values() []any { - rows, err := q.client.Query(fmt.Sprintf("SELECT data FROM %s WHERE status = 'pending' ORDER BY created_at ASC", q.tableName)) + rows, err := q.client.Query(fmt.Sprintf("SELECT data FROM %s WHERE status = 'pending' ORDER BY created_at ASC", quoteIdent(q.tableName))) if err != nil { return nil } @@ -296,7 +302,7 @@ func (q *Queue) Purge() { } }() - _, err = tx.Exec(fmt.Sprintf("DELETE FROM %s", q.tableName)) + _, err = tx.Exec(fmt.Sprintf("DELETE FROM %s", quoteIdent(q.tableName))) if err != nil { tx.Rollback() return @@ -311,3 +317,11 @@ func (q *Queue) Close() error { return nil } + +// Applies quotes to an identifier escaping any internal quotes. +// See: https://www.sqlite.org/lang_keywords.html +func quoteIdent(name string) string { + // Replace quotes with dobule quotes + escaped := strings.ReplaceAll(name, `"`, `""`) + return `"` + escaped + `"` +} diff --git a/queue_test.go b/queue_test.go index 1499cc9..9f3881a 100644 --- a/queue_test.go +++ b/queue_test.go @@ -62,7 +62,7 @@ func TestSQLiteQueue(t *testing.T) { if !ok { t.Errorf("Expected []byte, got %T", data) } - + if string(byteData) != "test item 1" { t.Errorf("Expected 'test item 1', got '%s'", string(byteData)) } @@ -99,7 +99,7 @@ func TestSQLiteQueue(t *testing.T) { if !ok { t.Errorf("Expected []byte, got %T", item) } - + if string(byteData) != "42" { t.Errorf("Expected '42', got '%s'", string(byteData)) } @@ -324,3 +324,21 @@ func TestConcurrentOperations(t *testing.T) { t.Errorf("Expected empty queue, got length %d", q.Len()) } } + +func TestQuoteIdent(t *testing.T) { + tt := []struct{ input, want string }{ + {"foo", `"foo"`}, + {"spaces are here", `"spaces are here"`}, + {`"quoted"`, `"""quoted"""`}, + {``, `""`}, // belive it or not this is a valid table name + } + + for _, tc := range tt { + t.Run(tc.input, func(t *testing.T) { + got := quoteIdent(tc.input) + if got != tc.want { + t.Errorf("Unexpected quotes tabled (want %q, got %q)", tc.want, got) + } + }) + } +}