From b9beeeafc00bdd6389c805cbfe026010d63b5c4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 23 Oct 2025 18:36:34 +0200 Subject: [PATCH] chore: create a derived context in gRPC Create a derived context of the original gRPC context when creating a context that is used for the actual statement execution. This ensures that any values in the original gRPC context is also available in the query context. The context.WithoutCancel(..) function ensures that cancelling either the derived or the parent context does not affect the other. --- conn.go | 40 +++++- driver.go | 9 +- spannerlib/api/connection.go | 23 ++-- spannerlib/grpc-server/server.go | 31 ++--- spannerlib/grpc-server/server_test.go | 174 ++++++++++++++++++++++++++ statements.go | 2 + 6 files changed, 245 insertions(+), 34 deletions(-) diff --git a/conn.go b/conn.go index 278cff0c..6215bdb2 100644 --- a/conn.go +++ b/conn.go @@ -872,6 +872,7 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec return createDriverResultRows(res, execOptions), nil } var iter rowIterator + ctx, cancel := context.WithCancelCause(ctx) if c.tx == nil { if statementInfo.StatementType == parser.StatementTypeDml { // Use a read/write transaction to execute the statement. @@ -902,14 +903,49 @@ func (c *conn) queryContext(ctx context.Context, query string, execOptions *Exec } res := createRows(iter, execOptions) if execOptions.DirectExecuteQuery { + if err := c.directExecuteQuery(cancel, res, execOptions); err != nil { + return nil, err + } + } + return res, nil +} + +func (c *conn) directExecuteQuery(cancel context.CancelCauseFunc, res *rows, execOptions *ExecOptions) error { + if execOptions.DirectExecuteContext == nil { // This call to res.getColumns() triggers the execution of the statement, as it needs to fetch the metadata. res.getColumns() if res.dirtyErr != nil && !errors.Is(res.dirtyErr, iterator.Done) { _ = res.Close() - return nil, res.dirtyErr + return res.dirtyErr } + return nil } - return res, nil + + // Asynchronously fetch the first partial result set from Spanner. + done := make(chan struct{}) + go func() { + defer close(done) + res.getColumns() + }() + // Wait until either the done channel is closed or the context is done. + select { + case <-execOptions.DirectExecuteContext.Done(): + // Cancel the execution. + cancel(execOptions.DirectExecuteContext.Err()) + case <-done: + } + + // Now wait until done channel is closed. This could be because the execution finished + // successfully, or because the context was cancelled, which again causes the execution + // to (eventually) fail. + select { + case <-done: + } + if res.dirtyErr != nil && !errors.Is(res.dirtyErr, iterator.Done) { + _ = res.Close() + return res.dirtyErr + } + return nil } func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { diff --git a/driver.go b/driver.go index 5359f5ab..6db6ba9b 100644 --- a/driver.go +++ b/driver.go @@ -204,12 +204,16 @@ type ExecOptions struct { // order to move to the result set that contains the spannerpb.ResultSetStats. ReturnResultSetStats bool - // DirectExecute determines whether a query is executed directly when the + // DirectExecuteQuery determines whether a query is executed directly when the // [sql.DB.QueryContext] method is called, or whether the actual query execution // is delayed until the first call to [sql.Rows.Next]. The default is to delay // the execution. Set this flag to true to execute the query directly when // [sql.DB.QueryContext] is called. DirectExecuteQuery bool + + // DirectExecuteContext is the context that is used for the execution of a query + // when DirectExecuteQuery is enabled. + DirectExecuteContext context.Context } func (dest *ExecOptions) merge(src *ExecOptions) { @@ -231,6 +235,9 @@ func (dest *ExecOptions) merge(src *ExecOptions) { if src.DirectExecuteQuery { dest.DirectExecuteQuery = src.DirectExecuteQuery } + if src.DirectExecuteContext != nil { + dest.DirectExecuteContext = src.DirectExecuteContext + } if src.AutocommitDMLMode != Unspecified { dest.AutocommitDMLMode = src.AutocommitDMLMode } diff --git a/spannerlib/api/connection.go b/spannerlib/api/connection.go index ab256879..22f911cd 100644 --- a/spannerlib/api/connection.go +++ b/spannerlib/api/connection.go @@ -66,6 +66,10 @@ func WriteMutations(ctx context.Context, poolId, connId int64, mutations *spanne // BeginTransaction starts a new transaction on the given connection. // A connection can have at most one transaction at any time. This function therefore returns an error if the // connection has an active transaction. +// +// NOTE: The context that is passed in to this function is registered as the transaction context. The transaction is +// invalidated if the context is cancelled. The context that is passed in to this function should therefore not be a +// context that is cancelled right after calling this function. func BeginTransaction(ctx context.Context, poolId, connId int64, txOpts *spannerpb.TransactionOptions) error { conn, err := findConnection(poolId, connId) if err != nil { @@ -93,11 +97,15 @@ func Rollback(ctx context.Context, poolId, connId int64) error { } func Execute(ctx context.Context, poolId, connId int64, executeSqlRequest *spannerpb.ExecuteSqlRequest) (int64, error) { + return ExecuteWithDirectExecuteContext(ctx, nil, poolId, connId, executeSqlRequest) +} + +func ExecuteWithDirectExecuteContext(ctx, directExecuteContext context.Context, poolId, connId int64, executeSqlRequest *spannerpb.ExecuteSqlRequest) (int64, error) { conn, err := findConnection(poolId, connId) if err != nil { return 0, err } - return conn.Execute(ctx, executeSqlRequest) + return conn.Execute(ctx, directExecuteContext, executeSqlRequest) } func ExecuteBatch(ctx context.Context, poolId, connId int64, statements *spannerpb.ExecuteBatchDmlRequest) (*spannerpb.ExecuteBatchDmlResponse, error) { @@ -300,16 +308,16 @@ func (conn *Connection) closeResults(ctx context.Context) { }) } -func (conn *Connection) Execute(ctx context.Context, statement *spannerpb.ExecuteSqlRequest) (int64, error) { - return execute(ctx, conn, conn.backend, statement) +func (conn *Connection) Execute(ctx, directExecuteContext context.Context, statement *spannerpb.ExecuteSqlRequest) (int64, error) { + return execute(ctx, directExecuteContext, conn, conn.backend, statement) } func (conn *Connection) ExecuteBatch(ctx context.Context, statements []*spannerpb.ExecuteBatchDmlRequest_Statement) (*spannerpb.ExecuteBatchDmlResponse, error) { return executeBatch(ctx, conn, conn.backend, statements) } -func execute(ctx context.Context, conn *Connection, executor queryExecutor, statement *spannerpb.ExecuteSqlRequest) (int64, error) { - params := extractParams(statement) +func execute(ctx, directExecuteContext context.Context, conn *Connection, executor queryExecutor, statement *spannerpb.ExecuteSqlRequest) (int64, error) { + params := extractParams(directExecuteContext, statement) it, err := executor.QueryContext(ctx, statement.Sql, params...) if err != nil { return 0, err @@ -397,7 +405,7 @@ func executeBatchDml(ctx context.Context, conn *Connection, executor queryExecut Params: statement.Params, ParamTypes: statement.ParamTypes, } - params := extractParams(request) + params := extractParams(nil, request) _, err := executor.ExecContext(ctx, statement.Sql, params...) if err != nil { return nil, err @@ -423,7 +431,7 @@ func executeBatchDml(ctx context.Context, conn *Connection, executor queryExecut return &response, nil } -func extractParams(statement *spannerpb.ExecuteSqlRequest) []any { +func extractParams(directExecuteContext context.Context, statement *spannerpb.ExecuteSqlRequest) []any { paramsLen := 1 if statement.Params != nil { paramsLen = 1 + len(statement.Params.Fields) @@ -436,6 +444,7 @@ func extractParams(statement *spannerpb.ExecuteSqlRequest) []any { ReturnResultSetMetadata: true, ReturnResultSetStats: true, DirectExecuteQuery: true, + DirectExecuteContext: directExecuteContext, }) if statement.Params != nil { if statement.ParamTypes == nil { diff --git a/spannerlib/grpc-server/server.go b/spannerlib/grpc-server/server.go index 24c82786..206237c4 100644 --- a/spannerlib/grpc-server/server.go +++ b/spannerlib/grpc-server/server.go @@ -102,22 +102,10 @@ func (s *spannerLibServer) CloseConnection(ctx context.Context, connection *pb.C return &emptypb.Empty{}, nil } -func contextWithSameDeadline(ctx context.Context) context.Context { - newContext := context.Background() - if deadline, ok := ctx.Deadline(); ok { - // Ignore the returned cancel function here, as the context will be closed when the Rows object is closed. - //goland:noinspection GoVetLostCancel - newContext, _ = context.WithDeadline(newContext, deadline) - } - return newContext -} - -func (s *spannerLibServer) Execute(ctx context.Context, request *pb.ExecuteRequest) (*pb.Rows, error) { - // Create a new context that is used for the query. We need to do this, because the context that is passed in to - // this function will be cancelled once the RPC call finishes. That again would cause further calls to Next on the - // underlying rows object to fail with a 'Context cancelled' error. - queryContext := contextWithSameDeadline(ctx) - id, err := api.Execute(queryContext, request.Connection.Pool.Id, request.Connection.Id, request.ExecuteSqlRequest) +func (s *spannerLibServer) Execute(ctx context.Context, request *pb.ExecuteRequest) (returnedRows *pb.Rows, returnedErr error) { + // Only use the context of the gRPC invocation for the DirectExecute option. That is: It is only used + // for fetching the first results, and can be cancelled after that. + id, err := api.ExecuteWithDirectExecuteContext(context.Background(), ctx, request.Connection.Pool.Id, request.Connection.Id, request.ExecuteSqlRequest) if err != nil { return nil, err } @@ -125,12 +113,12 @@ func (s *spannerLibServer) Execute(ctx context.Context, request *pb.ExecuteReque } func (s *spannerLibServer) ExecuteStreaming(request *pb.ExecuteRequest, stream grpc.ServerStreamingServer[pb.RowData]) error { - queryContext := contextWithSameDeadline(stream.Context()) + queryContext := stream.Context() id, err := api.Execute(queryContext, request.Connection.Pool.Id, request.Connection.Id, request.ExecuteSqlRequest) if err != nil { return err } - defer func() { _ = api.CloseRows(queryContext, request.Connection.Pool.Id, request.Connection.Id, id) }() + defer func() { _ = api.CloseRows(context.Background(), request.Connection.Pool.Id, request.Connection.Id, id) }() rows := &pb.Rows{Connection: request.Connection, Id: id} metadata, err := api.Metadata(queryContext, request.Connection.Pool.Id, request.Connection.Id, id) if err != nil { @@ -214,12 +202,7 @@ func (s *spannerLibServer) BeginTransaction(ctx context.Context, request *pb.Beg // Create a new context that is used for the transaction. We need to do this, because the context that is passed in // to this function will be cancelled once the RPC call finishes. That again would cause further calls on // the underlying transaction to fail with a 'Context cancelled' error. - txContext := context.Background() - if deadline, ok := ctx.Deadline(); ok { - // Ignore the returned cancel function here, as the context will be closed when the transaction is closed. - //goland:noinspection GoVetLostCancel - txContext, _ = context.WithDeadline(txContext, deadline) - } + txContext := context.WithoutCancel(ctx) err := api.BeginTransaction(txContext, request.Connection.Pool.Id, request.Connection.Id, request.TransactionOptions) if err != nil { return nil, err diff --git a/spannerlib/grpc-server/server_test.go b/spannerlib/grpc-server/server_test.go index 35cc3396..f4526718 100644 --- a/spannerlib/grpc-server/server_test.go +++ b/spannerlib/grpc-server/server_test.go @@ -2,20 +2,26 @@ package main import ( "context" + "errors" "fmt" + "io" "net" "os" "path/filepath" "reflect" "runtime" "testing" + "time" + "cloud.google.com/go/spanner" "cloud.google.com/go/spanner/admin/database/apiv1/databasepb" sppb "cloud.google.com/go/spanner/apiv1/spannerpb" "github.com/google/uuid" "github.com/googleapis/go-sql-spanner/testutil" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/structpb" pb "spannerlib/grpc-server/google/spannerlib/v1" ) @@ -142,6 +148,42 @@ func TestExecute(t *testing.T) { } } +func TestExecuteWithTimeout(t *testing.T) { + t.Parallel() + ctx := context.Background() + + server, teardown := setupMockSpannerServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + client, cleanup := startTestSpannerLibServer(t) + defer cleanup() + + pool, err := client.CreatePool(ctx, &pb.CreatePoolRequest{ConnectionString: dsn}) + if err != nil { + t.Fatalf("failed to create pool: %v", err) + } + connection, err := client.CreateConnection(ctx, &pb.CreateConnectionRequest{Pool: pool}) + if err != nil { + t.Fatalf("failed to create connection: %v", err) + } + + server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, testutil.SimulatedExecutionTime{MinimumExecutionTime: 2 * time.Millisecond}) + withTimeout, cancel := context.WithTimeout(ctx, time.Millisecond) + defer cancel() + _, err = client.Execute(withTimeout, &pb.ExecuteRequest{ + Connection: connection, + ExecuteSqlRequest: &sppb.ExecuteSqlRequest{Sql: testutil.SelectFooFromBar}, + }) + if g, w := status.Code(err), codes.DeadlineExceeded; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + if _, err := client.ClosePool(ctx, pool); err != nil { + t.Fatalf("failed to close pool: %v", err) + } +} + func TestExecuteStreaming(t *testing.T) { t.Parallel() ctx := context.Background() @@ -194,6 +236,49 @@ func TestExecuteStreaming(t *testing.T) { } } +func TestExecuteStreamingWithTimeout(t *testing.T) { + t.Parallel() + ctx := context.Background() + + server, teardown := setupMockSpannerServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + client, cleanup := startTestSpannerLibServer(t) + defer cleanup() + + pool, err := client.CreatePool(ctx, &pb.CreatePoolRequest{ConnectionString: dsn}) + if err != nil { + t.Fatalf("failed to create pool: %v", err) + } + connection, err := client.CreateConnection(ctx, &pb.CreateConnectionRequest{Pool: pool}) + if err != nil { + t.Fatalf("failed to create connection: %v", err) + } + + server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, testutil.SimulatedExecutionTime{MinimumExecutionTime: 2 * time.Millisecond}) + withTimeout, cancel := context.WithTimeout(ctx, time.Millisecond) + defer cancel() + stream, err := client.ExecuteStreaming(withTimeout, &pb.ExecuteRequest{ + Connection: connection, + ExecuteSqlRequest: &sppb.ExecuteSqlRequest{Sql: testutil.SelectFooFromBar}, + }) + // The timeout can happen here or while waiting for the first response. + if err != nil { + if g, w := spanner.ErrCode(err), codes.DeadlineExceeded; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + } + _, err = stream.Recv() + if g, w := spanner.ErrCode(err), codes.DeadlineExceeded; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } + + if _, err := client.ClosePool(ctx, pool); err != nil { + t.Fatalf("failed to close pool: %v", err) + } +} + func TestExecuteStreamingClientSideStatement(t *testing.T) { t.Parallel() ctx := context.Background() @@ -251,6 +336,95 @@ func TestExecuteStreamingClientSideStatement(t *testing.T) { } } +func TestExecuteStreamingCustomSql(t *testing.T) { + t.Parallel() + ctx := context.Background() + + server, teardown := setupMockSpannerServer(t) + defer teardown() + dsn := fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true", server.Address) + + client, cleanup := startTestSpannerLibServer(t) + defer cleanup() + + pool, err := client.CreatePool(ctx, &pb.CreatePoolRequest{ConnectionString: dsn}) + if err != nil { + t.Fatalf("failed to create pool: %v", err) + } + connection, err := client.CreateConnection(ctx, &pb.CreateConnectionRequest{Pool: pool}) + if err != nil { + t.Fatalf("failed to create connection: %v", err) + } + + stream, err := client.ExecuteStreaming(ctx, &pb.ExecuteRequest{ + Connection: connection, + ExecuteSqlRequest: &sppb.ExecuteSqlRequest{Sql: "begin"}, + }) + if err != nil { + t.Fatalf("failed to execute: %v", err) + } + row, err := stream.Recv() + if err != nil { + t.Fatalf("failed to receive row: %v", err) + } + if g, w := len(row.Data), 0; g != w { + t.Fatalf("row data length mismatch\n Got: %v\nWant: %v", g, w) + } + if _, err := stream.Recv(); !errors.Is(err, io.EOF) { + t.Fatalf("expected io.EOF, got: %v", err) + } + + stream, err = client.ExecuteStreaming(ctx, &pb.ExecuteRequest{ + Connection: connection, + ExecuteSqlRequest: &sppb.ExecuteSqlRequest{Sql: testutil.SelectFooFromBar}, + }) + if err != nil { + t.Fatalf("failed to execute: %v", err) + } + numRows := 0 + for { + row, err := stream.Recv() + if err != nil { + t.Fatalf("failed to receive row: %v", err) + } + if len(row.Data) == 0 { + break + } + if g, w := len(row.Data), 1; g != w { + t.Fatalf("num rows mismatch\n Got: %v\nWant: %v", g, w) + } + if g, w := len(row.Data[0].Values), 1; g != w { + t.Fatalf("num values mismatch\n Got: %v\nWant: %v", g, w) + } + numRows++ + } + if g, w := numRows, 2; g != w { + t.Fatalf("num rows mismatch\n Got: %v\nWant: %v", g, w) + } + + stream, err = client.ExecuteStreaming(ctx, &pb.ExecuteRequest{ + Connection: connection, + ExecuteSqlRequest: &sppb.ExecuteSqlRequest{Sql: "commit"}, + }) + if err != nil { + t.Fatalf("failed to execute: %v", err) + } + row, err = stream.Recv() + if err != nil { + t.Fatalf("failed to receive row: %v", err) + } + if g, w := len(row.Data), 0; g != w { + t.Fatalf("row data length mismatch\n Got: %v\nWant: %v", g, w) + } + if _, err := stream.Recv(); !errors.Is(err, io.EOF) { + t.Fatalf("expected io.EOF, got: %v", err) + } + + if _, err := client.ClosePool(ctx, pool); err != nil { + t.Fatalf("failed to close pool: %v", err) + } +} + func TestExecuteBatch(t *testing.T) { t.Parallel() ctx := context.Background() diff --git a/statements.go b/statements.go index 147146b5..72f3ed95 100644 --- a/statements.go +++ b/statements.go @@ -282,6 +282,8 @@ func (s *executableBeginStatement) execContext(ctx context.Context, c *conn, opt if len(s.stmt.Identifiers) != len(s.stmt.Literals) { return nil, status.Errorf(codes.InvalidArgument, "statement contains %d identifiers, but %d values given", len(s.stmt.Identifiers), len(s.stmt.Literals)) } + // Make sure the transaction context is not cancelled when this context is cancelled. + ctx = context.WithoutCancel(ctx) _, err := c.BeginTx(ctx, driver.TxOptions{}) if err != nil { return nil, err