Skip to content

Commit f2d010d

Browse files
authored
Merge pull request #5512 from quic/topic/ssyrk_direct_sme1
Support for SME1 based ssyrk_direct kernel for cblas_ssyrk level 3 API
2 parents 65af1b1 + 43d38d3 commit f2d010d

File tree

9 files changed

+353
-0
lines changed

9 files changed

+353
-0
lines changed

common_level3.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,27 @@ void strmm_direct_LTLN(BLASLONG M, BLASLONG N,
8989
float * A, BLASLONG strideA,
9090
float * B, BLASLONG strideB);
9191

92+
void ssyrk_direct_alpha_betaUN(BLASLONG N, BLASLONG K,
93+
float alpha,
94+
float * A, BLASLONG strideA,
95+
float beta,
96+
float * C, BLASLONG strideC);
97+
void ssyrk_direct_alpha_betaUT(BLASLONG N, BLASLONG K,
98+
float alpha,
99+
float * A, BLASLONG strideA,
100+
float beta,
101+
float * C, BLASLONG strideC);
102+
void ssyrk_direct_alpha_betaLN(BLASLONG N, BLASLONG K,
103+
float alpha,
104+
float * A, BLASLONG strideA,
105+
float beta,
106+
float * C, BLASLONG strideC);
107+
void ssyrk_direct_alpha_betaLT(BLASLONG N, BLASLONG K,
108+
float alpha,
109+
float * A, BLASLONG strideA,
110+
float beta,
111+
float * C, BLASLONG strideC);
112+
92113
int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
93114

94115
int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,

