Skip to content

Commit c986df1

Browse files
author
Allard Hendriksen
committed
Replace MSDModule by MSDModule2d and remove MSDModule2d
In previous versions, MSDModule was not used anymore by MSDRegressionModel and MSDSegmentationModel. Now it has been replaced by the MSDModule2d implementation. This allows removing the stitch functions and modules.
1 parent 7f835d1 commit c986df1

File tree

5 files changed

+42
-229
lines changed

5 files changed

+42
-229
lines changed

msd_pytorch/msd_block.py

Lines changed: 1 addition & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import torch
22
import msd_custom_convolutions as cc
3-
from msd_pytorch.msd_module import MSDFinalLayer, init_convolution_weights
43
import numpy as np
54

6-
IDX_WEIGHT_START = 3
7-
85

96
class MSDBlockImpl2d(torch.autograd.Function):
107
@staticmethod
@@ -90,6 +87,7 @@ def backward(ctx, grad_output):
9087
)
9188

9289
# Gradient w.r.t weights
90+
IDX_WEIGHT_START = 3 # The first weight has index 3 in the forward pass.
9391
if ctx.needs_input_grad[i + IDX_WEIGHT_START]:
9492
sub_grad_weight = torch.zeros_like(sub_weight)
9593
cc.conv_relu_backward_k(
@@ -181,53 +179,3 @@ def forward(self, input):
181179
weights = (self.__getattr__("weight{}".format(i)) for i in range(len(self.weights)))
182180

183181
return MSDBlockImpl2d.apply(input, self.dilations, bias, *weights)
184-
185-
186-
class MSDModule2d(torch.nn.Module):
187-
def __init__(
188-
self, c_in, c_out, depth, width, dilations=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
189-
):
190-
"""Create a 2-dimensional MSD Module
191-
192-
:param c_in: # of input channels
193-
:param c_out: # of output channels
194-
:param depth: # of layers
195-
:param width: # the width of the module
196-
:param dilations: `list(int)`
197-
198-
A list of dilations to use. Default is ``[1, 2, ..., 10]``. A
199-
good alternative is ``[1, 2, 4, 8]``. The dilations are
200-
repeated.
201-
202-
:returns: an MSD module
203-
:rtype: MSDModule2d
204-
205-
"""
206-
207-
super(MSDModule2d, self).__init__()
208-
209-
self.c_in = c_in
210-
self.c_out = c_out
211-
self.depth = depth
212-
self.width = width
213-
self.dilations = [dilations[i % len(dilations)] for i in range(depth)]
214-
215-
self.msd_block = MSDBlock2d(self.c_in, self.dilations, self.width)
216-
self.final_layer = MSDFinalLayer(c_in=c_in + width * depth, c_out=c_out)
217-
218-
self.reset_parameters()
219-
220-
def reset_parameters(self):
221-
# Initialize weights for hidden layers:
222-
for w in self.msd_block.weights:
223-
init_convolution_weights(
224-
w.data, self.c_in, self.c_out, self.width, self.depth
225-
)
226-
227-
self.msd_block.bias.data.zero_()
228-
self.final_layer.reset_parameters()
229-
230-
def forward(self, input):
231-
output = self.msd_block(input)
232-
output = self.final_layer(output)
233-
return output

msd_pytorch/msd_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from msd_pytorch.msd_block import MSDModule2d
1+
from msd_pytorch.msd_module import MSDModule
22
from torch.autograd import Variable
33
import numpy as np
44
import torch as t
@@ -84,7 +84,7 @@ def __init__(
8484
# network is saved.
8585
self.scale_in = scaling_module(c_in, c_in)
8686
self.scale_out = scaling_module(c_out, c_out)
87-
self.msd = MSDModule2d(c_in, c_out, depth, width, dilations)
87+
self.msd = MSDModule(c_in, c_out, depth, width, dilations)
8888

8989
# It is the task of any subclass to initialize `self.net` and
9090
# call `init_optimizer` to set the trainable parameters.

msd_pytorch/msd_module.py

Lines changed: 23 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
import torch.nn as nn
2-
from msd_pytorch.conv import Conv2dInPlaceModule
3-
from msd_pytorch.conv_relu import ConvRelu2dInPlaceModule
4-
from msd_pytorch.stitch import stitchLazy, StitchCopyModule, StitchBuffer
1+
import torch
52
from math import sqrt
63
import numpy as np
4+
from msd_pytorch.msd_block import MSDBlock2d
75

86

97
def units_in_front(c_in, width, layer_depth):
@@ -59,57 +57,7 @@ def init_convolution_weights(conv_weight, c_in, c_out, width, depth):
5957
conv_weight.normal_(0, std_dev)
6058

6159

62-
class MSDLayerModule(nn.Module):
63-
"""A hidden layer of the MSD module.
64-
65-
The primary responsibility of this module is to define the
66-
`forward()` method.
67-
68-
This module is used by the `MSDModule`.
69-
70-
This module is not responsible for
71-
72-
* Buffer management
73-
* Weight initialization
74-
"""
75-
76-
def __init__(self, buffer, c_in, layer_depth, width, dilation):
77-
"""Initialize the hidden layer.
78-
79-
:param buffer: a StitchBuffer object for storing the L and G buffers.
80-
:param c_in: The number of input channels of the MSD module.
81-
:param layer_depth:
82-
The depth of this layer in the MSD module. This index is
83-
zero-based: the first hidden layer has index zero.
84-
:param width: The width of the MSD module.
85-
:param dilation:
86-
An integer describing the dilation factor for the
87-
convolutions in this layer.
88-
:returns: A module for the MSD hidden layer.
89-
:rtype: MSDLayerModule
90-
91-
"""
92-
super(MSDLayerModule, self).__init__()
93-
94-
in_front = units_in_front(c_in, width, layer_depth)
95-
self.buffer, self.in_front, self.width = buffer, in_front, width
96-
97-
# Set output to None for the Conv2dInPlaceModule for now. We
98-
# set it in the forward pass.
99-
output = None
100-
self.convolution = ConvRelu2dInPlaceModule(
101-
output, in_front, width, kernel_size=3, dilation=dilation
102-
)
103-
104-
def forward(self, input):
105-
# Set output
106-
self.convolution.output = self.buffer.L.narrow(1, self.in_front, self.width)
107-
output = self.convolution(input)
108-
output = stitchLazy(output, self.buffer.L, self.buffer.G, self.in_front)
109-
return output
110-
111-
112-
class MSDFinalLayer(nn.Module):
60+
class MSDFinalLayer(torch.nn.Module):
11361
"""Documentation for MSDFinalLayer
11462
11563
Implements the final 1x1 multiplication and bias addition for all
@@ -122,7 +70,7 @@ def __init__(self, c_in, c_out):
12270
super(MSDFinalLayer, self).__init__()
12371
self.c_in = c_in
12472
self.c_out = c_out
125-
self.linear = nn.Conv1d(c_in, c_out, 1)
73+
self.linear = torch.nn.Conv1d(c_in, c_out, 1)
12674
self.reset_parameters()
12775

12876
def forward(self, input):
@@ -140,7 +88,7 @@ def reset_parameters(self):
14088
self.linear.bias.data.zero_()
14189

14290

143-
class MSDModule(nn.Module):
91+
class MSDModule(torch.nn.Module):
14492
def __init__(
14593
self, c_in, c_out, depth, width, dilations=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
14694
):
@@ -154,57 +102,41 @@ def __init__(
154102
155103
A list of dilations to use. Default is ``[1, 2, ..., 10]``. A
156104
good alternative is ``[1, 2, 4, 8]``. The dilations are
157-
repeated.
105+
repeated when the depth of the module exceeds the length of
106+
the list.
158107
159108
:returns: an MSD module
160109
:rtype: MSDModule
161110
162111
"""
163112
super(MSDModule, self).__init__()
164-
#
165113
self.c_in = c_in
166114
self.c_out = c_out
167115
self.depth = depth
168116
self.width = width
169-
self.dilations = dilations
117+
self.dilations = [dilations[i % len(dilations)] for i in range(depth)]
170118

171-
buffer = StitchBuffer()
172-
self.buffer = buffer
119+
if depth < 1:
120+
raise ValueError(f"Depth must be at least 1. Got: {depth}.")
121+
if width < 1:
122+
raise ValueError(f"Width must be at least 1. Got: {width}.")
173123

174-
# The first layer copies input into the L stitch buffer
175-
stitch_layer = StitchCopyModule(buffer, 0)
124+
self.msd_block = MSDBlock2d(self.c_in, self.dilations, self.width)
125+
self.final_layer = MSDFinalLayer(c_in=c_in + width * depth, c_out=c_out)
176126

177-
# Then we have `depth` number of hidden layers:
178-
self.hidden_layers = [
179-
MSDLayerModule(buffer, c_in, d, width, dilations[d % len(dilations)])
180-
for d in range(depth)
181-
]
127+
self.reset_parameters()
182128

129+
def reset_parameters(self):
183130
# Initialize weights for hidden layers:
184-
for m in self.hidden_layers:
131+
for w in self.msd_block.weights:
185132
init_convolution_weights(
186-
m.convolution.weight.data, c_in, c_out, width, depth
133+
w.data, self.c_in, self.c_out, self.width, self.depth
187134
)
188-
m.convolution.bias.data.zero_()
189-
190-
in_front = units_in_front(c_in, width, depth)
191-
self.c_final = MSDFinalLayer(in_front, c_out)
192135

193-
self.net = nn.Sequential(stitch_layer, *self.hidden_layers, self.c_final)
194-
195-
self.net.cuda()
136+
self.msd_block.bias.data.zero_()
137+
self.final_layer.reset_parameters()
196138

197139
def forward(self, input):
198-
self.init_buffers(input.data)
199-
return self.net(input)
200-
201-
def init_buffers(self, input):
202-
batch_sz, c_in, *shape = input.shape
203-
204-
assert c_in == self.c_in, "Unexpected number of input channels"
205-
206-
# Ensure that the stitch buffer is the correct shape
207-
total_units = units_in_front(self.c_in, self.width, self.depth)
208-
new_shape = (batch_sz, total_units, *shape)
209-
210-
self.buffer.like_(input, new_shape)
140+
output = self.msd_block(input)
141+
output = self.final_layer(output)
142+
return output

msd_pytorch/tests/test_msd_block.py

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -52,65 +52,3 @@ def assert_grads_equal(module, module2d):
5252

5353
assert torch_equal(module2d.final_layer.linear.weight.grad, module.c_final.linear.weight.grad)
5454
assert torch_equal(module2d.final_layer.linear.bias.grad, module.c_final.linear.bias.grad)
55-
56-
57-
def test_compare_msd_module():
58-
dtype = torch.float #
59-
device = torch.device("cuda")
60-
B = 2 # Batch size
61-
C_IN = 3 # Input channels
62-
C_OUT = 2 # Output channels
63-
H = 13 # Height
64-
W = 21 # Width
65-
dilations = [1, 5, 3] # Dilation
66-
depth = 10
67-
width = 2
68-
69-
# Input
70-
with_grad = dict(requires_grad=True, device=device, dtype=dtype)
71-
no_grad = dict(requires_grad=False, device=device, dtype=dtype)
72-
x1 = torch.randn(B, C_IN, H, W, **with_grad)
73-
x2 = x1.clone()
74-
tgt = torch.randn(B, C_OUT, H, W, **no_grad)
75-
76-
# Models
77-
m1 = msd_module.MSDModule(C_IN, C_OUT, depth, width, dilations).to(device)
78-
m2 = msd_block.MSDModule2d(C_IN, C_OUT, depth, width, dilations).to(device)
79-
80-
# Output
81-
init_weights_for_testing(m1)
82-
copy_weights(m1, m2)
83-
84-
o1 = m1(x1)
85-
o2 = m2(x2)
86-
l1 = torch.nn.MSELoss()(o1, tgt)
87-
l2 = torch.nn.MSELoss()(o2, tgt)
88-
l1.backward(torch.ones_like(o1))
89-
l2.backward(torch.ones_like(o2))
90-
91-
assert torch_equal(o1, o2)
92-
assert_grads_equal(m1, m2)
93-
94-
95-
def test_grad_check():
96-
torch.manual_seed(1)
97-
98-
dtype = torch.double
99-
size = (11, 13)
100-
batch_sz = 2
101-
102-
for depth in [9]:
103-
print(f"Depth: {depth}")
104-
width = c_in = c_out = batch_sz
105-
x = torch.randn(batch_sz, c_in, *size, dtype=dtype).cuda()
106-
x.requires_grad = True
107-
108-
net = msd_block.MSDModule2d(c_in, c_out, depth, width)
109-
net.cuda()
110-
net.double()
111-
112-
for p in net.parameters():
113-
p.data = torch.randn_like(p.data)
114-
115-
assert net is not None
116-
gradcheck(net, [x], raise_exception=True, atol=1e-4, rtol=1e-3)

0 commit comments

Comments
 (0)