Skip to content

Commit 4667c42

Browse files
committed
feat: support statement_timeout and transaction_timeout property
Add a statement_timeout connection property that is used as the default timeout for the execution of all statements that are executed on a connection. The timeout is only used for the actual execution, and not attached to the iterator that is returned for a query. This also means that a query that is executed without the DirectExecuteQuery option, will ignore the statement_timeout value. Also adds a transaction_timeout property that is additionally used for all statements in a read/write transaction. The deadline of the transaction is calculated at the start of the transaction, and all statements in the transaction get this deadline, unless the statement already has an earlier deadline from for example a statement_timeout or a context deadline. This change also fixes some issues with deadlines when using the gRPC API of SpannerLib. The context that is used for an RPC invocation is cancelled after the RPC has finished. This context should therefore not be used as the context for any query execution, as the context is attached to the row iterator, and would cancel the query execution halfway. Fixes #574 Fixes #575
1 parent ad05fde commit 4667c42

17 files changed

+698
-67
lines changed

conn.go

Lines changed: 180 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"database/sql"
2020
"database/sql/driver"
2121
"errors"
22+
"fmt"
2223
"log/slog"
2324
"slices"
2425
"sync"
@@ -831,6 +832,69 @@ func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, err
831832
return &stmt{conn: c, query: parsedSQL, numArgs: len(args), execOptions: execOptions}, nil
832833
}
833834

835+
// Adds any statement or transaction timeout to the given context. The deadline of the returned
836+
// context will be the earliest of:
837+
// 1. Any existing deadline on the input context.
838+
// 2. Any existing transaction deadline.
839+
// 3. A deadline calculated from the current time + the value of statement_timeout.
840+
func (c *conn) addStatementAndTransactionTimeout(ctx context.Context) (context.Context, context.CancelFunc, error) {
841+
var statementDeadline time.Time
842+
var transactionDeadline time.Time
843+
var deadline time.Time
844+
var hasStatementDeadline bool
845+
var hasTransactionDeadline bool
846+
847+
// Check if the connection has a value for statement_timeout.
848+
statementTimeout := propertyStatementTimeout.GetValueOrDefault(c.state)
849+
if statementTimeout != time.Duration(0) {
850+
hasStatementDeadline = true
851+
statementDeadline = time.Now().Add(statementTimeout)
852+
}
853+
// Check if the current transaction has a deadline.
854+
transactionDeadline, hasTransactionDeadline, err := c.transactionDeadline()
855+
if err != nil {
856+
return nil, nil, err
857+
}
858+
859+
// If there is no statement_timeout and no current transaction deadline,
860+
// then can just use the input context as-is.
861+
if !hasStatementDeadline && !hasTransactionDeadline {
862+
return ctx, func() {}, nil
863+
}
864+
865+
// If there is both a transaction and a statement deadline, then we use the earliest
866+
// of those two.
867+
if hasTransactionDeadline && hasStatementDeadline {
868+
if statementDeadline.Before(transactionDeadline) {
869+
deadline = statementDeadline
870+
} else {
871+
deadline = transactionDeadline
872+
}
873+
} else if hasStatementDeadline {
874+
deadline = statementDeadline
875+
} else {
876+
deadline = transactionDeadline
877+
}
878+
// context.WithDeadline automatically selects the earliest deadline of
879+
// the existing deadline on the context and the given deadline.
880+
newCtx, cancel := context.WithDeadline(ctx, deadline)
881+
return newCtx, cancel, nil
882+
}
883+
884+
// transactionDeadline returns the deadline of the current transaction
885+
// on the connection. This also activates the transaction if it is not
886+
// yet activated.
887+
func (c *conn) transactionDeadline() (time.Time, bool, error) {
888+
if c.tx == nil {
889+
return time.Time{}, false, nil
890+
}
891+
if err := c.tx.ensureActivated(); err != nil {
892+
return time.Time{}, false, err
893+
}
894+
deadline, hasDeadline := c.tx.deadline()
895+
return deadline, hasDeadline, nil
896+
}
897+
834898
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
835899
// Execute client side statement if it is one.
836900
clientStmt, err := c.parser.ParseClientSideStatement(query)
@@ -849,13 +913,22 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
849913
return c.queryContext(ctx, query, execOptions, args)
850914
}
851915

