diff --git a/.github/workflows/codegen-test.yml b/.github/workflows/codegen-test.yml new file mode 100644 index 0000000..85022d6 --- /dev/null +++ b/.github/workflows/codegen-test.yml @@ -0,0 +1,46 @@ +name: Test Code Generation + +on: + push: + pull_request: + +jobs: + test-code-generation: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Java + uses: actions/setup-java@v4 + with: + java-version: '23' + distribution: 'temurin' + + - name: Cache Maven dependencies + uses: actions/cache@v3 + with: + path: ~/.m2/repository + key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }} + restore-keys: | + ${{ runner.os }}-maven- + + - name: Build project + run: mvn -B compile --file pom.xml + + - name: Run Calcite tests + run: | + chmod +x scripts/test-codegen.sh + ./scripts/test-codegen.sh + + - name: Upload generated code + if: always() + uses: actions/upload-artifact@v4 + with: + name: generated-code + path: | + src/main/java/org/qed/Generated/*.java + !src/main/java/org/qed/Generated/CalciteTester.java + !src/main/java/org/qed/Generated/CalciteGenerator.java + !src/main/java/org/qed/Generated/CalciteUtilities.java + !src/main/java/org/qed/Generated/EmptyConfig.java \ No newline at end of file diff --git a/.github/workflows/prover-test.yml b/.github/workflows/prover-test.yml new file mode 100644 index 0000000..eca7d97 --- /dev/null +++ b/.github/workflows/prover-test.yml @@ -0,0 +1,58 @@ +name: Test Provability + +on: + push: + pull_request: + +jobs: + test-provability: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Java + uses: actions/setup-java@v4 + with: + java-version: '23' + distribution: 'temurin' + + - name: Cache Maven dependencies + uses: actions/cache@v3 + with: + path: ~/.m2/repository + key: ${{ runner.os }}-maven-${{ hashFiles('**/pom.xml') }} + restore-keys: | + ${{ runner.os }}-maven- + + - name: Build project + run: mvn -B compile --file pom.xml + + - name: Generate JSON for all rules + run: | + mkdir -p tmp-rules + mvn dependency:resolve + chmod +x scripts/generate-rule-json.sh + ./scripts/generate-rule-json.sh + + - name: Install dependencies + run: | + chmod +x scripts/install-dependencies.sh + ./scripts/install-dependencies.sh + + - name: Build qed-prover + run: | + chmod +x scripts/build-qed-prover.sh + ./scripts/build-qed-prover.sh + + - name: Test all rules + run: | + chmod +x scripts/test-rules.sh + ./scripts/test-rules.sh + + - name: Upload test artifacts + if: always() + uses: actions/upload-artifact@v4 + with: + name: results + path: tmp-rules/ \ No newline at end of file diff --git a/.gitignore b/.gitignore index 1b536bd..a2ef0ec 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ target/ *.iml .mvn/wrapper/maven-wrapper.jar +.vscode +.devcontainer diff --git a/flake.lock b/flake.lock index 02f9ea2..b8e3f20 100644 --- a/flake.lock +++ b/flake.lock @@ -1,15 +1,31 @@ { "nodes": { + "cvc5-src": { + "flake": false, + "locked": { + "lastModified": 1708192045, + "narHash": "sha256-hLb5qvWhAKeaGpR1HMH3URVXLuXjuvqapjx+loCa7Nc=", + "owner": "cvc5", + "repo": "cvc5", + "rev": "80878ee024f58a9a883479eed5dfe06402109e94", + "type": "github" + }, + "original": { + "owner": "cvc5", + "repo": "cvc5", + "type": "github" + } + }, "flake-utils": { "inputs": { "systems": "systems" }, "locked": { - "lastModified": 1694529238, - "narHash": "sha256-zsNZZGTGnMOf9YpHKJqMSsa0dXbfmxeoJ7xHlrt+xmY=", + "lastModified": 1705309234, + "narHash": "sha256-uNRRNRKmJyCRC/8y1RqBkqWBLM034y4qN7EprSdmgyA=", "owner": "numtide", "repo": "flake-utils", - "rev": "ff7b65b44d01cf9ba6a71320833626af21126384", + "rev": "1ef2e671c3b0c19053962c07dbda38332dcebf26", "type": "github" }, "original": { @@ -20,11 +36,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1700108881, - "narHash": "sha256-+Lqybl8kj0+nD/IlAWPPG/RDTa47gff9nbei0u7BntE=", + "lastModified": 1708247094, + "narHash": "sha256-H2VS7VwesetGDtIaaz4AMsRkPoSLEVzL/Ika8gnbUnE=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "7414e9ee0b3e9903c24d3379f577a417f0aae5f1", + "rev": "045b51a3ae66f673ed44b5bbd1f4a341d96703bf", "type": "github" }, "original": { @@ -34,23 +50,11 @@ "type": "github" } }, - "parser-release": { - "flake": false, - "locked": { - "narHash": "sha256-zlM8w0O4yYD/GYRBpr5ezoEY934mN5UoW4Ij9YSkNxw=", - "type": "file", - "url": "https://github.com/qed-solver/parser/releases/download/latest/qed-parser-1.0-SNAPSHOT-jar-with-dependencies.jar" - }, - "original": { - "type": "file", - "url": "https://github.com/qed-solver/parser/releases/download/latest/qed-parser-1.0-SNAPSHOT-jar-with-dependencies.jar" - } - }, "root": { "inputs": { + "cvc5-src": "cvc5-src", "flake-utils": "flake-utils", - "nixpkgs": "nixpkgs", - "parser-release": "parser-release" + "nixpkgs": "nixpkgs" } }, "systems": { diff --git a/flake.nix b/flake.nix index 20d9a53..296ddb6 100644 --- a/flake.nix +++ b/flake.nix @@ -2,35 +2,57 @@ inputs = { nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; flake-utils.url = "github:numtide/flake-utils"; - parser-release = { + cvc5-src = { flake = false; - url = "https://github.com/qed-solver/parser/releases/download/latest/qed-parser-1.0-SNAPSHOT-jar-with-dependencies.jar"; + url = "github:cvc5/cvc5"; }; }; - outputs = { self, nixpkgs, flake-utils, parser-release }: + outputs = { self, nixpkgs, flake-utils, cvc5-src }: flake-utils.lib.eachDefaultSystem (system: let pkgs = nixpkgs.legacyPackages.${system}; - parser = pkgs.stdenv.mkDerivation { - name = "qed-parser"; - src = parser-release; - buildInputs = with pkgs; [ jre ]; - nativeBuildInputs = with pkgs; [ makeWrapper ]; - buildCommand = '' - jar=$out/share/java/qed-parser.jar - install -Dm444 $src $jar - makeWrapper ${pkgs.jre}/bin/java $out/bin/qed-parser --add-flags "--enable-preview --add-opens=java.base/java.lang.reflect=ALL-UNNAMED -jar $jar" + cvc5-pname = "cvc5-1.0.5"; + cvc5-java = pkgs.stdenv.mkDerivation { + name = cvc5-pname; + src = cvc5-src; + + nativeBuildInputs = with pkgs; [ pkg-config cmake flex ]; + + buildInputs = with pkgs; [ + cadical.dev + symfpu + gmp + gtest + libantlr3c + antlr3_4 + boost + jdk21 + (python3.withPackages (ps: with ps; [ pyparsing toml ])) + ]; + + cmakeFlags = [ + "-DBUILD_BINDINGS_JAVA=ON" + "-DBUILD_SHARED_LIBS=1" + "-DCMAKE_BUILD_TYPE=Production" + "-DANTLR3_JAR=${pkgs.antlr3_4}/lib/antlr/antlr-3.4-complete.jar" + ]; + + preConfigure = '' + patchShebangs ./src/ ''; + }; in { - packages.default = parser; devShells.default = pkgs.mkShell { packages = with pkgs; [ - jdk + cvc5-java + jdk21 jetbrains.idea-community ]; + CVC5_JAVA = "${cvc5-java}/share/java/cvc5.jar"; + LD_LIBRARY_PATH = pkgs.lib.strings.makeLibraryPath [ cvc5-java ]; }; }); } diff --git a/pom.xml b/pom.xml index 5f3ace2..fd29e81 100644 --- a/pom.xml +++ b/pom.xml @@ -21,7 +21,6 @@ -classpath org.qed.Main - ${args} @@ -52,9 +51,11 @@ org.apache.maven.plugins maven-compiler-plugin - 19 - 19 - --enable-preview + 23 + 23 + + --enable-preview + @@ -88,13 +89,28 @@ org.slf4j - slf4j-nop - 2.0.5 + slf4j-api + 2.0.12 + + + org.slf4j + slf4j-simple + 2.0.12 org.glavo.kala kala-common 0.67.0 + + io.github.p-org.solvers + cvc5 + 0.0.7-v5 + + + org.reflections + reflections + 0.10.2 + - + \ No newline at end of file diff --git a/rules/JoinCommute-.json b/rules/JoinCommute-.json new file mode 100644 index 0000000..b6bdec7 --- /dev/null +++ b/rules/JoinCommute-.json @@ -0,0 +1,72 @@ +{ + "help" : [ "LogicalJoin(condition=[pred($0, $1)], joinType=[inner])\n LogicalTableScan(table=[[Left]])\n LogicalTableScan(table=[[Right]])\n", "LogicalProject(col-Left=[$1], col-Right=[$0])\n LogicalJoin(condition=[pred($1, $0)], joinType=[inner])\n LogicalTableScan(table=[[Right]])\n LogicalTableScan(table=[[Left]])\n" ], + "schemas" : [ { + "types" : [ "INTEGER" ], + "nullable" : [ true ], + "name" : "Left", + "guaranteed" : [ ], + "fields" : [ "col-Left" ], + "key" : [ ] + }, { + "types" : [ "INTEGER" ], + "nullable" : [ true ], + "name" : "Right", + "guaranteed" : [ ], + "fields" : [ "col-Right" ], + "key" : [ ] + } ], + "queries" : [ { + "join" : { + "condition" : { + "type" : "BOOLEAN", + "operand" : [ { + "column" : 0, + "type" : "INTEGER" + }, { + "column" : 1, + "type" : "INTEGER" + } ], + "operator" : "pred" + }, + "left" : { + "scan" : 0 + }, + "kind" : "INNER", + "right" : { + "scan" : 1 + } + } + }, { + "project" : { + "source" : { + "join" : { + "condition" : { + "type" : "BOOLEAN", + "operand" : [ { + "column" : 1, + "type" : "INTEGER" + }, { + "column" : 0, + "type" : "INTEGER" + } ], + "operator" : "pred" + }, + "left" : { + "scan" : 1 + }, + "kind" : "INNER", + "right" : { + "scan" : 0 + } + } + }, + "target" : [ { + "column" : 1, + "type" : "INTEGER" + }, { + "column" : 0, + "type" : "INTEGER" + } ] + } + } ] +} \ No newline at end of file diff --git a/scripts/build-qed-prover.sh b/scripts/build-qed-prover.sh new file mode 100644 index 0000000..10fe4b0 --- /dev/null +++ b/scripts/build-qed-prover.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +# Build qed-prover with Rust nightly + +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly +source $HOME/.cargo/env + +git clone https://github.com/qed-solver/prover.git qed-prover +cd qed-prover +cargo +nightly build --release \ No newline at end of file diff --git a/scripts/generate-rule-json.sh b/scripts/generate-rule-json.sh new file mode 100644 index 0000000..a34f5da --- /dev/null +++ b/scripts/generate-rule-json.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +# Script to generate JSON files for all RRule instances + +# Create temporary Java file for JSON generation +cat > JsonGenerator.java << 'EOF' +import org.qed.*; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import java.nio.file.*; + +public class JsonGenerator { + public static void main(String[] args) throws Exception { + String className = args[0]; + Class clazz = Class.forName(className); + RRule rule = (RRule) clazz.getDeclaredConstructor().newInstance(); + ObjectMapper mapper = new ObjectMapper(); + String fileName = rule.name() + "-" + rule.info() + ".json"; + ObjectNode jsonNode = rule.toJson(); + mapper.writerWithDefaultPrettyPrinter().writeValue( + Path.of("tmp-rules", fileName).toFile(), + jsonNode + ); + } +} +EOF + +# Build classpath +MAVEN_CP=$(mvn dependency:build-classpath -Dmdep.outputFile=/dev/stdout -q) +CLASSPATH="target/classes:${MAVEN_CP}" + +# Compile the generator +javac -cp "$CLASSPATH" JsonGenerator.java + +# Generate JSON for each rule +find src/main/java/org/qed/Generated/RRuleInstances -name '*.java' | while read file; do + className=$(echo "$file" | sed 's|src/main/java/||; s|/|.|g; s|\.java$||') + echo "Generating JSON for: $className" + java -cp ".:$CLASSPATH" JsonGenerator "$className" +done + +# Cleanup +rm -f JsonGenerator.java JsonGenerator.class + +echo "JSON generation complete. Files are in tmp-rules/" \ No newline at end of file diff --git a/scripts/install-dependencies.sh b/scripts/install-dependencies.sh new file mode 100644 index 0000000..500879f --- /dev/null +++ b/scripts/install-dependencies.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +# Install required dependencies for qed-prover + +sudo apt-get update +sudo apt-get install -y jq z3 + +# Install cvc5 - try package manager first, otherwise build from source +sudo apt-get install -y cvc5 || ( + sudo apt-get install -y cmake libgmp-dev && + git clone --depth 1 https://github.com/cvc5/cvc5.git && + cd cvc5 && + ./configure.sh --auto-download && + cd build && + make -j$(nproc) && + sudo make install +) \ No newline at end of file diff --git a/scripts/test-codegen.sh b/scripts/test-codegen.sh new file mode 100644 index 0000000..6ab614e --- /dev/null +++ b/scripts/test-codegen.sh @@ -0,0 +1,100 @@ +#!/bin/bash + +# Script to generate code for each rule and test whether the rules can be applied correctly + +echo "## Code Generation Test Results" >> $GITHUB_STEP_SUMMARY +echo "" >> $GITHUB_STEP_SUMMARY + +# Step 1: Generate code for each rule in RRuleInstances +# Create temporary Java file for code generation +cat > RuleGenerator.java << 'EOF' +import org.qed.Generated.CalciteTester; +import org.qed.*; +import java.nio.file.*; + +public class RuleGenerator { + public static void main(String[] args) throws Exception { + String className = args[0]; + Class clazz = Class.forName(className); + RRule rule = (RRule) clazz.getDeclaredConstructor().newInstance(); + + CalciteTester tester = new CalciteTester(); + tester.serialize(rule, CalciteTester.genPath); + + System.out.println("Generated code for: " + rule.name()); + } +} +EOF + +# Build classpath +MAVEN_CP=$(mvn dependency:build-classpath -Dmdep.outputFile=/dev/stdout -q) +CLASSPATH="target/classes:${MAVEN_CP}" + +# Compile the generator +javac -cp "$CLASSPATH" RuleGenerator.java + +# Generate code for each rule +find src/main/java/org/qed/Generated/RRuleInstances -name '*.java' -not -path '*/RRuleInstances-unprovable/*' | while read file; do + className=$(echo "$file" | sed 's|src/main/java/||; s|/|.|g; s|\.java$||') + java -cp ".:$CLASSPATH" RuleGenerator "$className" +done + +# Step 2: Check for missing tests +missing_tests="" +missing_count=0 +for rule_file in src/main/java/org/qed/Generated/RRuleInstances/*.java; do + rule_name=$(basename "$rule_file" .java) + if [ ! -f "src/main/java/org/qed/Generated/Tests/${rule_name}Test.java" ]; then + missing_tests="${missing_tests}- ${rule_name}\n" + missing_count=$((missing_count + 1)) + fi +done + +if [ $missing_count -gt 0 ]; then + echo "**⚠️ Warning: Missing tests for $missing_count rules:**" >> $GITHUB_STEP_SUMMARY + echo -e "$missing_tests" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY +fi + +# Step 3: Run all test classes +# Store results for summary +total_tests=0 +passed_tests=0 + +# Find all test files and run them +total_tests=0 +passed_tests=0 + +for test_file in src/main/java/org/qed/Generated/Tests/*Test.java; do + class_name=${test_file#src/main/java/} + class_name=${class_name%.java} + class_name=${class_name//\//.} + test_name=$(basename "$test_file" .java) + display_name=${test_name%Test} + total_tests=$((total_tests + 1)) + + # Run the test and capture output + if java -cp "$CLASSPATH" "$class_name" > /tmp/test_output.txt 2>&1; then + if grep -q "trivial" /tmp/test_output.txt; then + echo "⚠️ ${display_name}: TRIVIAL" >> $GITHUB_STEP_SUMMARY + elif grep -q "succeeded" /tmp/test_output.txt && ! grep -q "failed" /tmp/test_output.txt; then + echo "✅ ${display_name}: PASSED" >> $GITHUB_STEP_SUMMARY + passed_tests=$((passed_tests + 1)) + else + echo "❌ ${display_name}: FAILED" >> $GITHUB_STEP_SUMMARY + fi + else + echo "❌ ${display_name}: ERROR" >> $GITHUB_STEP_SUMMARY + fi +done + +# Clean up +rm -f RuleGenerator.java RuleGenerator.class /tmp/test_output.txt + +echo "" >> $GITHUB_STEP_SUMMARY +echo "**Summary:** $passed_tests/$total_tests passed" >> $GITHUB_STEP_SUMMARY + +# Exit with error if tests failed +if [ "$passed_tests" -ne "$total_tests" ]; then + exit 1 +fi \ No newline at end of file diff --git a/scripts/test-rules.sh b/scripts/test-rules.sh new file mode 100644 index 0000000..5dc149b --- /dev/null +++ b/scripts/test-rules.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +# Test all generated rules with qed-prover + +echo "## QED Prover Test Results" >> $GITHUB_STEP_SUMMARY +echo "" >> $GITHUB_STEP_SUMMARY + +failed_rules="" +total_count=0 +passed_count=0 + +for json_file in tmp-rules/*.json; do + rule_name=$(basename "$json_file" .json) + total_count=$((total_count + 1)) + ./qed-prover/target/release/qed-prover "$json_file" || true + + result_file="${json_file%.json}.result" + if [ -f "$result_file" ] && jq -e '.provable == true' "$result_file" > /dev/null 2>&1; then + echo "✅ $rule_name: PASSED" >> $GITHUB_STEP_SUMMARY + passed_count=$((passed_count + 1)) + else + echo "❌ $rule_name: FAILED" >> $GITHUB_STEP_SUMMARY + failed_rules="$failed_rules$rule_name," + fi +done + +echo "" >> $GITHUB_STEP_SUMMARY +echo "**Summary:** $passed_count/$total_count passed" >> $GITHUB_STEP_SUMMARY + +if [ -n "$failed_rules" ]; then + echo "::error::Failed rules: ${failed_rules%,}" + exit 1 +fi \ No newline at end of file diff --git a/src/main/java/org/qed/CodeGenerator.java b/src/main/java/org/qed/CodeGenerator.java new file mode 100644 index 0000000..7b5009e --- /dev/null +++ b/src/main/java/org/qed/CodeGenerator.java @@ -0,0 +1,263 @@ +package org.qed; + +public interface CodeGenerator { + + default String unimplemented(String context, Object object) { + return "<--" + context + object.getClass().getName() + "-->"; + } + + default E unimplementedOnMatch(E env, Object object) { + System.err.println(unimplemented("Unspecified onMatch codegen: ", object)); + return env; + } + + default E unimplementedTransform(E env, Object object) { + System.err.println(unimplemented("Unspecified transform codegen: ", object)); + return env; + } + + E preMatch(); + + default E onMatch(E env, RelRN pattern) { + return switch (pattern) { + case RelRN.Scan scan -> onMatchScan(env, scan); + case RelRN.Filter filter -> onMatchFilter(env, filter); + case RelRN.Project project -> onMatchProject(env, project); + case RelRN.Join join -> onMatchJoin(env, join); + case RelRN.Union union -> onMatchUnion(env, union); + case RelRN.Intersect intersect -> onMatchIntersect(env, intersect); + case RelRN.Minus minus -> onMatchMinus(env, minus); + case RelRN.Empty empty -> onMatchEmpty(env, empty); + case RelRN.Aggregate aggregate -> onMatchAggregate(env, aggregate); + default -> onMatchCustom(env, pattern); + }; + } + + default E onMatch(E env, RexRN pattern) { + return switch (pattern) { + case RexRN.Field field -> onMatchField(env, field); + case RexRN.JoinField joinField -> onMatchJoinField(env, joinField); + case RexRN.Proj proj -> onMatchProj(env, proj); + case RexRN.Pred pred -> onMatchPred(env, pred); + case RexRN.And and -> onMatchAnd(env, and); + case RexRN.Or or -> onMatchOr(env, or); + case RexRN.Not not -> onMatchNot(env, not); + case RexRN.True literal -> onMatchTrue(env, literal); + case RexRN.False literal -> onMatchFalse(env, literal); + default -> onMatchCustom(env, pattern); + }; + } + + default E postMatch(E env) { + return env; + } + + default E preTransform(E env) { + return env; + } + + default E transform(E env, RelRN target) { + return switch (target) { + case RelRN.Scan scan -> transformScan(env, scan); + case RelRN.Filter filter -> transformFilter(env, filter); + case RelRN.Project project -> transformProject(env, project); + case RelRN.Join join -> transformJoin(env, join); + case RelRN.Union union -> transformUnion(env, union); + case RelRN.Intersect intersect -> transformIntersect(env, intersect); + case RelRN.Minus minus -> transformMinus(env, minus); + case RelRN.Empty empty -> transformEmpty(env, empty); + case RelRN.Aggregate aggregate -> transformAggregate(env, aggregate); + default -> transformCustom(env, target); + }; + } + + default E transform(E env, RexRN target) { + return switch (target) { + case RexRN.Field field -> transformField(env, field); + case RexRN.JoinField joinField -> transformJoinField(env, joinField); + case RexRN.Pred pred -> transformPred(env, pred); + case RexRN.Proj proj -> transformProj(env, proj); + case RexRN.And and -> transformAnd(env, and); + case RexRN.Or or -> transformOr(env, or); + case RexRN.Not not -> transformNot(env, not); + case RexRN.True literal -> transformTrue(env, literal); + case RexRN.False literal -> transformFalse(env, literal); + default -> transformCustom(env, target); + }; + } + + default E postTransform(E env) { + return env; + } + + default String translate(String name, E onMatch, E transform) { + return "Unspecified translation to target language"; + } + + default String generate(RRule rule) { + System.out.printf("Generating Rule: %s\n", rule.name()); + var onMatch = postMatch(onMatch(preMatch(), rule.before())); + var transform = postTransform(transform(preTransform(onMatch), rule.after())); + return translate(rule.getClass().getSimpleName(), onMatch, transform); + } + + default E onMatchScan(E env, RelRN.Scan scan) { + return unimplementedOnMatch(env, scan); + } + + default E onMatchFilter(E env, RelRN.Filter filter) { + return unimplementedOnMatch(env, filter); + } + + default E onMatchProject(E env, RelRN.Project project) { + return unimplementedOnMatch(env, project); + } + + default E onMatchJoin(E env, RelRN.Join join) { + return unimplementedOnMatch(env, join); + } + + default E onMatchUnion(E env, RelRN.Union union) { + return unimplementedOnMatch(env, union); + } + + default E onMatchIntersect(E env, RelRN.Intersect intersect) { + return unimplementedOnMatch(env, intersect); + } + + default E onMatchMinus(E env, RelRN.Minus minus) { + return unimplementedOnMatch(env, minus); + } + + default E onMatchCustom(E env, RelRN custom) { + return unimplementedOnMatch(env, custom); + } + + default E onMatchField(E env, RexRN.Field field) { + return unimplementedOnMatch(env, field); + } + + default E onMatchJoinField(E env, RexRN.JoinField joinField) { + return unimplementedOnMatch(env, joinField); + } + + default E onMatchPred(E env, RexRN.Pred pred) { + return unimplementedOnMatch(env, pred); + } + + default E onMatchProj(E env, RexRN.Proj proj) { + return unimplementedOnMatch(env, proj); + } + + default E onMatchAnd(E env, RexRN.And and) { + return unimplementedOnMatch(env, and); + } + + default E onMatchOr(E env, RexRN.Or or) { + return unimplementedOnMatch(env, or); + } + + default E onMatchNot(E env, RexRN.Not not) { + return unimplementedOnMatch(env, not); + } + + default E onMatchCustom(E env, RexRN custom) { + return unimplementedOnMatch(env, custom); + } + + default E onMatchTrue(E env, RexRN literal) { + return unimplementedOnMatch(env, literal); + } + + default E onMatchFalse(E env, RexRN literal) { + return unimplementedOnMatch(env, literal); + } + + default E onMatchEmpty(E env, RelRN.Empty empty) { + return unimplementedOnMatch(env, empty); + } + + default E transformScan(E env, RelRN.Scan scan) { + return unimplementedTransform(env, scan); + } + + default E transformFilter(E env, RelRN.Filter filter) { + return unimplementedTransform(env, filter); + } + + default E transformProject(E env, RelRN.Project project) { + return unimplementedTransform(env, project); + } + + default E transformJoin(E env, RelRN.Join join) { + return unimplementedTransform(env, join); + } + + default E transformUnion(E env, RelRN.Union union) { + return unimplementedTransform(env, union); + } + + default E transformIntersect(E env, RelRN.Intersect intersect) { + return unimplementedTransform(env, intersect); + } + + default E transformMinus(E env, RelRN.Minus minus) { + return unimplementedTransform(env, minus); + } + + default E transformCustom(E env, RelRN custom) { + return unimplementedTransform(env, custom); + } + + default E transformField(E env, RexRN.Field field) { + return unimplementedTransform(env, field); + } + + default E transformJoinField(E env, RexRN.JoinField joinField) { + return unimplementedTransform(env, joinField); + } + + default E transformProj(E env, RexRN.Proj proj) { + return unimplementedTransform(env, proj); + } + + default E transformPred(E env, RexRN.Pred pred) { + return unimplementedTransform(env, pred); + } + + default E transformAnd(E env, RexRN.And and) { + return unimplementedTransform(env, and); + } + + default E transformOr(E env, RexRN.Or or) { + return unimplementedTransform(env, or); + } + + default E transformNot(E env, RexRN.Not not) { + return unimplementedTransform(env, not); + } + + default E transformCustom(E env, RexRN custom) { + return unimplementedTransform(env, custom); + } + + default E transformTrue(E env, RexRN literal) { + return unimplementedTransform(env, literal); + } + + default E transformFalse(E env, RexRN literal) { + return unimplementedTransform(env, literal); + } + + default E transformEmpty(E env, RelRN.Empty empty) { + return unimplementedTransform(env, empty); + } + + default E onMatchAggregate(E env, RelRN.Aggregate aggregate) { + return unimplementedOnMatch(env, aggregate); + } + + default E transformAggregate(E env, RelRN.Aggregate aggregate) { + return unimplementedTransform(env, aggregate); + } +} diff --git a/src/main/java/org/qed/Generated/CalciteGenerator.java b/src/main/java/org/qed/Generated/CalciteGenerator.java new file mode 100644 index 0000000..4a74dc1 --- /dev/null +++ b/src/main/java/org/qed/Generated/CalciteGenerator.java @@ -0,0 +1,611 @@ +package org.qed.Generated; + +import kala.collection.Seq; +import kala.collection.immutable.ImmutableMap; +import kala.tuple.Tuple; +import kala.tuple.Tuple2; +import org.qed.CodeGenerator; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.Generated.CalciteGenerator.Env; + +import java.util.concurrent.atomic.AtomicInteger; + +import javax.annotation.processing.Generated; + +public class CalciteGenerator implements CodeGenerator { + + @Override + public Env preMatch() { + return Env.empty(); + } + + @Override + public Env preTransform(Env env) { + var buildEnv = env.declare("call.builder()"); + return buildEnv.getValue().focus(buildEnv.getKey()); + } + + @Override + public Env postTransform(Env env) { + return env.state("call.transformTo(" + env.current() + ".build());"); + } + + @Override + public String translate(String name, Env onMatch, Env transform) { + var builder = new StringBuilder("package org.qed.Generated;\n\n"); + builder.append("import org.apache.calcite.plan.RelOptRuleCall;\n"); + builder.append("import org.apache.calcite.plan.RelRule;\n"); + builder.append("import org.apache.calcite.rel.RelNode;\n"); + builder.append("import org.apache.calcite.rel.core.JoinRelType;\n"); + builder.append("import org.apache.calcite.rel.logical.*;\n\n"); + builder.append("public class " + name + " extends RelRule<" + name + ".Config> {\n"); + builder.append("\tprotected " + name + "(Config config) {\n"); + builder.append("\t\tsuper(config);\n"); + builder.append("\t}\n\n"); + builder.append("\t@Override\n\tpublic void onMatch(RelOptRuleCall call) {\n"); + transform.statements().forEach(statement -> builder.append("\t\t").append(statement).append("\n")); + builder.append("\t}\n\n"); + builder.append("\tpublic interface Config extends EmptyConfig {\n"); + builder.append("\t\tConfig DEFAULT = new Config() {};\n\n"); + builder.append("\t\t@Override\n\t\tdefault " + name + " toRule() {\n"); + builder.append("\t\t\treturn new " + name + "(this);\n"); + builder.append("\t\t}\n\n"); + builder.append("\t\t@Override\n\t\tdefault String description() {\n"); + builder.append("\t\t\treturn \"" + name + "\";\n"); + builder.append("\t\t}\n\n"); + builder.append("\t\t@Override\n\t\tdefault RelRule.OperandTransform operandSupplier() {\n"); + builder.append("\t\t\treturn " + onMatch.skeleton() + ";\n"); + builder.append("\t\t}\n\n"); + builder.append("\t}\n"); + builder.append("}\n"); + return builder.toString(); + } + + @Override + public Env onMatchScan(Env env, RelRN.Scan scan) { + return env.symbol(scan.name(), env.current()).grow("operand(RelNode.class).anyInputs()"); + } + + @Override + public Env onMatchFilter(Env env, RelRN.Filter filter) { + var source_match = onMatch(env.next(), filter.source()); + var operator_match = source_match.grow("operand(LogicalFilter.class).oneInput(" + source_match.skeleton() + ")"); + var condition_match = operator_match.focus("((LogicalFilter) " + env.current() + ").getCondition()"); + return onMatch(condition_match, filter.cond()); + } + + @Override + public Env onMatchProject(Env env, RelRN.Project project) { + var source_match = onMatch(env.next(), project.source()); + var operator_match = + source_match.grow("operand(LogicalProject.class).oneInput(" + source_match.skeleton() + ")"); + var map_match = operator_match.focus("((LogicalProject) " + env.current() + ").getProjects()"); + return onMatch(map_match, project.map()); + } + + @Override + public Env onMatchPred(Env env, RexRN.Pred pred) { + return env.symbol(pred.operator().getName(), env.current()); + } + + @Override + public Env onMatchProj(Env env, RexRN.Proj proj) { + return env.symbol(proj.operator().getName(), env.current()); + } + + @Override + public Env onMatchJoin(Env env, RelRN.Join join) { + var current_join = "((LogicalJoin) " + env.current() + ")"; + // STR."\{join_env.current()}.getJoinType()" + var left_source_env = env.next(); + var left_match_env = onMatch(left_source_env, join.left()); + var right_source_env = left_match_env.next(); + var right_match_env = onMatch(right_source_env, join.right()); + var operator_match = + right_match_env.grow("operand(LogicalJoin.class).inputs(" + left_match_env.skeleton() + ", " + right_match_env.skeleton() + ")"); + var cond_source_env = operator_match.focus(current_join + ".getCondition()"); + return onMatch(cond_source_env, join.cond()); + } + + @Override + public Env onMatchAnd(Env env, RexRN.And and) { + // Process each source in the And condition + var current_env = env; + // Use a unique symbol name for the AND condition + String andSymbol = "and_" + env.varId.getAndIncrement(); + // Store the current expression as this AND node's symbol + current_env = current_env.symbol(andSymbol, current_env.current()); + + // Process each child source in the AND condition + for (var source : and.sources()) { + current_env = onMatch(current_env, source); + } + + return current_env; + } + + @Override + public Env onMatchUnion(Env env, RelRN.Union union) { + // Get the all flag from the union + boolean all = union.all(); + + // Process each source in the union + var current_env = env; + var skeletons = Seq.empty(); + + // Process all sources in the sequence + for (var source : union.sources()) { + var next_env = current_env.next(); + var source_env = onMatch(next_env, source); + skeletons = skeletons.appended(source_env.skeleton()); + current_env = source_env; + } + + // Build the input skeletons string for the operand + StringBuilder inputsBuilder = new StringBuilder(); + for (int i = 0; i < skeletons.size(); i++) { + if (i > 0) { + inputsBuilder.append(", "); + } + inputsBuilder.append(skeletons.get(i).toString()); + } + + // Create the union operand with the appropriate class based on the all flag + String operatorClass = all ? "LogicalUnionAll" : "LogicalUnion"; + return current_env.grow("operand(" + operatorClass + ".class).inputs(" + inputsBuilder.toString() + ")"); + } + + @Override + public Env onMatchIntersect(Env env, RelRN.Intersect intersect) { + // Get the all flag from the intersect + boolean all = intersect.all(); + + // Process each source in the intersect + var current_env = env; + var skeletons = Seq.empty(); + + // Process all sources in the sequence + for (var source : intersect.sources()) { + var next_env = current_env.next(); + var source_env = onMatch(next_env, source); + skeletons = skeletons.appended(source_env.skeleton()); + current_env = source_env; + } + + // Build the input skeletons string for the operand + StringBuilder inputsBuilder = new StringBuilder(); + for (int i = 0; i < skeletons.size(); i++) { + if (i > 0) { + inputsBuilder.append(", "); + } + inputsBuilder.append(skeletons.get(i).toString()); + } + + // Create the intersect operand with the appropriate class based on the all flag + String operatorClass = all ? "LogicalIntersectAll" : "LogicalIntersect"; + return current_env.grow("operand(" + operatorClass + ".class).inputs(" + inputsBuilder.toString() + ")"); + } + + @Override + public Env onMatchMinus(Env env, RelRN.Minus minus) { + // Get the all flag from the minus + boolean all = minus.all(); + + // Process each source in the minus + var current_env = env; + var skeletons = Seq.empty(); + + // Process all sources in the sequence + for (var source : minus.sources()) { + var next_env = current_env.next(); + var source_env = onMatch(next_env, source); + skeletons = skeletons.appended(source_env.skeleton()); + current_env = source_env; + } + + // Build the input skeletons string for the operand + StringBuilder inputsBuilder = new StringBuilder(); + for (int i = 0; i < skeletons.size(); i++) { + if (i > 0) { + inputsBuilder.append(", "); + } + inputsBuilder.append(skeletons.get(i).toString()); + } + + // Create the minus operand + return current_env.grow("operand(LogicalMinus.class).inputs(" + inputsBuilder.toString() + ")"); + } + + @Override + public Env onMatchField(Env env, RexRN.Field field) { + // Generate a unique symbolic name for this field + String fieldSymbol = "field_" + env.varId.getAndIncrement(); + + // Store the field expression in the environment's symbol table + return env.symbol(fieldSymbol, env.current()); + } + + @Override + public Env onMatchTrue(Env env, RexRN literal) { + // Create a unique symbol name for this true literal + String trueSymbol = "true_" + env.varId.getAndIncrement(); + + // Store the current expression as this true literal's symbol + return env.symbol(trueSymbol, env.current()); + } + + @Override + public Env onMatchFalse(Env env, RexRN literal) { + // Create a unique symbol name for this false literal + String falseSymbol = "false_" + env.varId.getAndIncrement(); + + // Store the current expression as this false literal's symbol + return env.symbol(falseSymbol, env.current()); + } + + @Override + public Env onMatchEmpty(Env env, RelRN.Empty empty) { + return env.grow("operand(LogicalValues.class).noInputs()"); + } + + + @Override + public Env transformScan(Env env, RelRN.Scan scan) { + return env.focus(env.current() + ".push(" + env.symbols().get(scan.name()) + ")"); + } + +// @Override +// public Env onMatchCustom(Env env, RexRN custom) { +// return switch (custom) { +// case RRule.JoinConditionPush.JoinPred joinPred -> { +// var pred = env.expressions().first(); +// var breakdown_env = assignVariable(env, STR."customSplitFilter(\{pred})"); +// var breakdown = breakdown_env.expressions().first(); +// yield breakdown_env +// .symbol(joinPred.bothPred(), STR."\{breakdown}.getBoth()") +// .symbol(joinPred.leftPred(), STR."\{breakdown}.getLeft()") +// .symbol(joinPred.rightPred(), STR."\{breakdown}.getRight()"); +// } +// default -> CodeGenerator.super.onMatchCustom(env, custom); +// }; +// } + + @Override + public Env transformFilter(Env env, RelRN.Filter filter) { + var source_transform = transform(env, filter.source()); + var source_expression = source_transform.current(); + var cond_transform = transform(source_transform, filter.cond()); + return cond_transform.focus(source_expression + ".filter(" + cond_transform.current() + ")"); + } + + @Override + public Env transformPred(Env env, RexRN.Pred pred) { + // ONLY apply to the very specific JoinCommute pattern where: + // 1. We have exactly 2 JoinField sources + // 2. The first JoinField has ordinal 1 and second has ordinal 0 (indicating argument swap) + // This is the EXACT pattern from JoinCommute rule: pred($1, $0) vs pred($0, $1) + boolean isExactJoinCommuteSwapPattern = pred.sources().size() == 2 && + pred.sources().get(0) instanceof RexRN.JoinField joinField1 && + pred.sources().get(1) instanceof RexRN.JoinField joinField2 && + joinField1.ordinal() == 1 && // First arg is ordinal 1 (left -> right in swap) + joinField2.ordinal() == 0; // Second arg is ordinal 0 (right -> left in swap) + + if (isExactJoinCommuteSwapPattern) { + // This is the JoinCommute swapped predicate - transform with swapped arguments + var currentEnv = env; + var transformedArgs = Seq.empty(); + + // Process arguments in reverse order to swap them back + var sources = pred.sources(); + var reversedSources = Seq.of(sources.get(1), sources.get(0)); + + for (var arg : reversedSources) { + currentEnv = transform(currentEnv, arg); + transformedArgs = transformedArgs.appended(currentEnv.current()); + currentEnv = currentEnv.focus(env.current()); + } + + // Create a new predicate call with transformed arguments + String argsString = transformedArgs.joinToString(", "); + + // Extract operator from the original join condition + String operatorCall = "((org.apache.calcite.rex.RexCall) ((LogicalJoin) call.rel(0)).getCondition()).getOperator()"; + + return currentEnv.focus(env.current() + ".call(" + operatorCall + ", " + argsString + ")"); + } else { + return env.focus(env.symbols().get(pred.operator().getName())); + } + } + + @Override + public Env transformJoinField(Env env, RexRN.JoinField joinField) { + // For JoinCommute: we need to calculate absolute field positions in the swapped join + + // Get the original join condition to extract the actual field indices + var origJoinDecl = env.declare("(LogicalJoin) call.rel(0)"); + var envWithOrigJoin = origJoinDecl.getValue(); + var conditionDecl = envWithOrigJoin.declare("(org.apache.calcite.rex.RexCall) " + origJoinDecl.getKey() + ".getCondition()"); + var envWithCondition = conditionDecl.getValue(); + + if (joinField.ordinal() == 0) { + // Ordinal 0 = Left table in original join + // Extract the left operand field index from original condition + var leftFieldDecl = envWithCondition.declare("((org.apache.calcite.rex.RexInputRef) " + conditionDecl.getKey() + ".getOperands().get(0)).getIndex()"); + var envWithLeftField = leftFieldDecl.getValue(); + + // In swapped join: Left table is now at input 1 + // Use field(2, 1, leftFieldIndex) syntax + return envWithLeftField.focus(env.current() + ".field(2, 1, " + leftFieldDecl.getKey() + ")"); + } + else if (joinField.ordinal() == 1) { + // Ordinal 1 = Right table in original join + // Extract the right operand field index from original condition + var rightFieldDecl = envWithCondition.declare("((org.apache.calcite.rex.RexInputRef) " + conditionDecl.getKey() + ".getOperands().get(1)).getIndex()"); + var envWithRightField = rightFieldDecl.getValue(); + + // Right table field index needs to be adjusted since it was originally after left table + var leftColCountDecl = envWithRightField.declare("call.rel(1).getRowType().getFieldCount()"); + var envWithLeftCount = leftColCountDecl.getValue(); + var adjustedRightFieldDecl = envWithLeftCount.declare(rightFieldDecl.getKey() + " - " + leftColCountDecl.getKey()); + var envWithAdjustedRightField = adjustedRightFieldDecl.getValue(); + + // In swapped join: Right table is now at input 0 + // Use field(2, 0, adjustedRightFieldIndex) syntax + return envWithAdjustedRightField.focus(env.current() + ".field(2, 0, " + adjustedRightFieldDecl.getKey() + ")"); + } else { + throw new UnsupportedOperationException("Unsupported join field ordinal: " + joinField.ordinal()); + } + } + + @Override + public Env transformJoin(Env env, RelRN.Join join) { + var left_source_transform = transform(env, join.left()); + var right_source_transform = transform(left_source_transform, join.right()); + var source_expression = right_source_transform.current(); + var cond_transform = transform(right_source_transform, join.cond()); + var join_type = switch (join.ty().semantics()) { + case INNER -> "JoinRelType.INNER"; + case LEFT -> "JoinRelType.LEFT"; + case RIGHT -> "JoinRelType.RIGHT"; + case FULL -> "JoinRelType.FULL"; + case SEMI -> "JoinRelType.SEMI"; + case ANTI -> "JoinRelType.ANTI"; + }; + return cond_transform.focus(source_expression + ".join(" + join_type + ", " + cond_transform.current() + ")"); + } + + @Override + public Env transformAnd(Env env, RexRN.And and) { + var source_transform = env; + var operands = Seq.empty(); + for (var source : and.sources()) { + source_transform = transform(source_transform, source); + operands = operands.appended(source_transform.current()); + source_transform = source_transform.focus(env.current()); + } + return source_transform.focus(env.current() + ".and(" + operands.joinToString(", ") + ")"); + } + + @Override + public Env transformUnion(Env env, RelRN.Union union) { + // Get the all flag from the union + boolean all = union.all(); + + // The number of sources + int sourceCount = union.sources().size(); + + // Transform each source + var current_env = env; + for (var source : union.sources()) { + current_env = transform(current_env, source); + } + + // Use the union method with the all flag and source count + // This matches the Calcite RelBuilder.union(boolean all, int n) signature + return current_env.focus(current_env.current() + ".union(" + all + ", " + sourceCount + ")"); + } + + @Override + public Env transformIntersect(Env env, RelRN.Intersect intersect) { + // Get the all flag from the intersect + boolean all = intersect.all(); + + // The number of sources + int sourceCount = intersect.sources().size(); + + // Transform each source + var current_env = env; + for (var source : intersect.sources()) { + current_env = transform(current_env, source); + } + + // Use the intersect method with the all flag and source count + // This matches the expected Calcite RelBuilder.intersect(boolean all, int n) signature + String methodName = all ? "intersectAll" : "intersect"; + return current_env.focus(current_env.current() + "." + methodName + "(" + all + ", " + sourceCount + ")"); + } + + @Override + public Env transformMinus(Env env, RelRN.Minus minus) { + // Get the all flag from the minus + boolean all = minus.all(); + + // The number of sources + int sourceCount = minus.sources().size(); + + // Transform each source + var current_env = env; + for (var source : minus.sources()) { + current_env = transform(current_env, source); + } + + // Use the minus method with the all flag and source count + // This matches the expected Calcite RelBuilder.minus(boolean all, int n) signature + return current_env.focus(current_env.current() + ".minus(" + all + ", " + sourceCount + ")"); + } + + @Override + public Env transformField(Env env, RexRN.Field field) { + // In Calcite, field references are typically created with a "field" method + // We'll need to pass some identifier for the field - use toString() if no specific field accessor is available + + // Assuming field has a method that returns some kind of identifier or name + // If not, we may need to adjust this implementation + return env.focus(env.current() + ".field(" + field + ")"); + } + + @Override + public Env transformProj(Env env, RexRN.Proj proj) { + // In Calcite, projections are typically created using the operator name + // This is similar to your transformPred implementation + + // Look up the symbol from the matching phase + if (!env.symbols().containsKey(proj.operator().getName())) { + throw new RuntimeException("Operator symbol not found: " + proj.operator().getName() + + ". Make sure onMatchProj is properly implemented."); + } + + // Return an environment focused on the expression for this projection + return env.focus(env.symbols().get(proj.operator().getName())); + } + + @Override + public Env transformProject(Env env, RelRN.Project project) { + // First transform the source relation + var source_transform = transform(env, project.source()); + var source_expression = source_transform.current(); + + // Then transform the projection map + var map_transform = transform(source_transform, project.map()); + + // Combine the source and projection using the project operation + // This creates a projection on top of the source relation + return map_transform.focus(source_expression + ".project(" + map_transform.current() + ")"); + } + + @Override + public Env transformTrue(Env env, RexRN literal) { + // In Calcite, true literals are typically represented using the + // rexBuilder.makeLiteral(true) method or just "TRUE" + return env.focus(env.current() + ".literal(true)"); + } + + @Override + public Env transformFalse(Env env, RexRN literal) { + // In Calcite, false literals are represented using the + // rexBuilder.makeLiteral(false) method or just "FALSE" + return env.focus(env.current() + ".literal(false)"); + } + + @Override + public Env transformEmpty(Env env, RelRN.Empty empty) { + // In Calcite, empty relations are created using the values() method with no tuples + // This creates a LogicalValues node with no rows + return env.focus(env.current() + ".empty()"); + } + + + @Override + public Env transformCustom(Env env, RelRN custom) { + return switch (custom) { + case org.qed.Generated.RRuleInstances.JoinCommute.ProjectionRelRN projection -> { + // Transform the source first - this builds the join + var sourceEnv = transform(env, projection.source()); + + // Get original table column counts + var leftTableDecl = sourceEnv.declare("call.rel(1)"); + var envWithLeftTable = leftTableDecl.getValue(); + var rightTableDecl = envWithLeftTable.declare("call.rel(2)"); + var envWithRightTable = rightTableDecl.getValue(); + + var leftColCountDecl = envWithRightTable.declare(leftTableDecl.getKey() + ".getRowType().getFieldCount()"); + var envWithLeftCount = leftColCountDecl.getValue(); + var rightColCountDecl = envWithLeftCount.declare(rightTableDecl.getKey() + ".getRowType().getFieldCount()"); + var envWithRightCount = rightColCountDecl.getValue(); + + // Create the projection indices as a List + var projectionIndicesDecl = envWithRightCount.declare( + "java.util.stream.IntStream.concat(" + + // Left columns: rightColCount + 0, rightColCount + 1, ..., rightColCount + leftColCount - 1 + "java.util.stream.IntStream.range(" + rightColCountDecl.getKey() + ", " + + rightColCountDecl.getKey() + " + " + leftColCountDecl.getKey() + "), " + + // Right columns: 0, 1, ..., rightColCount - 1 + "java.util.stream.IntStream.range(0, " + rightColCountDecl.getKey() + ")" + + ").boxed().collect(java.util.stream.Collectors.toList())" + ); + var envWithProjectionIndices = projectionIndicesDecl.getValue(); + + // Convert List to field references using RelBuilder.fields() + var fieldRefsDecl = envWithProjectionIndices.declare( + sourceEnv.current() + ".fields(" + projectionIndicesDecl.getKey() + ")" + ); + var envWithFieldRefs = fieldRefsDecl.getValue(); + + // Apply projection using the field references list + yield envWithFieldRefs.focus(sourceEnv.current() + ".project(" + fieldRefsDecl.getKey() + ")"); + } + default -> unimplementedTransform(env, custom); + }; + } + + public record Env(AtomicInteger varId, int rel, String current, String skeleton, Seq statements, + ImmutableMap symbols) { + public static Env empty() { + return new Env(new AtomicInteger(), 0, "call.rel(0)", "/* Unspecified skeleton */", Seq.empty(), + ImmutableMap.empty()); + } + + public Env next() { + return new Env(varId, rel + 1, "call.rel(" + (rel + 1) + ")", skeleton, statements, symbols); + } + + public Env focus(String target) { + return new Env(varId, rel, target, skeleton, statements, symbols); + } + + public Env state(String statement) { + return new Env(varId, rel, current, skeleton, statements.appended(statement), symbols); + } + + public Env symbol(String symbol, String expression) { + return new Env(varId, rel, current, skeleton, statements, symbols.putted(symbol, expression)); + } + + public Tuple2 declare(String expression) { + var name = "var_" + varId.getAndIncrement(); + return Tuple.of(name, state("var " + name + " = " + expression + ";")); + } + + public Env grow(String requirement) { + var vn = "s_" + varId.getAndIncrement(); + return new Env(varId, rel, current, vn + " -> " + vn + "." + requirement, statements, symbols); + } + } + + @Override + public Env onMatchAggregate(Env env, RelRN.Aggregate aggregate) { + var sourceMatch = onMatch(env.next(), aggregate.source()); + return sourceMatch.grow("operand(LogicalAggregate.class).oneInput(" + sourceMatch.skeleton() + ")"); + } + + @Override + public Env transformAggregate(Env env, RelRN.Aggregate aggregate) { + var sourceTransform = transform(env, aggregate.source()); + String builderWithSource = sourceTransform.current(); + String originalAgg = "((LogicalAggregate) call.rel(0))"; + var groupSetDecl = sourceTransform.declare(originalAgg + ".getGroupSet()"); + var envWithGroupSet = groupSetDecl.getValue(); + var groupKeyDecl = envWithGroupSet.declare( + builderWithSource + ".groupKey(" + groupSetDecl.getKey() + ")" + ); + var envWithGroupKey = groupKeyDecl.getValue(); + var aggCallsDecl = envWithGroupKey.declare(originalAgg + ".getAggCallList()"); + var envWithAggCalls = aggCallsDecl.getValue(); + return envWithAggCalls.focus( + builderWithSource + ".aggregate(" + + groupKeyDecl.getKey() + ", " + + aggCallsDecl.getKey() + ")" + ); + } +} diff --git a/src/main/java/org/qed/Generated/CalciteTester.java b/src/main/java/org/qed/Generated/CalciteTester.java new file mode 100644 index 0000000..4a1d6b3 --- /dev/null +++ b/src/main/java/org/qed/Generated/CalciteTester.java @@ -0,0 +1,165 @@ +package org.qed.Generated; + +import com.fasterxml.jackson.databind.ObjectMapper; +import kala.collection.Seq; +import kala.tuple.Tuple; + +import org.apache.calcite.jdbc.CalcitePrepare.SparkHandler.RuleSetBuilder; +import org.apache.calcite.plan.RelOptRule; +import org.apache.calcite.plan.hep.HepPlanner; +import org.apache.calcite.plan.hep.HepProgramBuilder; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.LogicalFilter; +import org.apache.calcite.rel.logical.LogicalValues; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.*; +import org.reflections.Reflections; + +import java.io.File; +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.lang.reflect.Modifier; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Set; +import java.util.stream.Collectors; + +public class CalciteTester { + // Assuming that current working directory is the root of the project + public static String genPath = "src/main/java/org/qed/Generated"; + public static String rulePath = "rules"; + + public static HepPlanner loadRule(RelOptRule rule) { + System.out.printf("Verifying Rule: %s\n", rule.getClass()); + var builder = new HepProgramBuilder().addRuleInstance(rule); + return new HepPlanner(builder.build()); + } + + public static HepPlanner loadRule(RelOptRule rule, int matchLimit) { + System.out.printf("Verifying Rule: %s (match limit: %d)\n", rule.getClass(), matchLimit); + var builder = new HepProgramBuilder() + .addMatchLimit(matchLimit) + .addRuleInstance(rule); + return new HepPlanner(builder.build()); + } + + public static Seq ruleList() { + Reflections reflections = new Reflections("org.qed.Generated.RRuleInstances"); + + Set> ruleClasses = reflections.getSubTypesOf(RRule.class); + var concreteRuleClasses = ruleClasses.stream() + .filter(clazz -> !clazz.isInterface() && + !Modifier.isAbstract(clazz.getModifiers()) && + !clazz.getName().contains("$")) + .collect(Collectors.toSet()); + + var individuals = Seq.from(concreteRuleClasses) + .mapUnchecked(Class::getConstructor) + .mapUnchecked(Constructor::newInstance) + .map(r -> (RRule) r); + + // var families = Seq.from(reflections.getSubTypesOf(RRule.RRuleFamily.class)) + // .filter(clazz -> !clazz.isInterface() && !Modifier.isAbstract(clazz.getModifiers())) + // .mapUnchecked(clazz -> { + // Constructor constructor = clazz.getDeclaredConstructor(); + // constructor.setAccessible(true); + // return constructor.newInstance(); + // }) + // .map(r -> (RRule.RRuleFamily) r); + + // return individuals.appendedAll(families.flatMap(RRule.RRuleFamily::family)); + return individuals; + } + + public static void verify() { + ruleList().forEachUnchecked(rule -> rule.dump(rulePath + "/" + rule.name() + ".json")); + } + + public static void generate() { + var tester = new CalciteTester(); + ruleList().forEach(r -> tester.serialize(r, genPath)); + } + + public static void runAllTests() { + try { + org.qed.Generated.Tests.FilterIntoJoinTest.runTest(); + org.qed.Generated.Tests.FilterMergeTest.runTest(); + org.qed.Generated.Tests.FilterProjectTransposeTest.runTest(); + org.qed.Generated.Tests.UnionMergeTest.runTest(); + org.qed.Generated.Tests.IntersectMergeTest.runTest(); + org.qed.Generated.Tests.FilterSetOpTransposeTest.runTest(); + org.qed.Generated.Tests.JoinExtractFilterTest.runTest(); + org.qed.Generated.Tests.SemiJoinFilterTransposeTest.runTest(); + org.qed.Generated.Tests.MinusMergeTest.runTest(); + org.qed.Generated.Tests.ProjectFilterTransposeTest.runTest(); + org.qed.Generated.Tests.JoinPushTransitivePredicatesTest.runTest(); + org.qed.Generated.Tests.JoinCommuteTest.runTest(); + } catch (Exception e) { + System.out.println("Test failed: " + e.getMessage()); + e.printStackTrace(); + } + } + + public static void main(String[] args) throws IOException { + var rule = new org.qed.Generated.RRuleInstances.JoinCommute(); + System.out.println(rule.explain()); + Files.createDirectories(Path.of(rulePath)); + new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); + // var rules = new RRuleInstance.JoinAssociate(); + // Files.createDirectories(Path.of(rulePath)); + // for (var rule : rules.family()) { + // new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(Path.of(rulePath, rule.name() + "-" + rule.info() + ".json").toFile(), rule.toJson()); + // } + generate(); + runAllTests(); + } + + public void serialize(RRule rule, String path) { + var generator = new CalciteGenerator(); + var code_gen = generator.generate(rule); + try { + Files.write(Path.of(path, rule.name() + ".java"), code_gen.getBytes()); + } catch (IOException ioe) { + System.err.println(ioe.getMessage()); + } + } + + public void test(RelOptRule rule, Seq tests) { + System.out.println("Testing rule " + rule.getClass().getSimpleName()); + var runner = loadRule(rule); + var exams = tests.mapUnchecked(t -> Tuple.of(t, JSONDeserializer.load(new File(t)))); + for (var entry : exams) { + if (entry.getValue().size() != 2) { + System.err.println(entry.getKey() + " does not have exactly two nodes, and thus is not a valid test"); + continue; + } + verify(runner, entry.getValue().get(0), entry.getValue().get(1)); + } + } + + public void verify(HepPlanner runner, RelNode source, RelNode target) { + runner.setRoot(source); + var answer = runner.findBestExp(); + + String answerExplain = answer.explain(); + String targetExplain = target.explain(); + + if(answerExplain.equals(targetExplain)) { + if(answerExplain.equals(source.explain())) + { + System.out.println("trivial"); + System.out.println("> Given source RelNode:\n" + source.explain()); + System.out.println("> Actual rewritten RelNode:\n" + answerExplain); + System.out.println("> Expected rewritten RelNode:\n" + targetExplain); + } + else System.out.println("succeeded"); + return; + } + System.out.println("failed"); + System.out.println("> Given source RelNode:\n" + source.explain()); + System.out.println("> Actual rewritten RelNode:\n" + answerExplain); + System.out.println("> Expected rewritten RelNode:\n" + targetExplain); + } +} diff --git a/src/main/java/org/qed/Generated/CalciteUtilities.java b/src/main/java/org/qed/Generated/CalciteUtilities.java new file mode 100644 index 0000000..22e545c --- /dev/null +++ b/src/main/java/org/qed/Generated/CalciteUtilities.java @@ -0,0 +1,16 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rex.RexNode; +import org.qed.RuleBuilder; + +import java.util.List; + +public record CalciteUtilities() { + public List compose(RelNode base, List inner, List outer) { + var builder = RuleBuilder.create(); + return RelOptUtil.pushPastProject(outer, (Project) builder.push(base).project(inner).build()); + } +} diff --git a/src/main/java/org/qed/Generated/EmptyConfig.java b/src/main/java/org/qed/Generated/EmptyConfig.java new file mode 100644 index 0000000..d784cfe --- /dev/null +++ b/src/main/java/org/qed/Generated/EmptyConfig.java @@ -0,0 +1,33 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.tools.RelBuilderFactory; +import org.checkerframework.checker.nullness.qual.Nullable; + +public interface EmptyConfig extends RelRule.Config { + @Override + default RelRule.Config withRelBuilderFactory(RelBuilderFactory factory) { + return this; + } + + @Override + default @Nullable String description() { + return "Unspecified Config Description"; + } + + @Override + default RelRule.Config withDescription(@Nullable String description) { + return this; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s -> s.operand(RelNode.class).anyInputs(); + } + + @Override + default RelRule.Config withOperandSupplier(RelRule.OperandTransform transform) { + return this; + } +} diff --git a/src/main/java/org/qed/Generated/FilterIntoJoin.java b/src/main/java/org/qed/Generated/FilterIntoJoin.java new file mode 100644 index 0000000..12e58ce --- /dev/null +++ b/src/main/java/org/qed/Generated/FilterIntoJoin.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class FilterIntoJoin extends RelRule { + protected FilterIntoJoin(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_4 = call.builder(); + call.transformTo(var_4.push(call.rel(2)).push(call.rel(3)).join(JoinRelType.INNER, var_4.push(call.rel(2)).push(call.rel(3)).and(((LogicalJoin) call.rel(1)).getCondition(), ((LogicalFilter) call.rel(0)).getCondition())).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterIntoJoin toRule() { + return new FilterIntoJoin(this); + } + + @Override + default String description() { + return "FilterIntoJoin"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_3 -> s_3.operand(LogicalFilter.class).oneInput(s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Generated/FilterMerge.java b/src/main/java/org/qed/Generated/FilterMerge.java new file mode 100644 index 0000000..f943b74 --- /dev/null +++ b/src/main/java/org/qed/Generated/FilterMerge.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class FilterMerge extends RelRule { + protected FilterMerge(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(2)).filter(var_3.push(call.rel(2)).and(((LogicalFilter) call.rel(1)).getCondition(), ((LogicalFilter) call.rel(0)).getCondition())).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterMerge toRule() { + return new FilterMerge(this); + } + + @Override + default String description() { + return "FilterMerge"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalFilter.class).oneInput(s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Generated/FilterProjectTranspose.java b/src/main/java/org/qed/Generated/FilterProjectTranspose.java new file mode 100644 index 0000000..89a31a0 --- /dev/null +++ b/src/main/java/org/qed/Generated/FilterProjectTranspose.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class FilterProjectTranspose extends RelRule { + protected FilterProjectTranspose(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(2)).project(((LogicalProject) call.rel(0)).getProjects()).filter(((LogicalFilter) call.rel(1)).getCondition()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterProjectTranspose toRule() { + return new FilterProjectTranspose(this); + } + + @Override + default String description() { + return "FilterProjectTranspose"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalProject.class).oneInput(s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Generated/FilterReduceFalse.java b/src/main/java/org/qed/Generated/FilterReduceFalse.java new file mode 100644 index 0000000..e9fb926 --- /dev/null +++ b/src/main/java/org/qed/Generated/FilterReduceFalse.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class FilterReduceFalse extends RelRule { + protected FilterReduceFalse(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.empty().build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterReduceFalse toRule() { + return new FilterReduceFalse(this); + } + + @Override + default String description() { + return "FilterReduceFalse"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/FilterReduceTrue.java b/src/main/java/org/qed/Generated/FilterReduceTrue.java new file mode 100644 index 0000000..a1c125c --- /dev/null +++ b/src/main/java/org/qed/Generated/FilterReduceTrue.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class FilterReduceTrue extends RelRule { + protected FilterReduceTrue(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(1)).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterReduceTrue toRule() { + return new FilterReduceTrue(this); + } + + @Override + default String description() { + return "FilterReduceTrue"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/FilterSetOpTranspose.java b/src/main/java/org/qed/Generated/FilterSetOpTranspose.java new file mode 100644 index 0000000..fa3a6b4 --- /dev/null +++ b/src/main/java/org/qed/Generated/FilterSetOpTranspose.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class FilterSetOpTranspose extends RelRule { + protected FilterSetOpTranspose(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_4 = call.builder(); + call.transformTo(var_4.push(call.rel(2)).filter(((LogicalFilter) call.rel(0)).getCondition()).push(call.rel(3)).filter(((LogicalFilter) call.rel(0)).getCondition()).union(false, 2).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default FilterSetOpTranspose toRule() { + return new FilterSetOpTranspose(this); + } + + @Override + default String description() { + return "FilterSetOpTranspose"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_3 -> s_3.operand(LogicalFilter.class).oneInput(s_2 -> s_2.operand(LogicalUnion.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Generated/IntersectMerge.java b/src/main/java/org/qed/Generated/IntersectMerge.java new file mode 100644 index 0000000..d274657 --- /dev/null +++ b/src/main/java/org/qed/Generated/IntersectMerge.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class IntersectMerge extends RelRule { + protected IntersectMerge(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_5 = call.builder(); + call.transformTo(var_5.push(call.rel(2)).push(call.rel(3)).push(call.rel(4)).intersect(false, 3).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default IntersectMerge toRule() { + return new IntersectMerge(this); + } + + @Override + default String description() { + return "IntersectMerge"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_4 -> s_4.operand(LogicalIntersect.class).inputs(s_2 -> s_2.operand(LogicalIntersect.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()), s_3 -> s_3.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/JoinAddRedundantSemiJoin.java b/src/main/java/org/qed/Generated/JoinAddRedundantSemiJoin.java new file mode 100644 index 0000000..dcc5c20 --- /dev/null +++ b/src/main/java/org/qed/Generated/JoinAddRedundantSemiJoin.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class JoinAddRedundantSemiJoin extends RelRule { + protected JoinAddRedundantSemiJoin(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(1)).push(call.rel(2)).join(JoinRelType.SEMI, ((LogicalJoin) call.rel(0)).getCondition()).push(call.rel(2)).join(JoinRelType.INNER, ((LogicalJoin) call.rel(0)).getCondition()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default JoinAddRedundantSemiJoin toRule() { + return new JoinAddRedundantSemiJoin(this); + } + + @Override + default String description() { + return "JoinAddRedundantSemiJoin"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/JoinCommute.java b/src/main/java/org/qed/Generated/JoinCommute.java new file mode 100644 index 0000000..147fd52 --- /dev/null +++ b/src/main/java/org/qed/Generated/JoinCommute.java @@ -0,0 +1,53 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class JoinCommute extends RelRule { + protected JoinCommute(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + var var_4 = (LogicalJoin) call.rel(0); + var var_5 = (org.apache.calcite.rex.RexCall) var_4.getCondition(); + var var_6 = ((org.apache.calcite.rex.RexInputRef) var_5.getOperands().get(0)).getIndex(); + var var_7 = (LogicalJoin) call.rel(0); + var var_8 = (org.apache.calcite.rex.RexCall) var_7.getCondition(); + var var_9 = ((org.apache.calcite.rex.RexInputRef) var_8.getOperands().get(1)).getIndex(); + var var_10 = call.rel(1).getRowType().getFieldCount(); + var var_11 = var_9 - var_10; + var var_12 = call.rel(1); + var var_13 = call.rel(2); + var var_14 = var_12.getRowType().getFieldCount(); + var var_15 = var_13.getRowType().getFieldCount(); + var var_16 = java.util.stream.IntStream.concat(java.util.stream.IntStream.range(var_15, var_15 + var_14), java.util.stream.IntStream.range(0, var_15)).boxed().collect(java.util.stream.Collectors.toList()); + var var_17 = var_3.push(call.rel(2)).push(call.rel(1)).join(JoinRelType.INNER, var_3.push(call.rel(2)).push(call.rel(1)).call(((org.apache.calcite.rex.RexCall) ((LogicalJoin) call.rel(0)).getCondition()).getOperator(), var_3.push(call.rel(2)).push(call.rel(1)).field(2, 1, var_6), var_3.push(call.rel(2)).push(call.rel(1)).field(2, 0, var_11))).fields(var_16); + call.transformTo(var_3.push(call.rel(2)).push(call.rel(1)).join(JoinRelType.INNER, var_3.push(call.rel(2)).push(call.rel(1)).call(((org.apache.calcite.rex.RexCall) ((LogicalJoin) call.rel(0)).getCondition()).getOperator(), var_3.push(call.rel(2)).push(call.rel(1)).field(2, 1, var_6), var_3.push(call.rel(2)).push(call.rel(1)).field(2, 0, var_11))).project(var_17).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default JoinCommute toRule() { + return new JoinCommute(this); + } + + @Override + default String description() { + return "JoinCommute"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/JoinExtractFilter.java b/src/main/java/org/qed/Generated/JoinExtractFilter.java new file mode 100644 index 0000000..132dc5e --- /dev/null +++ b/src/main/java/org/qed/Generated/JoinExtractFilter.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class JoinExtractFilter extends RelRule { + protected JoinExtractFilter(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(1)).push(call.rel(2)).join(JoinRelType.INNER, var_3.push(call.rel(1)).push(call.rel(2)).literal(true)).filter(((LogicalJoin) call.rel(0)).getCondition()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default JoinExtractFilter toRule() { + return new JoinExtractFilter(this); + } + + @Override + default String description() { + return "JoinExtractFilter"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/JoinPushTransitivePredicates.java b/src/main/java/org/qed/Generated/JoinPushTransitivePredicates.java new file mode 100644 index 0000000..50c24a5 --- /dev/null +++ b/src/main/java/org/qed/Generated/JoinPushTransitivePredicates.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class JoinPushTransitivePredicates extends RelRule { + protected JoinPushTransitivePredicates(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_4 = call.builder(); + call.transformTo(var_4.push(call.rel(2)).push(call.rel(3)).join(JoinRelType.INNER, var_4.push(call.rel(2)).push(call.rel(3)).and(((LogicalJoin) call.rel(1)).getCondition(), ((LogicalFilter) call.rel(0)).getCondition())).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default JoinPushTransitivePredicates toRule() { + return new JoinPushTransitivePredicates(this); + } + + @Override + default String description() { + return "JoinPushTransitivePredicates"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_3 -> s_3.operand(LogicalFilter.class).oneInput(s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Generated/MinusMerge.java b/src/main/java/org/qed/Generated/MinusMerge.java new file mode 100644 index 0000000..a519ec2 --- /dev/null +++ b/src/main/java/org/qed/Generated/MinusMerge.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class MinusMerge extends RelRule { + protected MinusMerge(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_5 = call.builder(); + call.transformTo(var_5.push(call.rel(2)).push(call.rel(3)).push(call.rel(4)).union(false, 2).minus(false, 2).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default MinusMerge toRule() { + return new MinusMerge(this); + } + + @Override + default String description() { + return "MinusMerge"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_4 -> s_4.operand(LogicalMinus.class).inputs(s_2 -> s_2.operand(LogicalMinus.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()), s_3 -> s_3.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/ProjectFilterTranspose.java b/src/main/java/org/qed/Generated/ProjectFilterTranspose.java new file mode 100644 index 0000000..cbbd7d8 --- /dev/null +++ b/src/main/java/org/qed/Generated/ProjectFilterTranspose.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class ProjectFilterTranspose extends RelRule { + protected ProjectFilterTranspose(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(2)).filter(((LogicalFilter) call.rel(0)).getCondition()).project(((LogicalProject) call.rel(1)).getProjects()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default ProjectFilterTranspose toRule() { + return new ProjectFilterTranspose(this); + } + + @Override + default String description() { + return "ProjectFilterTranspose"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalFilter.class).oneInput(s_1 -> s_1.operand(LogicalProject.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Generated/ProjectMerge.java b/src/main/java/org/qed/Generated/ProjectMerge.java new file mode 100644 index 0000000..0a4e458 --- /dev/null +++ b/src/main/java/org/qed/Generated/ProjectMerge.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class ProjectMerge extends RelRule { + protected ProjectMerge(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(2)).project(((LogicalProject) call.rel(0)).getProjects()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default ProjectMerge toRule() { + return new ProjectMerge(this); + } + + @Override + default String description() { + return "ProjectMerge"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalProject.class).oneInput(s_1 -> s_1.operand(LogicalProject.class).oneInput(s_0 -> s_0.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Generated/PruneEmptyFilter.java b/src/main/java/org/qed/Generated/PruneEmptyFilter.java new file mode 100644 index 0000000..bcb125e --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneEmptyFilter.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneEmptyFilter extends RelRule { + protected PruneEmptyFilter(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_2 = call.builder(); + call.transformTo(var_2.empty().build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneEmptyFilter toRule() { + return new PruneEmptyFilter(this); + } + + @Override + default String description() { + return "PruneEmptyFilter"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_1 -> s_1.operand(LogicalFilter.class).oneInput(s_0 -> s_0.operand(LogicalValues.class).noInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/PruneEmptyIntersect.java b/src/main/java/org/qed/Generated/PruneEmptyIntersect.java new file mode 100644 index 0000000..ddfcba3 --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneEmptyIntersect.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneEmptyIntersect extends RelRule { + protected PruneEmptyIntersect(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.empty().empty().intersect(false, 2).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneEmptyIntersect toRule() { + return new PruneEmptyIntersect(this); + } + + @Override + default String description() { + return "PruneEmptyIntersect"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalIntersect.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(LogicalValues.class).noInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/PruneEmptyMinus.java b/src/main/java/org/qed/Generated/PruneEmptyMinus.java new file mode 100644 index 0000000..41bffd1 --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneEmptyMinus.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneEmptyMinus extends RelRule { + protected PruneEmptyMinus(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.empty().build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneEmptyMinus toRule() { + return new PruneEmptyMinus(this); + } + + @Override + default String description() { + return "PruneEmptyMinus"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalMinus.class).inputs(s_0 -> s_0.operand(LogicalValues.class).noInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/PruneEmptyProject.java b/src/main/java/org/qed/Generated/PruneEmptyProject.java new file mode 100644 index 0000000..2241478 --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneEmptyProject.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneEmptyProject extends RelRule { + protected PruneEmptyProject(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_2 = call.builder(); + call.transformTo(var_2.empty().build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneEmptyProject toRule() { + return new PruneEmptyProject(this); + } + + @Override + default String description() { + return "PruneEmptyProject"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_1 -> s_1.operand(LogicalProject.class).oneInput(s_0 -> s_0.operand(LogicalValues.class).noInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/PruneEmptyUnion.java b/src/main/java/org/qed/Generated/PruneEmptyUnion.java new file mode 100644 index 0000000..df48a7d --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneEmptyUnion.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneEmptyUnion extends RelRule { + protected PruneEmptyUnion(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.empty().build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneEmptyUnion toRule() { + return new PruneEmptyUnion(this); + } + + @Override + default String description() { + return "PruneEmptyUnion"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalUnion.class).inputs(s_0 -> s_0.operand(LogicalValues.class).noInputs(), s_1 -> s_1.operand(LogicalValues.class).noInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/PruneLeftEmptyJoin.java b/src/main/java/org/qed/Generated/PruneLeftEmptyJoin.java new file mode 100644 index 0000000..95774f0 --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneLeftEmptyJoin.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneLeftEmptyJoin extends RelRule { + protected PruneLeftEmptyJoin(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(2)).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneLeftEmptyJoin toRule() { + return new PruneLeftEmptyJoin(this); + } + + @Override + default String description() { + return "PruneLeftEmptyJoin"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(LogicalValues.class).noInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/PruneRightEmptyJoin.java b/src/main/java/org/qed/Generated/PruneRightEmptyJoin.java new file mode 100644 index 0000000..5d4e1c3 --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneRightEmptyJoin.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneRightEmptyJoin extends RelRule { + protected PruneRightEmptyJoin(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_3 = call.builder(); + call.transformTo(var_3.push(call.rel(1)).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneRightEmptyJoin toRule() { + return new PruneRightEmptyJoin(this); + } + + @Override + default String description() { + return "PruneRightEmptyJoin"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(LogicalValues.class).noInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/PruneZeroRowsTable.java b/src/main/java/org/qed/Generated/PruneZeroRowsTable.java new file mode 100644 index 0000000..af02cf5 --- /dev/null +++ b/src/main/java/org/qed/Generated/PruneZeroRowsTable.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class PruneZeroRowsTable extends RelRule { + protected PruneZeroRowsTable(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_1 = call.builder(); + call.transformTo(var_1.push(call.rel(0)).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default PruneZeroRowsTable toRule() { + return new PruneZeroRowsTable(this); + } + + @Override + default String description() { + return "PruneZeroRowsTable"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_0 -> s_0.operand(RelNode.class).anyInputs(); + } + + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances-unprovable/JoinAssociate.java b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/JoinAssociate.java new file mode 100644 index 0000000..059695e --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/JoinAssociate.java @@ -0,0 +1,69 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record JoinAssociate() implements RRule.RRuleFamily { + static final RelRN a = RelRN.scan("A", "A_Type"); + static final RelRN b = RelRN.scan("B", "B_Type"); + static final RelRN c = RelRN.scan("C", "C_Type"); + static final String pred_ab = "pred_ab"; + static final String pred_bc = "pred_bc"; + static final RelRN.Join.JoinType.MetaJoinType mjt_0 = new RelRN.Join.JoinType.MetaJoinType("mjt_0"); + static final RelRN.Join.JoinType.MetaJoinType mjt_1 = new RelRN.Join.JoinType.MetaJoinType("mjt_1"); + static final RelRN.Join.JoinType.MetaJoinType mjt_2 = new RelRN.Join.JoinType.MetaJoinType("mjt_2"); + static final RelRN.Join.JoinType.MetaJoinType mjt_3 = new RelRN.Join.JoinType.MetaJoinType("mjt_3"); + + static final RelRN before_ab = a.join(mjt_0, RexRN.and( + a.joinPred(pred_ab, b), + new RexRN.JoinField(1, a, b).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), b); + + static final RelRN before = before_ab.join(mjt_1, RexRN.and( + new RexRN.Pred(RuleBuilder.create().genericPredicateOp(pred_bc, true), before_ab.joinFields(c, 1, 2)), + new RexRN.JoinField(1, before_ab, c).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), c); + + static final RelRN after_bc = b.join(mjt_2, RexRN.and( + b.joinPred(pred_bc, c), + new RexRN.JoinField(0, b, c).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), c); + + static final RelRN after = a.join(mjt_3, RexRN.and( + new RexRN.Pred(RuleBuilder.create().genericPredicateOp(pred_ab, true), a.joinFields(after_bc, 0, 1)), + new RexRN.JoinField(1, a, after_bc).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), after_bc); + + static final RRule template = new RRule() { + @Override + public RelRN before() { + return before; + } + + @Override + public RelRN after() { + return after; + } + + @Override + public String name() { + return JoinAssociate.class.getSimpleName(); + } + }; + + static Seq assignments() { + var joinTypes = Seq.of(JoinRelType.INNER, JoinRelType.LEFT, JoinRelType.RIGHT, JoinRelType.FULL).map(RelRN.Join.JoinType.ConcreteJoinType::new); + return joinTypes.flatMap(jt0 -> joinTypes.flatMap(jt1 -> joinTypes.flatMap(jt2 -> joinTypes.map(jt3 -> new RRule.RRuleGenerator.MetaAssignment(Map.of(mjt_0, jt0, mjt_1, jt1, mjt_2, jt2, mjt_3, jt3)))))); + } + + @Override + public Seq family() { + return new RRule.RRuleGenerator(template, assignments()).family(); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances-unprovable/PruneLeftEmptyJoin.java b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/PruneLeftEmptyJoin.java new file mode 100644 index 0000000..fbe5979 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/PruneLeftEmptyJoin.java @@ -0,0 +1,21 @@ +package org.qed.Generated.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record PruneLeftEmptyJoin() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + + @Override + public RelRN before() { + return left.empty().join(JoinRelType.RIGHT, "pred", right); + } + + @Override + public RelRN after() { + return right; + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances-unprovable/PruneRightEmptyJoin.java b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/PruneRightEmptyJoin.java new file mode 100644 index 0000000..4b4b5c0 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/PruneRightEmptyJoin.java @@ -0,0 +1,22 @@ +package org.qed.Generated.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record PruneRightEmptyJoin() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right.empty()); + + @Override + public RelRN before() { + return left.join(JoinRelType.LEFT, "pred", right.empty()); + } + + @Override + public RelRN after() { + return left; + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinJoinTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinJoinTranspose.java new file mode 100644 index 0000000..b1a4a3d --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinJoinTranspose.java @@ -0,0 +1,25 @@ +package org.qed.Generated.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + + +public record SemiJoinJoinTranspose() implements RRule { + static final RelRN left = RelRN.scan("left", "left_Type"); + static final RelRN middle = RelRN.scan("middle", "middle_Type"); + static final RelRN right = RelRN.scan("right", "right_Type"); + static final RexRN semiCond = left.joinPred("semi", middle); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, joinCond, right).join(JoinRelType.SEMI, semiCond, middle); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.SEMI, semiCond, middle).join(JoinRelType.INNER, joinCond, right); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinProjectTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinProjectTranspose.java new file mode 100644 index 0000000..71d2dfe --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinProjectTranspose.java @@ -0,0 +1,23 @@ +package org.qed.Generated.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record SemiJoinProjectTranspose() implements RRule { + static final RelRN left = RelRN.scan("Left", "left_type"); + static final RelRN right = RelRN.scan("Right", "right_type"); + static final RexRN proj = left.proj("proj", "proj_type"); + static final RexRN semiCond = left.joinPred("semi", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.SEMI, semiCond, right).project(proj); + } + + @Override + public RelRN after() { + return left.project(proj).join(JoinRelType.SEMI, semiCond, right); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinRemove.java b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinRemove.java new file mode 100644 index 0000000..3b21e3f --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances-unprovable/SemiJoinRemove.java @@ -0,0 +1,21 @@ +package org.qed.Generated.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record SemiJoinRemove() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + + @Override + public RelRN before() { + return left.join(JoinRelType.SEMI, RexRN.trueLiteral(), right); + } + + @Override + public RelRN after() { + return left; + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/AggregateFilterTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/AggregateFilterTranspose.java new file mode 100644 index 0000000..ab35583 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/AggregateFilterTranspose.java @@ -0,0 +1,30 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RelType; + +public record AggregateFilterTranspose() implements RRule { + static final RelRN source_col1 = RelRN.scan("Source1", "Int_Type"); + static final RelRN source_col2 = RelRN.scan("Source2", "Int_Type"); + static final RelRN source = source_col1.join(JoinRelType.INNER, RexRN.trueLiteral(), source_col2); + static final RexRN filterCondition = source.field(0).pred("cond"); + + @Override + public RelRN before() { + RelRN filteredSource = source.filter(filterCondition); + return new RelRN.Aggregate( + filteredSource, + Seq.of(filteredSource.field(0)), + Seq.of(new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(filteredSource.field(1)))) + ); + } + @Override + public RelRN after() { + RelRN.Aggregate aggregateOnSource = new RelRN.Aggregate(source, Seq.of(source.field(0)), Seq.of(new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(source.field(1))))); + return aggregateOnSource.filter(aggregateOnSource.field(0).pred("cond")); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java new file mode 100644 index 0000000..48c92d9 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java @@ -0,0 +1,29 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RelType; + +public record FilterAggregateTranspose() implements RRule { + static final RelRN source_col1 = RelRN.scan("Source1", "Int_Type"); + static final RelRN source_col2 = RelRN.scan("Source2", "Int_Type"); + static final RelRN source = source_col1.join(JoinRelType.INNER, RexRN.trueLiteral(), source_col2); + static final RelRN.Aggregate aggregate = new RelRN.Aggregate(source, Seq.of(source.field(0)), Seq.of(new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(source.field(1))))); + static final RexRN pushedCondition = aggregate.field(0).pred("pushed_cond"); + static final RexRN remainingCondition = aggregate.field(1).pred("remaining_cond"); + + @Override + public RelRN before() { + return aggregate.filter(RexRN.and(pushedCondition, remainingCondition)); + } + + @Override + public RelRN after() { + RelRN filteredSource = source.filter(source.field(0).pred("pushed_cond")); + RelRN.Aggregate newAggregate = new RelRN.Aggregate(filteredSource, Seq.of(filteredSource.field(0)), Seq.of(new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(filteredSource.field(1))))); + return newAggregate.filter(newAggregate.field(1).pred("remaining_cond")); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java new file mode 100644 index 0000000..d16f8ef --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterIntoJoin.java @@ -0,0 +1,27 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record FilterIntoJoin() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + var join = left.join(JoinRelType.INNER, joinCond, right); + return join.filter("outer"); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.INNER, RexRN.and(joinCond, left.joinPred("outer", right)), right); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterMerge.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterMerge.java new file mode 100644 index 0000000..d90aa6d --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterMerge.java @@ -0,0 +1,26 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record FilterMerge() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN inner = source.pred("inner"); + static final RexRN outer = source.pred("outer"); + + @Override + public RelRN before() { + return source.filter(inner).filter(outer); + } + + @Override + public RelRN after() { + return source.filter(RexRN.and(inner, outer)); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterProjectTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterProjectTranspose.java new file mode 100644 index 0000000..31a3a74 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterProjectTranspose.java @@ -0,0 +1,25 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record FilterProjectTranspose() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN proj = source.proj("proj", "Project_Type"); + + @Override + public RelRN before() { + return source.filter(proj.pred("pred")).project(proj); + } + + @Override + public RelRN after() { + return source.project(proj).filter("pred"); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterReduceFalse.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterReduceFalse.java new file mode 100644 index 0000000..8b4dd4b --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterReduceFalse.java @@ -0,0 +1,24 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record FilterReduceFalse() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + + @Override + public RelRN before() { + return source.filter(RexRN.falseLiteral()); + } + + @Override + public RelRN after() { + return source.empty(); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterReduceTrue.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterReduceTrue.java new file mode 100644 index 0000000..c3dd472 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterReduceTrue.java @@ -0,0 +1,24 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record FilterReduceTrue() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + + @Override + public RelRN before() { + return source.filter(RexRN.trueLiteral()); + } + + @Override + public RelRN after() { + return source; + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterSetOpTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterSetOpTranspose.java new file mode 100644 index 0000000..2dcf9e2 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterSetOpTranspose.java @@ -0,0 +1,28 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record FilterSetOpTranspose() implements RRule { + static final RelRN left = RelRN.scan("Left", "Common_Type"); + static final RelRN right = RelRN.scan("Right", "Common_Type"); + + @Override + public RelRN before() { + RelRN projTmp = left.union(false, right); + return projTmp.filter(projTmp.pred("filter")); + } + + @Override + public RelRN after() { + RexRN leftPred = left.pred("filter"); + RexRN rightPred = right.pred("filter"); + return left.filter(leftPred).union(false, right.filter(rightPred)); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/IntersectMerge.java b/src/main/java/org/qed/Generated/RRuleInstances/IntersectMerge.java new file mode 100644 index 0000000..27afe67 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/IntersectMerge.java @@ -0,0 +1,29 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record IntersectMerge() implements RRule { + // Use a common type for all relations to make them compatible + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + static final RelRN c = RelRN.scan("C", "Common_Type"); + + @Override + public RelRN before() { + // Nested INTERSECT: (A INTERSECT B) INTERSECT C + return a.intersect(false, b).intersect(false, c); + } + + @Override + public RelRN after() { + // Flattened INTERSECT: A INTERSECT B INTERSECT C + return a.intersect(false, b, c); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/JoinAddRedundantSemiJoin.java b/src/main/java/org/qed/Generated/RRuleInstances/JoinAddRedundantSemiJoin.java new file mode 100644 index 0000000..ec54c50 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/JoinAddRedundantSemiJoin.java @@ -0,0 +1,26 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record JoinAddRedundantSemiJoin() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final String pred = "pred"; + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, pred, right); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.SEMI, pred, right).join(JoinRelType.INNER, pred, right); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/JoinCommute.java b/src/main/java/org/qed/Generated/RRuleInstances/JoinCommute.java new file mode 100644 index 0000000..d1580be --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/JoinCommute.java @@ -0,0 +1,50 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Seq; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlOperator; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record JoinCommute() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final String pred = "pred"; + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, pred, right); + } + + @Override + public RelRN after() { + SqlOperator predOp = RuleBuilder.create().genericPredicateOp(pred, true); + RexRN swappedPred = new RexRN.Pred(predOp, Seq.of( + new RexRN.JoinField(1, right, left), + new RexRN.JoinField(0, right, left) + )); + RelRN swappedJoin = right.join(JoinRelType.INNER, swappedPred, left); + + return new ProjectionRelRN(swappedJoin); + } + + // implementation for the column reordering projection + public static record ProjectionRelRN(RelRN source) implements RelRN { + @Override + public RelNode semantics() { + RuleBuilder builder = RuleBuilder.create(); + builder.push(source.semantics()); + + RexNode leftField = builder.field(1); // Left columns (now at position 1) + RexNode rightField = builder.field(0); // Right columns (now at position 0) + + builder.project(leftField, rightField); + + return builder.build(); + } + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/RRuleInstances/JoinExtractFilter.java b/src/main/java/org/qed/Generated/RRuleInstances/JoinExtractFilter.java new file mode 100644 index 0000000..63db3ec --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/JoinExtractFilter.java @@ -0,0 +1,26 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record JoinExtractFilter() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, joinCond, right); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.INNER, RexRN.trueLiteral(), right).filter(joinCond); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/JoinPushTransitivePredicates.java b/src/main/java/org/qed/Generated/RRuleInstances/JoinPushTransitivePredicates.java new file mode 100644 index 0000000..1aaa5d9 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/JoinPushTransitivePredicates.java @@ -0,0 +1,23 @@ +package org.qed.Generated.RRuleInstances; + +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record JoinPushTransitivePredicates() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN cond1 = left.joinPred("cond1", right); + static final RexRN cond2 = left.joinPred("cond2", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, cond1, right).filter(cond2); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.INNER, RexRN.and(cond1, cond2), right); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/MinusMerge.java b/src/main/java/org/qed/Generated/RRuleInstances/MinusMerge.java new file mode 100644 index 0000000..a080685 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/MinusMerge.java @@ -0,0 +1,26 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record MinusMerge() implements RRule { + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + static final RelRN c = RelRN.scan("C", "Common_Type"); + + @Override + public RelRN before() { + return a.minus(false, b).minus(false, c); + } + + @Override + public RelRN after() { + return a.minus(false, b.union(false, c)); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/ProjectAggregateMerge.java b/src/main/java/org/qed/Generated/RRuleInstances/ProjectAggregateMerge.java new file mode 100644 index 0000000..55bbfb9 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/ProjectAggregateMerge.java @@ -0,0 +1,42 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RelType; + +public record ProjectAggregateMerge() implements RRule { + // Create a two-column source relation using a join, as the aggregate needs multiple columns. + static final RelRN source_col1 = RelRN.scan("SourceCol1", "Int_Type"); + static final RelRN source_col2 = RelRN.scan("SourceCol2", "Int_Type"); + static final RelRN source = source_col1.join(JoinRelType.INNER, RexRN.trueLiteral(), source_col2); + + static final RelRN.Aggregate aggregate = new RelRN.Aggregate( + source, + Seq.of(source.field(0)), // group by key (from source_col1) + Seq.of(new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(source.field(1)))) // sum(field from source_col2) + ); + + @Override + public RelRN before() { + // Project on top of the aggregate, selecting only the SUM result. + return aggregate.project(aggregate.field(1)); + } + + @Override + public RelRN after() { + // The 'after' state represents the merged operation. + // This would be a single aggregate operator that produces the projected result directly. + // Since the projection removes the group key, the ideal 'after' would be an aggregate + // that computes the sum for each group but does not output the group key. + // As a placeholder for this complex transformation, we create an aggregate + // with an empty group set, which results in a single row output (total sum). + return new RelRN.Aggregate( + source, + Seq.empty(), + Seq.of(new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(source.field(1)))) + ); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/RRuleInstances/ProjectFilterTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/ProjectFilterTranspose.java new file mode 100644 index 0000000..8a17d2c --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/ProjectFilterTranspose.java @@ -0,0 +1,20 @@ +package org.qed.Generated.RRuleInstances; + +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; + +public record ProjectFilterTranspose() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN proj = source.proj("proj", "Project_Type"); + + @Override + public RelRN before() { + return source.project(proj).filter("pred"); + } + + @Override + public RelRN after() { + return source.filter(proj.pred("pred")).project(proj); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/ProjectMerge.java b/src/main/java/org/qed/Generated/RRuleInstances/ProjectMerge.java new file mode 100644 index 0000000..7458002 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/ProjectMerge.java @@ -0,0 +1,27 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record ProjectMerge() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN inner = source.proj("inner", "Inner_Type"); + static final String outer = "outer"; + static final String outerType = "Outer_Type"; + + @Override + public RelRN before() { + return source.project(inner).project(outer, outerType); + } + + @Override + public RelRN after() { + return source.project(inner.proj(outer, outerType)); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyFilter.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyFilter.java new file mode 100644 index 0000000..a0c5eaf --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyFilter.java @@ -0,0 +1,20 @@ +package org.qed.Generated.RRuleInstances; + +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record PruneEmptyFilter() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN cond = source.pred("filter_cond"); + + @Override + public RelRN before() { + return source.empty().filter(cond); + } + + @Override + public RelRN after() { + return source.empty(); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyIntersect.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyIntersect.java new file mode 100644 index 0000000..fb4938f --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyIntersect.java @@ -0,0 +1,20 @@ +package org.qed.Generated.RRuleInstances; + +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record PruneEmptyIntersect() implements RRule { + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + + @Override + public RelRN before() { + return a.intersect(false, b.empty()); + } + + @Override + public RelRN after() { + return a.empty().intersect(false, b.empty()); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyMinus.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyMinus.java new file mode 100644 index 0000000..09c424c --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyMinus.java @@ -0,0 +1,20 @@ +package org.qed.Generated.RRuleInstances; + +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record PruneEmptyMinus() implements RRule { + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + + @Override + public RelRN before() { + return a.empty().minus(false, b); + } + + @Override + public RelRN after() { + return a.empty(); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyProject.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyProject.java new file mode 100644 index 0000000..0aac108 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyProject.java @@ -0,0 +1,20 @@ +package org.qed.Generated.RRuleInstances; + +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record PruneEmptyProject() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN proj = source.proj("proj", "Project_Type"); + + @Override + public RelRN before() { + return source.empty().project(proj); + } + + @Override + public RelRN after() { + return source.empty(); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyUnion.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyUnion.java new file mode 100644 index 0000000..74d3fc5 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneEmptyUnion.java @@ -0,0 +1,20 @@ +package org.qed.Generated.RRuleInstances; + +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record PruneEmptyUnion() implements RRule { + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + + @Override + public RelRN before() { + return a.empty().union(false, b.empty()); + } + + @Override + public RelRN after() { + return a.empty(); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/PruneZeroRowsTable.java b/src/main/java/org/qed/Generated/RRuleInstances/PruneZeroRowsTable.java new file mode 100644 index 0000000..393b226 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/PruneZeroRowsTable.java @@ -0,0 +1,19 @@ +package org.qed.Generated.RRuleInstances; + +import org.qed.RRule; +import org.qed.RelRN; +import org.qed.RexRN; + +public record PruneZeroRowsTable() implements RRule { + static final RelRN a = RelRN.scan("A", "Common_Type"); + + @Override + public RelRN before() { + return a; + } + + @Override + public RelRN after() { + return a; + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinFilterTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinFilterTranspose.java new file mode 100644 index 0000000..790c1d2 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/SemiJoinFilterTranspose.java @@ -0,0 +1,28 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record SemiJoinFilterTranspose() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right); + static final RexRN filterPred = left.pred("filter"); + + @Override + public RelRN before() { + return left.join(JoinRelType.SEMI, joinCond, right).filter(filterPred); + } + + @Override + public RelRN after() { + RelRN leftFiltered = left.filter(filterPred); + return leftFiltered.join(JoinRelType.SEMI, leftFiltered.joinPred("join", right), right); + } +} diff --git a/src/main/java/org/qed/Generated/RRuleInstances/UnionMerge.java b/src/main/java/org/qed/Generated/RRuleInstances/UnionMerge.java new file mode 100644 index 0000000..b71d320 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/UnionMerge.java @@ -0,0 +1,26 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RuleBuilder; + +public record UnionMerge() implements RRule { + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + static final RelRN c = RelRN.scan("C", "Common_Type"); + + @Override + public RelRN before() { + return a.union(false, b).union(false, c); + } + + @Override + public RelRN after() { + return a.union(false, b, c); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/SemiJoinFilterTranspose.java b/src/main/java/org/qed/Generated/SemiJoinFilterTranspose.java new file mode 100644 index 0000000..cc175ad --- /dev/null +++ b/src/main/java/org/qed/Generated/SemiJoinFilterTranspose.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class SemiJoinFilterTranspose extends RelRule { + protected SemiJoinFilterTranspose(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_4 = call.builder(); + call.transformTo(var_4.push(call.rel(2)).filter(((LogicalFilter) call.rel(0)).getCondition()).push(call.rel(3)).join(JoinRelType.SEMI, ((LogicalJoin) call.rel(1)).getCondition()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default SemiJoinFilterTranspose toRule() { + return new SemiJoinFilterTranspose(this); + } + + @Override + default String description() { + return "SemiJoinFilterTranspose"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_3 -> s_3.operand(LogicalFilter.class).oneInput(s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Generated/SemiJoinJoinTranspose.java b/src/main/java/org/qed/Generated/SemiJoinJoinTranspose.java new file mode 100644 index 0000000..33cc330 --- /dev/null +++ b/src/main/java/org/qed/Generated/SemiJoinJoinTranspose.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class SemiJoinJoinTranspose extends RelRule { + protected SemiJoinJoinTranspose(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_5 = call.builder(); + call.transformTo(var_5.push(call.rel(2)).push(call.rel(4)).join(JoinRelType.SEMI, ((LogicalJoin) call.rel(0)).getCondition()).push(call.rel(3)).join(JoinRelType.INNER, ((LogicalJoin) call.rel(1)).getCondition()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default SemiJoinJoinTranspose toRule() { + return new SemiJoinJoinTranspose(this); + } + + @Override + default String description() { + return "SemiJoinJoinTranspose"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_4 -> s_4.operand(LogicalJoin.class).inputs(s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()), s_3 -> s_3.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/SemiJoinProjectTranspose.java b/src/main/java/org/qed/Generated/SemiJoinProjectTranspose.java new file mode 100644 index 0000000..8ffffca --- /dev/null +++ b/src/main/java/org/qed/Generated/SemiJoinProjectTranspose.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class SemiJoinProjectTranspose extends RelRule { + protected SemiJoinProjectTranspose(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_4 = call.builder(); + call.transformTo(var_4.push(call.rel(2)).project(((LogicalProject) call.rel(0)).getProjects()).push(call.rel(3)).join(JoinRelType.SEMI, ((LogicalJoin) call.rel(1)).getCondition()).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default SemiJoinProjectTranspose toRule() { + return new SemiJoinProjectTranspose(this); + } + + @Override + default String description() { + return "SemiJoinProjectTranspose"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_3 -> s_3.operand(LogicalProject.class).oneInput(s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs())); + } + + } +} diff --git a/src/main/java/org/qed/Generated/SemiJoinRemove.java b/src/main/java/org/qed/Generated/SemiJoinRemove.java new file mode 100644 index 0000000..8d02d36 --- /dev/null +++ b/src/main/java/org/qed/Generated/SemiJoinRemove.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class SemiJoinRemove extends RelRule { + protected SemiJoinRemove(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_4 = call.builder(); + call.transformTo(var_4.push(call.rel(1)).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default SemiJoinRemove toRule() { + return new SemiJoinRemove(this); + } + + @Override + default String description() { + return "SemiJoinRemove"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_2 -> s_2.operand(LogicalJoin.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyFilterTest.java b/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyFilterTest.java new file mode 100644 index 0000000..dde4f6e --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyFilterTest.java @@ -0,0 +1,46 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class PruneEmptyFilterTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var table = builder.createQedTable( + Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false)) + ); + builder.addTable(table); + + var before = builder + .scan(table.getName()) + .filter( + builder.call( + builder.genericPredicateOp("filter_cond", true), + builder.fields() + ) + ) + .empty() + .build(); + + var after = builder + .scan(table.getName()) + .empty() + .build(); + + var runner = CalciteTester.loadRule( + org.qed.Generated.PruneEmptyFilter.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running PruneEmptyFilter test..."); + runTest(); + } +} diff --git a/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyMinusTest.java b/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyMinusTest.java new file mode 100644 index 0000000..d24d6ad --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyMinusTest.java @@ -0,0 +1,51 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; +import org.apache.calcite.rel.RelNode; + +public class PruneEmptyMinusTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var table = builder.createQedTable( + Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false)) + ); + builder.addTable(table); + + RelNode scanA = builder + .scan(table.getName()) + .build(); + + RelNode scanB = builder + .scan(table.getName()) + .build(); + + RelNode before = builder + .push(scanA) + .push(scanB) + .minus(false) + .empty() + .build(); + + RelNode after = builder + .push(scanA) + .empty() + .build(); + +// var runner = CalciteTester.loadRule( +// org.qed.Generated.PruneEmptyMinus.Config.DEFAULT.toRule() +// ); +// tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running PruneEmptyMinus test..."); + runTest(); + } +} diff --git a/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyProjectTest.java b/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyProjectTest.java new file mode 100644 index 0000000..547da2c --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyProjectTest.java @@ -0,0 +1,41 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class PruneEmptyProjectTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var table = builder.createQedTable( + Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false)) + ); + builder.addTable(table); + + var before = builder + .scan(table.getName()) + .empty() + .project(builder.field(0)) + .build(); + + var after = builder + .scan(table.getName()) + .empty() + .build(); + + var runner = CalciteTester.loadRule( + org.qed.Generated.PruneEmptyProject.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running PruneEmptyProject test..."); + runTest(); + } +} diff --git a/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyUnionTest.java b/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyUnionTest.java new file mode 100644 index 0000000..dcbbc9d --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests-Trivial/PruneEmptyUnionTest.java @@ -0,0 +1,62 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; +import org.apache.calcite.rel.RelNode; + +public class PruneEmptyUnionTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var table = builder.createQedTable( + Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false)) + ); + builder.addTable(table); + + RelNode scanA = builder + .scan(table.getName()) + .build(); + + RelNode emptyA = builder + .push(scanA) + .empty() + .build(); + + RelNode scanB = builder + .scan(table.getName()) + .build(); + + RelNode emptyB = builder + .push(scanB) + .empty() + .build(); + + RelNode before = builder + .push(scanA) + .push(scanB) + .union(false) + .empty() + .build(); + + RelNode after = builder + .push(emptyA) + .push(emptyB) + .union(false) + .build(); + + var runner = CalciteTester.loadRule( + org.qed.Generated.PruneEmptyUnion.Config.DEFAULT.toRule() + ); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running PruneEmptyMinus test..."); + runTest(); + } +} diff --git a/src/main/java/org/qed/Generated/Tests-Trivial/PruneZeroRowsTableTest.java b/src/main/java/org/qed/Generated/Tests-Trivial/PruneZeroRowsTableTest.java new file mode 100644 index 0000000..b7a91ed --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests-Trivial/PruneZeroRowsTableTest.java @@ -0,0 +1,40 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class PruneZeroRowsTableTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var table = builder.createQedTable( + Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false)) + ); + builder.addTable(table); + + var before = builder + .scan(table.getName()) + .empty() + .build(); + + var after = builder + .scan(table.getName()) + .empty() + .build(); + +// var runner = CalciteTester.loadRule( +// org.qed.Generated.PruneZeroRowsTable.Config.DEFAULT.toRule() +// ); +// tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running PruneZeroRowsTableTest..."); + runTest(); + } +} diff --git a/src/main/java/org/qed/Generated/Tests-Trivial/SemiJoinProjectTransposeTest.java b/src/main/java/org/qed/Generated/Tests-Trivial/SemiJoinProjectTransposeTest.java new file mode 100644 index 0000000..47d0755 --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests-Trivial/SemiJoinProjectTransposeTest.java @@ -0,0 +1,55 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class SemiJoinProjectTransposeTest { + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var leftTable = builder.createQedTable( + Seq.of(Tuple.of(RelType.fromString("INTEGER", true), /*nullable?*/ false))); + var rightTable = builder.createQedTable( + Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + + builder.addTable(leftTable); + builder.addTable(rightTable); + + var leftScan = builder.scan(leftTable.getName()).build(); + var rightScan = builder.scan(rightTable.getName()).build(); + + builder.push(leftScan); + builder.push(rightScan); + var joinPred = builder.equals( + builder.field(2, 0, 0), // left field 0 + builder.field(2, 1, 0)); // right field 0 + var semiJoin = builder.join(JoinRelType.SEMI, joinPred).build(); + builder.push(semiJoin); + + var before = builder.project(builder.field(0)).build(); + + builder.push(leftScan); + var projectedLeft = builder.project(builder.field(0)).build(); + + builder.push(projectedLeft); + builder.push(rightScan); + var afterJoinPred = builder.equals( + builder.field(2, 0, 0), + builder.field(2, 1, 0)); + var after = builder.join(JoinRelType.SEMI, afterJoinPred).build(); + + var runner = CalciteTester.loadRule( + org.qed.Generated.SemiJoinProjectTranspose.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running SemiJoinProjectTranspose test..."); + runTest(); + } +} diff --git a/src/main/java/org/qed/Generated/Tests/FilterIntoJoinTest.java b/src/main/java/org/qed/Generated/Tests/FilterIntoJoinTest.java new file mode 100644 index 0000000..f8fef01 --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/FilterIntoJoinTest.java @@ -0,0 +1,42 @@ + +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.Generated.RRuleInstances.FilterIntoJoin; +import org.qed.RuleBuilder; + +public class FilterIntoJoinTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + builder.addTable(table); + + var before = builder.scan(table.getName()) + .scan(table.getName()) + .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), builder.joinFields())) + .filter(builder.call(builder.genericPredicateOp("pred", true), builder.fields())) + .build(); + + var after = builder.scan(table.getName()) + .scan(table.getName()) + .join(JoinRelType.INNER, builder.call(SqlStdOperatorTable.AND, + builder.call(builder.genericPredicateOp("join", true), builder.joinFields()), + builder.call(builder.genericPredicateOp("pred", true), builder.joinFields()))) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.FilterIntoJoin.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running FilterIntoJoin test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/Tests/FilterMergeTest.java b/src/main/java/org/qed/Generated/Tests/FilterMergeTest.java new file mode 100644 index 0000000..0e40d00 --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/FilterMergeTest.java @@ -0,0 +1,36 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class FilterMergeTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + builder.addTable(table); + + var before = builder.scan(table.getName()) + .filter(builder.call(builder.genericPredicateOp("inner", true), builder.fields())) + .filter(builder.call(builder.genericPredicateOp("outer", true), builder.fields())) + .build(); + + var after = builder.scan(table.getName()).filter(builder.call(SqlStdOperatorTable.AND, + builder.call(builder.genericPredicateOp("inner", true), builder.fields()), + builder.call(builder.genericPredicateOp("outer", true), builder.fields()))) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.FilterMerge.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running FilterMerge test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/Tests/FilterProjectTransposeTest.java b/src/main/java/org/qed/Generated/Tests/FilterProjectTransposeTest.java new file mode 100644 index 0000000..0131765 --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/FilterProjectTransposeTest.java @@ -0,0 +1,42 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class FilterProjectTransposeTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var scan = builder.scan(table.getName()).build(); + + var before = builder + .push(scan) + .filter(builder.equals(builder.field(0), builder.literal(10))) + .project(builder.field(0)) + .build(); + + var after = builder + .push(scan) + .project(builder.field(0)) + .filter(builder.equals(builder.field(0), builder.literal(10))) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.FilterProjectTranspose.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running FilterProjectTranspose test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/Tests/FilterSetOpTransposeTest.java b/src/main/java/org/qed/Generated/Tests/FilterSetOpTransposeTest.java new file mode 100644 index 0000000..488b474 --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/FilterSetOpTransposeTest.java @@ -0,0 +1,37 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class FilterSetOpTransposeTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var scan1 = builder.scan(table.getName()).build(); + var scan2 = builder.scan(table.getName()).build(); + + var union = builder.push(scan1).push(scan2).union(false).build(); + var before = builder.push(union).filter(builder.call(builder.genericPredicateOp("filter", true), builder.fields())).build(); + + var filteredScan1 = builder.push(scan1).filter(builder.call(builder.genericPredicateOp("filter", true), builder.fields())).build(); + var filteredScan2 = builder.push(scan2).filter(builder.call(builder.genericPredicateOp("filter", true), builder.fields())).build(); + var after = builder.push(filteredScan1).push(filteredScan2).union(false).build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.FilterSetOpTranspose.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running FilterSetOpTranspose test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/Tests/IntersectMergeTest.java b/src/main/java/org/qed/Generated/Tests/IntersectMergeTest.java new file mode 100644 index 0000000..ff4093d --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/IntersectMergeTest.java @@ -0,0 +1,36 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class IntersectMergeTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var scan1 = builder.scan(table.getName()).build(); + var scan2 = builder.scan(table.getName()).build(); + var scan3 = builder.scan(table.getName()).build(); + + var firstIntersect = builder.push(scan1).push(scan2).intersect(false).build(); + var before = builder.push(firstIntersect).push(scan3).intersect(false).build(); + + var after = builder.push(scan1).push(scan2).push(scan3).intersect(false, 3).build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.IntersectMerge.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running IntersectMerge test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/Tests/JoinCommuteTest.java b/src/main/java/org/qed/Generated/Tests/JoinCommuteTest.java new file mode 100644 index 0000000..c6d0203 --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/JoinCommuteTest.java @@ -0,0 +1,84 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class JoinCommuteTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + // Create EMP table (8 columns like the real Calcite test) + var empTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), // EMPNO + Tuple.of(RelType.fromString("VARCHAR", true), false), // ENAME + Tuple.of(RelType.fromString("VARCHAR", true), false), // JOB + Tuple.of(RelType.fromString("INTEGER", true), false), // MGR + Tuple.of(RelType.fromString("DATE", true), false), // HIREDATE + Tuple.of(RelType.fromString("DECIMAL", true), false), // SAL + Tuple.of(RelType.fromString("DECIMAL", true), false), // COMM + Tuple.of(RelType.fromString("INTEGER", true), false) // DEPTNO + )); + builder.addTable(empTable); + + // Create DEPT table (3 columns like the real Calcite test) + var deptTable = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), // DEPTNO + Tuple.of(RelType.fromString("VARCHAR", true), false), // DNAME + Tuple.of(RelType.fromString("VARCHAR", true), false) // LOC + )); + builder.addTable(deptTable); + + // Build the "before" pattern: EMP JOIN DEPT ON EMP.DEPTNO = DEPT.DEPTNO + var empScan = builder.scan(empTable.getName()).build(); + var deptScan = builder.scan(deptTable.getName()).build(); + + var before = builder + .push(empScan) + .push(deptScan) + .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("equals", true), + builder.field(2, 0, 7), // EMP.DEPTNO (8th column, index 7) + builder.field(2, 1, 0) // DEPT.DEPTNO (1st column, index 0) + )) + .build(); + + // Build the "after" pattern: DEPT JOIN EMP ON DEPT.DEPTNO = EMP.DEPTNO + // followed by projection to restore original column order + var after = builder + .push(deptScan) + .push(empScan) + .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("equals", true), + builder.field(2, 1, 7), // EMP.DEPTNO (now at input 1, field 7) + builder.field(2, 0, 0) // DEPT.DEPTNO (now at input 0, field 0) + )) + .project( + // EMP columns first (now at positions 3-10 in the swapped join) + builder.field(3), // EMPNO + builder.field(4), // ENAME + builder.field(5), // JOB + builder.field(6), // MGR + builder.field(7), // HIREDATE + builder.field(8), // SAL + builder.field(9), // COMM + builder.field(10), // DEPTNO + // DEPT columns second (now at positions 0-2 in the swapped join) + builder.field(0), // DEPTNO0 + builder.field(1), // DNAME + builder.field(2) // LOC + ) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.JoinCommute.Config.DEFAULT.toRule(), 1); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running JoinCommute test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/Tests/JoinExtractFilterTest.java b/src/main/java/org/qed/Generated/Tests/JoinExtractFilterTest.java new file mode 100644 index 0000000..db2138a --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/JoinExtractFilterTest.java @@ -0,0 +1,46 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class JoinExtractFilterTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var leftTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + var rightTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("VARCHAR", true), false))); + builder.addTable(leftTable); + builder.addTable(rightTable); + + var leftScan = builder.scan(leftTable.getName()).build(); + var rightScan = builder.scan(rightTable.getName()).build(); + + var before = builder.push(leftScan) + .push(rightScan) + .join(JoinRelType.INNER, builder.call(builder.genericPredicateOp("join", true), + builder.field(2, 0, 0), builder.field(2, 1, 0))) + .build(); + + var trueJoin = builder.push(leftScan) + .push(rightScan) + .join(JoinRelType.INNER, builder.literal(true)) + .build(); + + var after = builder.push(trueJoin) + .filter(builder.call(builder.genericPredicateOp("join", true), builder.field(0), builder.field(1))) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.JoinExtractFilter.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running JoinExtractFilter test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/Tests/JoinPushTransitivePredicatesTest.java b/src/main/java/org/qed/Generated/Tests/JoinPushTransitivePredicatesTest.java new file mode 100644 index 0000000..3da47c7 --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/JoinPushTransitivePredicatesTest.java @@ -0,0 +1,42 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class JoinPushTransitivePredicatesTest { + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var before = builder.scan(table.getName()) + .scan(table.getName()).join(JoinRelType.INNER,builder.call(builder.genericPredicateOp("cond1", true), builder.joinFields())) + .filter(builder.call(builder.genericPredicateOp("cond2", true), builder.fields())) + .build(); + + var after = builder.scan(table.getName()) + .scan(table.getName()) + .join(JoinRelType.INNER, builder.call(SqlStdOperatorTable.AND, + builder.call(builder.genericPredicateOp("cond1", true), builder.joinFields()), + builder.call(builder.genericPredicateOp("cond2", true), builder.joinFields()))) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.JoinPushTransitivePredicates.Config.DEFAULT.toRule()); + + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running ProjectFilterTranspose test..."); + runTest(); + } +} diff --git a/src/main/java/org/qed/Generated/Tests/MinusMergeTest.java b/src/main/java/org/qed/Generated/Tests/MinusMergeTest.java new file mode 100644 index 0000000..bc9985e --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/MinusMergeTest.java @@ -0,0 +1,48 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.Generated.RRuleInstances.UnionMerge; +import org.qed.RuleBuilder; + +/** + * Test for the MinusMerge rule. + */ +public class MinusMergeTest { + + /** + * Run test for MinusMerge rule. + */ + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var scan1 = builder.scan(table.getName()).build(); + var scan2 = builder.scan(table.getName()).build(); + var scan3 = builder.scan(table.getName()).build(); + + // (A − B) − C + var before = builder.push(scan1).push(scan2).minus(false, 2).push(scan3).minus(false, 2).build(); + + // A − (B ∪ C) + var union = builder.push(scan2).push(scan3).union(false).build(); + var after = builder.push(scan1).push(union).minus(false, 2).build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.MinusMerge.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + /** + * Main method to run this test independently. + */ + public static void main(String[] args) { + System.out.println("Running MinusMerge test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/Tests/ProjectFilterTransposeTest.java b/src/main/java/org/qed/Generated/Tests/ProjectFilterTransposeTest.java new file mode 100644 index 0000000..05bfc31 --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/ProjectFilterTransposeTest.java @@ -0,0 +1,41 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.RuleBuilder; + +public class ProjectFilterTransposeTest { + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false), + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var scan = builder.scan(table.getName()).build(); + + var before = builder + .push(scan) + .project(builder.field(0)) + .filter(builder.equals(builder.field(0), builder.literal(10))) + .build(); + + var after = builder + .push(scan) + .filter(builder.equals(builder.field(0), builder.literal(10))) + .project(builder.field(0)) + .build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.ProjectFilterTranspose.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running ProjectFilterTranspose test..."); + runTest(); + } +} diff --git a/src/main/java/org/qed/Generated/Tests/SemiJoinFilterTransposeTest.java b/src/main/java/org/qed/Generated/Tests/SemiJoinFilterTransposeTest.java new file mode 100644 index 0000000..11e404a --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/SemiJoinFilterTransposeTest.java @@ -0,0 +1,49 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.Generated.RRuleInstances.SemiJoinFilterTranspose; +import org.qed.RuleBuilder; + +public class SemiJoinFilterTransposeTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + + var leftTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + var rightTable = builder.createQedTable(Seq.of(Tuple.of(RelType.fromString("INTEGER", true), false))); + builder.addTable(leftTable); + builder.addTable(rightTable); + + var leftScan = builder.scan(leftTable.getName()).build(); + var rightScan = builder.scan(rightTable.getName()).build(); + + builder.push(leftScan); + builder.push(rightScan); + var joinPredicate = builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0)); + var semiJoin = builder.join(JoinRelType.SEMI, joinPredicate).build(); + builder.push(semiJoin); + var filterPredicate = builder.call(builder.genericPredicateOp("filter", true), builder.field(0)); + var before = builder.filter(filterPredicate).build(); + + builder.push(leftScan); + var leftFilterPredicate = builder.call(builder.genericPredicateOp("filter", true), builder.field(0)); + var filteredLeft = builder.filter(leftFilterPredicate).build(); + builder.push(filteredLeft); + builder.push(rightScan); + var afterJoinPredicate = builder.equals(builder.field(2, 0, 0), builder.field(2, 1, 0)); + var after = builder.join(JoinRelType.SEMI, afterJoinPredicate).build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.SemiJoinFilterTranspose.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running SemiJoinFilterTranspose test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/Tests/UnionMergeTest.java b/src/main/java/org/qed/Generated/Tests/UnionMergeTest.java new file mode 100644 index 0000000..3ad7e07 --- /dev/null +++ b/src/main/java/org/qed/Generated/Tests/UnionMergeTest.java @@ -0,0 +1,37 @@ +package org.qed.Generated.Tests; + +import kala.collection.Seq; +import kala.tuple.Tuple; +import org.qed.Generated.CalciteTester; +import org.qed.RelType; +import org.qed.Generated.RRuleInstances.UnionMerge; +import org.qed.RuleBuilder; + +public class UnionMergeTest { + + public static void runTest() { + var tester = new CalciteTester(); + var builder = RuleBuilder.create(); + var table = builder.createQedTable(Seq.of( + Tuple.of(RelType.fromString("INTEGER", true), false) + )); + builder.addTable(table); + + var scan1 = builder.scan(table.getName()).build(); + var scan2 = builder.scan(table.getName()).build(); + var scan3 = builder.scan(table.getName()).build(); + + var firstUnion = builder.push(scan1).push(scan2).union(false).build(); + var before = builder.push(firstUnion).push(scan3).union(false).build(); + + var after = builder.push(scan1).push(scan2).push(scan3).union(false, 3).build(); + + var runner = CalciteTester.loadRule(org.qed.Generated.UnionMerge.Config.DEFAULT.toRule()); + tester.verify(runner, before, after); + } + + public static void main(String[] args) { + System.out.println("Running UnionMerge test..."); + runTest(); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/UnionMerge.java b/src/main/java/org/qed/Generated/UnionMerge.java new file mode 100644 index 0000000..d951fa8 --- /dev/null +++ b/src/main/java/org/qed/Generated/UnionMerge.java @@ -0,0 +1,39 @@ +package org.qed.Generated; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.logical.*; + +public class UnionMerge extends RelRule { + protected UnionMerge(Config config) { + super(config); + } + + @Override + public void onMatch(RelOptRuleCall call) { + var var_5 = call.builder(); + call.transformTo(var_5.push(call.rel(2)).push(call.rel(3)).push(call.rel(4)).union(false, 3).build()); + } + + public interface Config extends EmptyConfig { + Config DEFAULT = new Config() {}; + + @Override + default UnionMerge toRule() { + return new UnionMerge(this); + } + + @Override + default String description() { + return "UnionMerge"; + } + + @Override + default RelRule.OperandTransform operandSupplier() { + return s_4 -> s_4.operand(LogicalUnion.class).inputs(s_2 -> s_2.operand(LogicalUnion.class).inputs(s_0 -> s_0.operand(RelNode.class).anyInputs(), s_1 -> s_1.operand(RelNode.class).anyInputs()), s_3 -> s_3.operand(RelNode.class).anyInputs()); + } + + } +} diff --git a/src/main/java/org/qed/JSONDeserializer.java b/src/main/java/org/qed/JSONDeserializer.java index c725c40..a121124 100644 --- a/src/main/java/org/qed/JSONDeserializer.java +++ b/src/main/java/org/qed/JSONDeserializer.java @@ -34,6 +34,94 @@ public record JSONDeserializer() { private final static ObjectMapper mapper = new ObjectMapper(); + private static ImmutableSeq array(JsonNode node) throws Exception { + if (!node.isArray()) throw new Exception(); + return ImmutableSeq.from(node.elements()); + } + + private static ImmutableSeq array(JsonNode node, String path) throws Exception { + return array(node.required(path)); + } + + private static String string(JsonNode node) throws Exception { + if (!node.isTextual()) throw new Exception(); + return node.asText(); + } + + private static String string(JsonNode node, String path) throws Exception { + return string(node.required(path)); + } + + private static int integer(JsonNode node) throws Exception { + if (!node.isInt()) throw new Exception(); + return node.asInt(); + } + + private static int integer(JsonNode node, String path) throws Exception { + return integer(node.required(path)); + } + + private static boolean bool(JsonNode node) throws Exception { + if (!node.isBoolean()) throw new Exception(); + return node.asBoolean(); + } + + static SqlTypeName typeName(String name) { + name = switch (name) { + case "BOOL" -> "BOOLEAN"; + case "INT", "INT2", "INT4", "OID" -> "INTEGER"; + case "TIMESTAMPTZ" -> "TIMESTAMP"; + case "TIMETZ" -> "TIME"; + case "STRING" -> "VARCHAR"; + case "JSONB" -> "VARBINARY"; + default -> name; + }; + return Enum.valueOf(SqlTypeName.class, name); + } + + public static ImmutableSeq load(File file) throws Exception { + return new JSONDeserializer().deserialize(mapper.readTree(file)); + } + + public static void main(String[] args) throws Exception { + var refs = Seq.from(new File("RelOptRulesTest").listFiles()); + for (var file : refs) { + try { + var store = mapper.readTree(file); + new JSONDeserializer().deserialize(store); + } catch (Exception e) { + System.err.println("===> " + file.getName() + " <==="); + System.err.println(e.getMessage()); + System.err.println(); + } + } + } + + public ImmutableSeq deserialize(JsonNode node) throws Exception { + var builder = RuleBuilder.create(); + var tables = array(node, "schemas").mapChecked(schema -> { + var types = array(schema, "types").mapChecked(JSONDeserializer::string); + var nullabilities = array(schema, "nullable").mapChecked(JSONDeserializer::bool); + var name = schema.path("name").asText("DEFAULT_TABLE_NAME"); + var fields = schema.get("fields") == null ? + Seq.fill(types.size(), i -> String.format("DEFAULT_FIELD_NAME_%d", i)) : + array(schema, "fields").mapChecked(JSONDeserializer::string); + var keys = Set.from(array(schema, "key").map( + CheckedFunction.of(key -> ImmutableBitSet.of(array(key).mapChecked(JSONDeserializer::integer))))); + if (types.size() != nullabilities.size()) + throw new Exception("Expecting corresponding types and nullabilities"); + var sts = types.zip(nullabilities).map(tn -> { + var type = builder.getTypeFactory().createSqlType(typeName(tn.component1())); + return builder.getTypeFactory().createTypeWithNullability(type, tn.component2()); + }); + var table = new QedTable(name, fields, sts, keys, Set.empty()); + builder.addTable(table); + return table; + }); + var rel = new Rel(builder, ImmutableSeq.empty(), tables); + return array(node, "queries").mapChecked(rel); + } + private record Rel(RuleBuilder builder, ImmutableSeq globals, ImmutableSeq tables) implements CheckedFunction { Rel(RuleBuilder builder) { @@ -156,6 +244,15 @@ yield builder().push(sorted).sortLimit(rex().deserialize(content.required("offse private record Rex(RuleBuilder builder, ImmutableSeq globals, RexCorrelVariable local, ImmutableSeq tables) implements CheckedFunction { + static Seq ops = Seq.from(SqlStdOperatorTable.class.getDeclaredFields()) + .filter(f -> java.lang.reflect.Modifier.isPublic(f.getModifiers()) && + java.lang.reflect.Modifier.isStatic(f.getModifiers())).map(f -> { + var mist = Try.of(() -> f.get(null)).getOrNull(); + if (mist == null) return null; + if (mist instanceof SqlOperator op) return op; + return null; + }).filter(Objects::nonNull); + public RexNode resolve(int lvl) { assert lvl < globals().size() + local().getType().getFieldCount(); return lvl < globals().size() ? globals().get(lvl) : builder().getRexBuilder() @@ -177,15 +274,6 @@ public RelDataType type(String name) { return builder().getTypeFactory().createSqlType(typeName(name)); } - static Seq ops = Seq.from(SqlStdOperatorTable.class.getDeclaredFields()) - .filter(f -> java.lang.reflect.Modifier.isPublic(f.getModifiers()) && - java.lang.reflect.Modifier.isStatic(f.getModifiers())).map(f -> { - var mist = Try.of(() -> f.get(null)).getOrNull(); - if (mist == null) return null; - if (mist instanceof SqlOperator op) return op; - return null; - }).filter(Objects::nonNull); - SqlOperator op(String name, int arity) throws Exception { switch (name) { case "BOOL_AND" -> { @@ -273,93 +361,5 @@ public RexNode deserialize(JsonNode node) throws Exception { } } } - - private static ImmutableSeq array(JsonNode node) throws Exception { - if (!node.isArray()) throw new Exception(); - return ImmutableSeq.from(node.elements()); - } - - private static ImmutableSeq array(JsonNode node, String path) throws Exception { - return array(node.required(path)); - } - - private static String string(JsonNode node) throws Exception { - if (!node.isTextual()) throw new Exception(); - return node.asText(); - } - - private static String string(JsonNode node, String path) throws Exception { - return string(node.required(path)); - } - - private static int integer(JsonNode node) throws Exception { - if (!node.isInt()) throw new Exception(); - return node.asInt(); - } - - private static int integer(JsonNode node, String path) throws Exception { - return integer(node.required(path)); - } - - private static boolean bool(JsonNode node) throws Exception { - if (!node.isBoolean()) throw new Exception(); - return node.asBoolean(); - } - - static SqlTypeName typeName(String name) { - name = switch (name) { - case "BOOL" -> "BOOLEAN"; - case "INT", "INT2", "INT4", "OID" -> "INTEGER"; - case "TIMESTAMPTZ" -> "TIMESTAMP"; - case "TIMETZ" -> "TIME"; - case "STRING" -> "VARCHAR"; - case "JSONB" -> "VARBINARY"; - default -> name; - }; - return Enum.valueOf(SqlTypeName.class, name); - } - - public ImmutableSeq deserialize(JsonNode node) throws Exception { - var builder = RuleBuilder.create(); - var tables = array(node, "schemas").mapChecked(schema -> { - var types = array(schema, "types").mapChecked(JSONDeserializer::string); - var nullabilities = array(schema, "nullable").mapChecked(JSONDeserializer::bool); - var name = schema.path("name").asText("DEFAULT_TABLE_NAME"); - var fields = schema.get("fields") == null ? - Seq.fill(types.size(), i -> String.format("DEFAULT_FIELD_NAME_%d", i)) : - array(schema, "fields").mapChecked(JSONDeserializer::string); - var keys = Set.from(array(schema, "key").map( - CheckedFunction.of(key -> ImmutableBitSet.of(array(key).mapChecked(JSONDeserializer::integer))))); - if (types.size() != nullabilities.size()) - throw new Exception("Expecting corresponding types and nullabilities"); - var sts = types.zip(nullabilities).map(tn -> { - var type = builder.getTypeFactory().createSqlType(typeName(tn.component1())); - return builder.getTypeFactory().createTypeWithNullability(type, tn.component2()); - }); - var table = new QedTable(name, fields, sts, keys, Set.empty()); - builder.addTable(table); - return table; - }); - var rel = new Rel(builder, ImmutableSeq.empty(), tables); - return array(node, "queries").mapChecked(rel); - } - - public static ImmutableSeq load(File file) throws Exception { - return new JSONDeserializer().deserialize(mapper.readTree(file)); - } - - public static void main(String[] args) throws Exception { - var refs = Seq.from(new File("RelOptRulesTest").listFiles()); - for (var file : refs) { - try { - var store = mapper.readTree(file); - new JSONDeserializer().deserialize(store); - } catch (Exception e) { - System.err.println("===> " + file.getName() + " <==="); - System.err.println(e.getMessage()); - System.err.println(); - } - } - } } diff --git a/src/main/java/org/qed/JSONSerializer.java b/src/main/java/org/qed/JSONSerializer.java index bf8f4ec..1d077f8 100644 --- a/src/main/java/org/qed/JSONSerializer.java +++ b/src/main/java/org/qed/JSONSerializer.java @@ -21,32 +21,68 @@ public record JSONSerializer(Env env) { private final static ObjectMapper mapper = new ObjectMapper(); - private record Rel(Env env) { - Rel() { - this(new Env(0, ImmutableMap.empty(), MutableList.create())); + private static ArrayNode array(Seq objs) { + return new ArrayNode(mapper.getNodeFactory(), objs.asJava()); + } + + private static ObjectNode object(Map fields) { + return new ObjectNode(mapper.getNodeFactory(), fields.asJava()); + } + + private static BooleanNode bool(boolean b) { + return BooleanNode.valueOf(b); + } + + private static TextNode string(String s) { + return new TextNode(s); + } + + private static TextNode type(RelDataType type) { + if (type instanceof RelType.VarType varType) { + return new TextNode(varType.getName()); } + return new TextNode(type.getSqlTypeName().getName()); + } - private record Env(int lvl, ImmutableMap globals, MutableList tables) { - Env recorded(Set ids) { - return new Env(lvl, Seq.from(ids).foldLeft(globals, (g, id) -> g.putted(id, lvl)), tables); - } + private static IntNode integer(int i) { + return new IntNode(i); + } - Env lifted(int d) { - return new Env(lvl + d, globals, tables); - } + private static String qualifiedTableName(RelOptTable table) { + return Seq.from(table.getQualifiedName()).joinToString("."); + } - int resolve(RelOptTable table) { - var idx = tables.indexOf(table); - if (idx == -1) { - idx = tables.size(); - tables.append(table); - } - return idx; - } + public static ObjectNode serialize(Seq relNodes) { + var shuttle = new Rel(); + var helps = array(relNodes.map(rel -> new TextNode(rel.explain()))); + var queries = array(relNodes.map(shuttle::serialize)); + var tables = shuttle.env.tables(); + var schemas = array(tables.map(table -> { + var visitor = new Rex(shuttle.env.rex(table.getRowType().getFieldCount())); + var qedTable = table.unwrap(QedTable.class); + var fields = Seq.from(table.getRowType().getFieldList()); + return qedTable == null ? + object(Map.of("name", string(qualifiedTableName(table)), "fields", + array(fields.map(field -> string(field.getName()))), "types", + array(fields.map(field -> type(field.getType()))), "nullable", + array(fields.map(field -> bool(field.getType().isNullable()))), "key", + array((table.getKeys() != null ? Seq.from(table.getKeys()) : + Seq.empty()).map( + key -> array(Seq.from(key).map(JSONSerializer::integer)))), "guaranteed", + array(Seq.empty()))) : object(Map.of("name", string(qedTable.getName()), "fields", + array(qedTable.getColumnNames().map(JSONSerializer::string)), "types", + array(qedTable.getColumnTypes().map(JSONSerializer::type)), "nullable", + array(qedTable.getColumnTypes().map(type -> bool(type.isNullable()))), "key", + array(Seq.from(qedTable.getKeys().map(key -> array(Seq.from(key).map(JSONSerializer::integer))))), + "guaranteed", array(qedTable.getConstraints().map(visitor::serialize).toImmutableSeq()))); + })); - public Rex.Env rex(int delta) { - return new Rex.Env(lvl, delta, globals, tables); - } + return object(Map.of("schemas", schemas, "queries", queries, "help", helps)); + } + + private record Rel(Env env) { + Rel() { + this(new Env(0, ImmutableMap.empty(), MutableList.create())); } public JsonNode serialize(RelNode rel) { @@ -132,20 +168,32 @@ yield object(Map.of("sort", default -> throw new RuntimeException("Not implemented: " + rel.getRelTypeName()); }; } - } - private record Rex(Env env) { - private record Env(int base, int delta, ImmutableMap globals, - MutableList tables) { - public Rel.Env rel() { - return new Rel.Env(base + delta, globals, tables); + private record Env(int lvl, ImmutableMap globals, MutableList tables) { + Env recorded(Set ids) { + return new Env(lvl, Seq.from(ids).foldLeft(globals, (g, id) -> g.putted(id, lvl)), tables); } - int resolve(CorrelationId id) { - return globals.getOrThrow(id, () -> new RuntimeException("Correlation ID not declared")); + Env lifted(int d) { + return new Env(lvl + d, globals, tables); + } + + int resolve(RelOptTable table) { + var idx = tables.map(JSONSerializer::qualifiedTableName).indexOf(qualifiedTableName(table)); + if (idx == -1) { + idx = tables.size(); + tables.append(table); + } + return idx; + } + + public Rex.Env rex(int delta) { + return new Rex.Env(lvl, delta, globals, tables); } } + } + private record Rex(Env env) { public JsonNode serialize(RexNode rex) { return switch (rex) { case RexInputRef inputRef -> object(Map.of("column", integer(inputRef.getIndex() + env.base()), "type", @@ -165,57 +213,16 @@ public JsonNode serialize(RexNode rex) { default -> throw new RuntimeException("Not implemented: " + rex.getKind()); }; } - } - - private static ArrayNode array(Seq objs) { - return new ArrayNode(mapper.getNodeFactory(), objs.asJava()); - } - - private static ObjectNode object(Map fields) { - return new ObjectNode(mapper.getNodeFactory(), fields.asJava()); - } - - private static BooleanNode bool(boolean b) { - return BooleanNode.valueOf(b); - } - - private static TextNode string(String s) { - return new TextNode(s); - } - - private static TextNode type(RelDataType type) { - return new TextNode(type.getSqlTypeName().getName()); - } - - private static IntNode integer(int i) { - return new IntNode(i); - } - public static ObjectNode serialize(Seq relNodes) { - var shuttle = new Rel(); - var helps = array(relNodes.map(rel -> new TextNode(rel.explain()))); - var queries = array(relNodes.map(shuttle::serialize)); - var tables = shuttle.env.tables(); - var schemas = array(tables.map(table -> { - var visitor = new Rex(shuttle.env.rex(table.getRowType().getFieldCount())); - var qedTable = table.unwrap(QedTable.class); - var fields = Seq.from(table.getRowType().getFieldList()); - return qedTable == null ? - object(Map.of("name", string(Seq.from(table.getQualifiedName()).joinToString(".")), "fields", - array(fields.map(field -> string(field.getName()))), "types", - array(fields.map(field -> type(field.getType()))), "nullable", - array(fields.map(field -> bool(field.getType().isNullable()))), "key", - array((table.getKeys() != null ? Seq.from(table.getKeys()) : - Seq.empty()).map( - key -> array(Seq.from(key).map(JSONSerializer::integer)))), "guaranteed", - array(Seq.empty()))) : object(Map.of("name", string(qedTable.getName()), "fields", - array(qedTable.getColumnNames().map(JSONSerializer::string)), "types", - array(qedTable.getColumnTypes().map(JSONSerializer::type)), "nullable", - array(qedTable.getColumnTypes().map(type -> bool(type.isNullable()))), "key", - array(Seq.from(qedTable.getKeys().map(key -> array(Seq.from(key).map(JSONSerializer::integer))))), - "guaranteed", array(qedTable.getConstraints().map(visitor::serialize).toImmutableSeq()))); - })); + private record Env(int base, int delta, ImmutableMap globals, + MutableList tables) { + public Rel.Env rel() { + return new Rel.Env(base + delta, globals, tables); + } - return object(Map.of("schemas", schemas, "queries", queries, "help", helps)); + int resolve(CorrelationId id) { + return globals.getOrThrow(id, () -> new RuntimeException("Correlation ID not declared")); + } + } } } diff --git a/src/main/java/org/qed/RRule.java b/src/main/java/org/qed/RRule.java new file mode 100644 index 0000000..a548f9a --- /dev/null +++ b/src/main/java/org/qed/RRule.java @@ -0,0 +1,115 @@ +package org.qed; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import kala.collection.Map; +import kala.collection.Seq; + +import java.io.File; +import java.io.IOException; + +public interface RRule { + RelRN before(); + + RelRN after(); + + default String explain() { + return getClass().getName() + + "\n" + + before().semantics().explain() + + "=>" + + "\n" + + after().semantics().explain(); + } + + default String name() { + return getClass().getSimpleName(); + } + + default String info() { + return ""; + } + + default ObjectNode toJson() { + return JSONSerializer.serialize(Seq.of(before().semantics(), after().semantics())); + } + + default void dump(String path) throws IOException { + new ObjectMapper().writerWithDefaultPrettyPrinter().writeValue(new File(path), toJson()); + } + + interface RRuleFamily { + Seq family(); + } + + record RRuleGenerator(RRule rule, + Seq assignments) implements RRuleFamily { + @Override + public Seq family() { + return assignments.map(assignment -> new RRule() { + + @Override + public RelRN before() { + return assignment.replaceMetaRelRN(rule.before()); + } + + @Override + public RelRN after() { + return assignment.replaceMetaRelRN(rule.after()); + } + + @Override + public String name() { + return rule.name(); + } + + @Override + public String info() { + return assignment.info(); + } + }); + } + + public record MetaAssignment( + Map joinTypeAssignment) { + public RelRN.Join.JoinType replaceMetaJoinType(RelRN.Join.JoinType joinType) { + return switch (joinType) { + case RelRN.Join.JoinType.MetaJoinType metaJoinType -> joinTypeAssignment.get(metaJoinType); + default -> joinType; + }; + } + + public RexRN replaceMetaRexRN(RexRN rexRN) { + return switch (rexRN) { + case RexRN.Field field -> replaceMetaRelRN(field.source()).field(field.ordinal()); + default -> customReplaceMetaRexRN(rexRN); + }; + } + + public RelRN replaceMetaRelRN(RelRN relRN) { + return switch (relRN) { + case RelRN.Filter filter -> + replaceMetaRelRN(filter.source()).filter(replaceMetaRexRN(filter.cond())); + case RelRN.Join join -> + replaceMetaRelRN(join.left()).join(replaceMetaJoinType(join.ty()), + replaceMetaRexRN(join.cond()), replaceMetaRelRN(join.right())); + default -> customReplaceMetaRelRN(relRN); + }; + } + + public RexRN customReplaceMetaRexRN(RexRN rexRN) { + return rexRN; + } + + public RelRN customReplaceMetaRelRN(RelRN relRN) { + return relRN; + } + + public String info() { + return joinTypeAssignment.joinToString("&", (m, c) -> "{" + m.name() + "}=" + c.semantics()); + } + + } + } +} + diff --git a/src/main/java/org/qed/RRuleInstance.java b/src/main/java/org/qed/RRuleInstance.java new file mode 100644 index 0000000..f7e3711 --- /dev/null +++ b/src/main/java/org/qed/RRuleInstance.java @@ -0,0 +1,481 @@ +package org.qed; + +import kala.collection.Map; +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.qed.RRuleInstance.JoinAssociate; +// import org.qed.RRuleInstance.JoinConditionPush.JoinPred; + +public interface RRuleInstance { + record FilterIntoJoin() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + var join = left.join(JoinRelType.INNER, joinCond, right); + return join.filter("outer"); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.INNER, RexRN.and(joinCond, left.joinPred("outer", right)), right); + } + } + + record FilterMerge() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN inner = source.pred("inner"); + static final RexRN outer = source.pred("outer"); + + @Override + public RelRN before() { + return source.filter(inner).filter(outer); + } + + @Override + public RelRN after() { + return source.filter(RexRN.and(inner, outer)); + } + } + + record FilterProjectTranspose() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN proj = source.proj("proj", "Project_Type"); + + @Override + public RelRN before() { + return source.filter(proj.pred("pred")).project(proj); + } + + @Override + public RelRN after() { + return source.project(proj).filter("pred"); + } + } + + record FilterReduceFalse() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + + @Override + public RelRN before() { + return source.filter(RexRN.falseLiteral()); + } + + @Override + public RelRN after() { + return source.empty(); + } + } + + record FilterReduceTrue() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + + @Override + public RelRN before() { + return source.filter(RexRN.trueLiteral()); + } + + @Override + public RelRN after() { + return source; + } + } + + // TBD: include intersect to make it a rule familiy + record FilterSetOpTranspose() implements RRule { + static final RelRN left = RelRN.scan("Left", "Common_Type"); + static final RelRN right = RelRN.scan("Right", "Common_Type"); + + @Override + public RelRN before() { + RelRN projTmp = left.union(false, right); + return projTmp.filter(projTmp.pred("filter")); + } + + @Override + public RelRN after() { + RexRN leftPred = left.pred("filter"); + RexRN rightPred = right.pred("filter"); + return left.filter(leftPred).union(false, right.filter(rightPred)); + } + } + + record IntersectMerge() implements RRule { + // Use a common type for all relations to make them compatible + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + static final RelRN c = RelRN.scan("C", "Common_Type"); + + @Override + public RelRN before() { + // Nested INTERSECT: (A INTERSECT B) INTERSECT C + return a.intersect(false, b).intersect(false, c); + } + + @Override + public RelRN after() { + // Flattened INTERSECT: A INTERSECT B INTERSECT C + return a.intersect(false, b, c); + } + } + + // record JoinConditionPush() implements RRule { + // static final RelRN left = RelRN.scan("Left", "Left_Type"); + // static final RelRN right = RelRN.scan("Right", "Right_Type"); + // static final JoinPred joinPred = new JoinPred(left, right); + + // @Override + // public RelRN before() { + // return left.join(JoinRelType.INNER, joinPred, right); + // } + + // @Override + // public RelRN after() { + // var leftRN = left.filter(joinPred.leftPred()); + // var rightRN = right.filter(joinPred.rightPred()); + // return leftRN.join(JoinRelType.INNER, joinPred.bothPred(), rightRN); + // } + + // public record JoinPred(RelRN left, RelRN right) implements RexRN { + + // @Override + // public RexNode semantics() { + // return RexRN.and(left.joinPred(bothPred(), right), left.joinField(0, right).pred(leftPred()), + // left.joinField(1, right).pred(rightPred())).semantics(); + // } + + // public String bothPred() {return "both";} + + // public String leftPred() {return "left";} + + // public String rightPred() {return "right";} + + // } + // } + + record JoinAddRedundantSemiJoin() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final String pred = "pred"; + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, pred, right); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.SEMI, pred, right).join(JoinRelType.INNER, pred, right); + } + } + + // Todo: explore join types, see line 102 of JoinAssociateRule + record JoinAssociate() implements RRule.RRuleFamily { + static final RelRN a = RelRN.scan("A", "A_Type"); + static final RelRN b = RelRN.scan("B", "B_Type"); + static final RelRN c = RelRN.scan("C", "C_Type"); + static final String pred_ab = "pred_ab"; + static final String pred_bc = "pred_bc"; + static final RelRN.Join.JoinType.MetaJoinType mjt_0 = new RelRN.Join.JoinType.MetaJoinType("mjt_0"); + static final RelRN.Join.JoinType.MetaJoinType mjt_1 = new RelRN.Join.JoinType.MetaJoinType("mjt_1"); + static final RelRN.Join.JoinType.MetaJoinType mjt_2 = new RelRN.Join.JoinType.MetaJoinType("mjt_2"); + static final RelRN.Join.JoinType.MetaJoinType mjt_3 = new RelRN.Join.JoinType.MetaJoinType("mjt_3"); + + static final RelRN before_ab = a.join(mjt_0, RexRN.and( + a.joinPred(pred_ab, b), + new RexRN.JoinField(1, a, b).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), b); + + static final RelRN before = before_ab.join(mjt_1, RexRN.and( + new RexRN.Pred(RuleBuilder.create().genericPredicateOp(pred_bc, true), before_ab.joinFields(c, 1, 2)), + new RexRN.JoinField(1, before_ab, c).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), c); + + static final RelRN after_bc = b.join(mjt_2, RexRN.and( + b.joinPred(pred_bc, c), + new RexRN.JoinField(0, b, c).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), c); + + static final RelRN after = a.join(mjt_3, RexRN.and( + new RexRN.Pred(RuleBuilder.create().genericPredicateOp(pred_ab, true), a.joinFields(after_bc, 0, 1)), + new RexRN.JoinField(1, a, after_bc).pred(SqlStdOperatorTable.IS_NOT_NULL) + ), after_bc); + + static final RRule template = new RRule() { + @Override + public RelRN before() { + return before; + } + + @Override + public RelRN after() { + return after; + } + + @Override + public String name() { + return JoinAssociate.class.getSimpleName(); + } + }; + + static Seq assignments() { + var joinTypes = Seq.of(JoinRelType.INNER, JoinRelType.LEFT, JoinRelType.RIGHT, JoinRelType.FULL).map(RelRN.Join.JoinType.ConcreteJoinType::new); + return joinTypes.flatMap(jt0 -> joinTypes.flatMap(jt1 -> joinTypes.flatMap(jt2 -> joinTypes.map(jt3 -> new RRule.RRuleGenerator.MetaAssignment(Map.of(mjt_0, jt0, mjt_1, jt1, mjt_2, jt2, mjt_3, jt3)))))); + } + + @Override + public Seq family() { + return new RRule.RRuleGenerator(template, assignments()).family(); + } + } + + record JoinCommute() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("pred", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, joinCond, right); + } + + @Override + public RelRN after() { + // We need to swap the join fields in the condition + RexRN commutedJoinCond = right.joinPred("pred", left); + return right.join(JoinRelType.INNER, commutedJoinCond, left); + } + } + + record JoinExtractFilter() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, joinCond, right); + } + + @Override + public RelRN after() { + return left.join(JoinRelType.INNER, RexRN.trueLiteral(), right).filter(joinCond); + } + } + +// record JoinProjectTranspose() implements RRule { +// +// } + + // JoinConditionPush? +// record JoinPushExpressions() implements RRule { +// +// } + + // JoinConditionPush? +// record JoinPushTransitivePredicates() implements RRule { +// +// } + +// record JoinToSemiJoin() implements RRule { +// +// } + +// record JoinLeftUnionTranspose() implements RRule { +// +// } + +// record JoinRightUnionTranspose() implements RRule { +// +// } + +// record ProjectJoinRemove() implements RRule { +// +// @Override +// public RelRN before() { +// return null; +// } +// +// @Override +// public RelRN after() { +// return null; +// } +// } + +// record ProjectJoinJoinRemove() implements RRule { +// +// } + + record ProjectJoinTranspose() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN proj = left.proj("proj", "Project_Type"); + static final String joinCond = left.joinPred("join", right).toString(); + + @Override + public RelRN before() { + return left.join(JoinRelType.INNER, joinCond, right).project(proj); + } + + @Override + public RelRN after() { + return left.project(proj).join(JoinRelType.INNER, joinCond, right); + } + } + + record ProjectMerge() implements RRule { + static final RelRN source = RelRN.scan("Source", "Source_Type"); + static final RexRN inner = source.proj("inner", "Inner_Type"); + static final String outer = "outer"; + static final String outerType = "Outer_Type"; + + @Override + public RelRN before() { + return source.project(inner).project(outer, outerType); + } + + @Override + public RelRN after() { + return source.project(inner.proj(outer, outerType)); + } + } + + //TBD: currently provable for UNION ALL while unprovable for UNION + // record ProjectSetOpTranspose() implements RRule { + // static final RelRN left = RelRN.scan("Left", "Common_Type"); + // static final RelRN right = RelRN.scan("Right", "Common_Type"); + + // @Override + // public RelRN before() { + // RelRN projTmp = left.union(true, right); + // return projTmp.project(projTmp.proj("proj", "Proj_Type")); + // } + + // @Override + // public RelRN after() { + // RelRN projA = left.project(left.proj("proj", "Proj_Type")); + // RelRN projB = right.project(right.proj("proj", "Proj_Type")); + // return projA.union(true, projB); + // } + // } + + + /* TBD: Already optimized by calcite? */ + // record ProjectRemove() implements RRule { + // static final RelRN source = RelRN.scan("Source", "Source_Type"); + + // @Override + // public RelRN before() { + // return source.project(source.field(0)); + // } + + // @Override + // public RelRN after() { + // return source; + // } + // } + + record UnionMerge() implements RRule { + static final RelRN a = RelRN.scan("A", "Common_Type"); + static final RelRN b = RelRN.scan("B", "Common_Type"); + static final RelRN c = RelRN.scan("C", "Common_Type"); + + @Override + public RelRN before() { + return a.union(false, b).union(false, c); + } + + @Override + public RelRN after() { + return null; + } + } + + record SemiJoinFilterTranspose() implements RRule { + static final RelRN left = RelRN.scan("left", "Left_Type"); + static final RelRN right = RelRN.scan("right", "Right_Type"); + static final RexRN pred = left.pred("pred"); + static final RexRN joinCond = left.joinPred("join", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.SEMI, joinCond, right).filter(pred); + } + + @Override + public RelRN after() { + return left.filter(pred).join(JoinRelType.SEMI, joinCond, right); + } + } + + record SemiJoinJoinTranspose() implements RRule { + static final RelRN r = RelRN.scan("R", "R_Type"); + static final RelRN s = RelRN.scan("S", "S_Type"); + static final RelRN t = RelRN.scan("T", "T_Type"); + static final RexRN semiCond = r.joinPred("semi", s); + static final RexRN joinCond = r.joinPred("join", t); + + @Override + public RelRN before() { + return r.join(JoinRelType.INNER, joinCond, t).join(JoinRelType.SEMI, semiCond, s); + } + + @Override + public RelRN after() { + return r.join(JoinRelType.SEMI, semiCond, s).join(JoinRelType.INNER, joinCond, t); + } + } + + record SemiJoinProjectTranspose() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + static final RexRN proj = left.proj("proj", "Project_Type"); + static final RexRN semiCond = left.joinPred("semi", right); + + @Override + public RelRN before() { + return left.join(JoinRelType.SEMI, semiCond, right).project(proj); + } + + @Override + public RelRN after() { + return left.project(proj).join(JoinRelType.SEMI, semiCond, right); + } + } + + record SemiJoinRemove() implements RRule { + static final RelRN left = RelRN.scan("Left", "Left_Type"); + static final RelRN right = RelRN.scan("Right", "Right_Type"); + + @Override + public RelRN before() { + return left.join(JoinRelType.SEMI, RexRN.trueLiteral(), right); + } + + @Override + public RelRN after() { + return left; + } + } + +// record UnionMerge() implements RRule { +// +// } + +// record UnionRemove() implements RRule { +// +// } +} + +/* + * Semantically identical cases: + * FilterExpandIsNotDistinctFrom + * FilterScan + * JoinReduceExpression + * ProjectReduceExpression + * ProjectTableScan + */ diff --git a/src/main/java/org/qed/RelJSONShuttle.java b/src/main/java/org/qed/RelJSONShuttle.java deleted file mode 100644 index 1472f09..0000000 --- a/src/main/java/org/qed/RelJSONShuttle.java +++ /dev/null @@ -1,364 +0,0 @@ -package org.qed; - -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.*; -import kala.collection.Map; -import kala.collection.Seq; -import kala.collection.Set; -import kala.collection.immutable.ImmutableSeq; -import kala.control.Result; -import org.apache.calcite.rel.RelNode; -import org.apache.calcite.rel.core.TableScan; -import org.apache.calcite.rel.logical.*; -import org.apache.calcite.rel.type.*; -import org.apache.calcite.rex.*; -import org.apache.calcite.sql.SqlOperator; -import org.apache.calcite.sql.fun.SqlStdOperatorTable; -import org.apache.calcite.util.ImmutableBitSet; - -import java.io.IOException; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.List; - -public record RelJSONShuttle(Env env) { - private final static ObjectMapper mapper = new ObjectMapper(); - - private static ArrayNode array(Seq objs) { - return new ArrayNode(mapper.getNodeFactory(), objs.asJava()); - } - - - private static Result, String> array(JsonNode jsonNode, String field) { - var arr = jsonNode.get(field); - if (arr == null || !arr.isArray()) { - return Result.err(String.format("Missing array field %s in:\n%s", field, jsonNode.toPrettyString())); - } - return Result.ok(ImmutableSeq.from(arr.elements())); - } - - private static ObjectNode object(Map fields) { - return new ObjectNode(mapper.getNodeFactory(), fields.asJava()); - } - - private static Result object(JsonNode jsonNode, String field) { - var obj = jsonNode.get(field); - if (obj == null) { - return Result.err(String.format("Missing object field %s in:\n%s", field, jsonNode.toPrettyString())); - } - return Result.ok(obj); - } - - private static BooleanNode bool(boolean b) { - return b ? BooleanNode.TRUE : BooleanNode.FALSE; - } - - private static T unwrap(Result res) throws Exception { - if (res.isErr()) { - throw new Exception(res.getErr()); - } - return res.get(); - } - - public static void main(String[] args) throws IOException { - var res = RelJSONShuttle.deserializeFromJson(Paths.get("ElevatedRules/filterProjectTranspose.json")); - if (res.isErr()) { - System.out.println(res.getErr()); - } else { - res.get().forEach(r -> System.out.println(r.explain())); - } - } - - public static void serializeToJson(List relNodes, Path path) throws IOException { - var shuttle = new RelJSONShuttle(Env.empty()); - var helps = array(Seq.from(relNodes).map(rel -> new TextNode(rel.explain()))); - var queries = array(Seq.from(relNodes).map(shuttle::serialize)); - - var tables = shuttle.env.tables(); - var schemas = array(tables.map(table -> object(Map.of( - "name", new TextNode(table.getName()), - "fields", array(table.getColumnNames().map(TextNode::new)), - "types", array(table.getColumnTypes().map(type -> new TextNode(type.toString()))), - "nullable", array(table.getColumnTypes().map(RelDataType::isNullable).map(RelJSONShuttle::bool)), - "key", array(Seq.from(table.getKeys().map(key -> array(Seq.from(key).map(IntNode::new))))), - "guaranteed", array(table.getConstraints() - .map(check -> new RexJSONVisitor(shuttle.env.advanced(table.getColumnNames().size())).serialize(check)).toImmutableSeq()) - )))); - - var main = object(Map.of("schemas", schemas, "queries", queries, "help", helps)); - mapper.writerWithDefaultPrettyPrinter().writeValue(path.toFile(), main); - } - - public static Result, String> deserializeFromJson(Path path) throws IOException { - var node = mapper.readTree(path.toFile()); - var env = Env.empty(); - var tables = array(node, "schemas").flatMap(schemas -> { - var collected = ImmutableSeq.empty(); - for (var schema : schemas) { - try { - var tys = unwrap(array(schema, "types")); - var nbs = unwrap(array(schema, "nullable")); - var nm = unwrap(object(schema, "name")); - var fds = unwrap(array(schema, "fields")).map(JsonNode::asText); - var kys = unwrap(array(schema, "key")); - var kgs = Set.from(kys.map(kg -> ImmutableBitSet.of(Seq.from(kg.elements()).map(JsonNode::asInt)))); - if (tys.size() != nbs.size()) { - return Result.err("Expecting corresponding types and nullabilities"); - } - var sts = tys.zip(nbs).map(tn -> (RelDataType) RelType.fromString(tn.component1().asText(), - tn.component2().asBoolean())); - collected = collected.appended(new QedTable(nm.asText(), fds, sts, kgs, Set.empty())); - } catch (Exception e) { - return Result.err( - String.format("Broken table schemas: %s in\n%s", e.getMessage(), schema.toPrettyString())); - } - } - return Result.ok(collected); - }); - if (tables.isErr()) { - return Result.err(tables.getErr()); - } - env.tables().appendAll(tables.get()); - var queries = array(node, "queries"); - if (queries.isErr()) { - return Result.err(queries.getErr()); - } - var shuttle = new RelJSONShuttle(env); - return queries.get().map(q -> { - var builder = RuleBuilder.create(); - tables.get().forEach(builder::addTable); - return shuttle.deserialize(builder, q); - }).foldLeft(Result.ok(ImmutableSeq.empty()), (qs, qb) -> qs.flatMap(s -> qb.map(b -> s.appended(b.build())))); - } - - public JsonNode serialize(RelNode rel) { - return switch (rel) { - case TableScan scan -> - object(Map.of("scan", new IntNode(env.resolve(scan.getTable().unwrap(QedTable.class))))); - case LogicalValues values -> { - var visitor = new RexJSONVisitor(env); - var schema = array(Seq.from(values.getRowType().getFieldList()) - .map(field -> new TextNode(field.getType().toString()))); - var records = array(Seq.from(values.getTuples()) - .map(tuple -> array(Seq.from(tuple).map(visitor::serialize)))); - yield object(Map.of("values", object(Map.of("schema", schema, "content", records)))); - } - case LogicalFilter filter -> { - var visitor = new RexJSONVisitor(env.advanced(filter.getInput().getRowType().getFieldCount()) - .recorded(filter.getVariablesSet())); - yield object(Map.of("filter", - object(Map.of("condition", visitor.serialize(filter.getCondition()), "source", - serialize(filter.getInput()))))); - } - case LogicalProject project -> { - var visitor = new RexJSONVisitor(env.advanced(project.getInput().getRowType().getFieldCount()) - .recorded(project.getVariablesSet())); - var targets = array(Seq.from(project.getProjects()).map(visitor::serialize)); - yield object( - Map.of("project", object(Map.of("target", targets, "source", serialize(project.getInput()))))); - } - case LogicalJoin join -> { - var left = join.getLeft(); - var right = join.getRight(); - var visitor = new RexJSONVisitor( - env.advanced(left.getRowType().getFieldCount() + right.getRowType().getFieldCount()) - .recorded(join.getVariablesSet())); - yield object(Map.of("join", - object(Map.of("kind", new TextNode(join.getJoinType().toString()), "condition", - visitor.serialize(join.getCondition()), "left", serialize(left), "right", - serialize(right))))); - } - case LogicalCorrelate correlate -> { - var rightShuttle = new RelJSONShuttle(env.advanced(correlate.getLeft().getRowType().getFieldCount()) - .recorded(correlate.getVariablesSet()).advanced(0)); - yield object(Map.of("correlate", - array(Seq.of(serialize(correlate.getLeft()), rightShuttle.serialize(correlate.getRight()))))); - } - case LogicalAggregate aggregate -> { - var groupCount = aggregate.getGroupCount(); - var level = env.base(); - var types = Seq.from(aggregate.getInput().getRowType().getFieldList()) - .map(type -> new TextNode(type.getType().toString())); - var keyCols = array(Seq.from(aggregate.getGroupSet()) - .map(key -> object(Map.of("column", new IntNode(level + key), "type", types.get(key))))); - var keys = object(Map.of("project", - object(Map.of("target", keyCols, "source", serialize(aggregate.getInput()))))); - var conditions = array(Seq.from(aggregate.getGroupSet()).mapIndexed((i, key) -> { - var type = types.get(key); - var leftCol = object(Map.of("column", new IntNode(level + i), "type", type)); - var rightCol = object(Map.of("column", new IntNode(level + groupCount + key), "type", type)); - return object( - Map.of("operator", new TextNode("<=>"), "operand", array(Seq.of(leftCol, rightCol)), "type", - new TextNode("BOOLEAN"))); - })); - var condition = object(Map.of("operator", new TextNode("AND"), "operand", conditions, "type", - new TextNode("BOOLEAN"))); - var aggs = array(Seq.from(aggregate.getAggCallList()).map(call -> object( - Map.of("operator", new TextNode(call.getAggregation().getName()), "operand", - array(Seq.from(call.getArgList()).map(target -> object( - Map.of("column", new IntNode(level + groupCount + target), "type", - types.get(target))))), "distinct", bool(call.isDistinct()), - "ignoreNulls", bool(call.ignoreNulls()), "type", - new TextNode(call.getType().toString()))))); - var aggregated = object(Map.of("aggregate", object(Map.of("function", aggs, "source", - object(Map.of("filter", object(Map.of("condition", condition, "source", - new RelJSONShuttle(env.lifted(groupCount)).serialize(aggregate.getInput()))))))))); - yield object(Map.of("distinct", object(Map.of("correlate", array(Seq.of(keys, aggregated)))))); - } - case LogicalUnion union -> { - var result = object(Map.of("union", array(Seq.from(union.getInputs()).map(this::serialize)))); - yield union.all ? result : object(Map.of("distinct", result)); - } - case LogicalIntersect intersect when !intersect.all -> - object(Map.of("intersect", array(Seq.from(intersect.getInputs()).map(this::serialize)))); - case LogicalMinus minus when !minus.all -> - object(Map.of("except", array(Seq.from(minus.getInputs()).map(this::serialize)))); - case LogicalSort sort -> { - var types = Seq.from(sort.getInput().getRowType().getFieldList()) - .map(type -> new TextNode(type.getType().toString())); - var collations = array(Seq.from(sort.collation.getFieldCollations()).map(collation -> { - var index = collation.getFieldIndex(); - return array(Seq.of(new IntNode(index), types.get(index), new TextNode(collation.shortString()))); - })); - var args = object(Map.of("collation", collations, "source", serialize(sort.getInput()))); - var visitor = new RexJSONVisitor(env.advanced(sort.getInput().getRowType().getFieldCount())); - if (sort.offset != null) { - args.set("offset", visitor.serialize(sort.offset)); - } - if (sort.fetch != null) { - args.set("limit", visitor.serialize(sort.fetch)); - } - yield object(Map.of("sort", args)); - } - default -> throw new RuntimeException("Not implemented: " + rel.getRelTypeName()); - }; - } - - public Result deserialize(RuleBuilder builder, JsonNode jsonNode) { - var entry = jsonNode.fields().next(); - var kind = entry.getKey(); - var content = entry.getValue(); - return switch (kind) { - case String k when k.equals("scan") -> { - if (content.isInt() && 0 <= content.asInt() && content.asInt() < env.tables().size()) { - builder.scan(env.tables().get(content.asInt()).getName()); - yield Result.ok(builder); - } - yield Result.err(String.format("Missing table with index %s", content.toPrettyString())); - } - case String k when k.equals("values") -> { - try { - var et = unwrap(array(content, "schema")); - var rt = new RelRecordType(StructKind.FULLY_QUALIFIED, et.mapIndexed( - (i, t) -> (RelDataTypeField) new RelDataTypeFieldImpl(String.format("VALUES-%s", i), i, - RelType.fromString(t.asText(), true))).asJava()); - var vs = unwrap(array(content, "content")); - var vals = ImmutableSeq.>empty(); - for (var v : vs) { - var val = ImmutableSeq.empty(); - if (!v.isArray()) { - yield Result.err("Expecting tuple (JSON list) as value"); - } - for (var jl : Seq.from(v.elements())) { - var l = unwrap(new RexJSONVisitor(env).deserialize(builder, jl)); - if (l instanceof RexLiteral) { - val = val.appended((RexLiteral) l); - } else { - yield Result.err("Expecting literal expression"); - } - } - vals = vals.appended(val.asJava()); - } - builder.values(vals.asJava(), rt); - yield Result.ok(builder); - } catch (Exception e) { - yield Result.err(e.getMessage()); - } - } - case String k when k.equals("filter") -> { - try { - var cond = unwrap(object(content, "condition")); - var source = unwrap(object(content, "source")); - var bs = unwrap(deserialize(builder, source)); - var c = unwrap(new RexJSONVisitor(env).deserialize(builder, cond)); - bs.filter(c); - yield Result.ok(bs); - } catch (Exception e) { - yield Result.err(e.getMessage()); - } - } - case String k when k.equals("project") -> { - try { - var target = unwrap(array(content, "target")); - var source = unwrap(object(content, "source")); - var bs = unwrap(deserialize(builder, source)); - var ps = target.mapChecked(t -> unwrap(new RexJSONVisitor(env).deserialize(builder, t))); - bs.project(ps); - yield Result.ok(bs); - } catch (Exception e) { - yield Result.err(e.getMessage()); - } - } - case String k when k.equals("join") -> Result.err("Not implemented yet"); - case String k when k.equals("correlate") -> Result.err("Not implemented yet"); - default -> Result.err(String.format("Unrecognized node:\n%s", jsonNode.toPrettyString())); - }; - } - - public record RexJSONVisitor(Env env) { - public JsonNode serialize(RexNode rex) { - return switch (rex) { - case RexInputRef inputRef -> - object(Map.of("column", new IntNode(inputRef.getIndex() + env.base()), "type", - new TextNode(inputRef.getType().toString()))); - case RexLiteral literal -> object(Map.of("operator", - new TextNode(literal.getValue() == null ? "NULL" : literal.getValue().toString()), "operand", - array(Seq.empty()), "type", new TextNode(literal.getType().toString()))); - case RexSubQuery subQuery -> - object(Map.of("operator", new TextNode(subQuery.getOperator().toString()), "operand", - array(Seq.from(subQuery.getOperands()).map(this::serialize)), "query", - new RelJSONShuttle(env.advanced(0)).serialize(subQuery.rel), "type", - new TextNode(subQuery.getType().toString()))); - case RexCall call -> object(Map.of("operator", new TextNode(call.getOperator().toString()), "operand", - array(Seq.from(call.getOperands()).map(this::serialize)), "type", - new TextNode(call.getType().toString()))); - case RexFieldAccess fieldAccess -> object(Map.of("column", new IntNode( - fieldAccess.getField().getIndex() + - env.resolve(((RexCorrelVariable) fieldAccess.getReferenceExpr()).id)), "type", - new TextNode(fieldAccess.getType().toString()))); - default -> throw new RuntimeException("Not implemented: " + rex.getKind()); - }; - } - - public Result deserialize(RuleBuilder builder, JsonNode jsonNode) { - if (jsonNode.has("column") && jsonNode.get("column").isInt()) { - // WARNING: THIS IS WRONG! NO ENVIRONMENT CONSIDERED! - return Result.ok(builder.field(jsonNode.get("column").asInt())); - } else if (jsonNode.has("operator") && jsonNode.get("operator").isTextual()) { - var op = jsonNode.get("operator").asText(); - try { - var args = unwrap(array(jsonNode, "operand")); - var ty = RelType.fromString(unwrap(object(jsonNode, "type")).asText(), true); - if (args.isEmpty()) { - return Result.ok(RexLiteral.fromJdbcString(ty, ty.getSqlTypeName(), op)); - } else { - var fields = args.mapChecked(expr -> unwrap(deserialize(builder, expr))); - for (var refl : Seq.from(SqlStdOperatorTable.class.getDeclaredFields()) - .filter(f -> java.lang.reflect.Modifier.isPublic(f.getModifiers()) && - java.lang.reflect.Modifier.isStatic(f.getModifiers()))) { - var mist = refl.get(null); - if (mist instanceof SqlOperator sqlOperator && sqlOperator.getName().equals(op)) { - return Result.ok(builder.call(sqlOperator, fields)); - } - } - return Result.ok(builder.call(builder.genericProjectionOp(op, ty), fields)); - } - } catch (Exception e) { - return Result.err(e.getMessage()); - } - } - return Result.err(String.format("Unrecognized node:\n%s", jsonNode.toPrettyString())); - } - } -} \ No newline at end of file diff --git a/src/main/java/org/qed/RelRN.java b/src/main/java/org/qed/RelRN.java new file mode 100644 index 0000000..0967127 --- /dev/null +++ b/src/main/java/org/qed/RelRN.java @@ -0,0 +1,224 @@ +package org.qed; + +import kala.collection.Seq; +import kala.collection.Set; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.util.ImmutableBitSet; + +import java.util.Arrays; +import java.util.stream.IntStream; + +public interface RelRN { + static Scan scan(String id, RelType.VarType ty, boolean unique) { + return new Scan(id, ty, unique); + } + + static Scan scan(String id, String typeName) { + return scan(id, RexRN.varType(typeName, true), false); + } + + RelNode semantics(); + + default RexRN field(int ordinal) { + return new RexRN.Field(ordinal, this); + } + + default Seq fields(int... ordinals) { + return Seq.from(Arrays.stream(ordinals).iterator()).map(this::field); + } + + default Seq fields() { + return fields(IntStream.range(0, semantics().getRowType().getFieldCount()).toArray()); + } + + default RexRN joinField(int ordinal, RelRN right) { + return new RexRN.JoinField(ordinal, this, right); + } + + default Seq joinFields(RelRN right, int... ordinals) { + return Seq.from(Arrays.stream(ordinals).iterator()).map(i -> joinField(i, right)); + } + + default Seq joinFields(RelRN right) { + return joinFields(right, IntStream.range(0, + semantics().getRowType().getFieldCount() + right.semantics().getRowType().getFieldCount()).toArray()); + } + + default RexRN.Pred pred(SqlOperator op) { + return new RexRN.Pred(op, fields()); + } + + default RexRN.Pred pred(String name) { + return pred(RuleBuilder.create().genericPredicateOp(name, true)); + } + + default RexRN.Pred joinPred(SqlOperator op, RelRN right) { + return new RexRN.Pred(op, joinFields(right)); + } + + default RexRN.Pred joinPred(String name, RelRN right) { + return joinPred(RuleBuilder.create().genericPredicateOp(name, true), right); + } + + default RexRN.Proj proj(SqlOperator op) { + return new RexRN.Proj(op, fields()); + } + + default RexRN.Proj proj(String name, String type_name) { + return proj(RuleBuilder.create().genericProjectionOp(name, new RelType.VarType(type_name, true))); + } + + default Filter filter(RexRN cond) { + return new Filter(cond, this); + } + + default Filter filter(String name) { + return filter(pred(name)); + } + + default Project project(RexRN proj) { + return new Project(proj, this); + } + + default Project project(String name, String type_name) { + return project(proj(name, type_name)); + } + + default Join join(Join.JoinType ty, RexRN cond, RelRN right) { + return new Join(ty, cond, this, right); + } + + default Join join(JoinRelType ty, RexRN cond, RelRN right) { + return join(new Join.JoinType.ConcreteJoinType(ty), cond, right); + } + + default Join join(JoinRelType ty, String name, RelRN right) {return join(ty, joinPred(name, right), right);} + + default Union union(boolean all, RelRN... sources) { + return new Union(all, Seq.of(this).appendedAll(sources)); + } + + default Intersect intersect(boolean all, RelRN... sources) { + return new Intersect(all, Seq.of(this).appendedAll(sources)); + } + + default Minus minus (boolean all, RelRN... sources) { + return new Minus(all, Seq.of(this).appendedAll(sources)); + } + + default Empty empty() { + return new Empty(this); + } + + default Aggregate aggregate(Seq groupSet, Seq aggCalls) { + return new Aggregate(this, groupSet, aggCalls); + } + + record Scan(String name, RelType.VarType ty, boolean unique) implements RelRN { + + @Override + public RelNode semantics() { + var table = new QedTable(name, Seq.of("col-" + name), Seq.of(ty), unique ? + Set.of(ImmutableBitSet.of(0)) : Set.empty(), Set.empty()); + return RuleBuilder.create().addTable(table).scan(name).build(); + } + } + + record Filter(RexRN cond, RelRN source) implements RelRN { + @Override + public RelNode semantics() { + return RuleBuilder.create().push(source.semantics()).filter(cond.semantics()).build(); + } + } + + record Project(RexRN map, RelRN source) implements RelRN { + @Override + public RelNode semantics() { + return RuleBuilder.create().push(source.semantics()).project(map.semantics()).build(); + } + } + + record Join(Join.JoinType ty, RexRN cond, RelRN left, RelRN right) implements RelRN { + @Override + public RelNode semantics() { + return RuleBuilder.create().push(left.semantics()).push(right.semantics()).join(ty.semantics(), + cond.semantics()).build(); + } + + @Override + public RexRN field(int ordinal) { + return new RexRN.JoinField(ordinal, left, right); + } + + public interface JoinType { + JoinRelType semantics(); + + record ConcreteJoinType(JoinRelType type) implements JoinType { + @Override + public JoinRelType semantics() { + return type; + } + } + + record MetaJoinType(String name) implements JoinType { + @Override + public JoinRelType semantics() { + return JoinRelType.INNER; + } + } + } + + } + + record Union(boolean all, Seq sources) implements RelRN { + + @Override + public RelNode semantics() { + return RuleBuilder.create().pushAll(sources.map(RelRN::semantics)).union(all, sources.size()).build(); + } + } + + record Intersect(boolean all, Seq sources) implements RelRN { + + @Override + public RelNode semantics() { + return RuleBuilder.create().pushAll(sources.map(RelRN::semantics)).intersect(all, sources.size()).build(); + } + } + + record Minus(boolean all, Seq sources) implements RelRN { + + @Override + public RelNode semantics() { + return RuleBuilder.create().pushAll(sources.map(RelRN::semantics)).minus(all, sources.size()).build(); + } + } + + record Empty(RelRN sourceType) implements RelRN { + + @Override + public RelNode semantics() { + return RuleBuilder.create().values(sourceType.semantics().getRowType()).build(); + } + } + + record AggCall(String name, boolean distinct, RelType type, Seq operands) { + } + + record Aggregate(org.qed.RelRN source, Seq groupSet, Seq aggCalls) implements org.qed.RelRN { + @Override + public RelNode semantics() { + var builder = RuleBuilder.create(); + builder.push(source.semantics()); + var groupKey = builder.groupKey(groupSet.map(RexRN::semantics)); + var calls = aggCalls.map(agg -> { + var aggFunc = builder.genericAggregateOp(agg.name(), agg.type()); + return builder.aggregateCall(aggFunc, agg.distinct(), null, agg.name(), agg.operands().map(RexRN::semantics).asJava()); + }); + return builder.aggregate(groupKey, calls).build(); + } + } + +} \ No newline at end of file diff --git a/src/main/java/org/qed/RelType.java b/src/main/java/org/qed/RelType.java index 19e0d6a..d2b4328 100644 --- a/src/main/java/org/qed/RelType.java +++ b/src/main/java/org/qed/RelType.java @@ -7,6 +7,15 @@ import org.apache.calcite.sql.type.SqlTypeName; public sealed interface RelType extends RelDataType { + static RelType fromString(String name, boolean nullable) { + for (var tn : SqlTypeName.values()) { + if (tn.getName().equals(name)) { + return new BaseType(tn, nullable); + } + } + return new VarType(name, nullable); + } + final class VarType extends RelDataTypeImpl implements RelType { private final String name; private final boolean nullable; @@ -32,6 +41,10 @@ protected void generateTypeString(StringBuilder sb, boolean withDetail) { public boolean isNullable() { return nullable; } + + public String getName() { + return "INTEGER"; + } } final class BaseType extends BasicSqlType implements RelType { @@ -39,13 +52,4 @@ public BaseType(SqlTypeName typeName, boolean nullable) { super(RelDataTypeSystem.DEFAULT, typeName, nullable); } } - - static RelType fromString(String name, boolean nullable) { - for (var tn : SqlTypeName.values()) { - if (tn.getName().equals(name)) { - return new BaseType(tn, nullable); - } - } - return new VarType(name, nullable); - } } diff --git a/src/main/java/org/qed/RexRN.java b/src/main/java/org/qed/RexRN.java new file mode 100644 index 0000000..81d174f --- /dev/null +++ b/src/main/java/org/qed/RexRN.java @@ -0,0 +1,123 @@ +package org.qed; + +import kala.collection.Seq; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlOperator; + +public interface RexRN { + + static RelType.VarType varType(String id, boolean nullable) { + return new RelType.VarType(id, nullable); + } + + static And and(RexRN... sources) { + return new And(Seq.from(sources)); + } + + static False falseLiteral() { + return new False(); + } + + static True trueLiteral() { + return new True(); + } + + RexNode semantics(); + + default Pred pred(SqlOperator op) { + return new Pred(op, Seq.of(this)); + } + + default Pred pred(String name) { + return pred(RuleBuilder.create().genericPredicateOp(name, true)); + } + + default Proj proj(SqlOperator op) { + return new Proj(op, Seq.of(this)); + } + + default Proj proj(String name, String type_name) { + return proj(RuleBuilder.create().genericProjectionOp(name, new RelType.VarType(type_name, true))); + } + + + record Field(int ordinal, RelRN source) implements RexRN { + + @Override + public RexNode semantics() { + return RuleBuilder.create().push(source.semantics()).field(ordinal); + } + } + + record JoinField(int ordinal, RelRN left, RelRN right) implements RexRN { + + @Override + public RexNode semantics() { + var leftCols = left.semantics().getRowType().getFieldCount(); + return RuleBuilder.create().push(left.semantics()).push(right.semantics()).field(2, ordinal < leftCols ? + 0 : 1, ordinal < leftCols ? ordinal : ordinal - leftCols); + } + } + + record Pred(SqlOperator operator, Seq sources) implements RexRN { + + @Override + public RexNode semantics() { + var builder = RuleBuilder.create(); +// builder.genericPredicateOp(name, nullable) + return builder.call(operator, sources.map(RexRN::semantics)); + } + } + + record Proj(SqlOperator operator, Seq sources) implements RexRN { + + @Override + public RexNode semantics() { + var builder = RuleBuilder.create(); +// builder.genericProjectionOp(name, varType(type_name, nullable)) + return builder.call(operator, + sources.map(RexRN::semantics)); + } + } + + record And(Seq sources) implements RexRN { + + @Override + public RexNode semantics() { + return RuleBuilder.create().and(sources.map(RexRN::semantics)); + } + } + + record False() implements RexRN { + + @Override + public RexNode semantics() { + return RuleBuilder.create().literal(false); + } + } + + record Not(RexRN source) implements RexRN { + + @Override + public RexNode semantics() { + return RuleBuilder.create().not(source.semantics()); + } + } + + record Or(Seq sources) implements RexRN { + + @Override + public RexNode semantics() { + return RuleBuilder.create().or(sources.map(RexRN::semantics)); + } + } + + record True() implements RexRN { + + @Override + public RexNode semantics() { + return RuleBuilder.create().literal(true); + } + } + +} diff --git a/src/main/java/org/qed/SchemaGenerator.java b/src/main/java/org/qed/SchemaGenerator.java index bd47b37..484b1ef 100644 --- a/src/main/java/org/qed/SchemaGenerator.java +++ b/src/main/java/org/qed/SchemaGenerator.java @@ -227,7 +227,8 @@ public void addTable(SqlCreateTable createTable) { } } var qedTable = new QedTable(createTable.name.toString(), names.zip( - types.zip(nullabilities).map(type -> new RelType.BaseType(type.component1(), type.component2()))) + types.zip(nullabilities).map(type -> new RelType.BaseType(type.component1(), + type.component2()))) .toImmutableMap(), ImmutableSet.from(keys), ImmutableSet.from(checkConstraints)); tables.put(createTable.name.toString(), qedTable); }