Skip to content

Commit dfd50eb

Browse files
authored
Merge branch 'main' into gordon.hamilton/aggregateGroupingNewSubstraitForm
2 parents f854ed1 + c2eb5f7 commit dfd50eb

File tree

22 files changed

+371
-88
lines changed

22 files changed

+371
-88
lines changed

build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ allprojects {
5252
googleJavaFormat()
5353
removeUnusedImports()
5454
trimTrailingWhitespace()
55-
removeWildcardImports()
55+
forbidWildcardImports()
5656
}
5757
}
5858
}

core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import io.substrait.expression.Expression;
44
import io.substrait.expression.proto.ProtoExpressionConverter;
5+
import io.substrait.extension.DefaultExtensionCatalog;
56
import io.substrait.extension.ExtensionCollector;
67
import io.substrait.extension.ExtensionLookup;
78
import io.substrait.extension.ImmutableExtensionLookup;
@@ -23,7 +24,7 @@ public class ProtoExtendedExpressionConverter {
2324
new ExtensionCollector(), SimpleExtension.ExtensionCollection.builder().build());
2425

2526
public ProtoExtendedExpressionConverter() {
26-
this(SimpleExtension.loadDefaults());
27+
this(DefaultExtensionCatalog.DEFAULT_COLLECTION);
2728
}
2829

