From 7b3c56c8e1b0be9d84447ec62c3521ed20c24909 Mon Sep 17 00:00:00 2001 From: macronova Date: Mon, 29 Jan 2024 21:14:31 -0800 Subject: [PATCH 01/78] Wrap interface --- src/main/java/org/cosette/RelMatcher.java | 66 ++++----- src/main/java/org/cosette/RelRN.java | 92 +++++++++++++ src/main/java/org/cosette/RewriteRule.java | 8 ++ src/main/java/org/cosette/RexRN.java | 46 +++++++ src/main/java/org/cosette/RuleBuilder.java | 152 ++++++++++----------- 5 files changed, 255 insertions(+), 109 deletions(-) create mode 100644 src/main/java/org/cosette/RelRN.java create mode 100644 src/main/java/org/cosette/RewriteRule.java create mode 100644 src/main/java/org/cosette/RexRN.java diff --git a/src/main/java/org/cosette/RelMatcher.java b/src/main/java/org/cosette/RelMatcher.java index bd00517..a6e3c77 100644 --- a/src/main/java/org/cosette/RelMatcher.java +++ b/src/main/java/org/cosette/RelMatcher.java @@ -122,39 +122,39 @@ public static void main(String[] args) throws Exception { // var table_R = new CosetteTable("R", Map.of("a", new RelType.BaseType(SqlTypeName.INTEGER, true), "b", // new RelType.BaseType(SqlTypeName.INTEGER, true) // ), Set.empty(), Set.empty()); - var generator = new SchemaGenerator(); - generator.applyCreate("CREATE TABLE R (x INTEGER)"); - generator.applyCreate("CREATE TABLE S (a INTEGER)"); - var relBuilder = RuleBuilder.create(RawPlanner.generateConfig(generator.extractSchema())); - var query = relBuilder.scan("R").scan("S").join(JoinRelType.INNER, relBuilder.call(SqlStdOperatorTable.AND, - relBuilder.call(SqlStdOperatorTable.EQUALS, - relBuilder.call(SqlStdOperatorTable.PLUS, relBuilder.field(2, 0, 0), relBuilder.field(2, 1, 0)), - relBuilder.literal(0)), - relBuilder.call(SqlStdOperatorTable.EQUALS, relBuilder.field(2, 0, 0), relBuilder.field(2, 1, 0)))) - .build(); - var joinConditionPushRule = new RuleBuilder.JoinConditionPush(); - System.out.println("Example: SELECT * FROM R INNER JOIN S ON x + a = 0 AND x = a"); - System.out.println(); - System.out.println("Plan:"); - System.out.println(query.explain()); - System.out.println("Matcher:"); - System.out.println(joinConditionPushRule.getPattern().explain()); - var translator = joinConditionPushRule.match(query).get(); - var solver = translator.solver(); - for (var constraint : solver.getSygusConstraints()) { - System.out.println(constraint); - } - System.out.println(); - System.out.println("Can synthesize functions:"); - System.out.println(solver.checkSynth().hasSolution()); - System.out.println(); - for (var functionName : translator.declaredFunctions().store().keysView()) { - System.out.println("Synthesized function: " + functionName); - for (var functionComponent : translator.declaredFunctions().store().get(functionName).component1()) { - System.out.println(solver.getSynthSolution(functionComponent.component1())); - } - System.out.println(); - } +// var generator = new SchemaGenerator(); +// generator.applyCreate("CREATE TABLE R (x INTEGER)"); +// generator.applyCreate("CREATE TABLE S (a INTEGER)"); +// var relBuilder = RuleBuilder.create(RawPlanner.generateConfig(generator.extractSchema())); +// var query = relBuilder.scan("R").scan("S").join(JoinRelType.INNER, relBuilder.call(SqlStdOperatorTable.AND, +// relBuilder.call(SqlStdOperatorTable.EQUALS, +// relBuilder.call(SqlStdOperatorTable.PLUS, relBuilder.field(2, 0, 0), relBuilder.field(2, 1, 0)), +// relBuilder.literal(0)), +// relBuilder.call(SqlStdOperatorTable.EQUALS, relBuilder.field(2, 0, 0), relBuilder.field(2, 1, 0)))) +// .build(); +// var joinConditionPushRule = new RuleBuilder.JoinConditionPush(); +// System.out.println("Example: SELECT * FROM R INNER JOIN S ON x + a = 0 AND x = a"); +// System.out.println(); +// System.out.println("Plan:"); +// System.out.println(query.explain()); +// System.out.println("Matcher:"); +// System.out.println(joinConditionPushRule.getPattern().explain()); +// var translator = joinConditionPushRule.match(query).get(); +// var solver = translator.solver(); +// for (var constraint : solver.getSygusConstraints()) { +// System.out.println(constraint); +// } +// System.out.println(); +// System.out.println("Can synthesize functions:"); +// System.out.println(solver.checkSynth().hasSolution()); +// System.out.println(); +// for (var functionName : translator.declaredFunctions().store().keysView()) { +// System.out.println("Synthesized function: " + functionName); +// for (var functionComponent : translator.declaredFunctions().store().get(functionName).component1()) { +// System.out.println(solver.getSynthSolution(functionComponent.component1())); +// } +// System.out.println(); +// } } } diff --git a/src/main/java/org/cosette/RelRN.java b/src/main/java/org/cosette/RelRN.java new file mode 100644 index 0000000..39c9ec1 --- /dev/null +++ b/src/main/java/org/cosette/RelRN.java @@ -0,0 +1,92 @@ +package org.cosette; + +import kala.collection.Seq; +import kala.collection.Set; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.util.ImmutableBitSet; + +public interface RelRN { + RelNode semantics(); + + static RelType.VarType varType(String id, boolean nullable) { + return new RelType.VarType(id, nullable); + } + + static Scan scan(String id, RelType.VarType ty, boolean unique) { + return new Scan(id, ty, unique); + } + + default Filter filter(RexRN cond) { + return new Filter(cond, this); + } + + default Project project(Seq map) { + return new Project(map, this); + } + + default Join join(JoinRelType ty, RexRN cond, RelRN right) { + return new Join(ty, cond, this, right); + } + + default Union union(boolean all, RelRN... sources) { + return new Union(all, Seq.of(this).appendedAll(sources)); + } + + default Intersect intersect(boolean all, RelRN... sources) { + return new Intersect(all, Seq.of(this).appendedAll(sources)); + } + + record Scan(String id, RelType.VarType ty, boolean unique) implements RelRN { + + @Override + public RelNode semantics() { + var table = new CosetteTable(id, Seq.of("col-" + id), Seq.of(ty), unique ? Set.of(ImmutableBitSet.of(0)) : Set.empty(), Set.empty()); + return RuleBuilder.create().addTable(table).scan(id).build(); + } + } + + record Filter(RexRN cond, RelRN source) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(source.semantics()); + return builder.filter(cond.semantics().apply(builder)).build(); + } + } + + record Project(Seq map, RelRN source) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(source.semantics()); + return builder.project(map.map(m -> m.semantics().apply(builder))).build(); + } + } + + record Join(JoinRelType ty, RexRN cond, RelRN left, RelRN right) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(left.semantics()).push(left.semantics()); + return builder.join(ty, cond.semantics().apply(builder)).build(); + } + } + + record Union(boolean all, Seq sources) implements RelRN { + + @Override + public RelNode semantics() { + return RuleBuilder.create().pushAll(sources.map(RelRN::semantics)).union(all, sources.size()).build(); + } + } + + record Intersect(boolean all, Seq sources) implements RelRN { + + @Override + public RelNode semantics() { + return RuleBuilder.create().pushAll(sources.map(RelRN::semantics)).intersect(all, sources.size()).build(); + } + } + +} diff --git a/src/main/java/org/cosette/RewriteRule.java b/src/main/java/org/cosette/RewriteRule.java new file mode 100644 index 0000000..e3a4aab --- /dev/null +++ b/src/main/java/org/cosette/RewriteRule.java @@ -0,0 +1,8 @@ +package org.cosette; + +public class RewriteRule { + + public static void main(String[] args) { + } + +} diff --git a/src/main/java/org/cosette/RexRN.java b/src/main/java/org/cosette/RexRN.java new file mode 100644 index 0000000..6c4fb74 --- /dev/null +++ b/src/main/java/org/cosette/RexRN.java @@ -0,0 +1,46 @@ +package org.cosette; + +import kala.collection.Seq; +import org.apache.calcite.rex.RexNode; + +import java.util.function.Function; + +public interface RexRN { + Function semantics(); + + static Field field(int ord) { + return new Field(ord); + } + + static Proj proj(String id, RelType.VarType ty, RexRN... sources) { + return new Proj(id, ty, Seq.of(sources)); + } + + static Pred pred(String id, RexRN... sources) { + return new Pred(id, Seq.of(sources)); + } + + record Field(int ord) implements RexRN { + @Override + public Function semantics() { + return builder -> builder.field(ord); + } + } + + record Proj(String id, RelType.VarType ty, Seq sources) implements RexRN { + + @Override + public Function semantics() { + return builder -> builder.call(builder.genericProjectionOp(id, ty), sources.map(s -> s.semantics().apply(builder))); + } + } + + record Pred(String id, Seq sources) implements RexRN { + + @Override + public Function semantics() { + return builder -> builder.call(builder.genericPredicateOp(id, false), sources.map(s -> s.semantics().apply(builder))); + } + } + +} diff --git a/src/main/java/org/cosette/RuleBuilder.java b/src/main/java/org/cosette/RuleBuilder.java index b15eb8b..b83eff8 100644 --- a/src/main/java/org/cosette/RuleBuilder.java +++ b/src/main/java/org/cosette/RuleBuilder.java @@ -129,81 +129,81 @@ public CosetteAggregateFunction(String name, RelType returnType) { } } - public interface CosetteRule { - - RelNode getPattern(); - - RelNode getTransformation(); - - default ImmutableSeq deriveAdditionalConstraints(Solver solver, RexTranslator.Declarations declarations) { - return ImmutableSeq.empty(); - } - - default Result match(RelNode target) { - return RelMatcher.check(getPattern(), target).map(translator -> translator.addConstraints( - deriveAdditionalConstraints(translator.solver(), translator.declaredFunctions()))); - } - - } - - public static class JoinConditionPush implements CosetteRule { - - private final RelNode pattern; - private final RelNode transform; - private final String bothPredicate = "joinBoth"; - private final String leftPredicate = "joinLeft"; - private final String rightPredicate = "joinRight"; - - - public JoinConditionPush() { - var builder = RuleBuilder.create(); - var tableNames = builder.sourceSimpleTables(Seq.of(1, 2)); - tableNames.forEach(builder::scan); - var joinBoth = builder.genericPredicateOp(bothPredicate, true); - var joinLeft = builder.genericPredicateOp(leftPredicate, true); - var joinRight = builder.genericPredicateOp(rightPredicate, true); - var joinCond = builder.and(builder.call(joinBoth, builder.joinFields()), - builder.call(joinLeft, builder.fields(2, 0)), builder.call(joinRight, builder.fields(2, 1))); - builder.join(JoinRelType.INNER, joinCond); - pattern = builder.build(); - builder.scan(tableNames.get(0)).filter(builder.call(joinLeft, builder.fields())); - builder.scan(tableNames.get(1)).filter(builder.call(joinRight, builder.fields())); - joinCond = builder.call(joinBoth, builder.joinFields()); - builder.join(JoinRelType.INNER, joinCond); - transform = builder.build(); - } - - @Override - public RelNode getPattern() { - return pattern; - } - - @Override - public RelNode getTransformation() { - return transform; - } - - @Override - public ImmutableSeq deriveAdditionalConstraints(Solver solver, RexTranslator.Declarations declarations) { - var bp = declarations.store().get(bothPredicate); - var lp = declarations.store().get(leftPredicate); - var rp = declarations.store().get(rightPredicate); - var lvs = lp.component2().mapIndexed((i, s) -> solver.declareSygusVar(leftPredicate + "-V" + i, s)); - var rvs = rp.component2().mapIndexed((i, s) -> solver.declareSygusVar(rightPredicate + "-V" + i, s)); - var bpc = solver.mkTerm(Kind.APPLY_UF, - bp.component1().map(Tuple2::component1).appendedAll(lvs).appendedAll(rvs).toArray(new Term[]{})); - var lpc = solver.mkTerm(Kind.APPLY_UF, - lp.component1().map(Tuple2::component1).appendedAll(lvs).toArray(new Term[]{})); - var rpc = solver.mkTerm(Kind.APPLY_UF, - rp.component1().map(Tuple2::component1).appendedAll(rvs).toArray(new Term[]{})); - var acl = solver.mkTerm(Kind.IMPLIES, lpc, - solver.mkTerm(Kind.EXISTS, solver.mkTerm(Kind.VARIABLE_LIST, rvs.toArray(new Term[]{})), - solver.mkTerm(Kind.AND, bpc, rpc))); - var acr = solver.mkTerm(Kind.IMPLIES, rpc, - solver.mkTerm(Kind.EXISTS, solver.mkTerm(Kind.VARIABLE_LIST, lvs.toArray(new Term[]{})), - solver.mkTerm(Kind.AND, bpc, lpc))); - return ImmutableSeq.of(acl, acr); - } - } +// public interface CosetteRule { +// +// RelNode getPattern(); +// +// RelNode getTransformation(); +// +// default ImmutableSeq deriveAdditionalConstraints(Solver solver, RexTranslator.Declarations declarations) { +// return ImmutableSeq.empty(); +// } +// +// default Result match(RelNode target) { +// return RelMatcher.check(getPattern(), target).map(translator -> translator.addConstraints( +// deriveAdditionalConstraints(translator.solver(), translator.declaredFunctions()))); +// } +// +// } +// +// public static class JoinConditionPush implements CosetteRule { +// +// private final RelNode pattern; +// private final RelNode transform; +// private final String bothPredicate = "joinBoth"; +// private final String leftPredicate = "joinLeft"; +// private final String rightPredicate = "joinRight"; +// +// +// public JoinConditionPush() { +// var builder = RuleBuilder.create(); +// var tableNames = builder.sourceSimpleTables(Seq.of(1, 2)); +// tableNames.forEach(builder::scan); +// var joinBoth = builder.genericPredicateOp(bothPredicate, true); +// var joinLeft = builder.genericPredicateOp(leftPredicate, true); +// var joinRight = builder.genericPredicateOp(rightPredicate, true); +// var joinCond = builder.and(builder.call(joinBoth, builder.joinFields()), +// builder.call(joinLeft, builder.fields(2, 0)), builder.call(joinRight, builder.fields(2, 1))); +// builder.join(JoinRelType.INNER, joinCond); +// pattern = builder.build(); +// builder.scan(tableNames.get(0)).filter(builder.call(joinLeft, builder.fields())); +// builder.scan(tableNames.get(1)).filter(builder.call(joinRight, builder.fields())); +// joinCond = builder.call(joinBoth, builder.joinFields()); +// builder.join(JoinRelType.INNER, joinCond); +// transform = builder.build(); +// } +// +// @Override +// public RelNode getPattern() { +// return pattern; +// } +// +// @Override +// public RelNode getTransformation() { +// return transform; +// } +// +// @Override +// public ImmutableSeq deriveAdditionalConstraints(Solver solver, RexTranslator.Declarations declarations) { +// var bp = declarations.store().get(bothPredicate); +// var lp = declarations.store().get(leftPredicate); +// var rp = declarations.store().get(rightPredicate); +// var lvs = lp.component2().mapIndexed((i, s) -> solver.declareSygusVar(leftPredicate + "-V" + i, s)); +// var rvs = rp.component2().mapIndexed((i, s) -> solver.declareSygusVar(rightPredicate + "-V" + i, s)); +// var bpc = solver.mkTerm(Kind.APPLY_UF, +// bp.component1().map(Tuple2::component1).appendedAll(lvs).appendedAll(rvs).toArray(new Term[]{})); +// var lpc = solver.mkTerm(Kind.APPLY_UF, +// lp.component1().map(Tuple2::component1).appendedAll(lvs).toArray(new Term[]{})); +// var rpc = solver.mkTerm(Kind.APPLY_UF, +// rp.component1().map(Tuple2::component1).appendedAll(rvs).toArray(new Term[]{})); +// var acl = solver.mkTerm(Kind.IMPLIES, lpc, +// solver.mkTerm(Kind.EXISTS, solver.mkTerm(Kind.VARIABLE_LIST, rvs.toArray(new Term[]{})), +// solver.mkTerm(Kind.AND, bpc, rpc))); +// var acr = solver.mkTerm(Kind.IMPLIES, rpc, +// solver.mkTerm(Kind.EXISTS, solver.mkTerm(Kind.VARIABLE_LIST, lvs.toArray(new Term[]{})), +// solver.mkTerm(Kind.AND, bpc, lpc))); +// return ImmutableSeq.of(acl, acr); +// } +// } } \ No newline at end of file From aef3ace23ae87cf818e1fa3a8d25a26722799024 Mon Sep 17 00:00:00 2001 From: macronova Date: Tue, 5 Mar 2024 20:19:04 -0800 Subject: [PATCH 02/78] Prototype new DSL --- flake.lock | 146 +++--- flake.nix | 6 +- pom.xml | 15 +- src/main/java/org/cosette/CodeGenerator.java | 491 ++++++++++++++++++ .../java/org/cosette/JSONDeserializer.java | 194 +++---- src/main/java/org/cosette/JSONSerializer.java | 160 +++--- src/main/java/org/cosette/RRule.java | 94 ++++ src/main/java/org/cosette/RelJSONShuttle.java | 364 ------------- src/main/java/org/cosette/RelMatcher.java | 2 - src/main/java/org/cosette/RelRN.java | 78 ++- src/main/java/org/cosette/RelType.java | 18 +- src/main/java/org/cosette/RewriteRule.java | 8 - src/main/java/org/cosette/RexRN.java | 77 ++- src/main/java/org/cosette/RuleBuilder.java | 7 - 14 files changed, 975 insertions(+), 685 deletions(-) create mode 100644 src/main/java/org/cosette/CodeGenerator.java create mode 100644 src/main/java/org/cosette/RRule.java delete mode 100644 src/main/java/org/cosette/RelJSONShuttle.java delete mode 100644 src/main/java/org/cosette/RewriteRule.java diff --git a/flake.lock b/flake.lock index c04c34d..b8e3f20 100644 --- a/flake.lock +++ b/flake.lock @@ -1,78 +1,78 @@ { "nodes": { - "cvc5-src": { - "flake": false, - "locked": { - "lastModified": 1689701904, - "narHash": "sha256-Jnrnx2WF3u917EbL/NBwjkrpRnWZIRM/qkjbgm/qmK0=", - "owner": "cvc5", - "repo": "cvc5", - "rev": "97a64fc16319ec21f3e31538eeff3da4636a6471", - "type": "github" - }, - "original": { - "owner": "cvc5", - "repo": "cvc5", - "type": "github" - } - }, - "flake-utils": { - "inputs": { - "systems": "systems" - }, - "locked": { - "lastModified": 1689068808, - "narHash": "sha256-6ixXo3wt24N/melDWjq70UuHQLxGV8jZvooRanIHXw0=", - "owner": "numtide", - "repo": "flake-utils", - "rev": "919d646de7be200f3bf08cb76ae1f09402b6f9b4", - "type": "github" - }, - "original": { - "owner": "numtide", - "repo": "flake-utils", - "type": "github" - } - }, - "nixpkgs": { - "locked": { - "lastModified": 1689631193, - "narHash": "sha256-AGSkBZaiTODQc8eT1rZDrQIjtb8JtFwJ0wVPzArlrnM=", - "owner": "NixOS", - "repo": "nixpkgs", - "rev": "57695599bdc4f7bfe5d28cfa23f14b3d8bdf8a5f", - "type": "github" - }, - "original": { - "owner": "NixOS", - "ref": "nixpkgs-unstable", - "repo": "nixpkgs", - "type": "github" - } - }, - "root": { - "inputs": { - "cvc5-src": "cvc5-src", - "flake-utils": "flake-utils", - "nixpkgs": "nixpkgs" - } - }, - "systems": { - "locked": { - "lastModified": 1681028828, - "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", - "owner": "nix-systems", - "repo": "default", - "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", - "type": "github" - }, - "original": { - "owner": "nix-systems", - "repo": "default", - "type": "github" - } - } + "cvc5-src": { + "flake": false, + "locked": { + "lastModified": 1708192045, + "narHash": "sha256-hLb5qvWhAKeaGpR1HMH3URVXLuXjuvqapjx+loCa7Nc=", + "owner": "cvc5", + "repo": "cvc5", + "rev": "80878ee024f58a9a883479eed5dfe06402109e94", + "type": "github" + }, + "original": { + "owner": "cvc5", + "repo": "cvc5", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1705309234, + "narHash": "sha256-uNRRNRKmJyCRC/8y1RqBkqWBLM034y4qN7EprSdmgyA=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "1ef2e671c3b0c19053962c07dbda38332dcebf26", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1708247094, + "narHash": "sha256-H2VS7VwesetGDtIaaz4AMsRkPoSLEVzL/Ika8gnbUnE=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "045b51a3ae66f673ed44b5bbd1f4a341d96703bf", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixpkgs-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "cvc5-src": "cvc5-src", + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } }, "root": "root", "version": 7 - } +} diff --git a/flake.nix b/flake.nix index 1059c8a..296ddb6 100644 --- a/flake.nix +++ b/flake.nix @@ -27,7 +27,7 @@ libantlr3c antlr3_4 boost - jdk + jdk21 (python3.withPackages (ps: with ps; [ pyparsing toml ])) ]; @@ -48,10 +48,10 @@ devShells.default = pkgs.mkShell { packages = with pkgs; [ cvc5-java - jdk + jdk21 jetbrains.idea-community ]; - CVC5_JAVA = "${cvc5-java}/share/java/${cvc5-pname}.jar"; + CVC5_JAVA = "${cvc5-java}/share/java/cvc5.jar"; LD_LIBRARY_PATH = pkgs.lib.strings.makeLibraryPath [ cvc5-java ]; }; }); diff --git a/pom.xml b/pom.xml index 0f06421..d65d4bc 100644 --- a/pom.xml +++ b/pom.xml @@ -39,8 +39,8 @@ org.apache.maven.plugins maven-compiler-plugin - 19 - 19 + 21 + 21 --enable-preview @@ -75,8 +75,13 @@ org.slf4j - slf4j-nop - 2.0.5 + slf4j-api + 2.0.12 + + + org.slf4j + slf4j-simple + 2.0.12 org.glavo.kala @@ -86,7 +91,7 @@ io.github cvc5 - 1.0.5 + 1.1.1 system ${env.CVC5_JAVA} diff --git a/src/main/java/org/cosette/CodeGenerator.java b/src/main/java/org/cosette/CodeGenerator.java new file mode 100644 index 0000000..5d4ee57 --- /dev/null +++ b/src/main/java/org/cosette/CodeGenerator.java @@ -0,0 +1,491 @@ +package org.cosette; + +import kala.collection.Map; +import kala.collection.Seq; + +import java.util.concurrent.atomic.AtomicInteger; + +public interface CodeGenerator { + static void main(String[] args) { + // Calcite: + // - onMatch: Check RelNode type, register uninterpreted symbols + // - transform: Recursively build (depth first) with builder, keep builder expression + var calciteCodeGen = new CodeGenerator() { + final AtomicInteger variableIndex = new AtomicInteger(); + + Env assignVariable(Env env, String expression) { + var name = STR."var_\{variableIndex.getAndIncrement()}"; + return env.state(STR."var \{name} = \{expression};").express(Seq.of(name)); + } + + @Override + public Env preMatch() { + return assignVariable(Env.empty(), "call.rel(0)"); + } + + @Override + public Env preTransform(Env env) { + return assignVariable(env, "call.builder()"); + } + + @Override + public Env postTransform(Env env) { + var result_env = assignVariable(env, STR."\{env.expressions().first()}.build()"); + return result_env.state(STR."call.transformTo(\{result_env.expressions().first()});"); + } + + @Override + public String translate(Env ignored, Env transform) { + StringBuilder builder = new StringBuilder("@Override public void onMatch(RelOptRuleCall call) {\n"); + transform.statements().forEach(statement -> builder.append("\t").append(statement).append("\n")); + builder.append("}\n"); + return builder.toString(); + } + + @Override + public Env onMatchScan(Env env, RelRN.Scan scan) { + return env.symbol(scan.name(), env.expressions().first()); + } + + @Override + public Env onMatchFilter(Env env, RelRN.Filter filter) { + var rel = env.expressions().first(); + env = env.state(STR."if (!(\{rel} instanceof LogicalFilter)) { return; }"); + var rel_filter_env = assignVariable(env, STR."(LogicalFilter) \{rel}"); + var rel_filter = rel_filter_env.expressions().first(); + var source_env = assignVariable(rel_filter_env, STR."\{rel_filter}.getInput()"); + var source_match = onMatch(source_env, filter.source()); + var cond_env = assignVariable(source_match, STR."\{rel_filter}.getCondition()"); + return onMatch(cond_env, filter.cond()); + } + + @Override + public Env onMatchJoin(Env env, RelRN.Join join) { + var rel = env.expressions().first(); + env = env.state(STR."if (!(\{rel} instanceof LogicalJoin)) { return; }"); + var rel_join_env = assignVariable(env, STR."(LogicalJoin) \{rel}"); + var rel_join = rel_join_env.expressions().first(); + var join_type_env = assignVariable(rel_join_env, STR."\{rel_join}.getJoinType()"); + var join_type = join_type_env.expressions().first(); + var type_cond_expression = switch (join.ty()) { + + case INNER -> STR."\{join_type} == JoinRelType.INNER"; + case LEFT -> STR."\{join_type} == JoinRelType.LEFT"; + case RIGHT -> STR."\{join_type} == JoinRelType.RIGHT"; + case FULL -> STR."\{join_type} == JoinRelType.FULL"; + case SEMI -> STR."\{join_type} == JoinRelType.SEMI"; + case ANTI -> STR."\{join_type} == JoinRelType.ANTI"; + }; + var type_cond_env = assignVariable(join_type_env, type_cond_expression); + var type_match = type_cond_env.state(STR."if (!\{type_cond_env.expressions().first()}) { return; }"); + var left_source_env = assignVariable(type_match, STR."\{rel_join}.getLeft()"); + var left_match_env = onMatch(left_source_env, join.left()); + var right_source_env = assignVariable(left_match_env, STR."\{rel_join}.getRight()"); + var right_match_env = onMatch(right_source_env, join.right()); + var cond_source_env = assignVariable(right_match_env, STR."\{rel_join}.getCondition()"); + return onMatch(cond_source_env, join.cond()); + } + + @Override + public Env onMatchPred(Env env, RexRN.Pred pred) { + return env.symbol(pred.name(), env.expressions().first()); + } + + @Override + public Env onMatchCustom(Env env, RexRN custom) { + return switch (custom) { + case RRule.JoinConditionPush.JoinPred joinPred -> { + var pred = env.expressions().first(); + var breakdown_env = assignVariable(env, STR."customSplitFilter(\{pred})"); + var breakdown = breakdown_env.expressions().first(); + yield breakdown_env + .symbol(joinPred.bothPred(), STR."\{breakdown}.getBoth()") + .symbol(joinPred.leftPred(), STR."\{breakdown}.getLeft()") + .symbol(joinPred.rightPred(), STR."\{breakdown}.getRight()"); + } + default -> CodeGenerator.super.onMatchCustom(env, custom); + }; + } + + @Override + public Env transformScan(Env env, RelRN.Scan scan) { + var builder = env.expressions().first(); + var source_transform = STR."\{builder}.push(\{env.symbols().get(scan.name())})"; + return assignVariable(env, source_transform); + } + + @Override + public Env transformFilter(Env env, RelRN.Filter filter) { + var source_transform = transform(env, filter.source()); + var source_expression = source_transform.expressions().first(); + var cond_transform = transform(source_transform, filter.cond()); + return assignVariable(cond_transform, STR."\{source_expression}.filter(\{cond_transform.expressions().first()})"); + } + + @Override + public Env transformJoin(Env env, RelRN.Join join) { + var left_source_transform = transform(env, join.left()); + var right_source_transform = transform(left_source_transform, join.right()); + var source_expression = right_source_transform.expressions().first(); + var join_expression = switch (join.ty()) { + case INNER -> "JoinRelType.INNER"; + case LEFT -> "JoinRelType.LEFT"; + case RIGHT -> "JoinRelType.RIGHT"; + case FULL -> "JoinRelType.FULL"; + case SEMI -> "JoinRelType.SEMI"; + case ANTI -> "JoinRelType.ANTI"; + }; + var cond_transform = transform(right_source_transform, join.cond()); + return assignVariable(cond_transform, STR."\{source_expression}.join(\{join_expression}, \{cond_transform.expressions().first()})"); + } + + @Override + public Env transformPred(Env env, RexRN.Pred pred) { + return env.express(Seq.of(env.symbols().get(pred.name()))); + } + + @Override + public Env transformAnd(Env env, RexRN.And and) { + var source_transform = env; + var operands = Seq.empty(); + for (var source : and.sources()) { + source_transform = transform(source_transform, source); + operands = operands.appended(source_transform.expressions().first()); + source_transform = source_transform.express(env.expressions()); + } + return assignVariable(source_transform, STR."\{env.expressions().first()}.and(\{operands.joinToString(", ")})"); + } + }; + + // Cockroach: Recursively (depth-first) generate DSL + var cockroachCodeGen = new CodeGenerator() { + @Override + public String translate(Env onMatch, Env transform) { + return STR."\{onMatch.expressions().first()}\n=>\n\{transform.expressions().first()}\n"; + } + + @Override + public Env onMatchScan(Env env, RelRN.Scan scan) { + return env.express(Seq.of(STR."$\{scan.name()}:*")); + } + + @Override + public Env onMatchFilter(Env env, RelRN.Filter filter) { + var source_match = onMatch(env, filter.source()); + var source_expression = source_match.expressions().first(); + var cond_match = onMatch(source_match, filter.cond()); + var cond_expression = cond_match.expressions().first(); + return cond_match.express(Seq.of(STR."(Select \{source_expression} \{cond_expression})")); + } + + @Override + public Env onMatchJoin(Env env, RelRN.Join join) { + var left_source_match = onMatch(env, join.left()); + var left_source_expression = left_source_match.expressions().first(); + var right_source_match = onMatch(left_source_match, join.right()); + var right_source_expression = right_source_match.expressions().first(); + var join_expression = switch (join.ty()) { + case INNER -> "InnerJoin"; + case LEFT -> "LeftJoin"; + case RIGHT -> "RightJoin"; + case FULL -> "FullJoin"; + case SEMI -> "SemiJoin"; + case ANTI -> "AntiJoin"; + }; + var cond_match = onMatch(right_source_match, join.cond()); + var cond_expression = cond_match.expressions().first(); + return cond_match.express(Seq.of(STR."(\{join_expression} \{left_source_expression} \{right_source_expression} \{cond_expression} $private:*)")); + } + + @Override + public Env onMatchPred(Env env, RexRN.Pred pred) { + return env.symbol(pred.name(), STR."$\{pred.name()}").express(Seq.of(STR."$\{pred.name()}:*")); + } + + @Override + public Env onMatchCustom(Env env, RexRN custom) { + return switch (custom) { + case RRule.JoinConditionPush.JoinPred joinPred -> env + .symbol(joinPred.bothPred(), STR."(RemoveFiltersItem $\{joinPred.bothPred()} $item)") + .symbol(joinPred.leftPred(), "[(FiltersItem (MapJoinOpFilter $item $leftCols $equivSet))]") + .symbol(joinPred.rightPred(), "[(FiltersItem (MapJoinOpFilter $item $rightCols $equivSet))]") + .express(Seq.of("$on:[... $item:* & (CanMapJoinOpFilter $item $leftCols $equivSet) & (CanMapJoinOpFilter $item $rightCols $equivSet)...]")); + default -> CodeGenerator.super.onMatchCustom(env, custom); + }; + } + + @Override + public Env transformScan(Env env, RelRN.Scan scan) { + return env.express(Seq.of(STR."$\{scan.name()}")); + } + + @Override + public Env transformFilter(Env env, RelRN.Filter filter) { + var source_expression = transform(env, filter.source()).expressions().first(); + var cond_expression = transform(env, filter.cond()).expressions().first(); + return env.express(Seq.of(STR."(Select \{source_expression} \{cond_expression})")); + } + + @Override + public Env transformJoin(Env env, RelRN.Join join) { + var left_source_expression = transform(env, join.left()).expressions().first(); + var right_source_expression = transform(env, join.right()).expressions().first(); + var cond_expression = transform(env, join.cond()).expressions.first(); + return env.express(Seq.of(STR."((OpName) \{left_source_expression} \{right_source_expression} \{cond_expression} $private)")); + } + + @Override + public Env transformPred(Env env, RexRN.Pred pred) { + return env.express(Seq.of(env.symbols().get(pred.name()))); + } + + @Override + public Env transformAnd(Env env, RexRN.And and) { + var operands = and.sources().map(source -> transform(env, source).expressions().first()); + return env.express(Seq.of(STR."(ConcatFilters \{operands.joinToString(", ")})")); + } + }; + + // Note: Join is treated as if it is a custom operator + var filterMerge = new RRule.FilterMerge(); + var joinConditionPush = new RRule.JoinConditionPush(); + var calciteFilterMerge = calciteCodeGen.compose(filterMerge); + var calciteJoinConditionPush = calciteCodeGen.compose(joinConditionPush); + var cockroachFilterMerge = cockroachCodeGen.compose(filterMerge); + var cockroachJoinConditionPush = cockroachCodeGen.compose(joinConditionPush); + System.out.println(filterMerge.explain()); + System.out.println(calciteFilterMerge); + System.out.println(cockroachFilterMerge); + System.out.println(); + System.out.println(joinConditionPush.explain()); + System.out.println(calciteJoinConditionPush); + System.out.println(cockroachJoinConditionPush); + } + + default String unimplemented(String context, Object object) { + return STR."<--\{context}\{object.getClass().getName()}-->"; + } + + default Env unimplementedOnMatch(Env env, Object object) { + return env.express(Seq.of(unimplemented("Unspecified onMatch codegen: ", object))); + } + + default Env unimplementedTransform(Env env, Object object) { + return env.express(Seq.of(unimplemented("Unspecified transform codegen: ", object))); + } + + default Env preMatch() { + return Env.empty(); + } + + default Env onMatch(Env env, RelRN pattern) { + return switch (pattern) { + case RelRN.Scan scan -> onMatchScan(env, scan); + case RelRN.Filter filter -> onMatchFilter(env, filter); + case RelRN.Project project -> onMatchProject(env, project); + case RelRN.Join join -> onMatchJoin(env, join); + case RelRN.Union union -> onMatchUnion(env, union); + case RelRN.Intersect intersect -> onMatchIntersect(env, intersect); + default -> onMatchCustom(env, pattern); + }; + } + + default Env onMatch(Env env, RexRN pattern) { + return switch (pattern) { + case RexRN.Field field -> onMatchField(env, field); + case RexRN.JoinField joinField -> onMatchJoinField(env, joinField); + case RexRN.Proj proj -> onMatchProj(env, proj); + case RexRN.Pred pred -> onMatchPred(env, pred); + case RexRN.And and -> onMatchAnd(env, and); + case RexRN.Or or -> onMatchOr(env, or); + case RexRN.Not not -> onMatchNot(env, not); + default -> onMatchCustom(env, pattern); + }; + } + + default Env postMatch(Env env) { + return env; + } + + default Env preTransform(Env env) { + return env; + } + + default Env transform(Env env, RelRN target) { + return switch (target) { + case RelRN.Scan scan -> transformScan(env, scan); + case RelRN.Filter filter -> transformFilter(env, filter); + case RelRN.Project project -> transformProject(env, project); + case RelRN.Join join -> transformJoin(env, join); + case RelRN.Union union -> transformUnion(env, union); + case RelRN.Intersect intersect -> transformIntersect(env, intersect); + default -> transformCustom(env, target); + }; + } + + default Env transform(Env env, RexRN target) { + return switch (target) { + case RexRN.Field field -> transformField(env, field); + case RexRN.JoinField joinField -> transformJoinField(env, joinField); + case RexRN.Proj proj -> transformProj(env, proj); + case RexRN.Pred pred -> transformPred(env, pred); + case RexRN.And and -> transformAnd(env, and); + case RexRN.Or or -> transformOr(env, or); + case RexRN.Not not -> transformNot(env, not); + default -> transformCustom(env, target); + }; + } + + default Env postTransform(Env env) { + return env; + } + + default String translate(Env onMatch, Env transform) { + return unimplemented("Unspecified translation to target language: ", Env.empty()); + } + + default String compose(RRule rule) { + var onMatch = postMatch(onMatch(preMatch(), rule.before())); + var transform = postTransform(transform(preTransform(onMatch), rule.after())); + return translate(onMatch, transform); + } + + default Env onMatchScan(Env env, RelRN.Scan scan) { + return unimplementedOnMatch(env, scan); + } + + default Env onMatchFilter(Env env, RelRN.Filter filter) { + return unimplementedOnMatch(env, filter); + } + + default Env onMatchProject(Env env, RelRN.Project project) { + return unimplementedOnMatch(env, project); + } + + default Env onMatchJoin(Env env, RelRN.Join join) { + return unimplementedOnMatch(env, join); + } + + default Env onMatchUnion(Env env, RelRN.Union union) { + return unimplementedOnMatch(env, union); + } + + default Env onMatchIntersect(Env env, RelRN.Intersect intersect) { + return unimplementedOnMatch(env, intersect); + } + + default Env onMatchCustom(Env env, RelRN custom) { + return unimplementedOnMatch(env, custom); + } + + default Env onMatchField(Env env, RexRN.Field field) { + return unimplementedOnMatch(env, field); + } + + default Env onMatchJoinField(Env env, RexRN.JoinField joinField) { + return unimplementedOnMatch(env, joinField); + } + + default Env onMatchProj(Env env, RexRN.Proj proj) { + return unimplementedOnMatch(env, proj); + } + + default Env onMatchPred(Env env, RexRN.Pred pred) { + return unimplementedOnMatch(env, pred); + } + + default Env onMatchAnd(Env env, RexRN.And and) { + return unimplementedOnMatch(env, and); + } + + default Env onMatchOr(Env env, RexRN.Or or) { + return unimplementedOnMatch(env, or); + } + + default Env onMatchNot(Env env, RexRN.Not not) { + return unimplementedOnMatch(env, not); + } + + default Env onMatchCustom(Env env, RexRN custom) { + return unimplementedOnMatch(env, custom); + } + + default Env transformScan(Env env, RelRN.Scan scan) { + return unimplementedTransform(env, scan); + } + + default Env transformFilter(Env env, RelRN.Filter filter) { + return unimplementedTransform(env, filter); + } + + default Env transformProject(Env env, RelRN.Project project) { + return unimplementedTransform(env, project); + } + + default Env transformJoin(Env env, RelRN.Join join) { + return unimplementedTransform(env, join); + } + + default Env transformUnion(Env env, RelRN.Union union) { + return unimplementedTransform(env, union); + } + + default Env transformIntersect(Env env, RelRN.Intersect intersect) { + return unimplementedTransform(env, intersect); + } + + default Env transformCustom(Env env, RelRN custom) { + return unimplementedTransform(env, custom); + } + + default Env transformField(Env env, RexRN.Field field) { + return unimplementedTransform(env, field); + } + + default Env transformJoinField(Env env, RexRN.JoinField joinField) { + return unimplementedTransform(env, joinField); + } + + default Env transformProj(Env env, RexRN.Proj proj) { + return unimplementedTransform(env, proj); + } + + default Env transformPred(Env env, RexRN.Pred pred) { + return unimplementedTransform(env, pred); + } + + default Env transformAnd(Env env, RexRN.And and) { + return unimplementedTransform(env, and); + } + + default Env transformOr(Env env, RexRN.Or or) { + return unimplementedTransform(env, or); + } + + default Env transformNot(Env env, RexRN.Not not) { + return unimplementedTransform(env, not); + } + + default Env transformCustom(Env env, RexRN custom) { + return unimplementedTransform(env, custom); + } + + record Env(Seq expressions, Seq statements, Map symbols) { + public static Env empty() { + return new Env(Seq.empty(), Seq.empty(), Map.empty()); + } + + public Env express(Seq expressions) { + return new Env(expressions, statements, symbols); + } + + public Env state(String statement) { + return new Env(expressions, statements.appended(statement), symbols); + } + + public Env symbol(String symbol, String expression) { + return new Env(expressions, statements, symbols.toImmutableMap().putted(symbol, expression)); + } + } + +} diff --git a/src/main/java/org/cosette/JSONDeserializer.java b/src/main/java/org/cosette/JSONDeserializer.java index 2ecd6e4..f76f57b 100644 --- a/src/main/java/org/cosette/JSONDeserializer.java +++ b/src/main/java/org/cosette/JSONDeserializer.java @@ -34,6 +34,94 @@ public record JSONDeserializer() { private final static ObjectMapper mapper = new ObjectMapper(); + private static ImmutableSeq array(JsonNode node) throws Exception { + if (!node.isArray()) throw new Exception(); + return ImmutableSeq.from(node.elements()); + } + + private static ImmutableSeq array(JsonNode node, String path) throws Exception { + return array(node.required(path)); + } + + private static String string(JsonNode node) throws Exception { + if (!node.isTextual()) throw new Exception(); + return node.asText(); + } + + private static String string(JsonNode node, String path) throws Exception { + return string(node.required(path)); + } + + private static int integer(JsonNode node) throws Exception { + if (!node.isInt()) throw new Exception(); + return node.asInt(); + } + + private static int integer(JsonNode node, String path) throws Exception { + return integer(node.required(path)); + } + + private static boolean bool(JsonNode node) throws Exception { + if (!node.isBoolean()) throw new Exception(); + return node.asBoolean(); + } + + static SqlTypeName typeName(String name) { + name = switch (name) { + case "BOOL" -> "BOOLEAN"; + case "INT", "INT2", "INT4", "OID" -> "INTEGER"; + case "TIMESTAMPTZ" -> "TIMESTAMP"; + case "TIMETZ" -> "TIME"; + case "STRING" -> "VARCHAR"; + case "JSONB" -> "VARBINARY"; + default -> name; + }; + return Enum.valueOf(SqlTypeName.class, name); + } + + public static ImmutableSeq load(File file) throws Exception { + return new JSONDeserializer().deserialize(mapper.readTree(file)); + } + + public static void main(String[] args) throws Exception { + var refs = Seq.from(new File("RelOptRulesTest").listFiles()); + for (var file : refs) { + try { + var store = mapper.readTree(file); + new JSONDeserializer().deserialize(store); + } catch (Exception e) { + System.err.println("===> " + file.getName() + " <==="); + System.err.println(e.getMessage()); + System.err.println(); + } + } + } + + public ImmutableSeq deserialize(JsonNode node) throws Exception { + var builder = RuleBuilder.create(); + var tables = array(node, "schemas").mapChecked(schema -> { + var types = array(schema, "types").mapChecked(JSONDeserializer::string); + var nullabilities = array(schema, "nullable").mapChecked(JSONDeserializer::bool); + var name = schema.path("name").asText("DEFAULT_TABLE_NAME"); + var fields = schema.get("fields") == null ? + Seq.fill(types.size(), i -> String.format("DEFAULT_FIELD_NAME_%d", i)) : + array(schema, "fields").mapChecked(JSONDeserializer::string); + var keys = Set.from(array(schema, "key").map( + CheckedFunction.of(key -> ImmutableBitSet.of(array(key).mapChecked(JSONDeserializer::integer))))); + if (types.size() != nullabilities.size()) + throw new Exception("Expecting corresponding types and nullabilities"); + var sts = types.zip(nullabilities).map(tn -> { + var type = builder.getTypeFactory().createSqlType(typeName(tn.component1())); + return builder.getTypeFactory().createTypeWithNullability(type, tn.component2()); + }); + var table = new CosetteTable(name, fields, sts, keys, Set.empty()); + builder.addTable(table); + return table; + }); + var rel = new Rel(builder, ImmutableSeq.empty(), tables); + return array(node, "queries").mapChecked(rel); + } + private record Rel(RuleBuilder builder, ImmutableSeq globals, ImmutableSeq tables) implements CheckedFunction { Rel(RuleBuilder builder) { @@ -156,6 +244,15 @@ yield builder().push(sorted).sortLimit(rex().deserialize(content.required("offse private record Rex(RuleBuilder builder, ImmutableSeq globals, RexCorrelVariable local, ImmutableSeq tables) implements CheckedFunction { + static Seq ops = Seq.from(SqlStdOperatorTable.class.getDeclaredFields()) + .filter(f -> java.lang.reflect.Modifier.isPublic(f.getModifiers()) && + java.lang.reflect.Modifier.isStatic(f.getModifiers())).map(f -> { + var mist = Try.of(() -> f.get(null)).getOrNull(); + if (mist == null) return null; + if (mist instanceof SqlOperator op) return op; + return null; + }).filter(Objects::nonNull); + public RexNode resolve(int lvl) { assert lvl < globals().size() + local().getType().getFieldCount(); return lvl < globals().size() ? globals().get(lvl) : builder().getRexBuilder() @@ -177,15 +274,6 @@ public RelDataType type(String name) { return builder().getTypeFactory().createSqlType(typeName(name)); } - static Seq ops = Seq.from(SqlStdOperatorTable.class.getDeclaredFields()) - .filter(f -> java.lang.reflect.Modifier.isPublic(f.getModifiers()) && - java.lang.reflect.Modifier.isStatic(f.getModifiers())).map(f -> { - var mist = Try.of(() -> f.get(null)).getOrNull(); - if (mist == null) return null; - if (mist instanceof SqlOperator op) return op; - return null; - }).filter(Objects::nonNull); - SqlOperator op(String name, int arity) throws Exception { switch (name) { case "BOOL_AND" -> { @@ -273,93 +361,5 @@ public RexNode deserialize(JsonNode node) throws Exception { } } } - - private static ImmutableSeq array(JsonNode node) throws Exception { - if (!node.isArray()) throw new Exception(); - return ImmutableSeq.from(node.elements()); - } - - private static ImmutableSeq array(JsonNode node, String path) throws Exception { - return array(node.required(path)); - } - - private static String string(JsonNode node) throws Exception { - if (!node.isTextual()) throw new Exception(); - return node.asText(); - } - - private static String string(JsonNode node, String path) throws Exception { - return string(node.required(path)); - } - - private static int integer(JsonNode node) throws Exception { - if (!node.isInt()) throw new Exception(); - return node.asInt(); - } - - private static int integer(JsonNode node, String path) throws Exception { - return integer(node.required(path)); - } - - private static boolean bool(JsonNode node) throws Exception { - if (!node.isBoolean()) throw new Exception(); - return node.asBoolean(); - } - - static SqlTypeName typeName(String name) { - name = switch (name) { - case "BOOL" -> "BOOLEAN"; - case "INT", "INT2", "INT4", "OID" -> "INTEGER"; - case "TIMESTAMPTZ" -> "TIMESTAMP"; - case "TIMETZ" -> "TIME"; - case "STRING" -> "VARCHAR"; - case "JSONB" -> "VARBINARY"; - default -> name; - }; - return Enum.valueOf(SqlTypeName.class, name); - } - - public ImmutableSeq deserialize(JsonNode node) throws Exception { - var builder = RuleBuilder.create(); - var tables = array(node, "schemas").mapChecked(schema -> { - var types = array(schema, "types").mapChecked(JSONDeserializer::string); - var nullabilities = array(schema, "nullable").mapChecked(JSONDeserializer::bool); - var name = schema.path("name").asText("DEFAULT_TABLE_NAME"); - var fields = schema.get("fields") == null ? - Seq.fill(types.size(), i -> String.format("DEFAULT_FIELD_NAME_%d", i)) : - array(schema, "fields").mapChecked(JSONDeserializer::string); - var keys = Set.from(array(schema, "key").map( - CheckedFunction.of(key -> ImmutableBitSet.of(array(key).mapChecked(JSONDeserializer::integer))))); - if (types.size() != nullabilities.size()) - throw new Exception("Expecting corresponding types and nullabilities"); - var sts = types.zip(nullabilities).map(tn -> { - var type = builder.getTypeFactory().createSqlType(typeName(tn.component1())); - return builder.getTypeFactory().createTypeWithNullability(type, tn.component2()); - }); - var table = new CosetteTable(name, fields, sts, keys, Set.empty()); - builder.addTable(table); - return table; - }); - var rel = new Rel(builder, ImmutableSeq.empty(), tables); - return array(node, "queries").mapChecked(rel); - } - - public static ImmutableSeq load(File file) throws Exception { - return new JSONDeserializer().deserialize(mapper.readTree(file)); - } - - public static void main(String[] args) throws Exception { - var refs = Seq.from(new File("RelOptRulesTest").listFiles()); - for (var file : refs) { - try { - var store = mapper.readTree(file); - new JSONDeserializer().deserialize(store); - } catch (Exception e) { - System.err.println("===> " + file.getName() + " <==="); - System.err.println(e.getMessage()); - System.err.println(); - } - } - } } diff --git a/src/main/java/org/cosette/JSONSerializer.java b/src/main/java/org/cosette/JSONSerializer.java index e29204a..0dc2a1e 100644 --- a/src/main/java/org/cosette/JSONSerializer.java +++ b/src/main/java/org/cosette/JSONSerializer.java @@ -21,32 +21,61 @@ public record JSONSerializer(Env env) { private final static ObjectMapper mapper = new ObjectMapper(); - private record Rel(Env env) { - Rel() { - this(new Env(0, ImmutableMap.empty(), MutableList.create())); - } + private static ArrayNode array(Seq objs) { + return new ArrayNode(mapper.getNodeFactory(), objs.asJava()); + } - private record Env(int lvl, ImmutableMap globals, MutableList tables) { - Env recorded(Set ids) { - return new Env(lvl, Seq.from(ids).foldLeft(globals, (g, id) -> g.putted(id, lvl)), tables); - } + private static ObjectNode object(Map fields) { + return new ObjectNode(mapper.getNodeFactory(), fields.asJava()); + } - Env lifted(int d) { - return new Env(lvl + d, globals, tables); - } + private static BooleanNode bool(boolean b) { + return BooleanNode.valueOf(b); + } - int resolve(RelOptTable table) { - var idx = tables.indexOf(table); - if (idx == -1) { - idx = tables.size(); - tables.append(table); - } - return idx; - } + private static TextNode string(String s) { + return new TextNode(s); + } - public Rex.Env rex(int delta) { - return new Rex.Env(lvl, delta, globals, tables); - } + private static TextNode type(RelDataType type) { + return new TextNode(type.getSqlTypeName().getName()); + } + + private static IntNode integer(int i) { + return new IntNode(i); + } + + public static ObjectNode serialize(Seq relNodes) { + var shuttle = new Rel(); + var helps = array(relNodes.map(rel -> new TextNode(rel.explain()))); + var queries = array(relNodes.map(shuttle::serialize)); + var tables = shuttle.env.tables(); + var schemas = array(tables.map(table -> { + var visitor = new Rex(shuttle.env.rex(table.getRowType().getFieldCount())); + var cosette = table.unwrap(CosetteTable.class); + var fields = Seq.from(table.getRowType().getFieldList()); + return cosette == null ? + object(Map.of("name", string(Seq.from(table.getQualifiedName()).joinToString(".")), "fields", + array(fields.map(field -> string(field.getName()))), "types", + array(fields.map(field -> type(field.getType()))), "nullable", + array(fields.map(field -> bool(field.getType().isNullable()))), "key", + array((table.getKeys() != null ? Seq.from(table.getKeys()) : + Seq.empty()).map( + key -> array(Seq.from(key).map(JSONSerializer::integer)))), "guaranteed", + array(Seq.empty()))) : object(Map.of("name", string(cosette.getName()), "fields", + array(cosette.getColumnNames().map(JSONSerializer::string)), "types", + array(cosette.getColumnTypes().map(JSONSerializer::type)), "nullable", + array(cosette.getColumnTypes().map(type -> bool(type.isNullable()))), "key", + array(Seq.from(cosette.getKeys().map(key -> array(Seq.from(key).map(JSONSerializer::integer))))), + "guaranteed", array(cosette.getConstraints().map(visitor::serialize).toImmutableSeq()))); + })); + + return object(Map.of("schemas", schemas, "queries", queries, "help", helps)); + } + + private record Rel(Env env) { + Rel() { + this(new Env(0, ImmutableMap.empty(), MutableList.create())); } public JsonNode serialize(RelNode rel) { @@ -132,20 +161,32 @@ yield object(Map.of("sort", default -> throw new RuntimeException("Not implemented: " + rel.getRelTypeName()); }; } - } - private record Rex(Env env) { - private record Env(int base, int delta, ImmutableMap globals, - MutableList tables) { - public Rel.Env rel() { - return new Rel.Env(base + delta, globals, tables); + private record Env(int lvl, ImmutableMap globals, MutableList tables) { + Env recorded(Set ids) { + return new Env(lvl, Seq.from(ids).foldLeft(globals, (g, id) -> g.putted(id, lvl)), tables); } - int resolve(CorrelationId id) { - return globals.getOrThrow(id, () -> new RuntimeException("Correlation ID not declared")); + Env lifted(int d) { + return new Env(lvl + d, globals, tables); + } + + int resolve(RelOptTable table) { + var idx = tables.indexOf(table); + if (idx == -1) { + idx = tables.size(); + tables.append(table); + } + return idx; + } + + public Rex.Env rex(int delta) { + return new Rex.Env(lvl, delta, globals, tables); } } + } + private record Rex(Env env) { public JsonNode serialize(RexNode rex) { return switch (rex) { case RexInputRef inputRef -> object(Map.of("column", integer(inputRef.getIndex() + env.base()), "type", @@ -165,57 +206,16 @@ public JsonNode serialize(RexNode rex) { default -> throw new RuntimeException("Not implemented: " + rex.getKind()); }; } - } - - private static ArrayNode array(Seq objs) { - return new ArrayNode(mapper.getNodeFactory(), objs.asJava()); - } - - private static ObjectNode object(Map fields) { - return new ObjectNode(mapper.getNodeFactory(), fields.asJava()); - } - - private static BooleanNode bool(boolean b) { - return BooleanNode.valueOf(b); - } - - private static TextNode string(String s) { - return new TextNode(s); - } - - private static TextNode type(RelDataType type) { - return new TextNode(type.getSqlTypeName().getName()); - } - - private static IntNode integer(int i) { - return new IntNode(i); - } - public static ObjectNode serialize(Seq relNodes) { - var shuttle = new Rel(); - var helps = array(relNodes.map(rel -> new TextNode(rel.explain()))); - var queries = array(relNodes.map(shuttle::serialize)); - var tables = shuttle.env.tables(); - var schemas = array(tables.map(table -> { - var visitor = new Rex(shuttle.env.rex(table.getRowType().getFieldCount())); - var cosette = table.unwrap(CosetteTable.class); - var fields = Seq.from(table.getRowType().getFieldList()); - return cosette == null ? - object(Map.of("name", string(Seq.from(table.getQualifiedName()).joinToString(".")), "fields", - array(fields.map(field -> string(field.getName()))), "types", - array(fields.map(field -> type(field.getType()))), "nullable", - array(fields.map(field -> bool(field.getType().isNullable()))), "key", - array((table.getKeys() != null ? Seq.from(table.getKeys()) : - Seq.empty()).map( - key -> array(Seq.from(key).map(JSONSerializer::integer)))), "guaranteed", - array(Seq.empty()))) : object(Map.of("name", string(cosette.getName()), "fields", - array(cosette.getColumnNames().map(JSONSerializer::string)), "types", - array(cosette.getColumnTypes().map(JSONSerializer::type)), "nullable", - array(cosette.getColumnTypes().map(type -> bool(type.isNullable()))), "key", - array(Seq.from(cosette.getKeys().map(key -> array(Seq.from(key).map(JSONSerializer::integer))))), - "guaranteed", array(cosette.getConstraints().map(visitor::serialize).toImmutableSeq()))); - })); + private record Env(int base, int delta, ImmutableMap globals, + MutableList tables) { + public Rel.Env rel() { + return new Rel.Env(base + delta, globals, tables); + } - return object(Map.of("schemas", schemas, "queries", queries, "help", helps)); + int resolve(CorrelationId id) { + return globals.getOrThrow(id, () -> new RuntimeException("Correlation ID not declared")); + } + } } } diff --git a/src/main/java/org/cosette/RRule.java b/src/main/java/org/cosette/RRule.java new file mode 100644 index 0000000..15265ab --- /dev/null +++ b/src/main/java/org/cosette/RRule.java @@ -0,0 +1,94 @@ +package org.cosette; + +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rex.RexNode; + +public interface RRule { + + RelRN before(); + + RelRN after(); + + default String explain() { + return STR."\{getClass().getName()}\n\{before().semantics().explain()}=>\n\{after().semantics().explain()}"; + } + + record FilterMerge() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN inner = source.pred("inner"); + static final RexRN outer = source.pred("outer"); + + @Override + public RelRN before() { + return source.filter(inner).filter(outer); + } + + @Override + public RelRN after() { + return source.filter(RexRN.and(inner, outer)); + } + } + + record FilterIntoJoin() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + var join = left.join(JoinRelType.INNER, joinCond, right); + return join.filter("outer"); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.INNER, RexRN.and(joinCond, left.joinPred("outer", right)), right); + } + } + + record FilterProjectTranspose() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN proj = source.proj("proj", "Project_Type"); + + @Override + public RelRN before() { + return source.filter(proj.pred("pred")).project(proj); + } + + @Override + public RelRN after() { + return source.project(proj).filter("pred"); + } + } + + record JoinConditionPush() implements RRule { + record JoinPred(RelRN left, RelRN right) implements RexRN { + + @Override + public RexNode semantics() { + return RexRN.and(left.joinPred(bothPred(), right), left.joinField(0, right).pred(leftPred()), left.joinField(1, right).pred(rightPred())).semantics(); + } + + public String bothPred() { return "both"; } + public String leftPred() { return "left"; } + public String rightPred() { return "right"; } + + } + + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final JoinPred joinPred = new JoinPred(left, right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, joinPred, right); + } + + @Override + public RelRN after() { + var leftRN = left.filter(joinPred.leftPred()); + var rightRN = right.filter(joinPred.rightPred()); + return leftRN.join(JoinRelType.INNER, joinPred.bothPred(), rightRN); + } + } +} diff --git a/src/main/java/org/cosette/RelJSONShuttle.java b/src/main/java/org/cosette/RelJSONShuttle.java deleted file mode 100644 index a17c94b..0000000 --- a/src/main/java/org/cosette/RelJSONShuttle.java +++ /dev/null @@ -1,364 +0,0 @@ -package org.cosette; - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.*; -import kala.collection.Map; -import kala.collection.Seq; -import kala.collection.Set; -import kala.collection.immutable.ImmutableSeq; -import kala.control.Result; -import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.core.TableScan; -import org.apache.calcite.rel.logical.*; -import org.apache.calcite.rel.type.*; -import org.apache.calcite.rex.*; -import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.fun.SqlStdOperatorTable; -import org.apache.calcite.util.ImmutableBitSet; - -import java.io.IOException; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.List; - -public record RelJSONShuttle(Env env) { - private final static ObjectMapper mapper = new ObjectMapper(); - - private static ArrayNode array(Seq objs) { - return new ArrayNode(mapper.getNodeFactory(), objs.asJava()); - } - - - private static Result, String> array(JsonNode jsonNode, String field) { - var arr = jsonNode.get(field); - if (arr == null || !arr.isArray()) { - return Result.err(String.format("Missing array field %s in:\n%s", field, jsonNode.toPrettyString())); - } - return Result.ok(ImmutableSeq.from(arr.elements())); - } - - private static ObjectNode object(Map fields) { - return new ObjectNode(mapper.getNodeFactory(), fields.asJava()); - } - - private static Result object(JsonNode jsonNode, String field) { - var obj = jsonNode.get(field); - if (obj == null) { - return Result.err(String.format("Missing object field %s in:\n%s", field, jsonNode.toPrettyString())); - } - return Result.ok(obj); - } - - private static BooleanNode bool(boolean b) { - return b ? BooleanNode.TRUE : BooleanNode.FALSE; - } - - private static T unwrap(Result res) throws Exception { - if (res.isErr()) { - throw new Exception(res.getErr()); - } - return res.get(); - } - - public static void main(String[] args) throws IOException { - var res = RelJSONShuttle.deserializeFromJson(Paths.get("ElevatedRules/filterProjectTranspose.json")); - if (res.isErr()) { - System.out.println(res.getErr()); - } else { - res.get().forEach(r -> System.out.println(r.explain())); - } - } - - public static void serializeToJson(List relNodes, Path path) throws IOException { - var shuttle = new RelJSONShuttle(Env.empty()); - var helps = array(Seq.from(relNodes).map(rel -> new TextNode(rel.explain()))); - var queries = array(Seq.from(relNodes).map(shuttle::serialize)); - - var tables = shuttle.env.tables(); - var schemas = array(tables.map(table -> object(Map.of( - "name", new TextNode(table.getName()), - "fields", array(table.getColumnNames().map(TextNode::new)), - "types", array(table.getColumnTypes().map(type -> new TextNode(type.toString()))), - "nullable", array(table.getColumnTypes().map(RelDataType::isNullable).map(RelJSONShuttle::bool)), - "key", array(Seq.from(table.getKeys().map(key -> array(Seq.from(key).map(IntNode::new))))), - "guaranteed", array(table.getConstraints() - .map(check -> new RexJSONVisitor(shuttle.env.advanced(table.getColumnNames().size())).serialize(check)).toImmutableSeq()) - )))); - - var main = object(Map.of("schemas", schemas, "queries", queries, "help", helps)); - mapper.writerWithDefaultPrettyPrinter().writeValue(path.toFile(), main); - } - - public static Result, String> deserializeFromJson(Path path) throws IOException { - var node = mapper.readTree(path.toFile()); - var env = Env.empty(); - var tables = array(node, "schemas").flatMap(schemas -> { - var collected = ImmutableSeq.empty(); - for (var schema : schemas) { - try { - var tys = unwrap(array(schema, "types")); - var nbs = unwrap(array(schema, "nullable")); - var nm = unwrap(object(schema, "name")); - var fds = unwrap(array(schema, "fields")).map(JsonNode::asText); - var kys = unwrap(array(schema, "key")); - var kgs = Set.from(kys.map(kg -> ImmutableBitSet.of(Seq.from(kg.elements()).map(JsonNode::asInt)))); - if (tys.size() != nbs.size()) { - return Result.err("Expecting corresponding types and nullabilities"); - } - var sts = tys.zip(nbs).map(tn -> (RelDataType) RelType.fromString(tn.component1().asText(), - tn.component2().asBoolean())); - collected = collected.appended(new CosetteTable(nm.asText(), fds, sts, kgs, Set.empty())); - } catch (Exception e) { - return Result.err( - String.format("Broken table schemas: %s in\n%s", e.getMessage(), schema.toPrettyString())); - } - } - return Result.ok(collected); - }); - if (tables.isErr()) { - return Result.err(tables.getErr()); - } - env.tables().appendAll(tables.get()); - var queries = array(node, "queries"); - if (queries.isErr()) { - return Result.err(queries.getErr()); - } - var shuttle = new RelJSONShuttle(env); - return queries.get().map(q -> { - var builder = RuleBuilder.create(); - tables.get().forEach(builder::addTable); - return shuttle.deserialize(builder, q); - }).foldLeft(Result.ok(ImmutableSeq.empty()), (qs, qb) -> qs.flatMap(s -> qb.map(b -> s.appended(b.build())))); - } - - public JsonNode serialize(RelNode rel) { - return switch (rel) { - case TableScan scan -> - object(Map.of("scan", new IntNode(env.resolve(scan.getTable().unwrap(CosetteTable.class))))); - case LogicalValues values -> { - var visitor = new RexJSONVisitor(env); - var schema = array(Seq.from(values.getRowType().getFieldList()) - .map(field -> new TextNode(field.getType().toString()))); - var records = array(Seq.from(values.getTuples()) - .map(tuple -> array(Seq.from(tuple).map(visitor::serialize)))); - yield object(Map.of("values", object(Map.of("schema", schema, "content", records)))); - } - case LogicalFilter filter -> { - var visitor = new RexJSONVisitor(env.advanced(filter.getInput().getRowType().getFieldCount()) - .recorded(filter.getVariablesSet())); - yield object(Map.of("filter", - object(Map.of("condition", visitor.serialize(filter.getCondition()), "source", - serialize(filter.getInput()))))); - } - case LogicalProject project -> { - var visitor = new RexJSONVisitor(env.advanced(project.getInput().getRowType().getFieldCount()) - .recorded(project.getVariablesSet())); - var targets = array(Seq.from(project.getProjects()).map(visitor::serialize)); - yield object( - Map.of("project", object(Map.of("target", targets, "source", serialize(project.getInput()))))); - } - case LogicalJoin join -> { - var left = join.getLeft(); - var right = join.getRight(); - var visitor = new RexJSONVisitor( - env.advanced(left.getRowType().getFieldCount() + right.getRowType().getFieldCount()) - .recorded(join.getVariablesSet())); - yield object(Map.of("join", - object(Map.of("kind", new TextNode(join.getJoinType().toString()), "condition", - visitor.serialize(join.getCondition()), "left", serialize(left), "right", - serialize(right))))); - } - case LogicalCorrelate correlate -> { - var rightShuttle = new RelJSONShuttle(env.advanced(correlate.getLeft().getRowType().getFieldCount()) - .recorded(correlate.getVariablesSet()).advanced(0)); - yield object(Map.of("correlate", - array(Seq.of(serialize(correlate.getLeft()), rightShuttle.serialize(correlate.getRight()))))); - } - case LogicalAggregate aggregate -> { - var groupCount = aggregate.getGroupCount(); - var level = env.base(); - var types = Seq.from(aggregate.getInput().getRowType().getFieldList()) - .map(type -> new TextNode(type.getType().toString())); - var keyCols = array(Seq.from(aggregate.getGroupSet()) - .map(key -> object(Map.of("column", new IntNode(level + key), "type", types.get(key))))); - var keys = object(Map.of("project", - object(Map.of("target", keyCols, "source", serialize(aggregate.getInput()))))); - var conditions = array(Seq.from(aggregate.getGroupSet()).mapIndexed((i, key) -> { - var type = types.get(key); - var leftCol = object(Map.of("column", new IntNode(level + i), "type", type)); - var rightCol = object(Map.of("column", new IntNode(level + groupCount + key), "type", type)); - return object( - Map.of("operator", new TextNode("<=>"), "operand", array(Seq.of(leftCol, rightCol)), "type", - new TextNode("BOOLEAN"))); - })); - var condition = object(Map.of("operator", new TextNode("AND"), "operand", conditions, "type", - new TextNode("BOOLEAN"))); - var aggs = array(Seq.from(aggregate.getAggCallList()).map(call -> object( - Map.of("operator", new TextNode(call.getAggregation().getName()), "operand", - array(Seq.from(call.getArgList()).map(target -> object( - Map.of("column", new IntNode(level + groupCount + target), "type", - types.get(target))))), "distinct", bool(call.isDistinct()), - "ignoreNulls", bool(call.ignoreNulls()), "type", - new TextNode(call.getType().toString()))))); - var aggregated = object(Map.of("aggregate", object(Map.of("function", aggs, "source", - object(Map.of("filter", object(Map.of("condition", condition, "source", - new RelJSONShuttle(env.lifted(groupCount)).serialize(aggregate.getInput()))))))))); - yield object(Map.of("distinct", object(Map.of("correlate", array(Seq.of(keys, aggregated)))))); - } - case LogicalUnion union -> { - var result = object(Map.of("union", array(Seq.from(union.getInputs()).map(this::serialize)))); - yield union.all ? result : object(Map.of("distinct", result)); - } - case LogicalIntersect intersect when !intersect.all -> - object(Map.of("intersect", array(Seq.from(intersect.getInputs()).map(this::serialize)))); - case LogicalMinus minus when !minus.all -> - object(Map.of("except", array(Seq.from(minus.getInputs()).map(this::serialize)))); - case LogicalSort sort -> { - var types = Seq.from(sort.getInput().getRowType().getFieldList()) - .map(type -> new TextNode(type.getType().toString())); - var collations = array(Seq.from(sort.collation.getFieldCollations()).map(collation -> { - var index = collation.getFieldIndex(); - return array(Seq.of(new IntNode(index), types.get(index), new TextNode(collation.shortString()))); - })); - var args = object(Map.of("collation", collations, "source", serialize(sort.getInput()))); - var visitor = new RexJSONVisitor(env.advanced(sort.getInput().getRowType().getFieldCount())); - if (sort.offset != null) { - args.set("offset", visitor.serialize(sort.offset)); - } - if (sort.fetch != null) { - args.set("limit", visitor.serialize(sort.fetch)); - } - yield object(Map.of("sort", args)); - } - default -> throw new RuntimeException("Not implemented: " + rel.getRelTypeName()); - }; - } - - public Result deserialize(RuleBuilder builder, JsonNode jsonNode) { - var entry = jsonNode.fields().next(); - var kind = entry.getKey(); - var content = entry.getValue(); - return switch (kind) { - case String k when k.equals("scan") -> { - if (content.isInt() && 0 <= content.asInt() && content.asInt() < env.tables().size()) { - builder.scan(env.tables().get(content.asInt()).getName()); - yield Result.ok(builder); - } - yield Result.err(String.format("Missing table with index %s", content.toPrettyString())); - } - case String k when k.equals("values") -> { - try { - var et = unwrap(array(content, "schema")); - var rt = new RelRecordType(StructKind.FULLY_QUALIFIED, et.mapIndexed( - (i, t) -> (RelDataTypeField) new RelDataTypeFieldImpl(String.format("VALUES-%s", i), i, - RelType.fromString(t.asText(), true))).asJava()); - var vs = unwrap(array(content, "content")); - var vals = ImmutableSeq.>empty(); - for (var v : vs) { - var val = ImmutableSeq.empty(); - if (!v.isArray()) { - yield Result.err("Expecting tuple (JSON list) as value"); - } - for (var jl : Seq.from(v.elements())) { - var l = unwrap(new RexJSONVisitor(env).deserialize(builder, jl)); - if (l instanceof RexLiteral) { - val = val.appended((RexLiteral) l); - } else { - yield Result.err("Expecting literal expression"); - } - } - vals = vals.appended(val.asJava()); - } - builder.values(vals.asJava(), rt); - yield Result.ok(builder); - } catch (Exception e) { - yield Result.err(e.getMessage()); - } - } - case String k when k.equals("filter") -> { - try { - var cond = unwrap(object(content, "condition")); - var source = unwrap(object(content, "source")); - var bs = unwrap(deserialize(builder, source)); - var c = unwrap(new RexJSONVisitor(env).deserialize(builder, cond)); - bs.filter(c); - yield Result.ok(bs); - } catch (Exception e) { - yield Result.err(e.getMessage()); - } - } - case String k when k.equals("project") -> { - try { - var target = unwrap(array(content, "target")); - var source = unwrap(object(content, "source")); - var bs = unwrap(deserialize(builder, source)); - var ps = target.mapChecked(t -> unwrap(new RexJSONVisitor(env).deserialize(builder, t))); - bs.project(ps); - yield Result.ok(bs); - } catch (Exception e) { - yield Result.err(e.getMessage()); - } - } - case String k when k.equals("join") -> Result.err("Not implemented yet"); - case String k when k.equals("correlate") -> Result.err("Not implemented yet"); - default -> Result.err(String.format("Unrecognized node:\n%s", jsonNode.toPrettyString())); - }; - } - - public record RexJSONVisitor(Env env) { - public JsonNode serialize(RexNode rex) { - return switch (rex) { - case RexInputRef inputRef -> - object(Map.of("column", new IntNode(inputRef.getIndex() + env.base()), "type", - new TextNode(inputRef.getType().toString()))); - case RexLiteral literal -> object(Map.of("operator", - new TextNode(literal.getValue() == null ? "NULL" : literal.getValue().toString()), "operand", - array(Seq.empty()), "type", new TextNode(literal.getType().toString()))); - case RexSubQuery subQuery -> - object(Map.of("operator", new TextNode(subQuery.getOperator().toString()), "operand", - array(Seq.from(subQuery.getOperands()).map(this::serialize)), "query", - new RelJSONShuttle(env.advanced(0)).serialize(subQuery.rel), "type", - new TextNode(subQuery.getType().toString()))); - case RexCall call -> object(Map.of("operator", new TextNode(call.getOperator().toString()), "operand", - array(Seq.from(call.getOperands()).map(this::serialize)), "type", - new TextNode(call.getType().toString()))); - case RexFieldAccess fieldAccess -> object(Map.of("column", new IntNode( - fieldAccess.getField().getIndex() + - env.resolve(((RexCorrelVariable) fieldAccess.getReferenceExpr()).id)), "type", - new TextNode(fieldAccess.getType().toString()))); - default -> throw new RuntimeException("Not implemented: " + rex.getKind()); - }; - } - - public Result deserialize(RuleBuilder builder, JsonNode jsonNode) { - if (jsonNode.has("column") && jsonNode.get("column").isInt()) { - // WARNING: THIS IS WRONG! NO ENVIRONMENT CONSIDERED! - return Result.ok(builder.field(jsonNode.get("column").asInt())); - } else if (jsonNode.has("operator") && jsonNode.get("operator").isTextual()) { - var op = jsonNode.get("operator").asText(); - try { - var args = unwrap(array(jsonNode, "operand")); - var ty = RelType.fromString(unwrap(object(jsonNode, "type")).asText(), true); - if (args.isEmpty()) { - return Result.ok(RexLiteral.fromJdbcString(ty, ty.getSqlTypeName(), op)); - } else { - var fields = args.mapChecked(expr -> unwrap(deserialize(builder, expr))); - for (var refl : Seq.from(SqlStdOperatorTable.class.getDeclaredFields()) - .filter(f -> java.lang.reflect.Modifier.isPublic(f.getModifiers()) && - java.lang.reflect.Modifier.isStatic(f.getModifiers()))) { - var mist = refl.get(null); - if (mist instanceof SqlOperator sqlOperator && sqlOperator.getName().equals(op)) { - return Result.ok(builder.call(sqlOperator, fields)); - } - } - return Result.ok(builder.call(builder.genericProjectionOp(op, ty), fields)); - } - } catch (Exception e) { - return Result.err(e.getMessage()); - } - } - return Result.err(String.format("Unrecognized node:\n%s", jsonNode.toPrettyString())); - } - } -} \ No newline at end of file diff --git a/src/main/java/org/cosette/RelMatcher.java b/src/main/java/org/cosette/RelMatcher.java index a6e3c77..20f85be 100644 --- a/src/main/java/org/cosette/RelMatcher.java +++ b/src/main/java/org/cosette/RelMatcher.java @@ -9,7 +9,6 @@ import kala.tuple.Tuple; import kala.tuple.Tuple2; import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.logical.LogicalFilter; import org.apache.calcite.rel.logical.LogicalJoin; import org.apache.calcite.rel.logical.LogicalProject; @@ -17,7 +16,6 @@ import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; -import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.util.ImmutableBitSet; import java.util.stream.IntStream; diff --git a/src/main/java/org/cosette/RelRN.java b/src/main/java/org/cosette/RelRN.java index 39c9ec1..91bb217 100644 --- a/src/main/java/org/cosette/RelRN.java +++ b/src/main/java/org/cosette/RelRN.java @@ -6,29 +6,69 @@ import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.util.ImmutableBitSet; +import java.util.stream.IntStream; + public interface RelRN { + static Scan scan(String id, RelType.VarType ty, boolean unique) { + return new Scan(id, ty, unique); + } + + static Scan scan(String id, String typeName) { + return scan(id, RexRN.varType(typeName, true), false); + } + RelNode semantics(); - static RelType.VarType varType(String id, boolean nullable) { - return new RelType.VarType(id, nullable); + default RexRN field(int ordinal) { + return new RexRN.Field(ordinal, this); } - static Scan scan(String id, RelType.VarType ty, boolean unique) { - return new Scan(id, ty, unique); + default Seq fields() { + return Seq.from(IntStream.range(0, semantics().getRowType().getFieldCount()).iterator()).map(this::field); + } + + default RexRN joinField(int ordinal, RelRN right) { + return new RexRN.JoinField(ordinal, this, right); + } + + default Seq joinFields(RelRN right) { + return Seq.from(IntStream.range(0, semantics().getRowType().getFieldCount() + right.semantics().getRowType().getFieldCount()).iterator()).map(i -> joinField(i, right)); + } + + default RexRN.Pred pred(String name) { + return new RexRN.Pred(name, true, fields()); + } + + default RexRN.Pred joinPred(String name, RelRN right) { + return new RexRN.Pred(name, true, joinFields(right)); + } + + default RexRN.Proj proj(String name, String type_name) { + return new RexRN.Proj(name, type_name, true, fields()); } default Filter filter(RexRN cond) { return new Filter(cond, this); } - default Project project(Seq map) { - return new Project(map, this); + default Filter filter(String name) { + return filter(pred(name)); + } + + default Project project(RexRN proj) { + return new Project(proj, this); + } + + default Project project(String name, String type_name) { + return project(proj(name, type_name)); } default Join join(JoinRelType ty, RexRN cond, RelRN right) { return new Join(ty, cond, this, right); } + default Join join(JoinRelType ty, String name, RelRN right) { return join(ty, joinPred(name, right), right); } + default Union union(boolean all, RelRN... sources) { return new Union(all, Seq.of(this).appendedAll(sources)); } @@ -37,40 +77,40 @@ default Intersect intersect(boolean all, RelRN... sources) { return new Intersect(all, Seq.of(this).appendedAll(sources)); } - record Scan(String id, RelType.VarType ty, boolean unique) implements RelRN { + record Scan(String name, RelType.VarType ty, boolean unique) implements RelRN { @Override public RelNode semantics() { - var table = new CosetteTable(id, Seq.of("col-" + id), Seq.of(ty), unique ? Set.of(ImmutableBitSet.of(0)) : Set.empty(), Set.empty()); - return RuleBuilder.create().addTable(table).scan(id).build(); + var table = new CosetteTable(name, Seq.of(STR."col-\{name}"), Seq.of(ty), unique ? Set.of(ImmutableBitSet.of(0)) : Set.empty(), Set.empty()); + return RuleBuilder.create().addTable(table).scan(name).build(); } } record Filter(RexRN cond, RelRN source) implements RelRN { @Override public RelNode semantics() { - var builder = RuleBuilder.create(); - builder.push(source.semantics()); - return builder.filter(cond.semantics().apply(builder)).build(); + return RuleBuilder.create().push(source.semantics()).filter(cond.semantics()).build(); } } - record Project(Seq map, RelRN source) implements RelRN { + record Project(RexRN map, RelRN source) implements RelRN { @Override public RelNode semantics() { - var builder = RuleBuilder.create(); - builder.push(source.semantics()); - return builder.project(map.map(m -> m.semantics().apply(builder))).build(); + return RuleBuilder.create().push(source.semantics()).project(map.semantics()).build(); } } record Join(JoinRelType ty, RexRN cond, RelRN left, RelRN right) implements RelRN { @Override public RelNode semantics() { - var builder = RuleBuilder.create(); - builder.push(left.semantics()).push(left.semantics()); - return builder.join(ty, cond.semantics().apply(builder)).build(); + return RuleBuilder.create().push(left.semantics()).push(right.semantics()).join(ty, cond.semantics()).build(); } + + @Override + public RexRN field(int ordinal) { + return new RexRN.JoinField(ordinal, left, right); + } + } record Union(boolean all, Seq sources) implements RelRN { diff --git a/src/main/java/org/cosette/RelType.java b/src/main/java/org/cosette/RelType.java index 12ff173..d7b18be 100644 --- a/src/main/java/org/cosette/RelType.java +++ b/src/main/java/org/cosette/RelType.java @@ -7,6 +7,15 @@ import org.apache.calcite.sql.type.SqlTypeName; public sealed interface RelType extends RelDataType { + static RelType fromString(String name, boolean nullable) { + for (var tn : SqlTypeName.values()) { + if (tn.getName().equals(name)) { + return new BaseType(tn, nullable); + } + } + return new VarType(name, nullable); + } + final class VarType extends RelDataTypeImpl implements RelType { private final String name; private final boolean nullable; @@ -39,13 +48,4 @@ public BaseType(SqlTypeName typeName, boolean nullable) { super(RelDataTypeSystem.DEFAULT, typeName, nullable); } } - - static RelType fromString(String name, boolean nullable) { - for (var tn : SqlTypeName.values()) { - if (tn.getName().equals(name)) { - return new BaseType(tn, nullable); - } - } - return new VarType(name, nullable); - } } diff --git a/src/main/java/org/cosette/RewriteRule.java b/src/main/java/org/cosette/RewriteRule.java deleted file mode 100644 index e3a4aab..0000000 --- a/src/main/java/org/cosette/RewriteRule.java +++ /dev/null @@ -1,8 +0,0 @@ -package org.cosette; - -public class RewriteRule { - - public static void main(String[] args) { - } - -} diff --git a/src/main/java/org/cosette/RexRN.java b/src/main/java/org/cosette/RexRN.java index 6c4fb74..9a798a7 100644 --- a/src/main/java/org/cosette/RexRN.java +++ b/src/main/java/org/cosette/RexRN.java @@ -3,44 +3,85 @@ import kala.collection.Seq; import org.apache.calcite.rex.RexNode; -import java.util.function.Function; - public interface RexRN { - Function semantics(); - static Field field(int ord) { - return new Field(ord); + static RelType.VarType varType(String id, boolean nullable) { + return new RelType.VarType(id, nullable); + } + static And and(RexRN ...sources) { + return new And(Seq.from(sources)); } - static Proj proj(String id, RelType.VarType ty, RexRN... sources) { - return new Proj(id, ty, Seq.of(sources)); + RexNode semantics(); + + default Pred pred(String name) { + return new Pred(name, true, Seq.of(this)); } - static Pred pred(String id, RexRN... sources) { - return new Pred(id, Seq.of(sources)); + default Proj proj(String name, String type_name) { + return new Proj(name, type_name, true, Seq.of(this)); } - record Field(int ord) implements RexRN { + record Field(int ordinal, RelRN source) implements RexRN { + @Override - public Function semantics() { - return builder -> builder.field(ord); + public RexNode semantics() { + return RuleBuilder.create().push(source.semantics()).field(ordinal); } } - record Proj(String id, RelType.VarType ty, Seq sources) implements RexRN { + record JoinField(int ordinal, RelRN left, RelRN right) implements RexRN { @Override - public Function semantics() { - return builder -> builder.call(builder.genericProjectionOp(id, ty), sources.map(s -> s.semantics().apply(builder))); + public RexNode semantics() { + var leftCols = left.semantics().getRowType().getFieldCount(); + return RuleBuilder.create().push(left.semantics()).push(right.semantics()).field(2, ordinal < leftCols ? + 0 : 1, ordinal < leftCols ? ordinal : ordinal - leftCols); } } - record Pred(String id, Seq sources) implements RexRN { + record Pred(String name, boolean nullable, Seq sources) implements RexRN { @Override - public Function semantics() { - return builder -> builder.call(builder.genericPredicateOp(id, false), sources.map(s -> s.semantics().apply(builder))); + public RexNode semantics() { + var builder = RuleBuilder.create(); + return builder.call(builder.genericPredicateOp(name, nullable), sources.map(RexRN::semantics)); } } + record Proj(String name, String type_name, boolean nullable, Seq sources) implements RexRN { + + @Override + public RexNode semantics() { + var builder = RuleBuilder.create(); + return builder.call(builder.genericProjectionOp(name, varType(type_name, nullable)), + sources.map(RexRN::semantics)); + } + } + + record And(Seq sources) implements RexRN { + + @Override + public RexNode semantics() { + return RuleBuilder.create().and(sources.map(RexRN::semantics)); + } + } + + record Or(Seq sources) implements RexRN { + + @Override + public RexNode semantics() { + return RuleBuilder.create().or(sources.map(RexRN::semantics)); + } + } + + record Not(RexRN source) implements RexRN { + + @Override + public RexNode semantics() { + return RuleBuilder.create().not(source.semantics()); + } + } + + } diff --git a/src/main/java/org/cosette/RuleBuilder.java b/src/main/java/org/cosette/RuleBuilder.java index b83eff8..c01e0cb 100644 --- a/src/main/java/org/cosette/RuleBuilder.java +++ b/src/main/java/org/cosette/RuleBuilder.java @@ -1,20 +1,13 @@ package org.cosette; -import io.github.cvc5.Kind; -import io.github.cvc5.Solver; -import io.github.cvc5.Term; import kala.collection.Seq; import kala.collection.Set; -import kala.collection.immutable.ImmutableSeq; -import kala.control.Result; import kala.tuple.Tuple; import kala.tuple.Tuple2; import kala.tuple.Tuple3; import org.apache.calcite.plan.Context; import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptSchema; -import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rex.RexNode; import org.apache.calcite.schema.SchemaPlus; import org.apache.calcite.sql.*; From 7c58e66e65d904084855f1410b49b50d0c384140 Mon Sep 17 00:00:00 2001 From: macronova Date: Mon, 29 Apr 2024 21:09:12 -0700 Subject: [PATCH 03/78] Bump progress so far --- src/main/java/org/cosette/RelMatcher.java | 158 ------- src/main/java/org/cosette/RuleBuilder.java | 202 --------- src/main/java/org/qed/CodeGenerator.java | 384 +++--------------- .../org/qed/Generated/CalciteGenerator.java | 204 ++++++++++ .../java/org/qed/Generated/CalciteTester.java | 105 +++++ .../org/qed/Generated/CalciteUtilities.java | 16 + .../java/org/qed/Generated/EmptyConfig.java | 33 ++ .../org/qed/Generated/FilterIntoJoin.java | 42 ++ .../java/org/qed/Generated/FilterMerge.java | 38 ++ .../qed/Generated/FilterProjectTranspose.java | 39 ++ src/main/java/org/qed/JSONDeserializer.java | 194 ++++----- src/main/java/org/qed/JSONSerializer.java | 160 ++++---- src/main/java/org/qed/RRule.java | 91 +---- src/main/java/org/qed/RRuleInstance.java | 356 ++++++++++++++++ src/main/java/org/qed/RelRN.java | 33 +- src/main/java/org/qed/RexRN.java | 32 +- src/main/java/org/qed/SchemaGenerator.java | 3 +- 17 files changed, 1138 insertions(+), 952 deletions(-) delete mode 100644 src/main/java/org/cosette/RelMatcher.java delete mode 100644 src/main/java/org/cosette/RuleBuilder.java create mode 100644 src/main/java/org/qed/Generated/CalciteGenerator.java create mode 100644 src/main/java/org/qed/Generated/CalciteTester.java create mode 100644 src/main/java/org/qed/Generated/CalciteUtilities.java create mode 100644 src/main/java/org/qed/Generated/EmptyConfig.java create mode 100644 src/main/java/org/qed/Generated/FilterIntoJoin.java create mode 100644 src/main/java/org/qed/Generated/FilterMerge.java create mode 100644 src/main/java/org/qed/Generated/FilterProjectTranspose.java create mode 100644 src/main/java/org/qed/RRuleInstance.java diff --git a/src/main/java/org/cosette/RelMatcher.java b/src/main/java/org/cosette/RelMatcher.java deleted file mode 100644 index 20f85be..0000000 --- a/src/main/java/org/cosette/RelMatcher.java +++ /dev/null @@ -1,158 +0,0 @@ -package org.cosette; - -import kala.collection.Seq; -import kala.collection.Set; -import kala.collection.immutable.ImmutableMap; -import kala.collection.immutable.ImmutableSeq; -import kala.collection.immutable.ImmutableSet; -import kala.control.Result; -import kala.tuple.Tuple; -import kala.tuple.Tuple2; -import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.logical.LogicalFilter; -import org.apache.calcite.rel.logical.LogicalJoin; -import org.apache.calcite.rel.logical.LogicalProject; -import org.apache.calcite.rel.logical.LogicalTableScan; -import org.apache.calcite.rel.type.RelDataTypeField; -import org.apache.calcite.rex.RexInputRef; -import org.apache.calcite.rex.RexNode; -import org.apache.calcite.util.ImmutableBitSet; - -import java.util.stream.IntStream; - -public record RelMatcher() { - - public static Result check(RelNode pattern, RelNode target) { - return relMatch(pattern, target).flatMap(MatchEnv::verify); - } - - private static Result relMatch(RelNode pattern, RelNode target) { - return switch (pattern) { - case LogicalTableScan scan when scan.getTable() - .unwrap(CosetteTable.class) instanceof CosetteTable source -> { - if (!source.getColumnTypes().allMatch(t -> t instanceof RelType.VarType)) { - yield Result.err("Types in pattern should all be Variable types."); - } - var scanPattern = source.getColumnTypes().mapIndexed( - (i, t) -> Tuple.of(i, (RelType.VarType) t, t.isNullable(), - source.getKeys().contains(ImmutableBitSet.of(i)))); - var vts = scanPattern.map(t -> Tuple.of(t.component3(), t.component4())); - if (vts.size() != Set.from(vts).size()) { - yield Result.err( - "Unable to match duplicate (nullability, uniqueness) pairs for generic types.\n\t" + vts); - } - var tts = Seq.from(target.getRowType().getFieldList()).map(RelDataTypeField::getType) - .mapIndexed(Tuple::of); - var cms = ImmutableSeq.>empty(); - for (var t : scanPattern) { - var matched = tts.filter(tt -> { - // TODO: Derive target column uniqueness - return (t.component3() || !tt.component2().isNullable()) && !t.component4(); - }).map(Tuple2::component1); - cms = cms.appended(matched); - } - yield Result.ok(new MatchEnv( - new MatchEnv.FieldReference(cms, new MatchEnv.ProductType(tts.map(Tuple2::component2))), - ImmutableMap.empty(), ImmutableSeq.empty())); - } - case LogicalFilter filter when target instanceof LogicalFilter node -> - relMatch(filter.getInput(), node.getInput()).flatMap( - inputEnv -> inputEnv.assertConstraint(filter.getCondition(), Seq.of(node.getCondition()))); - case LogicalProject project when target instanceof LogicalProject node -> - relMatch(project.getInput(), node.getInput()).flatMap(inputEnv -> { - // Propagate raw references while leaving the rest for generic projection - var fieldPattern = Seq.from(project.getProjects()); - if (fieldPattern.isEmpty()) { - return Result.ok(inputEnv.updateFieldReference(Seq.empty(), Seq.empty())); - } else if (fieldPattern.allMatch(p -> p.getClass() == RexInputRef.class)) { - return Result.ok(inputEnv.updateFieldReference(fieldPattern.map( - field -> inputEnv.fieldReference().correspondence() - .get(((RexInputRef) field).getIndex())), - Seq.from(node.getProjects()).map(RexNode::getType))); - } else if (fieldPattern.size() == 1) { - return inputEnv.assertConstraint(project.getProjects().get(0), Seq.from(node.getProjects())) - .map(env -> env.updateFieldReference( - Seq.of(Seq.from(IntStream.range(0, node.getProjects().size()).iterator())), - Seq.from(node.getRowType().getFieldList()).map(RelDataTypeField::getType))); - } else { - return Result.err("TODO: Implement better field matching mechanism"); - } - }); - case LogicalJoin join when target instanceof LogicalJoin node && - join.getJoinType().equals(node.getJoinType()) -> relMatch(join.getLeft(), node.getLeft()).flatMap( - leftEnv -> relMatch(join.getRight(), node.getRight()).map(rightEnv -> Tuple.of(leftEnv, rightEnv))) - .map(env -> { - var leftEnv = env.component1(); - var rightEnv = env.component2(); - var leftFields = leftEnv.fieldReference(); - var rightFields = rightEnv.fieldReference(); - return new MatchEnv(new MatchEnv.FieldReference(leftFields.correspondence().appendedAll( - rightFields.correspondence() - .map(corr -> corr.map(i -> i + leftFields.sourceTypes().elements().size()))), - new MatchEnv.ProductType(leftFields.sourceTypes().elements() - .appendedAll(rightFields.sourceTypes().elements()))), - rightEnv.typeConstraints().toImmutableSeq().foldLeft(leftEnv.typeConstraints(), - (lcs, rc) -> lcs.putted(rc.component1(), - lcs.getOrDefault(rc.component1(), ImmutableSet.empty()) - .addedAll(rc.component2()))), - leftEnv.synthConstraints().appendedAll(rightEnv.synthConstraints())); - }).flatMap(inputEnv -> inputEnv.assertConstraint(join.getCondition(), Seq.of(node.getCondition()))); - default -> Result.err(String.format("Cannot match %s type pattern with %s target", pattern.getRelTypeName(), - target.getRelTypeName())); - }; - } - - public static void main(String[] args) throws Exception { -// var solver = new Solver(); -// solver.setOption("produce-models", "true"); -// solver.setOption("sygus", "true"); -// solver.setLogic("ALL"); -// var i = solver.mkVar(solver.getIntegerSort(), "i"); -// var f = solver.synthFun("f", new Term[]{i}, solver.getIntegerSort()); -// var x = solver.declareSygusVar("x", solver.getIntegerSort()); -// var x2 = solver.declareSygusVar("x", solver.getIntegerSort()); -// solver.addSygusConstraint(solver.mkTerm(Kind.EQUAL, x, x2)); -// var res = solver.checkSynth(); -// if (res.hasSolution()) { -// var synth = solver.getSynthSolution(f); -// System.out.println(synth); -// } -// var table_R = new CosetteTable("R", Map.of("a", new RelType.BaseType(SqlTypeName.INTEGER, true), "b", -// new RelType.BaseType(SqlTypeName.INTEGER, true) -// ), Set.empty(), Set.empty()); -// var generator = new SchemaGenerator(); -// generator.applyCreate("CREATE TABLE R (x INTEGER)"); -// generator.applyCreate("CREATE TABLE S (a INTEGER)"); -// var relBuilder = RuleBuilder.create(RawPlanner.generateConfig(generator.extractSchema())); -// var query = relBuilder.scan("R").scan("S").join(JoinRelType.INNER, relBuilder.call(SqlStdOperatorTable.AND, -// relBuilder.call(SqlStdOperatorTable.EQUALS, -// relBuilder.call(SqlStdOperatorTable.PLUS, relBuilder.field(2, 0, 0), relBuilder.field(2, 1, 0)), -// relBuilder.literal(0)), -// relBuilder.call(SqlStdOperatorTable.EQUALS, relBuilder.field(2, 0, 0), relBuilder.field(2, 1, 0)))) -// .build(); -// var joinConditionPushRule = new RuleBuilder.JoinConditionPush(); -// System.out.println("Example: SELECT * FROM R INNER JOIN S ON x + a = 0 AND x = a"); -// System.out.println(); -// System.out.println("Plan:"); -// System.out.println(query.explain()); -// System.out.println("Matcher:"); -// System.out.println(joinConditionPushRule.getPattern().explain()); -// var translator = joinConditionPushRule.match(query).get(); -// var solver = translator.solver(); -// for (var constraint : solver.getSygusConstraints()) { -// System.out.println(constraint); -// } -// System.out.println(); -// System.out.println("Can synthesize functions:"); -// System.out.println(solver.checkSynth().hasSolution()); -// System.out.println(); -// for (var functionName : translator.declaredFunctions().store().keysView()) { -// System.out.println("Synthesized function: " + functionName); -// for (var functionComponent : translator.declaredFunctions().store().get(functionName).component1()) { -// System.out.println(solver.getSynthSolution(functionComponent.component1())); -// } -// System.out.println(); -// } - } -} - diff --git a/src/main/java/org/cosette/RuleBuilder.java b/src/main/java/org/cosette/RuleBuilder.java deleted file mode 100644 index c01e0cb..0000000 --- a/src/main/java/org/cosette/RuleBuilder.java +++ /dev/null @@ -1,202 +0,0 @@ -package org.cosette; - -import kala.collection.Seq; -import kala.collection.Set; -import kala.tuple.Tuple; -import kala.tuple.Tuple2; -import kala.tuple.Tuple3; -import org.apache.calcite.plan.Context; -import org.apache.calcite.plan.RelOptCluster; -import org.apache.calcite.plan.RelOptSchema; -import org.apache.calcite.rex.RexNode; -import org.apache.calcite.schema.SchemaPlus; -import org.apache.calcite.sql.*; -import org.apache.calcite.sql.type.SqlTypeName; -import org.apache.calcite.tools.Frameworks; -import org.apache.calcite.tools.RelBuilder; -import org.apache.calcite.util.Optionality; -import org.checkerframework.checker.nullness.qual.Nullable; - -import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; - -public class RuleBuilder extends RelBuilder { - - private final AtomicInteger TABLE_ID_GENERATOR = new AtomicInteger(); - - private final SchemaPlus root; - - protected RuleBuilder(@Nullable Context context, RelOptCluster cluster, RelOptSchema relOptSchema, - SchemaPlus schema) { - super(context, cluster, relOptSchema); - root = schema; - } - - public static RuleBuilder create() { - var emptySchema = Frameworks.createRootSchema(true); - var config = Frameworks.newConfigBuilder().defaultSchema(emptySchema).build(); - return Frameworks.withPrepare(config, - (cluster, relOptSchema, rootSchema, statement) -> new RuleBuilder(config.getContext(), cluster, - relOptSchema, emptySchema)); - } - - public RuleBuilder addTable(CosetteTable table) { - root.add(table.getName(), table); - return this; - } - - /** - * Create a cosette table given the column types and whether they are unique (i.e. can be key) - * - * @param schema the list of column types and they are unique - * @return the table created from the given schema - */ - public CosetteTable createCosetteTable(Seq> schema) { - var identifier = "Table_" + TABLE_ID_GENERATOR.getAndIncrement(); - var cols = schema.mapIndexed( - (idx, tuple) -> Tuple.of(identifier + "_Column_" + idx, tuple.component1(), tuple.component2())); - return new CosetteTable(identifier, - cols.map(tuple -> Map.entry(tuple.component1(), tuple.component2())).toImmutableMap(), - Set.from(cols.filter(Tuple3::component3).map(tuple -> Set.of(tuple.component1()))), Set.of()); - } - - /** - * Create and return the names of the created simple tables after registering them to the builder - * - * @param typeIds the absolute value represents type id, while the sign indicates the uniqueness - * @return the names for the created tables - */ - public Seq sourceSimpleTables(Seq typeIds) { - return typeIds.map(id -> { - var identifier = "Table_" + TABLE_ID_GENERATOR.getAndIncrement(); - var colName = identifier + "_Column"; - var colType = new RelType.VarType("Type_" + (id < 0 ? -id : id), true); - var table = new CosetteTable(identifier, kala.collection.Map.of(colName, colType), - id < 0 ? Set.of(Set.of(colName)) : Set.of(), Set.empty()); - addTable(table); - return table.getName(); - }); - } - - public Seq joinFields() { - return Seq.from(fields(2, 0)).concat(fields(2, 1)); - } - - - public SqlAggFunction genericAggregateOp(String name, RelType aggregation) { - return new CosetteAggregateFunction(name, aggregation); - } - - public SqlOperator genericPredicateOp(String name, boolean nullable) { - return new CosetteFunction(name, new RelType.BaseType(SqlTypeName.BOOLEAN, nullable)); - } - - public SqlOperator genericProjectionOp(String name, RelType projection) { - return new CosetteFunction(name, projection); - } - - public static class CosetteFunction extends SqlFunction { - - private final RelType codomain; - - public CosetteFunction(String name, RelType returnType) { - super(name, SqlKind.OTHER_FUNCTION, opBinding -> { - var factory = opBinding.getTypeFactory(); - return factory.createTypeWithNullability(returnType, returnType.isNullable()); - }, null, null, SqlFunctionCategory.USER_DEFINED_FUNCTION); - codomain = returnType; - } - - public RelType getReturnType() { - return codomain; - } - } - - public static class CosetteAggregateFunction extends SqlAggFunction { - - public CosetteAggregateFunction(String name, RelType returnType) { - super(name, null, SqlKind.OTHER_FUNCTION, opBinding -> { - var factory = opBinding.getTypeFactory(); - return factory.createTypeWithNullability(returnType, returnType.isNullable()); - }, null, null, SqlFunctionCategory.USER_DEFINED_FUNCTION, false, false, Optionality.OPTIONAL); - } - } - -// public interface CosetteRule { -// -// RelNode getPattern(); -// -// RelNode getTransformation(); -// -// default ImmutableSeq deriveAdditionalConstraints(Solver solver, RexTranslator.Declarations declarations) { -// return ImmutableSeq.empty(); -// } -// -// default Result match(RelNode target) { -// return RelMatcher.check(getPattern(), target).map(translator -> translator.addConstraints( -// deriveAdditionalConstraints(translator.solver(), translator.declaredFunctions()))); -// } -// -// } -// -// public static class JoinConditionPush implements CosetteRule { -// -// private final RelNode pattern; -// private final RelNode transform; -// private final String bothPredicate = "joinBoth"; -// private final String leftPredicate = "joinLeft"; -// private final String rightPredicate = "joinRight"; -// -// -// public JoinConditionPush() { -// var builder = RuleBuilder.create(); -// var tableNames = builder.sourceSimpleTables(Seq.of(1, 2)); -// tableNames.forEach(builder::scan); -// var joinBoth = builder.genericPredicateOp(bothPredicate, true); -// var joinLeft = builder.genericPredicateOp(leftPredicate, true); -// var joinRight = builder.genericPredicateOp(rightPredicate, true); -// var joinCond = builder.and(builder.call(joinBoth, builder.joinFields()), -// builder.call(joinLeft, builder.fields(2, 0)), builder.call(joinRight, builder.fields(2, 1))); -// builder.join(JoinRelType.INNER, joinCond); -// pattern = builder.build(); -// builder.scan(tableNames.get(0)).filter(builder.call(joinLeft, builder.fields())); -// builder.scan(tableNames.get(1)).filter(builder.call(joinRight, builder.fields())); -// joinCond = builder.call(joinBoth, builder.joinFields()); -// builder.join(JoinRelType.INNER, joinCond); -// transform = builder.build(); -// } -// -// @Override -// public RelNode getPattern() { -// return pattern; -// } -// -// @Override -// public RelNode getTransformation() { -// return transform; -// } -// -// @Override -// public ImmutableSeq deriveAdditionalConstraints(Solver solver, RexTranslator.Declarations declarations) { -// var bp = declarations.store().get(bothPredicate); -// var lp = declarations.store().get(leftPredicate); -// var rp = declarations.store().get(rightPredicate); -// var lvs = lp.component2().mapIndexed((i, s) -> solver.declareSygusVar(leftPredicate + "-V" + i, s)); -// var rvs = rp.component2().mapIndexed((i, s) -> solver.declareSygusVar(rightPredicate + "-V" + i, s)); -// var bpc = solver.mkTerm(Kind.APPLY_UF, -// bp.component1().map(Tuple2::component1).appendedAll(lvs).appendedAll(rvs).toArray(new Term[]{})); -// var lpc = solver.mkTerm(Kind.APPLY_UF, -// lp.component1().map(Tuple2::component1).appendedAll(lvs).toArray(new Term[]{})); -// var rpc = solver.mkTerm(Kind.APPLY_UF, -// rp.component1().map(Tuple2::component1).appendedAll(rvs).toArray(new Term[]{})); -// var acl = solver.mkTerm(Kind.IMPLIES, lpc, -// solver.mkTerm(Kind.EXISTS, solver.mkTerm(Kind.VARIABLE_LIST, rvs.toArray(new Term[]{})), -// solver.mkTerm(Kind.AND, bpc, rpc))); -// var acr = solver.mkTerm(Kind.IMPLIES, rpc, -// solver.mkTerm(Kind.EXISTS, solver.mkTerm(Kind.VARIABLE_LIST, lvs.toArray(new Term[]{})), -// solver.mkTerm(Kind.AND, bpc, lpc))); -// return ImmutableSeq.of(acl, acr); -// } -// } - -} \ No newline at end of file diff --git a/src/main/java/org/qed/CodeGenerator.java b/src/main/java/org/qed/CodeGenerator.java index 5d4ee57..3fde6a0 100644 --- a/src/main/java/org/qed/CodeGenerator.java +++ b/src/main/java/org/qed/CodeGenerator.java @@ -1,284 +1,24 @@ -package org.cosette; - -import kala.collection.Map; -import kala.collection.Seq; - -import java.util.concurrent.atomic.AtomicInteger; - -public interface CodeGenerator { - static void main(String[] args) { - // Calcite: - // - onMatch: Check RelNode type, register uninterpreted symbols - // - transform: Recursively build (depth first) with builder, keep builder expression - var calciteCodeGen = new CodeGenerator() { - final AtomicInteger variableIndex = new AtomicInteger(); - - Env assignVariable(Env env, String expression) { - var name = STR."var_\{variableIndex.getAndIncrement()}"; - return env.state(STR."var \{name} = \{expression};").express(Seq.of(name)); - } - - @Override - public Env preMatch() { - return assignVariable(Env.empty(), "call.rel(0)"); - } - - @Override - public Env preTransform(Env env) { - return assignVariable(env, "call.builder()"); - } - - @Override - public Env postTransform(Env env) { - var result_env = assignVariable(env, STR."\{env.expressions().first()}.build()"); - return result_env.state(STR."call.transformTo(\{result_env.expressions().first()});"); - } - - @Override - public String translate(Env ignored, Env transform) { - StringBuilder builder = new StringBuilder("@Override public void onMatch(RelOptRuleCall call) {\n"); - transform.statements().forEach(statement -> builder.append("\t").append(statement).append("\n")); - builder.append("}\n"); - return builder.toString(); - } - - @Override - public Env onMatchScan(Env env, RelRN.Scan scan) { - return env.symbol(scan.name(), env.expressions().first()); - } - - @Override - public Env onMatchFilter(Env env, RelRN.Filter filter) { - var rel = env.expressions().first(); - env = env.state(STR."if (!(\{rel} instanceof LogicalFilter)) { return; }"); - var rel_filter_env = assignVariable(env, STR."(LogicalFilter) \{rel}"); - var rel_filter = rel_filter_env.expressions().first(); - var source_env = assignVariable(rel_filter_env, STR."\{rel_filter}.getInput()"); - var source_match = onMatch(source_env, filter.source()); - var cond_env = assignVariable(source_match, STR."\{rel_filter}.getCondition()"); - return onMatch(cond_env, filter.cond()); - } - - @Override - public Env onMatchJoin(Env env, RelRN.Join join) { - var rel = env.expressions().first(); - env = env.state(STR."if (!(\{rel} instanceof LogicalJoin)) { return; }"); - var rel_join_env = assignVariable(env, STR."(LogicalJoin) \{rel}"); - var rel_join = rel_join_env.expressions().first(); - var join_type_env = assignVariable(rel_join_env, STR."\{rel_join}.getJoinType()"); - var join_type = join_type_env.expressions().first(); - var type_cond_expression = switch (join.ty()) { - - case INNER -> STR."\{join_type} == JoinRelType.INNER"; - case LEFT -> STR."\{join_type} == JoinRelType.LEFT"; - case RIGHT -> STR."\{join_type} == JoinRelType.RIGHT"; - case FULL -> STR."\{join_type} == JoinRelType.FULL"; - case SEMI -> STR."\{join_type} == JoinRelType.SEMI"; - case ANTI -> STR."\{join_type} == JoinRelType.ANTI"; - }; - var type_cond_env = assignVariable(join_type_env, type_cond_expression); - var type_match = type_cond_env.state(STR."if (!\{type_cond_env.expressions().first()}) { return; }"); - var left_source_env = assignVariable(type_match, STR."\{rel_join}.getLeft()"); - var left_match_env = onMatch(left_source_env, join.left()); - var right_source_env = assignVariable(left_match_env, STR."\{rel_join}.getRight()"); - var right_match_env = onMatch(right_source_env, join.right()); - var cond_source_env = assignVariable(right_match_env, STR."\{rel_join}.getCondition()"); - return onMatch(cond_source_env, join.cond()); - } - - @Override - public Env onMatchPred(Env env, RexRN.Pred pred) { - return env.symbol(pred.name(), env.expressions().first()); - } - - @Override - public Env onMatchCustom(Env env, RexRN custom) { - return switch (custom) { - case RRule.JoinConditionPush.JoinPred joinPred -> { - var pred = env.expressions().first(); - var breakdown_env = assignVariable(env, STR."customSplitFilter(\{pred})"); - var breakdown = breakdown_env.expressions().first(); - yield breakdown_env - .symbol(joinPred.bothPred(), STR."\{breakdown}.getBoth()") - .symbol(joinPred.leftPred(), STR."\{breakdown}.getLeft()") - .symbol(joinPred.rightPred(), STR."\{breakdown}.getRight()"); - } - default -> CodeGenerator.super.onMatchCustom(env, custom); - }; - } - - @Override - public Env transformScan(Env env, RelRN.Scan scan) { - var builder = env.expressions().first(); - var source_transform = STR."\{builder}.push(\{env.symbols().get(scan.name())})"; - return assignVariable(env, source_transform); - } - - @Override - public Env transformFilter(Env env, RelRN.Filter filter) { - var source_transform = transform(env, filter.source()); - var source_expression = source_transform.expressions().first(); - var cond_transform = transform(source_transform, filter.cond()); - return assignVariable(cond_transform, STR."\{source_expression}.filter(\{cond_transform.expressions().first()})"); - } - - @Override - public Env transformJoin(Env env, RelRN.Join join) { - var left_source_transform = transform(env, join.left()); - var right_source_transform = transform(left_source_transform, join.right()); - var source_expression = right_source_transform.expressions().first(); - var join_expression = switch (join.ty()) { - case INNER -> "JoinRelType.INNER"; - case LEFT -> "JoinRelType.LEFT"; - case RIGHT -> "JoinRelType.RIGHT"; - case FULL -> "JoinRelType.FULL"; - case SEMI -> "JoinRelType.SEMI"; - case ANTI -> "JoinRelType.ANTI"; - }; - var cond_transform = transform(right_source_transform, join.cond()); - return assignVariable(cond_transform, STR."\{source_expression}.join(\{join_expression}, \{cond_transform.expressions().first()})"); - } - - @Override - public Env transformPred(Env env, RexRN.Pred pred) { - return env.express(Seq.of(env.symbols().get(pred.name()))); - } - - @Override - public Env transformAnd(Env env, RexRN.And and) { - var source_transform = env; - var operands = Seq.empty(); - for (var source : and.sources()) { - source_transform = transform(source_transform, source); - operands = operands.appended(source_transform.expressions().first()); - source_transform = source_transform.express(env.expressions()); - } - return assignVariable(source_transform, STR."\{env.expressions().first()}.and(\{operands.joinToString(", ")})"); - } - }; +package org.qed; - // Cockroach: Recursively (depth-first) generate DSL - var cockroachCodeGen = new CodeGenerator() { - @Override - public String translate(Env onMatch, Env transform) { - return STR."\{onMatch.expressions().first()}\n=>\n\{transform.expressions().first()}\n"; - } - - @Override - public Env onMatchScan(Env env, RelRN.Scan scan) { - return env.express(Seq.of(STR."$\{scan.name()}:*")); - } - - @Override - public Env onMatchFilter(Env env, RelRN.Filter filter) { - var source_match = onMatch(env, filter.source()); - var source_expression = source_match.expressions().first(); - var cond_match = onMatch(source_match, filter.cond()); - var cond_expression = cond_match.expressions().first(); - return cond_match.express(Seq.of(STR."(Select \{source_expression} \{cond_expression})")); - } - - @Override - public Env onMatchJoin(Env env, RelRN.Join join) { - var left_source_match = onMatch(env, join.left()); - var left_source_expression = left_source_match.expressions().first(); - var right_source_match = onMatch(left_source_match, join.right()); - var right_source_expression = right_source_match.expressions().first(); - var join_expression = switch (join.ty()) { - case INNER -> "InnerJoin"; - case LEFT -> "LeftJoin"; - case RIGHT -> "RightJoin"; - case FULL -> "FullJoin"; - case SEMI -> "SemiJoin"; - case ANTI -> "AntiJoin"; - }; - var cond_match = onMatch(right_source_match, join.cond()); - var cond_expression = cond_match.expressions().first(); - return cond_match.express(Seq.of(STR."(\{join_expression} \{left_source_expression} \{right_source_expression} \{cond_expression} $private:*)")); - } - - @Override - public Env onMatchPred(Env env, RexRN.Pred pred) { - return env.symbol(pred.name(), STR."$\{pred.name()}").express(Seq.of(STR."$\{pred.name()}:*")); - } - - @Override - public Env onMatchCustom(Env env, RexRN custom) { - return switch (custom) { - case RRule.JoinConditionPush.JoinPred joinPred -> env - .symbol(joinPred.bothPred(), STR."(RemoveFiltersItem $\{joinPred.bothPred()} $item)") - .symbol(joinPred.leftPred(), "[(FiltersItem (MapJoinOpFilter $item $leftCols $equivSet))]") - .symbol(joinPred.rightPred(), "[(FiltersItem (MapJoinOpFilter $item $rightCols $equivSet))]") - .express(Seq.of("$on:[... $item:* & (CanMapJoinOpFilter $item $leftCols $equivSet) & (CanMapJoinOpFilter $item $rightCols $equivSet)...]")); - default -> CodeGenerator.super.onMatchCustom(env, custom); - }; - } - - @Override - public Env transformScan(Env env, RelRN.Scan scan) { - return env.express(Seq.of(STR."$\{scan.name()}")); - } - - @Override - public Env transformFilter(Env env, RelRN.Filter filter) { - var source_expression = transform(env, filter.source()).expressions().first(); - var cond_expression = transform(env, filter.cond()).expressions().first(); - return env.express(Seq.of(STR."(Select \{source_expression} \{cond_expression})")); - } - - @Override - public Env transformJoin(Env env, RelRN.Join join) { - var left_source_expression = transform(env, join.left()).expressions().first(); - var right_source_expression = transform(env, join.right()).expressions().first(); - var cond_expression = transform(env, join.cond()).expressions.first(); - return env.express(Seq.of(STR."((OpName) \{left_source_expression} \{right_source_expression} \{cond_expression} $private)")); - } - - @Override - public Env transformPred(Env env, RexRN.Pred pred) { - return env.express(Seq.of(env.symbols().get(pred.name()))); - } - - @Override - public Env transformAnd(Env env, RexRN.And and) { - var operands = and.sources().map(source -> transform(env, source).expressions().first()); - return env.express(Seq.of(STR."(ConcatFilters \{operands.joinToString(", ")})")); - } - }; - - // Note: Join is treated as if it is a custom operator - var filterMerge = new RRule.FilterMerge(); - var joinConditionPush = new RRule.JoinConditionPush(); - var calciteFilterMerge = calciteCodeGen.compose(filterMerge); - var calciteJoinConditionPush = calciteCodeGen.compose(joinConditionPush); - var cockroachFilterMerge = cockroachCodeGen.compose(filterMerge); - var cockroachJoinConditionPush = cockroachCodeGen.compose(joinConditionPush); - System.out.println(filterMerge.explain()); - System.out.println(calciteFilterMerge); - System.out.println(cockroachFilterMerge); - System.out.println(); - System.out.println(joinConditionPush.explain()); - System.out.println(calciteJoinConditionPush); - System.out.println(cockroachJoinConditionPush); - } +public interface CodeGenerator { default String unimplemented(String context, Object object) { return STR."<--\{context}\{object.getClass().getName()}-->"; } - default Env unimplementedOnMatch(Env env, Object object) { - return env.express(Seq.of(unimplemented("Unspecified onMatch codegen: ", object))); + default E unimplementedOnMatch(E env, Object object) { + System.err.println(unimplemented("Unspecified onMatch codegen: ", object)); + return env; } - default Env unimplementedTransform(Env env, Object object) { - return env.express(Seq.of(unimplemented("Unspecified transform codegen: ", object))); + default E unimplementedTransform(E env, Object object) { + System.err.println(unimplemented("Unspecified transform codegen: ", object)); + return env; } - default Env preMatch() { - return Env.empty(); - } + E preMatch(); - default Env onMatch(Env env, RelRN pattern) { + default E onMatch(E env, RelRN pattern) { return switch (pattern) { case RelRN.Scan scan -> onMatchScan(env, scan); case RelRN.Filter filter -> onMatchFilter(env, filter); @@ -290,7 +30,7 @@ default Env onMatch(Env env, RelRN pattern) { }; } - default Env onMatch(Env env, RexRN pattern) { + default E onMatch(E env, RexRN pattern) { return switch (pattern) { case RexRN.Field field -> onMatchField(env, field); case RexRN.JoinField joinField -> onMatchJoinField(env, joinField); @@ -303,15 +43,15 @@ default Env onMatch(Env env, RexRN pattern) { }; } - default Env postMatch(Env env) { + default E postMatch(E env) { return env; } - default Env preTransform(Env env) { + default E preTransform(E env) { return env; } - default Env transform(Env env, RelRN target) { + default E transform(E env, RelRN target) { return switch (target) { case RelRN.Scan scan -> transformScan(env, scan); case RelRN.Filter filter -> transformFilter(env, filter); @@ -323,12 +63,12 @@ default Env transform(Env env, RelRN target) { }; } - default Env transform(Env env, RexRN target) { + default E transform(E env, RexRN target) { return switch (target) { case RexRN.Field field -> transformField(env, field); case RexRN.JoinField joinField -> transformJoinField(env, joinField); - case RexRN.Proj proj -> transformProj(env, proj); case RexRN.Pred pred -> transformPred(env, pred); + case RexRN.Proj proj -> transformProj(env, proj); case RexRN.And and -> transformAnd(env, and); case RexRN.Or or -> transformOr(env, or); case RexRN.Not not -> transformNot(env, not); @@ -336,156 +76,138 @@ default Env transform(Env env, RexRN target) { }; } - default Env postTransform(Env env) { + default E postTransform(E env) { return env; } - default String translate(Env onMatch, Env transform) { - return unimplemented("Unspecified translation to target language: ", Env.empty()); + default String translate(String name, E onMatch, E transform) { + return "Unspecified translation to target language"; } - default String compose(RRule rule) { + default String generate(RRule rule) { var onMatch = postMatch(onMatch(preMatch(), rule.before())); var transform = postTransform(transform(preTransform(onMatch), rule.after())); - return translate(onMatch, transform); + return translate(rule.getClass().getSimpleName(), onMatch, transform); } - default Env onMatchScan(Env env, RelRN.Scan scan) { + default E onMatchScan(E env, RelRN.Scan scan) { return unimplementedOnMatch(env, scan); } - default Env onMatchFilter(Env env, RelRN.Filter filter) { + default E onMatchFilter(E env, RelRN.Filter filter) { return unimplementedOnMatch(env, filter); } - default Env onMatchProject(Env env, RelRN.Project project) { + default E onMatchProject(E env, RelRN.Project project) { return unimplementedOnMatch(env, project); } - default Env onMatchJoin(Env env, RelRN.Join join) { + default E onMatchJoin(E env, RelRN.Join join) { return unimplementedOnMatch(env, join); } - default Env onMatchUnion(Env env, RelRN.Union union) { + default E onMatchUnion(E env, RelRN.Union union) { return unimplementedOnMatch(env, union); } - default Env onMatchIntersect(Env env, RelRN.Intersect intersect) { + default E onMatchIntersect(E env, RelRN.Intersect intersect) { return unimplementedOnMatch(env, intersect); } - default Env onMatchCustom(Env env, RelRN custom) { + default E onMatchCustom(E env, RelRN custom) { return unimplementedOnMatch(env, custom); } - default Env onMatchField(Env env, RexRN.Field field) { + default E onMatchField(E env, RexRN.Field field) { return unimplementedOnMatch(env, field); } - default Env onMatchJoinField(Env env, RexRN.JoinField joinField) { + default E onMatchJoinField(E env, RexRN.JoinField joinField) { return unimplementedOnMatch(env, joinField); } - default Env onMatchProj(Env env, RexRN.Proj proj) { - return unimplementedOnMatch(env, proj); + default E onMatchPred(E env, RexRN.Pred pred) { + return unimplementedOnMatch(env, pred); } - default Env onMatchPred(Env env, RexRN.Pred pred) { - return unimplementedOnMatch(env, pred); + default E onMatchProj(E env, RexRN.Proj proj) { + return unimplementedOnMatch(env, proj); } - default Env onMatchAnd(Env env, RexRN.And and) { + default E onMatchAnd(E env, RexRN.And and) { return unimplementedOnMatch(env, and); } - default Env onMatchOr(Env env, RexRN.Or or) { + default E onMatchOr(E env, RexRN.Or or) { return unimplementedOnMatch(env, or); } - default Env onMatchNot(Env env, RexRN.Not not) { + default E onMatchNot(E env, RexRN.Not not) { return unimplementedOnMatch(env, not); } - default Env onMatchCustom(Env env, RexRN custom) { + default E onMatchCustom(E env, RexRN custom) { return unimplementedOnMatch(env, custom); } - default Env transformScan(Env env, RelRN.Scan scan) { + default E transformScan(E env, RelRN.Scan scan) { return unimplementedTransform(env, scan); } - default Env transformFilter(Env env, RelRN.Filter filter) { + default E transformFilter(E env, RelRN.Filter filter) { return unimplementedTransform(env, filter); } - default Env transformProject(Env env, RelRN.Project project) { + default E transformProject(E env, RelRN.Project project) { return unimplementedTransform(env, project); } - default Env transformJoin(Env env, RelRN.Join join) { + default E transformJoin(E env, RelRN.Join join) { return unimplementedTransform(env, join); } - default Env transformUnion(Env env, RelRN.Union union) { + default E transformUnion(E env, RelRN.Union union) { return unimplementedTransform(env, union); } - default Env transformIntersect(Env env, RelRN.Intersect intersect) { + default E transformIntersect(E env, RelRN.Intersect intersect) { return unimplementedTransform(env, intersect); } - default Env transformCustom(Env env, RelRN custom) { + default E transformCustom(E env, RelRN custom) { return unimplementedTransform(env, custom); } - default Env transformField(Env env, RexRN.Field field) { + default E transformField(E env, RexRN.Field field) { return unimplementedTransform(env, field); } - default Env transformJoinField(Env env, RexRN.JoinField joinField) { + default E transformJoinField(E env, RexRN.JoinField joinField) { return unimplementedTransform(env, joinField); } - default Env transformProj(Env env, RexRN.Proj proj) { + default E transformProj(E env, RexRN.Proj proj) { return unimplementedTransform(env, proj); } - default Env transformPred(Env env, RexRN.Pred pred) { + default E transformPred(E env, RexRN.Pred pred) { return unimplementedTransform(env, pred); } - default Env transformAnd(Env env, RexRN.And and) { + default E transformAnd(E env, RexRN.And and) { return unimplementedTransform(env, and); } - default Env transformOr(Env env, RexRN.Or or) { + default E transformOr(E env, RexRN.Or or) { return unimplementedTransform(env, or); } - default Env transformNot(Env env, RexRN.Not not) { + default E transformNot(E env, RexRN.Not not) { return unimplementedTransform(env, not); } - default Env transformCustom(Env env, RexRN custom) { + default E transformCustom(E env, RexRN custom) { return unimplementedTransform(env, custom); } - record Env(Seq expressions, Seq statements, Map symbols) { - public static Env empty() { - return new Env(Seq.empty(), Seq.empty(), Map.empty()); - } - - public Env express(Seq expressions) { - return new Env(expressions, statements, symbols); - } - - public Env state(String statement) { - return new Env(expressions, statements.appended(statement), symbols); - } - - public Env symbol(String symbol, String expression) { - return new Env(expressions, statements, symbols.toImmutableMap().putted(symbol, expression)); - } - } - } diff --git a/src/main/java/org/qed/Generated/CalciteGenerator.java b/src/main/java/org/qed/Generated/CalciteGenerator.java new file mode 100644 index 0000000..3448f73 --- /dev/null +++ b/src/main/java/org/qed/Generated/CalciteGenerator.java @@ -0,0 +1,204 @@ +package org.qed.Generated; + +import kala.collection.Seq; +import kala.collection.immutable.ImmutableMap; +import kala.tuple.Tuple; +import kala.tuple.Tuple2; +import org.qed.CodeGenerator; +import org.qed.RelRN; +import org.qed.RexRN; + +import java.util.concurrent.atomic.AtomicInteger; + +public class CalciteGenerator implements CodeGenerator { + + @Override + public Env preMatch() { + return Env.empty(); + } + + @Override + public Env preTransform(Env env) { + var buildEnv = env.declare("call.builder()"); + return buildEnv.getValue().focus(buildEnv.getKey()); + } + + @Override + public Env postTransform(Env env) { + return env.state(STR."call.transformTo(\{env.current()}.build());"); + } + + @Override + public String translate(String name, Env onMatch, Env transform) { + var builder = new StringBuilder("package org.qed.Generated;\n\n"); + builder.append("import org.apache.calcite.plan.RelOptRuleCall;\n"); + builder.append("import org.apache.calcite.plan.RelRule;\n"); + builder.append("import org.apache.calcite.rel.RelNode;\n"); + builder.append("import org.apache.calcite.rel.core.JoinRelType;\n"); + builder.append("import org.apache.calcite.rel.logical.*;\n\n"); + builder.append(STR."public class \{name} extends RelRule<\{name}.Config> {\n"); + builder.append(STR."\tprotected \{name}(Config config) {\n"); + builder.append("\t\tsuper(config);\n"); + builder.append("\t}\n\n"); + builder.append("\t@Override\n\tpublic void onMatch(RelOptRuleCall call) {\n"); + transform.statements().forEach(statement -> builder.append("\t\t").append(statement).append("\n")); + builder.append("\t}\n\n"); + builder.append("\tpublic interface Config extends EmptyConfig {\n"); + builder.append("\t\tConfig DEFAULT = new Config() {};\n\n"); + builder.append(STR."\t\t@Override\n\t\tdefault \{name} toRule() {\n"); + builder.append(STR."\t\t\treturn new \{name}(this);\n"); + builder.append("\t\t}\n\n"); + builder.append("\t\t@Override\n\t\tdefault String description() {\n"); + builder.append(STR."\t\t\treturn \"\{name}\";\n"); + builder.append("\t\t}\n\n"); + builder.append("\t\t@Override\n\t\tdefault RelRule.OperandTransform operandSupplier() {\n"); + builder.append(STR."\t\t\treturn \{onMatch.skeleton()};\n"); + builder.append("\t\t}\n\n"); + builder.append("\t}\n"); + builder.append("}\n"); + return builder.toString(); + } + + @Override + public Env onMatchScan(Env env, RelRN.Scan scan) { + return env.symbol(scan.name(), env.current()).grow("operand(RelNode.class).anyInputs()"); + } + + @Override + public Env onMatchFilter(Env env, RelRN.Filter filter) { + var source_match = onMatch(env.next(), filter.source()); + var operator_match = source_match.grow(STR."operand(LogicalFilter.class).oneInput(\{source_match.skeleton()})"); + var condition_match = operator_match.focus(STR."((LogicalFilter) \{env.current()}).getCondition()"); + return onMatch(condition_match, filter.cond()); + } + + @Override + public Env onMatchProject(Env env, RelRN.Project project) { + var source_match = onMatch(env.next(), project.source()); + var operator_match = + source_match.grow(STR."operand(LogicalProject.class).oneInput(\{source_match.skeleton()})"); + var map_match = operator_match.focus(STR."((LogicalProject) \{env.current()}).getProjects()"); + return onMatch(map_match, project.map()); + } + + @Override + public Env onMatchPred(Env env, RexRN.Pred pred) { + return env.symbol(pred.name(), env.current()); + } + + @Override + public Env onMatchProj(Env env, RexRN.Proj proj) { + return env.symbol(proj.name(), env.current()); + } + + @Override + public Env onMatchJoin(Env env, RelRN.Join join) { + var current_join = STR."((LogicalJoin) \{env.current()})"; + // STR."\{join_env.current()}.getJoinType()" + var left_source_env = env.next(); + var left_match_env = onMatch(left_source_env, join.left()); + var right_source_env = left_match_env.next(); + var right_match_env = onMatch(right_source_env, join.right()); + var operator_match = + right_match_env.grow(STR."operand(LogicalJoin.class).inputs(\{left_match_env.skeleton()}, \{right_match_env.skeleton()})"); + var cond_source_env = operator_match.focus(STR."\{current_join}.getCondition()"); + return onMatch(cond_source_env, join.cond()); + } + + @Override + public Env transformScan(Env env, RelRN.Scan scan) { + return env.focus(STR."\{env.current()}.push(\{env.symbols().get(scan.name())})"); + } + +// @Override +// public Env onMatchCustom(Env env, RexRN custom) { +// return switch (custom) { +// case RRule.JoinConditionPush.JoinPred joinPred -> { +// var pred = env.expressions().first(); +// var breakdown_env = assignVariable(env, STR."customSplitFilter(\{pred})"); +// var breakdown = breakdown_env.expressions().first(); +// yield breakdown_env +// .symbol(joinPred.bothPred(), STR."\{breakdown}.getBoth()") +// .symbol(joinPred.leftPred(), STR."\{breakdown}.getLeft()") +// .symbol(joinPred.rightPred(), STR."\{breakdown}.getRight()"); +// } +// default -> CodeGenerator.super.onMatchCustom(env, custom); +// }; +// } + + @Override + public Env transformFilter(Env env, RelRN.Filter filter) { + var source_transform = transform(env, filter.source()); + var source_expression = source_transform.current(); + var cond_transform = transform(source_transform, filter.cond()); + return cond_transform.focus(STR."\{source_expression}.filter(\{cond_transform.current()})"); + } + + @Override + public Env transformPred(Env env, RexRN.Pred pred) { + return env.focus(env.symbols().get(pred.name())); + } + + @Override + public Env transformJoin(Env env, RelRN.Join join) { + var left_source_transform = transform(env, join.left()); + var right_source_transform = transform(left_source_transform, join.right()); + var source_expression = right_source_transform.current(); + var cond_transform = transform(right_source_transform, join.cond()); + var join_type = switch (join.ty()) { + case INNER -> "JoinRelType.INNER"; + case LEFT -> "JoinRelType.LEFT"; + case RIGHT -> "JoinRelType.RIGHT"; + case FULL -> "JoinRelType.FULL"; + case SEMI -> "JoinRelType.SEMI"; + case ANTI -> "JoinRelType.ANTI"; + }; + return cond_transform.focus(STR."\{source_expression}.join(\{join_type}, \{cond_transform.current()})"); + } + + @Override + public Env transformAnd(Env env, RexRN.And and) { + var source_transform = env; + var operands = Seq.empty(); + for (var source : and.sources()) { + source_transform = transform(source_transform, source); + operands = operands.appended(source_transform.current()); + source_transform = source_transform.focus(env.current()); + } + return source_transform.focus(STR."\{env.current()}.and(\{operands.joinToString(", ")})"); + } + + public record Env(AtomicInteger varId, int rel, String current, String skeleton, Seq statements, + ImmutableMap symbols) { + public static Env empty() { + return new Env(new AtomicInteger(), 0, "call.rel(0)", "/* Unspecified skeleton */", Seq.empty(), + ImmutableMap.empty()); + } + + public Env next() { + return new Env(varId, rel + 1, STR."call.rel(\{rel + 1})", skeleton, statements, symbols); + } + + public Env focus(String target) { + return new Env(varId, rel, target, skeleton, statements, symbols); + } + + public Env state(String statement) { + return new Env(varId, rel, current, skeleton, statements.appended(statement), symbols); + } + + public Env symbol(String symbol, String expression) { + return new Env(varId, rel, current, skeleton, statements, symbols.putted(symbol, expression)); + } + + public Tuple2 declare(String expression) { + var name = STR."var_\{varId.getAndIncrement()}"; + return Tuple.of(name, state(STR."var \{name} = \{expression};")); + } + + public Env grow(String requirement) { + var vn = STR."s_\{varId.getAndIncrement()}"; + return new Env(varId, rel, current, STR."\{vn} -> \{vn}.\{requirement}", statements, symbols); + } + } +} diff --git a/src/main/java/org/qed/Generated/CalciteTester.java b/src/main/java/org/qed/Generated/CalciteTester.java new file mode 100644 index 0000000..c419a1b --- /dev/null +++ b/src/main/java/org/qed/Generated/CalciteTester.java @@ -0,0 +1,105 @@ +package org.qed.Generated; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.hep.HepPlanner; +import org.apache.calcite.plan.hep.HepProgramBuilder; +import org.apache.calcite.rel.RelNode; +import org.qed.JSONDeserializer; +import org.qed.RRule; +import org.qed.RRuleInstance; + +import java.io.File; +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.nio.file.Files; +import java.nio.file.Path; + +public class CalciteTester { + // Assuming that current working directory is the root of the project + public static String genPath = "src/main/java/org/qed/Generated"; + + public static HepPlanner loadRule(RelOptRule rule) { + var builder = new HepProgramBuilder().addRuleInstance(rule); + return new HepPlanner(builder.build()); + } + + public static Seq ruleList() { + var individuals = Seq.from(RRuleInstance.class.getClasses()).filter(RRule.class::isAssignableFrom).mapUnchecked(Class::getConstructor).mapUnchecked(Constructor::newInstance).map(r -> (RRule) r); + var families = Seq.from(RRuleInstance.class.getClasses()).filter(RRule.RRuleFamily.class::isAssignableFrom).mapUnchecked(Class::getConstructor).mapUnchecked(Constructor::newInstance).map(r -> (RRule.RRuleFamily) r); + return individuals.appendedAll(families.flatMap(RRule.RRuleFamily::family)); + } + + public static void verify() { + ruleList().forEachUnchecked(rule -> rule.dump(STR."dump/\{rule.name()}.json")); + } + + public static void generate() { + var tester = new CalciteTester(); + ruleList().forEach(r -> tester.serialize(r, genPath)); + } + + public static void main(String[] args) { + generate(); +// var tester = new CalciteTester(); +// var builder = RuleBuilder.create(); +// var table = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); +// builder.addTable(table); +// var before = builder.scan(table.getName()) +// .filter(builder.call(builder.genericPredicateOp("inner", true), builder.fields())) +// .filter(builder.call(builder.genericPredicateOp("outer", true), builder.fields())) +// .build(); +// var after = builder.scan(table.getName()).filter(builder.call(SqlStdOperatorTable.AND, +// builder.call(builder.genericPredicateOp("inner", true), builder.fields()), +// builder.call(builder.genericPredicateOp("outer", true), builder.fields()))) +// .build(); +// var runner = loadRule(FilterMerge.Config.DEFAULT.toRule()); +// tester.verify(runner, before, after); +// before = builder.scan(table.getName()) +// .scan(table.getName()) +// .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) +// .filter(builder.call(builder.genericPredicateOp("pred", true), builder.fields())) +// .build(); +// after = builder.scan(table.getName()) +// .scan(table.getName()) +// .join(JoinRelType.INNER, builder.call(SqlStdOperatorTable.AND, +// builder.call(builder.genericPredicateOp("join", true), builder.joinFields()), +// builder.call(builder.genericPredicateOp("pred", true), builder.joinFields()))) +// .build(); +// runner = loadRule(FilterIntoJoin.Config.DEFAULT.toRule()); +// tester.verify(runner, before, after); + } + + public void serialize(RRule rule, String path) { + var generator = new CalciteGenerator(); + var code_gen = generator.generate(rule); + try { + Files.write(Path.of(path, STR."\{rule.name()}.java"), code_gen.getBytes()); + } catch (IOException ioe) { + System.err.println(ioe.getMessage()); + } + } + + public void test(RelOptRule rule, Seq tests) { + System.out.println(STR."Testing rule \{rule.getClass().getSimpleName()}"); + var runner = loadRule(rule); + var exams = tests.mapUnchecked(t -> Tuple.of(t, JSONDeserializer.load(new File(t)))); + for (var entry : exams) { + if (entry.getValue().size() != 2) { + System.err.println(STR."\{entry.getKey()} does not have exactly two nodes, and thus is not a valid test"); + continue; + } + verify(runner, entry.getValue().get(0), entry.getValue().get(1)); + } + } + + public void verify(HepPlanner runner, RelNode source, RelNode target) { + runner.setRoot(source); + var answer = runner.findBestExp(); + System.out.println(STR."> Given source RelNode:\n\{source.explain()}"); + System.out.println(STR."> Actual rewritten RelNode:\n\{answer.explain()}"); + System.out.println(STR."> Expected rewritten RelNode:\n\{target.explain()}"); + } + +} diff --git a/src/main/java/org/qed/Generated/CalciteUtilities.java b/src/main/java/org/qed/Generated/CalciteUtilities.java new file mode 100644 index 0000000..22e545c --- /dev/null +++ b/src/main/java/org/qed/Generated/CalciteUtilities.java @@ -0,0 +1,16 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rex.RexNode; +import org.qed.RuleBuilder; + +import java.util.List; + +public record CalciteUtilities() { + public List compose(RelNode base, List inner, List outer) { + var builder = RuleBuilder.create(); + return RelOptUtil.pushPastProject(outer, (Project) builder.push(base).project(inner).build()); + } +} diff --git a/src/main/java/org/qed/Generated/EmptyConfig.java b/src/main/java/org/qed/Generated/EmptyConfig.java new file mode 100644 index 0000000..d784cfe --- /dev/null +++ b/src/main/java/org/qed/Generated/EmptyConfig.java @@ -0,0 +1,33 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.tools.RelBuilderFactory; +import org.checkerframework.checker.nullness.qual.Nullable; + +public interface EmptyConfig extends RelRule.Config { + @Override + default RelRule.Config withRelBuilderFactory(RelBuilderFactory factory) { + return this; + } + + @Override + default @Nullable String description() { + return "Unspecified Config Description"; + } + + @Override + default RelRule.Config withDescription(@Nullable String description) { + return this; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s -> s.operand(RelNode.class).anyInputs(); + } + + @Override + default RelRule.Config withOperandSupplier(RelRule.OperandTransform transform) { + return this; + } +} diff --git a/src/main/java/org/qed/Generated/FilterIntoJoin.java b/src/main/java/org/qed/Generated/FilterIntoJoin.java new file mode 100644 index 0000000..b288f4b --- /dev/null +++ b/src/main/java/org/qed/Generated/FilterIntoJoin.java @@ -0,0 +1,42 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.LogicalFilter; +import org.apache.calcite.rel.logical.LogicalJoin; + +public class FilterIntoJoin extends RelRule { + protected FilterIntoJoin(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_4 = call.builder(); + call.transformTo(var_4.push(call.rel(2)).push(call.rel(3)).join(JoinRelType.INNER, + var_4.push(call.rel(2)).push(call.rel(3)).and(((LogicalJoin) call.rel(1)).getCondition(), + ((LogicalFilter) call.rel(0)).getCondition())).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterIntoJoin toRule() { + return new FilterIntoJoin(this); + } + + @Override + default String description() { + return "FilterIntoJoin"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_3 -> s_3.operand(LogicalFilter.class).oneInput(s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Generated/FilterMerge.java b/src/main/java/org/qed/Generated/FilterMerge.java new file mode 100644 index 0000000..3ee9946 --- /dev/null +++ b/src/main/java/org/qed/Generated/FilterMerge.java @@ -0,0 +1,38 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.logical.LogicalFilter; + +public class FilterMerge extends RelRule { + protected FilterMerge(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(2)).filter(var_3.push(call.rel(2)).and(((LogicalFilter) call.rel(1)).getCondition(), ((LogicalFilter) call.rel(0)).getCondition())).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterMerge toRule() { + return new FilterMerge(this); + } + + @Override + default String description() { + return "FilterMerge"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalFilter.class).oneInput(s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Generated/FilterProjectTranspose.java b/src/main/java/org/qed/Generated/FilterProjectTranspose.java new file mode 100644 index 0000000..decb19a --- /dev/null +++ b/src/main/java/org/qed/Generated/FilterProjectTranspose.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.logical.LogicalFilter; +import org.apache.calcite.rel.logical.LogicalProject; + +public class FilterProjectTranspose extends RelRule { + protected FilterProjectTranspose(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.filter(((LogicalFilter) call.rel(1)).getCondition()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterProjectTranspose toRule() { + return new FilterProjectTranspose(this); + } + + @Override + default String description() { + return "FilterProjectTranspose"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalProject.class).oneInput(s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/JSONDeserializer.java b/src/main/java/org/qed/JSONDeserializer.java index c725c40..a121124 100644 --- a/src/main/java/org/qed/JSONDeserializer.java +++ b/src/main/java/org/qed/JSONDeserializer.java @@ -34,6 +34,94 @@ public record JSONDeserializer() { private final static ObjectMapper mapper = new ObjectMapper(); + private static ImmutableSeq array(JsonNode node) throws Exception { + if (!node.isArray()) throw new Exception(); + return ImmutableSeq.from(node.elements()); + } + + private static ImmutableSeq array(JsonNode node, String path) throws Exception { + return array(node.required(path)); + } + + private static String string(JsonNode node) throws Exception { + if (!node.isTextual()) throw new Exception(); + return node.asText(); + } + + private static String string(JsonNode node, String path) throws Exception { + return string(node.required(path)); + } + + private static int integer(JsonNode node) throws Exception { + if (!node.isInt()) throw new Exception(); + return node.asInt(); + } + + private static int integer(JsonNode node, String path) throws Exception { + return integer(node.required(path)); + } + + private static boolean bool(JsonNode node) throws Exception { + if (!node.isBoolean()) throw new Exception(); + return node.asBoolean(); + } + + static SqlTypeName typeName(String name) { + name = switch (name) { + case "BOOL" -> "BOOLEAN"; + case "INT", "INT2", "INT4", "OID" -> "INTEGER"; + case "TIMESTAMPTZ" -> "TIMESTAMP"; + case "TIMETZ" -> "TIME"; + case "STRING" -> "VARCHAR"; + case "JSONB" -> "VARBINARY"; + default -> name; + }; + return Enum.valueOf(SqlTypeName.class, name); + } + + public static ImmutableSeq load(File file) throws Exception { + return new JSONDeserializer().deserialize(mapper.readTree(file)); + } + + public static void main(String[] args) throws Exception { + var refs = Seq.from(new File("RelOptRulesTest").listFiles()); + for (var file : refs) { + try { + var store = mapper.readTree(file); + new JSONDeserializer().deserialize(store); + } catch (Exception e) { + System.err.println("===> " + file.getName() + " <==="); + System.err.println(e.getMessage()); + System.err.println(); + } + } + } + + public ImmutableSeq deserialize(JsonNode node) throws Exception { + var builder = RuleBuilder.create(); + var tables = array(node, "schemas").mapChecked(schema -> { + var types = array(schema, "types").mapChecked(JSONDeserializer::string); + var nullabilities = array(schema, "nullable").mapChecked(JSONDeserializer::bool); + var name = schema.path("name").asText("DEFAULT_TABLE_NAME"); + var fields = schema.get("fields") == null ? + Seq.fill(types.size(), i -> String.format("DEFAULT_FIELD_NAME_%d", i)) : + array(schema, "fields").mapChecked(JSONDeserializer::string); + var keys = Set.from(array(schema, "key").map( + CheckedFunction.of(key -> ImmutableBitSet.of(array(key).mapChecked(JSONDeserializer::integer))))); + if (types.size() != nullabilities.size()) + throw new Exception("Expecting corresponding types and nullabilities"); + var sts = types.zip(nullabilities).map(tn -> { + var type = builder.getTypeFactory().createSqlType(typeName(tn.component1())); + return builder.getTypeFactory().createTypeWithNullability(type, tn.component2()); + }); + var table = new QedTable(name, fields, sts, keys, Set.empty()); + builder.addTable(table); + return table; + }); + var rel = new Rel(builder, ImmutableSeq.empty(), tables); + return array(node, "queries").mapChecked(rel); + } + private record Rel(RuleBuilder builder, ImmutableSeq globals, ImmutableSeq tables) implements CheckedFunction { Rel(RuleBuilder builder) { @@ -156,6 +244,15 @@ yield builder().push(sorted).sortLimit(rex().deserialize(content.required("offse private record Rex(RuleBuilder builder, ImmutableSeq globals, RexCorrelVariable local, ImmutableSeq tables) implements CheckedFunction { + static Seq ops = Seq.from(SqlStdOperatorTable.class.getDeclaredFields()) + .filter(f -> java.lang.reflect.Modifier.isPublic(f.getModifiers()) && + java.lang.reflect.Modifier.isStatic(f.getModifiers())).map(f -> { + var mist = Try.of(() -> f.get(null)).getOrNull(); + if (mist == null) return null; + if (mist instanceof SqlOperator op) return op; + return null; + }).filter(Objects::nonNull); + public RexNode resolve(int lvl) { assert lvl < globals().size() + local().getType().getFieldCount(); return lvl < globals().size() ? globals().get(lvl) : builder().getRexBuilder() @@ -177,15 +274,6 @@ public RelDataType type(String name) { return builder().getTypeFactory().createSqlType(typeName(name)); } - static Seq ops = Seq.from(SqlStdOperatorTable.class.getDeclaredFields()) - .filter(f -> java.lang.reflect.Modifier.isPublic(f.getModifiers()) && - java.lang.reflect.Modifier.isStatic(f.getModifiers())).map(f -> { - var mist = Try.of(() -> f.get(null)).getOrNull(); - if (mist == null) return null; - if (mist instanceof SqlOperator op) return op; - return null; - }).filter(Objects::nonNull); - SqlOperator op(String name, int arity) throws Exception { switch (name) { case "BOOL_AND" -> { @@ -273,93 +361,5 @@ public RexNode deserialize(JsonNode node) throws Exception { } } } - - private static ImmutableSeq array(JsonNode node) throws Exception { - if (!node.isArray()) throw new Exception(); - return ImmutableSeq.from(node.elements()); - } - - private static ImmutableSeq array(JsonNode node, String path) throws Exception { - return array(node.required(path)); - } - - private static String string(JsonNode node) throws Exception { - if (!node.isTextual()) throw new Exception(); - return node.asText(); - } - - private static String string(JsonNode node, String path) throws Exception { - return string(node.required(path)); - } - - private static int integer(JsonNode node) throws Exception { - if (!node.isInt()) throw new Exception(); - return node.asInt(); - } - - private static int integer(JsonNode node, String path) throws Exception { - return integer(node.required(path)); - } - - private static boolean bool(JsonNode node) throws Exception { - if (!node.isBoolean()) throw new Exception(); - return node.asBoolean(); - } - - static SqlTypeName typeName(String name) { - name = switch (name) { - case "BOOL" -> "BOOLEAN"; - case "INT", "INT2", "INT4", "OID" -> "INTEGER"; - case "TIMESTAMPTZ" -> "TIMESTAMP"; - case "TIMETZ" -> "TIME"; - case "STRING" -> "VARCHAR"; - case "JSONB" -> "VARBINARY"; - default -> name; - }; - return Enum.valueOf(SqlTypeName.class, name); - } - - public ImmutableSeq deserialize(JsonNode node) throws Exception { - var builder = RuleBuilder.create(); - var tables = array(node, "schemas").mapChecked(schema -> { - var types = array(schema, "types").mapChecked(JSONDeserializer::string); - var nullabilities = array(schema, "nullable").mapChecked(JSONDeserializer::bool); - var name = schema.path("name").asText("DEFAULT_TABLE_NAME"); - var fields = schema.get("fields") == null ? - Seq.fill(types.size(), i -> String.format("DEFAULT_FIELD_NAME_%d", i)) : - array(schema, "fields").mapChecked(JSONDeserializer::string); - var keys = Set.from(array(schema, "key").map( - CheckedFunction.of(key -> ImmutableBitSet.of(array(key).mapChecked(JSONDeserializer::integer))))); - if (types.size() != nullabilities.size()) - throw new Exception("Expecting corresponding types and nullabilities"); - var sts = types.zip(nullabilities).map(tn -> { - var type = builder.getTypeFactory().createSqlType(typeName(tn.component1())); - return builder.getTypeFactory().createTypeWithNullability(type, tn.component2()); - }); - var table = new QedTable(name, fields, sts, keys, Set.empty()); - builder.addTable(table); - return table; - }); - var rel = new Rel(builder, ImmutableSeq.empty(), tables); - return array(node, "queries").mapChecked(rel); - } - - public static ImmutableSeq load(File file) throws Exception { - return new JSONDeserializer().deserialize(mapper.readTree(file)); - } - - public static void main(String[] args) throws Exception { - var refs = Seq.from(new File("RelOptRulesTest").listFiles()); - for (var file : refs) { - try { - var store = mapper.readTree(file); - new JSONDeserializer().deserialize(store); - } catch (Exception e) { - System.err.println("===> " + file.getName() + " <==="); - System.err.println(e.getMessage()); - System.err.println(); - } - } - } } diff --git a/src/main/java/org/qed/JSONSerializer.java b/src/main/java/org/qed/JSONSerializer.java index bf8f4ec..a3b818a 100644 --- a/src/main/java/org/qed/JSONSerializer.java +++ b/src/main/java/org/qed/JSONSerializer.java @@ -21,32 +21,61 @@ public record JSONSerializer(Env env) { private final static ObjectMapper mapper = new ObjectMapper(); - private record Rel(Env env) { - Rel() { - this(new Env(0, ImmutableMap.empty(), MutableList.create())); - } + private static ArrayNode array(Seq objs) { + return new ArrayNode(mapper.getNodeFactory(), objs.asJava()); + } - private record Env(int lvl, ImmutableMap globals, MutableList tables) { - Env recorded(Set ids) { - return new Env(lvl, Seq.from(ids).foldLeft(globals, (g, id) -> g.putted(id, lvl)), tables); - } + private static ObjectNode object(Map fields) { + return new ObjectNode(mapper.getNodeFactory(), fields.asJava()); + } - Env lifted(int d) { - return new Env(lvl + d, globals, tables); - } + private static BooleanNode bool(boolean b) { + return BooleanNode.valueOf(b); + } - int resolve(RelOptTable table) { - var idx = tables.indexOf(table); - if (idx == -1) { - idx = tables.size(); - tables.append(table); - } - return idx; - } + private static TextNode string(String s) { + return new TextNode(s); + } - public Rex.Env rex(int delta) { - return new Rex.Env(lvl, delta, globals, tables); - } + private static TextNode type(RelDataType type) { + return new TextNode(type.getSqlTypeName().getName()); + } + + private static IntNode integer(int i) { + return new IntNode(i); + } + + public static ObjectNode serialize(Seq relNodes) { + var shuttle = new Rel(); + var helps = array(relNodes.map(rel -> new TextNode(rel.explain()))); + var queries = array(relNodes.map(shuttle::serialize)); + var tables = shuttle.env.tables(); + var schemas = array(tables.map(table -> { + var visitor = new Rex(shuttle.env.rex(table.getRowType().getFieldCount())); + var qedTable = table.unwrap(QedTable.class); + var fields = Seq.from(table.getRowType().getFieldList()); + return qedTable == null ? + object(Map.of("name", string(Seq.from(table.getQualifiedName()).joinToString(".")), "fields", + array(fields.map(field -> string(field.getName()))), "types", + array(fields.map(field -> type(field.getType()))), "nullable", + array(fields.map(field -> bool(field.getType().isNullable()))), "key", + array((table.getKeys() != null ? Seq.from(table.getKeys()) : + Seq.empty()).map( + key -> array(Seq.from(key).map(JSONSerializer::integer)))), "guaranteed", + array(Seq.empty()))) : object(Map.of("name", string(qedTable.getName()), "fields", + array(qedTable.getColumnNames().map(JSONSerializer::string)), "types", + array(qedTable.getColumnTypes().map(JSONSerializer::type)), "nullable", + array(qedTable.getColumnTypes().map(type -> bool(type.isNullable()))), "key", + array(Seq.from(qedTable.getKeys().map(key -> array(Seq.from(key).map(JSONSerializer::integer))))), + "guaranteed", array(qedTable.getConstraints().map(visitor::serialize).toImmutableSeq()))); + })); + + return object(Map.of("schemas", schemas, "queries", queries, "help", helps)); + } + + private record Rel(Env env) { + Rel() { + this(new Env(0, ImmutableMap.empty(), MutableList.create())); } public JsonNode serialize(RelNode rel) { @@ -132,20 +161,32 @@ yield object(Map.of("sort", default -> throw new RuntimeException("Not implemented: " + rel.getRelTypeName()); }; } - } - private record Rex(Env env) { - private record Env(int base, int delta, ImmutableMap globals, - MutableList tables) { - public Rel.Env rel() { - return new Rel.Env(base + delta, globals, tables); + private record Env(int lvl, ImmutableMap globals, MutableList tables) { + Env recorded(Set ids) { + return new Env(lvl, Seq.from(ids).foldLeft(globals, (g, id) -> g.putted(id, lvl)), tables); } - int resolve(CorrelationId id) { - return globals.getOrThrow(id, () -> new RuntimeException("Correlation ID not declared")); + Env lifted(int d) { + return new Env(lvl + d, globals, tables); + } + + int resolve(RelOptTable table) { + var idx = tables.indexOf(table); + if (idx == -1) { + idx = tables.size(); + tables.append(table); + } + return idx; + } + + public Rex.Env rex(int delta) { + return new Rex.Env(lvl, delta, globals, tables); } } + } + private record Rex(Env env) { public JsonNode serialize(RexNode rex) { return switch (rex) { case RexInputRef inputRef -> object(Map.of("column", integer(inputRef.getIndex() + env.base()), "type", @@ -165,57 +206,16 @@ public JsonNode serialize(RexNode rex) { default -> throw new RuntimeException("Not implemented: " + rex.getKind()); }; } - } - - private static ArrayNode array(Seq objs) { - return new ArrayNode(mapper.getNodeFactory(), objs.asJava()); - } - - private static ObjectNode object(Map fields) { - return new ObjectNode(mapper.getNodeFactory(), fields.asJava()); - } - - private static BooleanNode bool(boolean b) { - return BooleanNode.valueOf(b); - } - - private static TextNode string(String s) { - return new TextNode(s); - } - - private static TextNode type(RelDataType type) { - return new TextNode(type.getSqlTypeName().getName()); - } - - private static IntNode integer(int i) { - return new IntNode(i); - } - public static ObjectNode serialize(Seq relNodes) { - var shuttle = new Rel(); - var helps = array(relNodes.map(rel -> new TextNode(rel.explain()))); - var queries = array(relNodes.map(shuttle::serialize)); - var tables = shuttle.env.tables(); - var schemas = array(tables.map(table -> { - var visitor = new Rex(shuttle.env.rex(table.getRowType().getFieldCount())); - var qedTable = table.unwrap(QedTable.class); - var fields = Seq.from(table.getRowType().getFieldList()); - return qedTable == null ? - object(Map.of("name", string(Seq.from(table.getQualifiedName()).joinToString(".")), "fields", - array(fields.map(field -> string(field.getName()))), "types", - array(fields.map(field -> type(field.getType()))), "nullable", - array(fields.map(field -> bool(field.getType().isNullable()))), "key", - array((table.getKeys() != null ? Seq.from(table.getKeys()) : - Seq.empty()).map( - key -> array(Seq.from(key).map(JSONSerializer::integer)))), "guaranteed", - array(Seq.empty()))) : object(Map.of("name", string(qedTable.getName()), "fields", - array(qedTable.getColumnNames().map(JSONSerializer::string)), "types", - array(qedTable.getColumnTypes().map(JSONSerializer::type)), "nullable", - array(qedTable.getColumnTypes().map(type -> bool(type.isNullable()))), "key", - array(Seq.from(qedTable.getKeys().map(key -> array(Seq.from(key).map(JSONSerializer::integer))))), - "guaranteed", array(qedTable.getConstraints().map(visitor::serialize).toImmutableSeq()))); - })); + private record Env(int base, int delta, ImmutableMap globals, + MutableList tables) { + public Rel.Env rel() { + return new Rel.Env(base + delta, globals, tables); + } - return object(Map.of("schemas", schemas, "queries", queries, "help", helps)); + int resolve(CorrelationId id) { + return globals.getOrThrow(id, () -> new RuntimeException("Correlation ID not declared")); + } + } } } diff --git a/src/main/java/org/qed/RRule.java b/src/main/java/org/qed/RRule.java index 15265ab..d7f3606 100644 --- a/src/main/java/org/qed/RRule.java +++ b/src/main/java/org/qed/RRule.java @@ -1,9 +1,18 @@ -package org.cosette; +package org.qed; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import kala.collection.Seq; import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rex.RexNode; +import java.io.File; +import java.io.IOException; + public interface RRule { + interface RRuleFamily { + Seq family(); + } RelRN before(); @@ -13,82 +22,16 @@ default String explain() { return STR."\{getClass().getName()}\n\{before().semantics().explain()}=>\n\{after().semantics().explain()}"; } - record FilterMerge() implements RRule { - static final RelRN source = RelRN.scan("Source", "Source_Type"); - static final RexRN inner = source.pred("inner"); - static final RexRN outer = source.pred("outer"); - - @Override - public RelRN before() { - return source.filter(inner).filter(outer); - } - - @Override - public RelRN after() { - return source.filter(RexRN.and(inner, outer)); - } + default String name() { + return getClass().getSimpleName(); } - record FilterIntoJoin() implements RRule { - static final RelRN left = RelRN.scan("Left", "Left_Type"); - static final RelRN right = RelRN.scan("Right", "Right_Type"); - static final RexRN joinCond = left.joinPred("join", right); - - @Override - public RelRN before() { - var join = left.join(JoinRelType.INNER, joinCond, right); - return join.filter("outer"); - } - - @Override - public RelRN after() { - return left.join(JoinRelType.INNER, RexRN.and(joinCond, left.joinPred("outer", right)), right); - } + default ObjectNode toJson() { + return JSONSerializer.serialize(Seq.of(before().semantics(), after().semantics())); } - record FilterProjectTranspose() implements RRule { - static final RelRN source = RelRN.scan("Source", "Source_Type"); - static final RexRN proj = source.proj("proj", "Project_Type"); - - @Override - public RelRN before() { - return source.filter(proj.pred("pred")).project(proj); - } - - @Override - public RelRN after() { - return source.project(proj).filter("pred"); - } - } - - record JoinConditionPush() implements RRule { - record JoinPred(RelRN left, RelRN right) implements RexRN { - - @Override - public RexNode semantics() { - return RexRN.and(left.joinPred(bothPred(), right), left.joinField(0, right).pred(leftPred()), left.joinField(1, right).pred(rightPred())).semantics(); - } - - public String bothPred() { return "both"; } - public String leftPred() { return "left"; } - public String rightPred() { return "right"; } - - } - - static final RelRN left = RelRN.scan("Left", "Left_Type"); - static final RelRN right = RelRN.scan("Right", "Right_Type"); - static final JoinPred joinPred = new JoinPred(left, right); - - @Override - public RelRN before() { - return left.join(JoinRelType.INNER, joinPred, right); - } - - @Override - public RelRN after() { - var leftRN = left.filter(joinPred.leftPred()); - var rightRN = right.filter(joinPred.rightPred()); - return leftRN.join(JoinRelType.INNER, joinPred.bothPred(), rightRN); - } + default void dump(String path) throws IOException { + new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(new File(path), toJson()); } } + diff --git a/src/main/java/org/qed/RRuleInstance.java b/src/main/java/org/qed/RRuleInstance.java new file mode 100644 index 0000000..49a565c --- /dev/null +++ b/src/main/java/org/qed/RRuleInstance.java @@ -0,0 +1,356 @@ +package org.qed; + +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rex.RexNode; + +public record RRuleInstance() { + record FilterIntoJoin() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + var join = left.join(JoinRelType.INNER, joinCond, right); + return join.filter("outer"); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.INNER, RexRN.and(joinCond, left.joinPred("outer", right)), right); + } + } + + record FilterMerge() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN inner = source.pred("inner"); + static final RexRN outer = source.pred("outer"); + + @Override + public RelRN before() { + return source.filter(inner).filter(outer); + } + + @Override + public RelRN after() { + return source.filter(RexRN.and(inner, outer)); + } + } + + record FilterProjectTranspose() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN proj = source.proj("proj", "Project_Type"); + + @Override + public RelRN before() { + return source.filter(proj.pred("pred")).project(proj); + } + + @Override + public RelRN after() { + return source.project(proj).filter("pred"); + } + } + + record FilterReduceFalse() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + + @Override + public RelRN before() { + return source.filter(RexRN.falseLiteral()); + } + + @Override + public RelRN after() { + return source.empty(); + } + } + + record FilterReduceTrue() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + + @Override + public RelRN before() { + return source.filter(RexRN.trueLiteral()); + } + + @Override + public RelRN after() { + return source; + } + } + +// record FilterSetOpTransposeRule implements RRule { +// +// } + +// record IntersectMerge implements RRule { +// +// } + + record JoinConditionPush() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final JoinPred joinPred = new JoinPred(left, right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, joinPred, right); + } + + @Override + public RelRN after() { + var leftRN = left.filter(joinPred.leftPred()); + var rightRN = right.filter(joinPred.rightPred()); + return leftRN.join(JoinRelType.INNER, joinPred.bothPred(), rightRN); + } + + public record JoinPred(RelRN left, RelRN right) implements RexRN { + + @Override + public RexNode semantics() { + return RexRN.and(left.joinPred(bothPred(), right), left.joinField(0, right).pred(leftPred()), + left.joinField(1, right).pred(rightPred())).semantics(); + } + + public String bothPred() {return "both";} + + public String leftPred() {return "left";} + + public String rightPred() {return "right";} + + } + } + + record JoinAddRedundantSemiJoin() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final String pred = "pred"; + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, pred, right); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.SEMI, pred, right).join(JoinRelType.INNER, pred, right); + } + } + + // Todo: explore join types, see line 102 of JoinAssociateRule + record JoinAssociate() implements RRule { + static final RelRN a = RelRN.scan("A", "A_Type"); + static final RelRN b = RelRN.scan("B", "A_Type"); + static final RelRN c = RelRN.scan("C", "A_Type"); + static final String pred_a = "pred_a"; + static final String pred_b = "pred_b"; + static final String pred_ab = "pred_ab"; + static final String pred_c = "pred_c"; + static final String pred_ac = "pred_ac"; + static final String pred_bc = "pred_bc"; + static final String pred_abc = "pred_abc"; + + @Override + public RelRN before() { + var ab = a.join(JoinRelType.INNER, RexRN.and( + new RexRN.JoinField(0, a, b).pred(pred_a), + new RexRN.JoinField(1, a, b).pred(pred_b), + new RexRN.Pred(pred_ab, true, a.joinFields(b)) + ), b); + return ab.join(JoinRelType.INNER, RexRN.and( + new RexRN.JoinField(2, ab, c).pred(pred_c), + new RexRN.Pred(pred_ac, true, a.joinFields(b, 0, 2)), + new RexRN.Pred(pred_bc, true, a.joinFields(b, 1, 2)), + new RexRN.Pred(pred_abc, true, a.joinFields(b)) + ), c); + } + + @Override + public RelRN after() { + var bc = b.join(JoinRelType.INNER, RexRN.and( + new RexRN.JoinField(0, a, b).pred(pred_b), + new RexRN.JoinField(1, a, b).pred(pred_c), + new RexRN.Pred(pred_bc, true, a.joinFields(b)) + ), c); + return a.join(JoinRelType.INNER, RexRN.and( + new RexRN.JoinField(0, bc, c).pred(pred_a), + new RexRN.Pred(pred_ab, true, a.joinFields(b, 0, 1)), + new RexRN.Pred(pred_ac, true, a.joinFields(b, 0, 2)), + new RexRN.Pred(pred_abc, true, a.joinFields(b)) + ), c); + } + } + +// record JoinCommute() implements RRule { +// static final RelRN left = RelRN.scan("Left", "Left_Type"); +// static final RelRN right = RelRN.scan("Right", "Right_Type"); +// static final String pred = "pred"; +// +// @Override +// public RelRN before() { +// return left.join(JoinRelType.INNER, pred, right); +// } +// +// @Override +// public RelRN after() { +// return right.join(JoinRelType.INNER, new RexRN.Pred( +// pred, true, right.joinFields(left, 1, 0) +// ), left).project("?"); +// } +// } + +// record JoinExtractFilter() implements RRule { +// +// } + +// record JoinProjectTranspose() implements RRule { +// +// } + + // JoinConditionPush? +// record JoinPushExpressions() implements RRule { +// +// } + + // JoinConditionPush? +// record JoinPushTransitivePredicates() implements RRule { +// +// } + +// record JoinToSemiJoin() implements RRule { +// +// } + +// record JoinLeftUnionTranspose() implements RRule { +// +// } + +// record JoinRightUnionTranspose() implements RRule { +// +// } + + record ProjectFilterTranspose() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + + @Override + public RelRN before() { + var pred = new ProjectFilterTranspose.ProjectFilter(source); + return source.filter(pred).project(pred.proj(), pred.projType()); + } + + @Override + public RelRN after() { + var pred = new ProjectFilterTranspose.ProjectFilter(source); + return source.project(pred.proj(), pred.projType()).filter(pred.pred()); + } + + public record ProjectFilter(RelRN source) implements RexRN { + @Override + public RexNode semantics() { + return source.pred(pred()).proj(proj(), projType()).semantics(); + } + + public String proj() { + return "proj"; + } + + public String projType() { + return "Project_Type"; + } + + public String pred() { + return "pred"; + } + } + } + +// record ProjectJoinRemove() implements RRule { +// +// @Override +// public RelRN before() { +// return null; +// } +// +// @Override +// public RelRN after() { +// return null; +// } +// } + +// record ProjectJoinJoinRemove() implements RRule { +// +// } + +// record ProjectJoinTranspose() implements RRule { +// +// } + + record ProjectMerge() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN inner = source.proj("inner", "Inner_Type"); + static final String outer = "outer"; + static final String outerType = "Outer_Type"; + + @Override + public RelRN before() { + return source.project(inner).project(outer, outerType); + } + + @Override + public RelRN after() { + return source.project(inner.proj(outer, outerType)); + } + } + +// record ProjectSetOpTranspose() implements RRule { +// +// } + + record ProjectRemove() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + + @Override + public RelRN before() { + return source.project(source.field(0)); + } + + @Override + public RelRN after() { + return null; + } + } + +// record SemiJoinFilterTranspose() implements RRule { +// +// } + +// record SemiJoinJoinTranspose() implements RRule { +// +// } + +// record SemiJoinProjectTranspose() implements RRule { +// +// } + +// record SemiJoinRemove() implements RRule { +// +// } + +// record UnionMerge() implements RRule { +// +// } + +// record UnionRemove() implements RRule { +// +// } +} + +/* + * Semantically identical cases: + * FilterExpandIsNotDistinctFrom + * FilterScan + * JoinReduceExpression + * ProjectReduceExpression + * ProjectTableScan + */ diff --git a/src/main/java/org/qed/RelRN.java b/src/main/java/org/qed/RelRN.java index 91bb217..7035e96 100644 --- a/src/main/java/org/qed/RelRN.java +++ b/src/main/java/org/qed/RelRN.java @@ -1,4 +1,4 @@ -package org.cosette; +package org.qed; import kala.collection.Seq; import kala.collection.Set; @@ -6,6 +6,7 @@ import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.util.ImmutableBitSet; +import java.util.Arrays; import java.util.stream.IntStream; public interface RelRN { @@ -23,16 +24,25 @@ default RexRN field(int ordinal) { return new RexRN.Field(ordinal, this); } + default Seq fields(int... ordinals) { + return Seq.from(Arrays.stream(ordinals).iterator()).map(this::field); + } + default Seq fields() { - return Seq.from(IntStream.range(0, semantics().getRowType().getFieldCount()).iterator()).map(this::field); + return fields(IntStream.range(0, semantics().getRowType().getFieldCount()).toArray()); } default RexRN joinField(int ordinal, RelRN right) { return new RexRN.JoinField(ordinal, this, right); } + default Seq joinFields(RelRN right, int... ordinals) { + return Seq.from(Arrays.stream(ordinals).iterator()).map(i -> joinField(i, right)); + } + default Seq joinFields(RelRN right) { - return Seq.from(IntStream.range(0, semantics().getRowType().getFieldCount() + right.semantics().getRowType().getFieldCount()).iterator()).map(i -> joinField(i, right)); + return joinFields(right, IntStream.range(0, + semantics().getRowType().getFieldCount() + right.semantics().getRowType().getFieldCount()).toArray()); } default RexRN.Pred pred(String name) { @@ -67,7 +77,7 @@ default Join join(JoinRelType ty, RexRN cond, RelRN right) { return new Join(ty, cond, this, right); } - default Join join(JoinRelType ty, String name, RelRN right) { return join(ty, joinPred(name, right), right); } + default Join join(JoinRelType ty, String name, RelRN right) {return join(ty, joinPred(name, right), right);} default Union union(boolean all, RelRN... sources) { return new Union(all, Seq.of(this).appendedAll(sources)); @@ -77,11 +87,16 @@ default Intersect intersect(boolean all, RelRN... sources) { return new Intersect(all, Seq.of(this).appendedAll(sources)); } + default Empty empty() { + return new Empty(this); + } + record Scan(String name, RelType.VarType ty, boolean unique) implements RelRN { @Override public RelNode semantics() { - var table = new CosetteTable(name, Seq.of(STR."col-\{name}"), Seq.of(ty), unique ? Set.of(ImmutableBitSet.of(0)) : Set.empty(), Set.empty()); + var table = new QedTable(name, Seq.of(STR."col-\{name}"), Seq.of(ty), unique ? + Set.of(ImmutableBitSet.of(0)) : Set.empty(), Set.empty()); return RuleBuilder.create().addTable(table).scan(name).build(); } } @@ -129,4 +144,12 @@ public RelNode semantics() { } } + record Empty(RelRN sourceType) implements RelRN { + + @Override + public RelNode semantics() { + return RuleBuilder.create().values(sourceType.semantics().getRowType()).build(); + } + } + } diff --git a/src/main/java/org/qed/RexRN.java b/src/main/java/org/qed/RexRN.java index 9a798a7..873d932 100644 --- a/src/main/java/org/qed/RexRN.java +++ b/src/main/java/org/qed/RexRN.java @@ -1,4 +1,4 @@ -package org.cosette; +package org.qed; import kala.collection.Seq; import org.apache.calcite.rex.RexNode; @@ -8,10 +8,19 @@ public interface RexRN { static RelType.VarType varType(String id, boolean nullable) { return new RelType.VarType(id, nullable); } - static And and(RexRN ...sources) { + + static And and(RexRN... sources) { return new And(Seq.from(sources)); } + static False falseLiteral() { + return new False(); + } + + static True trueLiteral() { + return new True(); + } + RexNode semantics(); default Pred pred(String name) { @@ -67,11 +76,11 @@ public RexNode semantics() { } } - record Or(Seq sources) implements RexRN { + record False() implements RexRN { @Override public RexNode semantics() { - return RuleBuilder.create().or(sources.map(RexRN::semantics)); + return RuleBuilder.create().literal(false); } } @@ -83,5 +92,20 @@ public RexNode semantics() { } } + record Or(Seq sources) implements RexRN { + + @Override + public RexNode semantics() { + return RuleBuilder.create().or(sources.map(RexRN::semantics)); + } + } + + record True() implements RexRN { + + @Override + public RexNode semantics() { + return RuleBuilder.create().literal(true); + } + } } diff --git a/src/main/java/org/qed/SchemaGenerator.java b/src/main/java/org/qed/SchemaGenerator.java index bd47b37..484b1ef 100644 --- a/src/main/java/org/qed/SchemaGenerator.java +++ b/src/main/java/org/qed/SchemaGenerator.java @@ -227,7 +227,8 @@ public void addTable(SqlCreateTable createTable) { } } var qedTable = new QedTable(createTable.name.toString(), names.zip( - types.zip(nullabilities).map(type -> new RelType.BaseType(type.component1(), type.component2()))) + types.zip(nullabilities).map(type -> new RelType.BaseType(type.component1(), + type.component2()))) .toImmutableMap(), ImmutableSet.from(keys), ImmutableSet.from(checkConstraints)); tables.put(createTable.name.toString(), qedTable); } From 517c9b0b26ea38bba9e24fc15139dc81cbe98710 Mon Sep 17 00:00:00 2001 From: macronova Date: Thu, 16 May 2024 16:15:29 -0700 Subject: [PATCH 04/78] Checkpoint for thesis --- .../org/qed/Generated/CalciteGenerator.java | 8 +- .../java/org/qed/Generated/CalciteTester.java | 21 +- src/main/java/org/qed/JSONSerializer.java | 11 +- src/main/java/org/qed/RRule.java | 85 +++- src/main/java/org/qed/RRuleInstance.java | 87 +++-- src/main/java/org/qed/RelJSONShuttle.java | 364 ------------------ src/main/java/org/qed/RelRN.java | 48 ++- src/main/java/org/qed/RelType.java | 4 + src/main/java/org/qed/RexRN.java | 24 +- 9 files changed, 225 insertions(+), 427 deletions(-) delete mode 100644 src/main/java/org/qed/RelJSONShuttle.java diff --git a/src/main/java/org/qed/Generated/CalciteGenerator.java b/src/main/java/org/qed/Generated/CalciteGenerator.java index 3448f73..1262b2c 100644 --- a/src/main/java/org/qed/Generated/CalciteGenerator.java +++ b/src/main/java/org/qed/Generated/CalciteGenerator.java @@ -83,12 +83,12 @@ public Env onMatchProject(Env env, RelRN.Project project) { @Override public Env onMatchPred(Env env, RexRN.Pred pred) { - return env.symbol(pred.name(), env.current()); + return env.symbol(pred.operator().getName(), env.current()); } @Override public Env onMatchProj(Env env, RexRN.Proj proj) { - return env.symbol(proj.name(), env.current()); + return env.symbol(proj.operator().getName(), env.current()); } @Override @@ -136,7 +136,7 @@ public Env transformFilter(Env env, RelRN.Filter filter) { @Override public Env transformPred(Env env, RexRN.Pred pred) { - return env.focus(env.symbols().get(pred.name())); + return env.focus(env.symbols().get(pred.operator().getName())); } @Override @@ -145,7 +145,7 @@ public Env transformJoin(Env env, RelRN.Join join) { var right_source_transform = transform(left_source_transform, join.right()); var source_expression = right_source_transform.current(); var cond_transform = transform(right_source_transform, join.cond()); - var join_type = switch (join.ty()) { + var join_type = switch (join.ty().semantics()) { case INNER -> "JoinRelType.INNER"; case LEFT -> "JoinRelType.LEFT"; case RIGHT -> "JoinRelType.RIGHT"; diff --git a/src/main/java/org/qed/Generated/CalciteTester.java b/src/main/java/org/qed/Generated/CalciteTester.java index c419a1b..61d250b 100644 --- a/src/main/java/org/qed/Generated/CalciteTester.java +++ b/src/main/java/org/qed/Generated/CalciteTester.java @@ -1,5 +1,6 @@ package org.qed.Generated; +import com.fasterxml.jackson.databind.ObjectMapper; import kala.collection.Seq; import kala.tuple.Tuple; import org.apache.calcite.plan.RelOptRule; @@ -7,6 +8,7 @@ import org.apache.calcite.plan.hep.HepProgramBuilder; import org.apache.calcite.rel.RelNode; import org.qed.JSONDeserializer; +import org.qed.JSONSerializer; import org.qed.RRule; import org.qed.RRuleInstance; @@ -19,6 +21,7 @@ public class CalciteTester { // Assuming that current working directory is the root of the project public static String genPath = "src/main/java/org/qed/Generated"; + public static String rulePath = "rules"; public static HepPlanner loadRule(RelOptRule rule) { var builder = new HepProgramBuilder().addRuleInstance(rule); @@ -26,13 +29,16 @@ public static HepPlanner loadRule(RelOptRule rule) { } public static Seq ruleList() { - var individuals = Seq.from(RRuleInstance.class.getClasses()).filter(RRule.class::isAssignableFrom).mapUnchecked(Class::getConstructor).mapUnchecked(Constructor::newInstance).map(r -> (RRule) r); - var families = Seq.from(RRuleInstance.class.getClasses()).filter(RRule.RRuleFamily.class::isAssignableFrom).mapUnchecked(Class::getConstructor).mapUnchecked(Constructor::newInstance).map(r -> (RRule.RRuleFamily) r); + var individuals = + Seq.from(RRuleInstance.class.getClasses()).filter(RRule.class::isAssignableFrom).mapUnchecked(Class::getConstructor).mapUnchecked(Constructor::newInstance).map(r -> (RRule) r); + System.out.println(Seq.from(RRuleInstance.class.getClasses()).filter(RRule.RRuleFamily.class::isAssignableFrom).mapUnchecked(Class::getConstructor)); + var families = + Seq.from(RRuleInstance.class.getClasses()).filter(RRule.RRuleFamily.class::isAssignableFrom).mapUnchecked(Class::getConstructor).mapUnchecked(Constructor::newInstance).map(r -> (RRule.RRuleFamily) r); return individuals.appendedAll(families.flatMap(RRule.RRuleFamily::family)); } public static void verify() { - ruleList().forEachUnchecked(rule -> rule.dump(STR."dump/\{rule.name()}.json")); + ruleList().forEachUnchecked(rule -> rule.dump(STR."\{rulePath}/\{rule.name()}.json")); } public static void generate() { @@ -40,8 +46,13 @@ public static void generate() { ruleList().forEach(r -> tester.serialize(r, genPath)); } - public static void main(String[] args) { - generate(); + public static void main(String[] args) throws IOException { + var rules = new RRuleInstance.JoinAssociate(); + Files.createDirectories(Path.of(rulePath)); + for (var rule : rules.family()) { + new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, STR."\{rule.name()}-\{rule.info()}.json").toFile(), rule.toJson()); + } +// generate(); // var tester = new CalciteTester(); // var builder = RuleBuilder.create(); // var table = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); diff --git a/src/main/java/org/qed/JSONSerializer.java b/src/main/java/org/qed/JSONSerializer.java index a3b818a..1d077f8 100644 --- a/src/main/java/org/qed/JSONSerializer.java +++ b/src/main/java/org/qed/JSONSerializer.java @@ -38,6 +38,9 @@ private static TextNode string(String s) { } private static TextNode type(RelDataType type) { + if (type instanceof RelType.VarType varType) { + return new TextNode(varType.getName()); + } return new TextNode(type.getSqlTypeName().getName()); } @@ -45,6 +48,10 @@ private static IntNode integer(int i) { return new IntNode(i); } + private static String qualifiedTableName(RelOptTable table) { + return Seq.from(table.getQualifiedName()).joinToString("."); + } + public static ObjectNode serialize(Seq relNodes) { var shuttle = new Rel(); var helps = array(relNodes.map(rel -> new TextNode(rel.explain()))); @@ -55,7 +62,7 @@ public static ObjectNode serialize(Seq relNodes) { var qedTable = table.unwrap(QedTable.class); var fields = Seq.from(table.getRowType().getFieldList()); return qedTable == null ? - object(Map.of("name", string(Seq.from(table.getQualifiedName()).joinToString(".")), "fields", + object(Map.of("name", string(qualifiedTableName(table)), "fields", array(fields.map(field -> string(field.getName()))), "types", array(fields.map(field -> type(field.getType()))), "nullable", array(fields.map(field -> bool(field.getType().isNullable()))), "key", @@ -172,7 +179,7 @@ Env lifted(int d) { } int resolve(RelOptTable table) { - var idx = tables.indexOf(table); + var idx = tables.map(JSONSerializer::qualifiedTableName).indexOf(qualifiedTableName(table)); if (idx == -1) { idx = tables.size(); tables.append(table); diff --git a/src/main/java/org/qed/RRule.java b/src/main/java/org/qed/RRule.java index d7f3606..157c0e4 100644 --- a/src/main/java/org/qed/RRule.java +++ b/src/main/java/org/qed/RRule.java @@ -2,18 +2,13 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; +import kala.collection.Map; import kala.collection.Seq; -import org.apache.calcite.rel.core.JoinRelType; -import org.apache.calcite.rex.RexNode; import java.io.File; import java.io.IOException; public interface RRule { - interface RRuleFamily { - Seq family(); - } - RelRN before(); RelRN after(); @@ -26,6 +21,10 @@ default String name() { return getClass().getSimpleName(); } + default String info() { + return ""; + } + default ObjectNode toJson() { return JSONSerializer.serialize(Seq.of(before().semantics(), after().semantics())); } @@ -33,5 +32,79 @@ default ObjectNode toJson() { default void dump(String path) throws IOException { new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(new File(path), toJson()); } + + interface RRuleFamily { + Seq family(); + } + + record RRuleGenerator(RRule rule, + Seq assignments) implements RRuleFamily { + @Override + public Seq family() { + return assignments.map(assignment -> new RRule() { + + @Override + public RelRN before() { + return assignment.replaceMetaRelRN(rule.before()); + } + + @Override + public RelRN after() { + return assignment.replaceMetaRelRN(rule.after()); + } + + @Override + public String name() { + return rule.name(); + } + + @Override + public String info() { + return assignment.info(); + } + }); + } + + public record MetaAssignment( + Map joinTypeAssignment) { + public RelRN.Join.JoinType replaceMetaJoinType(RelRN.Join.JoinType joinType) { + return switch (joinType) { + case RelRN.Join.JoinType.MetaJoinType metaJoinType -> joinTypeAssignment.get(metaJoinType); + default -> joinType; + }; + } + + public RexRN replaceMetaRexRN(RexRN rexRN) { + return switch (rexRN) { + case RexRN.Field field -> replaceMetaRelRN(field.source()).field(field.ordinal()); + default -> customReplaceMetaRexRN(rexRN); + }; + } + + public RelRN replaceMetaRelRN(RelRN relRN) { + return switch (relRN) { + case RelRN.Filter filter -> + replaceMetaRelRN(filter.source()).filter(replaceMetaRexRN(filter.cond())); + case RelRN.Join join -> + replaceMetaRelRN(join.left()).join(replaceMetaJoinType(join.ty()), + replaceMetaRexRN(join.cond()), replaceMetaRelRN(join.right())); + default -> customReplaceMetaRelRN(relRN); + }; + } + + public RexRN customReplaceMetaRexRN(RexRN rexRN) { + return rexRN; + } + + public RelRN customReplaceMetaRelRN(RelRN relRN) { + return relRN; + } + + public String info() { + return joinTypeAssignment.joinToString("&", (m, c) -> STR."\{m.name()}=\{c.semantics()}"); + } + + } + } } diff --git a/src/main/java/org/qed/RRuleInstance.java b/src/main/java/org/qed/RRuleInstance.java index 49a565c..a889253 100644 --- a/src/main/java/org/qed/RRuleInstance.java +++ b/src/main/java/org/qed/RRuleInstance.java @@ -1,9 +1,12 @@ package org.qed; +import kala.collection.Map; +import kala.collection.Seq; import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; -public record RRuleInstance() { +public interface RRuleInstance { record FilterIntoJoin() implements RRule { static final RelRN left = RelRN.scan("Left", "Left_Type"); static final RelRN right = RelRN.scan("Right", "Right_Type"); @@ -139,46 +142,62 @@ public RelRN after() { } // Todo: explore join types, see line 102 of JoinAssociateRule - record JoinAssociate() implements RRule { + record JoinAssociate() implements RRule.RRuleFamily { static final RelRN a = RelRN.scan("A", "A_Type"); - static final RelRN b = RelRN.scan("B", "A_Type"); - static final RelRN c = RelRN.scan("C", "A_Type"); - static final String pred_a = "pred_a"; - static final String pred_b = "pred_b"; + static final RelRN b = RelRN.scan("B", "B_Type"); + static final RelRN c = RelRN.scan("C", "C_Type"); static final String pred_ab = "pred_ab"; - static final String pred_c = "pred_c"; - static final String pred_ac = "pred_ac"; static final String pred_bc = "pred_bc"; - static final String pred_abc = "pred_abc"; + static final RelRN.Join.JoinType.MetaJoinType mjt_0 = new RelRN.Join.JoinType.MetaJoinType("mjt_0"); + static final RelRN.Join.JoinType.MetaJoinType mjt_1 = new RelRN.Join.JoinType.MetaJoinType("mjt_1"); + static final RelRN.Join.JoinType.MetaJoinType mjt_2 = new RelRN.Join.JoinType.MetaJoinType("mjt_2"); + static final RelRN.Join.JoinType.MetaJoinType mjt_3 = new RelRN.Join.JoinType.MetaJoinType("mjt_3"); + + static final RelRN before_ab = a.join(mjt_0, RexRN.and( + a.joinPred(pred_ab, b), + new RexRN.JoinField(1, a, b).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), b); + + static final RelRN before = before_ab.join(mjt_1, RexRN.and( + new RexRN.Pred(RuleBuilder.create().genericPredicateOp(pred_bc, true), before_ab.joinFields(c, 1, 2)), + new RexRN.JoinField(1, before_ab, c).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), c); + + static final RelRN after_bc = b.join(mjt_2, RexRN.and( + b.joinPred(pred_bc, c), + new RexRN.JoinField(0, b, c).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), c); + + static final RelRN after = a.join(mjt_3, RexRN.and( + new RexRN.Pred(RuleBuilder.create().genericPredicateOp(pred_ab, true), a.joinFields(after_bc, 0, 1)), + new RexRN.JoinField(1, a, after_bc).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), after_bc); + + static final RRule template = new RRule() { + @Override + public RelRN before() { + return before; + } - @Override - public RelRN before() { - var ab = a.join(JoinRelType.INNER, RexRN.and( - new RexRN.JoinField(0, a, b).pred(pred_a), - new RexRN.JoinField(1, a, b).pred(pred_b), - new RexRN.Pred(pred_ab, true, a.joinFields(b)) - ), b); - return ab.join(JoinRelType.INNER, RexRN.and( - new RexRN.JoinField(2, ab, c).pred(pred_c), - new RexRN.Pred(pred_ac, true, a.joinFields(b, 0, 2)), - new RexRN.Pred(pred_bc, true, a.joinFields(b, 1, 2)), - new RexRN.Pred(pred_abc, true, a.joinFields(b)) - ), c); + @Override + public RelRN after() { + return after; + } + + @Override + public String name() { + return JoinAssociate.class.getSimpleName(); + } + }; + + static Seq assignments() { + var joinTypes = Seq.of(JoinRelType.INNER, JoinRelType.LEFT, JoinRelType.RIGHT, JoinRelType.FULL).map(RelRN.Join.JoinType.ConcreteJoinType::new); + return joinTypes.flatMap(jt0 -> joinTypes.flatMap(jt1 -> joinTypes.flatMap(jt2 -> joinTypes.map(jt3 -> new RRule.RRuleGenerator.MetaAssignment(Map.of(mjt_0, jt0, mjt_1, jt1, mjt_2, jt2, mjt_3, jt3)))))); } @Override - public RelRN after() { - var bc = b.join(JoinRelType.INNER, RexRN.and( - new RexRN.JoinField(0, a, b).pred(pred_b), - new RexRN.JoinField(1, a, b).pred(pred_c), - new RexRN.Pred(pred_bc, true, a.joinFields(b)) - ), c); - return a.join(JoinRelType.INNER, RexRN.and( - new RexRN.JoinField(0, bc, c).pred(pred_a), - new RexRN.Pred(pred_ab, true, a.joinFields(b, 0, 1)), - new RexRN.Pred(pred_ac, true, a.joinFields(b, 0, 2)), - new RexRN.Pred(pred_abc, true, a.joinFields(b)) - ), c); + public Seq family() { + return new RRule.RRuleGenerator(template, assignments()).family(); } } diff --git a/src/main/java/org/qed/RelJSONShuttle.java b/src/main/java/org/qed/RelJSONShuttle.java deleted file mode 100644 index 1472f09..0000000 --- a/src/main/java/org/qed/RelJSONShuttle.java +++ /dev/null @@ -1,364 +0,0 @@ -package org.qed; - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.*; -import kala.collection.Map; -import kala.collection.Seq; -import kala.collection.Set; -import kala.collection.immutable.ImmutableSeq; -import kala.control.Result; -import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.core.TableScan; -import org.apache.calcite.rel.logical.*; -import org.apache.calcite.rel.type.*; -import org.apache.calcite.rex.*; -import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.fun.SqlStdOperatorTable; -import org.apache.calcite.util.ImmutableBitSet; - -import java.io.IOException; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.List; - -public record RelJSONShuttle(Env env) { - private final static ObjectMapper mapper = new ObjectMapper(); - - private static ArrayNode array(Seq objs) { - return new ArrayNode(mapper.getNodeFactory(), objs.asJava()); - } - - - private static Result, String> array(JsonNode jsonNode, String field) { - var arr = jsonNode.get(field); - if (arr == null || !arr.isArray()) { - return Result.err(String.format("Missing array field %s in:\n%s", field, jsonNode.toPrettyString())); - } - return Result.ok(ImmutableSeq.from(arr.elements())); - } - - private static ObjectNode object(Map fields) { - return new ObjectNode(mapper.getNodeFactory(), fields.asJava()); - } - - private static Result object(JsonNode jsonNode, String field) { - var obj = jsonNode.get(field); - if (obj == null) { - return Result.err(String.format("Missing object field %s in:\n%s", field, jsonNode.toPrettyString())); - } - return Result.ok(obj); - } - - private static BooleanNode bool(boolean b) { - return b ? BooleanNode.TRUE : BooleanNode.FALSE; - } - - private static T unwrap(Result res) throws Exception { - if (res.isErr()) { - throw new Exception(res.getErr()); - } - return res.get(); - } - - public static void main(String[] args) throws IOException { - var res = RelJSONShuttle.deserializeFromJson(Paths.get("ElevatedRules/filterProjectTranspose.json")); - if (res.isErr()) { - System.out.println(res.getErr()); - } else { - res.get().forEach(r -> System.out.println(r.explain())); - } - } - - public static void serializeToJson(List relNodes, Path path) throws IOException { - var shuttle = new RelJSONShuttle(Env.empty()); - var helps = array(Seq.from(relNodes).map(rel -> new TextNode(rel.explain()))); - var queries = array(Seq.from(relNodes).map(shuttle::serialize)); - - var tables = shuttle.env.tables(); - var schemas = array(tables.map(table -> object(Map.of( - "name", new TextNode(table.getName()), - "fields", array(table.getColumnNames().map(TextNode::new)), - "types", array(table.getColumnTypes().map(type -> new TextNode(type.toString()))), - "nullable", array(table.getColumnTypes().map(RelDataType::isNullable).map(RelJSONShuttle::bool)), - "key", array(Seq.from(table.getKeys().map(key -> array(Seq.from(key).map(IntNode::new))))), - "guaranteed", array(table.getConstraints() - .map(check -> new RexJSONVisitor(shuttle.env.advanced(table.getColumnNames().size())).serialize(check)).toImmutableSeq()) - )))); - - var main = object(Map.of("schemas", schemas, "queries", queries, "help", helps)); - mapper.writerWithDefaultPrettyPrinter().writeValue(path.toFile(), main); - } - - public static Result, String> deserializeFromJson(Path path) throws IOException { - var node = mapper.readTree(path.toFile()); - var env = Env.empty(); - var tables = array(node, "schemas").flatMap(schemas -> { - var collected = ImmutableSeq.empty(); - for (var schema : schemas) { - try { - var tys = unwrap(array(schema, "types")); - var nbs = unwrap(array(schema, "nullable")); - var nm = unwrap(object(schema, "name")); - var fds = unwrap(array(schema, "fields")).map(JsonNode::asText); - var kys = unwrap(array(schema, "key")); - var kgs = Set.from(kys.map(kg -> ImmutableBitSet.of(Seq.from(kg.elements()).map(JsonNode::asInt)))); - if (tys.size() != nbs.size()) { - return Result.err("Expecting corresponding types and nullabilities"); - } - var sts = tys.zip(nbs).map(tn -> (RelDataType) RelType.fromString(tn.component1().asText(), - tn.component2().asBoolean())); - collected = collected.appended(new QedTable(nm.asText(), fds, sts, kgs, Set.empty())); - } catch (Exception e) { - return Result.err( - String.format("Broken table schemas: %s in\n%s", e.getMessage(), schema.toPrettyString())); - } - } - return Result.ok(collected); - }); - if (tables.isErr()) { - return Result.err(tables.getErr()); - } - env.tables().appendAll(tables.get()); - var queries = array(node, "queries"); - if (queries.isErr()) { - return Result.err(queries.getErr()); - } - var shuttle = new RelJSONShuttle(env); - return queries.get().map(q -> { - var builder = RuleBuilder.create(); - tables.get().forEach(builder::addTable); - return shuttle.deserialize(builder, q); - }).foldLeft(Result.ok(ImmutableSeq.empty()), (qs, qb) -> qs.flatMap(s -> qb.map(b -> s.appended(b.build())))); - } - - public JsonNode serialize(RelNode rel) { - return switch (rel) { - case TableScan scan -> - object(Map.of("scan", new IntNode(env.resolve(scan.getTable().unwrap(QedTable.class))))); - case LogicalValues values -> { - var visitor = new RexJSONVisitor(env); - var schema = array(Seq.from(values.getRowType().getFieldList()) - .map(field -> new TextNode(field.getType().toString()))); - var records = array(Seq.from(values.getTuples()) - .map(tuple -> array(Seq.from(tuple).map(visitor::serialize)))); - yield object(Map.of("values", object(Map.of("schema", schema, "content", records)))); - } - case LogicalFilter filter -> { - var visitor = new RexJSONVisitor(env.advanced(filter.getInput().getRowType().getFieldCount()) - .recorded(filter.getVariablesSet())); - yield object(Map.of("filter", - object(Map.of("condition", visitor.serialize(filter.getCondition()), "source", - serialize(filter.getInput()))))); - } - case LogicalProject project -> { - var visitor = new RexJSONVisitor(env.advanced(project.getInput().getRowType().getFieldCount()) - .recorded(project.getVariablesSet())); - var targets = array(Seq.from(project.getProjects()).map(visitor::serialize)); - yield object( - Map.of("project", object(Map.of("target", targets, "source", serialize(project.getInput()))))); - } - case LogicalJoin join -> { - var left = join.getLeft(); - var right = join.getRight(); - var visitor = new RexJSONVisitor( - env.advanced(left.getRowType().getFieldCount() + right.getRowType().getFieldCount()) - .recorded(join.getVariablesSet())); - yield object(Map.of("join", - object(Map.of("kind", new TextNode(join.getJoinType().toString()), "condition", - visitor.serialize(join.getCondition()), "left", serialize(left), "right", - serialize(right))))); - } - case LogicalCorrelate correlate -> { - var rightShuttle = new RelJSONShuttle(env.advanced(correlate.getLeft().getRowType().getFieldCount()) - .recorded(correlate.getVariablesSet()).advanced(0)); - yield object(Map.of("correlate", - array(Seq.of(serialize(correlate.getLeft()), rightShuttle.serialize(correlate.getRight()))))); - } - case LogicalAggregate aggregate -> { - var groupCount = aggregate.getGroupCount(); - var level = env.base(); - var types = Seq.from(aggregate.getInput().getRowType().getFieldList()) - .map(type -> new TextNode(type.getType().toString())); - var keyCols = array(Seq.from(aggregate.getGroupSet()) - .map(key -> object(Map.of("column", new IntNode(level + key), "type", types.get(key))))); - var keys = object(Map.of("project", - object(Map.of("target", keyCols, "source", serialize(aggregate.getInput()))))); - var conditions = array(Seq.from(aggregate.getGroupSet()).mapIndexed((i, key) -> { - var type = types.get(key); - var leftCol = object(Map.of("column", new IntNode(level + i), "type", type)); - var rightCol = object(Map.of("column", new IntNode(level + groupCount + key), "type", type)); - return object( - Map.of("operator", new TextNode("<=>"), "operand", array(Seq.of(leftCol, rightCol)), "type", - new TextNode("BOOLEAN"))); - })); - var condition = object(Map.of("operator", new TextNode("AND"), "operand", conditions, "type", - new TextNode("BOOLEAN"))); - var aggs = array(Seq.from(aggregate.getAggCallList()).map(call -> object( - Map.of("operator", new TextNode(call.getAggregation().getName()), "operand", - array(Seq.from(call.getArgList()).map(target -> object( - Map.of("column", new IntNode(level + groupCount + target), "type", - types.get(target))))), "distinct", bool(call.isDistinct()), - "ignoreNulls", bool(call.ignoreNulls()), "type", - new TextNode(call.getType().toString()))))); - var aggregated = object(Map.of("aggregate", object(Map.of("function", aggs, "source", - object(Map.of("filter", object(Map.of("condition", condition, "source", - new RelJSONShuttle(env.lifted(groupCount)).serialize(aggregate.getInput()))))))))); - yield object(Map.of("distinct", object(Map.of("correlate", array(Seq.of(keys, aggregated)))))); - } - case LogicalUnion union -> { - var result = object(Map.of("union", array(Seq.from(union.getInputs()).map(this::serialize)))); - yield union.all ? result : object(Map.of("distinct", result)); - } - case LogicalIntersect intersect when !intersect.all -> - object(Map.of("intersect", array(Seq.from(intersect.getInputs()).map(this::serialize)))); - case LogicalMinus minus when !minus.all -> - object(Map.of("except", array(Seq.from(minus.getInputs()).map(this::serialize)))); - case LogicalSort sort -> { - var types = Seq.from(sort.getInput().getRowType().getFieldList()) - .map(type -> new TextNode(type.getType().toString())); - var collations = array(Seq.from(sort.collation.getFieldCollations()).map(collation -> { - var index = collation.getFieldIndex(); - return array(Seq.of(new IntNode(index), types.get(index), new TextNode(collation.shortString()))); - })); - var args = object(Map.of("collation", collations, "source", serialize(sort.getInput()))); - var visitor = new RexJSONVisitor(env.advanced(sort.getInput().getRowType().getFieldCount())); - if (sort.offset != null) { - args.set("offset", visitor.serialize(sort.offset)); - } - if (sort.fetch != null) { - args.set("limit", visitor.serialize(sort.fetch)); - } - yield object(Map.of("sort", args)); - } - default -> throw new RuntimeException("Not implemented: " + rel.getRelTypeName()); - }; - } - - public Result deserialize(RuleBuilder builder, JsonNode jsonNode) { - var entry = jsonNode.fields().next(); - var kind = entry.getKey(); - var content = entry.getValue(); - return switch (kind) { - case String k when k.equals("scan") -> { - if (content.isInt() && 0 <= content.asInt() && content.asInt() < env.tables().size()) { - builder.scan(env.tables().get(content.asInt()).getName()); - yield Result.ok(builder); - } - yield Result.err(String.format("Missing table with index %s", content.toPrettyString())); - } - case String k when k.equals("values") -> { - try { - var et = unwrap(array(content, "schema")); - var rt = new RelRecordType(StructKind.FULLY_QUALIFIED, et.mapIndexed( - (i, t) -> (RelDataTypeField) new RelDataTypeFieldImpl(String.format("VALUES-%s", i), i, - RelType.fromString(t.asText(), true))).asJava()); - var vs = unwrap(array(content, "content")); - var vals = ImmutableSeq.>empty(); - for (var v : vs) { - var val = ImmutableSeq.empty(); - if (!v.isArray()) { - yield Result.err("Expecting tuple (JSON list) as value"); - } - for (var jl : Seq.from(v.elements())) { - var l = unwrap(new RexJSONVisitor(env).deserialize(builder, jl)); - if (l instanceof RexLiteral) { - val = val.appended((RexLiteral) l); - } else { - yield Result.err("Expecting literal expression"); - } - } - vals = vals.appended(val.asJava()); - } - builder.values(vals.asJava(), rt); - yield Result.ok(builder); - } catch (Exception e) { - yield Result.err(e.getMessage()); - } - } - case String k when k.equals("filter") -> { - try { - var cond = unwrap(object(content, "condition")); - var source = unwrap(object(content, "source")); - var bs = unwrap(deserialize(builder, source)); - var c = unwrap(new RexJSONVisitor(env).deserialize(builder, cond)); - bs.filter(c); - yield Result.ok(bs); - } catch (Exception e) { - yield Result.err(e.getMessage()); - } - } - case String k when k.equals("project") -> { - try { - var target = unwrap(array(content, "target")); - var source = unwrap(object(content, "source")); - var bs = unwrap(deserialize(builder, source)); - var ps = target.mapChecked(t -> unwrap(new RexJSONVisitor(env).deserialize(builder, t))); - bs.project(ps); - yield Result.ok(bs); - } catch (Exception e) { - yield Result.err(e.getMessage()); - } - } - case String k when k.equals("join") -> Result.err("Not implemented yet"); - case String k when k.equals("correlate") -> Result.err("Not implemented yet"); - default -> Result.err(String.format("Unrecognized node:\n%s", jsonNode.toPrettyString())); - }; - } - - public record RexJSONVisitor(Env env) { - public JsonNode serialize(RexNode rex) { - return switch (rex) { - case RexInputRef inputRef -> - object(Map.of("column", new IntNode(inputRef.getIndex() + env.base()), "type", - new TextNode(inputRef.getType().toString()))); - case RexLiteral literal -> object(Map.of("operator", - new TextNode(literal.getValue() == null ? "NULL" : literal.getValue().toString()), "operand", - array(Seq.empty()), "type", new TextNode(literal.getType().toString()))); - case RexSubQuery subQuery -> - object(Map.of("operator", new TextNode(subQuery.getOperator().toString()), "operand", - array(Seq.from(subQuery.getOperands()).map(this::serialize)), "query", - new RelJSONShuttle(env.advanced(0)).serialize(subQuery.rel), "type", - new TextNode(subQuery.getType().toString()))); - case RexCall call -> object(Map.of("operator", new TextNode(call.getOperator().toString()), "operand", - array(Seq.from(call.getOperands()).map(this::serialize)), "type", - new TextNode(call.getType().toString()))); - case RexFieldAccess fieldAccess -> object(Map.of("column", new IntNode( - fieldAccess.getField().getIndex() + - env.resolve(((RexCorrelVariable) fieldAccess.getReferenceExpr()).id)), "type", - new TextNode(fieldAccess.getType().toString()))); - default -> throw new RuntimeException("Not implemented: " + rex.getKind()); - }; - } - - public Result deserialize(RuleBuilder builder, JsonNode jsonNode) { - if (jsonNode.has("column") && jsonNode.get("column").isInt()) { - // WARNING: THIS IS WRONG! NO ENVIRONMENT CONSIDERED! - return Result.ok(builder.field(jsonNode.get("column").asInt())); - } else if (jsonNode.has("operator") && jsonNode.get("operator").isTextual()) { - var op = jsonNode.get("operator").asText(); - try { - var args = unwrap(array(jsonNode, "operand")); - var ty = RelType.fromString(unwrap(object(jsonNode, "type")).asText(), true); - if (args.isEmpty()) { - return Result.ok(RexLiteral.fromJdbcString(ty, ty.getSqlTypeName(), op)); - } else { - var fields = args.mapChecked(expr -> unwrap(deserialize(builder, expr))); - for (var refl : Seq.from(SqlStdOperatorTable.class.getDeclaredFields()) - .filter(f -> java.lang.reflect.Modifier.isPublic(f.getModifiers()) && - java.lang.reflect.Modifier.isStatic(f.getModifiers()))) { - var mist = refl.get(null); - if (mist instanceof SqlOperator sqlOperator && sqlOperator.getName().equals(op)) { - return Result.ok(builder.call(sqlOperator, fields)); - } - } - return Result.ok(builder.call(builder.genericProjectionOp(op, ty), fields)); - } - } catch (Exception e) { - return Result.err(e.getMessage()); - } - } - return Result.err(String.format("Unrecognized node:\n%s", jsonNode.toPrettyString())); - } - } -} \ No newline at end of file diff --git a/src/main/java/org/qed/RelRN.java b/src/main/java/org/qed/RelRN.java index 7035e96..bdc364f 100644 --- a/src/main/java/org/qed/RelRN.java +++ b/src/main/java/org/qed/RelRN.java @@ -4,6 +4,7 @@ import kala.collection.Set; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.util.ImmutableBitSet; import java.util.Arrays; @@ -45,16 +46,28 @@ default Seq joinFields(RelRN right) { semantics().getRowType().getFieldCount() + right.semantics().getRowType().getFieldCount()).toArray()); } + default RexRN.Pred pred(SqlOperator op) { + return new RexRN.Pred(op, fields()); + } + default RexRN.Pred pred(String name) { - return new RexRN.Pred(name, true, fields()); + return pred(RuleBuilder.create().genericPredicateOp(name, true)); + } + + default RexRN.Pred joinPred(SqlOperator op, RelRN right) { + return new RexRN.Pred(op, joinFields(right)); } default RexRN.Pred joinPred(String name, RelRN right) { - return new RexRN.Pred(name, true, joinFields(right)); + return joinPred(RuleBuilder.create().genericPredicateOp(name, true), right); + } + + default RexRN.Proj proj(SqlOperator op) { + return new RexRN.Proj(op, fields()); } default RexRN.Proj proj(String name, String type_name) { - return new RexRN.Proj(name, type_name, true, fields()); + return proj(RuleBuilder.create().genericProjectionOp(name, new RelType.VarType(type_name, true))); } default Filter filter(RexRN cond) { @@ -73,10 +86,14 @@ default Project project(String name, String type_name) { return project(proj(name, type_name)); } - default Join join(JoinRelType ty, RexRN cond, RelRN right) { + default Join join(Join.JoinType ty, RexRN cond, RelRN right) { return new Join(ty, cond, this, right); } + default Join join(JoinRelType ty, RexRN cond, RelRN right) { + return join(new Join.JoinType.ConcreteJoinType(ty), cond, right); + } + default Join join(JoinRelType ty, String name, RelRN right) {return join(ty, joinPred(name, right), right);} default Union union(boolean all, RelRN... sources) { @@ -115,10 +132,11 @@ public RelNode semantics() { } } - record Join(JoinRelType ty, RexRN cond, RelRN left, RelRN right) implements RelRN { + record Join(Join.JoinType ty, RexRN cond, RelRN left, RelRN right) implements RelRN { @Override public RelNode semantics() { - return RuleBuilder.create().push(left.semantics()).push(right.semantics()).join(ty, cond.semantics()).build(); + return RuleBuilder.create().push(left.semantics()).push(right.semantics()).join(ty.semantics(), + cond.semantics()).build(); } @Override @@ -126,6 +144,24 @@ public RexRN field(int ordinal) { return new RexRN.JoinField(ordinal, left, right); } + public interface JoinType { + JoinRelType semantics(); + + record ConcreteJoinType(JoinRelType type) implements JoinType { + @Override + public JoinRelType semantics() { + return type; + } + } + + record MetaJoinType(String name) implements JoinType { + @Override + public JoinRelType semantics() { + return JoinRelType.INNER; + } + } + } + } record Union(boolean all, Seq sources) implements RelRN { diff --git a/src/main/java/org/qed/RelType.java b/src/main/java/org/qed/RelType.java index b1619cd..d2b4328 100644 --- a/src/main/java/org/qed/RelType.java +++ b/src/main/java/org/qed/RelType.java @@ -41,6 +41,10 @@ protected void generateTypeString(StringBuilder sb, boolean withDetail) { public boolean isNullable() { return nullable; } + + public String getName() { + return "INTEGER"; + } } final class BaseType extends BasicSqlType implements RelType { diff --git a/src/main/java/org/qed/RexRN.java b/src/main/java/org/qed/RexRN.java index 873d932..81d174f 100644 --- a/src/main/java/org/qed/RexRN.java +++ b/src/main/java/org/qed/RexRN.java @@ -2,6 +2,7 @@ import kala.collection.Seq; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlOperator; public interface RexRN { @@ -23,14 +24,23 @@ static True trueLiteral() { RexNode semantics(); + default Pred pred(SqlOperator op) { + return new Pred(op, Seq.of(this)); + } + default Pred pred(String name) { - return new Pred(name, true, Seq.of(this)); + return pred(RuleBuilder.create().genericPredicateOp(name, true)); + } + + default Proj proj(SqlOperator op) { + return new Proj(op, Seq.of(this)); } default Proj proj(String name, String type_name) { - return new Proj(name, type_name, true, Seq.of(this)); + return proj(RuleBuilder.create().genericProjectionOp(name, new RelType.VarType(type_name, true))); } + record Field(int ordinal, RelRN source) implements RexRN { @Override @@ -49,21 +59,23 @@ public RexNode semantics() { } } - record Pred(String name, boolean nullable, Seq sources) implements RexRN { + record Pred(SqlOperator operator, Seq sources) implements RexRN { @Override public RexNode semantics() { var builder = RuleBuilder.create(); - return builder.call(builder.genericPredicateOp(name, nullable), sources.map(RexRN::semantics)); +// builder.genericPredicateOp(name, nullable) + return builder.call(operator, sources.map(RexRN::semantics)); } } - record Proj(String name, String type_name, boolean nullable, Seq sources) implements RexRN { + record Proj(SqlOperator operator, Seq sources) implements RexRN { @Override public RexNode semantics() { var builder = RuleBuilder.create(); - return builder.call(builder.genericProjectionOp(name, varType(type_name, nullable)), +// builder.genericProjectionOp(name, varType(type_name, nullable)) + return builder.call(operator, sources.map(RexRN::semantics)); } } From 5a56caf2a3c7ee7cca00c2a53cdc4147affc75ac Mon Sep 17 00:00:00 2001 From: yushinliang Date: Sun, 30 Mar 2025 22:52:57 -0700 Subject: [PATCH 05/78] fill out rules ProjectJoinTranspose, SemiJoinFilterTranspose, SemiJoinJoinTranspose, SemiJoinProjectTranspose, SemiJoinRemove --- .vscode/settings.json | 3 + src/main/java/org/qed/RRuleInstance.java | 94 ++++++++++++++++++++---- 2 files changed, 82 insertions(+), 15 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..c5f3f6b --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "java.configuration.updateBuildConfiguration": "interactive" +} \ No newline at end of file diff --git a/src/main/java/org/qed/RRuleInstance.java b/src/main/java/org/qed/RRuleInstance.java index a889253..9e5bc30 100644 --- a/src/main/java/org/qed/RRuleInstance.java +++ b/src/main/java/org/qed/RRuleInstance.java @@ -301,9 +301,22 @@ public String pred() { // // } -// record ProjectJoinTranspose() implements RRule { -// -// } + record ProjectJoinTranspose() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN proj = left.proj("proj", "Project_Type"); + static final String joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, joinCond, right).project(proj); + } + + @Override + public RelRN after() { + return left.project(proj).join(JoinRelType.INNER, joinCond, right); + } + } record ProjectMerge() implements RRule { static final RelRN source = RelRN.scan("Source", "Source_Type"); @@ -340,21 +353,72 @@ public RelRN after() { } } -// record SemiJoinFilterTranspose() implements RRule { -// -// } + record SemiJoinFilterTranspose() implements RRule { + static final RelRN left = RelRN.scan("left", "Left_Type"); + static final RelRN right = RelRN.scan("right", "Right_Type"); + static final RexRN pred = left.pred("pred"); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.SEMI, joinCond, right).filter(pred); + } + + @Override + public RelRN after() { + return left.filter(pred).join(JoinRelType.SEMI, joinCond, right); + } + } -// record SemiJoinJoinTranspose() implements RRule { -// -// } + record SemiJoinJoinTranspose() implements RRule { + static final RelRN r = RelRN.scan("R", "R_Type"); + static final RelRN s = RelRN.scan("S", "S_Type"); + static final RelRN t = RelRN.scan("T", "T_Type"); + static final RexRN semiCond = r.joinPred("semi", s); + static final RexRN joinCond = r.joinPred("join", t); -// record SemiJoinProjectTranspose() implements RRule { -// -// } + @Override + public RelRN before() { + return r.join(JoinRelType.INNER, joinCond, t).join(JoinRelType.SEMI, semiCond, s); + } -// record SemiJoinRemove() implements RRule { -// -// } + @Override + public RelRN after() { + return r.join(JoinRelType.SEMI, semiCond, s).join(JoinRelType.INNER, joinCond, t); + } + } + + record SemiJoinProjectTranspose() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN proj = left.proj("proj", "Project_Type"); + static final RexRN semiCond = left.joinPred("semi", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.SEMI, semiCond, right).project(proj); + } + + @Override + public RelRN after() { + return left.project(proj).join(JoinRelType.SEMI, semiCond, right); + } + } + + record SemiJoinRemove() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + + @Override + public RelRN before() { + return left.join(JoinRelType.SEMI, RexRN.trueLiteral(), right); + } + + @Override + public RelRN after() { + return left; + } + } // record UnionMerge() implements RRule { // From be547496c2a939b7773919c8ed94597d3ba6cd48 Mon Sep 17 00:00:00 2001 From: yushinliang Date: Tue, 1 Apr 2025 16:24:41 -0700 Subject: [PATCH 06/78] try generate Calcite code with ProjectJoinTranspose --- .../java/org/qed/Generated/CalciteTester.java | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/main/java/org/qed/Generated/CalciteTester.java b/src/main/java/org/qed/Generated/CalciteTester.java index 61d250b..1c4f92a 100644 --- a/src/main/java/org/qed/Generated/CalciteTester.java +++ b/src/main/java/org/qed/Generated/CalciteTester.java @@ -3,6 +3,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import kala.collection.Seq; import kala.tuple.Tuple; + +import org.apache.calcite.jdbc.CalcitePrepare.SparkHandler.RuleSetBuilder; import org.apache.calcite.plan.RelOptRule; import org.apache.calcite.plan.hep.HepPlanner; import org.apache.calcite.plan.hep.HepProgramBuilder; @@ -52,6 +54,34 @@ public static void main(String[] args) throws IOException { for (var rule : rules.family()) { new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, STR."\{rule.name()}-\{rule.info()}.json").toFile(), rule.toJson()); } + + + var r = new RRuleInstance.ProjectJoinTranspose(); + new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, STR."\{r.name()}-\{r.info()}.json").toFile(), r.toJson()); + + generate(); + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + builder.addTable(table); + var before = builder.scan(table.getName()) + .scan(table.getName()) + .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) + .project(builder.call(builder.genericProjectionOp("proj", RelType.fromString("INTEGER", true)), builder.fields(0))) + .build(); + var leftProjected = builder.scan(table.getName()) + .project(builder.call(builder.genericProjectionOp("proj", RelType.fromString("INTEGER", true)), builder.fields(0))) + .build(); + var after = builder.push(leftProjected) + .push(builder.scan(table.getName())) + .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) + .build(); + var runner = loadRule(ProjectJoinTranspose.Config.DEFAULT.toRule()); + var tester = new CalciteTester(); + tester.verify(runner, before, after); + + + // generate(); // var tester = new CalciteTester(); // var builder = RuleBuilder.create(); From e8e58f05a92b8367a786267e3266fe6b241128f7 Mon Sep 17 00:00:00 2001 From: yushinliang Date: Wed, 2 Apr 2025 01:22:45 -0700 Subject: [PATCH 07/78] try generating calcite code for rules --- pom.xml | 2 - .../java/org/qed/Generated/CalciteTester.java | 60 ++++++++++++++++++- 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/pom.xml b/pom.xml index 71ae84d..e0f9b76 100644 --- a/pom.xml +++ b/pom.xml @@ -105,8 +105,6 @@ io.github cvc5 1.1.1 - system - ${env.CVC5_JAVA} diff --git a/src/main/java/org/qed/Generated/CalciteTester.java b/src/main/java/org/qed/Generated/CalciteTester.java index 1c4f92a..7e54ccc 100644 --- a/src/main/java/org/qed/Generated/CalciteTester.java +++ b/src/main/java/org/qed/Generated/CalciteTester.java @@ -54,7 +54,6 @@ public static void main(String[] args) throws IOException { for (var rule : rules.family()) { new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, STR."\{rule.name()}-\{rule.info()}.json").toFile(), rule.toJson()); } - var r = new RRuleInstance.ProjectJoinTranspose(); new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, STR."\{r.name()}-\{r.info()}.json").toFile(), r.toJson()); @@ -64,6 +63,7 @@ public static void main(String[] args) throws IOException { var builder = RuleBuilder.create(); var table = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); builder.addTable(table); + var before = builder.scan(table.getName()) .scan(table.getName()) .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) @@ -77,9 +77,63 @@ public static void main(String[] args) throws IOException { .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) .build(); var runner = loadRule(ProjectJoinTranspose.Config.DEFAULT.toRule()); - var tester = new CalciteTester(); tester.verify(runner, before, after); - + + before = builder.scan(table.getName()) + .scan(table.getName()) + .join(JoinRelType.SEMI, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) + .filter(builder.call(builder.genericPredicateOp("pred", true), builder.fields())) + .build(); + var leftFiltered = builder.scan(table.getName()).filter(builder.call(builder.genericPredicateOp("pred", true), builder.fields())) + after = builder.push(leftFiltered) + .push(builder.scan(table.getName())) + .join(JoinRelType.SEMI, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) + .build(); + runner = loadRule(SemiJoinFilterTranspose.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + + var semiFirst = builder.scan(table.getName()) + .scan(table.getName()) + .join(JoinRelType.SEMI, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) + .build(); + before = builder.push(semiFirst) + .push(builder.scan(table.getName())) + .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) + .build(); + var innerFirst = builder.scan(table.getName()) + .scan(table.getName()) + .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) + .build(); + after = builder.push(innerFirst) + .push(builder.scan(table.getName())) + .join(JoinRelType.SEMI, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) + .build(); + runner = loadRule(SemiJoinJoinTranspose.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + + before = builder.scan(table.getName()) + .scan(table.getName()) + .join(JoinRelType.SEMI, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) + .project(builder.call(builder.genericProjectionOp("proj", RelType.fromString("INTEGER", true)), builder.fields(0))) + .build(); + leftProjected = builder.scan(table.getName()) + .project(builder.call(builder.genericProjectionOp("proj", RelType.fromString("INTEGER", true)), builder.fields(0))) + .build(); + after = builder.push(leftProjected) + .push(builder.scan(table.getName())) + .join(JoinRelType.SEMI, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) + .build(); + runner = loadRule(SemiJoinProjectTranspose.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + + before = builder.scan(table.getName()) + .scan(table.getName()) + .join(JoinRelType.SEMI, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) + .build(); + after = builder.scan(table.getName()).build(); + runner = loadRule(SemiJoinRemove.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + // generate(); From 51d00cd4ae887578691a26c2ba68b981f3301e7c Mon Sep 17 00:00:00 2001 From: wkaiz Date: Mon, 24 Feb 2025 17:32:13 -0800 Subject: [PATCH 08/78] changing string templates back to strings --- src/main/java/org/qed/CodeGenerator.java | 2 +- .../org/qed/Generated/CalciteGenerator.java | 46 +++++------ .../java/org/qed/Generated/CalciteTester.java | 79 +++++++++---------- src/main/java/org/qed/RRule.java | 9 ++- src/main/java/org/qed/RelRN.java | 2 +- 5 files changed, 71 insertions(+), 67 deletions(-) diff --git a/src/main/java/org/qed/CodeGenerator.java b/src/main/java/org/qed/CodeGenerator.java index 3fde6a0..fa0d1c6 100644 --- a/src/main/java/org/qed/CodeGenerator.java +++ b/src/main/java/org/qed/CodeGenerator.java @@ -3,7 +3,7 @@ public interface CodeGenerator { default String unimplemented(String context, Object object) { - return STR."<--\{context}\{object.getClass().getName()}-->"; + return "<--" + context + object.getClass().getName() + "-->"; } default E unimplementedOnMatch(E env, Object object) { diff --git a/src/main/java/org/qed/Generated/CalciteGenerator.java b/src/main/java/org/qed/Generated/CalciteGenerator.java index 1262b2c..47581c6 100644 --- a/src/main/java/org/qed/Generated/CalciteGenerator.java +++ b/src/main/java/org/qed/Generated/CalciteGenerator.java @@ -25,7 +25,7 @@ public Env preTransform(Env env) { @Override public Env postTransform(Env env) { - return env.state(STR."call.transformTo(\{env.current()}.build());"); + return env.state("call.transformTo(" + env.current() + ".build());"); } @Override @@ -36,8 +36,8 @@ public String translate(String name, Env onMatch, Env transform) { builder.append("import org.apache.calcite.rel.RelNode;\n"); builder.append("import org.apache.calcite.rel.core.JoinRelType;\n"); builder.append("import org.apache.calcite.rel.logical.*;\n\n"); - builder.append(STR."public class \{name} extends RelRule<\{name}.Config> {\n"); - builder.append(STR."\tprotected \{name}(Config config) {\n"); + builder.append("public class " + name + " extends RelRule<" + name + ".Config> {\n"); + builder.append("\tprotected " + name + "(Config config) {\n"); builder.append("\t\tsuper(config);\n"); builder.append("\t}\n\n"); builder.append("\t@Override\n\tpublic void onMatch(RelOptRuleCall call) {\n"); @@ -45,14 +45,14 @@ public String translate(String name, Env onMatch, Env transform) { builder.append("\t}\n\n"); builder.append("\tpublic interface Config extends EmptyConfig {\n"); builder.append("\t\tConfig DEFAULT = new Config() {};\n\n"); - builder.append(STR."\t\t@Override\n\t\tdefault \{name} toRule() {\n"); - builder.append(STR."\t\t\treturn new \{name}(this);\n"); + builder.append("\t\t@Override\n\t\tdefault " + name + " toRule() {\n"); + builder.append("\t\t\treturn new " + name + "(this);\n"); builder.append("\t\t}\n\n"); builder.append("\t\t@Override\n\t\tdefault String description() {\n"); - builder.append(STR."\t\t\treturn \"\{name}\";\n"); + builder.append("\t\t\treturn \"" + name + "\";\n"); builder.append("\t\t}\n\n"); builder.append("\t\t@Override\n\t\tdefault RelRule.OperandTransform operandSupplier() {\n"); - builder.append(STR."\t\t\treturn \{onMatch.skeleton()};\n"); + builder.append("\t\t\treturn " + onMatch.skeleton() + ";\n"); builder.append("\t\t}\n\n"); builder.append("\t}\n"); builder.append("}\n"); @@ -67,8 +67,8 @@ public Env onMatchScan(Env env, RelRN.Scan scan) { @Override public Env onMatchFilter(Env env, RelRN.Filter filter) { var source_match = onMatch(env.next(), filter.source()); - var operator_match = source_match.grow(STR."operand(LogicalFilter.class).oneInput(\{source_match.skeleton()})"); - var condition_match = operator_match.focus(STR."((LogicalFilter) \{env.current()}).getCondition()"); + var operator_match = source_match.grow("operand(LogicalFilter.class).oneInput(" + source_match.skeleton() + ")"); + var condition_match = operator_match.focus("((LogicalFilter) " + env.current() + ").getCondition()"); return onMatch(condition_match, filter.cond()); } @@ -76,8 +76,8 @@ public Env onMatchFilter(Env env, RelRN.Filter filter) { public Env onMatchProject(Env env, RelRN.Project project) { var source_match = onMatch(env.next(), project.source()); var operator_match = - source_match.grow(STR."operand(LogicalProject.class).oneInput(\{source_match.skeleton()})"); - var map_match = operator_match.focus(STR."((LogicalProject) \{env.current()}).getProjects()"); + source_match.grow("operand(LogicalProject.class).oneInput(" + source_match.skeleton() + ")"); + var map_match = operator_match.focus("((LogicalProject) " + env.current() + ").getProjects()"); return onMatch(map_match, project.map()); } @@ -93,21 +93,21 @@ public Env onMatchProj(Env env, RexRN.Proj proj) { @Override public Env onMatchJoin(Env env, RelRN.Join join) { - var current_join = STR."((LogicalJoin) \{env.current()})"; + var current_join = "((LogicalJoin) " + env.current() + ")"; // STR."\{join_env.current()}.getJoinType()" var left_source_env = env.next(); var left_match_env = onMatch(left_source_env, join.left()); var right_source_env = left_match_env.next(); var right_match_env = onMatch(right_source_env, join.right()); var operator_match = - right_match_env.grow(STR."operand(LogicalJoin.class).inputs(\{left_match_env.skeleton()}, \{right_match_env.skeleton()})"); - var cond_source_env = operator_match.focus(STR."\{current_join}.getCondition()"); + right_match_env.grow("operand(LogicalJoin.class).inputs(" + left_match_env.skeleton() + ", " + right_match_env.skeleton() + ")"); + var cond_source_env = operator_match.focus(current_join + ".getCondition()"); return onMatch(cond_source_env, join.cond()); } @Override public Env transformScan(Env env, RelRN.Scan scan) { - return env.focus(STR."\{env.current()}.push(\{env.symbols().get(scan.name())})"); + return env.focus(env.current() + ".push(" + env.symbols().get(scan.name()) + ")"); } // @Override @@ -131,7 +131,7 @@ public Env transformFilter(Env env, RelRN.Filter filter) { var source_transform = transform(env, filter.source()); var source_expression = source_transform.current(); var cond_transform = transform(source_transform, filter.cond()); - return cond_transform.focus(STR."\{source_expression}.filter(\{cond_transform.current()})"); + return cond_transform.focus(source_expression + ".filter(" + cond_transform.current() + ")"); } @Override @@ -153,7 +153,7 @@ public Env transformJoin(Env env, RelRN.Join join) { case SEMI -> "JoinRelType.SEMI"; case ANTI -> "JoinRelType.ANTI"; }; - return cond_transform.focus(STR."\{source_expression}.join(\{join_type}, \{cond_transform.current()})"); + return cond_transform.focus(source_expression + ".join(" + join_type + ", " + cond_transform.current() + ")"); } @Override @@ -165,7 +165,7 @@ public Env transformAnd(Env env, RexRN.And and) { operands = operands.appended(source_transform.current()); source_transform = source_transform.focus(env.current()); } - return source_transform.focus(STR."\{env.current()}.and(\{operands.joinToString(", ")})"); + return source_transform.focus(env.current() + ".and(" + operands.joinToString(", ") + ")"); } public record Env(AtomicInteger varId, int rel, String current, String skeleton, Seq statements, @@ -176,7 +176,7 @@ public static Env empty() { } public Env next() { - return new Env(varId, rel + 1, STR."call.rel(\{rel + 1})", skeleton, statements, symbols); + return new Env(varId, rel + 1, "call.rel(" + (rel + 1) + ")", skeleton, statements, symbols); } public Env focus(String target) { @@ -192,13 +192,13 @@ public Env symbol(String symbol, String expression) { } public Tuple2 declare(String expression) { - var name = STR."var_\{varId.getAndIncrement()}"; - return Tuple.of(name, state(STR."var \{name} = \{expression};")); + var name = "var_" + varId.getAndIncrement(); + return Tuple.of(name, state("var " + name + " = " + expression + ";")); } public Env grow(String requirement) { - var vn = STR."s_\{varId.getAndIncrement()}"; - return new Env(varId, rel, current, STR."\{vn} -> \{vn}.\{requirement}", statements, symbols); + var vn = "s_" + varId.getAndIncrement(); + return new Env(varId, rel, current, vn + " -> " + vn + "." + requirement, statements, symbols); } } } diff --git a/src/main/java/org/qed/Generated/CalciteTester.java b/src/main/java/org/qed/Generated/CalciteTester.java index 61d250b..2b612b8 100644 --- a/src/main/java/org/qed/Generated/CalciteTester.java +++ b/src/main/java/org/qed/Generated/CalciteTester.java @@ -7,10 +7,9 @@ import org.apache.calcite.plan.hep.HepPlanner; import org.apache.calcite.plan.hep.HepProgramBuilder; import org.apache.calcite.rel.RelNode; -import org.qed.JSONDeserializer; -import org.qed.JSONSerializer; -import org.qed.RRule; -import org.qed.RRuleInstance; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.*; import java.io.File; import java.io.IOException; @@ -38,7 +37,7 @@ public static Seq ruleList() { } public static void verify() { - ruleList().forEachUnchecked(rule -> rule.dump(STR."\{rulePath}/\{rule.name()}.json")); + ruleList().forEachUnchecked(rule -> rule.dump(rulePath + "/" + rule.name() + ".json")); } public static void generate() { @@ -50,55 +49,55 @@ public static void main(String[] args) throws IOException { var rules = new RRuleInstance.JoinAssociate(); Files.createDirectories(Path.of(rulePath)); for (var rule : rules.family()) { - new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, STR."\{rule.name()}-\{rule.info()}.json").toFile(), rule.toJson()); + new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); } -// generate(); -// var tester = new CalciteTester(); -// var builder = RuleBuilder.create(); -// var table = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); -// builder.addTable(table); -// var before = builder.scan(table.getName()) -// .filter(builder.call(builder.genericPredicateOp("inner", true), builder.fields())) -// .filter(builder.call(builder.genericPredicateOp("outer", true), builder.fields())) -// .build(); -// var after = builder.scan(table.getName()).filter(builder.call(SqlStdOperatorTable.AND, -// builder.call(builder.genericPredicateOp("inner", true), builder.fields()), -// builder.call(builder.genericPredicateOp("outer", true), builder.fields()))) -// .build(); -// var runner = loadRule(FilterMerge.Config.DEFAULT.toRule()); -// tester.verify(runner, before, after); -// before = builder.scan(table.getName()) -// .scan(table.getName()) -// .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) -// .filter(builder.call(builder.genericPredicateOp("pred", true), builder.fields())) -// .build(); -// after = builder.scan(table.getName()) -// .scan(table.getName()) -// .join(JoinRelType.INNER, builder.call(SqlStdOperatorTable.AND, -// builder.call(builder.genericPredicateOp("join", true), builder.joinFields()), -// builder.call(builder.genericPredicateOp("pred", true), builder.joinFields()))) -// .build(); -// runner = loadRule(FilterIntoJoin.Config.DEFAULT.toRule()); -// tester.verify(runner, before, after); + generate(); + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + builder.addTable(table); + var before = builder.scan(table.getName()) + .filter(builder.call(builder.genericPredicateOp("inner", true), builder.fields())) + .filter(builder.call(builder.genericPredicateOp("outer", true), builder.fields())) + .build(); + var after = builder.scan(table.getName()).filter(builder.call(SqlStdOperatorTable.AND, + builder.call(builder.genericPredicateOp("inner", true), builder.fields()), + builder.call(builder.genericPredicateOp("outer", true), builder.fields()))) + .build(); + var runner = loadRule(FilterMerge.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + before = builder.scan(table.getName()) + .scan(table.getName()) + .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) + .filter(builder.call(builder.genericPredicateOp("pred", true), builder.fields())) + .build(); + after = builder.scan(table.getName()) + .scan(table.getName()) + .join(JoinRelType.INNER, builder.call(SqlStdOperatorTable.AND, + builder.call(builder.genericPredicateOp("join", true), builder.joinFields()), + builder.call(builder.genericPredicateOp("pred", true), builder.joinFields()))) + .build(); + runner = loadRule(FilterIntoJoin.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); } public void serialize(RRule rule, String path) { var generator = new CalciteGenerator(); var code_gen = generator.generate(rule); try { - Files.write(Path.of(path, STR."\{rule.name()}.java"), code_gen.getBytes()); + Files.write(Path.of(path, rule.name() + ".java"), code_gen.getBytes()); } catch (IOException ioe) { System.err.println(ioe.getMessage()); } } public void test(RelOptRule rule, Seq tests) { - System.out.println(STR."Testing rule \{rule.getClass().getSimpleName()}"); + System.out.println("Testing rule " + rule.getClass().getSimpleName()); var runner = loadRule(rule); var exams = tests.mapUnchecked(t -> Tuple.of(t, JSONDeserializer.load(new File(t)))); for (var entry : exams) { if (entry.getValue().size() != 2) { - System.err.println(STR."\{entry.getKey()} does not have exactly two nodes, and thus is not a valid test"); + System.err.println(entry.getKey() + " does not have exactly two nodes, and thus is not a valid test"); continue; } verify(runner, entry.getValue().get(0), entry.getValue().get(1)); @@ -108,9 +107,9 @@ public void test(RelOptRule rule, Seq tests) { public void verify(HepPlanner runner, RelNode source, RelNode target) { runner.setRoot(source); var answer = runner.findBestExp(); - System.out.println(STR."> Given source RelNode:\n\{source.explain()}"); - System.out.println(STR."> Actual rewritten RelNode:\n\{answer.explain()}"); - System.out.println(STR."> Expected rewritten RelNode:\n\{target.explain()}"); + System.out.println("> Given source RelNode:\n" + source.explain()); + System.out.println("> Actual rewritten RelNode:\n" + answer.explain()); + System.out.println("> Expected rewritten RelNode:\n" + target.explain()); } } diff --git a/src/main/java/org/qed/RRule.java b/src/main/java/org/qed/RRule.java index 157c0e4..a548f9a 100644 --- a/src/main/java/org/qed/RRule.java +++ b/src/main/java/org/qed/RRule.java @@ -14,7 +14,12 @@ public interface RRule { RelRN after(); default String explain() { - return STR."\{getClass().getName()}\n\{before().semantics().explain()}=>\n\{after().semantics().explain()}"; + return getClass().getName() + + "\n" + + before().semantics().explain() + + "=>" + + "\n" + + after().semantics().explain(); } default String name() { @@ -101,7 +106,7 @@ public RelRN customReplaceMetaRelRN(RelRN relRN) { } public String info() { - return joinTypeAssignment.joinToString("&", (m, c) -> STR."\{m.name()}=\{c.semantics()}"); + return joinTypeAssignment.joinToString("&", (m, c) -> "{" + m.name() + "}=" + c.semantics()); } } diff --git a/src/main/java/org/qed/RelRN.java b/src/main/java/org/qed/RelRN.java index bdc364f..ec09567 100644 --- a/src/main/java/org/qed/RelRN.java +++ b/src/main/java/org/qed/RelRN.java @@ -112,7 +112,7 @@ record Scan(String name, RelType.VarType ty, boolean unique) implements RelRN { @Override public RelNode semantics() { - var table = new QedTable(name, Seq.of(STR."col-\{name}"), Seq.of(ty), unique ? + var table = new QedTable(name, Seq.of("col-" + name), Seq.of(ty), unique ? Set.of(ImmutableBitSet.of(0)) : Set.empty(), Set.empty()); return RuleBuilder.create().addTable(table).scan(name).build(); } From bb4027768e96009d83e955a87b90b1a8d40e5425 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Fri, 2 May 2025 16:05:22 -0700 Subject: [PATCH 09/78] add FilterSetOpTranspose IntersectMerge JoinExtractFilter SemiJoinFilterTranspose UnionMerge --- src/main/java/org/qed/CodeGenerator.java | 30 ++ .../org/qed/Generated/CalciteGenerator.java | 212 ++++++++++++ .../java/org/qed/Generated/CalciteTester.java | 310 ++++++++++++++++-- .../org/qed/Generated/FilterIntoJoin.java | 63 ++-- .../java/org/qed/Generated/FilterMerge.java | 61 ++-- .../qed/Generated/FilterProjectTranspose.java | 62 ++-- .../org/qed/Generated/FilterReduceFalse.java | 39 +++ .../org/qed/Generated/FilterReduceTrue.java | 39 +++ .../qed/Generated/FilterSetOpTranspose.java | 39 +++ .../org/qed/Generated/IntersectMerge.java | 39 +++ .../Generated/JoinAddRedundantSemiJoin.java | 39 +++ .../java/org/qed/Generated/JoinCommute.java | 39 +++ .../org/qed/Generated/JoinExtractFilter.java | 39 +++ .../java/org/qed/Generated/ProjectMerge.java | 39 +++ .../Generated/SemiJoinFilterTranspose.java | 39 +++ .../java/org/qed/Generated/UnionMerge.java | 39 +++ src/main/java/org/qed/RRuleInstance.java | 277 ++++++++++------ 17 files changed, 1184 insertions(+), 221 deletions(-) create mode 100644 src/main/java/org/qed/Generated/FilterReduceFalse.java create mode 100644 src/main/java/org/qed/Generated/FilterReduceTrue.java create mode 100644 src/main/java/org/qed/Generated/FilterSetOpTranspose.java create mode 100644 src/main/java/org/qed/Generated/IntersectMerge.java create mode 100644 src/main/java/org/qed/Generated/JoinAddRedundantSemiJoin.java create mode 100644 src/main/java/org/qed/Generated/JoinCommute.java create mode 100644 src/main/java/org/qed/Generated/JoinExtractFilter.java create mode 100644 src/main/java/org/qed/Generated/ProjectMerge.java create mode 100644 src/main/java/org/qed/Generated/SemiJoinFilterTranspose.java create mode 100644 src/main/java/org/qed/Generated/UnionMerge.java diff --git a/src/main/java/org/qed/CodeGenerator.java b/src/main/java/org/qed/CodeGenerator.java index fa0d1c6..85e1f6f 100644 --- a/src/main/java/org/qed/CodeGenerator.java +++ b/src/main/java/org/qed/CodeGenerator.java @@ -26,6 +26,7 @@ default E onMatch(E env, RelRN pattern) { case RelRN.Join join -> onMatchJoin(env, join); case RelRN.Union union -> onMatchUnion(env, union); case RelRN.Intersect intersect -> onMatchIntersect(env, intersect); + case RelRN.Empty empty -> onMatchEmpty(env, empty); default -> onMatchCustom(env, pattern); }; } @@ -39,6 +40,8 @@ default E onMatch(E env, RexRN pattern) { case RexRN.And and -> onMatchAnd(env, and); case RexRN.Or or -> onMatchOr(env, or); case RexRN.Not not -> onMatchNot(env, not); + case RexRN.True literal -> onMatchTrue(env, literal); + case RexRN.False literal -> onMatchFalse(env, literal); default -> onMatchCustom(env, pattern); }; } @@ -59,6 +62,7 @@ default E transform(E env, RelRN target) { case RelRN.Join join -> transformJoin(env, join); case RelRN.Union union -> transformUnion(env, union); case RelRN.Intersect intersect -> transformIntersect(env, intersect); + case RelRN.Empty empty -> transformEmpty(env, empty); default -> transformCustom(env, target); }; } @@ -72,6 +76,8 @@ default E transform(E env, RexRN target) { case RexRN.And and -> transformAnd(env, and); case RexRN.Or or -> transformOr(env, or); case RexRN.Not not -> transformNot(env, not); + case RexRN.True literal -> transformTrue(env, literal); + case RexRN.False literal -> transformFalse(env, literal); default -> transformCustom(env, target); }; } @@ -85,6 +91,7 @@ default String translate(String name, E onMatch, E transform) { } default String generate(RRule rule) { + System.out.printf("Generating Rule: %s\n", rule.name()); var onMatch = postMatch(onMatch(preMatch(), rule.before())); var transform = postTransform(transform(preTransform(onMatch), rule.after())); return translate(rule.getClass().getSimpleName(), onMatch, transform); @@ -150,6 +157,18 @@ default E onMatchCustom(E env, RexRN custom) { return unimplementedOnMatch(env, custom); } + default E onMatchTrue(E env, RexRN literal) { + return unimplementedOnMatch(env, literal); + } + + default E onMatchFalse(E env, RexRN literal) { + return unimplementedOnMatch(env, literal); + } + + default E onMatchEmpty(E env, RelRN.Empty empty) { + return unimplementedOnMatch(env, empty); + } + default E transformScan(E env, RelRN.Scan scan) { return unimplementedTransform(env, scan); } @@ -210,4 +229,15 @@ default E transformCustom(E env, RexRN custom) { return unimplementedTransform(env, custom); } + default E transformTrue(E env, RexRN literal) { + return unimplementedTransform(env, literal); + } + + default E transformFalse(E env, RexRN literal) { + return unimplementedTransform(env, literal); + } + + default E transformEmpty(E env, RelRN.Empty empty) { + return unimplementedTransform(env, empty); + } } diff --git a/src/main/java/org/qed/Generated/CalciteGenerator.java b/src/main/java/org/qed/Generated/CalciteGenerator.java index 47581c6..6792cca 100644 --- a/src/main/java/org/qed/Generated/CalciteGenerator.java +++ b/src/main/java/org/qed/Generated/CalciteGenerator.java @@ -7,6 +7,7 @@ import org.qed.CodeGenerator; import org.qed.RelRN; import org.qed.RexRN; +import org.qed.Generated.CalciteGenerator.Env; import java.util.concurrent.atomic.AtomicInteger; @@ -105,6 +106,118 @@ public Env onMatchJoin(Env env, RelRN.Join join) { return onMatch(cond_source_env, join.cond()); } + @Override + public Env onMatchAnd(Env env, RexRN.And and) { + // Process each source in the And condition + var current_env = env; + // Use a unique symbol name for the AND condition + String andSymbol = "and_" + env.varId.getAndIncrement(); + // Store the current expression as this AND node's symbol + current_env = current_env.symbol(andSymbol, current_env.current()); + + // Process each child source in the AND condition + for (var source : and.sources()) { + current_env = onMatch(current_env, source); + } + + return current_env; + } + + @Override + public Env onMatchUnion(Env env, RelRN.Union union) { + // Get the all flag from the union + boolean all = union.all(); + + // Process each source in the union + var current_env = env; + var skeletons = Seq.empty(); + + // Process all sources in the sequence + for (var source : union.sources()) { + var next_env = current_env.next(); + var source_env = onMatch(next_env, source); + skeletons = skeletons.appended(source_env.skeleton()); + current_env = source_env; + } + + // Build the input skeletons string for the operand + StringBuilder inputsBuilder = new StringBuilder(); + for (int i = 0; i < skeletons.size(); i++) { + if (i > 0) { + inputsBuilder.append(", "); + } + inputsBuilder.append(skeletons.get(i).toString()); + } + + // Create the union operand with the appropriate class based on the all flag + String operatorClass = all ? "LogicalUnionAll" : "LogicalUnion"; + return current_env.grow("operand(" + operatorClass + ".class).inputs(" + inputsBuilder.toString() + ")"); + } + + @Override + public Env onMatchIntersect(Env env, RelRN.Intersect intersect) { + // Get the all flag from the intersect + boolean all = intersect.all(); + + // Process each source in the intersect + var current_env = env; + var skeletons = Seq.empty(); + + // Process all sources in the sequence + for (var source : intersect.sources()) { + var next_env = current_env.next(); + var source_env = onMatch(next_env, source); + skeletons = skeletons.appended(source_env.skeleton()); + current_env = source_env; + } + + // Build the input skeletons string for the operand + StringBuilder inputsBuilder = new StringBuilder(); + for (int i = 0; i < skeletons.size(); i++) { + if (i > 0) { + inputsBuilder.append(", "); + } + inputsBuilder.append(skeletons.get(i).toString()); + } + + // Create the intersect operand with the appropriate class based on the all flag + String operatorClass = all ? "LogicalIntersectAll" : "LogicalIntersect"; + return current_env.grow("operand(" + operatorClass + ".class).inputs(" + inputsBuilder.toString() + ")"); + } + + @Override + public Env onMatchField(Env env, RexRN.Field field) { + // Generate a unique symbolic name for this field + String fieldSymbol = "field_" + env.varId.getAndIncrement(); + + // Store the field expression in the environment's symbol table + return env.symbol(fieldSymbol, env.current()); + } + + @Override + public Env onMatchTrue(Env env, RexRN literal) { + // Create a unique symbol name for this true literal + String trueSymbol = "true_" + env.varId.getAndIncrement(); + + // Store the current expression as this true literal's symbol + return env.symbol(trueSymbol, env.current()); + } + + @Override + public Env onMatchFalse(Env env, RexRN literal) { + // Create a unique symbol name for this false literal + String falseSymbol = "false_" + env.varId.getAndIncrement(); + + // Store the current expression as this false literal's symbol + return env.symbol(falseSymbol, env.current()); + } + + @Override + public Env onMatchEmpty(Env env, RelRN.Empty empty) { + return env.grow("operand(LogicalValues.class).noInputs()"); + } + + @Override public Env transformScan(Env env, RelRN.Scan scan) { return env.focus(env.current() + ".push(" + env.symbols().get(scan.name()) + ")"); @@ -168,6 +281,105 @@ public Env transformAnd(Env env, RexRN.And and) { return source_transform.focus(env.current() + ".and(" + operands.joinToString(", ") + ")"); } + @Override + public Env transformUnion(Env env, RelRN.Union union) { + // Get the all flag from the union + boolean all = union.all(); + + // The number of sources + int sourceCount = union.sources().size(); + + // Transform each source + var current_env = env; + for (var source : union.sources()) { + current_env = transform(current_env, source); + } + + // Use the union method with the all flag and source count + // This matches the Calcite RelBuilder.union(boolean all, int n) signature + return current_env.focus(current_env.current() + ".union(" + all + ", " + sourceCount + ")"); + } + + @Override + public Env transformIntersect(Env env, RelRN.Intersect intersect) { + // Get the all flag from the intersect + boolean all = intersect.all(); + + // The number of sources + int sourceCount = intersect.sources().size(); + + // Transform each source + var current_env = env; + for (var source : intersect.sources()) { + current_env = transform(current_env, source); + } + + // Use the intersect method with the all flag and source count + // This matches the expected Calcite RelBuilder.intersect(boolean all, int n) signature + String methodName = all ? "intersectAll" : "intersect"; + return current_env.focus(current_env.current() + "." + methodName + "(" + all + ", " + sourceCount + ")"); + } + + @Override + public Env transformField(Env env, RexRN.Field field) { + // In Calcite, field references are typically created with a "field" method + // We'll need to pass some identifier for the field - use toString() if no specific field accessor is available + + // Assuming field has a method that returns some kind of identifier or name + // If not, we may need to adjust this implementation + return env.focus(env.current() + ".field(" + field + ")"); + } + + @Override + public Env transformProj(Env env, RexRN.Proj proj) { + // In Calcite, projections are typically created using the operator name + // This is similar to your transformPred implementation + + // Look up the symbol from the matching phase + if (!env.symbols().containsKey(proj.operator().getName())) { + throw new RuntimeException("Operator symbol not found: " + proj.operator().getName() + + ". Make sure onMatchProj is properly implemented."); + } + + // Return an environment focused on the expression for this projection + return env.focus(env.symbols().get(proj.operator().getName())); + } + + @Override + public Env transformProject(Env env, RelRN.Project project) { + // First transform the source relation + var source_transform = transform(env, project.source()); + var source_expression = source_transform.current(); + + // Then transform the projection map + var map_transform = transform(source_transform, project.map()); + + // Combine the source and projection using the project operation + // This creates a projection on top of the source relation + return map_transform.focus(source_expression + ".project(" + map_transform.current() + ")"); + } + + @Override + public Env transformTrue(Env env, RexRN literal) { + // In Calcite, true literals are typically represented using the + // rexBuilder.makeLiteral(true) method or just "TRUE" + return env.focus(env.current() + ".literal(true)"); + } + + @Override + public Env transformFalse(Env env, RexRN literal) { + // In Calcite, false literals are represented using the + // rexBuilder.makeLiteral(false) method or just "FALSE" + return env.focus(env.current() + ".literal(false)"); + } + + @Override + public Env transformEmpty(Env env, RelRN.Empty empty) { + // In Calcite, empty relations are created using the values() method with no tuples + // This creates a LogicalValues node with no rows + return env.focus(env.current() + ".empty()"); + } + public record Env(AtomicInteger varId, int rel, String current, String skeleton, Seq statements, ImmutableMap symbols) { public static Env empty() { diff --git a/src/main/java/org/qed/Generated/CalciteTester.java b/src/main/java/org/qed/Generated/CalciteTester.java index 2b612b8..492d108 100644 --- a/src/main/java/org/qed/Generated/CalciteTester.java +++ b/src/main/java/org/qed/Generated/CalciteTester.java @@ -8,6 +8,9 @@ import org.apache.calcite.plan.hep.HepProgramBuilder; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.LogicalFilter; +import org.apache.calcite.rel.logical.LogicalValues; +import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.qed.*; @@ -23,6 +26,7 @@ public class CalciteTester { public static String rulePath = "rules"; public static HepPlanner loadRule(RelOptRule rule) { + System.out.printf("Verifying Rule: %s\n", rule.getClass()); var builder = new HepProgramBuilder().addRuleInstance(rule); return new HepPlanner(builder.build()); } @@ -31,9 +35,11 @@ public static Seq ruleList() { var individuals = Seq.from(RRuleInstance.class.getClasses()).filter(RRule.class::isAssignableFrom).mapUnchecked(Class::getConstructor).mapUnchecked(Constructor::newInstance).map(r -> (RRule) r); System.out.println(Seq.from(RRuleInstance.class.getClasses()).filter(RRule.RRuleFamily.class::isAssignableFrom).mapUnchecked(Class::getConstructor)); - var families = - Seq.from(RRuleInstance.class.getClasses()).filter(RRule.RRuleFamily.class::isAssignableFrom).mapUnchecked(Class::getConstructor).mapUnchecked(Constructor::newInstance).map(r -> (RRule.RRuleFamily) r); - return individuals.appendedAll(families.flatMap(RRule.RRuleFamily::family)); + /* To be restored */ + // var families = + // Seq.from(RRuleInstance.class.getClasses()).filter(RRule.RRuleFamily.class::isAssignableFrom).mapUnchecked(Class::getConstructor).mapUnchecked(Constructor::newInstance).map(r -> (RRule.RRuleFamily) r); + // return individuals.appendedAll(families.flatMap(RRule.RRuleFamily::family)); + return individuals; } public static void verify() { @@ -46,39 +52,284 @@ public static void generate() { } public static void main(String[] args) throws IOException { - var rules = new RRuleInstance.JoinAssociate(); - Files.createDirectories(Path.of(rulePath)); - for (var rule : rules.family()) { - new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); - } - generate(); + // var rule = new RRuleInstance.FilterReduceTrue(); + // Files.createDirectories(Path.of(rulePath)); + // new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); + // var rules = new RRuleInstance.JoinAssociate(); + // Files.createDirectories(Path.of(rulePath)); + // for (var rule : rules.family()) { + // new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); + // } + // generate(); + + /* FilterIntoJoin */ var tester = new CalciteTester(); var builder = RuleBuilder.create(); var table = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); builder.addTable(table); var before = builder.scan(table.getName()) - .filter(builder.call(builder.genericPredicateOp("inner", true), builder.fields())) - .filter(builder.call(builder.genericPredicateOp("outer", true), builder.fields())) - .build(); - var after = builder.scan(table.getName()).filter(builder.call(SqlStdOperatorTable.AND, - builder.call(builder.genericPredicateOp("inner", true), builder.fields()), - builder.call(builder.genericPredicateOp("outer", true), builder.fields()))) - .build(); - var runner = loadRule(FilterMerge.Config.DEFAULT.toRule()); - tester.verify(runner, before, after); - before = builder.scan(table.getName()) .scan(table.getName()) .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) .filter(builder.call(builder.genericPredicateOp("pred", true), builder.fields())) .build(); - after = builder.scan(table.getName()) + var after = builder.scan(table.getName()) .scan(table.getName()) .join(JoinRelType.INNER, builder.call(SqlStdOperatorTable.AND, builder.call(builder.genericPredicateOp("join", true), builder.joinFields()), builder.call(builder.genericPredicateOp("pred", true), builder.joinFields()))) .build(); - runner = loadRule(FilterIntoJoin.Config.DEFAULT.toRule()); + var runner = loadRule(FilterIntoJoin.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + + /* FilterMerge */ + before = builder.scan(table.getName()) + .filter(builder.call(builder.genericPredicateOp("inner", true), builder.fields())) + .filter(builder.call(builder.genericPredicateOp("outer", true), builder.fields())) + .build(); + after = builder.scan(table.getName()).filter(builder.call(SqlStdOperatorTable.AND, + builder.call(builder.genericPredicateOp("inner", true), builder.fields()), + builder.call(builder.genericPredicateOp("outer", true), builder.fields()))) + .build(); + runner = loadRule(FilterMerge.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + + /* FilterProjectTranspose */ + builder = RuleBuilder.create(); + table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + var scan = builder.scan(table.getName()).build(); + before = builder + .push(scan) + .filter(builder.equals(builder.field(0), builder.literal(10))) + .project(builder.field(0)) + .build(); + after = builder + .push(scan) + .project(builder.field(0)) + .filter(builder.equals(builder.field(0), builder.literal(10))) + .build(); + runner = loadRule(FilterProjectTranspose.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + + /* UnionMerge */ + builder = RuleBuilder.create(); + table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + var scan1 = builder.scan(table.getName()).build(); + var scan2 = builder.scan(table.getName()).build(); + var scan3 = builder.scan(table.getName()).build(); + var firstUnion = builder.push(scan1).push(scan2).union(false).build(); + before = builder.push(firstUnion).push(scan3).union(false).build(); + after = builder.push(scan1).push(scan2).push(scan3).union(false, 3).build(); + runner = loadRule(UnionMerge.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + + /* IntersectMerge */ + builder = RuleBuilder.create(); + table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + scan1 = builder.scan(table.getName()).build(); + scan2 = builder.scan(table.getName()).build(); + scan3 = builder.scan(table.getName()).build(); + var firstIntersect = builder.push(scan1).push(scan2).intersect(false).build(); + before = builder.push(firstIntersect).push(scan3).intersect(false).build(); + after = builder.push(scan1).push(scan2).push(scan3).intersect(false, 3).build(); + runner = loadRule(IntersectMerge.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + + /* FilterSetOpTranspose */ + builder = RuleBuilder.create(); + table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + scan1 = builder.scan(table.getName()).build(); + scan2 = builder.scan(table.getName()).build(); + var union = builder.push(scan1).push(scan2).union(false).build(); + before = builder.push(union).filter(builder.call(builder.genericPredicateOp("filter", true), builder.fields())).build(); + var filteredScan1 = builder.push(scan1).filter(builder.call(builder.genericPredicateOp("filter", true), builder.fields())).build(); + var filteredScan2 = builder.push(scan2).filter(builder.call(builder.genericPredicateOp("filter", true), builder.fields())).build(); + after = builder.push(filteredScan1).push(filteredScan2).union(false).build(); + runner = loadRule(FilterSetOpTranspose.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + + /* JoinExtractFilter */ + builder = RuleBuilder.create(); + var leftTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + var rightTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("VARCHAR", true), false))); + builder.addTable(leftTable); + builder.addTable(rightTable); + var leftScan = builder.scan(leftTable.getName()).build(); + var rightScan = builder.scan(rightTable.getName()).build(); + before = builder.push(leftScan).push(rightScan).join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.field(2, 0, 0), builder.field(2, 1, 0))).build(); + var trueJoin = builder.push(leftScan).push(rightScan).join(JoinRelType.INNER, builder.literal(true)).build(); + after = builder.push(trueJoin).filter(builder.call(builder.genericPredicateOp("join", true), builder.field(0), builder.field(1))).build(); + runner = loadRule(JoinExtractFilter.Config.DEFAULT.toRule()); tester.verify(runner, before, after); + + /* SemiJoinFilterTranspose */ + builder = RuleBuilder.create(); + leftTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + rightTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + builder.addTable(leftTable); + builder.addTable(rightTable); + leftScan = builder.scan(leftTable.getName()).build(); + rightScan = builder.scan(rightTable.getName()).build(); + builder.push(leftScan); + builder.push(rightScan); + var joinPredicate = builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0)); + var semiJoin = builder.join(JoinRelType.SEMI, joinPredicate).build(); + builder.push(semiJoin); + var filterPredicate = builder.call(builder.genericPredicateOp("filter", true), builder.field(0)); + before = builder.filter(filterPredicate).build(); + builder.push(leftScan); + var leftFilterPredicate = builder.call(builder.genericPredicateOp("filter", true), builder.field(0)); + var filteredLeft = builder.filter(leftFilterPredicate).build(); + builder.push(filteredLeft); + builder.push(rightScan); + var afterJoinPredicate = builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0)); + after = builder.join(JoinRelType.SEMI, afterJoinPredicate).build(); + runner = loadRule(SemiJoinFilterTranspose.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + + /* JoinCommute */ + // TBD: failed + // builder = RuleBuilder.create(); + // leftTable = builder.createQedTable(Seq.of( + // Tuple.of(RelType.fromString("INTEGER", true), false) + // )); + // rightTable = builder.createQedTable(Seq.of( + // Tuple.of(RelType.fromString("INTEGER", true), false) + // )); + // builder.addTable(leftTable); + // builder.addTable(rightTable); + // leftScan = builder.scan(leftTable.getName()).build(); + // rightScan = builder.scan(rightTable.getName()).build(); + // before = builder + // .push(leftScan) + // .push(rightScan) + // .join( + // JoinRelType.INNER, + // builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0)) + // ) + // .build(); + // after = builder + // .push(rightScan) + // .push(leftScan) + // .join( + // JoinRelType.INNER, + // builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0)) + // ) + // .build(); + // runner = loadRule(JoinCommute.Config.DEFAULT.toRule()); + // tester.verify(runner, before, after); + + /* ProjectMerge */ + // TBD: Automatically optimized? + // builder = RuleBuilder.create(); + // table = builder.createQedTable(Seq.of( + // Tuple.of(RelType.fromString("INTEGER", true), false), + // Tuple.of(RelType.fromString("INTEGER", true), false), + // Tuple.of(RelType.fromString("INTEGER", true), false) + // )); + // builder.addTable(table); + // scan = builder.scan(table.getName()).build(); + // var innerProject = builder + // .push(scan) + // .project(builder.field(0), builder.field(1)) + // .build(); + // before = builder + // .push(innerProject) + // .project(builder.field(0)) + // .build(); + // after = builder + // .push(scan) + // .project(builder.field(0)) + // .build(); + // runner = loadRule(ProjectMerge.Config.DEFAULT.toRule()); + // tester.verify(runner, before, after); + + /* JoinAddRedundantSemiJoin */ + // TBD: failed + // builder = RuleBuilder.create(); + // leftTable = builder.createQedTable(Seq.of( + // Tuple.of(RelType.fromString("INTEGER", true), false) + // )); + // rightTable = builder.createQedTable(Seq.of( + // Tuple.of(RelType.fromString("INTEGER", true), false) + // )); + // builder.addTable(leftTable); + // builder.addTable(rightTable); + // leftScan = builder.scan(leftTable.getName()).build(); + // rightScan = builder.scan(rightTable.getName()).build(); + // before = builder + // .push(leftScan) + // .push(rightScan) + // .join( + // JoinRelType.INNER, + // builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0)) + // ) + // .build(); + // after = builder + // .push(leftScan) + // .push(rightScan) + // .join(JoinRelType.SEMI, builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0))) + // .push(rightScan) + // .join(JoinRelType.INNER, builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0))) + // .build(); + // runner = loadRule(JoinAddRedundantSemiJoin.Config.DEFAULT.toRule()); + // tester.verify(runner, before, after); + + /* FilterReduceFalse */ + // TBD: Automatically optimized? + // builder = RuleBuilder.create(); + // table = builder.createQedTable(Seq.of( + // Tuple.of(RelType.fromString("INTEGER", true), false), + // Tuple.of(RelType.fromString("INTEGER", true), false) + // )); + // builder.addTable(table); + // scan = builder.scan(table.getName()).build(); + // before = builder + // .push(scan) + // .filter(builder.equals(builder.field(0), builder.literal(10))) + // .filter(builder.literal(false)) + // .build(); + // after = builder + // .push(scan) + // .empty() + // .build(); + // runner = loadRule(FilterReduceFalse.Config.DEFAULT.toRule()); + // tester.verify(runner, before, after); + + /* FilterReduceTrue */ + // TBD: Automatically optimized? And something wrong with this rule? + // builder = RuleBuilder.create(); + // table = builder.createQedTable(Seq.of( + // Tuple.of(RelType.fromString("INTEGER", true), false), + // Tuple.of(RelType.fromString("INTEGER", true), false) + // )); + // builder.addTable(table); + // scan = builder.scan(table.getName()).build(); + // before = builder + // .push(scan) + // .filter(builder.equals(builder.field(0), builder.literal(10))) + // .filter(builder.literal(true)) + // .filter(builder.equals(builder.field(1), builder.literal(20))) + // .build(); + // after = builder + // .push(scan) + // .filter(builder.equals(builder.field(0), builder.literal(10))) + // .filter(builder.equals(builder.field(1), builder.literal(20))) + // .build(); + // runner = loadRule(FilterReduceTrue.Config.DEFAULT.toRule()); + // tester.verify(runner, before, after); } public void serialize(RRule rule, String path) { @@ -107,9 +358,20 @@ public void test(RelOptRule rule, Seq tests) { public void verify(HepPlanner runner, RelNode source, RelNode target) { runner.setRoot(source); var answer = runner.findBestExp(); + + String answerExplain = answer.explain(); + String targetExplain = target.explain(); + + if(answerExplain.equals(targetExplain)) { + System.out.println("succeeded"); + // System.out.println("> Given source RelNode:\n" + source.explain()); + // System.out.println("> Actual rewritten RelNode:\n" + answerExplain); + // System.out.println("> Expected rewritten RelNode:\n" + targetExplain); + return; + } + System.out.println("failed"); System.out.println("> Given source RelNode:\n" + source.explain()); - System.out.println("> Actual rewritten RelNode:\n" + answer.explain()); - System.out.println("> Expected rewritten RelNode:\n" + target.explain()); + System.out.println("> Actual rewritten RelNode:\n" + answerExplain); + System.out.println("> Expected rewritten RelNode:\n" + targetExplain); } - } diff --git a/src/main/java/org/qed/Generated/FilterIntoJoin.java b/src/main/java/org/qed/Generated/FilterIntoJoin.java index b288f4b..12e58ce 100644 --- a/src/main/java/org/qed/Generated/FilterIntoJoin.java +++ b/src/main/java/org/qed/Generated/FilterIntoJoin.java @@ -4,39 +4,36 @@ import org.apache.calcite.plan.RelRule; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.JoinRelType; -import org.apache.calcite.rel.logical.LogicalFilter; -import org.apache.calcite.rel.logical.LogicalJoin; +import org.apache.calcite.rel.logical.*; public class FilterIntoJoin extends RelRule { - protected FilterIntoJoin(Config config) { - super(config); - } - - @Override - public void onMatch(RelOptRuleCall call) { - var var_4 = call.builder(); - call.transformTo(var_4.push(call.rel(2)).push(call.rel(3)).join(JoinRelType.INNER, - var_4.push(call.rel(2)).push(call.rel(3)).and(((LogicalJoin) call.rel(1)).getCondition(), - ((LogicalFilter) call.rel(0)).getCondition())).build()); - } - - public interface Config extends EmptyConfig { - Config DEFAULT = new Config() {}; - - @Override - default FilterIntoJoin toRule() { - return new FilterIntoJoin(this); - } - - @Override - default String description() { - return "FilterIntoJoin"; - } - - @Override - default RelRule.OperandTransform operandSupplier() { - return s_3 -> s_3.operand(LogicalFilter.class).oneInput(s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs())); - } - - } + protected FilterIntoJoin(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_4 = call.builder(); + call.transformTo(var_4.push(call.rel(2)).push(call.rel(3)).join(JoinRelType.INNER, var_4.push(call.rel(2)).push(call.rel(3)).and(((LogicalJoin) call.rel(1)).getCondition(), ((LogicalFilter) call.rel(0)).getCondition())).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterIntoJoin toRule() { + return new FilterIntoJoin(this); + } + + @Override + default String description() { + return "FilterIntoJoin"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_3 -> s_3.operand(LogicalFilter.class).oneInput(s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs())); + } + + } } diff --git a/src/main/java/org/qed/Generated/FilterMerge.java b/src/main/java/org/qed/Generated/FilterMerge.java index 3ee9946..f943b74 100644 --- a/src/main/java/org/qed/Generated/FilterMerge.java +++ b/src/main/java/org/qed/Generated/FilterMerge.java @@ -3,36 +3,37 @@ import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelRule; import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.logical.LogicalFilter; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; public class FilterMerge extends RelRule { - protected FilterMerge(Config config) { - super(config); - } - - @Override - public void onMatch(RelOptRuleCall call) { - var var_3 = call.builder(); - call.transformTo(var_3.push(call.rel(2)).filter(var_3.push(call.rel(2)).and(((LogicalFilter) call.rel(1)).getCondition(), ((LogicalFilter) call.rel(0)).getCondition())).build()); - } - - public interface Config extends EmptyConfig { - Config DEFAULT = new Config() {}; - - @Override - default FilterMerge toRule() { - return new FilterMerge(this); - } - - @Override - default String description() { - return "FilterMerge"; - } - - @Override - default RelRule.OperandTransform operandSupplier() { - return s_2 -> s_2.operand(LogicalFilter.class).oneInput(s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); - } - - } + protected FilterMerge(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(2)).filter(var_3.push(call.rel(2)).and(((LogicalFilter) call.rel(1)).getCondition(), ((LogicalFilter) call.rel(0)).getCondition())).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterMerge toRule() { + return new FilterMerge(this); + } + + @Override + default String description() { + return "FilterMerge"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalFilter.class).oneInput(s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); + } + + } } diff --git a/src/main/java/org/qed/Generated/FilterProjectTranspose.java b/src/main/java/org/qed/Generated/FilterProjectTranspose.java index decb19a..89a31a0 100644 --- a/src/main/java/org/qed/Generated/FilterProjectTranspose.java +++ b/src/main/java/org/qed/Generated/FilterProjectTranspose.java @@ -3,37 +3,37 @@ import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelRule; import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.logical.LogicalFilter; -import org.apache.calcite.rel.logical.LogicalProject; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; public class FilterProjectTranspose extends RelRule { - protected FilterProjectTranspose(Config config) { - super(config); - } - - @Override - public void onMatch(RelOptRuleCall call) { - var var_3 = call.builder(); - call.transformTo(var_3.filter(((LogicalFilter) call.rel(1)).getCondition()).build()); - } - - public interface Config extends EmptyConfig { - Config DEFAULT = new Config() {}; - - @Override - default FilterProjectTranspose toRule() { - return new FilterProjectTranspose(this); - } - - @Override - default String description() { - return "FilterProjectTranspose"; - } - - @Override - default RelRule.OperandTransform operandSupplier() { - return s_2 -> s_2.operand(LogicalProject.class).oneInput(s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); - } - - } + protected FilterProjectTranspose(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(2)).project(((LogicalProject) call.rel(0)).getProjects()).filter(((LogicalFilter) call.rel(1)).getCondition()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterProjectTranspose toRule() { + return new FilterProjectTranspose(this); + } + + @Override + default String description() { + return "FilterProjectTranspose"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalProject.class).oneInput(s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); + } + + } } diff --git a/src/main/java/org/qed/Generated/FilterReduceFalse.java b/src/main/java/org/qed/Generated/FilterReduceFalse.java new file mode 100644 index 0000000..e9fb926 --- /dev/null +++ b/src/main/java/org/qed/Generated/FilterReduceFalse.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class FilterReduceFalse extends RelRule { + protected FilterReduceFalse(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.empty().build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterReduceFalse toRule() { + return new FilterReduceFalse(this); + } + + @Override + default String description() { + return "FilterReduceFalse"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/FilterReduceTrue.java b/src/main/java/org/qed/Generated/FilterReduceTrue.java new file mode 100644 index 0000000..a1c125c --- /dev/null +++ b/src/main/java/org/qed/Generated/FilterReduceTrue.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class FilterReduceTrue extends RelRule { + protected FilterReduceTrue(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(1)).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterReduceTrue toRule() { + return new FilterReduceTrue(this); + } + + @Override + default String description() { + return "FilterReduceTrue"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/FilterSetOpTranspose.java b/src/main/java/org/qed/Generated/FilterSetOpTranspose.java new file mode 100644 index 0000000..fa3a6b4 --- /dev/null +++ b/src/main/java/org/qed/Generated/FilterSetOpTranspose.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class FilterSetOpTranspose extends RelRule { + protected FilterSetOpTranspose(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_4 = call.builder(); + call.transformTo(var_4.push(call.rel(2)).filter(((LogicalFilter) call.rel(0)).getCondition()).push(call.rel(3)).filter(((LogicalFilter) call.rel(0)).getCondition()).union(false, 2).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterSetOpTranspose toRule() { + return new FilterSetOpTranspose(this); + } + + @Override + default String description() { + return "FilterSetOpTranspose"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_3 -> s_3.operand(LogicalFilter.class).oneInput(s_2 -> s_2.operand(LogicalUnion.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Generated/IntersectMerge.java b/src/main/java/org/qed/Generated/IntersectMerge.java new file mode 100644 index 0000000..d274657 --- /dev/null +++ b/src/main/java/org/qed/Generated/IntersectMerge.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class IntersectMerge extends RelRule { + protected IntersectMerge(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_5 = call.builder(); + call.transformTo(var_5.push(call.rel(2)).push(call.rel(3)).push(call.rel(4)).intersect(false, 3).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default IntersectMerge toRule() { + return new IntersectMerge(this); + } + + @Override + default String description() { + return "IntersectMerge"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_4 -> s_4.operand(LogicalIntersect.class).inputs(s_2 -> s_2.operand(LogicalIntersect.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()), s_3 -> s_3.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/JoinAddRedundantSemiJoin.java b/src/main/java/org/qed/Generated/JoinAddRedundantSemiJoin.java new file mode 100644 index 0000000..dcc5c20 --- /dev/null +++ b/src/main/java/org/qed/Generated/JoinAddRedundantSemiJoin.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class JoinAddRedundantSemiJoin extends RelRule { + protected JoinAddRedundantSemiJoin(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(1)).push(call.rel(2)).join(JoinRelType.SEMI, ((LogicalJoin) call.rel(0)).getCondition()).push(call.rel(2)).join(JoinRelType.INNER, ((LogicalJoin) call.rel(0)).getCondition()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default JoinAddRedundantSemiJoin toRule() { + return new JoinAddRedundantSemiJoin(this); + } + + @Override + default String description() { + return "JoinAddRedundantSemiJoin"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/JoinCommute.java b/src/main/java/org/qed/Generated/JoinCommute.java new file mode 100644 index 0000000..894330c --- /dev/null +++ b/src/main/java/org/qed/Generated/JoinCommute.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class JoinCommute extends RelRule { + protected JoinCommute(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(2)).push(call.rel(1)).join(JoinRelType.INNER, ((LogicalJoin) call.rel(0)).getCondition()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default JoinCommute toRule() { + return new JoinCommute(this); + } + + @Override + default String description() { + return "JoinCommute"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/JoinExtractFilter.java b/src/main/java/org/qed/Generated/JoinExtractFilter.java new file mode 100644 index 0000000..132dc5e --- /dev/null +++ b/src/main/java/org/qed/Generated/JoinExtractFilter.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class JoinExtractFilter extends RelRule { + protected JoinExtractFilter(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(1)).push(call.rel(2)).join(JoinRelType.INNER, var_3.push(call.rel(1)).push(call.rel(2)).literal(true)).filter(((LogicalJoin) call.rel(0)).getCondition()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default JoinExtractFilter toRule() { + return new JoinExtractFilter(this); + } + + @Override + default String description() { + return "JoinExtractFilter"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/ProjectMerge.java b/src/main/java/org/qed/Generated/ProjectMerge.java new file mode 100644 index 0000000..0a4e458 --- /dev/null +++ b/src/main/java/org/qed/Generated/ProjectMerge.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class ProjectMerge extends RelRule { + protected ProjectMerge(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(2)).project(((LogicalProject) call.rel(0)).getProjects()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default ProjectMerge toRule() { + return new ProjectMerge(this); + } + + @Override + default String description() { + return "ProjectMerge"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalProject.class).oneInput(s_1 -> s_1.operand(LogicalProject.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Generated/SemiJoinFilterTranspose.java b/src/main/java/org/qed/Generated/SemiJoinFilterTranspose.java new file mode 100644 index 0000000..cc175ad --- /dev/null +++ b/src/main/java/org/qed/Generated/SemiJoinFilterTranspose.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class SemiJoinFilterTranspose extends RelRule { + protected SemiJoinFilterTranspose(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_4 = call.builder(); + call.transformTo(var_4.push(call.rel(2)).filter(((LogicalFilter) call.rel(0)).getCondition()).push(call.rel(3)).join(JoinRelType.SEMI, ((LogicalJoin) call.rel(1)).getCondition()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default SemiJoinFilterTranspose toRule() { + return new SemiJoinFilterTranspose(this); + } + + @Override + default String description() { + return "SemiJoinFilterTranspose"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_3 -> s_3.operand(LogicalFilter.class).oneInput(s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Generated/UnionMerge.java b/src/main/java/org/qed/Generated/UnionMerge.java new file mode 100644 index 0000000..d951fa8 --- /dev/null +++ b/src/main/java/org/qed/Generated/UnionMerge.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class UnionMerge extends RelRule { + protected UnionMerge(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_5 = call.builder(); + call.transformTo(var_5.push(call.rel(2)).push(call.rel(3)).push(call.rel(4)).union(false, 3).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default UnionMerge toRule() { + return new UnionMerge(this); + } + + @Override + default String description() { + return "UnionMerge"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_4 -> s_4.operand(LogicalUnion.class).inputs(s_2 -> s_2.operand(LogicalUnion.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()), s_3 -> s_3.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/RRuleInstance.java b/src/main/java/org/qed/RRuleInstance.java index a889253..e94556e 100644 --- a/src/main/java/org/qed/RRuleInstance.java +++ b/src/main/java/org/qed/RRuleInstance.java @@ -5,6 +5,8 @@ import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RRuleInstance.JoinAssociate; +// import org.qed.RRuleInstance.JoinConditionPush.JoinPred; public interface RRuleInstance { record FilterIntoJoin() implements RRule { @@ -83,47 +85,76 @@ public RelRN after() { } } -// record FilterSetOpTransposeRule implements RRule { -// -// } - -// record IntersectMerge implements RRule { -// -// } - - record JoinConditionPush() implements RRule { - static final RelRN left = RelRN.scan("Left", "Left_Type"); - static final RelRN right = RelRN.scan("Right", "Right_Type"); - static final JoinPred joinPred = new JoinPred(left, right); - + record FilterSetOpTranspose() implements RRule { + static final RelRN left = RelRN.scan("Left", "Common_Type"); + static final RelRN right = RelRN.scan("Right", "Common_Type"); + @Override public RelRN before() { - return left.join(JoinRelType.INNER, joinPred, right); + RelRN projTmp = left.union(false, right); + return projTmp.filter(projTmp.pred("filter")); + } + + @Override + public RelRN after() { + RexRN leftPred = left.pred("filter"); + RexRN rightPred = right.pred("filter"); + return left.filter(leftPred).union(false, right.filter(rightPred)); } + } + record IntersectMerge() implements RRule { + // Use a common type for all relations to make them compatible + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + static final RelRN c = RelRN.scan("C", "Common_Type"); + + @Override + public RelRN before() { + // Nested INTERSECT: (A INTERSECT B) INTERSECT C + return a.intersect(false, b).intersect(false, c); + } + @Override public RelRN after() { - var leftRN = left.filter(joinPred.leftPred()); - var rightRN = right.filter(joinPred.rightPred()); - return leftRN.join(JoinRelType.INNER, joinPred.bothPred(), rightRN); + // Flattened INTERSECT: A INTERSECT B INTERSECT C + return a.intersect(false, b, c); } + } - public record JoinPred(RelRN left, RelRN right) implements RexRN { + // record JoinConditionPush() implements RRule { + // static final RelRN left = RelRN.scan("Left", "Left_Type"); + // static final RelRN right = RelRN.scan("Right", "Right_Type"); + // static final JoinPred joinPred = new JoinPred(left, right); - @Override - public RexNode semantics() { - return RexRN.and(left.joinPred(bothPred(), right), left.joinField(0, right).pred(leftPred()), - left.joinField(1, right).pred(rightPred())).semantics(); - } + // @Override + // public RelRN before() { + // return left.join(JoinRelType.INNER, joinPred, right); + // } - public String bothPred() {return "both";} + // @Override + // public RelRN after() { + // var leftRN = left.filter(joinPred.leftPred()); + // var rightRN = right.filter(joinPred.rightPred()); + // return leftRN.join(JoinRelType.INNER, joinPred.bothPred(), rightRN); + // } - public String leftPred() {return "left";} + // public record JoinPred(RelRN left, RelRN right) implements RexRN { - public String rightPred() {return "right";} + // @Override + // public RexNode semantics() { + // return RexRN.and(left.joinPred(bothPred(), right), left.joinField(0, right).pred(leftPred()), + // left.joinField(1, right).pred(rightPred())).semantics(); + // } - } - } + // public String bothPred() {return "both";} + + // public String leftPred() {return "left";} + + // public String rightPred() {return "right";} + + // } + // } record JoinAddRedundantSemiJoin() implements RRule { static final RelRN left = RelRN.scan("Left", "Left_Type"); @@ -201,30 +232,42 @@ public Seq family() { } } -// record JoinCommute() implements RRule { -// static final RelRN left = RelRN.scan("Left", "Left_Type"); -// static final RelRN right = RelRN.scan("Right", "Right_Type"); -// static final String pred = "pred"; -// -// @Override -// public RelRN before() { -// return left.join(JoinRelType.INNER, pred, right); -// } -// -// @Override -// public RelRN after() { -// return right.join(JoinRelType.INNER, new RexRN.Pred( -// pred, true, right.joinFields(left, 1, 0) -// ), left).project("?"); -// } -// } + record JoinCommute() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("pred", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, joinCond, right); + } + + @Override + public RelRN after() { + // We need to swap the join fields in the condition + RexRN commutedJoinCond = right.joinPred("pred", left); + return right.join(JoinRelType.INNER, commutedJoinCond, left); + } + } -// record JoinExtractFilter() implements RRule { -// -// } + record JoinExtractFilter() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, joinCond, right); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.INNER, RexRN.trueLiteral(), right).filter(joinCond); + } + } // record JoinProjectTranspose() implements RRule { -// +// // } // JoinConditionPush? @@ -238,52 +281,17 @@ public Seq family() { // } // record JoinToSemiJoin() implements RRule { -// +// // } // record JoinLeftUnionTranspose() implements RRule { -// +// // } // record JoinRightUnionTranspose() implements RRule { // // } - record ProjectFilterTranspose() implements RRule { - static final RelRN source = RelRN.scan("Source", "Source_Type"); - - @Override - public RelRN before() { - var pred = new ProjectFilterTranspose.ProjectFilter(source); - return source.filter(pred).project(pred.proj(), pred.projType()); - } - - @Override - public RelRN after() { - var pred = new ProjectFilterTranspose.ProjectFilter(source); - return source.project(pred.proj(), pred.projType()).filter(pred.pred()); - } - - public record ProjectFilter(RelRN source) implements RexRN { - @Override - public RexNode semantics() { - return source.pred(pred()).proj(proj(), projType()).semantics(); - } - - public String proj() { - return "proj"; - } - - public String projType() { - return "Project_Type"; - } - - public String pred() { - return "pred"; - } - } - } - // record ProjectJoinRemove() implements RRule { // // @Override @@ -322,27 +330,94 @@ public RelRN after() { } } -// record ProjectSetOpTranspose() implements RRule { -// -// } - - record ProjectRemove() implements RRule { - static final RelRN source = RelRN.scan("Source", "Source_Type"); - + //TBD: currently provable for UNION ALL while unprovable for UNION + // record ProjectSetOpTranspose() implements RRule { + // static final RelRN left = RelRN.scan("Left", "Common_Type"); + // static final RelRN right = RelRN.scan("Right", "Common_Type"); + + // @Override + // public RelRN before() { + // RelRN projTmp = left.union(true, right); + // return projTmp.project(projTmp.proj("proj", "Proj_Type")); + // } + + // @Override + // public RelRN after() { + // RelRN projA = left.project(left.proj("proj", "Proj_Type")); + // RelRN projB = right.project(right.proj("proj", "Proj_Type")); + // return projA.union(true, projB); + // } + // } + + + /* TBD: Already optimized by calcite? */ + // record ProjectRemove() implements RRule { + // static final RelRN source = RelRN.scan("Source", "Source_Type"); + + // @Override + // public RelRN before() { + // return source.project(source.field(0)); + // } + + // @Override + // public RelRN after() { + // return source; + // } + // } + + record UnionMerge() implements RRule { + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + static final RelRN c = RelRN.scan("C", "Common_Type"); + @Override public RelRN before() { - return source.project(source.field(0)); + return a.union(false, b).union(false, c); } - + @Override public RelRN after() { - return null; + return a.union(false, b, c); } } -// record SemiJoinFilterTranspose() implements RRule { -// -// } + // TBD: generated java file doesn't work + // record PushFilterSemiJoin() implements RRule { + // static final RelRN left = RelRN.scan("Left", "Left_Type"); + // static final RelRN right = RelRN.scan("Right", "Right_Type"); + // static final RexRN joinCond = left.joinPred("join", right); + // static final RexRN leftFilter = left.pred("left_filter"); + + // @Override + // public RelRN before() { + // return left.join(JoinRelType.SEMI, RexRN.and(joinCond, leftFilter), right); + // } + + // @Override + // public RelRN after() { + // return left.filter(leftFilter).join(JoinRelType.SEMI, joinCond, right); + // } + // } + + record SemiJoinFilterTranspose() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right); + static final RexRN filterPred = left.pred("filter"); + + @Override + public RelRN before() { + // Semi-join followed by a filter + return left.join(JoinRelType.SEMI, joinCond, right).filter(filterPred); + } + + @Override + public RelRN after() { + // Push the filter before the semi-join + RelRN leftFiltered = left.filter(filterPred); + return leftFiltered.join(JoinRelType.SEMI, leftFiltered.joinPred("join", right), right); + } + } // record SemiJoinJoinTranspose() implements RRule { // @@ -356,10 +431,6 @@ public RelRN after() { // // } -// record UnionMerge() implements RRule { -// -// } - // record UnionRemove() implements RRule { // // } From 923118cd486cd63bd313054a7e3cee54ff373b07 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Fri, 2 May 2025 17:18:06 -0700 Subject: [PATCH 10/78] add rules FilterSetOpTranspose IntersectMerge JoinExtractFilter SemiJoinFilterTranspose UnionMerge --- src/main/java/org/qed/Generated/CalciteTester.java | 9 ++++----- src/main/java/org/qed/RRuleInstance.java | 1 + 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/qed/Generated/CalciteTester.java b/src/main/java/org/qed/Generated/CalciteTester.java index 492d108..4727402 100644 --- a/src/main/java/org/qed/Generated/CalciteTester.java +++ b/src/main/java/org/qed/Generated/CalciteTester.java @@ -35,7 +35,6 @@ public static Seq ruleList() { var individuals = Seq.from(RRuleInstance.class.getClasses()).filter(RRule.class::isAssignableFrom).mapUnchecked(Class::getConstructor).mapUnchecked(Constructor::newInstance).map(r -> (RRule) r); System.out.println(Seq.from(RRuleInstance.class.getClasses()).filter(RRule.RRuleFamily.class::isAssignableFrom).mapUnchecked(Class::getConstructor)); - /* To be restored */ // var families = // Seq.from(RRuleInstance.class.getClasses()).filter(RRule.RRuleFamily.class::isAssignableFrom).mapUnchecked(Class::getConstructor).mapUnchecked(Constructor::newInstance).map(r -> (RRule.RRuleFamily) r); // return individuals.appendedAll(families.flatMap(RRule.RRuleFamily::family)); @@ -52,16 +51,16 @@ public static void generate() { } public static void main(String[] args) throws IOException { - // var rule = new RRuleInstance.FilterReduceTrue(); - // Files.createDirectories(Path.of(rulePath)); - // new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); + var rule = new RRuleInstance.FilterSetOpTranspose(); + Files.createDirectories(Path.of(rulePath)); + new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); // var rules = new RRuleInstance.JoinAssociate(); // Files.createDirectories(Path.of(rulePath)); // for (var rule : rules.family()) { // new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); // } // generate(); - + /* FilterIntoJoin */ var tester = new CalciteTester(); var builder = RuleBuilder.create(); diff --git a/src/main/java/org/qed/RRuleInstance.java b/src/main/java/org/qed/RRuleInstance.java index e94556e..e3c5cc6 100644 --- a/src/main/java/org/qed/RRuleInstance.java +++ b/src/main/java/org/qed/RRuleInstance.java @@ -85,6 +85,7 @@ public RelRN after() { } } + // TBD: include intersect to make it a rule familiy record FilterSetOpTranspose() implements RRule { static final RelRN left = RelRN.scan("Left", "Common_Type"); static final RelRN right = RelRN.scan("Right", "Common_Type"); From 0e125f115bbc8d00ce2ace3890e66f3bdb202d97 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Mon, 12 May 2025 21:23:51 -0700 Subject: [PATCH 11/78] refactored rule instances --- pom.xml | 5 ++ .../java/org/qed/Generated/CalciteTester.java | 40 ++++++++--- .../qed/RRuleInstances/FilterIntoJoin.java | 27 ++++++++ .../org/qed/RRuleInstances/FilterMerge.java | 26 +++++++ .../FilterProjectTranspose.java | 25 +++++++ .../qed/RRuleInstances/FilterReduceFalse.java | 24 +++++++ .../qed/RRuleInstances/FilterReduceTrue.java | 24 +++++++ .../RRuleInstances/FilterSetOpTranspose.java | 28 ++++++++ .../qed/RRuleInstances/IntersectMerge.java | 29 ++++++++ .../JoinAddRedundantSemiJoin.java | 26 +++++++ .../org/qed/RRuleInstances/JoinAssociate.java | 69 +++++++++++++++++++ .../org/qed/RRuleInstances/JoinCommute.java | 28 ++++++++ .../qed/RRuleInstances/JoinExtractFilter.java | 26 +++++++ .../org/qed/RRuleInstances/ProjectMerge.java | 27 ++++++++ .../SemiJoinFilterTranspose.java | 30 ++++++++ .../org/qed/RRuleInstances/UnionMerge.java | 26 +++++++ 16 files changed, 451 insertions(+), 9 deletions(-) create mode 100644 src/main/java/org/qed/RRuleInstances/FilterIntoJoin.java create mode 100644 src/main/java/org/qed/RRuleInstances/FilterMerge.java create mode 100644 src/main/java/org/qed/RRuleInstances/FilterProjectTranspose.java create mode 100644 src/main/java/org/qed/RRuleInstances/FilterReduceFalse.java create mode 100644 src/main/java/org/qed/RRuleInstances/FilterReduceTrue.java create mode 100644 src/main/java/org/qed/RRuleInstances/FilterSetOpTranspose.java create mode 100644 src/main/java/org/qed/RRuleInstances/IntersectMerge.java create mode 100644 src/main/java/org/qed/RRuleInstances/JoinAddRedundantSemiJoin.java create mode 100644 src/main/java/org/qed/RRuleInstances/JoinAssociate.java create mode 100644 src/main/java/org/qed/RRuleInstances/JoinCommute.java create mode 100644 src/main/java/org/qed/RRuleInstances/JoinExtractFilter.java create mode 100644 src/main/java/org/qed/RRuleInstances/ProjectMerge.java create mode 100644 src/main/java/org/qed/RRuleInstances/SemiJoinFilterTranspose.java create mode 100644 src/main/java/org/qed/RRuleInstances/UnionMerge.java diff --git a/pom.xml b/pom.xml index 71ae84d..d6bcf67 100644 --- a/pom.xml +++ b/pom.xml @@ -108,5 +108,10 @@ system ${env.CVC5_JAVA} + + org.reflections + reflections + 0.10.2 + diff --git a/src/main/java/org/qed/Generated/CalciteTester.java b/src/main/java/org/qed/Generated/CalciteTester.java index 4727402..0d5237a 100644 --- a/src/main/java/org/qed/Generated/CalciteTester.java +++ b/src/main/java/org/qed/Generated/CalciteTester.java @@ -13,12 +13,16 @@ import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.qed.*; +import org.reflections.Reflections; import java.io.File; import java.io.IOException; import java.lang.reflect.Constructor; +import java.lang.reflect.Modifier; import java.nio.file.Files; import java.nio.file.Path; +import java.util.Set; +import java.util.stream.Collectors; public class CalciteTester { // Assuming that current working directory is the root of the project @@ -32,11 +36,29 @@ public static HepPlanner loadRule(RelOptRule rule) { } public static Seq ruleList() { - var individuals = - Seq.from(RRuleInstance.class.getClasses()).filter(RRule.class::isAssignableFrom).mapUnchecked(Class::getConstructor).mapUnchecked(Constructor::newInstance).map(r -> (RRule) r); - System.out.println(Seq.from(RRuleInstance.class.getClasses()).filter(RRule.RRuleFamily.class::isAssignableFrom).mapUnchecked(Class::getConstructor)); - // var families = - // Seq.from(RRuleInstance.class.getClasses()).filter(RRule.RRuleFamily.class::isAssignableFrom).mapUnchecked(Class::getConstructor).mapUnchecked(Constructor::newInstance).map(r -> (RRule.RRuleFamily) r); + Reflections reflections = new Reflections("org.qed.RRuleInstances"); + + Set> ruleClasses = reflections.getSubTypesOf(RRule.class); + var concreteRuleClasses = ruleClasses.stream() + .filter(clazz -> !clazz.isInterface() && + !Modifier.isAbstract(clazz.getModifiers()) && + !clazz.getName().contains("$")) // Skip all inner classes + .collect(Collectors.toSet()); + + var individuals = Seq.from(concreteRuleClasses) + .mapUnchecked(Class::getConstructor) + .mapUnchecked(Constructor::newInstance) + .map(r -> (RRule) r); + + // var families = Seq.from(reflections.getSubTypesOf(RRule.RRuleFamily.class)) + // .filter(clazz -> !clazz.isInterface() && !Modifier.isAbstract(clazz.getModifiers())) + // .mapUnchecked(clazz -> { + // Constructor constructor = clazz.getDeclaredConstructor(); + // constructor.setAccessible(true); + // return constructor.newInstance(); + // }) + // .map(r -> (RRule.RRuleFamily) r); + // return individuals.appendedAll(families.flatMap(RRule.RRuleFamily::family)); return individuals; } @@ -51,15 +73,15 @@ public static void generate() { } public static void main(String[] args) throws IOException { - var rule = new RRuleInstance.FilterSetOpTranspose(); - Files.createDirectories(Path.of(rulePath)); - new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); + // var rule = new RRuleInstance.FilterSetOpTranspose(); + // Files.createDirectories(Path.of(rulePath)); + // new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); // var rules = new RRuleInstance.JoinAssociate(); // Files.createDirectories(Path.of(rulePath)); // for (var rule : rules.family()) { // new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); // } - // generate(); + generate(); /* FilterIntoJoin */ var tester = new CalciteTester(); diff --git a/src/main/java/org/qed/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/RRuleInstances/FilterIntoJoin.java new file mode 100644 index 0000000..c381c30 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/FilterIntoJoin.java @@ -0,0 +1,27 @@ +package org.qed.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record FilterIntoJoin() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + var join = left.join(JoinRelType.INNER, joinCond, right); + return join.filter("outer"); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.INNER, RexRN.and(joinCond, left.joinPred("outer", right)), right); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/FilterMerge.java b/src/main/java/org/qed/RRuleInstances/FilterMerge.java new file mode 100644 index 0000000..a46ec17 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/FilterMerge.java @@ -0,0 +1,26 @@ +package org.qed.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record FilterMerge() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN inner = source.pred("inner"); + static final RexRN outer = source.pred("outer"); + + @Override + public RelRN before() { + return source.filter(inner).filter(outer); + } + + @Override + public RelRN after() { + return source.filter(RexRN.and(inner, outer)); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/FilterProjectTranspose.java b/src/main/java/org/qed/RRuleInstances/FilterProjectTranspose.java new file mode 100644 index 0000000..ccd5f68 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/FilterProjectTranspose.java @@ -0,0 +1,25 @@ +package org.qed.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record FilterProjectTranspose() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN proj = source.proj("proj", "Project_Type"); + + @Override + public RelRN before() { + return source.filter(proj.pred("pred")).project(proj); + } + + @Override + public RelRN after() { + return source.project(proj).filter("pred"); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/FilterReduceFalse.java b/src/main/java/org/qed/RRuleInstances/FilterReduceFalse.java new file mode 100644 index 0000000..026a4c4 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/FilterReduceFalse.java @@ -0,0 +1,24 @@ +package org.qed.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record FilterReduceFalse() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + + @Override + public RelRN before() { + return source.filter(RexRN.falseLiteral()); + } + + @Override + public RelRN after() { + return source.empty(); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/FilterReduceTrue.java b/src/main/java/org/qed/RRuleInstances/FilterReduceTrue.java new file mode 100644 index 0000000..be40a5c --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/FilterReduceTrue.java @@ -0,0 +1,24 @@ +package org.qed.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record FilterReduceTrue() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + + @Override + public RelRN before() { + return source.filter(RexRN.trueLiteral()); + } + + @Override + public RelRN after() { + return source; + } +} diff --git a/src/main/java/org/qed/RRuleInstances/FilterSetOpTranspose.java b/src/main/java/org/qed/RRuleInstances/FilterSetOpTranspose.java new file mode 100644 index 0000000..85bc91c --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/FilterSetOpTranspose.java @@ -0,0 +1,28 @@ +package org.qed.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record FilterSetOpTranspose() implements RRule { + static final RelRN left = RelRN.scan("Left", "Common_Type"); + static final RelRN right = RelRN.scan("Right", "Common_Type"); + + @Override + public RelRN before() { + RelRN projTmp = left.union(false, right); + return projTmp.filter(projTmp.pred("filter")); + } + + @Override + public RelRN after() { + RexRN leftPred = left.pred("filter"); + RexRN rightPred = right.pred("filter"); + return left.filter(leftPred).union(false, right.filter(rightPred)); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/IntersectMerge.java b/src/main/java/org/qed/RRuleInstances/IntersectMerge.java new file mode 100644 index 0000000..a4ea680 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/IntersectMerge.java @@ -0,0 +1,29 @@ +package org.qed.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record IntersectMerge() implements RRule { + // Use a common type for all relations to make them compatible + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + static final RelRN c = RelRN.scan("C", "Common_Type"); + + @Override + public RelRN before() { + // Nested INTERSECT: (A INTERSECT B) INTERSECT C + return a.intersect(false, b).intersect(false, c); + } + + @Override + public RelRN after() { + // Flattened INTERSECT: A INTERSECT B INTERSECT C + return a.intersect(false, b, c); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/JoinAddRedundantSemiJoin.java b/src/main/java/org/qed/RRuleInstances/JoinAddRedundantSemiJoin.java new file mode 100644 index 0000000..937ee61 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/JoinAddRedundantSemiJoin.java @@ -0,0 +1,26 @@ +package org.qed.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record JoinAddRedundantSemiJoin() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final String pred = "pred"; + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, pred, right); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.SEMI, pred, right).join(JoinRelType.INNER, pred, right); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/JoinAssociate.java b/src/main/java/org/qed/RRuleInstances/JoinAssociate.java new file mode 100644 index 0000000..2c23ecc --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/JoinAssociate.java @@ -0,0 +1,69 @@ +package org.qed.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record JoinAssociate() implements RRule.RRuleFamily { + static final RelRN a = RelRN.scan("A", "A_Type"); + static final RelRN b = RelRN.scan("B", "B_Type"); + static final RelRN c = RelRN.scan("C", "C_Type"); + static final String pred_ab = "pred_ab"; + static final String pred_bc = "pred_bc"; + static final RelRN.Join.JoinType.MetaJoinType mjt_0 = new RelRN.Join.JoinType.MetaJoinType("mjt_0"); + static final RelRN.Join.JoinType.MetaJoinType mjt_1 = new RelRN.Join.JoinType.MetaJoinType("mjt_1"); + static final RelRN.Join.JoinType.MetaJoinType mjt_2 = new RelRN.Join.JoinType.MetaJoinType("mjt_2"); + static final RelRN.Join.JoinType.MetaJoinType mjt_3 = new RelRN.Join.JoinType.MetaJoinType("mjt_3"); + + static final RelRN before_ab = a.join(mjt_0, RexRN.and( + a.joinPred(pred_ab, b), + new RexRN.JoinField(1, a, b).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), b); + + static final RelRN before = before_ab.join(mjt_1, RexRN.and( + new RexRN.Pred(RuleBuilder.create().genericPredicateOp(pred_bc, true), before_ab.joinFields(c, 1, 2)), + new RexRN.JoinField(1, before_ab, c).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), c); + + static final RelRN after_bc = b.join(mjt_2, RexRN.and( + b.joinPred(pred_bc, c), + new RexRN.JoinField(0, b, c).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), c); + + static final RelRN after = a.join(mjt_3, RexRN.and( + new RexRN.Pred(RuleBuilder.create().genericPredicateOp(pred_ab, true), a.joinFields(after_bc, 0, 1)), + new RexRN.JoinField(1, a, after_bc).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), after_bc); + + static final RRule template = new RRule() { + @Override + public RelRN before() { + return before; + } + + @Override + public RelRN after() { + return after; + } + + @Override + public String name() { + return JoinAssociate.class.getSimpleName(); + } + }; + + static Seq assignments() { + var joinTypes = Seq.of(JoinRelType.INNER, JoinRelType.LEFT, JoinRelType.RIGHT, JoinRelType.FULL).map(RelRN.Join.JoinType.ConcreteJoinType::new); + return joinTypes.flatMap(jt0 -> joinTypes.flatMap(jt1 -> joinTypes.flatMap(jt2 -> joinTypes.map(jt3 -> new RRule.RRuleGenerator.MetaAssignment(Map.of(mjt_0, jt0, mjt_1, jt1, mjt_2, jt2, mjt_3, jt3)))))); + } + + @Override + public Seq family() { + return new RRule.RRuleGenerator(template, assignments()).family(); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/JoinCommute.java b/src/main/java/org/qed/RRuleInstances/JoinCommute.java new file mode 100644 index 0000000..055ddf4 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/JoinCommute.java @@ -0,0 +1,28 @@ +package org.qed.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record JoinCommute() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("pred", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, joinCond, right); + } + + @Override + public RelRN after() { + // We need to swap the join fields in the condition + RexRN commutedJoinCond = right.joinPred("pred", left); + return right.join(JoinRelType.INNER, commutedJoinCond, left); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/JoinExtractFilter.java b/src/main/java/org/qed/RRuleInstances/JoinExtractFilter.java new file mode 100644 index 0000000..dacf5f3 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/JoinExtractFilter.java @@ -0,0 +1,26 @@ +package org.qed.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record JoinExtractFilter() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, joinCond, right); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.INNER, RexRN.trueLiteral(), right).filter(joinCond); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/ProjectMerge.java b/src/main/java/org/qed/RRuleInstances/ProjectMerge.java new file mode 100644 index 0000000..84124ef --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/ProjectMerge.java @@ -0,0 +1,27 @@ +package org.qed.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record ProjectMerge() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN inner = source.proj("inner", "Inner_Type"); + static final String outer = "outer"; + static final String outerType = "Outer_Type"; + + @Override + public RelRN before() { + return source.project(inner).project(outer, outerType); + } + + @Override + public RelRN after() { + return source.project(inner.proj(outer, outerType)); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/SemiJoinFilterTranspose.java b/src/main/java/org/qed/RRuleInstances/SemiJoinFilterTranspose.java new file mode 100644 index 0000000..a7193e8 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/SemiJoinFilterTranspose.java @@ -0,0 +1,30 @@ +package org.qed.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record SemiJoinFilterTranspose() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right); + static final RexRN filterPred = left.pred("filter"); + + @Override + public RelRN before() { + // Semi-join followed by a filter + return left.join(JoinRelType.SEMI, joinCond, right).filter(filterPred); + } + + @Override + public RelRN after() { + // Push the filter before the semi-join + RelRN leftFiltered = left.filter(filterPred); + return leftFiltered.join(JoinRelType.SEMI, leftFiltered.joinPred("join", right), right); + } +} diff --git a/src/main/java/org/qed/RRuleInstances/UnionMerge.java b/src/main/java/org/qed/RRuleInstances/UnionMerge.java new file mode 100644 index 0000000..80ab22b --- /dev/null +++ b/src/main/java/org/qed/RRuleInstances/UnionMerge.java @@ -0,0 +1,26 @@ +package org.qed.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record UnionMerge() implements RRule { + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + static final RelRN c = RelRN.scan("C", "Common_Type"); + + @Override + public RelRN before() { + return a.union(false, b).union(false, c); + } + + @Override + public RelRN after() { + return a.union(false, b, c); + } +} \ No newline at end of file From 1320f04ff5b42d176a708051228a7469d1e23df4 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Tue, 20 May 2025 19:25:31 -0700 Subject: [PATCH 12/78] removed unnecessary comments, refactored tests --- .../java/org/qed/Generated/CalciteTester.java | 292 ++---------------- .../RRuleInstances/FilterIntoJoin.java | 2 +- .../RRuleInstances/FilterMerge.java | 2 +- .../FilterProjectTranspose.java | 2 +- .../RRuleInstances/FilterReduceFalse.java | 2 +- .../RRuleInstances/FilterReduceTrue.java | 2 +- .../RRuleInstances/FilterSetOpTranspose.java | 2 +- .../RRuleInstances/IntersectMerge.java | 2 +- .../JoinAddRedundantSemiJoin.java | 2 +- .../RRuleInstances/JoinAssociate.java | 2 +- .../RRuleInstances/JoinCommute.java | 2 +- .../RRuleInstances/JoinExtractFilter.java | 2 +- .../RRuleInstances/ProjectMerge.java | 2 +- .../SemiJoinFilterTranspose.java | 2 +- .../RRuleInstances/UnionMerge.java | 2 +- .../Generated/Tests/FilterIntoJoinTest.java | 51 +++ .../qed/Generated/Tests/FilterMergeTest.java | 46 +++ .../Tests/FilterProjectTransposeTest.java | 52 ++++ .../Tests/FilterSetOpTransposeTest.java | 47 +++ .../Generated/Tests/IntersectMergeTest.java | 46 +++ .../Tests/JoinExtractFilterTest.java | 56 ++++ .../Tests/SemiJoinFilterTransposeTest.java | 60 ++++ .../qed/Generated/Tests/UnionMergeTest.java | 46 +++ 23 files changed, 438 insertions(+), 286 deletions(-) rename src/main/java/org/qed/{ => Generated}/RRuleInstances/FilterIntoJoin.java (95%) rename src/main/java/org/qed/{ => Generated}/RRuleInstances/FilterMerge.java (94%) rename src/main/java/org/qed/{ => Generated}/RRuleInstances/FilterProjectTranspose.java (94%) rename src/main/java/org/qed/{ => Generated}/RRuleInstances/FilterReduceFalse.java (93%) rename src/main/java/org/qed/{ => Generated}/RRuleInstances/FilterReduceTrue.java (92%) rename src/main/java/org/qed/{ => Generated}/RRuleInstances/FilterSetOpTranspose.java (95%) rename src/main/java/org/qed/{ => Generated}/RRuleInstances/IntersectMerge.java (95%) rename src/main/java/org/qed/{ => Generated}/RRuleInstances/JoinAddRedundantSemiJoin.java (94%) rename src/main/java/org/qed/{ => Generated}/RRuleInstances/JoinAssociate.java (98%) rename src/main/java/org/qed/{ => Generated}/RRuleInstances/JoinCommute.java (95%) rename src/main/java/org/qed/{ => Generated}/RRuleInstances/JoinExtractFilter.java (94%) rename src/main/java/org/qed/{ => Generated}/RRuleInstances/ProjectMerge.java (94%) rename src/main/java/org/qed/{ => Generated}/RRuleInstances/SemiJoinFilterTranspose.java (95%) rename src/main/java/org/qed/{ => Generated}/RRuleInstances/UnionMerge.java (94%) create mode 100644 src/main/java/org/qed/Generated/Tests/FilterIntoJoinTest.java create mode 100644 src/main/java/org/qed/Generated/Tests/FilterMergeTest.java create mode 100644 src/main/java/org/qed/Generated/Tests/FilterProjectTransposeTest.java create mode 100644 src/main/java/org/qed/Generated/Tests/FilterSetOpTransposeTest.java create mode 100644 src/main/java/org/qed/Generated/Tests/IntersectMergeTest.java create mode 100644 src/main/java/org/qed/Generated/Tests/JoinExtractFilterTest.java create mode 100644 src/main/java/org/qed/Generated/Tests/SemiJoinFilterTransposeTest.java create mode 100644 src/main/java/org/qed/Generated/Tests/UnionMergeTest.java diff --git a/src/main/java/org/qed/Generated/CalciteTester.java b/src/main/java/org/qed/Generated/CalciteTester.java index 0d5237a..2582008 100644 --- a/src/main/java/org/qed/Generated/CalciteTester.java +++ b/src/main/java/org/qed/Generated/CalciteTester.java @@ -36,13 +36,13 @@ public static HepPlanner loadRule(RelOptRule rule) { } public static Seq ruleList() { - Reflections reflections = new Reflections("org.qed.RRuleInstances"); + Reflections reflections = new Reflections("org.qed.Generated.RRuleInstances"); Set> ruleClasses = reflections.getSubTypesOf(RRule.class); var concreteRuleClasses = ruleClasses.stream() .filter(clazz -> !clazz.isInterface() && !Modifier.isAbstract(clazz.getModifiers()) && - !clazz.getName().contains("$")) // Skip all inner classes + !clazz.getName().contains("$")) .collect(Collectors.toSet()); var individuals = Seq.from(concreteRuleClasses) @@ -72,6 +72,22 @@ public static void generate() { ruleList().forEach(r -> tester.serialize(r, genPath)); } + public static void runAllTests() { + try { + org.qed.Generated.Tests.FilterIntoJoinTest.runTest(); + org.qed.Generated.Tests.FilterMergeTest.runTest(); + org.qed.Generated.Tests.FilterProjectTransposeTest.runTest(); + org.qed.Generated.Tests.UnionMergeTest.runTest(); + org.qed.Generated.Tests.IntersectMergeTest.runTest(); + org.qed.Generated.Tests.FilterSetOpTransposeTest.runTest(); + org.qed.Generated.Tests.JoinExtractFilterTest.runTest(); + org.qed.Generated.Tests.SemiJoinFilterTransposeTest.runTest(); + } catch (Exception e) { + System.out.println("Test failed: " + e.getMessage()); + e.printStackTrace(); + } + } + public static void main(String[] args) throws IOException { // var rule = new RRuleInstance.FilterSetOpTranspose(); // Files.createDirectories(Path.of(rulePath)); @@ -81,276 +97,8 @@ public static void main(String[] args) throws IOException { // for (var rule : rules.family()) { // new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); // } - generate(); - - /* FilterIntoJoin */ - var tester = new CalciteTester(); - var builder = RuleBuilder.create(); - var table = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); - builder.addTable(table); - var before = builder.scan(table.getName()) - .scan(table.getName()) - .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) - .filter(builder.call(builder.genericPredicateOp("pred", true), builder.fields())) - .build(); - var after = builder.scan(table.getName()) - .scan(table.getName()) - .join(JoinRelType.INNER, builder.call(SqlStdOperatorTable.AND, - builder.call(builder.genericPredicateOp("join", true), builder.joinFields()), - builder.call(builder.genericPredicateOp("pred", true), builder.joinFields()))) - .build(); - var runner = loadRule(FilterIntoJoin.Config.DEFAULT.toRule()); - tester.verify(runner, before, after); - - /* FilterMerge */ - before = builder.scan(table.getName()) - .filter(builder.call(builder.genericPredicateOp("inner", true), builder.fields())) - .filter(builder.call(builder.genericPredicateOp("outer", true), builder.fields())) - .build(); - after = builder.scan(table.getName()).filter(builder.call(SqlStdOperatorTable.AND, - builder.call(builder.genericPredicateOp("inner", true), builder.fields()), - builder.call(builder.genericPredicateOp("outer", true), builder.fields()))) - .build(); - runner = loadRule(FilterMerge.Config.DEFAULT.toRule()); - tester.verify(runner, before, after); - - /* FilterProjectTranspose */ - builder = RuleBuilder.create(); - table = builder.createQedTable(Seq.of( - Tuple.of(RelType.fromString("INTEGER", true), false), - Tuple.of(RelType.fromString("INTEGER", true), false) - )); - builder.addTable(table); - var scan = builder.scan(table.getName()).build(); - before = builder - .push(scan) - .filter(builder.equals(builder.field(0), builder.literal(10))) - .project(builder.field(0)) - .build(); - after = builder - .push(scan) - .project(builder.field(0)) - .filter(builder.equals(builder.field(0), builder.literal(10))) - .build(); - runner = loadRule(FilterProjectTranspose.Config.DEFAULT.toRule()); - tester.verify(runner, before, after); - - /* UnionMerge */ - builder = RuleBuilder.create(); - table = builder.createQedTable(Seq.of( - Tuple.of(RelType.fromString("INTEGER", true), false) - )); - builder.addTable(table); - var scan1 = builder.scan(table.getName()).build(); - var scan2 = builder.scan(table.getName()).build(); - var scan3 = builder.scan(table.getName()).build(); - var firstUnion = builder.push(scan1).push(scan2).union(false).build(); - before = builder.push(firstUnion).push(scan3).union(false).build(); - after = builder.push(scan1).push(scan2).push(scan3).union(false, 3).build(); - runner = loadRule(UnionMerge.Config.DEFAULT.toRule()); - tester.verify(runner, before, after); - - /* IntersectMerge */ - builder = RuleBuilder.create(); - table = builder.createQedTable(Seq.of( - Tuple.of(RelType.fromString("INTEGER", true), false) - )); - builder.addTable(table); - scan1 = builder.scan(table.getName()).build(); - scan2 = builder.scan(table.getName()).build(); - scan3 = builder.scan(table.getName()).build(); - var firstIntersect = builder.push(scan1).push(scan2).intersect(false).build(); - before = builder.push(firstIntersect).push(scan3).intersect(false).build(); - after = builder.push(scan1).push(scan2).push(scan3).intersect(false, 3).build(); - runner = loadRule(IntersectMerge.Config.DEFAULT.toRule()); - tester.verify(runner, before, after); - - /* FilterSetOpTranspose */ - builder = RuleBuilder.create(); - table = builder.createQedTable(Seq.of( - Tuple.of(RelType.fromString("INTEGER", true), false) - )); - builder.addTable(table); - scan1 = builder.scan(table.getName()).build(); - scan2 = builder.scan(table.getName()).build(); - var union = builder.push(scan1).push(scan2).union(false).build(); - before = builder.push(union).filter(builder.call(builder.genericPredicateOp("filter", true), builder.fields())).build(); - var filteredScan1 = builder.push(scan1).filter(builder.call(builder.genericPredicateOp("filter", true), builder.fields())).build(); - var filteredScan2 = builder.push(scan2).filter(builder.call(builder.genericPredicateOp("filter", true), builder.fields())).build(); - after = builder.push(filteredScan1).push(filteredScan2).union(false).build(); - runner = loadRule(FilterSetOpTranspose.Config.DEFAULT.toRule()); - tester.verify(runner, before, after); - - /* JoinExtractFilter */ - builder = RuleBuilder.create(); - var leftTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); - var rightTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("VARCHAR", true), false))); - builder.addTable(leftTable); - builder.addTable(rightTable); - var leftScan = builder.scan(leftTable.getName()).build(); - var rightScan = builder.scan(rightTable.getName()).build(); - before = builder.push(leftScan).push(rightScan).join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.field(2, 0, 0), builder.field(2, 1, 0))).build(); - var trueJoin = builder.push(leftScan).push(rightScan).join(JoinRelType.INNER, builder.literal(true)).build(); - after = builder.push(trueJoin).filter(builder.call(builder.genericPredicateOp("join", true), builder.field(0), builder.field(1))).build(); - runner = loadRule(JoinExtractFilter.Config.DEFAULT.toRule()); - tester.verify(runner, before, after); - - /* SemiJoinFilterTranspose */ - builder = RuleBuilder.create(); - leftTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); - rightTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); - builder.addTable(leftTable); - builder.addTable(rightTable); - leftScan = builder.scan(leftTable.getName()).build(); - rightScan = builder.scan(rightTable.getName()).build(); - builder.push(leftScan); - builder.push(rightScan); - var joinPredicate = builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0)); - var semiJoin = builder.join(JoinRelType.SEMI, joinPredicate).build(); - builder.push(semiJoin); - var filterPredicate = builder.call(builder.genericPredicateOp("filter", true), builder.field(0)); - before = builder.filter(filterPredicate).build(); - builder.push(leftScan); - var leftFilterPredicate = builder.call(builder.genericPredicateOp("filter", true), builder.field(0)); - var filteredLeft = builder.filter(leftFilterPredicate).build(); - builder.push(filteredLeft); - builder.push(rightScan); - var afterJoinPredicate = builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0)); - after = builder.join(JoinRelType.SEMI, afterJoinPredicate).build(); - runner = loadRule(SemiJoinFilterTranspose.Config.DEFAULT.toRule()); - tester.verify(runner, before, after); - - /* JoinCommute */ - // TBD: failed - // builder = RuleBuilder.create(); - // leftTable = builder.createQedTable(Seq.of( - // Tuple.of(RelType.fromString("INTEGER", true), false) - // )); - // rightTable = builder.createQedTable(Seq.of( - // Tuple.of(RelType.fromString("INTEGER", true), false) - // )); - // builder.addTable(leftTable); - // builder.addTable(rightTable); - // leftScan = builder.scan(leftTable.getName()).build(); - // rightScan = builder.scan(rightTable.getName()).build(); - // before = builder - // .push(leftScan) - // .push(rightScan) - // .join( - // JoinRelType.INNER, - // builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0)) - // ) - // .build(); - // after = builder - // .push(rightScan) - // .push(leftScan) - // .join( - // JoinRelType.INNER, - // builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0)) - // ) - // .build(); - // runner = loadRule(JoinCommute.Config.DEFAULT.toRule()); - // tester.verify(runner, before, after); - - /* ProjectMerge */ - // TBD: Automatically optimized? - // builder = RuleBuilder.create(); - // table = builder.createQedTable(Seq.of( - // Tuple.of(RelType.fromString("INTEGER", true), false), - // Tuple.of(RelType.fromString("INTEGER", true), false), - // Tuple.of(RelType.fromString("INTEGER", true), false) - // )); - // builder.addTable(table); - // scan = builder.scan(table.getName()).build(); - // var innerProject = builder - // .push(scan) - // .project(builder.field(0), builder.field(1)) - // .build(); - // before = builder - // .push(innerProject) - // .project(builder.field(0)) - // .build(); - // after = builder - // .push(scan) - // .project(builder.field(0)) - // .build(); - // runner = loadRule(ProjectMerge.Config.DEFAULT.toRule()); - // tester.verify(runner, before, after); - - /* JoinAddRedundantSemiJoin */ - // TBD: failed - // builder = RuleBuilder.create(); - // leftTable = builder.createQedTable(Seq.of( - // Tuple.of(RelType.fromString("INTEGER", true), false) - // )); - // rightTable = builder.createQedTable(Seq.of( - // Tuple.of(RelType.fromString("INTEGER", true), false) - // )); - // builder.addTable(leftTable); - // builder.addTable(rightTable); - // leftScan = builder.scan(leftTable.getName()).build(); - // rightScan = builder.scan(rightTable.getName()).build(); - // before = builder - // .push(leftScan) - // .push(rightScan) - // .join( - // JoinRelType.INNER, - // builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0)) - // ) - // .build(); - // after = builder - // .push(leftScan) - // .push(rightScan) - // .join(JoinRelType.SEMI, builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0))) - // .push(rightScan) - // .join(JoinRelType.INNER, builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0))) - // .build(); - // runner = loadRule(JoinAddRedundantSemiJoin.Config.DEFAULT.toRule()); - // tester.verify(runner, before, after); - - /* FilterReduceFalse */ - // TBD: Automatically optimized? - // builder = RuleBuilder.create(); - // table = builder.createQedTable(Seq.of( - // Tuple.of(RelType.fromString("INTEGER", true), false), - // Tuple.of(RelType.fromString("INTEGER", true), false) - // )); - // builder.addTable(table); - // scan = builder.scan(table.getName()).build(); - // before = builder - // .push(scan) - // .filter(builder.equals(builder.field(0), builder.literal(10))) - // .filter(builder.literal(false)) - // .build(); - // after = builder - // .push(scan) - // .empty() - // .build(); - // runner = loadRule(FilterReduceFalse.Config.DEFAULT.toRule()); - // tester.verify(runner, before, after); - - /* FilterReduceTrue */ - // TBD: Automatically optimized? And something wrong with this rule? - // builder = RuleBuilder.create(); - // table = builder.createQedTable(Seq.of( - // Tuple.of(RelType.fromString("INTEGER", true), false), - // Tuple.of(RelType.fromString("INTEGER", true), false) - // )); - // builder.addTable(table); - // scan = builder.scan(table.getName()).build(); - // before = builder - // .push(scan) - // .filter(builder.equals(builder.field(0), builder.literal(10))) - // .filter(builder.literal(true)) - // .filter(builder.equals(builder.field(1), builder.literal(20))) - // .build(); - // after = builder - // .push(scan) - // .filter(builder.equals(builder.field(0), builder.literal(10))) - // .filter(builder.equals(builder.field(1), builder.literal(20))) - // .build(); - // runner = loadRule(FilterReduceTrue.Config.DEFAULT.toRule()); - // tester.verify(runner, before, after); + // generate(); + runAllTests(); } public void serialize(RRule rule, String path) { diff --git a/src/main/java/org/qed/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java similarity index 95% rename from src/main/java/org/qed/RRuleInstances/FilterIntoJoin.java rename to src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java index c381c30..24165aa 100644 --- a/src/main/java/org/qed/RRuleInstances/FilterIntoJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java @@ -1,4 +1,4 @@ -package org.qed.RRuleInstances; +package org.qed.Generated.RRuleInstances; import kala.collection.Map; import kala.collection.Seq; diff --git a/src/main/java/org/qed/RRuleInstances/FilterMerge.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterMerge.java similarity index 94% rename from src/main/java/org/qed/RRuleInstances/FilterMerge.java rename to src/main/java/org/qed/Generated/RRuleInstances/FilterMerge.java index a46ec17..d90aa6d 100644 --- a/src/main/java/org/qed/RRuleInstances/FilterMerge.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterMerge.java @@ -1,4 +1,4 @@ -package org.qed.RRuleInstances; +package org.qed.Generated.RRuleInstances; import kala.collection.Map; import kala.collection.Seq; diff --git a/src/main/java/org/qed/RRuleInstances/FilterProjectTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterProjectTranspose.java similarity index 94% rename from src/main/java/org/qed/RRuleInstances/FilterProjectTranspose.java rename to src/main/java/org/qed/Generated/RRuleInstances/FilterProjectTranspose.java index ccd5f68..31a3a74 100644 --- a/src/main/java/org/qed/RRuleInstances/FilterProjectTranspose.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterProjectTranspose.java @@ -1,4 +1,4 @@ -package org.qed.RRuleInstances; +package org.qed.Generated.RRuleInstances; import kala.collection.Map; import kala.collection.Seq; diff --git a/src/main/java/org/qed/RRuleInstances/FilterReduceFalse.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterReduceFalse.java similarity index 93% rename from src/main/java/org/qed/RRuleInstances/FilterReduceFalse.java rename to src/main/java/org/qed/Generated/RRuleInstances/FilterReduceFalse.java index 026a4c4..8b4dd4b 100644 --- a/src/main/java/org/qed/RRuleInstances/FilterReduceFalse.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterReduceFalse.java @@ -1,4 +1,4 @@ -package org.qed.RRuleInstances; +package org.qed.Generated.RRuleInstances; import kala.collection.Map; import kala.collection.Seq; diff --git a/src/main/java/org/qed/RRuleInstances/FilterReduceTrue.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterReduceTrue.java similarity index 92% rename from src/main/java/org/qed/RRuleInstances/FilterReduceTrue.java rename to src/main/java/org/qed/Generated/RRuleInstances/FilterReduceTrue.java index be40a5c..c3dd472 100644 --- a/src/main/java/org/qed/RRuleInstances/FilterReduceTrue.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterReduceTrue.java @@ -1,4 +1,4 @@ -package org.qed.RRuleInstances; +package org.qed.Generated.RRuleInstances; import kala.collection.Map; import kala.collection.Seq; diff --git a/src/main/java/org/qed/RRuleInstances/FilterSetOpTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterSetOpTranspose.java similarity index 95% rename from src/main/java/org/qed/RRuleInstances/FilterSetOpTranspose.java rename to src/main/java/org/qed/Generated/RRuleInstances/FilterSetOpTranspose.java index 85bc91c..2dcf9e2 100644 --- a/src/main/java/org/qed/RRuleInstances/FilterSetOpTranspose.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterSetOpTranspose.java @@ -1,4 +1,4 @@ -package org.qed.RRuleInstances; +package org.qed.Generated.RRuleInstances; import kala.collection.Map; import kala.collection.Seq; diff --git a/src/main/java/org/qed/RRuleInstances/IntersectMerge.java b/src/main/java/org/qed/Generated/RRuleInstances/IntersectMerge.java similarity index 95% rename from src/main/java/org/qed/RRuleInstances/IntersectMerge.java rename to src/main/java/org/qed/Generated/RRuleInstances/IntersectMerge.java index a4ea680..27afe67 100644 --- a/src/main/java/org/qed/RRuleInstances/IntersectMerge.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/IntersectMerge.java @@ -1,4 +1,4 @@ -package org.qed.RRuleInstances; +package org.qed.Generated.RRuleInstances; import kala.collection.Map; import kala.collection.Seq; diff --git a/src/main/java/org/qed/RRuleInstances/JoinAddRedundantSemiJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/JoinAddRedundantSemiJoin.java similarity index 94% rename from src/main/java/org/qed/RRuleInstances/JoinAddRedundantSemiJoin.java rename to src/main/java/org/qed/Generated/RRuleInstances/JoinAddRedundantSemiJoin.java index 937ee61..ec54c50 100644 --- a/src/main/java/org/qed/RRuleInstances/JoinAddRedundantSemiJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/JoinAddRedundantSemiJoin.java @@ -1,4 +1,4 @@ -package org.qed.RRuleInstances; +package org.qed.Generated.RRuleInstances; import kala.collection.Map; import kala.collection.Seq; diff --git a/src/main/java/org/qed/RRuleInstances/JoinAssociate.java b/src/main/java/org/qed/Generated/RRuleInstances/JoinAssociate.java similarity index 98% rename from src/main/java/org/qed/RRuleInstances/JoinAssociate.java rename to src/main/java/org/qed/Generated/RRuleInstances/JoinAssociate.java index 2c23ecc..059695e 100644 --- a/src/main/java/org/qed/RRuleInstances/JoinAssociate.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/JoinAssociate.java @@ -1,4 +1,4 @@ -package org.qed.RRuleInstances; +package org.qed.Generated.RRuleInstances; import kala.collection.Map; import kala.collection.Seq; diff --git a/src/main/java/org/qed/RRuleInstances/JoinCommute.java b/src/main/java/org/qed/Generated/RRuleInstances/JoinCommute.java similarity index 95% rename from src/main/java/org/qed/RRuleInstances/JoinCommute.java rename to src/main/java/org/qed/Generated/RRuleInstances/JoinCommute.java index 055ddf4..22f2344 100644 --- a/src/main/java/org/qed/RRuleInstances/JoinCommute.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/JoinCommute.java @@ -1,4 +1,4 @@ -package org.qed.RRuleInstances; +package org.qed.Generated.RRuleInstances; import kala.collection.Map; import kala.collection.Seq; diff --git a/src/main/java/org/qed/RRuleInstances/JoinExtractFilter.java b/src/main/java/org/qed/Generated/RRuleInstances/JoinExtractFilter.java similarity index 94% rename from src/main/java/org/qed/RRuleInstances/JoinExtractFilter.java rename to src/main/java/org/qed/Generated/RRuleInstances/JoinExtractFilter.java index dacf5f3..63db3ec 100644 --- a/src/main/java/org/qed/RRuleInstances/JoinExtractFilter.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/JoinExtractFilter.java @@ -1,4 +1,4 @@ -package org.qed.RRuleInstances; +package org.qed.Generated.RRuleInstances; import kala.collection.Map; import kala.collection.Seq; diff --git a/src/main/java/org/qed/RRuleInstances/ProjectMerge.java b/src/main/java/org/qed/Generated/RRuleInstances/ProjectMerge.java similarity index 94% rename from src/main/java/org/qed/RRuleInstances/ProjectMerge.java rename to src/main/java/org/qed/Generated/RRuleInstances/ProjectMerge.java index 84124ef..7458002 100644 --- a/src/main/java/org/qed/RRuleInstances/ProjectMerge.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/ProjectMerge.java @@ -1,4 +1,4 @@ -package org.qed.RRuleInstances; +package org.qed.Generated.RRuleInstances; import kala.collection.Map; import kala.collection.Seq; diff --git a/src/main/java/org/qed/RRuleInstances/SemiJoinFilterTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinFilterTranspose.java similarity index 95% rename from src/main/java/org/qed/RRuleInstances/SemiJoinFilterTranspose.java rename to src/main/java/org/qed/Generated/RRuleInstances/SemiJoinFilterTranspose.java index a7193e8..d1bc90e 100644 --- a/src/main/java/org/qed/RRuleInstances/SemiJoinFilterTranspose.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinFilterTranspose.java @@ -1,4 +1,4 @@ -package org.qed.RRuleInstances; +package org.qed.Generated.RRuleInstances; import kala.collection.Map; import kala.collection.Seq; diff --git a/src/main/java/org/qed/RRuleInstances/UnionMerge.java b/src/main/java/org/qed/Generated/RRuleInstances/UnionMerge.java similarity index 94% rename from src/main/java/org/qed/RRuleInstances/UnionMerge.java rename to src/main/java/org/qed/Generated/RRuleInstances/UnionMerge.java index 80ab22b..b71d320 100644 --- a/src/main/java/org/qed/RRuleInstances/UnionMerge.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/UnionMerge.java @@ -1,4 +1,4 @@ -package org.qed.RRuleInstances; +package org.qed.Generated.RRuleInstances; import kala.collection.Map; import kala.collection.Seq; diff --git a/src/main/java/org/qed/Generated/Tests/FilterIntoJoinTest.java b/src/main/java/org/qed/Generated/Tests/FilterIntoJoinTest.java new file mode 100644 index 0000000..6abfbd2 --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/FilterIntoJoinTest.java @@ -0,0 +1,51 @@ + +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.Generated.RRuleInstances.FilterIntoJoin; +import org.qed.RuleBuilder; + +/** + * Test for the FilterIntoJoin rule. + */ +public class FilterIntoJoinTest { + + /** + * Run test for FilterIntoJoin rule. + */ + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + builder.addTable(table); + + var before = builder.scan(table.getName()) + .scan(table.getName()) + .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) + .filter(builder.call(builder.genericPredicateOp("pred", true), builder.fields())) + .build(); + + var after = builder.scan(table.getName()) + .scan(table.getName()) + .join(JoinRelType.INNER, builder.call(SqlStdOperatorTable.AND, + builder.call(builder.genericPredicateOp("join", true), builder.joinFields()), + builder.call(builder.genericPredicateOp("pred", true), builder.joinFields()))) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.FilterIntoJoin.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + /** + * Main method to run this test independently. + */ + public static void main(String[] args) { + System.out.println("Running FilterIntoJoin test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/Tests/FilterMergeTest.java b/src/main/java/org/qed/Generated/Tests/FilterMergeTest.java new file mode 100644 index 0000000..33f63cc --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/FilterMergeTest.java @@ -0,0 +1,46 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.Generated.RRuleInstances.FilterMerge; +import org.qed.RuleBuilder; + +/** + * Test for the FilterMerge rule. + */ +public class FilterMergeTest { + + /** + * Run test for FilterMerge rule. + */ + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + builder.addTable(table); + + var before = builder.scan(table.getName()) + .filter(builder.call(builder.genericPredicateOp("inner", true), builder.fields())) + .filter(builder.call(builder.genericPredicateOp("outer", true), builder.fields())) + .build(); + + var after = builder.scan(table.getName()).filter(builder.call(SqlStdOperatorTable.AND, + builder.call(builder.genericPredicateOp("inner", true), builder.fields()), + builder.call(builder.genericPredicateOp("outer", true), builder.fields()))) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.FilterMerge.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + /** + * Main method to run this test independently. + */ + public static void main(String[] args) { + System.out.println("Running FilterMerge test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/Tests/FilterProjectTransposeTest.java b/src/main/java/org/qed/Generated/Tests/FilterProjectTransposeTest.java new file mode 100644 index 0000000..c030841 --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/FilterProjectTransposeTest.java @@ -0,0 +1,52 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.Generated.RRuleInstances.FilterProjectTranspose; +import org.qed.RuleBuilder; + +/** + * Test for the FilterProjectTranspose rule. + */ +public class FilterProjectTransposeTest { + + /** + * Run test for FilterProjectTranspose rule. + */ + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var scan = builder.scan(table.getName()).build(); + + var before = builder + .push(scan) + .filter(builder.equals(builder.field(0), builder.literal(10))) + .project(builder.field(0)) + .build(); + + var after = builder + .push(scan) + .project(builder.field(0)) + .filter(builder.equals(builder.field(0), builder.literal(10))) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.FilterProjectTranspose.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + /** + * Main method to run this test independently. + */ + public static void main(String[] args) { + System.out.println("Running FilterProjectTranspose test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/Tests/FilterSetOpTransposeTest.java b/src/main/java/org/qed/Generated/Tests/FilterSetOpTransposeTest.java new file mode 100644 index 0000000..50f0d1c --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/FilterSetOpTransposeTest.java @@ -0,0 +1,47 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.Generated.RRuleInstances.FilterSetOpTranspose; +import org.qed.RuleBuilder; + +/** + * Test for the FilterSetOpTranspose rule. + */ +public class FilterSetOpTransposeTest { + + /** + * Run test for FilterSetOpTranspose rule. + */ + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var scan1 = builder.scan(table.getName()).build(); + var scan2 = builder.scan(table.getName()).build(); + + var union = builder.push(scan1).push(scan2).union(false).build(); + var before = builder.push(union).filter(builder.call(builder.genericPredicateOp("filter", true), builder.fields())).build(); + + var filteredScan1 = builder.push(scan1).filter(builder.call(builder.genericPredicateOp("filter", true), builder.fields())).build(); + var filteredScan2 = builder.push(scan2).filter(builder.call(builder.genericPredicateOp("filter", true), builder.fields())).build(); + var after = builder.push(filteredScan1).push(filteredScan2).union(false).build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.FilterSetOpTranspose.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + /** + * Main method to run this test independently. + */ + public static void main(String[] args) { + System.out.println("Running FilterSetOpTranspose test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/Tests/IntersectMergeTest.java b/src/main/java/org/qed/Generated/Tests/IntersectMergeTest.java new file mode 100644 index 0000000..089d9ef --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/IntersectMergeTest.java @@ -0,0 +1,46 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.Generated.RRuleInstances.IntersectMerge; +import org.qed.RuleBuilder; + +/** + * Test for the IntersectMerge rule. + */ +public class IntersectMergeTest { + + /** + * Run test for IntersectMerge rule. + */ + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var scan1 = builder.scan(table.getName()).build(); + var scan2 = builder.scan(table.getName()).build(); + var scan3 = builder.scan(table.getName()).build(); + + var firstIntersect = builder.push(scan1).push(scan2).intersect(false).build(); + var before = builder.push(firstIntersect).push(scan3).intersect(false).build(); + + var after = builder.push(scan1).push(scan2).push(scan3).intersect(false, 3).build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.IntersectMerge.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + /** + * Main method to run this test independently. + */ + public static void main(String[] args) { + System.out.println("Running IntersectMerge test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/Tests/JoinExtractFilterTest.java b/src/main/java/org/qed/Generated/Tests/JoinExtractFilterTest.java new file mode 100644 index 0000000..72833a6 --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/JoinExtractFilterTest.java @@ -0,0 +1,56 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.Generated.RRuleInstances.JoinExtractFilter; +import org.qed.RuleBuilder; + +/** + * Test for the JoinExtractFilter rule. + */ +public class JoinExtractFilterTest { + + /** + * Run test for JoinExtractFilter rule. + */ + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var leftTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + var rightTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("VARCHAR", true), false))); + builder.addTable(leftTable); + builder.addTable(rightTable); + + var leftScan = builder.scan(leftTable.getName()).build(); + var rightScan = builder.scan(rightTable.getName()).build(); + + var before = builder.push(leftScan) + .push(rightScan) + .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), + builder.field(2, 0, 0), builder.field(2, 1, 0))) + .build(); + + var trueJoin = builder.push(leftScan) + .push(rightScan) + .join(JoinRelType.INNER, builder.literal(true)) + .build(); + + var after = builder.push(trueJoin) + .filter(builder.call(builder.genericPredicateOp("join", true), builder.field(0), builder.field(1))) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.JoinExtractFilter.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + /** + * Main method to run this test independently. + */ + public static void main(String[] args) { + System.out.println("Running JoinExtractFilter test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/Tests/SemiJoinFilterTransposeTest.java b/src/main/java/org/qed/Generated/Tests/SemiJoinFilterTransposeTest.java new file mode 100644 index 0000000..2fdde4f --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/SemiJoinFilterTransposeTest.java @@ -0,0 +1,60 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.Generated.RRuleInstances.SemiJoinFilterTranspose; +import org.qed.RuleBuilder; + +/** + * Test for the SemiJoinFilterTranspose rule. + */ +public class SemiJoinFilterTransposeTest { + + /** + * Run test for SemiJoinFilterTranspose rule. + */ + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var leftTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + var rightTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + builder.addTable(leftTable); + builder.addTable(rightTable); + + var leftScan = builder.scan(leftTable.getName()).build(); + var rightScan = builder.scan(rightTable.getName()).build(); + + // Build the "before" relation + builder.push(leftScan); + builder.push(rightScan); + var joinPredicate = builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0)); + var semiJoin = builder.join(JoinRelType.SEMI, joinPredicate).build(); + builder.push(semiJoin); + var filterPredicate = builder.call(builder.genericPredicateOp("filter", true), builder.field(0)); + var before = builder.filter(filterPredicate).build(); + + // Build the expected "after" relation + builder.push(leftScan); + var leftFilterPredicate = builder.call(builder.genericPredicateOp("filter", true), builder.field(0)); + var filteredLeft = builder.filter(leftFilterPredicate).build(); + builder.push(filteredLeft); + builder.push(rightScan); + var afterJoinPredicate = builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0)); + var after = builder.join(JoinRelType.SEMI, afterJoinPredicate).build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.SemiJoinFilterTranspose.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + /** + * Main method to run this test independently. + */ + public static void main(String[] args) { + System.out.println("Running SemiJoinFilterTranspose test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/Tests/UnionMergeTest.java b/src/main/java/org/qed/Generated/Tests/UnionMergeTest.java new file mode 100644 index 0000000..a29bef8 --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/UnionMergeTest.java @@ -0,0 +1,46 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.Generated.RRuleInstances.UnionMerge; +import org.qed.RuleBuilder; + +/** + * Test for the UnionMerge rule. + */ +public class UnionMergeTest { + + /** + * Run test for UnionMerge rule. + */ + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var scan1 = builder.scan(table.getName()).build(); + var scan2 = builder.scan(table.getName()).build(); + var scan3 = builder.scan(table.getName()).build(); + + var firstUnion = builder.push(scan1).push(scan2).union(false).build(); + var before = builder.push(firstUnion).push(scan3).union(false).build(); + + var after = builder.push(scan1).push(scan2).push(scan3).union(false, 3).build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.UnionMerge.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + /** + * Main method to run this test independently. + */ + public static void main(String[] args) { + System.out.println("Running UnionMerge test..."); + runTest(); + } +} \ No newline at end of file From df26af8df352e8c541ef4941090050e64e15311c Mon Sep 17 00:00:00 2001 From: wkaiz Date: Fri, 23 May 2025 13:00:42 -0700 Subject: [PATCH 13/78] add rules projectfiltertranspose and joinpushtransitivepredicates --- pom.xml | 9 ---- .../java/org/qed/Generated/CalciteTester.java | 18 ++++---- .../JoinPushTransitivePredicates.java | 39 +++++++++++++++++ .../qed/Generated/ProjectFilterTranspose.java | 39 +++++++++++++++++ .../JoinPushTransitivePredicates.java | 23 ++++++++++ .../ProjectFilterTranspose.java | 20 +++++++++ .../JoinPushTransitivePredicatesTest.java | 42 +++++++++++++++++++ .../Tests/ProjectFilterTransposeTest.java | 41 ++++++++++++++++++ 8 files changed, 213 insertions(+), 18 deletions(-) create mode 100644 src/main/java/org/qed/Generated/JoinPushTransitivePredicates.java create mode 100644 src/main/java/org/qed/Generated/ProjectFilterTranspose.java create mode 100644 src/main/java/org/qed/Generated/RRuleInstances/JoinPushTransitivePredicates.java create mode 100644 src/main/java/org/qed/Generated/RRuleInstances/ProjectFilterTranspose.java create mode 100644 src/main/java/org/qed/Generated/Tests/JoinPushTransitivePredicatesTest.java create mode 100644 src/main/java/org/qed/Generated/Tests/ProjectFilterTransposeTest.java diff --git a/pom.xml b/pom.xml index d6bcf67..de798ae 100644 --- a/pom.xml +++ b/pom.xml @@ -21,7 +21,6 @@ -classpath org.qed.Main - ${args} @@ -54,7 +53,6 @@ 21 21 - --enable-preview @@ -101,13 +99,6 @@ kala-common 0.67.0 - - io.github - cvc5 - 1.1.1 - system - ${env.CVC5_JAVA} - org.reflections reflections diff --git a/src/main/java/org/qed/Generated/CalciteTester.java b/src/main/java/org/qed/Generated/CalciteTester.java index 2582008..2af5c84 100644 --- a/src/main/java/org/qed/Generated/CalciteTester.java +++ b/src/main/java/org/qed/Generated/CalciteTester.java @@ -89,15 +89,15 @@ public static void runAllTests() { } public static void main(String[] args) throws IOException { - // var rule = new RRuleInstance.FilterSetOpTranspose(); - // Files.createDirectories(Path.of(rulePath)); - // new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); - // var rules = new RRuleInstance.JoinAssociate(); - // Files.createDirectories(Path.of(rulePath)); - // for (var rule : rules.family()) { - // new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); - // } - // generate(); +// var rule = new RRuleInstance.FilterSetOpTranspose(); +// Files.createDirectories(Path.of(rulePath)); +// new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); +// var rules = new RRuleInstance.JoinAssociate(); +// Files.createDirectories(Path.of(rulePath)); +// for (var rule : rules.family()) { +// new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); +// } + generate(); runAllTests(); } diff --git a/src/main/java/org/qed/Generated/JoinPushTransitivePredicates.java b/src/main/java/org/qed/Generated/JoinPushTransitivePredicates.java new file mode 100644 index 0000000..50c24a5 --- /dev/null +++ b/src/main/java/org/qed/Generated/JoinPushTransitivePredicates.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class JoinPushTransitivePredicates extends RelRule { + protected JoinPushTransitivePredicates(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_4 = call.builder(); + call.transformTo(var_4.push(call.rel(2)).push(call.rel(3)).join(JoinRelType.INNER, var_4.push(call.rel(2)).push(call.rel(3)).and(((LogicalJoin) call.rel(1)).getCondition(), ((LogicalFilter) call.rel(0)).getCondition())).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default JoinPushTransitivePredicates toRule() { + return new JoinPushTransitivePredicates(this); + } + + @Override + default String description() { + return "JoinPushTransitivePredicates"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_3 -> s_3.operand(LogicalFilter.class).oneInput(s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Generated/ProjectFilterTranspose.java b/src/main/java/org/qed/Generated/ProjectFilterTranspose.java new file mode 100644 index 0000000..cbbd7d8 --- /dev/null +++ b/src/main/java/org/qed/Generated/ProjectFilterTranspose.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class ProjectFilterTranspose extends RelRule { + protected ProjectFilterTranspose(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(2)).filter(((LogicalFilter) call.rel(0)).getCondition()).project(((LogicalProject) call.rel(1)).getProjects()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default ProjectFilterTranspose toRule() { + return new ProjectFilterTranspose(this); + } + + @Override + default String description() { + return "ProjectFilterTranspose"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalFilter.class).oneInput(s_1 -> s_1.operand(LogicalProject.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/JoinPushTransitivePredicates.java b/src/main/java/org/qed/Generated/RRuleInstances/JoinPushTransitivePredicates.java new file mode 100644 index 0000000..1aaa5d9 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/JoinPushTransitivePredicates.java @@ -0,0 +1,23 @@ +package org.qed.Generated.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record JoinPushTransitivePredicates() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN cond1 = left.joinPred("cond1", right); + static final RexRN cond2 = left.joinPred("cond2", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, cond1, right).filter(cond2); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.INNER, RexRN.and(cond1, cond2), right); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/ProjectFilterTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/ProjectFilterTranspose.java new file mode 100644 index 0000000..8a17d2c --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/ProjectFilterTranspose.java @@ -0,0 +1,20 @@ +package org.qed.Generated.RRuleInstances; + +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; + +public record ProjectFilterTranspose() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN proj = source.proj("proj", "Project_Type"); + + @Override + public RelRN before() { + return source.project(proj).filter("pred"); + } + + @Override + public RelRN after() { + return source.filter(proj.pred("pred")).project(proj); + } +} diff --git a/src/main/java/org/qed/Generated/Tests/JoinPushTransitivePredicatesTest.java b/src/main/java/org/qed/Generated/Tests/JoinPushTransitivePredicatesTest.java new file mode 100644 index 0000000..3da47c7 --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/JoinPushTransitivePredicatesTest.java @@ -0,0 +1,42 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class JoinPushTransitivePredicatesTest { + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var before = builder.scan(table.getName()) + .scan(table.getName()).join(JoinRelType.INNER,builder.call(builder.genericPredicateOp("cond1", true), builder.joinFields())) + .filter(builder.call(builder.genericPredicateOp("cond2", true), builder.fields())) + .build(); + + var after = builder.scan(table.getName()) + .scan(table.getName()) + .join(JoinRelType.INNER, builder.call(SqlStdOperatorTable.AND, + builder.call(builder.genericPredicateOp("cond1", true), builder.joinFields()), + builder.call(builder.genericPredicateOp("cond2", true), builder.joinFields()))) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.JoinPushTransitivePredicates.Config.DEFAULT.toRule()); + + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running ProjectFilterTranspose test..."); + runTest(); + } +} diff --git a/src/main/java/org/qed/Generated/Tests/ProjectFilterTransposeTest.java b/src/main/java/org/qed/Generated/Tests/ProjectFilterTransposeTest.java new file mode 100644 index 0000000..05bfc31 --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/ProjectFilterTransposeTest.java @@ -0,0 +1,41 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class ProjectFilterTransposeTest { + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var scan = builder.scan(table.getName()).build(); + + var before = builder + .push(scan) + .project(builder.field(0)) + .filter(builder.equals(builder.field(0), builder.literal(10))) + .build(); + + var after = builder + .push(scan) + .filter(builder.equals(builder.field(0), builder.literal(10))) + .project(builder.field(0)) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.ProjectFilterTranspose.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running ProjectFilterTranspose test..."); + runTest(); + } +} From 86dbedb5430d8c390cb2cc9cd37beff4425af563 Mon Sep 17 00:00:00 2001 From: Yushin Liang Date: Mon, 26 May 2025 00:21:14 +0800 Subject: [PATCH 14/78] verified SemiJoinRemove rule in PipelineInstance.java --- pom.xml | 8 +- .../java/org/qed/Generated/CalciteTester.java | 113 ------------------ .../org/qed/Generated/PipelineInstance.java | 47 ++++++++ .../org/qed/Generated/SemiJoinRemove.java | 39 ++++++ src/main/java/org/qed/RRuleInstance.java | 2 +- 5 files changed, 92 insertions(+), 117 deletions(-) create mode 100644 src/main/java/org/qed/Generated/PipelineInstance.java create mode 100644 src/main/java/org/qed/Generated/SemiJoinRemove.java diff --git a/pom.xml b/pom.xml index 77457e0..3daca99 100644 --- a/pom.xml +++ b/pom.xml @@ -52,9 +52,11 @@ org.apache.maven.plugins maven-compiler-plugin - 21 - 21 - --enable-preview + 23 + 23 + + --enable-preview + diff --git a/src/main/java/org/qed/Generated/CalciteTester.java b/src/main/java/org/qed/Generated/CalciteTester.java index cd98cef..0735617 100644 --- a/src/main/java/org/qed/Generated/CalciteTester.java +++ b/src/main/java/org/qed/Generated/CalciteTester.java @@ -88,118 +88,6 @@ public static void runAllTests() { System.out.println("Test failed: " + e.getMessage()); e.printStackTrace(); } -<<<<<<< HEAD - - var r = new RRuleInstance.ProjectJoinTranspose(); - new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, STR."\{r.name()}-\{r.info()}.json").toFile(), r.toJson()); - - generate(); - var tester = new CalciteTester(); - var builder = RuleBuilder.create(); - var table = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); - builder.addTable(table); - - var before = builder.scan(table.getName()) - .scan(table.getName()) - .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) - .project(builder.call(builder.genericProjectionOp("proj", RelType.fromString("INTEGER", true)), builder.fields(0))) - .build(); - var leftProjected = builder.scan(table.getName()) - .project(builder.call(builder.genericProjectionOp("proj", RelType.fromString("INTEGER", true)), builder.fields(0))) - .build(); - var after = builder.push(leftProjected) - .push(builder.scan(table.getName())) - .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) - .build(); - var runner = loadRule(ProjectJoinTranspose.Config.DEFAULT.toRule()); - tester.verify(runner, before, after); - - before = builder.scan(table.getName()) - .scan(table.getName()) - .join(JoinRelType.SEMI, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) - .filter(builder.call(builder.genericPredicateOp("pred", true), builder.fields())) - .build(); - var leftFiltered = builder.scan(table.getName()).filter(builder.call(builder.genericPredicateOp("pred", true), builder.fields())) - after = builder.push(leftFiltered) - .push(builder.scan(table.getName())) - .join(JoinRelType.SEMI, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) - .build(); - runner = loadRule(SemiJoinFilterTranspose.Config.DEFAULT.toRule()); - tester.verify(runner, before, after); - - var semiFirst = builder.scan(table.getName()) - .scan(table.getName()) - .join(JoinRelType.SEMI, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) - .build(); - before = builder.push(semiFirst) - .push(builder.scan(table.getName())) - .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) - .build(); - var innerFirst = builder.scan(table.getName()) - .scan(table.getName()) - .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) - .build(); - after = builder.push(innerFirst) - .push(builder.scan(table.getName())) - .join(JoinRelType.SEMI, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) - .build(); - runner = loadRule(SemiJoinJoinTranspose.Config.DEFAULT.toRule()); - tester.verify(runner, before, after); - - before = builder.scan(table.getName()) - .scan(table.getName()) - .join(JoinRelType.SEMI, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) - .project(builder.call(builder.genericProjectionOp("proj", RelType.fromString("INTEGER", true)), builder.fields(0))) - .build(); - leftProjected = builder.scan(table.getName()) - .project(builder.call(builder.genericProjectionOp("proj", RelType.fromString("INTEGER", true)), builder.fields(0))) - .build(); - after = builder.push(leftProjected) - .push(builder.scan(table.getName())) - .join(JoinRelType.SEMI, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) - .build(); - runner = loadRule(SemiJoinProjectTranspose.Config.DEFAULT.toRule()); - tester.verify(runner, before, after); - - before = builder.scan(table.getName()) - .scan(table.getName()) - .join(JoinRelType.SEMI, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) - .build(); - after = builder.scan(table.getName()).build(); - runner = loadRule(SemiJoinRemove.Config.DEFAULT.toRule()); - tester.verify(runner, before, after); - - - -// generate(); -// var tester = new CalciteTester(); -// var builder = RuleBuilder.create(); -// var table = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); -// builder.addTable(table); -// var before = builder.scan(table.getName()) -// .filter(builder.call(builder.genericPredicateOp("inner", true), builder.fields())) -// .filter(builder.call(builder.genericPredicateOp("outer", true), builder.fields())) -// .build(); -// var after = builder.scan(table.getName()).filter(builder.call(SqlStdOperatorTable.AND, -// builder.call(builder.genericPredicateOp("inner", true), builder.fields()), -// builder.call(builder.genericPredicateOp("outer", true), builder.fields()))) -// .build(); -// var runner = loadRule(FilterMerge.Config.DEFAULT.toRule()); -// tester.verify(runner, before, after); -// before = builder.scan(table.getName()) -// .scan(table.getName()) -// .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) -// .filter(builder.call(builder.genericPredicateOp("pred", true), builder.fields())) -// .build(); -// after = builder.scan(table.getName()) -// .scan(table.getName()) -// .join(JoinRelType.INNER, builder.call(SqlStdOperatorTable.AND, -// builder.call(builder.genericPredicateOp("join", true), builder.joinFields()), -// builder.call(builder.genericPredicateOp("pred", true), builder.joinFields()))) -// .build(); -// runner = loadRule(FilterIntoJoin.Config.DEFAULT.toRule()); -// tester.verify(runner, before, after); -======= } public static void main(String[] args) throws IOException { @@ -213,7 +101,6 @@ public static void main(String[] args) throws IOException { // } // generate(); runAllTests(); ->>>>>>> upstream/dsl } public void serialize(RRule rule, String path) { diff --git a/src/main/java/org/qed/Generated/PipelineInstance.java b/src/main/java/org/qed/Generated/PipelineInstance.java new file mode 100644 index 0000000..a42688a --- /dev/null +++ b/src/main/java/org/qed/Generated/PipelineInstance.java @@ -0,0 +1,47 @@ +package org.qed.Generated; + +import com.fasterxml.jackson.databind.ObjectMapper; +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.*; +import org.qed.RRuleInstance.SemiJoinRemove; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; + + +public class PipelineInstance { + static final String rulePath = "rules"; + static final String genPath = "src/main/java/org/qed/Generated"; + public static void runTest() { + var rule = new RRuleInstance.SemiJoinRemove(); + var builder = RuleBuilder.create(); + var tester = new CalciteTester(); + + var table = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + builder.addTable(table); + + var before = builder + .scan(table.getName()) + .scan(table.getName()) + .join(JoinRelType.SEMI, builder.literal(true)) + .build(); + var after = builder + .scan(table.getName()) + .build(); + + + // tester.serialize(rule, genPath); + var runner = tester.loadRule(org.qed.Generated.SemiJoinRemove.Config.DEFAULT.toRule()); + System.out.println("verifying:"); + tester.verify(runner, before, after); + } + + public static void main(String[] args) throws IOException { + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/SemiJoinRemove.java b/src/main/java/org/qed/Generated/SemiJoinRemove.java new file mode 100644 index 0000000..8d02d36 --- /dev/null +++ b/src/main/java/org/qed/Generated/SemiJoinRemove.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class SemiJoinRemove extends RelRule { + protected SemiJoinRemove(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_4 = call.builder(); + call.transformTo(var_4.push(call.rel(1)).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default SemiJoinRemove toRule() { + return new SemiJoinRemove(this); + } + + @Override + default String description() { + return "SemiJoinRemove"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/RRuleInstance.java b/src/main/java/org/qed/RRuleInstance.java index 09c712a..f7e3711 100644 --- a/src/main/java/org/qed/RRuleInstance.java +++ b/src/main/java/org/qed/RRuleInstance.java @@ -314,7 +314,7 @@ record ProjectJoinTranspose() implements RRule { static final RelRN left = RelRN.scan("Left", "Left_Type"); static final RelRN right = RelRN.scan("Right", "Right_Type"); static final RexRN proj = left.proj("proj", "Project_Type"); - static final String joinCond = left.joinPred("join", right); + static final String joinCond = left.joinPred("join", right).toString(); @Override public RelRN before() { From 639d9edc76ba790e83002b38b7cce26b0887c321 Mon Sep 17 00:00:00 2001 From: Yushin Liang Date: Tue, 27 May 2025 14:19:34 +0800 Subject: [PATCH 15/78] generated semijoinremove calcite code --- .../org/qed/Generated/PipelineInstance.java | 47 ------------------- 1 file changed, 47 deletions(-) delete mode 100644 src/main/java/org/qed/Generated/PipelineInstance.java diff --git a/src/main/java/org/qed/Generated/PipelineInstance.java b/src/main/java/org/qed/Generated/PipelineInstance.java deleted file mode 100644 index a42688a..0000000 --- a/src/main/java/org/qed/Generated/PipelineInstance.java +++ /dev/null @@ -1,47 +0,0 @@ -package org.qed.Generated; - -import com.fasterxml.jackson.databind.ObjectMapper; -import kala.collection.Seq; -import kala.tuple.Tuple; -import org.apache.calcite.plan.RelOptRule; -import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.core.JoinRelType; -import org.qed.*; -import org.qed.RRuleInstance.SemiJoinRemove; - -import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.Path; - - -public class PipelineInstance { - static final String rulePath = "rules"; - static final String genPath = "src/main/java/org/qed/Generated"; - public static void runTest() { - var rule = new RRuleInstance.SemiJoinRemove(); - var builder = RuleBuilder.create(); - var tester = new CalciteTester(); - - var table = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); - builder.addTable(table); - - var before = builder - .scan(table.getName()) - .scan(table.getName()) - .join(JoinRelType.SEMI, builder.literal(true)) - .build(); - var after = builder - .scan(table.getName()) - .build(); - - - // tester.serialize(rule, genPath); - var runner = tester.loadRule(org.qed.Generated.SemiJoinRemove.Config.DEFAULT.toRule()); - System.out.println("verifying:"); - tester.verify(runner, before, after); - } - - public static void main(String[] args) throws IOException { - runTest(); - } -} \ No newline at end of file From d78e7259ed65fd2aad1657f5de97ea9206e8e183 Mon Sep 17 00:00:00 2001 From: wkaiz Date: Tue, 27 May 2025 16:12:28 -0700 Subject: [PATCH 16/78] refactor semijoinremove --- .../RRuleInstances/SemiJoinRemove.java | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 src/main/java/org/qed/Generated/RRuleInstances/SemiJoinRemove.java diff --git a/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinRemove.java b/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinRemove.java new file mode 100644 index 0000000..d08c439 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinRemove.java @@ -0,0 +1,21 @@ +package org.qed.Generated.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +record SemiJoinRemove() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + + @Override + public RelRN before() { + return left.join(JoinRelType.SEMI, RexRN.trueLiteral(), right); + } + + @Override + public RelRN after() { + return left; + } +} From 8f8f65b6c0c3d3b513948b2dea79a49814d39daf Mon Sep 17 00:00:00 2001 From: wkaiz Date: Tue, 27 May 2025 16:42:34 -0700 Subject: [PATCH 17/78] removing comments; adding semijoinjointranspose and semijoinprojecttrasponse with its test case --- .../RRuleInstances/SemiJoinJoinTranspose.java | 25 +++++++++ .../SemiJoinProjectTranspose.java | 23 ++++++++ .../RRuleInstances/SemiJoinRemove.java | 2 +- .../qed/Generated/SemiJoinJoinTranspose.java | 39 +++++++++++++ .../Generated/SemiJoinProjectTranspose.java | 39 +++++++++++++ .../Generated/Tests/FilterIntoJoinTest.java | 11 +--- .../qed/Generated/Tests/FilterMergeTest.java | 12 +--- .../Tests/FilterProjectTransposeTest.java | 12 +--- .../Tests/FilterSetOpTransposeTest.java | 12 +--- .../Generated/Tests/IntersectMergeTest.java | 12 +--- .../Tests/JoinExtractFilterTest.java | 12 +--- .../Tests/SemiJoinFilterTransposeTest.java | 17 +----- .../Tests/SemiJoinProjectTransposeTest.java | 55 +++++++++++++++++++ .../qed/Generated/Tests/UnionMergeTest.java | 11 +--- 14 files changed, 192 insertions(+), 90 deletions(-) create mode 100644 src/main/java/org/qed/Generated/RRuleInstances/SemiJoinJoinTranspose.java create mode 100644 src/main/java/org/qed/Generated/RRuleInstances/SemiJoinProjectTranspose.java create mode 100644 src/main/java/org/qed/Generated/SemiJoinJoinTranspose.java create mode 100644 src/main/java/org/qed/Generated/SemiJoinProjectTranspose.java create mode 100644 src/main/java/org/qed/Generated/Tests/SemiJoinProjectTransposeTest.java diff --git a/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinJoinTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinJoinTranspose.java new file mode 100644 index 0000000..b1a4a3d --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinJoinTranspose.java @@ -0,0 +1,25 @@ +package org.qed.Generated.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + + +public record SemiJoinJoinTranspose() implements RRule { + static final RelRN left = RelRN.scan("left", "left_Type"); + static final RelRN middle = RelRN.scan("middle", "middle_Type"); + static final RelRN right = RelRN.scan("right", "right_Type"); + static final RexRN semiCond = left.joinPred("semi", middle); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, joinCond, right).join(JoinRelType.SEMI, semiCond, middle); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.SEMI, semiCond, middle).join(JoinRelType.INNER, joinCond, right); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinProjectTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinProjectTranspose.java new file mode 100644 index 0000000..71d2dfe --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinProjectTranspose.java @@ -0,0 +1,23 @@ +package org.qed.Generated.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record SemiJoinProjectTranspose() implements RRule { + static final RelRN left = RelRN.scan("Left", "left_type"); + static final RelRN right = RelRN.scan("Right", "right_type"); + static final RexRN proj = left.proj("proj", "proj_type"); + static final RexRN semiCond = left.joinPred("semi", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.SEMI, semiCond, right).project(proj); + } + + @Override + public RelRN after() { + return left.project(proj).join(JoinRelType.SEMI, semiCond, right); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinRemove.java b/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinRemove.java index d08c439..3b21e3f 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinRemove.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinRemove.java @@ -5,7 +5,7 @@ import org.qed.RelRN; import org.qed.RexRN; -record SemiJoinRemove() implements RRule { +public record SemiJoinRemove() implements RRule { static final RelRN left = RelRN.scan("Left", "Left_Type"); static final RelRN right = RelRN.scan("Right", "Right_Type"); diff --git a/src/main/java/org/qed/Generated/SemiJoinJoinTranspose.java b/src/main/java/org/qed/Generated/SemiJoinJoinTranspose.java new file mode 100644 index 0000000..33cc330 --- /dev/null +++ b/src/main/java/org/qed/Generated/SemiJoinJoinTranspose.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class SemiJoinJoinTranspose extends RelRule { + protected SemiJoinJoinTranspose(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_5 = call.builder(); + call.transformTo(var_5.push(call.rel(2)).push(call.rel(4)).join(JoinRelType.SEMI, ((LogicalJoin) call.rel(0)).getCondition()).push(call.rel(3)).join(JoinRelType.INNER, ((LogicalJoin) call.rel(1)).getCondition()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default SemiJoinJoinTranspose toRule() { + return new SemiJoinJoinTranspose(this); + } + + @Override + default String description() { + return "SemiJoinJoinTranspose"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_4 -> s_4.operand(LogicalJoin.class).inputs(s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()), s_3 -> s_3.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/SemiJoinProjectTranspose.java b/src/main/java/org/qed/Generated/SemiJoinProjectTranspose.java new file mode 100644 index 0000000..8ffffca --- /dev/null +++ b/src/main/java/org/qed/Generated/SemiJoinProjectTranspose.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class SemiJoinProjectTranspose extends RelRule { + protected SemiJoinProjectTranspose(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_4 = call.builder(); + call.transformTo(var_4.push(call.rel(2)).project(((LogicalProject) call.rel(0)).getProjects()).push(call.rel(3)).join(JoinRelType.SEMI, ((LogicalJoin) call.rel(1)).getCondition()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default SemiJoinProjectTranspose toRule() { + return new SemiJoinProjectTranspose(this); + } + + @Override + default String description() { + return "SemiJoinProjectTranspose"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_3 -> s_3.operand(LogicalProject.class).oneInput(s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Generated/Tests/FilterIntoJoinTest.java b/src/main/java/org/qed/Generated/Tests/FilterIntoJoinTest.java index 6abfbd2..f8fef01 100644 --- a/src/main/java/org/qed/Generated/Tests/FilterIntoJoinTest.java +++ b/src/main/java/org/qed/Generated/Tests/FilterIntoJoinTest.java @@ -10,14 +10,8 @@ import org.qed.Generated.RRuleInstances.FilterIntoJoin; import org.qed.RuleBuilder; -/** - * Test for the FilterIntoJoin rule. - */ public class FilterIntoJoinTest { - /** - * Run test for FilterIntoJoin rule. - */ public static void runTest() { var tester = new CalciteTester(); var builder = RuleBuilder.create(); @@ -40,10 +34,7 @@ public static void runTest() { var runner = CalciteTester.loadRule(org.qed.Generated.FilterIntoJoin.Config.DEFAULT.toRule()); tester.verify(runner, before, after); } - - /** - * Main method to run this test independently. - */ + public static void main(String[] args) { System.out.println("Running FilterIntoJoin test..."); runTest(); diff --git a/src/main/java/org/qed/Generated/Tests/FilterMergeTest.java b/src/main/java/org/qed/Generated/Tests/FilterMergeTest.java index 33f63cc..0e40d00 100644 --- a/src/main/java/org/qed/Generated/Tests/FilterMergeTest.java +++ b/src/main/java/org/qed/Generated/Tests/FilterMergeTest.java @@ -5,17 +5,10 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.qed.Generated.CalciteTester; import org.qed.RelType; -import org.qed.Generated.RRuleInstances.FilterMerge; import org.qed.RuleBuilder; -/** - * Test for the FilterMerge rule. - */ public class FilterMergeTest { - /** - * Run test for FilterMerge rule. - */ public static void runTest() { var tester = new CalciteTester(); var builder = RuleBuilder.create(); @@ -35,10 +28,7 @@ public static void runTest() { var runner = CalciteTester.loadRule(org.qed.Generated.FilterMerge.Config.DEFAULT.toRule()); tester.verify(runner, before, after); } - - /** - * Main method to run this test independently. - */ + public static void main(String[] args) { System.out.println("Running FilterMerge test..."); runTest(); diff --git a/src/main/java/org/qed/Generated/Tests/FilterProjectTransposeTest.java b/src/main/java/org/qed/Generated/Tests/FilterProjectTransposeTest.java index c030841..0131765 100644 --- a/src/main/java/org/qed/Generated/Tests/FilterProjectTransposeTest.java +++ b/src/main/java/org/qed/Generated/Tests/FilterProjectTransposeTest.java @@ -4,17 +4,10 @@ import kala.tuple.Tuple; import org.qed.Generated.CalciteTester; import org.qed.RelType; -import org.qed.Generated.RRuleInstances.FilterProjectTranspose; import org.qed.RuleBuilder; -/** - * Test for the FilterProjectTranspose rule. - */ public class FilterProjectTransposeTest { - /** - * Run test for FilterProjectTranspose rule. - */ public static void runTest() { var tester = new CalciteTester(); var builder = RuleBuilder.create(); @@ -41,10 +34,7 @@ public static void runTest() { var runner = CalciteTester.loadRule(org.qed.Generated.FilterProjectTranspose.Config.DEFAULT.toRule()); tester.verify(runner, before, after); } - - /** - * Main method to run this test independently. - */ + public static void main(String[] args) { System.out.println("Running FilterProjectTranspose test..."); runTest(); diff --git a/src/main/java/org/qed/Generated/Tests/FilterSetOpTransposeTest.java b/src/main/java/org/qed/Generated/Tests/FilterSetOpTransposeTest.java index 50f0d1c..488b474 100644 --- a/src/main/java/org/qed/Generated/Tests/FilterSetOpTransposeTest.java +++ b/src/main/java/org/qed/Generated/Tests/FilterSetOpTransposeTest.java @@ -4,17 +4,10 @@ import kala.tuple.Tuple; import org.qed.Generated.CalciteTester; import org.qed.RelType; -import org.qed.Generated.RRuleInstances.FilterSetOpTranspose; import org.qed.RuleBuilder; -/** - * Test for the FilterSetOpTranspose rule. - */ public class FilterSetOpTransposeTest { - /** - * Run test for FilterSetOpTranspose rule. - */ public static void runTest() { var tester = new CalciteTester(); var builder = RuleBuilder.create(); @@ -36,10 +29,7 @@ public static void runTest() { var runner = CalciteTester.loadRule(org.qed.Generated.FilterSetOpTranspose.Config.DEFAULT.toRule()); tester.verify(runner, before, after); } - - /** - * Main method to run this test independently. - */ + public static void main(String[] args) { System.out.println("Running FilterSetOpTranspose test..."); runTest(); diff --git a/src/main/java/org/qed/Generated/Tests/IntersectMergeTest.java b/src/main/java/org/qed/Generated/Tests/IntersectMergeTest.java index 089d9ef..ff4093d 100644 --- a/src/main/java/org/qed/Generated/Tests/IntersectMergeTest.java +++ b/src/main/java/org/qed/Generated/Tests/IntersectMergeTest.java @@ -4,17 +4,10 @@ import kala.tuple.Tuple; import org.qed.Generated.CalciteTester; import org.qed.RelType; -import org.qed.Generated.RRuleInstances.IntersectMerge; import org.qed.RuleBuilder; -/** - * Test for the IntersectMerge rule. - */ public class IntersectMergeTest { - /** - * Run test for IntersectMerge rule. - */ public static void runTest() { var tester = new CalciteTester(); var builder = RuleBuilder.create(); @@ -35,10 +28,7 @@ public static void runTest() { var runner = CalciteTester.loadRule(org.qed.Generated.IntersectMerge.Config.DEFAULT.toRule()); tester.verify(runner, before, after); } - - /** - * Main method to run this test independently. - */ + public static void main(String[] args) { System.out.println("Running IntersectMerge test..."); runTest(); diff --git a/src/main/java/org/qed/Generated/Tests/JoinExtractFilterTest.java b/src/main/java/org/qed/Generated/Tests/JoinExtractFilterTest.java index 72833a6..db2138a 100644 --- a/src/main/java/org/qed/Generated/Tests/JoinExtractFilterTest.java +++ b/src/main/java/org/qed/Generated/Tests/JoinExtractFilterTest.java @@ -5,17 +5,10 @@ import org.apache.calcite.rel.core.JoinRelType; import org.qed.Generated.CalciteTester; import org.qed.RelType; -import org.qed.Generated.RRuleInstances.JoinExtractFilter; import org.qed.RuleBuilder; -/** - * Test for the JoinExtractFilter rule. - */ public class JoinExtractFilterTest { - /** - * Run test for JoinExtractFilter rule. - */ public static void runTest() { var tester = new CalciteTester(); var builder = RuleBuilder.create(); @@ -45,10 +38,7 @@ public static void runTest() { var runner = CalciteTester.loadRule(org.qed.Generated.JoinExtractFilter.Config.DEFAULT.toRule()); tester.verify(runner, before, after); } - - /** - * Main method to run this test independently. - */ + public static void main(String[] args) { System.out.println("Running JoinExtractFilter test..."); runTest(); diff --git a/src/main/java/org/qed/Generated/Tests/SemiJoinFilterTransposeTest.java b/src/main/java/org/qed/Generated/Tests/SemiJoinFilterTransposeTest.java index 2fdde4f..11e404a 100644 --- a/src/main/java/org/qed/Generated/Tests/SemiJoinFilterTransposeTest.java +++ b/src/main/java/org/qed/Generated/Tests/SemiJoinFilterTransposeTest.java @@ -8,14 +8,8 @@ import org.qed.Generated.RRuleInstances.SemiJoinFilterTranspose; import org.qed.RuleBuilder; -/** - * Test for the SemiJoinFilterTranspose rule. - */ public class SemiJoinFilterTransposeTest { - /** - * Run test for SemiJoinFilterTranspose rule. - */ public static void runTest() { var tester = new CalciteTester(); var builder = RuleBuilder.create(); @@ -27,8 +21,7 @@ public static void runTest() { var leftScan = builder.scan(leftTable.getName()).build(); var rightScan = builder.scan(rightTable.getName()).build(); - - // Build the "before" relation + builder.push(leftScan); builder.push(rightScan); var joinPredicate = builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0)); @@ -36,8 +29,7 @@ public static void runTest() { builder.push(semiJoin); var filterPredicate = builder.call(builder.genericPredicateOp("filter", true), builder.field(0)); var before = builder.filter(filterPredicate).build(); - - // Build the expected "after" relation + builder.push(leftScan); var leftFilterPredicate = builder.call(builder.genericPredicateOp("filter", true), builder.field(0)); var filteredLeft = builder.filter(leftFilterPredicate).build(); @@ -49,10 +41,7 @@ public static void runTest() { var runner = CalciteTester.loadRule(org.qed.Generated.SemiJoinFilterTranspose.Config.DEFAULT.toRule()); tester.verify(runner, before, after); } - - /** - * Main method to run this test independently. - */ + public static void main(String[] args) { System.out.println("Running SemiJoinFilterTranspose test..."); runTest(); diff --git a/src/main/java/org/qed/Generated/Tests/SemiJoinProjectTransposeTest.java b/src/main/java/org/qed/Generated/Tests/SemiJoinProjectTransposeTest.java new file mode 100644 index 0000000..47d0755 --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/SemiJoinProjectTransposeTest.java @@ -0,0 +1,55 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class SemiJoinProjectTransposeTest { + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var leftTable = builder.createQedTable( + Seq.of(Tuple.of(RelType.fromString("INTEGER", true), /*nullable?*/ false))); + var rightTable = builder.createQedTable( + Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + + builder.addTable(leftTable); + builder.addTable(rightTable); + + var leftScan = builder.scan(leftTable.getName()).build(); + var rightScan = builder.scan(rightTable.getName()).build(); + + builder.push(leftScan); + builder.push(rightScan); + var joinPred = builder.equals( + builder.field(2, 0, 0), // left field 0 + builder.field(2, 1, 0)); // right field 0 + var semiJoin = builder.join(JoinRelType.SEMI, joinPred).build(); + builder.push(semiJoin); + + var before = builder.project(builder.field(0)).build(); + + builder.push(leftScan); + var projectedLeft = builder.project(builder.field(0)).build(); + + builder.push(projectedLeft); + builder.push(rightScan); + var afterJoinPred = builder.equals( + builder.field(2, 0, 0), + builder.field(2, 1, 0)); + var after = builder.join(JoinRelType.SEMI, afterJoinPred).build(); + + var runner = CalciteTester.loadRule( + org.qed.Generated.SemiJoinProjectTranspose.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running SemiJoinProjectTranspose test..."); + runTest(); + } +} diff --git a/src/main/java/org/qed/Generated/Tests/UnionMergeTest.java b/src/main/java/org/qed/Generated/Tests/UnionMergeTest.java index a29bef8..3ad7e07 100644 --- a/src/main/java/org/qed/Generated/Tests/UnionMergeTest.java +++ b/src/main/java/org/qed/Generated/Tests/UnionMergeTest.java @@ -7,14 +7,8 @@ import org.qed.Generated.RRuleInstances.UnionMerge; import org.qed.RuleBuilder; -/** - * Test for the UnionMerge rule. - */ public class UnionMergeTest { - /** - * Run test for UnionMerge rule. - */ public static void runTest() { var tester = new CalciteTester(); var builder = RuleBuilder.create(); @@ -35,10 +29,7 @@ public static void runTest() { var runner = CalciteTester.loadRule(org.qed.Generated.UnionMerge.Config.DEFAULT.toRule()); tester.verify(runner, before, after); } - - /** - * Main method to run this test independently. - */ + public static void main(String[] args) { System.out.println("Running UnionMerge test..."); runTest(); From c127a561dadfede6885c647d6e1f42553b7d18cb Mon Sep 17 00:00:00 2001 From: joyemang33 Date: Sat, 31 May 2025 16:20:35 +0800 Subject: [PATCH 18/78] filter out .vscode --- .gitignore | 1 + .vscode/settings.json | 3 --- 2 files changed, 1 insertion(+), 3 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.gitignore b/.gitignore index 1b536bd..e9fb4cd 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ target/ *.iml .mvn/wrapper/maven-wrapper.jar +.vscode diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index c5f3f6b..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "java.configuration.updateBuildConfiguration": "interactive" -} \ No newline at end of file From 827f2d52c6bc2d3bac89dd817f3d9d68daa4b44e Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 10:16:49 -0700 Subject: [PATCH 19/78] workflow try 1 --- .github/workflows/test-new-rules.yml | 204 +++++++++++++++++++++++++++ 1 file changed, 204 insertions(+) create mode 100644 .github/workflows/test-new-rules.yml diff --git a/.github/workflows/test-new-rules.yml b/.github/workflows/test-new-rules.yml new file mode 100644 index 0000000..2133719 --- /dev/null +++ b/.github/workflows/test-new-rules.yml @@ -0,0 +1,204 @@ +name: Test New Rules + +on: + push: + paths: + - 'src/main/java/org/qed/Generated/RRuleInstances/**/*.java' + pull_request: + paths: + - 'src/main/java/org/qed/Generated/RRuleInstances/**/*.java' + +jobs: + # First job: detect which files changed + get-changed-files: + runs-on: ubuntu-latest + outputs: + changed_files: ${{ steps.changed-files.outputs.all_changed_files }} + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Get changed files + id: changed-files + uses: tj-actions/changed-files@v41 + with: + files: | + src/main/java/org/qed/Generated/RRuleInstances/**/*.java + separator: ',' + + # Second job: test the changed rules + test-rules: + needs: get-changed-files + if: ${{ needs.get-changed-files.outputs.changed_files != '' }} + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Java + uses: actions/setup-java@v4 + with: + java-version: '11' + distribution: 'temurin' + + - name: Cache Maven dependencies + uses: actions/cache@v3 + with: + path: ~/.m2/repository + key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }} + restore-keys: | + ${{ runner.os }}-maven- + + - name: Build parser project + run: mvn -B compile --file pom.xml + + - name: Generate JSON for new/modified rules + run: | + mkdir -p rules + + # Create JSON generator + cat > JsonGenerator.java << 'EOF' + import org.qed.*; + import com.fasterxml.jackson.databind.ObjectMapper; + import java.nio.file.*; + + public class JsonGenerator { + public static void main(String[] args) throws Exception { + if (args.length == 0) { + System.err.println("Usage: java JsonGenerator "); + System.exit(1); + } + + String className = args[0]; + try { + Class clazz = Class.forName(className); + RRule rule = (RRule) clazz.getDeclaredConstructor().newInstance(); + + ObjectMapper mapper = new ObjectMapper(); + Files.createDirectories(Path.of("rules")); + String fileName = rule.name() + "-" + rule.info() + ".json"; + mapper.writerWithDefaultPrettyPrinter().writeValue( + Path.of("rules", fileName).toFile(), + rule.toJson() + ); + System.out.println("Generated: " + fileName); + } catch (Exception e) { + System.err.println("Failed to generate JSON for " + className); + e.printStackTrace(); + System.exit(1); + } + } + } + EOF + + # Get classpath + CLASSPATH="target/classes:$(mvn dependency:build-classpath -q -DforceStdout)" + + # Compile generator + javac -cp "$CLASSPATH" JsonGenerator.java + + # Process each changed file + IFS=',' read -ra FILES <<< "${{ needs.get-changed-files.outputs.changed_files }}" + for file in "${FILES[@]}"; do + if [ -n "$file" ]; then + className=$(echo "$file" | sed 's|src/main/java/||; s|/|.|g; s|\.java$||') + echo "Processing: $className" + java -cp ".:$CLASSPATH" JsonGenerator "$className" + fi + done + + rm -f JsonGenerator.java JsonGenerator.class + + - name: Setup qed-prover + run: | + # Option 1: Try to checkout qed-prover (update URL if needed) + # git clone https://github.com/qed-solver/prover.git || { + # echo "::error::Cannot clone qed-prover repository" + # echo "Please update the repository URL in the workflow" + # exit 1 + # } + + # Option 2: For testing, create a mock qed-prover + echo "::warning::Using mock qed-prover for testing. Replace with real qed-prover!" + mkdir -p qed-prover/target/release + + cat > qed-prover/target/release/qed-prover << 'EOF' + #!/bin/bash + # MOCK qed-prover - Replace with real implementation + json_file="$1" + rule_name=$(basename "$json_file" .json) + + # Mock: all rules pass for testing + echo '{"provable":true,"panicked":false,"complete_fragment":false,"equiv_class_duration":{"secs":0,"nanos":20638791},"equiv_class_timed_out":false,"smt_duration":{"secs":0,"nanos":55148417},"smt_timed_out":false,"nontrivial_perms":false,"translate_duration":{"secs":0,"nanos":1047250},"normal_duration":{"secs":0,"nanos":1880834},"stable_duration":{"secs":0,"nanos":28722792},"unify_duration":{"secs":0,"nanos":55404417},"total_duration":{"secs":0,"nanos":129852459}}' + EOF + + chmod +x qed-prover/target/release/qed-prover + + - name: Install jq + run: | + sudo apt-get update + sudo apt-get install -y jq + + - name: Test rules with qed-prover + run: | + echo "## Test Results" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + + failed_rules="" + passed_rules="" + + # Test the generated JSON files + for json_file in rules/*.json; do + if [ -f "$json_file" ]; then + rule_name=$(basename "$json_file" .json) + echo "Testing rule: $rule_name" + + # Run qed-prover + output=$(./qed-prover/target/release/qed-prover "$json_file" 2>&1) + + # Parse result + if echo "$output" | jq -e '.provable == true' > /dev/null 2>&1; then + echo "✅ $rule_name: PASSED" | tee -a $GITHUB_STEP_SUMMARY + passed_rules="$passed_rules$rule_name," + else + echo "❌ $rule_name: FAILED" | tee -a $GITHUB_STEP_SUMMARY + failed_rules="$failed_rules$rule_name," + fi + fi + done + + # Set outputs for use in other steps/jobs + echo "failed_rules=$failed_rules" >> $GITHUB_OUTPUT + echo "passed_rules=$passed_rules" >> $GITHUB_OUTPUT + + # Fail if any rules failed + if [ -n "$failed_rules" ]; then + echo "::error::Some rules failed verification: ${failed_rules%,}" + exit 1 + fi + + - name: Comment PR (if applicable) + if: github.event_name == 'pull_request' && always() + uses: actions/github-script@v7 + with: + script: | + const fs = require('fs'); + const summary = fs.readFileSync(process.env.GITHUB_STEP_SUMMARY, 'utf8'); + + github.rest.issues.createComment({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + body: `## QED Prover Test Results\n\n${summary}` + }); + + - name: Upload test artifacts + if: always() + uses: actions/upload-artifact@v3 + with: + name: test-results + path: | + rules/*.json + rules/*.result + test-output-*.log \ No newline at end of file From a2c66bcc04004770e5c7bafd588efd1e81a8ae2f Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 11:17:27 -0700 Subject: [PATCH 20/78] workflow try 1 --- .github/workflows/test-new-rules.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-new-rules.yml b/.github/workflows/test-new-rules.yml index 2133719..5ef0a53 100644 --- a/.github/workflows/test-new-rules.yml +++ b/.github/workflows/test-new-rules.yml @@ -1,4 +1,4 @@ -name: Test New Rules +name: Test New Rule on: push: From 34202e579edb87de7b20fe9ddb6ddf7d51caa9ae Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 11:21:06 -0700 Subject: [PATCH 21/78] workflow test 1 --- .../java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java index 24165aa..78ca4ee 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java @@ -22,6 +22,7 @@ public RelRN before() { @Override public RelRN after() { - return left.join(JoinRelType.INNER, RexRN.and(joinCond, left.joinPred("outer", right)), right); + return left.join(JoinRelType.INNER, RexRN.and( + joinCond, left.joinPred("outer", right)), right); } } From 3b786ac49e81845b5f7a77c94cca32afc43e8f8d Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 11:25:46 -0700 Subject: [PATCH 22/78] workflow try 2 --- .github/workflows/test-new-rules.yml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/test-new-rules.yml b/.github/workflows/test-new-rules.yml index 5ef0a53..00c27dc 100644 --- a/.github/workflows/test-new-rules.yml +++ b/.github/workflows/test-new-rules.yml @@ -194,11 +194,11 @@ jobs: }); - name: Upload test artifacts - if: always() - uses: actions/upload-artifact@v3 - with: - name: test-results - path: | - rules/*.json - rules/*.result - test-output-*.log \ No newline at end of file + if: always() + uses: actions/upload-artifact@v4 + with: + name: test-results + path: | + rules/*.json + rules/*.result + test-output-*.log \ No newline at end of file From 992131cc6ebb8eb8a912283645eddfaa559949ea Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 11:26:35 -0700 Subject: [PATCH 23/78] workflow test 2 --- .../java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java index 78ca4ee..24165aa 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java @@ -22,7 +22,6 @@ public RelRN before() { @Override public RelRN after() { - return left.join(JoinRelType.INNER, RexRN.and( - joinCond, left.joinPred("outer", right)), right); + return left.join(JoinRelType.INNER, RexRN.and(joinCond, left.joinPred("outer", right)), right); } } From 07f97230ef8b3526743ce7a8c5b578e20877b849 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 11:29:13 -0700 Subject: [PATCH 24/78] workflow test 3 --- .../java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java index 24165aa..86770be 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java @@ -22,6 +22,6 @@ public RelRN before() { @Override public RelRN after() { - return left.join(JoinRelType.INNER, RexRN.and(joinCond, left.joinPred("outer", right)), right); + return left.join(JoinRelType.SEMI, RexRN.and(joinCond, left.joinPred("outer", right)), right); } } From 09d36ee1a70c7d31e5190bb99436f8cb3d370169 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 11:39:34 -0700 Subject: [PATCH 25/78] new workflow try --- .github/workflows/test-new-rules.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test-new-rules.yml b/.github/workflows/test-new-rules.yml index 00c27dc..125f65a 100644 --- a/.github/workflows/test-new-rules.yml +++ b/.github/workflows/test-new-rules.yml @@ -195,10 +195,10 @@ jobs: - name: Upload test artifacts if: always() - uses: actions/upload-artifact@v4 - with: - name: test-results - path: | - rules/*.json - rules/*.result - test-output-*.log \ No newline at end of file + uses: actions/upload-artifact@v4 + with: + name: test-results + path: | + rules/*.json + rules/*.result + test-output-*.log \ No newline at end of file From c88b45e17d36ae5cb7db6f57b51801c0dd788e10 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 11:41:27 -0700 Subject: [PATCH 26/78] new workflow try --- .github/workflows/test-new-rules.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/test-new-rules.yml b/.github/workflows/test-new-rules.yml index 125f65a..2b3f2e2 100644 --- a/.github/workflows/test-new-rules.yml +++ b/.github/workflows/test-new-rules.yml @@ -32,7 +32,6 @@ jobs: needs: get-changed-files if: ${{ needs.get-changed-files.outputs.changed_files != '' }} runs-on: ubuntu-latest - steps: - uses: actions/checkout@v4 From 04278429a24eef53da5a041ed01133a8036b127d Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 12:41:45 -0700 Subject: [PATCH 27/78] workflow new try --- .github/workflows/test-new-rules.yml | 179 ++++++++------------------- 1 file changed, 52 insertions(+), 127 deletions(-) diff --git a/.github/workflows/test-new-rules.yml b/.github/workflows/test-new-rules.yml index 2b3f2e2..d6b9123 100644 --- a/.github/workflows/test-new-rules.yml +++ b/.github/workflows/test-new-rules.yml @@ -1,4 +1,4 @@ -name: Test New Rule +name: Test All Rules on: push: @@ -9,38 +9,18 @@ on: - 'src/main/java/org/qed/Generated/RRuleInstances/**/*.java' jobs: - # First job: detect which files changed - get-changed-files: - runs-on: ubuntu-latest - outputs: - changed_files: ${{ steps.changed-files.outputs.all_changed_files }} - steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v41 - with: - files: | - src/main/java/org/qed/Generated/RRuleInstances/**/*.java - separator: ',' - - # Second job: test the changed rules - test-rules: - needs: get-changed-files - if: ${{ needs.get-changed-files.outputs.changed_files != '' }} + test-all-rules: runs-on: ubuntu-latest + steps: - uses: actions/checkout@v4 - + - name: Set up Java uses: actions/setup-java@v4 with: java-version: '11' distribution: 'temurin' - + - name: Cache Maven dependencies uses: actions/cache@v3 with: @@ -48,37 +28,33 @@ jobs: key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }} restore-keys: | ${{ runner.os }}-maven- - - - name: Build parser project + + - name: Build project run: mvn -B compile --file pom.xml - - - name: Generate JSON for new/modified rules + + - name: Generate JSON for all rules run: | - mkdir -p rules - - # Create JSON generator + mkdir -p tmp-rules + cat > JsonGenerator.java << 'EOF' import org.qed.*; import com.fasterxml.jackson.databind.ObjectMapper; import java.nio.file.*; - public class JsonGenerator { public static void main(String[] args) throws Exception { if (args.length == 0) { System.err.println("Usage: java JsonGenerator "); System.exit(1); } - String className = args[0]; try { Class clazz = Class.forName(className); RRule rule = (RRule) clazz.getDeclaredConstructor().newInstance(); - ObjectMapper mapper = new ObjectMapper(); - Files.createDirectories(Path.of("rules")); + Files.createDirectories(Path.of("tmp-rules")); String fileName = rule.name() + "-" + rule.info() + ".json"; mapper.writerWithDefaultPrettyPrinter().writeValue( - Path.of("rules", fileName).toFile(), + Path.of("tmp-rules", fileName).toFile(), rule.toJson() ); System.out.println("Generated: " + fileName); @@ -90,114 +66,63 @@ jobs: } } EOF - - # Get classpath + CLASSPATH="target/classes:$(mvn dependency:build-classpath -q -DforceStdout)" - - # Compile generator javac -cp "$CLASSPATH" JsonGenerator.java - - # Process each changed file - IFS=',' read -ra FILES <<< "${{ needs.get-changed-files.outputs.changed_files }}" - for file in "${FILES[@]}"; do - if [ -n "$file" ]; then - className=$(echo "$file" | sed 's|src/main/java/||; s|/|.|g; s|\.java$||') - echo "Processing: $className" - java -cp ".:$CLASSPATH" JsonGenerator "$className" - fi + + find src/main/java/org/qed/Generated/RRuleInstances -name '*.java' | while read file; do + className=$(echo "$file" | sed 's|src/main/java/||; s|/|.|g; s|\.java$||') + echo "Generating for: $className" + java -cp ".:$CLASSPATH" JsonGenerator "$className" done - + rm -f JsonGenerator.java JsonGenerator.class - - - name: Setup qed-prover + + - name: Install Rust and clone qed-prover run: | - # Option 1: Try to checkout qed-prover (update URL if needed) - # git clone https://github.com/qed-solver/prover.git || { - # echo "::error::Cannot clone qed-prover repository" - # echo "Please update the repository URL in the workflow" - # exit 1 - # } - - # Option 2: For testing, create a mock qed-prover - echo "::warning::Using mock qed-prover for testing. Replace with real qed-prover!" - mkdir -p qed-prover/target/release - - cat > qed-prover/target/release/qed-prover << 'EOF' - #!/bin/bash - # MOCK qed-prover - Replace with real implementation - json_file="$1" - rule_name=$(basename "$json_file" .json) - - # Mock: all rules pass for testing - echo '{"provable":true,"panicked":false,"complete_fragment":false,"equiv_class_duration":{"secs":0,"nanos":20638791},"equiv_class_timed_out":false,"smt_duration":{"secs":0,"nanos":55148417},"smt_timed_out":false,"nontrivial_perms":false,"translate_duration":{"secs":0,"nanos":1047250},"normal_duration":{"secs":0,"nanos":1880834},"stable_duration":{"secs":0,"nanos":28722792},"unify_duration":{"secs":0,"nanos":55404417},"total_duration":{"secs":0,"nanos":129852459}}' - EOF - - chmod +x qed-prover/target/release/qed-prover - + sudo apt-get update + sudo apt-get install -y curl git + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + source $HOME/.cargo/env + git clone https://github.com/qed-solver/prover.git qed-prover + cd qed-prover + cargo build --release + + - name: Install jq run: | sudo apt-get update sudo apt-get install -y jq - - - name: Test rules with qed-prover + + - name: Test all rules run: | - echo "## Test Results" >> $GITHUB_STEP_SUMMARY + echo "## QED Prover Test Results" >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY - + failed_rules="" passed_rules="" - - # Test the generated JSON files - for json_file in rules/*.json; do - if [ -f "$json_file" ]; then - rule_name=$(basename "$json_file" .json) - echo "Testing rule: $rule_name" - - # Run qed-prover - output=$(./qed-prover/target/release/qed-prover "$json_file" 2>&1) - - # Parse result - if echo "$output" | jq -e '.provable == true' > /dev/null 2>&1; then - echo "✅ $rule_name: PASSED" | tee -a $GITHUB_STEP_SUMMARY - passed_rules="$passed_rules$rule_name," - else - echo "❌ $rule_name: FAILED" | tee -a $GITHUB_STEP_SUMMARY - failed_rules="$failed_rules$rule_name," - fi + + for json_file in tmp-rules/*.json; do + rule_name=$(basename "$json_file" .json) + echo "Testing $rule_name" + output=$(./qed-prover/target/release/qed-prover "$json_file") + if echo "$output" | jq -e '.provable == true' > /dev/null; then + echo "✅ $rule_name: PASSED" | tee -a $GITHUB_STEP_SUMMARY + passed_rules="$passed_rules$rule_name," + else + echo "❌ $rule_name: FAILED" | tee -a $GITHUB_STEP_SUMMARY + failed_rules="$failed_rules$rule_name," fi done - - # Set outputs for use in other steps/jobs - echo "failed_rules=$failed_rules" >> $GITHUB_OUTPUT - echo "passed_rules=$passed_rules" >> $GITHUB_OUTPUT - - # Fail if any rules failed + if [ -n "$failed_rules" ]; then - echo "::error::Some rules failed verification: ${failed_rules%,}" + echo "::error::Some rules failed: ${failed_rules%,}" exit 1 fi - - - name: Comment PR (if applicable) - if: github.event_name == 'pull_request' && always() - uses: actions/github-script@v7 - with: - script: | - const fs = require('fs'); - const summary = fs.readFileSync(process.env.GITHUB_STEP_SUMMARY, 'utf8'); - - github.rest.issues.createComment({ - issue_number: context.issue.number, - owner: context.repo.owner, - repo: context.repo.repo, - body: `## QED Prover Test Results\n\n${summary}` - }); - + - name: Upload test artifacts - if: always() + if: always() uses: actions/upload-artifact@v4 with: - name: test-results - path: | - rules/*.json - rules/*.result - test-output-*.log \ No newline at end of file + name: rule-json + path: tmp-rules/ From f2a93bac27499f82a4bd4dc460889570ff02ba06 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 12:45:02 -0700 Subject: [PATCH 28/78] src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java --- .../java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java index 86770be..c67e874 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java @@ -12,7 +12,7 @@ public record FilterIntoJoin() implements RRule { static final RelRN left = RelRN.scan("Left", "Left_Type"); static final RelRN right = RelRN.scan("Right", "Right_Type"); - static final RexRN joinCond = left.joinPred("join", right); + static final RexRN joinCond = left.joinPred("join", right); @Override public RelRN before() { From 1edfc9a820d54a572fb9619be0cea6c826bcb43f Mon Sep 17 00:00:00 2001 From: joyemang33 Date: Sat, 31 May 2025 19:55:06 +0000 Subject: [PATCH 29/78] update .gitignore --- .gitignore | 1 + src/main/java/org/qed/Generated/CalciteTester.java | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index e9fb4cd..870d0c2 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ target/ *.iml .mvn/wrapper/maven-wrapper.jar .vscode +.devcontainer \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/CalciteTester.java b/src/main/java/org/qed/Generated/CalciteTester.java index aefde13..4fed912 100644 --- a/src/main/java/org/qed/Generated/CalciteTester.java +++ b/src/main/java/org/qed/Generated/CalciteTester.java @@ -99,7 +99,7 @@ public static void main(String[] args) throws IOException { // for (var rule : rules.family()) { // new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); // } - generate(); + generate(); runAllTests(); } From c33bb3fe9923fc7bf1e770884c94fceb8b2a08ea Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 15:38:17 -0700 Subject: [PATCH 30/78] new workflow --- .github/workflows/test-new-rules.yml | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test-new-rules.yml b/.github/workflows/test-new-rules.yml index d6b9123..536375c 100644 --- a/.github/workflows/test-new-rules.yml +++ b/.github/workflows/test-new-rules.yml @@ -29,6 +29,21 @@ jobs: restore-keys: | ${{ runner.os }}-maven- + # Add this step to manually install CVC5 + - name: Install CVC5 dependency + run: | + # Download CVC5 JAR from GitHub releases or other source + wget https://github.com/cvc5/cvc5/releases/download/cvc5-1.2.0/cvc5-1.2.0.jar -O cvc5.jar + + # Install it to local Maven repository + mvn install:install-file \ + -Dfile=cvc5.jar \ + -DgroupId=io.github.cvc5 \ + -DartifactId=cvc5 \ + -Dversion=1.2.1 \ + -Dpackaging=jar \ + -DgeneratePom=true + - name: Build project run: mvn -B compile --file pom.xml @@ -88,7 +103,6 @@ jobs: cd qed-prover cargo build --release - - name: Install jq run: | sudo apt-get update @@ -125,4 +139,4 @@ jobs: uses: actions/upload-artifact@v4 with: name: rule-json - path: tmp-rules/ + path: tmp-rules/ \ No newline at end of file From fc46291834b5f86e150612803813323e83ec80bf Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 15:40:51 -0700 Subject: [PATCH 31/78] new test --- .../java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java index c67e874..86770be 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java @@ -12,7 +12,7 @@ public record FilterIntoJoin() implements RRule { static final RelRN left = RelRN.scan("Left", "Left_Type"); static final RelRN right = RelRN.scan("Right", "Right_Type"); - static final RexRN joinCond = left.joinPred("join", right); + static final RexRN joinCond = left.joinPred("join", right); @Override public RelRN before() { From 71b9eb9a4bb9740b4419b7baf7433cfc80a835a4 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 15:47:30 -0700 Subject: [PATCH 32/78] updated pom.xml --- pom.xml | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/pom.xml b/pom.xml index ccfa971..d194cc2 100644 --- a/pom.xml +++ b/pom.xml @@ -21,6 +21,7 @@ -classpath org.qed.Main + ${args} @@ -51,11 +52,9 @@ org.apache.maven.plugins maven-compiler-plugin - 23 - 23 - - --enable-preview - + 21 + 21 + --enable-preview @@ -103,9 +102,9 @@ 0.67.0 - io.github.cvc5 + io.github.p-org.solvers cvc5 - 1.2.1 + 0.0.7-v5 org.reflections @@ -113,4 +112,4 @@ 0.10.2 - + \ No newline at end of file From e74cc30c37dc1a344c3dfb349aa2aaaee9d7c694 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 15:49:18 -0700 Subject: [PATCH 33/78] new test --- .../java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java index 86770be..c67e874 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java @@ -12,7 +12,7 @@ public record FilterIntoJoin() implements RRule { static final RelRN left = RelRN.scan("Left", "Left_Type"); static final RelRN right = RelRN.scan("Right", "Right_Type"); - static final RexRN joinCond = left.joinPred("join", right); + static final RexRN joinCond = left.joinPred("join", right); @Override public RelRN before() { From de4e7ebe63f9a282eddc953e0131dfc88b716e5e Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 15:51:39 -0700 Subject: [PATCH 34/78] new workflow --- .github/workflows/test-new-rules.yml | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/.github/workflows/test-new-rules.yml b/.github/workflows/test-new-rules.yml index 536375c..d6b9123 100644 --- a/.github/workflows/test-new-rules.yml +++ b/.github/workflows/test-new-rules.yml @@ -29,21 +29,6 @@ jobs: restore-keys: | ${{ runner.os }}-maven- - # Add this step to manually install CVC5 - - name: Install CVC5 dependency - run: | - # Download CVC5 JAR from GitHub releases or other source - wget https://github.com/cvc5/cvc5/releases/download/cvc5-1.2.0/cvc5-1.2.0.jar -O cvc5.jar - - # Install it to local Maven repository - mvn install:install-file \ - -Dfile=cvc5.jar \ - -DgroupId=io.github.cvc5 \ - -DartifactId=cvc5 \ - -Dversion=1.2.1 \ - -Dpackaging=jar \ - -DgeneratePom=true - - name: Build project run: mvn -B compile --file pom.xml @@ -103,6 +88,7 @@ jobs: cd qed-prover cargo build --release + - name: Install jq run: | sudo apt-get update @@ -139,4 +125,4 @@ jobs: uses: actions/upload-artifact@v4 with: name: rule-json - path: tmp-rules/ \ No newline at end of file + path: tmp-rules/ From 86970480b671c1072ca1b7475db8785f026a6fd0 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 15:53:30 -0700 Subject: [PATCH 35/78] new test --- .../java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java index c67e874..7ccfa1c 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java @@ -18,6 +18,7 @@ public record FilterIntoJoin() implements RRule { public RelRN before() { var join = left.join(JoinRelType.INNER, joinCond, right); return join.filter("outer"); + // return join.filter("outer"); } @Override From d7ddac7819dece36cc09b0db4acd463313864daa Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 15:57:03 -0700 Subject: [PATCH 36/78] workflow --- .github/workflows/test-new-rules.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-new-rules.yml b/.github/workflows/test-new-rules.yml index d6b9123..5ffc300 100644 --- a/.github/workflows/test-new-rules.yml +++ b/.github/workflows/test-new-rules.yml @@ -18,7 +18,7 @@ jobs: - name: Set up Java uses: actions/setup-java@v4 with: - java-version: '11' + java-version: '21' distribution: 'temurin' - name: Cache Maven dependencies From 62a12dd14d3dd067f98f7cb8141325a7cbbbd2f0 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 15:57:38 -0700 Subject: [PATCH 37/78] new test --- .../java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java index 7ccfa1c..c67e874 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java @@ -18,7 +18,6 @@ public record FilterIntoJoin() implements RRule { public RelRN before() { var join = left.join(JoinRelType.INNER, joinCond, right); return join.filter("outer"); - // return join.filter("outer"); } @Override From 6913915e97799818252f8ca5c0e82bf25a199ba5 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 16:01:43 -0700 Subject: [PATCH 38/78] new workflow --- .github/workflows/test-new-rules.yml | 68 +++++++++++++++++++++++++--- 1 file changed, 62 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test-new-rules.yml b/.github/workflows/test-new-rules.yml index 5ffc300..6e60e8b 100644 --- a/.github/workflows/test-new-rules.yml +++ b/.github/workflows/test-new-rules.yml @@ -29,6 +29,34 @@ jobs: restore-keys: | ${{ runner.os }}-maven- + # Build CVC5 from source with Java bindings + - name: Build and Install CVC5 + run: | + # Install build dependencies + sudo apt-get update + sudo apt-get install -y cmake libgmp-dev python3 python3-pip + + # Clone and build CVC5 + git clone https://github.com/cvc5/cvc5.git + cd cvc5 + ./configure.sh --java-bindings --auto-download --prefix=$PWD/build/install + cd build + make -j$(nproc) + make install + + # Find the built JAR file + CVC5_JAR=$(find install/share/java -name "cvc5*.jar" | grep -v sources | head -1) + echo "Found CVC5 JAR: $CVC5_JAR" + + # Install to Maven local repository + mvn install:install-file \ + -Dfile="$CVC5_JAR" \ + -DgroupId=io.github.cvc5 \ + -DartifactId=cvc5 \ + -Dversion=1.2.1 \ + -Dpackaging=jar \ + -DgeneratePom=true + - name: Build project run: mvn -B compile --file pom.xml @@ -36,9 +64,13 @@ jobs: run: | mkdir -p tmp-rules + # First, ensure all dependencies are downloaded + mvn dependency:resolve + cat > JsonGenerator.java << 'EOF' import org.qed.*; import com.fasterxml.jackson.databind.ObjectMapper; + import com.fasterxml.jackson.databind.node.ObjectNode; import java.nio.file.*; public class JsonGenerator { public static void main(String[] args) throws Exception { @@ -53,9 +85,14 @@ jobs: ObjectMapper mapper = new ObjectMapper(); Files.createDirectories(Path.of("tmp-rules")); String fileName = rule.name() + "-" + rule.info() + ".json"; + + // Get the JSON object + ObjectNode jsonNode = rule.toJson(); + + // Write to file mapper.writerWithDefaultPrettyPrinter().writeValue( Path.of("tmp-rules", fileName).toFile(), - rule.toJson() + jsonNode ); System.out.println("Generated: " + fileName); } catch (Exception e) { @@ -67,16 +104,36 @@ jobs: } EOF - CLASSPATH="target/classes:$(mvn dependency:build-classpath -q -DforceStdout)" + # Build the classpath more reliably + echo "Building classpath..." + MAVEN_CP=$(mvn dependency:build-classpath -Dmdep.outputFile=/dev/stdout -q) + CLASSPATH="target/classes:${MAVEN_CP}" + + echo "Compiling JsonGenerator..." javac -cp "$CLASSPATH" JsonGenerator.java + + if [ $? -ne 0 ]; then + echo "Failed to compile JsonGenerator.java" + echo "Classpath was: $CLASSPATH" + exit 1 + fi + echo "Finding rule files..." find src/main/java/org/qed/Generated/RRuleInstances -name '*.java' | while read file; do className=$(echo "$file" | sed 's|src/main/java/||; s|/|.|g; s|\.java$||') - echo "Generating for: $className" - java -cp ".:$CLASSPATH" JsonGenerator "$className" + echo "Generating JSON for: $className" + java -cp ".:$CLASSPATH" JsonGenerator "$className" || echo "Failed to process $className" done rm -f JsonGenerator.java JsonGenerator.class + + # Check if any JSON files were generated + if [ ! "$(ls -A tmp-rules)" ]; then + echo "Error: No JSON files were generated!" + exit 1 + fi + + echo "Generated $(ls -1 tmp-rules/*.json | wc -l) JSON files" - name: Install Rust and clone qed-prover run: | @@ -88,7 +145,6 @@ jobs: cd qed-prover cargo build --release - - name: Install jq run: | sudo apt-get update @@ -125,4 +181,4 @@ jobs: uses: actions/upload-artifact@v4 with: name: rule-json - path: tmp-rules/ + path: tmp-rules/ \ No newline at end of file From f5a43838edd43148d95de37711feb945dc03f571 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 16:03:23 -0700 Subject: [PATCH 39/78] new test --- .../java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java index c67e874..473fca1 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java @@ -17,6 +17,7 @@ public record FilterIntoJoin() implements RRule { @Override public RelRN before() { var join = left.join(JoinRelType.INNER, joinCond, right); + // return join.filter("outer"); return join.filter("outer"); } From 8ead6a74fd94e5218b278998b69f2340559a9b70 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 17:02:38 -0700 Subject: [PATCH 40/78] new workflow --- .github/workflows/test-new-rules.yml | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test-new-rules.yml b/.github/workflows/test-new-rules.yml index 6e60e8b..7e2d126 100644 --- a/.github/workflows/test-new-rules.yml +++ b/.github/workflows/test-new-rules.yml @@ -139,11 +139,18 @@ jobs: run: | sudo apt-get update sudo apt-get install -y curl git - curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + + # Install Rust with nightly toolchain + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly source $HOME/.cargo/env + + # Verify we're using nightly + rustc --version + + # Clone and build qed-prover git clone https://github.com/qed-solver/prover.git qed-prover cd qed-prover - cargo build --release + cargo +nightly build --release - name: Install jq run: | From 5c00f9e1f178fff8741edaf87a69f2df37e45c9d Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 17:03:01 -0700 Subject: [PATCH 41/78] new test --- .../java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java index 473fca1..c67e874 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java @@ -17,7 +17,6 @@ public record FilterIntoJoin() implements RRule { @Override public RelRN before() { var join = left.join(JoinRelType.INNER, joinCond, right); - // return join.filter("outer"); return join.filter("outer"); } From 5b56bc6c8f8325104d0a70989b35f7f62bf4d11e Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 17:37:14 -0700 Subject: [PATCH 42/78] new workflow --- .github/workflows/test-new-rules.yml | 51 +++++++++++++++++++++++----- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test-new-rules.yml b/.github/workflows/test-new-rules.yml index 7e2d126..a6fff7e 100644 --- a/.github/workflows/test-new-rules.yml +++ b/.github/workflows/test-new-rules.yml @@ -152,10 +152,10 @@ jobs: cd qed-prover cargo +nightly build --release - - name: Install jq + - name: Install jq and z3 run: | sudo apt-get update - sudo apt-get install -y jq + sudo apt-get install -y jq z3 - name: Test all rules run: | @@ -164,20 +164,51 @@ jobs: failed_rules="" passed_rules="" + total_count=0 + passed_count=0 for json_file in tmp-rules/*.json; do + if [ ! -f "$json_file" ]; then + continue + fi + rule_name=$(basename "$json_file" .json) echo "Testing $rule_name" - output=$(./qed-prover/target/release/qed-prover "$json_file") - if echo "$output" | jq -e '.provable == true' > /dev/null; then - echo "✅ $rule_name: PASSED" | tee -a $GITHUB_STEP_SUMMARY - passed_rules="$passed_rules$rule_name," + total_count=$((total_count + 1)) + + # Run the prover - it creates a .result file + ./qed-prover/target/release/qed-prover "$json_file" || true + + # Check if the result file was created + result_file="${json_file%.json}.result" + if [ -f "$result_file" ]; then + # Parse the result file + if jq -e '.provable == true' "$result_file" > /dev/null 2>&1; then + echo "✅ $rule_name: PASSED" | tee -a $GITHUB_STEP_SUMMARY + passed_rules="$passed_rules$rule_name," + passed_count=$((passed_count + 1)) + else + # Check if it panicked or just failed to prove + panicked=$(jq -r '.panicked // false' "$result_file") + if [ "$panicked" = "true" ]; then + echo "❌ $rule_name: PANICKED" | tee -a $GITHUB_STEP_SUMMARY + else + echo "❌ $rule_name: FAILED TO PROVE" | tee -a $GITHUB_STEP_SUMMARY + fi + failed_rules="$failed_rules$rule_name," + fi else - echo "❌ $rule_name: FAILED" | tee -a $GITHUB_STEP_SUMMARY + echo "❌ $rule_name: NO RESULT FILE GENERATED" | tee -a $GITHUB_STEP_SUMMARY failed_rules="$failed_rules$rule_name," fi done + echo "" >> $GITHUB_STEP_SUMMARY + echo "### Summary" >> $GITHUB_STEP_SUMMARY + echo "- Total rules tested: $total_count" >> $GITHUB_STEP_SUMMARY + echo "- Passed: $passed_count" >> $GITHUB_STEP_SUMMARY + echo "- Failed: $((total_count - passed_count))" >> $GITHUB_STEP_SUMMARY + if [ -n "$failed_rules" ]; then echo "::error::Some rules failed: ${failed_rules%,}" exit 1 @@ -187,5 +218,7 @@ jobs: if: always() uses: actions/upload-artifact@v4 with: - name: rule-json - path: tmp-rules/ \ No newline at end of file + name: rule-test-results + path: | + tmp-rules/*.json + tmp-rules/*.result \ No newline at end of file From 440de538b41befd5a97c5b8a9c02166445f52ae9 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 17:37:36 -0700 Subject: [PATCH 43/78] new test --- .../java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java index c67e874..473fca1 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java @@ -17,6 +17,7 @@ public record FilterIntoJoin() implements RRule { @Override public RelRN before() { var join = left.join(JoinRelType.INNER, joinCond, right); + // return join.filter("outer"); return join.filter("outer"); } From 3b82055503b498a829f324383e3e7fddae8c0383 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 18:06:22 -0700 Subject: [PATCH 44/78] new test --- .github/workflows/test-new-rules.yml | 153 +++--------------- .../RRuleInstances/FilterIntoJoin.java | 1 - 2 files changed, 25 insertions(+), 129 deletions(-) diff --git a/.github/workflows/test-new-rules.yml b/.github/workflows/test-new-rules.yml index a6fff7e..e22af70 100644 --- a/.github/workflows/test-new-rules.yml +++ b/.github/workflows/test-new-rules.yml @@ -2,11 +2,7 @@ name: Test All Rules on: push: - paths: - - 'src/main/java/org/qed/Generated/RRuleInstances/**/*.java' pull_request: - paths: - - 'src/main/java/org/qed/Generated/RRuleInstances/**/*.java' jobs: test-all-rules: @@ -29,42 +25,12 @@ jobs: restore-keys: | ${{ runner.os }}-maven- - # Build CVC5 from source with Java bindings - - name: Build and Install CVC5 - run: | - # Install build dependencies - sudo apt-get update - sudo apt-get install -y cmake libgmp-dev python3 python3-pip - - # Clone and build CVC5 - git clone https://github.com/cvc5/cvc5.git - cd cvc5 - ./configure.sh --java-bindings --auto-download --prefix=$PWD/build/install - cd build - make -j$(nproc) - make install - - # Find the built JAR file - CVC5_JAR=$(find install/share/java -name "cvc5*.jar" | grep -v sources | head -1) - echo "Found CVC5 JAR: $CVC5_JAR" - - # Install to Maven local repository - mvn install:install-file \ - -Dfile="$CVC5_JAR" \ - -DgroupId=io.github.cvc5 \ - -DartifactId=cvc5 \ - -Dversion=1.2.1 \ - -Dpackaging=jar \ - -DgeneratePom=true - - name: Build project run: mvn -B compile --file pom.xml - name: Generate JSON for all rules run: | mkdir -p tmp-rules - - # First, ensure all dependencies are downloaded mvn dependency:resolve cat > JsonGenerator.java << 'EOF' @@ -74,143 +40,76 @@ jobs: import java.nio.file.*; public class JsonGenerator { public static void main(String[] args) throws Exception { - if (args.length == 0) { - System.err.println("Usage: java JsonGenerator "); - System.exit(1); - } String className = args[0]; - try { - Class clazz = Class.forName(className); - RRule rule = (RRule) clazz.getDeclaredConstructor().newInstance(); - ObjectMapper mapper = new ObjectMapper(); - Files.createDirectories(Path.of("tmp-rules")); - String fileName = rule.name() + "-" + rule.info() + ".json"; - - // Get the JSON object - ObjectNode jsonNode = rule.toJson(); - - // Write to file - mapper.writerWithDefaultPrettyPrinter().writeValue( - Path.of("tmp-rules", fileName).toFile(), - jsonNode - ); - System.out.println("Generated: " + fileName); - } catch (Exception e) { - System.err.println("Failed to generate JSON for " + className); - e.printStackTrace(); - System.exit(1); - } + Class clazz = Class.forName(className); + RRule rule = (RRule) clazz.getDeclaredConstructor().newInstance(); + ObjectMapper mapper = new ObjectMapper(); + String fileName = rule.name() + "-" + rule.info() + ".json"; + ObjectNode jsonNode = rule.toJson(); + mapper.writerWithDefaultPrettyPrinter().writeValue( + Path.of("tmp-rules", fileName).toFile(), + jsonNode + ); } } EOF - # Build the classpath more reliably - echo "Building classpath..." MAVEN_CP=$(mvn dependency:build-classpath -Dmdep.outputFile=/dev/stdout -q) CLASSPATH="target/classes:${MAVEN_CP}" - - echo "Compiling JsonGenerator..." javac -cp "$CLASSPATH" JsonGenerator.java - - if [ $? -ne 0 ]; then - echo "Failed to compile JsonGenerator.java" - echo "Classpath was: $CLASSPATH" - exit 1 - fi - echo "Finding rule files..." find src/main/java/org/qed/Generated/RRuleInstances -name '*.java' | while read file; do className=$(echo "$file" | sed 's|src/main/java/||; s|/|.|g; s|\.java$||') - echo "Generating JSON for: $className" - java -cp ".:$CLASSPATH" JsonGenerator "$className" || echo "Failed to process $className" + java -cp ".:$CLASSPATH" JsonGenerator "$className" done rm -f JsonGenerator.java JsonGenerator.class - - # Check if any JSON files were generated - if [ ! "$(ls -A tmp-rules)" ]; then - echo "Error: No JSON files were generated!" - exit 1 - fi - - echo "Generated $(ls -1 tmp-rules/*.json | wc -l) JSON files" - - name: Install Rust and clone qed-prover + - name: Install dependencies run: | sudo apt-get update - sudo apt-get install -y curl git - - # Install Rust with nightly toolchain + sudo apt-get install -y jq z3 + wget https://github.com/cvc5/cvc5/releases/download/cvc5-1.1.2/cvc5-Linux-x86_64 -O cvc5 + chmod +x cvc5 + sudo mv cvc5 /usr/local/bin/ + + - name: Build qed-prover + run: | curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly source $HOME/.cargo/env - - # Verify we're using nightly - rustc --version - - # Clone and build qed-prover git clone https://github.com/qed-solver/prover.git qed-prover cd qed-prover cargo +nightly build --release - - name: Install jq and z3 - run: | - sudo apt-get update - sudo apt-get install -y jq z3 - - name: Test all rules run: | echo "## QED Prover Test Results" >> $GITHUB_STEP_SUMMARY echo "" >> $GITHUB_STEP_SUMMARY failed_rules="" - passed_rules="" total_count=0 passed_count=0 for json_file in tmp-rules/*.json; do - if [ ! -f "$json_file" ]; then - continue - fi - rule_name=$(basename "$json_file" .json) - echo "Testing $rule_name" total_count=$((total_count + 1)) - - # Run the prover - it creates a .result file ./qed-prover/target/release/qed-prover "$json_file" || true - # Check if the result file was created result_file="${json_file%.json}.result" - if [ -f "$result_file" ]; then - # Parse the result file - if jq -e '.provable == true' "$result_file" > /dev/null 2>&1; then - echo "✅ $rule_name: PASSED" | tee -a $GITHUB_STEP_SUMMARY - passed_rules="$passed_rules$rule_name," - passed_count=$((passed_count + 1)) - else - # Check if it panicked or just failed to prove - panicked=$(jq -r '.panicked // false' "$result_file") - if [ "$panicked" = "true" ]; then - echo "❌ $rule_name: PANICKED" | tee -a $GITHUB_STEP_SUMMARY - else - echo "❌ $rule_name: FAILED TO PROVE" | tee -a $GITHUB_STEP_SUMMARY - fi - failed_rules="$failed_rules$rule_name," - fi + if [ -f "$result_file" ] && jq -e '.provable == true' "$result_file" > /dev/null 2>&1; then + echo "✅ $rule_name: PASSED" >> $GITHUB_STEP_SUMMARY + passed_count=$((passed_count + 1)) else - echo "❌ $rule_name: NO RESULT FILE GENERATED" | tee -a $GITHUB_STEP_SUMMARY + echo "❌ $rule_name: FAILED" >> $GITHUB_STEP_SUMMARY failed_rules="$failed_rules$rule_name," fi done echo "" >> $GITHUB_STEP_SUMMARY - echo "### Summary" >> $GITHUB_STEP_SUMMARY - echo "- Total rules tested: $total_count" >> $GITHUB_STEP_SUMMARY - echo "- Passed: $passed_count" >> $GITHUB_STEP_SUMMARY - echo "- Failed: $((total_count - passed_count))" >> $GITHUB_STEP_SUMMARY + echo "**Summary:** $passed_count/$total_count passed" >> $GITHUB_STEP_SUMMARY if [ -n "$failed_rules" ]; then - echo "::error::Some rules failed: ${failed_rules%,}" + echo "::error::Failed rules: ${failed_rules%,}" exit 1 fi @@ -219,6 +118,4 @@ jobs: uses: actions/upload-artifact@v4 with: name: rule-test-results - path: | - tmp-rules/*.json - tmp-rules/*.result \ No newline at end of file + path: tmp-rules/ \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java index 473fca1..c67e874 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java @@ -17,7 +17,6 @@ public record FilterIntoJoin() implements RRule { @Override public RelRN before() { var join = left.join(JoinRelType.INNER, joinCond, right); - // return join.filter("outer"); return join.filter("outer"); } From 9f6a48b503b146e4e9bcf402cc462f7495d3e44d Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 18:14:21 -0700 Subject: [PATCH 45/78] new test --- .../java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java index c67e874..dd2159d 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java @@ -22,6 +22,7 @@ public RelRN before() { @Override public RelRN after() { - return left.join(JoinRelType.SEMI, RexRN.and(joinCond, left.joinPred("outer", right)), right); + return left.join( + JoinRelType.SEMI, RexRN.and(joinCond, left.joinPred("outer", right)), right); } } From 7dd23c3f557b547edd507de48bc6480e001022a1 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 18:17:18 -0700 Subject: [PATCH 46/78] delete join associate temporarily --- .../RRuleInstances/JoinAssociate.java | 69 ------------------- 1 file changed, 69 deletions(-) delete mode 100644 src/main/java/org/qed/Generated/RRuleInstances/JoinAssociate.java diff --git a/src/main/java/org/qed/Generated/RRuleInstances/JoinAssociate.java b/src/main/java/org/qed/Generated/RRuleInstances/JoinAssociate.java deleted file mode 100644 index 059695e..0000000 --- a/src/main/java/org/qed/Generated/RRuleInstances/JoinAssociate.java +++ /dev/null @@ -1,69 +0,0 @@ -package org.qed.Generated.RRuleInstances; - -import kala.collection.Map; -import kala.collection.Seq; -import org.apache.calcite.rel.core.JoinRelType; -import org.apache.calcite.sql.fun.SqlStdOperatorTable; -import org.qed.RelRN; -import org.qed.RexRN; -import org.qed.RRule; -import org.qed.RuleBuilder; - -public record JoinAssociate() implements RRule.RRuleFamily { - static final RelRN a = RelRN.scan("A", "A_Type"); - static final RelRN b = RelRN.scan("B", "B_Type"); - static final RelRN c = RelRN.scan("C", "C_Type"); - static final String pred_ab = "pred_ab"; - static final String pred_bc = "pred_bc"; - static final RelRN.Join.JoinType.MetaJoinType mjt_0 = new RelRN.Join.JoinType.MetaJoinType("mjt_0"); - static final RelRN.Join.JoinType.MetaJoinType mjt_1 = new RelRN.Join.JoinType.MetaJoinType("mjt_1"); - static final RelRN.Join.JoinType.MetaJoinType mjt_2 = new RelRN.Join.JoinType.MetaJoinType("mjt_2"); - static final RelRN.Join.JoinType.MetaJoinType mjt_3 = new RelRN.Join.JoinType.MetaJoinType("mjt_3"); - - static final RelRN before_ab = a.join(mjt_0, RexRN.and( - a.joinPred(pred_ab, b), - new RexRN.JoinField(1, a, b).pred(SqlStdOperatorTable.IS_NOT_NULL) - ), b); - - static final RelRN before = before_ab.join(mjt_1, RexRN.and( - new RexRN.Pred(RuleBuilder.create().genericPredicateOp(pred_bc, true), before_ab.joinFields(c, 1, 2)), - new RexRN.JoinField(1, before_ab, c).pred(SqlStdOperatorTable.IS_NOT_NULL) - ), c); - - static final RelRN after_bc = b.join(mjt_2, RexRN.and( - b.joinPred(pred_bc, c), - new RexRN.JoinField(0, b, c).pred(SqlStdOperatorTable.IS_NOT_NULL) - ), c); - - static final RelRN after = a.join(mjt_3, RexRN.and( - new RexRN.Pred(RuleBuilder.create().genericPredicateOp(pred_ab, true), a.joinFields(after_bc, 0, 1)), - new RexRN.JoinField(1, a, after_bc).pred(SqlStdOperatorTable.IS_NOT_NULL) - ), after_bc); - - static final RRule template = new RRule() { - @Override - public RelRN before() { - return before; - } - - @Override - public RelRN after() { - return after; - } - - @Override - public String name() { - return JoinAssociate.class.getSimpleName(); - } - }; - - static Seq assignments() { - var joinTypes = Seq.of(JoinRelType.INNER, JoinRelType.LEFT, JoinRelType.RIGHT, JoinRelType.FULL).map(RelRN.Join.JoinType.ConcreteJoinType::new); - return joinTypes.flatMap(jt0 -> joinTypes.flatMap(jt1 -> joinTypes.flatMap(jt2 -> joinTypes.map(jt3 -> new RRule.RRuleGenerator.MetaAssignment(Map.of(mjt_0, jt0, mjt_1, jt1, mjt_2, jt2, mjt_3, jt3)))))); - } - - @Override - public Seq family() { - return new RRule.RRuleGenerator(template, assignments()).family(); - } -} From c878388e6c224dc631b7c609947012f92d02ead8 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 18:24:54 -0700 Subject: [PATCH 47/78] new test --- .github/workflows/test-new-rules.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test-new-rules.yml b/.github/workflows/test-new-rules.yml index e22af70..44a4b8c 100644 --- a/.github/workflows/test-new-rules.yml +++ b/.github/workflows/test-new-rules.yml @@ -69,9 +69,10 @@ jobs: run: | sudo apt-get update sudo apt-get install -y jq z3 - wget https://github.com/cvc5/cvc5/releases/download/cvc5-1.1.2/cvc5-Linux-x86_64 -O cvc5 - chmod +x cvc5 - sudo mv cvc5 /usr/local/bin/ + wget https://github.com/cvc5/cvc5/releases/download/cvc5-1.2.1/cvc5-1.2.1-Linux-x86_64-static.zip + unzip cvc5-1.2.1-Linux-x86_64-static.zip + chmod +x cvc5-1.2.1/bin/cvc5 + sudo mv cvc5-1.2.1/bin/cvc5 /usr/local/bin/ - name: Build qed-prover run: | From e3047974349b023f096d8d78a60bfa0b9ab764b8 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 18:27:39 -0700 Subject: [PATCH 48/78] new test --- .github/workflows/test-new-rules.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test-new-rules.yml b/.github/workflows/test-new-rules.yml index 44a4b8c..ec38f1a 100644 --- a/.github/workflows/test-new-rules.yml +++ b/.github/workflows/test-new-rules.yml @@ -69,10 +69,10 @@ jobs: run: | sudo apt-get update sudo apt-get install -y jq z3 - wget https://github.com/cvc5/cvc5/releases/download/cvc5-1.2.1/cvc5-1.2.1-Linux-x86_64-static.zip - unzip cvc5-1.2.1-Linux-x86_64-static.zip - chmod +x cvc5-1.2.1/bin/cvc5 - sudo mv cvc5-1.2.1/bin/cvc5 /usr/local/bin/ + wget https://github.com/cvc5/cvc5/releases/download/cvc5-1.2.1/cvc5-Linux-x86_64-static.zip + unzip cvc5-Linux-x86_64-static.zip + chmod +x cvc5/bin/cvc5 + sudo mv cvc5/bin/cvc5 /usr/local/bin/ - name: Build qed-prover run: | From 5a39e4edbc40cd0c8b3c26c14671eb9394e1672d Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 18:31:36 -0700 Subject: [PATCH 49/78] new test --- .github/workflows/test-new-rules.yml | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test-new-rules.yml b/.github/workflows/test-new-rules.yml index ec38f1a..d8cafc7 100644 --- a/.github/workflows/test-new-rules.yml +++ b/.github/workflows/test-new-rules.yml @@ -68,11 +68,13 @@ jobs: - name: Install dependencies run: | sudo apt-get update - sudo apt-get install -y jq z3 - wget https://github.com/cvc5/cvc5/releases/download/cvc5-1.2.1/cvc5-Linux-x86_64-static.zip - unzip cvc5-Linux-x86_64-static.zip - chmod +x cvc5/bin/cvc5 - sudo mv cvc5/bin/cvc5 /usr/local/bin/ + sudo apt-get install -y jq z3 cmake libgmp-dev + git clone --depth 1 --branch cvc5-1.2.1 https://github.com/cvc5/cvc5.git + cd cvc5 + ./configure.sh --auto-download + cd build + make -j$(nproc) cvc5 + sudo cp bin/cvc5 /usr/local/bin/ - name: Build qed-prover run: | From 36c6c46a3a59b1404e477e0a6cf30950f0798595 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 18:46:44 -0700 Subject: [PATCH 50/78] new test --- .github/workflows/test-new-rules.yml | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test-new-rules.yml b/.github/workflows/test-new-rules.yml index d8cafc7..6eac450 100644 --- a/.github/workflows/test-new-rules.yml +++ b/.github/workflows/test-new-rules.yml @@ -68,13 +68,17 @@ jobs: - name: Install dependencies run: | sudo apt-get update - sudo apt-get install -y jq z3 cmake libgmp-dev - git clone --depth 1 --branch cvc5-1.2.1 https://github.com/cvc5/cvc5.git - cd cvc5 - ./configure.sh --auto-download - cd build - make -j$(nproc) cvc5 - sudo cp bin/cvc5 /usr/local/bin/ + sudo apt-get install -y jq z3 + # Try to install cvc5 from package manager first, otherwise build from source + sudo apt-get install -y cvc5 || ( + sudo apt-get install -y cmake libgmp-dev && + git clone --depth 1 https://github.com/cvc5/cvc5.git && + cd cvc5 && + ./configure.sh --auto-download && + cd build && + make -j$(nproc) && + sudo make install + ) - name: Build qed-prover run: | From 06677975c94880b2fea3d83ee6464f4a63905fd3 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 19:55:40 -0700 Subject: [PATCH 51/78] restore FilterIntoJoin.java --- .../java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java index dd2159d..92517fc 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java @@ -22,7 +22,6 @@ public RelRN before() { @Override public RelRN after() { - return left.join( - JoinRelType.SEMI, RexRN.and(joinCond, left.joinPred("outer", right)), right); + return left.join(JoinRelType.INNER, RexRN.and(joinCond, left.joinPred("outer", right)), right); } } From e164dda67d2a09be660c11e1ac3b561b2dba1e98 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 20:03:59 -0700 Subject: [PATCH 52/78] update test.yml --- .github/workflows/{test-new-rules.yml => test.yml} | 6 +++++- .../org/qed/Generated/RRuleInstances/FilterIntoJoin.java | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) rename .github/workflows/{test-new-rules.yml => test.yml} (96%) diff --git a/.github/workflows/test-new-rules.yml b/.github/workflows/test.yml similarity index 96% rename from .github/workflows/test-new-rules.yml rename to .github/workflows/test.yml index 6eac450..d0c3032 100644 --- a/.github/workflows/test-new-rules.yml +++ b/.github/workflows/test.yml @@ -2,7 +2,11 @@ name: Test All Rules on: push: + paths: + - 'src/main/java/org/qed/Generated/RRuleInstances/**/*.java' pull_request: + paths: + - 'src/main/java/org/qed/Generated/RRuleInstances/**/*.java' jobs: test-all-rules: @@ -124,5 +128,5 @@ jobs: if: always() uses: actions/upload-artifact@v4 with: - name: rule-test-results + name: results path: tmp-rules/ \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java index 92517fc..24165aa 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java @@ -12,7 +12,7 @@ public record FilterIntoJoin() implements RRule { static final RelRN left = RelRN.scan("Left", "Left_Type"); static final RelRN right = RelRN.scan("Right", "Right_Type"); - static final RexRN joinCond = left.joinPred("join", right); + static final RexRN joinCond = left.joinPred("join", right); @Override public RelRN before() { From a25c7dda707946c1c7e8e2c40ee55f09c55155ef Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 22:53:20 -0700 Subject: [PATCH 53/78] delete unprovable rules --- .../Generated/RRuleInstances/JoinCommute.java | 28 ------------------- .../RRuleInstances/SemiJoinJoinTranspose.java | 25 ----------------- .../SemiJoinProjectTranspose.java | 23 --------------- .../RRuleInstances/SemiJoinRemove.java | 21 -------------- 4 files changed, 97 deletions(-) delete mode 100644 src/main/java/org/qed/Generated/RRuleInstances/JoinCommute.java delete mode 100644 src/main/java/org/qed/Generated/RRuleInstances/SemiJoinJoinTranspose.java delete mode 100644 src/main/java/org/qed/Generated/RRuleInstances/SemiJoinProjectTranspose.java delete mode 100644 src/main/java/org/qed/Generated/RRuleInstances/SemiJoinRemove.java diff --git a/src/main/java/org/qed/Generated/RRuleInstances/JoinCommute.java b/src/main/java/org/qed/Generated/RRuleInstances/JoinCommute.java deleted file mode 100644 index 22f2344..0000000 --- a/src/main/java/org/qed/Generated/RRuleInstances/JoinCommute.java +++ /dev/null @@ -1,28 +0,0 @@ -package org.qed.Generated.RRuleInstances; - -import kala.collection.Map; -import kala.collection.Seq; -import org.apache.calcite.rel.core.JoinRelType; -import org.apache.calcite.sql.fun.SqlStdOperatorTable; -import org.qed.RelRN; -import org.qed.RexRN; -import org.qed.RRule; -import org.qed.RuleBuilder; - -public record JoinCommute() implements RRule { - static final RelRN left = RelRN.scan("Left", "Left_Type"); - static final RelRN right = RelRN.scan("Right", "Right_Type"); - static final RexRN joinCond = left.joinPred("pred", right); - - @Override - public RelRN before() { - return left.join(JoinRelType.INNER, joinCond, right); - } - - @Override - public RelRN after() { - // We need to swap the join fields in the condition - RexRN commutedJoinCond = right.joinPred("pred", left); - return right.join(JoinRelType.INNER, commutedJoinCond, left); - } -} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinJoinTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinJoinTranspose.java deleted file mode 100644 index b1a4a3d..0000000 --- a/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinJoinTranspose.java +++ /dev/null @@ -1,25 +0,0 @@ -package org.qed.Generated.RRuleInstances; - -import org.apache.calcite.rel.core.JoinRelType; -import org.qed.RRule; -import org.qed.RelRN; -import org.qed.RexRN; - - -public record SemiJoinJoinTranspose() implements RRule { - static final RelRN left = RelRN.scan("left", "left_Type"); - static final RelRN middle = RelRN.scan("middle", "middle_Type"); - static final RelRN right = RelRN.scan("right", "right_Type"); - static final RexRN semiCond = left.joinPred("semi", middle); - static final RexRN joinCond = left.joinPred("join", right); - - @Override - public RelRN before() { - return left.join(JoinRelType.INNER, joinCond, right).join(JoinRelType.SEMI, semiCond, middle); - } - - @Override - public RelRN after() { - return left.join(JoinRelType.SEMI, semiCond, middle).join(JoinRelType.INNER, joinCond, right); - } -} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinProjectTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinProjectTranspose.java deleted file mode 100644 index 71d2dfe..0000000 --- a/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinProjectTranspose.java +++ /dev/null @@ -1,23 +0,0 @@ -package org.qed.Generated.RRuleInstances; - -import org.apache.calcite.rel.core.JoinRelType; -import org.qed.RRule; -import org.qed.RelRN; -import org.qed.RexRN; - -public record SemiJoinProjectTranspose() implements RRule { - static final RelRN left = RelRN.scan("Left", "left_type"); - static final RelRN right = RelRN.scan("Right", "right_type"); - static final RexRN proj = left.proj("proj", "proj_type"); - static final RexRN semiCond = left.joinPred("semi", right); - - @Override - public RelRN before() { - return left.join(JoinRelType.SEMI, semiCond, right).project(proj); - } - - @Override - public RelRN after() { - return left.project(proj).join(JoinRelType.SEMI, semiCond, right); - } -} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinRemove.java b/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinRemove.java deleted file mode 100644 index 3b21e3f..0000000 --- a/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinRemove.java +++ /dev/null @@ -1,21 +0,0 @@ -package org.qed.Generated.RRuleInstances; - -import org.apache.calcite.rel.core.JoinRelType; -import org.qed.RRule; -import org.qed.RelRN; -import org.qed.RexRN; - -public record SemiJoinRemove() implements RRule { - static final RelRN left = RelRN.scan("Left", "Left_Type"); - static final RelRN right = RelRN.scan("Right", "Right_Type"); - - @Override - public RelRN before() { - return left.join(JoinRelType.SEMI, RexRN.trueLiteral(), right); - } - - @Override - public RelRN after() { - return left; - } -} From 7a8640942d163ca90ea5c10b0ff10aa5cd449182 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 23:19:40 -0700 Subject: [PATCH 54/78] test rule that cannot compile --- .../java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java index 24165aa..8c5212d 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java @@ -10,7 +10,7 @@ import org.qed.RuleBuilder; public record FilterIntoJoin() implements RRule { - static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN left = RelRN.scan("Left", "Left_Type"; static final RelRN right = RelRN.scan("Right", "Right_Type"); static final RexRN joinCond = left.joinPred("join", right); From 357dcb9a55a963847cb0645c83630bba19122f2f Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 23:33:25 -0700 Subject: [PATCH 55/78] update workflow --- .github/workflows/test.yml | 35 ++------------------------- scripts/generate-rule-json.sh | 45 +++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 33 deletions(-) create mode 100644 scripts/generate-rule-json.sh diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d0c3032..436ed22 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,44 +36,13 @@ jobs: run: | mkdir -p tmp-rules mvn dependency:resolve - - cat > JsonGenerator.java << 'EOF' - import org.qed.*; - import com.fasterxml.jackson.databind.ObjectMapper; - import com.fasterxml.jackson.databind.node.ObjectNode; - import java.nio.file.*; - public class JsonGenerator { - public static void main(String[] args) throws Exception { - String className = args[0]; - Class clazz = Class.forName(className); - RRule rule = (RRule) clazz.getDeclaredConstructor().newInstance(); - ObjectMapper mapper = new ObjectMapper(); - String fileName = rule.name() + "-" + rule.info() + ".json"; - ObjectNode jsonNode = rule.toJson(); - mapper.writerWithDefaultPrettyPrinter().writeValue( - Path.of("tmp-rules", fileName).toFile(), - jsonNode - ); - } - } - EOF - - MAVEN_CP=$(mvn dependency:build-classpath -Dmdep.outputFile=/dev/stdout -q) - CLASSPATH="target/classes:${MAVEN_CP}" - javac -cp "$CLASSPATH" JsonGenerator.java - - find src/main/java/org/qed/Generated/RRuleInstances -name '*.java' | while read file; do - className=$(echo "$file" | sed 's|src/main/java/||; s|/|.|g; s|\.java$||') - java -cp ".:$CLASSPATH" JsonGenerator "$className" - done - - rm -f JsonGenerator.java JsonGenerator.class + chmod +x scripts/generate-rule-json.sh + ./scripts/generate-rule-json.sh - name: Install dependencies run: | sudo apt-get update sudo apt-get install -y jq z3 - # Try to install cvc5 from package manager first, otherwise build from source sudo apt-get install -y cvc5 || ( sudo apt-get install -y cmake libgmp-dev && git clone --depth 1 https://github.com/cvc5/cvc5.git && diff --git a/scripts/generate-rule-json.sh b/scripts/generate-rule-json.sh new file mode 100644 index 0000000..a34f5da --- /dev/null +++ b/scripts/generate-rule-json.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +# Script to generate JSON files for all RRule instances + +# Create temporary Java file for JSON generation +cat > JsonGenerator.java << 'EOF' +import org.qed.*; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import java.nio.file.*; + +public class JsonGenerator { + public static void main(String[] args) throws Exception { + String className = args[0]; + Class clazz = Class.forName(className); + RRule rule = (RRule) clazz.getDeclaredConstructor().newInstance(); + ObjectMapper mapper = new ObjectMapper(); + String fileName = rule.name() + "-" + rule.info() + ".json"; + ObjectNode jsonNode = rule.toJson(); + mapper.writerWithDefaultPrettyPrinter().writeValue( + Path.of("tmp-rules", fileName).toFile(), + jsonNode + ); + } +} +EOF + +# Build classpath +MAVEN_CP=$(mvn dependency:build-classpath -Dmdep.outputFile=/dev/stdout -q) +CLASSPATH="target/classes:${MAVEN_CP}" + +# Compile the generator +javac -cp "$CLASSPATH" JsonGenerator.java + +# Generate JSON for each rule +find src/main/java/org/qed/Generated/RRuleInstances -name '*.java' | while read file; do + className=$(echo "$file" | sed 's|src/main/java/||; s|/|.|g; s|\.java$||') + echo "Generating JSON for: $className" + java -cp ".:$CLASSPATH" JsonGenerator "$className" +done + +# Cleanup +rm -f JsonGenerator.java JsonGenerator.class + +echo "JSON generation complete. Files are in tmp-rules/" \ No newline at end of file From 248f0381915d7a272df0b91378d535fca7395c38 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 23:34:22 -0700 Subject: [PATCH 56/78] restore FilterIntoJoin.java to correct version --- .../java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java index 8c5212d..24165aa 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java @@ -10,7 +10,7 @@ import org.qed.RuleBuilder; public record FilterIntoJoin() implements RRule { - static final RelRN left = RelRN.scan("Left", "Left_Type"; + static final RelRN left = RelRN.scan("Left", "Left_Type"); static final RelRN right = RelRN.scan("Right", "Right_Type"); static final RexRN joinCond = left.joinPred("join", right); From 1257866c0acf07feacc91610f6df4a5443ad1b08 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 23:53:12 -0700 Subject: [PATCH 57/78] update java version --- .github/workflows/test.yml | 2 +- pom.xml | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 436ed22..0663ca9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -18,7 +18,7 @@ jobs: - name: Set up Java uses: actions/setup-java@v4 with: - java-version: '21' + java-version: '23' distribution: 'temurin' - name: Cache Maven dependencies diff --git a/pom.xml b/pom.xml index d194cc2..fd29e81 100644 --- a/pom.xml +++ b/pom.xml @@ -21,7 +21,6 @@ -classpath org.qed.Main - ${args} @@ -52,9 +51,11 @@ org.apache.maven.plugins maven-compiler-plugin - 21 - 21 - --enable-preview + 23 + 23 + + --enable-preview + From a0b066b4ffba4d341c0d8068b4c2bcc2ed5045d8 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sat, 31 May 2025 23:54:45 -0700 Subject: [PATCH 58/78] test workflow --- .../qed/Generated/RRuleInstances/SemiJoinFilterTranspose.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinFilterTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinFilterTranspose.java index d1bc90e..790c1d2 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinFilterTranspose.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinFilterTranspose.java @@ -17,13 +17,11 @@ public record SemiJoinFilterTranspose() implements RRule { @Override public RelRN before() { - // Semi-join followed by a filter return left.join(JoinRelType.SEMI, joinCond, right).filter(filterPred); } @Override public RelRN after() { - // Push the filter before the semi-join RelRN leftFiltered = left.filter(filterPred); return leftFiltered.join(JoinRelType.SEMI, leftFiltered.joinPred("join", right), right); } From 2a2e9cebde85f280a4301e61fb36d469d8d0f7c0 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sun, 1 Jun 2025 00:20:17 -0700 Subject: [PATCH 59/78] update workflow --- .github/workflows/test.yml | 51 ++++----------------------------- scripts/build-qed-prover.sh | 0 scripts/install-dependencies.sh | 17 +++++++++++ scripts/test-rules.sh | 33 +++++++++++++++++++++ 4 files changed, 56 insertions(+), 45 deletions(-) create mode 100644 scripts/build-qed-prover.sh create mode 100644 scripts/install-dependencies.sh create mode 100644 scripts/test-rules.sh diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0663ca9..a1ed0b8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -41,57 +41,18 @@ jobs: - name: Install dependencies run: | - sudo apt-get update - sudo apt-get install -y jq z3 - sudo apt-get install -y cvc5 || ( - sudo apt-get install -y cmake libgmp-dev && - git clone --depth 1 https://github.com/cvc5/cvc5.git && - cd cvc5 && - ./configure.sh --auto-download && - cd build && - make -j$(nproc) && - sudo make install - ) + chmod +x scripts/install-dependencies.sh + ./scripts/install-dependencies.sh - name: Build qed-prover run: | - curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly - source $HOME/.cargo/env - git clone https://github.com/qed-solver/prover.git qed-prover - cd qed-prover - cargo +nightly build --release + chmod +x scripts/build-qed-prover.sh + ./scripts/build-qed-prover.sh - name: Test all rules run: | - echo "## QED Prover Test Results" >> $GITHUB_STEP_SUMMARY - echo "" >> $GITHUB_STEP_SUMMARY - - failed_rules="" - total_count=0 - passed_count=0 - - for json_file in tmp-rules/*.json; do - rule_name=$(basename "$json_file" .json) - total_count=$((total_count + 1)) - ./qed-prover/target/release/qed-prover "$json_file" || true - - result_file="${json_file%.json}.result" - if [ -f "$result_file" ] && jq -e '.provable == true' "$result_file" > /dev/null 2>&1; then - echo "✅ $rule_name: PASSED" >> $GITHUB_STEP_SUMMARY - passed_count=$((passed_count + 1)) - else - echo "❌ $rule_name: FAILED" >> $GITHUB_STEP_SUMMARY - failed_rules="$failed_rules$rule_name," - fi - done - - echo "" >> $GITHUB_STEP_SUMMARY - echo "**Summary:** $passed_count/$total_count passed" >> $GITHUB_STEP_SUMMARY - - if [ -n "$failed_rules" ]; then - echo "::error::Failed rules: ${failed_rules%,}" - exit 1 - fi + chmod +x scripts/test-rules.sh + ./scripts/test-rules.sh - name: Upload test artifacts if: always() diff --git a/scripts/build-qed-prover.sh b/scripts/build-qed-prover.sh new file mode 100644 index 0000000..e69de29 diff --git a/scripts/install-dependencies.sh b/scripts/install-dependencies.sh new file mode 100644 index 0000000..500879f --- /dev/null +++ b/scripts/install-dependencies.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +# Install required dependencies for qed-prover + +sudo apt-get update +sudo apt-get install -y jq z3 + +# Install cvc5 - try package manager first, otherwise build from source +sudo apt-get install -y cvc5 || ( + sudo apt-get install -y cmake libgmp-dev && + git clone --depth 1 https://github.com/cvc5/cvc5.git && + cd cvc5 && + ./configure.sh --auto-download && + cd build && + make -j$(nproc) && + sudo make install +) \ No newline at end of file diff --git a/scripts/test-rules.sh b/scripts/test-rules.sh new file mode 100644 index 0000000..5dc149b --- /dev/null +++ b/scripts/test-rules.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# Test all generated rules with qed-prover + +echo "## QED Prover Test Results" >> $GITHUB_STEP_SUMMARY +echo "" >> $GITHUB_STEP_SUMMARY + +failed_rules="" +total_count=0 +passed_count=0 + +for json_file in tmp-rules/*.json; do + rule_name=$(basename "$json_file" .json) + total_count=$((total_count + 1)) + ./qed-prover/target/release/qed-prover "$json_file" || true + + result_file="${json_file%.json}.result" + if [ -f "$result_file" ] && jq -e '.provable == true' "$result_file" > /dev/null 2>&1; then + echo "✅ $rule_name: PASSED" >> $GITHUB_STEP_SUMMARY + passed_count=$((passed_count + 1)) + else + echo "❌ $rule_name: FAILED" >> $GITHUB_STEP_SUMMARY + failed_rules="$failed_rules$rule_name," + fi +done + +echo "" >> $GITHUB_STEP_SUMMARY +echo "**Summary:** $passed_count/$total_count passed" >> $GITHUB_STEP_SUMMARY + +if [ -n "$failed_rules" ]; then + echo "::error::Failed rules: ${failed_rules%,}" + exit 1 +fi \ No newline at end of file From d21779efe88739c0b688edbad9b44f7a907e4d80 Mon Sep 17 00:00:00 2001 From: zengzirong <122090719@link.cuhk.edu.cn> Date: Sun, 1 Jun 2025 00:43:15 -0700 Subject: [PATCH 60/78] new update --- scripts/build-qed-prover.sh | 10 +++ .../JoinAssociate.java | 69 +++++++++++++++++++ .../JoinCommute.java | 28 ++++++++ .../SemiJoinJoinTranspose.java | 25 +++++++ .../SemiJoinProjectTranspose.java | 23 +++++++ .../SemiJoinRemove.java | 21 ++++++ .../RRuleInstances/FilterIntoJoin.java | 2 +- 7 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 src/main/java/org/qed/Generated/RRuleInstances-unprovable/JoinAssociate.java create mode 100644 src/main/java/org/qed/Generated/RRuleInstances-unprovable/JoinCommute.java create mode 100644 src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinJoinTranspose.java create mode 100644 src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinProjectTranspose.java create mode 100644 src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinRemove.java diff --git a/scripts/build-qed-prover.sh b/scripts/build-qed-prover.sh index e69de29..10fe4b0 100644 --- a/scripts/build-qed-prover.sh +++ b/scripts/build-qed-prover.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +# Build qed-prover with Rust nightly + +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly +source $HOME/.cargo/env + +git clone https://github.com/qed-solver/prover.git qed-prover +cd qed-prover +cargo +nightly build --release \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/RRuleInstances-unprovable/JoinAssociate.java b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/JoinAssociate.java new file mode 100644 index 0000000..059695e --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/JoinAssociate.java @@ -0,0 +1,69 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record JoinAssociate() implements RRule.RRuleFamily { + static final RelRN a = RelRN.scan("A", "A_Type"); + static final RelRN b = RelRN.scan("B", "B_Type"); + static final RelRN c = RelRN.scan("C", "C_Type"); + static final String pred_ab = "pred_ab"; + static final String pred_bc = "pred_bc"; + static final RelRN.Join.JoinType.MetaJoinType mjt_0 = new RelRN.Join.JoinType.MetaJoinType("mjt_0"); + static final RelRN.Join.JoinType.MetaJoinType mjt_1 = new RelRN.Join.JoinType.MetaJoinType("mjt_1"); + static final RelRN.Join.JoinType.MetaJoinType mjt_2 = new RelRN.Join.JoinType.MetaJoinType("mjt_2"); + static final RelRN.Join.JoinType.MetaJoinType mjt_3 = new RelRN.Join.JoinType.MetaJoinType("mjt_3"); + + static final RelRN before_ab = a.join(mjt_0, RexRN.and( + a.joinPred(pred_ab, b), + new RexRN.JoinField(1, a, b).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), b); + + static final RelRN before = before_ab.join(mjt_1, RexRN.and( + new RexRN.Pred(RuleBuilder.create().genericPredicateOp(pred_bc, true), before_ab.joinFields(c, 1, 2)), + new RexRN.JoinField(1, before_ab, c).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), c); + + static final RelRN after_bc = b.join(mjt_2, RexRN.and( + b.joinPred(pred_bc, c), + new RexRN.JoinField(0, b, c).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), c); + + static final RelRN after = a.join(mjt_3, RexRN.and( + new RexRN.Pred(RuleBuilder.create().genericPredicateOp(pred_ab, true), a.joinFields(after_bc, 0, 1)), + new RexRN.JoinField(1, a, after_bc).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), after_bc); + + static final RRule template = new RRule() { + @Override + public RelRN before() { + return before; + } + + @Override + public RelRN after() { + return after; + } + + @Override + public String name() { + return JoinAssociate.class.getSimpleName(); + } + }; + + static Seq assignments() { + var joinTypes = Seq.of(JoinRelType.INNER, JoinRelType.LEFT, JoinRelType.RIGHT, JoinRelType.FULL).map(RelRN.Join.JoinType.ConcreteJoinType::new); + return joinTypes.flatMap(jt0 -> joinTypes.flatMap(jt1 -> joinTypes.flatMap(jt2 -> joinTypes.map(jt3 -> new RRule.RRuleGenerator.MetaAssignment(Map.of(mjt_0, jt0, mjt_1, jt1, mjt_2, jt2, mjt_3, jt3)))))); + } + + @Override + public Seq family() { + return new RRule.RRuleGenerator(template, assignments()).family(); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances-unprovable/JoinCommute.java b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/JoinCommute.java new file mode 100644 index 0000000..22f2344 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/JoinCommute.java @@ -0,0 +1,28 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record JoinCommute() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("pred", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, joinCond, right); + } + + @Override + public RelRN after() { + // We need to swap the join fields in the condition + RexRN commutedJoinCond = right.joinPred("pred", left); + return right.join(JoinRelType.INNER, commutedJoinCond, left); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinJoinTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinJoinTranspose.java new file mode 100644 index 0000000..b1a4a3d --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinJoinTranspose.java @@ -0,0 +1,25 @@ +package org.qed.Generated.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + + +public record SemiJoinJoinTranspose() implements RRule { + static final RelRN left = RelRN.scan("left", "left_Type"); + static final RelRN middle = RelRN.scan("middle", "middle_Type"); + static final RelRN right = RelRN.scan("right", "right_Type"); + static final RexRN semiCond = left.joinPred("semi", middle); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, joinCond, right).join(JoinRelType.SEMI, semiCond, middle); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.SEMI, semiCond, middle).join(JoinRelType.INNER, joinCond, right); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinProjectTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinProjectTranspose.java new file mode 100644 index 0000000..71d2dfe --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinProjectTranspose.java @@ -0,0 +1,23 @@ +package org.qed.Generated.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record SemiJoinProjectTranspose() implements RRule { + static final RelRN left = RelRN.scan("Left", "left_type"); + static final RelRN right = RelRN.scan("Right", "right_type"); + static final RexRN proj = left.proj("proj", "proj_type"); + static final RexRN semiCond = left.joinPred("semi", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.SEMI, semiCond, right).project(proj); + } + + @Override + public RelRN after() { + return left.project(proj).join(JoinRelType.SEMI, semiCond, right); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinRemove.java b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinRemove.java new file mode 100644 index 0000000..3b21e3f --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinRemove.java @@ -0,0 +1,21 @@ +package org.qed.Generated.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record SemiJoinRemove() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + + @Override + public RelRN before() { + return left.join(JoinRelType.SEMI, RexRN.trueLiteral(), right); + } + + @Override + public RelRN after() { + return left; + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java index 24165aa..d16f8ef 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java @@ -19,7 +19,7 @@ public RelRN before() { var join = left.join(JoinRelType.INNER, joinCond, right); return join.filter("outer"); } - + @Override public RelRN after() { return left.join(JoinRelType.INNER, RexRN.and(joinCond, left.joinPred("outer", right)), right); From e8c7847ea592bdceca441939e096d4f28adc3af9 Mon Sep 17 00:00:00 2001 From: zengzirong <121952797+zengzirong@users.noreply.github.com> Date: Mon, 2 Jun 2025 20:19:27 -0700 Subject: [PATCH 61/78] CI Testing Add GitHub Actions workflow to verify code generation on push/PR --- .github/workflows/codegen-test.yml | 46 ++++++++ .../workflows/{test.yml => prover-test.yml} | 8 +- scripts/test-codegen.sh | 100 ++++++++++++++++++ .../java/org/qed/Generated/CalciteTester.java | 15 ++- .../SemiJoinProjectTransposeTest.java | 0 5 files changed, 159 insertions(+), 10 deletions(-) create mode 100644 .github/workflows/codegen-test.yml rename .github/workflows/{test.yml => prover-test.yml} (87%) create mode 100644 scripts/test-codegen.sh rename src/main/java/org/qed/Generated/{Tests => Tests-failed}/SemiJoinProjectTransposeTest.java (100%) diff --git a/.github/workflows/codegen-test.yml b/.github/workflows/codegen-test.yml new file mode 100644 index 0000000..85022d6 --- /dev/null +++ b/.github/workflows/codegen-test.yml @@ -0,0 +1,46 @@ +name: Test Code Generation + +on: + push: + pull_request: + +jobs: + test-code-generation: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Java + uses: actions/setup-java@v4 + with: + java-version: '23' + distribution: 'temurin' + + - name: Cache Maven dependencies + uses: actions/cache@v3 + with: + path: ~/.m2/repository + key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }} + restore-keys: | + ${{ runner.os }}-maven- + + - name: Build project + run: mvn -B compile --file pom.xml + + - name: Run Calcite tests + run: | + chmod +x scripts/test-codegen.sh + ./scripts/test-codegen.sh + + - name: Upload generated code + if: always() + uses: actions/upload-artifact@v4 + with: + name: generated-code + path: | + src/main/java/org/qed/Generated/*.java + !src/main/java/org/qed/Generated/CalciteTester.java + !src/main/java/org/qed/Generated/CalciteGenerator.java + !src/main/java/org/qed/Generated/CalciteUtilities.java + !src/main/java/org/qed/Generated/EmptyConfig.java \ No newline at end of file diff --git a/.github/workflows/test.yml b/.github/workflows/prover-test.yml similarity index 87% rename from .github/workflows/test.yml rename to .github/workflows/prover-test.yml index a1ed0b8..eca7d97 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/prover-test.yml @@ -1,15 +1,11 @@ -name: Test All Rules +name: Test Provability on: push: - paths: - - 'src/main/java/org/qed/Generated/RRuleInstances/**/*.java' pull_request: - paths: - - 'src/main/java/org/qed/Generated/RRuleInstances/**/*.java' jobs: - test-all-rules: + test-provability: runs-on: ubuntu-latest steps: diff --git a/scripts/test-codegen.sh b/scripts/test-codegen.sh new file mode 100644 index 0000000..6ab614e --- /dev/null +++ b/scripts/test-codegen.sh @@ -0,0 +1,100 @@ +#!/bin/bash + +# Script to generate code for each rule and test whether the rules can be applied correctly + +echo "## Code Generation Test Results" >> $GITHUB_STEP_SUMMARY +echo "" >> $GITHUB_STEP_SUMMARY + +# Step 1: Generate code for each rule in RRuleInstances +# Create temporary Java file for code generation +cat > RuleGenerator.java << 'EOF' +import org.qed.Generated.CalciteTester; +import org.qed.*; +import java.nio.file.*; + +public class RuleGenerator { + public static void main(String[] args) throws Exception { + String className = args[0]; + Class clazz = Class.forName(className); + RRule rule = (RRule) clazz.getDeclaredConstructor().newInstance(); + + CalciteTester tester = new CalciteTester(); + tester.serialize(rule, CalciteTester.genPath); + + System.out.println("Generated code for: " + rule.name()); + } +} +EOF + +# Build classpath +MAVEN_CP=$(mvn dependency:build-classpath -Dmdep.outputFile=/dev/stdout -q) +CLASSPATH="target/classes:${MAVEN_CP}" + +# Compile the generator +javac -cp "$CLASSPATH" RuleGenerator.java + +# Generate code for each rule +find src/main/java/org/qed/Generated/RRuleInstances -name '*.java' -not -path '*/RRuleInstances-unprovable/*' | while read file; do + className=$(echo "$file" | sed 's|src/main/java/||; s|/|.|g; s|\.java$||') + java -cp ".:$CLASSPATH" RuleGenerator "$className" +done + +# Step 2: Check for missing tests +missing_tests="" +missing_count=0 +for rule_file in src/main/java/org/qed/Generated/RRuleInstances/*.java; do + rule_name=$(basename "$rule_file" .java) + if [ ! -f "src/main/java/org/qed/Generated/Tests/${rule_name}Test.java" ]; then + missing_tests="${missing_tests}- ${rule_name}\n" + missing_count=$((missing_count + 1)) + fi +done + +if [ $missing_count -gt 0 ]; then + echo "**⚠️ Warning: Missing tests for $missing_count rules:**" >> $GITHUB_STEP_SUMMARY + echo -e "$missing_tests" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY +fi + +# Step 3: Run all test classes +# Store results for summary +total_tests=0 +passed_tests=0 + +# Find all test files and run them +total_tests=0 +passed_tests=0 + +for test_file in src/main/java/org/qed/Generated/Tests/*Test.java; do + class_name=${test_file#src/main/java/} + class_name=${class_name%.java} + class_name=${class_name//\//.} + test_name=$(basename "$test_file" .java) + display_name=${test_name%Test} + total_tests=$((total_tests + 1)) + + # Run the test and capture output + if java -cp "$CLASSPATH" "$class_name" > /tmp/test_output.txt 2>&1; then + if grep -q "trivial" /tmp/test_output.txt; then + echo "⚠️ ${display_name}: TRIVIAL" >> $GITHUB_STEP_SUMMARY + elif grep -q "succeeded" /tmp/test_output.txt && ! grep -q "failed" /tmp/test_output.txt; then + echo "✅ ${display_name}: PASSED" >> $GITHUB_STEP_SUMMARY + passed_tests=$((passed_tests + 1)) + else + echo "❌ ${display_name}: FAILED" >> $GITHUB_STEP_SUMMARY + fi + else + echo "❌ ${display_name}: ERROR" >> $GITHUB_STEP_SUMMARY + fi +done + +# Clean up +rm -f RuleGenerator.java RuleGenerator.class /tmp/test_output.txt + +echo "" >> $GITHUB_STEP_SUMMARY +echo "**Summary:** $passed_tests/$total_tests passed" >> $GITHUB_STEP_SUMMARY + +# Exit with error if tests failed +if [ "$passed_tests" -ne "$total_tests" ]; then + exit 1 +fi \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/CalciteTester.java b/src/main/java/org/qed/Generated/CalciteTester.java index 4fed912..949a2cf 100644 --- a/src/main/java/org/qed/Generated/CalciteTester.java +++ b/src/main/java/org/qed/Generated/CalciteTester.java @@ -84,6 +84,9 @@ public static void runAllTests() { org.qed.Generated.Tests.FilterSetOpTransposeTest.runTest(); org.qed.Generated.Tests.JoinExtractFilterTest.runTest(); org.qed.Generated.Tests.SemiJoinFilterTransposeTest.runTest(); + org.qed.Generated.Tests.ProjectFilterTransposeTest.runTest(); + org.qed.Generated.Tests.JoinPushTransitivePredicatesTest.runTest(); + org.qed.Generated.Tests.SemiJoinProjectTransposeTest.runTest(); } catch (Exception e) { System.out.println("Test failed: " + e.getMessage()); e.printStackTrace(); @@ -134,10 +137,14 @@ public void verify(HepPlanner runner, RelNode source, RelNode target) { String targetExplain = target.explain(); if(answerExplain.equals(targetExplain)) { - System.out.println("succeeded"); - // System.out.println("> Given source RelNode:\n" + source.explain()); - // System.out.println("> Actual rewritten RelNode:\n" + answerExplain); - // System.out.println("> Expected rewritten RelNode:\n" + targetExplain); + if(answerExplain.equals(source.explain())) + { + System.out.println("trivial"); + System.out.println("> Given source RelNode:\n" + source.explain()); + System.out.println("> Actual rewritten RelNode:\n" + answerExplain); + System.out.println("> Expected rewritten RelNode:\n" + targetExplain); + } + else System.out.println("succeeded"); return; } System.out.println("failed"); diff --git a/src/main/java/org/qed/Generated/Tests/SemiJoinProjectTransposeTest.java b/src/main/java/org/qed/Generated/Tests-failed/SemiJoinProjectTransposeTest.java similarity index 100% rename from src/main/java/org/qed/Generated/Tests/SemiJoinProjectTransposeTest.java rename to src/main/java/org/qed/Generated/Tests-failed/SemiJoinProjectTransposeTest.java From a202d28777bab1a473f002aaa297533a341b551f Mon Sep 17 00:00:00 2001 From: wkaiz Date: Thu, 5 Jun 2025 15:45:52 -0700 Subject: [PATCH 62/78] Adding Prune Rules --- .../RRuleInstances/PruneEmptyFilter.java | 20 +++++++++++++++++ .../RRuleInstances/PruneEmptyProject.java | 20 +++++++++++++++++ .../RRuleInstances/PruneLeftEmptyJoin.java | 22 +++++++++++++++++++ .../RRuleInstances/PruneRightEmptyJoin.java | 22 +++++++++++++++++++ 4 files changed, 84 insertions(+) create mode 100644 src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyFilter.java create mode 100644 src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyProject.java create mode 100644 src/main/java/org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java create mode 100644 src/main/java/org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyFilter.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyFilter.java new file mode 100644 index 0000000..a0c5eaf --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyFilter.java @@ -0,0 +1,20 @@ +package org.qed.Generated.RRuleInstances; + +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record PruneEmptyFilter() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN cond = source.pred("filter_cond"); + + @Override + public RelRN before() { + return source.empty().filter(cond); + } + + @Override + public RelRN after() { + return source.empty(); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyProject.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyProject.java new file mode 100644 index 0000000..0aac108 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyProject.java @@ -0,0 +1,20 @@ +package org.qed.Generated.RRuleInstances; + +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record PruneEmptyProject() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN proj = source.proj("proj", "Project_Type"); + + @Override + public RelRN before() { + return source.empty().project(proj); + } + + @Override + public RelRN after() { + return source.empty(); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java new file mode 100644 index 0000000..e38d07c --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java @@ -0,0 +1,22 @@ +package org.qed.Generated.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record PruneLeftEmptyJoin() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + return left.empty().join(JoinRelType.LEFT, joinCond, right); + } + + @Override + public RelRN after() { + return left.empty(); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java new file mode 100644 index 0000000..5b4b390 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java @@ -0,0 +1,22 @@ +package org.qed.Generated.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record PruneRightEmptyJoin() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.RIGHT, joinCond, right.empty()); + } + + @Override + public RelRN after() { + return right.empty(); + } +} From f024bd5c5ccdfb87edc7681aadf2338ee6f07d85 Mon Sep 17 00:00:00 2001 From: wkaiz Date: Thu, 5 Jun 2025 16:01:50 -0700 Subject: [PATCH 63/78] Fixing PruneEmptyJoins --- .../org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java | 4 ++-- .../org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java index e38d07c..cf78ac0 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java @@ -12,11 +12,11 @@ public record PruneLeftEmptyJoin() implements RRule { @Override public RelRN before() { - return left.empty().join(JoinRelType.LEFT, joinCond, right); + return left.empty().join(JoinRelType.RIGHT, joinCond, right); } @Override public RelRN after() { - return left.empty(); + return right; } } diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java index 5b4b390..498c0f9 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java @@ -12,11 +12,11 @@ public record PruneRightEmptyJoin() implements RRule { @Override public RelRN before() { - return left.join(JoinRelType.RIGHT, joinCond, right.empty()); + return left.join(JoinRelType.LEFT, joinCond, right.empty()); } @Override public RelRN after() { - return right.empty(); + return left; } } From 54376c3c778f53d0b8b57554e31332526d884260 Mon Sep 17 00:00:00 2001 From: wkaiz Date: Thu, 5 Jun 2025 16:05:46 -0700 Subject: [PATCH 64/78] Fixing PruneEmptyJoin --- .../org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java | 2 +- .../org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java index cf78ac0..00f7e76 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java @@ -8,7 +8,7 @@ public record PruneLeftEmptyJoin() implements RRule { static final RelRN left = RelRN.scan("Left", "Left_Type"); static final RelRN right = RelRN.scan("Right", "Right_Type"); - static final RexRN joinCond = left.joinPred("join", right); + static final RexRN joinCond = left.empty().joinPred("join", right); @Override public RelRN before() { diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java index 498c0f9..2101e75 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java @@ -8,7 +8,7 @@ public record PruneRightEmptyJoin() implements RRule { static final RelRN left = RelRN.scan("Left", "Left_Type"); static final RelRN right = RelRN.scan("Right", "Right_Type"); - static final RexRN joinCond = left.joinPred("join", right); + static final RexRN joinCond = left.joinPred("join", right.empty()); @Override public RelRN before() { From 8df38178187756c3b738868d06a97aa5d76df181 Mon Sep 17 00:00:00 2001 From: wkaiz Date: Thu, 5 Jun 2025 16:14:57 -0700 Subject: [PATCH 65/78] Adjusting --- .../org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java | 3 +-- .../org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java index 00f7e76..fbe5979 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java @@ -8,11 +8,10 @@ public record PruneLeftEmptyJoin() implements RRule { static final RelRN left = RelRN.scan("Left", "Left_Type"); static final RelRN right = RelRN.scan("Right", "Right_Type"); - static final RexRN joinCond = left.empty().joinPred("join", right); @Override public RelRN before() { - return left.empty().join(JoinRelType.RIGHT, joinCond, right); + return left.empty().join(JoinRelType.RIGHT, "pred", right); } @Override diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java index 2101e75..4b4b5c0 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java @@ -12,7 +12,7 @@ public record PruneRightEmptyJoin() implements RRule { @Override public RelRN before() { - return left.join(JoinRelType.LEFT, joinCond, right.empty()); + return left.join(JoinRelType.LEFT, "pred", right.empty()); } @Override From 9ac74e6910e4d2bfa829c3bf51c0215c1bfb1758 Mon Sep 17 00:00:00 2001 From: wkaiz Date: Thu, 5 Jun 2025 16:25:16 -0700 Subject: [PATCH 66/78] Adding PruneEmptyIntersect --- .../PruneLeftEmptyJoin.java | 0 .../PruneRightEmptyJoin.java | 0 .../RRuleInstances/PruneEmptyIntersect.java | 20 +++++++++++++++++++ 3 files changed, 20 insertions(+) rename src/main/java/org/qed/Generated/{RRuleInstances => RRuleInstances-unprovable}/PruneLeftEmptyJoin.java (100%) rename src/main/java/org/qed/Generated/{RRuleInstances => RRuleInstances-unprovable}/PruneRightEmptyJoin.java (100%) create mode 100644 src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyIntersect.java diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/PruneLeftEmptyJoin.java similarity index 100% rename from src/main/java/org/qed/Generated/RRuleInstances/PruneLeftEmptyJoin.java rename to src/main/java/org/qed/Generated/RRuleInstances-unprovable/PruneLeftEmptyJoin.java diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/PruneRightEmptyJoin.java similarity index 100% rename from src/main/java/org/qed/Generated/RRuleInstances/PruneRightEmptyJoin.java rename to src/main/java/org/qed/Generated/RRuleInstances-unprovable/PruneRightEmptyJoin.java diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyIntersect.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyIntersect.java new file mode 100644 index 0000000..fb4938f --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyIntersect.java @@ -0,0 +1,20 @@ +package org.qed.Generated.RRuleInstances; + +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record PruneEmptyIntersect() implements RRule { + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + + @Override + public RelRN before() { + return a.intersect(false, b.empty()); + } + + @Override + public RelRN after() { + return a.empty().intersect(false, b.empty()); + } +} From 3837421154aa33f3d2ebd6892f350def6fc6ecf0 Mon Sep 17 00:00:00 2001 From: yushinliang <143054261+yushinliang@users.noreply.github.com> Date: Tue, 10 Jun 2025 01:20:36 +0800 Subject: [PATCH 67/78] added minus operator and generated code for MinusMerge (#16) --- .gitignore | 2 +- src/main/java/org/qed/CodeGenerator.java | 10 ++++ .../org/qed/Generated/CalciteGenerator.java | 49 +++++++++++++++++++ .../java/org/qed/Generated/CalciteTester.java | 1 + .../java/org/qed/Generated/MinusMerge.java | 39 +++++++++++++++ .../Generated/RRuleInstances/MinusMerge.java | 26 ++++++++++ .../qed/Generated/Tests/MinusMergeTest.java | 48 ++++++++++++++++++ src/main/java/org/qed/RelRN.java | 12 +++++ 8 files changed, 186 insertions(+), 1 deletion(-) create mode 100644 src/main/java/org/qed/Generated/MinusMerge.java create mode 100644 src/main/java/org/qed/Generated/RRuleInstances/MinusMerge.java create mode 100644 src/main/java/org/qed/Generated/Tests/MinusMergeTest.java diff --git a/.gitignore b/.gitignore index 870d0c2..a2ef0ec 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,4 @@ target/ *.iml .mvn/wrapper/maven-wrapper.jar .vscode -.devcontainer \ No newline at end of file +.devcontainer diff --git a/src/main/java/org/qed/CodeGenerator.java b/src/main/java/org/qed/CodeGenerator.java index 85e1f6f..bf7da08 100644 --- a/src/main/java/org/qed/CodeGenerator.java +++ b/src/main/java/org/qed/CodeGenerator.java @@ -26,6 +26,7 @@ default E onMatch(E env, RelRN pattern) { case RelRN.Join join -> onMatchJoin(env, join); case RelRN.Union union -> onMatchUnion(env, union); case RelRN.Intersect intersect -> onMatchIntersect(env, intersect); + case RelRN.Minus minus -> onMatchMinus(env, minus); case RelRN.Empty empty -> onMatchEmpty(env, empty); default -> onMatchCustom(env, pattern); }; @@ -62,6 +63,7 @@ default E transform(E env, RelRN target) { case RelRN.Join join -> transformJoin(env, join); case RelRN.Union union -> transformUnion(env, union); case RelRN.Intersect intersect -> transformIntersect(env, intersect); + case RelRN.Minus minus -> transformMinus(env, minus); case RelRN.Empty empty -> transformEmpty(env, empty); default -> transformCustom(env, target); }; @@ -121,6 +123,10 @@ default E onMatchIntersect(E env, RelRN.Intersect intersect) { return unimplementedOnMatch(env, intersect); } + default E onMatchMinus(E env, RelRN.Minus minus) { + return unimplementedOnMatch(env, minus); + } + default E onMatchCustom(E env, RelRN custom) { return unimplementedOnMatch(env, custom); } @@ -193,6 +199,10 @@ default E transformIntersect(E env, RelRN.Intersect intersect) { return unimplementedTransform(env, intersect); } + default E transformMinus(E env, RelRN.Minus minus) { + return unimplementedTransform(env, minus); + } + default E transformCustom(E env, RelRN custom) { return unimplementedTransform(env, custom); } diff --git a/src/main/java/org/qed/Generated/CalciteGenerator.java b/src/main/java/org/qed/Generated/CalciteGenerator.java index 6792cca..a5f9418 100644 --- a/src/main/java/org/qed/Generated/CalciteGenerator.java +++ b/src/main/java/org/qed/Generated/CalciteGenerator.java @@ -185,6 +185,36 @@ public Env onMatchIntersect(Env env, RelRN.Intersect intersect) { return current_env.grow("operand(" + operatorClass + ".class).inputs(" + inputsBuilder.toString() + ")"); } + @Override + public Env onMatchMinus(Env env, RelRN.Minus minus) { + // Get the all flag from the minus + boolean all = minus.all(); + + // Process each source in the minus + var current_env = env; + var skeletons = Seq.empty(); + + // Process all sources in the sequence + for (var source : minus.sources()) { + var next_env = current_env.next(); + var source_env = onMatch(next_env, source); + skeletons = skeletons.appended(source_env.skeleton()); + current_env = source_env; + } + + // Build the input skeletons string for the operand + StringBuilder inputsBuilder = new StringBuilder(); + for (int i = 0; i < skeletons.size(); i++) { + if (i > 0) { + inputsBuilder.append(", "); + } + inputsBuilder.append(skeletons.get(i).toString()); + } + + // Create the minus operand + return current_env.grow("operand(LogicalMinus.class).inputs(" + inputsBuilder.toString() + ")"); + } + @Override public Env onMatchField(Env env, RexRN.Field field) { // Generate a unique symbolic name for this field @@ -320,6 +350,25 @@ public Env transformIntersect(Env env, RelRN.Intersect intersect) { return current_env.focus(current_env.current() + "." + methodName + "(" + all + ", " + sourceCount + ")"); } + @Override + public Env transformMinus(Env env, RelRN.Minus minus) { + // Get the all flag from the minus + boolean all = minus.all(); + + // The number of sources + int sourceCount = minus.sources().size(); + + // Transform each source + var current_env = env; + for (var source : minus.sources()) { + current_env = transform(current_env, source); + } + + // Use the minus method with the all flag and source count + // This matches the expected Calcite RelBuilder.minus(boolean all, int n) signature + return current_env.focus(current_env.current() + ".minus(" + all + ", " + sourceCount + ")"); + } + @Override public Env transformField(Env env, RexRN.Field field) { // In Calcite, field references are typically created with a "field" method diff --git a/src/main/java/org/qed/Generated/CalciteTester.java b/src/main/java/org/qed/Generated/CalciteTester.java index 949a2cf..cccd850 100644 --- a/src/main/java/org/qed/Generated/CalciteTester.java +++ b/src/main/java/org/qed/Generated/CalciteTester.java @@ -84,6 +84,7 @@ public static void runAllTests() { org.qed.Generated.Tests.FilterSetOpTransposeTest.runTest(); org.qed.Generated.Tests.JoinExtractFilterTest.runTest(); org.qed.Generated.Tests.SemiJoinFilterTransposeTest.runTest(); + org.qed.Generated.Tests.MinusMergeTest.runTest(); org.qed.Generated.Tests.ProjectFilterTransposeTest.runTest(); org.qed.Generated.Tests.JoinPushTransitivePredicatesTest.runTest(); org.qed.Generated.Tests.SemiJoinProjectTransposeTest.runTest(); diff --git a/src/main/java/org/qed/Generated/MinusMerge.java b/src/main/java/org/qed/Generated/MinusMerge.java new file mode 100644 index 0000000..a519ec2 --- /dev/null +++ b/src/main/java/org/qed/Generated/MinusMerge.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class MinusMerge extends RelRule { + protected MinusMerge(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_5 = call.builder(); + call.transformTo(var_5.push(call.rel(2)).push(call.rel(3)).push(call.rel(4)).union(false, 2).minus(false, 2).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default MinusMerge toRule() { + return new MinusMerge(this); + } + + @Override + default String description() { + return "MinusMerge"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_4 -> s_4.operand(LogicalMinus.class).inputs(s_2 -> s_2.operand(LogicalMinus.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()), s_3 -> s_3.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/MinusMerge.java b/src/main/java/org/qed/Generated/RRuleInstances/MinusMerge.java new file mode 100644 index 0000000..a080685 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/MinusMerge.java @@ -0,0 +1,26 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record MinusMerge() implements RRule { + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + static final RelRN c = RelRN.scan("C", "Common_Type"); + + @Override + public RelRN before() { + return a.minus(false, b).minus(false, c); + } + + @Override + public RelRN after() { + return a.minus(false, b.union(false, c)); + } +} diff --git a/src/main/java/org/qed/Generated/Tests/MinusMergeTest.java b/src/main/java/org/qed/Generated/Tests/MinusMergeTest.java new file mode 100644 index 0000000..bc9985e --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/MinusMergeTest.java @@ -0,0 +1,48 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.Generated.RRuleInstances.UnionMerge; +import org.qed.RuleBuilder; + +/** + * Test for the MinusMerge rule. + */ +public class MinusMergeTest { + + /** + * Run test for MinusMerge rule. + */ + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var scan1 = builder.scan(table.getName()).build(); + var scan2 = builder.scan(table.getName()).build(); + var scan3 = builder.scan(table.getName()).build(); + + // (A − B) − C + var before = builder.push(scan1).push(scan2).minus(false, 2).push(scan3).minus(false, 2).build(); + + // A − (B ∪ C) + var union = builder.push(scan2).push(scan3).union(false).build(); + var after = builder.push(scan1).push(union).minus(false, 2).build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.MinusMerge.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + /** + * Main method to run this test independently. + */ + public static void main(String[] args) { + System.out.println("Running MinusMerge test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/RelRN.java b/src/main/java/org/qed/RelRN.java index ec09567..becc662 100644 --- a/src/main/java/org/qed/RelRN.java +++ b/src/main/java/org/qed/RelRN.java @@ -104,6 +104,10 @@ default Intersect intersect(boolean all, RelRN... sources) { return new Intersect(all, Seq.of(this).appendedAll(sources)); } + default Minus minus (boolean all, RelRN... sources) { + return new Minus(all, Seq.of(this).appendedAll(sources)); + } + default Empty empty() { return new Empty(this); } @@ -180,6 +184,14 @@ public RelNode semantics() { } } + record Minus(boolean all, Seq sources) implements RelRN { + + @Override + public RelNode semantics() { + return RuleBuilder.create().pushAll(sources.map(RelRN::semantics)).minus(all, sources.size()).build(); + } + } + record Empty(RelRN sourceType) implements RelRN { @Override From 3844e902548bd13572aba434dfb6782ea6e9a0c8 Mon Sep 17 00:00:00 2001 From: zengzirong <121952797+zengzirong@users.noreply.github.com> Date: Sat, 14 Jun 2025 02:05:39 -0700 Subject: [PATCH 68/78] Fixed JoinCommute (#17) * sync with origin dsl * fixed JoinCommute --- .../org/qed/Generated/CalciteGenerator.java | 122 +++++++++++++++++- .../java/org/qed/Generated/CalciteTester.java | 29 +++-- .../java/org/qed/Generated/JoinCommute.java | 16 ++- .../org/qed/Generated/PruneEmptyFilter.java | 39 ++++++ .../qed/Generated/PruneEmptyIntersect.java | 39 ++++++ .../org/qed/Generated/PruneEmptyProject.java | 39 ++++++ .../org/qed/Generated/PruneLeftEmptyJoin.java | 39 ++++++ .../qed/Generated/PruneRightEmptyJoin.java | 39 ++++++ .../JoinCommute.java | 28 ---- .../Generated/RRuleInstances/JoinCommute.java | 50 +++++++ .../SemiJoinProjectTransposeTest.java | 0 .../qed/Generated/Tests/JoinCommuteTest.java | 84 ++++++++++++ 12 files changed, 484 insertions(+), 40 deletions(-) create mode 100644 src/main/java/org/qed/Generated/PruneEmptyFilter.java create mode 100644 src/main/java/org/qed/Generated/PruneEmptyIntersect.java create mode 100644 src/main/java/org/qed/Generated/PruneEmptyProject.java create mode 100644 src/main/java/org/qed/Generated/PruneLeftEmptyJoin.java create mode 100644 src/main/java/org/qed/Generated/PruneRightEmptyJoin.java delete mode 100644 src/main/java/org/qed/Generated/RRuleInstances-unprovable/JoinCommute.java create mode 100644 src/main/java/org/qed/Generated/RRuleInstances/JoinCommute.java rename src/main/java/org/qed/Generated/{Tests-failed => Tests-Trivial}/SemiJoinProjectTransposeTest.java (100%) create mode 100644 src/main/java/org/qed/Generated/Tests/JoinCommuteTest.java diff --git a/src/main/java/org/qed/Generated/CalciteGenerator.java b/src/main/java/org/qed/Generated/CalciteGenerator.java index a5f9418..32691cd 100644 --- a/src/main/java/org/qed/Generated/CalciteGenerator.java +++ b/src/main/java/org/qed/Generated/CalciteGenerator.java @@ -11,6 +11,8 @@ import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.processing.Generated; + public class CalciteGenerator implements CodeGenerator { @Override @@ -279,7 +281,81 @@ public Env transformFilter(Env env, RelRN.Filter filter) { @Override public Env transformPred(Env env, RexRN.Pred pred) { - return env.focus(env.symbols().get(pred.operator().getName())); + // ONLY apply to the very specific JoinCommute pattern where: + // 1. We have exactly 2 JoinField sources + // 2. The first JoinField has ordinal 1 and second has ordinal 0 (indicating argument swap) + // This is the EXACT pattern from JoinCommute rule: pred($1, $0) vs pred($0, $1) + boolean isExactJoinCommuteSwapPattern = pred.sources().size() == 2 && + pred.sources().get(0) instanceof RexRN.JoinField joinField1 && + pred.sources().get(1) instanceof RexRN.JoinField joinField2 && + joinField1.ordinal() == 1 && // First arg is ordinal 1 (left -> right in swap) + joinField2.ordinal() == 0; // Second arg is ordinal 0 (right -> left in swap) + + if (isExactJoinCommuteSwapPattern) { + // This is the JoinCommute swapped predicate - transform with swapped arguments + var currentEnv = env; + var transformedArgs = Seq.empty(); + + // Process arguments in reverse order to swap them back + var sources = pred.sources(); + var reversedSources = Seq.of(sources.get(1), sources.get(0)); + + for (var arg : reversedSources) { + currentEnv = transform(currentEnv, arg); + transformedArgs = transformedArgs.appended(currentEnv.current()); + currentEnv = currentEnv.focus(env.current()); + } + + // Create a new predicate call with transformed arguments + String argsString = transformedArgs.joinToString(", "); + + // Extract operator from the original join condition + String operatorCall = "((org.apache.calcite.rex.RexCall) ((LogicalJoin) call.rel(0)).getCondition()).getOperator()"; + + return currentEnv.focus(env.current() + ".call(" + operatorCall + ", " + argsString + ")"); + } else { + return env.focus(env.symbols().get(pred.operator().getName())); + } + } + + @Override + public Env transformJoinField(Env env, RexRN.JoinField joinField) { + // For JoinCommute: we need to calculate absolute field positions in the swapped join + + // Get the original join condition to extract the actual field indices + var origJoinDecl = env.declare("(LogicalJoin) call.rel(0)"); + var envWithOrigJoin = origJoinDecl.getValue(); + var conditionDecl = envWithOrigJoin.declare("(org.apache.calcite.rex.RexCall) " + origJoinDecl.getKey() + ".getCondition()"); + var envWithCondition = conditionDecl.getValue(); + + if (joinField.ordinal() == 0) { + // Ordinal 0 = Left table in original join + // Extract the left operand field index from original condition + var leftFieldDecl = envWithCondition.declare("((org.apache.calcite.rex.RexInputRef) " + conditionDecl.getKey() + ".getOperands().get(0)).getIndex()"); + var envWithLeftField = leftFieldDecl.getValue(); + + // In swapped join: Left table is now at input 1 + // Use field(2, 1, leftFieldIndex) syntax + return envWithLeftField.focus(env.current() + ".field(2, 1, " + leftFieldDecl.getKey() + ")"); + } + else if (joinField.ordinal() == 1) { + // Ordinal 1 = Right table in original join + // Extract the right operand field index from original condition + var rightFieldDecl = envWithCondition.declare("((org.apache.calcite.rex.RexInputRef) " + conditionDecl.getKey() + ".getOperands().get(1)).getIndex()"); + var envWithRightField = rightFieldDecl.getValue(); + + // Right table field index needs to be adjusted since it was originally after left table + var leftColCountDecl = envWithRightField.declare("call.rel(1).getRowType().getFieldCount()"); + var envWithLeftCount = leftColCountDecl.getValue(); + var adjustedRightFieldDecl = envWithLeftCount.declare(rightFieldDecl.getKey() + " - " + leftColCountDecl.getKey()); + var envWithAdjustedRightField = adjustedRightFieldDecl.getValue(); + + // In swapped join: Right table is now at input 0 + // Use field(2, 0, adjustedRightFieldIndex) syntax + return envWithAdjustedRightField.focus(env.current() + ".field(2, 0, " + adjustedRightFieldDecl.getKey() + ")"); + } else { + throw new UnsupportedOperationException("Unsupported join field ordinal: " + joinField.ordinal()); + } } @Override @@ -429,6 +505,50 @@ public Env transformEmpty(Env env, RelRN.Empty empty) { return env.focus(env.current() + ".empty()"); } + + @Override + public Env transformCustom(Env env, RelRN custom) { + return switch (custom) { + case org.qed.Generated.RRuleInstances.JoinCommute.ProjectionRelRN projection -> { + // Transform the source first - this builds the join + var sourceEnv = transform(env, projection.source()); + + // Get original table column counts + var leftTableDecl = sourceEnv.declare("call.rel(1)"); + var envWithLeftTable = leftTableDecl.getValue(); + var rightTableDecl = envWithLeftTable.declare("call.rel(2)"); + var envWithRightTable = rightTableDecl.getValue(); + + var leftColCountDecl = envWithRightTable.declare(leftTableDecl.getKey() + ".getRowType().getFieldCount()"); + var envWithLeftCount = leftColCountDecl.getValue(); + var rightColCountDecl = envWithLeftCount.declare(rightTableDecl.getKey() + ".getRowType().getFieldCount()"); + var envWithRightCount = rightColCountDecl.getValue(); + + // Create the projection indices as a List + var projectionIndicesDecl = envWithRightCount.declare( + "java.util.stream.IntStream.concat(" + + // Left columns: rightColCount + 0, rightColCount + 1, ..., rightColCount + leftColCount - 1 + "java.util.stream.IntStream.range(" + rightColCountDecl.getKey() + ", " + + rightColCountDecl.getKey() + " + " + leftColCountDecl.getKey() + "), " + + // Right columns: 0, 1, ..., rightColCount - 1 + "java.util.stream.IntStream.range(0, " + rightColCountDecl.getKey() + ")" + + ").boxed().collect(java.util.stream.Collectors.toList())" + ); + var envWithProjectionIndices = projectionIndicesDecl.getValue(); + + // Convert List to field references using RelBuilder.fields() + var fieldRefsDecl = envWithProjectionIndices.declare( + sourceEnv.current() + ".fields(" + projectionIndicesDecl.getKey() + ")" + ); + var envWithFieldRefs = fieldRefsDecl.getValue(); + + // Apply projection using the field references list + yield envWithFieldRefs.focus(sourceEnv.current() + ".project(" + fieldRefsDecl.getKey() + ")"); + } + default -> unimplementedTransform(env, custom); + }; + } + public record Env(AtomicInteger varId, int rel, String current, String skeleton, Seq statements, ImmutableMap symbols) { public static Env empty() { diff --git a/src/main/java/org/qed/Generated/CalciteTester.java b/src/main/java/org/qed/Generated/CalciteTester.java index cccd850..4a1d6b3 100644 --- a/src/main/java/org/qed/Generated/CalciteTester.java +++ b/src/main/java/org/qed/Generated/CalciteTester.java @@ -37,13 +37,21 @@ public static HepPlanner loadRule(RelOptRule rule) { return new HepPlanner(builder.build()); } + public static HepPlanner loadRule(RelOptRule rule, int matchLimit) { + System.out.printf("Verifying Rule: %s (match limit: %d)\n", rule.getClass(), matchLimit); + var builder = new HepProgramBuilder() + .addMatchLimit(matchLimit) + .addRuleInstance(rule); + return new HepPlanner(builder.build()); + } + public static Seq ruleList() { Reflections reflections = new Reflections("org.qed.Generated.RRuleInstances"); Set> ruleClasses = reflections.getSubTypesOf(RRule.class); var concreteRuleClasses = ruleClasses.stream() .filter(clazz -> !clazz.isInterface() && - !Modifier.isAbstract(clazz.getModifiers()) && + !Modifier.isAbstract(clazz.getModifiers()) && !clazz.getName().contains("$")) .collect(Collectors.toSet()); @@ -87,7 +95,7 @@ public static void runAllTests() { org.qed.Generated.Tests.MinusMergeTest.runTest(); org.qed.Generated.Tests.ProjectFilterTransposeTest.runTest(); org.qed.Generated.Tests.JoinPushTransitivePredicatesTest.runTest(); - org.qed.Generated.Tests.SemiJoinProjectTransposeTest.runTest(); + org.qed.Generated.Tests.JoinCommuteTest.runTest(); } catch (Exception e) { System.out.println("Test failed: " + e.getMessage()); e.printStackTrace(); @@ -95,14 +103,15 @@ public static void runAllTests() { } public static void main(String[] args) throws IOException { -// var rule = new RRuleInstance.FilterSetOpTranspose(); -// Files.createDirectories(Path.of(rulePath)); -// new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); -// var rules = new RRuleInstance.JoinAssociate(); -// Files.createDirectories(Path.of(rulePath)); -// for (var rule : rules.family()) { -// new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); -// } + var rule = new org.qed.Generated.RRuleInstances.JoinCommute(); + System.out.println(rule.explain()); + Files.createDirectories(Path.of(rulePath)); + new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); + // var rules = new RRuleInstance.JoinAssociate(); + // Files.createDirectories(Path.of(rulePath)); + // for (var rule : rules.family()) { + // new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); + // } generate(); runAllTests(); } diff --git a/src/main/java/org/qed/Generated/JoinCommute.java b/src/main/java/org/qed/Generated/JoinCommute.java index 894330c..147fd52 100644 --- a/src/main/java/org/qed/Generated/JoinCommute.java +++ b/src/main/java/org/qed/Generated/JoinCommute.java @@ -14,7 +14,21 @@ protected JoinCommute(Config config) { @Override public void onMatch(RelOptRuleCall call) { var var_3 = call.builder(); - call.transformTo(var_3.push(call.rel(2)).push(call.rel(1)).join(JoinRelType.INNER, ((LogicalJoin) call.rel(0)).getCondition()).build()); + var var_4 = (LogicalJoin) call.rel(0); + var var_5 = (org.apache.calcite.rex.RexCall) var_4.getCondition(); + var var_6 = ((org.apache.calcite.rex.RexInputRef) var_5.getOperands().get(0)).getIndex(); + var var_7 = (LogicalJoin) call.rel(0); + var var_8 = (org.apache.calcite.rex.RexCall) var_7.getCondition(); + var var_9 = ((org.apache.calcite.rex.RexInputRef) var_8.getOperands().get(1)).getIndex(); + var var_10 = call.rel(1).getRowType().getFieldCount(); + var var_11 = var_9 - var_10; + var var_12 = call.rel(1); + var var_13 = call.rel(2); + var var_14 = var_12.getRowType().getFieldCount(); + var var_15 = var_13.getRowType().getFieldCount(); + var var_16 = java.util.stream.IntStream.concat(java.util.stream.IntStream.range(var_15, var_15 + var_14), java.util.stream.IntStream.range(0, var_15)).boxed().collect(java.util.stream.Collectors.toList()); + var var_17 = var_3.push(call.rel(2)).push(call.rel(1)).join(JoinRelType.INNER, var_3.push(call.rel(2)).push(call.rel(1)).call(((org.apache.calcite.rex.RexCall) ((LogicalJoin) call.rel(0)).getCondition()).getOperator(), var_3.push(call.rel(2)).push(call.rel(1)).field(2, 1, var_6), var_3.push(call.rel(2)).push(call.rel(1)).field(2, 0, var_11))).fields(var_16); + call.transformTo(var_3.push(call.rel(2)).push(call.rel(1)).join(JoinRelType.INNER, var_3.push(call.rel(2)).push(call.rel(1)).call(((org.apache.calcite.rex.RexCall) ((LogicalJoin) call.rel(0)).getCondition()).getOperator(), var_3.push(call.rel(2)).push(call.rel(1)).field(2, 1, var_6), var_3.push(call.rel(2)).push(call.rel(1)).field(2, 0, var_11))).project(var_17).build()); } public interface Config extends EmptyConfig { diff --git a/src/main/java/org/qed/Generated/PruneEmptyFilter.java b/src/main/java/org/qed/Generated/PruneEmptyFilter.java new file mode 100644 index 0000000..bcb125e --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneEmptyFilter.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneEmptyFilter extends RelRule { + protected PruneEmptyFilter(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_2 = call.builder(); + call.transformTo(var_2.empty().build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneEmptyFilter toRule() { + return new PruneEmptyFilter(this); + } + + @Override + default String description() { + return "PruneEmptyFilter"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(LogicalValues.class).noInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/PruneEmptyIntersect.java b/src/main/java/org/qed/Generated/PruneEmptyIntersect.java new file mode 100644 index 0000000..ddfcba3 --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneEmptyIntersect.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneEmptyIntersect extends RelRule { + protected PruneEmptyIntersect(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.empty().empty().intersect(false, 2).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneEmptyIntersect toRule() { + return new PruneEmptyIntersect(this); + } + + @Override + default String description() { + return "PruneEmptyIntersect"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalIntersect.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(LogicalValues.class).noInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/PruneEmptyProject.java b/src/main/java/org/qed/Generated/PruneEmptyProject.java new file mode 100644 index 0000000..2241478 --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneEmptyProject.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneEmptyProject extends RelRule { + protected PruneEmptyProject(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_2 = call.builder(); + call.transformTo(var_2.empty().build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneEmptyProject toRule() { + return new PruneEmptyProject(this); + } + + @Override + default String description() { + return "PruneEmptyProject"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_1 -> s_1.operand(LogicalProject.class).oneInput(s_0 -> s_0.operand(LogicalValues.class).noInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/PruneLeftEmptyJoin.java b/src/main/java/org/qed/Generated/PruneLeftEmptyJoin.java new file mode 100644 index 0000000..95774f0 --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneLeftEmptyJoin.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneLeftEmptyJoin extends RelRule { + protected PruneLeftEmptyJoin(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(2)).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneLeftEmptyJoin toRule() { + return new PruneLeftEmptyJoin(this); + } + + @Override + default String description() { + return "PruneLeftEmptyJoin"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(LogicalValues.class).noInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/PruneRightEmptyJoin.java b/src/main/java/org/qed/Generated/PruneRightEmptyJoin.java new file mode 100644 index 0000000..5d4e1c3 --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneRightEmptyJoin.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneRightEmptyJoin extends RelRule { + protected PruneRightEmptyJoin(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(1)).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneRightEmptyJoin toRule() { + return new PruneRightEmptyJoin(this); + } + + @Override + default String description() { + return "PruneRightEmptyJoin"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(LogicalValues.class).noInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances-unprovable/JoinCommute.java b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/JoinCommute.java deleted file mode 100644 index 22f2344..0000000 --- a/src/main/java/org/qed/Generated/RRuleInstances-unprovable/JoinCommute.java +++ /dev/null @@ -1,28 +0,0 @@ -package org.qed.Generated.RRuleInstances; - -import kala.collection.Map; -import kala.collection.Seq; -import org.apache.calcite.rel.core.JoinRelType; -import org.apache.calcite.sql.fun.SqlStdOperatorTable; -import org.qed.RelRN; -import org.qed.RexRN; -import org.qed.RRule; -import org.qed.RuleBuilder; - -public record JoinCommute() implements RRule { - static final RelRN left = RelRN.scan("Left", "Left_Type"); - static final RelRN right = RelRN.scan("Right", "Right_Type"); - static final RexRN joinCond = left.joinPred("pred", right); - - @Override - public RelRN before() { - return left.join(JoinRelType.INNER, joinCond, right); - } - - @Override - public RelRN after() { - // We need to swap the join fields in the condition - RexRN commutedJoinCond = right.joinPred("pred", left); - return right.join(JoinRelType.INNER, commutedJoinCond, left); - } -} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/JoinCommute.java b/src/main/java/org/qed/Generated/RRuleInstances/JoinCommute.java new file mode 100644 index 0000000..d1580be --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/JoinCommute.java @@ -0,0 +1,50 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Seq; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlOperator; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record JoinCommute() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final String pred = "pred"; + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, pred, right); + } + + @Override + public RelRN after() { + SqlOperator predOp = RuleBuilder.create().genericPredicateOp(pred, true); + RexRN swappedPred = new RexRN.Pred(predOp, Seq.of( + new RexRN.JoinField(1, right, left), + new RexRN.JoinField(0, right, left) + )); + RelRN swappedJoin = right.join(JoinRelType.INNER, swappedPred, left); + + return new ProjectionRelRN(swappedJoin); + } + + // implementation for the column reordering projection + public static record ProjectionRelRN(RelRN source) implements RelRN { + @Override + public RelNode semantics() { + RuleBuilder builder = RuleBuilder.create(); + builder.push(source.semantics()); + + RexNode leftField = builder.field(1); // Left columns (now at position 1) + RexNode rightField = builder.field(0); // Right columns (now at position 0) + + builder.project(leftField, rightField); + + return builder.build(); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/Tests-failed/SemiJoinProjectTransposeTest.java b/src/main/java/org/qed/Generated/Tests-Trivial/SemiJoinProjectTransposeTest.java similarity index 100% rename from src/main/java/org/qed/Generated/Tests-failed/SemiJoinProjectTransposeTest.java rename to src/main/java/org/qed/Generated/Tests-Trivial/SemiJoinProjectTransposeTest.java diff --git a/src/main/java/org/qed/Generated/Tests/JoinCommuteTest.java b/src/main/java/org/qed/Generated/Tests/JoinCommuteTest.java new file mode 100644 index 0000000..c6d0203 --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/JoinCommuteTest.java @@ -0,0 +1,84 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class JoinCommuteTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + // Create EMP table (8 columns like the real Calcite test) + var empTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), // EMPNO + Tuple.of(RelType.fromString("VARCHAR", true), false), // ENAME + Tuple.of(RelType.fromString("VARCHAR", true), false), // JOB + Tuple.of(RelType.fromString("INTEGER", true), false), // MGR + Tuple.of(RelType.fromString("DATE", true), false), // HIREDATE + Tuple.of(RelType.fromString("DECIMAL", true), false), // SAL + Tuple.of(RelType.fromString("DECIMAL", true), false), // COMM + Tuple.of(RelType.fromString("INTEGER", true), false) // DEPTNO + )); + builder.addTable(empTable); + + // Create DEPT table (3 columns like the real Calcite test) + var deptTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), // DEPTNO + Tuple.of(RelType.fromString("VARCHAR", true), false), // DNAME + Tuple.of(RelType.fromString("VARCHAR", true), false) // LOC + )); + builder.addTable(deptTable); + + // Build the "before" pattern: EMP JOIN DEPT ON EMP.DEPTNO = DEPT.DEPTNO + var empScan = builder.scan(empTable.getName()).build(); + var deptScan = builder.scan(deptTable.getName()).build(); + + var before = builder + .push(empScan) + .push(deptScan) + .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("equals", true), + builder.field(2, 0, 7), // EMP.DEPTNO (8th column, index 7) + builder.field(2, 1, 0) // DEPT.DEPTNO (1st column, index 0) + )) + .build(); + + // Build the "after" pattern: DEPT JOIN EMP ON DEPT.DEPTNO = EMP.DEPTNO + // followed by projection to restore original column order + var after = builder + .push(deptScan) + .push(empScan) + .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("equals", true), + builder.field(2, 1, 7), // EMP.DEPTNO (now at input 1, field 7) + builder.field(2, 0, 0) // DEPT.DEPTNO (now at input 0, field 0) + )) + .project( + // EMP columns first (now at positions 3-10 in the swapped join) + builder.field(3), // EMPNO + builder.field(4), // ENAME + builder.field(5), // JOB + builder.field(6), // MGR + builder.field(7), // HIREDATE + builder.field(8), // SAL + builder.field(9), // COMM + builder.field(10), // DEPTNO + // DEPT columns second (now at positions 0-2 in the swapped join) + builder.field(0), // DEPTNO0 + builder.field(1), // DNAME + builder.field(2) // LOC + ) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.JoinCommute.Config.DEFAULT.toRule(), 1); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running JoinCommute test..."); + runTest(); + } +} \ No newline at end of file From f44de6b1921d457446abdb2941e9393949de09b9 Mon Sep 17 00:00:00 2001 From: wkaiz Date: Wed, 18 Jun 2025 18:53:22 -0700 Subject: [PATCH 69/78] Adding Prune Rules and Its Tests --- .../FilterAggregateTranspose.java | 50 +++++++++++++++ .../RRuleInstances/PruneEmptyMinus.java | 20 ++++++ .../RRuleInstances/PruneEmptyUnion.java | 20 ++++++ .../RRuleInstances/PruneZeroRowsTable.java | 19 ++++++ .../Tests-Trivial/PruneEmptyFilterTest.java | 46 ++++++++++++++ .../Tests-Trivial/PruneEmptyMinusTest.java | 51 +++++++++++++++ .../Tests-Trivial/PruneEmptyProjectTest.java | 41 ++++++++++++ .../Tests-Trivial/PruneEmptyUnionTest.java | 62 +++++++++++++++++++ .../Tests-Trivial/PruneZeroRowsTableTest.java | 40 ++++++++++++ 9 files changed, 349 insertions(+) create mode 100644 src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java create mode 100644 src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyMinus.java create mode 100644 src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyUnion.java create mode 100644 src/main/java/org/qed/Generated/RRuleInstances/PruneZeroRowsTable.java create mode 100644 src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyFilterTest.java create mode 100644 src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyMinusTest.java create mode 100644 src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyProjectTest.java create mode 100644 src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyUnionTest.java create mode 100644 src/main/java/org/qed/Generated/Tests-Trivial/PruneZeroRowsTableTest.java diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java new file mode 100644 index 0000000..0d87ddf --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java @@ -0,0 +1,50 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Seq; +import org.qed.RelRN; +import org.qed.RelType; +import org.qed.RexRN; +import org.qed.RRule; + +public record FilterAggregateTranspose() implements RRule { + + // Define a source relation with at least two columns. + static final RelRN source = RelRN.scan("Source", "Source_Type"); + + // Define the grouping key for the aggregation, using the first column of the source. + static final Seq groupSet = Seq.of(source.field(0)); + + // Define an aggregate function call, e.g., SUM on the second column. + static final Seq aggCalls = Seq.of( + new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(source.field(1))) + ); + + // Define the aggregate node. Its output schema will be (group_key_type, sum_type). + static final RelRN aggregate = source.aggregate(groupSet, aggCalls); + + // Define a predicate that filters on the grouping key (the first column of the aggregate's output). + static final RexRN pred = aggregate.field(0).pred("pred"); + + /** + * The 'before' pattern represents a Filter applied on top of an Aggregate. + */ + @Override + public RelRN before() { + return aggregate.filter(pred); + } + + /** + * The 'after' pattern represents the transposed operators, where the Aggregate + * is applied on top of a Filter. + */ + @Override + public RelRN after() { + // The predicate is rewritten to apply to the aggregate's input (the source). + // The filter condition was on the first field of the aggregate's output (the group key), + // which corresponds to the first field of the original source. + RelRN filteredSource = source.filter(source.field(0).pred("pred")); + + // The aggregation is now applied to the filtered source. + return filteredSource.aggregate(groupSet, aggCalls); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyMinus.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyMinus.java new file mode 100644 index 0000000..09c424c --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyMinus.java @@ -0,0 +1,20 @@ +package org.qed.Generated.RRuleInstances; + +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record PruneEmptyMinus() implements RRule { + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + + @Override + public RelRN before() { + return a.empty().minus(false, b); + } + + @Override + public RelRN after() { + return a.empty(); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyUnion.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyUnion.java new file mode 100644 index 0000000..8d3ee24 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyUnion.java @@ -0,0 +1,20 @@ +package org.qed.Generated.RRuleInstances; + +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record PruneEmptyUnion() implements RRule { + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + + @Override + public RelRN before() { + return a.union(false, b.empty()); + } + + @Override + public RelRN after() { + return a; + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneZeroRowsTable.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneZeroRowsTable.java new file mode 100644 index 0000000..393b226 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneZeroRowsTable.java @@ -0,0 +1,19 @@ +package org.qed.Generated.RRuleInstances; + +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record PruneZeroRowsTable() implements RRule { + static final RelRN a = RelRN.scan("A", "Common_Type"); + + @Override + public RelRN before() { + return a; + } + + @Override + public RelRN after() { + return a; + } +} diff --git a/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyFilterTest.java b/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyFilterTest.java new file mode 100644 index 0000000..dde4f6e --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyFilterTest.java @@ -0,0 +1,46 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class PruneEmptyFilterTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var table = builder.createQedTable( + Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false)) + ); + builder.addTable(table); + + var before = builder + .scan(table.getName()) + .filter( + builder.call( + builder.genericPredicateOp("filter_cond", true), + builder.fields() + ) + ) + .empty() + .build(); + + var after = builder + .scan(table.getName()) + .empty() + .build(); + + var runner = CalciteTester.loadRule( + org.qed.Generated.PruneEmptyFilter.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running PruneEmptyFilter test..."); + runTest(); + } +} diff --git a/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyMinusTest.java b/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyMinusTest.java new file mode 100644 index 0000000..d24d6ad --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyMinusTest.java @@ -0,0 +1,51 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; +import org.apache.calcite.rel.RelNode; + +public class PruneEmptyMinusTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var table = builder.createQedTable( + Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false)) + ); + builder.addTable(table); + + RelNode scanA = builder + .scan(table.getName()) + .build(); + + RelNode scanB = builder + .scan(table.getName()) + .build(); + + RelNode before = builder + .push(scanA) + .push(scanB) + .minus(false) + .empty() + .build(); + + RelNode after = builder + .push(scanA) + .empty() + .build(); + +// var runner = CalciteTester.loadRule( +// org.qed.Generated.PruneEmptyMinus.Config.DEFAULT.toRule() +// ); +// tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running PruneEmptyMinus test..."); + runTest(); + } +} diff --git a/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyProjectTest.java b/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyProjectTest.java new file mode 100644 index 0000000..547da2c --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyProjectTest.java @@ -0,0 +1,41 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class PruneEmptyProjectTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var table = builder.createQedTable( + Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false)) + ); + builder.addTable(table); + + var before = builder + .scan(table.getName()) + .empty() + .project(builder.field(0)) + .build(); + + var after = builder + .scan(table.getName()) + .empty() + .build(); + + var runner = CalciteTester.loadRule( + org.qed.Generated.PruneEmptyProject.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running PruneEmptyProject test..."); + runTest(); + } +} diff --git a/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyUnionTest.java b/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyUnionTest.java new file mode 100644 index 0000000..a18b9eb --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyUnionTest.java @@ -0,0 +1,62 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; +import org.apache.calcite.rel.RelNode; + +public class PruneEmptyUnionTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var table = builder.createQedTable( + Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false)) + ); + builder.addTable(table); + + RelNode scanA = builder + .scan(table.getName()) + .build(); + + RelNode emptyA = builder + .push(scanA) + .empty() + .build(); + + RelNode scanB = builder + .scan(table.getName()) + .build(); + + RelNode emptyB = builder + .push(scanB) + .empty() + .build(); + + RelNode before = builder + .push(scanA) + .push(scanB) + .union(false) + .empty() + .build(); + + RelNode after = builder + .push(emptyA) + .push(emptyB) + .union(false) + .build(); + +// var runner = CalciteTester.loadRule( +// org.qed.Generated.PruneEmptyUnion.Config.DEFAULT.toRule() +// ); +// tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running PruneEmptyMinus test..."); + runTest(); + } +} diff --git a/src/main/java/org/qed/Generated/Tests-Trivial/PruneZeroRowsTableTest.java b/src/main/java/org/qed/Generated/Tests-Trivial/PruneZeroRowsTableTest.java new file mode 100644 index 0000000..b7a91ed --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests-Trivial/PruneZeroRowsTableTest.java @@ -0,0 +1,40 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class PruneZeroRowsTableTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var table = builder.createQedTable( + Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false)) + ); + builder.addTable(table); + + var before = builder + .scan(table.getName()) + .empty() + .build(); + + var after = builder + .scan(table.getName()) + .empty() + .build(); + +// var runner = CalciteTester.loadRule( +// org.qed.Generated.PruneZeroRowsTable.Config.DEFAULT.toRule() +// ); +// tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running PruneZeroRowsTableTest..."); + runTest(); + } +} From 60918b86e47da2d03584806568e63d1eca3acdb1 Mon Sep 17 00:00:00 2001 From: wkaiz Date: Wed, 18 Jun 2025 18:53:35 -0700 Subject: [PATCH 70/78] Testing Out Aggregate --- src/main/java/org/qed/CodeGenerator.java | 10 ++++++++++ src/main/java/org/qed/RelRN.java | 23 ++++++++++++++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/qed/CodeGenerator.java b/src/main/java/org/qed/CodeGenerator.java index bf7da08..7b5009e 100644 --- a/src/main/java/org/qed/CodeGenerator.java +++ b/src/main/java/org/qed/CodeGenerator.java @@ -28,6 +28,7 @@ default E onMatch(E env, RelRN pattern) { case RelRN.Intersect intersect -> onMatchIntersect(env, intersect); case RelRN.Minus minus -> onMatchMinus(env, minus); case RelRN.Empty empty -> onMatchEmpty(env, empty); + case RelRN.Aggregate aggregate -> onMatchAggregate(env, aggregate); default -> onMatchCustom(env, pattern); }; } @@ -65,6 +66,7 @@ default E transform(E env, RelRN target) { case RelRN.Intersect intersect -> transformIntersect(env, intersect); case RelRN.Minus minus -> transformMinus(env, minus); case RelRN.Empty empty -> transformEmpty(env, empty); + case RelRN.Aggregate aggregate -> transformAggregate(env, aggregate); default -> transformCustom(env, target); }; } @@ -250,4 +252,12 @@ default E transformFalse(E env, RexRN literal) { default E transformEmpty(E env, RelRN.Empty empty) { return unimplementedTransform(env, empty); } + + default E onMatchAggregate(E env, RelRN.Aggregate aggregate) { + return unimplementedOnMatch(env, aggregate); + } + + default E transformAggregate(E env, RelRN.Aggregate aggregate) { + return unimplementedTransform(env, aggregate); + } } diff --git a/src/main/java/org/qed/RelRN.java b/src/main/java/org/qed/RelRN.java index becc662..638985b 100644 --- a/src/main/java/org/qed/RelRN.java +++ b/src/main/java/org/qed/RelRN.java @@ -112,6 +112,10 @@ default Empty empty() { return new Empty(this); } + default Aggregate aggregate(Seq groupSet, Seq aggCalls) { + return new Aggregate(this, groupSet, aggCalls); + } + record Scan(String name, RelType.VarType ty, boolean unique) implements RelRN { @Override @@ -200,4 +204,21 @@ public RelNode semantics() { } } -} + record AggCall(String name, boolean distinct, RelType type, Seq operands) { + } + + record Aggregate(RelRN source, Seq groupSet, Seq aggCalls) implements RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(source.semantics()); + var groupKey = builder.groupKey(groupSet.map(RexRN::semantics)); + var calls = aggCalls.map(agg -> { + var aggFunc = builder.genericAggregateOp(agg.name(), agg.type()); + return builder.aggregateCall(aggFunc, agg.distinct(), null, agg.name(), agg.operands().map(RexRN::semantics).asJava()); + }); + return builder.aggregate(groupKey, calls).build(); + } + } + +} \ No newline at end of file From 28db8b2eb513197d5b94f7c03c5d89ce630e4f82 Mon Sep 17 00:00:00 2001 From: wkaiz Date: Wed, 18 Jun 2025 19:01:11 -0700 Subject: [PATCH 71/78] Editing PruneEmptyUnion --- .../org/qed/Generated/PruneEmptyMinus.java | 39 +++++++++++++++++++ .../org/qed/Generated/PruneEmptyUnion.java | 39 +++++++++++++++++++ .../org/qed/Generated/PruneZeroRowsTable.java | 39 +++++++++++++++++++ .../RRuleInstances/PruneEmptyUnion.java | 4 +- 4 files changed, 119 insertions(+), 2 deletions(-) create mode 100644 src/main/java/org/qed/Generated/PruneEmptyMinus.java create mode 100644 src/main/java/org/qed/Generated/PruneEmptyUnion.java create mode 100644 src/main/java/org/qed/Generated/PruneZeroRowsTable.java diff --git a/src/main/java/org/qed/Generated/PruneEmptyMinus.java b/src/main/java/org/qed/Generated/PruneEmptyMinus.java new file mode 100644 index 0000000..41bffd1 --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneEmptyMinus.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneEmptyMinus extends RelRule { + protected PruneEmptyMinus(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.empty().build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneEmptyMinus toRule() { + return new PruneEmptyMinus(this); + } + + @Override + default String description() { + return "PruneEmptyMinus"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalMinus.class).inputs(s_0 -> s_0.operand(LogicalValues.class).noInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/PruneEmptyUnion.java b/src/main/java/org/qed/Generated/PruneEmptyUnion.java new file mode 100644 index 0000000..df48a7d --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneEmptyUnion.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneEmptyUnion extends RelRule { + protected PruneEmptyUnion(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.empty().build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneEmptyUnion toRule() { + return new PruneEmptyUnion(this); + } + + @Override + default String description() { + return "PruneEmptyUnion"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalUnion.class).inputs(s_0 -> s_0.operand(LogicalValues.class).noInputs(), s_1 -> s_1.operand(LogicalValues.class).noInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/PruneZeroRowsTable.java b/src/main/java/org/qed/Generated/PruneZeroRowsTable.java new file mode 100644 index 0000000..af02cf5 --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneZeroRowsTable.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneZeroRowsTable extends RelRule { + protected PruneZeroRowsTable(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_1 = call.builder(); + call.transformTo(var_1.push(call.rel(0)).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneZeroRowsTable toRule() { + return new PruneZeroRowsTable(this); + } + + @Override + default String description() { + return "PruneZeroRowsTable"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_0 -> s_0.operand(RelNode.class).anyInputs(); + } + + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyUnion.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyUnion.java index 8d3ee24..74d3fc5 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyUnion.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyUnion.java @@ -10,11 +10,11 @@ public record PruneEmptyUnion() implements RRule { @Override public RelRN before() { - return a.union(false, b.empty()); + return a.empty().union(false, b.empty()); } @Override public RelRN after() { - return a; + return a.empty(); } } From 192d6856198780d3897f58487f62db71c1037d40 Mon Sep 17 00:00:00 2001 From: wkaiz Date: Wed, 18 Jun 2025 19:01:19 -0700 Subject: [PATCH 72/78] Adding Its Test --- .../qed/Generated/Tests-Trivial/PruneEmptyUnionTest.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyUnionTest.java b/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyUnionTest.java index a18b9eb..dcbbc9d 100644 --- a/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyUnionTest.java +++ b/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyUnionTest.java @@ -49,10 +49,10 @@ public static void runTest() { .union(false) .build(); -// var runner = CalciteTester.loadRule( -// org.qed.Generated.PruneEmptyUnion.Config.DEFAULT.toRule() -// ); -// tester.verify(runner, before, after); + var runner = CalciteTester.loadRule( + org.qed.Generated.PruneEmptyUnion.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); } public static void main(String[] args) { From 727b4343094d0b87531dcfbc6d74d11e5a5a03c5 Mon Sep 17 00:00:00 2001 From: wkaiz Date: Thu, 19 Jun 2025 16:14:54 -0700 Subject: [PATCH 73/78] Removing comments --- rules/JoinCommute-.json | 72 +++++++++++++++++++ .../FilterAggregateTranspose.java | 21 ------ 2 files changed, 72 insertions(+), 21 deletions(-) create mode 100644 rules/JoinCommute-.json diff --git a/rules/JoinCommute-.json b/rules/JoinCommute-.json new file mode 100644 index 0000000..b6bdec7 --- /dev/null +++ b/rules/JoinCommute-.json @@ -0,0 +1,72 @@ +{ + "help" : [ "LogicalJoin(condition=[pred($0, $1)], joinType=[inner])\n LogicalTableScan(table=[[Left]])\n LogicalTableScan(table=[[Right]])\n", "LogicalProject(col-Left=[$1], col-Right=[$0])\n LogicalJoin(condition=[pred($1, $0)], joinType=[inner])\n LogicalTableScan(table=[[Right]])\n LogicalTableScan(table=[[Left]])\n" ], + "schemas" : [ { + "types" : [ "INTEGER" ], + "nullable" : [ true ], + "name" : "Left", + "guaranteed" : [ ], + "fields" : [ "col-Left" ], + "key" : [ ] + }, { + "types" : [ "INTEGER" ], + "nullable" : [ true ], + "name" : "Right", + "guaranteed" : [ ], + "fields" : [ "col-Right" ], + "key" : [ ] + } ], + "queries" : [ { + "join" : { + "condition" : { + "type" : "BOOLEAN", + "operand" : [ { + "column" : 0, + "type" : "INTEGER" + }, { + "column" : 1, + "type" : "INTEGER" + } ], + "operator" : "pred" + }, + "left" : { + "scan" : 0 + }, + "kind" : "INNER", + "right" : { + "scan" : 1 + } + } + }, { + "project" : { + "source" : { + "join" : { + "condition" : { + "type" : "BOOLEAN", + "operand" : [ { + "column" : 1, + "type" : "INTEGER" + }, { + "column" : 0, + "type" : "INTEGER" + } ], + "operator" : "pred" + }, + "left" : { + "scan" : 1 + }, + "kind" : "INNER", + "right" : { + "scan" : 0 + } + } + }, + "target" : [ { + "column" : 1, + "type" : "INTEGER" + }, { + "column" : 0, + "type" : "INTEGER" + } ] + } + } ] +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java index 0d87ddf..ff200db 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java @@ -8,43 +8,22 @@ public record FilterAggregateTranspose() implements RRule { - // Define a source relation with at least two columns. static final RelRN source = RelRN.scan("Source", "Source_Type"); - - // Define the grouping key for the aggregation, using the first column of the source. static final Seq groupSet = Seq.of(source.field(0)); - - // Define an aggregate function call, e.g., SUM on the second column. static final Seq aggCalls = Seq.of( new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(source.field(1))) ); - - // Define the aggregate node. Its output schema will be (group_key_type, sum_type). static final RelRN aggregate = source.aggregate(groupSet, aggCalls); - - // Define a predicate that filters on the grouping key (the first column of the aggregate's output). static final RexRN pred = aggregate.field(0).pred("pred"); - /** - * The 'before' pattern represents a Filter applied on top of an Aggregate. - */ @Override public RelRN before() { return aggregate.filter(pred); } - /** - * The 'after' pattern represents the transposed operators, where the Aggregate - * is applied on top of a Filter. - */ @Override public RelRN after() { - // The predicate is rewritten to apply to the aggregate's input (the source). - // The filter condition was on the first field of the aggregate's output (the group key), - // which corresponds to the first field of the original source. RelRN filteredSource = source.filter(source.field(0).pred("pred")); - - // The aggregation is now applied to the filtered source. return filteredSource.aggregate(groupSet, aggCalls); } } \ No newline at end of file From 30fc245b9e7c6737cbc612121d0b45f57d4fdcdf Mon Sep 17 00:00:00 2001 From: wkaiz Date: Thu, 19 Jun 2025 16:41:12 -0700 Subject: [PATCH 74/78] Updating FilterAggregateTranspose, should pass provability --- .../qed/Generated/RRuleInstances/FilterAggregateTranspose.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java index ff200db..f3ab0eb 100644 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java @@ -11,7 +11,7 @@ public record FilterAggregateTranspose() implements RRule { static final RelRN source = RelRN.scan("Source", "Source_Type"); static final Seq groupSet = Seq.of(source.field(0)); static final Seq aggCalls = Seq.of( - new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(source.field(1))) + new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(source.field(0))) ); static final RelRN aggregate = source.aggregate(groupSet, aggCalls); static final RexRN pred = aggregate.field(0).pred("pred"); From c6810e6762c657bd815c5750bbf4baeabc94f25a Mon Sep 17 00:00:00 2001 From: wkaiz Date: Sat, 28 Jun 2025 16:54:53 -0700 Subject: [PATCH 75/78] Adding Everything Except for FilterAggregateTranpose since other works --- pom.xml | 4 +-- .../FilterAggregateTranspose.java | 29 ------------------- 2 files changed, 2 insertions(+), 31 deletions(-) delete mode 100644 src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java diff --git a/pom.xml b/pom.xml index fd29e81..9bf55a1 100644 --- a/pom.xml +++ b/pom.xml @@ -51,8 +51,8 @@ org.apache.maven.plugins maven-compiler-plugin - 23 - 23 + 24 + 24 --enable-preview diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java deleted file mode 100644 index f3ab0eb..0000000 --- a/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java +++ /dev/null @@ -1,29 +0,0 @@ -package org.qed.Generated.RRuleInstances; - -import kala.collection.Seq; -import org.qed.RelRN; -import org.qed.RelType; -import org.qed.RexRN; -import org.qed.RRule; - -public record FilterAggregateTranspose() implements RRule { - - static final RelRN source = RelRN.scan("Source", "Source_Type"); - static final Seq groupSet = Seq.of(source.field(0)); - static final Seq aggCalls = Seq.of( - new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(source.field(0))) - ); - static final RelRN aggregate = source.aggregate(groupSet, aggCalls); - static final RexRN pred = aggregate.field(0).pred("pred"); - - @Override - public RelRN before() { - return aggregate.filter(pred); - } - - @Override - public RelRN after() { - RelRN filteredSource = source.filter(source.field(0).pred("pred")); - return filteredSource.aggregate(groupSet, aggCalls); - } -} \ No newline at end of file From 75f0b8c07b7645bea9eaf0a38defafee40ff38de Mon Sep 17 00:00:00 2001 From: wkaiz Date: Sat, 28 Jun 2025 16:57:08 -0700 Subject: [PATCH 76/78] Reversing Target version for now --- pom.xml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index 9bf55a1..fd29e81 100644 --- a/pom.xml +++ b/pom.xml @@ -51,8 +51,8 @@ org.apache.maven.plugins maven-compiler-plugin - 24 - 24 + 23 + 23 --enable-preview From 53e8cef2e544906464597c74cedf257d7fc30053 Mon Sep 17 00:00:00 2001 From: wkaiz Date: Thu, 3 Jul 2025 17:46:54 -0700 Subject: [PATCH 77/78] Adding AggregateFilterTranspose and FilterAggregateTranspose --- .../org/qed/Generated/CalciteGenerator.java | 138 +++++++++++------- .../AggregateFilterTranspose.java | 30 ++++ .../FilterAggregateTranspose.java | 29 ++++ src/main/java/org/qed/RelRN.java | 2 +- 4 files changed, 142 insertions(+), 57 deletions(-) create mode 100644 src/main/java/org/qed/Generated/RRuleInstances/AggregateFilterTranspose.java create mode 100644 src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java diff --git a/src/main/java/org/qed/Generated/CalciteGenerator.java b/src/main/java/org/qed/Generated/CalciteGenerator.java index 32691cd..4a74dc1 100644 --- a/src/main/java/org/qed/Generated/CalciteGenerator.java +++ b/src/main/java/org/qed/Generated/CalciteGenerator.java @@ -116,12 +116,12 @@ public Env onMatchAnd(Env env, RexRN.And and) { String andSymbol = "and_" + env.varId.getAndIncrement(); // Store the current expression as this AND node's symbol current_env = current_env.symbol(andSymbol, current_env.current()); - + // Process each child source in the AND condition for (var source : and.sources()) { current_env = onMatch(current_env, source); } - + return current_env; } @@ -129,11 +129,11 @@ public Env onMatchAnd(Env env, RexRN.And and) { public Env onMatchUnion(Env env, RelRN.Union union) { // Get the all flag from the union boolean all = union.all(); - + // Process each source in the union var current_env = env; var skeletons = Seq.empty(); - + // Process all sources in the sequence for (var source : union.sources()) { var next_env = current_env.next(); @@ -141,7 +141,7 @@ public Env onMatchUnion(Env env, RelRN.Union union) { skeletons = skeletons.appended(source_env.skeleton()); current_env = source_env; } - + // Build the input skeletons string for the operand StringBuilder inputsBuilder = new StringBuilder(); for (int i = 0; i < skeletons.size(); i++) { @@ -150,7 +150,7 @@ public Env onMatchUnion(Env env, RelRN.Union union) { } inputsBuilder.append(skeletons.get(i).toString()); } - + // Create the union operand with the appropriate class based on the all flag String operatorClass = all ? "LogicalUnionAll" : "LogicalUnion"; return current_env.grow("operand(" + operatorClass + ".class).inputs(" + inputsBuilder.toString() + ")"); @@ -160,11 +160,11 @@ public Env onMatchUnion(Env env, RelRN.Union union) { public Env onMatchIntersect(Env env, RelRN.Intersect intersect) { // Get the all flag from the intersect boolean all = intersect.all(); - + // Process each source in the intersect var current_env = env; var skeletons = Seq.empty(); - + // Process all sources in the sequence for (var source : intersect.sources()) { var next_env = current_env.next(); @@ -172,7 +172,7 @@ public Env onMatchIntersect(Env env, RelRN.Intersect intersect) { skeletons = skeletons.appended(source_env.skeleton()); current_env = source_env; } - + // Build the input skeletons string for the operand StringBuilder inputsBuilder = new StringBuilder(); for (int i = 0; i < skeletons.size(); i++) { @@ -181,7 +181,7 @@ public Env onMatchIntersect(Env env, RelRN.Intersect intersect) { } inputsBuilder.append(skeletons.get(i).toString()); } - + // Create the intersect operand with the appropriate class based on the all flag String operatorClass = all ? "LogicalIntersectAll" : "LogicalIntersect"; return current_env.grow("operand(" + operatorClass + ".class).inputs(" + inputsBuilder.toString() + ")"); @@ -191,11 +191,11 @@ public Env onMatchIntersect(Env env, RelRN.Intersect intersect) { public Env onMatchMinus(Env env, RelRN.Minus minus) { // Get the all flag from the minus boolean all = minus.all(); - + // Process each source in the minus var current_env = env; var skeletons = Seq.empty(); - + // Process all sources in the sequence for (var source : minus.sources()) { var next_env = current_env.next(); @@ -203,7 +203,7 @@ public Env onMatchMinus(Env env, RelRN.Minus minus) { skeletons = skeletons.appended(source_env.skeleton()); current_env = source_env; } - + // Build the input skeletons string for the operand StringBuilder inputsBuilder = new StringBuilder(); for (int i = 0; i < skeletons.size(); i++) { @@ -212,7 +212,7 @@ public Env onMatchMinus(Env env, RelRN.Minus minus) { } inputsBuilder.append(skeletons.get(i).toString()); } - + // Create the minus operand return current_env.grow("operand(LogicalMinus.class).inputs(" + inputsBuilder.toString() + ")"); } @@ -221,7 +221,7 @@ public Env onMatchMinus(Env env, RelRN.Minus minus) { public Env onMatchField(Env env, RexRN.Field field) { // Generate a unique symbolic name for this field String fieldSymbol = "field_" + env.varId.getAndIncrement(); - + // Store the field expression in the environment's symbol table return env.symbol(fieldSymbol, env.current()); } @@ -230,7 +230,7 @@ public Env onMatchField(Env env, RexRN.Field field) { public Env onMatchTrue(Env env, RexRN literal) { // Create a unique symbol name for this true literal String trueSymbol = "true_" + env.varId.getAndIncrement(); - + // Store the current expression as this true literal's symbol return env.symbol(trueSymbol, env.current()); } @@ -239,7 +239,7 @@ public Env onMatchTrue(Env env, RexRN literal) { public Env onMatchFalse(Env env, RexRN literal) { // Create a unique symbol name for this false literal String falseSymbol = "false_" + env.varId.getAndIncrement(); - + // Store the current expression as this false literal's symbol return env.symbol(falseSymbol, env.current()); } @@ -290,28 +290,28 @@ public Env transformPred(Env env, RexRN.Pred pred) { pred.sources().get(1) instanceof RexRN.JoinField joinField2 && joinField1.ordinal() == 1 && // First arg is ordinal 1 (left -> right in swap) joinField2.ordinal() == 0; // Second arg is ordinal 0 (right -> left in swap) - + if (isExactJoinCommuteSwapPattern) { // This is the JoinCommute swapped predicate - transform with swapped arguments var currentEnv = env; var transformedArgs = Seq.empty(); - + // Process arguments in reverse order to swap them back var sources = pred.sources(); var reversedSources = Seq.of(sources.get(1), sources.get(0)); - + for (var arg : reversedSources) { currentEnv = transform(currentEnv, arg); transformedArgs = transformedArgs.appended(currentEnv.current()); currentEnv = currentEnv.focus(env.current()); } - + // Create a new predicate call with transformed arguments String argsString = transformedArgs.joinToString(", "); - + // Extract operator from the original join condition String operatorCall = "((org.apache.calcite.rex.RexCall) ((LogicalJoin) call.rel(0)).getCondition()).getOperator()"; - + return currentEnv.focus(env.current() + ".call(" + operatorCall + ", " + argsString + ")"); } else { return env.focus(env.symbols().get(pred.operator().getName())); @@ -327,30 +327,30 @@ public Env transformJoinField(Env env, RexRN.JoinField joinField) { var envWithOrigJoin = origJoinDecl.getValue(); var conditionDecl = envWithOrigJoin.declare("(org.apache.calcite.rex.RexCall) " + origJoinDecl.getKey() + ".getCondition()"); var envWithCondition = conditionDecl.getValue(); - + if (joinField.ordinal() == 0) { - // Ordinal 0 = Left table in original join - // Extract the left operand field index from original condition + // Ordinal 0 = Left table in original join + // Extract the left operand field index from original condition var leftFieldDecl = envWithCondition.declare("((org.apache.calcite.rex.RexInputRef) " + conditionDecl.getKey() + ".getOperands().get(0)).getIndex()"); var envWithLeftField = leftFieldDecl.getValue(); - + // In swapped join: Left table is now at input 1 // Use field(2, 1, leftFieldIndex) syntax return envWithLeftField.focus(env.current() + ".field(2, 1, " + leftFieldDecl.getKey() + ")"); - } + } else if (joinField.ordinal() == 1) { - // Ordinal 1 = Right table in original join + // Ordinal 1 = Right table in original join // Extract the right operand field index from original condition var rightFieldDecl = envWithCondition.declare("((org.apache.calcite.rex.RexInputRef) " + conditionDecl.getKey() + ".getOperands().get(1)).getIndex()"); var envWithRightField = rightFieldDecl.getValue(); - + // Right table field index needs to be adjusted since it was originally after left table var leftColCountDecl = envWithRightField.declare("call.rel(1).getRowType().getFieldCount()"); var envWithLeftCount = leftColCountDecl.getValue(); var adjustedRightFieldDecl = envWithLeftCount.declare(rightFieldDecl.getKey() + " - " + leftColCountDecl.getKey()); var envWithAdjustedRightField = adjustedRightFieldDecl.getValue(); - - // In swapped join: Right table is now at input 0 + + // In swapped join: Right table is now at input 0 // Use field(2, 0, adjustedRightFieldIndex) syntax return envWithAdjustedRightField.focus(env.current() + ".field(2, 0, " + adjustedRightFieldDecl.getKey() + ")"); } else { @@ -391,16 +391,16 @@ public Env transformAnd(Env env, RexRN.And and) { public Env transformUnion(Env env, RelRN.Union union) { // Get the all flag from the union boolean all = union.all(); - + // The number of sources int sourceCount = union.sources().size(); - + // Transform each source var current_env = env; for (var source : union.sources()) { current_env = transform(current_env, source); } - + // Use the union method with the all flag and source count // This matches the Calcite RelBuilder.union(boolean all, int n) signature return current_env.focus(current_env.current() + ".union(" + all + ", " + sourceCount + ")"); @@ -410,16 +410,16 @@ public Env transformUnion(Env env, RelRN.Union union) { public Env transformIntersect(Env env, RelRN.Intersect intersect) { // Get the all flag from the intersect boolean all = intersect.all(); - + // The number of sources int sourceCount = intersect.sources().size(); - + // Transform each source var current_env = env; for (var source : intersect.sources()) { current_env = transform(current_env, source); } - + // Use the intersect method with the all flag and source count // This matches the expected Calcite RelBuilder.intersect(boolean all, int n) signature String methodName = all ? "intersectAll" : "intersect"; @@ -430,16 +430,16 @@ public Env transformIntersect(Env env, RelRN.Intersect intersect) { public Env transformMinus(Env env, RelRN.Minus minus) { // Get the all flag from the minus boolean all = minus.all(); - + // The number of sources int sourceCount = minus.sources().size(); - + // Transform each source var current_env = env; for (var source : minus.sources()) { current_env = transform(current_env, source); } - + // Use the minus method with the all flag and source count // This matches the expected Calcite RelBuilder.minus(boolean all, int n) signature return current_env.focus(current_env.current() + ".minus(" + all + ", " + sourceCount + ")"); @@ -449,7 +449,7 @@ public Env transformMinus(Env env, RelRN.Minus minus) { public Env transformField(Env env, RexRN.Field field) { // In Calcite, field references are typically created with a "field" method // We'll need to pass some identifier for the field - use toString() if no specific field accessor is available - + // Assuming field has a method that returns some kind of identifier or name // If not, we may need to adjust this implementation return env.focus(env.current() + ".field(" + field + ")"); @@ -459,13 +459,13 @@ public Env transformField(Env env, RexRN.Field field) { public Env transformProj(Env env, RexRN.Proj proj) { // In Calcite, projections are typically created using the operator name // This is similar to your transformPred implementation - + // Look up the symbol from the matching phase if (!env.symbols().containsKey(proj.operator().getName())) { - throw new RuntimeException("Operator symbol not found: " + proj.operator().getName() + + throw new RuntimeException("Operator symbol not found: " + proj.operator().getName() + ". Make sure onMatchProj is properly implemented."); } - + // Return an environment focused on the expression for this projection return env.focus(env.symbols().get(proj.operator().getName())); } @@ -475,10 +475,10 @@ public Env transformProject(Env env, RelRN.Project project) { // First transform the source relation var source_transform = transform(env, project.source()); var source_expression = source_transform.current(); - + // Then transform the projection map var map_transform = transform(source_transform, project.map()); - + // Combine the source and projection using the project operation // This creates a projection on top of the source relation return map_transform.focus(source_expression + ".project(" + map_transform.current() + ")"); @@ -486,14 +486,14 @@ public Env transformProject(Env env, RelRN.Project project) { @Override public Env transformTrue(Env env, RexRN literal) { - // In Calcite, true literals are typically represented using the + // In Calcite, true literals are typically represented using the // rexBuilder.makeLiteral(true) method or just "TRUE" return env.focus(env.current() + ".literal(true)"); } @Override public Env transformFalse(Env env, RexRN literal) { - // In Calcite, false literals are represented using the + // In Calcite, false literals are represented using the // rexBuilder.makeLiteral(false) method or just "FALSE" return env.focus(env.current() + ".literal(false)"); } @@ -505,43 +505,43 @@ public Env transformEmpty(Env env, RelRN.Empty empty) { return env.focus(env.current() + ".empty()"); } - + @Override public Env transformCustom(Env env, RelRN custom) { return switch (custom) { case org.qed.Generated.RRuleInstances.JoinCommute.ProjectionRelRN projection -> { // Transform the source first - this builds the join var sourceEnv = transform(env, projection.source()); - + // Get original table column counts var leftTableDecl = sourceEnv.declare("call.rel(1)"); var envWithLeftTable = leftTableDecl.getValue(); var rightTableDecl = envWithLeftTable.declare("call.rel(2)"); var envWithRightTable = rightTableDecl.getValue(); - + var leftColCountDecl = envWithRightTable.declare(leftTableDecl.getKey() + ".getRowType().getFieldCount()"); var envWithLeftCount = leftColCountDecl.getValue(); var rightColCountDecl = envWithLeftCount.declare(rightTableDecl.getKey() + ".getRowType().getFieldCount()"); var envWithRightCount = rightColCountDecl.getValue(); - + // Create the projection indices as a List var projectionIndicesDecl = envWithRightCount.declare( "java.util.stream.IntStream.concat(" + // Left columns: rightColCount + 0, rightColCount + 1, ..., rightColCount + leftColCount - 1 - "java.util.stream.IntStream.range(" + rightColCountDecl.getKey() + ", " + + "java.util.stream.IntStream.range(" + rightColCountDecl.getKey() + ", " + rightColCountDecl.getKey() + " + " + leftColCountDecl.getKey() + "), " + // Right columns: 0, 1, ..., rightColCount - 1 "java.util.stream.IntStream.range(0, " + rightColCountDecl.getKey() + ")" + ").boxed().collect(java.util.stream.Collectors.toList())" ); var envWithProjectionIndices = projectionIndicesDecl.getValue(); - + // Convert List to field references using RelBuilder.fields() var fieldRefsDecl = envWithProjectionIndices.declare( sourceEnv.current() + ".fields(" + projectionIndicesDecl.getKey() + ")" ); var envWithFieldRefs = fieldRefsDecl.getValue(); - + // Apply projection using the field references list yield envWithFieldRefs.focus(sourceEnv.current() + ".project(" + fieldRefsDecl.getKey() + ")"); } @@ -582,4 +582,30 @@ public Env grow(String requirement) { return new Env(varId, rel, current, vn + " -> " + vn + "." + requirement, statements, symbols); } } + + @Override + public Env onMatchAggregate(Env env, RelRN.Aggregate aggregate) { + var sourceMatch = onMatch(env.next(), aggregate.source()); + return sourceMatch.grow("operand(LogicalAggregate.class).oneInput(" + sourceMatch.skeleton() + ")"); + } + + @Override + public Env transformAggregate(Env env, RelRN.Aggregate aggregate) { + var sourceTransform = transform(env, aggregate.source()); + String builderWithSource = sourceTransform.current(); + String originalAgg = "((LogicalAggregate) call.rel(0))"; + var groupSetDecl = sourceTransform.declare(originalAgg + ".getGroupSet()"); + var envWithGroupSet = groupSetDecl.getValue(); + var groupKeyDecl = envWithGroupSet.declare( + builderWithSource + ".groupKey(" + groupSetDecl.getKey() + ")" + ); + var envWithGroupKey = groupKeyDecl.getValue(); + var aggCallsDecl = envWithGroupKey.declare(originalAgg + ".getAggCallList()"); + var envWithAggCalls = aggCallsDecl.getValue(); + return envWithAggCalls.focus( + builderWithSource + ".aggregate(" + + groupKeyDecl.getKey() + ", " + + aggCallsDecl.getKey() + ")" + ); + } } diff --git a/src/main/java/org/qed/Generated/RRuleInstances/AggregateFilterTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/AggregateFilterTranspose.java new file mode 100644 index 0000000..ab35583 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/AggregateFilterTranspose.java @@ -0,0 +1,30 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RelType; + +public record AggregateFilterTranspose() implements RRule { + static final RelRN source_col1 = RelRN.scan("Source1", "Int_Type"); + static final RelRN source_col2 = RelRN.scan("Source2", "Int_Type"); + static final RelRN source = source_col1.join(JoinRelType.INNER, RexRN.trueLiteral(), source_col2); + static final RexRN filterCondition = source.field(0).pred("cond"); + + @Override + public RelRN before() { + RelRN filteredSource = source.filter(filterCondition); + return new RelRN.Aggregate( + filteredSource, + Seq.of(filteredSource.field(0)), + Seq.of(new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(filteredSource.field(1)))) + ); + } + @Override + public RelRN after() { + RelRN.Aggregate aggregateOnSource = new RelRN.Aggregate(source, Seq.of(source.field(0)), Seq.of(new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(source.field(1))))); + return aggregateOnSource.filter(aggregateOnSource.field(0).pred("cond")); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java new file mode 100644 index 0000000..48c92d9 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java @@ -0,0 +1,29 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RelType; + +public record FilterAggregateTranspose() implements RRule { + static final RelRN source_col1 = RelRN.scan("Source1", "Int_Type"); + static final RelRN source_col2 = RelRN.scan("Source2", "Int_Type"); + static final RelRN source = source_col1.join(JoinRelType.INNER, RexRN.trueLiteral(), source_col2); + static final RelRN.Aggregate aggregate = new RelRN.Aggregate(source, Seq.of(source.field(0)), Seq.of(new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(source.field(1))))); + static final RexRN pushedCondition = aggregate.field(0).pred("pushed_cond"); + static final RexRN remainingCondition = aggregate.field(1).pred("remaining_cond"); + + @Override + public RelRN before() { + return aggregate.filter(RexRN.and(pushedCondition, remainingCondition)); + } + + @Override + public RelRN after() { + RelRN filteredSource = source.filter(source.field(0).pred("pushed_cond")); + RelRN.Aggregate newAggregate = new RelRN.Aggregate(filteredSource, Seq.of(filteredSource.field(0)), Seq.of(new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(filteredSource.field(1))))); + return newAggregate.filter(newAggregate.field(1).pred("remaining_cond")); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/RelRN.java b/src/main/java/org/qed/RelRN.java index 638985b..0967127 100644 --- a/src/main/java/org/qed/RelRN.java +++ b/src/main/java/org/qed/RelRN.java @@ -207,7 +207,7 @@ public RelNode semantics() { record AggCall(String name, boolean distinct, RelType type, Seq operands) { } - record Aggregate(RelRN source, Seq groupSet, Seq aggCalls) implements RelRN { + record Aggregate(org.qed.RelRN source, Seq groupSet, Seq aggCalls) implements org.qed.RelRN { @Override public RelNode semantics() { var builder = RuleBuilder.create(); From ff3c353e20092f209e863cfd8f7fa620186efa66 Mon Sep 17 00:00:00 2001 From: wkaiz Date: Mon, 21 Jul 2025 18:56:58 -0700 Subject: [PATCH 78/78] Testing... --- .../RRuleInstances/ProjectAggregateMerge.java | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 src/main/java/org/qed/Generated/RRuleInstances/ProjectAggregateMerge.java diff --git a/src/main/java/org/qed/Generated/RRuleInstances/ProjectAggregateMerge.java b/src/main/java/org/qed/Generated/RRuleInstances/ProjectAggregateMerge.java new file mode 100644 index 0000000..55bbfb9 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/ProjectAggregateMerge.java @@ -0,0 +1,42 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RelType; + +public record ProjectAggregateMerge() implements RRule { + // Create a two-column source relation using a join, as the aggregate needs multiple columns. + static final RelRN source_col1 = RelRN.scan("SourceCol1", "Int_Type"); + static final RelRN source_col2 = RelRN.scan("SourceCol2", "Int_Type"); + static final RelRN source = source_col1.join(JoinRelType.INNER, RexRN.trueLiteral(), source_col2); + + static final RelRN.Aggregate aggregate = new RelRN.Aggregate( + source, + Seq.of(source.field(0)), // group by key (from source_col1) + Seq.of(new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(source.field(1)))) // sum(field from source_col2) + ); + + @Override + public RelRN before() { + // Project on top of the aggregate, selecting only the SUM result. + return aggregate.project(aggregate.field(1)); + } + + @Override + public RelRN after() { + // The 'after' state represents the merged operation. + // This would be a single aggregate operator that produces the projected result directly. + // Since the projection removes the group key, the ideal 'after' would be an aggregate + // that computes the sum for each group but does not output the group key. + // As a placeholder for this complex transformation, we create an aggregate + // with an empty group set, which results in a single row output (total sum). + return new RelRN.Aggregate( + source, + Seq.empty(), + Seq.of(new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(source.field(1)))) + ); + } +} \ No newline at end of file