Skip to content

Commit 0a13b15

Browse files
gord02vbarua
authored andcommitted
fix: add tests, fixed bugs in code, and added support for both forms of proto in RelProtoConverter
1 parent 7861549 commit 0a13b15

File tree

3 files changed

+187
-14
lines changed

3 files changed

+187
-14
lines changed

core/src/main/java/io/substrait/relation/ProtoRelConverter.java

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -628,26 +628,23 @@ protected Aggregate newAggregate(AggregateRel rel) {
628628

629629
List<Aggregate.Grouping> groupings = new ArrayList<>(rel.getGroupingsCount());
630630

631-
// the deprecated form of Grouping is not used
631+
// new proto form is used
632632
if (!rel.getGroupingExpressionsList().isEmpty()) {
633+
633634
List<io.substrait.proto.Expression> allGroupingKeys = rel.getGroupingExpressionsList();
634635

635-
// for every grouping object on aggregate, it has a list of references into the
636-
// aggregate's expressionList for the specific sorting set
637-
for (AggregateRel.Grouping grouping : rel.getGroupingsList()) {
638-
List<io.substrait.proto.Expression> groupingKeys = new ArrayList<>();
639-
for (int key : grouping.getExpressionReferencesList()) {
640-
groupingKeys.add(allGroupingKeys.get(key));
641-
}
642-
groupings.add(
636+
for (int i = 0; i < rel.getGroupingsList().size(); i++) {
637+
// put all groupingExpressions into the group
638+
Aggregate.Grouping group =
643639
Aggregate.Grouping.builder()
644640
.expressions(
645-
groupingKeys.stream()
641+
allGroupingKeys.stream()
646642
.map(protoExprConverter::from)
647-
.collect(Collectors.toList()))
648-
.build());
643+
.collect(java.util.stream.Collectors.toList()))
644+
.build();
645+
groupings.add(group);
649646
}
650-
Aggregate.builder().input(input).groupings(groupings);
647+
651648
} else {
652649
// using the deprecated form of Grouping and Aggregate
653650
for (AggregateRel.Grouping grouping : rel.getGroupingsList()) {

core/src/main/java/io/substrait/relation/RelProtoConverter.java

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,11 @@
4747
import io.substrait.relation.physical.NestedLoopJoin;
4848
import io.substrait.type.proto.TypeProtoConverter;
4949
import io.substrait.util.EmptyVisitationContext;
50+
import java.util.ArrayList;
5051
import java.util.Collection;
52+
import java.util.HashMap;
5153
import java.util.List;
54+
import java.util.Map;
5255
import java.util.stream.Collectors;
5356
import java.util.stream.IntStream;
5457

@@ -117,12 +120,45 @@ private io.substrait.proto.Expression.FieldReference toProto(FieldReference fiel
117120

118121
@Override
119122
public Rel visit(Aggregate aggregate, EmptyVisitationContext context) throws RuntimeException {
123+
124+
List<io.substrait.proto.Expression> groupingExpressions = new ArrayList<>();
125+
Map<Expression, Integer> map = new HashMap<>();
126+
int i = 0; // unique reference values for each expression
127+
128+
List<AggregateRel.Grouping> newGroupings = new ArrayList<>();
129+
130+
for (Aggregate.Grouping gp : aggregate.getGroupings()) {
131+
// every grouping has an expression_reference list
132+
List<Integer> expr_refs = new ArrayList<>();
133+
134+
for (Expression e : gp.getExpressions()) {
135+
int ref;
136+
if (!map.containsKey(e)) {
137+
groupingExpressions.add(this.toProto(e)); // put unique expressions into full list
138+
ref = i;
139+
map.put(e, i++);
140+
} else {
141+
ref = map.get(e);
142+
}
143+
expr_refs.add(ref);
144+
}
145+
146+
newGroupings.add(
147+
AggregateRel.Grouping.newBuilder()
148+
.addAllExpressionReferences(expr_refs)
149+
.addAllGroupingExpressions(
150+
gp.getExpressions().stream().map(this::toProto).collect(Collectors.toList()))
151+
.build());
152+
}
153+
120154
AggregateRel.Builder builder =
121155
AggregateRel.newBuilder()
122156
.setInput(toProto(aggregate.getInput()))
123157
.setCommon(common(aggregate))
124158
.addAllGroupings(
125-
aggregate.getGroupings().stream().map(this::toProto).collect(Collectors.toList()))
159+
newGroupings) // adding groupings with the expression references and grouping
160+
// expressions set
161+
.addAllGroupingExpressions(groupingExpressions) // new grouping_expression attribute
126162
.addAllMeasures(
127163
aggregate.getMeasures().stream().map(this::toProto).collect(Collectors.toList()));
128164

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
package io.substrait.relation;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertTrue;
5+
6+
import io.substrait.TestBase;
7+
import io.substrait.extension.ExtensionLookup;
8+
import io.substrait.extension.ImmutableExtensionLookup;
9+
import io.substrait.proto.AggregateRel;
10+
import io.substrait.proto.Expression;
11+
import io.substrait.proto.Plan;
12+
import io.substrait.proto.Rel;
13+
import org.junit.jupiter.api.Test;
14+
15+
class AggregateRelTest extends TestBase {
16+
17+
protected static final Plan plan = Plan.newBuilder().build();
18+
protected static final ExtensionLookup functionLookup =
19+
ImmutableExtensionLookup.builder().from(plan).build();
20+
protected static final io.substrait.proto.NamedStruct namedStruct = createSchema();
21+
22+
public static io.substrait.proto.NamedStruct createSchema() {
23+
24+
io.substrait.proto.Type i32Type =
25+
io.substrait.proto.Type.newBuilder()
26+
.setI32(io.substrait.proto.Type.I32.getDefaultInstance())
27+
.build();
28+
29+
// Build a NamedStruct schema with two fields: col1, col2
30+
io.substrait.proto.Type.Struct structType =
31+
io.substrait.proto.Type.Struct.newBuilder().addTypes(i32Type).addTypes(i32Type).build();
32+
33+
return io.substrait.proto.NamedStruct.newBuilder()
34+
.setStruct(structType)
35+
.addNames("col1")
36+
.addNames("col2")
37+
.build();
38+
}
39+
40+
public static io.substrait.proto.Expression createExpression(int col) {
41+
// Build a ReferenceSegment that refers to struct field col
42+
Expression.ReferenceSegment seg1 =
43+
Expression.ReferenceSegment.newBuilder()
44+
.setStructField(
45+
Expression.ReferenceSegment.StructField.newBuilder().setField(col).build())
46+
.build();
47+
48+
// Build a FieldReference that uses the directReference and a rootReference
49+
Expression.FieldReference fieldRef1 =
50+
Expression.FieldReference.newBuilder()
51+
.setDirectReference(seg1)
52+
.setRootReference(Expression.FieldReference.RootReference.getDefaultInstance())
53+
.build();
54+
55+
// Wrap the FieldReference in an Expression.selection
56+
return Expression.newBuilder().setSelection(fieldRef1).build();
57+
}
58+
59+
@Test
60+
public void testDeprecatedGroupingExpressionsAreMapped() {
61+
Expression col1Ref = createExpression(0);
62+
Expression col2Ref = createExpression(1);
63+
64+
AggregateRel.Grouping grouping =
65+
AggregateRel.Grouping.newBuilder()
66+
.addGroupingExpressions(col1Ref) // deprecated proto form
67+
.addGroupingExpressions(col2Ref)
68+
.build();
69+
70+
// Build an input ReadRel
71+
io.substrait.proto.ReadRel readProto =
72+
io.substrait.proto.ReadRel.newBuilder().setBaseSchema(namedStruct).build();
73+
74+
// Build the AggregateRel with the new grouping_expressions field
75+
AggregateRel aggrProto =
76+
AggregateRel.newBuilder()
77+
.setInput(io.substrait.proto.Rel.newBuilder().setRead(readProto))
78+
.addGroupings(grouping)
79+
.build();
80+
81+
Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build();
82+
ProtoRelConverter converter = new ProtoRelConverter(functionLookup);
83+
io.substrait.relation.Rel resultRel = converter.from(relProto);
84+
85+
assertTrue(resultRel instanceof Aggregate);
86+
Aggregate agg = (Aggregate) resultRel;
87+
assertEquals(1, agg.getGroupings().size());
88+
assertEquals(2, agg.getGroupings().get(0).getExpressions().size());
89+
90+
// Relation to Proto where both deprecated and new form are implemented
91+
RelProtoConverter relToProtoConverter = new RelProtoConverter(functionCollector);
92+
Rel newProto = relToProtoConverter.toProto(resultRel);
93+
94+
assertEquals(2, newProto.getAggregate().getGroupings(0).getExpressionReferencesCount());
95+
assertEquals(2, newProto.getAggregate().getGroupings(0).getGroupingExpressionsList().size());
96+
assertEquals(2, newProto.getAggregate().getGroupingExpressionsList().size());
97+
}
98+
99+
@Test
100+
public void testNewAggregateProtoForm() {
101+
Expression col1Ref = createExpression(0);
102+
Expression col2Ref = createExpression(1);
103+
104+
AggregateRel.Grouping grouping =
105+
AggregateRel.Grouping.newBuilder()
106+
.addExpressionReferences(0) // new proto form
107+
.addExpressionReferences(1)
108+
.build();
109+
110+
// Build an input ReadRel
111+
io.substrait.proto.ReadRel readProto =
112+
io.substrait.proto.ReadRel.newBuilder().setBaseSchema(namedStruct).build();
113+
114+
// Build the AggregateRel with the new grouping_expressions field
115+
AggregateRel aggrProto =
116+
AggregateRel.newBuilder()
117+
.setInput(io.substrait.proto.Rel.newBuilder().setRead(readProto))
118+
.addGroupingExpressions(col1Ref)
119+
.addGroupingExpressions(col2Ref)
120+
.addGroupings(grouping)
121+
.build();
122+
123+
Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build();
124+
ProtoRelConverter converter = new ProtoRelConverter(functionLookup);
125+
io.substrait.relation.Rel resultRel = converter.from(relProto);
126+
127+
assertTrue(resultRel instanceof Aggregate);
128+
Aggregate agg = (Aggregate) resultRel;
129+
assertEquals(1, agg.getGroupings().size());
130+
assertEquals(2, agg.getGroupings().get(0).getExpressions().size());
131+
132+
// Relation to Proto where both deprecated and new form are implemented
133+
RelProtoConverter relToProtoConverter = new RelProtoConverter(functionCollector);
134+
Rel newProto = relToProtoConverter.toProto(resultRel);
135+
136+
assertEquals(2, newProto.getAggregate().getGroupings(0).getExpressionReferencesCount());
137+
assertEquals(2, newProto.getAggregate().getGroupings(0).getGroupingExpressionsList().size());
138+
assertEquals(2, newProto.getAggregate().getGroupingExpressionsList().size());
139+
}
140+
}

0 commit comments

Comments
 (0)