@@ -482,11 +482,8 @@ int GpuArray_rgemmBatch_3d(cb_transpose transA, cb_transpose transB, double alph
482482 cb_order o ;
483483 int cA , cB , cC ;
484484 int err ;
485- gpudata * * A_datas = NULL , * * B_datas = NULL , * * C_datas = NULL ;
486- size_t * A_offsets = NULL , * B_offsets = NULL , * C_offsets = NULL ;
487- size_t i ;
488485
489- if (A -> typecode != GA_FLOAT && A -> typecode != GA_DOUBLE )
486+ if (A -> typecode != GA_FLOAT && A -> typecode != GA_DOUBLE && A -> typecode != GA_HALF )
490487 return error_set (ctx -> err , GA_INVALID_ERROR , "Unsupported dtype" );
491488
492489 if (A -> nd != 3 || B -> nd != 3 || C -> nd != 3 )
@@ -625,50 +622,90 @@ int GpuArray_rgemmBatch_3d(cb_transpose transA, cb_transpose transB, double alph
625622 if (err != GA_NO_ERROR )
626623 goto cleanup ;
627624
628- A_datas = (gpudata * * )malloc (batchCount * sizeof (gpudata * ));
629- B_datas = (gpudata * * )malloc (batchCount * sizeof (gpudata * ));
630- C_datas = (gpudata * * )malloc (batchCount * sizeof (gpudata * ));
631-
632- A_offsets = (size_t * )malloc (batchCount * sizeof (size_t ));
633- B_offsets = (size_t * )malloc (batchCount * sizeof (size_t ));
634- C_offsets = (size_t * )malloc (batchCount * sizeof (size_t ));
635-
636- for (i = 0 ; i < batchCount ; i ++ ) {
637- A_datas [i ] = Ap -> data ;
638- B_datas [i ] = Bp -> data ;
639- C_datas [i ] = Cp -> data ;
640- A_offsets [i ] = (Ap -> offset + i * Ap -> strides [0 ]) / elsize ;
641- B_offsets [i ] = (Bp -> offset + i * Bp -> strides [0 ]) / elsize ;
642- C_offsets [i ] = (Cp -> offset + i * Cp -> strides [0 ]) / elsize ;
643- }
644-
645625 switch (C -> typecode ) {
646626 case GA_HALF :
647- err = gpublas_hgemmBatch (o , transA , transB , m , n , k , (float )alpha ,
648- A_datas , A_offsets , lda ,
649- B_datas , B_offsets , ldb ,
650- (float )beta ,
651- C_datas , C_offsets , ldc , batchCount , 0 );
627+ err = gpublas_hgemm3D (o , transA , transB , m , n , k , (float )alpha ,
628+ Ap -> data , Ap -> offset /elsize , lda , Ap -> strides [0 ]/elsize ,
629+ Bp -> data , Bp -> offset /elsize , ldb , Bp -> strides [0 ]/elsize ,
630+ (float )beta ,
631+ Cp -> data , Cp -> offset /elsize , ldc , Cp -> strides [0 ]/elsize ,
632+ batchCount , 0 );
652633 break ;
653634 case GA_FLOAT :
654- err = gpublas_sgemmBatch (o , transA , transB , m , n , k , (float )alpha ,
655- A_datas , A_offsets , lda ,
656- B_datas , B_offsets , ldb ,
657- (float )beta ,
658- C_datas , C_offsets , ldc , batchCount , 0 );
635+ err = gpublas_sgemm3D (o , transA , transB , m , n , k , (float )alpha ,
636+ Ap -> data , Ap -> offset /elsize , lda , Ap -> strides [0 ]/elsize ,
637+ Bp -> data , Bp -> offset /elsize , ldb , Bp -> strides [0 ]/elsize ,
638+ (float )beta ,
639+ Cp -> data , Cp -> offset /elsize , ldc , Cp -> strides [0 ]/elsize ,
640+ batchCount , 0 );
659641 break ;
660642 case GA_DOUBLE :
661- err = gpublas_dgemmBatch (o , transA , transB , m , n , k , (double )alpha ,
662- A_datas , A_offsets , lda ,
663- B_datas , B_offsets , ldb ,
664- (double )beta ,
665- C_datas , C_offsets , ldc , batchCount , 0 );
643+ err = gpublas_dgemm3D (o , transA , transB , m , n , k , (double )alpha ,
644+ Ap -> data , Ap -> offset /elsize , lda , Ap -> strides [0 ]/elsize ,
645+ Bp -> data , Bp -> offset /elsize , ldb , Bp -> strides [0 ]/elsize ,
646+ (double )beta ,
647+ Cp -> data , Cp -> offset /elsize , ldc , Cp -> strides [0 ]/elsize ,
648+ batchCount , 0 );
666649 break ;
667650 }
668651
652+ if (err == GA_DEVSUP_ERROR ) {
653+ gpudata * * A_datas = NULL , * * B_datas = NULL , * * C_datas = NULL ;
654+ size_t * A_offsets = NULL , * B_offsets = NULL , * C_offsets = NULL ;
655+ size_t i ;
656+
657+ A_datas = (gpudata * * )malloc (batchCount * sizeof (gpudata * ));
658+ B_datas = (gpudata * * )malloc (batchCount * sizeof (gpudata * ));
659+ C_datas = (gpudata * * )malloc (batchCount * sizeof (gpudata * ));
660+
661+ A_offsets = (size_t * )malloc (batchCount * sizeof (size_t ));
662+ B_offsets = (size_t * )malloc (batchCount * sizeof (size_t ));
663+ C_offsets = (size_t * )malloc (batchCount * sizeof (size_t ));
664+
665+ if (A_datas == NULL || B_datas == NULL || C_datas == NULL ||
666+ A_offsets == NULL || B_offsets == NULL || C_offsets ) {
667+ err = error_sys (ctx -> err , "malloc" );
668+ goto old_cleanup ;
669+ }
670+
671+ for (i = 0 ; i < batchCount ; i ++ ) {
672+ A_datas [i ] = Ap -> data ;
673+ B_datas [i ] = Bp -> data ;
674+ C_datas [i ] = Cp -> data ;
675+ A_offsets [i ] = (Ap -> offset + i * Ap -> strides [0 ]) / elsize ;
676+ B_offsets [i ] = (Bp -> offset + i * Bp -> strides [0 ]) / elsize ;
677+ C_offsets [i ] = (Cp -> offset + i * Cp -> strides [0 ]) / elsize ;
678+ }
679+
680+ switch (C -> typecode ) {
681+ case GA_HALF :
682+ err = gpublas_hgemmBatch (o , transA , transB , m , n , k , (float )alpha ,
683+ A_datas , A_offsets , lda ,
684+ B_datas , B_offsets , ldb ,
685+ (float )beta ,
686+ C_datas , C_offsets , ldc , batchCount , 0 );
687+ break ;
688+ case GA_FLOAT :
689+ err = gpublas_sgemmBatch (o , transA , transB , m , n , k , (float )alpha ,
690+ A_datas , A_offsets , lda ,
691+ B_datas , B_offsets , ldb ,
692+ (float )beta ,
693+ C_datas , C_offsets , ldc , batchCount , 0 );
694+ break ;
695+ case GA_DOUBLE :
696+ err = gpublas_dgemmBatch (o , transA , transB , m , n , k , (double )alpha ,
697+ A_datas , A_offsets , lda ,
698+ B_datas , B_offsets , ldb ,
699+ (double )beta ,
700+ C_datas , C_offsets , ldc , batchCount , 0 );
701+ break ;
702+ }
703+ old_cleanup :
704+ free (A_datas ); free (B_datas ); free (C_datas );
705+ free (A_offsets ); free (B_offsets ); free (C_offsets );
706+ }
707+
669708 cleanup :
670- free (A_datas ); free (B_datas ); free (C_datas );
671- free (A_offsets ); free (B_offsets ); free (C_offsets );
672709 if (Ap == & copyA )
673710 GpuArray_clear (& copyA );
674711 if (Bp == & copyB )
0 commit comments