Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 38 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -872,13 +872,14 @@
return createDriverResultRows(res, execOptions), nil
}
var iter rowIterator
ctx, cancel := context.WithCancelCause(ctx)

Check failure on line 875 in conn.go

View workflow job for this annotation

GitHub Actions / lint

the cancel function is not used on all paths (possible context leak)
if c.tx == nil {
if statementInfo.StatementType == parser.StatementTypeDml {
// Use a read/write transaction to execute the statement.
var commitResponse *spanner.CommitResponse
iter, commitResponse, err = c.execSingleQueryTransactional(ctx, c.client, stmt, statementInfo, execOptions)
if err != nil {
return nil, err

Check failure on line 882 in conn.go

View workflow job for this annotation

GitHub Actions / lint

this return statement may be reached without using the cancel var defined on line 875
}
c.setCommitResponse(commitResponse)
} else if execOptions.PartitionedQueryOptions.PartitionQuery {
Expand All @@ -902,14 +903,49 @@
}
res := createRows(iter, execOptions)
if execOptions.DirectExecuteQuery {
if err := c.directExecuteQuery(cancel, res, execOptions); err != nil {
return nil, err
}
}
return res, nil
}

func (c *conn) directExecuteQuery(cancel context.CancelCauseFunc, res *rows, execOptions *ExecOptions) error {
if execOptions.DirectExecuteContext == nil {
// This call to res.getColumns() triggers the execution of the statement, as it needs to fetch the metadata.
res.getColumns()
if res.dirtyErr != nil && !errors.Is(res.dirtyErr, iterator.Done) {
_ = res.Close()
return nil, res.dirtyErr
return res.dirtyErr
}
return nil
}
return res, nil

// Asynchronously fetch the first partial result set from Spanner.
done := make(chan struct{})
go func() {
defer close(done)
res.getColumns()
}()
// Wait until either the done channel is closed or the context is done.
select {
case <-execOptions.DirectExecuteContext.Done():
// Cancel the execution.
cancel(execOptions.DirectExecuteContext.Err())
case <-done:
}

// Now wait until done channel is closed. This could be because the execution finished
// successfully, or because the context was cancelled, which again causes the execution
// to (eventually) fail.
select {
case <-done:
}
if res.dirtyErr != nil && !errors.Is(res.dirtyErr, iterator.Done) {
_ = res.Close()
return res.dirtyErr
}
return nil
}

func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
Expand Down
9 changes: 8 additions & 1 deletion driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,16 @@ type ExecOptions struct {
// order to move to the result set that contains the spannerpb.ResultSetStats.
ReturnResultSetStats bool

// DirectExecute determines whether a query is executed directly when the
// DirectExecuteQuery determines whether a query is executed directly when the
// [sql.DB.QueryContext] method is called, or whether the actual query execution
// is delayed until the first call to [sql.Rows.Next]. The default is to delay
// the execution. Set this flag to true to execute the query directly when
// [sql.DB.QueryContext] is called.
DirectExecuteQuery bool

// DirectExecuteContext is the context that is used for the execution of a query
// when DirectExecuteQuery is enabled.
DirectExecuteContext context.Context
}

