1- import torch .nn as nn
2- from msd_pytorch .conv import Conv2dInPlaceModule
3- from msd_pytorch .conv_relu import ConvRelu2dInPlaceModule
4- from msd_pytorch .stitch import stitchLazy , StitchCopyModule , StitchBuffer
1+ import torch
52from math import sqrt
63import numpy as np
4+ from msd_pytorch .msd_block import MSDBlock2d
75
86
97def units_in_front (c_in , width , layer_depth ):
@@ -59,57 +57,7 @@ def init_convolution_weights(conv_weight, c_in, c_out, width, depth):
5957 conv_weight .normal_ (0 , std_dev )
6058
6159
62- class MSDLayerModule (nn .Module ):
63- """A hidden layer of the MSD module.
64-
65- The primary responsibility of this module is to define the
66- `forward()` method.
67-
68- This module is used by the `MSDModule`.
69-
70- This module is not responsible for
71-
72- * Buffer management
73- * Weight initialization
74- """
75-
76- def __init__ (self , buffer , c_in , layer_depth , width , dilation ):
77- """Initialize the hidden layer.
78-
79- :param buffer: a StitchBuffer object for storing the L and G buffers.
80- :param c_in: The number of input channels of the MSD module.
81- :param layer_depth:
82- The depth of this layer in the MSD module. This index is
83- zero-based: the first hidden layer has index zero.
84- :param width: The width of the MSD module.
85- :param dilation:
86- An integer describing the dilation factor for the
87- convolutions in this layer.
88- :returns: A module for the MSD hidden layer.
89- :rtype: MSDLayerModule
90-
91- """
92- super (MSDLayerModule , self ).__init__ ()
93-
94- in_front = units_in_front (c_in , width , layer_depth )
95- self .buffer , self .in_front , self .width = buffer , in_front , width
96-
97- # Set output to None for the Conv2dInPlaceModule for now. We
98- # set it in the forward pass.
99- output = None
100- self .convolution = ConvRelu2dInPlaceModule (
101- output , in_front , width , kernel_size = 3 , dilation = dilation
102- )
103-
104- def forward (self , input ):
105- # Set output
106- self .convolution .output = self .buffer .L .narrow (1 , self .in_front , self .width )
107- output = self .convolution (input )
108- output = stitchLazy (output , self .buffer .L , self .buffer .G , self .in_front )
109- return output
110-
111-
112- class MSDFinalLayer (nn .Module ):
60+ class MSDFinalLayer (torch .nn .Module ):
11361 """Documentation for MSDFinalLayer
11462
11563 Implements the final 1x1 multiplication and bias addition for all
@@ -122,7 +70,7 @@ def __init__(self, c_in, c_out):
12270 super (MSDFinalLayer , self ).__init__ ()
12371 self .c_in = c_in
12472 self .c_out = c_out
125- self .linear = nn .Conv1d (c_in , c_out , 1 )
73+ self .linear = torch . nn .Conv1d (c_in , c_out , 1 )
12674 self .reset_parameters ()
12775
12876 def forward (self , input ):
@@ -140,7 +88,7 @@ def reset_parameters(self):
14088 self .linear .bias .data .zero_ ()
14189
14290
143- class MSDModule (nn .Module ):
91+ class MSDModule (torch . nn .Module ):
14492 def __init__ (
14593 self , c_in , c_out , depth , width , dilations = [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 ]
14694 ):
@@ -154,57 +102,41 @@ def __init__(
154102
155103 A list of dilations to use. Default is ``[1, 2, ..., 10]``. A
156104 good alternative is ``[1, 2, 4, 8]``. The dilations are
157- repeated.
105+ repeated when the depth of the module exceeds the length of
106+ the list.
158107
159108 :returns: an MSD module
160109 :rtype: MSDModule
161110
162111 """
163112 super (MSDModule , self ).__init__ ()
164- #
165113 self .c_in = c_in
166114 self .c_out = c_out
167115 self .depth = depth
168116 self .width = width
169- self .dilations = dilations
117+ self .dilations = [ dilations [ i % len ( dilations )] for i in range ( depth )]
170118
171- buffer = StitchBuffer ()
172- self .buffer = buffer
119+ if depth < 1 :
120+ raise ValueError (f"Depth must be at least 1. Got: { depth } ." )
121+ if width < 1 :
122+ raise ValueError (f"Width must be at least 1. Got: { width } ." )
173123
174- # The first layer copies input into the L stitch buffer
175- stitch_layer = StitchCopyModule ( buffer , 0 )
124+ self . msd_block = MSDBlock2d ( self . c_in , self . dilations , self . width )
125+ self . final_layer = MSDFinalLayer ( c_in = c_in + width * depth , c_out = c_out )
176126
177- # Then we have `depth` number of hidden layers:
178- self .hidden_layers = [
179- MSDLayerModule (buffer , c_in , d , width , dilations [d % len (dilations )])
180- for d in range (depth )
181- ]
127+ self .reset_parameters ()
182128
129+ def reset_parameters (self ):
183130 # Initialize weights for hidden layers:
184- for m in self .hidden_layers :
131+ for w in self .msd_block . weights :
185132 init_convolution_weights (
186- m . convolution . weight . data , c_in , c_out , width , depth
133+ w . data , self . c_in , self . c_out , self . width , self . depth
187134 )
188- m .convolution .bias .data .zero_ ()
189-
190- in_front = units_in_front (c_in , width , depth )
191- self .c_final = MSDFinalLayer (in_front , c_out )
192135
193- self .net = nn .Sequential (stitch_layer , * self .hidden_layers , self .c_final )
194-
195- self .net .cuda ()
136+ self .msd_block .bias .data .zero_ ()
137+ self .final_layer .reset_parameters ()
196138
197139 def forward (self , input ):
198- self .init_buffers (input .data )
199- return self .net (input )
200-
201- def init_buffers (self , input ):
202- batch_sz , c_in , * shape = input .shape
203-
204- assert c_in == self .c_in , "Unexpected number of input channels"
205-
206- # Ensure that the stitch buffer is the correct shape
207- total_units = units_in_front (self .c_in , self .width , self .depth )
208- new_shape = (batch_sz , total_units , * shape )
209-
210- self .buffer .like_ (input , new_shape )
140+ output = self .msd_block (input )
141+ output = self .final_layer (output )
142+ return output
0 commit comments