@@ -563,6 +563,11 @@ class InnerLoopVectorizer {
563563 Value *VectorTripCount, BasicBlock *MiddleBlock,
564564 VPTransformState &State);
565565
566+ void fixupEarlyExitIVUsers (PHINode *OrigPhi, const InductionDescriptor &II,
567+ BasicBlock *VectorEarlyExitBB,
568+ BasicBlock *MiddleBlock, VPlan &Plan,
569+ VPTransformState &State);
570+
566571 // / Iteratively sink the scalarized operands of a predicated instruction into
567572 // / the block that was created for it.
568573 void sinkScalarOperands (Instruction *PredInst);
@@ -2838,6 +2843,23 @@ BasicBlock *InnerLoopVectorizer::createVectorizedLoopSkeleton(
28382843 return LoopVectorPreHeader;
28392844}
28402845
2846+ static bool isValueIncomingFromBlock (BasicBlock *ExitingBB, Value *V,
2847+ Instruction *UI) {
2848+ PHINode *PHI = dyn_cast<PHINode>(UI);
2849+ assert (PHI && " Expected LCSSA form" );
2850+
2851+ // If this loop has an uncountable early exit then there could be
2852+ // different users of OrigPhi with either:
2853+ // 1. Multiple users, because each exiting block (countable or
2854+ // uncountable) jumps to the same exit block, or ..
2855+ // 2. A single user with an incoming value from a countable or
2856+ // uncountable exiting block.
2857+ // In both cases there is no guarantee this came from a countable exiting
2858+ // block, i.e. the latch.
2859+ int Index = PHI->getBasicBlockIndex (ExitingBB);
2860+ return Index != -1 && PHI->getIncomingValue (Index) == V;
2861+ }
2862+
28412863// Fix up external users of the induction variable. At this point, we are
28422864// in LCSSA form, with all external PHIs that use the IV having one input value,
28432865// coming from the remainder loop. We need those PHIs to also have a correct
@@ -2853,19 +2875,20 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
28532875 // We allow both, but they, obviously, have different values.
28542876
28552877 DenseMap<Value *, Value *> MissingVals;
2878+ BasicBlock *OrigLoopLatch = OrigLoop->getLoopLatch ();
28562879
28572880 Value *EndValue = cast<PHINode>(OrigPhi->getIncomingValueForBlock (
28582881 OrigLoop->getLoopPreheader ()))
28592882 ->getIncomingValueForBlock (MiddleBlock);
28602883
28612884 // An external user of the last iteration's value should see the value that
28622885 // the remainder loop uses to initialize its own IV.
2863- Value *PostInc = OrigPhi->getIncomingValueForBlock (OrigLoop-> getLoopLatch () );
2886+ Value *PostInc = OrigPhi->getIncomingValueForBlock (OrigLoopLatch );
28642887 for (User *U : PostInc->users ()) {
28652888 Instruction *UI = cast<Instruction>(U);
28662889 if (!OrigLoop->contains (UI)) {
2867- assert (isa<PHINode>(UI) && " Expected LCSSA form " );
2868- MissingVals[UI ] = EndValue;
2890+ if ( isValueIncomingFromBlock (OrigLoopLatch, PostInc, UI))
2891+ MissingVals[cast<PHINode>(UI) ] = EndValue;
28692892 }
28702893 }
28712894
@@ -2875,7 +2898,9 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
28752898 for (User *U : OrigPhi->users ()) {
28762899 auto *UI = cast<Instruction>(U);
28772900 if (!OrigLoop->contains (UI)) {
2878- assert (isa<PHINode>(UI) && " Expected LCSSA form" );
2901+ if (!isValueIncomingFromBlock (OrigLoopLatch, OrigPhi, UI))
2902+ continue ;
2903+
28792904 IRBuilder<> B (MiddleBlock->getTerminator ());
28802905
28812906 // Fast-math-flags propagate from the original induction instruction.
@@ -2905,18 +2930,6 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
29052930 }
29062931 }
29072932
2908- assert ((MissingVals.empty () ||
2909- all_of (MissingVals,
2910- [MiddleBlock, this ](const std::pair<Value *, Value *> &P) {
2911- return all_of (
2912- predecessors (cast<Instruction>(P.first )->getParent ()),
2913- [MiddleBlock, this ](BasicBlock *Pred) {
2914- return Pred == MiddleBlock ||
2915- Pred == OrigLoop->getLoopLatch ();
2916- });
2917- })) &&
2918- " Expected escaping values from latch/middle.block only" );
2919-
29202933 for (auto &I : MissingVals) {
29212934 PHINode *PHI = cast<PHINode>(I.first );
29222935 // One corner case we have to handle is two IVs "chasing" each-other,
@@ -2929,6 +2942,102 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi,
29292942 }
29302943}
29312944
2945+ void InnerLoopVectorizer::fixupEarlyExitIVUsers (PHINode *OrigPhi,
2946+ const InductionDescriptor &II,
2947+ BasicBlock *VectorEarlyExitBB,
2948+ BasicBlock *MiddleBlock,
2949+ VPlan &Plan,
2950+ VPTransformState &State) {
2951+ // There are two kinds of external IV usages - those that use the value
2952+ // computed in the last iteration (the PHI) and those that use the penultimate
2953+ // value (the value that feeds into the phi from the loop latch).
2954+ // We allow both, but they, obviously, have different values.
2955+ DenseMap<Value *, Value *> MissingVals;
2956+ BasicBlock *OrigLoopLatch = OrigLoop->getLoopLatch ();
2957+ BasicBlock *EarlyExitingBB = Legal->getUncountableEarlyExitingBlock ();
2958+ Value *PostInc = OrigPhi->getIncomingValueForBlock (OrigLoopLatch);
2959+
2960+ // Obtain the canonical IV, since we have to use the most recent value
2961+ // before exiting the loop early. This is unlike fixupIVUsers, which has
2962+ // the luxury of using the end value in the middle block.
2963+ VPBasicBlock *EntryVPBB = Plan.getVectorLoopRegion ()->getEntryBasicBlock ();
2964+ // NOTE: We cannot call Plan.getCanonicalIV() here because the original
2965+ // recipe created whilst building plans is no longer valid.
2966+ VPHeaderPHIRecipe *CanonicalIVR =
2967+ cast<VPHeaderPHIRecipe>(&*EntryVPBB->begin ());
2968+ Value *CanonicalIV = State.get (CanonicalIVR->getVPSingleValue (), true );
2969+
2970+ // Search for the mask that drove us to exit early.
2971+ VPBasicBlock *EarlyExitVPBB = Plan.getVectorLoopRegion ()->getEarlyExit ();
2972+ VPBasicBlock *MiddleSplitVPBB =
2973+ cast<VPBasicBlock>(EarlyExitVPBB->getSinglePredecessor ());
2974+ VPInstruction *BranchOnCond =
2975+ cast<VPInstruction>(MiddleSplitVPBB->getTerminator ());
2976+ assert (BranchOnCond->getOpcode () == VPInstruction::BranchOnCond &&
2977+ " Expected middle.split block terminator to be a branch-on-cond" );
2978+ VPInstruction *ScalarEarlyExitCond =
2979+ cast<VPInstruction>(BranchOnCond->getOperand (0 ));
2980+ assert (
2981+ ScalarEarlyExitCond->getOpcode () == VPInstruction::AnyOf &&
2982+ " Expected middle.split block terminator branch condition to be any-of" );
2983+ VPValue *VectorEarlyExitCond = ScalarEarlyExitCond->getOperand (0 );
2984+ // Finally get the mask that led us into the early exit block.
2985+ Value *EarlyExitMask = State.get (VectorEarlyExitCond);
2986+
2987+ // Calculate the IV step.
2988+ VPValue *StepVPV = Plan.getSCEVExpansion (II.getStep ());
2989+ assert (StepVPV && " step must have been expanded during VPlan execution" );
2990+ Value *Step = StepVPV->isLiveIn () ? StepVPV->getLiveInIRValue ()
2991+ : State.get (StepVPV, VPLane (0 ));
2992+
2993+ auto FixUpPhi = [&](Instruction *UI, bool PostInc) -> Value * {
2994+ IRBuilder<> B (VectorEarlyExitBB->getTerminator ());
2995+ assert (isa<PHINode>(UI) && " Expected LCSSA form" );
2996+
2997+ // Fast-math-flags propagate from the original induction instruction.
2998+ if (isa_and_nonnull<FPMathOperator>(II.getInductionBinOp ()))
2999+ B.setFastMathFlags (II.getInductionBinOp ()->getFastMathFlags ());
3000+
3001+ Type *CtzType = CanonicalIV->getType ();
3002+ Value *Ctz = B.CreateCountTrailingZeroElems (CtzType, EarlyExitMask);
3003+ Ctz = B.CreateAdd (Ctz, cast<PHINode>(CanonicalIV));
3004+ if (PostInc)
3005+ Ctz = B.CreateAdd (Ctz, ConstantInt::get (CtzType, 1 ));
3006+
3007+ Value *Escape = emitTransformedIndex (B, Ctz, II.getStartValue (), Step,
3008+ II.getKind (), II.getInductionBinOp ());
3009+ Escape->setName (" ind.early.escape" );
3010+ return Escape;
3011+ };
3012+
3013+ for (User *U : PostInc->users ()) {
3014+ auto *UI = cast<Instruction>(U);
3015+ if (!OrigLoop->contains (UI)) {
3016+ if (isValueIncomingFromBlock (EarlyExitingBB, PostInc, UI))
3017+ MissingVals[UI] = FixUpPhi (UI, true );
3018+ }
3019+ }
3020+
3021+ for (User *U : OrigPhi->users ()) {
3022+ auto *UI = cast<Instruction>(U);
3023+ if (!OrigLoop->contains (UI)) {
3024+ if (isValueIncomingFromBlock (EarlyExitingBB, OrigPhi, UI))
3025+ MissingVals[UI] = FixUpPhi (UI, false );
3026+ }
3027+ }
3028+
3029+ for (auto &I : MissingVals) {
3030+ PHINode *PHI = cast<PHINode>(I.first );
3031+ // One corner case we have to handle is two IVs "chasing" each-other,
3032+ // that is %IV2 = phi [...], [ %IV1, %latch ]
3033+ // In this case, if IV1 has an external use, we need to avoid adding both
3034+ // "last value of IV1" and "penultimate value of IV2". So, verify that we
3035+ // don't already have an incoming value for the middle block.
3036+ if (PHI->getBasicBlockIndex (VectorEarlyExitBB) == -1 )
3037+ PHI->addIncoming (I.second , VectorEarlyExitBB);
3038+ }
3039+ }
3040+
29323041namespace {
29333042
29343043struct CSEDenseMapInfo {
@@ -3062,6 +3171,13 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State) {
30623171 OuterLoop->addBasicBlockToLoop (MiddleSplitBB, *LI);
30633172 PredVPBB = PredVPBB->getSinglePredecessor ();
30643173 }
3174+
3175+ BasicBlock *OrigEarlyExitBB = Legal->getUncountableEarlyExitBlock ();
3176+ if (Loop *EEL = LI->getLoopFor (OrigEarlyExitBB)) {
3177+ BasicBlock *VectorEarlyExitBB =
3178+ State.CFG .VPBB2IRBB [VectorRegion->getEarlyExit ()];
3179+ EEL->addBasicBlockToLoop (VectorEarlyExitBB, *LI);
3180+ }
30653181 }
30663182
30673183 // After vectorization, the exit blocks of the original loop will have
@@ -3091,6 +3207,15 @@ void InnerLoopVectorizer::fixVectorizedLoop(VPTransformState &State) {
30913207 getOrCreateVectorTripCount (nullptr ), LoopMiddleBlock, State);
30923208 }
30933209
3210+ if (Legal->hasUncountableEarlyExit ()) {
3211+ VPBasicBlock *VectorEarlyExitVPBB =
3212+ cast<VPBasicBlock>(VectorRegion->getEarlyExit ());
3213+ BasicBlock *VectorEarlyExitBB = State.CFG .VPBB2IRBB [VectorEarlyExitVPBB];
3214+ for (const auto &Entry : Legal->getInductionVars ())
3215+ fixupEarlyExitIVUsers (Entry.first , Entry.second , VectorEarlyExitBB,
3216+ LoopMiddleBlock, Plan, State);
3217+ }
3218+
30943219 for (Instruction *PI : PredicatedInstructions)
30953220 sinkScalarOperands (&*PI);
30963221
@@ -8974,6 +9099,9 @@ static void addScalarResumePhis(VPRecipeBuilder &Builder, VPlan &Plan) {
89749099 auto *VectorPhiR = cast<VPHeaderPHIRecipe>(Builder.getRecipe (ScalarPhiI));
89759100 if (!isa<VPFirstOrderRecurrencePHIRecipe, VPReductionPHIRecipe>(VectorPhiR))
89769101 continue ;
9102+ assert (!Plan.getVectorLoopRegion ()->getEarlyExit () &&
9103+ " Cannot handle "
9104+ " first-order recurrences with uncountable early exits" );
89779105 // The backedge value provides the value to resume coming out of a loop,
89789106 // which for FORs is a vector whose last element needs to be extracted. The
89799107 // start value provides the value if the loop is bypassed.
@@ -9032,8 +9160,7 @@ static SetVector<VPIRInstruction *> collectUsersInExitBlocks(
90329160 auto *P = dyn_cast<PHINode>(U);
90339161 return P && Inductions.contains (P);
90349162 }))) {
9035- if (ExitVPBB->getSinglePredecessor () == MiddleVPBB)
9036- continue ;
9163+ V = VPValue::getNull ();
90379164 }
90389165 ExitUsersToFix.insert (ExitIRI);
90399166 ExitIRI->addOperand (V);
@@ -9061,18 +9188,30 @@ addUsersInExitBlocks(VPlan &Plan,
90619188 for (const auto &[Idx, Op] : enumerate(ExitIRI->operands ())) {
90629189 // Pass live-in values used by exit phis directly through to their users
90639190 // in the exit block.
9064- if (Op->isLiveIn ())
9191+ if (Op->isLiveIn () || Op-> isNull () )
90659192 continue ;
90669193
90679194 // Currently only live-ins can be used by exit values from blocks not
90689195 // exiting via the vector latch through to the middle block.
9069- if (ExitIRI->getParent ()->getSinglePredecessor () != MiddleVPBB)
9070- return false ;
9071-
90729196 LLVMContext &Ctx = ExitIRI->getInstruction ().getContext ();
9073- VPValue *Ext = B.createNaryOp (VPInstruction::ExtractFromEnd,
9074- {Op, Plan.getOrAddLiveIn (ConstantInt::get (
9075- IntegerType::get (Ctx, 32 ), 1 ))});
9197+ VPValue *Ext;
9198+ VPBasicBlock *PredVPBB =
9199+ cast<VPBasicBlock>(ExitIRI->getParent ()->getPredecessors ()[Idx]);
9200+ if (PredVPBB != MiddleVPBB) {
9201+ VPBasicBlock *VectorEarlyExitVPBB =
9202+ Plan.getVectorLoopRegion ()->getEarlyExit ();
9203+ VPBuilder B2 (VectorEarlyExitVPBB,
9204+ VectorEarlyExitVPBB->getFirstNonPhi ());
9205+ assert (ExitIRI->getParent ()->getNumPredecessors () <= 2 );
9206+ VPValue *EarlyExitMask =
9207+ Plan.getVectorLoopRegion ()->getVectorEarlyExitCond ();
9208+ Ext = B2.createNaryOp (VPInstruction::ExtractFirstActive,
9209+ {Op, EarlyExitMask});
9210+ } else {
9211+ Ext = B.createNaryOp (VPInstruction::ExtractFromEnd,
9212+ {Op, Plan.getOrAddLiveIn (ConstantInt::get (
9213+ IntegerType::get (Ctx, 32 ), 1 ))});
9214+ }
90769215 ExitIRI->setOperand (Idx, Ext);
90779216 }
90789217 }
0 commit comments