Skip to content

Commit 9c0248a

Browse files
authored
revert: remove incorrect override of Project output schema
Reverts 77e7f8f
1 parent e1bb551 commit 9c0248a

File tree

3 files changed

+3
-65
lines changed

3 files changed

+3
-65
lines changed

core/src/main/java/io/substrait/relation/ProtoRelConverter.java

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -620,19 +620,8 @@ protected Expand newExpand(ExpandRel rel) {
620620

621621
protected Aggregate newAggregate(AggregateRel rel) {
622622
Rel input = from(rel.getInput());
623-
Type.Struct inputSchema;
624-
if (input instanceof Project) {
625-
List<Type> types =
626-
((Project) input)
627-
.getExpressions().stream().map(Expression::getType).collect(Collectors.toList());
628-
inputSchema = Type.Struct.builder().fields(types).nullable(false).build();
629-
} else {
630-
inputSchema = input.getRecordType();
631-
}
632-
633623
ProtoExpressionConverter protoExprConverter =
634-
new ProtoExpressionConverter(lookup, extensions, inputSchema, this);
635-
624+
new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this);
636625
ProtoAggregateFunctionConverter protoAggrFuncConverter =
637626
new ProtoAggregateFunctionConverter(lookup, extensions, protoExprConverter);
638627

isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java

Lines changed: 1 addition & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -244,60 +244,9 @@ public Rel visit(org.apache.calcite.rel.core.Minus minus) {
244244
return Set.builder().inputs(inputs).setOp(setOp).build();
245245
}
246246

247-
/**
248-
* Pre-processes the input to an Aggregate relation to handle nullability changes introduced by
249-
* ROLLUP/CUBE/GROUPING SETS.
250-
*
251-
* @param aggregate The original Calcite aggregate node.
252-
* @return A Substrait Rel node that is correctly typed to be the input to the Substrait
253-
* Aggregate.
254-
*/
255-
private Rel handleRollupCorrection(org.apache.calcite.rel.core.Aggregate aggregate) {
256-
Rel originalInput = apply(aggregate.getInput());
257-
258-
// Determine the correct final output type for the aggregate, which accounts for nullability.
259-
NamedStruct aggregateOutputType = typeConverter.toNamedStruct(aggregate.getRowType());
260-
List<Integer> groupKeyIndices = aggregate.getGroupSet().asList();
261-
262-
// Create a list of expressions to cast the original input to the correct final type if needed.
263-
List<Expression> castExpressions = new ArrayList<>();
264-
265-
boolean needsCasting = false;
266-
for (int i = 0; i < originalInput.getRecordType().fields().size(); i++) {
267-
Expression fieldReference = FieldReference.newInputRelReference(i, originalInput);
268-
269-
if (groupKeyIndices.contains(i)) {
270-
int groupKeyOutputIndex = groupKeyIndices.indexOf(i);
271-
Type finalType = aggregateOutputType.struct().fields().get(groupKeyOutputIndex);
272-
273-
if (finalType.nullable() && !fieldReference.getType().nullable()) {
274-
needsCasting = true; // Mark that a cast is necessary.
275-
castExpressions.add(
276-
Expression.Cast.builder()
277-
.type(finalType)
278-
.input(fieldReference)
279-
.failureBehavior(Expression.FailureBehavior.RETURN_NULL)
280-
.build());
281-
} else {
282-
castExpressions.add(fieldReference);
283-
}
284-
} else {
285-
castExpressions.add(fieldReference);
286-
}
287-
}
288-
289-
// Only add the extra Project node if a cast was actually needed.
290-
if (needsCasting) {
291-
return Project.builder().input(originalInput).expressions(castExpressions).build();
292-
}
293-
294-
// If no casting was needed, just return the original converted input.
295-
return originalInput;
296-
}
297-
298247
@Override
299248
public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) {
300-
Rel input = handleRollupCorrection(aggregate);
249+
Rel input = apply(aggregate.getInput());
301250
Stream<ImmutableBitSet> sets;
302251
if (aggregate.groupSets != null) {
303252
sets = aggregate.groupSets.stream();

isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
public class TpcdsQueryTest extends PlanTestBase {
1616
private static final Set<Integer> toSubstraitExclusions = Set.of(9, 27, 36, 70, 86);
1717
private static final Set<Integer> fromSubstraitPojoExclusions = Set.of(1, 30, 81);
18-
private static final Set<Integer> fromSubstraitProtoExclusions = Set.of(1, 30, 81);
18+
private static final Set<Integer> fromSubstraitProtoExclusions = Set.of(1, 30, 67, 81);
1919

2020
static IntStream testCases() {
2121
return IntStream.rangeClosed(1, 99).filter(n -> !toSubstraitExclusions.contains(n));

0 commit comments

Comments
 (0)