Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified lib/darwin_amd64/libsql_experimental.a
Binary file not shown.
42 changes: 42 additions & 0 deletions lib/include/libsql.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@

#include <stdint.h>

#define LIBSQL_INT 1

#define LIBSQL_FLOAT 2

#define LIBSQL_TEXT 3

#define LIBSQL_BLOB 4

#define LIBSQL_NULL 5

typedef struct libsql_connection libsql_connection;

typedef struct libsql_database libsql_database;
Expand All @@ -17,6 +27,16 @@ typedef struct libsql_stmt libsql_stmt;

typedef const libsql_database *libsql_database_t;

typedef struct {
const char *db_path;
const char *primary_url;
const char *auth_token;
char read_your_writes;
const char *encryption_key;
int sync_interval;
char with_webpki;
} libsql_config;

typedef const libsql_connection *libsql_connection_t;

typedef const libsql_stmt *libsql_stmt_t;
Expand Down Expand Up @@ -46,16 +66,36 @@ int libsql_open_sync(const char *db_path,
libsql_database_t *out_db,
const char **out_err_msg);

int libsql_open_sync_with_webpki(const char *db_path,
const char *primary_url,
const char *auth_token,
char read_your_writes,
const char *encryption_key,
libsql_database_t *out_db,
const char **out_err_msg);

int libsql_open_sync_with_config(libsql_config config, libsql_database_t *out_db, const char **out_err_msg);

int libsql_open_ext(const char *url, libsql_database_t *out_db, const char **out_err_msg);

int libsql_open_file(const char *url, libsql_database_t *out_db, const char **out_err_msg);

int libsql_open_remote(const char *url, const char *auth_token, libsql_database_t *out_db, const char **out_err_msg);

int libsql_open_remote_with_webpki(const char *url,
const char *auth_token,
libsql_database_t *out_db,
const char **out_err_msg);

void libsql_close(libsql_database_t db);

int libsql_connect(libsql_database_t db, libsql_connection_t *out_conn, const char **out_err_msg);

int libsql_load_extension(libsql_connection_t conn,
const char *path,
const char *entry_point,
const char **out_err_msg);

int libsql_reset(libsql_connection_t conn, const char **out_err_msg);

void libsql_disconnect(libsql_connection_t conn);
Expand All @@ -76,6 +116,8 @@ int libsql_query_stmt(libsql_stmt_t stmt, libsql_rows_t *out_rows, const char **

int libsql_execute_stmt(libsql_stmt_t stmt, const char **out_err_msg);

int libsql_reset_stmt(libsql_stmt_t stmt, const char **out_err_msg);

void libsql_free_stmt(libsql_stmt_t stmt);

int libsql_query(libsql_connection_t conn, const char *sql, libsql_rows_t *out_rows, const char **out_err_msg);
Expand Down
44 changes: 41 additions & 3 deletions libsql.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
sqldriver "database/sql/driver"
"errors"
"fmt"
"golang.org/x/exp/slices"
"io"
"net/url"
"regexp"
Expand All @@ -40,11 +41,17 @@ func init() {
sql.Register("libsql", driver{})
}

type extension struct {
path string
entryPoint string
}

type config struct {
authToken *string
readYourWrites *bool
encryptionKey *string
syncInterval *time.Duration
extensions []extension
}

type Option interface {
Expand Down Expand Up @@ -103,6 +110,16 @@ func WithSyncInterval(interval time.Duration) Option {
})
}

func WithExtension(path, entryPoint string) Option {
return option(func(o *config) error {
if slices.ContainsFunc(o.extensions, func(e extension) bool { return e.path == path }) {
return fmt.Errorf("extension %s already added", path)
}
o.extensions = append(o.extensions, extension{path, entryPoint})
return nil
})
}

func NewEmbeddedReplicaConnector(dbPath string, primaryUrl string, opts ...Option) (*Connector, error) {
var config config
errs := make([]error, 0, len(opts))
Expand Down Expand Up @@ -130,7 +147,7 @@ func NewEmbeddedReplicaConnector(dbPath string, primaryUrl string, opts ...Optio
if config.syncInterval != nil {
syncInterval = *config.syncInterval
}
return openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken, readYourWrites, encryptionKey, syncInterval)
return openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken, readYourWrites, encryptionKey, syncInterval, config.extensions)
}

type driver struct{}
Expand Down Expand Up @@ -191,7 +208,7 @@ func openRemoteConnector(primaryUrl, authToken string) (*Connector, error) {
return &Connector{nativeDbPtr: nativeDbPtr}, nil
}

func openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string, readYourWrites bool, encryptionKey string, syncInterval time.Duration) (*Connector, error) {
func openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string, readYourWrites bool, encryptionKey string, syncInterval time.Duration, extensions []extension) (*Connector, error) {
var closeCh chan struct{}
var closeAckCh chan struct{}
nativeDbPtr, err := libsqlOpenWithSync(dbPath, primaryUrl, authToken, readYourWrites, encryptionKey)
Expand Down Expand Up @@ -224,10 +241,11 @@ func openEmbeddedReplicaConnector(dbPath, primaryUrl, authToken string, readYour
}
}()
}
return &Connector{nativeDbPtr: nativeDbPtr, closeCh: closeCh, closeAckCh: closeAckCh}, nil
return &Connector{extensions: extensions, nativeDbPtr: nativeDbPtr, closeCh: closeCh, closeAckCh: closeAckCh}, nil
}

type Connector struct {
extensions []extension
nativeDbPtr C.libsql_database_t
closeCh chan<- struct{}
closeAckCh <-chan struct{}
Expand Down Expand Up @@ -256,6 +274,26 @@ func (c *Connector) Connect(ctx context.Context) (sqldriver.Conn, error) {
if err != nil {
return nil, err
}
for _, ext := range c.extensions {
err := func() error {
extPath := C.CString(ext.path)
defer C.free(unsafe.Pointer(extPath))
var extEntryPoint *C.char = nil
if ext.entryPoint != "" {
extEntryPoint = C.CString(ext.entryPoint)
defer C.free(unsafe.Pointer(extEntryPoint))
}
var errMsg *C.char
statusCode := C.libsql_load_extension(nativeConnPtr, extPath, extEntryPoint, &errMsg)
if statusCode != 0 {
return libsqlError(fmt.Sprintf("failed to load extension %s %s", ext.path, ext.entryPoint), statusCode, errMsg)
}
return nil
}()
if err != nil {
return nil, err
}
}
return &conn{nativePtr: nativeConnPtr}, nil
}

Expand Down