@@ -1347,27 +1347,46 @@ class LoopVectorizationCostModel {
13471347    return  InterleaveInfo.getInterleaveGroup (Instr);
13481348  }
13491349
1350+   // / Calculate in advance whether a scalar epilogue is required when
1351+   // / vectorizing and not vectorizing. If \p Invalidate is true then
1352+   // / invalidate a previous decision.
1353+   void  collectScalarEpilogueRequirements (bool  Invalidate) {
1354+     auto  NeedsScalarEpilogue = [&](bool  IsVectorizing) -> bool  {
1355+       if  (!isScalarEpilogueAllowed ()) {
1356+         LLVM_DEBUG (dbgs () << " LV: Loop does not require scalar epilogue"  );
1357+         return  false ;
1358+       }
1359+       //  If we might exit from anywhere but the latch, must run the exiting
1360+       //  iteration in scalar form.
1361+       if  (TheLoop->getExitingBlock () != TheLoop->getLoopLatch ()) {
1362+         LLVM_DEBUG (dbgs () << " LV: Loop requires scalar epilogue: not exiting " 
1363+                              " from latch block\n "  );
1364+         return  true ;
1365+       }
1366+       if  (IsVectorizing && InterleaveInfo.requiresScalarEpilogue ()) {
1367+         LLVM_DEBUG (dbgs () << " LV: Loop requires scalar epilogue: " 
1368+                              " interleaved group requires scalar epilogue"  );
1369+         return  true ;
1370+       }
1371+       LLVM_DEBUG (dbgs () << " LV: Loop does not require scalar epilogue"  );
1372+       return  false ;
1373+     };
1374+ 
1375+     assert ((Invalidate || !RequiresScalarEpilogue) &&
1376+            " Already determined scalar epilogue requirements!"  );
1377+     std::pair<bool , bool > Result;
1378+     Result.first  = NeedsScalarEpilogue (true );
1379+     LLVM_DEBUG (dbgs () << " , when vectorizing\n "  );
1380+     Result.second  = NeedsScalarEpilogue (false );
1381+     LLVM_DEBUG (dbgs () << " , when not vectorizing\n "  );
1382+     RequiresScalarEpilogue = Result;
1383+   }
1384+ 
13501385  // / Returns true if we're required to use a scalar epilogue for at least
13511386  // / the final iteration of the original loop.
13521387  bool  requiresScalarEpilogue (bool  IsVectorizing) const  {
1353-     if  (!isScalarEpilogueAllowed ()) {
1354-       LLVM_DEBUG (dbgs () << " LV: Loop does not require scalar epilogue\n "  );
1355-       return  false ;
1356-     }
1357-     //  If we might exit from anywhere but the latch, must run the exiting
1358-     //  iteration in scalar form.
1359-     if  (TheLoop->getExitingBlock () != TheLoop->getLoopLatch ()) {
1360-       LLVM_DEBUG (dbgs () << " LV: Loop requires scalar epilogue: not exiting " 
1361-                            " from latch block\n "  );
1362-       return  true ;
1363-     }
1364-     if  (IsVectorizing && InterleaveInfo.requiresScalarEpilogue ()) {
1365-       LLVM_DEBUG (dbgs () << " LV: Loop requires scalar epilogue: " 
1366-                            " interleaved group requires scalar epilogue\n "  );
1367-       return  true ;
1368-     }
1369-     LLVM_DEBUG (dbgs () << " LV: Loop does not require scalar epilogue\n "  );
1370-     return  false ;
1388+     auto  &CachedResult = *RequiresScalarEpilogue;
1389+     return  IsVectorizing ? CachedResult.first  : CachedResult.second ;
13711390  }
13721391
13731392  // / Returns true if we're required to use a scalar epilogue for at least
@@ -1391,6 +1410,15 @@ class LoopVectorizationCostModel {
13911410    return  ScalarEpilogueStatus == CM_ScalarEpilogueAllowed;
13921411  }
13931412
1413+   // / Update the ScalarEpilogueStatus to a new value, potentially triggering a
1414+   // / recalculation of the scalar epilogue requirements.
1415+   void  setScalarEpilogueStatus (ScalarEpilogueLowering Status) {
1416+     bool  Changed = ScalarEpilogueStatus != Status;
1417+     ScalarEpilogueStatus = Status;
1418+     if  (Changed)
1419+       collectScalarEpilogueRequirements (/* Invalidate=*/ true );
1420+   }
1421+ 
13941422  // / Returns the TailFoldingStyle that is best for the current loop.
13951423  TailFoldingStyle getTailFoldingStyle (bool  IVUpdateMayOverflow = true ) const  {
13961424    if  (!ChosenTailFoldingStyle)
@@ -1771,6 +1799,9 @@ class LoopVectorizationCostModel {
17711799
17721800  // / All element types found in the loop.
17731801  SmallPtrSet<Type *, 16 > ElementTypesInLoop;
1802+ 
1803+   // / Keeps track of whether we require a scalar epilogue.
1804+   std::optional<std::pair<bool , bool >> RequiresScalarEpilogue;
17741805};
17751806} //  end namespace llvm
17761807
@@ -4058,7 +4089,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
40584089    if  (ScalarEpilogueStatus == CM_ScalarEpilogueNotNeededUsePredicate) {
40594090      LLVM_DEBUG (dbgs () << " LV: Cannot fold tail by masking: vectorize with a " 
40604091                           " scalar epilogue instead.\n "  );
4061-       ScalarEpilogueStatus =  CM_ScalarEpilogueAllowed;
4092+       setScalarEpilogueStatus ( CM_ScalarEpilogueAllowed) ;
40624093      return  computeFeasibleMaxVF (MaxTC, UserVF, false );
40634094    }
40644095    return  FixedScalableVFPair::getNone ();
@@ -4074,6 +4105,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
40744105    //  Note: There is no need to invalidate any cost modeling decisions here, as
40754106    //  none were taken so far.
40764107    InterleaveInfo.invalidateGroupsRequiringScalarEpilogue ();
4108+     collectScalarEpilogueRequirements (/* Invalidate=*/ true );
40774109  }
40784110
40794111  FixedScalableVFPair MaxFactors = computeFeasibleMaxVF (MaxTC, UserVF, true );
@@ -4145,7 +4177,7 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) {
41454177  if  (ScalarEpilogueStatus == CM_ScalarEpilogueNotNeededUsePredicate) {
41464178    LLVM_DEBUG (dbgs () << " LV: Cannot fold tail by masking: vectorize with a " 
41474179                         " scalar epilogue instead.\n "  );
4148-     ScalarEpilogueStatus =  CM_ScalarEpilogueAllowed;
4180+     setScalarEpilogueStatus ( CM_ScalarEpilogueAllowed) ;
41494181    return  MaxFactors;
41504182  }
41514183
@@ -7058,6 +7090,7 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) {
70587090  if  (!OrigLoop->isInnermost ()) {
70597091    //  If the user doesn't provide a vectorization factor, determine a
70607092    //  reasonable one.
7093+     CM.collectScalarEpilogueRequirements (/* Invalidate=*/ false );
70617094    if  (UserVF.isZero ()) {
70627095      VF = determineVPlanVF (TTI, CM);
70637096      LLVM_DEBUG (dbgs () << " LV: VPlan computed VF "   << VF << " .\n "  );
@@ -7102,6 +7135,7 @@ LoopVectorizationPlanner::planInVPlanNativePath(ElementCount UserVF) {
71027135
71037136void  LoopVectorizationPlanner::plan (ElementCount UserVF, unsigned  UserIC) {
71047137  assert (OrigLoop->isInnermost () && " Inner loop expected."  );
7138+   CM.collectScalarEpilogueRequirements (/* Invalidate=*/ false );
71057139  CM.collectValuesToIgnore ();
71067140  CM.collectElementTypesForWidening ();
71077141
@@ -7116,11 +7150,13 @@ void LoopVectorizationPlanner::plan(ElementCount UserVF, unsigned UserIC) {
71167150        dbgs ()
71177151        << " LV: Invalidate all interleaved groups due to fold-tail by masking " 
71187152           " which requires masked-interleaved support.\n "  );
7119-     if  (CM.InterleaveInfo .invalidateGroups ())
7153+     if  (CM.InterleaveInfo .invalidateGroups ()) { 
71207154      //  Invalidating interleave groups also requires invalidating all decisions
71217155      //  based on them, which includes widening decisions and uniform and scalar
71227156      //  values.
71237157      CM.invalidateCostModelingDecisions ();
7158+       CM.collectScalarEpilogueRequirements (/* Invalidate=*/ true );
7159+     }
71247160  }
71257161
71267162  if  (CM.foldTailByMasking ())
0 commit comments