Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions src/main/java/org/qed/Generated/MySQLGenerator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package org.qed.Generated;

import org.qed.Generated.RRuleInstances.JoinCommute;
import org.qed.RelRN;
import org.qed.RexRN;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

public class MySQLGenerator {

private int subqueryCounter = 0;
private final String tableName;
private final List<String> columnNames;

public MySQLGenerator(String tableName, List<String> columnNames) {
this.tableName = tableName;
this.columnNames = columnNames;
}

private static class FlattenedSQLParts {
String fromClause = "";
List<String> projections = new ArrayList<>();
List<String> conditions = new ArrayList<>();
}

public String translate(String name, RelRN before, RelRN after) {
String beforeSQL;
String afterSQL;

if (name.equals("JoinCommute")) {
subqueryCounter = 0;
beforeSQL = transformNested(before, true, false, new AtomicInteger(0));
subqueryCounter = 0;
afterSQL = transformNested(before, true, true, new AtomicInteger(0));
} else {
subqueryCounter = 0;
beforeSQL = transformNested(before, true, false, new AtomicInteger(0));
afterSQL = transformFlatten(after);
}

return "INSERT INTO query_rewrite.rewrite_rules\n" +
" (pattern, replacement) VALUES(\n" +
" '" + beforeSQL + "',\n" +
" '" + afterSQL + "'\n" +
");";
}

private String transformNested(RelRN node, boolean isRoot, boolean swapJoinSides, AtomicInteger filterIndex) {
if (node instanceof RelRN.Scan) {
return "SELECT * FROM " + tableName;
} else if (node instanceof RelRN.Project project) {
String cols = String.join(", ", columnNames);
if (project.source() instanceof RelRN.Scan) {
return "SELECT " + cols + " FROM " + tableName;
}
String innerSQL = transformNested(project.source(), false, swapJoinSides, filterIndex);
String alias = "t" + (subqueryCounter++);
return "SELECT " + cols + " FROM (" + innerSQL + ") AS " + alias;
} else if (node instanceof RelRN.Filter filter) {
String innerSQL = transformNested(filter.source(), false, swapJoinSides, filterIndex);
int currentIndex = filterIndex.getAndIncrement();
String condition = (currentIndex < columnNames.size())
? columnNames.get(currentIndex) + " = ?"
: columnNames.get(0) + " = ?";

if (isRoot) {
return innerSQL + " WHERE " + condition;
} else {
String alias = "t" + (subqueryCounter++);
return "SELECT * FROM (" + innerSQL + " WHERE " + condition + ") AS " + alias;
}
} else if (node instanceof RelRN.Join join) {
String leftAlias = "t0";
String rightAlias = "t1";

RelRN firstNode = swapJoinSides ? join.right() : join.left();
String firstAlias = swapJoinSides ? rightAlias : leftAlias;
RelRN secondNode = swapJoinSides ? join.left() : join.right();
String secondAlias = swapJoinSides ? leftAlias : rightAlias;

String firstSQL = "(" + transformNested(firstNode, false, swapJoinSides, filterIndex) + ")";
String secondSQL = "(" + transformNested(secondNode, false, swapJoinSides, filterIndex) + ")";

String joinCond = renderJoinCondition(join.cond(), leftAlias, rightAlias, swapJoinSides);

String joinExpr =
firstSQL + " AS " + firstAlias +
" " + join.ty().semantics().name() + " JOIN " +
secondSQL + " AS " + secondAlias +
" ON " + joinCond;

if (isRoot) {
return "SELECT * FROM " + joinExpr;
} else {
String alias = "t" + (subqueryCounter++);
return "SELECT * FROM (" + joinExpr + ") AS " + alias;
}

} else if (node instanceof JoinCommute.ProjectionRelRN projRN) {
return transformNested(projRN.source(), isRoot, swapJoinSides, filterIndex);
} else {
throw new UnsupportedOperationException("Unsupported RelRN: " + node);
}
}

private String renderJoinCondition(RexRN cond, String leftAlias, String rightAlias, boolean swap) {
if (cond instanceof RexRN.Pred p) {
if (p.sources().get(0) instanceof RexRN.JoinField jf) {
String colName = columnNames.get(jf.ordinal());
String first = swap ? rightAlias : leftAlias;
String second = swap ? leftAlias : rightAlias;
return first + "." + colName + " = " + second + "." + colName;
}
}
throw new UnsupportedOperationException("Unsupported join condition: " + cond);
}

public String transformFlatten(RelRN node) {
FlattenedSQLParts parts = new FlattenedSQLParts();
collectFlattenedParts(node, parts);
String selectClause = parts.projections.isEmpty() ? "SELECT *" : "SELECT " + String.join(", ", parts.projections);
String whereClause = parts.conditions.isEmpty() ? "" : " WHERE " + String.join(" AND ", parts.conditions);
return selectClause + " FROM " + parts.fromClause + whereClause;
}

private void collectFlattenedParts(RelRN node, FlattenedSQLParts parts) {
switch (node) {
case RelRN.Scan scan -> parts.fromClause = tableName;
case RelRN.Project project -> {
collectFlattenedParts(project.source(), parts);
parts.projections.addAll(columnNames);
}
case RelRN.Filter filter -> {
collectFlattenedParts(filter.source(), parts);
collectPredConditions(filter.cond(), parts.conditions);
}
default -> throw new UnsupportedOperationException("Unsupported RelRN for flatten: " + node);
}
}

private void collectPredConditions(RexRN pred, List<String> conditions) {
if (pred instanceof RexRN.Pred) {
int currentConditions = conditions.size();
if (currentConditions < columnNames.size()) {
conditions.add(columnNames.get(currentConditions) + " = ?");
}
} else if (pred instanceof RexRN.And and) {
for (RexRN child : and.sources()) {
collectPredConditions(child, conditions);
}
}
}
}
43 changes: 43 additions & 0 deletions src/main/java/org/qed/Generated/MySQLTester.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package org.qed.Generated;

