Skip to content

Commit 5b7df7f

Browse files
committed
Added errors for pyarrow < 2, support for TokenSpan null values, additional testing
1 parent 59d5361 commit 5b7df7f

File tree

2 files changed

+65
-10
lines changed

2 files changed

+65
-10
lines changed

text_extensions_for_pandas/array/arrow_conversion.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@
2020
#
2121
# Provide Arrow compatible classes for serializing to pyarrow.
2222
#
23+
from distutils.version import LooseVersion
2324

2425
import numpy as np
2526
import pyarrow as pa
2627

2728
from text_extensions_for_pandas.array.span import SpanArray
28-
from text_extensions_for_pandas.array.token_span import TokenSpanArray
29+
from text_extensions_for_pandas.array.token_span import TokenSpanArray, _EMPTY_SPAN_ARRAY_SINGLETON
2930
from text_extensions_for_pandas.array.tensor import TensorArray
3031
from text_extensions_for_pandas.array.string_table import StringTable
3132

@@ -108,6 +109,9 @@ def span_to_arrow(char_span: SpanArray) -> pa.ExtensionArray:
108109
:param char_span: A SpanArray to be converted
109110
:return: pyarrow.ExtensionArray containing Span data
110111
"""
112+
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
113+
raise NotImplementedError("Arrow serialization for SpanArray is not supported with "
114+
"PyArrow versions < 2.0.0")
111115
# Create array for begins, ends
112116
begins_array = pa.array(char_span.begin)
113117
ends_array = pa.array(char_span.end)
@@ -130,9 +134,14 @@ def arrow_to_span(extension_array: pa.ExtensionArray) -> SpanArray:
130134
Convert a pyarrow.ExtensionArray with type ArrowSpanType to
131135
a SpanArray.
132136
137+
..NOTE: Only supported with PyArrow >= 2.0.0
138+
133139
:param extension_array: pyarrow.ExtensionArray with type ArrowSpanType
134140
:return: SpanArray
135141
"""
142+
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
143+
raise NotImplementedError("Arrow serialization for SpanArray is not supported with "
144+
"PyArrow versions < 2.0.0")
136145
if isinstance(extension_array, pa.ChunkedArray):
137146
if extension_array.num_chunks > 1:
138147
raise ValueError("Only pyarrow.Array with a single chunk is supported")
@@ -175,18 +184,28 @@ def token_span_to_arrow(token_span: TokenSpanArray) -> pa.ExtensionArray:
175184
:param token_span: A TokenSpanArray to be converted
176185
:return: pyarrow.ExtensionArray containing TokenSpan data
177186
"""
187+
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
188+
raise NotImplementedError("Arrow serialization for TokenSpanArray is not supported with "
189+
"PyArrow versions < 2.0.0")
178190
# Create arrays for begins/ends
179191
token_begins_array = pa.array(token_span.begin_token)
180192
token_ends_array = pa.array(token_span.end_token)
181193

194+
# Filter out any empty SpanArrays
195+
non_null_tokens = token_span.tokens[~token_span.isna()]
196+
assert len(non_null_tokens) > 0
197+
182198
# Get either single document as a list or use a list of all if multiple docs
183-
assert len(token_span.tokens) > 0
184-
if all([token is token_span.tokens[0] for token in token_span.tokens]):
185-
tokens_arrays = [token_span.tokens[0]]
186-
tokens_indices = pa.array([0] * len(token_span.tokens))
199+
if all([token is non_null_tokens[0] for token in non_null_tokens]):
200+
tokens_arrays = [non_null_tokens[0]]
201+
tokens_indices = pa.array([0] * len(token_span.tokens), mask=token_span.isna())
187202
else:
188-
tokens_arrays = token_span.tokens
189-
tokens_indices = pa.array(range(len(tokens_arrays)))
203+
raise NotImplementedError("TokenSpan Multi-doc serialization not yet implemented due to "
204+
"ArrowNotImplementedError: Concat with dictionary unification NYI")
205+
tokens_arrays = non_null_tokens
206+
tokens_indices = np.zeros_like(token_span.tokens)
207+
tokens_indices[~token_span.isna()] = range(len(tokens_arrays))
208+
tokens_indices = pa.array(tokens_indices, mask=token_span.isna())
190209

191210
# Convert each token SpanArray to Arrow and get as raw storage
192211
arrow_tokens_arrays = [span_to_arrow(sa).storage for sa in tokens_arrays]
@@ -217,6 +236,9 @@ def arrow_to_token_span(extension_array: pa.ExtensionArray) -> TokenSpanArray:
217236
:param extension_array: pyarrow.ExtensionArray with type ArrowTokenSpanType
218237
:return: TokenSpanArray
219238
"""
239+
if LooseVersion(pa.__version__) < LooseVersion("2.0.0"):
240+
raise NotImplementedError("Arrow serialization for TokenSpanArray is not supported with "
241+
"PyArrow versions < 2.0.0")
220242
if isinstance(extension_array, pa.ChunkedArray):
221243
if extension_array.num_chunks > 1:
222244
raise ValueError("Only pyarrow.Array with a single chunk is supported")
@@ -252,7 +274,8 @@ def arrow_to_token_span(extension_array: pa.ExtensionArray) -> TokenSpanArray:
252274
tokens_arrays.append(tokens_array)
253275

