diff --git a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy index 7db9e56e7194..260dada6e934 100644 --- a/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy +++ b/buildSrc/src/main/groovy/org/apache/beam/gradle/BeamModulePlugin.groovy @@ -727,6 +727,7 @@ class BeamModulePlugin implements Plugin { commons_math3 : "org.apache.commons:commons-math3:3.6.1", dbcp2 : "org.apache.commons:commons-dbcp2:$dbcp2_version", error_prone_annotations : "com.google.errorprone:error_prone_annotations:$errorprone_version", + envoy_control_plane_api : "io.envoyproxy.controlplane:api:1.0.49", failsafe : "dev.failsafe:failsafe:3.3.0", flogger_system_backend : "com.google.flogger:flogger-system-backend:0.7.4", gax : "com.google.api:gax", // google_cloud_platform_libraries_bom sets version diff --git a/examples/java/build.gradle b/examples/java/build.gradle index 5334538cc09f..8ae0b47ef923 100644 --- a/examples/java/build.gradle +++ b/examples/java/build.gradle @@ -54,6 +54,7 @@ dependencies { implementation project(":sdks:java:extensions:python") implementation project(":sdks:java:io:google-cloud-platform") implementation project(":sdks:java:io:kafka") + implementation project(":sdks:java:io:components") implementation project(":sdks:java:extensions:ml") implementation library.java.avro implementation library.java.bigdataoss_util diff --git a/examples/java/src/main/java/org/apache/beam/examples/RateLimiterSimple.java b/examples/java/src/main/java/org/apache/beam/examples/RateLimiterSimple.java new file mode 100644 index 000000000000..89e2d5d06802 --- /dev/null +++ b/examples/java/src/main/java/org/apache/beam/examples/RateLimiterSimple.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.examples; + +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.io.components.ratelimiter.EnvoyRateLimiterContext; +import org.apache.beam.sdk.io.components.ratelimiter.EnvoyRateLimiterFactory; +import org.apache.beam.sdk.io.components.ratelimiter.RateLimiter; +import org.apache.beam.sdk.io.components.ratelimiter.RateLimiterContext; +import org.apache.beam.sdk.io.components.ratelimiter.RateLimiterFactory; +import org.apache.beam.sdk.io.components.ratelimiter.RateLimiterOptions; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A simple example demonstrating how to use the {@link RateLimiter} in a custom {@link DoFn}. + * + *

This pipeline creates a small set of elements and processes them using a DoFn that calls an + * external service (simulated). The processing is rate-limited using an Envoy Rate Limit Service. + * + *

To run this example, you need a running Envoy Rate Limit Service. + */ +public class RateLimiterSimple { + + public interface Options extends PipelineOptions { + @Description("Address of the Envoy Rate Limit Service(eg:localhost:8081)") + String getRateLimiterAddress(); + + void setRateLimiterAddress(String value); + + @Description("Domain for the Rate Limit Service(eg:mydomain)") + String getRateLimiterDomain(); + + void setRateLimiterDomain(String value); + } + + static class CallExternalServiceFn extends DoFn { + private final String rlsAddress; + private final String rlsDomain; + private transient @Nullable RateLimiter rateLimiter; + private static final Logger LOG = LoggerFactory.getLogger(CallExternalServiceFn.class); + + public CallExternalServiceFn(String rlsAddress, String rlsDomain) { + this.rlsAddress = rlsAddress; + this.rlsDomain = rlsDomain; + } + + @Setup + public void setup() { + // Create the RateLimiterOptions. + RateLimiterOptions options = RateLimiterOptions.builder().setAddress(rlsAddress).build(); + + // Static RateLimtier with pre-configured domain and descriptors + RateLimiterFactory factory = new EnvoyRateLimiterFactory(options); + RateLimiterContext context = + EnvoyRateLimiterContext.builder() + .setDomain(rlsDomain) + .addDescriptor("database", "users") + .build(); + this.rateLimiter = factory.getLimiter(context); + } + + @Teardown + public void teardown() { + if (rateLimiter != null) { + try { + rateLimiter.close(); + } catch (Exception e) { + throw new RuntimeException("Failed to close RateLimiter", e); + } + } + } + + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + String element = c.element(); + try { + Preconditions.checkNotNull(rateLimiter).allow(1); + } catch (Exception e) { + throw new RuntimeException("Failed to acquire rate limit token", e); + } + + // Simulate external API call + LOG.info("Processing: " + element); + Thread.sleep(100); + c.output("Processed: " + element); + } + } + + public static void main(String[] args) { + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + Pipeline p = Pipeline.create(options); + + p.apply( + "CreateItems", + Create.of( + IntStream.range(0, 100).mapToObj(i -> "item" + i).collect(Collectors.toList()))) + .apply( + "CallExternalService", + ParDo.of( + new CallExternalServiceFn( + options.getRateLimiterAddress(), options.getRateLimiterDomain()))); + + p.run().waitUntilFinish(); + } +} diff --git a/sdks/java/io/components/build.gradle b/sdks/java/io/components/build.gradle index 25bf95772110..b19c58e658c7 100644 --- a/sdks/java/io/components/build.gradle +++ b/sdks/java/io/components/build.gradle @@ -18,7 +18,7 @@ plugins { id 'org.apache.beam.module' } applyJavaNature( - automaticModuleName: 'org.apache.beam.sdk.io.components', + automaticModuleName: 'org.apache.beam.sdk.io.components', ) description = "Apache Beam :: SDKs :: Java :: IO :: Components" @@ -26,6 +26,11 @@ ext.summary = "Components for building fully featured IOs" dependencies { implementation project(path: ":sdks:java:core", configuration: "shadow") + implementation library.java.auto_value_annotations + implementation library.java.envoy_control_plane_api + implementation library.java.grpc_api + implementation library.java.grpc_stub + implementation library.java.grpc_protobuf implementation library.java.protobuf_java permitUnusedDeclared library.java.protobuf_java // BEAM-11761 implementation library.java.slf4j_api diff --git a/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/EnvoyRateLimiter.java b/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/EnvoyRateLimiter.java new file mode 100644 index 000000000000..9fc3da80dca4 --- /dev/null +++ b/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/EnvoyRateLimiter.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.components.ratelimiter; + +import java.io.IOException; + +/** + * A lightweight handle for an Envoy-based rate limiter. + * + *

Delegates work to the {@link EnvoyRateLimiterFactory} using the baked-in {@link + * EnvoyRateLimiterContext}. + */ +public class EnvoyRateLimiter implements RateLimiter { + private final EnvoyRateLimiterFactory factory; + private final EnvoyRateLimiterContext context; + + public EnvoyRateLimiter(EnvoyRateLimiterFactory factory, EnvoyRateLimiterContext context) { + this.factory = factory; + this.context = context; + } + + @Override + public boolean allow(int permits) throws IOException, InterruptedException { + return factory.allow(context, permits); + } + + @Override + public void close() throws Exception { + factory.close(); + } +} diff --git a/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/EnvoyRateLimiterContext.java b/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/EnvoyRateLimiterContext.java new file mode 100644 index 000000000000..b710e74b914f --- /dev/null +++ b/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/EnvoyRateLimiterContext.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.components.ratelimiter; + +import com.google.auto.value.AutoValue; +import java.util.Map; +import org.apache.beam.sdk.schemas.AutoValueSchema; +import org.apache.beam.sdk.schemas.annotations.DefaultSchema; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.checkerframework.checker.nullness.qual.NonNull; + +/** + * Context for an Envoy Rate Limiter call. + * + *

Contains the domain and descriptors required to define a specific rate limit bucket. + */ +@DefaultSchema(AutoValueSchema.class) +@AutoValue +public abstract class EnvoyRateLimiterContext implements RateLimiterContext { + + public abstract String getDomain(); + + public abstract ImmutableMap getDescriptors(); + + public static Builder builder() { + return new AutoValue_EnvoyRateLimiterContext.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setDomain(@NonNull String domain); + + public abstract ImmutableMap.Builder descriptorsBuilder(); + + public Builder addDescriptor(@NonNull String key, @NonNull String value) { + descriptorsBuilder().put(key, value); + return this; + } + + public Builder setDescriptors(@NonNull Map descriptors) { + descriptorsBuilder().putAll(descriptors); + return this; + } + + public abstract EnvoyRateLimiterContext build(); + } +} diff --git a/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/EnvoyRateLimiterFactory.java b/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/EnvoyRateLimiterFactory.java new file mode 100644 index 000000000000..5a27c309d4ec --- /dev/null +++ b/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/EnvoyRateLimiterFactory.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.components.ratelimiter; + +import io.envoyproxy.envoy.extensions.common.ratelimit.v3.RateLimitDescriptor; +import io.envoyproxy.envoy.service.ratelimit.v3.RateLimitRequest; +import io.envoyproxy.envoy.service.ratelimit.v3.RateLimitResponse; +import io.envoyproxy.envoy.service.ratelimit.v3.RateLimitServiceGrpc; +import io.grpc.StatusRuntimeException; +import java.io.IOException; +import java.util.Map; +import javax.annotation.Nullable; +import org.apache.beam.sdk.io.components.throttling.ThrottlingSignaler; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Distribution; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.util.Sleeper; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** A {@link RateLimiterFactory} for Envoy Rate Limit Service. */ +public class EnvoyRateLimiterFactory implements RateLimiterFactory { + private static final Logger LOG = LoggerFactory.getLogger(EnvoyRateLimiterFactory.class); + private static final int RPC_RETRY_COUNT = 3; + private static final long RPC_RETRY_DELAY_MILLIS = 5000; + + private final RateLimiterOptions options; + + private transient volatile @Nullable RateLimitServiceGrpc.RateLimitServiceBlockingStub stub; + private transient @Nullable RateLimiterClientCache clientCache; + private final ThrottlingSignaler throttlingSignaler; + private final Sleeper sleeper; + + private final Counter requestsTotal; + private final Counter requestsAllowed; + private final Counter requestsThrottled; + private final Counter rpcErrors; + private final Counter rpcRetries; + private final Distribution rpcLatency; + + public EnvoyRateLimiterFactory(RateLimiterOptions options) { + this(options, Sleeper.DEFAULT); + } + + @VisibleForTesting + EnvoyRateLimiterFactory(RateLimiterOptions options, Sleeper sleeper) { + this.options = options; + this.sleeper = sleeper; + String namespace = EnvoyRateLimiterFactory.class.getName(); + this.throttlingSignaler = new ThrottlingSignaler(namespace); + this.requestsTotal = Metrics.counter(namespace, "ratelimit-requests-total"); + this.requestsAllowed = Metrics.counter(namespace, "ratelimit-requests-allowed"); + this.requestsThrottled = Metrics.counter(namespace, "ratelimit-requests-throttled"); + this.rpcErrors = Metrics.counter(namespace, "ratelimit-rpc-errors"); + this.rpcRetries = Metrics.counter(namespace, "ratelimit-rpc-retries"); + this.rpcLatency = Metrics.distribution(namespace, "ratelimit-rpc-latency-ms"); + } + + @Override + public synchronized void close() { + if (clientCache != null) { + clientCache.release(); + clientCache = null; + } + stub = null; + } + + private void init() { + if (stub != null) { + return; + } + synchronized (this) { + if (stub == null) { + RateLimiterClientCache cache = RateLimiterClientCache.getOrCreate(options.getAddress()); + this.clientCache = cache; + stub = RateLimitServiceGrpc.newBlockingStub(cache.getChannel()); + } + } + } + + @VisibleForTesting + void setStub(RateLimitServiceGrpc.RateLimitServiceBlockingStub stub) { + this.stub = stub; + } + + @Override + public RateLimiter getLimiter(RateLimiterContext context) { + if (!(context instanceof EnvoyRateLimiterContext)) { + throw new IllegalArgumentException( + "EnvoyRateLimiterFactory requires EnvoyRateLimiterContext"); + } + return new EnvoyRateLimiter(this, (EnvoyRateLimiterContext) context); + } + + @Override + public boolean allow(RateLimiterContext context, int permits) + throws IOException, InterruptedException { + if (!(context instanceof EnvoyRateLimiterContext)) { + throw new IllegalArgumentException( + "EnvoyRateLimiterFactory requires EnvoyRateLimiterContext, got: " + + context.getClass().getName()); + } + EnvoyRateLimiterContext envoyContext = (EnvoyRateLimiterContext) context; + Preconditions.checkArgument(permits >= 0, "Permits must be non-negative"); + return callEnvoy(envoyContext, permits); + } + + private boolean callEnvoy(EnvoyRateLimiterContext context, int tokens) + throws IOException, InterruptedException { + + init(); + RateLimitServiceGrpc.RateLimitServiceBlockingStub currentStub = stub; + if (currentStub == null) { + throw new IllegalStateException("RateLimitServiceStub is null"); + } + + Map descriptors = context.getDescriptors(); + RateLimitDescriptor.Builder descriptorBuilder = RateLimitDescriptor.newBuilder(); + + for (Map.Entry entry : descriptors.entrySet()) { + descriptorBuilder.addEntries( + RateLimitDescriptor.Entry.newBuilder() + .setKey(entry.getKey()) + .setValue(entry.getValue()) + .build()); + } + + RateLimitRequest request = + RateLimitRequest.newBuilder() + .setDomain(context.getDomain()) + .setHitsAddend(tokens) + .addDescriptors(descriptorBuilder.build()) + .build(); + + Integer maxRetries = options.getMaxRetries(); + long timeoutMillis = options.getTimeout().toMillis(); + + requestsTotal.inc(); + int attempt = 0; + while (true) { + if (maxRetries != null && attempt > maxRetries) { + return false; + } + + // RPC Retry Loop + RateLimitResponse response = null; + long startTime = System.currentTimeMillis(); + for (int i = 0; i < RPC_RETRY_COUNT; i++) { + try { + response = + currentStub + .withDeadlineAfter(timeoutMillis, java.util.concurrent.TimeUnit.MILLISECONDS) + .shouldRateLimit(request); + long endTime = System.currentTimeMillis(); + rpcLatency.update(endTime - startTime); + break; + } catch (StatusRuntimeException e) { + rpcErrors.inc(); + if (i == RPC_RETRY_COUNT - 1) { + LOG.error("RateLimitService call failed after {} attempts", RPC_RETRY_COUNT, e); + throw new IOException("Failed to call Rate Limit Service", e); + } + rpcRetries.inc(); + LOG.warn("RateLimitService call failed, retrying", e); + if (sleeper != null) { + sleeper.sleep(RPC_RETRY_DELAY_MILLIS); + } + } + } + + if (response == null) { + throw new IOException("Failed to get response from Rate Limit Service"); + } + + if (response.getOverallCode() == RateLimitResponse.Code.OK) { + requestsAllowed.inc(); + return true; + } else if (response.getOverallCode() == RateLimitResponse.Code.OVER_LIMIT) { + long sleepMillis = 0; + for (RateLimitResponse.DescriptorStatus status : response.getStatusesList()) { + if (status.getCode() == RateLimitResponse.Code.OVER_LIMIT + && status.hasDurationUntilReset()) { + long durationMillis = + status.getDurationUntilReset().getSeconds() * 1000 + + status.getDurationUntilReset().getNanos() / 1_000_000; + if (durationMillis > sleepMillis) { + sleepMillis = durationMillis; + } + } + } + + if (sleepMillis == 0) { + sleepMillis = 1000; + } + + long jitter = + (long) + (java.util.concurrent.ThreadLocalRandom.current().nextDouble() + * (0.01 * sleepMillis)); + sleepMillis += jitter; + + LOG.warn("Throttled by RLS, sleeping for {} ms", sleepMillis); + if (sleeper != null) { + requestsThrottled.inc(); + if (throttlingSignaler != null) { + throttlingSignaler.signalThrottling(sleepMillis); + } + sleeper.sleep(sleepMillis); + } + attempt++; + } else { + throw new IOException( + "Rate Limit Service returned unknown code: " + response.getOverallCode()); + } + } + } +} diff --git a/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiter.java b/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiter.java new file mode 100644 index 000000000000..8c02654b3964 --- /dev/null +++ b/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiter.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.components.ratelimiter; + +import java.io.IOException; +import java.io.Serializable; + +/** + * A RateLimiter allows to fetch permits from a rate limiter service and blocks execution when the + * rate limit is exceeded. + * + *

Implementations must be {@link Serializable} as they are passed to workers. + */ +public interface RateLimiter extends Serializable, AutoCloseable { + + /** + * Blocks until the specified number of permits are acquired and returns true if the request was + * allowed or false if the request was rejected. + * + * @param permits Number of permits to acquire. + * @return true if the request was allowed, false if it was rejected (and retries exceeded). + * @throws IOException if there is an error communicating with the rate limiter service. + * @throws InterruptedException if the thread is interrupted while waiting. + */ + boolean allow(int permits) throws IOException, InterruptedException; +} diff --git a/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiterClientCache.java b/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiterClientCache.java new file mode 100644 index 000000000000..9857ca9b4bd7 --- /dev/null +++ b/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiterClientCache.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.components.ratelimiter; + +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A static cache for {@link ManagedChannel}s to Rate Limit Service. + * + *

This class ensures that multiple DoFn instances (threads) in the same Worker sharing the same + * RLS address will share a single {@link ManagedChannel}. + * + *

It uses reference counting to close the channel when it is no longer in use by any RateLimiter + * instance. + */ +public class RateLimiterClientCache { + private static final Logger LOG = LoggerFactory.getLogger(RateLimiterClientCache.class); + private static final Map CACHE = new ConcurrentHashMap<>(); + + private final ManagedChannel channel; + private final String address; + private int refCount = 0; + + private RateLimiterClientCache(String address) { + this.address = address; + LOG.info("Creating new ManagedChannel for RLS at {}", address); + this.channel = ManagedChannelBuilder.forTarget(address).usePlaintext().build(); + } + + /** + * Gets or creates a cached client for the given address. Increments the reference count. + * Synchronized on the class to prevent race conditions when multiple instances call getOrCreate() + * simultaneously + */ + public static synchronized RateLimiterClientCache getOrCreate(String address) { + RateLimiterClientCache client = CACHE.get(address); + if (client == null) { + client = new RateLimiterClientCache(address); + CACHE.put(address, client); + } + client.refCount++; + LOG.debug("Referenced RLS Channel for {}. New RefCount: {}", address, client.refCount); + return client; + } + + public ManagedChannel getChannel() { + return channel; + } + + /** + * Releases the client. Decrements the reference count. If reference count reaches 0, the channel + * is shut down and removed from the cache. Synchronized on the class to prevent race conditions + * when multiple instances call release() simultaneously + */ + public void release() { + synchronized (RateLimiterClientCache.class) { + refCount--; + LOG.debug("Released RLS Channel for {}. New RefCount: {}", address, refCount); + if (refCount <= 0) { + LOG.info("Closing ManagedChannel for RLS at {}", address); + CACHE.remove(address); + channel.shutdown(); + try { + channel.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + LOG.error("Couldn't gracefully close gRPC channel={}", channel, e); + } + channel.shutdownNow(); + } + } + } +} diff --git a/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiterContext.java b/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiterContext.java new file mode 100644 index 000000000000..6387bf5789e4 --- /dev/null +++ b/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiterContext.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.components.ratelimiter; + +import java.io.Serializable; + +/** + * A marker interface for context data required to check ratelimit. + * + *

Implementations must be {@link Serializable}. + */ +public interface RateLimiterContext extends Serializable {} diff --git a/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiterFactory.java b/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiterFactory.java new file mode 100644 index 000000000000..b4330cd53db5 --- /dev/null +++ b/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiterFactory.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.components.ratelimiter; + +import java.io.IOException; +import java.io.Serializable; + +/** + * A factory that manages connections to rate limit service and creates lightweight handles. + * + *

Implementations must be {@link Serializable} as they are passed to workers. The factory + * typically manages the heavy connection (e.g. gRPC stub) and is thread-safe. + */ +public interface RateLimiterFactory extends Serializable, AutoCloseable { + + /** + * Creates a lightweight ratelimiter handle bound to a specific context. + * + *

Use this when passing ratelimiter to IO components, which doesn't need to know about the + * configuration or the underlying ratelimiter service details. This is also useful in DoFns when + * you want to use the ratelimiter in a static way based on the compile time context. + * + * @param context The context for the ratelimit. + * @return A {@link RateLimiter} handle. + */ + RateLimiter getLimiter(RateLimiterContext context); + + /** + * Blocks until the specified number of permits are acquired and returns true if the request was + * allowed or false if the request was rejected. + * + *

Use this for when the ratelimit namespace or descriptors are not known at compile time. + * allows you to use the ratelimiter in a dynamic way based on the runtime data. + * + * @param context The context for the ratelimit. + * @param permits Number of permits to acquire. + * @return true if the request is allowed, false if rejected. + * @throws IOException if there is an error communicating with the ratelimiter service. + * @throws InterruptedException if the thread is interrupted while waiting. + */ + boolean allow(RateLimiterContext context, int permits) throws IOException, InterruptedException; +} diff --git a/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiterOptions.java b/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiterOptions.java new file mode 100644 index 000000000000..3b925609bd70 --- /dev/null +++ b/sdks/java/io/components/src/main/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiterOptions.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.components.ratelimiter; + +import com.google.auto.value.AutoValue; +import java.io.Serializable; +import java.time.Duration; +import javax.annotation.Nullable; +import org.apache.beam.sdk.schemas.AutoValueSchema; +import org.apache.beam.sdk.schemas.annotations.DefaultSchema; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; + +/** Configuration options for {@link RateLimiterFactory}. */ +@DefaultSchema(AutoValueSchema.class) +@AutoValue +public abstract class RateLimiterOptions implements Serializable { + public abstract String getAddress(); + + @Nullable + public abstract Integer getMaxRetries(); + + public abstract Duration getTimeout(); + + public static Builder builder() { + return new AutoValue_RateLimiterOptions.Builder().setTimeout(Duration.ofSeconds(5)); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setAddress(String address); + + public abstract Builder setMaxRetries(Integer maxRetries); + + public abstract Builder setTimeout(Duration timeout); + + abstract RateLimiterOptions autoBuild(); + + public RateLimiterOptions build() { + RateLimiterOptions options = autoBuild(); + Preconditions.checkArgument( + options.getTimeout().compareTo(Duration.ZERO) > 0, "Timeout must be positive"); + Integer maxRetries = options.getMaxRetries(); + if (maxRetries != null) { + Preconditions.checkArgument(maxRetries >= 0, "MaxRetries must be non-negative"); + } + return options; + } + } +} diff --git a/sdks/java/io/components/src/test/java/org/apache/beam/sdk/io/components/ratelimiter/EnvoyRateLimiterTest.java b/sdks/java/io/components/src/test/java/org/apache/beam/sdk/io/components/ratelimiter/EnvoyRateLimiterTest.java new file mode 100644 index 000000000000..e94d0b42eb3c --- /dev/null +++ b/sdks/java/io/components/src/test/java/org/apache/beam/sdk/io/components/ratelimiter/EnvoyRateLimiterTest.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.components.ratelimiter; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.verify; + +import io.envoyproxy.envoy.service.ratelimit.v3.RateLimitRequest; +import io.envoyproxy.envoy.service.ratelimit.v3.RateLimitResponse; +import io.envoyproxy.envoy.service.ratelimit.v3.RateLimitServiceGrpc; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.Status; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.StreamObserver; +import java.io.IOException; +import org.apache.beam.sdk.util.Sleeper; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +/** Tests for {@link EnvoyRateLimiterFactory}. */ +@RunWith(JUnit4.class) +public class EnvoyRateLimiterTest { + @Mock private Sleeper sleeper; + + private EnvoyRateLimiterFactory factory; + private RateLimiterOptions options; + private EnvoyRateLimiterContext context; + + private Server server; + private ManagedChannel channel; + private TestRateLimitService service; + + @Before + public void setUp() throws Exception { + MockitoAnnotations.openMocks(this); + options = + RateLimiterOptions.builder() + .setAddress("localhost:8081") + .setTimeout(java.time.Duration.ofSeconds(1)) + .build(); + + String serverName = InProcessServerBuilder.generateName(); + service = new TestRateLimitService(); + server = + InProcessServerBuilder.forName(serverName) + .directExecutor() + .addService(service) + .build() + .start(); + channel = InProcessChannelBuilder.forName(serverName).directExecutor().build(); + + factory = new EnvoyRateLimiterFactory(options, sleeper); + factory.setStub(RateLimitServiceGrpc.newBlockingStub(channel)); + + context = + EnvoyRateLimiterContext.builder() + .setDomain("test-domain") + .addDescriptor("key", "value") + .build(); + } + + @After + public void tearDown() { + if (channel != null) { + channel.shutdownNow(); + } + if (server != null) { + server.shutdownNow(); + } + } + + @Test + public void testAllow_OK() throws Exception { + service.responseToReturn = + RateLimitResponse.newBuilder().setOverallCode(RateLimitResponse.Code.OK).build(); + + assertTrue(factory.allow(context, 1)); + } + + @Test + public void testAllow_OverLimit() throws Exception { + service.responseToReturn = + RateLimitResponse.newBuilder() + .setOverallCode(RateLimitResponse.Code.OVER_LIMIT) + .addStatuses( + RateLimitResponse.DescriptorStatus.newBuilder() + .setCode(RateLimitResponse.Code.OVER_LIMIT) + .setDurationUntilReset( + com.google.protobuf.Duration.newBuilder().setSeconds(1).build()) + .build()) + .build(); + + factory = + new EnvoyRateLimiterFactory( + RateLimiterOptions.builder() + .setAddress("foo") + .setTimeout(java.time.Duration.ofSeconds(1)) + .setMaxRetries(1) + .build(), + sleeper); + factory.setStub(RateLimitServiceGrpc.newBlockingStub(channel)); + + assertFalse(factory.allow(context, 1)); + + // Verify sleep was called. + verify(sleeper, org.mockito.Mockito.atLeastOnce()).sleep(anyLong()); + } + + @Test + public void testAllow_RpcError() throws Exception { + service.errorToThrow = Status.UNAVAILABLE.asRuntimeException(); + assertThrows(IOException.class, () -> factory.allow(context, 1)); + } + + @Test + public void testInvalidContext() { + assertThrows( + IllegalArgumentException.class, () -> factory.allow(new RateLimiterContext() {}, 1)); + } + + static class TestRateLimitService extends RateLimitServiceGrpc.RateLimitServiceImplBase { + volatile RateLimitResponse responseToReturn; + volatile RuntimeException errorToThrow; + + @Override + public void shouldRateLimit( + RateLimitRequest request, StreamObserver responseObserver) { + if (errorToThrow != null) { + responseObserver.onError(errorToThrow); + return; + } + if (responseToReturn != null) { + responseObserver.onNext(responseToReturn); + responseObserver.onCompleted(); + } else { + // Default OK + responseObserver.onNext( + RateLimitResponse.newBuilder().setOverallCode(RateLimitResponse.Code.OK).build()); + responseObserver.onCompleted(); + } + } + } +} diff --git a/sdks/java/io/components/src/test/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiterClientCacheTest.java b/sdks/java/io/components/src/test/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiterClientCacheTest.java new file mode 100644 index 000000000000..4eb61b279c34 --- /dev/null +++ b/sdks/java/io/components/src/test/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiterClientCacheTest.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.components.ratelimiter; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link RateLimiterClientCache}. */ +@RunWith(JUnit4.class) +public class RateLimiterClientCacheTest { + + @Test + public void testGetOrCreate_SameAddress() { + String address = "addr1"; + RateLimiterClientCache client1 = RateLimiterClientCache.getOrCreate(address); + RateLimiterClientCache client2 = RateLimiterClientCache.getOrCreate(address); + + assertSame(client1, client2); + assertFalse(client1.getChannel().isShutdown()); + + // cleanup + client1.release(); + // client2 is still using the same channel + assertFalse(client1.getChannel().isShutdown()); + client2.release(); + assertTrue(client1.getChannel().isShutdown()); + } + + @Test + public void testGetOrCreate_DifferentAddress_ReturnsDifferentInstances() { + RateLimiterClientCache client1 = RateLimiterClientCache.getOrCreate("addr1"); + RateLimiterClientCache client2 = RateLimiterClientCache.getOrCreate("addr2"); + + assertNotSame(client1, client2); + + assertFalse(client1.getChannel().isShutdown()); + assertFalse(client2.getChannel().isShutdown()); + client1.release(); + assertTrue(client1.getChannel().isShutdown()); + client2.release(); + assertTrue(client2.getChannel().isShutdown()); + } + + @Test + public void testConcurrency() throws InterruptedException, ExecutionException { + int threads = 10; + int iterations = 100; + String address = "concurrent-addr"; + ExecutorService pool = Executors.newFixedThreadPool(threads); + List> futures = new ArrayList<>(); + + for (int i = 0; i < threads; i++) { + futures.add( + pool.submit( + new Callable() { + @Override + public Boolean call() { + for (int j = 0; j < iterations; j++) { + RateLimiterClientCache client = RateLimiterClientCache.getOrCreate(address); + // do some tiny work + try { + Thread.sleep(1); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + client.release(); + } + return true; + } + })); + } + + for (Future f : futures) { + assertTrue(f.get()); + } + + pool.shutdown(); + pool.awaitTermination(5, TimeUnit.SECONDS); + + // After all threads are done, cache should be empty or create new one cleanly + RateLimiterClientCache client = RateLimiterClientCache.getOrCreate(address); + assertFalse(client.getChannel().isShutdown()); + client.release(); + assertTrue(client.getChannel().isShutdown()); + } +} diff --git a/sdks/java/io/components/src/test/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiterOptionsTest.java b/sdks/java/io/components/src/test/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiterOptionsTest.java new file mode 100644 index 000000000000..cb8674b4e502 --- /dev/null +++ b/sdks/java/io/components/src/test/java/org/apache/beam/sdk/io/components/ratelimiter/RateLimiterOptionsTest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.components.ratelimiter; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import java.time.Duration; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link RateLimiterOptions}. */ +@RunWith(JUnit4.class) +public class RateLimiterOptionsTest { + + @Test + public void testValidOptions() { + RateLimiterOptions options = + RateLimiterOptions.builder() + .setAddress("localhost:8081") + .setTimeout(Duration.ofSeconds(1)) + .setMaxRetries(3) + .build(); + + assertEquals("localhost:8081", options.getAddress()); + assertEquals(Duration.ofSeconds(1), options.getTimeout()); + assertEquals(Integer.valueOf(3), options.getMaxRetries()); + } + + @Test + public void testNegativeTimeout() { + assertThrows( + IllegalArgumentException.class, + () -> + RateLimiterOptions.builder() + .setAddress("localhost:8081") + .setTimeout(Duration.ofSeconds(-1)) + .build()); + } + + @Test + public void testZeroTimeout() { + assertThrows( + IllegalArgumentException.class, + () -> + RateLimiterOptions.builder() + .setAddress("localhost:8081") + .setTimeout(Duration.ZERO) + .build()); + } + + @Test + public void testNegativeMaxRetries() { + assertThrows( + IllegalArgumentException.class, + () -> RateLimiterOptions.builder().setAddress("localhost:8081").setMaxRetries(-1).build()); + } + + @Test + public void testNullMaxRetriesIsAllowed() { + RateLimiterOptions options = + RateLimiterOptions.builder().setAddress("localhost:8081").setMaxRetries(null).build(); + assertEquals(null, options.getMaxRetries()); + } +} diff --git a/sdks/java/io/google-cloud-platform/build.gradle b/sdks/java/io/google-cloud-platform/build.gradle index 5dd3f9bb761d..ec92b0c5c7f4 100644 --- a/sdks/java/io/google-cloud-platform/build.gradle +++ b/sdks/java/io/google-cloud-platform/build.gradle @@ -162,6 +162,7 @@ dependencies { testImplementation project(path: ":runners:direct-java", configuration: "shadow") testImplementation project(":sdks:java:managed") testImplementation project(path: ":sdks:java:io:common") + implementation project(":sdks:java:io:components") testImplementation project(path: ":sdks:java:testing:test-utils") testImplementation library.java.commons_math3 testImplementation library.java.google_cloud_bigquery diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java index 450710112a1b..1568faf25928 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerIO.java @@ -75,6 +75,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.io.components.ratelimiter.RateLimiter; import org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamMetrics; import org.apache.beam.sdk.io.gcp.spanner.changestreams.ChangeStreamsConstants; import org.apache.beam.sdk.io.gcp.spanner.changestreams.MetadataSpannerConfigFactory; @@ -1289,6 +1290,8 @@ public abstract static class Write extends PTransform, Spa abstract @Nullable PCollectionView getDialectView(); + abstract @Nullable RateLimiter getRateLimiter(); + abstract Builder toBuilder(); @AutoValue.Builder @@ -1310,6 +1313,8 @@ abstract static class Builder { abstract Builder setDialectView(PCollectionView dialect); + abstract Builder setRateLimiter(RateLimiter rateLimiter); + abstract Write build(); } @@ -1393,6 +1398,11 @@ public Write withUsingPlainTextChannel(ValueProvider plainText) { return withSpannerConfig(config.withUsingPlainTextChannel(plainText)); } + /** Specifies the {@link RateLimiter} to use to throttle IO. */ + public Write withRateLimiter(RateLimiter rateLimiter) { + return toBuilder().setRateLimiter(rateLimiter).build(); + } + /** * Specifies whether to use plaintext channel. * @@ -1697,7 +1707,10 @@ public SpannerWriteResult expand(PCollection input) { "Write batches to Spanner", ParDo.of( new WriteToSpannerFn( - spec.getSpannerConfig(), spec.getFailureMode(), FAILED_MUTATIONS_TAG)) + spec.getSpannerConfig(), + spec.getFailureMode(), + FAILED_MUTATIONS_TAG, + spec.getRateLimiter())) .withOutputTags(MAIN_OUT_TAG, TupleTagList.of(FAILED_MUTATIONS_TAG))); return new SpannerWriteResult( @@ -2458,11 +2471,17 @@ static class WriteToSpannerFn extends DoFn, Void> { private transient FluentBackoff bundleWriteBackoff; private transient LoadingCache writeMetricsByTableName; + private final @Nullable RateLimiter rateLimiter; + WriteToSpannerFn( - SpannerConfig spannerConfig, FailureMode failureMode, TupleTag failedTag) { + SpannerConfig spannerConfig, + FailureMode failureMode, + TupleTag failedTag, + @Nullable RateLimiter rateLimiter) { this.spannerConfig = spannerConfig; this.failureMode = failureMode; this.failedTag = failedTag; + this.rateLimiter = rateLimiter; } @Setup @@ -2491,8 +2510,11 @@ public ServiceCallMetric load(String tableName) { } @Teardown - public void teardown() { + public void teardown() throws Exception { spannerAccessor.close(); + if (rateLimiter != null) { + rateLimiter.close(); + } } @ProcessElement @@ -2618,7 +2640,7 @@ private static ServiceCallMetric buildWriteServiceCallMetric( /** Write the Mutations to Spanner, handling DEADLINE_EXCEEDED with backoff/retries. */ private void writeMutations(Iterable mutationIterable) - throws SpannerException, IOException { + throws SpannerException, IOException, InterruptedException { BackOff backoff = bundleWriteBackoff.backoff(); List mutations = ImmutableList.copyOf(mutationIterable); @@ -2626,6 +2648,9 @@ private void writeMutations(Iterable mutationIterable) Stopwatch timer = Stopwatch.createStarted(); // loop is broken on success, timeout backoff/retry attempts exceeded, or other failure. try { + if (rateLimiter != null) { + rateLimiter.allow(1); + } spannerWriteWithRetryIfSchemaChange(mutations); spannerWriteSuccess.inc(); return; diff --git a/settings.gradle.kts b/settings.gradle.kts index 4540fa4b597b..1d6e0e6e20ae 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -212,6 +212,7 @@ include(":sdks:java:io:azure-cosmos") include(":sdks:java:io:cassandra") include(":sdks:java:io:clickhouse") include(":sdks:java:io:common") +include(":sdks:java:io:components") include(":sdks:java:io:contextualtextio") include(":sdks:java:io:debezium") include(":sdks:java:io:debezium:expansion-service")