2020#
2121# Provide Arrow compatible classes for serializing to pyarrow.
2222#
23+ from distutils .version import LooseVersion
2324
2425import numpy as np
2526import pyarrow as pa
2627
2728from text_extensions_for_pandas .array .span import SpanArray
28- from text_extensions_for_pandas .array .token_span import TokenSpanArray
29+ from text_extensions_for_pandas .array .token_span import TokenSpanArray , _EMPTY_SPAN_ARRAY_SINGLETON
2930from text_extensions_for_pandas .array .tensor import TensorArray
3031from text_extensions_for_pandas .array .string_table import StringTable
3132
@@ -108,6 +109,9 @@ def span_to_arrow(char_span: SpanArray) -> pa.ExtensionArray:
108109 :param char_span: A SpanArray to be converted
109110 :return: pyarrow.ExtensionArray containing Span data
110111 """
112+ if LooseVersion (pa .__version__ ) < LooseVersion ("2.0.0" ):
113+ raise NotImplementedError ("Arrow serialization for SpanArray is not supported with "
114+ "PyArrow versions < 2.0.0" )
111115 # Create array for begins, ends
112116 begins_array = pa .array (char_span .begin )
113117 ends_array = pa .array (char_span .end )
@@ -130,9 +134,14 @@ def arrow_to_span(extension_array: pa.ExtensionArray) -> SpanArray:
130134 Convert a pyarrow.ExtensionArray with type ArrowSpanType to
131135 a SpanArray.
132136
137+ ..NOTE: Only supported with PyArrow >= 2.0.0
138+
133139 :param extension_array: pyarrow.ExtensionArray with type ArrowSpanType
134140 :return: SpanArray
135141 """
142+ if LooseVersion (pa .__version__ ) < LooseVersion ("2.0.0" ):
143+ raise NotImplementedError ("Arrow serialization for SpanArray is not supported with "
144+ "PyArrow versions < 2.0.0" )
136145 if isinstance (extension_array , pa .ChunkedArray ):
137146 if extension_array .num_chunks > 1 :
138147 raise ValueError ("Only pyarrow.Array with a single chunk is supported" )
@@ -175,18 +184,28 @@ def token_span_to_arrow(token_span: TokenSpanArray) -> pa.ExtensionArray:
175184 :param token_span: A TokenSpanArray to be converted
176185 :return: pyarrow.ExtensionArray containing TokenSpan data
177186 """
187+ if LooseVersion (pa .__version__ ) < LooseVersion ("2.0.0" ):
188+ raise NotImplementedError ("Arrow serialization for TokenSpanArray is not supported with "
189+ "PyArrow versions < 2.0.0" )
178190 # Create arrays for begins/ends
179191 token_begins_array = pa .array (token_span .begin_token )
180192 token_ends_array = pa .array (token_span .end_token )
181193
194+ # Filter out any empty SpanArrays
195+ non_null_tokens = token_span .tokens [~ token_span .isna ()]
196+ assert len (non_null_tokens ) > 0
197+
182198 # Get either single document as a list or use a list of all if multiple docs
183- assert len (token_span .tokens ) > 0
184- if all ([token is token_span .tokens [0 ] for token in token_span .tokens ]):
185- tokens_arrays = [token_span .tokens [0 ]]
186- tokens_indices = pa .array ([0 ] * len (token_span .tokens ))
199+ if all ([token is non_null_tokens [0 ] for token in non_null_tokens ]):
200+ tokens_arrays = [non_null_tokens [0 ]]
201+ tokens_indices = pa .array ([0 ] * len (token_span .tokens ), mask = token_span .isna ())
187202 else :
188- tokens_arrays = token_span .tokens
189- tokens_indices = pa .array (range (len (tokens_arrays )))
203+ raise NotImplementedError ("TokenSpan Multi-doc serialization not yet implemented due to "
204+ "ArrowNotImplementedError: Concat with dictionary unification NYI" )
205+ tokens_arrays = non_null_tokens
206+ tokens_indices = np .zeros_like (token_span .tokens )
207+ tokens_indices [~ token_span .isna ()] = range (len (tokens_arrays ))
208+ tokens_indices = pa .array (tokens_indices , mask = token_span .isna ())
190209
191210 # Convert each token SpanArray to Arrow and get as raw storage
192211 arrow_tokens_arrays = [span_to_arrow (sa ).storage for sa in tokens_arrays ]
@@ -217,6 +236,9 @@ def arrow_to_token_span(extension_array: pa.ExtensionArray) -> TokenSpanArray:
217236 :param extension_array: pyarrow.ExtensionArray with type ArrowTokenSpanType
218237 :return: TokenSpanArray
219238 """
239+ if LooseVersion (pa .__version__ ) < LooseVersion ("2.0.0" ):
240+ raise NotImplementedError ("Arrow serialization for TokenSpanArray is not supported with "
241+ "PyArrow versions < 2.0.0" )
220242 if isinstance (extension_array , pa .ChunkedArray ):
221243 if extension_array .num_chunks > 1 :
222244 raise ValueError ("Only pyarrow.Array with a single chunk is supported" )
@@ -252,7 +274,8 @@ def arrow_to_token_span(extension_array: pa.ExtensionArray) -> TokenSpanArray:
252274 tokens_arrays .append (tokens_array )
253275
254276 # Map the token indices to the actual token SpanArray for each element in the TokenSpanArray
255- tokens = [tokens_arrays [i .as_py ()] for i in tokens_indices ]
277+ tokens = [_EMPTY_SPAN_ARRAY_SINGLETON if i is None else tokens_arrays [i ]
278+ for i in tokens_indices .to_pylist ()]
256279
257280 # Zero-copy convert arrays to numpy
258281 token_begins = token_begins_array .to_numpy ()
0 commit comments