Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 69 additions & 17 deletions python/ray/data/_internal/arrow_ops/transform_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ def unify_schemas(
all_columns = set()
for schema in schemas:
for col_name in schema.names:
# Check for duplicate field names in this schema
if schema.names.count(col_name) > 1:
# This is broken for Pandas blocks and broken with the logic here
raise ValueError(
f"Schema {schema} has multiple fields with the same name: {col_name}"
)
col_type = schema.field(col_name).type
if pa.types.is_list(col_type) and pa.types.is_null(col_type.value_type):
cols_with_null_list.add(col_name)
Expand All @@ -197,20 +203,16 @@ def unify_schemas(

columns_with_objects = set()
columns_with_tensor_array = set()
columns_with_struct = set()
for col_name in all_columns:
for s in schemas:
indices = s.get_all_field_indices(col_name)
if len(indices) > 1:
# This is broken for Pandas blocks and broken with the logic here
raise ValueError(
f"Schema {s} has multiple fields with the same name: {col_name}"
)
elif len(indices) == 0:
continue
if isinstance(s.field(col_name).type, ArrowPythonObjectType):
columns_with_objects.add(col_name)
if isinstance(s.field(col_name).type, arrow_tensor_types):
columns_with_tensor_array.add(col_name)
if col_name in s.names:
if isinstance(s.field(col_name).type, ArrowPythonObjectType):
columns_with_objects.add(col_name)
if isinstance(s.field(col_name).type, arrow_tensor_types):
columns_with_tensor_array.add(col_name)
if isinstance(s.field(col_name).type, pa.StructType):
columns_with_struct.add(col_name)

if len(columns_with_objects.intersection(columns_with_tensor_array)) > 0:
# This is supportable if we use object type, but it will be expensive
Expand All @@ -222,10 +224,18 @@ def unify_schemas(
tensor_array_types = [
s.field(col_name).type
for s in schemas
if isinstance(s.field(col_name).type, arrow_tensor_types)
if col_name in s.names
and isinstance(s.field(col_name).type, arrow_tensor_types)
]

if ArrowTensorType._need_variable_shaped_tensor_array(tensor_array_types):
# Check if we have missing tensor fields (some schemas don't have this field)
has_missing_fields = len(tensor_array_types) < len(schemas)

# Convert to variable-shaped if needed or if we have missing fields
if (
ArrowTensorType._need_variable_shaped_tensor_array(tensor_array_types)
or has_missing_fields
):
if isinstance(tensor_array_types[0], ArrowVariableShapedTensorType):
new_type = tensor_array_types[0]
elif isinstance(tensor_array_types[0], arrow_fixed_shape_tensor_types):
Expand All @@ -243,6 +253,23 @@ def unify_schemas(
for col_name in columns_with_objects:
schema_field_overrides[col_name] = ArrowPythonObjectType()

for col_name in columns_with_struct:
field_types = [s.field(col_name).type for s in schemas]

# Unify struct schemas
struct_schemas = []
for t in field_types:
if t is not None and pa.types.is_struct(t):
struct_schemas.append(pa.schema(list(t)))
else:
struct_schemas.append(pa.schema([]))

unified_struct_schema = unify_schemas(
struct_schemas, promote_types=promote_types
)

schema_field_overrides[col_name] = pa.struct(list(unified_struct_schema))

if cols_with_null_list:
# For each opaque list column, iterate through all schemas until we find
# a valid value_type that can be used to override the column types in
Expand All @@ -260,9 +287,10 @@ def unify_schemas(
# Go through all schemas and update the types of columns from the above loop.
for schema in schemas:
for col_name, col_new_type in schema_field_overrides.items():
var_shaped_col = schema.field(col_name).with_type(col_new_type)
col_idx = schema.get_field_index(col_name)
schema = schema.set(col_idx, var_shaped_col)
if col_name in schema.names:
var_shaped_col = schema.field(col_name).with_type(col_new_type)
col_idx = schema.get_field_index(col_name)
schema = schema.set(col_idx, var_shaped_col)
schemas_to_unify.append(schema)
else:
schemas_to_unify = schemas
Expand Down Expand Up @@ -362,6 +390,12 @@ def _backfill_missing_fields(
"""
import pyarrow as pa

from ray.air.util.tensor_extensions.arrow import (
ArrowTensorType,
ArrowVariableShapedTensorType,
get_arrow_extension_tensor_types,
)

# Flatten chunked arrays into a single array if necessary
if isinstance(column, pa.ChunkedArray):
column = pa.concat_arrays(column.chunks)
Expand All @@ -381,6 +415,8 @@ def _backfill_missing_fields(
if column.type == unified_struct_type:
return column

tensor_types = get_arrow_extension_tensor_types()

aligned_fields = []

# Iterate over the fields in the unified struct type schema
Expand All @@ -398,6 +434,22 @@ def _backfill_missing_fields(
unified_struct_type=field_type,
block_length=block_length,
)

# Handle tensor extension type mismatches
elif isinstance(field_type, tensor_types) and isinstance(
current_array.type, tensor_types
):
# Convert to variable-shaped if needed
if ArrowTensorType._need_variable_shaped_tensor_array(
[current_array.type, field_type]
) and not isinstance(current_array.type, ArrowVariableShapedTensorType):
# Only convert if it's not already a variable-shaped tensor array
current_array = current_array.to_variable_shaped_tensor_array()

# The schema should already be unified by unify_schemas, so types
# should be compatible. If not, let the error propagate up.
# No explicit casting needed - PyArrow will handle type compatibility
# during struct creation or raise appropriate errors.
aligned_fields.append(current_array)
else:
# If the field is missing, fill with nulls
Expand Down
Loading