852-
func (c *conn) queryContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (driver.Rows, error) {
916+
func (c *conn) queryContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (returnedRows driver.Rows, returnedErr error) {
917+
ctx, cancelCause := context.WithCancelCause(ctx)
918+
cancel := func() {
919+
cancelCause(nil)
920+
}
921+
defer func() {
922+
if returnedErr != nil {
923+
cancel()
924+
}
925+
}()
853926
// Clear the commit timestamp of this connection before we execute the query.
854927
c.clearCommitResponse()
855928
// Check if the execution options contains an instruction to execute
856929
// a specific partition of a PartitionedQuery.
857930
if pq := execOptions.PartitionedQueryOptions.ExecutePartition.PartitionedQuery; pq != nil {
858-
return pq.execute(ctx, execOptions.PartitionedQueryOptions.ExecutePartition.Index)
931+
return pq.execute(ctx, cancel, execOptions.PartitionedQueryOptions.ExecutePartition.Index)
859932
}
860933

861934
stmt, err := prepareSpannerStmt(c.parser, query, args)
@@ -869,7 +942,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
869942
if err != nil {
870943
return nil, err
871944
}
872-
return createDriverResultRows(res, execOptions), nil
945+
return createDriverResultRows(res, cancel, execOptions), nil
873946
}
874947
var iter rowIterator
875948
if c.tx == nil {
@@ -884,7 +957,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
884957
} else if execOptions.PartitionedQueryOptions.PartitionQuery {
885958
return nil, spanner.ToSpannerError(status.Errorf(codes.FailedPrecondition, "PartitionQuery is only supported in batch read-only transactions"))
886959
} else if execOptions.PartitionedQueryOptions.AutoPartitionQuery {
887-
return c.executeAutoPartitionedQuery(ctx, query, execOptions, args)
960+
return c.executeAutoPartitionedQuery(ctx, cancel, query, execOptions, args)
888961
} else {
889962
// The statement was either detected as being a query, or potentially not recognized at all.
890963
// In that case, just default to using a single-use read-only transaction and let Spanner
@@ -893,25 +966,75 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
893966
}
894967
} else {
895968
if execOptions.PartitionedQueryOptions.PartitionQuery {
969+
// The driver.Rows instance that is returned for partitionQuery does not
970+
// contain a context, and therefore also does not cancel the context when it is closed.
971+
defer cancel()
896972
return c.tx.partitionQuery(ctx, stmt, execOptions)
897973
}
898974
iter, err = c.tx.Query(ctx, stmt, statementInfo.StatementType, execOptions)
899975
if err != nil {
900976
return nil, err
901977
}
902978
}
903-
res := createRows(iter, execOptions)
979+
res := createRows(iter, cancel, execOptions)
904980
if execOptions.DirectExecuteQuery {
905-
// This call to res.getColumns() triggers the execution of the statement, as it needs to fetch the metadata.
906-
res.getColumns()
907-
if res.dirtyErr != nil && !errors.Is(res.dirtyErr, iterator.Done) {
908-
_ = res.Close()
909-
return nil, res.dirtyErr
981+
if err := c.directExecuteQuery(ctx, cancelCause, res, execOptions); err != nil {
982+
return nil, err
910983
}
911984
}
912985
return res, nil
913986
}
914987

