|
15 | 15 | import tables
|
16 | 16 | except ImportError:
|
17 | 17 | tables = None
|
| 18 | +import numpy as np |
18 | 19 | import warnings
|
19 | 20 | from os.path import isfile
|
20 | 21 | from pylearn2.compat import OrderedDict
|
@@ -86,9 +87,7 @@ def __new__(cls, filename, X=None, topo_view=None, y=None, load_all=False,
|
86 | 87 | return HDF5DatasetDeprecated(filename, X, topo_view, y, load_all,
|
87 | 88 | cache_size, **kwargs)
|
88 | 89 | else:
|
89 |
| - return super(HDF5Dataset, cls).__new__( |
90 |
| - cls, filename, sources, spaces, aliases, load_all, cache_size, |
91 |
| - use_h5py, **kwargs) |
| 90 | + return super(HDF5Dataset, cls).__new__(cls) |
92 | 91 |
|
93 | 92 | def __init__(self, filename, sources, spaces, aliases=None, load_all=False,
|
94 | 93 | cache_size=None, use_h5py='auto', **kwargs):
|
@@ -204,7 +203,7 @@ def iterator(self, mode=None, data_specs=None, batch_size=None,
|
204 | 203 | provided when the dataset object has been created will be used.
|
205 | 204 | """
|
206 | 205 | if data_specs is None:
|
207 |
| - data_specs = (self._get_sources, self._get_spaces) |
| 206 | + data_specs = (self._get_spaces(), self._get_sources()) |
208 | 207 |
|
209 | 208 | [mode, batch_size, num_batches, rng, data_specs] = self._init_iterator(
|
210 | 209 | mode, batch_size, num_batches, rng, data_specs)
|
@@ -240,7 +239,7 @@ def _get_spaces(self):
|
240 | 239 | -------
|
241 | 240 | A Space or a list of Spaces.
|
242 | 241 | """
|
243 |
| - space = [self.spaces[s] for s in self._get_sources] |
| 242 | + space = [self.spaces[s] for s in self._get_sources()] |
244 | 243 | return space[0] if len(space) == 1 else tuple(space)
|
245 | 244 |
|
246 | 245 | def get_data_specs(self, source_or_alias=None):
|
@@ -310,16 +309,16 @@ def get(self, sources, indexes):
|
310 | 309 | sources[s], *e.args))
|
311 | 310 | if (isinstance(indexes, (slice, py_integer_types)) or
|
312 | 311 | len(indexes) == 1):
|
313 |
| - rval.append(sdata[indexes]) |
| 312 | + val = sdata[indexes] |
314 | 313 | else:
|
315 | 314 | warnings.warn('Accessing non sequential elements of an '
|
316 | 315 | 'HDF5 file will be at best VERY slow. Avoid '
|
317 | 316 | 'using iteration schemes that access '
|
318 | 317 | 'random/shuffled data with hdf5 datasets!!')
|
319 |
| - val = [] |
320 |
| - [val.append(sdata[idx]) for idx in indexes] |
321 |
| - rval.append(val) |
322 |
| - return tuple(rval) |
| 318 | + val = [sdata[idx] for idx in indexes] |
| 319 | + val = tuple(tuple(row) for row in val) |
| 320 | + rval.append(val) |
| 321 | + return tuple(np.array(v) for v in rval) |
323 | 322 |
|
324 | 323 | @wraps(Dataset.get_num_examples, assigned=(), updated=())
|
325 | 324 | def get_num_examples(self, source_or_alias=None):
|
|
0 commit comments