Skip to content

Commit b366e92

Browse files
authored
DIOPI adapt codes for npu device (#2973)
1 parent 00e92ab commit b366e92

File tree

6 files changed

+139
-39
lines changed

6 files changed

+139
-39
lines changed

mmcv/ops/csrc/pytorch/focal_loss.cpp

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,13 @@
55
#include <diopi/diopirt.h>
66
#include <diopi/functions.h>
77
#include <diopi/functions_mmcv.h>
8+
#include <torch/csrc/utils/pybind.h>
89

910
#include "csrc_dipu/diopirt/diopirt_impl.h"
11+
#include "csrc_dipu/runtime/device/deviceapis.h"
12+
#include "csrc_dipu/utils/helpfunc.hpp"
1013

14+
using dipu::VENDOR_TYPE;
1115
using dipu::diopi_helper::toDiopiScalar;
1216
using dipu::diopi_helper::toDiopiTensorHandle;
1317
#endif
@@ -57,9 +61,16 @@ void sigmoid_focal_loss_forward_diopi(Tensor input, Tensor target,
5761
auto weight_p = toDiopiTensorHandle(weight);
5862
auto output_p = toDiopiTensorHandle(output);
5963
if (reinterpret_cast<void *>(diopiSigmoidFocalLossMmcv) != nullptr) {
60-
auto ret = diopiSigmoidFocalLossMmcv(ch, output_p, input_p, target_p,
61-
weight_p, gamma, alpha);
62-
if (ret == diopiSuccess) return;
64+
if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
65+
pybind11::gil_scoped_release no_gil;
66+
auto ret = diopiSigmoidFocalLossMmcv(ch, output_p, input_p, target_p,
67+
weight_p, gamma, alpha);
68+
if (ret == diopiSuccess) return;
69+
} else {
70+
auto ret = diopiSigmoidFocalLossMmcv(ch, output_p, input_p, target_p,
71+
weight_p, gamma, alpha);
72+
if (ret == diopiSuccess) return;
73+
}
6374
}
6475
LOG(WARNING)
6576
<< "Fallback to cpu: mmcv ext op sigmoid_focal_loss_forward_impl";
@@ -90,9 +101,16 @@ void sigmoid_focal_loss_backward_diopi(Tensor input, Tensor target,
90101
auto weight_p = toDiopiTensorHandle(weight);
91102
auto grad_input_p = toDiopiTensorHandle(grad_input);
92103
if (reinterpret_cast<void *>(diopiSigmoidFocalLossBackwardMmcv) != nullptr) {
93-
auto ret = diopiSigmoidFocalLossBackwardMmcv(
94-
ch, grad_input_p, input_p, target_p, weight_p, gamma, alpha);
95-
if (ret == diopiSuccess) return;
104+
if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
105+
pybind11::gil_scoped_release no_gil;
106+
auto ret = diopiSigmoidFocalLossBackwardMmcv(
107+
ch, grad_input_p, input_p, target_p, weight_p, gamma, alpha);
108+
if (ret == diopiSuccess) return;
109+
} else {
110+
auto ret = diopiSigmoidFocalLossBackwardMmcv(
111+
ch, grad_input_p, input_p, target_p, weight_p, gamma, alpha);
112+
if (ret == diopiSuccess) return;
113+
}
96114
}
97115
LOG(WARNING)
98116
<< "Fallback to cpu: mmcv ext op sigmoid_focal_loss_forward_impl";

mmcv/ops/csrc/pytorch/modulated_deform_conv.cpp

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,13 @@
55
#include <diopi/diopirt.h>
66
#include <diopi/functions.h>
77
#include <diopi/functions_mmcv.h>
8+
#include <torch/csrc/utils/pybind.h>
89

910
#include "csrc_dipu/diopirt/diopirt_impl.h"
11+
#include "csrc_dipu/runtime/device/deviceapis.h"
12+
#include "csrc_dipu/utils/helpfunc.hpp"
1013

14+
using dipu::VENDOR_TYPE;
1115
using dipu::diopi_helper::toDiopiScalar;
1216
using dipu::diopi_helper::toDiopiTensorHandle;
1317
#endif
@@ -273,11 +277,20 @@ void modulated_deform_conv_forward_diopi(
273277
auto output_p = toDiopiTensorHandle(output);
274278
auto columns_p = toDiopiTensorHandle(columns);
275279
if (reinterpret_cast<void*>(diopiModulatedDeformConvMmcv) != nullptr) {
276-
auto ret = diopiModulatedDeformConvMmcv(
277-
ch, output_p, columns_p, ones_p, input_p, weight_p, bias_p, offset_p,
278-
mask_p, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
279-
dilation_h, dilation_w, group, deformable_group, with_bias);
280-
if (ret == diopiSuccess) return;
280+
if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
281+
pybind11::gil_scoped_release no_gil;
282+
auto ret = diopiModulatedDeformConvMmcv(
283+
ch, output_p, columns_p, ones_p, input_p, weight_p, bias_p, offset_p,
284+
mask_p, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
285+
dilation_h, dilation_w, group, deformable_group, with_bias);
286+
if (ret == diopiSuccess) return;
287+
} else {
288+
auto ret = diopiModulatedDeformConvMmcv(
289+
ch, output_p, columns_p, ones_p, input_p, weight_p, bias_p, offset_p,
290+
mask_p, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
291+
dilation_h, dilation_w, group, deformable_group, with_bias);
292+
if (ret == diopiSuccess) return;
293+
}
281294
}
282295
LOG(WARNING) << "Fallback to cpu: mmcv ext op modulated_deform_conv_forward";
283296
auto input_cpu = input.cpu();
@@ -331,12 +344,24 @@ void modulated_deform_conv_backward_diopi(
331344

332345
if (reinterpret_cast<void*>(diopiModulatedDeformConvBackwardMmcv) !=
333346
nullptr) {
334-
auto ret = diopiModulatedDeformConvBackwardMmcv(
335-
ch, grad_input_p, grad_weight_p, grad_bias_p, grad_offset_p,
336-
grad_mask_p, input_p, weight_p, bias_p, ones_p, offset_p, mask_p,
337-
columns_p, grad_output_p, kernel_h, kernel_w, stride_h, stride_w, pad_h,
338-
pad_w, dilation_h, dilation_w, group, deformable_group, with_bias);
339-
if (ret == diopiSuccess) return;
347+
if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
348+
pybind11::gil_scoped_release no_gil;
349+
auto ret = diopiModulatedDeformConvBackwardMmcv(
350+
ch, grad_input_p, grad_weight_p, grad_bias_p, grad_offset_p,
351+
grad_mask_p, input_p, weight_p, bias_p, ones_p, offset_p, mask_p,
352+
columns_p, grad_output_p, kernel_h, kernel_w, stride_h, stride_w,
353+
pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
354+
with_bias);
355+
if (ret == diopiSuccess) return;
356+
} else {
357+
auto ret = diopiModulatedDeformConvBackwardMmcv(
358+
ch, grad_input_p, grad_weight_p, grad_bias_p, grad_offset_p,
359+
grad_mask_p, input_p, weight_p, bias_p, ones_p, offset_p, mask_p,
360+
columns_p, grad_output_p, kernel_h, kernel_w, stride_h, stride_w,
361+
pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
362+
with_bias);
363+
if (ret == diopiSuccess) return;
364+
}
340365
}
341366
LOG(WARNING) << "Fallback to cpu: mmcv ext op modulated_deform_conv_forward";
342367
auto input_cpu = input.cpu();

mmcv/ops/csrc/pytorch/nms.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
#include <diopi/diopirt.h>
66
#include <diopi/functions.h>
77
#include <diopi/functions_mmcv.h>
8+
#include <torch/csrc/utils/pybind.h>
89

910
#include "csrc_dipu/base/basedef.h"
1011
#include "csrc_dipu/diopirt/diopirt_impl.h"
12+
#include "csrc_dipu/runtime/device/deviceapis.h"
13+
#include "csrc_dipu/utils/helpfunc.hpp"
1114

15+
using dipu::VENDOR_TYPE;
1216
using dipu::diopi_helper::toDiopiScalar;
1317
using dipu::diopi_helper::toDiopiTensorHandle;
1418
#endif
@@ -45,11 +49,21 @@ Tensor nms_diopi(Tensor boxes, Tensor scores, float iou_threshold, int offset) {
4549
auto scores_p = toDiopiTensorHandle(scores);
4650
bool is_mock_cuda = boxes.device().type() == dipu::DIPU_DEVICE_TYPE;
4751
if (is_mock_cuda && reinterpret_cast<void*>(diopiNmsMmcv) != nullptr) {
48-
auto ret =
49-
diopiNmsMmcv(ch, outhandle, boxes_p, scores_p, iou_threshold, offset);
50-
if (ret == diopiSuccess) {
51-
auto tensorhandle = reinterpret_cast<Tensor*>(*outhandle);
52-
return *tensorhandle;
52+
if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
53+
pybind11::gil_scoped_release no_gil;
54+
auto ret =
55+
diopiNmsMmcv(ch, outhandle, boxes_p, scores_p, iou_threshold, offset);
56+
if (ret == diopiSuccess) {
57+
auto tensorhandle = reinterpret_cast<Tensor*>(*outhandle);
58+
return *tensorhandle;
59+
}
60+
} else {
61+
auto ret =
62+
diopiNmsMmcv(ch, outhandle, boxes_p, scores_p, iou_threshold, offset);
63+
if (ret == diopiSuccess) {
64+
auto tensorhandle = reinterpret_cast<Tensor*>(*outhandle);
65+
return *tensorhandle;
66+
}
5367
}
5468
}
5569
LOG(WARNING) << "Fallback to cpu: mmcv ext op nms";

mmcv/ops/csrc/pytorch/roi_align.cpp

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
#include <diopi/diopirt.h>
66
#include <diopi/functions.h>
77
#include <diopi/functions_mmcv.h>
8+
#include <torch/csrc/utils/pybind.h>
89

910
#include "csrc_dipu/base/basedef.h"
1011
#include "csrc_dipu/diopirt/diopirt_impl.h"
12+
#include "csrc_dipu/runtime/device/deviceapis.h"
13+
#include "csrc_dipu/utils/helpfunc.hpp"
1114

15+
using dipu::VENDOR_TYPE;
1216
using dipu::diopi_helper::toDiopiScalar;
1317
using dipu::diopi_helper::toDiopiTensorHandle;
1418
#endif
@@ -56,10 +60,18 @@ void roi_align_forward_diopi(Tensor input, Tensor rois, Tensor output,
5660
auto argmax_x_p = toDiopiTensorHandle(argmax_x);
5761
bool is_mock_cuda = input.device().type() == dipu::DIPU_DEVICE_TYPE;
5862
if (is_mock_cuda && reinterpret_cast<void *>(diopiRoiAlignMmcv) != nullptr) {
59-
auto ret = diopiRoiAlignMmcv(
60-
ch, out_p, argmax_y_p, argmax_x_p, input_p, rois_p, aligned_height,
61-
aligned_width, sampling_ratio, pool_mode, spatial_scale, aligned);
62-
if (ret == diopiSuccess) return;
63+
if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
64+
pybind11::gil_scoped_release no_gil;
65+
auto ret = diopiRoiAlignMmcv(
66+
ch, out_p, argmax_y_p, argmax_x_p, input_p, rois_p, aligned_height,
67+
aligned_width, sampling_ratio, pool_mode, spatial_scale, aligned);
68+
if (ret == diopiSuccess) return;
69+
} else {
70+
auto ret = diopiRoiAlignMmcv(
71+
ch, out_p, argmax_y_p, argmax_x_p, input_p, rois_p, aligned_height,
72+
aligned_width, sampling_ratio, pool_mode, spatial_scale, aligned);
73+
if (ret == diopiSuccess) return;
74+
}
6375
}
6476
LOG(WARNING) << "Fallback to cpu: mmcv ext op roi_align_forward";
6577
auto input_cpu = input.cpu();
@@ -96,11 +108,20 @@ void roi_align_backward_diopi(Tensor grad_output, Tensor rois, Tensor argmax_y,
96108
bool is_mock_cuda = grad_output.device().type() == dipu::DIPU_DEVICE_TYPE;
97109
if (is_mock_cuda &&
98110
reinterpret_cast<void *>(diopiRoiAlignBackwardMmcv) != nullptr) {
99-
auto ret = diopiRoiAlignBackwardMmcv(ch, grad_input_, grad_output_, rois_,
100-
argmax_y_, argmax_x_, aligned_height,
101-
aligned_width, sampling_ratio,
102-
pool_mode, spatial_scale, aligned);
103-
if (ret == diopiSuccess) return;
111+
if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
112+
pybind11::gil_scoped_release no_gil;
113+
auto ret = diopiRoiAlignBackwardMmcv(ch, grad_input_, grad_output_, rois_,
114+
argmax_y_, argmax_x_, aligned_height,
115+
aligned_width, sampling_ratio,
116+
pool_mode, spatial_scale, aligned);
117+
if (ret == diopiSuccess) return;
118+
} else {
119+
auto ret = diopiRoiAlignBackwardMmcv(ch, grad_input_, grad_output_, rois_,
120+
argmax_y_, argmax_x_, aligned_height,
121+
aligned_width, sampling_ratio,
122+
pool_mode, spatial_scale, aligned);
123+
if (ret == diopiSuccess) return;
124+
}
104125
}
105126
LOG(WARNING) << "Fallback to cpu: mmcv ext op roi_align_backward";
106127
auto grad_output_cpu = grad_output.cpu();

mmcv/ops/csrc/pytorch/voxelization.cpp

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,13 @@
55
#include <diopi/diopirt.h>
66
#include <diopi/functions.h>
77
#include <diopi/functions_mmcv.h>
8+
#include <torch/csrc/utils/pybind.h>
89

910
#include "csrc_dipu/diopirt/diopirt_impl.h"
11+
#include "csrc_dipu/runtime/device/deviceapis.h"
12+
#include "csrc_dipu/utils/helpfunc.hpp"
1013

14+
using dipu::VENDOR_TYPE;
1115
using dipu::diopi_helper::toDiopiScalar;
1216
using dipu::diopi_helper::toDiopiTensorHandle;
1317
#endif
@@ -84,11 +88,20 @@ void hard_voxelize_forward_diopi(const at::Tensor &points,
8488
auto num_points_per_voxel_p = toDiopiTensorHandle(num_points_per_voxel);
8589
auto voxel_num_p = toDiopiTensorHandle(voxel_num);
8690
if (reinterpret_cast<void *>(diopiHardVoxelizeMmcv) != nullptr) {
87-
auto ret = diopiHardVoxelizeMmcv(
88-
ch, voxels_p, coors_p, num_points_per_voxel_p, voxel_num_p, points_p,
89-
voxel_size_p, coors_range_p, max_points, max_voxels, NDim,
90-
deterministic);
91-
if (ret == diopiSuccess) return;
91+
if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
92+
pybind11::gil_scoped_release no_gil;
93+
auto ret = diopiHardVoxelizeMmcv(
94+
ch, voxels_p, coors_p, num_points_per_voxel_p, voxel_num_p, points_p,
95+
voxel_size_p, coors_range_p, max_points, max_voxels, NDim,
96+
deterministic);
97+
if (ret == diopiSuccess) return;
98+
} else {
99+
auto ret = diopiHardVoxelizeMmcv(
100+
ch, voxels_p, coors_p, num_points_per_voxel_p, voxel_num_p, points_p,
101+
voxel_size_p, coors_range_p, max_points, max_voxels, NDim,
102+
deterministic);
103+
if (ret == diopiSuccess) return;
104+
}
92105
}
93106
LOG(WARNING) << "Fallback to cpu: mmcv ext op hard_voxelize_forward";
94107
auto points_cpu = points.cpu();
@@ -146,9 +159,16 @@ void dynamic_voxelize_forward_diopi(const at::Tensor &points,
146159
auto coors_range_p = toDiopiTensorHandle(coors_range);
147160
auto coors_p = toDiopiTensorHandle(coors);
148161
if (reinterpret_cast<void *>(diopiDynamicVoxelizeMmcv) != nullptr) {
149-
auto ret = diopiDynamicVoxelizeMmcv(ch, coors_p, points_p, voxel_size_p,
150-
coors_range_p, NDim);
151-
if (ret == diopiSuccess) return;
162+
if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
163+
pybind11::gil_scoped_release no_gil;
164+
auto ret = diopiDynamicVoxelizeMmcv(ch, coors_p, points_p, voxel_size_p,
165+
coors_range_p, NDim);
166+
if (ret == diopiSuccess) return;
167+
} else {
168+
auto ret = diopiDynamicVoxelizeMmcv(ch, coors_p, points_p, voxel_size_p,
169+
coors_range_p, NDim);
170+
if (ret == diopiSuccess) return;
171+
}
152172
}
153173
LOG(WARNING) << "Fallback to cpu: mmcv ext op dynamic_voxelize_forward";
154174
auto points_cpu = points.cpu();

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,12 @@ def get_extensions():
244244
dipu_path = os.getenv('DIPU_PATH')
245245
vendor_include_dirs = os.getenv('VENDOR_INCLUDE_DIRS')
246246
nccl_include_dirs = os.getenv('NCCL_INCLUDE_DIRS')
247+
pytorch_dir = os.getenv('PYTORCH_DIR')
247248
include_dirs.append(dipu_root)
248249
include_dirs.append(diopi_path + '/include')
249250
include_dirs.append(dipu_path + '/dist/include')
250251
include_dirs.append(vendor_include_dirs)
252+
include_dirs.append(pytorch_dir + 'torch/include')
251253
if nccl_include_dirs:
252254
include_dirs.append(nccl_include_dirs)
253255
library_dirs += [dipu_root]

0 commit comments

Comments
 (0)