254276
# Map the token indices to the actual token SpanArray for each element in the TokenSpanArray
255-
tokens = [tokens_arrays[i.as_py()] for i in tokens_indices]
277+
tokens = [_EMPTY_SPAN_ARRAY_SINGLETON if i is None else tokens_arrays[i]
278+
for i in tokens_indices.to_pylist()]
256279

257280
# Zero-copy convert arrays to numpy
258281
token_begins = token_begins_array.to_numpy()

text_extensions_for_pandas/array/test_token_span.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
#
1515

1616
import pandas as pd
17+
from distutils.version import LooseVersion
1718
import os
1819
import tempfile
1920
import unittest
2021
# noinspection PyPackageRequirements
2122
import pytest
2223

2324
from pandas.tests.extension import base
25+
import pyarrow as pa
2426

2527
from text_extensions_for_pandas.array.test_span import ArrayTestBase
2628
from text_extensions_for_pandas.array.span import *
@@ -365,6 +367,8 @@ def test_as_frame(self):
365367
self.assertEqual(len(df), len(arr))
366368

367369

370+
@pytest.mark.skipif(LooseVersion(pa.__version__) < LooseVersion("2.0.0"),
371+
reason="Nested dictionaries only supported in Arrow >= 2.0.0")
368372
class TokenSpanArrayIOTests(ArrayTestBase):
369373

370374
def do_roundtrip(self, df):
@@ -383,7 +387,7 @@ def test_feather(self):
383387
self.do_roundtrip(df1)
384388

385389
# More token spans than tokens
386-
"""ts2 = TokenSpanArray(toks, [0, 1, 2, 3, 0, 2, 0], [1, 2, 3, 4, 2, 4, 4])
390+
ts2 = TokenSpanArray(toks, [0, 1, 2, 3, 0, 2, 0], [1, 2, 3, 4, 2, 4, 4])
387391
df2 = pd.DataFrame({"ts2": ts2})
388392
self.do_roundtrip(df2)
389393

@@ -404,7 +408,35 @@ def test_feather(self):
404408

405409
# All columns together, TokenSpan arrays padded as needed
406410
df = pd.concat([df1, df2, df3, df4], axis=1)
407-
self.do_roundtrip(df)"""
411+
self.do_roundtrip(df)
412+
413+
@pytest.mark.skip(reason="ArrowNotImplementedError: Concat with dictionary unification NYI")
414+
def test_feather_multi_doc(self):
415+
toks = self._make_spans_of_tokens()
416+
arr = TokenSpanArray(toks, np.arange(len(toks)), np.arange(len(toks)) + 1)
417+
df1 = pd.DataFrame({'TokenSpan': arr})
418+
419+
toks = SpanArray(
420+
"Have at it.", np.array([0, 5, 8]), np.array([4, 7, 11])
421+
)
422+
arr = TokenSpanArray(toks, np.arange(len(toks)), np.arange(len(toks)) + 1)
423+
df2 = pd.DataFrame({'TokenSpan': arr})
424+
425+
df = pd.concat([df1, df2], ignore_index=True)
426+
self.assertFalse(df["TokenSpan"].array.is_single_document)
427+
self.do_roundtrip(df)
428+
429+
@pytest.mark.skip(reason="ArrowNotImplementedError: Writing DictionaryArray with nested dictionary type not yet supported")
430+
def test_parquet(self):
431+
toks = self._make_spans_of_tokens()
432+
arr = TokenSpanArray(toks, np.arange(len(toks)), np.arange(len(toks)) + 1)
433+
df = pd.DataFrame({'TokenSpan': arr})
434+
435+
with tempfile.TemporaryDirectory() as dirpath:
436+
filename = os.path.join(dirpath, "token_span_array_test.parquet")
437+
df.to_parquet(filename)
438+
df_read = pd.read_parquet(filename)
439+
pd.testing.assert_frame_equal(df, df_read)
408440

409441

410442
@pytest.fixture

0 commit comments

Comments
 (0)