Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions diopi_test/python/conformance/diopi_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6079,3 +6079,12 @@ def rotary_emb_v2(query, key, cos, sin, dim):
ret = func(query.context(), query, key, cos, sin, dim)
check_returncode(ret)
return query, key

def spmm(row_ptr, col_ind, value, input) -> Tensor:
M = row_ptr.size().data[-1]-1
N = input.size().data[-1]
out = Tensor(list([M, N]),dtype=Dtype.float32)
func = check_function("diopiSpMM")
ret = func(input.context(), out, row_ptr, col_ind, value, input)
check_returncode(ret)
return out
1 change: 1 addition & 0 deletions diopi_test/python/mytest.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells":[{"cell_type":"code","execution_count":3,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["tensor([[ 0.0000, -0.0000, -0.0000, ..., -0.0000, 0.0000, 0.0000],\n"," [-0.0000, 0.0000, -0.0000, ..., 0.0000, 0.0000, 0.0000],\n"," [-0.0000, 0.0000, -0.0000, ..., 0.0000, -0.0000, 0.0000],\n"," ...,\n"," [ 0.0000, 23.0290, 0.0000, ..., 0.0000, 0.0000, 0.0000],\n"," [-3.7184, -0.0000, -0.0000, ..., -0.0000, -0.0000, 0.0000],\n"," [ 0.0000, 0.0000, -0.0000, ..., -0.0000, -4.7012, 0.0000]])\n","tensor(crow_indices=tensor([ 0, 426, 850, ..., 1676665,\n"," 1677043, 1677436]),\n"," col_indices=tensor([ 25, 40, 72, ..., 4071, 4075, 4094]),\n"," values=tensor([11.1403, -1.7448, 5.9189, ..., 0.0816, 10.1160,\n"," -4.7012]), size=(4096, 4096), nnz=1677436,\n"," layout=torch.sparse_csr)\n","[[-0.12721 -0.49293 -1.36716 ... 0.11863 0.81915 0.1262 ]\n"," [-2.30494 -0.50631 -0.23826 ... -0.88652 0.44469 0.43873]\n"," [ 1.19146 0.83548 -1.03986 ... -0.18479 -1.14953 0.54477]\n"," ...\n"," [ 0.43145 1.52025 0.337 ... 0.12734 1.17637 -0.83798]\n"," [-1.14607 0.24201 -0.35273 ... -1.18817 -1.57963 0.49221]\n"," [-0.75149 -1.61606 1.13707 ... 0.69304 0.95235 -1.07856]]\n","[4096, 128]\n"]}],"source":["import torch\n","import diopilib\n","from conformance.diopi_functions import spmm\n","from conformance.diopi_runtime import Tensor\n","import numpy as np\n","from torch.nn.functional import dropout\n","# sparse matrix:\n","# 0.5, 0, 1\n","# 0, 0, 2\n","# 1, 3, 0.6\n","\n","M, K, N = 4096, 4096, 128\n","\n","a = torch.randn((M,K),dtype=torch.float32)\n","a = dropout(a, p=0.9)\n","sparse_a = a.to_sparse_csr()\n","print(a)\n","print(sparse_a)\n","b = np.random.randn(K,N).astype(np.float32)\n","print(b)\n","\n","input = Tensor.from_numpy(b)\n","row_ptr = Tensor.from_numpy(sparse_a.crow_indices().numpy().astype(np.int32))\n","col_ind = Tensor.from_numpy(sparse_a.col_indices().numpy().astype(np.int32))\n","values = Tensor.from_numpy(sparse_a.values().numpy().astype(np.float32))\n","print(list(input.size().data))\n","c = spmm(row_ptr, col_ind, values, input)\n","c_ref = a @ torch.from_numpy(b)\n","# c = rbrmsr_spmm(sparse_a.crow_indices(), sparse_a.col_indices(), sparse_a.values(), b)\n"]},{"cell_type":"code","execution_count":4,"metadata":{},"outputs":[{"name":"stdout","output_type":"stream","text":["ours: tensor([[ 189.8907, -14.9602, 114.9101, ..., 84.5701, -9.3818,\n"," 367.1490],\n"," [ -49.5185, 155.6731, -420.4368, ..., 107.9797, 367.4091,\n"," -8.2983],\n"," [ 100.9162, 10.6236, 311.1810, ..., -59.7648, 264.6500,\n"," -59.7523],\n"," ...,\n"," [-231.3855, 114.1076, -277.9331, ..., -188.6128, 57.9538,\n"," -185.5934],\n"," [ 57.2971, 128.9449, -58.4205, ..., -96.2650, -76.2732,\n"," 210.6950],\n"," [-113.7153, 34.8387, 90.0678, ..., -45.6139, 31.5837,\n"," -257.7583]])\n","ref: tensor([[ 189.8907, -14.9602, 114.9101, ..., 84.5701, -9.3818,\n"," 367.1489],\n"," [ -49.5186, 155.6732, -420.4370, ..., 107.9796, 367.4092,\n"," -8.2982],\n"," [ 100.9163, 10.6236, 311.1810, ..., -59.7648, 264.6500,\n"," -59.7523],\n"," ...,\n"," [-231.3856, 114.1077, -277.9332, ..., -188.6128, 57.9537,\n"," -185.5935],\n"," [ 57.2971, 128.9448, -58.4205, ..., -96.2650, -76.2731,\n"," 210.6951],\n"," [-113.7153, 34.8386, 90.0678, ..., -45.6140, 31.5837,\n"," -257.7583]])\n"]}],"source":["c_ours = torch.from_numpy(c.numpy())\n","print(\"ours:\", c_ours)\n","print(\"ref: \", c_ref)\n","assert torch.allclose(c_ours, c_ref, rtol=1e-03, atol=1e-03)"]}],"metadata":{"kernelspec":{"display_name":"eda","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.19"}},"nbformat":4,"nbformat_minor":2}
2 changes: 1 addition & 1 deletion impl/torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ if (DYLOAD)
set(IMPL_SRC wrap_func.cpp)
endif()

