Skip to content

Commit 93f38ba

Browse files
committed
feat: add transactions
1 parent 21a1c4a commit 93f38ba

File tree

13 files changed

+344
-94
lines changed

13 files changed

+344
-94
lines changed

conn.go

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ type SpannerConn interface {
5959
// RunBatch sends all batched DDL or DML statements to Spanner. This is a
6060
// no-op if no statements have been batched or if there is no active batch.
6161
RunBatch(ctx context.Context) error
62+
// RunDmlBatch sends all batched DML statements to Spanner. This is a
63+
// no-op if no statements have been batched or if there is no active DML batch.
64+
RunDmlBatch(ctx context.Context) (SpannerResult, error)
6265
// AbortBatch aborts the current DDL or DML batch and discards all batched
6366
// statements.
6467
AbortBatch() error
@@ -202,6 +205,11 @@ type SpannerConn interface {
202205
withTempBatchReadOnlyTransactionOptions(options *BatchReadOnlyTransactionOptions)
203206
}
204207

208+
type SpannerResult interface {
209+
driver.Result
210+
BatchRowsAffected() ([]int64, error)
211+
}
212+
205213
var _ SpannerConn = &conn{}
206214

207215
type conn struct {
@@ -261,6 +269,7 @@ type conn struct {
261269
// tempBatchReadOnlyTransactionOptions are temporarily set right before a
262270
// batch read-only transaction is started on a Spanner connection.
263271
tempBatchReadOnlyTransactionOptions *BatchReadOnlyTransactionOptions
272+
tempProtoTransactionOptions *spannerpb.TransactionOptions
264273
}
265274

266275
func (c *conn) UnderlyingClient() (*spanner.Client, error) {
@@ -444,6 +453,18 @@ func (c *conn) RunBatch(ctx context.Context) error {
444453
return err
445454
}
446455

456+
func (c *conn) RunDmlBatch(ctx context.Context) (SpannerResult, error) {
457+
res, err := c.runBatch(ctx)
458+
if err != nil {
459+
return nil, err
460+
}
461+
spannerRes, ok := res.(SpannerResult)
462+
if !ok {
463+
return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "not a DML batch"))
464+
}
465+
return spannerRes, nil
466+
}
467+
447468
func (c *conn) AbortBatch() error {
448469
_, err := c.abortBatch()
449470
return err
@@ -522,7 +543,7 @@ func (c *conn) runDDLBatch(ctx context.Context) (driver.Result, error) {
522543
return c.execDDL(ctx, statements...)
523544
}
524545

525-
func (c *conn) runDMLBatch(ctx context.Context) (driver.Result, error) {
546+
func (c *conn) runDMLBatch(ctx context.Context) (SpannerResult, error) {
526547
statements := c.batch.statements
527548
options := c.batch.options
528549
options.QueryOptions.LastStatement = true
@@ -541,7 +562,7 @@ func (c *conn) abortBatch() (driver.Result, error) {
541562

542563
func (c *conn) execDDL(ctx context.Context, statements ...spanner.Statement) (driver.Result, error) {
543564
if c.batch != nil && c.batch.tp == dml {
544-
return nil, spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "This connection has an active DML batch"))
565+
return nil, spanner.ToSpannerError(status.Error(codes.FailedPrecondition, "This connection has an active DDL batch"))
545566
}
546567
if c.batch != nil && c.batch.tp == ddl {
547568
c.batch.statements = append(c.batch.statements, statements...)
@@ -567,7 +588,7 @@ func (c *conn) execDDL(ctx context.Context, statements ...spanner.Statement) (dr
567588
return driver.ResultNoRows, nil
568589
}
569590

570-
func (c *conn) execBatchDML(ctx context.Context, statements []spanner.Statement, options ExecOptions) (driver.Result, error) {
591+
func (c *conn) execBatchDML(ctx context.Context, statements []spanner.Statement, options ExecOptions) (SpannerResult, error) {
571592
if len(statements) == 0 {
572593
return &result{}, nil
573594
}
@@ -586,7 +607,7 @@ func (c *conn) execBatchDML(ctx context.Context, statements []spanner.Statement,
586607
return err
587608
}, options.TransactionOptions)
588609
}
589-
return &result{rowsAffected: sum(affected)}, err
610+
return &result{rowsAffected: sum(affected), batchUpdateCounts: affected}, err
590611
}
591612

592613
func sum(affected []int64) int64 {

conn_with_mockserver_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"cloud.google.com/go/spanner"
2525
"cloud.google.com/go/spanner/apiv1/spannerpb"
2626
"github.com/googleapis/go-sql-spanner/testutil"
27+
"google.golang.org/grpc/codes"
2728
)
2829

2930
func TestBeginTx(t *testing.T) {
@@ -48,6 +49,25 @@ func TestBeginTx(t *testing.T) {
4849
}
4950
}
5051

52+
func TestTwoTransactionsOnOneConn(t *testing.T) {
53+
t.Parallel()
54+
55+
db, _, teardown := setupTestDBConnection(t)
56+
defer teardown()
57+
ctx := context.Background()
58+
59+
c, _ := db.Conn(ctx)
60+
tx1, err := c.BeginTx(ctx, &sql.TxOptions{})
61+
defer tx1.Rollback()
62+
if err != nil {
63+
t.Fatal(err)
64+
}
65+
_, err = c.BeginTx(ctx, &sql.TxOptions{})
66+
if g, w := spanner.ErrCode(err), codes.FailedPrecondition; g != w {
67+
t.Fatalf("BeginTx error code mismatch\n Got: %v\nWant: %v", g, w)
68+
}
69+
}
70+
5171
func TestBeginTxWithIsolationLevel(t *testing.T) {
5272
t.Parallel()
5373

driver.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,10 @@ func BeginReadWriteTransaction(ctx context.Context, db *sql.DB, options ReadWrit
954954
// be active when we hit this point.
955955
go conn.Close()
956956
}
957+
return BeginReadWriteTransactionOnConn(ctx, conn, options)
958+
}
959+
960+
func BeginReadWriteTransactionOnConn(ctx context.Context, conn *sql.Conn, options ReadWriteTransactionOptions) (*sql.Tx, error) {
957961
if err := withTempReadWriteTransactionOptions(conn, &options); err != nil {
958962
return nil, err
959963
}
@@ -1005,6 +1009,10 @@ func BeginReadOnlyTransaction(ctx context.Context, db *sql.DB, options ReadOnlyT
10051009
// be active when we hit this point.
10061010
go conn.Close()
10071011
}
1012+
return BeginReadOnlyTransactionOnConn(ctx, conn, options)
1013+
}
1014+
1015+
func BeginReadOnlyTransactionOnConn(ctx context.Context, conn *sql.Conn, options ReadOnlyTransactionOptions) (*sql.Tx, error) {
10081016
if err := withTempReadOnlyTransactionOptions(conn, &options); err != nil {
10091017
return nil, err
10101018
}

driver_with_mockserver_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4801,9 +4801,8 @@ func setupTestDBConnectionWithParams(t *testing.T, params string) (db *sql.DB, s
48014801

48024802
func setupTestDBConnectionWithParamsAndDialect(t *testing.T, params string, dialect databasepb.DatabaseDialect) (db *sql.DB, server *testutil.MockedSpannerInMemTestServer, teardown func()) {
48034803
server, _, serverTeardown := setupMockedTestServerWithDialect(t, dialect)
4804-
db, err := sql.Open(
4805-
"spanner",
4806-
fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true;%s", server.Address, params))
4804+
dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true;%s", server.Address, params)
4805+
db, err := sql.Open("spanner", dsn)
48074806
if err != nil {
48084807
serverTeardown()
48094808
t.Fatal(err)

spannerlib/backend/db_pool.go

Lines changed: 17 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,60 +3,35 @@ package backend
33
import (
44
"context"
55
"database/sql"
6-
"errors"
7-
"fmt"
8-
"sync"
96

107
spannerdriver "github.com/googleapis/go-sql-spanner"
118
)
129

10+
// Pool is a simple wrapper around sql.DB and contains a pool of connections.
1311
type Pool struct {
14-
Project string
15-
Instance string
16-
Database string
17-
18-
mu sync.Mutex
19-
entries map[string]*sql.DB
20-
}
21-
22-
func (pool *Pool) Close() (err error) {
23-
pool.mu.Lock()
24-
defer pool.mu.Unlock()
25-
for _, db := range pool.entries {
26-
err = errors.Join(err, db.Close())
27-
}
28-
return err
12+
db *sql.DB
2913
}
3014

31-
func (pool *Pool) Conn(ctx context.Context, project, instance, database string) (*sql.Conn, error) {
32-
if project == "" {
33-
project = pool.Project
34-
}
35-
if instance == "" {
36-
instance = pool.Instance
37-
}
38-
if database == "" {
39-
database = pool.Database
40-
}
41-
key := fmt.Sprintf("projects/%s/instances/%s/databases/%s", project, instance, database)
42-
pool.mu.Lock()
43-
defer pool.mu.Unlock()
44-
if db, ok := pool.entries[key]; ok {
45-
return db.Conn(ctx)
46-
}
47-
config := spannerdriver.ConnectorConfig{
48-
Project: project,
49-
Instance: instance,
50-
Database: database,
15+
func CreatePool(dsn string) (*Pool, error) {
16+
config, err := spannerdriver.ExtractConnectorConfig(dsn)
17+
if err != nil {
18+
return nil, err
5119
}
5220
connector, err := spannerdriver.CreateConnector(config)
5321
if err != nil {
5422
return nil, err
5523
}
5624
db := sql.OpenDB(connector)
57-
if pool.entries == nil {
58-
pool.entries = make(map[string]*sql.DB)
25+
pool := &Pool{
26+
db: db,
5927
}
60-
pool.entries[key] = db
61-
return db.Conn(ctx)
28+
return pool, nil
29+
}
30+
31+
func (pool *Pool) Close() (err error) {
32+
return pool.db.Close()
33+
}
34+
35+
func (pool *Pool) Conn(ctx context.Context) (*sql.Conn, error) {
36+
return pool.db.Conn(ctx)
6237
}

0 commit comments

Comments
 (0)