Skip to content

pytorch 1.1 fixes #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytorch_superpixpool/suppixpool_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ std::vector<at::Tensor> suppixpool_max_forward(
const int batch_size = img.size(0);
const int channels_size = img.size(1);

at::Tensor output = at::zeros(torch::CUDA(at::kInt), {batch_size, channels_size, K});
at::Tensor output = torch::zeros({batch_size, channels_size, K}, torch::CUDA(at::kInt));
output = output.type_as(img);
// torch::set_requires_grad(output, true);
at::Tensor outIdx = -at::ones(torch::CUDA(at::kInt), {batch_size, channels_size, K});
at::Tensor outIdx = -torch::ones({batch_size, channels_size, K}, torch::CUDA(at::kInt));
return suppixpool_max_cuda_forward(img, spx_labels, output, outIdx, K);
// return {output, outIdx};
// return {img, spx_labels};
Expand Down
5 changes: 4 additions & 1 deletion pytorch_superpixpool/suppixpool_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,10 @@ std::vector<at::Tensor> suppixpool_max_cuda_forward(
output.data<scalar_t>()
);
}));
return {output, outIdx};
return {
output,
outIdx
};
}

std::vector<at::Tensor> suppixpool_max_cuda_backward(
Expand Down
3 changes: 2 additions & 1 deletion pytorch_superpixpool/suppixpool_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def forward(ctx, img, spx):
# print("number of -1: ", indices.eq(-1).sum())
# print indices
# assert np.all(indices.cpu().numpy()>=0)

ctx.save_for_backward(indices, img, spx, K)
return outputs

Expand Down Expand Up @@ -44,6 +45,6 @@ def __init__(self):
def forward(self, pooled, spx):
outShape = pooled.size()[0:2]+spx.size()[-2:]
out = pooled.new_zeros(outShape)
for batch in xrange(pooled.size()[0]):
for batch in range(pooled.size()[0]):
out[batch, :, :, :] = pooled[batch, :, spx[batch,:,:]]
return out
33 changes: 20 additions & 13 deletions pytorch_superpixpool/test_GPUpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,39 @@
import numpy as np
import time
from skimage.segmentation import slic
from torch.autograd import Variable

if __name__ == "__main__":

GPU = torch.device("cuda:0")
if __name__ == "__main__":

GPU = torch.device('cuda')
batch_size = 1
n_channels = 16
xSize = 256
ySize = 512
n_channels = 2
xSize = 4
ySize = 4

X = torch.randn((batch_size,n_channels,xSize,ySize), dtype=torch.float32, device=GPU)
X = torch.randn((batch_size,n_channels,xSize,ySize), dtype=torch.float32, device=GPU,requires_grad=True)
spx = np.array([np.arange(xSize*ySize).reshape(xSize,ySize)]*batch_size)
# spx = np.zeros((batch_size, xSize, ySize))
spx = torch.from_numpy(spx)
spx = spx.to(GPU)

# X + X
print ("INPUT ARRAY ----------------- \n", X)
pool = SupPixPool()
pld = pool(X, spx)



print ("POOLED ARRAY ----------------- \n", pld)
print ("Shape of pooled array: ", pld.size())
unpool = SupPixUnpool()
unpld = unpool(pld, spx)
print ("Unpooling back to original: ", np.all(unpld == X))
# unpool = SupPixUnpool()
# unpld = unpool(pld, spx)
# print(unpld.shape, X.shape)
#print ("Unpooling back to original: ", np.all(unpld.detach().cpu().numpy() == X.detach().cpu().numpy()))

res = torch.autograd.gradcheck(pool, (X, spx), raise_exception=False)
resUnpool = torch.autograd.gradcheck(unpool, (pld, spx), raise_exception=False)
res = torch.autograd.gradcheck(pool, (X.double(), spx), raise_exception=True)
# resUnpool = torch.autograd.gradcheck(unpool, (pld, spx), raise_exception=False)

print ("Gradients of pooling are {}.".format("correct" if res else "wrong")) # res should be True if the gradients are correct.
print ("Gradients of unpooling are {}.".format("correct" if resUnpool else "wrong"))
# print ("Gradients of pooling are {}.".format("correct" if res else "wrong")) # res should be True if the gradients are correct.
# print ("Gradients of unpooling are {}.".format("correct" if resUnpool else "wrong"))