From 3ee086c6987f4177d9ce4cf74045e3bbd81d8fd5 Mon Sep 17 00:00:00 2001 From: Gordon Hamilton Date: Wed, 24 Sep 2025 11:46:39 -0400 Subject: [PATCH 1/9] fix: adding support for both deprecated and new proto representation of Grouping for the AggregateRel --- .../substrait/relation/ProtoRelConverter.java | 43 +++++++++++++++---- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index a055a8f97..493a5dbad 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -626,15 +626,40 @@ protected Aggregate newAggregate(AggregateRel rel) { new ProtoAggregateFunctionConverter(lookup, extensions, protoExprConverter); List groupings = new ArrayList<>(rel.getGroupingsCount()); - for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { - groupings.add( - Aggregate.Grouping.builder() - .expressions( - grouping.getGroupingExpressionsList().stream() - .map(protoExprConverter::from) - .collect(java.util.stream.Collectors.toList())) - .build()); - } + +// the deprecated form of Grouping is not used + if (!rel.getGroupingExpressionsList().isEmpty()) { + List allGroupingKeys = rel.getGroupingExpressionsList(); + +// for every grouping object on aggregate, it has a list of references into the aggregate's expressionList for the specific sorting set + for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { + List groupingKeys = new ArrayList<>(); + for (int key: grouping.getExpressionReferencesList()) { + groupingKeys.add(allGroupingKeys.get(key)); + } + groupings.add( + Aggregate.Grouping.builder() + .expressions( + groupingKeys.stream() + .map(protoExprConverter::from) + .collect(Collectors.toList())) + .build()); + } + Aggregate.builder().input(input).groupings(groupings); + }else{ + // using the deprecated form of Grouping and Aggregate + for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { + groupings.add( + Aggregate.Grouping.builder() + .expressions( + grouping.getGroupingExpressionsList().stream() + .map(protoExprConverter::from) + .collect(Collectors.toList())) + .build()); + } + } + + List measures = new ArrayList<>(rel.getMeasuresCount()); for (AggregateRel.Measure measure : rel.getMeasuresList()) { measures.add( From e28168f6afaf8ab41f24c3ab9bb28fa75270931f Mon Sep 17 00:00:00 2001 From: Gordon Hamilton Date: Wed, 24 Sep 2025 15:01:06 -0400 Subject: [PATCH 2/9] fix: adding support for both deprecated and new proto of Groupings for the AggregateRel --- .../substrait/relation/ProtoRelConverter.java | 62 +++++++++---------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 493a5dbad..574762309 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -627,38 +627,38 @@ protected Aggregate newAggregate(AggregateRel rel) { List groupings = new ArrayList<>(rel.getGroupingsCount()); -// the deprecated form of Grouping is not used - if (!rel.getGroupingExpressionsList().isEmpty()) { - List allGroupingKeys = rel.getGroupingExpressionsList(); - -// for every grouping object on aggregate, it has a list of references into the aggregate's expressionList for the specific sorting set - for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { - List groupingKeys = new ArrayList<>(); - for (int key: grouping.getExpressionReferencesList()) { - groupingKeys.add(allGroupingKeys.get(key)); - } - groupings.add( - Aggregate.Grouping.builder() - .expressions( - groupingKeys.stream() - .map(protoExprConverter::from) - .collect(Collectors.toList())) - .build()); - } - Aggregate.builder().input(input).groupings(groupings); - }else{ - // using the deprecated form of Grouping and Aggregate - for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { - groupings.add( - Aggregate.Grouping.builder() - .expressions( - grouping.getGroupingExpressionsList().stream() - .map(protoExprConverter::from) - .collect(Collectors.toList())) - .build()); - } + // the deprecated form of Grouping is not used + if (!rel.getGroupingExpressionsList().isEmpty()) { + List allGroupingKeys = rel.getGroupingExpressionsList(); + + // for every grouping object on aggregate, it has a list of references into the + // aggregate's expressionList for the specific sorting set + for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { + List groupingKeys = new ArrayList<>(); + for (int key : grouping.getExpressionReferencesList()) { + groupingKeys.add(allGroupingKeys.get(key)); + } + groupings.add( + Aggregate.Grouping.builder() + .expressions( + groupingKeys.stream() + .map(protoExprConverter::from) + .collect(Collectors.toList())) + .build()); } - + Aggregate.builder().input(input).groupings(groupings); + } else { + // using the deprecated form of Grouping and Aggregate + for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { + groupings.add( + Aggregate.Grouping.builder() + .expressions( + grouping.getGroupingExpressionsList().stream() + .map(protoExprConverter::from) + .collect(Collectors.toList())) + .build()); + } + } List measures = new ArrayList<>(rel.getMeasuresCount()); for (AggregateRel.Measure measure : rel.getMeasuresList()) { From 50ee1bcd6d0f3c5b20374adfbd905f4bf1d39579 Mon Sep 17 00:00:00 2001 From: Gordon Hamilton Date: Mon, 29 Sep 2025 17:05:25 -0400 Subject: [PATCH 3/9] add tests, fixed bugs in code, and added support for both forms of proto in RelProtoConverter --- .../substrait/relation/ProtoRelConverter.java | 28 ++-- .../substrait/relation/RelProtoConverter.java | 40 ++++- .../substrait/relation/AggregateRelTest.java | 139 ++++++++++++++++++ 3 files changed, 186 insertions(+), 21 deletions(-) create mode 100644 core/src/test/java/io/substrait/relation/AggregateRelTest.java diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 574762309..9730a930a 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -627,26 +627,20 @@ protected Aggregate newAggregate(AggregateRel rel) { List groupings = new ArrayList<>(rel.getGroupingsCount()); - // the deprecated form of Grouping is not used + // new proto form is used if (!rel.getGroupingExpressionsList().isEmpty()) { - List allGroupingKeys = rel.getGroupingExpressionsList(); - // for every grouping object on aggregate, it has a list of references into the - // aggregate's expressionList for the specific sorting set - for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { - List groupingKeys = new ArrayList<>(); - for (int key : grouping.getExpressionReferencesList()) { - groupingKeys.add(allGroupingKeys.get(key)); + List allGroupingKeys = rel.getGroupingExpressionsList(); + + for (int i = 0; i < rel.getGroupingsList().size(); i++) { + // put all groupingExpressions into the group + Aggregate.Grouping group = Aggregate.Grouping.builder() + .expressions(allGroupingKeys.stream() + .map(protoExprConverter::from) + .collect(java.util.stream.Collectors.toList())).build(); + groupings.add(group); } - groupings.add( - Aggregate.Grouping.builder() - .expressions( - groupingKeys.stream() - .map(protoExprConverter::from) - .collect(Collectors.toList())) - .build()); - } - Aggregate.builder().input(input).groupings(groupings); + } else { // using the deprecated form of Grouping and Aggregate for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index a1da96b03..4dd67ba06 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -47,8 +47,11 @@ import io.substrait.relation.physical.NestedLoopJoin; import io.substrait.type.proto.TypeProtoConverter; import io.substrait.util.EmptyVisitationContext; +import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -61,7 +64,7 @@ public class RelProtoConverter protected final ExtensionCollector extensionCollector; - public RelProtoConverter(ExtensionCollector extensionCollector) { + public RelProtoConverter(ExtensionCollector extensionCollector) { this.extensionCollector = extensionCollector; this.exprProtoConverter = new ExpressionProtoConverter(extensionCollector, this); this.typeProtoConverter = new TypeProtoConverter(extensionCollector); @@ -117,12 +120,41 @@ private io.substrait.proto.Expression.FieldReference toProto(FieldReference fiel @Override public Rel visit(Aggregate aggregate, EmptyVisitationContext context) throws RuntimeException { - AggregateRel.Builder builder = + + List groupingExpressions = new ArrayList<>(); + Map map = new HashMap<>(); + int i = 0;// unique reference values for each expression + + List newGroupings = new ArrayList<>(); + + for(Aggregate.Grouping gp : aggregate.getGroupings()) { + // every grouping has an expression_reference list + List expr_refs = new ArrayList<>(); + + for(Expression e: gp.getExpressions()) { + int ref; + if(!map.containsKey(e)) { + groupingExpressions.add(this.toProto(e)); // put unique expressions into full list + ref = i; + map.put(e, i++); + }else{ + ref = map.get(e); + } + expr_refs.add(ref); + } + + newGroupings.add(AggregateRel.Grouping.newBuilder() + .addAllExpressionReferences(expr_refs) + .addAllGroupingExpressions(gp.getExpressions().stream().map(this::toProto).collect(Collectors.toList())) + .build()); + } + + AggregateRel.Builder builder = AggregateRel.newBuilder() .setInput(toProto(aggregate.getInput())) .setCommon(common(aggregate)) - .addAllGroupings( - aggregate.getGroupings().stream().map(this::toProto).collect(Collectors.toList())) + .addAllGroupings(newGroupings) // adding groupings with the expression references and grouping expressions set + .addAllGroupingExpressions(groupingExpressions) // new grouping_expression attribute .addAllMeasures( aggregate.getMeasures().stream().map(this::toProto).collect(Collectors.toList())); diff --git a/core/src/test/java/io/substrait/relation/AggregateRelTest.java b/core/src/test/java/io/substrait/relation/AggregateRelTest.java new file mode 100644 index 000000000..15dfb7d9b --- /dev/null +++ b/core/src/test/java/io/substrait/relation/AggregateRelTest.java @@ -0,0 +1,139 @@ +package io.substrait.relation; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.TestBase; +import io.substrait.expression.proto.ProtoExpressionConverter; +import io.substrait.extension.ExtensionLookup; +import io.substrait.extension.ImmutableExtensionLookup; +import io.substrait.proto.AggregateRel; +import io.substrait.proto.Expression; +import io.substrait.proto.Plan; +import io.substrait.proto.Rel; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; +import org.junit.jupiter.api.Test; + +class AggregateRelTest extends TestBase { + + protected static final Plan plan = Plan.newBuilder().build(); + protected static final ExtensionLookup functionLookup = + ImmutableExtensionLookup.builder().from(plan).build(); + protected static final io.substrait.proto.NamedStruct namedStruct = createSchema(); + + public static io.substrait.proto.NamedStruct createSchema() { + + io.substrait.proto.Type i32Type = + io.substrait.proto.Type.newBuilder() + .setI32(io.substrait.proto.Type.I32.getDefaultInstance()) + .build(); + + // Build a NamedStruct schema with two fields: col1, col2 + io.substrait.proto.Type.Struct structType = + io.substrait.proto.Type.Struct.newBuilder().addTypes(i32Type).addTypes(i32Type).build(); + + return io.substrait.proto.NamedStruct.newBuilder() + .setStruct(structType) + .addNames("col1") + .addNames("col2") + .build(); + } + + public static io.substrait.proto.Expression createExpression(int col) { + // Build a ReferenceSegment that refers to struct field col + Expression.ReferenceSegment seg1 = + Expression.ReferenceSegment.newBuilder() + .setStructField( + Expression.ReferenceSegment.StructField.newBuilder().setField(col).build()) + .build(); + + // Build a FieldReference that uses the directReference and a rootReference + Expression.FieldReference fieldRef1 = + Expression.FieldReference.newBuilder() + .setDirectReference(seg1) + .setRootReference(Expression.FieldReference.RootReference.getDefaultInstance()) + .build(); + + // Wrap the FieldReference in an Expression.selection + return Expression.newBuilder().setSelection(fieldRef1).build(); + } + + + @Test + public void testDeprecatedGroupingExpressionsAreMapped() { + Expression col1Ref = createExpression(0); + Expression col2Ref = createExpression(1); + + AggregateRel.Grouping grouping = + AggregateRel.Grouping.newBuilder() + .addGroupingExpressions(col1Ref) // deprecated proto form + .addGroupingExpressions(col2Ref) + .build(); + + // Build an input ReadRel + io.substrait.proto.ReadRel readProto = + io.substrait.proto.ReadRel.newBuilder().setBaseSchema(namedStruct).build(); + + // Build the AggregateRel with the new grouping_expressions field + AggregateRel aggrProto = + AggregateRel.newBuilder() + .setInput(io.substrait.proto.Rel.newBuilder().setRead(readProto)) + .addGroupings(grouping) + .build(); + + Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build(); + ProtoRelConverter converter = new ProtoRelConverter(functionLookup); + io.substrait.relation.Rel resultRel = converter.from(relProto); + + assertTrue(resultRel instanceof Aggregate); + Aggregate agg = (Aggregate) resultRel; + assertEquals(1, agg.getGroupings().size()); + assertEquals(2, agg.getGroupings().get(0).getExpressions().size()); + } + + @Test + public void testNewAggregateProtoForm() { + Expression col1Ref = createExpression(0); + Expression col2Ref = createExpression(1); + + AggregateRel.Grouping grouping = + AggregateRel.Grouping.newBuilder() + .addExpressionReferences(0) // new proto form + .addExpressionReferences(1) + .build(); + + // Build an input ReadRel + io.substrait.proto.ReadRel readProto = + io.substrait.proto.ReadRel.newBuilder().setBaseSchema(namedStruct).build(); + + // Build the AggregateRel with the new grouping_expressions field + AggregateRel aggrProto = + AggregateRel.newBuilder() + .setInput(io.substrait.proto.Rel.newBuilder().setRead(readProto)) + .addGroupingExpressions(col1Ref) + .addGroupingExpressions(col2Ref) + .addGroupings(grouping) + .build(); + + Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build(); + ProtoRelConverter converter = new ProtoRelConverter(functionLookup); + io.substrait.relation.Rel resultRel = converter.from(relProto); + + assertTrue(resultRel instanceof Aggregate); + Aggregate agg = (Aggregate) resultRel; + assertEquals(1, agg.getGroupings().size()); + assertEquals(2, agg.getGroupings().get(0).getExpressions().size()); + + + // Relation to Proto where both deprecated and new form are implemented + RelProtoConverter relToProtoConverter = new RelProtoConverter(functionCollector); + Rel newProto = relToProtoConverter.toProto(resultRel); + + assertEquals(2, newProto.getAggregate().getGroupings(0).getExpressionReferencesCount()); + assertEquals(2, newProto.getAggregate().getGroupings(0).getGroupingExpressionsList().size()); + assertEquals(2, newProto.getAggregate().getGroupingExpressionsList().size()); + } +} From 6ffc6ee5f8a1412ffcd019692b3c9ad80462fa24 Mon Sep 17 00:00:00 2001 From: Gordon Hamilton Date: Mon, 29 Sep 2025 17:05:25 -0400 Subject: [PATCH 4/9] fix: add tests, fixed bugs in code, and added support for both forms of proto in RelProtoConverter --- .../substrait/relation/ProtoRelConverter.java | 23 ++- .../substrait/relation/RelProtoConverter.java | 38 ++++- .../substrait/relation/AggregateRelTest.java | 140 ++++++++++++++++++ 3 files changed, 187 insertions(+), 14 deletions(-) create mode 100644 core/src/test/java/io/substrait/relation/AggregateRelTest.java diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 574762309..9913f0188 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -627,26 +627,23 @@ protected Aggregate newAggregate(AggregateRel rel) { List groupings = new ArrayList<>(rel.getGroupingsCount()); - // the deprecated form of Grouping is not used + // new proto form is used if (!rel.getGroupingExpressionsList().isEmpty()) { + List allGroupingKeys = rel.getGroupingExpressionsList(); - // for every grouping object on aggregate, it has a list of references into the - // aggregate's expressionList for the specific sorting set - for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { - List groupingKeys = new ArrayList<>(); - for (int key : grouping.getExpressionReferencesList()) { - groupingKeys.add(allGroupingKeys.get(key)); - } - groupings.add( + for (int i = 0; i < rel.getGroupingsList().size(); i++) { + // put all groupingExpressions into the group + Aggregate.Grouping group = Aggregate.Grouping.builder() .expressions( - groupingKeys.stream() + allGroupingKeys.stream() .map(protoExprConverter::from) - .collect(Collectors.toList())) - .build()); + .collect(java.util.stream.Collectors.toList())) + .build(); + groupings.add(group); } - Aggregate.builder().input(input).groupings(groupings); + } else { // using the deprecated form of Grouping and Aggregate for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index a1da96b03..dc7a42a2a 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -47,8 +47,11 @@ import io.substrait.relation.physical.NestedLoopJoin; import io.substrait.type.proto.TypeProtoConverter; import io.substrait.util.EmptyVisitationContext; +import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -117,12 +120,45 @@ private io.substrait.proto.Expression.FieldReference toProto(FieldReference fiel @Override public Rel visit(Aggregate aggregate, EmptyVisitationContext context) throws RuntimeException { + + List groupingExpressions = new ArrayList<>(); + Map map = new HashMap<>(); + int i = 0; // unique reference values for each expression + + List newGroupings = new ArrayList<>(); + + for (Aggregate.Grouping gp : aggregate.getGroupings()) { + // every grouping has an expression_reference list + List expr_refs = new ArrayList<>(); + + for (Expression e : gp.getExpressions()) { + int ref; + if (!map.containsKey(e)) { + groupingExpressions.add(this.toProto(e)); // put unique expressions into full list + ref = i; + map.put(e, i++); + } else { + ref = map.get(e); + } + expr_refs.add(ref); + } + + newGroupings.add( + AggregateRel.Grouping.newBuilder() + .addAllExpressionReferences(expr_refs) + .addAllGroupingExpressions( + gp.getExpressions().stream().map(this::toProto).collect(Collectors.toList())) + .build()); + } + AggregateRel.Builder builder = AggregateRel.newBuilder() .setInput(toProto(aggregate.getInput())) .setCommon(common(aggregate)) .addAllGroupings( - aggregate.getGroupings().stream().map(this::toProto).collect(Collectors.toList())) + newGroupings) // adding groupings with the expression references and grouping + // expressions set + .addAllGroupingExpressions(groupingExpressions) // new grouping_expression attribute .addAllMeasures( aggregate.getMeasures().stream().map(this::toProto).collect(Collectors.toList())); diff --git a/core/src/test/java/io/substrait/relation/AggregateRelTest.java b/core/src/test/java/io/substrait/relation/AggregateRelTest.java new file mode 100644 index 000000000..d71c90b85 --- /dev/null +++ b/core/src/test/java/io/substrait/relation/AggregateRelTest.java @@ -0,0 +1,140 @@ +package io.substrait.relation; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.TestBase; +import io.substrait.extension.ExtensionLookup; +import io.substrait.extension.ImmutableExtensionLookup; +import io.substrait.proto.AggregateRel; +import io.substrait.proto.Expression; +import io.substrait.proto.Plan; +import io.substrait.proto.Rel; +import org.junit.jupiter.api.Test; + +class AggregateRelTest extends TestBase { + + protected static final Plan plan = Plan.newBuilder().build(); + protected static final ExtensionLookup functionLookup = + ImmutableExtensionLookup.builder().from(plan).build(); + protected static final io.substrait.proto.NamedStruct namedStruct = createSchema(); + + public static io.substrait.proto.NamedStruct createSchema() { + + io.substrait.proto.Type i32Type = + io.substrait.proto.Type.newBuilder() + .setI32(io.substrait.proto.Type.I32.getDefaultInstance()) + .build(); + + // Build a NamedStruct schema with two fields: col1, col2 + io.substrait.proto.Type.Struct structType = + io.substrait.proto.Type.Struct.newBuilder().addTypes(i32Type).addTypes(i32Type).build(); + + return io.substrait.proto.NamedStruct.newBuilder() + .setStruct(structType) + .addNames("col1") + .addNames("col2") + .build(); + } + + public static io.substrait.proto.Expression createExpression(int col) { + // Build a ReferenceSegment that refers to struct field col + Expression.ReferenceSegment seg1 = + Expression.ReferenceSegment.newBuilder() + .setStructField( + Expression.ReferenceSegment.StructField.newBuilder().setField(col).build()) + .build(); + + // Build a FieldReference that uses the directReference and a rootReference + Expression.FieldReference fieldRef1 = + Expression.FieldReference.newBuilder() + .setDirectReference(seg1) + .setRootReference(Expression.FieldReference.RootReference.getDefaultInstance()) + .build(); + + // Wrap the FieldReference in an Expression.selection + return Expression.newBuilder().setSelection(fieldRef1).build(); + } + + @Test + public void testDeprecatedGroupingExpressionsAreMapped() { + Expression col1Ref = createExpression(0); + Expression col2Ref = createExpression(1); + + AggregateRel.Grouping grouping = + AggregateRel.Grouping.newBuilder() + .addGroupingExpressions(col1Ref) // deprecated proto form + .addGroupingExpressions(col2Ref) + .build(); + + // Build an input ReadRel + io.substrait.proto.ReadRel readProto = + io.substrait.proto.ReadRel.newBuilder().setBaseSchema(namedStruct).build(); + + // Build the AggregateRel with the new grouping_expressions field + AggregateRel aggrProto = + AggregateRel.newBuilder() + .setInput(io.substrait.proto.Rel.newBuilder().setRead(readProto)) + .addGroupings(grouping) + .build(); + + Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build(); + ProtoRelConverter converter = new ProtoRelConverter(functionLookup); + io.substrait.relation.Rel resultRel = converter.from(relProto); + + assertTrue(resultRel instanceof Aggregate); + Aggregate agg = (Aggregate) resultRel; + assertEquals(1, agg.getGroupings().size()); + assertEquals(2, agg.getGroupings().get(0).getExpressions().size()); + + // Relation to Proto where both deprecated and new form are implemented + RelProtoConverter relToProtoConverter = new RelProtoConverter(functionCollector); + Rel newProto = relToProtoConverter.toProto(resultRel); + + assertEquals(2, newProto.getAggregate().getGroupings(0).getExpressionReferencesCount()); + assertEquals(2, newProto.getAggregate().getGroupings(0).getGroupingExpressionsList().size()); + assertEquals(2, newProto.getAggregate().getGroupingExpressionsList().size()); + } + + @Test + public void testNewAggregateProtoForm() { + Expression col1Ref = createExpression(0); + Expression col2Ref = createExpression(1); + + AggregateRel.Grouping grouping = + AggregateRel.Grouping.newBuilder() + .addExpressionReferences(0) // new proto form + .addExpressionReferences(1) + .build(); + + // Build an input ReadRel + io.substrait.proto.ReadRel readProto = + io.substrait.proto.ReadRel.newBuilder().setBaseSchema(namedStruct).build(); + + // Build the AggregateRel with the new grouping_expressions field + AggregateRel aggrProto = + AggregateRel.newBuilder() + .setInput(io.substrait.proto.Rel.newBuilder().setRead(readProto)) + .addGroupingExpressions(col1Ref) + .addGroupingExpressions(col2Ref) + .addGroupings(grouping) + .build(); + + Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build(); + ProtoRelConverter converter = new ProtoRelConverter(functionLookup); + io.substrait.relation.Rel resultRel = converter.from(relProto); + + assertTrue(resultRel instanceof Aggregate); + Aggregate agg = (Aggregate) resultRel; + assertEquals(1, agg.getGroupings().size()); + assertEquals(2, agg.getGroupings().get(0).getExpressions().size()); + + // Relation to Proto where both deprecated and new form are implemented + RelProtoConverter relToProtoConverter = new RelProtoConverter(functionCollector); + Rel newProto = relToProtoConverter.toProto(resultRel); + + assertEquals(2, newProto.getAggregate().getGroupings(0).getExpressionReferencesCount()); + assertEquals(2, newProto.getAggregate().getGroupings(0).getGroupingExpressionsList().size()); + assertEquals(2, newProto.getAggregate().getGroupingExpressionsList().size()); + } +} From b793c857ea1be2637b510cbca29554c00b0f9bc6 Mon Sep 17 00:00:00 2001 From: Gordon Hamilton Date: Wed, 24 Sep 2025 11:46:39 -0400 Subject: [PATCH 5/9] fix: adding support for both deprecated and new proto representation of Grouping for the AggregateRel --- .../substrait/relation/ProtoRelConverter.java | 43 +++++++++++++++---- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 5ed4b5d4d..359288302 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -627,15 +627,40 @@ protected Aggregate newAggregate(AggregateRel rel) { new ProtoAggregateFunctionConverter(lookup, extensions, protoExprConverter); List groupings = new ArrayList<>(rel.getGroupingsCount()); - for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { - groupings.add( - Aggregate.Grouping.builder() - .expressions( - grouping.getGroupingExpressionsList().stream() - .map(protoExprConverter::from) - .collect(java.util.stream.Collectors.toList())) - .build()); - } + +// the deprecated form of Grouping is not used + if (!rel.getGroupingExpressionsList().isEmpty()) { + List allGroupingKeys = rel.getGroupingExpressionsList(); + +// for every grouping object on aggregate, it has a list of references into the aggregate's expressionList for the specific sorting set + for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { + List groupingKeys = new ArrayList<>(); + for (int key: grouping.getExpressionReferencesList()) { + groupingKeys.add(allGroupingKeys.get(key)); + } + groupings.add( + Aggregate.Grouping.builder() + .expressions( + groupingKeys.stream() + .map(protoExprConverter::from) + .collect(Collectors.toList())) + .build()); + } + Aggregate.builder().input(input).groupings(groupings); + }else{ + // using the deprecated form of Grouping and Aggregate + for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { + groupings.add( + Aggregate.Grouping.builder() + .expressions( + grouping.getGroupingExpressionsList().stream() + .map(protoExprConverter::from) + .collect(Collectors.toList())) + .build()); + } + } + + List measures = new ArrayList<>(rel.getMeasuresCount()); for (AggregateRel.Measure measure : rel.getMeasuresList()) { measures.add( From 786154910df5d4cd6bad353c36ee33c1485c2c58 Mon Sep 17 00:00:00 2001 From: Gordon Hamilton Date: Wed, 24 Sep 2025 15:01:06 -0400 Subject: [PATCH 6/9] fix: adding support for both deprecated and new proto of Groupings for the AggregateRel --- .../substrait/relation/ProtoRelConverter.java | 62 +++++++++---------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 359288302..dda527831 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -628,38 +628,38 @@ protected Aggregate newAggregate(AggregateRel rel) { List groupings = new ArrayList<>(rel.getGroupingsCount()); -// the deprecated form of Grouping is not used - if (!rel.getGroupingExpressionsList().isEmpty()) { - List allGroupingKeys = rel.getGroupingExpressionsList(); - -// for every grouping object on aggregate, it has a list of references into the aggregate's expressionList for the specific sorting set - for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { - List groupingKeys = new ArrayList<>(); - for (int key: grouping.getExpressionReferencesList()) { - groupingKeys.add(allGroupingKeys.get(key)); - } - groupings.add( - Aggregate.Grouping.builder() - .expressions( - groupingKeys.stream() - .map(protoExprConverter::from) - .collect(Collectors.toList())) - .build()); - } - Aggregate.builder().input(input).groupings(groupings); - }else{ - // using the deprecated form of Grouping and Aggregate - for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { - groupings.add( - Aggregate.Grouping.builder() - .expressions( - grouping.getGroupingExpressionsList().stream() - .map(protoExprConverter::from) - .collect(Collectors.toList())) - .build()); - } + // the deprecated form of Grouping is not used + if (!rel.getGroupingExpressionsList().isEmpty()) { + List allGroupingKeys = rel.getGroupingExpressionsList(); + + // for every grouping object on aggregate, it has a list of references into the + // aggregate's expressionList for the specific sorting set + for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { + List groupingKeys = new ArrayList<>(); + for (int key : grouping.getExpressionReferencesList()) { + groupingKeys.add(allGroupingKeys.get(key)); + } + groupings.add( + Aggregate.Grouping.builder() + .expressions( + groupingKeys.stream() + .map(protoExprConverter::from) + .collect(Collectors.toList())) + .build()); } - + Aggregate.builder().input(input).groupings(groupings); + } else { + // using the deprecated form of Grouping and Aggregate + for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { + groupings.add( + Aggregate.Grouping.builder() + .expressions( + grouping.getGroupingExpressionsList().stream() + .map(protoExprConverter::from) + .collect(Collectors.toList())) + .build()); + } + } List measures = new ArrayList<>(rel.getMeasuresCount()); for (AggregateRel.Measure measure : rel.getMeasuresList()) { From 0a13b155732132f8fb4c24d2c397f074eec25409 Mon Sep 17 00:00:00 2001 From: Gordon Hamilton Date: Mon, 29 Sep 2025 17:05:25 -0400 Subject: [PATCH 7/9] fix: add tests, fixed bugs in code, and added support for both forms of proto in RelProtoConverter --- .../substrait/relation/ProtoRelConverter.java | 23 ++- .../substrait/relation/RelProtoConverter.java | 38 ++++- .../substrait/relation/AggregateRelTest.java | 140 ++++++++++++++++++ 3 files changed, 187 insertions(+), 14 deletions(-) create mode 100644 core/src/test/java/io/substrait/relation/AggregateRelTest.java diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index dda527831..d6c041dad 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -628,26 +628,23 @@ protected Aggregate newAggregate(AggregateRel rel) { List groupings = new ArrayList<>(rel.getGroupingsCount()); - // the deprecated form of Grouping is not used + // new proto form is used if (!rel.getGroupingExpressionsList().isEmpty()) { + List allGroupingKeys = rel.getGroupingExpressionsList(); - // for every grouping object on aggregate, it has a list of references into the - // aggregate's expressionList for the specific sorting set - for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { - List groupingKeys = new ArrayList<>(); - for (int key : grouping.getExpressionReferencesList()) { - groupingKeys.add(allGroupingKeys.get(key)); - } - groupings.add( + for (int i = 0; i < rel.getGroupingsList().size(); i++) { + // put all groupingExpressions into the group + Aggregate.Grouping group = Aggregate.Grouping.builder() .expressions( - groupingKeys.stream() + allGroupingKeys.stream() .map(protoExprConverter::from) - .collect(Collectors.toList())) - .build()); + .collect(java.util.stream.Collectors.toList())) + .build(); + groupings.add(group); } - Aggregate.builder().input(input).groupings(groupings); + } else { // using the deprecated form of Grouping and Aggregate for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index a1da96b03..dc7a42a2a 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -47,8 +47,11 @@ import io.substrait.relation.physical.NestedLoopJoin; import io.substrait.type.proto.TypeProtoConverter; import io.substrait.util.EmptyVisitationContext; +import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -117,12 +120,45 @@ private io.substrait.proto.Expression.FieldReference toProto(FieldReference fiel @Override public Rel visit(Aggregate aggregate, EmptyVisitationContext context) throws RuntimeException { + + List groupingExpressions = new ArrayList<>(); + Map map = new HashMap<>(); + int i = 0; // unique reference values for each expression + + List newGroupings = new ArrayList<>(); + + for (Aggregate.Grouping gp : aggregate.getGroupings()) { + // every grouping has an expression_reference list + List expr_refs = new ArrayList<>(); + + for (Expression e : gp.getExpressions()) { + int ref; + if (!map.containsKey(e)) { + groupingExpressions.add(this.toProto(e)); // put unique expressions into full list + ref = i; + map.put(e, i++); + } else { + ref = map.get(e); + } + expr_refs.add(ref); + } + + newGroupings.add( + AggregateRel.Grouping.newBuilder() + .addAllExpressionReferences(expr_refs) + .addAllGroupingExpressions( + gp.getExpressions().stream().map(this::toProto).collect(Collectors.toList())) + .build()); + } + AggregateRel.Builder builder = AggregateRel.newBuilder() .setInput(toProto(aggregate.getInput())) .setCommon(common(aggregate)) .addAllGroupings( - aggregate.getGroupings().stream().map(this::toProto).collect(Collectors.toList())) + newGroupings) // adding groupings with the expression references and grouping + // expressions set + .addAllGroupingExpressions(groupingExpressions) // new grouping_expression attribute .addAllMeasures( aggregate.getMeasures().stream().map(this::toProto).collect(Collectors.toList())); diff --git a/core/src/test/java/io/substrait/relation/AggregateRelTest.java b/core/src/test/java/io/substrait/relation/AggregateRelTest.java new file mode 100644 index 000000000..d71c90b85 --- /dev/null +++ b/core/src/test/java/io/substrait/relation/AggregateRelTest.java @@ -0,0 +1,140 @@ +package io.substrait.relation; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.TestBase; +import io.substrait.extension.ExtensionLookup; +import io.substrait.extension.ImmutableExtensionLookup; +import io.substrait.proto.AggregateRel; +import io.substrait.proto.Expression; +import io.substrait.proto.Plan; +import io.substrait.proto.Rel; +import org.junit.jupiter.api.Test; + +class AggregateRelTest extends TestBase { + + protected static final Plan plan = Plan.newBuilder().build(); + protected static final ExtensionLookup functionLookup = + ImmutableExtensionLookup.builder().from(plan).build(); + protected static final io.substrait.proto.NamedStruct namedStruct = createSchema(); + + public static io.substrait.proto.NamedStruct createSchema() { + + io.substrait.proto.Type i32Type = + io.substrait.proto.Type.newBuilder() + .setI32(io.substrait.proto.Type.I32.getDefaultInstance()) + .build(); + + // Build a NamedStruct schema with two fields: col1, col2 + io.substrait.proto.Type.Struct structType = + io.substrait.proto.Type.Struct.newBuilder().addTypes(i32Type).addTypes(i32Type).build(); + + return io.substrait.proto.NamedStruct.newBuilder() + .setStruct(structType) + .addNames("col1") + .addNames("col2") + .build(); + } + + public static io.substrait.proto.Expression createExpression(int col) { + // Build a ReferenceSegment that refers to struct field col + Expression.ReferenceSegment seg1 = + Expression.ReferenceSegment.newBuilder() + .setStructField( + Expression.ReferenceSegment.StructField.newBuilder().setField(col).build()) + .build(); + + // Build a FieldReference that uses the directReference and a rootReference + Expression.FieldReference fieldRef1 = + Expression.FieldReference.newBuilder() + .setDirectReference(seg1) + .setRootReference(Expression.FieldReference.RootReference.getDefaultInstance()) + .build(); + + // Wrap the FieldReference in an Expression.selection + return Expression.newBuilder().setSelection(fieldRef1).build(); + } + + @Test + public void testDeprecatedGroupingExpressionsAreMapped() { + Expression col1Ref = createExpression(0); + Expression col2Ref = createExpression(1); + + AggregateRel.Grouping grouping = + AggregateRel.Grouping.newBuilder() + .addGroupingExpressions(col1Ref) // deprecated proto form + .addGroupingExpressions(col2Ref) + .build(); + + // Build an input ReadRel + io.substrait.proto.ReadRel readProto = + io.substrait.proto.ReadRel.newBuilder().setBaseSchema(namedStruct).build(); + + // Build the AggregateRel with the new grouping_expressions field + AggregateRel aggrProto = + AggregateRel.newBuilder() + .setInput(io.substrait.proto.Rel.newBuilder().setRead(readProto)) + .addGroupings(grouping) + .build(); + + Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build(); + ProtoRelConverter converter = new ProtoRelConverter(functionLookup); + io.substrait.relation.Rel resultRel = converter.from(relProto); + + assertTrue(resultRel instanceof Aggregate); + Aggregate agg = (Aggregate) resultRel; + assertEquals(1, agg.getGroupings().size()); + assertEquals(2, agg.getGroupings().get(0).getExpressions().size()); + + // Relation to Proto where both deprecated and new form are implemented + RelProtoConverter relToProtoConverter = new RelProtoConverter(functionCollector); + Rel newProto = relToProtoConverter.toProto(resultRel); + + assertEquals(2, newProto.getAggregate().getGroupings(0).getExpressionReferencesCount()); + assertEquals(2, newProto.getAggregate().getGroupings(0).getGroupingExpressionsList().size()); + assertEquals(2, newProto.getAggregate().getGroupingExpressionsList().size()); + } + + @Test + public void testNewAggregateProtoForm() { + Expression col1Ref = createExpression(0); + Expression col2Ref = createExpression(1); + + AggregateRel.Grouping grouping = + AggregateRel.Grouping.newBuilder() + .addExpressionReferences(0) // new proto form + .addExpressionReferences(1) + .build(); + + // Build an input ReadRel + io.substrait.proto.ReadRel readProto = + io.substrait.proto.ReadRel.newBuilder().setBaseSchema(namedStruct).build(); + + // Build the AggregateRel with the new grouping_expressions field + AggregateRel aggrProto = + AggregateRel.newBuilder() + .setInput(io.substrait.proto.Rel.newBuilder().setRead(readProto)) + .addGroupingExpressions(col1Ref) + .addGroupingExpressions(col2Ref) + .addGroupings(grouping) + .build(); + + Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build(); + ProtoRelConverter converter = new ProtoRelConverter(functionLookup); + io.substrait.relation.Rel resultRel = converter.from(relProto); + + assertTrue(resultRel instanceof Aggregate); + Aggregate agg = (Aggregate) resultRel; + assertEquals(1, agg.getGroupings().size()); + assertEquals(2, agg.getGroupings().get(0).getExpressions().size()); + + // Relation to Proto where both deprecated and new form are implemented + RelProtoConverter relToProtoConverter = new RelProtoConverter(functionCollector); + Rel newProto = relToProtoConverter.toProto(resultRel); + + assertEquals(2, newProto.getAggregate().getGroupings(0).getExpressionReferencesCount()); + assertEquals(2, newProto.getAggregate().getGroupings(0).getGroupingExpressionsList().size()); + assertEquals(2, newProto.getAggregate().getGroupingExpressionsList().size()); + } +} From cc7ad0b776105765ea070c7c6a51c469098291ec Mon Sep 17 00:00:00 2001 From: Gordon Hamilton Date: Thu, 2 Oct 2025 09:50:00 -0400 Subject: [PATCH 8/9] fix: fixed grouping creation, removed rel to proto support and added multi grouping set test --- .../substrait/relation/ProtoRelConverter.java | 20 +++--- .../substrait/relation/RelProtoConverter.java | 38 +--------- .../substrait/relation/AggregateRelTest.java | 71 ++++++++++++------- 3 files changed, 55 insertions(+), 74 deletions(-) diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index 9913f0188..7759a7f97 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -632,18 +632,14 @@ protected Aggregate newAggregate(AggregateRel rel) { List allGroupingKeys = rel.getGroupingExpressionsList(); - for (int i = 0; i < rel.getGroupingsList().size(); i++) { - // put all groupingExpressions into the group - Aggregate.Grouping group = - Aggregate.Grouping.builder() - .expressions( - allGroupingKeys.stream() - .map(protoExprConverter::from) - .collect(java.util.stream.Collectors.toList())) - .build(); - groupings.add(group); + for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { + List references = grouping.getExpressionReferencesList(); + List groupExpressions = new ArrayList<>(); + for (int ref : references) { + groupExpressions.add(protoExprConverter.from(allGroupingKeys.get(ref))); + } + groupings.add(Aggregate.Grouping.builder().addAllExpressions(groupExpressions).build()); } - } else { // using the deprecated form of Grouping and Aggregate for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { @@ -652,7 +648,7 @@ protected Aggregate newAggregate(AggregateRel rel) { .expressions( grouping.getGroupingExpressionsList().stream() .map(protoExprConverter::from) - .collect(Collectors.toList())) + .collect(java.util.stream.Collectors.toList())) .build()); } } diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index dc7a42a2a..a1da96b03 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -47,11 +47,8 @@ import io.substrait.relation.physical.NestedLoopJoin; import io.substrait.type.proto.TypeProtoConverter; import io.substrait.util.EmptyVisitationContext; -import java.util.ArrayList; import java.util.Collection; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -120,45 +117,12 @@ private io.substrait.proto.Expression.FieldReference toProto(FieldReference fiel @Override public Rel visit(Aggregate aggregate, EmptyVisitationContext context) throws RuntimeException { - - List groupingExpressions = new ArrayList<>(); - Map map = new HashMap<>(); - int i = 0; // unique reference values for each expression - - List newGroupings = new ArrayList<>(); - - for (Aggregate.Grouping gp : aggregate.getGroupings()) { - // every grouping has an expression_reference list - List expr_refs = new ArrayList<>(); - - for (Expression e : gp.getExpressions()) { - int ref; - if (!map.containsKey(e)) { - groupingExpressions.add(this.toProto(e)); // put unique expressions into full list - ref = i; - map.put(e, i++); - } else { - ref = map.get(e); - } - expr_refs.add(ref); - } - - newGroupings.add( - AggregateRel.Grouping.newBuilder() - .addAllExpressionReferences(expr_refs) - .addAllGroupingExpressions( - gp.getExpressions().stream().map(this::toProto).collect(Collectors.toList())) - .build()); - } - AggregateRel.Builder builder = AggregateRel.newBuilder() .setInput(toProto(aggregate.getInput())) .setCommon(common(aggregate)) .addAllGroupings( - newGroupings) // adding groupings with the expression references and grouping - // expressions set - .addAllGroupingExpressions(groupingExpressions) // new grouping_expression attribute + aggregate.getGroupings().stream().map(this::toProto).collect(Collectors.toList())) .addAllMeasures( aggregate.getMeasures().stream().map(this::toProto).collect(Collectors.toList())); diff --git a/core/src/test/java/io/substrait/relation/AggregateRelTest.java b/core/src/test/java/io/substrait/relation/AggregateRelTest.java index d71c90b85..b18623b33 100644 --- a/core/src/test/java/io/substrait/relation/AggregateRelTest.java +++ b/core/src/test/java/io/substrait/relation/AggregateRelTest.java @@ -9,6 +9,7 @@ import io.substrait.proto.AggregateRel; import io.substrait.proto.Expression; import io.substrait.proto.Plan; +import io.substrait.proto.ReadRel; import io.substrait.proto.Rel; import org.junit.jupiter.api.Test; @@ -37,7 +38,7 @@ public static io.substrait.proto.NamedStruct createSchema() { .build(); } - public static io.substrait.proto.Expression createExpression(int col) { + public static io.substrait.proto.Expression createFieldReference(int col) { // Build a ReferenceSegment that refers to struct field col Expression.ReferenceSegment seg1 = Expression.ReferenceSegment.newBuilder() @@ -58,8 +59,8 @@ public static io.substrait.proto.Expression createExpression(int col) { @Test public void testDeprecatedGroupingExpressionsAreMapped() { - Expression col1Ref = createExpression(0); - Expression col2Ref = createExpression(1); + Expression col1Ref = createFieldReference(0); + Expression col2Ref = createFieldReference(1); AggregateRel.Grouping grouping = AggregateRel.Grouping.newBuilder() @@ -68,13 +69,12 @@ public void testDeprecatedGroupingExpressionsAreMapped() { .build(); // Build an input ReadRel - io.substrait.proto.ReadRel readProto = - io.substrait.proto.ReadRel.newBuilder().setBaseSchema(namedStruct).build(); + ReadRel readProto = ReadRel.newBuilder().setBaseSchema(namedStruct).build(); // Build the AggregateRel with the new grouping_expressions field AggregateRel aggrProto = AggregateRel.newBuilder() - .setInput(io.substrait.proto.Rel.newBuilder().setRead(readProto)) + .setInput(Rel.newBuilder().setRead(readProto)) .addGroupings(grouping) .build(); @@ -86,20 +86,12 @@ public void testDeprecatedGroupingExpressionsAreMapped() { Aggregate agg = (Aggregate) resultRel; assertEquals(1, agg.getGroupings().size()); assertEquals(2, agg.getGroupings().get(0).getExpressions().size()); - - // Relation to Proto where both deprecated and new form are implemented - RelProtoConverter relToProtoConverter = new RelProtoConverter(functionCollector); - Rel newProto = relToProtoConverter.toProto(resultRel); - - assertEquals(2, newProto.getAggregate().getGroupings(0).getExpressionReferencesCount()); - assertEquals(2, newProto.getAggregate().getGroupings(0).getGroupingExpressionsList().size()); - assertEquals(2, newProto.getAggregate().getGroupingExpressionsList().size()); } @Test public void testNewAggregateProtoForm() { - Expression col1Ref = createExpression(0); - Expression col2Ref = createExpression(1); + Expression col1Ref = createFieldReference(0); + Expression col2Ref = createFieldReference(1); AggregateRel.Grouping grouping = AggregateRel.Grouping.newBuilder() @@ -108,13 +100,12 @@ public void testNewAggregateProtoForm() { .build(); // Build an input ReadRel - io.substrait.proto.ReadRel readProto = - io.substrait.proto.ReadRel.newBuilder().setBaseSchema(namedStruct).build(); + ReadRel readProto = ReadRel.newBuilder().setBaseSchema(namedStruct).build(); // Build the AggregateRel with the new grouping_expressions field AggregateRel aggrProto = AggregateRel.newBuilder() - .setInput(io.substrait.proto.Rel.newBuilder().setRead(readProto)) + .setInput(Rel.newBuilder().setRead(readProto)) .addGroupingExpressions(col1Ref) .addGroupingExpressions(col2Ref) .addGroupings(grouping) @@ -128,13 +119,43 @@ public void testNewAggregateProtoForm() { Aggregate agg = (Aggregate) resultRel; assertEquals(1, agg.getGroupings().size()); assertEquals(2, agg.getGroupings().get(0).getExpressions().size()); + } + + @Test + public void testNewAggregateProtoFormMultipleGroupings() { + Expression col1Ref = createFieldReference(0); + Expression col2Ref = createFieldReference(1); + + AggregateRel.Grouping grouping1 = + AggregateRel.Grouping.newBuilder() + .addExpressionReferences(0) // new proto form + .addExpressionReferences(1) + .build(); + + AggregateRel.Grouping grouping2 = + AggregateRel.Grouping.newBuilder().addExpressionReferences(1).build(); + + // Build an input ReadRel + ReadRel readProto = ReadRel.newBuilder().setBaseSchema(namedStruct).build(); - // Relation to Proto where both deprecated and new form are implemented - RelProtoConverter relToProtoConverter = new RelProtoConverter(functionCollector); - Rel newProto = relToProtoConverter.toProto(resultRel); + // Build the AggregateRel with the new grouping_expressions field + AggregateRel aggrProto = + AggregateRel.newBuilder() + .setInput(Rel.newBuilder().setRead(readProto)) + .addGroupingExpressions(col1Ref) + .addGroupingExpressions(col2Ref) + .addGroupings(grouping1) + .addGroupings(grouping2) + .build(); - assertEquals(2, newProto.getAggregate().getGroupings(0).getExpressionReferencesCount()); - assertEquals(2, newProto.getAggregate().getGroupings(0).getGroupingExpressionsList().size()); - assertEquals(2, newProto.getAggregate().getGroupingExpressionsList().size()); + Rel relProto = Rel.newBuilder().setAggregate(aggrProto).build(); + ProtoRelConverter converter = new ProtoRelConverter(functionLookup); + io.substrait.relation.Rel resultRel = converter.from(relProto); + + assertTrue(resultRel instanceof Aggregate); + Aggregate agg = (Aggregate) resultRel; + assertEquals(2, agg.getGroupings().size()); + assertEquals(2, agg.getGroupings().get(0).getExpressions().size()); + assertEquals(1, agg.getGroupings().get(1).getExpressions().size()); } } From d3cf639c58cae9d15f6c200402d99ae1fae16e9f Mon Sep 17 00:00:00 2001 From: Gordon Hamilton Date: Fri, 3 Oct 2025 10:21:51 -0400 Subject: [PATCH 9/9] fix: improving naming conventions and minor pre-converting outside of for loop --- .../io/substrait/relation/ProtoRelConverter.java | 12 +++++++----- .../java/io/substrait/relation/AggregateRelTest.java | 9 ++++----- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index c03ca7565..5c13968bd 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -628,22 +628,24 @@ protected Aggregate newAggregate(AggregateRel rel) { List groupings = new ArrayList<>(rel.getGroupingsCount()); - // new proto form is used + // Groupings are set using the AggregateRel grouping_expression mechanism if (!rel.getGroupingExpressionsList().isEmpty()) { - - List allGroupingKeys = rel.getGroupingExpressionsList(); + List allGroupingExpressions = + rel.getGroupingExpressionsList().stream() + .map(protoExprConverter::from) + .collect(java.util.stream.Collectors.toList()); for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { List references = grouping.getExpressionReferencesList(); List groupExpressions = new ArrayList<>(); for (int ref : references) { - groupExpressions.add(protoExprConverter.from(allGroupingKeys.get(ref))); + groupExpressions.add(allGroupingExpressions.get(ref)); } groupings.add(Aggregate.Grouping.builder().addAllExpressions(groupExpressions).build()); } } else { - // using the deprecated form of Grouping and Aggregate + // Groupings are set using the deprecated Grouping grouping_expressions mechanism for (AggregateRel.Grouping grouping : rel.getGroupingsList()) { groupings.add( Aggregate.Grouping.builder() diff --git a/core/src/test/java/io/substrait/relation/AggregateRelTest.java b/core/src/test/java/io/substrait/relation/AggregateRelTest.java index 33fefd593..4d0b604aa 100644 --- a/core/src/test/java/io/substrait/relation/AggregateRelTest.java +++ b/core/src/test/java/io/substrait/relation/AggregateRelTest.java @@ -58,7 +58,7 @@ public static io.substrait.proto.Expression createFieldReference(int col) { } @Test - public void testDeprecatedGroupingExpressionsAreMapped() { + public void testDeprecatedGroupingExpressionConversion() { Expression col1Ref = createFieldReference(0); Expression col2Ref = createFieldReference(1); @@ -71,7 +71,6 @@ public void testDeprecatedGroupingExpressionsAreMapped() { // Build an input ReadRel ReadRel readProto = ReadRel.newBuilder().setBaseSchema(namedStruct).build(); - // Build the AggregateRel with the new grouping_expressions field AggregateRel aggrProto = AggregateRel.newBuilder() @@ -90,13 +89,13 @@ public void testDeprecatedGroupingExpressionsAreMapped() { } @Test - public void testNewAggregateProtoForm() { + public void testAggregateWithSingleGrouping() { Expression col1Ref = createFieldReference(0); Expression col2Ref = createFieldReference(1); AggregateRel.Grouping grouping = AggregateRel.Grouping.newBuilder() - .addExpressionReferences(0) // new proto form + .addExpressionReferences(0) .addExpressionReferences(1) .build(); @@ -123,7 +122,7 @@ public void testNewAggregateProtoForm() { } @Test - public void testNewAggregateProtoFormMultipleGroupings() { + public void testAggregateWithMultipleGroupings() { Expression col1Ref = createFieldReference(0); Expression col2Ref = createFieldReference(1);