Skip to content

Commit 1fcf22f

Browse files
author
Yevgeni Litvin
committed
Add tests: without tranform_spec and with list of strings with some, all elements being None.
Change the type of numpy dtype to np.object instead of np.unicode_
1 parent 9b2bb69 commit 1fcf22f

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

petastorm/arrow_reader_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def read_next(self, workers_pool, schema, ngram):
6262
column_as_numpy = column_as_pandas
6363

6464
if pa.types.is_string(column.type):
65-
result_dict[column_name] = column_as_numpy.astype(np.unicode_)
65+
result_dict[column_name] = column_as_numpy.astype(np.object)
6666
elif pa.types.is_list(column.type) or pa.types.is_fixed_size_list(column.type):
6767
# Assuming all lists are of the same length, hence we can collate them into a matrix
6868
list_of_lists = column_as_numpy

petastorm/tests/test_parquet_reader.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import numpy as np
1818
import pandas as pd
19+
import pyarrow as pa
1920
import pytest
2021
from pyarrow import parquet as pq
2122

@@ -223,6 +224,38 @@ def fill_id_with_nones(x):
223224
assert sample.id.dtype.type == null_column_dtype
224225

225226

227+
@pytest.mark.parametrize('np_dtype, pa_dtype, null_value',
228+
((np.float32, pa.float32(), np.nan), (np.object, pa.string(), None)))
229+
@pytest.mark.parametrize('reader_factory', _D)
230+
def test_entire_column_of_typed_nulls(reader_factory, np_dtype, pa_dtype, null_value, tmp_path):
231+
path = tmp_path / "dataset"
232+
schema = pa.schema([pa.field('all_nulls', pa_dtype)])
233+
pq.write_table(pa.Table.from_pydict({"all_nulls": [null_value] * 10}, schema=schema), path)
234+
235+
with reader_factory("file:///" + str(path)) as reader:
236+
sample = next(reader)
237+
assert sample.all_nulls.dtype == np_dtype
238+
if np_dtype == np.float32:
239+
assert np.all(np.isnan(sample.all_nulls))
240+
elif np_dtype == np.object:
241+
assert all(v is None for v in sample.all_nulls)
242+
else:
243+
assert False, "Unexpected np_dtype"
244+
245+
246+
@pytest.mark.parametrize('reader_factory', _D)
247+
def test_column_with_list_of_strings_some_are_null(reader_factory, tmp_path):
248+
path = tmp_path / "dataset"
249+
schema = pa.schema([pa.field('some_nulls', pa.list_(pa.string(), -1))])
250+
pq.write_table(pa.Table.from_pydict({"some_nulls": [['a0', 'a1'], ['b0', None], [None, None]]}, schema=schema),
251+
path)
252+
253+
with reader_factory("file:///" + str(path)) as reader:
254+
sample = next(reader)
255+
assert sample.some_nulls.dtype == np.object
256+
np.testing.assert_equal(sample.some_nulls, [['a0', 'a1'], ['b0', None], [None, None]])
257+
258+
226259
@pytest.mark.parametrize('reader_factory', _D)
227260
def test_transform_spec_returns_all_none_values_in_a_list_field(scalar_dataset, reader_factory):
228261
def fill_id_with_nones(x):

0 commit comments

Comments
 (0)