diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 5ed4b5d4d..5c13968bd 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -627,15 +627,36 @@ protected Aggregate newAggregate(AggregateRel rel) { new ProtoAggregateFunctionConverter(lookup, extensions, protoExprConverter); List groupings = new ArrayList<>(rel.getGroupingsCount()); - for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { - groupings.add( - Aggregate.Grouping.builder() - .expressions( - grouping.getGroupingExpressionsList().stream() - .map(protoExprConverter::from) - .collect(java.util.stream.Collectors.toList())) - .build()); + + // Groupings are set using the AggregateRel grouping_expression mechanism + if (!rel.getGroupingExpressionsList().isEmpty()) { + List allGroupingExpressions = + rel.getGroupingExpressionsList().stream() + .map(protoExprConverter::from) + .collect(java.util.stream.Collectors.toList()); + + for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { + List references = grouping.getExpressionReferencesList(); + List groupExpressions = new ArrayList<>(); + for (int ref : references) { + groupExpressions.add(allGroupingExpressions.get(ref)); + } + groupings.add(Aggregate.Grouping.builder().addAllExpressions(groupExpressions).build()); + } + + } else { + // Groupings are set using the deprecated Grouping grouping_expressions mechanism + for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { + groupings.add( + Aggregate.Grouping.builder() + .expressions( + grouping.getGroupingExpressionsList().stream() + .map(protoExprConverter::from) + .collect(java.util.stream.Collectors.toList())) + .build()); + } } + List measures = new ArrayList<>(rel.getMeasuresCount()); for (AggregateRel.Measure measure : rel.getMeasuresList()) { measures.add( diff --git a/core/src/test/java/io/substrait/relation/AggregateRelTest.java b/core/src/test/java/io/substrait/relation/AggregateRelTest.java new file mode 100644 index 000000000..4d0b604aa --- /dev/null +++ b/core/src/test/java/io/substrait/relation/AggregateRelTest.java @@ -0,0 +1,161 @@ +package io.substrait.relation; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.TestBase; +import io.substrait.extension.ExtensionLookup; +import io.substrait.extension.ImmutableExtensionLookup; +import io.substrait.proto.AggregateRel; +import io.substrait.proto.Expression; +import io.substrait.proto.Plan; +import io.substrait.proto.ReadRel; +import io.substrait.proto.Rel; +import org.junit.jupiter.api.Test; + +class AggregateRelTest extends TestBase { + + protected static final Plan plan = Plan.newBuilder().build(); + protected static final ExtensionLookup functionLookup = + ImmutableExtensionLookup.builder().from(plan).build(); + protected static final io.substrait.proto.NamedStruct namedStruct = createSchema(); + + public static io.substrait.proto.NamedStruct createSchema() { + + io.substrait.proto.Type i32Type = + io.substrait.proto.Type.newBuilder() + .setI32(io.substrait.proto.Type.I32.getDefaultInstance()) + .build(); + + // Build a NamedStruct schema with two fields: col1, col2 + io.substrait.proto.Type.Struct structType = + io.substrait.proto.Type.Struct.newBuilder().addTypes(i32Type).addTypes(i32Type).build(); + + return io.substrait.proto.NamedStruct.newBuilder() + .setStruct(structType) + .addNames("col1") + .addNames("col2") + .build(); + } + + public static io.substrait.proto.Expression createFieldReference(int col) { + // Build a ReferenceSegment that refers to struct field col + Expression.ReferenceSegment seg1 = + Expression.ReferenceSegment.newBuilder() + .setStructField( + Expression.ReferenceSegment.StructField.newBuilder().setField(col).build()) + .build(); + + // Build a FieldReference that uses the directReference and a rootReference + Expression.FieldReference fieldRef1 = + Expression.FieldReference.newBuilder() + .setDirectReference(seg1) + .setRootReference(Expression.FieldReference.RootReference.getDefaultInstance()) + .build(); + + // Wrap the FieldReference in an Expression.selection + return Expression.newBuilder().setSelection(fieldRef1).build(); + } + + @Test + public void testDeprecatedGroupingExpressionConversion() { + Expression col1Ref = createFieldReference(0); + Expression col2Ref = createFieldReference(1); + + AggregateRel.Grouping grouping = + AggregateRel.Grouping.newBuilder() + .addGroupingExpressions(col1Ref) // deprecated proto form + .addGroupingExpressions(col2Ref) + .build(); + + // Build an input ReadRel + ReadRel readProto = ReadRel.newBuilder().setBaseSchema(namedStruct).build(); + + // Build the AggregateRel with the new grouping_expressions field + AggregateRel aggrProto = + AggregateRel.newBuilder() + .setInput(Rel.newBuilder().setRead(readProto)) + .addGroupings(grouping) + .build(); + + Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build(); + ProtoRelConverter converter = new ProtoRelConverter(functionLookup); + io.substrait.relation.Rel resultRel = converter.from(relProto); + + assertTrue(resultRel instanceof Aggregate); + Aggregate agg = (Aggregate) resultRel; + assertEquals(1, agg.getGroupings().size()); + assertEquals(2, agg.getGroupings().get(0).getExpressions().size()); + } + + @Test + public void testAggregateWithSingleGrouping() { + Expression col1Ref = createFieldReference(0); + Expression col2Ref = createFieldReference(1); + + AggregateRel.Grouping grouping = + AggregateRel.Grouping.newBuilder() + .addExpressionReferences(0) + .addExpressionReferences(1) + .build(); + + // Build an input ReadRel + ReadRel readProto = ReadRel.newBuilder().setBaseSchema(namedStruct).build(); + + // Build the AggregateRel with the new grouping_expressions field + AggregateRel aggrProto = + AggregateRel.newBuilder() + .setInput(Rel.newBuilder().setRead(readProto)) + .addGroupingExpressions(col1Ref) + .addGroupingExpressions(col2Ref) + .addGroupings(grouping) + .build(); + + Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build(); + ProtoRelConverter converter = new ProtoRelConverter(functionLookup); + io.substrait.relation.Rel resultRel = converter.from(relProto); + + assertTrue(resultRel instanceof Aggregate); + Aggregate agg = (Aggregate) resultRel; + assertEquals(1, agg.getGroupings().size()); + assertEquals(2, agg.getGroupings().get(0).getExpressions().size()); + } + + @Test + public void testAggregateWithMultipleGroupings() { + Expression col1Ref = createFieldReference(0); + Expression col2Ref = createFieldReference(1); + + AggregateRel.Grouping grouping1 = + AggregateRel.Grouping.newBuilder() + .addExpressionReferences(0) // new proto form + .addExpressionReferences(1) + .build(); + + AggregateRel.Grouping grouping2 = + AggregateRel.Grouping.newBuilder().addExpressionReferences(1).build(); + + // Build an input ReadRel + ReadRel readProto = ReadRel.newBuilder().setBaseSchema(namedStruct).build(); + + // Build the AggregateRel with the new grouping_expressions field + AggregateRel aggrProto = + AggregateRel.newBuilder() + .setInput(Rel.newBuilder().setRead(readProto)) + .addGroupingExpressions(col1Ref) + .addGroupingExpressions(col2Ref) + .addGroupings(grouping1) + .addGroupings(grouping2) + .build(); + + Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build(); + ProtoRelConverter converter = new ProtoRelConverter(functionLookup); + io.substrait.relation.Rel resultRel = converter.from(relProto); + + assertTrue(resultRel instanceof Aggregate); + Aggregate agg = (Aggregate) resultRel; + assertEquals(2, agg.getGroupings().size()); + assertEquals(2, agg.getGroupings().get(0).getExpressions().size()); + assertEquals(1, agg.getGroupings().get(1).getExpressions().size()); + } +}