diff --git a/llvm/lib/SYCLPostLink/ComputeModuleRuntimeInfo.cpp b/llvm/lib/SYCLPostLink/ComputeModuleRuntimeInfo.cpp index 39548ab2db82d..c4bd7129fefe4 100644 --- a/llvm/lib/SYCLPostLink/ComputeModuleRuntimeInfo.cpp +++ b/llvm/lib/SYCLPostLink/ComputeModuleRuntimeInfo.cpp @@ -66,11 +66,12 @@ bool isModuleUsingTsan(const Module &M) { // Optional. // Otherwise, it returns an Optional containing a list of reached // SPIR kernel function's names. -static std::optional> -traverseCGToFindSPIRKernels(const Function *StartingFunction) { +static std::optional> traverseCGToFindSPIRKernels( + const std::vector &StartingFunctionVec) { std::queue FunctionsToVisit; std::unordered_set VisitedFunctions; - FunctionsToVisit.push(StartingFunction); + for (const Function *FPtr : StartingFunctionVec) + FunctionsToVisit.push(FPtr); std::vector KernelNames; while (!FunctionsToVisit.empty()) { @@ -106,13 +107,20 @@ traverseCGToFindSPIRKernels(const Function *StartingFunction) { return {std::move(KernelNames)}; } -static std::vector getKernelNamesUsingAssert(const Module &M) { - auto *DevicelibAssertFailFunction = M.getFunction("__devicelib_assert_fail"); - if (!DevicelibAssertFailFunction) +static std::vector +getKernelNamesUsingSpecialFunctions(const Module &M, + const std::vector &FNames) { + std::vector SpecialFunctionVec; + for (const auto Fn : FNames) { + Function *FPtr = M.getFunction(Fn); + if (FPtr) + SpecialFunctionVec.push_back(FPtr); + } + + if (SpecialFunctionVec.size() == 0) return {}; - auto TraverseResult = - traverseCGToFindSPIRKernels(DevicelibAssertFailFunction); + auto TraverseResult = traverseCGToFindSPIRKernels(SpecialFunctionVec); if (TraverseResult.has_value()) return std::move(*TraverseResult); @@ -442,7 +450,9 @@ PropSetRegTy computeModuleProperties(const Module &M, PropSet.add(PropSetRegTy::SYCL_MISC_PROP, "optLevel", OptLevel); } { - std::vector FuncNames = getKernelNamesUsingAssert(M); + std::vector AssertFuncNames{"__devicelib_assert_fail"}; + std::vector FuncNames = + getKernelNamesUsingSpecialFunctions(M, AssertFuncNames); for (const StringRef &FName : FuncNames) PropSet.add(PropSetRegTy::SYCL_ASSERT_USED, FName, true); }