diff --git a/pylearn2/datasets/csv_dataset.py b/pylearn2/datasets/csv_dataset.py index d7b1845f11..b2cbc48d15 100644 --- a/pylearn2/datasets/csv_dataset.py +++ b/pylearn2/datasets/csv_dataset.py @@ -125,6 +125,7 @@ def __init__(self, if self.task == 'regression': super(CSVDataset, self).__init__(X=X, y=y, **kwargs) else: + y = y.astype('int32') super(CSVDataset, self).__init__(X=X, y=y, y_labels=np.max(y) + 1, **kwargs) diff --git a/pylearn2/datasets/tests/test_csv_dataset.py b/pylearn2/datasets/tests/test_csv_dataset.py index 30bc204ef1..3a8d9c519d 100644 --- a/pylearn2/datasets/tests/test_csv_dataset.py +++ b/pylearn2/datasets/tests/test_csv_dataset.py @@ -9,7 +9,8 @@ def test_loading_classification(): 'datasets', 'tests', 'test.csv') d = CSVDataset(path=test_path, expect_headers=False) assert(np.array_equal(d.X, np.array([[1., 2., 3.], [4., 5., 6.]]))) - assert(np.array_equal(d.y, np.array([[0.], [1.]]))) + assert(np.array_equal(d.y, np.array([[0], [1]]))) + assert (d.y.dtype == 'int32') def test_loading_regression():