Skip to content

Commit 8307afe

Browse files
mor-gkmboehm7
authored andcommitted
[SYSTEMDS-3921] New Rewrite for Relational Selection Pushdown
Closes #2413.
1 parent 39cc18e commit 8307afe

5 files changed

Lines changed: 417 additions & 1 deletion

File tree

src/main/java/org/apache/sysds/hops/OptimizerUtils.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,12 @@ public enum MemoryManager {
335335

336336
public static boolean AUTO_GPU_CACHE_EVICTION = true;
337337

338-
//////////////////////
338+
/**
339+
* Boolean specifying if relational algebra rewrites are allowed (e.g. Selection Pushdowns).
340+
*/
341+
public static boolean ALLOW_RA_REWRITES = false;
342+
343+
//////////////////////
339344
// Optimizer levels //
340345
//////////////////////
341346

src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
117117
_sbRuleSet.add( new RewriteMarkLoopVariablesUpdateInPlace() );
118118
if( LineageCacheConfig.getCompAssRW() )
119119
_sbRuleSet.add( new MarkForLineageReuse() );
120+
if( OptimizerUtils.ALLOW_RA_REWRITES )
121+
_sbRuleSet.add( new RewriteRaPushdown() );
120122
_sbRuleSet.add( new RewriteRemoveTransformEncodeMeta() );
121123
_dagRuleSet.add( new RewriteNonScalarPrint() );
122124
if( OptimizerUtils.ALLOW_JOIN_REORDERING_REWRITE )
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
21+
package org.apache.sysds.hops.rewrite;
22+
23+
import org.apache.sysds.common.Types;
24+
import org.apache.sysds.hops.DataOp;
25+
import org.apache.sysds.hops.FunctionOp;
26+
import org.apache.sysds.hops.Hop;
27+
import org.apache.sysds.hops.LiteralOp;
28+
import org.apache.sysds.hops.UnaryOp;
29+
import org.apache.sysds.hops.recompile.Recompiler;
30+
import org.apache.sysds.parser.StatementBlock;
31+
import org.apache.sysds.parser.VariableSet;
32+
33+
import java.util.ArrayList;
34+
import java.util.List;
35+
36+
/**
37+
* Rule: Simplify program structure by rewriting relational expressions,
38+
* implemented here: Pushdown of Selections before Join.
39+
*/
40+
public class RewriteRaPushdown extends StatementBlockRewriteRule
41+
{
42+
@Override
43+
public boolean createsSplitDag() {
44+
return false;
45+
}
46+
47+
@Override
48+
public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus state) {
49+
ArrayList<StatementBlock> ret = new ArrayList<>();
50+
ret.add(sb);
51+
return ret;
52+
}
53+
54+
@Override
55+
public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus state) {
56+
if (sbs == null || sbs.size() <= 1)
57+
return sbs;
58+
59+
ArrayList<StatementBlock> tmpList = new ArrayList<>(sbs);
60+
boolean changed = false;
61+
62+
// iterate over all SBs including a FuncOp with FuncName m_raJoin
63+
for (int i : findFunctionSb(tmpList, "m_raJoin", 0)){
64+
StatementBlock sb1 = tmpList.get(i);
65+
FunctionOp joinOp = findFunctionOp(sb1.getHops(), "m_raJoin");
66+
67+
// iterate over all following SBs including a FuncOp with FuncName m_raSelection
68+
for (int j : findFunctionSb(tmpList, "m_raSelection", i+1)){
69+
StatementBlock sb2 = tmpList.get(j);
70+
FunctionOp selOp = findFunctionOp(sb2.getHops(), "m_raSelection");
71+
72+
// create deep copy to ensure data consistency
73+
FunctionOp tmpJoinOp = (FunctionOp) Recompiler.deepCopyHopsDag(joinOp);
74+
FunctionOp tmpSelOp = (FunctionOp) Recompiler.deepCopyHopsDag(selOp);
75+
76+
if (!checkDataDependency(tmpJoinOp, tmpSelOp)){continue;}
77+
78+
Hop selColHop = tmpSelOp.getInput(1);
79+
long selCol = getConstantSelectionCol(selColHop);
80+
if (selCol <= 0)
81+
continue;
82+
83+
// collect Variable Sets
84+
VariableSet joinRead = new VariableSet(sb1.variablesRead());
85+
VariableSet joinUpdated = new VariableSet(sb1.variablesUpdated());
86+
VariableSet selRead = new VariableSet(sb2.variablesRead());
87+
88+
// join inputs: [A, colA, B, colB, method]
89+
long colsLeft = tmpJoinOp.getInput(0).getDataCharacteristics().getCols();
90+
long colsRight = tmpJoinOp.getInput(2).getDataCharacteristics().getCols();
91+
if (colsLeft <= 0 || colsRight <= 0)
92+
continue;
93+
94+
// decide which side of inner join the selection belongs to (A / B)
95+
int selSideIdx;
96+
if (selCol <= colsLeft) {
97+
selSideIdx = 0;
98+
}
99+
else if (selCol <= colsLeft + colsRight) {
100+
selSideIdx = 2;
101+
LiteralOp adjustedColHop = new LiteralOp(selCol - colsLeft);
102+
adjustedColHop.setName(selColHop.getName());
103+
HopRewriteUtils.replaceChildReference(tmpSelOp, selColHop, adjustedColHop, 1);
104+
}
105+
else { continue; } // invalid column index
106+
107+
// switch funcOps Output Variables
108+
String joinOutVar = tmpJoinOp.getOutputVariableNames()[0];
109+
tmpJoinOp.getOutputVariableNames()[0] = tmpSelOp.getOutputVariableNames()[0];
110+
tmpSelOp.getOutputVariableNames()[0] = joinOutVar;
111+
112+
// rewire selection to consume the correct join input and adjusted column
113+
Hop newSelInput = tmpJoinOp.getInput().get(selSideIdx);
114+
HopRewriteUtils.replaceChildReference(tmpSelOp, tmpSelOp.getInput().get(0), newSelInput, 0);
115+
116+
// let the join take selection output instead of raw input
117+
Hop newJoinInput = HopRewriteUtils.createTransientRead(joinOutVar, tmpSelOp);
118+
HopRewriteUtils.replaceChildReference(tmpJoinOp, newSelInput, newJoinInput, selSideIdx);
119+
120+
//switch StatementBlock-assignments
121+
sb1.getHops().remove(joinOp);
122+
sb1.getHops().add(tmpSelOp);
123+
sb2.getHops().remove(selOp);
124+
sb2.getHops().add(tmpJoinOp);
125+
126+
// modify SB- variable sets
127+
VariableSet vs = new VariableSet();
128+
vs.addVariable(joinOutVar, joinUpdated.getVariable(joinOutVar));
129+
selRead.removeVariables(vs);
130+
selRead.addVariable(newSelInput.getName(), joinRead.getVariable(newSelInput.getName()));
131+
132+
// selection now reads the original join inputs plus its own metadata
133+
sb1.setReadVariables(selRead);
134+
sb1.setLiveOut(VariableSet.minus(joinUpdated, selRead));
135+
sb1.setLiveIn(selRead);
136+
sb1.setGen(selRead);
137+
138+
// join now consumes the selection output and produces the output
139+
sb2.setReadVariables(sb1.liveOut());
140+
sb2.setGen(sb1.liveOut());
141+
sb2.setLiveIn(sb1.liveOut());
142+
143+
// mark change & increment i by 1 (i+1 = now join-Sb)
144+
changed = true;
145+
i++;
146+
147+
LOG.debug("Applied rewrite: pushed m_raSelection before m_raJoin (blocks lines "
148+
+ sb1.getBeginLine() + "-" + sb1.getEndLine() + " and "
149+
+ sb2.getBeginLine() + "-" + sb2.getEndLine() + ").");
150+
}
151+
}
152+
return changed ? tmpList : sbs;
153+
}
154+
155+
private List<Integer> findFunctionSb(List<StatementBlock> sbs, String functionName, int startIdx) {
156+
List<Integer> functionSbs = new ArrayList<>();
157+
158+
for (int i = startIdx; i < sbs.size(); i++) {
159+
StatementBlock sb = sbs.get(i);
160+
161+
// easy preconditions
162+
if (!HopRewriteUtils.isLastLevelStatementBlock(sb) || sb.isSplitDag()) {
163+
continue;
164+
}
165+
166+
// find if StatementBlocks have certain FunctionOp, continue if not found
167+
FunctionOp functionOp = findFunctionOp(sb.getHops(), functionName);
168+
169+
// if found, add to list
170+
if (functionOp != null) { functionSbs.add(i); }
171+
}
172+
173+
return functionSbs;
174+
}
175+
176+
private boolean checkDataDependency(FunctionOp fOut, FunctionOp fIn){
177+
for (String out : fOut.getOutputVariableNames()) {
178+
for (Hop h : fIn.getInput()) {
179+
if (h.getName().equals(out)){
180+
return true;
181+
}
182+
}
183+
}
184+
return false;
185+
}
186+
187+
private FunctionOp findFunctionOp(List<Hop> roots, String functionName) {
188+
if (roots == null)
189+
return null;
190+
Hop.resetVisitStatus(roots, true);
191+
for (Hop root : roots) {
192+
if (root instanceof FunctionOp funcOp) {
193+
if (funcOp.getFunctionName().equals(functionName))
194+
{ return funcOp; }
195+
}
196+
}
197+
return null;
198+
}
199+
200+
private long getConstantSelectionCol(Hop selColHop) {
201+
if (selColHop instanceof LiteralOp lit)
202+
return HopRewriteUtils.getIntValueSafe(lit);
203+
204+
// Handle casted literals (e.g., type propagation inserted casts)
205+
if (selColHop instanceof UnaryOp uop && uop.getOp() == Types.OpOp1.CAST_AS_INT
206+
&& uop.getInput().get(0) instanceof LiteralOp lit)
207+
return HopRewriteUtils.getIntValueSafe(lit);
208+
209+
// If hop is a dataop whose input is a literal, try to fold
210+
if (selColHop instanceof DataOp dop && !dop.getInput().isEmpty() && dop.getInput().get(0) instanceof LiteralOp lit)
211+
return HopRewriteUtils.getIntValueSafe(lit);
212+
213+
return -1; // unknown at rewrite time
214+
}
215+
}
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.functions.rewrite;
21+
22+
import java.util.HashMap;
23+
24+
import org.apache.sysds.common.Opcodes;
25+
import org.apache.sysds.common.Types;
26+
import org.apache.sysds.hops.OptimizerUtils;
27+
import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
28+
import org.apache.sysds.test.AutomatedTestBase;
29+
import org.apache.sysds.test.TestConfiguration;
30+
import org.apache.sysds.test.TestUtils;
31+
import org.junit.Test;
32+
33+
public class RewritePushdownRaSelectionTest extends AutomatedTestBase
34+
{
35+
private static final String TEST_NAME = "RewritePushdownRaSelection";
36+
private static final String TEST_DIR = "functions/rewrite/";
37+
private static final String TEST_CLASS_DIR = TEST_DIR + RewritePushdownRaSelectionTest.class.getSimpleName() + "/";
38+
39+
private static final double eps = 1e-8;
40+
41+
@Override
42+
public void setUp() {
43+
TestUtils.clearAssertionInformation();
44+
addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"result"}));
45+
}
46+
47+
@Test
48+
public void testRewritePushdownRaSelectionNoRewrite() {
49+
int col = 1;
50+
String op = Opcodes.EQUAL.toString();
51+
double val = 4.0;
52+
53+
// Expected output matrix
54+
double[][] Y = {
55+
{4,7,8,4,7,8},
56+
{4,7,8,4,5,10},
57+
{4,3,5,4,7,8},
58+
{4,3,5,4,5,10},
59+
};
60+
61+
testRewritePushdownRaSelection(col, op, val, Y, "nested-loop", false);
62+
}
63+
64+
@Test
65+
public void testRewritePushdownRaSelection1() {
66+
int col = 1;
67+
String op = Opcodes.EQUAL.toString();
68+
double val = 4.0;
69+
70+
// Expected output matrix
71+
double[][] Y = {
72+
{4,7,8,4,7,8},
73+
{4,7,8,4,5,10},
74+
{4,3,5,4,7,8},
75+
{4,3,5,4,5,10},
76+
};
77+
78+
testRewritePushdownRaSelection(col, op, val, Y, "sort-merge", true);
79+
}
80+
81+
@Test
82+
public void testRewritePushdownRaSelection2() {
83+
int col = 5;
84+
String op = Opcodes.EQUAL.toString();
85+
double val = 7.0;
86+
87+
// Expected output matrix
88+
double[][] Y = {
89+
{4,7,8,4,7,8},
90+
{4,3,5,4,7,8},
91+
};
92+
93+
testRewritePushdownRaSelection(col, op, val, Y, "sort-merge", true);
94+
}
95+
96+
private void testRewritePushdownRaSelection(int col, String op, double val, double[][] Y,
97+
String method, boolean rewrites) {
98+
99+
//generate actual dataset and variables
100+
double[][] A = {
101+
{1, 2, 3},
102+
{4, 7, 8},
103+
{1, 3, 6},
104+
{4, 3, 5},
105+
{5, 8, 9}
106+
};
107+
double[][] B = {
108+
{1, 2, 9},
109+
{3, 7, 6},
110+
{2, 8, 5},
111+
{4, 7, 8},
112+
{4, 5, 10}
113+
};
114+
int colA = 1;
115+
int colB = 1;
116+
117+
runRewritePushdownRaSelectionTest(A, colA, B, colB, Y, col, op, val, method, rewrites);
118+
}
119+
120+
121+
private void runRewritePushdownRaSelectionTest(double [][] A, int colA, double [][] B, int colB, double [][] Y,
122+
int col, String op, double val, String method, boolean rewrites)
123+
{
124+
Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);
125+
boolean oldFlag = OptimizerUtils.ALLOW_RA_REWRITES;
126+
127+
try
128+
{
129+
loadTestConfiguration(getTestConfiguration(TEST_NAME));
130+
String HOME = SCRIPT_DIR + TEST_DIR;
131+
132+
fullDMLScriptName = HOME + TEST_NAME + ".dml";
133+
programArgs = new String[]{"-explain", "hops", "-args",
134+
input("A"), String.valueOf(colA), input("B"),
135+
String.valueOf(colB), String.valueOf(col), op, String.valueOf(val), method, output("result") };
136+
writeInputMatrixWithMTD("A", A, true);
137+
writeInputMatrixWithMTD("B", B, true);
138+
139+
OptimizerUtils.ALLOW_RA_REWRITES = rewrites;
140+
141+
// run dmlScript
142+
runTest(null);
143+
144+
//compare matrices
145+
HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromOutputDir("result");
146+
HashMap<CellIndex, Double> expectedOutput = TestUtils.convert2DDoubleArrayToHashMap(Y);
147+
TestUtils.compareMatrices(dmlfile, expectedOutput, eps, "Stat-DML", "Expected");
148+
}
149+
finally {
150+
rtplatform = platformOld;
151+
OptimizerUtils.ALLOW_RA_REWRITES = oldFlag;
152+
}
153+
}
154+
}

0 commit comments

Comments
 (0)