diff --git a/cgosqlite/bulk.go b/cgosqlite/bulk.go new file mode 100644 index 0000000..0a0c4e7 --- /dev/null +++ b/cgosqlite/bulk.go @@ -0,0 +1,261 @@ +package cgosqlite + +// #include +// #include +// #include +// #include +// #include +// #include +// #include "bulk.h" +import "C" +import ( + "fmt" + "time" + "unsafe" + + "github.com/tailscale/sqlite/sqliteh" +) + +type valueType uint64 + +const ( + valueNULL = valueType(0) // matches VALUE_NULL + valueInt64 = valueType(1) // matches VALUE_INT64 + valueText = valueType(2) // matches VALUE_TEXT +) + +// value has the same memory layout as the type cValue. +type value struct { + valueType valueType + value uint64 // either int64 value or { off, len uint32 } into paramsText +} + +// bulkStmt implements sqliteh.BulkStmt. +// It is embedded into BulkQuery/BulkQueryRow/BulkExec to provide common features. +type bulkStmt struct { + active bool // true if a query has been started and not yet reset + stmt *Stmt + params []value + paramsText []byte + paramIndexes map[string]int +} + +func newBulkStmt(s *Stmt) *bulkStmt { + b := &bulkStmt{ + stmt: s, + params: make([]value, s.BindParameterCount()), + } + return b +} + +func (b *bulkStmt) clear() { + for i := range b.params { + b.params[i].valueType = valueNULL + } + b.paramsText = b.paramsText[:0] +} + +func (b *bulkStmt) ResetAndClear() { + if b.active { + b.stmt.ResetAndClear() + b.active = false + } + b.clear() +} + +func (b *bulkStmt) Finalize() error { + b.clear() + return b.stmt.Finalize() +} + +func (b *bulkStmt) ParamIndex(name string) int { + i, ok := b.paramIndexes[name] + if !ok { + if b.paramIndexes == nil { + b.paramIndexes = make(map[string]int) + } + i = b.stmt.BindParameterIndex(name) + b.paramIndexes[name] = i + } + return i +} + +func (b *bulkStmt) SetInt64(i int, value int64) { + v := b.params[i] + v.valueType = valueInt64 + v.value = uint64(value) +} + +func (b *bulkStmt) SetNull(i int) { + v := b.params[i] + v.valueType = valueNULL +} + +func (b *bulkStmt) SetText(i int, value []byte) { + off := len(b.paramsText) + b.paramsText = append(b.paramsText, value...) + + v := b.params[i] + v.valueType = valueText + v.value = uint64(off)<<32 | uint64(len(value)) +} + +type BulkExec struct { + *bulkStmt +} + +func NewBulkExec(stmt *Stmt) (*BulkExec, error) { + b := &BulkExec{bulkStmt: newBulkStmt(stmt)} + // TODO check for lack of return rows? + return b, nil +} + +func (b *BulkExec) Exec() (lastInsertRowID, changes int64, d time.Duration, err error) { + b.stmt.rowid, b.stmt.changes, b.stmt.duration = 0, 0, 0 + var params unsafe.Pointer + if len(b.params) > 0 { + params = unsafe.Pointer(&b.params[0]) + } + var paramsText unsafe.Pointer + if len(b.paramsText) > 0 { + paramsText = unsafe.Pointer(&b.paramsText[0]) + } + res := C.bulk_exec( + b.stmt.stmt, + (*C.struct_cValue)(params), + (*C.char)(paramsText), C.size_t(len(b.paramsText)), + &b.stmt.rowid, &b.stmt.changes, &b.stmt.duration, + ) + b.clear() + if sqliteh.Code(res) != sqliteh.SQLITE_DONE { + return lastInsertRowID, changes, d, errCode(res) + } + lastInsertRowID = int64(b.stmt.rowid) + changes = int64(b.stmt.changes) + d = time.Duration(b.stmt.duration) + return lastInsertRowID, changes, d, nil +} + +const dataArrLen = 128 // matches DATA_ARR_LEN in C + +type BulkQuery struct { + *bulkStmt + err error + dataArr [dataArrLen]value // backing array of data + data []value + dataText []byte + duration int64 // total duration of all time spent in sqlite + colCount int +} + +func NewBulkQuery(stmt *Stmt) (*BulkQuery, error) { + b := &BulkQuery{ + bulkStmt: newBulkStmt(stmt), + dataText: make([]byte, 1<<16), + colCount: stmt.ColumnCount(), + } + return b, nil +} + +func (b *BulkQuery) Int64(i int) int64 { + if b.data[i].valueType != valueInt64 { + panic(fmt.Sprintf("attempting to access column %d as int64, type is %v", i, b.data[i].valueType)) + } + return int64(b.data[i].value) +} + +func (b *BulkQuery) Null(i int) bool { return b.data[i].valueType == valueNULL } + +func (b *BulkQuery) Text(i int) []byte { + if b.data[i].valueType != valueText { + panic(fmt.Sprintf("attempting to access column %d as text, type is %v", i, b.data[i].valueType)) + } + v := b.data[i].value + voff := uint32(v >> 32) + vlen := uint32(v) + return b.dataText[voff : voff+vlen] +} + +func (b *BulkQuery) Error() error { return b.err } + +func (b *BulkQuery) resizeDataText() { + b.dataText = append(b.dataText, make([]byte, len(b.dataText))...) +} + +func (b *BulkQuery) Query() { + b.err = nil + b.active = false + b.stmt.duration = 0 + var params unsafe.Pointer + if len(b.params) > 0 { + params = unsafe.Pointer(&b.params[0]) + } + var paramsText unsafe.Pointer + if len(b.paramsText) > 0 { + paramsText = unsafe.Pointer(&b.paramsText[0]) + } + res := C.bulk_query( + b.stmt.stmt, + (*C.struct_cValue)(params), + (*C.char)(paramsText), C.size_t(len(b.paramsText)), + &b.stmt.duration, + ) + b.duration = int64(b.stmt.duration) + if res == C.SQLITE_ROW { + b.active = true + } + b.err = errCode(res) +} + +func (b *BulkQuery) Next() bool { + if len(b.data) > 0 { + b.data = b.data[b.colCount:] + if len(b.data) > 0 { + return true + } + } + if !b.active { + return false + } + + for { + b.stmt.duration = 0 + data := unsafe.Pointer(&b.dataArr[0]) + dataText := unsafe.Pointer(&b.dataText[0]) + var rowsRead C.size_t + res := C.bulk_query_step( + b.stmt.stmt, + (*C.struct_cValue)(data), + (*C.char)(dataText), C.size_t(len(b.dataText)), + &rowsRead, + &b.stmt.duration, + ) + + if rowsRead == 0 { + switch res { + case C.SQLITE_DONE: + b.active = false + return false + case C.BULK_TEXT_TOO_SMALL: + b.resizeDataText() + continue + default: + b.err = errCode(res) + return false + } + } else { + b.data = b.dataArr[:int(rowsRead)*b.colCount] + switch res { + case C.SQLITE_ROW, C.BULK_TEXT_TOO_SMALL: + b.active = true + return true + case C.SQLITE_DONE: + b.active = false + return true + default: + b.err = errCode(res) + return false + } + } + } +} diff --git a/cgosqlite/bulk.h b/cgosqlite/bulk.h new file mode 100644 index 0000000..383d15b --- /dev/null +++ b/cgosqlite/bulk.h @@ -0,0 +1,169 @@ +static void monotonic_clock_gettime(struct timespec* t) { + clock_gettime(CLOCK_MONOTONIC, t); +} + +int64_t ns_since(const struct timespec t1) +{ + struct timespec t2; + monotonic_clock_gettime(&t2); + return ((int64_t)t2.tv_sec - (int64_t)t1.tv_sec) * (int64_t)1000000000 + + ((int64_t)t2.tv_nsec - (int64_t)t1.tv_nsec); +} + +#define VALUE_NULL 0 +#define VALUE_INT64 1 +#define VALUE_TEXT 2 + +// cValue matches the Go type named value. +struct cValue { + uint64_t valueType; // one of VALUE_* + uint64_t value; // either int64 or off/len uint32 +}; + +static int bulk_bind(sqlite3_stmt* stmt, struct cValue* params, const char* text, size_t textLen) { + int count = sqlite3_bind_parameter_count(stmt); + int ret; + for (int i = 0; i < count; i++) { + switch (params[i].valueType) { + case VALUE_INT64: + ret = sqlite3_bind_int64(stmt, i+1, (sqlite3_int64)params[i].value); + if (ret) { + return ret; + } + break; + case VALUE_TEXT: { + uint32_t off = (uint32_t)(params[i].value>>32); + uint32_t len = (uint32_t)(params[i].value); + if (((size_t)off + (size_t)len) > textLen) { + return SQLITE_MISUSE; + } + const char* p = &text[off]; + ret = sqlite3_bind_text64(stmt, i+1, p, len, SQLITE_STATIC, SQLITE_UTF8); + if (ret) { + return ret; + } + break; + } + case VALUE_NULL: + default: + ret = sqlite3_bind_null(stmt, i+1); + if (ret) { + return ret; + } + break; + } + } + return 0; +} + +static int bulk_exec( + sqlite3_stmt* stmt, + struct cValue* params, + const char* text, size_t textLen, + sqlite3_int64* rowid, + sqlite3_int64* changes, + int64_t* duration_ns) { + struct timespec t1; + if (duration_ns) { + monotonic_clock_gettime(&t1); + } + int ret = bulk_bind(stmt, params, text, textLen); + if (ret) { + return ret; + } + ret = sqlite3_step(stmt); + if (ret != SQLITE_DONE) { + if (duration_ns) { + *duration_ns = ns_since(t1); + } + return ret; + } + sqlite3* db = sqlite3_db_handle(stmt); + *rowid = sqlite3_last_insert_rowid(db); + *changes = sqlite3_changes(db); + sqlite3_reset(stmt); + sqlite3_clear_bindings(stmt); + + if (duration_ns) { + *duration_ns = ns_since(t1); + } + return ret; +} + +#define DATA_ARR_LEN 128 +#define BULK_TEXT_TOO_SMALL -1 + +static int bulk_query_step( + sqlite3_stmt* stmt, + struct cValue* data, + char* dataText, size_t dataTextLen, + size_t* rowsRead, + int64_t* duration_ns) { + int ret; + int row; // data offset + int off = 0; // dataText offset + + int numCols = sqlite3_column_count(stmt); + for (row = 0; (row+1)*numCols < DATA_ARR_LEN; row++) { + for (int col = 0; col < numCols; col++) { + struct cValue* d = &data[row*numCols + col]; + switch (sqlite3_column_type(stmt, col)) { + case SQLITE_NULL: + d->valueType = VALUE_NULL; + break; + case SQLITE_INTEGER: + case SQLITE_FLOAT: + d->valueType = VALUE_INT64; + d->value = sqlite3_column_int64(stmt, col); + break; + case SQLITE_TEXT: + case SQLITE_BLOB: { + int len = sqlite3_column_bytes(stmt, col); + if (off+len >= dataTextLen) { + *rowsRead = row; // do not include the current row + return BULK_TEXT_TOO_SMALL; + } + memcpy(dataText+off, sqlite3_column_blob(stmt, col), len); + d->valueType = VALUE_TEXT; + d->value = ((uint64_t)off)<<32 | ((uint64_t)len); + off += len; + break; + } + } + } + + ret = sqlite3_step(stmt); + if (ret != SQLITE_ROW) { + *rowsRead = row + 1; + if (ret == SQLITE_DONE) { + sqlite3_reset(stmt); + sqlite3_clear_bindings(stmt); + } + return ret; + } + } + + *rowsRead = row; + return ret; +} + +static int bulk_query( + sqlite3_stmt* stmt, + struct cValue* params, + const char* text, size_t textLen, + int64_t* duration_ns) { + struct timespec t1; + if (duration_ns) { + monotonic_clock_gettime(&t1); + } + sqlite3_reset(stmt); + int ret = bulk_bind(stmt, params, text, textLen); + if (ret) { + return ret; + } + ret = sqlite3_step(stmt); + if (duration_ns) { + *duration_ns = ns_since(t1); + } + return ret; +} diff --git a/cgosqlite/bulk_test.go b/cgosqlite/bulk_test.go new file mode 100644 index 0000000..9ca9eed --- /dev/null +++ b/cgosqlite/bulk_test.go @@ -0,0 +1,207 @@ +package cgosqlite_test + +import ( + "testing" + + "github.com/tailscale/sqlite/cgosqlite" + "github.com/tailscale/sqlite/sqliteh" + "tailscale.com/tstest" +) + +func TestBulkExec(t *testing.T) { + db, err := cgosqlite.Open("file:mem?mode=memory", sqliteh.OpenFlagsDefault, "") + if err != nil { + t.Fatal(err) + } + defer func() { + err := db.Close() + if !t.Failed() { + if err != nil { + t.Error(err) + } + } + }() + + stmt, _, err := db.Prepare("CREATE TABLE t (c0, c1, c2);", 0) + if err != nil { + t.Fatal(err) + } + b, err := cgosqlite.NewBulkExec(stmt.(*cgosqlite.Stmt)) + if err != nil { + t.Fatal(err) + } + if _, _, _, err := b.Exec(); err != nil { + t.Fatal(err) + } + if err := b.Finalize(); err != nil { + t.Fatal(err) + } + + stmt, _, err = db.Prepare("INSERT INTO t (c0, c1, c2) VALUES (?, ?, ?);", 0) + if err != nil { + t.Fatal(err) + } + b, err = cgosqlite.NewBulkExec(stmt.(*cgosqlite.Stmt)) + if err != nil { + t.Fatal(err) + } + b.SetNull(0) + b.SetInt64(1, 42) + b.SetText(2, []byte("hello, world!")) + if _, changes, _, err := b.Exec(); err != nil { + t.Fatal(err) + } else if changes != 1 { + t.Errorf("changes=%d, want 1", changes) + } + msg := []byte("hello, world!") + + // TODO: this should be zero allocs, but for some reason it's 1? + err = tstest.MinAllocsPerRun(t, 1, func() { + b.SetNull(0) + b.SetInt64(1, 43) + b.SetText(2, msg) + if _, changes, _, err := b.Exec(); err != nil { + t.Fatal(err) + } else if changes != 1 { + t.Errorf("changes=%d, want 1", changes) + } + }) + if err != nil { + t.Fatal(err) + } + + if err := b.Finalize(); err != nil { + t.Fatal(err) + } +} + +func BenchmarkBulkExec(b *testing.B) { + dir := b.TempDir() + db, err := cgosqlite.Open("file:"+dir+"/testbulkdexec", sqliteh.OpenFlagsDefault, "") + if err != nil { + b.Fatal(err) + } + defer db.Close() + + stmt, _, err := db.Prepare("PRAGMA journal_mode=WAL;", 0) + if err != nil { + b.Fatal(err) + } + if _, err := stmt.Step(); err != nil { + b.Fatal(err) + } + stmt.Finalize() + + stmt, _, err = db.Prepare("PRAGMA synchronous=NORMAL;", 0) + if err != nil { + b.Fatal(err) + } + if _, err := stmt.Step(); err != nil { + b.Fatal(err) + } + stmt.Finalize() + + stmt, _, err = db.Prepare("CREATE TABLE t (c);", 0) + if err != nil { + b.Fatal(err) + } + if _, err := stmt.Step(); err != nil { + b.Fatal(err) + } + stmt.Finalize() + + stmt, _, err = db.Prepare("INSERT INTO t (c) VALUES (?);", 0) + if err != nil { + b.Fatal(err) + } + bstmt, err := cgosqlite.NewBulkExec(stmt.(*cgosqlite.Stmt)) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + bstmt.SetInt64(0, 42) + if _, _, _, err := bstmt.Exec(); err != nil { + b.Fatal(err) + } + } + + if err := bstmt.Finalize(); err != nil { + b.Fatal(err) + } +} + +func TestBulkQuery(t *testing.T) { + db, err := cgosqlite.Open("file:mem?mode=memory", sqliteh.OpenFlagsDefault, "") + if err != nil { + t.Fatal(err) + } + defer func() { + err := db.Close() + if !t.Failed() { + if err != nil { + t.Error(err) + } + } + }() + + stmt, _, err := db.Prepare("CREATE TABLE t (c0, c1, c2);", 0) + if err != nil { + t.Fatal(err) + } + if _, err := stmt.Step(); err != nil { + t.Fatal(err) + } + stmt.Finalize() + + stmt, _, err = db.Prepare("INSERT INTO t (c0, c1, c2) VALUES (?, ?, ?);", 0) + if err != nil { + t.Fatal(err) + } + const totalRows = 128*3 + 3 // not a clean multiple of dataArrLen + for i := 0; i < totalRows; i++ { + stmt.Reset() + stmt.BindInt64(1, int64(i)) + stmt.BindText64(2, "hello c1") + stmt.BindText64(3, "bye c2") + if _, err := stmt.Step(); err != nil { + t.Fatal(err) + } + } + stmt.Finalize() + + stmt, _, err = db.Prepare("SELECT c0, c1, c2 FROM t;", 0) + if err != nil { + t.Fatal(err) + } + bstmt, err := cgosqlite.NewBulkQuery(stmt.(*cgosqlite.Stmt)) + if err != nil { + t.Fatal(err) + } + bstmt.Query() + rows := 0 + for bstmt.Next() { + if got, want := bstmt.Int64(0), int64(rows); got != want { + t.Errorf("row %d, col0=%d, want %d", rows, got, rows) + } + if got, want := string(bstmt.Text(1)), "hello c1"; got != want { + t.Errorf("row %d, col1=%q, want %q", rows, got, want) + } + if got, want := string(bstmt.Text(2)), "bye c2"; got != want { + t.Errorf("row %d, col2=%q, want %q", rows, got, want) + } + rows++ + } + if err := bstmt.Error(); err != nil { + t.Fatal(err) + } + if rows != totalRows { + t.Fatalf("rows=%d, want totalRows %d", rows, totalRows) + } + if err := bstmt.Finalize(); err != nil { + t.Fatal(err) + } +} diff --git a/go.mod b/go.mod index b3d3be8..ba84d04 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,9 @@ module github.com/tailscale/sqlite -go 1.16 +go 1.18 + +require ( + github.com/google/go-cmp v0.5.7 // indirect + go4.org/mem v0.0.0-20210711025021-927187094b94 // indirect + tailscale.com v1.24.2 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..36220d2 --- /dev/null +++ b/go.sum @@ -0,0 +1,7 @@ +github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= +github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= +go4.org/mem v0.0.0-20210711025021-927187094b94 h1:OAAkygi2Js191AJP1Ds42MhJRgeofeKGjuoUqNp1QC4= +go4.org/mem v0.0.0-20210711025021-927187094b94/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +tailscale.com v1.24.2 h1:xNqEMKLHjqKwKUlggL2QEt1B+oit08w3SwZEIWCmqTg= +tailscale.com v1.24.2/go.mod h1:/z/lF98LSt9CjpwVT6pHh5Vwb1NGsM5/ACI4cLMJfvY= diff --git a/sqliteh/sqliteh.go b/sqliteh/sqliteh.go index 6dcf1db..805377a 100644 --- a/sqliteh/sqliteh.go +++ b/sqliteh/sqliteh.go @@ -851,3 +851,50 @@ func itoa(buf []byte, val int64) []byte { } return buf[i:] } + +type BulkStmt interface { + ResetAndClear() + Finalize() + + SetInt64(i int, v int64) + SetNull(i int) + SetText(i int, v []byte) + + ParamIndex(name string) int +} + +// BullkQuery executes an SQL statement. +// It is designed to minimize allocations and cgo calls. +type BulkQuery interface { + BulkStmt + + // Int64 reports the current row's column i value as an int64. + // Panics if there is no current row, or i is out of bounds, or value is not an int64. + Int64(i int) int64 + // Null reports if current row's column i value is NULL. + // Panics if there is no current row, or i is out of bounds. + Null(i int) bool + // Text reports the current row's column i value as text. + // The []byte is valid only as long as the BulkQuery is currently on this row. + // Panics if there is no current row, or i is out of bounds, or value is not an int64. + Text(i int) []byte + + // Query starts the query. + // Any error is reported in the Error method. + Query() + + // Next moves the query to the next row. + Next() bool + + // Error reports any error from the query. + Error() error +} + +// BullkExec executes an SQL statement that returns no rows. +// It is designed to minimize allocations and cgo calls. +type BullkExec interface { + BulkStmt + + // Exec executes the query. + Exec() (lastInsertRowID, changes int64, d time.Duration, err error) +}