From c76238c807d12f35254974dc3f1699fd9168f35d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Knut=20Olav=20L=C3=B8ite?= Date: Thu, 23 Oct 2025 10:48:34 +0200 Subject: [PATCH] feat: support connect_timeout property Add support for a connect_timeout property that sets the maximum time that the driver will wait when creating a new connection to Spanner. This property only affects the very first creation of a connection to Spanner for a given connector. Once a connection has been established, all further connection creations will not use this value. Fixes #576 --- connection_properties.go | 9 ++++++++ connectionstate/converters.go | 2 +- driver.go | 11 +++++++++ driver_with_mockserver_test.go | 41 ++++++++++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 1 deletion(-) diff --git a/connection_properties.go b/connection_properties.go index 31218cd7..370c293e 100644 --- a/connection_properties.go +++ b/connection_properties.go @@ -461,6 +461,15 @@ var propertyDisableStatementCache = createConnectionProperty( connectionstate.ContextStartup, connectionstate.ConvertBool, ) +var propertyConnectTimeout = createConnectionProperty( + "connect_timeout", + "The amount of time to wait before timing out when creating a new connection.", + 0, + false, + nil, + connectionstate.ContextStartup, + connectionstate.ConvertDuration, +) // Generated read-only properties. These cannot be set by the user anywhere. var propertyCommitTimestamp = createReadOnlyConnectionProperty( diff --git a/connectionstate/converters.go b/connectionstate/converters.go index 19646a3a..8b950237 100644 --- a/connectionstate/converters.go +++ b/connectionstate/converters.go @@ -113,7 +113,7 @@ func parseTimestamp(re *regexp.Regexp, params string) (time.Time, error) { func parseDuration(re *regexp.Regexp, value string) (time.Duration, error) { matches := matchesToMap(re, value) if matches["duration"] == "" && matches["number"] == "" && matches["null"] == "" { - return 0, spanner.ToSpannerError(status.Error(codes.InvalidArgument, fmt.Sprintf("No duration found: %v", value))) + return 0, spanner.ToSpannerError(status.Error(codes.InvalidArgument, fmt.Sprintf("No or invalid duration found: %v", value))) } if matches["duration"] != "" { d, err := time.ParseDuration(matches["duration"]) diff --git a/driver.go b/driver.go index 5359f5ab..93f0c8a0 100644 --- a/driver.go +++ b/driver.go @@ -732,6 +732,17 @@ func openDriverConn(ctx context.Context, c *connector) (driver.Conn, error) { c.connectorConfig.Project, c.connectorConfig.Instance, c.connectorConfig.Database) + if value, ok := c.initialPropertyValues[propertyConnectTimeout.Key()]; ok { + if timeout, err := value.GetValue(); err == nil { + if duration, ok := timeout.(time.Duration); ok { + var cancel context.CancelFunc + // This will set the actual timeout of the context to the lower of the + // current context timeout (if any) and the value from the connection property. + ctx, cancel = context.WithTimeout(ctx, duration) + defer cancel() + } + } + } if err := c.increaseConnCount(ctx, databaseName, opts); err != nil { return nil, err diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index b059aa63..1ed093c8 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -5555,6 +5555,47 @@ func TestReturnResultSetMetadataAndStats(t *testing.T) { } } +func TestConnectTimeout(t *testing.T) { + t.Parallel() + + server, _, serverTeardown := setupMockedTestServerWithDialect(t, databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL) + defer serverTeardown() + db, err := sql.Open( + "spanner", + fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true;connect_timeout=1ms", server.Address)) + if err != nil { + t.Fatal(err) + } + defer silentClose(db) + + // Make the ExecuteStreamingSql method a bit slow, so the query that is used to detect the dialect responds a bit slowly. + server.TestSpanner.PutExecutionTime(testutil.MethodExecuteStreamingSql, testutil.SimulatedExecutionTime{MinimumExecutionTime: time.Millisecond * 10}) + + // Try to get/create a connection using a context without a deadline. + // This will cause the connect_timeout to be used. + c, err := db.Conn(context.Background()) + if g, w := spanner.ErrCode(err), codes.DeadlineExceeded; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } else if c != nil { + _ = c.Close() + } +} + +func TestInvalidConnectTimeout(t *testing.T) { + t.Parallel() + + server, _, serverTeardown := setupMockedTestServerWithDialect(t, databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL) + defer serverTeardown() + db, err := sql.Open( + "spanner", + fmt.Sprintf("%s/projects/p/instances/i/databases/d?useplaintext=true;connect_timeout='very long'", server.Address)) + if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w { + t.Fatalf("error code mismatch\n Got: %v\nWant: %v", g, w) + } else if db != nil { + defer silentClose(db) + } +} + func numeric(v string) big.Rat { res, _ := big.NewRat(1, 1).SetString(v) return *res