@@ -19,6 +19,7 @@ import (
1919 "database/sql"
2020 "database/sql/driver"
2121 "errors"
22+ "fmt"
2223 "log/slog"
2324 "slices"
2425 "sync"
@@ -831,6 +832,69 @@ func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, err
831832 return & stmt {conn : c , query : parsedSQL , numArgs : len (args ), execOptions : execOptions }, nil
832833}
833834
835+ // Adds any statement or transaction timeout to the given context. The deadline of the returned
836+ // context will be the earliest of:
837+ // 1. Any existing deadline on the input context.
838+ // 2. Any existing transaction deadline.
839+ // 3. A deadline calculated from the current time + the value of statement_timeout.
840+ func (c * conn ) addStatementAndTransactionTimeout (ctx context.Context ) (context.Context , context.CancelFunc , error ) {
841+ var statementDeadline time.Time
842+ var transactionDeadline time.Time
843+ var deadline time.Time
844+ var hasStatementDeadline bool
845+ var hasTransactionDeadline bool
846+
847+ // Check if the connection has a value for statement_timeout.
848+ statementTimeout := propertyStatementTimeout .GetValueOrDefault (c .state )
849+ if statementTimeout != time .Duration (0 ) {
850+ hasStatementDeadline = true
851+ statementDeadline = time .Now ().Add (statementTimeout )
852+ }
853+ // Check if the current transaction has a deadline.
854+ transactionDeadline , hasTransactionDeadline , err := c .transactionDeadline ()
855+ if err != nil {
856+ return nil , nil , err
857+ }
858+
859+ // If there is no statement_timeout and no current transaction deadline,
860+ // then can just use the input context as-is.
861+ if ! hasStatementDeadline && ! hasTransactionDeadline {
862+ return ctx , func () {}, nil
863+ }
864+
865+ // If there is both a transaction and a statement deadline, then we use the earliest
866+ // of those two.
867+ if hasTransactionDeadline && hasStatementDeadline {
868+ if statementDeadline .Before (transactionDeadline ) {
869+ deadline = statementDeadline
870+ } else {
871+ deadline = transactionDeadline
872+ }
873+ } else if hasStatementDeadline {
874+ deadline = statementDeadline
875+ } else {
876+ deadline = transactionDeadline
877+ }
878+ // context.WithDeadline automatically selects the earliest deadline of
879+ // the existing deadline on the context and the given deadline.
880+ newCtx , cancel := context .WithDeadline (ctx , deadline )
881+ return newCtx , cancel , nil
882+ }
883+
884+ // transactionDeadline returns the deadline of the current transaction
885+ // on the connection. This also activates the transaction if it is not
886+ // yet activated.
887+ func (c * conn ) transactionDeadline () (time.Time , bool , error ) {
888+ if c .tx == nil {
889+ return time.Time {}, false , nil
890+ }
891+ if err := c .tx .ensureActivated (); err != nil {
892+ return time.Time {}, false , err
893+ }
894+ deadline , hasDeadline := c .tx .deadline ()
895+ return deadline , hasDeadline , nil
896+ }
897+
834898func (c * conn ) QueryContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Rows , error ) {
835899 // Execute client side statement if it is one.
836900 clientStmt , err := c .parser .ParseClientSideStatement (query )
@@ -849,13 +913,22 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
849913 return c .queryContext (ctx , query , execOptions , args )
850914}
851915
852- func (c * conn ) queryContext (ctx context.Context , query string , execOptions * ExecOptions , args []driver.NamedValue ) (driver.Rows , error ) {
916+ func (c * conn ) queryContext (ctx context.Context , query string , execOptions * ExecOptions , args []driver.NamedValue ) (returnedRows driver.Rows , returnedErr error ) {
917+ ctx , cancelCause := context .WithCancelCause (ctx )
918+ cancel := func () {
919+ cancelCause (nil )
920+ }
921+ defer func () {
922+ if returnedErr != nil {
923+ cancel ()
924+ }
925+ }()
853926 // Clear the commit timestamp of this connection before we execute the query.
854927 c .clearCommitResponse ()
855928 // Check if the execution options contains an instruction to execute
856929 // a specific partition of a PartitionedQuery.
857930 if pq := execOptions .PartitionedQueryOptions .ExecutePartition .PartitionedQuery ; pq != nil {
858- return pq .execute (ctx , execOptions .PartitionedQueryOptions .ExecutePartition .Index )
931+ return pq .execute (ctx , cancel , execOptions .PartitionedQueryOptions .ExecutePartition .Index )
859932 }
860933
861934 stmt , err := prepareSpannerStmt (c .parser , query , args )
@@ -869,7 +942,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
869942 if err != nil {
870943 return nil , err
871944 }
872- return createDriverResultRows (res , execOptions ), nil
945+ return createDriverResultRows (res , cancel , execOptions ), nil
873946 }
874947 var iter rowIterator
875948 if c .tx == nil {
@@ -884,7 +957,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
884957 } else if execOptions .PartitionedQueryOptions .PartitionQuery {
885958 return nil , spanner .ToSpannerError (status .Errorf (codes .FailedPrecondition , "PartitionQuery is only supported in batch read-only transactions" ))
886959 } else if execOptions .PartitionedQueryOptions .AutoPartitionQuery {
887- return c .executeAutoPartitionedQuery (ctx , query , execOptions , args )
960+ return c .executeAutoPartitionedQuery (ctx , cancel , query , execOptions , args )
888961 } else {
889962 // The statement was either detected as being a query, or potentially not recognized at all.
890963 // In that case, just default to using a single-use read-only transaction and let Spanner
@@ -893,25 +966,75 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec
893966 }
894967 } else {
895968 if execOptions .PartitionedQueryOptions .PartitionQuery {
969+ // The driver.Rows instance that is returned for partitionQuery does not
970+ // contain a context, and therefore also does not cancel the context when it is closed.
971+ defer cancel ()
896972 return c .tx .partitionQuery (ctx , stmt , execOptions )
897973 }
898974 iter , err = c .tx .Query (ctx , stmt , statementInfo .StatementType , execOptions )
899975 if err != nil {
900976 return nil , err
901977 }
902978 }
903- res := createRows (iter , execOptions )
979+ res := createRows (iter , cancel , execOptions )
904980 if execOptions .DirectExecuteQuery {
905- // This call to res.getColumns() triggers the execution of the statement, as it needs to fetch the metadata.
906- res .getColumns ()
907- if res .dirtyErr != nil && ! errors .Is (res .dirtyErr , iterator .Done ) {
908- _ = res .Close ()
909- return nil , res .dirtyErr
981+ if err := c .directExecuteQuery (ctx , cancelCause , res , execOptions ); err != nil {
982+ return nil , err
910983 }
911984 }
912985 return res , nil
913986}
914987
988+ // directExecuteQuery blocks until the first PartialResultSet has been returned by Spanner. Any statement_timeout and/or
989+ // transaction_timeout is used while waiting for the first result to be returned.
990+ func (c * conn ) directExecuteQuery (ctx context.Context , cancelQuery context.CancelCauseFunc , res * rows , execOptions * ExecOptions ) error {
991+ statementCtx := ctx
992+ if execOptions .DirectExecuteContext != nil {
993+ statementCtx = execOptions .DirectExecuteContext
994+ }
995+ // Add the statement or transaction deadline to the context.
996+ statementCtx , cancelStatement , err := c .addStatementAndTransactionTimeout (statementCtx )
997+ if err != nil {
998+ return err
999+ }
1000+ defer cancelStatement ()
1001+
1002+ // Asynchronously fetch the first partial result set from Spanner.
1003+ done := make (chan struct {})
1004+ go func () {
1005+ // Calling res.getColumns() ensures that the first PartialResultSet has been returned, as it contains the
1006+ // metadata of the query.
1007+ defer close (done )
1008+ res .getColumns ()
1009+ }()
1010+ // Wait until either the done channel is closed or the context is done.
1011+ var statementErr error
1012+ select {
1013+ case <- statementCtx .Done ():
1014+ statementErr = statementCtx .Err ()
1015+ // Cancel the query execution.
1016+ cancelQuery (statementCtx .Err ())
1017+ case <- done :
1018+ }
1019+
1020+ // Now wait until done channel is closed. This could be because the execution finished
1021+ // successfully, or because the context was cancelled, which again causes the execution
1022+ // to (eventually) fail.
1023+ <- done
1024+ if res .dirtyErr != nil && ! errors .Is (res .dirtyErr , iterator .Done ) {
1025+ _ = res .Close ()
1026+ if statementErr != nil {
1027+ // Create a status error from the statement error and wrap both the Spanner error and the status error into
1028+ // one error. This will preserve the DeadlineExceeded error code from statementErr, and include the request
1029+ // ID from the Spanner error.
1030+ s := status .FromContextError (statementErr )
1031+ return fmt .Errorf ("%w: %w" , s .Err (), res .dirtyErr )
1032+ }
1033+ return res .dirtyErr
1034+ }
1035+ return nil
1036+ }
1037+
9151038func (c * conn ) ExecContext (ctx context.Context , query string , args []driver.NamedValue ) (driver.Result , error ) {
9161039 // Execute client side statement if it is one.
9171040 stmt , err := c .parser .ParseClientSideStatement (query )
@@ -929,7 +1052,13 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
9291052 return c .execContext (ctx , query , execOptions , args )
9301053}
9311054
932- func (c * conn ) execContext (ctx context.Context , query string , execOptions * ExecOptions , args []driver.NamedValue ) (driver.Result , error ) {
1055+ func (c * conn ) execContext (ctx context.Context , query string , execOptions * ExecOptions , args []driver.NamedValue ) (returnedResult driver.Result , returnedErr error ) {
1056+ // Add the statement/transaction deadline to the context.
1057+ ctx , cancel , err := c .addStatementAndTransactionTimeout (ctx )
1058+ if err != nil {
1059+ return nil , err
1060+ }
1061+ defer cancel ()
9331062 // Clear the commit timestamp of this connection before we execute the statement.
9341063 c .clearCommitResponse ()
9351064
@@ -1041,6 +1170,18 @@ func (c *conn) resetTransactionForRetry(ctx context.Context, errDuringCommit boo
10411170 return noTransaction ()
10421171 }
10431172 c .tx = c .prevTx
1173+ // If the aborted error happened during the Commit, then the transaction
1174+ // context has been cancelled, and we need to create a new one.
1175+ if rwTx , ok := c .tx .contextTransaction .(* readWriteTransaction ); ok {
1176+ newCtx , cancel := c .addTransactionTimeout (c .tx .ctx )
1177+ rwTx .ctx = newCtx
1178+ // Make sure that we cancel the new context when the transaction is closed.
1179+ origClose := rwTx .close
1180+ rwTx .close = func (result txResult , commitResponse * spanner.CommitResponse , commitErr error ) {
1181+ origClose (result , commitResponse , commitErr )
1182+ cancel ()
1183+ }
1184+ }
10441185 c .resetForRetry = true
10451186 } else if c .tx == nil {
10461187 return noTransaction ()
@@ -1248,6 +1389,17 @@ func (c *conn) beginTx(ctx context.Context, driverOpts driver.TxOptions, closeFu
12481389 return c .tx , nil
12491390}
12501391
1392+ // addTransactionTimeout creates a new derived context with the current transaction_timeout.
1393+ func (c * conn ) addTransactionTimeout (ctx context.Context ) (context.Context , context.CancelFunc ) {
1394+ timeout := propertyTransactionTimeout .GetValueOrDefault (c .state )
1395+ if timeout == time .Duration (0 ) {
1396+ return ctx , func () {}
1397+ }
1398+ // Note that this will set the actual deadline to the earliest of the existing deadline on ctx and the calculated
1399+ // deadline based on the timeout.
1400+ return context .WithTimeout (ctx , timeout )
1401+ }
1402+
12511403func (c * conn ) activateTransaction () (contextTransaction , error ) {
12521404 closeFunc := c .tx .close
12531405 if propertyTransactionReadOnly .GetValueOrDefault (c .state ) {
@@ -1283,19 +1435,23 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
12831435 opts := spanner.TransactionOptions {}
12841436 opts .BeginTransactionOption = c .convertDefaultBeginTransactionOption (propertyBeginTransactionOption .GetValueOrDefault (c .state ))
12851437
1286- tx , err := spanner .NewReadWriteStmtBasedTransactionWithCallbackForOptions (c .tx .ctx , c .client , opts , func () spanner.TransactionOptions {
1438+ // Add the current value of transaction_timeout to the context that is registered
1439+ // on the transaction.
1440+ ctx , cancel := c .addTransactionTimeout (c .tx .ctx )
1441+ tx , err := spanner .NewReadWriteStmtBasedTransactionWithCallbackForOptions (ctx , c .client , opts , func () spanner.TransactionOptions {
12871442 defer func () {
12881443 // Reset the transaction_tag after starting the transaction.
12891444 _ = propertyTransactionTag .ResetValue (c .state , connectionstate .ContextUser )
12901445 }()
12911446 return c .effectiveTransactionOptions (spannerpb .TransactionOptions_ISOLATION_LEVEL_UNSPECIFIED , c .options ( /*reset=*/ true ))
12921447 })
12931448 if err != nil {
1449+ cancel ()
12941450 return nil , err
12951451 }
12961452 logger := c .logger .With ("tx" , "rw" )
12971453 return & readWriteTransaction {
1298- ctx : c . tx . ctx ,
1454+ ctx : ctx ,
12991455 conn : c ,
13001456 logger : logger ,
13011457 rwTx : tx ,
@@ -1307,6 +1463,7 @@ func (c *conn) activateTransaction() (contextTransaction, error) {
13071463 } else {
13081464 closeFunc (txResultRollback )
13091465 }
1466+ cancel ()
13101467 },
13111468 retryAborts : sync .OnceValue (func () bool {
13121469 return c .RetryAbortsInternally ()
@@ -1371,7 +1528,15 @@ func queryInSingleUse(ctx context.Context, c *spanner.Client, statement spanner.
13711528 return c .Single ().WithTimestampBound (tb ).QueryWithOptions (ctx , statement , options .QueryOptions )
13721529}
13731530
1374- func (c * conn ) executeAutoPartitionedQuery (ctx context.Context , query string , execOptions * ExecOptions , args []driver.NamedValue ) (driver.Rows , error ) {
1531+ func (c * conn ) executeAutoPartitionedQuery (ctx context.Context , cancel context.CancelFunc , query string , execOptions * ExecOptions , args []driver.NamedValue ) (returnedRows driver.Rows , returnedErr error ) {
1532+ // The cancel() function is called by the returned Rows object when it is closed.
1533+ // However, if an error is returned instead of a Rows instance, we need to cancel
1534+ // the context when we return from this function.
1535+ defer func () {
1536+ if returnedErr != nil {
1537+ cancel ()
1538+ }
1539+ }()
13751540 tx , err := c .BeginTx (ctx , driver.TxOptions {ReadOnly : true , Isolation : withBatchReadOnly (driver .IsolationLevel (sql .LevelDefault ))})
13761541 if err != nil {
13771542 return nil , err
@@ -1383,6 +1548,7 @@ func (c *conn) executeAutoPartitionedQuery(ctx context.Context, query string, ex
13831548 }
13841549 if rows , ok := r .(* rows ); ok {
13851550 rows .close = func () error {
1551+ defer cancel ()
13861552 return tx .Commit ()
13871553 }
13881554 }
0 commit comments