import org.qed.*;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;

public class MySQLTester {

public static String genPath = "src/main/java/org/qed/Generated/mysql";

public static String tableName = "testdb.users";
public static List<String> columnNames = List.of("id", "status");

public static void main(String[] args) {
var filterRule = new org.qed.Generated.RRuleInstances.FilterMerge();
new MySQLTester().serializeWithNumericSuffix(filterRule, genPath);

var projectRule = new org.qed.Generated.RRuleInstances.ProjectMerge();
new MySQLTester().serializeWithNumericSuffix(projectRule, genPath);

var joinCommute = new org.qed.Generated.RRuleInstances.JoinCommute();
new MySQLTester().serializeWithNumericSuffix(joinCommute, genPath);
}

public void serializeWithNumericSuffix(RRule rule, String path) {
serialize(rule, path, tableName, columnNames, 1);
serialize(rule, path, tableName, List.of(columnNames.get(1), columnNames.get(0)), 2);
}

private void serialize(RRule rule, String path, String tableName, List<String> colNames, int fileIndex) {
var generator = new MySQLGenerator(tableName, colNames);
var codeGen = generator.translate(rule.name(), rule.before(), rule.after());
try {
Files.createDirectories(Path.of(path));
String fileName = rule.name() + fileIndex + ".sql";
Files.write(Path.of(path, fileName), codeGen.getBytes());
} catch (IOException ioe) {
System.err.println(ioe.getMessage());
}
}
}
135 changes: 135 additions & 0 deletions src/main/java/org/qed/Generated/ProxySQLGenerator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
package org.qed.Generated;

import org.qed.Generated.RRuleInstances.JoinCommute;
import org.qed.RelRN;
import org.qed.RexRN;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

