From e01b798113699f16db7c7776511bc838c866c69f Mon Sep 17 00:00:00 2001 From: luciechoi Date: Thu, 30 Oct 2025 03:36:25 +0000 Subject: [PATCH 1/2] Fix indvar pass to skip on unfolding predicates on control convergence operations --- llvm/lib/Transforms/Scalar/IndVarSimplify.cpp | 31 ++++ .../skip-predication-convergence.ll | 98 ++++++++++++ .../skip-predictaion-nested-convergence.ll | 139 ++++++++++++++++++ 3 files changed, 268 insertions(+) create mode 100644 llvm/test/Transforms/IndVarSimplify/skip-predication-convergence.ll create mode 100644 llvm/test/Transforms/IndVarSimplify/skip-predictaion-nested-convergence.ll diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp index 7ebcc219efc15..421aad8872f9a 100644 --- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -1859,6 +1859,37 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { } } + // If the loop body uses a convergence token defined within the loop, skip + // predication. This is to avoid changing the convergence behavior of the + // loop. + SmallVector blocks = ExitingBlocks; + SmallVector tokens = {}; + size_t index = 0; // Assume Exiting Blocks are sorted. + while (index < blocks.size()) { + BasicBlock *BB = blocks[index]; + index++; + const auto exitingBlockName = BB->getName(); + for (Instruction &I : *BB) { + // Check if the instruction uses any convergence tokens. + if (auto *CB = dyn_cast(&I); + CB && !isa(&I)) { + auto token = CB->getConvergenceControlToken(); + if (token && llvm::is_contained(tokens, token)) { + return false; + } + } + if (isa(&I)) { + tokens.push_back(cast(&I)); + } + } + + for (BasicBlock *Succ : successors(BB)) { + const auto succName = Succ->getName(); + if (Succ != L->getLoopLatch() && !llvm::is_contained(blocks, Succ)) + blocks.push_back(Succ); + } + } + bool Changed = false; // Finally, do the actual predication for all predicatable blocks. A couple // of notes here: diff --git a/llvm/test/Transforms/IndVarSimplify/skip-predication-convergence.ll b/llvm/test/Transforms/IndVarSimplify/skip-predication-convergence.ll new file mode 100644 index 0000000000000..12fca6778f15e --- /dev/null +++ b/llvm/test/Transforms/IndVarSimplify/skip-predication-convergence.ll @@ -0,0 +1,98 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=indvars -indvars-predicate-loops=1 -S | FileCheck %s + +; Loop with body using loop convergence token should be skipped by IndVarSimplify. + +%"class.hlsl::RWStructuredBuffer" = type { target("spirv.VulkanBuffer", [0 x i32], 12, 1), target("spirv.VulkanBuffer", i32, 12, 1) } + +@_ZL3Out = internal global %"class.hlsl::RWStructuredBuffer" poison, align 8 +@.str = private unnamed_addr constant [4 x i8] c"Out\00", align 1 + +declare token @llvm.experimental.convergence.entry() #0 + +define void @loop() local_unnamed_addr #1 { +; CHECK-LABEL: @loop( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = tail call token @llvm.experimental.convergence.entry() +; CHECK-NEXT: [[TMP1:%.*]] = tail call target("spirv.VulkanBuffer", [0 x i32], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0i32_12_1t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str) +; CHECK-NEXT: [[TMP2:%.*]] = tail call target("spirv.VulkanBuffer", i32, 12, 1) @llvm.spv.resource.counterhandlefromimplicitbinding.tspirv.VulkanBuffer_i32_12_1t.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1) [[TMP1]], i32 0, i32 0) +; CHECK-NEXT: store target("spirv.VulkanBuffer", [0 x i32], 12, 1) [[TMP1]], ptr @_ZL3Out, align 8 +; CHECK-NEXT: store target("spirv.VulkanBuffer", i32, 12, 1) [[TMP2]], ptr getelementptr inbounds nuw (i8, ptr @_ZL3Out, i64 8), align 8 +; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.spv.thread.id.in.group.i32(i32 0) +; CHECK-NEXT: br label [[FOR_COND_I:%.*]] +; CHECK: for.cond.i: +; CHECK-NEXT: [[I_0_I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[INC_I:%.*]], [[FOR_BODY_I:%.*]] ] +; CHECK-NEXT: [[TMP4:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP0]]) ] +; CHECK-NEXT: [[CMP_I:%.*]] = icmp ult i32 [[I_0_I]], 8 +; CHECK-NEXT: br i1 [[CMP_I]], label [[FOR_BODY_I]], label [[_Z4LOOPDV3_J_EXIT_LOOPEXIT:%.*]] +; CHECK: for.body.i: +; CHECK-NEXT: [[CMP1_I:%.*]] = icmp eq i32 [[I_0_I]], [[TMP3]] +; CHECK-NEXT: [[INC_I]] = add nuw nsw i32 [[I_0_I]], 1 +; CHECK-NEXT: br i1 [[CMP1_I]], label [[IF_THEN_I:%.*]], label [[FOR_COND_I]] +; CHECK: _Z4loopDv3_j.exit.loopexit: +; CHECK-NEXT: br label [[_Z4LOOPDV3_J_EXIT:%.*]] +; CHECK: if.then.i: +; CHECK-NEXT: [[HLSL_WAVE_ACTIVE_MAX2_I:%.*]] = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 [[TMP3]]) [ "convergencectrl"(token [[TMP4]]) ] +; CHECK-NEXT: [[TMP5:%.*]] = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1) [[TMP1]], i32 [[TMP3]]) +; CHECK-NEXT: store i32 [[HLSL_WAVE_ACTIVE_MAX2_I]], ptr addrspace(11) [[TMP5]], align 4 +; CHECK-NEXT: br label [[_Z4LOOPDV3_J_EXIT]] +; CHECK: _Z4loopDv3_j.exit: +; CHECK-NEXT: ret void +; +entry: + %0 = tail call token @llvm.experimental.convergence.entry() + %1 = tail call target("spirv.VulkanBuffer", [0 x i32], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0i32_12_1t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str) + %2 = tail call target("spirv.VulkanBuffer", i32, 12, 1) @llvm.spv.resource.counterhandlefromimplicitbinding.tspirv.VulkanBuffer_i32_12_1t.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1) %1, i32 0, i32 0) + store target("spirv.VulkanBuffer", [0 x i32], 12, 1) %1, ptr @_ZL3Out, align 8 + store target("spirv.VulkanBuffer", i32, 12, 1) %2, ptr getelementptr inbounds nuw (i8, ptr @_ZL3Out, i64 8), align 8 + %3 = tail call i32 @llvm.spv.thread.id.in.group.i32(i32 0) + br label %for.cond.i + +; Loop: +for.cond.i: ; preds = %for.body.i, %entry + %i.0.i = phi i32 [ 0, %entry ], [ %inc.i, %for.body.i ] + %4 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ] + %cmp.i = icmp ult i32 %i.0.i, 8 + br i1 %cmp.i, label %for.body.i, label %_Z4loopDv3_j.exit.loopexit + +for.body.i: ; preds = %for.cond.i + %cmp1.i = icmp eq i32 %i.0.i, %3 + %inc.i = add nuw nsw i32 %i.0.i, 1 + br i1 %cmp1.i, label %if.then.i, label %for.cond.i + +; Exit blocks +_Z4loopDv3_j.exit.loopexit: ; preds = %for.cond.i + br label %_Z4loopDv3_j.exit + +if.then.i: ; preds = %for.body.i + %hlsl.wave.active.max2.i = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 %3) [ "convergencectrl"(token %4) ] + %5 = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1) %1, i32 %3) + store i32 %hlsl.wave.active.max2.i, ptr addrspace(11) %5, align 4 + br label %_Z4loopDv3_j.exit + +_Z4loopDv3_j.exit: ; preds = %_Z4loopDv3_j.exit.loopexit, %if.then.i + ret void +} + +; Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none) +declare i32 @llvm.spv.thread.id.in.group.i32(i32) #2 + +; Function Attrs: convergent mustprogress nocallback nofree nosync nounwind willreturn memory(none) +declare token @llvm.experimental.convergence.loop() #0 + +; Function Attrs: convergent mustprogress nocallback nofree nosync nounwind willreturn memory(none) +declare i32 @llvm.spv.wave.reduce.umax.i32(i32) #0 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none) +declare target("spirv.VulkanBuffer", [0 x i32], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0i32_12_1t(i32, i32, i32, i32, ptr) #4 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none) +declare target("spirv.VulkanBuffer", i32, 12, 1) @llvm.spv.resource.counterhandlefromimplicitbinding.tspirv.VulkanBuffer_i32_12_1t.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1), i32, i32) #4 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none) +declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1), i32) #4 + +attributes #0 = { convergent mustprogress nocallback nofree nosync nounwind willreturn memory(none) } +attributes #1 = { convergent noinline norecurse "frame-pointer"="all" "hlsl.numthreads"="8,1,1" "hlsl.shader"="compute" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } +attributes #2 = { mustprogress nofree nosync nounwind willreturn memory(none) } +attributes #4 = { mustprogress nocallback nofree nosync nounwind willreturn memory(none) } diff --git a/llvm/test/Transforms/IndVarSimplify/skip-predictaion-nested-convergence.ll b/llvm/test/Transforms/IndVarSimplify/skip-predictaion-nested-convergence.ll new file mode 100644 index 0000000000000..22f25b1428556 --- /dev/null +++ b/llvm/test/Transforms/IndVarSimplify/skip-predictaion-nested-convergence.ll @@ -0,0 +1,139 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=indvars -indvars-predicate-loops=1 -S | FileCheck %s + +; Nested loops with body using loop convergence token should be skipped by IndVarSimplify. + +%"class.hlsl::RWStructuredBuffer" = type { target("spirv.VulkanBuffer", [0 x i32], 12, 1), target("spirv.VulkanBuffer", i32, 12, 1) } + +@_ZL3Out = internal global %"class.hlsl::RWStructuredBuffer" poison, align 8 +@.str = private unnamed_addr constant [4 x i8] c"Out\00", align 1 + +; Function Attrs: convergent mustprogress nocallback nofree nosync nounwind willreturn memory(none) +declare token @llvm.experimental.convergence.entry() #0 + +define void @nested() local_unnamed_addr #1 { +; CHECK-LABEL: @nested( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = tail call token @llvm.experimental.convergence.entry() +; CHECK-NEXT: [[TMP1:%.*]] = tail call target("spirv.VulkanBuffer", [0 x i32], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0i32_12_1t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str) +; CHECK-NEXT: [[TMP2:%.*]] = tail call target("spirv.VulkanBuffer", i32, 12, 1) @llvm.spv.resource.counterhandlefromimplicitbinding.tspirv.VulkanBuffer_i32_12_1t.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1) [[TMP1]], i32 0, i32 0) +; CHECK-NEXT: store target("spirv.VulkanBuffer", [0 x i32], 12, 1) [[TMP1]], ptr @_ZL3Out, align 8 +; CHECK-NEXT: store target("spirv.VulkanBuffer", i32, 12, 1) [[TMP2]], ptr getelementptr inbounds nuw (i8, ptr @_ZL3Out, i64 8), align 8 +; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.spv.thread.id.in.group.i32(i32 0) +; CHECK-NEXT: [[TMP4:%.*]] = tail call i32 @llvm.spv.thread.id.in.group.i32(i32 1) +; CHECK-NEXT: [[MUL_I:%.*]] = shl nuw nsw i32 [[TMP3]], 3 +; CHECK-NEXT: [[ADD_I:%.*]] = add nuw nsw i32 [[MUL_I]], [[TMP4]] +; CHECK-NEXT: br label [[FOR_COND_I:%.*]] +; CHECK: for.cond.i: +; CHECK-NEXT: [[I_0_I:%.*]] = phi i32 [ 0, [[ENTRY:%.*]] ], [ [[INC10_I:%.*]], [[CLEANUP_I:%.*]] ] +; CHECK-NEXT: [[TMP5:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP0]]) ] +; CHECK-NEXT: [[CMP_I:%.*]] = icmp ult i32 [[I_0_I]], 8 +; CHECK-NEXT: br i1 [[CMP_I]], label [[FOR_COND1_I_PREHEADER:%.*]], label [[_Z4NESTEDDV3_J_EXIT:%.*]] +; CHECK: for.cond1.i.preheader: +; CHECK-NEXT: [[CMP5_I:%.*]] = icmp eq i32 [[I_0_I]], [[TMP3]] +; CHECK-NEXT: br label [[FOR_COND1_I:%.*]] +; CHECK: for.cond1.i: +; CHECK-NEXT: [[J_0_I:%.*]] = phi i32 [ [[INC_I:%.*]], [[FOR_BODY4_I:%.*]] ], [ 0, [[FOR_COND1_I_PREHEADER]] ] +; CHECK-NEXT: [[TMP6:%.*]] = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token [[TMP5]]) ] +; CHECK-NEXT: [[CMP2_I:%.*]] = icmp ult i32 [[J_0_I]], 8 +; CHECK-NEXT: br i1 [[CMP2_I]], label [[FOR_BODY4_I]], label [[CLEANUP_I_LOOPEXIT:%.*]] +; CHECK: for.body4.i: +; CHECK-NEXT: [[CMP6_I:%.*]] = icmp eq i32 [[J_0_I]], [[TMP4]] +; CHECK-NEXT: [[OR_COND:%.*]] = select i1 [[CMP5_I]], i1 [[CMP6_I]], i1 false +; CHECK-NEXT: [[INC_I]] = add nuw nsw i32 [[J_0_I]], 1 +; CHECK-NEXT: br i1 [[OR_COND]], label [[IF_THEN_I:%.*]], label [[FOR_COND1_I]] +; CHECK: cleanup.i.loopexit: +; CHECK-NEXT: br label [[CLEANUP_I]] +; CHECK: if.then.i: +; CHECK-NEXT: [[HLSL_WAVE_ACTIVE_MAX7_I:%.*]] = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 [[ADD_I]]) [ "convergencectrl"(token [[TMP6]]) ] +; CHECK-NEXT: [[TMP7:%.*]] = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1) [[TMP1]], i32 [[ADD_I]]) +; CHECK-NEXT: store i32 [[HLSL_WAVE_ACTIVE_MAX7_I]], ptr addrspace(11) [[TMP7]], align 4 +; CHECK-NEXT: br label [[CLEANUP_I]] +; CHECK: cleanup.i: +; CHECK-NEXT: [[INC10_I]] = add nuw nsw i32 [[I_0_I]], 1 +; CHECK-NEXT: br label [[FOR_COND_I]] +; CHECK: _Z4nestedDv3_j.exit: +; CHECK-NEXT: ret void +; +entry: + %0 = tail call token @llvm.experimental.convergence.entry() + %1 = tail call target("spirv.VulkanBuffer", [0 x i32], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0i32_12_1t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str) + %2 = tail call target("spirv.VulkanBuffer", i32, 12, 1) @llvm.spv.resource.counterhandlefromimplicitbinding.tspirv.VulkanBuffer_i32_12_1t.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1) %1, i32 0, i32 0) + store target("spirv.VulkanBuffer", [0 x i32], 12, 1) %1, ptr @_ZL3Out, align 8 + store target("spirv.VulkanBuffer", i32, 12, 1) %2, ptr getelementptr inbounds nuw (i8, ptr @_ZL3Out, i64 8), align 8 + %3 = tail call i32 @llvm.spv.thread.id.in.group.i32(i32 0) + %4 = tail call i32 @llvm.spv.thread.id.in.group.i32(i32 1) + %mul.i = shl nuw nsw i32 %3, 3 + %add.i = add nuw nsw i32 %mul.i, %4 + br label %for.cond.i + +for.cond.i: ; preds = %cleanup.i, %entry + %i.0.i = phi i32 [ 0, %entry ], [ %inc10.i, %cleanup.i ] + %5 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %0) ] + %cmp.i = icmp ult i32 %i.0.i, 8 + br i1 %cmp.i, label %for.cond1.i.preheader, label %_Z4nestedDv3_j.exit + +; Preheader: +for.cond1.i.preheader: ; preds = %for.cond.i + %cmp5.i = icmp eq i32 %i.0.i, %3 + br label %for.cond1.i + +; Loop: +for.cond1.i: ; preds = %for.body4.i, %for.cond1.i.preheader + %j.0.i = phi i32 [ %inc.i, %for.body4.i ], [ 0, %for.cond1.i.preheader ] + %6 = call token @llvm.experimental.convergence.loop() [ "convergencectrl"(token %5) ] + %cmp2.i = icmp ult i32 %j.0.i, 8 + br i1 %cmp2.i, label %for.body4.i, label %cleanup.i.loopexit + +for.body4.i: ; preds = %for.cond1.i + %cmp6.i = icmp eq i32 %j.0.i, %4 + %or.cond = select i1 %cmp5.i, i1 %cmp6.i, i1 false + %inc.i = add nuw nsw i32 %j.0.i, 1 + br i1 %or.cond, label %if.then.i, label %for.cond1.i + +; Exit blocks +cleanup.i.loopexit: ; preds = %for.cond1.i + br label %cleanup.i + +if.then.i: ; preds = %for.body4.i + %hlsl.wave.active.max7.i = call spir_func i32 @llvm.spv.wave.reduce.umax.i32(i32 %add.i) [ "convergencectrl"(token %6) ] + %7 = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1) %1, i32 %add.i) + store i32 %hlsl.wave.active.max7.i, ptr addrspace(11) %7, align 4 + br label %cleanup.i + +cleanup.i: ; preds = %cleanup.i.loopexit, %if.then.i + %inc10.i = add nuw nsw i32 %i.0.i, 1 + br label %for.cond.i + +_Z4nestedDv3_j.exit: ; preds = %for.cond.i + ret void +} + +; Function Attrs: mustprogress nofree nosync nounwind willreturn memory(none) +declare i32 @llvm.spv.thread.id.in.group.i32(i32) #2 + +; Function Attrs: convergent mustprogress nocallback nofree nosync nounwind willreturn memory(none) +declare token @llvm.experimental.convergence.loop() #0 + +; Function Attrs: convergent mustprogress nocallback nofree nosync nounwind willreturn memory(none) +declare i32 @llvm.spv.wave.reduce.umax.i32(i32) #0 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none) +declare target("spirv.VulkanBuffer", [0 x i32], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0i32_12_1t(i32, i32, i32, i32, ptr) #4 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none) +declare target("spirv.VulkanBuffer", i32, 12, 1) @llvm.spv.resource.counterhandlefromimplicitbinding.tspirv.VulkanBuffer_i32_12_1t.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1), i32, i32) #4 + +; Function Attrs: mustprogress nocallback nofree nosync nounwind willreturn memory(none) +declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0i32_12_1t(target("spirv.VulkanBuffer", [0 x i32], 12, 1), i32) #4 + +; Function Attrs: nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: readwrite) +declare void @llvm.experimental.noalias.scope.decl(metadata) #5 + +attributes #0 = { convergent mustprogress nocallback nofree nosync nounwind willreturn memory(none) } +attributes #1 = { convergent noinline norecurse "frame-pointer"="all" "hlsl.numthreads"="8,8,1" "hlsl.shader"="compute" "no-infs-fp-math"="true" "no-nans-fp-math"="true" "no-signed-zeros-fp-math"="true" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } +attributes #2 = { mustprogress nofree nosync nounwind willreturn memory(none) } +attributes #3 = { mustprogress nocallback nofree nosync nounwind willreturn memory(argmem: readwrite) } +attributes #4 = { mustprogress nocallback nofree nosync nounwind willreturn memory(none) } +attributes #5 = { nocallback nofree nosync nounwind willreturn memory(inaccessiblemem: readwrite) } +attributes #6 = { nounwind } From 4fc8d878d3459192aa8cbf5f8caf136d77598aa2 Mon Sep 17 00:00:00 2001 From: luciechoi Date: Thu, 30 Oct 2025 21:56:29 +0000 Subject: [PATCH 2/2] Use helper --- llvm/lib/Transforms/Scalar/IndVarSimplify.cpp | 38 +++++-------------- 1 file changed, 10 insertions(+), 28 deletions(-) diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp index 421aad8872f9a..3a093b33ddfa7 100644 --- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -31,6 +31,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSA.h" @@ -1859,35 +1860,16 @@ bool IndVarSimplify::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { } } - // If the loop body uses a convergence token defined within the loop, skip - // predication. This is to avoid changing the convergence behavior of the - // loop. - SmallVector blocks = ExitingBlocks; - SmallVector tokens = {}; - size_t index = 0; // Assume Exiting Blocks are sorted. - while (index < blocks.size()) { - BasicBlock *BB = blocks[index]; - index++; - const auto exitingBlockName = BB->getName(); - for (Instruction &I : *BB) { - // Check if the instruction uses any convergence tokens. - if (auto *CB = dyn_cast(&I); - CB && !isa(&I)) { - auto token = CB->getConvergenceControlToken(); - if (token && llvm::is_contained(tokens, token)) { - return false; - } - } - if (isa(&I)) { - tokens.push_back(cast(&I)); - } - } + CodeMetrics Metrics; + SmallPtrSet EphValues; + for (BasicBlock *BB : L->blocks()) { + Metrics.analyzeBasicBlock(BB, *TTI, EphValues, /* PrepareForLTO= */ false, + L); + } - for (BasicBlock *Succ : successors(BB)) { - const auto succName = Succ->getName(); - if (Succ != L->getLoopLatch() && !llvm::is_contained(blocks, Succ)) - blocks.push_back(Succ); - } + if (Metrics.Convergence == ConvergenceKind::ExtendedLoop) { + // Do not predicate loops with extended convergence. + return false; } bool Changed = false;