Skip to content

Commit 43d38d3

Browse files
committed
Support for SME1 based ssyrk_direct kernel for cblas_ssyrk level 3 API
1 parent 644ea07 commit 43d38d3

File tree

9 files changed

+353
-0
lines changed

9 files changed

+353
-0
lines changed

common_level3.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,27 @@ void strmm_direct_LTLN(BLASLONG M, BLASLONG N,
8989
float * A, BLASLONG strideA,
9090
float * B, BLASLONG strideB);
9191

92+
void ssyrk_direct_alpha_betaUN(BLASLONG N, BLASLONG K,
93+
float alpha,
94+
float * A, BLASLONG strideA,
95+
float beta,
96+
float * C, BLASLONG strideC);
97+
void ssyrk_direct_alpha_betaUT(BLASLONG N, BLASLONG K,
98+
float alpha,
99+
float * A, BLASLONG strideA,
100+
float beta,
101+
float * C, BLASLONG strideC);
102+
void ssyrk_direct_alpha_betaLN(BLASLONG N, BLASLONG K,
103+
float alpha,
104+
float * A, BLASLONG strideA,
105+
float beta,
106+
float * C, BLASLONG strideC);
107+
void ssyrk_direct_alpha_betaLT(BLASLONG N, BLASLONG K,
108+
float alpha,
109+
float * A, BLASLONG strideA,
110+
float beta,
111+
float * C, BLASLONG strideC);
112+
92113
int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
93114

94115
int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,

