1+ /*
2+ * Licensed to the Apache Software Foundation (ASF) under one
3+ * or more contributor license agreements. See the NOTICE file
4+ * distributed with this work for additional information
5+ * regarding copyright ownership. The ASF licenses this file
6+ * to you under the Apache License, Version 2.0 (the
7+ * "License"); you may not use this file except in compliance
8+ * with the License. You may obtain a copy of the License at
9+ *
10+ * http://www.apache.org/licenses/LICENSE-2.0
11+ *
12+ * Unless required by applicable law or agreed to in writing,
13+ * software distributed under the License is distributed on an
14+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+ * KIND, either express or implied. See the License for the
16+ * specific language governing permissions and limitations
17+ * under the License.
18+ */
19+
20+
21+ package org .apache .sysds .hops .rewrite ;
22+
23+ import org .apache .sysds .common .Types ;
24+ import org .apache .sysds .hops .DataOp ;
25+ import org .apache .sysds .hops .FunctionOp ;
26+ import org .apache .sysds .hops .Hop ;
27+ import org .apache .sysds .hops .LiteralOp ;
28+ import org .apache .sysds .hops .UnaryOp ;
29+ import org .apache .sysds .hops .recompile .Recompiler ;
30+ import org .apache .sysds .parser .StatementBlock ;
31+ import org .apache .sysds .parser .VariableSet ;
32+
33+ import java .util .ArrayList ;
34+ import java .util .List ;
35+
36+ /**
37+ * Rule: Simplify program structure by rewriting relational expressions,
38+ * implemented here: Pushdown of Selections before Join.
39+ */
40+ public class RewriteRaPushdown extends StatementBlockRewriteRule
41+ {
42+ @ Override
43+ public boolean createsSplitDag () {
44+ return false ;
45+ }
46+
47+ @ Override
48+ public List <StatementBlock > rewriteStatementBlock (StatementBlock sb , ProgramRewriteStatus state ) {
49+ ArrayList <StatementBlock > ret = new ArrayList <>();
50+ ret .add (sb );
51+ return ret ;
52+ }
53+
54+ @ Override
55+ public List <StatementBlock > rewriteStatementBlocks (List <StatementBlock > sbs , ProgramRewriteStatus state ) {
56+ if (sbs == null || sbs .size () <= 1 )
57+ return sbs ;
58+
59+ ArrayList <StatementBlock > tmpList = new ArrayList <>(sbs );
60+ boolean changed = false ;
61+
62+ // iterate over all SBs including a FuncOp with FuncName m_raJoin
63+ for (int i : findFunctionSb (tmpList , "m_raJoin" , 0 )){
64+ StatementBlock sb1 = tmpList .get (i );
65+ FunctionOp joinOp = findFunctionOp (sb1 .getHops (), "m_raJoin" );
66+
67+ // iterate over all following SBs including a FuncOp with FuncName m_raSelection
68+ for (int j : findFunctionSb (tmpList , "m_raSelection" , i +1 )){
69+ StatementBlock sb2 = tmpList .get (j );
70+ FunctionOp selOp = findFunctionOp (sb2 .getHops (), "m_raSelection" );
71+
72+ // create deep copy to ensure data consistency
73+ FunctionOp tmpJoinOp = (FunctionOp ) Recompiler .deepCopyHopsDag (joinOp );
74+ FunctionOp tmpSelOp = (FunctionOp ) Recompiler .deepCopyHopsDag (selOp );
75+
76+ if (!checkDataDependency (tmpJoinOp , tmpSelOp )){continue ;}
77+
78+ Hop selColHop = tmpSelOp .getInput (1 );
79+ long selCol = getConstantSelectionCol (selColHop );
80+ if (selCol <= 0 )
81+ continue ;
82+
83+ // collect Variable Sets
84+ VariableSet joinRead = new VariableSet (sb1 .variablesRead ());
85+ VariableSet joinUpdated = new VariableSet (sb1 .variablesUpdated ());
86+ VariableSet selRead = new VariableSet (sb2 .variablesRead ());
87+
88+ // join inputs: [A, colA, B, colB, method]
89+ long colsLeft = tmpJoinOp .getInput (0 ).getDataCharacteristics ().getCols ();
90+ long colsRight = tmpJoinOp .getInput (2 ).getDataCharacteristics ().getCols ();
91+ if (colsLeft <= 0 || colsRight <= 0 )
92+ continue ;
93+
94+ // decide which side of inner join the selection belongs to (A / B)
95+ int selSideIdx ;
96+ if (selCol <= colsLeft ) {
97+ selSideIdx = 0 ;
98+ }
99+ else if (selCol <= colsLeft + colsRight ) {
100+ selSideIdx = 2 ;
101+ LiteralOp adjustedColHop = new LiteralOp (selCol - colsLeft );
102+ adjustedColHop .setName (selColHop .getName ());
103+ HopRewriteUtils .replaceChildReference (tmpSelOp , selColHop , adjustedColHop , 1 );
104+ }
105+ else { continue ; } // invalid column index
106+
107+ // switch funcOps Output Variables
108+ String joinOutVar = tmpJoinOp .getOutputVariableNames ()[0 ];
109+ tmpJoinOp .getOutputVariableNames ()[0 ] = tmpSelOp .getOutputVariableNames ()[0 ];
110+ tmpSelOp .getOutputVariableNames ()[0 ] = joinOutVar ;
111+
112+ // rewire selection to consume the correct join input and adjusted column
113+ Hop newSelInput = tmpJoinOp .getInput ().get (selSideIdx );
114+ HopRewriteUtils .replaceChildReference (tmpSelOp , tmpSelOp .getInput ().get (0 ), newSelInput , 0 );
115+
116+ // let the join take selection output instead of raw input
117+ Hop newJoinInput = HopRewriteUtils .createTransientRead (joinOutVar , tmpSelOp );
118+ HopRewriteUtils .replaceChildReference (tmpJoinOp , newSelInput , newJoinInput , selSideIdx );
119+
120+ //switch StatementBlock-assignments
121+ sb1 .getHops ().remove (joinOp );
122+ sb1 .getHops ().add (tmpSelOp );
123+ sb2 .getHops ().remove (selOp );
124+ sb2 .getHops ().add (tmpJoinOp );
125+
126+ // modify SB- variable sets
127+ VariableSet vs = new VariableSet ();
128+ vs .addVariable (joinOutVar , joinUpdated .getVariable (joinOutVar ));
129+ selRead .removeVariables (vs );
130+ selRead .addVariable (newSelInput .getName (), joinRead .getVariable (newSelInput .getName ()));
131+
132+ // selection now reads the original join inputs plus its own metadata
133+ sb1 .setReadVariables (selRead );
134+ sb1 .setLiveOut (VariableSet .minus (joinUpdated , selRead ));
135+ sb1 .setLiveIn (selRead );
136+ sb1 .setGen (selRead );
137+
138+ // join now consumes the selection output and produces the output
139+ sb2 .setReadVariables (sb1 .liveOut ());
140+ sb2 .setGen (sb1 .liveOut ());
141+ sb2 .setLiveIn (sb1 .liveOut ());
142+
143+ // mark change & increment i by 1 (i+1 = now join-Sb)
144+ changed = true ;
145+ i ++;
146+
147+ LOG .debug ("Applied rewrite: pushed m_raSelection before m_raJoin (blocks lines "
148+ + sb1 .getBeginLine () + "-" + sb1 .getEndLine () + " and "
149+ + sb2 .getBeginLine () + "-" + sb2 .getEndLine () + ")." );
150+ }
151+ }
152+ return changed ? tmpList : sbs ;
153+ }
154+
155+ private List <Integer > findFunctionSb (List <StatementBlock > sbs , String functionName , int startIdx ) {
156+ List <Integer > functionSbs = new ArrayList <>();
157+
158+ for (int i = startIdx ; i < sbs .size (); i ++) {
159+ StatementBlock sb = sbs .get (i );
160+
161+ // easy preconditions
162+ if (!HopRewriteUtils .isLastLevelStatementBlock (sb ) || sb .isSplitDag ()) {
163+ continue ;
164+ }
165+
166+ // find if StatementBlocks have certain FunctionOp, continue if not found
167+ FunctionOp functionOp = findFunctionOp (sb .getHops (), functionName );
168+
169+ // if found, add to list
170+ if (functionOp != null ) { functionSbs .add (i ); }
171+ }
172+
173+ return functionSbs ;
174+ }
175+
176+ private boolean checkDataDependency (FunctionOp fOut , FunctionOp fIn ){
177+ for (String out : fOut .getOutputVariableNames ()) {
178+ for (Hop h : fIn .getInput ()) {
179+ if (h .getName ().equals (out )){
180+ return true ;
181+ }
182+ }
183+ }
184+ return false ;
185+ }
186+
187+ private FunctionOp findFunctionOp (List <Hop > roots , String functionName ) {
188+ if (roots == null )
189+ return null ;
190+ Hop .resetVisitStatus (roots , true );
191+ for (Hop root : roots ) {
192+ if (root instanceof FunctionOp funcOp ) {
193+ if (funcOp .getFunctionName ().equals (functionName ))
194+ { return funcOp ; }
195+ }
196+ }
197+ return null ;
198+ }
199+
200+ private long getConstantSelectionCol (Hop selColHop ) {
201+ if (selColHop instanceof LiteralOp lit )
202+ return HopRewriteUtils .getIntValueSafe (lit );
203+
204+ // Handle casted literals (e.g., type propagation inserted casts)
205+ if (selColHop instanceof UnaryOp uop && uop .getOp () == Types .OpOp1 .CAST_AS_INT
206+ && uop .getInput ().get (0 ) instanceof LiteralOp lit )
207+ return HopRewriteUtils .getIntValueSafe (lit );
208+
209+ // If hop is a dataop whose input is a literal, try to fold
210+ if (selColHop instanceof DataOp dop && !dop .getInput ().isEmpty () && dop .getInput ().get (0 ) instanceof LiteralOp lit )
211+ return HopRewriteUtils .getIntValueSafe (lit );
212+
213+ return -1 ; // unknown at rewrite time
214+ }
215+ }
0 commit comments