From 00ee228c03a2d8847cc1c942478022c226f08832 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Mon, 9 Jun 2025 19:23:09 -0700 Subject: [PATCH 1/2] sync with origin dsl --- .../org/qed/Generated/PruneEmptyFilter.java | 39 +++++++++++++++++++ .../qed/Generated/PruneEmptyIntersect.java | 39 +++++++++++++++++++ .../org/qed/Generated/PruneEmptyProject.java | 39 +++++++++++++++++++ .../org/qed/Generated/PruneLeftEmptyJoin.java | 39 +++++++++++++++++++ .../qed/Generated/PruneRightEmptyJoin.java | 39 +++++++++++++++++++ 5 files changed, 195 insertions(+) create mode 100644 src/main/java/org/qed/Generated/PruneEmptyFilter.java create mode 100644 src/main/java/org/qed/Generated/PruneEmptyIntersect.java create mode 100644 src/main/java/org/qed/Generated/PruneEmptyProject.java create mode 100644 src/main/java/org/qed/Generated/PruneLeftEmptyJoin.java create mode 100644 src/main/java/org/qed/Generated/PruneRightEmptyJoin.java diff --git a/src/main/java/org/qed/Generated/PruneEmptyFilter.java b/src/main/java/org/qed/Generated/PruneEmptyFilter.java new file mode 100644 index 0000000..bcb125e --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneEmptyFilter.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneEmptyFilter extends RelRule { + protected PruneEmptyFilter(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_2 = call.builder(); + call.transformTo(var_2.empty().build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneEmptyFilter toRule() { + return new PruneEmptyFilter(this); + } + + @Override + default String description() { + return "PruneEmptyFilter"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(LogicalValues.class).noInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/PruneEmptyIntersect.java b/src/main/java/org/qed/Generated/PruneEmptyIntersect.java new file mode 100644 index 0000000..ddfcba3 --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneEmptyIntersect.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneEmptyIntersect extends RelRule { + protected PruneEmptyIntersect(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.empty().empty().intersect(false, 2).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneEmptyIntersect toRule() { + return new PruneEmptyIntersect(this); + } + + @Override + default String description() { + return "PruneEmptyIntersect"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalIntersect.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(LogicalValues.class).noInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/PruneEmptyProject.java b/src/main/java/org/qed/Generated/PruneEmptyProject.java new file mode 100644 index 0000000..2241478 --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneEmptyProject.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneEmptyProject extends RelRule { + protected PruneEmptyProject(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_2 = call.builder(); + call.transformTo(var_2.empty().build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneEmptyProject toRule() { + return new PruneEmptyProject(this); + } + + @Override + default String description() { + return "PruneEmptyProject"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_1 -> s_1.operand(LogicalProject.class).oneInput(s_0 -> s_0.operand(LogicalValues.class).noInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/PruneLeftEmptyJoin.java b/src/main/java/org/qed/Generated/PruneLeftEmptyJoin.java new file mode 100644 index 0000000..95774f0 --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneLeftEmptyJoin.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneLeftEmptyJoin extends RelRule { + protected PruneLeftEmptyJoin(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(2)).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneLeftEmptyJoin toRule() { + return new PruneLeftEmptyJoin(this); + } + + @Override + default String description() { + return "PruneLeftEmptyJoin"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(LogicalValues.class).noInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/PruneRightEmptyJoin.java b/src/main/java/org/qed/Generated/PruneRightEmptyJoin.java new file mode 100644 index 0000000..5d4e1c3 --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneRightEmptyJoin.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneRightEmptyJoin extends RelRule { + protected PruneRightEmptyJoin(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(1)).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneRightEmptyJoin toRule() { + return new PruneRightEmptyJoin(this); + } + + @Override + default String description() { + return "PruneRightEmptyJoin"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(LogicalValues.class).noInputs()); + } + + } +} From b8833dd22525d6fb9d8e18865e00ea386764bcf0 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 14 Jun 2025 00:14:21 -0700 Subject: [PATCH 2/2] fixed JoinCommute --- .../org/qed/Generated/CalciteGenerator.java | 122 +++++++++++++++++- .../java/org/qed/Generated/CalciteTester.java | 29 +++-- .../java/org/qed/Generated/JoinCommute.java | 16 ++- .../JoinCommute.java | 28 ---- .../Generated/RRuleInstances/JoinCommute.java | 50 +++++++ .../SemiJoinProjectTransposeTest.java | 0 .../qed/Generated/Tests/JoinCommuteTest.java | 84 ++++++++++++ 7 files changed, 289 insertions(+), 40 deletions(-) delete mode 100644 src/main/java/org/qed/Generated/RRuleInstances-unprovable/JoinCommute.java create mode 100644 src/main/java/org/qed/Generated/RRuleInstances/JoinCommute.java rename src/main/java/org/qed/Generated/{Tests-failed => Tests-Trivial}/SemiJoinProjectTransposeTest.java (100%) create mode 100644 src/main/java/org/qed/Generated/Tests/JoinCommuteTest.java diff --git a/src/main/java/org/qed/Generated/CalciteGenerator.java b/src/main/java/org/qed/Generated/CalciteGenerator.java index a5f9418..32691cd 100644 --- a/src/main/java/org/qed/Generated/CalciteGenerator.java +++ b/src/main/java/org/qed/Generated/CalciteGenerator.java @@ -11,6 +11,8 @@ import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.processing.Generated; + public class CalciteGenerator implements CodeGenerator { @Override @@ -279,7 +281,81 @@ public Env transformFilter(Env env, RelRN.Filter filter) { @Override public Env transformPred(Env env, RexRN.Pred pred) { - return env.focus(env.symbols().get(pred.operator().getName())); + // ONLY apply to the very specific JoinCommute pattern where: + // 1. We have exactly 2 JoinField sources + // 2. The first JoinField has ordinal 1 and second has ordinal 0 (indicating argument swap) + // This is the EXACT pattern from JoinCommute rule: pred($1, $0) vs pred($0, $1) + boolean isExactJoinCommuteSwapPattern = pred.sources().size() == 2 && + pred.sources().get(0) instanceof RexRN.JoinField joinField1 && + pred.sources().get(1) instanceof RexRN.JoinField joinField2 && + joinField1.ordinal() == 1 && // First arg is ordinal 1 (left -> right in swap) + joinField2.ordinal() == 0; // Second arg is ordinal 0 (right -> left in swap) + + if (isExactJoinCommuteSwapPattern) { + // This is the JoinCommute swapped predicate - transform with swapped arguments + var currentEnv = env; + var transformedArgs = Seq.empty(); + + // Process arguments in reverse order to swap them back + var sources = pred.sources(); + var reversedSources = Seq.of(sources.get(1), sources.get(0)); + + for (var arg : reversedSources) { + currentEnv = transform(currentEnv, arg); + transformedArgs = transformedArgs.appended(currentEnv.current()); + currentEnv = currentEnv.focus(env.current()); + } + + // Create a new predicate call with transformed arguments + String argsString = transformedArgs.joinToString(", "); + + // Extract operator from the original join condition + String operatorCall = "((org.apache.calcite.rex.RexCall) ((LogicalJoin) call.rel(0)).getCondition()).getOperator()"; + + return currentEnv.focus(env.current() + ".call(" + operatorCall + ", " + argsString + ")"); + } else { + return env.focus(env.symbols().get(pred.operator().getName())); + } + } + + @Override + public Env transformJoinField(Env env, RexRN.JoinField joinField) { + // For JoinCommute: we need to calculate absolute field positions in the swapped join + + // Get the original join condition to extract the actual field indices + var origJoinDecl = env.declare("(LogicalJoin) call.rel(0)"); + var envWithOrigJoin = origJoinDecl.getValue(); + var conditionDecl = envWithOrigJoin.declare("(org.apache.calcite.rex.RexCall) " + origJoinDecl.getKey() + ".getCondition()"); + var envWithCondition = conditionDecl.getValue(); + + if (joinField.ordinal() == 0) { + // Ordinal 0 = Left table in original join + // Extract the left operand field index from original condition + var leftFieldDecl = envWithCondition.declare("((org.apache.calcite.rex.RexInputRef) " + conditionDecl.getKey() + ".getOperands().get(0)).getIndex()"); + var envWithLeftField = leftFieldDecl.getValue(); + + // In swapped join: Left table is now at input 1 + // Use field(2, 1, leftFieldIndex) syntax + return envWithLeftField.focus(env.current() + ".field(2, 1, " + leftFieldDecl.getKey() + ")"); + } + else if (joinField.ordinal() == 1) { + // Ordinal 1 = Right table in original join + // Extract the right operand field index from original condition + var rightFieldDecl = envWithCondition.declare("((org.apache.calcite.rex.RexInputRef) " + conditionDecl.getKey() + ".getOperands().get(1)).getIndex()"); + var envWithRightField = rightFieldDecl.getValue(); + + // Right table field index needs to be adjusted since it was originally after left table + var leftColCountDecl = envWithRightField.declare("call.rel(1).getRowType().getFieldCount()"); + var envWithLeftCount = leftColCountDecl.getValue(); + var adjustedRightFieldDecl = envWithLeftCount.declare(rightFieldDecl.getKey() + " - " + leftColCountDecl.getKey()); + var envWithAdjustedRightField = adjustedRightFieldDecl.getValue(); + + // In swapped join: Right table is now at input 0 + // Use field(2, 0, adjustedRightFieldIndex) syntax + return envWithAdjustedRightField.focus(env.current() + ".field(2, 0, " + adjustedRightFieldDecl.getKey() + ")"); + } else { + throw new UnsupportedOperationException("Unsupported join field ordinal: " + joinField.ordinal()); + } } @Override @@ -429,6 +505,50 @@ public Env transformEmpty(Env env, RelRN.Empty empty) { return env.focus(env.current() + ".empty()"); } + + @Override + public Env transformCustom(Env env, RelRN custom) { + return switch (custom) { + case org.qed.Generated.RRuleInstances.JoinCommute.ProjectionRelRN projection -> { + // Transform the source first - this builds the join + var sourceEnv = transform(env, projection.source()); + + // Get original table column counts + var leftTableDecl = sourceEnv.declare("call.rel(1)"); + var envWithLeftTable = leftTableDecl.getValue(); + var rightTableDecl = envWithLeftTable.declare("call.rel(2)"); + var envWithRightTable = rightTableDecl.getValue(); + + var leftColCountDecl = envWithRightTable.declare(leftTableDecl.getKey() + ".getRowType().getFieldCount()"); + var envWithLeftCount = leftColCountDecl.getValue(); + var rightColCountDecl = envWithLeftCount.declare(rightTableDecl.getKey() + ".getRowType().getFieldCount()"); + var envWithRightCount = rightColCountDecl.getValue(); + + // Create the projection indices as a List + var projectionIndicesDecl = envWithRightCount.declare( + "java.util.stream.IntStream.concat(" + + // Left columns: rightColCount + 0, rightColCount + 1, ..., rightColCount + leftColCount - 1 + "java.util.stream.IntStream.range(" + rightColCountDecl.getKey() + ", " + + rightColCountDecl.getKey() + " + " + leftColCountDecl.getKey() + "), " + + // Right columns: 0, 1, ..., rightColCount - 1 + "java.util.stream.IntStream.range(0, " + rightColCountDecl.getKey() + ")" + + ").boxed().collect(java.util.stream.Collectors.toList())" + ); + var envWithProjectionIndices = projectionIndicesDecl.getValue(); + + // Convert List to field references using RelBuilder.fields() + var fieldRefsDecl = envWithProjectionIndices.declare( + sourceEnv.current() + ".fields(" + projectionIndicesDecl.getKey() + ")" + ); + var envWithFieldRefs = fieldRefsDecl.getValue(); + + // Apply projection using the field references list + yield envWithFieldRefs.focus(sourceEnv.current() + ".project(" + fieldRefsDecl.getKey() + ")"); + } + default -> unimplementedTransform(env, custom); + }; + } + public record Env(AtomicInteger varId, int rel, String current, String skeleton, Seq statements, ImmutableMap symbols) { public static Env empty() { diff --git a/src/main/java/org/qed/Generated/CalciteTester.java b/src/main/java/org/qed/Generated/CalciteTester.java index cccd850..4a1d6b3 100644 --- a/src/main/java/org/qed/Generated/CalciteTester.java +++ b/src/main/java/org/qed/Generated/CalciteTester.java @@ -37,13 +37,21 @@ public static HepPlanner loadRule(RelOptRule rule) { return new HepPlanner(builder.build()); } + public static HepPlanner loadRule(RelOptRule rule, int matchLimit) { + System.out.printf("Verifying Rule: %s (match limit: %d)\n", rule.getClass(), matchLimit); + var builder = new HepProgramBuilder() + .addMatchLimit(matchLimit) + .addRuleInstance(rule); + return new HepPlanner(builder.build()); + } + public static Seq ruleList() { Reflections reflections = new Reflections("org.qed.Generated.RRuleInstances"); Set> ruleClasses = reflections.getSubTypesOf(RRule.class); var concreteRuleClasses = ruleClasses.stream() .filter(clazz -> !clazz.isInterface() && - !Modifier.isAbstract(clazz.getModifiers()) && + !Modifier.isAbstract(clazz.getModifiers()) && !clazz.getName().contains("$")) .collect(Collectors.toSet()); @@ -87,7 +95,7 @@ public static void runAllTests() { org.qed.Generated.Tests.MinusMergeTest.runTest(); org.qed.Generated.Tests.ProjectFilterTransposeTest.runTest(); org.qed.Generated.Tests.JoinPushTransitivePredicatesTest.runTest(); - org.qed.Generated.Tests.SemiJoinProjectTransposeTest.runTest(); + org.qed.Generated.Tests.JoinCommuteTest.runTest(); } catch (Exception e) { System.out.println("Test failed: " + e.getMessage()); e.printStackTrace(); @@ -95,14 +103,15 @@ public static void runAllTests() { } public static void main(String[] args) throws IOException { -// var rule = new RRuleInstance.FilterSetOpTranspose(); -// Files.createDirectories(Path.of(rulePath)); -// new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); -// var rules = new RRuleInstance.JoinAssociate(); -// Files.createDirectories(Path.of(rulePath)); -// for (var rule : rules.family()) { -// new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); -// } + var rule = new org.qed.Generated.RRuleInstances.JoinCommute(); + System.out.println(rule.explain()); + Files.createDirectories(Path.of(rulePath)); + new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); + // var rules = new RRuleInstance.JoinAssociate(); + // Files.createDirectories(Path.of(rulePath)); + // for (var rule : rules.family()) { + // new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); + // } generate(); runAllTests(); } diff --git a/src/main/java/org/qed/Generated/JoinCommute.java b/src/main/java/org/qed/Generated/JoinCommute.java index 894330c..147fd52 100644 --- a/src/main/java/org/qed/Generated/JoinCommute.java +++ b/src/main/java/org/qed/Generated/JoinCommute.java @@ -14,7 +14,21 @@ protected JoinCommute(Config config) { @Override public void onMatch(RelOptRuleCall call) { var var_3 = call.builder(); - call.transformTo(var_3.push(call.rel(2)).push(call.rel(1)).join(JoinRelType.INNER, ((LogicalJoin) call.rel(0)).getCondition()).build()); + var var_4 = (LogicalJoin) call.rel(0); + var var_5 = (org.apache.calcite.rex.RexCall) var_4.getCondition(); + var var_6 = ((org.apache.calcite.rex.RexInputRef) var_5.getOperands().get(0)).getIndex(); + var var_7 = (LogicalJoin) call.rel(0); + var var_8 = (org.apache.calcite.rex.RexCall) var_7.getCondition(); + var var_9 = ((org.apache.calcite.rex.RexInputRef) var_8.getOperands().get(1)).getIndex(); + var var_10 = call.rel(1).getRowType().getFieldCount(); + var var_11 = var_9 - var_10; + var var_12 = call.rel(1); + var var_13 = call.rel(2); + var var_14 = var_12.getRowType().getFieldCount(); + var var_15 = var_13.getRowType().getFieldCount(); + var var_16 = java.util.stream.IntStream.concat(java.util.stream.IntStream.range(var_15, var_15 + var_14), java.util.stream.IntStream.range(0, var_15)).boxed().collect(java.util.stream.Collectors.toList()); + var var_17 = var_3.push(call.rel(2)).push(call.rel(1)).join(JoinRelType.INNER, var_3.push(call.rel(2)).push(call.rel(1)).call(((org.apache.calcite.rex.RexCall) ((LogicalJoin) call.rel(0)).getCondition()).getOperator(), var_3.push(call.rel(2)).push(call.rel(1)).field(2, 1, var_6), var_3.push(call.rel(2)).push(call.rel(1)).field(2, 0, var_11))).fields(var_16); + call.transformTo(var_3.push(call.rel(2)).push(call.rel(1)).join(JoinRelType.INNER, var_3.push(call.rel(2)).push(call.rel(1)).call(((org.apache.calcite.rex.RexCall) ((LogicalJoin) call.rel(0)).getCondition()).getOperator(), var_3.push(call.rel(2)).push(call.rel(1)).field(2, 1, var_6), var_3.push(call.rel(2)).push(call.rel(1)).field(2, 0, var_11))).project(var_17).build()); } public interface Config extends EmptyConfig { diff --git a/src/main/java/org/qed/Generated/RRuleInstances-unprovable/JoinCommute.java b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/JoinCommute.java deleted file mode 100644 index 22f2344..0000000 --- a/src/main/java/org/qed/Generated/RRuleInstances-unprovable/JoinCommute.java +++ /dev/null @@ -1,28 +0,0 @@ -package org.qed.Generated.RRuleInstances; - -import kala.collection.Map; -import kala.collection.Seq; -import org.apache.calcite.rel.core.JoinRelType; -import org.apache.calcite.sql.fun.SqlStdOperatorTable; -import org.qed.RelRN; -import org.qed.RexRN; -import org.qed.RRule; -import org.qed.RuleBuilder; - -public record JoinCommute() implements RRule { - static final RelRN left = RelRN.scan("Left", "Left_Type"); - static final RelRN right = RelRN.scan("Right", "Right_Type"); - static final RexRN joinCond = left.joinPred("pred", right); - - @Override - public RelRN before() { - return left.join(JoinRelType.INNER, joinCond, right); - } - - @Override - public RelRN after() { - // We need to swap the join fields in the condition - RexRN commutedJoinCond = right.joinPred("pred", left); - return right.join(JoinRelType.INNER, commutedJoinCond, left); - } -} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/JoinCommute.java b/src/main/java/org/qed/Generated/RRuleInstances/JoinCommute.java new file mode 100644 index 0000000..d1580be --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/JoinCommute.java @@ -0,0 +1,50 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Seq; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlOperator; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record JoinCommute() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final String pred = "pred"; + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, pred, right); + } + + @Override + public RelRN after() { + SqlOperator predOp = RuleBuilder.create().genericPredicateOp(pred, true); + RexRN swappedPred = new RexRN.Pred(predOp, Seq.of( + new RexRN.JoinField(1, right, left), + new RexRN.JoinField(0, right, left) + )); + RelRN swappedJoin = right.join(JoinRelType.INNER, swappedPred, left); + + return new ProjectionRelRN(swappedJoin); + } + + // implementation for the column reordering projection + public static record ProjectionRelRN(RelRN source) implements RelRN { + @Override + public RelNode semantics() { + RuleBuilder builder = RuleBuilder.create(); + builder.push(source.semantics()); + + RexNode leftField = builder.field(1); // Left columns (now at position 1) + RexNode rightField = builder.field(0); // Right columns (now at position 0) + + builder.project(leftField, rightField); + + return builder.build(); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/Tests-failed/SemiJoinProjectTransposeTest.java b/src/main/java/org/qed/Generated/Tests-Trivial/SemiJoinProjectTransposeTest.java similarity index 100% rename from src/main/java/org/qed/Generated/Tests-failed/SemiJoinProjectTransposeTest.java rename to src/main/java/org/qed/Generated/Tests-Trivial/SemiJoinProjectTransposeTest.java diff --git a/src/main/java/org/qed/Generated/Tests/JoinCommuteTest.java b/src/main/java/org/qed/Generated/Tests/JoinCommuteTest.java new file mode 100644 index 0000000..c6d0203 --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/JoinCommuteTest.java @@ -0,0 +1,84 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class JoinCommuteTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + // Create EMP table (8 columns like the real Calcite test) + var empTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), // EMPNO + Tuple.of(RelType.fromString("VARCHAR", true), false), // ENAME + Tuple.of(RelType.fromString("VARCHAR", true), false), // JOB + Tuple.of(RelType.fromString("INTEGER", true), false), // MGR + Tuple.of(RelType.fromString("DATE", true), false), // HIREDATE + Tuple.of(RelType.fromString("DECIMAL", true), false), // SAL + Tuple.of(RelType.fromString("DECIMAL", true), false), // COMM + Tuple.of(RelType.fromString("INTEGER", true), false) // DEPTNO + )); + builder.addTable(empTable); + + // Create DEPT table (3 columns like the real Calcite test) + var deptTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), // DEPTNO + Tuple.of(RelType.fromString("VARCHAR", true), false), // DNAME + Tuple.of(RelType.fromString("VARCHAR", true), false) // LOC + )); + builder.addTable(deptTable); + + // Build the "before" pattern: EMP JOIN DEPT ON EMP.DEPTNO = DEPT.DEPTNO + var empScan = builder.scan(empTable.getName()).build(); + var deptScan = builder.scan(deptTable.getName()).build(); + + var before = builder + .push(empScan) + .push(deptScan) + .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("equals", true), + builder.field(2, 0, 7), // EMP.DEPTNO (8th column, index 7) + builder.field(2, 1, 0) // DEPT.DEPTNO (1st column, index 0) + )) + .build(); + + // Build the "after" pattern: DEPT JOIN EMP ON DEPT.DEPTNO = EMP.DEPTNO + // followed by projection to restore original column order + var after = builder + .push(deptScan) + .push(empScan) + .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("equals", true), + builder.field(2, 1, 7), // EMP.DEPTNO (now at input 1, field 7) + builder.field(2, 0, 0) // DEPT.DEPTNO (now at input 0, field 0) + )) + .project( + // EMP columns first (now at positions 3-10 in the swapped join) + builder.field(3), // EMPNO + builder.field(4), // ENAME + builder.field(5), // JOB + builder.field(6), // MGR + builder.field(7), // HIREDATE + builder.field(8), // SAL + builder.field(9), // COMM + builder.field(10), // DEPTNO + // DEPT columns second (now at positions 0-2 in the swapped join) + builder.field(0), // DEPTNO0 + builder.field(1), // DNAME + builder.field(2) // LOC + ) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.JoinCommute.Config.DEFAULT.toRule(), 1); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running JoinCommute test..."); + runTest(); + } +} \ No newline at end of file