diff --git a/paimon-common/src/main/java/org/apache/paimon/predicate/RowIdPredicateVisitor.java b/paimon-common/src/main/java/org/apache/paimon/predicate/RowIdPredicateVisitor.java new file mode 100644 index 000000000000..9a0de2af524b --- /dev/null +++ b/paimon-common/src/main/java/org/apache/paimon/predicate/RowIdPredicateVisitor.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.predicate; + +import java.util.HashSet; +import java.util.Set; + +import static org.apache.paimon.table.SpecialFields.ROW_ID; + +/** + * The {@link PredicateVisitor} to extract a list of Row IDs from predicates. The returned Row IDs + * can be pushed down to manifest readers and file readers to enable efficient random access. + * + *

Note that there is a significant distinction between returning {@code null} and returning an + * empty set: + * + *

+ */ +public class RowIdPredicateVisitor implements PredicateVisitor> { + + @Override + public Set visit(LeafPredicate predicate) { + if (ROW_ID.name().equals(predicate.fieldName())) { + LeafFunction function = predicate.function(); + if (function instanceof Equal || function instanceof In) { + HashSet rowIds = new HashSet<>(); + for (Object literal : predicate.literals()) { + rowIds.add((Long) literal); + } + return rowIds; + } + } + return null; + } + + @Override + public Set visit(CompoundPredicate predicate) { + CompoundPredicate.Function function = predicate.function(); + HashSet rowIds = null; + // `And` means we should get the intersection of all children. + if (function instanceof And) { + for (Predicate child : predicate.children()) { + Set childSet = child.visit(this); + if (childSet == null) { + return null; + } + + if (rowIds == null) { + rowIds = new HashSet<>(childSet); + } else { + rowIds.retainAll(childSet); + } + + // shortcut for intersection + if (rowIds.isEmpty()) { + return rowIds; + } + } + } else if (function instanceof Or) { + // `Or` means we should get the union of all children + rowIds = new HashSet<>(); + for (Predicate child : predicate.children()) { + Set childSet = child.visit(this); + if (childSet == null) { + return null; + } + + rowIds.addAll(childSet); + } + } else { + // unexpected function type, just return null + return null; + } + return rowIds; + } + + @Override + public Set visit(TransformPredicate predicate) { + // do not support transform predicate now. + return null; + } +} diff --git a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/SystemTableSource.java b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/SystemTableSource.java index e99f265d03b4..818735c5550d 100644 --- a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/SystemTableSource.java +++ b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/source/SystemTableSource.java @@ -19,15 +19,23 @@ package org.apache.paimon.flink.source; import org.apache.paimon.CoreOptions; +import org.apache.paimon.annotation.VisibleForTesting; import org.apache.paimon.flink.FlinkConnectorOptions; +import org.apache.paimon.flink.LogicalTypeConversion; import org.apache.paimon.flink.NestedProjectedRowData; import org.apache.paimon.flink.PaimonDataStreamScanProvider; +import org.apache.paimon.flink.PredicateConverter; import org.apache.paimon.flink.Projection; import org.apache.paimon.options.Options; +import org.apache.paimon.predicate.PartitionPredicateVisitor; import org.apache.paimon.predicate.Predicate; +import org.apache.paimon.predicate.PredicateBuilder; +import org.apache.paimon.predicate.PredicateVisitor; +import org.apache.paimon.predicate.RowIdPredicateVisitor; import org.apache.paimon.table.DataTable; import org.apache.paimon.table.Table; import org.apache.paimon.table.source.ReadBuilder; +import org.apache.paimon.table.system.RowTrackingTable; import org.apache.flink.api.common.eventtime.WatermarkStrategy; import org.apache.flink.api.connector.source.Boundedness; @@ -36,16 +44,28 @@ import org.apache.flink.table.catalog.ObjectIdentifier; import org.apache.flink.table.connector.ChangelogMode; import org.apache.flink.table.data.RowData; +import org.apache.flink.table.expressions.ResolvedExpression; +import org.apache.flink.table.types.logical.RowType; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import javax.annotation.Nullable; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; + /** A {@link FlinkTableSource} for system table. */ public class SystemTableSource extends FlinkTableSource { + private static final Logger LOG = LoggerFactory.getLogger(SystemTableSource.class); private final boolean unbounded; private final int splitBatchSize; private final FlinkConnectorOptions.SplitAssignMode splitAssignMode; private final ObjectIdentifier tableIdentifier; + @Nullable private List rowIds; public SystemTableSource(Table table, boolean unbounded, ObjectIdentifier tableIdentifier) { super(table); @@ -62,6 +82,7 @@ public SystemTableSource( @Nullable Predicate predicate, @Nullable int[][] projectFields, @Nullable Long limit, + @Nullable List rowIds, int splitBatchSize, FlinkConnectorOptions.SplitAssignMode splitAssignMode, ObjectIdentifier tableIdentifier) { @@ -70,6 +91,62 @@ public SystemTableSource( this.splitBatchSize = splitBatchSize; this.splitAssignMode = splitAssignMode; this.tableIdentifier = tableIdentifier; + this.rowIds = rowIds; + } + + @Override + public Result applyFilters(List filters) { + List partitionKeys = table.partitionKeys(); + RowType rowType = LogicalTypeConversion.toLogicalType(table.rowType()); + + // The source must ensure the consumed filters are fully evaluated, otherwise the result + // of query will be wrong. + List unConsumedFilters = new ArrayList<>(); + List consumedFilters = new ArrayList<>(); + List converted = new ArrayList<>(); + PredicateVisitor onlyPartFieldsVisitor = + new PartitionPredicateVisitor(partitionKeys); + PredicateVisitor> rowIdVisitor = new RowIdPredicateVisitor(); + + Set rowIdSet = null; + for (ResolvedExpression filter : filters) { + Optional predicateOptional = PredicateConverter.convert(rowType, filter); + + if (!predicateOptional.isPresent()) { + unConsumedFilters.add(filter); + } else { + Predicate p = predicateOptional.get(); + if (isUnbounded() || !p.visit(onlyPartFieldsVisitor)) { + boolean rowIdFilterConsumed = false; + if (table instanceof RowTrackingTable) { + Set ids = p.visit(rowIdVisitor); + if (ids != null) { + rowIdFilterConsumed = true; + if (rowIdSet == null) { + rowIdSet = new HashSet<>(ids); + } else { + rowIdSet.retainAll(ids); + } + } + } + if (rowIdFilterConsumed) { + // do not need to add consumed RowId filters to predicate + consumedFilters.add(filter); + } else { + unConsumedFilters.add(filter); + converted.add(p); + } + } else { + consumedFilters.add(filter); + converted.add(p); + } + } + } + predicate = converted.isEmpty() ? null : PredicateBuilder.and(converted); + rowIds = rowIdSet == null ? null : new ArrayList<>(rowIdSet); + LOG.info("Consumed filters: {} of {}", consumedFilters, filters); + + return Result.of(filters, unConsumedFilters); } @Override @@ -97,6 +174,9 @@ public ScanRuntimeProvider getScanRuntimeProvider(ScanContext scanContext) { readBuilder.withFilter(predicate); } readBuilder.withPartitionFilter(partitionPredicate); + if (rowIds != null) { + readBuilder.withRowIds(rowIds); + } if (unbounded && table instanceof DataTable) { source = @@ -141,6 +221,7 @@ public SystemTableSource copy() { predicate, projectFields, limit, + rowIds, splitBatchSize, splitAssignMode, tableIdentifier); @@ -155,4 +236,9 @@ public String asSummaryString() { public boolean isUnbounded() { return unbounded; } + + @VisibleForTesting + public List getRowIds() { + return rowIds; + } } diff --git a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/FlinkTableSourceTest.java b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/FlinkTableSourceTest.java index cff9ab6f4d25..994c8766f8ee 100644 --- a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/FlinkTableSourceTest.java +++ b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/FlinkTableSourceTest.java @@ -18,6 +18,7 @@ package org.apache.paimon.flink.source; +import org.apache.paimon.CoreOptions; import org.apache.paimon.fs.FileIO; import org.apache.paimon.fs.Path; import org.apache.paimon.fs.local.LocalFileIO; @@ -27,6 +28,7 @@ import org.apache.paimon.table.FileStoreTableFactory; import org.apache.paimon.table.Table; import org.apache.paimon.table.TableTestBase; +import org.apache.paimon.table.system.RowTrackingTable; import org.apache.paimon.types.DataTypes; import org.apache.paimon.shade.guava30.com.google.common.collect.ImmutableList; @@ -136,6 +138,107 @@ public void testApplyPartitionTable() throws Exception { .isEqualTo(ImmutableList.of(filters.get(1))); } + @Test + public void testApplyRowIdFilters() throws Exception { + FileIO fileIO = LocalFileIO.create(); + Path tablePath = new Path(String.format("%s/%s.db/%s", warehouse, database, "T")); + Schema schema = + Schema.newBuilder() + .column("col1", DataTypes.INT()) + .column("col2", DataTypes.STRING()) + .column("col3", DataTypes.DOUBLE()) + .column("p1", DataTypes.INT()) + .column("p2", DataTypes.STRING()) + .partitionKeys("p1", "p2") + .option(CoreOptions.ROW_TRACKING_ENABLED.key(), "true") + .option(CoreOptions.DATA_EVOLUTION_ENABLED.key(), "true") + .build(); + TableSchema tableSchema = new SchemaManager(fileIO, tablePath).createTable(schema); + Table table = + new RowTrackingTable( + FileStoreTableFactory.create(LocalFileIO.create(), tablePath, tableSchema)); + SystemTableSource tableSource = + new SystemTableSource(table, false, ObjectIdentifier.of("catalog1", "db1", "T")); + + List filters; + + // col1 = 1 && p1 = 1 => [p1 = 1], idList = NULL + filters = ImmutableList.of(col1Equal1(), p1Equal1()); + Assertions.assertThat(tableSource.applyFilters(filters).getRemainingFilters()) + .containsExactlyInAnyOrderElementsOf(ImmutableList.of(filters.get(0))); + Assertions.assertThat(tableSource.getRowIds()).isNull(); + + // col1 = 1 || _ROW_ID = 1 => [col1 = 1 || _ROW_ID = 1], idList = NULL + filters = ImmutableList.of(or(col1Equal1(), rowIdEqual(1))); + Assertions.assertThat(tableSource.applyFilters(filters).getRemainingFilters()) + .containsExactlyInAnyOrderElementsOf(ImmutableList.of(filters.get(0))); + Assertions.assertThat(tableSource.getRowIds()).isNull(); + + // _ROW_ID = 1 && col1 = 1 => [col1 = 1], idList = [1] + filters = ImmutableList.of(rowIdEqual(1), col1Equal1()); + Assertions.assertThat(tableSource.applyFilters(filters).getRemainingFilters()) + .containsExactlyInAnyOrderElementsOf(ImmutableList.of(filters.get(1))); + Assertions.assertThat(tableSource.getRowIds()) + .containsExactlyInAnyOrderElementsOf(ImmutableList.of(1L)); + + // _ROW_ID in (1, 2, 3) && col1 = 1 => [col1 = 1], idList = [1, 2, 3] + filters = ImmutableList.of(rowIdIn(1L, 2L, 3L), col1Equal1()); + Assertions.assertThat(tableSource.applyFilters(filters).getRemainingFilters()) + .containsExactlyInAnyOrderElementsOf(ImmutableList.of(filters.get(1))); + Assertions.assertThat(tableSource.getRowIds()) + .containsExactlyInAnyOrderElementsOf(ImmutableList.of(1L, 2L, 3L)); + + // _ROW_ID = 1 && _ROW_ID = 2 && p1 = 1 => None, idList = [] + filters = ImmutableList.of(rowIdEqual(1), rowIdEqual(2), p1Equal1()); + Assertions.assertThat(tableSource.applyFilters(filters).getRemainingFilters()).isEmpty(); + Assertions.assertThat(tableSource.getRowIds()).isEmpty(); + + // _ROW_ID = 1 && (_ROW_ID = 1 || _ROW_ID = 2) => None, idList = [1] + filters = ImmutableList.of(rowIdEqual(1), or(rowIdEqual(1), rowIdEqual(2))); + Assertions.assertThat(tableSource.applyFilters(filters).getRemainingFilters()).isEmpty(); + Assertions.assertThat(tableSource.getRowIds()) + .containsExactlyInAnyOrderElementsOf(ImmutableList.of(1L)); + + // _ROW_ID in (1, 2, 3, 4) && _ROW_ID in (1, 4, 6, 9) => None, idList = [1, 4] + filters = ImmutableList.of(rowIdIn(1L, 2L, 3L, 4L), rowIdIn(1L, 4L, 6L, 9L)); + Assertions.assertThat(tableSource.applyFilters(filters).getRemainingFilters()).isEmpty(); + Assertions.assertThat(tableSource.getRowIds()) + .containsExactlyInAnyOrderElementsOf(ImmutableList.of(1L, 4L)); + + // _ROW_ID in (1, 2, 3, 4) || _ROW_ID in (4, 5) => None, idList = [1, 2, 3, 4, 5] + filters = ImmutableList.of(or(rowIdIn(1L, 2L, 3L, 4L), rowIdIn(4L, 5L))); + Assertions.assertThat(tableSource.applyFilters(filters).getRemainingFilters()).isEmpty(); + Assertions.assertThat(tableSource.getRowIds()) + .containsExactlyInAnyOrderElementsOf(ImmutableList.of(1L, 2L, 3L, 4L, 5L)); + } + + private ResolvedExpression rowIdEqual(long literal) { + return CallExpression.anonymous( + BuiltInFunctionDefinitions.EQUALS, + ImmutableList.of( + new FieldReferenceExpression( + "_ROW_ID", org.apache.flink.table.api.DataTypes.BIGINT(), 0, 5), + new ValueLiteralExpression( + literal, org.apache.flink.table.api.DataTypes.BIGINT().notNull())), + org.apache.flink.table.api.DataTypes.BOOLEAN()); + } + + private ResolvedExpression rowIdIn(Long... literals) { + ImmutableList.Builder argsBuilder = ImmutableList.builder(); + argsBuilder.add( + new FieldReferenceExpression( + "_ROW_ID", org.apache.flink.table.api.DataTypes.BIGINT(), 0, 5)); + for (long literal : literals) { + argsBuilder.add( + new ValueLiteralExpression( + literal, org.apache.flink.table.api.DataTypes.BIGINT().notNull())); + } + return CallExpression.anonymous( + BuiltInFunctionDefinitions.IN, + argsBuilder.build(), + org.apache.flink.table.api.DataTypes.BOOLEAN()); + } + private ResolvedExpression col1Equal1() { return CallExpression.anonymous( BuiltInFunctionDefinitions.EQUALS, diff --git a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/RowIdPushDownITCase.java b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/RowIdPushDownITCase.java new file mode 100644 index 000000000000..d53968b81f46 --- /dev/null +++ b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/source/RowIdPushDownITCase.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.flink.source; + +import org.apache.paimon.flink.CatalogITCaseBase; + +import org.apache.paimon.shade.guava30.com.google.common.collect.ImmutableList; + +import org.apache.flink.types.Row; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.List; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +/** ITCase for RowId push down. */ +public class RowIdPushDownITCase extends CatalogITCaseBase { + + @Override + public List ddl() { + return ImmutableList.of( + "CREATE TABLE T (" + + "a INT, b INT, c STRING) PARTITIONED BY (a) " + + "WITH ('row-tracking.enabled'='true');"); + } + + @BeforeEach + @Override + public void before() throws IOException { + super.before(); + setParallelism(1); + batchSql("INSERT INTO T VALUES (1, 1, '1'), (2, 2, '2'), (3, 3, '3'), (4, 4, '4')"); + } + + @Test + public void testSimplePredicate() throws Exception { + List result; + + // 1. in + result = sql("SELECT * FROM T$row_tracking WHERE _ROW_ID IN (1, 2)"); + assertFilteredResultEquals(result, id -> id == 1 || id == 2); + + // 2. equal + result = sql("SELECT * FROM T$row_tracking WHERE _ROW_ID = 3"); + assertFilteredResultEquals(result, id -> id == 3); + + // 3. empty + result = sql("SELECT * FROM T$row_tracking WHERE _ROW_ID IN (5, 6)"); + Assertions.assertThat(result).isEmpty(); + } + + @Test + public void testCompoundPredicate() throws Exception { + List result; + + // 1. AND + result = sql("SELECT * FROM T$row_tracking WHERE _ROW_ID IN (1, 2, 3) AND _ROW_ID = 3"); + assertFilteredResultEquals(result, id -> id == 3); + + // 2. OR + result = sql("SELECT * FROM T$row_tracking WHERE _ROW_ID IN (1, 2) OR _ROW_ID = 3"); + assertFilteredResultEquals(result, id -> id == 1 || id == 2 || id == 3); + + // 3. AND with empty result + result = sql("SELECT * FROM T$row_tracking WHERE _ROW_ID IN (1, 2) AND _ROW_ID = 3"); + Assertions.assertThat(result).isEmpty(); + } + + private void assertFilteredResultEquals(List result, Predicate rowIdFilter) { + List fullScan = sql("SELECT * FROM T$row_tracking"); + Assertions.assertThat(result) + .containsExactlyInAnyOrderElementsOf( + fullScan.stream() + .filter( + row -> { + Long rowId = row.getFieldAs(3); + return rowIdFilter.test(rowId); + }) + .collect(Collectors.toList())); + } +}