988+
// directExecuteQuery blocks until the first PartialResultSet has been returned by Spanner. Any statement_timeout and/or
989+
// transaction_timeout is used while waiting for the first result to be returned.
990+
func (c *conn) directExecuteQuery(ctx context.Context, cancelQuery context.CancelCauseFunc, res *rows, execOptions *ExecOptions) error {
991+
statementCtx := ctx
992+
if execOptions.DirectExecuteContext != nil {
993+
statementCtx = execOptions.DirectExecuteContext
994+
}
995+
// Add the statement or transaction deadline to the context.
996+
statementCtx, cancelStatement, err := c.addStatementAndTransactionTimeout(statementCtx)
997+
if err != nil {
998+
return err
999+
}
1000+
defer cancelStatement()
1001+
1002+
// Asynchronously fetch the first partial result set from Spanner.
1003+
done := make(chan struct{})
1004+
go func() {
1005+
// Calling res.getColumns() ensures that the first PartialResultSet has been returned, as it contains the
1006+
// metadata of the query.
1007+
defer close(done)
1008+
res.getColumns()
1009+
}()
1010+
// Wait until either the done channel is closed or the context is done.
1011+
var statementErr error
1012+
select {
1013+
case <-statementCtx.Done():
1014+
statementErr = statementCtx.Err()
1015+
// Cancel the query execution.
1016+
cancelQuery(statementCtx.Err())
1017+
case <-done:
1018+
}
1019+
1020+
// Now wait until done channel is closed. This could be because the execution finished
1021+
// successfully, or because the context was cancelled, which again causes the execution
1022+
// to (eventually) fail.
1023+
<-done
1024+
if res.dirtyErr != nil && !errors.Is(res.dirtyErr, iterator.Done) {
1025+
_ = res.Close()
1026+
if statementErr != nil {
1027+
// Create a status error from the statement error and wrap both the Spanner error and the status error into
1028+
// one error. This will preserve the DeadlineExceeded error code from statementErr, and include the request
1029+
// ID from the Spanner error.
1030+
s := status.FromContextError(statementErr)
1031+
return fmt.Errorf("%w: %w", s.Err(), res.dirtyErr)
1032+
}
1033+
return res.dirtyErr
1034+
}
1035+
return nil
1036+
}
1037+
9151038
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
9161039
// Execute client side statement if it is one.
9171040
stmt, err := c.parser.ParseClientSideStatement(query)
@@ -929,7 +1052,13 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
9291052
return c.execContext(ctx, query, execOptions, args)
9301053
}
9311054

