Skip to content

Commit 10cb9ec

Browse files
authored
Merge pull request #447 from abergeron/hgemmstridedbatch
Use the new GemmStridedBatched functions from cublas for the 3d gemm.
2 parents 177bb37 + 7c14e97 commit 10cb9ec

File tree

12 files changed

+544
-40
lines changed

12 files changed

+544
-40
lines changed

src/gpuarray/blas.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ GPUARRAY_PUBLIC int GpuArray_rger(double alpha, GpuArray *X, GpuArray *Y,
3434
GPUARRAY_PUBLIC int GpuArray_rgemmBatch_3d(cb_transpose transA, cb_transpose transB,
3535
double alpha, GpuArray *A, GpuArray *B,
3636
double beta, GpuArray *C, int nocopy);
37+
#define GpuArray_hgemmBatch_3d GpuArray_rgemmBatch_3d
3738
#define GpuArray_sgemmBatch_3d GpuArray_rgemmBatch_3d
3839
#define GpuArray_dgemmBatch_3d GpuArray_rgemmBatch_3d
3940

src/gpuarray/buffer_blas.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,30 @@ GPUARRAY_PUBLIC int gpublas_hgemmBatch(
115115
float beta, gpudata **C, size_t *offC, size_t ldc,
116116
size_t batchCount, int flags);
117117

118+
GPUARRAY_PUBLIC int gpublas_hgemm3D(
119+
cb_order order, cb_transpose transA, cb_transpose transB,
120+
size_t M, size_t N, size_t K, float alpha,
121+
gpudata *A, size_t offA, size_t lda, ssize_t strideA,
122+
gpudata *B, size_t offB, size_t ldb, ssize_t strideB,
123+
float beta, gpudata *C, size_t offC, size_t ldc, ssize_t strideC,
124+
size_t batchCount, int flags);
125+
126+
GPUARRAY_PUBLIC int gpublas_sgemm3D(
127+
cb_order order, cb_transpose transA, cb_transpose transB,
128+
size_t M, size_t N, size_t K, float alpha,
129+
gpudata *A, size_t offA, size_t lda, ssize_t strideA,
130+
gpudata *B, size_t offB, size_t ldb, ssize_t strideB,
131+
float beta, gpudata *C, size_t offC, size_t ldc, ssize_t strideC,
132+
size_t batchCount, int flags);
133+
134+
GPUARRAY_PUBLIC int gpublas_dgemm3D(
135+
cb_order order, cb_transpose transA, cb_transpose transB,
136+
size_t M, size_t N, size_t K, double alpha,
137+
gpudata *A, size_t offA, size_t lda, ssize_t strideA,
138+
gpudata *B, size_t offB, size_t ldb, ssize_t strideB,
139+
double beta, gpudata *C, size_t offC, size_t ldc, ssize_t strideC,
140+
size_t batchCount, int flags);
141+
118142
GPUARRAY_PUBLIC int gpublas_sgemmBatch(
119143
cb_order order, cb_transpose transA, cb_transpose transB,
120144
size_t M, size_t N, size_t K, float alpha,

src/gpuarray/util.h

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,90 @@ GPUARRAY_PUBLIC void gpuarray_elemwise_collapse(unsigned int n,
9898
unsigned int *nd,
9999
size_t *dim, ssize_t **strs);
100100

101+
102+
typedef struct _ga_half_t { uint16_t h; } ga_half_t;
103+
104+
/* code strongly inspired from
105+
https://github.com/numpy/numpy/blob/master/numpy/core/src/npymath/halffloat.c#L246 */
106+
107+
static inline ga_half_t ga_float2half(float f) {
108+
union {
109+
float f;
110+
uint32_t bits;
111+
} bf;
112+
union {
113+
ga_half_t h;
114+
uint16_t bits;
115+
} bh;
116+
117+
uint32_t f_exp, f_sig;
118+
uint16_t h_sgn, h_exp, h_sig;
119+
120+
bf.f = f;
121+
122+
h_sgn = (bf.bits&0x80000000u) >> 16;
123+
f_exp = (bf.bits&0x7f800000u);
124+
125+
/* Exponent overflow/NaN converts to signed inf/NaN */
126+
if (f_exp >= 0x47800000u) {
127+
if (f_exp == 0x7f800000u) {
128+
/* Inf or NaN */
129+
f_sig = (bf.bits&0x007fffffu);
130+
if (f_sig != 0) {
131+
/* NaN - propagate the flag in the significand... */
132+
bh.bits = (uint16_t) (0x7c00u + (f_sig >> 13));
133+
/* ...but make sure it stays a NaN */
134+
if (bh.bits == 0x7c00u) {
135+
bh.bits++;
136+
}
137+
bh.bits += h_sgn;
138+
return bh.h;
139+
} else {
140+
/* signed inf */
141+
bh.bits = h_sgn + 0x7c00u;
142+
return bh.h;
143+
}
144+
} else {
145+
bh.bits = h_sgn + 0x7c00u;
146+
return bh.h;
147+
}
148+
}
149+
150+
if (f_exp <= 0x38000000u) {
151+
/*
152+
* Signed zeros, subnormal floats, and floats with small
153+
* exponents all convert to signed zero halfs.
154+
*/
155+
if (f_exp < 0x33000000u) {
156+
bh.bits = h_sgn;
157+
return bh.h;
158+
}
159+
/* Make the subnormal significand */
160+
f_exp >>= 23;
161+
f_sig = (0x00800000u + (bf.bits&0x007fffffu));
162+
f_sig >>= (113 - f_exp);
163+
/* Handle rounding by adding 1 to the bit beyond half precision */
164+
f_sig += 0x00001000u;
165+
h_sig = (uint16_t) (f_sig >> 13);
166+
/*
167+
* If the rounding causes a bit to spill into h_exp, it will
168+
* increment h_exp from zero to one and h_sig will be zero.
169+
* This is the correct result.
170+
*/
171+
bh.bits = h_sgn + h_sig;
172+
return bh.h;
173+
}
174+
175+
/* Regular case with no overflow or underflow */
176+
h_exp = (uint16_t) ((f_exp - 0x38000000u) >> 13);
177+
/* Handle rounding by adding 1 to the bit beyond half precision */
178+
f_sig = (bf.bits&0x007fffffu);
179+
f_sig += 0x00001000u;
180+
h_sig = (uint16_t) (f_sig >> 13);
181+
bh.bits = h_sgn + h_exp + h_sig;
182+
return bh.h;
183+
}
184+
101185
#ifdef __cplusplus
102186
}
103187
#endif

src/gpuarray_array_blas.c

Lines changed: 75 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -482,11 +482,8 @@ int GpuArray_rgemmBatch_3d(cb_transpose transA, cb_transpose transB, double alph
482482
cb_order o;
483483
int cA, cB, cC;
484484
int err;
485-
gpudata **A_datas = NULL, **B_datas = NULL, **C_datas = NULL;
486-
size_t *A_offsets = NULL, *B_offsets = NULL, *C_offsets = NULL;
487-
size_t i;
488485

489-
if (A->typecode != GA_FLOAT && A->typecode != GA_DOUBLE)
486+
if (A->typecode != GA_FLOAT && A->typecode != GA_DOUBLE && A->typecode != GA_HALF)
490487
return error_set(ctx->err, GA_INVALID_ERROR, "Unsupported dtype");
491488

492489
if (A->nd != 3 || B->nd != 3 || C->nd != 3)
@@ -625,50 +622,90 @@ int GpuArray_rgemmBatch_3d(cb_transpose transA, cb_transpose transB, double alph
625622
if (err != GA_NO_ERROR)
626623
goto cleanup;
627624

628-
A_datas = (gpudata**)malloc(batchCount * sizeof(gpudata*));
629-
B_datas = (gpudata**)malloc(batchCount * sizeof(gpudata*));
630-
C_datas = (gpudata**)malloc(batchCount * sizeof(gpudata*));
631-
632-
A_offsets = (size_t*)malloc(batchCount * sizeof(size_t));
633-
B_offsets = (size_t*)malloc(batchCount * sizeof(size_t));
634-
C_offsets = (size_t*)malloc(batchCount * sizeof(size_t));
635-
636-
for (i = 0; i < batchCount; i++) {
637-
A_datas[i] = Ap->data;
638-
B_datas[i] = Bp->data;
639-
C_datas[i] = Cp->data;
640-
A_offsets[i] = (Ap->offset + i * Ap->strides[0]) / elsize;
641-
B_offsets[i] = (Bp->offset + i * Bp->strides[0]) / elsize;
642-
C_offsets[i] = (Cp->offset + i * Cp->strides[0]) / elsize;
643-
}
644-
645625
switch (C->typecode) {
646626
case GA_HALF:
647-
err = gpublas_hgemmBatch(o, transA, transB, m, n, k, (float)alpha,
648-
A_datas, A_offsets, lda,
649-
B_datas, B_offsets, ldb,
650-
(float)beta,
651-
C_datas, C_offsets, ldc, batchCount, 0);
627+
err = gpublas_hgemm3D(o, transA, transB, m, n, k, (float)alpha,
628+
Ap->data, Ap->offset/elsize, lda, Ap->strides[0]/elsize,
629+
Bp->data, Bp->offset/elsize, ldb, Bp->strides[0]/elsize,
630+
(float)beta,
631+
Cp->data, Cp->offset/elsize, ldc, Cp->strides[0]/elsize,
632+
batchCount, 0);
652633
break;
653634
case GA_FLOAT:
654-
err = gpublas_sgemmBatch(o, transA, transB, m, n, k, (float)alpha,
655-
A_datas, A_offsets, lda,
656-
B_datas, B_offsets, ldb,
657-
(float)beta,
658-
C_datas, C_offsets, ldc, batchCount, 0);
635+
err = gpublas_sgemm3D(o, transA, transB, m, n, k, (float)alpha,
636+
Ap->data, Ap->offset/elsize, lda, Ap->strides[0]/elsize,
637+
Bp->data, Bp->offset/elsize, ldb, Bp->strides[0]/elsize,
638+
(float)beta,
639+
Cp->data, Cp->offset/elsize, ldc, Cp->strides[0]/elsize,
640+
batchCount, 0);
659641
break;
660642
case GA_DOUBLE:
661-
err = gpublas_dgemmBatch(o, transA, transB, m, n, k, (double)alpha,
662-
A_datas, A_offsets, lda,
663-
B_datas, B_offsets, ldb,
664-
(double)beta,
665-
C_datas, C_offsets, ldc, batchCount, 0);
643+
err = gpublas_dgemm3D(o, transA, transB, m, n, k, (double)alpha,
644+
Ap->data, Ap->offset/elsize, lda, Ap->strides[0]/elsize,
645+
Bp->data, Bp->offset/elsize, ldb, Bp->strides[0]/elsize,
646+
(double)beta,
647+
Cp->data, Cp->offset/elsize, ldc, Cp->strides[0]/elsize,
648+
batchCount, 0);
666649
break;
667650
}
668651

652+
if (err == GA_DEVSUP_ERROR) {
653+
gpudata **A_datas = NULL, **B_datas = NULL, **C_datas = NULL;
654+
size_t *A_offsets = NULL, *B_offsets = NULL, *C_offsets = NULL;
655+
size_t i;
656+
657+
A_datas = (gpudata**)malloc(batchCount * sizeof(gpudata*));
658+
B_datas = (gpudata**)malloc(batchCount * sizeof(gpudata*));
659+
C_datas = (gpudata**)malloc(batchCount * sizeof(gpudata*));
660+
661+
A_offsets = (size_t*)malloc(batchCount * sizeof(size_t));
662+
B_offsets = (size_t*)malloc(batchCount * sizeof(size_t));
663+
C_offsets = (size_t*)malloc(batchCount * sizeof(size_t));
664+
665+
if (A_datas == NULL || B_datas == NULL || C_datas == NULL ||
666+
A_offsets == NULL || B_offsets == NULL || C_offsets) {
667+
err = error_sys(ctx->err, "malloc");
668+
goto old_cleanup;
669+
}
670+
671+
for (i = 0; i < batchCount; i++) {
672+
A_datas[i] = Ap->data;
673+
B_datas[i] = Bp->data;
674+
C_datas[i] = Cp->data;
675+
A_offsets[i] = (Ap->offset + i * Ap->strides[0]) / elsize;
676+
B_offsets[i] = (Bp->offset + i * Bp->strides[0]) / elsize;
677+
C_offsets[i] = (Cp->offset + i * Cp->strides[0]) / elsize;
678+
}
679+
680+
switch (C->typecode) {
681+
case GA_HALF:
682+
err = gpublas_hgemmBatch(o, transA, transB, m, n, k, (float)alpha,
683+
A_datas, A_offsets, lda,
684+
B_datas, B_offsets, ldb,
685+
(float)beta,
686+
C_datas, C_offsets, ldc, batchCount, 0);
687+
break;
688+
case GA_FLOAT:
689+
err = gpublas_sgemmBatch(o, transA, transB, m, n, k, (float)alpha,
690+
A_datas, A_offsets, lda,
691+
B_datas, B_offsets, ldb,
692+
(float)beta,
693+
C_datas, C_offsets, ldc, batchCount, 0);
694+
break;
695+
case GA_DOUBLE:
696+
err = gpublas_dgemmBatch(o, transA, transB, m, n, k, (double)alpha,
697+
A_datas, A_offsets, lda,
698+
B_datas, B_offsets, ldb,
699+
(double)beta,
700+
C_datas, C_offsets, ldc, batchCount, 0);
701+
break;
702+
}
703+
old_cleanup:
704+
free(A_datas); free(B_datas); free(C_datas);
705+
free(A_offsets); free(B_offsets); free(C_offsets);
706+
}
707+
669708
cleanup:
670-
free(A_datas); free(B_datas); free(C_datas);
671-
free(A_offsets); free(B_offsets); free(C_offsets);
672709
if (Ap == &copyA)
673710
GpuArray_clear(&copyA);
674711
if (Bp == &copyB)

0 commit comments

Comments
 (0)