common_param.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,10 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
263263
void (*strmm_direct_LNLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
264264
void (*strmm_direct_LTUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
265265
void (*strmm_direct_LTLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
266+
void (*ssyrk_direct_alpha_betaUN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
267+
void (*ssyrk_direct_alpha_betaUT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
268+
void (*ssyrk_direct_alpha_betaLN) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
269+
void (*ssyrk_direct_alpha_betaLT) (BLASLONG, BLASLONG, float, float *, BLASLONG, float, float *, BLASLONG);
266270
#endif
267271

268272

common_s.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@
5656
#define STRMM_DIRECT_LNLN strmm_direct_LNLN
5757
#define STRMM_DIRECT_LTUN strmm_direct_LTUN
5858
#define STRMM_DIRECT_LTLN strmm_direct_LTLN
59+
#define SSYRK_DIRECT_ALPHA_BETA_UN ssyrk_direct_alpha_betaUN
60+
#define SSYRK_DIRECT_ALPHA_BETA_UT ssyrk_direct_alpha_betaUT
61+
#define SSYRK_DIRECT_ALPHA_BETA_LN ssyrk_direct_alpha_betaLN
62+
#define SSYRK_DIRECT_ALPHA_BETA_LT ssyrk_direct_alpha_betaLT
5963

6064
#define SGEMM_ONCOPY sgemm_oncopy
6165
#define SGEMM_OTCOPY sgemm_otcopy
@@ -232,6 +236,10 @@
232236
#define STRMM_DIRECT_LNLN gotoblas -> strmm_direct_LNLN
233237
#define STRMM_DIRECT_LTUN gotoblas -> strmm_direct_LTUN
234238
#define STRMM_DIRECT_LTLN gotoblas -> strmm_direct_LTLN
239+
#define SSYRK_DIRECT_ALPHA_BETA_UN gotoblas -> ssyrk_direct_alpha_betaUN
240+
#define SSYRK_DIRECT_ALPHA_BETA_UT gotoblas -> ssyrk_direct_alpha_betaUT
241+
#define SSYRK_DIRECT_ALPHA_BETA_LN gotoblas -> ssyrk_direct_alpha_betaLN
242+
#define SSYRK_DIRECT_ALPHA_BETA_LT gotoblas -> ssyrk_direct_alpha_betaLT
235243
#endif
236244

237245
#define SGEMM_ONCOPY gotoblas -> sgemm_oncopy

interface/syrk.c

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,23 @@ double NNK;
338338
BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME));
339339
return;
340340
}
341+
#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
342+
#if defined(ARCH_ARM64) && (defined(USE_SSYRK_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
343+
#if defined(DYNAMIC_ARCH)
344+
if (support_sme1())
345+
#endif
346+
if (args.n == 0) return;
347+
if (order == CblasRowMajor && n == ldc) {
348+
if (Trans == CblasNoTrans && k == lda) {
349+
(Uplo == CblasUpper ? SSYRK_DIRECT_ALPHA_BETA_UN : SSYRK_DIRECT_ALPHA_BETA_LN)(n, k, alpha, a, lda, beta, c, ldc);
350+
return;
351+
} else if (Trans == CblasTrans && n == lda){
352+
(Uplo == CblasUpper ? SSYRK_DIRECT_ALPHA_BETA_UT : SSYRK_DIRECT_ALPHA_BETA_LT)(n, k, alpha, a, lda, beta, c, ldc);
353+
return;
354+
}
355+
}
356+
#endif
357+
#endif
341358

342359
#endif
343360

kernel/CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
241241
if (ARM64)
242242
set(USE_DIRECT_STRMM true)
243243
endif()
244+
set(USE_DIRECT_SSYRK false)
245+
if (ARM64)
246+
set(USE_DIRECT_SSYRK true)
247+
endif()
244248
set(USE_DIRECT_SGEMM false)
245249
if (X86_64 OR ARM64)
246250
set(USE_DIRECT_SGEMM true)
@@ -293,6 +297,16 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
293297
endif ()
294298
endif ()
295299

300+
if (USE_DIRECT_SSYRK)
301+
if (ARM64)
302+
set (SSYRKDIRECTKERNEL_ALPHA_BETA ssyrk_direct_alpha_beta_arm64_sme1.c)
303+
GenerateNamedObjects("${KERNELDIR}/${SSYRKDIRECTKERNEL_ALPHA_BETA}" "" "syrk_direct_alpha_betaUN" false "" "" false SINGLE)
304+
GenerateNamedObjects("${KERNELDIR}/${SSYRKDIRECTKERNEL_ALPHA_BETA}" "" "syrk_direct_alpha_betaUT" false "" "" false SINGLE)
305+
GenerateNamedObjects("${KERNELDIR}/${SSYRKDIRECTKERNEL_ALPHA_BETA}" "" "syrk_direct_alpha_betaLN" false "" "" false SINGLE)
306+
GenerateNamedObjects("${KERNELDIR}/${SSYRKDIRECTKERNEL_ALPHA_BETA}" "" "syrk_direct_alpha_betaLT" false "" "" false SINGLE)
307+
endif ()
308+
endif()
309+
296310
foreach (float_type SINGLE DOUBLE)
297311
string(SUBSTRING ${float_type} 0 1 float_char)
298312
GenerateNamedObjects("${KERNELDIR}/${${float_char}GEMMKERNEL}" "" "gemm_kernel" false "" "" false ${float_type})

kernel/Makefile.L3

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ USE_TRMM = 1
5454
USE_DIRECT_SGEMM = 1
5555
USE_DIRECT_SSYMM = 1
5656
USE_DIRECT_STRMM = 1
57+
USE_DIRECT_SSYRK = 1
5758
endif
5859

5960
ifeq ($(ARCH), riscv64)
@@ -161,6 +162,16 @@ endif
161162
endif
162163
endif
163164

165+
ifdef USE_DIRECT_SSYRK
166+
ifndef SSYRKDIRECTKERNEL_ALPHA_BETA
167+
ifeq ($(ARCH), arm64)
168+
ifeq ($(TARGET_CORE), ARMV9SME)
169+
HAVE_SME = 1
170+
endif
171+
SSYRKDIRECTKERNEL_ALPHA_BETA = ssyrk_direct_alpha_beta_arm64_sme1.c
172+
endif
173+
endif
174+
endif
164175

165176
ifeq ($(BUILD_BFLOAT16), 1)
166177
ifndef BGEMMKERNEL
@@ -261,6 +272,14 @@ SKERNELOBJS += \
261272
endif
262273
endif
263274

275+
ifdef USE_DIRECT_SSYRK
276+
ifeq ($(ARCH), arm64)
277+
SKERNELOBJS += \
278+
ssyrk_direct_alpha_betaUN$(TSUFFIX).$(SUFFIX) ssyrk_direct_alpha_betaUT$(TSUFFIX).$(SUFFIX) \
279+
ssyrk_direct_alpha_betaLN$(TSUFFIX).$(SUFFIX) ssyrk_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX)
280+
endif
281+
endif
282+
264283
ifneq "$(or $(BUILD_DOUBLE),$(BUILD_COMPLEX16))" ""
265284
DKERNELOBJS += \
266285
dgemm_beta$(TSUFFIX).$(SUFFIX) \
@@ -1158,6 +1177,21 @@ $(KDIR)xgemm_kernel_r$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(XGEMMKERNEL) $(XGEMMD
11581177
$(KDIR)xgemm_kernel_b$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(XGEMMKERNEL) $(XGEMMDEPEND)
11591178
$(CC) $(CFLAGS) -c -DXDOUBLE -DCOMPLEX -DCC $< -o $@
11601179

1180+
ifdef USE_DIRECT_SSYRK
1181+
ifeq ($(ARCH), arm64)
1182+
$(KDIR)ssyrk_direct_alpha_betaUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIRECTKERNEL_ALPHA_BETA)
1183+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DUPPER -UTRANSA $< -o $@
1184+
1185+
$(KDIR)ssyrk_direct_alpha_betaUT$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIRECTKERNEL_ALPHA_BETA)
1186+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DUPPER -DTRANSA $< -o $@
1187+
1188+
$(KDIR)ssyrk_direct_alpha_betaLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIRECTKERNEL_ALPHA_BETA)
1189+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UUPPER -UTRANSA $< -o $@
1190+
1191+
$(KDIR)ssyrk_direct_alpha_betaLT$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SSYRKDIRECTKERNEL_ALPHA_BETA)
1192+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UUPPER -DTRANSA $< -o $@
1193+
endif
1194+
endif
11611195

11621196
ifdef USE_TRMM
11631197
$(KDIR)strmm_kernel_LN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMKERNEL)

0 commit comments

Comments
 (0)