Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4999,6 +4999,12 @@ def HLSLWaveActiveAnyTrue : LangBuiltin<"HLSL_LANG"> {
let Prototype = "bool(bool)";
}

def HLSLWaveActiveBitOr : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_active_bit_or"];
let Attributes = [NoThrow, Const];
let Prototype = "void (...)";
}

def HLSLWaveActiveCountBits : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_wave_active_count_bits"];
let Attributes = [NoThrow, Const];
Expand Down
9 changes: 9 additions & 0 deletions clang/lib/CodeGen/CGHLSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,15 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
return EmitRuntimeCall(
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
}
case Builtin::BI__builtin_hlsl_wave_active_bit_or: {
Value *Op = EmitScalarExpr(E->getArg(0));
assert(Op->getType()->hasUnsignedIntegerRepresentation() &&
"Intrinsic WaveActiveBitOr operand must have a unsigned integer representation");

Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveBitOrIntrinsic();
return EmitRuntimeCall(
Intrinsic::getOrInsertDeclaration(&CGM.getModule(), ID), {Op});
}
case Builtin::BI__builtin_hlsl_wave_active_count_bits: {
Value *OpExpr = EmitScalarExpr(E->getArg(0));
Intrinsic::ID ID = CGM.getHLSLRuntime().getWaveActiveCountBitsIntrinsic();
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class CGHLSLRuntime {
GENERATE_HLSL_INTRINSIC_FUNCTION(Dot4AddU8Packed, dot4add_u8packed)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAllTrue, wave_all)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveAnyTrue, wave_any)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveBitOr, wave_reduce_or)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveActiveCountBits, wave_active_countbits)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveIsFirstLane, wave_is_first_lane)
GENERATE_HLSL_INTRINSIC_FUNCTION(WaveGetLaneCount, wave_get_lane_count)
Expand Down
33 changes: 33 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_alias_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -2498,6 +2498,39 @@ __attribute__((convergent)) double3 WaveReadLaneAt(double3, uint32_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_read_lane_at)
__attribute__((convergent)) double4 WaveReadLaneAt(double4, uint32_t);

//===----------------------------------------------------------------------===//
// WaveActiveBitOr builtins
//===----------------------------------------------------------------------===//

