diff --git a/src/main/java/org/qed/Generated/CalciteGenerator.java b/src/main/java/org/qed/Generated/CalciteGenerator.java index 32691cd..4a74dc1 100644 --- a/src/main/java/org/qed/Generated/CalciteGenerator.java +++ b/src/main/java/org/qed/Generated/CalciteGenerator.java @@ -116,12 +116,12 @@ public Env onMatchAnd(Env env, RexRN.And and) { String andSymbol = "and_" + env.varId.getAndIncrement(); // Store the current expression as this AND node's symbol current_env = current_env.symbol(andSymbol, current_env.current()); - + // Process each child source in the AND condition for (var source : and.sources()) { current_env = onMatch(current_env, source); } - + return current_env; } @@ -129,11 +129,11 @@ public Env onMatchAnd(Env env, RexRN.And and) { public Env onMatchUnion(Env env, RelRN.Union union) { // Get the all flag from the union boolean all = union.all(); - + // Process each source in the union var current_env = env; var skeletons = Seq.empty(); - + // Process all sources in the sequence for (var source : union.sources()) { var next_env = current_env.next(); @@ -141,7 +141,7 @@ public Env onMatchUnion(Env env, RelRN.Union union) { skeletons = skeletons.appended(source_env.skeleton()); current_env = source_env; } - + // Build the input skeletons string for the operand StringBuilder inputsBuilder = new StringBuilder(); for (int i = 0; i < skeletons.size(); i++) { @@ -150,7 +150,7 @@ public Env onMatchUnion(Env env, RelRN.Union union) { } inputsBuilder.append(skeletons.get(i).toString()); } - + // Create the union operand with the appropriate class based on the all flag String operatorClass = all ? "LogicalUnionAll" : "LogicalUnion"; return current_env.grow("operand(" + operatorClass + ".class).inputs(" + inputsBuilder.toString() + ")"); @@ -160,11 +160,11 @@ public Env onMatchUnion(Env env, RelRN.Union union) { public Env onMatchIntersect(Env env, RelRN.Intersect intersect) { // Get the all flag from the intersect boolean all = intersect.all(); - + // Process each source in the intersect var current_env = env; var skeletons = Seq.empty(); - + // Process all sources in the sequence for (var source : intersect.sources()) { var next_env = current_env.next(); @@ -172,7 +172,7 @@ public Env onMatchIntersect(Env env, RelRN.Intersect intersect) { skeletons = skeletons.appended(source_env.skeleton()); current_env = source_env; } - + // Build the input skeletons string for the operand StringBuilder inputsBuilder = new StringBuilder(); for (int i = 0; i < skeletons.size(); i++) { @@ -181,7 +181,7 @@ public Env onMatchIntersect(Env env, RelRN.Intersect intersect) { } inputsBuilder.append(skeletons.get(i).toString()); } - + // Create the intersect operand with the appropriate class based on the all flag String operatorClass = all ? "LogicalIntersectAll" : "LogicalIntersect"; return current_env.grow("operand(" + operatorClass + ".class).inputs(" + inputsBuilder.toString() + ")"); @@ -191,11 +191,11 @@ public Env onMatchIntersect(Env env, RelRN.Intersect intersect) { public Env onMatchMinus(Env env, RelRN.Minus minus) { // Get the all flag from the minus boolean all = minus.all(); - + // Process each source in the minus var current_env = env; var skeletons = Seq.empty(); - + // Process all sources in the sequence for (var source : minus.sources()) { var next_env = current_env.next(); @@ -203,7 +203,7 @@ public Env onMatchMinus(Env env, RelRN.Minus minus) { skeletons = skeletons.appended(source_env.skeleton()); current_env = source_env; } - + // Build the input skeletons string for the operand StringBuilder inputsBuilder = new StringBuilder(); for (int i = 0; i < skeletons.size(); i++) { @@ -212,7 +212,7 @@ public Env onMatchMinus(Env env, RelRN.Minus minus) { } inputsBuilder.append(skeletons.get(i).toString()); } - + // Create the minus operand return current_env.grow("operand(LogicalMinus.class).inputs(" + inputsBuilder.toString() + ")"); } @@ -221,7 +221,7 @@ public Env onMatchMinus(Env env, RelRN.Minus minus) { public Env onMatchField(Env env, RexRN.Field field) { // Generate a unique symbolic name for this field String fieldSymbol = "field_" + env.varId.getAndIncrement(); - + // Store the field expression in the environment's symbol table return env.symbol(fieldSymbol, env.current()); } @@ -230,7 +230,7 @@ public Env onMatchField(Env env, RexRN.Field field) { public Env onMatchTrue(Env env, RexRN literal) { // Create a unique symbol name for this true literal String trueSymbol = "true_" + env.varId.getAndIncrement(); - + // Store the current expression as this true literal's symbol return env.symbol(trueSymbol, env.current()); } @@ -239,7 +239,7 @@ public Env onMatchTrue(Env env, RexRN literal) { public Env onMatchFalse(Env env, RexRN literal) { // Create a unique symbol name for this false literal String falseSymbol = "false_" + env.varId.getAndIncrement(); - + // Store the current expression as this false literal's symbol return env.symbol(falseSymbol, env.current()); } @@ -290,28 +290,28 @@ public Env transformPred(Env env, RexRN.Pred pred) { pred.sources().get(1) instanceof RexRN.JoinField joinField2 && joinField1.ordinal() == 1 && // First arg is ordinal 1 (left -> right in swap) joinField2.ordinal() == 0; // Second arg is ordinal 0 (right -> left in swap) - + if (isExactJoinCommuteSwapPattern) { // This is the JoinCommute swapped predicate - transform with swapped arguments var currentEnv = env; var transformedArgs = Seq.empty(); - + // Process arguments in reverse order to swap them back var sources = pred.sources(); var reversedSources = Seq.of(sources.get(1), sources.get(0)); - + for (var arg : reversedSources) { currentEnv = transform(currentEnv, arg); transformedArgs = transformedArgs.appended(currentEnv.current()); currentEnv = currentEnv.focus(env.current()); } - + // Create a new predicate call with transformed arguments String argsString = transformedArgs.joinToString(", "); - + // Extract operator from the original join condition String operatorCall = "((org.apache.calcite.rex.RexCall) ((LogicalJoin) call.rel(0)).getCondition()).getOperator()"; - + return currentEnv.focus(env.current() + ".call(" + operatorCall + ", " + argsString + ")"); } else { return env.focus(env.symbols().get(pred.operator().getName())); @@ -327,30 +327,30 @@ public Env transformJoinField(Env env, RexRN.JoinField joinField) { var envWithOrigJoin = origJoinDecl.getValue(); var conditionDecl = envWithOrigJoin.declare("(org.apache.calcite.rex.RexCall) " + origJoinDecl.getKey() + ".getCondition()"); var envWithCondition = conditionDecl.getValue(); - + if (joinField.ordinal() == 0) { - // Ordinal 0 = Left table in original join - // Extract the left operand field index from original condition + // Ordinal 0 = Left table in original join + // Extract the left operand field index from original condition var leftFieldDecl = envWithCondition.declare("((org.apache.calcite.rex.RexInputRef) " + conditionDecl.getKey() + ".getOperands().get(0)).getIndex()"); var envWithLeftField = leftFieldDecl.getValue(); - + // In swapped join: Left table is now at input 1 // Use field(2, 1, leftFieldIndex) syntax return envWithLeftField.focus(env.current() + ".field(2, 1, " + leftFieldDecl.getKey() + ")"); - } + } else if (joinField.ordinal() == 1) { - // Ordinal 1 = Right table in original join + // Ordinal 1 = Right table in original join // Extract the right operand field index from original condition var rightFieldDecl = envWithCondition.declare("((org.apache.calcite.rex.RexInputRef) " + conditionDecl.getKey() + ".getOperands().get(1)).getIndex()"); var envWithRightField = rightFieldDecl.getValue(); - + // Right table field index needs to be adjusted since it was originally after left table var leftColCountDecl = envWithRightField.declare("call.rel(1).getRowType().getFieldCount()"); var envWithLeftCount = leftColCountDecl.getValue(); var adjustedRightFieldDecl = envWithLeftCount.declare(rightFieldDecl.getKey() + " - " + leftColCountDecl.getKey()); var envWithAdjustedRightField = adjustedRightFieldDecl.getValue(); - - // In swapped join: Right table is now at input 0 + + // In swapped join: Right table is now at input 0 // Use field(2, 0, adjustedRightFieldIndex) syntax return envWithAdjustedRightField.focus(env.current() + ".field(2, 0, " + adjustedRightFieldDecl.getKey() + ")"); } else { @@ -391,16 +391,16 @@ public Env transformAnd(Env env, RexRN.And and) { public Env transformUnion(Env env, RelRN.Union union) { // Get the all flag from the union boolean all = union.all(); - + // The number of sources int sourceCount = union.sources().size(); - + // Transform each source var current_env = env; for (var source : union.sources()) { current_env = transform(current_env, source); } - + // Use the union method with the all flag and source count // This matches the Calcite RelBuilder.union(boolean all, int n) signature return current_env.focus(current_env.current() + ".union(" + all + ", " + sourceCount + ")"); @@ -410,16 +410,16 @@ public Env transformUnion(Env env, RelRN.Union union) { public Env transformIntersect(Env env, RelRN.Intersect intersect) { // Get the all flag from the intersect boolean all = intersect.all(); - + // The number of sources int sourceCount = intersect.sources().size(); - + // Transform each source var current_env = env; for (var source : intersect.sources()) { current_env = transform(current_env, source); } - + // Use the intersect method with the all flag and source count // This matches the expected Calcite RelBuilder.intersect(boolean all, int n) signature String methodName = all ? "intersectAll" : "intersect"; @@ -430,16 +430,16 @@ public Env transformIntersect(Env env, RelRN.Intersect intersect) { public Env transformMinus(Env env, RelRN.Minus minus) { // Get the all flag from the minus boolean all = minus.all(); - + // The number of sources int sourceCount = minus.sources().size(); - + // Transform each source var current_env = env; for (var source : minus.sources()) { current_env = transform(current_env, source); } - + // Use the minus method with the all flag and source count // This matches the expected Calcite RelBuilder.minus(boolean all, int n) signature return current_env.focus(current_env.current() + ".minus(" + all + ", " + sourceCount + ")"); @@ -449,7 +449,7 @@ public Env transformMinus(Env env, RelRN.Minus minus) { public Env transformField(Env env, RexRN.Field field) { // In Calcite, field references are typically created with a "field" method // We'll need to pass some identifier for the field - use toString() if no specific field accessor is available - + // Assuming field has a method that returns some kind of identifier or name // If not, we may need to adjust this implementation return env.focus(env.current() + ".field(" + field + ")"); @@ -459,13 +459,13 @@ public Env transformField(Env env, RexRN.Field field) { public Env transformProj(Env env, RexRN.Proj proj) { // In Calcite, projections are typically created using the operator name // This is similar to your transformPred implementation - + // Look up the symbol from the matching phase if (!env.symbols().containsKey(proj.operator().getName())) { - throw new RuntimeException("Operator symbol not found: " + proj.operator().getName() + + throw new RuntimeException("Operator symbol not found: " + proj.operator().getName() + ". Make sure onMatchProj is properly implemented."); } - + // Return an environment focused on the expression for this projection return env.focus(env.symbols().get(proj.operator().getName())); } @@ -475,10 +475,10 @@ public Env transformProject(Env env, RelRN.Project project) { // First transform the source relation var source_transform = transform(env, project.source()); var source_expression = source_transform.current(); - + // Then transform the projection map var map_transform = transform(source_transform, project.map()); - + // Combine the source and projection using the project operation // This creates a projection on top of the source relation return map_transform.focus(source_expression + ".project(" + map_transform.current() + ")"); @@ -486,14 +486,14 @@ public Env transformProject(Env env, RelRN.Project project) { @Override public Env transformTrue(Env env, RexRN literal) { - // In Calcite, true literals are typically represented using the + // In Calcite, true literals are typically represented using the // rexBuilder.makeLiteral(true) method or just "TRUE" return env.focus(env.current() + ".literal(true)"); } @Override public Env transformFalse(Env env, RexRN literal) { - // In Calcite, false literals are represented using the + // In Calcite, false literals are represented using the // rexBuilder.makeLiteral(false) method or just "FALSE" return env.focus(env.current() + ".literal(false)"); } @@ -505,43 +505,43 @@ public Env transformEmpty(Env env, RelRN.Empty empty) { return env.focus(env.current() + ".empty()"); } - + @Override public Env transformCustom(Env env, RelRN custom) { return switch (custom) { case org.qed.Generated.RRuleInstances.JoinCommute.ProjectionRelRN projection -> { // Transform the source first - this builds the join var sourceEnv = transform(env, projection.source()); - + // Get original table column counts var leftTableDecl = sourceEnv.declare("call.rel(1)"); var envWithLeftTable = leftTableDecl.getValue(); var rightTableDecl = envWithLeftTable.declare("call.rel(2)"); var envWithRightTable = rightTableDecl.getValue(); - + var leftColCountDecl = envWithRightTable.declare(leftTableDecl.getKey() + ".getRowType().getFieldCount()"); var envWithLeftCount = leftColCountDecl.getValue(); var rightColCountDecl = envWithLeftCount.declare(rightTableDecl.getKey() + ".getRowType().getFieldCount()"); var envWithRightCount = rightColCountDecl.getValue(); - + // Create the projection indices as a List var projectionIndicesDecl = envWithRightCount.declare( "java.util.stream.IntStream.concat(" + // Left columns: rightColCount + 0, rightColCount + 1, ..., rightColCount + leftColCount - 1 - "java.util.stream.IntStream.range(" + rightColCountDecl.getKey() + ", " + + "java.util.stream.IntStream.range(" + rightColCountDecl.getKey() + ", " + rightColCountDecl.getKey() + " + " + leftColCountDecl.getKey() + "), " + // Right columns: 0, 1, ..., rightColCount - 1 "java.util.stream.IntStream.range(0, " + rightColCountDecl.getKey() + ")" + ").boxed().collect(java.util.stream.Collectors.toList())" ); var envWithProjectionIndices = projectionIndicesDecl.getValue(); - + // Convert List to field references using RelBuilder.fields() var fieldRefsDecl = envWithProjectionIndices.declare( sourceEnv.current() + ".fields(" + projectionIndicesDecl.getKey() + ")" ); var envWithFieldRefs = fieldRefsDecl.getValue(); - + // Apply projection using the field references list yield envWithFieldRefs.focus(sourceEnv.current() + ".project(" + fieldRefsDecl.getKey() + ")"); } @@ -582,4 +582,30 @@ public Env grow(String requirement) { return new Env(varId, rel, current, vn + " -> " + vn + "." + requirement, statements, symbols); } } + + @Override + public Env onMatchAggregate(Env env, RelRN.Aggregate aggregate) { + var sourceMatch = onMatch(env.next(), aggregate.source()); + return sourceMatch.grow("operand(LogicalAggregate.class).oneInput(" + sourceMatch.skeleton() + ")"); + } + + @Override + public Env transformAggregate(Env env, RelRN.Aggregate aggregate) { + var sourceTransform = transform(env, aggregate.source()); + String builderWithSource = sourceTransform.current(); + String originalAgg = "((LogicalAggregate) call.rel(0))"; + var groupSetDecl = sourceTransform.declare(originalAgg + ".getGroupSet()"); + var envWithGroupSet = groupSetDecl.getValue(); + var groupKeyDecl = envWithGroupSet.declare( + builderWithSource + ".groupKey(" + groupSetDecl.getKey() + ")" + ); + var envWithGroupKey = groupKeyDecl.getValue(); + var aggCallsDecl = envWithGroupKey.declare(originalAgg + ".getAggCallList()"); + var envWithAggCalls = aggCallsDecl.getValue(); + return envWithAggCalls.focus( + builderWithSource + ".aggregate(" + + groupKeyDecl.getKey() + ", " + + aggCallsDecl.getKey() + ")" + ); + } } diff --git a/src/main/java/org/qed/Generated/RRuleInstances/AggregateFilterTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/AggregateFilterTranspose.java new file mode 100644 index 0000000..ab35583 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/AggregateFilterTranspose.java @@ -0,0 +1,30 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RelType; + +public record AggregateFilterTranspose() implements RRule { + static final RelRN source_col1 = RelRN.scan("Source1", "Int_Type"); + static final RelRN source_col2 = RelRN.scan("Source2", "Int_Type"); + static final RelRN source = source_col1.join(JoinRelType.INNER, RexRN.trueLiteral(), source_col2); + static final RexRN filterCondition = source.field(0).pred("cond"); + + @Override + public RelRN before() { + RelRN filteredSource = source.filter(filterCondition); + return new RelRN.Aggregate( + filteredSource, + Seq.of(filteredSource.field(0)), + Seq.of(new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(filteredSource.field(1)))) + ); + } + @Override + public RelRN after() { + RelRN.Aggregate aggregateOnSource = new RelRN.Aggregate(source, Seq.of(source.field(0)), Seq.of(new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(source.field(1))))); + return aggregateOnSource.filter(aggregateOnSource.field(0).pred("cond")); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java b/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java new file mode 100644 index 0000000..48c92d9 --- /dev/null +++ b/src/main/java/org/qed/Generated/RRuleInstances/FilterAggregateTranspose.java @@ -0,0 +1,29 @@ +package org.qed.Generated.RRuleInstances; + +import kala.collection.Seq; +import org.apache.calcite.rel.core.JoinRelType; +import org.qed.RelRN; +import org.qed.RexRN; +import org.qed.RRule; +import org.qed.RelType; + +public record FilterAggregateTranspose() implements RRule { + static final RelRN source_col1 = RelRN.scan("Source1", "Int_Type"); + static final RelRN source_col2 = RelRN.scan("Source2", "Int_Type"); + static final RelRN source = source_col1.join(JoinRelType.INNER, RexRN.trueLiteral(), source_col2); + static final RelRN.Aggregate aggregate = new RelRN.Aggregate(source, Seq.of(source.field(0)), Seq.of(new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(source.field(1))))); + static final RexRN pushedCondition = aggregate.field(0).pred("pushed_cond"); + static final RexRN remainingCondition = aggregate.field(1).pred("remaining_cond"); + + @Override + public RelRN before() { + return aggregate.filter(RexRN.and(pushedCondition, remainingCondition)); + } + + @Override + public RelRN after() { + RelRN filteredSource = source.filter(source.field(0).pred("pushed_cond")); + RelRN.Aggregate newAggregate = new RelRN.Aggregate(filteredSource, Seq.of(filteredSource.field(0)), Seq.of(new RelRN.AggCall("SUM", false, RelType.fromString("INTEGER", true), Seq.of(filteredSource.field(1))))); + return newAggregate.filter(newAggregate.field(1).pred("remaining_cond")); + } +} \ No newline at end of file diff --git a/src/main/java/org/qed/RelRN.java b/src/main/java/org/qed/RelRN.java index 638985b..0967127 100644 --- a/src/main/java/org/qed/RelRN.java +++ b/src/main/java/org/qed/RelRN.java @@ -207,7 +207,7 @@ public RelNode semantics() { record AggCall(String name, boolean distinct, RelType type, Seq operands) { } - record Aggregate(RelRN source, Seq groupSet, Seq aggCalls) implements RelRN { + record Aggregate(org.qed.RelRN source, Seq groupSet, Seq aggCalls) implements org.qed.RelRN { @Override public RelNode semantics() { var builder = RuleBuilder.create();