func (dest *ExecOptions) merge(src *ExecOptions) {
Expand All @@ -231,6 +235,9 @@ func (dest *ExecOptions) merge(src *ExecOptions) {
if src.DirectExecuteQuery {
dest.DirectExecuteQuery = src.DirectExecuteQuery
}
if src.DirectExecuteContext != nil {
dest.DirectExecuteContext = src.DirectExecuteContext
}
if src.AutocommitDMLMode != Unspecified {
dest.AutocommitDMLMode = src.AutocommitDMLMode
}
Expand Down
23 changes: 16 additions & 7 deletions spannerlib/api/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ func WriteMutations(ctx context.Context, poolId, connId int64, mutations *spanne
// BeginTransaction starts a new transaction on the given connection.
// A connection can have at most one transaction at any time. This function therefore returns an error if the
// connection has an active transaction.
//
// NOTE: The context that is passed in to this function is registered as the transaction context. The transaction is
// invalidated if the context is cancelled. The context that is passed in to this function should therefore not be a
// context that is cancelled right after calling this function.
func BeginTransaction(ctx context.Context, poolId, connId int64, txOpts *spannerpb.TransactionOptions) error {
conn, err := findConnection(poolId, connId)
if err != nil {
Expand Down Expand Up @@ -93,11 +97,15 @@ func Rollback(ctx context.Context, poolId, connId int64) error {
}

func Execute(ctx context.Context, poolId, connId int64, executeSqlRequest *spannerpb.ExecuteSqlRequest) (int64, error) {
return ExecuteWithDirectExecuteContext(ctx, nil, poolId, connId, executeSqlRequest)
}

func ExecuteWithDirectExecuteContext(ctx, directExecuteContext context.Context, poolId, connId int64, executeSqlRequest *spannerpb.ExecuteSqlRequest) (int64, error) {
conn, err := findConnection(poolId, connId)
if err != nil {
return 0, err
}
return conn.Execute(ctx, executeSqlRequest)
return conn.Execute(ctx, directExecuteContext, executeSqlRequest)
}

func ExecuteBatch(ctx context.Context, poolId, connId int64, statements *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) {
Expand Down Expand Up @@ -300,16 +308,16 @@ func (conn *Connection) closeResults(ctx context.Context) {
})
}

func (conn *Connection) Execute(ctx context.Context, statement *spannerpb.ExecuteSqlRequest) (int64, error) {
return execute(ctx, conn, conn.backend, statement)
func (conn *Connection) Execute(ctx, directExecuteContext context.Context, statement *spannerpb.ExecuteSqlRequest) (int64, error) {
return execute(ctx, directExecuteContext, conn, conn.backend, statement)
}

func (conn *Connection) ExecuteBatch(ctx context.Context, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) {
return executeBatch(ctx, conn, conn.backend, statements)
}

func execute(ctx context.Context, conn *Connection, executor queryExecutor, statement *spannerpb.ExecuteSqlRequest) (int64, error) {
params := extractParams(statement)
func execute(ctx, directExecuteContext context.Context, conn *Connection, executor queryExecutor, statement *spannerpb.ExecuteSqlRequest) (int64, error) {
params := extractParams(directExecuteContext, statement)
it, err := executor.QueryContext(ctx, statement.Sql, params...)
if err != nil {
return 0, err
Expand Down Expand Up @@ -397,7 +405,7 @@ func executeBatchDml(ctx context.Context, conn *Connection, executor queryExecut
Params: statement.Params,
ParamTypes: statement.ParamTypes,
}
params := extractParams(request)
params := extractParams(nil, request)
_, err := executor.ExecContext(ctx, statement.Sql, params...)
if err != nil {
return nil, err
Expand All @@ -423,7 +431,7 @@ func executeBatchDml(ctx context.Context, conn *Connection, executor queryExecut
return &response, nil
}

func extractParams(statement *spannerpb.ExecuteSqlRequest) []any {
func extractParams(directExecuteContext context.Context, statement *spannerpb.ExecuteSqlRequest) []any {
paramsLen := 1
if statement.Params != nil {
paramsLen = 1 + len(statement.Params.Fields)
Expand All @@ -436,6 +444,7 @@ func extractParams(statement *spannerpb.ExecuteSqlRequest) []any {
ReturnResultSetMetadata: true,
ReturnResultSetStats: true,
DirectExecuteQuery: true,
DirectExecuteContext: directExecuteContext,
})
if statement.Params != nil {
if statement.ParamTypes == nil {
Expand Down
31 changes: 7 additions & 24 deletions spannerlib/grpc-server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,35 +102,23 @@ func (s *spannerLibServer) CloseConnection(ctx context.Context, connection *pb.C
return &emptypb.Empty{}, nil
}

func contextWithSameDeadline(ctx context.Context) context.Context {
newContext := context.Background()
if deadline, ok := ctx.Deadline(); ok {
// Ignore the returned cancel function here, as the context will be closed when the Rows object is closed.
//goland:noinspection GoVetLostCancel
newContext, _ = context.WithDeadline(newContext, deadline)
}
return newContext
}

func (s *spannerLibServer) Execute(ctx context.Context, request *pb.ExecuteRequest) (*pb.Rows, error) {
// Create a new context that is used for the query. We need to do this, because the context that is passed in to
// this function will be cancelled once the RPC call finishes. That again would cause further calls to Next on the
// underlying rows object to fail with a 'Context cancelled' error.
queryContext := contextWithSameDeadline(ctx)
id, err := api.Execute(queryContext, request.Connection.Pool.Id, request.Connection.Id, request.ExecuteSqlRequest)
func (s *spannerLibServer) Execute(ctx context.Context, request *pb.ExecuteRequest) (returnedRows *pb.Rows, returnedErr error) {
// Only use the context of the gRPC invocation for the DirectExecute option. That is: It is only used
// for fetching the first results, and can be cancelled after that.
id, err := api.ExecuteWithDirectExecuteContext(context.Background(), ctx, request.Connection.Pool.Id, request.Connection.Id, request.ExecuteSqlRequest)
if err != nil {
return nil, err
}
return &pb.Rows{Connection: request.Connection, Id: id}, nil
}

func (s *spannerLibServer) ExecuteStreaming(request *pb.ExecuteRequest, stream grpc.ServerStreamingServer[pb.RowData]) error {
queryContext := contextWithSameDeadline(stream.Context())
queryContext := stream.Context()
id, err := api.Execute(queryContext, request.Connection.Pool.Id, request.Connection.Id, request.ExecuteSqlRequest)
if err != nil {
return err
}
defer func() { _ = api.CloseRows(queryContext, request.Connection.Pool.Id, request.Connection.Id, id) }()
defer func() { _ = api.CloseRows(context.Background(), request.Connection.Pool.Id, request.Connection.Id, id) }()
rows := &pb.Rows{Connection: request.Connection, Id: id}
metadata, err := api.Metadata(queryContext, request.Connection.Pool.Id, request.Connection.Id, id)
if err != nil {
Expand Down Expand Up @@ -214,12 +202,7 @@ func (s *spannerLibServer) BeginTransaction(ctx context.Context, request *pb.Beg
// Create a new context that is used for the transaction. We need to do this, because the context that is passed in
// to this function will be cancelled once the RPC call finishes. That again would cause further calls on
// the underlying transaction to fail with a 'Context cancelled' error.
txContext := context.Background()
if deadline, ok := ctx.Deadline(); ok {
// Ignore the returned cancel function here, as the context will be closed when the transaction is closed.
//goland:noinspection GoVetLostCancel
txContext, _ = context.WithDeadline(txContext, deadline)
}
txContext := context.WithoutCancel(ctx)
err := api.BeginTransaction(txContext, request.Connection.Pool.Id, request.Connection.Id, request.TransactionOptions)
if err != nil {
return nil, err
Expand Down
Loading
Loading