public class ProxySQLGenerator {

private final Map<RexRN, Integer> predicateToGroupIndex = new HashMap<>();
private final AtomicInteger groupCounter = new AtomicInteger(1);
private boolean isReduceTrueRule = false;

public String translate(int ruleId, String name, RelRN before, RelRN after) {
predicateToGroupIndex.clear();
groupCounter.set(1);
this.isReduceTrueRule = false;

String matchPattern = generateMatchPattern(before);
String replacePattern = generateReplacePattern(after);

return String.format(
"""
INSERT INTO mysql_query_rules (rule_id, active, match_pattern, replace_pattern)
VALUES (
%d, 1,
'^%s',
'%s'
);""",
ruleId, matchPattern, replacePattern
);
}

private String generateMatchPattern(RelRN node) {
return switch (node) {
case RelRN.Filter filter -> {
if (filter.cond() instanceof RexRN.True) {
this.isReduceTrueRule = true;
groupCounter.addAndGet(2);
yield "SELECT (.*) FROM (.*) WHERE TRUE";
}

if (filter.cond() instanceof RexRN.False) {
groupCounter.addAndGet(2);
yield "SELECT (.*) FROM (.*) WHERE FALSE";
}

String sourcePattern = generateMatchPattern(filter.source());
String conditionRegex = "(.*) = (.*)";
if (filter.source() instanceof RelRN.Scan) {
int conditionGroupStart = groupCounter.get();
groupCounter.addAndGet(2);
predicateToGroupIndex.put(filter.cond(), conditionGroupStart);
yield sourcePattern + " WHERE " + conditionRegex;
} else {
groupCounter.getAndIncrement();
int conditionGroupStart = groupCounter.get();
groupCounter.addAndGet(2);
predicateToGroupIndex.put(filter.cond(), conditionGroupStart);
yield String.format("SELECT \\* FROM \\(%s\\) AS (.*) WHERE %s", sourcePattern, conditionRegex);
}
}
case RelRN.Project project -> {
if (project.source() instanceof RelRN.Project innerProject && innerProject.source() instanceof RelRN.Scan) {
yield "SELECT (.*) FROM \\(SELECT (.*) FROM (.*)\\) AS (.*)";
}
throw new UnsupportedOperationException("This generator only supports the specific Project(Project(Scan)) pattern.");
}
case RelRN.Join join -> {
if (join.left() instanceof RelRN.Scan && join.right() instanceof RelRN.Scan) {
yield "SELECT \\* FROM (.*) AS (.*?) INNER JOIN (.*) AS (.*?) ON (.*?)\\.(.*?) = (.*?)\\.(.*)";
}
throw new UnsupportedOperationException("This generator only supports simple Scan-Join-Scan patterns.");
}
case RelRN.Scan scan -> {
groupCounter.getAndIncrement();
yield "SELECT \\* FROM (.*)";
}
default -> throw new UnsupportedOperationException("Unsupported RelRN for match pattern: " + node.getClass().getSimpleName());
};
}

private String generateReplacePattern(RelRN node) {
return switch (node) {
case RelRN.Empty empty -> {
yield "SELECT \\1 FROM \\2 LIMIT 0";
}
case JoinCommute.ProjectionRelRN proj -> {
if (proj.source() instanceof RelRN.Join) {
yield "SELECT * FROM \\3 AS \\4 INNER JOIN \\1 AS \\2 ON \\7.\\8 = \\5.\\6";
}
throw new UnsupportedOperationException("Unsupported 'after' pattern for JoinCommute.");
}
case RelRN.Filter filter -> {
String fromClause = generateReplacePattern(filter.source());
String whereClause = buildWhereClause(filter.cond());
yield String.format("%s WHERE %s", fromClause, whereClause);
}
case RelRN.Project project -> {
if (project.source() instanceof RelRN.Scan) {
yield "SELECT \\1 FROM \\3";
}
throw new UnsupportedOperationException("Unsupported 'after' pattern for ProjectMerge.");
}
case RelRN.Scan scan -> {
if (this.isReduceTrueRule) {
yield "SELECT \\1 FROM \\2";
} else {
yield "SELECT * FROM \\1";
}
}
default -> throw new UnsupportedOperationException("Unsupported RelRN for replace pattern: " + node.getClass().getSimpleName());
};
}

private String buildWhereClause(RexRN condition) {
return switch (condition) {
case RexRN.And andNode -> andNode.sources().stream()
.map(this::buildWhereClause)
.collect(Collectors.joining(" AND "));
case RexRN.Pred pred -> {
Integer groupIndex = predicateToGroupIndex.get(pred);
if (groupIndex == null) {
throw new IllegalStateException("Predicate from 'after' tree not found in 'before' tree: " + pred);
}
yield String.format("\\%d = \\%d", groupIndex, groupIndex + 1);
}
default -> throw new UnsupportedOperationException("Unsupported RexRN for WHERE clause: " + condition.getClass().getSimpleName());
};
}
}
57 changes: 57 additions & 0 deletions src/main/java/org/qed/Generated/ProxySQLTester.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package org.qed.Generated;

import org.qed.RRule;
import org.qed.Generated.RRuleInstances.FilterMerge;
import org.qed.Generated.RRuleInstances.ProjectMerge;
import org.qed.Generated.RRuleInstances.JoinCommute;
import org.qed.Generated.RRuleInstances.FilterReduceFalse;
import org.qed.Generated.RRuleInstances.FilterReduceTrue; // Import the new rule

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;

public class ProxySQLTester {

public static final String OUTPUT_PATH = "src/main/java/org/qed/Generated/proxysql";

private int nextRuleId = 10;

public static void main(String[] args) {
var tester = new ProxySQLTester();

List<RRule> rulesToGenerate = List.of(
new FilterMerge(),
new ProjectMerge(),
new JoinCommute(),
new FilterReduceFalse(),
new FilterReduceTrue()
);

for (RRule rule : rulesToGenerate) {
tester.generateRuleFile(rule);
System.out.println();
}
}

public void generateRuleFile(RRule rule) {
int currentRuleId = this.nextRuleId;
this.nextRuleId += 10;

var generator = new ProxySQLGenerator();
String ruleSql = generator.translate(currentRuleId, rule.name(), rule.before(), rule.after());

try {
Path outputDir = Path.of(OUTPUT_PATH);
Files.createDirectories(outputDir);

String fileName = rule.name() + ".sql";
Path filePath = outputDir.resolve(fileName);
Files.writeString(filePath, ruleSql);

} catch (IOException | UnsupportedOperationException e) {
e.printStackTrace();
}
}
}
Loading