-
Notifications
You must be signed in to change notification settings - Fork 172
Open
Description
It looks like CUDA9 deprecates __shfl and __any. I was able to compile using the following quick&dirty patch:
--- LookupTable.cu 2019-05-18 11:03:38.615935768 +0200
+++ LookupTable.cu 2019-05-18 11:08:42.189278728 +0200
@@ -6,54 +6,54 @@
#include <thrust/execution_policy.h>
#include <thrust/iterator/constant_iterator.h>
#include <thrust/transform_reduce.h>
#if CUDA_VERSION >= 7000
#include <thrust/system/cuda/execution_policy.h>
#endif
#include <thrust/unique.h>
#include "THCHalf.h"
#include "THCHalfAutoNumerics.cuh"
#include "THCTensorSort.cuh"
+#define FULL_MASK 0xffffffff
-
const int WARP_SIZE = 32;
__device__ __forceinline__ bool warpHasCollision(int val)
{
// Compare our value to the values stored in the next 16 lanes,
// wrapping around at 32. If any pair of values is the same than
// there is a collision in the warp.
bool dup = 0;
const int laneId = threadIdx.x % 32;
#if __CUDA_ARCH__ >= 300
#pragma unroll
for (int i = 1; i <= 16; i++)
{
+ dup |= (__shfl_sync(FULL_MASK, val, (laneId + i) % 32) == val);
- dup |= (__shfl(val, (laneId + i) % 32) == val);
}
#else
volatile __shared__ int values[128];
values[threadIdx.x] = val;
const int offset = threadIdx.x - laneId;
#pragma unroll
for (int i = 1; i <= 16; i++)
{
dup |= (values[offset + ((laneId + i) % 32)] == val);
}
#endif
+ return __any_sync(FULL_MASK, dup) != 0;
- return __any(dup) != 0;
}
template <typename Dtype>
__global__ void cunn_LookupTable_accGradParametersKernelByFeature(
long *input, Dtype *gradOutput, Dtype *gradWeight, Dtype scale, ptrdiff_t numel,
long stride, int paddingValue) {
const int featureDim = blockIdx.x * 4 + threadIdx.x / 32;
if (featureDim >= stride) {
return;everdom
Metadata
Metadata
Assignees
Labels
No labels