2930
public ProtoExtendedExpressionConverter(SimpleExtension.ExtensionCollection extensionCollection) {

core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
package io.substrait.extension;
22

3+
import java.util.Arrays;
4+
import java.util.List;
5+
import java.util.stream.Collectors;
6+
37
public class DefaultExtensionCatalog {
48
public static final String FUNCTIONS_AGGREGATE_APPROX = "/functions_aggregate_approx.yaml";
59
public static final String FUNCTIONS_AGGREGATE_GENERIC = "/functions_aggregate_generic.yaml";
@@ -14,4 +18,28 @@ public class DefaultExtensionCatalog {
1418
public static final String FUNCTIONS_ROUNDING_DECIMAL = "/functions_rounding_decimal.yaml";
1519
public static final String FUNCTIONS_SET = "/functions_set.yaml";
1620
public static final String FUNCTIONS_STRING = "/functions_string.yaml";
21+
22+
public static final SimpleExtension.ExtensionCollection DEFAULT_COLLECTION =
23+
loadDefaultCollection();
24+
25+
private static SimpleExtension.ExtensionCollection loadDefaultCollection() {
26+
List<String> defaultFiles =
27+
Arrays.asList(
28+
"boolean",
29+
"aggregate_generic",
30+
"aggregate_approx",
31+
"arithmetic_decimal",
32+
"arithmetic",
33+
"comparison",
34+
"datetime",
35+
"logarithmic",
36+
"rounding",
37+
"rounding_decimal",
38+
"string")
39+
.stream()
40+
.map(c -> String.format("/functions_%s.yaml", c))
41+
.collect(Collectors.toList());
42+
43+
return SimpleExtension.load(defaultFiles);
44+
}
1745
}

core/src/main/java/io/substrait/extension/SimpleExtension.java

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import java.io.IOException;
2323
import java.io.InputStream;
2424
import java.io.UncheckedIOException;
25-
import java.util.Arrays;
2625
import java.util.List;
2726
import java.util.Map;
2827
import java.util.Optional;
@@ -701,27 +700,6 @@ public ExtensionCollection merge(ExtensionCollection extensionCollection) {
701700
}
702701
}
703702

704-
public static ExtensionCollection loadDefaults() {
705-
List<String> defaultFiles =
706-
Arrays.asList(
707-
"boolean",
708-
"aggregate_generic",
709-
"aggregate_approx",
710-
"arithmetic_decimal",
711-
"arithmetic",
712-
"comparison",
713-
"datetime",
714-
"logarithmic",
715-
"rounding",
716-
"rounding_decimal",
717-
"string")
718-
.stream()
719-
.map(c -> String.format("/functions_%s.yaml", c))
720-
.collect(Collectors.toList());
721-
722-
return load(defaultFiles);
723-
}
724-
725703
public static ExtensionCollection load(List<String> resourcePaths) {
726704
if (resourcePaths.isEmpty()) {
727705
throw new IllegalArgumentException("Require at least one resource path.");

core/src/main/java/io/substrait/plan/ProtoPlanConverter.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package io.substrait.plan;
22

3+
import io.substrait.extension.DefaultExtensionCatalog;
34
import io.substrait.extension.ExtensionLookup;
45
import io.substrait.extension.ImmutableExtensionLookup;
56
import io.substrait.extension.SimpleExtension;
@@ -16,7 +17,7 @@ public class ProtoPlanConverter {
1617
protected final SimpleExtension.ExtensionCollection extensionCollection;
1718

1819
public ProtoPlanConverter() {
19-
this(SimpleExtension.loadDefaults());
20+
this(DefaultExtensionCatalog.DEFAULT_COLLECTION);
2021
}
2122

2223
public ProtoPlanConverter(SimpleExtension.ExtensionCollection extensionCollection) {

core/src/main/java/io/substrait/relation/Aggregate.java

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,42 @@ public abstract class Aggregate extends SingleInputRel implements HasExtension {
2121

2222
@Override
2323
protected Type.Struct deriveRecordType() {
24-
return TypeCreator.REQUIRED.struct(
25-
Stream.concat(
26-
// unique grouping expressions
27-
getGroupings().stream()
28-
.flatMap(g -> g.getExpressions().stream())
29-
.collect(Collectors.toCollection(LinkedHashSet::new))
30-
.stream()
31-
.map(Expression::getType),
32-
33-
// measures
34-
getMeasures().stream().map(t -> t.getFunction().getType())));
24+
// If there's only one grouping set (or none), the nullability rule doesn't apply.
25+
if (getGroupings().size() <= 1) {
26+
final Stream<Type> groupingTypes =
27+
getGroupings().stream()
28+
.flatMap(g -> g.getExpressions().stream())
29+
.map(Expression::getType);
30+
final Stream<Type> measureTypes = getMeasures().stream().map(t -> t.getFunction().getType());
31+
return TypeCreator.REQUIRED.struct(Stream.concat(groupingTypes, measureTypes));
32+
}
33+
34+
final LinkedHashSet<Expression> uniqueGroupingExpressions =
35+
getGroupings().stream()
36+
.flatMap(g -> g.getExpressions().stream())
37+
.collect(Collectors.toCollection(LinkedHashSet::new));
38+
39+
// For each unique grouping expression, determine its final nullability based on the spec.
40+
final Stream<Type> groupingTypes =
41+
uniqueGroupingExpressions.stream()
42+
.map(
43+
expr -> {
44+
// the code below implements the following statement from the spec
45+
// (https://substrait.io/relations/logical_relations/#aggregate-operation):
46+
// "The values for the grouping expression columns that are not
47+
// part of the grouping set for a particular record will be set to null."
48+
final boolean appearsInAllSets =
49+
getGroupings().stream().allMatch(g -> g.getExpressions().contains(expr));
50+
if (appearsInAllSets) {
51+
return expr.getType();
52+
} else {
53+
return TypeCreator.asNullable(expr.getType());
54+
}
55+
});
56+
57+
final Stream<Type> measureTypes = getMeasures().stream().map(t -> t.getFunction().getType());
58+
59+
return TypeCreator.REQUIRED.struct(Stream.concat(groupingTypes, measureTypes));
3560
}
3661

3762
@Override

core/src/main/java/io/substrait/relation/ProtoAggregateFunctionConverter.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import io.substrait.expression.FunctionArg;
66
import io.substrait.expression.FunctionOption;
77
import io.substrait.expression.proto.ProtoExpressionConverter;
8+
import io.substrait.extension.DefaultExtensionCatalog;
89
import io.substrait.extension.ExtensionLookup;
910
import io.substrait.extension.SimpleExtension;
1011
import io.substrait.type.proto.ProtoTypeConverter;
@@ -24,7 +25,7 @@ public class ProtoAggregateFunctionConverter {
2425

2526
public ProtoAggregateFunctionConverter(
2627
ExtensionLookup lookup, ProtoExpressionConverter protoExpressionConverter) {
27-
this(lookup, SimpleExtension.loadDefaults(), protoExpressionConverter);
28+
this(lookup, DefaultExtensionCatalog.DEFAULT_COLLECTION, protoExpressionConverter);
2829
}
2930

3031
public ProtoAggregateFunctionConverter(

core/src/main/java/io/substrait/relation/ProtoRelConverter.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import io.substrait.expression.Expression;
44
import io.substrait.expression.proto.ProtoExpressionConverter;
55
import io.substrait.extension.AdvancedExtension;
6+
import io.substrait.extension.DefaultExtensionCatalog;
67
import io.substrait.extension.ExtensionLookup;
78
import io.substrait.extension.SimpleExtension;
89
import io.substrait.hint.Hint;
@@ -55,7 +56,7 @@ public class ProtoRelConverter {
5556
private final ProtoTypeConverter protoTypeConverter;
5657

5758
public ProtoRelConverter(ExtensionLookup lookup) {
58-
this(lookup, SimpleExtension.loadDefaults());
59+
this(lookup, DefaultExtensionCatalog.DEFAULT_COLLECTION);
5960
}
6061

6162
public ProtoRelConverter(ExtensionLookup lookup, SimpleExtension.ExtensionCollection extensions) {

core/src/test/java/io/substrait/TestBase.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import static org.junit.jupiter.api.Assertions.assertEquals;
44

55
import io.substrait.dsl.SubstraitBuilder;
6+
import io.substrait.extension.DefaultExtensionCatalog;
67
import io.substrait.extension.ExtensionCollector;
78
import io.substrait.extension.SimpleExtension;
89
import io.substrait.relation.ProtoRelConverter;
@@ -13,7 +14,7 @@
1314
public abstract class TestBase {
1415

1516
protected static final SimpleExtension.ExtensionCollection defaultExtensionCollection =
16-
SimpleExtension.loadDefaults();
17+
DefaultExtensionCatalog.DEFAULT_COLLECTION;
1718

1819
protected TypeCreator R = TypeCreator.REQUIRED;
1920

gradle/libs.versions.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ graal = "25.0.0"
66
graal-plugin = "0.11.0"
77
gradle-extensions = "2.0.0"
88
guava = "33.5.0-jre"
9-
httpclient5 = "5.5"
10-
immutables = "2.11.3"
9+
httpclient5 = "5.5.1"
10+
immutables = "2.11.4"
1111
jackson = "2.20.0"
1212
jreleaser = "1.20.0"
1313
json-smart = "2.6.0"
@@ -20,10 +20,10 @@ reflections = "0.9.12"
2020
scala-library = "2.12.20"
2121
scalatest = "3.2.19"
2222
scalatestplus-junit5 = "3.2.19.0"
23-
shadow = "9.1.0"
23+
shadow = "9.2.2"
2424
slf4j = "2.0.17"
2525
spark = "3.4.4"
26-
spotless = "7.2.1"
26+
spotless = "8.0.0"
2727

2828
[libraries]
2929
antlr4 = { module = "org.antlr:antlr4", version.ref = "antlr" }

0 commit comments

Comments
 (0)