Skip to content

Commit bfab337

Browse files
committed
properly expose uris and urns in POJO plans
1 parent 7d9e184 commit bfab337

File tree

7 files changed

+221
-13
lines changed

7 files changed

+221
-13
lines changed

core/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ dependencies {
117117
testImplementation(platform("org.junit:junit-bom:${JUNIT_VERSION}"))
118118
testImplementation("org.junit.jupiter:junit-jupiter")
119119
testRuntimeOnly("org.junit.platform:junit-platform-launcher")
120+
testRuntimeOnly("org.jetbrains.kotlin:kotlin-stdlib:${properties.get("kotlin.version")}")
120121
api("com.google.protobuf:protobuf-java:${PROTOBUF_VERSION}")
121122
implementation("com.fasterxml.jackson.core:jackson-databind:${JACKSON_VERSION}")
122123
implementation("com.fasterxml.jackson.core:jackson-annotations:${JACKSON_VERSION}")

core/src/main/java/io/substrait/extendedexpression/ProtoExtendedExpressionConverter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public ExtendedExpression from(io.substrait.proto.ExtendedExpression extendedExp
3434
// fill in simple extension information through a discovery in the current proto-extended
3535
// expression
3636
ExtensionLookup functionLookup =
37-
ImmutableExtensionLookup.builder().from(extendedExpression).build();
37+
ImmutableExtensionLookup.builder().from(extendedExpression, this.extensionCollection.uriUrnMap()).build();
3838

3939
NamedStruct baseSchemaProto = extendedExpression.getBaseSchema();
4040

core/src/main/java/io/substrait/extension/ImmutableExtensionLookup.java

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import io.substrait.proto.ExtendedExpression;
44
import io.substrait.proto.Plan;
55
import io.substrait.proto.SimpleExtensionDeclaration;
6+
import io.substrait.proto.SimpleExtensionURI;
67
import io.substrait.proto.SimpleExtensionURN;
78
import java.util.Collections;
89
import java.util.HashMap;
@@ -29,24 +30,35 @@ public static class Builder {
2930
private final Map<Integer, SimpleExtension.FunctionAnchor> functionMap = new HashMap<>();
3031
private final Map<Integer, SimpleExtension.TypeAnchor> typeMap = new HashMap<>();
3132

32-
public Builder from(Plan plan) {
33-
return from(plan.getExtensionUrnsList(), plan.getExtensionsList());
33+
public Builder from(Plan plan, BidiMap<String, String> uriUrnMap) {
34+
return from(
35+
plan.getExtensionUrnsList(), plan.getExtensionUrisList(), plan.getExtensionsList(), uriUrnMap);
3436
}
3537

36-
public Builder from(ExtendedExpression extendedExpression) {
38+
public Builder from(ExtendedExpression extendedExpression, BidiMap<String, String> uriUrnMap) {
3739
return from(
38-
extendedExpression.getExtensionUrnsList(), extendedExpression.getExtensionsList());
40+
extendedExpression.getExtensionUrnsList(),
41+
extendedExpression.getExtensionUrisList(),
42+
extendedExpression.getExtensionsList(),
43+
uriUrnMap);
3944
}
4045

4146
private Builder from(
4247
List<SimpleExtensionURN> simpleExtensionURNs,
43-
List<SimpleExtensionDeclaration> simpleExtensionDeclarations) {
48+
List<SimpleExtensionURI> simpleExtensionURIs,
49+
List<SimpleExtensionDeclaration> simpleExtensionDeclarations,
50+
BidiMap<String, String> uriUrnMap) {
4451
Map<Integer, String> urnMap = new HashMap<>();
52+
Map<Integer, String> uriMap = new HashMap<>();
4553
// Handle URN format
4654
for (SimpleExtensionURN extension : simpleExtensionURNs) {
4755
urnMap.put(extension.getExtensionUrnAnchor(), extension.getUrn());
4856
}
4957

58+
for (SimpleExtensionURI extension : simpleExtensionURIs) {
59+
uriMap.put(extension.getExtensionUriAnchor(), extension.getUri());
60+
}
61+
5062
// Add all functions used in plan to the functionMap
5163
for (SimpleExtensionDeclaration extension : simpleExtensionDeclarations) {
5264
if (!extension.hasExtensionFunction()) {
@@ -56,9 +68,22 @@ private Builder from(
5668
int reference = func.getFunctionAnchor();
5769
String urn = urnMap.get(func.getExtensionUrnReference());
5870
if (urn == null) {
59-
throw new IllegalStateException(
60-
"Could not find extension URN for function reference "
61-
+ func.getExtensionUrnReference());
71+
int uriReference = func.getExtensionUriReference();
72+
String uri = uriMap.get(uriReference);
73+
if (uri == null) {
74+
throw new IllegalStateException(
75+
"Could not find extension URN for function reference "
76+
+ func.getExtensionUrnReference()
77+
+ " or extension URI for function reference "
78+
+ func.getExtensionUriReference());
79+
}
80+
// Translate URI to URN using the BidiMap
81+
urn = uriUrnMap.get(uri);
82+
if (urn == null) {
83+
throw new IllegalStateException(
84+
"Could not translate URI '" + uri + "' to URN. "
85+
+ "URI-URN mapping not found in the provided mapping.");
86+
}
6287
}
6388
String name = func.getName();
6489
SimpleExtension.FunctionAnchor anchor = SimpleExtension.FunctionAnchor.of(urn, name);
@@ -74,8 +99,22 @@ private Builder from(
7499
int reference = type.getTypeAnchor();
75100
String urn = urnMap.get(type.getExtensionUrnReference());
76101
if (urn == null) {
77-
throw new IllegalStateException(
78-
"Could not find extension URN for type reference " + type.getExtensionUrnReference());
102+
int uriReference = type.getExtensionUriReference();
103+
String uri = uriMap.get(uriReference);
104+
if (uri == null) {
105+
throw new IllegalStateException(
106+
"Could not find extension URN for type reference "
107+
+ type.getExtensionUrnReference()
108+
+ " or extension URI for type reference "
109+
+ type.getExtensionUriReference());
110+
}
111+
// Translate URI to URN using the BidiMap
112+
urn = uriUrnMap.get(uri);
113+
if (urn == null) {
114+
throw new IllegalStateException(
115+
"Could not translate URI '" + uri + "' to URN. "
116+
+ "URI-URN mapping not found in the provided mapping.");
117+
}
79118
}
80119
String name = type.getName();
81120
SimpleExtension.TypeAnchor anchor = SimpleExtension.TypeAnchor.of(urn, name);

core/src/main/java/io/substrait/extension/SimpleExtension.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,7 @@ public Stream<SimpleExtension.Function> resolve(String urn) {
588588
@Value.Immutable
589589
public abstract static class ExtensionCollection {
590590
@Value.Default
591-
BidiMap<String, String> uriUrnMap() {
591+
public BidiMap<String, String> uriUrnMap() {
592592
return new BidiMap<>();
593593
}
594594

core/src/main/java/io/substrait/plan/Plan.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ public Version getVersion() {
2121

2222
public abstract Optional<AdvancedExtension> getAdvancedExtension();
2323

24+
public abstract List<String> getExtensionUrns();
25+
26+
public abstract List<String> getExtensionUris();
27+
2428
public static ImmutablePlan.Builder builder() {
2529
return ImmutablePlan.builder();
2630
}

core/src/main/java/io/substrait/plan/ProtoPlanConverter.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ protected ProtoRelConverter getProtoRelConverter(ExtensionLookup functionLookup)
2929
}
3030

3131
public Plan from(io.substrait.proto.Plan plan) {
32-
ExtensionLookup functionLookup = ImmutableExtensionLookup.builder().from(plan).build();
32+
ExtensionLookup functionLookup =
33+
ImmutableExtensionLookup.builder().from(plan, extensionCollection.uriUrnMap()).build();
3334
ProtoRelConverter relConverter = getProtoRelConverter(functionLookup);
3435
List<Plan.Root> roots = new ArrayList<>();
3536
for (PlanRel planRel : plan.getRelationsList()) {
@@ -54,12 +55,23 @@ public Plan from(io.substrait.proto.Plan plan) {
5455
versionBuilder.producer(Optional.of(plan.getVersion().getProducer()));
5556
}
5657

58+
List<String> extensionUrns =
59+
plan.getExtensionUrnsList().stream()
60+
.map(urn -> urn.getUrn())
61+
.collect(java.util.stream.Collectors.toList());
62+
List<String> extensionUris =
63+
extensionUrns.stream()
64+
.map(urn -> extensionCollection.getUri(urn))
65+
.collect(java.util.stream.Collectors.toList());
66+
5767
return Plan.builder()
5868
.roots(roots)
5969
.expectedTypeUrls(plan.getExpectedTypeUrlsList())
6070
.advancedExtension(
6171
Optional.ofNullable(plan.hasAdvancedExtensions() ? plan.getAdvancedExtensions() : null))
6272
.version(versionBuilder.build())
73+
.extensionUrns(extensionUrns)
74+
.extensionUris(extensionUris)
6375
.build();
6476
}
6577
}
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
package io.substrait.extension;
2+
3+
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
4+
import static org.junit.jupiter.api.Assertions.assertEquals;
5+
import static org.junit.jupiter.api.Assertions.assertNotNull;
6+
import static org.junit.jupiter.api.Assertions.assertTrue;
7+
8+
import io.substrait.plan.Plan;
9+
import io.substrait.plan.PlanProtoConverter;
10+
import io.substrait.plan.ProtoPlanConverter;
11+
import io.substrait.proto.SimpleExtensionURI;
12+
import io.substrait.proto.SimpleExtensionURN;
13+
import io.substrait.proto.SimpleExtensionDeclaration;
14+
import io.substrait.proto.PlanRel;
15+
16+
import org.junit.jupiter.api.Test;
17+
18+
/**
19+
* Tests describing the desired URI ↔ URN migration behaviour. These are disabled until the runtime
20+
* support is implemented.
21+
*/
22+
public class UriUrnMigrationTest {
23+
24+
private static final String SAMPLE_URI = "https://example.com/extensions/sample.yaml";
25+
private static final String SAMPLE_YAML =
26+
"%YAML 1.2\n"
27+
+ "---\n"
28+
+ "urn: extension:test:sample\n"
29+
+ "scalar_functions:\n"
30+
+ " - name: add\n"
31+
+ " impls:\n"
32+
+ " - args:\n"
33+
+ " - value: i32\n"
34+
+ " - value: i32\n"
35+
+ " return: i32\n";
36+
37+
@Test
38+
void uriOnlyPlanShouldHaveUrn() throws Exception {
39+
SimpleExtension.ExtensionCollection extensions = SimpleExtension.load(SAMPLE_URI, SAMPLE_YAML);
40+
io.substrait.proto.Plan protoPlan =
41+
io.substrait.proto.Plan.newBuilder()
42+
.addExtensionUrns(
43+
SimpleExtensionURN.newBuilder()
44+
.setExtensionUrnAnchor(1)
45+
.setUrn("extension:test:sample")
46+
.build())
47+
.addExtensions(
48+
SimpleExtensionDeclaration.newBuilder()
49+
.setExtensionFunction(
50+
SimpleExtensionDeclaration.ExtensionFunction.newBuilder()
51+
.setFunctionAnchor(1)
52+
.setName("add:i32_i32")
53+
.setExtensionUrnReference(1)
54+
.build())
55+
.build())
56+
.addRelations(
57+
PlanRel.newBuilder()
58+
.setRoot(
59+
io.substrait.proto.RelRoot.newBuilder()
60+
.setInput(
61+
io.substrait.proto.Rel.newBuilder()
62+
.setProject(
63+
io.substrait.proto.ProjectRel.newBuilder()
64+
.setInput(
65+
io.substrait.proto.Rel.newBuilder()
66+
.setRead(
67+
io.substrait.proto.ReadRel.newBuilder()
68+
.setNamedTable(
69+
io.substrait.proto.ReadRel
70+
.NamedTable.newBuilder()
71+
.addNames("dummy")
72+
.build())
73+
.setBaseSchema(
74+
io.substrait.proto.NamedStruct
75+
.newBuilder()
76+
.addNames("col")
77+
.setStruct(
78+
io.substrait.proto.Type
79+
.Struct.newBuilder()
80+
.addTypes(
81+
io.substrait.proto
82+
.Type
83+
.newBuilder()
84+
.setI32(
85+
io.substrait
86+
.proto
87+
.Type
88+
.I32
89+
.newBuilder())
90+
.build())
91+
.build())
92+
.build())
93+
.build())
94+
.build())
95+
.addExpressions(
96+
io.substrait.proto.Expression.newBuilder()
97+
.setScalarFunction(
98+
io.substrait.proto.Expression.ScalarFunction
99+
.newBuilder()
100+
.setFunctionReference(
101+
1) // Uses our add function
102+
.addArguments(
103+
io.substrait.proto.FunctionArgument
104+
.newBuilder()
105+
.setValue(
106+
io.substrait.proto
107+
.Expression.newBuilder()
108+
.setLiteral(
109+
io.substrait.proto
110+
.Expression
111+
.Literal
112+
.newBuilder()
113+
.setI32(1)
114+
.build())
115+
.build())
116+
.build())
117+
.addArguments(
118+
io.substrait.proto.FunctionArgument
119+
.newBuilder()
120+
.setValue(
121+
io.substrait.proto
122+
.Expression.newBuilder()
123+
.setLiteral(
124+
io.substrait.proto
125+
.Expression
126+
.Literal
127+
.newBuilder()
128+
.setI32(2)
129+
.build())
130+
.build())
131+
.build())
132+
.setOutputType(
133+
io.substrait.proto.Type.newBuilder()
134+
.setI32(
135+
io.substrait.proto.Type.I32
136+
.newBuilder())
137+
.build())
138+
.build())
139+
.build())
140+
.build())
141+
.build())
142+
.addNames("result")
143+
.build())
144+
.build())
145+
.build();
146+
147+
Plan planFromProto = new ProtoPlanConverter(extensions).from(protoPlan);
148+
149+
assertTrue(planFromProto.getExtensionUris().size() > 0, "Plan should have URI");
150+
assertTrue(planFromProto.getExtensionUrns().size() > 0, "Plan should have URN");
151+
}
152+
}

0 commit comments

Comments
 (0)