Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/main/java/org/qed/Generated/MySQL/FilterMerge1.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
INSERT INTO query_rewrite.rewrite_rules
(pattern, replacement) VALUES(
'SELECT * FROM (SELECT * FROM testdb.users WHERE id = ?) AS t0 WHERE status = ?',
'SELECT * FROM testdb.users WHERE id = ? AND status = ?'
);
5 changes: 5 additions & 0 deletions src/main/java/org/qed/Generated/MySQL/FilterMerge2.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
INSERT INTO query_rewrite.rewrite_rules
(pattern, replacement) VALUES(
'SELECT * FROM (SELECT * FROM testdb.users WHERE status = ?) AS t0 WHERE id = ?',
'SELECT * FROM testdb.users WHERE status = ? AND id = ?'
);
5 changes: 5 additions & 0 deletions src/main/java/org/qed/Generated/MySQL/JoinCommute1.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
INSERT INTO query_rewrite.rewrite_rules
(pattern, replacement) VALUES(
'(SELECT * FROM testdb.users) AS t0 INNER JOIN (SELECT * FROM testdb.users) AS t1 ON t0.id = t1.id',
'(SELECT * FROM testdb.users) AS t1 INNER JOIN (SELECT * FROM testdb.users) AS t0 ON t1.id = t0.id'
);
5 changes: 5 additions & 0 deletions src/main/java/org/qed/Generated/MySQL/JoinCommute2.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
INSERT INTO query_rewrite.rewrite_rules
(pattern, replacement) VALUES(
'(SELECT * FROM testdb.users) AS t0 INNER JOIN (SELECT * FROM testdb.users) AS t1 ON t0.status = t1.status',
'(SELECT * FROM testdb.users) AS t1 INNER JOIN (SELECT * FROM testdb.users) AS t0 ON t1.status = t0.status'
);
5 changes: 5 additions & 0 deletions src/main/java/org/qed/Generated/MySQL/ProjectMerge1.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
INSERT INTO query_rewrite.rewrite_rules
(pattern, replacement) VALUES(
'SELECT id, status FROM (SELECT id, status FROM testdb.users) AS t0',
'SELECT id, status FROM testdb.users'
);
5 changes: 5 additions & 0 deletions src/main/java/org/qed/Generated/MySQL/ProjectMerge2.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
INSERT INTO query_rewrite.rewrite_rules
(pattern, replacement) VALUES(
'SELECT status, id FROM (SELECT status, id FROM testdb.users) AS t0',
'SELECT status, id FROM testdb.users'
);
147 changes: 147 additions & 0 deletions src/main/java/org/qed/Generated/MySQLGenerator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package org.qed.Generated;

import org.qed.Generated.RRuleInstances.JoinCommute;
import org.qed.RelRN;
import org.qed.RexRN;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;

