diff --git a/main.py b/main.py index 7e1f3f8..8b9f321 100644 --- a/main.py +++ b/main.py @@ -164,7 +164,8 @@ def main(): for i in range(len(dataset)): inputs, target = dataset[i] inputs = Variable(inputs.cuda()) - outputs = model(inputs.unsqueeze(0)) + with torch.no_grad(): + outputs = model(inputs.unsqueeze(0)) _, pred = torch.max(outputs, 1) pred = pred.data.cpu().numpy().squeeze().astype(np.uint8) mask = target.numpy().astype(np.uint8)