file(GLOB REAL_IMPL_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} functions/functions_mmcv/*.cu functions/functions_ext/*.cu functions/*.cpp helper.cpp build_aten.cpp)
file(GLOB REAL_IMPL_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} functions/functions_sparse/*.cu functions/functions_mmcv/*.cu functions/functions_ext/*.cu functions/*.cpp helper.cpp build_aten.cpp)

# adaptor
set(USE_ADAPTOR ON)
Expand Down
14 changes: 14 additions & 0 deletions impl/torch/functions/functions.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/**

Check notice on line 1 in impl/torch/functions/functions.cpp

View workflow job for this annotation

GitHub Actions / cpp-linter

Run clang-format on impl/torch/functions/functions.cpp

File impl/torch/functions/functions.cpp does not conform to Custom style guidelines. (lines 31, 4387, 4388)
* @file
* @author DeepLink
* @copyright (c) 2023, DeepLink.
Expand Down Expand Up @@ -30,6 +30,7 @@

#include "../helper.hpp"
#include "../vision_kernel.h"
#include "../sparse_kernel.h"

namespace impl {
namespace cuda {
Expand Down Expand Up @@ -4383,5 +4384,18 @@
return diopiSuccess;
}

diopiError_t diopiSpMM(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiTensorHandle_t row_ptr,
diopiTensorHandle_t col_ind, diopiTensorHandle_t value, diopiTensorHandle_t input){
impl::aten::setCurStream(ctx);
auto atInput = impl::aten::buildATen(input);
auto atRowPtr = impl::aten::buildATen(row_ptr);
auto atColInd = impl::aten::buildATen(col_ind);
auto atValue = impl::aten::buildATen(value);
auto atOut = impl::aten::buildATen(out);
sparse::ops::row_balance_row_major_seq_reduce_kernel(atOut, atRowPtr, atColInd, atValue, atInput);

return diopiSuccess;
}

} // namespace cuda
} // namespace impl
2 changes: 1 addition & 1 deletion impl/torch/functions/functions_mmcv/cuda_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ constexpr int THREADS_PER_BLOCK = 512;
inline int GET_BLOCKS(const int N, const int num_threads = THREADS_PER_BLOCK) {
int optimal_block_num = (N + num_threads - 1) / num_threads;
int max_block_num = 4096;
return std::min(optimal_block_num, max_block_num);
return min(optimal_block_num, max_block_num);
}

template <typename T>
Expand Down
85 changes: 85 additions & 0 deletions impl/torch/functions/functions_sparse/spmm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/library.h>

#include <iostream>
#include <tuple>
#include <vector>

#include "utils.h"

namespace sparse{
namespace ops {


template <typename Index, typename DType>
__global__ void
csrspmm_seqreduce_rowbalance_kernel(const Index nr, const Index feature_size,
const Index rowPtr[], const Index colIdx[],
const DType values[], const DType dnInput[],
DType dnOutput[]) {
Index row_tile = blockDim.y; // 8
Index subwarp_id = threadIdx.y;
Index stride = row_tile * gridDim.x; // 8 * (m/8)
Index row = blockIdx.x * row_tile + subwarp_id;
Index v_id = (blockIdx.y * blockDim.x) + threadIdx.x;
dnInput += v_id;
dnOutput += v_id;
DType val;
// DType res = init(REDUCE::Op);
Index col;
for (; row < nr; row += stride) {
DType res = 0;
Index E_k_idx = -1;
Index start = __ldg(rowPtr + row);
Index end = __ldg(rowPtr + row + 1);

for (Index p = start; p < end; p++) {
col = __ldg(colIdx + p);
val = __guard_load_default_one<DType>(values, p);
res += val * __ldg(dnInput + col * feature_size);
}

dnOutput[row * feature_size] = res;
}
}

at::Tensor row_balance_row_major_seq_reduce_kernel(at::Tensor& out, at::Tensor& row_ptr, at::Tensor& col_ind,
at::Tensor& value, at::Tensor& input){
// assertTensor(row_ptr, at::kInt32);
// assertTensor(col_ind, at::kInt32);
// assertTensor(input, at::kFloat32);
// assertTensor(value, at::kFloat32);
input = input.contiguous();
// int v = row_ptr.size(0) - 1;
// int Ndim_worker = input.size(1);
// int f = Ndim_worker;
// int e = col_ind.size(0);

int Mdim_worker = row_ptr.size(0) - 1;
// int v = Mdim_worker;
int Ndim_worker = input.size(1);
// int f = Ndim_worker;
// int e = col_ind.size(0);
int RefThreadPerBlock = 256;
int Ndim_threadblock = CEIL(Ndim_worker, RefThreadPerBlock);
int Ndim_thread_per_tb = min(Ndim_worker, RefThreadPerBlock);
int Mdim_thread_per_tb = CEIL(RefThreadPerBlock, Ndim_thread_per_tb);
int Mdim_threadblock = CEIL(Mdim_worker, Mdim_thread_per_tb);

dim3 gridDim(Mdim_threadblock, Ndim_threadblock, 1);
dim3 blockDim(Ndim_thread_per_tb, Mdim_thread_per_tb, 1);

// auto out = at::empty({v, f}, options);
csrspmm_seqreduce_rowbalance_kernel<int, float>
<<<gridDim, blockDim>>>(
Mdim_worker, Ndim_worker, row_ptr.data_ptr<int>(),
col_ind.data_ptr<int>(), value.data_ptr<float>(),
input.data_ptr<float>(), out.data_ptr<float>());
return out;
}

}
}
Loading
Loading