// \brief Returns the value of the expression for the given lane index within
// the specified wave.

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
__attribute__((convergent)) uint WaveActiveBitOr(uint);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
__attribute__((convergent)) uint2 WaveActiveBitOr(uint2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
__attribute__((convergent)) uint3 WaveActiveBitOr(uint3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
__attribute__((convergent)) uint4 WaveActiveBitOr(uint4);

_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
__attribute__((convergent)) uint64_t WaveActiveBitOr(uint64_t);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
__attribute__((convergent)) uint64_t2 WaveActiveBitOr(uint64_t2);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
__attribute__((convergent)) uint64_t3 WaveActiveBitOr(uint64_t3);
_HLSL_AVAILABILITY(shadermodel, 6.0)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_wave_active_bit_or)
__attribute__((convergent)) uint64_t4 WaveActiveBitOr(uint64_t4);

//===----------------------------------------------------------------------===//
// WaveActiveMax builtins
//===----------------------------------------------------------------------===//
Expand Down
23 changes: 23 additions & 0 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3211,6 +3211,29 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
TheCall->setType(ArgTyExpr);
break;
}
case Builtin::BI__builtin_hlsl_wave_active_bit_or: {
if (SemaRef.checkArgCount(TheCall, 1))
return true;

if (CheckWaveActive(&SemaRef, TheCall))
return true;

// Ensure the expr type is interpretable as a uint or vector<uint>
ExprResult Expr = TheCall->getArg(0);
QualType ArgTyExpr = Expr.get()->getType();
auto *VTy = ArgTyExpr->getAs<VectorType>();
if (!(ArgTyExpr->isIntegerType() ||
(VTy && VTy->getElementType()->isIntegerType()))) {
SemaRef.Diag(TheCall->getArg(0)->getBeginLoc(),
diag::err_builtin_invalid_arg_type)
<< ArgTyExpr << SemaRef.Context.UnsignedIntTy << 1 << 0 << 0;
return true;
}

// Ensure input expr type is the same as the return type
TheCall->setType(ArgTyExpr);
break;
}
// Note these are llvm builtins that we want to catch invalid intrinsic
// generation. Normal handling of these builitns will occur elsewhere.
case Builtin::BI__builtin_elementwise_bitreverse: {
Expand Down
30 changes: 30 additions & 0 deletions clang/test/CodeGenHLSL/builtins/WaveActiveBitOr.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
// RUN: %clang_cc1 -std=hlsl2021 -finclude-default-header -triple \
// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \
// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV

// Test basic lowering to runtime function call.

// CHECK-LABEL: test_uint
uint test_uint(uint expr) {
// CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.reduce.or.i32([[TY]] %[[#]])
Copy link
Member

@farzonl farzonl Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see clang/test/CodeGenHLSL/builtins/dot.hlsl that should let you use CHECK for most of your intrinsics so we don't have to do so many seperate SPIRV vs DX checks.

should look something like this for the first one

// DXCHECK: %name = call @llvm.[[ICF:dx]].<intrinsic_name>.(...
// SPVCHECK: %name = call @llvm.[[ICF:spv]].<intrinsic_name>.(...

successive checks
// DXCHECK: %name = call @llvm.[[ICF]].<intrinsic_name>.(...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is still a separate DX and SPV check? Or do you mean:
// CHECK: %[[RET:.*]] = call [[TY:.*]] @llvm.[[ICF]].wave.reduce.or.i32([[TY]] %[[#]])

Because that wouldn't work as SPV needs spir_func in the call as well

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my recomendation is to look at how we did these tests

git grep -n "\[\ICF:dx\]\]"
clang/test/CodeGenHLSL/builtins/dot.hlsl:23:// DXCHECK: %hlsl.dot = call i32 @llvm.[[ICF:dx]].sdot.v2i32(<2 x i32>
clang/test/CodeGenHLSL/builtins/isinf.hlsl:19:// DXCHECK: %hlsl.isinf = call i1 @llvm.[[ICF:dx]].isinf.f32(
clang/test/CodeGenHLSL/builtins/isnan.hlsl:19:// DXCHECK: %hlsl.isnan = call i1 @llvm.[[ICF:dx]].isnan.f32(

git grep -n "\[\ICF:spv\]\]"
clang/test/CodeGenHLSL/builtins/dot.hlsl:24:// SPVCHECK: %hlsl.dot = call i32 @llvm.[[ICF:spv]].sdot.v2i32(<2 x i32>
clang/test/CodeGenHLSL/builtins/isinf.hlsl:20:// SPVCHECK: %hlsl.isinf = call i1 @llvm.[[ICF:spv]].isinf.f32(
clang/test/CodeGenHLSL/builtins/isnan.hlsl:20:// SPVCHECK: %hlsl.isnan = call i1 @llvm.[[ICF:spv]].isnan.f32(

git grep -n "\[\[FN_TYP.*\]\]" 
clang/test/CodeGenHLSL/builtins/isinf.hlsl:17:// DXCHECK: define hidden [[FN_TYPE:]]noundef i1 @
clang/test/CodeGenHLSL/builtins/isinf.hlsl:18:// SPVCHECK: define hidden [[FN_TYPE:spir_func ]]noundef i1 @
clang/test/CodeGenHLSL/builtins/isinf.hlsl:24:// CHECK: define hidden [[FN_TYPE]]noundef i1 @
clang/test/CodeGenHLSL/builtins/isinf.hlsl:30:// CHECK: define hidden [[FN_TYPE]]noundef <2 x i1> @
clang/test/CodeGenHLSL/builtins/isinf.hlsl:36:// NATIVE_HALF: define hidden [[FN_TYPE]]noundef <3 x i1> @
clang/test/CodeGenHLSL/builtins/isinf.hlsl:42:// NATIVE_HALF: define hidden [[FN_TYPE]]noundef <4 x i1> @
clang/test/CodeGenHLSL/builtins/isinf.hlsl:49:// CHECK: define hidden [[FN_TYPE]]noundef <2 x i1> @
clang/test/CodeGenHLSL/builtins/isinf.hlsl:54:// CHECK: define hidden [[FN_TYPE]]noundef <3 x i1> @
clang/test/CodeGenHLSL/builtins/isinf.hlsl:59:// CHECK: define hidden [[FN_TYPE]]noundef <4 x i1> @
clang/test/CodeGenHLSL/builtins/isnan.hlsl:17:// DXCHECK: define hidden [[FN_TYPE:]]noundef i1 @
clang/test/CodeGenHLSL/builtins/isnan.hlsl:18:// SPVCHECK: define hidden [[FN_TYPE:spir_func ]]noundef i1 @
clang/test/CodeGenHLSL/builtins/isnan.hlsl:24:// CHECK: define hidden [[FN_TYPE]]noundef i1 @
clang/test/CodeGenHLSL/builtins/isnan.hlsl:30:// CHECK: define hidden [[FN_TYPE]]noundef <2 x i1> @
clang/test/CodeGenHLSL/builtins/isnan.hlsl:36:// NATIVE_HALF: define hidden [[FN_TYPE]]noundef <3 x i1> @
clang/test/CodeGenHLSL/builtins/isnan.hlsl:42:// NATIVE_HALF: define hidden [[FN_TYPE]]noundef <4 x i1> @
clang/test/CodeGenHLSL/builtins/isnan.hlsl:49:// CHECK: define hidden [[FN_TYPE]]noundef <2 x i1> @
clang/test/CodeGenHLSL/builtins/isnan.hlsl:54:// CHECK: define hidden [[FN_TYPE]]noundef <3 x i1> @
clang/test/CodeGenHLSL/builtins/isnan.hlsl:59:// CHECK: define hidden [[FN_TYPE]]noundef <4 x i1> @

// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.reduce.or.i32([[TY]] %[[#]])
// CHECK: ret [[TY]] %[[RET]]
return WaveActiveBitOr(expr);
}

// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.or.i32([[TY]]) #[[#attr:]]
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.or.i32([[TY]]) #[[#attr:]]

// CHECK-LABEL: test_uint64_t
uint64_t test_uint64_t(uint64_t expr) {
// CHECK-SPIRV: %[[RET:.*]] = call spir_func [[TY:.*]] @llvm.spv.wave.reduce.or.i64([[TY]] %[[#]])
// CHECK-DXIL: %[[RET:.*]] = call [[TY:.*]] @llvm.dx.wave.reduce.or.i64([[TY]] %[[#]])
// CHECK: ret [[TY]] %[[RET]]
return WaveActiveBitOr(expr);
}

// CHECK-DXIL: declare [[TY]] @llvm.dx.wave.reduce.or.i64([[TY]]) #[[#attr:]]
// CHECK-SPIRV: declare [[TY]] @llvm.spv.wave.reduce.or.i64([[TY]]) #[[#attr:]]
38 changes: 38 additions & 0 deletions clang/test/SemaHLSL/BuiltIns/WaveActiveBitOr-errors.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify

uint test_too_few_arg() {
return __builtin_hlsl_wave_active_bit_or();
// expected-error@-1 {{too few arguments to function call, expected 1, have 0}}
}

uint2 test_too_many_arg(uint2 p0) {
return __builtin_hlsl_wave_active_bit_or(p0, p0);
// expected-error@-1 {{too many arguments to function call, expected 1, have 2}}
}

bool test_expr_bool_type_check(bool p0) {
return __builtin_hlsl_wave_active_bit_or(p0);
// expected-error@-1 {{invalid operand of type 'bool'}}
}

float test_expr_float_type_check(float p0) {
return __builtin_hlsl_wave_active_bit_or(p0);
// expected-error@-1 {{invalid operand of type 'float'}}
}

bool2 test_expr_bool_vec_type_check(bool2 p0) {
return __builtin_hlsl_wave_active_bit_or(p0);
// expected-error@-1 {{invalid operand of type 'bool2' (aka 'vector<bool, 2>')}}
}

float2 test_expr_float_type_check(float2 p0) {
return __builtin_hlsl_wave_active_bit_or(p0);
// expected-error@-1 {{invalid operand of type 'float2' (aka 'vector<float, 2>')}}
}

struct S { float f; };

S test_expr_struct_type_check(S p0) {
return __builtin_hlsl_wave_active_bit_or(p0);
// expected-error@-1 {{invalid operand of type 'S' where a scalar or vector is required}}
}
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def int_dx_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1
def int_dx_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_getlaneindex : DefaultAttrsIntrinsic<[llvm_i32_ty], [], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_or : DefaultAttrsIntrinsic<[llvm_anyint_ty], [llvm_anyint_ty], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_dx_wave_reduce_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
Expand Down
3 changes: 2 additions & 1 deletion llvm/include/llvm/IR/IntrinsicsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
def int_spv_wave_active_countbits : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_i1_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_or : DefaultAttrsIntrinsic<[llvm_anyint_ty], [llvm_anyint_ty], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_umax : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_max : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
def int_spv_wave_reduce_sum : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>], [IntrConvergent, IntrNoMem]>;
Expand All @@ -136,7 +137,7 @@ def int_spv_rsqrt : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty]
def int_spv_sclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
def int_spv_nclamp : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;