932-
func (c *conn) execContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (driver.Result, error) {
1055+
func (c *conn) execContext(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (returnedResult driver.Result, returnedErr error) {
1056+
// Add the statement/transaction deadline to the context.
1057+
ctx, cancel, err := c.addStatementAndTransactionTimeout(ctx)
1058+
if err != nil {
1059+
return nil, err
1060+
}
1061+
defer cancel()
9331062
// Clear the commit timestamp of this connection before we execute the statement.
9341063
c.clearCommitResponse()
9351064

@@ -1041,6 +1170,18 @@ func (c *conn) resetTransactionForRetry(ctx context.Context, errDuringCommit boo
10411170
return noTransaction()
10421171
}
10431172
c.tx = c.prevTx
1173+
// If the aborted error happened during the Commit, then the transaction
1174+
// context has been cancelled, and we need to create a new one.
1175+
if rwTx, ok := c.tx.contextTransaction.(*readWriteTransaction); ok {
1176+
newCtx, cancel := c.addTransactionTimeout(c.tx.ctx)
1177+
rwTx.ctx = newCtx
1178+
// Make sure that we cancel the new context when the transaction is closed.
1179+
origClose := rwTx.close
1180+
rwTx.close = func(result txResult, commitResponse *spanner.CommitResponse, commitErr error) {
1181+
origClose(result, commitResponse, commitErr)
1182+
cancel()
1183+
}
1184+
}
10441185
c.resetForRetry = true
10451186
} else if c.tx == nil {
10461187
return noTransaction()
@@ -1248,6 +1389,17 @@ func (c *conn) beginTx(ctx context.Context, driverOpts driver.TxOptions, closeFu
12481389
return c.tx, nil
12491390
}
12501391

1392+
// addTransactionTimeout creates a new derived context with the current transaction_timeout.
1393+
func (c *conn) addTransactionTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
1394+
timeout := propertyTransactionTimeout.GetValueOrDefault(c.state)
1395+
if timeout == time.Duration(0) {
1396+
return ctx, func() {}
1397+
}
1398+
// Note that this will set the actual deadline to the earliest of the existing deadline on ctx and the calculated
1399+
// deadline based on the timeout.
1400+
return context.WithTimeout(ctx, timeout)
1401+
}
1402+
12511403
func (c *conn) activateTransaction() (contextTransaction, error) {
12521404
closeFunc := c.tx.close
12531405
if propertyTransactionReadOnly.GetValueOrDefault(c.state) {
@@ -1283,19 +1435,23 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
12831435
opts := spanner.TransactionOptions{}
12841436
opts.BeginTransactionOption = c.convertDefaultBeginTransactionOption(propertyBeginTransactionOption.GetValueOrDefault(c.state))
12851437

1286-
tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(c.tx.ctx, c.client, opts, func() spanner.TransactionOptions {
1438+
// Add the current value of transaction_timeout to the context that is registered
1439+
// on the transaction.
1440+
ctx, cancel := c.addTransactionTimeout(c.tx.ctx)
1441+
tx, err := spanner.NewReadWriteStmtBasedTransactionWithCallbackForOptions(ctx, c.client, opts, func() spanner.TransactionOptions {
12871442
defer func() {
12881443
// Reset the transaction_tag after starting the transaction.
12891444
_ = propertyTransactionTag.ResetValue(c.state, connectionstate.ContextUser)
12901445
}()
12911446
return c.effectiveTransactionOptions(spannerpb.TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED, c.options( /*reset=*/ true))
12921447
})
12931448
if err != nil {
1449+
cancel()
12941450
return nil, err
12951451
}
12961452
logger := c.logger.With("tx", "rw")
12971453
return &readWriteTransaction{
1298-
ctx: c.tx.ctx,
1454+
ctx: ctx,
12991455
conn: c,
13001456
logger: logger,
13011457
rwTx: tx,
@@ -1307,6 +1463,7 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
13071463
} else {
13081464
closeFunc(txResultRollback)
13091465
}
1466+
cancel()
13101467
},
13111468
retryAborts: sync.OnceValue(func() bool {
13121469
return c.RetryAbortsInternally()
@@ -1371,7 +1528,15 @@ func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.
13711528
return c.Single().WithTimestampBound(tb).QueryWithOptions(ctx, statement, options.QueryOptions)
13721529
}
13731530

1374-
func (c *conn) executeAutoPartitionedQuery(ctx context.Context, query string, execOptions *ExecOptions, args []driver.NamedValue) (driver.Rows, error) {
1531+
func (c *conn) executeAutoPartitionedQuery(ctx context.Context, cancel context.CancelFunc, query string, execOptions *ExecOptions, args []driver.NamedValue) (returnedRows driver.Rows, returnedErr error) {
1532+
// The cancel() function is called by the returned Rows object when it is closed.
1533+
// However, if an error is returned instead of a Rows instance, we need to cancel
1534+
// the context when we return from this function.
1535+
defer func() {
1536+
if returnedErr != nil {
1537+
cancel()
1538+
}
1539+
}()
13751540
tx, err := c.BeginTx(ctx, driver.TxOptions{ReadOnly: true, Isolation: withBatchReadOnly(driver.IsolationLevel(sql.LevelDefault))})
13761541
if err != nil {
13771542
return nil, err
@@ -1383,6 +1548,7 @@ func (c *conn) executeAutoPartitionedQuery(ctx context.Context, query string, ex
13831548
}
13841549
if rows, ok := r.(*rows); ok {
13851550
rows.close = func() error {
1551+
defer cancel()
13861552
return tx.Commit()
13871553
}
13881554
}

0 commit comments

Comments
 (0)