Skip to content

Commit 781e61f

Browse files
authored
concat: Handle mixed Tensor types for structs (#54386)
<!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? <!-- Please give a short summary of the change and the problem this solves. --> `concat`: Handle mixed Tensor types for structs **unify_schemas** - Handle duplicate column names in schema. - For structs, invoke `unify_schemas` on itself. - For tensors, handle missing fields. **concat** - For structs, `_align_struct_fields` is invoked to handle missing fields and aligned schemas. Here handle Tensors type mismatch in `_backfill_missing_fields`. **Tests** - Added test fixtures to existing test. No logic changes. ``` test_arrow_concat_empty test_arrow_concat_single_block test_arrow_concat_basic test_arrow_concat_null_promotion test_arrow_concat_tensor_extension_uniform test_arrow_concat_tensor_extension_variable_shaped test_arrow_concat_tensor_extension_uniform_and_variable_shaped test_arrow_concat_tensor_extension_uniform_but_different test_arrow_concat_with_objects test_struct_with_different_field_names test_nested_structs test_struct_with_null_values test_struct_with_mismatched_lengths test_struct_with_empty_arrays test_arrow_concat_object_with_tensor_fails test_unify_schemas test_unify_schemas_type_promotion test_arrow_block_select test_arrow_block_slice_copy test_arrow_block_slice_copy_empty ``` - Test `concat` of tables with structs & tensors coverage. ``` test_struct_with_arrow_variable_shaped_tensor_type test_mixed_tensor_types_same_dtype test_mixed_tensor_types_fixed_shape_different test_mixed_tensor_types_variable_shaped test_mixed_tensor_types_in_struct test_nested_struct_with_mixed_tensor_types test_multiple_tensor_fields_in_struct test_struct_with_incompatible_tensor_dtypes_fails test_struct_with_additional_fields test_struct_with_null_tensor_values ``` - Test `unify_schema` coverage. ``` test_unify_schemas_null_typed_lists test_unify_schemas_object_types test_unify_schemas_duplicate_fields test_unify_schemas_incompatible_tensor_dtypes test_unify_schemas_objects_and_tensors test_unify_schemas_missing_tensor_fields test_unify_schemas_nested_struct_tensors test_unify_schemas_edge_cases test_unify_schemas_mixed_tensor_types ``` ## Related issue number "Closes #54186" ## Checks - [ ] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [ ] I've run `scripts/format.sh` to lint the changes in this PR. - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [ ] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: Srinath Krishnamachari <[email protected]>
1 parent 9392c25 commit 781e61f

File tree

2 files changed

+2726
-505
lines changed

2 files changed

+2726
-505
lines changed

python/ray/data/_internal/arrow_ops/transform_pyarrow.py

Lines changed: 69 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,12 @@ def unify_schemas(
182182
all_columns = set()
183183
for schema in schemas:
184184
for col_name in schema.names:
185+
# Check for duplicate field names in this schema
186+
if schema.names.count(col_name) > 1:
187+
# This is broken for Pandas blocks and broken with the logic here
188+
raise ValueError(
189+
f"Schema {schema} has multiple fields with the same name: {col_name}"
190+
)
185191
col_type = schema.field(col_name).type
186192
if pa.types.is_list(col_type) and pa.types.is_null(col_type.value_type):
187193
cols_with_null_list.add(col_name)
@@ -197,20 +203,16 @@ def unify_schemas(
197203

198204
columns_with_objects = set()
199205
columns_with_tensor_array = set()
206+
columns_with_struct = set()
200207
for col_name in all_columns:
201208
for s in schemas:
202-
indices = s.get_all_field_indices(col_name)
203-
if len(indices) > 1:
204-
# This is broken for Pandas blocks and broken with the logic here
205-
raise ValueError(
206-
f"Schema {s} has multiple fields with the same name: {col_name}"
207-
)
208-
elif len(indices) == 0:
209-
continue
210-
if isinstance(s.field(col_name).type, ArrowPythonObjectType):
211-
columns_with_objects.add(col_name)
212-
if isinstance(s.field(col_name).type, arrow_tensor_types):
213-
columns_with_tensor_array.add(col_name)
209+
if col_name in s.names:
210+
if isinstance(s.field(col_name).type, ArrowPythonObjectType):
211+
columns_with_objects.add(col_name)
212+
if isinstance(s.field(col_name).type, arrow_tensor_types):
213+
columns_with_tensor_array.add(col_name)
214+
if isinstance(s.field(col_name).type, pa.StructType):
215+
columns_with_struct.add(col_name)
214216

215217
if len(columns_with_objects.intersection(columns_with_tensor_array)) > 0:
216218
# This is supportable if we use object type, but it will be expensive
@@ -222,10 +224,18 @@ def unify_schemas(
222224
tensor_array_types = [
223225
s.field(col_name).type
224226
for s in schemas
225-
if isinstance(s.field(col_name).type, arrow_tensor_types)
227+
if col_name in s.names
228+
and isinstance(s.field(col_name).type, arrow_tensor_types)
226229
]
227230

228-
if ArrowTensorType._need_variable_shaped_tensor_array(tensor_array_types):
231+
# Check if we have missing tensor fields (some schemas don't have this field)
232+
has_missing_fields = len(tensor_array_types) < len(schemas)
233+
234+
# Convert to variable-shaped if needed or if we have missing fields
235+
if (
236+
ArrowTensorType._need_variable_shaped_tensor_array(tensor_array_types)
237+
or has_missing_fields
238+
):
229239
if isinstance(tensor_array_types[0], ArrowVariableShapedTensorType):
230240
new_type = tensor_array_types[0]
231241
elif isinstance(tensor_array_types[0], arrow_fixed_shape_tensor_types):
@@ -243,6 +253,23 @@ def unify_schemas(
243253
for col_name in columns_with_objects:
244254
schema_field_overrides[col_name] = ArrowPythonObjectType()
245255

256+
for col_name in columns_with_struct:
257+
field_types = [s.field(col_name).type for s in schemas]
258+
259+
# Unify struct schemas
260+
struct_schemas = []
261+
for t in field_types:
262+
if t is not None and pa.types.is_struct(t):
263+
struct_schemas.append(pa.schema(list(t)))
264+
else:
265+
struct_schemas.append(pa.schema([]))
266+
267+
unified_struct_schema = unify_schemas(
268+
struct_schemas, promote_types=promote_types
269+
)
270+
271+
schema_field_overrides[col_name] = pa.struct(list(unified_struct_schema))
272+
246273
if cols_with_null_list:
247274
# For each opaque list column, iterate through all schemas until we find
248275
# a valid value_type that can be used to override the column types in
@@ -260,9 +287,10 @@ def unify_schemas(
260287
# Go through all schemas and update the types of columns from the above loop.
261288
for schema in schemas:
262289
for col_name, col_new_type in schema_field_overrides.items():
263-
var_shaped_col = schema.field(col_name).with_type(col_new_type)
264-
col_idx = schema.get_field_index(col_name)
265-
schema = schema.set(col_idx, var_shaped_col)
290+
if col_name in schema.names:
291+
var_shaped_col = schema.field(col_name).with_type(col_new_type)
292+
col_idx = schema.get_field_index(col_name)
293+
schema = schema.set(col_idx, var_shaped_col)
266294
schemas_to_unify.append(schema)
267295
else:
268296
schemas_to_unify = schemas
@@ -362,6 +390,12 @@ def _backfill_missing_fields(
362390
"""
363391
import pyarrow as pa
364392

393+
from ray.air.util.tensor_extensions.arrow import (
394+
ArrowTensorType,
395+
ArrowVariableShapedTensorType,
396+
get_arrow_extension_tensor_types,
397+
)
398+
365399
# Flatten chunked arrays into a single array if necessary
366400
if isinstance(column, pa.ChunkedArray):
367401
column = pa.concat_arrays(column.chunks)
@@ -381,6 +415,8 @@ def _backfill_missing_fields(
381415
if column.type == unified_struct_type:
382416
return column
383417

418+
tensor_types = get_arrow_extension_tensor_types()
419+
384420
aligned_fields = []
385421

386422
# Iterate over the fields in the unified struct type schema
@@ -398,6 +434,22 @@ def _backfill_missing_fields(
398434
unified_struct_type=field_type,
399435
block_length=block_length,
400436
)
437+
438+
# Handle tensor extension type mismatches
439+
elif isinstance(field_type, tensor_types) and isinstance(
440+
current_array.type, tensor_types
441+
):
442+
# Convert to variable-shaped if needed
443+
if ArrowTensorType._need_variable_shaped_tensor_array(
444+
[current_array.type, field_type]
445+
) and not isinstance(current_array.type, ArrowVariableShapedTensorType):
446+
# Only convert if it's not already a variable-shaped tensor array
447+
current_array = current_array.to_variable_shaped_tensor_array()
448+
449+
# The schema should already be unified by unify_schemas, so types
450+
# should be compatible. If not, let the error propagate up.
451+
# No explicit casting needed - PyArrow will handle type compatibility
452+
# during struct creation or raise appropriate errors.
401453
aligned_fields.append(current_array)
402454
else:
403455
# If the field is missing, fill with nulls

0 commit comments

Comments
 (0)