Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sqliteq
import (
"database/sql"
"fmt"
"strings"
"sync/atomic"
"time"

Expand Down Expand Up @@ -44,7 +45,7 @@ func newQueue(db *sql.DB, tableName string, opts ...Option) (*Queue, error) {
// initTable initializes the queue table if it doesn't exist
func (q *Queue) initTable() error {
createTableSQL := fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
CREATE TABLE IF NOT EXISTS %[1]s (
id INTEGER PRIMARY KEY AUTOINCREMENT,
data BLOB NOT NULL,
status TEXT NOT NULL,
Expand All @@ -53,10 +54,14 @@ func (q *Queue) initTable() error {
created_at TIMESTAMP,
updated_at TIMESTAMP
);
CREATE INDEX IF NOT EXISTS %s_status_idx ON %s (status, created_at);
CREATE INDEX IF NOT EXISTS %s_status_ack_idx ON %s (status, ack);
CREATE INDEX IF NOT EXISTS %s_ack_id_idx ON %s (ack_id);
`, q.tableName, q.tableName, q.tableName, q.tableName, q.tableName, q.tableName, q.tableName)
CREATE INDEX IF NOT EXISTS %[2]s ON %[1]s (status, created_at);
CREATE INDEX IF NOT EXISTS %[3]s ON %[1]s (status, ack);
CREATE INDEX IF NOT EXISTS %[4]s ON %[1]s (ack_id);
`,
quoteIdent(q.tableName),
quoteIdent(q.tableName+"_status_idx"),
quoteIdent(q.tableName+"_status_ack_idx"),
quoteIdent(q.tableName+"_ack_id_idx"))

_, err := q.client.Exec(createTableSQL)
return err
Expand All @@ -72,7 +77,8 @@ func (q *Queue) RequeueNoAckRows() {
}()

_, err = tx.Exec(
fmt.Sprintf("UPDATE %s SET status = 'pending', updated_at = ? WHERE status = 'processing' AND ack = 0", q.tableName),
fmt.Sprintf("UPDATE %s SET status = 'pending', updated_at = ? WHERE status = 'processing' AND ack = 0",
quoteIdent(q.tableName)),
time.Now().UTC(),
)

Expand All @@ -99,9 +105,8 @@ func (q *Queue) Enqueue(item any) bool {
}()

_, err = tx.Exec(
fmt.Sprintf("INSERT INTO %s (data, status, ack, created_at, updated_at) VALUES (?, ?, ?, ?, ?)", q.tableName),
item, "pending", 0, now, now,
)
fmt.Sprintf("INSERT INTO %s (data, status, ack, created_at, updated_at) VALUES (?, ?, ?, ?, ?)",
quoteIdent(q.tableName)), item, "pending", 0, now, now)
if err != nil {
return false
}
Expand Down Expand Up @@ -135,7 +140,7 @@ func (q *Queue) dequeueInternal(withAckId bool) (item any, success bool, ackID s
// Only dequeue pending items in FIFO order
row := tx.QueryRow(fmt.Sprintf(
"SELECT id, data, ack_id FROM %s WHERE status = 'pending' ORDER BY created_at ASC LIMIT 1",
q.tableName,
quoteIdent(q.tableName),
))

// Use NullString to handle NULL values from database
Expand Down Expand Up @@ -168,13 +173,14 @@ func (q *Queue) dequeueInternal(withAckId bool) (item any, success bool, ackID s

// Update the item to processing status
_, err = tx.Exec(
fmt.Sprintf("UPDATE %s SET status = 'processing', ack_id = ?, updated_at = ? WHERE id = ?", q.tableName),
fmt.Sprintf("UPDATE %s SET status = 'processing', ack_id = ?, updated_at = ? WHERE id = ?",
quoteIdent(q.tableName)),
ackID, now, id,
)
} else {
// For regular Dequeue, just delete the item immediately
_, err = tx.Exec(
fmt.Sprintf("DELETE FROM %s WHERE id = ?", q.tableName),
fmt.Sprintf("DELETE FROM %s WHERE id = ?", quoteIdent(q.tableName)),
id,
)
}
Expand Down Expand Up @@ -224,13 +230,13 @@ func (q *Queue) Acknowledge(ackID string) bool {
if q.removeOnComplete {
// If removeOnComplete is true, delete the acknowledged item
result, err = tx.Exec(
fmt.Sprintf("DELETE FROM %s WHERE ack_id = ? ", q.tableName),
fmt.Sprintf("DELETE FROM %s WHERE ack_id = ? ", quoteIdent(q.tableName)),
ackID,
)
} else {
// Otherwise, mark it as completed and set ack to 1 (true in SQLite)
result, err = tx.Exec(
fmt.Sprintf("UPDATE %s SET status = 'completed', ack = 1, updated_at = ? WHERE ack_id = ?", q.tableName),
fmt.Sprintf("UPDATE %s SET status = 'completed', ack = 1, updated_at = ? WHERE ack_id = ?", quoteIdent(q.tableName)),
time.Now().UTC(), ackID,
)
}
Expand All @@ -253,7 +259,7 @@ func (q *Queue) Acknowledge(ackID string) bool {
// Len returns the number of pending items in the queue
func (q *Queue) Len() int {
var count int
row := q.client.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE status = 'pending'", q.tableName))
row := q.client.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s WHERE status = 'pending'", quoteIdent(q.tableName)))
err := row.Scan(&count)
if err != nil {
return 0
Expand All @@ -263,7 +269,7 @@ func (q *Queue) Len() int {

// Values returns all pending items in the queue
func (q *Queue) Values() []any {
rows, err := q.client.Query(fmt.Sprintf("SELECT data FROM %s WHERE status = 'pending' ORDER BY created_at ASC", q.tableName))
rows, err := q.client.Query(fmt.Sprintf("SELECT data FROM %s WHERE status = 'pending' ORDER BY created_at ASC", quoteIdent(q.tableName)))
if err != nil {
return nil
}
Expand Down Expand Up @@ -296,7 +302,7 @@ func (q *Queue) Purge() {
}
}()

_, err = tx.Exec(fmt.Sprintf("DELETE FROM %s", q.tableName))
_, err = tx.Exec(fmt.Sprintf("DELETE FROM %s", quoteIdent(q.tableName)))
if err != nil {
tx.Rollback()
return
Expand All @@ -311,3 +317,11 @@ func (q *Queue) Close() error {

return nil
}

// Applies quotes to an identifier escaping any internal quotes.
// See: https://www.sqlite.org/lang_keywords.html
func quoteIdent(name string) string {
// Replace quotes with dobule quotes
escaped := strings.ReplaceAll(name, `"`, `""`)
return `"` + escaped + `"`
}
22 changes: 20 additions & 2 deletions queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestSQLiteQueue(t *testing.T) {
if !ok {
t.Errorf("Expected []byte, got %T", data)
}

if string(byteData) != "test item 1" {
t.Errorf("Expected 'test item 1', got '%s'", string(byteData))
}
Expand Down Expand Up @@ -99,7 +99,7 @@ func TestSQLiteQueue(t *testing.T) {
if !ok {
t.Errorf("Expected []byte, got %T", item)
}

if string(byteData) != "42" {
t.Errorf("Expected '42', got '%s'", string(byteData))
}
Expand Down Expand Up @@ -324,3 +324,21 @@ func TestConcurrentOperations(t *testing.T) {
t.Errorf("Expected empty queue, got length %d", q.Len())
}
}

func TestQuoteIdent(t *testing.T) {
tt := []struct{ input, want string }{
{"foo", `"foo"`},
{"spaces are here", `"spaces are here"`},
{`"quoted"`, `"""quoted"""`},
{``, `""`}, // belive it or not this is a valid table name
}

for _, tc := range tt {
t.Run(tc.input, func(t *testing.T) {
got := quoteIdent(tc.input)
if got != tc.want {
t.Errorf("Unexpected quotes tabled (want %q, got %q)", tc.want, got)
}
})
}
}