Skip to content

Commit 143784e

Browse files
committed
Fix HDF5 dataset
1. Don't provide parameters for superclass __new__ method (object.__new__). 2. Add function calls where required. 3. Fix interaction with iterators.
1 parent f320968 commit 143784e

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

pylearn2/datasets/hdf5.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import tables
1616
except ImportError:
1717
tables = None
18+
import numpy as np
1819
import warnings
1920
from os.path import isfile
2021
from pylearn2.compat import OrderedDict
@@ -86,9 +87,7 @@ def __new__(cls, filename, X=None, topo_view=None, y=None, load_all=False,
8687
return HDF5DatasetDeprecated(filename, X, topo_view, y, load_all,
8788
cache_size, **kwargs)
8889
else:
89-
return super(HDF5Dataset, cls).__new__(
90-
cls, filename, sources, spaces, aliases, load_all, cache_size,
91-
use_h5py, **kwargs)
90+
return super(HDF5Dataset, cls).__new__(cls)
9291

9392
def __init__(self, filename, sources, spaces, aliases=None, load_all=False,
9493
cache_size=None, use_h5py='auto', **kwargs):
@@ -204,7 +203,7 @@ def iterator(self, mode=None, data_specs=None, batch_size=None,
204203
provided when the dataset object has been created will be used.
205204
"""
206205
if data_specs is None:
207-
data_specs = (self._get_sources, self._get_spaces)
206+
data_specs = (self._get_spaces(), self._get_sources())
208207

209208
[mode, batch_size, num_batches, rng, data_specs] = self._init_iterator(
210209
mode, batch_size, num_batches, rng, data_specs)
@@ -240,7 +239,7 @@ def _get_spaces(self):
240239
-------
241240
A Space or a list of Spaces.
242241
"""
243-
space = [self.spaces[s] for s in self._get_sources]
242+
space = [self.spaces[s] for s in self._get_sources()]
244243
return space[0] if len(space) == 1 else tuple(space)
245244

246245
def get_data_specs(self, source_or_alias=None):
@@ -310,16 +309,16 @@ def get(self, sources, indexes):
310309
sources[s], *e.args))
311310
if (isinstance(indexes, (slice, py_integer_types)) or
312311
len(indexes) == 1):
313-
rval.append(sdata[indexes])
312+
val = sdata[indexes]
314313
else:
315314
warnings.warn('Accessing non sequential elements of an '
316315
'HDF5 file will be at best VERY slow. Avoid '
317316
'using iteration schemes that access '
318317
'random/shuffled data with hdf5 datasets!!')
319-
val = []
320-
[val.append(sdata[idx]) for idx in indexes]
321-
rval.append(val)
322-
return tuple(rval)
318+
val = [sdata[idx] for idx in indexes]
319+
val = tuple(tuple(row) for row in val)
320+
rval.append(val)
321+
return [np.array(val) for val in rval]
323322

324323
@wraps(Dataset.get_num_examples, assigned=(), updated=())
325324
def get_num_examples(self, source_or_alias=None):

0 commit comments

Comments
 (0)