From b98176e9a71a18949be7aae37dc7fe79fc63b5b1 Mon Sep 17 00:00:00 2001 From: Sam Schaub Date: Tue, 28 Oct 2025 21:30:51 +0000 Subject: [PATCH] feat: add keepalive feature to tear down streams in their absence --- .../v1/StreamingSubscriberConnection.java | 158 ++++++++++++ .../v1/StreamingSubscriberConnectionTest.java | 244 ++++++++++++++++++ 2 files changed, 402 insertions(+) diff --git a/google-cloud-pubsub/src/main/java/com/google/cloud/pubsub/v1/StreamingSubscriberConnection.java b/google-cloud-pubsub/src/main/java/com/google/cloud/pubsub/v1/StreamingSubscriberConnection.java index 319dd31f5..7cc6de22c 100644 --- a/google-cloud-pubsub/src/main/java/com/google/cloud/pubsub/v1/StreamingSubscriberConnection.java +++ b/google-cloud-pubsub/src/main/java/com/google/cloud/pubsub/v1/StreamingSubscriberConnection.java @@ -63,6 +63,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; @@ -95,6 +96,7 @@ final class StreamingSubscriberConnection extends AbstractApiService implements private final SubscriberStub subscriberStub; private final int channelAffinity; + private final long protocolVersion; private final String subscription; private final SubscriptionName subscriptionNameObject; private final ScheduledExecutorService systemExecutor; @@ -127,6 +129,17 @@ final class StreamingSubscriberConnection extends AbstractApiService implements private OpenTelemetryPubsubTracer tracer = new OpenTelemetryPubsubTracer(null, false); private final SubscriberShutdownSettings subscriberShutdownSettings; + private final boolean enableKeepalive; + private static final long KEEP_ALIVE_SUPPORT_VERSION = 2; + private static final Duration CLIENT_PING_INTERVAL = Duration.ofSeconds(30); + private ScheduledFuture pingSchedulerHandle; + + private static final Duration SERVER_TIMEOUT_DURATION = Duration.ofSeconds(45); + private static final Duration SERVER_PING_TIMEOUT_DURATION = Duration.ofSeconds(15); + private final AtomicLong lastServerResponseTime; + private final AtomicLong lastClientPingTime; + private ScheduledFuture serverMonitorHandle; + private StreamingSubscriberConnection(Builder builder) { subscription = builder.subscription; subscriptionNameObject = SubscriptionName.parse(builder.subscription); @@ -154,6 +167,7 @@ private StreamingSubscriberConnection(Builder builder) { subscriberStub = builder.subscriberStub; channelAffinity = builder.channelAffinity; + protocolVersion = builder.protocolVersion; MessageDispatcher.Builder messageDispatcherBuilder; if (builder.receiver != null) { @@ -190,6 +204,9 @@ private StreamingSubscriberConnection(Builder builder) { flowControlSettings = builder.flowControlSettings; useLegacyFlowControl = builder.useLegacyFlowControl; + enableKeepalive = protocolVersion >= KEEP_ALIVE_SUPPORT_VERSION; + lastServerResponseTime = new AtomicLong(clock.nanoTime()); + lastClientPingTime = new AtomicLong(clock.nanoTime()); } public StreamingSubscriberConnection setExactlyOnceDeliveryEnabled( @@ -218,6 +235,12 @@ protected void doStop() { } finally { lock.unlock(); } + + if (enableKeepalive) { + stopClientPinger(); + stopServerMonitor(); + } + runShutdown(); notifyStopped(); } @@ -266,6 +289,10 @@ public void onStart(StreamController controller) { @Override public void onResponse(StreamingPullResponse response) { + if (enableKeepalive) { + lastServerResponseTime.set(clock.nanoTime()); + } + channelReconnectBackoffMillis.set(INITIAL_CHANNEL_RECONNECT_BACKOFF.toMillis()); boolean exactlyOnceDeliveryEnabledResponse = @@ -295,11 +322,19 @@ public void onResponse(StreamingPullResponse response) { @Override public void onError(Throwable t) { + if (enableKeepalive) { + stopClientPinger(); + stopServerMonitor(); + } errorFuture.setException(t); } @Override public void onComplete() { + if (enableKeepalive) { + stopClientPinger(); + stopServerMonitor(); + } logger.fine("Streaming pull terminated successfully!"); errorFuture.set(null); } @@ -336,6 +371,7 @@ private void initialize() { this.useLegacyFlowControl ? 0 : valueOrZero(flowControlSettings.getMaxOutstandingRequestBytes())) + .setProtocolVersion(protocolVersion) .build()); /** @@ -350,6 +386,11 @@ private void initialize() { lock.unlock(); } + if (enableKeepalive) { + startClientPinger(); + startServerMonitor(); + } + ApiFutures.addCallback( errorFuture, new ApiFutureCallback() { @@ -366,6 +407,10 @@ public void onSuccess(@Nullable Void result) { @Override public void onFailure(Throwable cause) { + if (enableKeepalive) { + stopClientPinger(); + stopServerMonitor(); + } if (!isAlive()) { // we don't care about subscription failures when we're no longer running. logger.log(Level.FINE, "pull failure after service no longer running", cause); @@ -410,6 +455,113 @@ private boolean isAlive() { return state == State.RUNNING || state == State.STARTING; } + private void startClientPinger() { + if (pingSchedulerHandle != null) { + pingSchedulerHandle.cancel(false); + } + + pingSchedulerHandle = + systemExecutor.scheduleAtFixedRate( + () -> { + try { + lock.lock(); + try { + if (clientStream != null && isAlive()) { + clientStream.send(StreamingPullRequest.newBuilder().build()); + lastClientPingTime.set(clock.nanoTime()); + logger.log(Level.FINEST, "Sent client keepalive ping"); + } + } finally { + lock.unlock(); + } + } catch (Exception e) { + logger.log(Level.FINE, "Error sending client keepalive ping", e); + } + }, + CLIENT_PING_INTERVAL.getSeconds(), + CLIENT_PING_INTERVAL.getSeconds(), + TimeUnit.SECONDS); + } + + private void stopClientPinger() { + if (pingSchedulerHandle != null) { + pingSchedulerHandle.cancel(false); + pingSchedulerHandle = null; + } + } + + private void startServerMonitor() { + if (serverMonitorHandle != null) { + serverMonitorHandle.cancel(false); + } + + Duration checkInterval = Duration.ofSeconds(15); + serverMonitorHandle = + systemExecutor.scheduleAtFixedRate( + () -> { + try { + if (!isAlive()) { + return; + } + + long now = clock.nanoTime(); + long lastResponse = lastServerResponseTime.get(); + Duration elapsedSinceResponse = Duration.ofNanos(now - lastResponse); + + if (elapsedSinceResponse.compareTo(SERVER_TIMEOUT_DURATION) <= 0) { + return; + } + + long lastPing = lastClientPingTime.get(); + if (lastPing > lastResponse) { + Duration elapsedSincePing = Duration.ofNanos(now - lastPing); + if (elapsedSincePing.compareTo(SERVER_PING_TIMEOUT_DURATION) <= 0) { + // waiting for response from ping + return; + } + + logger.log( + Level.WARNING, + "No response from server for {0} seconds, and no response to ping sent {1}" + + " seconds ago. Closing stream.", + new Object[] { + elapsedSinceResponse.getSeconds(), elapsedSincePing.getSeconds() + }); + } else { + logger.log( + Level.WARNING, + "No response from server for {0} seconds. Closing stream.", + elapsedSinceResponse.getSeconds()); + } + + lock.lock(); + try { + if (clientStream != null) { + clientStream.closeSendWithError( + Status.UNAVAILABLE + .withDescription("Keepalive timeout with server") + .asException()); + } + } finally { + lock.unlock(); + } + stopServerMonitor(); + } catch (Exception e) { + logger.log(Level.FINE, "Error in server keepaliver monitor", e); + } + }, + checkInterval.getSeconds(), + checkInterval.getSeconds(), + TimeUnit.SECONDS); + } + + private void stopServerMonitor() { + if (serverMonitorHandle != null) { + serverMonitorHandle.cancel(false); + serverMonitorHandle = null; + } + } + public void setResponseOutstandingMessages(AckResponse ackResponse) { // We will close the futures with ackResponse - if there are multiple references to the same // future they will be handled appropriately @@ -769,6 +921,7 @@ public static final class Builder { private Distribution ackLatencyDistribution; private SubscriberStub subscriberStub; private int channelAffinity; + private long protocolVersion; private FlowController flowController; private FlowControlSettings flowControlSettings; private boolean useLegacyFlowControl; @@ -840,6 +993,11 @@ public Builder setChannelAffinity(int channelAffinity) { return this; } + public Builder setProtocolVersion(long protocolVersion) { + this.protocolVersion = protocolVersion; + return this; + } + public Builder setFlowController(FlowController flowController) { this.flowController = flowController; return this; diff --git a/google-cloud-pubsub/src/test/java/com/google/cloud/pubsub/v1/StreamingSubscriberConnectionTest.java b/google-cloud-pubsub/src/test/java/com/google/cloud/pubsub/v1/StreamingSubscriberConnectionTest.java index f79825d85..676b7172c 100644 --- a/google-cloud-pubsub/src/test/java/com/google/cloud/pubsub/v1/StreamingSubscriberConnectionTest.java +++ b/google-cloud-pubsub/src/test/java/com/google/cloud/pubsub/v1/StreamingSubscriberConnectionTest.java @@ -28,12 +28,18 @@ import com.google.api.gax.core.Distribution; import com.google.api.gax.grpc.GrpcStatusCode; import com.google.api.gax.rpc.ApiException; +import com.google.api.gax.rpc.BidiStreamingCallable; +import com.google.api.gax.rpc.ClientStream; +import com.google.api.gax.rpc.ResponseObserver; import com.google.api.gax.rpc.StatusCode; +import com.google.api.gax.rpc.StreamController; import com.google.cloud.pubsub.v1.stub.SubscriberStub; import com.google.common.collect.Lists; import com.google.protobuf.Any; import com.google.pubsub.v1.AcknowledgeRequest; import com.google.pubsub.v1.ModifyAckDeadlineRequest; +import com.google.pubsub.v1.StreamingPullRequest; +import com.google.pubsub.v1.StreamingPullResponse; import com.google.rpc.ErrorInfo; import com.google.rpc.Status; import io.grpc.Status.Code; @@ -44,11 +50,13 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TestName; +import org.mockito.ArgumentCaptor; /** Tests for {@link StreamingSubscriberConnection}. */ public class StreamingSubscriberConnectionTest { @@ -86,6 +94,11 @@ public class StreamingSubscriberConnectionTest { private static Duration ACK_EXPIRATION_PADDING_DEFAULT_DURATION = Duration.ofSeconds(10); private static int MAX_DURATION_PER_ACK_EXTENSION_DEFAULT_SECONDS = 10; + private static final long KEEP_ALIVE_SUPPORT_VERSION = 2; + private static final Duration CLIENT_PING_INTERVAL = Duration.ofSeconds(30); + private static final Duration SERVER_TIMEOUT_DURATION = Duration.ofSeconds(45); + private static final Duration MAX_ACK_EXTENSION_PERIOD = Duration.ofMinutes(60); + @Before public void setUp() { systemExecutor = new FakeScheduledExecutorService(); @@ -670,6 +683,227 @@ public void testMaxPerRequestChanges() { } } + @Test + public void testClientPinger_pingSent() { + BidiStreamingCallable mockStreamingCallable = + mock(BidiStreamingCallable.class); + ClientStream mockClientStream = mock(ClientStream.class); + when(mockSubscriberStub.streamingPullCallable()).thenReturn(mockStreamingCallable); + when(mockStreamingCallable.splitCall(any(ResponseObserver.class), any())) + .thenReturn(mockClientStream); + + StreamingSubscriberConnection streamingSubscriberConnection = + getKeepaliveStreamingSubscriberConnection(); + + streamingSubscriberConnection.startAsync(); + streamingSubscriberConnection.awaitRunning(); + + systemExecutor.advanceTime(CLIENT_PING_INTERVAL); + systemExecutor.advanceTime(CLIENT_PING_INTERVAL); + + ArgumentCaptor requestCaptor = + ArgumentCaptor.forClass(StreamingPullRequest.class); + // 1 initial request + 2 pings + verify(mockClientStream, times(3)).send(requestCaptor.capture()); + List requests = requestCaptor.getAllValues(); + + StreamingPullRequest initialRequest = requests.get(0); + assertEquals(MOCK_SUBSCRIPTION_NAME, initialRequest.getSubscription()); + assertEquals(KEEP_ALIVE_SUPPORT_VERSION, initialRequest.getProtocolVersion()); + assertEquals(0, initialRequest.getMaxOutstandingMessages()); + + StreamingPullRequest firstPing = requests.get(1); + assertEquals(StreamingPullRequest.getDefaultInstance(), firstPing); + + StreamingPullRequest secondPing = requests.get(2); + assertEquals(StreamingPullRequest.getDefaultInstance(), secondPing); + + streamingSubscriberConnection.stopAsync(); + streamingSubscriberConnection.awaitTerminated(); + + // No more pings + systemExecutor.advanceTime(CLIENT_PING_INTERVAL); + verify(mockClientStream, times(3)).send(any(StreamingPullRequest.class)); + } + + @Test + public void testClientPinger_pingsNotSentWhenDisabled() { + BidiStreamingCallable mockStreamingCallable = + mock(BidiStreamingCallable.class); + ClientStream mockClientStream = mock(ClientStream.class); + when(mockSubscriberStub.streamingPullCallable()).thenReturn(mockStreamingCallable); + when(mockStreamingCallable.splitCall(any(ResponseObserver.class), any())) + .thenReturn(mockClientStream); + + StreamingSubscriberConnection streamingSubscriberConnection = + getStreamingSubscriberConnection(false); // keepalive disabled + + streamingSubscriberConnection.startAsync(); + streamingSubscriberConnection.awaitRunning(); + + // Initial request. + verify(mockClientStream, times(1)).send(any(StreamingPullRequest.class)); + + // No pings + systemExecutor.advanceTime(CLIENT_PING_INTERVAL); + systemExecutor.advanceTime(CLIENT_PING_INTERVAL); + + verify(mockClientStream, times(1)).send(any(StreamingPullRequest.class)); + } + + @Test + public void testServerMonitor_timesOut() { + BidiStreamingCallable mockStreamingCallable = + mock(BidiStreamingCallable.class); + ClientStream mockClientStream = mock(ClientStream.class); + ArgumentCaptor> observerCaptor = + ArgumentCaptor.forClass(ResponseObserver.class); + when(mockSubscriberStub.streamingPullCallable()).thenReturn(mockStreamingCallable); + when(mockStreamingCallable.splitCall(observerCaptor.capture(), any())) + .thenReturn(mockClientStream); + + AtomicInteger pingCount = new AtomicInteger(0); + doAnswer( + (invocation) -> { + StreamingPullRequest req = invocation.getArgument(0); + // Pings are empty requests + if (req.getSubscription().isEmpty()) { + if (pingCount.incrementAndGet() > 1) { + throw new RuntimeException("ping failed"); + } + } + return null; + }) + .when(mockClientStream) + .send(any(StreamingPullRequest.class)); + + StreamingSubscriberConnection streamingSubscriberConnection = + getKeepaliveStreamingSubscriberConnection(); + + streamingSubscriberConnection.startAsync(); + streamingSubscriberConnection.awaitRunning(); + + ResponseObserver observer = observerCaptor.getValue(); + StreamController mockController = mock(StreamController.class); + observer.onStart(mockController); + + // Should not time out yet + // a ping will be sent at 30s. a monitor check will happen at 45s. + // last ping was 15s ago so no timeout. + systemExecutor.advanceTime(SERVER_TIMEOUT_DURATION); + verify(mockClientStream, never()).closeSendWithError(any(Exception.class)); + + // first ping at 30s ok. second at 60s fails. + // monitor check at 60s. last ping 30s > 15s. timeout. + systemExecutor.advanceTime(Duration.ofSeconds(30)); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(mockClientStream, times(1)).closeSendWithError(exceptionCaptor.capture()); + StatusException exception = (StatusException) exceptionCaptor.getValue(); + assertEquals(Code.UNAVAILABLE, exception.getStatus().getCode()); + assertEquals("Keepalive timeout with server", exception.getStatus().getDescription()); + } + + @Test + public void testServerMonitor_timesOutWhenPingsFail() { + BidiStreamingCallable mockStreamingCallable = + mock(BidiStreamingCallable.class); + ClientStream mockClientStream = mock(ClientStream.class); + ArgumentCaptor> observerCaptor = + ArgumentCaptor.forClass(ResponseObserver.class); + when(mockSubscriberStub.streamingPullCallable()).thenReturn(mockStreamingCallable); + when(mockStreamingCallable.splitCall(observerCaptor.capture(), any())) + .thenReturn(mockClientStream); + + AtomicInteger pingCount = new AtomicInteger(0); + doAnswer( + (invocation) -> { + StreamingPullRequest req = invocation.getArgument(0); + // Pings are empty requests + if (req.getSubscription().isEmpty()) { + if (pingCount.incrementAndGet() > 1) { + throw new RuntimeException("ping failed"); + } + } + return null; + }) + .when(mockClientStream) + .send(any(StreamingPullRequest.class)); + + StreamingSubscriberConnection streamingSubscriberConnection = + getKeepaliveStreamingSubscriberConnection(); + + streamingSubscriberConnection.startAsync(); + streamingSubscriberConnection.awaitRunning(); + + ResponseObserver observer = observerCaptor.getValue(); + StreamController mockController = mock(StreamController.class); + observer.onStart(mockController); + + // First ping at 30s will succeed. + // Pings at 60s and 90s will fail to send. + // Monitor checks every 15s. + // At 90s, last successful ping was 60s ago. Timeout should occur. + systemExecutor.advanceTime(Duration.ofSeconds(90)); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(mockClientStream, times(1)).closeSendWithError(exceptionCaptor.capture()); + StatusException exception = (StatusException) exceptionCaptor.getValue(); + assertEquals(Code.UNAVAILABLE, exception.getStatus().getCode()); + assertEquals("Keepalive timeout with server", exception.getStatus().getDescription()); + } + + @Test + public void testServerMonitor_doesNotTimeOutWithResponse() { + BidiStreamingCallable mockStreamingCallable = + mock(BidiStreamingCallable.class); + ClientStream mockClientStream = mock(ClientStream.class); + ArgumentCaptor> observerCaptor = + ArgumentCaptor.forClass(ResponseObserver.class); + when(mockSubscriberStub.streamingPullCallable()).thenReturn(mockStreamingCallable); + when(mockStreamingCallable.splitCall(observerCaptor.capture(), any())) + .thenReturn(mockClientStream); + + AtomicInteger pingCount = new AtomicInteger(0); + doAnswer( + (invocation) -> { + StreamingPullRequest req = invocation.getArgument(0); + // Pings are empty requests + if (req.getSubscription().isEmpty()) { + if (pingCount.incrementAndGet() > 2) { + throw new RuntimeException("ping failed"); + } + } + return null; + }) + .when(mockClientStream) + .send(any(StreamingPullRequest.class)); + + StreamingSubscriberConnection streamingSubscriberConnection = + getKeepaliveStreamingSubscriberConnection(); + + streamingSubscriberConnection.startAsync(); + streamingSubscriberConnection.awaitRunning(); + + ResponseObserver observer = observerCaptor.getValue(); + StreamController mockController = mock(StreamController.class); + observer.onStart(mockController); + + // Advance 30s and simulate a response. + systemExecutor.advanceTime(Duration.ofSeconds(30)); + observer.onResponse(StreamingPullResponse.getDefaultInstance()); + + // last response at 30s. pings at 30s, 60s. + // checks at 45s (elapsed 15), 60s (elapsed 30) 75s (elapsed 45) + // should not timeout + systemExecutor.advanceTime(Duration.ofSeconds(45)); + verify(mockClientStream, never()).closeSendWithError(any(Exception.class)); + + // last response at 30s. pings at 30s and 60s ok. next check at 90s fails. + // elapsed since response is 60s. elasped since ping is 30s. Should timeout. + systemExecutor.advanceTime(Duration.ofSeconds(30)); + verify(mockClientStream, times(1)).closeSendWithError(any(Exception.class)); + } + private StreamingSubscriberConnection getStreamingSubscriberConnection( boolean exactlyOnceDeliveryEnabled) { StreamingSubscriberConnection streamingSubscriberConnection = @@ -682,11 +916,21 @@ private StreamingSubscriberConnection getStreamingSubscriberConnection( return streamingSubscriberConnection; } + private StreamingSubscriberConnection getKeepaliveStreamingSubscriberConnection() { + StreamingSubscriberConnection streamingSubscriberConnection = + getStreamingSubscriberConnectionFromBuilder( + StreamingSubscriberConnection.newBuilder(mock(MessageReceiverWithAckResponse.class)) + .setProtocolVersion(KEEP_ALIVE_SUPPORT_VERSION)); + + return streamingSubscriberConnection; + } + private StreamingSubscriberConnection getStreamingSubscriberConnectionFromBuilder( StreamingSubscriberConnection.Builder builder) { return builder .setSubscription(MOCK_SUBSCRIPTION_NAME) .setAckExpirationPadding(ACK_EXPIRATION_PADDING_DEFAULT_DURATION) + .setMaxAckExtensionPeriod(MAX_ACK_EXTENSION_PERIOD) .setAckLatencyDistribution(mock(Distribution.class)) .setSubscriberStub(mockSubscriberStub) .setChannelAffinity(0)