Skip to content

Commit 637ffbf

Browse files
authored
feat: handle new grouping mechanism in AggregateRel protos (#521)
AggregateRel proto messages using the AggregateRel grouping_expressions field to configure grouping sets can now be read into POJOs
1 parent 077a60d commit 637ffbf

File tree

2 files changed

+190
-8
lines changed

2 files changed

+190
-8
lines changed

core/src/main/java/io/substrait/relation/ProtoRelConverter.java

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -627,15 +627,36 @@ protected Aggregate newAggregate(AggregateRel rel) {
627627
new ProtoAggregateFunctionConverter(lookup, extensions, protoExprConverter);
628628

629629
List<Aggregate.Grouping> groupings = new ArrayList<>(rel.getGroupingsCount());
630-
for (AggregateRel.Grouping grouping : rel.getGroupingsList()) {
631-
groupings.add(
632-
Aggregate.Grouping.builder()
633-
.expressions(
634-
grouping.getGroupingExpressionsList().stream()
635-
.map(protoExprConverter::from)
636-
.collect(java.util.stream.Collectors.toList()))
637-
.build());
630+
631+
// Groupings are set using the AggregateRel grouping_expression mechanism
632+
if (!rel.getGroupingExpressionsList().isEmpty()) {
633+
List<Expression> allGroupingExpressions =
634+
rel.getGroupingExpressionsList().stream()
635+
.map(protoExprConverter::from)
636+
.collect(java.util.stream.Collectors.toList());
637+
638+
for (AggregateRel.Grouping grouping : rel.getGroupingsList()) {
639+
List<Integer> references = grouping.getExpressionReferencesList();
640+
List<Expression> groupExpressions = new ArrayList<>();
641+
for (int ref : references) {
642+
groupExpressions.add(allGroupingExpressions.get(ref));
643+
}
644+
groupings.add(Aggregate.Grouping.builder().addAllExpressions(groupExpressions).build());
645+
}
646+
647+
} else {
648+
// Groupings are set using the deprecated Grouping grouping_expressions mechanism
649+
for (AggregateRel.Grouping grouping : rel.getGroupingsList()) {
650+
groupings.add(
651+
Aggregate.Grouping.builder()
652+
.expressions(
653+
grouping.getGroupingExpressionsList().stream()
654+
.map(protoExprConverter::from)
655+
.collect(java.util.stream.Collectors.toList()))
656+
.build());
657+
}
638658
}
659+
639660
List<Aggregate.Measure> measures = new ArrayList<>(rel.getMeasuresCount());
640661
for (AggregateRel.Measure measure : rel.getMeasuresList()) {
641662
measures.add(
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
package io.substrait.relation;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertTrue;
5+
6+
import io.substrait.TestBase;
7+
import io.substrait.extension.ExtensionLookup;
8+
import io.substrait.extension.ImmutableExtensionLookup;
9+
import io.substrait.proto.AggregateRel;
10+
import io.substrait.proto.Expression;
11+
import io.substrait.proto.Plan;
12+
import io.substrait.proto.ReadRel;
13+
import io.substrait.proto.Rel;
14+
import org.junit.jupiter.api.Test;
15+
16+
class AggregateRelTest extends TestBase {
17+
18+
protected static final Plan plan = Plan.newBuilder().build();
19+
protected static final ExtensionLookup functionLookup =
20+
ImmutableExtensionLookup.builder().from(plan).build();
21+
protected static final io.substrait.proto.NamedStruct namedStruct = createSchema();
22+
23+
public static io.substrait.proto.NamedStruct createSchema() {
24+
25+
io.substrait.proto.Type i32Type =
26+
io.substrait.proto.Type.newBuilder()
27+
.setI32(io.substrait.proto.Type.I32.getDefaultInstance())
28+
.build();
29+
30+
// Build a NamedStruct schema with two fields: col1, col2
31+
io.substrait.proto.Type.Struct structType =
32+
io.substrait.proto.Type.Struct.newBuilder().addTypes(i32Type).addTypes(i32Type).build();
33+
34+
return io.substrait.proto.NamedStruct.newBuilder()
35+
.setStruct(structType)
36+
.addNames("col1")
37+
.addNames("col2")
38+
.build();
39+
}
40+
41+
public static io.substrait.proto.Expression createFieldReference(int col) {
42+
// Build a ReferenceSegment that refers to struct field col
43+
Expression.ReferenceSegment seg1 =
44+
Expression.ReferenceSegment.newBuilder()
45+
.setStructField(
46+
Expression.ReferenceSegment.StructField.newBuilder().setField(col).build())
47+
.build();
48+
49+
// Build a FieldReference that uses the directReference and a rootReference
50+
Expression.FieldReference fieldRef1 =
51+
Expression.FieldReference.newBuilder()
52+
.setDirectReference(seg1)
53+
.setRootReference(Expression.FieldReference.RootReference.getDefaultInstance())
54+
.build();
55+
56+
// Wrap the FieldReference in an Expression.selection
57+
return Expression.newBuilder().setSelection(fieldRef1).build();
58+
}
59+
60+
@Test
61+
public void testDeprecatedGroupingExpressionConversion() {
62+
Expression col1Ref = createFieldReference(0);
63+
Expression col2Ref = createFieldReference(1);
64+
65+
AggregateRel.Grouping grouping =
66+
AggregateRel.Grouping.newBuilder()
67+
.addGroupingExpressions(col1Ref) // deprecated proto form
68+
.addGroupingExpressions(col2Ref)
69+
.build();
70+
71+
// Build an input ReadRel
72+
ReadRel readProto = ReadRel.newBuilder().setBaseSchema(namedStruct).build();
73+
74+
// Build the AggregateRel with the new grouping_expressions field
75+
AggregateRel aggrProto =
76+
AggregateRel.newBuilder()
77+
.setInput(Rel.newBuilder().setRead(readProto))
78+
.addGroupings(grouping)
79+
.build();
80+
81+
Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build();
82+
ProtoRelConverter converter = new ProtoRelConverter(functionLookup);
83+
io.substrait.relation.Rel resultRel = converter.from(relProto);
84+
85+
assertTrue(resultRel instanceof Aggregate);
86+
Aggregate agg = (Aggregate) resultRel;
87+
assertEquals(1, agg.getGroupings().size());
88+
assertEquals(2, agg.getGroupings().get(0).getExpressions().size());
89+
}
90+
91+
@Test
92+
public void testAggregateWithSingleGrouping() {
93+
Expression col1Ref = createFieldReference(0);
94+
Expression col2Ref = createFieldReference(1);
95+
96+
AggregateRel.Grouping grouping =
97+
AggregateRel.Grouping.newBuilder()
98+
.addExpressionReferences(0)
99+
.addExpressionReferences(1)
100+
.build();
101+
102+
// Build an input ReadRel
103+
ReadRel readProto = ReadRel.newBuilder().setBaseSchema(namedStruct).build();
104+
105+
// Build the AggregateRel with the new grouping_expressions field
106+
AggregateRel aggrProto =
107+
AggregateRel.newBuilder()
108+
.setInput(Rel.newBuilder().setRead(readProto))
109+
.addGroupingExpressions(col1Ref)
110+
.addGroupingExpressions(col2Ref)
111+
.addGroupings(grouping)
112+
.build();
113+
114+
Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build();
115+
ProtoRelConverter converter = new ProtoRelConverter(functionLookup);
116+
io.substrait.relation.Rel resultRel = converter.from(relProto);
117+
118+
assertTrue(resultRel instanceof Aggregate);
119+
Aggregate agg = (Aggregate) resultRel;
120+
assertEquals(1, agg.getGroupings().size());
121+
assertEquals(2, agg.getGroupings().get(0).getExpressions().size());
122+
}
123+
124+
@Test
125+
public void testAggregateWithMultipleGroupings() {
126+
Expression col1Ref = createFieldReference(0);
127+
Expression col2Ref = createFieldReference(1);
128+
129+
AggregateRel.Grouping grouping1 =
130+
AggregateRel.Grouping.newBuilder()
131+
.addExpressionReferences(0) // new proto form
132+
.addExpressionReferences(1)
133+
.build();
134+
135+
AggregateRel.Grouping grouping2 =
136+
AggregateRel.Grouping.newBuilder().addExpressionReferences(1).build();
137+
138+
// Build an input ReadRel
139+
ReadRel readProto = ReadRel.newBuilder().setBaseSchema(namedStruct).build();
140+
141+
// Build the AggregateRel with the new grouping_expressions field
142+
AggregateRel aggrProto =
143+
AggregateRel.newBuilder()
144+
.setInput(Rel.newBuilder().setRead(readProto))
145+
.addGroupingExpressions(col1Ref)
146+
.addGroupingExpressions(col2Ref)
147+
.addGroupings(grouping1)
148+
.addGroupings(grouping2)
149+
.build();
150+
151+
Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build();
152+
ProtoRelConverter converter = new ProtoRelConverter(functionLookup);
153+
io.substrait.relation.Rel resultRel = converter.from(relProto);
154+
155+
assertTrue(resultRel instanceof Aggregate);
156+
Aggregate agg = (Aggregate) resultRel;
157+
assertEquals(2, agg.getGroupings().size());
158+
assertEquals(2, agg.getGroupings().get(0).getExpressions().size());
159+
assertEquals(1, agg.getGroupings().get(1).getExpressions().size());
160+
}
161+
}

0 commit comments

Comments
 (0)