// Create resource handle given the binding information. Returns a
// Create resource handle given the binding information. Returns a
// type appropriate for the kind of resource given the set id, binding id,
// array size of the binding, as well as an index and an indicator
// whether that index may be non-uniform.
Expand Down
22 changes: 22 additions & 0 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,10 @@ defvar WaveOpKind_Product = 1;
defvar WaveOpKind_Min = 2;
defvar WaveOpKind_Max = 3;

defvar WaveBitOpKind_And = 0;
defvar WaveBitOpKind_Or = 1;
defvar WaveBitOpKind_Xor = 2;

defvar SignedOpKind_Signed = 0;
defvar SignedOpKind_Unsigned = 1;

Expand Down Expand Up @@ -1069,6 +1073,24 @@ def WaveActiveOp : DXILOp<119, waveActiveOp> {
let attributes = [Attributes<DXIL1_0, []>];
}

def WaveActiveBit : DXILOp<120, waveActiveBit> {
let Doc = "returns the result of the operation across waves";
let intrinsics = [
IntrinSelect<int_dx_wave_reduce_or,
[
IntrinArgIndex<0>, IntrinArgI8<WaveBitOpKind_Or>,
]>,
Comment on lines +1078 to +1082
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the formatting seems off

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its the same as DXILOp<119>, so what is wrong?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can run tablegen files through clang-format. There is no ci for this so suspect maybe thats why the WaveActiveOp look the way they do. If it doesn't change this then feel free to ignore my comment.

];

let arguments = [OverloadTy, Int8Ty];
let result = OverloadTy;
let overloads = [
Overloads<DXIL1_0, [Int32Ty, Int64Ty]>
];
let stages = [Stages<DXIL1_0, [all_stages]>];
let attributes = [Attributes<DXIL1_0, []>];
}

