|
5 | 5 | #include <diopi/diopirt.h> |
6 | 6 | #include <diopi/functions.h> |
7 | 7 | #include <diopi/functions_mmcv.h> |
| 8 | +#include <torch/csrc/utils/pybind.h> |
8 | 9 |
|
9 | 10 | #include "csrc_dipu/diopirt/diopirt_impl.h" |
| 11 | +#include "csrc_dipu/runtime/device/deviceapis.h" |
| 12 | +#include "csrc_dipu/utils/helpfunc.hpp" |
10 | 13 |
|
| 14 | +using dipu::VENDOR_TYPE; |
11 | 15 | using dipu::diopi_helper::toDiopiScalar; |
12 | 16 | using dipu::diopi_helper::toDiopiTensorHandle; |
13 | 17 | #endif |
@@ -273,11 +277,20 @@ void modulated_deform_conv_forward_diopi( |
273 | 277 | auto output_p = toDiopiTensorHandle(output); |
274 | 278 | auto columns_p = toDiopiTensorHandle(columns); |
275 | 279 | if (reinterpret_cast<void*>(diopiModulatedDeformConvMmcv) != nullptr) { |
276 | | - auto ret = diopiModulatedDeformConvMmcv( |
277 | | - ch, output_p, columns_p, ones_p, input_p, weight_p, bias_p, offset_p, |
278 | | - mask_p, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, |
279 | | - dilation_h, dilation_w, group, deformable_group, with_bias); |
280 | | - if (ret == diopiSuccess) return; |
| 280 | + if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) { |
| 281 | + pybind11::gil_scoped_release no_gil; |
| 282 | + auto ret = diopiModulatedDeformConvMmcv( |
| 283 | + ch, output_p, columns_p, ones_p, input_p, weight_p, bias_p, offset_p, |
| 284 | + mask_p, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, |
| 285 | + dilation_h, dilation_w, group, deformable_group, with_bias); |
| 286 | + if (ret == diopiSuccess) return; |
| 287 | + } else { |
| 288 | + auto ret = diopiModulatedDeformConvMmcv( |
| 289 | + ch, output_p, columns_p, ones_p, input_p, weight_p, bias_p, offset_p, |
| 290 | + mask_p, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w, |
| 291 | + dilation_h, dilation_w, group, deformable_group, with_bias); |
| 292 | + if (ret == diopiSuccess) return; |
| 293 | + } |
281 | 294 | } |
282 | 295 | LOG(WARNING) << "Fallback to cpu: mmcv ext op modulated_deform_conv_forward"; |
283 | 296 | auto input_cpu = input.cpu(); |
@@ -331,12 +344,24 @@ void modulated_deform_conv_backward_diopi( |
331 | 344 |
|
332 | 345 | if (reinterpret_cast<void*>(diopiModulatedDeformConvBackwardMmcv) != |
333 | 346 | nullptr) { |
334 | | - auto ret = diopiModulatedDeformConvBackwardMmcv( |
335 | | - ch, grad_input_p, grad_weight_p, grad_bias_p, grad_offset_p, |
336 | | - grad_mask_p, input_p, weight_p, bias_p, ones_p, offset_p, mask_p, |
337 | | - columns_p, grad_output_p, kernel_h, kernel_w, stride_h, stride_w, pad_h, |
338 | | - pad_w, dilation_h, dilation_w, group, deformable_group, with_bias); |
339 | | - if (ret == diopiSuccess) return; |
| 347 | + if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) { |
| 348 | + pybind11::gil_scoped_release no_gil; |
| 349 | + auto ret = diopiModulatedDeformConvBackwardMmcv( |
| 350 | + ch, grad_input_p, grad_weight_p, grad_bias_p, grad_offset_p, |
| 351 | + grad_mask_p, input_p, weight_p, bias_p, ones_p, offset_p, mask_p, |
| 352 | + columns_p, grad_output_p, kernel_h, kernel_w, stride_h, stride_w, |
| 353 | + pad_h, pad_w, dilation_h, dilation_w, group, deformable_group, |
| 354 | + with_bias); |
| 355 | + if (ret == diopiSuccess) return; |
| 356 | + } else { |
| 357 | + auto ret = diopiModulatedDeformConvBackwardMmcv( |
| 358 | + ch, grad_input_p, grad_weight_p, grad_bias_p, grad_offset_p, |
| 359 | + grad_mask_p, input_p, weight_p, bias_p, ones_p, offset_p, mask_p, |
| 360 | + columns_p, grad_output_p, kernel_h, kernel_w, stride_h, stride_w, |
| 361 | + pad_h, pad_w, dilation_h, dilation_w, group, deformable_group, |
| 362 | + with_bias); |
| 363 | + if (ret == diopiSuccess) return; |
| 364 | + } |
340 | 365 | } |
341 | 366 | LOG(WARNING) << "Fallback to cpu: mmcv ext op modulated_deform_conv_forward"; |
342 | 367 | auto input_cpu = input.cpu(); |
|
0 commit comments