diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java index cfa06f3cf0d5..5c9e19da160e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/FileIO.java @@ -1059,6 +1059,12 @@ public static FileNaming relativeFileNaming( abstract @Nullable Integer getMaxNumWritersPerBundle(); + abstract @Nullable Integer getBatchSize(); + + abstract @Nullable Integer getBatchSizeBytes(); + + abstract @Nullable Duration getBatchMaxBufferingDuration(); + abstract @Nullable ErrorHandler getBadRecordErrorHandler(); abstract Builder toBuilder(); @@ -1112,6 +1118,13 @@ abstract Builder setSharding( abstract Builder setMaxNumWritersPerBundle( @Nullable Integer maxNumWritersPerBundle); + abstract Builder setBatchSize(@Nullable Integer batchSize); + + abstract Builder setBatchSizeBytes(@Nullable Integer batchSizeBytes); + + abstract Builder setBatchMaxBufferingDuration( + @Nullable Duration batchMaxBufferingDuration); + abstract Builder setBadRecordErrorHandler( @Nullable ErrorHandler badRecordErrorHandler); @@ -1301,6 +1314,7 @@ public Write withDestinationCoder(Coder desti */ public Write withNumShards(int numShards) { checkArgument(numShards >= 0, "numShards must be non-negative, but was: %s", numShards); + checkArgument(!getAutoSharding(), "Cannot set numShards when withAutoSharding() is used"); if (numShards == 0) { return withNumShards(null); } @@ -1311,6 +1325,7 @@ public Write withNumShards(int numShards) { * Like {@link #withNumShards(int)}. Specifying {@code null} means runner-determined sharding. */ public Write withNumShards(@Nullable ValueProvider numShards) { + checkArgument(!getAutoSharding(), "Cannot set numShards when withAutoSharding() is used"); return toBuilder().setNumShards(numShards).build(); } @@ -1321,6 +1336,7 @@ public Write withNumShards(@Nullable ValueProvider public Write withSharding( PTransform, PCollectionView> sharding) { checkArgument(sharding != null, "sharding can not be null"); + checkArgument(!getAutoSharding(), "Cannot set sharding when withAutoSharding() is used"); return toBuilder().setSharding(sharding).build(); } @@ -1337,6 +1353,9 @@ public Write withIgnoreWindowing() { } public Write withAutoSharding() { + checkArgument( + getNumShards() == null && getSharding() == null, + "Cannot use withAutoSharding() when withNumShards() or withSharding() is set"); return toBuilder().setAutoSharding(true).build(); } @@ -1366,6 +1385,44 @@ public Write withBadRecordErrorHandler( return toBuilder().setBadRecordErrorHandler(errorHandler).build(); } + /** + * Returns a new {@link Write} that will batch the input records using specified batch size. The + * default value is {@link WriteFiles#FILE_TRIGGERING_RECORD_COUNT}. + * + *

This option is used only for writing unbounded data with auto-sharding. + */ + public Write withBatchSize(@Nullable Integer batchSize) { + checkArgument(batchSize > 0, "batchSize must be positive, but was: %s", batchSize); + return toBuilder().setBatchSize(batchSize).build(); + } + + /** + * Returns a new {@link Write} that will batch the input records using specified batch size in + * bytes. The default value is {@link WriteFiles#FILE_TRIGGERING_BYTE_COUNT}. + * + *

This option is used only for writing unbounded data with auto-sharding. + */ + public Write withBatchSizeBytes(@Nullable Integer batchSizeBytes) { + checkArgument( + batchSizeBytes > 0, "batchSizeBytes must be positive, but was: %s", batchSizeBytes); + return toBuilder().setBatchSizeBytes(batchSizeBytes).build(); + } + + /** + * Returns a new {@link Write} that will batch the input records using specified max buffering + * duration. The default value is {@link WriteFiles#FILE_TRIGGERING_RECORD_BUFFERING_DURATION}. + * + *

This option is used only for writing unbounded data with auto-sharding. + */ + public Write withBatchMaxBufferingDuration( + @Nullable Duration batchMaxBufferingDuration) { + checkArgument( + batchMaxBufferingDuration.isLongerThan(Duration.ZERO), + "batchMaxBufferingDuration must be positive, but was: %s", + batchMaxBufferingDuration); + return toBuilder().setBatchMaxBufferingDuration(batchMaxBufferingDuration).build(); + } + @VisibleForTesting Contextful> resolveFileNamingFn() { if (getDynamic()) { @@ -1482,6 +1539,15 @@ public WriteFilesResult expand(PCollection input) { if (getBadRecordErrorHandler() != null) { writeFiles = writeFiles.withBadRecordErrorHandler(getBadRecordErrorHandler()); } + if (getBatchSize() != null) { + writeFiles = writeFiles.withBatchSize(getBatchSize()); + } + if (getBatchSizeBytes() != null) { + writeFiles = writeFiles.withBatchSizeBytes(getBatchSizeBytes()); + } + if (getBatchMaxBufferingDuration() != null) { + writeFiles = writeFiles.withBatchMaxBufferingDuration(getBatchMaxBufferingDuration()); + } return input.apply(writeFiles); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileIOTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileIOTest.java index d3c1f6680bee..dffc4943bfab 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileIOTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/FileIOTest.java @@ -19,11 +19,14 @@ import static org.apache.beam.sdk.io.fs.ResolveOptions.StandardResolveOptions.RESOLVE_FILE; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects.firstNonNull; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.isA; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import java.io.BufferedReader; import java.io.File; import java.io.FileNotFoundException; import java.io.FileOutputStream; @@ -38,7 +41,9 @@ import java.nio.file.Paths; import java.nio.file.StandardCopyOption; import java.nio.file.attribute.FileTime; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.zip.GZIPOutputStream; @@ -46,6 +51,7 @@ import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.io.fs.EmptyMatchTreatment; import org.apache.beam.sdk.io.fs.MatchResult; +import org.apache.beam.sdk.io.fs.MatchResult.Metadata; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.state.StateSpec; import org.apache.beam.sdk.state.StateSpecs; @@ -53,23 +59,30 @@ import org.apache.beam.sdk.testing.NeedsRunner; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.UsesUnboundedPCollections; import org.apache.beam.sdk.testing.UsesUnboundedSplittableParDo; import org.apache.beam.sdk.transforms.Contextful; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.Requirements; import org.apache.beam.sdk.transforms.SerializableFunctions; import org.apache.beam.sdk.transforms.View; import org.apache.beam.sdk.transforms.Watch; +import org.apache.beam.sdk.transforms.windowing.AfterWatermark; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; +import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.joda.time.Duration; import org.junit.Rule; import org.junit.Test; @@ -547,4 +560,130 @@ public void testFileIoDynamicNaming() throws IOException { "Output file shard 0 exists after pipeline completes", new File(outputFileName + "-0").exists()); } + + @Test + @Category({NeedsRunner.class, UsesUnboundedPCollections.class}) + public void testWriteUnboundedWithCustomBatchSize() throws IOException { + File root = tmpFolder.getRoot(); + List inputs = Arrays.asList("one", "two", "three", "four", "five", "six"); + + PTransform, PCollection> transform = + Window.into(FixedWindows.of(Duration.standardSeconds(10))) + .triggering(AfterWatermark.pastEndOfWindow()) + .withAllowedLateness(Duration.ZERO) + .discardingFiredPanes(); + + FileIO.Write write = + FileIO.write() + .via(TextIO.sink()) + .to(root.getAbsolutePath()) + .withPrefix("output") + .withSuffix(".txt") + .withAutoSharding() + .withBatchSize(3) + .withBatchSizeBytes(1024 * 1024) // Set high to avoid triggering flushing by byte count. + .withBatchMaxBufferingDuration( + Duration.standardMinutes(1)); // Set high to avoid triggering flushing by duration. + + // Prepare timestamps for the elements. + List timestamps = new ArrayList<>(); + for (long i = 0; i < inputs.size(); i++) { + timestamps.add(i + 1); + } + + p.apply(Create.timestamped(inputs, timestamps).withCoder(StringUtf8Coder.of())) + .setIsBoundedInternal(IsBounded.UNBOUNDED) + .apply(transform) + .apply(write); + p.run().waitUntilFinish(); + + // Verify that the custom batch parameters are set. + assertEquals(3, write.getBatchSize().intValue()); + assertEquals(1024 * 1024, write.getBatchSizeBytes().intValue()); + assertEquals(Duration.standardMinutes(1), write.getBatchMaxBufferingDuration()); + + // Verify file contents. + checkFileContents(root, "output", inputs); + + // With auto-sharding, we can't assert on the exact number of output files, but because + // batch size is 3 and there are 6 elements, we expect at least 2 files. + final String pattern = new File(root, "output").getAbsolutePath() + "*"; + List metadata = + FileSystems.match(Collections.singletonList(pattern)).get(0).metadata(); + assertTrue(metadata.size() >= 2); + } + + @Test + @Category({NeedsRunner.class, UsesUnboundedPCollections.class}) + public void testWriteUnboundedWithCustomBatchSizeBytes() throws IOException { + File root = tmpFolder.getRoot(); + // The elements plus newline characters give a total of 4+4+6+5+5+4=28 bytes. + List inputs = Arrays.asList("one", "two", "three", "four", "five", "six"); + // Assign timestamps so that all elements fall into the same 10s window. + List timestamps = Arrays.asList(1L, 2L, 3L, 4L, 5L, 6L); + + FileIO.Write write = + FileIO.write() + .via(TextIO.sink()) + .to(root.getAbsolutePath()) + .withPrefix("output") + .withSuffix(".txt") + .withAutoSharding() + .withBatchSize(1000) // Set high to avoid flushing by record count. + .withBatchSizeBytes(10) + .withBatchMaxBufferingDuration( + Duration.standardMinutes(1)); // Set high to avoid flushing by duration. + + p.apply(Create.timestamped(inputs, timestamps).withCoder(StringUtf8Coder.of())) + .setIsBoundedInternal(IsBounded.UNBOUNDED) + .apply( + Window.into(FixedWindows.of(Duration.standardSeconds(10))) + .triggering(AfterWatermark.pastEndOfWindow()) + .withAllowedLateness(Duration.ZERO) + .discardingFiredPanes()) + .apply(write); + + p.run().waitUntilFinish(); + + // Verify that the custom batch parameters are set. + assertEquals(1000, write.getBatchSize().intValue()); + assertEquals(10, write.getBatchSizeBytes().intValue()); + assertEquals(Duration.standardMinutes(1), write.getBatchMaxBufferingDuration()); + checkFileContents(root, "output", inputs); + + // With auto-sharding, we cannot assert on the exact number of output files. The BatchSizeBytes + // acts as a threshold for flushing; once buffer size reaches 10 bytes, a flush is triggered, + // but more items may be added before it completes. With 28 bytes total, we can only guarantee + // at least 2 files are produced. + final String pattern = new File(root, "output").getAbsolutePath() + "*"; + List metadata = + FileSystems.match(Collections.singletonList(pattern)).get(0).metadata(); + assertTrue(metadata.size() >= 2); + } + + static void checkFileContents(File rootDir, String prefix, List inputs) + throws IOException { + List outputFiles = Lists.newArrayList(); + final String pattern = new File(rootDir, prefix).getAbsolutePath() + "*"; + List metadata = + FileSystems.match(Collections.singletonList(pattern)).get(0).metadata(); + for (Metadata meta : metadata) { + outputFiles.add(new File(meta.resourceId().toString())); + } + assertFalse("Should have produced at least 1 output file", outputFiles.isEmpty()); + + List actual = Lists.newArrayList(); + for (File outputFile : outputFiles) { + List actualShard = Lists.newArrayList(); + try (BufferedReader reader = + Files.newBufferedReader(outputFile.toPath(), StandardCharsets.UTF_8)) { + String line; + while ((line = reader.readLine()) != null) { + actualShard.add(line); + } + } + actual.addAll(actualShard); + } + assertThat(actual, containsInAnyOrder(inputs.toArray())); + } }