def WaveAllBitCount : DXILOp<135, waveAllOp> {
let Doc = "returns the count of bits set to 1 across the wave";
let intrinsics = [IntrinSelect<int_dx_wave_active_countbits>];
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/DirectX/DXILShaderFlags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ static bool checkWaveOps(Intrinsic::ID IID) {
case Intrinsic::dx_wave_readlane:
case Intrinsic::dx_wave_active_countbits:
// Wave Active Op Variants
case Intrinsic::dx_wave_reduce_or:
case Intrinsic::dx_wave_reduce_sum:
case Intrinsic::dx_wave_reduce_usum:
case Intrinsic::dx_wave_reduce_max:
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
case Intrinsic::dx_saturate:
case Intrinsic::dx_splitdouble:
case Intrinsic::dx_wave_readlane:
case Intrinsic::dx_wave_reduce_or:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you add a line here then you should be testing this intrinsic gets scalarized

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That happens in WaveActiveBitOr.ll? I am unsure what I would have to add to make a test 'scalar', as the result of the dx or spriv is always a scalar value as it is the same value across the wave.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run these command and you should find examples

git grep -n "function(scalarizer<load-store>)" -- llvm/test/CodeGen/DirectX/
git grep -n "\-scalarizer" -- llvm/test/CodeGen/DirectX/

case Intrinsic::dx_wave_reduce_max:
case Intrinsic::dx_wave_reduce_sum:
case Intrinsic::dx_wave_reduce_umax:
Expand Down
34 changes: 32 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectDot4AddPackedExpansion(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectWaveReduceOr(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectWaveReduceMax(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I, bool IsUnsigned) const;

Expand Down Expand Up @@ -2012,8 +2015,7 @@ bool SPIRVInstructionSelector::selectAnyOrAll(Register ResVReg,
Register InputRegister = I.getOperand(2).getReg();
SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister);

if (!InputType)
report_fatal_error("Input Type could not be determined.");
assert(InputType && "VReg has no type assigned");

bool IsBoolTy = GR.isScalarOrVectorOfType(InputRegister, SPIRV::OpTypeBool);
bool IsVectorTy = InputType->getOpcode() == SPIRV::OpTypeVector;
Expand Down Expand Up @@ -2427,6 +2429,32 @@ bool SPIRVInstructionSelector::selectWaveActiveCountBits(
return Result;
}

bool SPIRVInstructionSelector::selectWaveReduceOr(
Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const {

assert(I.getNumOperands() == 3);
assert(I.getOperand(2).isReg());
MachineBasicBlock &BB = *I.getParent();
Register InputRegister = I.getOperand(2).getReg();
SPIRVType *InputType = GR.getSPIRVTypeForVReg(InputRegister);

if (!InputType)
report_fatal_error("Input Type could not be determined.");

SPIRVType *IntTy = GR.getOrCreateSPIRVIntegerType(32, I, TII);

auto Opcode = SPIRV::OpGroupNonUniformBitwiseOr;

return BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(GR.getOrCreateConstInt(SPIRV::Scope::Subgroup, I, IntTy, TII,
!STI.isShader()))
.addImm(SPIRV::GroupOperation::Reduce)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this is why you are adding reduce to the name.
https://godbolt.org/z/Pe7hcEcfr

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, should I still rename wave_reduce_or?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah OpGroupBroadcast, OpGroupReduce, and OpGroupGather are spirv implementation details. I don’t think it makes sense to take the language of spirv and apply it broadly across all the llvm intrinsics needed by this feature.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm i am a little conflicted because it seems like previous implementers are already adding reduce into the intrinsic names. This might be a question for @Keenuts.

.addUse(I.getOperand(2).getReg())
.constrainAllUses(TII, TRI, RBI);
}

bool SPIRVInstructionSelector::selectWaveReduceMax(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I,
Expand Down Expand Up @@ -3427,6 +3455,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformAny);
case Intrinsic::spv_wave_is_first_lane:
return selectWaveOpInst(ResVReg, ResType, I, SPIRV::OpGroupNonUniformElect);
case Intrinsic::spv_wave_reduce_or:
return selectWaveReduceOr(ResVReg, ResType, I);
case Intrinsic::spv_wave_reduce_umax:
return selectWaveReduceMax(ResVReg, ResType, I, /*IsUnsigned*/ true);
case Intrinsic::spv_wave_reduce_max:
Expand Down
7 changes: 7 additions & 0 deletions llvm/test/CodeGen/DirectX/ShaderFlags/wave-ops.ll
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ entry:
ret i1 %ret
}

define noundef i32 @wave_bit_or(i32 %x) {
entry:
; CHECK: Function wave_bit_or : [[WAVE_FLAG]]
%ret = call i32 @llvm.dx.wave.reduce.or(i32 %x)
ret i32 %ret
}

define noundef i1 @wave_readlane(i1 %x, i32 %idx) {
entry:
; CHECK: Function wave_readlane : [[WAVE_FLAG]]
Expand Down
19 changes: 19 additions & 0 deletions llvm/test/CodeGen/DirectX/WaveActiveBitOr.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s

define noundef i32 @wave_bitor_simple(i32 noundef %p1) {
entry:
; CHECK: call i32 @dx.op.waveActiveBit.i32(i32 120, i32 %p1, i8 1){{$}}
%ret = call i32 @llvm.dx.wave.reduce.or.i32(i32 %p1)
ret i32 %ret
}

declare i32 @llvm.dx.wave.reduce.or.i32(i32)

define noundef i64 @wave_bitor_simple64(i64 noundef %p1) {
entry:
; CHECK: call i64 @dx.op.waveActiveBit.i64(i32 120, i64 %p1, i8 1){{$}}
%ret = call i64 @llvm.dx.wave.reduce.or.i64(i64 %p1)
ret i64 %ret
}

declare i64 @llvm.dx.wave.reduce.or.i64(i64)
30 changes: 30 additions & 0 deletions llvm/test/CodeGen/SPIRV/hlsl-intrinsics/WaveActiveBitOr.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan %s -o - -filetype=obj | spirv-val --target-env spv1.4 %}

; Test lowering to spir-v backend for various types and scalar/vector

; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#uint64:]] = OpTypeInt 64 0
; CHECK-DAG: %[[#scope:]] = OpConstant %[[#uint]] 3

; CHECK-LABEL: Begin function test_uint
; CHECK: %[[#iexpr:]] = OpFunctionParameter %[[#uint]]
define i32 @test_uint(i32 %iexpr) {
entry:
; CHECK: %[[#iret:]] = OpGroupNonUniformBitwiseOr %[[#uint]] %[[#scope]] Reduce %[[#iexpr]]
%0 = call i32 @llvm.spv.wave.reduce.or.i32(i32 %iexpr)
ret i32 %0
}

declare i32 @llvm.spv.wave.reduce.or.i32(i32)

; CHECK-LABEL: Begin function test_uint64
; CHECK: %[[#iexpr64:]] = OpFunctionParameter %[[#uint64]]
define i64 @test_uint64(i64 %iexpr64) {
entry:
; CHECK: %[[#iret:]] = OpGroupNonUniformBitwiseOr %[[#uint64]] %[[#scope]] Reduce %[[#iexpr64]]
%0 = call i64 @llvm.spv.wave.reduce.or.i64(i64 %iexpr64)
ret i64 %0
}

declare i64 @llvm.spv.wave.reduce.or.i64(i64)
Loading