Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions fbgemm_gpu/include/fbgemm_gpu/utils/cuda_prelude.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ inline int get_device_sm_cnt_() {

namespace fbgemm_gpu {

#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 9000
#if !defined(USE_ROCM) && defined(CUDA_VERSION)
#define FBGEMM_USE_SUBWARP_SHUFFLE
#endif

Expand Down Expand Up @@ -88,7 +88,7 @@ DEVICE_INLINE T shfl_xor(
int laneMask,
int width = kWarpSize,
unsigned shfl_sync_mask = static_cast<unsigned>(kFullWarpMask)) {
#if defined(USE_ROCM) || CUDA_VERSION < 9000
#if defined(USE_ROCM)
return __shfl_xor(val, laneMask, width);
#else
return __shfl_xor_sync(shfl_sync_mask, val, laneMask, width);
Expand All @@ -101,7 +101,7 @@ DEVICE_INLINE T shfl_sync(
int srcLane = 0,
int width = kWarpSize,
unsigned shfl_sync_mask = static_cast<unsigned>(kFullWarpMask)) {
#if defined(USE_ROCM) || CUDA_VERSION < 9000
#if defined(USE_ROCM)
return __shfl(val, srcLane, width);
#else
return __shfl_sync(shfl_sync_mask, val, srcLane, width);
Expand All @@ -114,21 +114,21 @@ DEVICE_INLINE T shfl_down_sync(
unsigned delta,
int width = kWarpSize,
unsigned shfl_sync_mask = static_cast<unsigned>(kFullWarpMask)) {
#if defined(USE_ROCM) || CUDA_VERSION < 9000
#if defined(USE_ROCM)
return __shfl_down(val, delta, width);
#else
return __shfl_down_sync(shfl_sync_mask, val, delta, width);
#endif
}

#if defined(USE_ROCM) || CUDA_VERSION < 9000
#if defined(USE_ROCM)
DEVICE_INLINE uint64_t ballot_sync(
#else
DEVICE_INLINE uint32_t ballot_sync(
#endif
int predicate,
unsigned shfl_sync_mask = static_cast<unsigned>(kFullWarpMask)) {
#if defined(USE_ROCM) || CUDA_VERSION < 9000
#if defined(USE_ROCM)
return __ballot(predicate);
#else
return __ballot_sync(shfl_sync_mask, predicate);
Expand Down
4 changes: 1 addition & 3 deletions fbgemm_gpu/include/fbgemm_gpu/utils/float.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ struct Half4 {
*reinterpret_cast<unsigned int*>(p) = *reinterpret_cast<unsigned int*>(&a);
*reinterpret_cast<unsigned int*>(p + 2) =
*reinterpret_cast<unsigned int*>(&b);
#elif CUDA_VERSION >= 9000
#else

#ifndef __HALF2_TO_UI
// cuda_fp16.hpp doesn't export this
Expand All @@ -64,8 +64,6 @@ struct Half4 {
asm("st.v2.u32 [%0], {%1, %2};"
:
: "l"(p), "r"(__HALF2_TO_UI(a)), "r"(__HALF2_TO_UI(b)));
#else
asm("st.v2.u32 [%0], {%1, %2};" : : "l"(p), "r"(a.x), "r"(b.x));
#endif
}
};
Expand Down
28 changes: 0 additions & 28 deletions fbgemm_gpu/include/fbgemm_gpu/utils/vec4.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,9 @@ struct Vec4T<float> : public Vec4BaseT<float> {
acc.w = b.y;
#else
Half4 out;
#if CUDA_VERSION >= 9000
asm("ld.global.v2.u32 {%0, %1}, [%2];"
: "=r"(__HALF2_TO_UI(out.a)), "=r"(__HALF2_TO_UI(out.b))
: "l"(p));
#else
asm("ld.global.v2.u32 {%0, %1}, [%2];"
: "=r"(out.a.x), "=r"(out.b.x)
: "l"(p));
#endif

float2 a = __half22float2(out.a);
float2 b = __half22float2(out.b);
Expand Down Expand Up @@ -287,15 +281,9 @@ struct Vec4T<at::Half> : public Vec4BaseT<at::Half> {
acc.w = b.y;
#else
Half4 out;
#if CUDA_VERSION >= 9000
asm("ld.global.v2.u32 {%0, %1}, [%2];"
: "=r"(__HALF2_TO_UI(out.a)), "=r"(__HALF2_TO_UI(out.b))
: "l"(p));
#else
asm("ld.global.v2.u32 {%0, %1}, [%2];"
: "=r"(out.a.x), "=r"(out.b.x)
: "l"(p));
#endif

float2 a = __half22float2(out.a);
float2 b = __half22float2(out.b);
Expand Down Expand Up @@ -360,22 +348,12 @@ struct Vec4T<at::Half> : public Vec4BaseT<at::Half> {
dst[3] = src[3];
#else
Half4 out;
#if CUDA_VERSION >= 9000
asm("ld.global.v2.u32 {%0, %1}, [%2];"
: "=r"(__HALF2_TO_UI(out.a)), "=r"(__HALF2_TO_UI(out.b))
: "l"(src));
#else
asm("ld.global.v2.u32 {%0, %1}, [%2];"
: "=r"(out.a.x), "=r"(out.b.x)
: "l"(src));
#endif
#if CUDA_VERSION >= 9000
asm("st.v2.u32 [%0], {%1, %2};"
:
: "l"(dst), "r"(__HALF2_TO_UI(out.a)), "r"(__HALF2_TO_UI(out.b)));
#else
asm("st.v2.u32 [%0], {%1, %2};" : : "l"(dst), "r"(out.a.x), "r"(out.b.x));
#endif
#endif
}

Expand Down Expand Up @@ -488,15 +466,9 @@ struct Vec4T<at::BFloat16> : public Vec4BaseT<at::BFloat16> {
acc.w = b.y;
#else
Half4 out;
#if CUDA_VERSION >= 9000
asm("ld.global.v2.u32 {%0, %1}, [%2];"
: "=r"(__HALF2_TO_UI(out.a)), "=r"(__HALF2_TO_UI(out.b))
: "l"(p));
#else
asm("ld.global.v2.u32 {%0, %1}, [%2];"
: "=r"(out.a.x), "=r"(out.b.x)
: "l"(p));
#endif

float2 a = __half22float2(out.a);
float2 b = __half22float2(out.b);
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/include/fbgemm_gpu/utils/vec_quant.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ DEVICE_INLINE T shfl_xor(
const T val,
int laneMask,
int width = kThreadsPerWarp) {
#if defined(__HIP_PLATFORM_AMD__) || CUDA_VERSION < 9000
#if defined(__HIP_PLATFORM_AMD__)
return __shfl_xor(val, laneMask, width);
#else
return __shfl_xor_sync(shfl_sync_mask, val, laneMask, width);
Expand Down
Loading