public class MySQLGenerator {

private int subqueryCounter = 0;
private final String tableName;
private final List<String> columnNames;

public MySQLGenerator(String tableName, List<String> columnNames) {
this.tableName = tableName;
this.columnNames = columnNames;
}

private static class FlattenedSQLParts {
String fromClause = "";
List<String> projections = new ArrayList<>();
List<String> conditions = new ArrayList<>();
}

public String translate(String name, RelRN before, RelRN after) {
String beforeSQL;
String afterSQL;

if (name.equals("JoinCommute")) {
subqueryCounter = 0;
beforeSQL = transformNested(before, true, false, new AtomicInteger(0));
subqueryCounter = 0;
afterSQL = transformNested(before, true, true, new AtomicInteger(0));
} else {
subqueryCounter = 0;
beforeSQL = transformNested(before, true, false, new AtomicInteger(0));
afterSQL = transformFlatten(after);
}

return "INSERT INTO query_rewrite.rewrite_rules\n" +
" (pattern, replacement) VALUES(\n" +
" '" + beforeSQL + "',\n" +
" '" + afterSQL + "'\n" +
");";
}

private String transformNested(RelRN node, boolean isRoot, boolean swapJoinSides, AtomicInteger filterIndex) {
if (node instanceof RelRN.Scan) {
return "SELECT * FROM " + tableName;
} else if (node instanceof RelRN.Project project) {
String cols = String.join(", ", columnNames);
if (project.source() instanceof RelRN.Scan) {
return "SELECT " + cols + " FROM " + tableName;
}
String innerSQL = transformNested(project.source(), false, swapJoinSides, filterIndex);
String alias = "t" + (subqueryCounter++);
return "SELECT " + cols + " FROM (" + innerSQL + ") AS " + alias;
} else if (node instanceof RelRN.Filter filter) {
String innerSQL = transformNested(filter.source(), false, swapJoinSides, filterIndex);
int currentIndex = filterIndex.getAndIncrement();
String condition = (currentIndex < columnNames.size())
? columnNames.get(currentIndex) + " = ?"
: columnNames.get(0) + " = ?";

if (isRoot) {
return innerSQL + " WHERE " + condition;
} else {
String alias = "t" + (subqueryCounter++);
return "SELECT * FROM (" + innerSQL + " WHERE " + condition + ") AS " + alias;
}
} else if (node instanceof RelRN.Join join) {
String leftAlias = "t0";
String rightAlias = "t1";

RelRN firstNode = swapJoinSides ? join.right() : join.left();
String firstAlias = swapJoinSides ? rightAlias : leftAlias;
RelRN secondNode = swapJoinSides ? join.left() : join.right();
String secondAlias = swapJoinSides ? leftAlias : rightAlias;

String firstSQL = "(" + transformNested(firstNode, false, swapJoinSides, filterIndex) + ")";
String secondSQL = "(" + transformNested(secondNode, false, swapJoinSides, filterIndex) + ")";

String joinCond = renderJoinCondition(join.cond(), leftAlias, rightAlias, swapJoinSides);

return firstSQL + " AS " + firstAlias +
" " + join.ty().semantics().name() + " JOIN " +
secondSQL + " AS " + secondAlias +
" ON " + joinCond;

} else if (node instanceof JoinCommute.ProjectionRelRN projRN) {
return transformNested(projRN.source(), isRoot, swapJoinSides, filterIndex);
} else {
throw new UnsupportedOperationException("Unsupported RelRN: " + node);
}
}

private String renderJoinCondition(RexRN cond, String leftAlias, String rightAlias, boolean swap) {
if (cond instanceof RexRN.Pred p) {
if (p.sources().get(0) instanceof RexRN.JoinField jf) {
String colName = columnNames.get(jf.ordinal());
String first = swap ? rightAlias : leftAlias;
String second = swap ? leftAlias : rightAlias;
return first + "." + colName + " = " + second + "." + colName;
}
}
throw new UnsupportedOperationException("Unsupported join condition: " + cond);
}

public String transformFlatten(RelRN node) {
FlattenedSQLParts parts = new FlattenedSQLParts();
collectFlattenedParts(node, parts);
String selectClause = parts.projections.isEmpty() ? "SELECT *" : "SELECT " + String.join(", ", parts.projections);
String whereClause = parts.conditions.isEmpty() ? "" : " WHERE " + String.join(" AND ", parts.conditions);
return selectClause + " FROM " + parts.fromClause + whereClause;
}

private void collectFlattenedParts(RelRN node, FlattenedSQLParts parts) {
switch (node) {
case RelRN.Scan scan -> parts.fromClause = tableName;
case RelRN.Project project -> {
collectFlattenedParts(project.source(), parts);
parts.projections.addAll(columnNames);
}
case RelRN.Filter filter -> {
collectFlattenedParts(filter.source(), parts);
collectPredConditions(filter.cond(), parts.conditions);
}
default -> throw new UnsupportedOperationException("Unsupported RelRN for flatten: " + node);
}
}

private void collectPredConditions(RexRN pred, List<String> conditions) {
if (pred instanceof RexRN.Pred) {
int currentConditions = conditions.size();
if (currentConditions < columnNames.size()) {
conditions.add(columnNames.get(currentConditions) + " = ?");
}
} else if (pred instanceof RexRN.And and) {
for (RexRN child : and.sources()) {
collectPredConditions(child, conditions);
}
}
}
}
43 changes: 43 additions & 0 deletions src/main/java/org/qed/Generated/MySQLTester.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package org.qed.Generated;

