diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f874fa5eaa..e3491d7f11 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -36,6 +36,28 @@ foreach(test_bin ${OpenBLAS_Tests}) target_link_libraries(${test_bin} ${OpenBLAS_LIBNAME}) endforeach() +if (BUILD_BFLOAT16) + add_executable(test_bgemm compare_sgemm_bgemm.c) + target_compile_definitions(test_bgemm PUBLIC -DIBFLOAT16 -DOBFLOAT16) + target_link_libraries(test_bgemm ${OpenBLAS_LIBNAME}) + add_executable(test_bgemv compare_sgemv_bgemv.c) + target_compile_definitions(test_bgemv PUBLIC -DIBFLOAT16 -DOBFLOAT16) + target_link_libraries(test_bgemv ${OpenBLAS_LIBNAME}) + add_executable(test_sbgemm compare_sgemm_sbgemm.c) + target_compile_definitions(test_sbgemm PUBLIC -DIBFLOAT16) + target_link_libraries(test_sbgemm ${OpenBLAS_LIBNAME}) + add_executable(test_sbgemv compare_sgemv_sbgemv.c) + target_compile_definitions(test_sbgemv PUBLIC -DIBFLOAT16) + target_link_libraries(test_sbgemv ${OpenBLAS_LIBNAME}) +endif() + +if (BUILD_HFLOAT16) + add_executable(test_shgemm compare_sgemm_shgemm.c) + target_link_libraries(test_shgemm ${OpenBLAS_LIBNAME}) + add_executable(test_shgemv compare_sgemv_shgemv.c) + target_link_libraries(test_shgemv ${OpenBLAS_LIBNAME}) +endif() + # $1 exec, $2 input, $3 output_result if(WIN32) FILE(WRITE ${CMAKE_CURRENT_BINARY_DIR}/test_helper.ps1 @@ -94,3 +116,21 @@ add_test(NAME "${float_type}blas3_3m" endif() endif() endforeach() + +if (BUILD_BFLOAT16) + add_test(NAME "bgemm" + COMMAND $) + add_test(NAME "bgemv" + COMMAND $) + add_test(NAME "sbgemm" + COMMAND $) + add_test(NAME "sbgemv" + COMMAND $) +endif() + +if (BUILD_HFLOAT16) + add_test(NAME "shgemm" + COMMAND $) + add_test(NAME "shgemv" + COMMAND $) +endif() diff --git a/test/Makefile b/test/Makefile index 15e45302cb..8b69976f76 100644 --- a/test/Makefile +++ b/test/Makefile @@ -234,6 +234,9 @@ ifeq ($(BUILD_BFLOAT16),1) BF3= test_bgemm B3 = test_sbgemm endif +ifeq ($(BUILD_HFLOAT16),1) +H3 = test_shgemm +endif ifeq ($(BUILD_SINGLE),1) S3=sblat3 endif @@ -257,9 +260,9 @@ endif ifeq ($(SUPPORT_GEMM3M),1) -level3: $(BF3) $(B3) $(S3) $(D3) $(C3) $(Z3) level3_3m +level3: $(BF3) $(B3) $(H3) $(S3) $(D3) $(C3) $(Z3) level3_3m else -level3: $(BF3) $(B3) $(S3) $(D3) $(C3) $(Z3) +level3: $(BF3) $(B3) $(H3) $(S3) $(D3) $(C3) $(Z3) endif ifneq ($(CROSS), 1) @@ -454,6 +457,9 @@ test_sbgemv : compare_sgemv_sbgemv.c ../$(LIBNAME) endif ifeq ($(BUILD_HFLOAT16),1) +test_shgemm : compare_sgemm_shgemm.c test_helpers.h ../$(LIBNAME) + $(CC) $(CLDFLAGS) -DIHFLOAT16 -o test_shgemm compare_sgemm_shgemm.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) + test_shgemv : compare_sgemv_shgemv.c ../$(LIBNAME) $(CC) $(CLDFLAGS) -o test_shgemv compare_sgemv_shgemv.c ../$(LIBNAME) $(EXTRALIB) $(CEXTRALIB) endif @@ -475,7 +481,7 @@ clean: @rm -f *.$(SUFFIX) *.$(PSUFFIX) gmon.$(SUFFIX)ut *.SUMM *.cxml *.exe *.pdb *.dwf \ sblat1 dblat1 cblat1 zblat1 \ sblat2 dblat2 cblat2 zblat2 \ - test_bgemm test_bgemv test_sbgemm test_sbgemv test_shgemv sblat3 dblat3 cblat3 zblat3 \ + test_bgemm test_bgemv test_sbgemm test_sbgemv test_shgemm test_shgemv sblat3 dblat3 cblat3 zblat3 \ sblat1p dblat1p cblat1p zblat1p \ sblat2p dblat2p cblat2p zblat2p \ sblat3p dblat3p cblat3p zblat3p \ diff --git a/test/compare_sgemm_shgemm.c b/test/compare_sgemm_shgemm.c new file mode 100644 index 0000000000..7a97a06697 --- /dev/null +++ b/test/compare_sgemm_shgemm.c @@ -0,0 +1,234 @@ +/*************************************************************************** +Copyright (c) 2020,2025 The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ +#include +#include +#include "../common.h" + +#include "test_helpers.h" + +#define SGEMM BLASFUNC(sgemm) +#define SHGEMM BLASFUNC(shgemm) +#define SHGEMM_LARGEST 256 + +int +main (int argc, char *argv[]) +{ + blasint m, n, k; + int i, j, l; + blasint x, y; + int ret = 0; + int rret = 0; + int loop = SHGEMM_LARGEST; + char transA = 'N', transB = 'N'; + float alpha = 1.0, beta = 0.0; + int xvals[6]={3,24,55,71,SHGEMM_LARGEST/2,SHGEMM_LARGEST}; + + for (x = 0; x <= loop; x++) + { + if ((x > 100) && (x != SHGEMM_LARGEST)) continue; + m = k = n = x; + float *A = (float *)malloc_safe(m * k * sizeof(FLOAT)); + float *B = (float *)malloc_safe(k * n * sizeof(FLOAT)); + float *C = (float *)malloc_safe(m * n * sizeof(FLOAT)); + _Float16 *AA = (_Float16 *)malloc_safe(m * k * sizeof(_Float16)); + _Float16 *BB = (_Float16 *)malloc_safe(k * n * sizeof(_Float16)); + float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT)); + float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT)); + if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || + (DD == NULL) || (CC == NULL)) + return 1; + + for (j = 0; j < m; j++) + { + for (i = 0; i < k; i++) + { + A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + AA[j * k + i] = (_Float16) A[j * k + i]; + } + } + for (j = 0; j < n; j++) + { + for (i = 0; i < k; i++) + { + B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + BB[j * k + i] = (_Float16) B[j * k + i]; + } + } + for (y = 0; y < 4; y++) + { + if ((y == 0) || (y == 2)) { + transA = 'N'; + } else { + transA = 'T'; + } + if ((y == 0) || (y == 1)) { + transB = 'N'; + } else { + transB = 'T'; + } + + memset(CC, 0, m * n * sizeof(FLOAT)); + memset(DD, 0, m * n * sizeof(FLOAT)); + memset(C, 0, m * n * sizeof(FLOAT)); + + SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, + &m, B, &k, &beta, C, &m); + SHGEMM (&transA, &transB, &m, &n, &k, &alpha, (_Float16*) AA, + &m, (_Float16*)BB, &k, &beta, CC, &m); + + for (i = 0; i < n; i++) + for (j = 0; j < m; j++) + { + for (l = 0; l < k; l++) + if (transA == 'N' && transB == 'N') + { + DD[i * m + j] += + (float) AA[l * m + j] * (float)BB[l + k * i]; + } else if (transA == 'T' && transB == 'N') + { + DD[i * m + j] += + (float)AA[k * j + l] * (float)BB[l + k * i]; + } else if (transA == 'N' && transB == 'T') + { + DD[i * m + j] += + (float)AA[l * m + j] * (float)BB[i + l * n]; + } else if (transA == 'T' && transB == 'T') + { + DD[i * m + j] += + (float)AA[k * j + l] * (float)BB[i + l * n]; + } + if (!is_close(CC[i * m + j], C[i * m + j], 0.01, 0.001)) { + fprintf(stderr,"CC %f C %f \n",(float)CC[i*m+j],C[i*m+j]); + ret++; + } + if (!is_close(CC[i * m + j], DD[i * m + j], 0.001, 0.0001)) { + fprintf(stderr,"CC %f DD %f \n",(float)CC[i*m+j],(float)DD[i*m+j]); + ret++; + } + } + } + free(A); + free(B); + free(C); + free(AA); + free(BB); + free(DD); + free(CC); + } + if (ret != 0) { + fprintf(stderr, "SHGEMM FAILURES: %d!!!\n", ret); + return 1; + } + + + for (loop = 0; loop<6; loop++) { + x=xvals[loop]; + for (alpha=0.;alpha<=1.;alpha+=0.5) + { + for (beta = 0.0; beta <=1.; beta+=0.5) { + + m = k = n = x; + float *A = (float *)malloc_safe(m * k * sizeof(FLOAT)); + float *B = (float *)malloc_safe(k * n * sizeof(FLOAT)); + float *C = (float *)malloc_safe(m * n * sizeof(FLOAT)); + _Float16 *AA = (_Float16 *)malloc_safe(m * k * sizeof(_Float16)); + _Float16 *BB = (_Float16 *)malloc_safe(k * n * sizeof(_Float16)); + float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT)); + if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || + (CC == NULL)) + return 1; + + for (j = 0; j < m; j++) + { + for (i = 0; i < k; i++) + { + A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + AA[j * k + i] = (_Float16) A[j * k + i]; + } + } + for (j = 0; j < n; j++) + { + for (i = 0; i < k; i++) + { + B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + BB[j * k + i] = (_Float16) B[j * k + i]; + } + } + + for (y = 0; y < 4; y++) + { + if ((y == 0) || (y == 2)) { + transA = 'N'; + } else { + transA = 'T'; + } + if ((y == 0) || (y == 1)) { + transB = 'N'; + } else { + transB = 'T'; + } + + memset(CC, 0, m * n * sizeof(FLOAT)); + memset(C, 0, m * n * sizeof(FLOAT)); + + SGEMM (&transA, &transB, &m, &n, &k, &alpha, A, + &m, B, &k, &beta, C, &m); + SHGEMM (&transA, &transB, &m, &n, &k, &alpha, (_Float16*) AA, + &m, (_Float16*)BB, &k, &beta, CC, &m); + + for (i = 0; i < n; i++) + for (j = 0; j < m; j++) + { + if (!is_close(CC[i * m + j], C[i * m + j], 0.01, 0.001)) { + ret++; + } + } + } + free(A); + free(B); + free(C); + free(AA); + free(BB); + free(CC); + + if (ret != 0) { +/* + * fprintf(stderr, "SHGEMM FAILURES FOR n=%d, alpha=%f beta=%f : %d\n", x, alpha, beta, ret); + */ + rret++; + ret=0; +/* } else { + fprintf(stderr, "SHGEMM SUCCEEDED FOR n=%d, alpha=%f beta=%f : %d\n", x, alpha, beta, ret); +*/ + } + } + + } + } + if (rret > 0) return(1); + return(0); +}