Skip to content
6 changes: 3 additions & 3 deletions src/main/java/org/qed/CodeGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ default E unimplementedTransform(E env, Object object) {
return env;
}

E preMatch();
E preMatch(String rulename);

default E onMatch(E env, RelRN pattern) {
return switch (pattern) {
Expand Down Expand Up @@ -96,9 +96,9 @@ default String translate(String name, E onMatch, E transform) {

default String generate(RRule rule) {
System.out.printf("Generating Rule: %s\n", rule.name());
var onMatch = postMatch(onMatch(preMatch(), rule.before()));
var onMatch = postMatch(onMatch(preMatch(rule.name()), rule.before()));
var transform = postTransform(transform(preTransform(onMatch), rule.after()));
return translate(rule.getClass().getSimpleName(), onMatch, transform);
return translate(rule.name(), onMatch, transform);
}

default E onMatchScan(E env, RelRN.Scan scan) {
Expand Down
83 changes: 65 additions & 18 deletions src/main/java/org/qed/Generated/CalciteGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
public class CalciteGenerator implements CodeGenerator<CalciteGenerator.Env> {

@Override
public Env preMatch() {
return Env.empty();
public Env preMatch(String rulename) {
return Env.empty(rulename);
}

@Override
Expand All @@ -36,13 +36,56 @@ public String translate(String name, Env onMatch, Env transform) {
var builder = new StringBuilder("package org.qed.Generated;\n\n");
builder.append("import org.apache.calcite.plan.RelOptRuleCall;\n");
builder.append("import org.apache.calcite.plan.RelRule;\n");
builder.append("import org.apache.calcite.plan.RelOptUtil;\n");
builder.append("import java.util.List;\n");
builder.append("import org.apache.calcite.rel.RelNode;\n");
builder.append("import org.apache.calcite.rel.core.JoinRelType;\n");
builder.append("import org.apache.calcite.rel.logical.*;\n\n");
builder.append("import org.apache.calcite.rel.logical.*;\n");
if (name.equals("ProjectFilterTranspose")) {
builder.append("import org.apache.calcite.rex.RexInputRef;\n");
builder.append("import org.apache.calcite.rex.RexShuttle;\n");
builder.append("import java.util.HashMap;\n");
}
builder.append("\n");
builder.append("public class " + name + " extends RelRule<" + name + ".Config> {\n");
builder.append("\tprotected " + name + "(Config config) {\n");
builder.append("\t\tsuper(config);\n");
builder.append("\t}\n\n");

if(name.equals("ProjectFilterTranspose")) {
builder.append(
"""
\tprivate static org.apache.calcite.rex.RexNode mapFilterToProjectedColumns(RelOptRuleCall call) {
\t\tvar filter = (LogicalFilter) call.rel(1);
\t\tvar project = (LogicalProject) call.rel(0);
\t\tvar rexBuilder = project.getCluster().getRexBuilder();
\t\t
\t\t// Create mapping from table column index to projected position
\t\tvar tableToProjectMapping = new HashMap<Integer, Integer>();
\t\tfor (int projectedPos = 0; projectedPos < project.getProjects().size(); projectedPos++) {
\t\t\tvar projectExpr = project.getProjects().get(projectedPos);
\t\t\tif (projectExpr instanceof RexInputRef inputRef) {
\t\t\t\ttableToProjectMapping.put(inputRef.getIndex(), projectedPos);
\t\t\t}
\t\t}
\t\t
\t\t// Rewrite filter condition to use projected positions
\t\treturn filter.getCondition().accept(new RexShuttle() {
\t\t\t@Override
\t\t\tpublic org.apache.calcite.rex.RexNode visitInputRef(RexInputRef inputRef) {
\t\t\t\tInteger projectedPos = tableToProjectMapping.get(inputRef.getIndex());
\t\t\t\tif (projectedPos != null) {
\t\t\t\t\treturn rexBuilder.makeInputRef(inputRef.getType(), projectedPos);
\t\t\t\t}
\t\t\t\treturn inputRef;
\t\t\t}
\t\t});
\t}

"""
);
}

builder.append("\t@Override\n\tpublic void onMatch(RelOptRuleCall call) {\n");
transform.statements().forEach(statement -> builder.append("\t\t").append(statement).append("\n"));
builder.append("\t}\n\n");
Expand Down Expand Up @@ -313,7 +356,17 @@ public Env transformPred(Env env, RexRN.Pred pred) {
String operatorCall = "((org.apache.calcite.rex.RexCall) ((LogicalJoin) call.rel(0)).getCondition()).getOperator()";

return currentEnv.focus(env.current() + ".call(" + operatorCall + ", " + argsString + ")");
} else {
}
else if (pred.sources().anyMatch(source -> source instanceof RexRN.Proj)) {
return env.focus(
"RelOptUtil.pushFilterPastProject(((LogicalFilter) call.rel(0)).getCondition(), " +
"((LogicalProject) call.rel(1)))"
);
}
else if (env.rulename.equals("ProjectFilterTranspose")) {
return env.focus("mapFilterToProjectedColumns(call)");
}
else {
return env.focus(env.symbols().get(pred.operator().getName()));
}
}
Expand Down Expand Up @@ -472,15 +525,9 @@ public Env transformProj(Env env, RexRN.Proj proj) {

@Override
public Env transformProject(Env env, RelRN.Project project) {
// First transform the source relation
var source_transform = transform(env, project.source());
var source_expression = source_transform.current();

// Then transform the projection map
var map_transform = transform(source_transform, project.map());

// Combine the source and projection using the project operation
// This creates a projection on top of the source relation
return map_transform.focus(source_expression + ".project(" + map_transform.current() + ")");
}

Expand Down Expand Up @@ -550,26 +597,26 @@ public Env transformCustom(Env env, RelRN custom) {
}

public record Env(AtomicInteger varId, int rel, String current, String skeleton, Seq<String> statements,
ImmutableMap<String, String> symbols) {
public static Env empty() {
ImmutableMap<String, String> symbols, String rulename) {
public static Env empty(String rulename) {
return new Env(new AtomicInteger(), 0, "call.rel(0)", "/* Unspecified skeleton */", Seq.empty(),
ImmutableMap.empty());
ImmutableMap.empty(), rulename);
}

public Env next() {
return new Env(varId, rel + 1, "call.rel(" + (rel + 1) + ")", skeleton, statements, symbols);
return new Env(varId, rel + 1, "call.rel(" + (rel + 1) + ")", skeleton, statements, symbols, rulename);
}

public Env focus(String target) {
return new Env(varId, rel, target, skeleton, statements, symbols);
return new Env(varId, rel, target, skeleton, statements, symbols, rulename);
}

public Env state(String statement) {
return new Env(varId, rel, current, skeleton, statements.appended(statement), symbols);
return new Env(varId, rel, current, skeleton, statements.appended(statement), symbols, rulename);
}

public Env symbol(String symbol, String expression) {
return new Env(varId, rel, current, skeleton, statements, symbols.putted(symbol, expression));
return new Env(varId, rel, current, skeleton, statements, symbols.putted(symbol, expression), rulename);
}

public Tuple2<String, Env> declare(String expression) {
Expand All @@ -579,7 +626,7 @@ public Tuple2<String, Env> declare(String expression) {

public Env grow(String requirement) {
var vn = "s_" + varId.getAndIncrement();
return new Env(varId, rel, current, vn + " -> " + vn + "." + requirement, statements, symbols);
return new Env(varId, rel, current, vn + " -> " + vn + "." + requirement, statements, symbols, rulename);
}
}

Expand Down
38 changes: 24 additions & 14 deletions src/main/java/org/qed/Generated/CalciteTester.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.qed.*;
import org.reflections.Reflections;
import org.apache.calcite.rel.rules.*;

import java.io.File;
import java.io.IOException;
Expand Down Expand Up @@ -52,7 +53,8 @@ public static Seq<RRule> ruleList() {
var concreteRuleClasses = ruleClasses.stream()
.filter(clazz -> !clazz.isInterface() &&
!Modifier.isAbstract(clazz.getModifiers()) &&
!clazz.getName().contains("$"))
!clazz.getName().contains("$") &&
clazz.getSimpleName().equals("ProjectFilterTranspose"))
.collect(Collectors.toSet());

var individuals = Seq.from(concreteRuleClasses)
Expand Down Expand Up @@ -84,26 +86,28 @@ public static void generate() {

public static void runAllTests() {
try {
org.qed.Generated.Tests.FilterIntoJoinTest.runTest();
org.qed.Generated.Tests.FilterMergeTest.runTest();
org.qed.Generated.Tests.FilterProjectTransposeTest.runTest();
org.qed.Generated.Tests.UnionMergeTest.runTest();
org.qed.Generated.Tests.IntersectMergeTest.runTest();
org.qed.Generated.Tests.FilterSetOpTransposeTest.runTest();
org.qed.Generated.Tests.JoinExtractFilterTest.runTest();
org.qed.Generated.Tests.SemiJoinFilterTransposeTest.runTest();
org.qed.Generated.Tests.MinusMergeTest.runTest();
// org.qed.Generated.Tests.FilterIntoJoinTest.runTest();
// org.qed.Generated.Tests.FilterMergeTest.runTest();
// org.qed.Generated.Tests.FilterProjectTransposeTest.runTest();
// org.qed.Generated.Tests.UnionMergeTest.runTest();
// org.qed.Generated.Tests.IntersectMergeTest.runTest();
// org.qed.Generated.Tests.FilterSetOpTransposeTest.runTest();
// org.qed.Generated.Tests.JoinExtractFilterTest.runTest();
// org.qed.Generated.Tests.SemiJoinFilterTransposeTest.runTest();
// org.qed.Generated.Tests.MinusMergeTest.runTest();
org.qed.Generated.Tests.ProjectFilterTransposeTest.runTest();
org.qed.Generated.Tests.JoinPushTransitivePredicatesTest.runTest();
org.qed.Generated.Tests.JoinCommuteTest.runTest();
// org.qed.Generated.Tests.JoinPushTransitivePredicatesTest.runTest();
// org.qed.Generated.Tests.JoinCommuteTest.runTest();
// org.qed.Generated.Tests.JoinConditionPushTest.runTest();
// org.qed.Generated.Tests.AggregateProjectMergeTest.runTest();
} catch (Exception e) {
System.out.println("Test failed: " + e.getMessage());
e.printStackTrace();
}
}

public static void main(String[] args) throws IOException {
var rule = new org.qed.Generated.RRuleInstances.JoinCommute();
var rule = new org.qed.Generated.RRuleInstances.FilterProjectTranspose();
System.out.println(rule.explain());
Files.createDirectories(Path.of(rulePath));
new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson());
Expand Down Expand Up @@ -154,7 +158,13 @@ public void verify(HepPlanner runner, RelNode source, RelNode target) {
System.out.println("> Actual rewritten RelNode:\n" + answerExplain);
System.out.println("> Expected rewritten RelNode:\n" + targetExplain);
}
else System.out.println("succeeded");
else
{
System.out.println("succeeded");
System.out.println("> Given source RelNode:\n" + source.explain());
System.out.println("> Actual rewritten RelNode:\n" + answerExplain);
System.out.println("> Expected rewritten RelNode:\n" + targetExplain);
}
return;
}
System.out.println("failed");
Expand Down
6 changes: 4 additions & 2 deletions src/main/java/org/qed/Generated/FilterProjectTranspose.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.plan.RelOptUtil;
import java.util.List;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.logical.*;
Expand All @@ -14,7 +16,7 @@ protected FilterProjectTranspose(Config config) {
@Override
public void onMatch(RelOptRuleCall call) {
var var_3 = call.builder();
call.transformTo(var_3.push(call.rel(2)).project(((LogicalProject) call.rel(0)).getProjects()).filter(((LogicalFilter) call.rel(1)).getCondition()).build());
call.transformTo(var_3.push(call.rel(2)).filter(RelOptUtil.pushFilterPastProject(((LogicalFilter) call.rel(0)).getCondition(), ((LogicalProject) call.rel(1)))).project(((LogicalProject) call.rel(1)).getProjects()).build());
}

public interface Config extends EmptyConfig {
Expand All @@ -32,7 +34,7 @@ default String description() {

@Override
default RelRule.OperandTransform operandSupplier() {
return s_2 -> s_2.operand(LogicalProject.class).oneInput(s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs()));
return s_2 -> s_2.operand(LogicalFilter.class).oneInput(s_1 -> s_1.operand(LogicalProject.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs()));
}

}
Expand Down
36 changes: 34 additions & 2 deletions src/main/java/org/qed/Generated/ProjectFilterTranspose.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,51 @@

import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.plan.RelOptUtil;
import java.util.List;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.logical.*;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexShuttle;
import java.util.HashMap;

public class ProjectFilterTranspose extends RelRule<ProjectFilterTranspose.Config> {
protected ProjectFilterTranspose(Config config) {
super(config);
}

private static org.apache.calcite.rex.RexNode mapFilterToProjectedColumns(RelOptRuleCall call) {
var filter = (LogicalFilter) call.rel(1);
var project = (LogicalProject) call.rel(0);
var rexBuilder = project.getCluster().getRexBuilder();

// Create mapping from table column index to projected position
var tableToProjectMapping = new HashMap<Integer, Integer>();
for (int projectedPos = 0; projectedPos < project.getProjects().size(); projectedPos++) {
var projectExpr = project.getProjects().get(projectedPos);
if (projectExpr instanceof RexInputRef inputRef) {
tableToProjectMapping.put(inputRef.getIndex(), projectedPos);
}
}

// Rewrite filter condition to use projected positions
return filter.getCondition().accept(new RexShuttle() {
@Override
public org.apache.calcite.rex.RexNode visitInputRef(RexInputRef inputRef) {
Integer projectedPos = tableToProjectMapping.get(inputRef.getIndex());
if (projectedPos != null) {
return rexBuilder.makeInputRef(inputRef.getType(), projectedPos);
}
return inputRef;
}
});
}

@Override
public void onMatch(RelOptRuleCall call) {
var var_3 = call.builder();
call.transformTo(var_3.push(call.rel(2)).filter(((LogicalFilter) call.rel(0)).getCondition()).project(((LogicalProject) call.rel(1)).getProjects()).build());
call.transformTo(var_3.push(call.rel(2)).project(((LogicalProject) call.rel(0)).getProjects()).filter(mapFilterToProjectedColumns(call)).build());
}

public interface Config extends EmptyConfig {
Expand All @@ -32,7 +64,7 @@ default String description() {

@Override
default RelRule.OperandTransform operandSupplier() {
return s_2 -> s_2.operand(LogicalFilter.class).oneInput(s_1 -> s_1.operand(LogicalProject.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs()));
return s_2 -> s_2.operand(LogicalProject.class).oneInput(s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs()));
}

}
Expand Down
Loading