From edfaf331952bb9a3e560fd9542447ea23dd0896a Mon Sep 17 00:00:00 2001 From: sparklesea Date: Mon, 1 Jul 2024 09:15:59 +0000 Subject: [PATCH] add one spmm kernel --- .../python/conformance/diopi_functions.py | 9 + diopi_test/python/mytest.ipynb | 1 + impl/torch/CMakeLists.txt | 2 +- impl/torch/functions/functions.cpp | 14 + .../functions/functions_mmcv/cuda_helpers.h | 2 +- impl/torch/functions/functions_sparse/spmm.cu | 85 +++++ impl/torch/functions/functions_sparse/utils.h | 291 ++++++++++++++++++ impl/torch/sparse_kernel.h | 12 + impl/torch/test/CMakeLists.txt | 2 +- proto/include/diopi/functions.h | 15 + 10 files changed, 430 insertions(+), 3 deletions(-) create mode 100644 diopi_test/python/mytest.ipynb create mode 100644 impl/torch/functions/functions_sparse/spmm.cu create mode 100644 impl/torch/functions/functions_sparse/utils.h create mode 100644 impl/torch/sparse_kernel.h diff --git a/diopi_test/python/conformance/diopi_functions.py b/diopi_test/python/conformance/diopi_functions.py index 73b4e407f4..7d9161112f 100644 --- a/diopi_test/python/conformance/diopi_functions.py +++ b/diopi_test/python/conformance/diopi_functions.py @@ -6079,3 +6079,12 @@ def rotary_emb_v2(query, key, cos, sin, dim): ret = func(query.context(), query, key, cos, sin, dim) check_returncode(ret) return query, key + +def spmm(row_ptr, col_ind, value, input) -> Tensor: + M = row_ptr.size().data[-1]-1 + N = input.size().data[-1] + out = Tensor(list([M, N]),dtype=Dtype.float32) + func = check_function("diopiSpMM") + ret = func(input.context(), out, row_ptr, col_ind, value, input) + check_returncode(ret) + return out diff --git a/diopi_test/python/mytest.ipynb b/diopi_test/python/mytest.ipynb new file mode 100644 index 0000000000..98a3b854ee --- /dev/null +++ b/diopi_test/python/mytest.ipynb @@ -0,0 +1 @@ +{"cells":[{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["tensor([[ 0.0000, -0.0000, -0.0000, ..., -0.0000, 0.0000, 0.0000],\n"," [-0.0000, 0.0000, -0.0000, ..., 0.0000, 0.0000, 0.0000],\n"," [-0.0000, 0.0000, -0.0000, ..., 0.0000, -0.0000, 0.0000],\n"," ...,\n"," [ 0.0000, 23.0290, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n"," [-3.7184, -0.0000, -0.0000, ..., -0.0000, -0.0000, 0.0000],\n"," [ 0.0000, 0.0000, -0.0000, ..., -0.0000, -4.7012, 0.0000]])\n","tensor(crow_indices=tensor([ 0, 426, 850, ..., 1676665,\n"," 1677043, 1677436]),\n"," col_indices=tensor([ 25, 40, 72, ..., 4071, 4075, 4094]),\n"," values=tensor([11.1403, -1.7448, 5.9189, ..., 0.0816, 10.1160,\n"," -4.7012]), size=(4096, 4096), nnz=1677436,\n"," layout=torch.sparse_csr)\n","[[-0.12721 -0.49293 -1.36716 ... 0.11863 0.81915 0.1262 ]\n"," [-2.30494 -0.50631 -0.23826 ... -0.88652 0.44469 0.43873]\n"," [ 1.19146 0.83548 -1.03986 ... -0.18479 -1.14953 0.54477]\n"," ...\n"," [ 0.43145 1.52025 0.337 ... 0.12734 1.17637 -0.83798]\n"," [-1.14607 0.24201 -0.35273 ... -1.18817 -1.57963 0.49221]\n"," [-0.75149 -1.61606 1.13707 ... 0.69304 0.95235 -1.07856]]\n","[4096, 128]\n"]}],"source":["import torch\n","import diopilib\n","from conformance.diopi_functions import spmm\n","from conformance.diopi_runtime import Tensor\n","import numpy as np\n","from torch.nn.functional import dropout\n","# sparse matrix:\n","# 0.5, 0, 1\n","# 0, 0, 2\n","# 1, 3, 0.6\n","\n","M, K, N = 4096, 4096, 128\n","\n","a = torch.randn((M,K),dtype=torch.float32)\n","a = dropout(a, p=0.9)\n","sparse_a = a.to_sparse_csr()\n","print(a)\n","print(sparse_a)\n","b = np.random.randn(K,N).astype(np.float32)\n","print(b)\n","\n","input = Tensor.from_numpy(b)\n","row_ptr = Tensor.from_numpy(sparse_a.crow_indices().numpy().astype(np.int32))\n","col_ind = Tensor.from_numpy(sparse_a.col_indices().numpy().astype(np.int32))\n","values = Tensor.from_numpy(sparse_a.values().numpy().astype(np.float32))\n","print(list(input.size().data))\n","c = spmm(row_ptr, col_ind, values, input)\n","c_ref = a @ torch.from_numpy(b)\n","# c = rbrmsr_spmm(sparse_a.crow_indices(), sparse_a.col_indices(), sparse_a.values(), b)\n"]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["ours: tensor([[ 189.8907, -14.9602, 114.9101, ..., 84.5701, -9.3818,\n"," 367.1490],\n"," [ -49.5185, 155.6731, -420.4368, ..., 107.9797, 367.4091,\n"," -8.2983],\n"," [ 100.9162, 10.6236, 311.1810, ..., -59.7648, 264.6500,\n"," -59.7523],\n"," ...,\n"," [-231.3855, 114.1076, -277.9331, ..., -188.6128, 57.9538,\n"," -185.5934],\n"," [ 57.2971, 128.9449, -58.4205, ..., -96.2650, -76.2732,\n"," 210.6950],\n"," [-113.7153, 34.8387, 90.0678, ..., -45.6139, 31.5837,\n"," -257.7583]])\n","ref: tensor([[ 189.8907, -14.9602, 114.9101, ..., 84.5701, -9.3818,\n"," 367.1489],\n"," [ -49.5186, 155.6732, -420.4370, ..., 107.9796, 367.4092,\n"," -8.2982],\n"," [ 100.9163, 10.6236, 311.1810, ..., -59.7648, 264.6500,\n"," -59.7523],\n"," ...,\n"," [-231.3856, 114.1077, -277.9332, ..., -188.6128, 57.9537,\n"," -185.5935],\n"," [ 57.2971, 128.9448, -58.4205, ..., -96.2650, -76.2731,\n"," 210.6951],\n"," [-113.7153, 34.8386, 90.0678, ..., -45.6140, 31.5837,\n"," -257.7583]])\n"]}],"source":["c_ours = torch.from_numpy(c.numpy())\n","print(\"ours:\", c_ours)\n","print(\"ref: \", c_ref)\n","assert torch.allclose(c_ours, c_ref, rtol=1e-03, atol=1e-03)"]}],"metadata":{"kernelspec":{"display_name":"eda","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.19"}},"nbformat":4,"nbformat_minor":2} diff --git a/impl/torch/CMakeLists.txt b/impl/torch/CMakeLists.txt index 5419defb11..d3e844c100 100644 --- a/impl/torch/CMakeLists.txt +++ b/impl/torch/CMakeLists.txt @@ -33,7 +33,7 @@ if (DYLOAD) set(IMPL_SRC wrap_func.cpp) endif() -file(GLOB REAL_IMPL_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} functions/functions_mmcv/*.cu functions/functions_ext/*.cu functions/*.cpp helper.cpp build_aten.cpp) +file(GLOB REAL_IMPL_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} functions/functions_sparse/*.cu functions/functions_mmcv/*.cu functions/functions_ext/*.cu functions/*.cpp helper.cpp build_aten.cpp) # adaptor set(USE_ADAPTOR ON) diff --git a/impl/torch/functions/functions.cpp b/impl/torch/functions/functions.cpp index 2a17e36424..53fb3bf5aa 100644 --- a/impl/torch/functions/functions.cpp +++ b/impl/torch/functions/functions.cpp @@ -30,6 +30,7 @@ #include "../helper.hpp" #include "../vision_kernel.h" +#include "../sparse_kernel.h" namespace impl { namespace cuda { @@ -4383,5 +4384,18 @@ DIOPI_API diopiError_t diopiBatchNormElemt(diopiContextHandle_t ctx, diopiTensor return diopiSuccess; } +diopiError_t diopiSpMM(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t row_ptr, + diopiTensorHandle_t col_ind, diopiTensorHandle_t value, diopiTensorHandle_t input){ + impl::aten::setCurStream(ctx); + auto atInput = impl::aten::buildATen(input); + auto atRowPtr = impl::aten::buildATen(row_ptr); + auto atColInd = impl::aten::buildATen(col_ind); + auto atValue = impl::aten::buildATen(value); + auto atOut = impl::aten::buildATen(out); + sparse::ops::row_balance_row_major_seq_reduce_kernel(atOut, atRowPtr, atColInd, atValue, atInput); + + return diopiSuccess; +} + } // namespace cuda } // namespace impl diff --git a/impl/torch/functions/functions_mmcv/cuda_helpers.h b/impl/torch/functions/functions_mmcv/cuda_helpers.h index ef44ae4418..395978b27e 100644 --- a/impl/torch/functions/functions_mmcv/cuda_helpers.h +++ b/impl/torch/functions/functions_mmcv/cuda_helpers.h @@ -41,7 +41,7 @@ constexpr int THREADS_PER_BLOCK = 512; inline int GET_BLOCKS(const int N, const int num_threads = THREADS_PER_BLOCK) { int optimal_block_num = (N + num_threads - 1) / num_threads; int max_block_num = 4096; - return std::min(optimal_block_num, max_block_num); + return min(optimal_block_num, max_block_num); } template diff --git a/impl/torch/functions/functions_sparse/spmm.cu b/impl/torch/functions/functions_sparse/spmm.cu new file mode 100644 index 0000000000..65647c4935 --- /dev/null +++ b/impl/torch/functions/functions_sparse/spmm.cu @@ -0,0 +1,85 @@ +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "utils.h" + +namespace sparse{ +namespace ops { + + +template +__global__ void +csrspmm_seqreduce_rowbalance_kernel(const Index nr, const Index feature_size, + const Index rowPtr[], const Index colIdx[], + const DType values[], const DType dnInput[], + DType dnOutput[]) { + Index row_tile = blockDim.y; // 8 + Index subwarp_id = threadIdx.y; + Index stride = row_tile * gridDim.x; // 8 * (m/8) + Index row = blockIdx.x * row_tile + subwarp_id; + Index v_id = (blockIdx.y * blockDim.x) + threadIdx.x; + dnInput += v_id; + dnOutput += v_id; + DType val; + // DType res = init(REDUCE::Op); + Index col; + for (; row < nr; row += stride) { + DType res = 0; + Index E_k_idx = -1; + Index start = __ldg(rowPtr + row); + Index end = __ldg(rowPtr + row + 1); + + for (Index p = start; p < end; p++) { + col = __ldg(colIdx + p); + val = __guard_load_default_one(values, p); + res += val * __ldg(dnInput + col * feature_size); + } + + dnOutput[row * feature_size] = res; + } +} + +at::Tensor row_balance_row_major_seq_reduce_kernel(at::Tensor& out, at::Tensor& row_ptr, at::Tensor& col_ind, + at::Tensor& value, at::Tensor& input){ + // assertTensor(row_ptr, at::kInt32); + // assertTensor(col_ind, at::kInt32); + // assertTensor(input, at::kFloat32); + // assertTensor(value, at::kFloat32); + input = input.contiguous(); + // int v = row_ptr.size(0) - 1; + // int Ndim_worker = input.size(1); + // int f = Ndim_worker; + // int e = col_ind.size(0); + + int Mdim_worker = row_ptr.size(0) - 1; + // int v = Mdim_worker; + int Ndim_worker = input.size(1); + // int f = Ndim_worker; + // int e = col_ind.size(0); + int RefThreadPerBlock = 256; + int Ndim_threadblock = CEIL(Ndim_worker, RefThreadPerBlock); + int Ndim_thread_per_tb = min(Ndim_worker, RefThreadPerBlock); + int Mdim_thread_per_tb = CEIL(RefThreadPerBlock, Ndim_thread_per_tb); + int Mdim_threadblock = CEIL(Mdim_worker, Mdim_thread_per_tb); + + dim3 gridDim(Mdim_threadblock, Ndim_threadblock, 1); + dim3 blockDim(Ndim_thread_per_tb, Mdim_thread_per_tb, 1); + + // auto out = at::empty({v, f}, options); + csrspmm_seqreduce_rowbalance_kernel + <<>>( + Mdim_worker, Ndim_worker, row_ptr.data_ptr(), + col_ind.data_ptr(), value.data_ptr(), + input.data_ptr(), out.data_ptr()); + return out; +} + +} +} \ No newline at end of file diff --git a/impl/torch/functions/functions_sparse/utils.h b/impl/torch/functions/functions_sparse/utils.h new file mode 100644 index 0000000000..04b90a6140 --- /dev/null +++ b/impl/torch/functions/functions_sparse/utils.h @@ -0,0 +1,291 @@ +#ifndef UTILS_SPARSE +#define UTILS_SPARSE +#include +#include +#include +#include +#include + +#include "device_atomic_functions.h" +#include "device_launch_parameters.h" + +#define CEIL(x, y) (((x) + (y)-1) / (y)) + +#define FULLMASK 0xffffffff +#define MIN(a, b) ((a < b) ? a : b) +#define MAX(a, b) ((a < b) ? b : a) + +enum gespmmAlg_t { + GESPMM_ALG_SEQREDUCE_ROWBALANCE = 0, + GESPMM_ALG_PARREDUCE_ROWBALANCE, + GESPMM_ALG_SEQREDUCE_NNZBALANCE, + GESPMM_ALG_PARREDUCE_NNZBALANCE, + GESPMM_ALG_DEFAULT +}; + +// #define SHFL_DOWN_REDUCE(v, temp_v, REDUCE, idx) \ +// switch (REDUCE) { \ +// case REDUCEOP::SUM: \ +// case REDUCEOP::MEAN: \ +// v += __shfl_down_sync(FULLMASK, v, 16); \ +// v += __shfl_down_sync(FULLMASK, v, 8); \ +// v += __shfl_down_sync(FULLMASK, v, 4); \ +// v += __shfl_down_sync(FULLMASK, v, 2); \ +// v += __shfl_down_sync(FULLMASK, v, 1); \ +// break; \ +// case REDUCEOP::MAX: \ +// temp_v = __shfl_down_sync(FULLMASK, temp_v, 16); \ +// if (temp_v > v) { \ +// v = temp_v; \ +// idx = __shfl_down_sync(FULLMASK, idx, 16); \ +// } \ +// temp_v = __shfl_down_sync(FULLMASK, temp_v, 8); \ +// if (temp_v > v) { \ +// v = temp_v; \ +// idx = __shfl_down_sync(FULLMASK, idx, 8); \ +// } \ +// temp_v = __shfl_down_sync(FULLMASK, temp_v, 4); \ +// if (temp_v > v) { \ +// v = temp_v; \ +// idx = __shfl_down_sync(FULLMASK, idx, 4); \ +// } \ +// temp_v = __shfl_down_sync(FULLMASK, temp_v, 2); \ +// if (temp_v > v) { \ +// v = temp_v; \ +// idx = __shfl_down_sync(FULLMASK, idx, 2); \ +// } \ +// temp_v = __shfl_down_sync(FULLMASK, temp_v, 1); \ +// if (temp_v > v) { \ +// v = temp_v; \ +// idx = __shfl_down_sync(FULLMASK, idx, 1); \ +// } \ +// break; \ +// case REDUCEOP::MIN: \ +// temp_v = __shfl_down_sync(FULLMASK, temp_v, 16); \ +// if (temp_v < v) { \ +// v = temp_v; \ +// idx = __shfl_down_sync(FULLMASK, idx, 16); \ +// } \ +// temp_v = __shfl_down_sync(FULLMASK, temp_v, 8); \ +// if (temp_v < v) { \ +// v = temp_v; \ +// idx = __shfl_down_sync(FULLMASK, idx, 8); \ +// } \ +// temp_v = __shfl_down_sync(FULLMASK, temp_v, 4); \ +// if (temp_v < v) { \ +// v = temp_v; \ +// idx = __shfl_down_sync(FULLMASK, idx, 4); \ +// } \ +// temp_v = __shfl_down_sync(FULLMASK, temp_v, 2); \ +// if (temp_v < v) { \ +// v = temp_v; \ +// idx = __shfl_down_sync(FULLMASK, idx, 2); \ +// } \ +// temp_v = __shfl_down_sync(FULLMASK, temp_v, 1); \ +// if (temp_v < v) { \ +// v = temp_v; \ +// idx = __shfl_down_sync(FULLMASK, idx, 1); \ +// } \ +// break; \ +// default: \ +// break; \ +// }; + +#define SEG_SHFL_SCAN(v, tmpv, segid, tmps) \ + tmpv = __shfl_down_sync(FULLMASK, v, 1); \ + tmps = __shfl_down_sync(FULLMASK, segid, 1); \ + if (tmps == segid && lane_id < 31) \ + v += tmpv; \ + tmpv = __shfl_down_sync(FULLMASK, v, 2); \ + tmps = __shfl_down_sync(FULLMASK, segid, 2); \ + if (tmps == segid && lane_id < 30) \ + v += tmpv; \ + tmpv = __shfl_down_sync(FULLMASK, v, 4); \ + tmps = __shfl_down_sync(FULLMASK, segid, 4); \ + if (tmps == segid && lane_id < 28) \ + v += tmpv; \ + tmpv = __shfl_down_sync(FULLMASK, v, 8); \ + tmps = __shfl_down_sync(FULLMASK, segid, 8); \ + if (tmps == segid && lane_id < 24) \ + v += tmpv; \ + tmpv = __shfl_down_sync(FULLMASK, v, 16); \ + tmps = __shfl_down_sync(FULLMASK, segid, 16); \ + if (tmps == segid && lane_id < 16) \ + v += tmpv; + +#define checkCudaError(a) \ + do { \ + if (cudaSuccess != (a)) { \ + fprintf(stderr, "Cuda runTime error in line %d of file %s \ + : %s \n", \ + __LINE__, __FILE__, cudaGetErrorString(cudaGetLastError())); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +#define checkCuSparseError(a) \ + do { \ + if (CUSPARSE_STATUS_SUCCESS != (a)) { \ + fprintf(stderr, "CuSparse runTime error in line %d of file %s \ + : %s \n", \ + __LINE__, __FILE__, cudaGetErrorString(cudaGetLastError())); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) +__device__ __forceinline__ float sum_reduce(float acc, float x) { + return acc + x; +} + +template +__device__ __forceinline__ T __guard_load_default_one(const T *base, + int offset) { + if (base != nullptr) + return base[offset]; + else + return static_cast(1); +} + +__device__ __forceinline__ float sum_init() { return 0; } + +__device__ __forceinline__ int findRow(const int *S_csrRowPtr, int eid, + int start, int end) { + int low = start, high = end; + if (low == high) + return low; + while (low < high) { + int mid = (low + high) >> 1; + if (S_csrRowPtr[mid] <= eid) + low = mid + 1; + else + high = mid; + } + if (S_csrRowPtr[high] == eid) + return high; + else + return high - 1; +} + +template +__device__ __forceinline__ void Load(ldType &tmp, data *array, int offset) { + tmp = *(reinterpret_cast(array + offset)); +} + +template +__device__ __forceinline__ void Load(data *lhd, data *rhd, int offset) { + *(reinterpret_cast(lhd)) = + *(reinterpret_cast(rhd + offset)); +} + +template +__device__ __forceinline__ void Store(data *lhd, data *rhd, int offset) { + *(reinterpret_cast(lhd + offset)) = + *(reinterpret_cast(rhd)); +} + +template +__device__ __forceinline__ void Load4(ldType *tmp, data *array, int *offset, + int offset2 = 0) { + Load(tmp[0], array, offset[0] + offset2); + Load(tmp[1], array, offset[1] + offset2); + Load(tmp[2], array, offset[2] + offset2); + Load(tmp[3], array, offset[3] + offset2); +} + +template +__device__ __forceinline__ data vecDot2(vecData &lhd, vecData &rhd) { + return lhd.x * rhd.x + lhd.y * rhd.y; +} + +template +__device__ __forceinline__ data vecDot4(vecData &lhd, vecData &rhd) { + return lhd.x * rhd.x + lhd.y * rhd.y + lhd.z * rhd.z + lhd.w * rhd.w; +} + +template +__device__ __forceinline__ void vec4Dot4(data *cal, vecData *lhd, + vecData *rhd) { + cal[0] += vecDot4(lhd[0], rhd[0]); + cal[1] += vecDot4(lhd[1], rhd[1]); + cal[2] += vecDot4(lhd[2], rhd[2]); + cal[3] += vecDot4(lhd[3], rhd[3]); +} + +template +__device__ __forceinline__ void vec2Dot4(data *cal, vecData *lhd, + vecData *rhd) { + cal[0] += vecDot2(lhd[0], rhd[0]); + cal[1] += vecDot2(lhd[1], rhd[1]); + cal[2] += vecDot2(lhd[2], rhd[2]); + cal[3] += vecDot2(lhd[3], rhd[3]); +} + +template +__device__ __forceinline__ void Dot4(data *cal, data *lhd, data *rhd) { + cal[0] += lhd[0] * rhd[0]; + cal[1] += lhd[1] * rhd[1]; + cal[2] += lhd[2] * rhd[2]; + cal[3] += lhd[3] * rhd[3]; +} + +template +__device__ __forceinline__ void selfMul4(data *lhd, data *rhd) { + lhd[0] *= rhd[0]; + lhd[1] *= rhd[1]; + lhd[2] *= rhd[2]; + lhd[3] *= rhd[3]; +} + +template +__device__ __forceinline__ void selfMulConst4(data *lhd, data Const) { + lhd[0] *= Const; + lhd[1] *= Const; + lhd[2] *= Const; + lhd[3] *= Const; +} + +template +__device__ __forceinline__ void selfAddConst4(data *lhd, data Const) { + lhd[0] += Const; + lhd[1] += Const; + lhd[2] += Const; + lhd[3] += Const; +} + +template +__device__ __forceinline__ void AllReduce4(data *multi, int stride, + int warpSize) { + for (; stride > 0; stride >>= 1) { + multi[0] += __shfl_xor_sync(0xffffffff, multi[0], stride, warpSize); + multi[1] += __shfl_xor_sync(0xffffffff, multi[1], stride, warpSize); + multi[2] += __shfl_xor_sync(0xffffffff, multi[2], stride, warpSize); + multi[3] += __shfl_xor_sync(0xffffffff, multi[3], stride, warpSize); + } +} + +template +__device__ __forceinline__ void AllReduce(data multi, int stride, + int warpSize) { + for (; stride > 0; stride >>= 1) { + multi += __shfl_xor_sync(0xffffffff, multi, stride, warpSize); + } +} + +// This function finds the first element in seg_offsets greater than elem_id +// (n^th) and output n-1 to seg_numbers[tid] + +template +__device__ __forceinline__ index_t +binary_search_segment_number(const index_t *seg_offsets, const index_t n_seg, + const index_t n_elem, const index_t elem_id) { + index_t lo = 1, hi = n_seg, mid; + while (lo < hi) { + mid = (lo + hi) >> 1; + if (seg_offsets[mid] <= elem_id) { + lo = mid + 1; + } else { + hi = mid; + } + } + return (hi - 1); +} +#endif // UTILS_SPARSE diff --git a/impl/torch/sparse_kernel.h b/impl/torch/sparse_kernel.h new file mode 100644 index 0000000000..f128107c21 --- /dev/null +++ b/impl/torch/sparse_kernel.h @@ -0,0 +1,12 @@ + +#pragma once + +#include + +namespace sparse { +namespace ops { + +at::Tensor row_balance_row_major_seq_reduce_kernel(at::Tensor& out, at::Tensor& row_ptr, at::Tensor& col_ind, at::Tensor& value, at::Tensor& input); + +} // namespace ops +} // namespace sparse diff --git a/impl/torch/test/CMakeLists.txt b/impl/torch/test/CMakeLists.txt index cef1c26495..cd5e9a7465 100644 --- a/impl/torch/test/CMakeLists.txt +++ b/impl/torch/test/CMakeLists.txt @@ -40,7 +40,7 @@ add_custom_target(test_code_gen ALL set(FUNCTIONS_SRC ${GEN_FILES}) pybind11_add_module(${DIOPI_FUNCTIONS} SHARED ${FUNCTIONS_SRC}) -target_link_libraries(${DIOPI_FUNCTIONS} PRIVATE diopirt ${DEVICEIMPL}) +target_link_libraries(${DIOPI_FUNCTIONS} PRIVATE diopirt -Wl,--no-as-needed ${DEVICEIMPL} -Wl,--as-needed) add_dependencies(${DIOPI_FUNCTIONS} test_code_gen) file(MAKE_DIRECTORY ${DIOPI_TEST_DIR}/python) diff --git a/proto/include/diopi/functions.h b/proto/include/diopi/functions.h index 7c978453a8..23a9a7ba9d 100644 --- a/proto/include/diopi/functions.h +++ b/proto/include/diopi/functions.h @@ -3573,6 +3573,21 @@ DIOPI_API diopiError_t diopiGetNativeMemoryFormat(diopiContextHandle_t ctx, diop DIOPI_API diopiError_t diopiTensorDestructionHook(diopiContextHandle_t ctx, void* ptr); // ============================================custom api end======================================== +// ============================================sparse api begin======================================== +/** + * @brief Row Balance Row Major Sequence Reduce SpMM + * @param[in] ctx Context environment. + * @param[out] out output tensor. + * @param[in] input input tensor. + * @param[in] row_ptr A tensor that stores begin index of each row for col_ind + * @param[in] col_ind A tensor that stores column indexs + * @param[in] value A tensor that stores values + * @param[in] input A tensor that stores input matrix data + */ +DIOPI_API diopiError_t diopiSpMM(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t row_ptr, + diopiTensorHandle_t col_ind, diopiTensorHandle_t value, diopiTensorHandle_t input); +// ============================================sparse api end======================================== + #if defined(__cplusplus) } #endif // __cplusplus