Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 121 additions & 1 deletion src/main/java/org/qed/Generated/CalciteGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import java.util.concurrent.atomic.AtomicInteger;

import javax.annotation.processing.Generated;

public class CalciteGenerator implements CodeGenerator<CalciteGenerator.Env> {

@Override
Expand Down Expand Up @@ -279,7 +281,81 @@ public Env transformFilter(Env env, RelRN.Filter filter) {

@Override
public Env transformPred(Env env, RexRN.Pred pred) {
return env.focus(env.symbols().get(pred.operator().getName()));
// ONLY apply to the very specific JoinCommute pattern where:
// 1. We have exactly 2 JoinField sources
// 2. The first JoinField has ordinal 1 and second has ordinal 0 (indicating argument swap)
// This is the EXACT pattern from JoinCommute rule: pred($1, $0) vs pred($0, $1)
boolean isExactJoinCommuteSwapPattern = pred.sources().size() == 2 &&
pred.sources().get(0) instanceof RexRN.JoinField joinField1 &&
pred.sources().get(1) instanceof RexRN.JoinField joinField2 &&
joinField1.ordinal() == 1 && // First arg is ordinal 1 (left -> right in swap)
joinField2.ordinal() == 0; // Second arg is ordinal 0 (right -> left in swap)

if (isExactJoinCommuteSwapPattern) {
// This is the JoinCommute swapped predicate - transform with swapped arguments
var currentEnv = env;
var transformedArgs = Seq.<String>empty();

// Process arguments in reverse order to swap them back
var sources = pred.sources();
var reversedSources = Seq.of(sources.get(1), sources.get(0));

for (var arg : reversedSources) {
currentEnv = transform(currentEnv, arg);
transformedArgs = transformedArgs.appended(currentEnv.current());
currentEnv = currentEnv.focus(env.current());
}

// Create a new predicate call with transformed arguments
String argsString = transformedArgs.joinToString(", ");

// Extract operator from the original join condition
String operatorCall = "((org.apache.calcite.rex.RexCall) ((LogicalJoin) call.rel(0)).getCondition()).getOperator()";

return currentEnv.focus(env.current() + ".call(" + operatorCall + ", " + argsString + ")");
} else {
return env.focus(env.symbols().get(pred.operator().getName()));
}
}

@Override
public Env transformJoinField(Env env, RexRN.JoinField joinField) {
// For JoinCommute: we need to calculate absolute field positions in the swapped join

// Get the original join condition to extract the actual field indices
var origJoinDecl = env.declare("(LogicalJoin) call.rel(0)");
var envWithOrigJoin = origJoinDecl.getValue();
var conditionDecl = envWithOrigJoin.declare("(org.apache.calcite.rex.RexCall) " + origJoinDecl.getKey() + ".getCondition()");
var envWithCondition = conditionDecl.getValue();

if (joinField.ordinal() == 0) {
// Ordinal 0 = Left table in original join
// Extract the left operand field index from original condition
var leftFieldDecl = envWithCondition.declare("((org.apache.calcite.rex.RexInputRef) " + conditionDecl.getKey() + ".getOperands().get(0)).getIndex()");
var envWithLeftField = leftFieldDecl.getValue();

// In swapped join: Left table is now at input 1
// Use field(2, 1, leftFieldIndex) syntax
return envWithLeftField.focus(env.current() + ".field(2, 1, " + leftFieldDecl.getKey() + ")");
}
else if (joinField.ordinal() == 1) {
// Ordinal 1 = Right table in original join
// Extract the right operand field index from original condition
var rightFieldDecl = envWithCondition.declare("((org.apache.calcite.rex.RexInputRef) " + conditionDecl.getKey() + ".getOperands().get(1)).getIndex()");
var envWithRightField = rightFieldDecl.getValue();

// Right table field index needs to be adjusted since it was originally after left table
var leftColCountDecl = envWithRightField.declare("call.rel(1).getRowType().getFieldCount()");
var envWithLeftCount = leftColCountDecl.getValue();
var adjustedRightFieldDecl = envWithLeftCount.declare(rightFieldDecl.getKey() + " - " + leftColCountDecl.getKey());
var envWithAdjustedRightField = adjustedRightFieldDecl.getValue();

// In swapped join: Right table is now at input 0
// Use field(2, 0, adjustedRightFieldIndex) syntax
return envWithAdjustedRightField.focus(env.current() + ".field(2, 0, " + adjustedRightFieldDecl.getKey() + ")");
} else {
throw new UnsupportedOperationException("Unsupported join field ordinal: " + joinField.ordinal());
}
}

@Override
Expand Down Expand Up @@ -429,6 +505,50 @@ public Env transformEmpty(Env env, RelRN.Empty empty) {
return env.focus(env.current() + ".empty()");
}


@Override
public Env transformCustom(Env env, RelRN custom) {
return switch (custom) {
case org.qed.Generated.RRuleInstances.JoinCommute.ProjectionRelRN projection -> {
// Transform the source first - this builds the join
var sourceEnv = transform(env, projection.source());

// Get original table column counts
var leftTableDecl = sourceEnv.declare("call.rel(1)");
var envWithLeftTable = leftTableDecl.getValue();
var rightTableDecl = envWithLeftTable.declare("call.rel(2)");
var envWithRightTable = rightTableDecl.getValue();

var leftColCountDecl = envWithRightTable.declare(leftTableDecl.getKey() + ".getRowType().getFieldCount()");
var envWithLeftCount = leftColCountDecl.getValue();
var rightColCountDecl = envWithLeftCount.declare(rightTableDecl.getKey() + ".getRowType().getFieldCount()");
var envWithRightCount = rightColCountDecl.getValue();

// Create the projection indices as a List<Integer>
var projectionIndicesDecl = envWithRightCount.declare(
"java.util.stream.IntStream.concat(" +
// Left columns: rightColCount + 0, rightColCount + 1, ..., rightColCount + leftColCount - 1
"java.util.stream.IntStream.range(" + rightColCountDecl.getKey() + ", " +
rightColCountDecl.getKey() + " + " + leftColCountDecl.getKey() + "), " +
// Right columns: 0, 1, ..., rightColCount - 1
"java.util.stream.IntStream.range(0, " + rightColCountDecl.getKey() + ")" +
").boxed().collect(java.util.stream.Collectors.toList())"
);
var envWithProjectionIndices = projectionIndicesDecl.getValue();

// Convert List<Integer> to field references using RelBuilder.fields()
var fieldRefsDecl = envWithProjectionIndices.declare(
sourceEnv.current() + ".fields(" + projectionIndicesDecl.getKey() + ")"
);
var envWithFieldRefs = fieldRefsDecl.getValue();

// Apply projection using the field references list
yield envWithFieldRefs.focus(sourceEnv.current() + ".project(" + fieldRefsDecl.getKey() + ")");
}
default -> unimplementedTransform(env, custom);
};
}

public record Env(AtomicInteger varId, int rel, String current, String skeleton, Seq<String> statements,
ImmutableMap<String, String> symbols) {
public static Env empty() {
Expand Down
29 changes: 19 additions & 10 deletions src/main/java/org/qed/Generated/CalciteTester.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,21 @@ public static HepPlanner loadRule(RelOptRule rule) {
return new HepPlanner(builder.build());
}

public static HepPlanner loadRule(RelOptRule rule, int matchLimit) {
System.out.printf("Verifying Rule: %s (match limit: %d)\n", rule.getClass(), matchLimit);
var builder = new HepProgramBuilder()
.addMatchLimit(matchLimit)
.addRuleInstance(rule);
return new HepPlanner(builder.build());
}

public static Seq<RRule> ruleList() {
Reflections reflections = new Reflections("org.qed.Generated.RRuleInstances");

Set<Class<? extends RRule>> ruleClasses = reflections.getSubTypesOf(RRule.class);
var concreteRuleClasses = ruleClasses.stream()
.filter(clazz -> !clazz.isInterface() &&
!Modifier.isAbstract(clazz.getModifiers()) &&
!Modifier.isAbstract(clazz.getModifiers()) &&
!clazz.getName().contains("$"))
.collect(Collectors.toSet());

Expand Down Expand Up @@ -87,22 +95,23 @@ public static void runAllTests() {
org.qed.Generated.Tests.MinusMergeTest.runTest();
org.qed.Generated.Tests.ProjectFilterTransposeTest.runTest();
org.qed.Generated.Tests.JoinPushTransitivePredicatesTest.runTest();
org.qed.Generated.Tests.SemiJoinProjectTransposeTest.runTest();
org.qed.Generated.Tests.JoinCommuteTest.runTest();
} catch (Exception e) {
System.out.println("Test failed: " + e.getMessage());
e.printStackTrace();
}
}

public static void main(String[] args) throws IOException {
// var rule = new RRuleInstance.FilterSetOpTranspose();
// Files.createDirectories(Path.of(rulePath));
// new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson());
// var rules = new RRuleInstance.JoinAssociate();
// Files.createDirectories(Path.of(rulePath));
// for (var rule : rules.family()) {
// new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson());
// }
var rule = new org.qed.Generated.RRuleInstances.JoinCommute();
System.out.println(rule.explain());
Files.createDirectories(Path.of(rulePath));
new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson());
// var rules = new RRuleInstance.JoinAssociate();
// Files.createDirectories(Path.of(rulePath));
// for (var rule : rules.family()) {
// new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson());
// }
generate();
runAllTests();
}
Expand Down
16 changes: 15 additions & 1 deletion src/main/java/org/qed/Generated/JoinCommute.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,21 @@ protected JoinCommute(Config config) {
@Override
public void onMatch(RelOptRuleCall call) {
var var_3 = call.builder();
call.transformTo(var_3.push(call.rel(2)).push(call.rel(1)).join(JoinRelType.INNER, ((LogicalJoin) call.rel(0)).getCondition()).build());
var var_4 = (LogicalJoin) call.rel(0);
var var_5 = (org.apache.calcite.rex.RexCall) var_4.getCondition();
var var_6 = ((org.apache.calcite.rex.RexInputRef) var_5.getOperands().get(0)).getIndex();
var var_7 = (LogicalJoin) call.rel(0);
var var_8 = (org.apache.calcite.rex.RexCall) var_7.getCondition();
var var_9 = ((org.apache.calcite.rex.RexInputRef) var_8.getOperands().get(1)).getIndex();
var var_10 = call.rel(1).getRowType().getFieldCount();
var var_11 = var_9 - var_10;
var var_12 = call.rel(1);
var var_13 = call.rel(2);
var var_14 = var_12.getRowType().getFieldCount();
var var_15 = var_13.getRowType().getFieldCount();
var var_16 = java.util.stream.IntStream.concat(java.util.stream.IntStream.range(var_15, var_15 + var_14), java.util.stream.IntStream.range(0, var_15)).boxed().collect(java.util.stream.Collectors.toList());
var var_17 = var_3.push(call.rel(2)).push(call.rel(1)).join(JoinRelType.INNER, var_3.push(call.rel(2)).push(call.rel(1)).call(((org.apache.calcite.rex.RexCall) ((LogicalJoin) call.rel(0)).getCondition()).getOperator(), var_3.push(call.rel(2)).push(call.rel(1)).field(2, 1, var_6), var_3.push(call.rel(2)).push(call.rel(1)).field(2, 0, var_11))).fields(var_16);
call.transformTo(var_3.push(call.rel(2)).push(call.rel(1)).join(JoinRelType.INNER, var_3.push(call.rel(2)).push(call.rel(1)).call(((org.apache.calcite.rex.RexCall) ((LogicalJoin) call.rel(0)).getCondition()).getOperator(), var_3.push(call.rel(2)).push(call.rel(1)).field(2, 1, var_6), var_3.push(call.rel(2)).push(call.rel(1)).field(2, 0, var_11))).project(var_17).build());
}

public interface Config extends EmptyConfig {
Expand Down
39 changes: 39 additions & 0 deletions src/main/java/org/qed/Generated/PruneEmptyFilter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package org.qed.Generated;

import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.logical.*;

public class PruneEmptyFilter extends RelRule<PruneEmptyFilter.Config> {
protected PruneEmptyFilter(Config config) {
super(config);
}

@Override
public void onMatch(RelOptRuleCall call) {
var var_2 = call.builder();
call.transformTo(var_2.empty().build());
}

public interface Config extends EmptyConfig {
Config DEFAULT = new Config() {};

@Override
default PruneEmptyFilter toRule() {
return new PruneEmptyFilter(this);
}

@Override
default String description() {
return "PruneEmptyFilter";
}

@Override
default RelRule.OperandTransform operandSupplier() {
return s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(LogicalValues.class).noInputs());
}

}
}
39 changes: 39 additions & 0 deletions src/main/java/org/qed/Generated/PruneEmptyIntersect.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package org.qed.Generated;

import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.logical.*;

public class PruneEmptyIntersect extends RelRule<PruneEmptyIntersect.Config> {
protected PruneEmptyIntersect(Config config) {
super(config);
}

@Override
public void onMatch(RelOptRuleCall call) {
var var_3 = call.builder();
call.transformTo(var_3.empty().empty().intersect(false, 2).build());
}

public interface Config extends EmptyConfig {
Config DEFAULT = new Config() {};

@Override
default PruneEmptyIntersect toRule() {
return new PruneEmptyIntersect(this);
}

@Override
default String description() {
return "PruneEmptyIntersect";
}

@Override
default RelRule.OperandTransform operandSupplier() {
return s_2 -> s_2.operand(LogicalIntersect.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(LogicalValues.class).noInputs());
}

}
}
39 changes: 39 additions & 0 deletions src/main/java/org/qed/Generated/PruneEmptyProject.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package org.qed.Generated;

import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.logical.*;

public class PruneEmptyProject extends RelRule<PruneEmptyProject.Config> {
protected PruneEmptyProject(Config config) {
super(config);
}

@Override
public void onMatch(RelOptRuleCall call) {
var var_2 = call.builder();
call.transformTo(var_2.empty().build());
}

public interface Config extends EmptyConfig {
Config DEFAULT = new Config() {};

@Override
default PruneEmptyProject toRule() {
return new PruneEmptyProject(this);
}

@Override
default String description() {
return "PruneEmptyProject";
}

@Override
default RelRule.OperandTransform operandSupplier() {
return s_1 -> s_1.operand(LogicalProject.class).oneInput(s_0 -> s_0.operand(LogicalValues.class).noInputs());
}

}
}
39 changes: 39 additions & 0 deletions src/main/java/org/qed/Generated/PruneLeftEmptyJoin.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package org.qed.Generated;

import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.logical.*;

public class PruneLeftEmptyJoin extends RelRule<PruneLeftEmptyJoin.Config> {
protected PruneLeftEmptyJoin(Config config) {
super(config);
}

@Override
public void onMatch(RelOptRuleCall call) {
var var_3 = call.builder();
call.transformTo(var_3.push(call.rel(2)).build());
}

public interface Config extends EmptyConfig {
Config DEFAULT = new Config() {};

@Override
default PruneLeftEmptyJoin toRule() {
return new PruneLeftEmptyJoin(this);
}

@Override
default String description() {
return "PruneLeftEmptyJoin";
}

@Override
default RelRule.OperandTransform operandSupplier() {
return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(LogicalValues.class).noInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs());
}

}
}
Loading