Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ ms_deform_attn_forward(
const at::Tensor &attn_weight,
const int im2col_step)
{
if (value.type().is_cuda())
if (value.device().is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_forward(
Expand All @@ -49,7 +49,7 @@ ms_deform_attn_backward(
const at::Tensor &grad_output,
const int im2col_step)
{
if (value.type().is_cuda())
if (value.device().is_cuda())
{
#ifdef WITH_CUDA
return ms_deform_attn_cuda_backward(
Expand All @@ -61,4 +61,4 @@ ms_deform_attn_backward(
AT_ERROR("Not implemented on the CPU");
}

} // namespace groundingdino
} // namespace groundingdino
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <c10/cuda/CUDAGuard.h>

namespace groundingdino {

Expand All @@ -26,17 +27,17 @@ at::Tensor ms_deform_attn_cuda_forward(
const at::Tensor &attn_weight,
const int im2col_step)
{

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in line 29, "{" must erased

AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");

AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
TORCH_CHECK(value.is_contiguous(), "value tensor has to be contiguous");
TORCH_CHECK(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
TORCH_CHECK(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
TORCH_CHECK(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
TORCH_CHECK(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");

TORCH_CHECK(value.device().type() == at::kCUDA, "value must be a CUDA tensor");
TORCH_CHECK(spatial_shapes.device().type() == at::kCUDA, "spatial_shapes must be a CUDA tensor");
TORCH_CHECK(level_start_index.device().type() == at::kCUDA, "level_start_index must be a CUDA tensor");
TORCH_CHECK(sampling_loc.device().type() == at::kCUDA, "sampling_loc must be a CUDA tensor");
TORCH_CHECK(attn_weight.device().type() == at::kCUDA, "attn_weight must be a CUDA tensor");

const int batch = value.size(0);
const int spatial_size = value.size(1);
Expand All @@ -50,7 +51,7 @@ at::Tensor ms_deform_attn_cuda_forward(

const int im2col_step_ = std::min(batch, im2col_step);

AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
TORCH_CHECK(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);

auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());

Expand All @@ -59,19 +60,19 @@ at::Tensor ms_deform_attn_cuda_forward(
auto per_value_size = spatial_size * num_heads * channels;
auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;

for (int n = 0; n < batch/im2col_step_; ++n)
{
auto columns = output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.scalar_type(), "ms_deform_attn_forward_cuda", ([&] {
ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data_ptr<int64_t>(),
level_start_index.data_ptr<int64_t>(),
sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
columns.data<scalar_t>());

columns.data_ptr<scalar_t>());
}));
}

Expand All @@ -80,7 +81,6 @@ at::Tensor ms_deform_attn_cuda_forward(
return output;
}


std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &value,
const at::Tensor &spatial_shapes,
Expand All @@ -90,20 +90,19 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
const at::Tensor &grad_output,
const int im2col_step)
{

AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");

AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
TORCH_CHECK(value.is_contiguous(), "value tensor has to be contiguous");
TORCH_CHECK(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
TORCH_CHECK(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
TORCH_CHECK(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
TORCH_CHECK(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
TORCH_CHECK(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");

TORCH_CHECK(value.device().type() == at::kCUDA, "value must be a CUDA tensor");
TORCH_CHECK(spatial_shapes.device().type() == at::kCUDA, "spatial_shapes must be a CUDA tensor");
TORCH_CHECK(level_start_index.device().type() == at::kCUDA, "level_start_index must be a CUDA tensor");
TORCH_CHECK(sampling_loc.device().type() == at::kCUDA, "sampling_loc must be a CUDA tensor");
TORCH_CHECK(attn_weight.device().type() == at::kCUDA, "attn_weight must be a CUDA tensor");
TORCH_CHECK(grad_output.device().type() == at::kCUDA, "grad_output must be a CUDA tensor");

const int batch = value.size(0);
const int spatial_size = value.size(1);
Expand All @@ -117,7 +116,7 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(

const int im2col_step_ = std::min(batch, im2col_step);

AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
TORCH_CHECK(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);

auto grad_value = at::zeros_like(value);
auto grad_sampling_loc = at::zeros_like(sampling_loc);
Expand All @@ -132,19 +131,18 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
for (int n = 0; n < batch/im2col_step_; ++n)
{
auto grad_output_g = grad_output_n.select(0, n);
AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.scalar_type(), "ms_deform_attn_backward_cuda", ([&] {
ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
grad_output_g.data<scalar_t>(),
value.data<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data<int64_t>(),
level_start_index.data<int64_t>(),
sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
grad_output_g.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
spatial_shapes.data_ptr<int64_t>(),
level_start_index.data_ptr<int64_t>(),
sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size,
batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
grad_value.data<scalar_t>() + n * im2col_step_ * per_value_size,
grad_sampling_loc.data<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
grad_attn_weight.data<scalar_t>() + n * im2col_step_ * per_attn_weight_size);

grad_value.data_ptr<scalar_t>() + n * im2col_step_ * per_value_size,
grad_sampling_loc.data_ptr<scalar_t>() + n * im2col_step_ * per_sample_loc_size,
grad_attn_weight.data_ptr<scalar_t>() + n * im2col_step_ * per_attn_weight_size);
}));
}

Expand All @@ -153,4 +151,4 @@ std::vector<at::Tensor> ms_deform_attn_cuda_backward(
};
}

} // namespace groundingdino
} // namespace groundingdino