|
16 | 16 |
|
17 | 17 | import numpy as np |
18 | 18 | import pandas as pd |
| 19 | +import pyarrow as pa |
19 | 20 | import pytest |
20 | 21 | from pyarrow import parquet as pq |
21 | 22 |
|
@@ -223,6 +224,38 @@ def fill_id_with_nones(x): |
223 | 224 | assert sample.id.dtype.type == null_column_dtype |
224 | 225 |
|
225 | 226 |
|
| 227 | +@pytest.mark.parametrize('np_dtype, pa_dtype, null_value', |
| 228 | + ((np.float32, pa.float32(), np.nan), (np.object, pa.string(), None))) |
| 229 | +@pytest.mark.parametrize('reader_factory', _D) |
| 230 | +def test_entire_column_of_typed_nulls(reader_factory, np_dtype, pa_dtype, null_value, tmp_path): |
| 231 | + path = tmp_path / "dataset" |
| 232 | + schema = pa.schema([pa.field('all_nulls', pa_dtype)]) |
| 233 | + pq.write_table(pa.Table.from_pydict({"all_nulls": [null_value] * 10}, schema=schema), path) |
| 234 | + |
| 235 | + with reader_factory("file:///" + str(path)) as reader: |
| 236 | + sample = next(reader) |
| 237 | + assert sample.all_nulls.dtype == np_dtype |
| 238 | + if np_dtype == np.float32: |
| 239 | + assert np.all(np.isnan(sample.all_nulls)) |
| 240 | + elif np_dtype == np.object: |
| 241 | + assert all(v is None for v in sample.all_nulls) |
| 242 | + else: |
| 243 | + assert False, "Unexpected np_dtype" |
| 244 | + |
| 245 | + |
| 246 | +@pytest.mark.parametrize('reader_factory', _D) |
| 247 | +def test_column_with_list_of_strings_some_are_null(reader_factory, tmp_path): |
| 248 | + path = tmp_path / "dataset" |
| 249 | + schema = pa.schema([pa.field('some_nulls', pa.list_(pa.string(), -1))]) |
| 250 | + pq.write_table(pa.Table.from_pydict({"some_nulls": [['a0', 'a1'], ['b0', None], [None, None]]}, schema=schema), |
| 251 | + path) |
| 252 | + |
| 253 | + with reader_factory("file:///" + str(path)) as reader: |
| 254 | + sample = next(reader) |
| 255 | + assert sample.some_nulls.dtype == np.object |
| 256 | + np.testing.assert_equal(sample.some_nulls, [['a0', 'a1'], ['b0', None], [None, None]]) |
| 257 | + |
| 258 | + |
226 | 259 | @pytest.mark.parametrize('reader_factory', _D) |
227 | 260 | def test_transform_spec_returns_all_none_values_in_a_list_field(scalar_dataset, reader_factory): |
228 | 261 | def fill_id_with_nones(x): |
|
0 commit comments