import org.qed.*;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;

public class MySQLTester {

public static String genPath = "src/main/java/org/qed/Generated/MySQL";

public static String tableName = "testdb.users";
public static List<String> columnNames = List.of("id", "status");

public static void main(String[] args) {
var filterRule = new org.qed.Generated.RRuleInstances.FilterMerge();
new MySQLTester().serializeWithNumericSuffix(filterRule, genPath);

var projectRule = new org.qed.Generated.RRuleInstances.ProjectMerge();
new MySQLTester().serializeWithNumericSuffix(projectRule, genPath);

var joinCommute = new org.qed.Generated.RRuleInstances.JoinCommute();
new MySQLTester().serializeWithNumericSuffix(joinCommute, genPath);
}

public void serializeWithNumericSuffix(RRule rule, String path) {
serialize(rule, path, tableName, columnNames, 1);
serialize(rule, path, tableName, List.of(columnNames.get(1), columnNames.get(0)), 2);
}

private void serialize(RRule rule, String path, String tableName, List<String> colNames, int fileIndex) {
var generator = new MySQLGenerator(tableName, colNames);
var codeGen = generator.translate(rule.name(), rule.before(), rule.after());
try {
Files.createDirectories(Path.of(path));
String fileName = rule.name() + fileIndex + ".sql";
Files.write(Path.of(path, fileName), codeGen.getBytes());
} catch (IOException ioe) {
System.err.println(ioe.getMessage());
}
}
}
66 changes: 66 additions & 0 deletions src/main/java/org/qed/Generated/Tests-MySQL/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import mysql.connector
from pathlib import Path

MYSQL_USER = "root"
MYSQL_PASSWORD = "wkaiz"
MYSQL_DATABASE = "query_rewrite"
SQL_FILE = Path("C:\\Users\\wesle\\OneDrive\\Desktop\\parser\\src\\main\\java\\org\\qed\\Generated\\MySQL\\FilterMerge2.sql") # path to your SQL file

TEST_QUERY = """
SELECT * FROM (SELECT * FROM testdb.users WHERE status = 'active') AS t0
WHERE id = 1;
"""

conn = mysql.connector.connect(
host="localhost",
user=MYSQL_USER,
password=MYSQL_PASSWORD,
database=MYSQL_DATABASE
)
cursor = conn.cursor()

with SQL_FILE.open("r", encoding="utf-8") as f:
sql_commands = f.read()

for cmd in sql_commands.split(";"):
cmd = cmd.strip()
if cmd:
cursor.execute(cmd + ";")

conn.commit()
print(f"{SQL_FILE} executed successfully.")

cursor.execute("""
DELETE FROM rewrite_rules
WHERE id < (
SELECT max_id FROM (SELECT MAX(id) AS max_id FROM rewrite_rules) AS t
);

""")
conn.commit()
print("Deleted all rules except the last one.")

cursor.execute("CALL flush_rewrite_rules();")
conn.commit()
print("Flushed rewrite rules.")

cursor.execute("SELECT * FROM rewrite_rules;")
print("Current rules in table:")
for row in cursor.fetchall():
print(row)

cursor.execute(TEST_QUERY)
results = cursor.fetchall()

print("\nTest query results:")
for row in results:
print(row)

cursor.execute("SHOW WARNINGS;")
warnings = cursor.fetchall()
print("\nWarnings (should indicate rewrite):")
for w in warnings:
print(w)

cursor.close()
conn.close()