Skip to content

Commit f320968

Browse files
committed
Merge pull request #1534 from woozzu/params
Fix StackedBlocks when layers have no _params
2 parents 6b6974f + cc0fb77 commit f320968

File tree

2 files changed

+40
-3
lines changed

2 files changed

+40
-3
lines changed

pylearn2/blocks.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,15 @@ def __init__(self, layers):
131131
super(StackedBlocks, self).__init__()
132132

133133
self._layers = layers
134-
# Do not duplicate the parameters if some are shared between layers
135-
self._params = set([p for l in self._layers for p in l._params])
134+
self._params = set()
135+
for l in self._layers:
136+
if not hasattr(l, '_params'):
137+
self._params = None
138+
break
139+
else:
140+
# Do not duplicate the parameters if some are shared
141+
# between layers
142+
self._params.update(l._params)
136143

137144
def layers(self):
138145
"""
@@ -242,7 +249,8 @@ def append(self, layer):
242249
layer : WRITEME
243250
"""
244251
self._layers.append(layer)
245-
self._params.update(layer._params)
252+
if self._params is not None:
253+
self._params.update(layer._params)
246254

247255
def get_input_space(self):
248256
"""

pylearn2/tests/test_blocks.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""
2+
Unit tests for blocks
3+
"""
4+
5+
from pylearn2.models.autoencoder import Autoencoder
6+
from pylearn2.blocks import Block, StackedBlocks
7+
8+
9+
def test_stackedblocks_with_params():
10+
"""
11+
Test StackedBlocks when all layers have trainable params
12+
"""
13+
14+
aes = [Autoencoder(100, 50, 'tanh', 'tanh'),
15+
Autoencoder(50, 10, 'tanh', 'tanh')]
16+
sb = StackedBlocks(aes)
17+
_params = set([p for l in sb._layers for p in l._params])
18+
19+
assert sb._params == _params
20+
21+
22+
def test_stackedblocks_without_params():
23+
"""
24+
Test StackedBlocks when not all layers have trainable params
25+
"""
26+
27+
sb = StackedBlocks([Block(), Block()])
28+
29+
assert sb._params is None

0 commit comments

Comments
 (0)