-
Notifications
You must be signed in to change notification settings - Fork 15k
[HLSL][DXIL][SPIRV] Added WaveActiveBitOr HLSL intrinsic #165156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
|
@llvm/pr-subscribers-clang-codegen @llvm/pr-subscribers-clang Author: Sietze Riemersma (KungFuDonkey) ChangesAdds the WaveActiveBitOr intrinsic. This intrinsic required a bit more work than the last intrinsics that I have done. There are some peculiarities, which I verified with dxcompiler:
Followed the checklist:
Full diff: https://github.com/llvm/llvm-project/pull/165156.diff 16 Files Affected:
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 792e2e07ec594..3eff68e5b7b8f 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4999,6 +4999,12 @@ def HLSLWaveActiveAnyTrue : LangBuiltin<"HLSL_LANG"> {
let Prototype = "bool(bool)";
}
+def HLSLWaveActiveBitOr : LangBuiltin<"HLSL_LANG"> {
+ let Spellings = ["__builtin_hlsl_wave_active_bit_or"];
+ let Attributes = [NoThrow, Const];
+ let Prototype = "void (...)";
+}
+
def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_active_count_bits"];
let Attributes = [NoThrow, Const];
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 4f2f5a761f197..2668f223aaa8f 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -690,6 +690,15 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
return EmitRuntimeCall(
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
}
+ case Builtin::BI__builtin_hlsl_wave_active_bit_or: {
+ Value *Op = EmitScalarExpr(E->getArg(0));
+ assert(Op->getType()->hasUnsignedIntegerRepresentation() &&
+ "Intrinsic WaveActiveBitOr operand must have a unsigned integer representation");
+
+ Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveBitOrIntrinsic();
+ return EmitRuntimeCall(
+ Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
+ }
case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
Value *OpExpr = EmitScalarExpr(E->getArg(0));
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index 7c6c2850fd4d4..b7fb15f10b6c5 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -113,6 +113,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
+ GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveBitOr, wave_reduce_or)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveGetLaneCount, wave_get_lane_count)
diff --git a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
index d973371312701..2075f557a9a0a 100644
--- a/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
+++ b/clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
@@ -2498,6 +2498,36 @@ __attribute__((convergent)) double3 WaveReadLaneAt(double3, uint32_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) double4 WaveReadLaneAt(double4, uint32_t);
+//===----------------------------------------------------------------------===//
+// WaveActiveBitOr builtins
+//===----------------------------------------------------------------------===//
+
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
+__attribute__((convergent)) uint WaveActiveBitOr(uint);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
+__attribute__((convergent)) uint2 WaveActiveBitOr(uint2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
+__attribute__((convergent)) uint3 WaveActiveBitOr(uint3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
+__attribute__((convergent)) uint4 WaveActiveBitOr(uint4);
+
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
+__attribute__((convergent)) uint64_t WaveActiveBitOr(uint64_t);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
+__attribute__((convergent)) uint64_t2 WaveActiveBitOr(uint64_t2);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
+__attribute__((convergent)) uint64_t3 WaveActiveBitOr(uint64_t3);
+_HLSL_AVAILABILITY(shadermodel, 6.0)
+_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
+__attribute__((convergent)) uint64_t4 WaveActiveBitOr(uint64_t4);
+
//===----------------------------------------------------------------------===//
// WaveActiveMax builtins
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index f34706677b59f..5969173b15fab 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3211,6 +3211,29 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
TheCall->setType(ArgTyExpr);
break;
}
+ case Builtin::BI__builtin_hlsl_wave_active_bit_or: {
+ if (SemaRef.checkArgCount(TheCall, 1))
+ return true;
+
+ // Ensure input expr type is a scalar/vector and the same as the return type
+ if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0))
+ return true;
+ if (CheckWaveActive(&SemaRef, TheCall))
+ return true;
+
+ // Ensure expression parameter type can be interpreted as a uint
+ ExprResult Expr = TheCall->getArg(0);
+ QualType ArgTyExpr = Expr.get()->getType();
+ if (!ArgTyExpr->isIntegerType()) {
+ SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
+ diag::err_typecheck_convert_incompatible)
+ << ArgTyExpr << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
+ return true;
+ }
+
+ TheCall->setType(ArgTyExpr);
+ break;
+ }
// Note these are llvm builtins that we want to catch invalid intrinsic
// generation. Normal handling of these builitns will occur elsewhere.
case Builtin::BI__builtin_elementwise_bitreverse: {
diff --git a/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl b/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl
new file mode 100644
index 0000000000000..5e030208679b9
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl
@@ -0,0 +1,30 @@
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
+// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
+// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV
+
+// Test basic lowering to runtime function call.
+
+// CHECK-LABEL: test_uint
+uint test_uint(uint expr) {
+ // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.reduce.or.i32([[TY]] %[[#]])
+ // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.reduce.or.i32([[TY]] %[[#]])
+ // CHECK: ret [[TY]] %[[RET]]
+ return WaveActiveBitOr(expr);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.or.i32([[TY]]) #[[#attr:]]
+// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.or.i32([[TY]]) #[[#attr:]]
+
+// CHECK-LABEL: test_uint64_t
+uint64_t test_uint64_t(uint64_t expr) {
+ // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.reduce.or.i64([[TY]] %[[#]])
+ // CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.reduce.or.i64([[TY]] %[[#]])
+ // CHECK: ret [[TY]] %[[RET]]
+ return WaveActiveBitOr(expr);
+}
+
+// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.or.i64([[TY]]) #[[#attr:]]
+// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.or.i64([[TY]]) #[[#attr:]]
diff --git a/clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl
new file mode 100644
index 0000000000000..8eb62c7d96b4e
--- /dev/null
+++ b/clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl
@@ -0,0 +1,38 @@
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify
+
+uint test_too_few_arg() {
+ return __builtin_hlsl_wave_active_bit_or();
+ // expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
+}
+
+uint2 test_too_many_arg(uint2 p0) {
+ return __builtin_hlsl_wave_active_bit_or(p0, p0);
+ // expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
+}
+
+bool test_expr_bool_type_check(bool p0) {
+ return __builtin_hlsl_wave_active_bit_or(p0);
+ // expected-error@-1 {{invalid operand of type 'bool'}}
+}
+
+float test_expr_float_type_check(float p0) {
+ return __builtin_hlsl_wave_active_bit_or(p0);
+ // expected-error@-1 {{invalid operand of type 'float'}}
+}
+
+bool2 test_expr_bool_vec_type_check(bool2 p0) {
+ return __builtin_hlsl_wave_active_bit_or(p0);
+ // expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>')}}
+}
+
+float2 test_expr_float_type_check(float2 p0) {
+ return __builtin_hlsl_wave_active_bit_or(p0);
+ // expected-error@-1 {{invalid operand of type 'float2' (aka 'vector<float, 2>')}}
+}
+
+struct S { float f; };
+
+S test_expr_struct_type_check(S p0) {
+ return __builtin_hlsl_wave_active_bit_or(p0);
+ // expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}}
+}
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 3b7077c52db21..22f5644c1103e 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -151,6 +151,7 @@ def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1
def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
+def int_dx_wave_reduce_or : DefaultAttrsIntrinsic<[llvm_anyint_ty], [llvm_anyint_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
index 49a182be98acd..8747c46c97293 100644
--- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td
+++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td
@@ -120,6 +120,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
+ def int_spv_wave_reduce_or : DefaultAttrsIntrinsic<[llvm_anyint_ty], [llvm_anyint_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
@@ -136,7 +137,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
def int_spv_sclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
def int_spv_nclamp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
- // Create resource handle given the binding information. Returns a
+ // Create resource handle given the binding information. Returns a
// type appropriate for the kind of resource given the set id, binding id,
// array size of the binding, as well as an index and an indicator
// whether that index may be non-uniform.
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 44c48305f2832..537eba6afd0a2 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -316,6 +316,10 @@ defvar WaveOpKind_Product = 1;
defvar WaveOpKind_Min = 2;
defvar WaveOpKind_Max = 3;
+defvar WaveBitOpKind_And = 0;
+defvar WaveBitOpKind_Or = 1;
+defvar WaveBitOpKind_Xor = 2;
+
defvar SignedOpKind_Signed = 0;
defvar SignedOpKind_Unsigned = 1;
@@ -1069,6 +1073,24 @@ def WaveActiveOp : DXILOp<119, waveActiveOp> {
let attributes = [Attributes<DXIL1_0, []>];
}
+def WaveActiveBit : DXILOp<120, waveActiveBit> {
+ let Doc = "returns the result of the operation across waves";
+ let intrinsics = [
+ IntrinSelect<int_dx_wave_reduce_or,
+ [
+ IntrinArgIndex<0>, IntrinArgI8<WaveBitOpKind_Or>,
+ ]>,
+ ];
+
+ let arguments = [OverloadTy, Int8Ty];
+ let result = OverloadTy;
+ let overloads = [
+ Overloads<DXIL1_0, [Int32Ty, Int64Ty]>
+ ];
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+ let attributes = [Attributes<DXIL1_0, []>];
+}
+
def WaveAllBitCount : DXILOp<135, waveAllOp> {
let Doc = "returns the count of bits set to 1 across the wave";
let intrinsics = [IntrinSelect<int_dx_wave_active_countbits>];
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index e7e7f2ce66ae8..7c47e107f51b3 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -90,6 +90,7 @@ static bool checkWaveOps(Intrinsic::ID IID) {
case Intrinsic::dx_wave_readlane:
case Intrinsic::dx_wave_active_countbits:
// Wave Active Op Variants
+ case Intrinsic::dx_wave_reduce_or:
case Intrinsic::dx_wave_reduce_sum:
case Intrinsic::dx_wave_reduce_usum:
case Intrinsic::dx_wave_reduce_max:
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 68fd3e0bc74c7..50b7e9fd957f5 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -54,6 +54,7 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
case Intrinsic::dx_saturate:
case Intrinsic::dx_splitdouble:
case Intrinsic::dx_wave_readlane:
+ case Intrinsic::dx_wave_reduce_or:
case Intrinsic::dx_wave_reduce_max:
case Intrinsic::dx_wave_reduce_sum:
case Intrinsic::dx_wave_reduce_umax:
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index a0cff4d82b500..3429e98de0548 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -219,6 +219,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectDot4AddPackedExpansion(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
+ bool selectWaveReduceOr(Register ResVReg, const SPIRVType *ResType,
+ MachineInstr &I) const;
+
bool selectWaveReduceMax(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, bool IsUnsigned) const;
@@ -2427,6 +2430,32 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
return Result;
}
+bool SPIRVInstructionSelector::selectWaveReduceOr(
+ Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
+
+ assert(I.getNumOperands() == 3);
+ assert(I.getOperand(2).isReg());
+ MachineBasicBlock &BB = *I.getParent();
+ Register InputRegister = I.getOperand(2).getReg();
+ SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister);
+
+ if (!InputType)
+ report_fatal_error("Input Type could not be determined.");
+
+ SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);
+
+ auto Opcode = SPIRV::OpGroupNonUniformBitwiseOr;
+
+ return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
+ .addDef(ResVReg)
+ .addUse(GR.getSPIRVTypeID(ResType))
+ .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII,
+ !STI.isShader()))
+ .addImm(SPIRV::GroupOperation::Reduce)
+ .addUse(I.getOperand(2).getReg())
+ .constrainAllUses(TII, TRI, RBI);
+}
+
bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I,
@@ -3427,6 +3456,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny);
case Intrinsic::spv_wave_is_first_lane:
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformElect);
+ case Intrinsic::spv_wave_reduce_or:
+ return selectWaveReduceOr(ResVReg, ResType, I);
case Intrinsic::spv_wave_reduce_umax:
return selectWaveReduceMax(ResVReg, ResType, I, /*IsUnsigned*/ true);
case Intrinsic::spv_wave_reduce_max:
diff --git a/llvm/test/CodeGen/DirectX/ShaderFlags/wave-ops.ll b/llvm/test/CodeGen/DirectX/ShaderFlags/wave-ops.ll
index 7a876f67615cd..140152aa80c07 100644
--- a/llvm/test/CodeGen/DirectX/ShaderFlags/wave-ops.ll
+++ b/llvm/test/CodeGen/DirectX/ShaderFlags/wave-ops.ll
@@ -41,6 +41,13 @@ entry:
ret i1 %ret
}
+define noundef i32 @wave_bit_or(i32 %x) {
+entry:
+ ; CHECK: Function wave_bit_or : [[WAVE_FLAG]]
+ %ret = call i32 @llvm.dx.wave.reduce.or(i32 %x)
+ ret i32 %ret
+}
+
define noundef i1 @wave_readlane(i1 %x, i32 %idx) {
entry:
; CHECK: Function wave_readlane : [[WAVE_FLAG]]
diff --git a/llvm/test/CodeGen/DirectX/WaveActiveBitOr.ll b/llvm/test/CodeGen/DirectX/WaveActiveBitOr.ll
new file mode 100644
index 0000000000000..386a263ec245d
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/WaveActiveBitOr.ll
@@ -0,0 +1,19 @@
+; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s
+
+define noundef i32 @wave_bitor_simple(i32 noundef %p1) {
+entry:
+; CHECK: call i32 @dx.op.waveActiveBit.i32(i32 120, i32 %p1, i8 1){{$}}
+ %ret = call i32 @llvm.dx.wave.reduce.or.i32(i32 %p1)
+ ret i32 %ret
+}
+
+declare i32 @llvm.dx.wave.reduce.or.i32(i32)
+
+define noundef i64 @wave_bitor_simple64(i64 noundef %p1) {
+entry:
+; CHECK: call i64 @dx.op.waveActiveBit.i64(i32 120, i64 %p1, i8 1){{$}}
+ %ret = call i64 @llvm.dx.wave.reduce.or.i64(i64 %p1)
+ ret i64 %ret
+}
+
+declare i64 @llvm.dx.wave.reduce.or.i64(i64)
diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBitOr.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBitOr.ll
new file mode 100644
index 0000000000000..2df3a335bee69
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBitOr.ll
@@ -0,0 +1,30 @@
+; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %}
+
+; Test lowering to spir-v backend for various types and scalar/vector
+
+; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
+; CHECK-DAG: %[[#uint64:]] = OpTypeInt 64 0
+; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3
+
+; CHECK-LABEL: Begin function test_uint
+; CHECK: %[[#iexpr:]] = OpFunctionParameter %[[#uint]]
+define i32 @test_uint(i32 %iexpr) {
+entry:
+; CHECK: %[[#iret:]] = OpGroupNonUniformBitwiseOr %[[#uint]] %[[#scope]] Reduce %[[#iexpr]]
+ %0 = call i32 @llvm.spv.wave.reduce.or.i32(i32 %iexpr)
+ ret i32 %0
+}
+
+declare i32 @llvm.spv.wave.reduce.or.i32(i32)
+
+; CHECK-LABEL: Begin function test_uint64
+; CHECK: %[[#iexpr64:]] = OpFunctionParameter %[[#uint64]]
+define i64 @test_uint64(i64 %iexpr64) {
+entry:
+; CHECK: %[[#iret:]] = OpGroupNonUniformBitwiseOr %[[#uint64]] %[[#scope]] Reduce %[[#iexpr64]]
+ %0 = call i64 @llvm.spv.wave.reduce.or.i64(i64 %iexpr64)
+ ret i64 %0
+}
+
+declare i64 @llvm.spv.wave.reduce.or.i64(i64)
\ No newline at end of file
|
clang/lib/Sema/SemaHLSL.cpp
Outdated
| return true; | ||
|
|
||
| // Ensure input expr type is a scalar/vector and the same as the return type | ||
| if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BI__builtin_hlsl_wave_active_max does CheckWaveActive first. Why are yo doing CheckAnyScalarOrVector.?For predictability of Error orders can this be changed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I am still new to this, so take my PR with a grain of salt.
I copied this from WaveActiveSum, which has that order. I'll change it for this function.
I guess the Scalar or Vector type makes no sense as we only allow uint32 and uint64 here
Will update the code
| SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(), | ||
| diag::err_typecheck_convert_incompatible) | ||
| << ArgTyExpr << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is unique about BI__builtin_hlsl_wave_active_bit_or that you need err_typecheck_convert_incompatible when BI__builtin_hlsl_wave_active_max does not do this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess its not a convert, I wanted to enforce only uint/uint64 usage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check if new push is more like what you'd expect
| let intrinsics = [ | ||
| IntrinSelect<int_dx_wave_reduce_or, | ||
| [ | ||
| IntrinArgIndex<0>, IntrinArgI8<WaveBitOpKind_Or>, | ||
| ]>, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the formatting seems off
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its the same as DXILOp<119>, so what is wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can run tablegen files through clang-format. There is no ci for this so suspect maybe thats why the WaveActiveOp look the way they do. If it doesn't change this then feel free to ignore my comment.
You can test this locally with the following command:git-clang-format --diff origin/main HEAD --extensions cpp,h -- clang/lib/CodeGen/CGHLSLBuiltins.cpp clang/lib/CodeGen/CGHLSLRuntime.h clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h clang/lib/Sema/SemaHLSL.cpp llvm/lib/Target/DirectX/DXILShaderFlags.cpp llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp --diff_from_common_commit
View the diff from clang-format here.diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 4c1872308..ad6d31814 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -735,7 +735,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
case Builtin::BI__builtin_hlsl_wave_active_bit_or: {
Value *Op = EmitScalarExpr(E->getArg(0));
assert(Op->getType()->hasUnsignedIntegerRepresentation() &&
- "Intrinsic WaveActiveBitOr operand must have a unsigned integer representation");
+ "Intrinsic WaveActiveBitOr operand must have a unsigned integer "
+ "representation");
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveBitOrIntrinsic();
return EmitRuntimeCall(
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 9da372914..d39d5843d 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -3306,7 +3306,7 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
QualType ArgTyExpr = Expr.get()->getType();
auto *VTy = ArgTyExpr->getAs<VectorType>();
if (!(ArgTyExpr->isIntegerType() ||
- (VTy && VTy->getElementType()->isIntegerType()))) {
+ (VTy && VTy->getElementType()->isIntegerType()))) {
SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
diag::err_builtin_invalid_arg_type)
<< ArgTyExpr << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index ddc9bb3c5..18f29498b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -2432,8 +2432,9 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
return Result;
}
-bool SPIRVInstructionSelector::selectWaveReduceOr(
- Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {
+bool SPIRVInstructionSelector::selectWaveReduceOr(Register ResVReg,
+ const SPIRVType *ResType,
+ MachineInstr &I) const {
assert(I.getNumOperands() == 3);
assert(I.getOperand(2).isReg());
|
| .addUse(GR.getSPIRVTypeID(ResType)) | ||
| .addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII, | ||
| !STI.isShader())) | ||
| .addImm(SPIRV::GroupOperation::Reduce) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see this is why you are adding reduce to the name.
https://godbolt.org/z/Pe7hcEcfr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, should I still rename wave_reduce_or?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah OpGroupBroadcast, OpGroupReduce, and OpGroupGather are spirv implementation details. I don’t think it makes sense to take the language of spirv and apply it broadly across all the llvm intrinsics needed by this feature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm i am a little conflicted because it seems like previous implementers are already adding reduce into the intrinsic names. This might be a question for @Keenuts.
| case Intrinsic::dx_saturate: | ||
| case Intrinsic::dx_splitdouble: | ||
| case Intrinsic::dx_wave_readlane: | ||
| case Intrinsic::dx_wave_reduce_or: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you add a line here then you should be testing this intrinsic gets scalarized
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That happens in WaveActiveBitOr.ll? I am unsure what I would have to add to make a test 'scalar', as the result of the dx or spriv is always a scalar value as it is the same value across the wave.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
run these command and you should find examples
git grep -n "function(scalarizer<load-store>)" -- llvm/test/CodeGen/DirectX/
git grep -n "\-scalarizer" -- llvm/test/CodeGen/DirectX/
|
|
||
| // CHECK-LABEL: test_uint | ||
| uint test_uint(uint expr) { | ||
| // CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.reduce.or.i32([[TY]] %[[#]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see clang/test/CodeGenHLSL/builtins/dot.hlsl that should let you use CHECK for most of your intrinsics so we don't have to do so many seperate SPIRV vs DX checks.
should look something like this for the first one
// DXCHECK: %name = call @llvm.[[ICF:dx]].<intrinsic_name>.(...
// SPVCHECK: %name = call @llvm.[[ICF:spv]].<intrinsic_name>.(...
successive checks
// DXCHECK: %name = call @llvm.[[ICF]].<intrinsic_name>.(...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is still a separate DX and SPV check? Or do you mean:
// CHECK: %[[RET:.*]] = call [[TY:.*]] @llvm.[[ICF]].wave.reduce.or.i32([[TY]] %[[#]])
Because that wouldn't work as SPV needs spir_func in the call as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my recomendation is to look at how we did these tests
git grep -n "\[\ICF:dx\]\]"
clang/test/CodeGenHLSL/builtins/dot.hlsl:23:// DXCHECK: %hlsl.dot = call i32 @llvm.[[ICF:dx]].sdot.v2i32(<2 x i32>
clang/test/CodeGenHLSL/builtins/isinf.hlsl:19:// DXCHECK: %hlsl.isinf = call i1 @llvm.[[ICF:dx]].isinf.f32(
clang/test/CodeGenHLSL/builtins/isnan.hlsl:19:// DXCHECK: %hlsl.isnan = call i1 @llvm.[[ICF:dx]].isnan.f32(
git grep -n "\[\ICF:spv\]\]"
clang/test/CodeGenHLSL/builtins/dot.hlsl:24:// SPVCHECK: %hlsl.dot = call i32 @llvm.[[ICF:spv]].sdot.v2i32(<2 x i32>
clang/test/CodeGenHLSL/builtins/isinf.hlsl:20:// SPVCHECK: %hlsl.isinf = call i1 @llvm.[[ICF:spv]].isinf.f32(
clang/test/CodeGenHLSL/builtins/isnan.hlsl:20:// SPVCHECK: %hlsl.isnan = call i1 @llvm.[[ICF:spv]].isnan.f32(
git grep -n "\[\[FN_TYP.*\]\]"
clang/test/CodeGenHLSL/builtins/isinf.hlsl:17:// DXCHECK: define hidden [[FN_TYPE:]]noundef i1 @
clang/test/CodeGenHLSL/builtins/isinf.hlsl:18:// SPVCHECK: define hidden [[FN_TYPE:spir_func ]]noundef i1 @
clang/test/CodeGenHLSL/builtins/isinf.hlsl:24:// CHECK: define hidden [[FN_TYPE]]noundef i1 @
clang/test/CodeGenHLSL/builtins/isinf.hlsl:30:// CHECK: define hidden [[FN_TYPE]]noundef <2 x i1> @
clang/test/CodeGenHLSL/builtins/isinf.hlsl:36:// NATIVE_HALF: define hidden [[FN_TYPE]]noundef <3 x i1> @
clang/test/CodeGenHLSL/builtins/isinf.hlsl:42:// NATIVE_HALF: define hidden [[FN_TYPE]]noundef <4 x i1> @
clang/test/CodeGenHLSL/builtins/isinf.hlsl:49:// CHECK: define hidden [[FN_TYPE]]noundef <2 x i1> @
clang/test/CodeGenHLSL/builtins/isinf.hlsl:54:// CHECK: define hidden [[FN_TYPE]]noundef <3 x i1> @
clang/test/CodeGenHLSL/builtins/isinf.hlsl:59:// CHECK: define hidden [[FN_TYPE]]noundef <4 x i1> @
clang/test/CodeGenHLSL/builtins/isnan.hlsl:17:// DXCHECK: define hidden [[FN_TYPE:]]noundef i1 @
clang/test/CodeGenHLSL/builtins/isnan.hlsl:18:// SPVCHECK: define hidden [[FN_TYPE:spir_func ]]noundef i1 @
clang/test/CodeGenHLSL/builtins/isnan.hlsl:24:// CHECK: define hidden [[FN_TYPE]]noundef i1 @
clang/test/CodeGenHLSL/builtins/isnan.hlsl:30:// CHECK: define hidden [[FN_TYPE]]noundef <2 x i1> @
clang/test/CodeGenHLSL/builtins/isnan.hlsl:36:// NATIVE_HALF: define hidden [[FN_TYPE]]noundef <3 x i1> @
clang/test/CodeGenHLSL/builtins/isnan.hlsl:42:// NATIVE_HALF: define hidden [[FN_TYPE]]noundef <4 x i1> @
clang/test/CodeGenHLSL/builtins/isnan.hlsl:49:// CHECK: define hidden [[FN_TYPE]]noundef <2 x i1> @
clang/test/CodeGenHLSL/builtins/isnan.hlsl:54:// CHECK: define hidden [[FN_TYPE]]noundef <3 x i1> @
clang/test/CodeGenHLSL/builtins/isnan.hlsl:59:// CHECK: define hidden [[FN_TYPE]]noundef <4 x i1> @
Adds the WaveActiveBitOr intrinsic from issue #99167. This intrinsic required a bit more work than the last intrinsics that I have done.
There are some peculiarities, which I verified with dxcompiler:
Followed the checklist:
offload test: llvm/offload-test-suite#487