common_param.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,10 @@ int (*shgemv_t) (BLASLONG, BLASLONG, float, hfloat16 *, BLASLONG, hfloat16 *, BL
264264
void (*strmm_direct_LNLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
265265
void (*strmm_direct_LTUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
266266
void (*strmm_direct_LTLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
267+
void (*ssyrk_direct_alpha_betaUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
268+
void (*ssyrk_direct_alpha_betaUT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
269+
void (*ssyrk_direct_alpha_betaLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
270+
void (*ssyrk_direct_alpha_betaLT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
267271
#endif
268272

269273

common_s.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@
5656
#define STRMM_DIRECT_LNLN strmm_direct_LNLN
5757
#define STRMM_DIRECT_LTUN strmm_direct_LTUN
5858
#define STRMM_DIRECT_LTLN strmm_direct_LTLN
59+
#define SSYRK_DIRECT_ALPHA_BETA_UN ssyrk_direct_alpha_betaUN
60+
#define SSYRK_DIRECT_ALPHA_BETA_UT ssyrk_direct_alpha_betaUT
61+
#define SSYRK_DIRECT_ALPHA_BETA_LN ssyrk_direct_alpha_betaLN
62+
#define SSYRK_DIRECT_ALPHA_BETA_LT ssyrk_direct_alpha_betaLT
5963

6064
#define SGEMM_ONCOPY sgemm_oncopy
6165
#define SGEMM_OTCOPY sgemm_otcopy
@@ -232,6 +236,10 @@
232236
#define STRMM_DIRECT_LNLN gotoblas -> strmm_direct_LNLN
233237
#define STRMM_DIRECT_LTUN gotoblas -> strmm_direct_LTUN
234238
#define STRMM_DIRECT_LTLN gotoblas -> strmm_direct_LTLN
239+
#define SSYRK_DIRECT_ALPHA_BETA_UN gotoblas -> ssyrk_direct_alpha_betaUN
240+
#define SSYRK_DIRECT_ALPHA_BETA_UT gotoblas -> ssyrk_direct_alpha_betaUT
241+
#define SSYRK_DIRECT_ALPHA_BETA_LN gotoblas -> ssyrk_direct_alpha_betaLN
242+
#define SSYRK_DIRECT_ALPHA_BETA_LT gotoblas -> ssyrk_direct_alpha_betaLT
235243
#endif
236244

237245
#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy

interface/syrk.c

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,23 @@ double NNK;
338338
BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
339339
return;
340340
}
341+
#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
342+
#if defined(ARCH_ARM64) && (defined(USE_SSYRK_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
343+
#if defined(DYNAMIC_ARCH)
344+
if (support_sme1())
345+
#endif
346+
if (args.n == 0) return;
347+
if (order == CblasRowMajor && n == ldc) {
348+
if (Trans == CblasNoTrans && k == lda) {
349+
(Uplo == CblasUpper ? SSYRK_DIRECT_ALPHA_BETA_UN : SSYRK_DIRECT_ALPHA_BETA_LN)(n, k, alpha, a, lda, beta, c, ldc);
350+
return;
351+
} else if (Trans == CblasTrans && n == lda){
352+
(Uplo == CblasUpper ? SSYRK_DIRECT_ALPHA_BETA_UT : SSYRK_DIRECT_ALPHA_BETA_LT)(n, k, alpha, a, lda, beta, c, ldc);
353+
return;
354+
}
355+
}
356+
#endif
357+
#endif
341358

342359
#endif
343360

kernel/CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
245245
if (ARM64)
246246
set(USE_DIRECT_STRMM true)
247247
endif()
248+
set(USE_DIRECT_SSYRK false)
249+
if (ARM64)
250+
set(USE_DIRECT_SSYRK true)
251+
endif()
248252
set(USE_DIRECT_SGEMM false)
249253
if (X86_64 OR ARM64)
250254
set(USE_DIRECT_SGEMM true)
@@ -297,6 +301,16 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
297301
endif ()
298302
endif ()
299303

304+
if (USE_DIRECT_SSYRK)
305+
if (ARM64)
306+
set (SSYRKDIRECTKERNEL_ALPHA_BETA ssyrk_direct_alpha_beta_arm64_sme1.c)
307+
GenerateNamedObjects("${KERNELDIR}/${SSYRKDIRECTKERNEL_ALPHA_BETA}" "" "syrk_direct_alpha_betaUN" false "" "" false SINGLE)
308+
GenerateNamedObjects("${KERNELDIR}/${SSYRKDIRECTKERNEL_ALPHA_BETA}" "" "syrk_direct_alpha_betaUT" false "" "" false SINGLE)
309+
GenerateNamedObjects("${KERNELDIR}/${SSYRKDIRECTKERNEL_ALPHA_BETA}" "" "syrk_direct_alpha_betaLN" false "" "" false SINGLE)
310+
GenerateNamedObjects("${KERNELDIR}/${SSYRKDIRECTKERNEL_ALPHA_BETA}" "" "syrk_direct_alpha_betaLT" false "" "" false SINGLE)
311+
endif ()
312+
endif()
313+
300314
foreach (float_type SINGLE DOUBLE)
301315
string(SUBSTRING ${float_type} 0 1 float_char)
302316
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMKERNEL}" "" "gemm_kernel" false "" "" false ${float_type})

kernel/Makefile.L3

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ USE_TRMM = 1
5454
USE_DIRECT_SGEMM = 1
5555
USE_DIRECT_SSYMM = 1
5656
USE_DIRECT_STRMM = 1
57+
USE_DIRECT_SSYRK = 1
5758
endif
5859

5960
ifeq ($(ARCH), riscv64)
@@ -161,6 +162,16 @@ endif
161162
endif
162163
endif
163164

165+
ifdef USE_DIRECT_SSYRK
166+
ifndef SSYRKDIRECTKERNEL_ALPHA_BETA
167+
ifeq ($(ARCH), arm64)
168+
ifeq ($(TARGET_CORE), ARMV9SME)
169+
HAVE_SME = 1
170+
endif
171+
SSYRKDIRECTKERNEL_ALPHA_BETA = ssyrk_direct_alpha_beta_arm64_sme1.c
172+
endif
173+
endif
174+
endif
164175

165176
ifeq ($(BUILD_BFLOAT16), 1)
166177
ifndef BGEMMKERNEL
@@ -261,6 +272,14 @@ SKERNELOBJS += \
261272
endif
262273
endif
263274

275+
ifdef USE_DIRECT_SSYRK
276+
ifeq ($(ARCH), arm64)
277+
SKERNELOBJS += \
278+
ssyrk_direct_alpha_betaUN$(TSUFFIX).$(SUFFIX) ssyrk_direct_alpha_betaUT$(TSUFFIX).$(SUFFIX) \
279+
ssyrk_direct_alpha_betaLN$(TSUFFIX).$(SUFFIX) ssyrk_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX)
280+
endif
281+
endif
282+
264283
ifneq "$(or $(BUILD_DOUBLE),$(BUILD_COMPLEX16))" ""
265284
DKERNELOBJS += \
266285
dgemm_beta$(TSUFFIX).$(SUFFIX) \
@@ -1158,6 +1177,21 @@ $(KDIR)xgemm_kernel_r$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(XGEMMKERNEL) $(XGEMMD
11581177
$(KDIR)xgemm_kernel_b$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(XGEMMKERNEL) $(XGEMMDEPEND)
11591178
$(CC) $(CFLAGS) -c -DXDOUBLE -DCOMPLEX -DCC $< -o $@
11601179

1180+
ifdef USE_DIRECT_SSYRK
1181+
ifeq ($(ARCH), arm64)
1182+
$(KDIR)ssyrk_direct_alpha_betaUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIRECTKERNEL_ALPHA_BETA)
1183+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DUPPER -UTRANSA $< -o $@
1184+
1185+
$(KDIR)ssyrk_direct_alpha_betaUT$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIRECTKERNEL_ALPHA_BETA)
1186+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DUPPER -DTRANSA $< -o $@
1187+
1188+
$(KDIR)ssyrk_direct_alpha_betaLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIRECTKERNEL_ALPHA_BETA)
1189+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UUPPER -UTRANSA $< -o $@
1190+
1191+
$(KDIR)ssyrk_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIRECTKERNEL_ALPHA_BETA)
1192+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UUPPER -DTRANSA $< -o $@
1193+
endif
1194+
endif
11611195

11621196
ifdef USE_TRMM
11631197
$(KDIR)strmm_kernel_LN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMKERNEL)

0 commit comments

Comments
 (0)