Skip to content

Commit 3ffcb9d

Browse files
committed
track urn<->uri mapping internally
1 parent 1813e60 commit 3ffcb9d

File tree

7 files changed

+155
-62
lines changed

7 files changed

+155
-62
lines changed

core/src/main/java/io/substrait/extension/BidiMap.java

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ public BidiMap(Map<T1, T2> forwardMap) {
1515
}
1616

1717
public BidiMap() {
18-
this.forwardMap = new HashMap<>();
19-
this.reverseMap = new HashMap<>();
18+
this.forwardMap = new HashMap<>();
19+
this.reverseMap = new HashMap<>();
2020
}
2121

2222
public T2 get(T1 t1) {
@@ -28,15 +28,33 @@ public T1 reverseGet(T2 t2) {
2828
}
2929

3030
public void put(T1 t1, T2 t2) {
31+
// Check for conflicting mappings (different values for same key)
32+
T2 existingForward = forwardMap.get(t1);
33+
T1 existingReverse = reverseMap.get(t2);
34+
35+
if (existingForward != null && !existingForward.equals(t2)) {
36+
throw new IllegalArgumentException("Key already exists in map with different value");
37+
}
38+
if (existingReverse != null && !existingReverse.equals(t1)) {
39+
throw new IllegalArgumentException("Key already exists in map with different value");
40+
}
41+
42+
// Allow identical mappings, only add if not already present
3143
forwardMap.put(t1, t2);
3244
reverseMap.put(t2, t1);
3345
}
3446

47+
public void merge(BidiMap<T1, T2> other) {
48+
for (Map.Entry<T1, T2> entry : other.forwardEntrySet()) {
49+
put(entry.getKey(), entry.getValue());
50+
}
51+
}
52+
3553
public Set<Map.Entry<T1, T2>> forwardEntrySet() {
36-
return forwardMap.entrySet();
54+
return forwardMap.entrySet();
3755
}
3856

39-
public Set<Map.Entry<T2, T1>> reverseEntrySet() {
40-
return reverseMap.entrySet();
41-
}
57+
public Set<Map.Entry<T2, T1>> reverseEntrySet() {
58+
return reverseMap.entrySet();
59+
}
4260
}

core/src/main/java/io/substrait/extension/SimpleExtension.java

Lines changed: 95 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@ private static void validateUrn(String urn) {
5252
throw new IllegalArgumentException("URN cannot be null or empty");
5353
}
5454
if (!urn.matches("^extension:[^:]+:[^:]+$")) {
55-
throw new IllegalArgumentException("URN must follow format 'extension:<namespace>:<name>', got: " + urn);
55+
throw new IllegalArgumentException(
56+
"URN must follow format 'extension:<namespace>:<name>', got: " + urn);
5657
}
5758
}
5859

59-
6060
private static ObjectMapper objectMapper(String urn) {
6161
InjectableValues.Std iv = new InjectableValues.Std();
6262
iv.addValue(URN_LOCATOR_KEY, urn);
@@ -551,6 +551,13 @@ public abstract static class ExtensionSignatures {
551551
@JsonProperty("urn")
552552
public abstract String urn();
553553

554+
// URI is not from YAML, but from the loading context
555+
// this only needs to be present temporarily to handle the URI -> URN migration
556+
@Value.Default
557+
public String uri() {
558+
return "";
559+
}
560+
554561
@JsonProperty("scalar_functions")
555562
public abstract List<ScalarFunction> scalars();
556563

@@ -580,6 +587,11 @@ public Stream<SimpleExtension.Function> resolve(String urn) {
580587

581588
@Value.Immutable
582589
public abstract static class ExtensionCollection {
590+
@Value.Default
591+
BidiMap<String, String> uriUrnMap() {
592+
return new BidiMap<>();
593+
}
594+
583595
private final Supplier<Set<String>> urnSupplier =
584596
Util.memoize(
585597
() -> {
@@ -700,7 +712,55 @@ public WindowFunctionVariant getWindowFunction(FunctionAnchor anchor) {
700712
anchor.key(), anchor.urn()));
701713
}
702714

715+
/**
716+
* Gets the URN for a given URI. This is only useful during the URI -> URN migration, and will
717+
* be dropped when the migration is complete.
718+
*
719+
* @param uri The URI to look up
720+
* @return The corresponding URN, or null if not found
721+
*/
722+
public String getUrn(String uri) {
723+
return uriUrnMap().get(uri);
724+
}
725+
726+
/**
727+
* Gets the URI for a given URN. This is only useful during the URI -> URN migration, and will
728+
* be dropped when the migration is complete.
729+
*
730+
* @param urn The URN to look up
731+
* @return The corresponding URI, or null if not found
732+
*/
733+
public String getUri(String urn) {
734+
return uriUrnMap().reverseGet(urn);
735+
}
736+
737+
/**
738+
* Checks if a URI has a corresponding URN mapping. This is only useful during the URI -> URN
739+
* migration, and will be dropped when the migration is complete.
740+
*
741+
* @param uri The URI to check
742+
* @return true if the URI has a URN mapping, false otherwise
743+
*/
744+
public boolean hasUrn(String uri) {
745+
return uriUrnMap().get(uri) != null;
746+
}
747+
748+
/**
749+
* Checks if a URN has a corresponding URI mapping. This is only useful during the URI -> URN
750+
* migration, and will be dropped when the migration is complete.
751+
*
752+
* @param urn The URN to check
753+
* @return true if the URN has a URI mapping, false otherwise
754+
*/
755+
public boolean hasUri(String urn) {
756+
return uriUrnMap().reverseGet(urn) != null;
757+
}
758+
703759
public ExtensionCollection merge(ExtensionCollection extensionCollection) {
760+
BidiMap<String, String> mergedUriUrnMap = new BidiMap<>();
761+
mergedUriUrnMap.merge(uriUrnMap());
762+
mergedUriUrnMap.merge(extensionCollection.uriUrnMap());
763+
704764
return ImmutableSimpleExtension.ExtensionCollection.builder()
705765
.addAllAggregateFunctions(aggregateFunctions())
706766
.addAllAggregateFunctions(extensionCollection.aggregateFunctions())
@@ -710,6 +770,7 @@ public ExtensionCollection merge(ExtensionCollection extensionCollection) {
710770
.addAllWindowFunctions(extensionCollection.windowFunctions())
711771
.addAllTypes(types())
712772
.addAllTypes(extensionCollection.types())
773+
.uriUrnMap(mergedUriUrnMap)
713774
.build();
714775
}
715776
}
@@ -740,12 +801,12 @@ private static ExtensionCollection load(List<String> resourcePaths) {
740801
throw new IllegalArgumentException("Require at least one resource path.");
741802
}
742803

743-
List<ExtensionCollection> extensions =
804+
List<ExtensionCollection> extensions =
744805
resourcePaths.stream()
745806
.map(
746807
path -> {
747808
try (InputStream stream = ExtensionCollection.class.getResourceAsStream(path)) {
748-
return load(stream);
809+
return load(path, stream);
749810
} catch (IOException e) {
750811
throw new UncheckedIOException(e);
751812
}
@@ -758,40 +819,48 @@ private static ExtensionCollection load(List<String> resourcePaths) {
758819
return complete;
759820
}
760821

761-
public static ExtensionCollection load(String content) {
822+
public static ExtensionCollection load(String uri, String content) {
762823
try {
763-
// Parse with basic YAML mapper first to extract URN (if present)
824+
// Parse with basic YAML mapper first to extract URN
764825
ObjectMapper basicYamlMapper = new ObjectMapper(new YAMLFactory());
765826
com.fasterxml.jackson.databind.JsonNode rootNode = basicYamlMapper.readTree(content);
766-
767-
// URN is required
768827
com.fasterxml.jackson.databind.JsonNode urnNode = rootNode.get("urn");
769828
if (urnNode == null) {
770829
throw new IllegalArgumentException("Extension YAML file must contain a 'urn' field");
771830
}
772831
String urn = urnNode.asText();
773832
validateUrn(urn);
774833

775-
// Then parse with URN-aware mapper
776-
ExtensionSignatures doc = objectMapper(urn).readValue(content, ExtensionSignatures.class);
777-
return buildExtensionCollection(urn, doc);
834+
ExtensionSignatures docWithoutUri =
835+
objectMapper(urn).readValue(content, ExtensionSignatures.class);
836+
837+
ExtensionSignatures doc =
838+
ImmutableSimpleExtension.ExtensionSignatures.builder()
839+
.from(docWithoutUri)
840+
.uri(uri)
841+
.build();
842+
843+
return buildExtensionCollection(uri, doc);
778844
} catch (IOException e) {
779845
throw new IllegalStateException(e);
780846
}
781847
}
782848

783-
public static ExtensionCollection load(InputStream stream) {
784-
try (Scanner scanner = new Scanner(stream)) {
785-
scanner.useDelimiter("\\A");
786-
String content = scanner.next();
787-
return load(content);
788-
}
849+
public static ExtensionCollection load(String uri, InputStream stream) {
850+
try (Scanner scanner = new Scanner(stream)) {
851+
scanner.useDelimiter("\\A");
852+
String content = scanner.next();
853+
return load(uri, content);
789854
}
790-
855+
}
791856

792857
public static ExtensionCollection buildExtensionCollection(
793-
String urn, ExtensionSignatures extensionSignatures) {
858+
String uri, ExtensionSignatures extensionSignatures) {
859+
String urn = extensionSignatures.urn();
794860
validateUrn(urn);
861+
if (uri == null || uri == "") {
862+
throw new IllegalArgumentException("URI cannot be null or empty");
863+
}
795864
List<ScalarFunctionVariant> scalarFunctionVariants =
796865
extensionSignatures.scalars().stream()
797866
.flatMap(t -> t.resolve(urn))
@@ -824,23 +893,23 @@ public static ExtensionCollection buildExtensionCollection(
824893
Stream.concat(windowFunctionVariants, windowAggFunctionVariants)
825894
.collect(Collectors.toList());
826895

896+
BidiMap<String, String> uriUrnMap = new BidiMap<>();
897+
uriUrnMap.put(uri, urn);
898+
827899
ImmutableSimpleExtension.ExtensionCollection collection =
828900
ImmutableSimpleExtension.ExtensionCollection.builder()
829901
.scalarFunctions(scalarFunctionVariants)
830902
.aggregateFunctions(aggregateFunctionVariants)
831903
.windowFunctions(allWindowFunctionVariants)
832904
.addAllTypes(extensionSignatures.types())
905+
.uriUrnMap(uriUrnMap)
833906
.build();
907+
834908
LOGGER.atDebug().log(
835909
"Loaded {} aggregate functions and {} scalar functions from {}.",
836910
collection.aggregateFunctions().size(),
837-
collection.scalarFunctions().size(), extensionSignatures.urn());
911+
collection.scalarFunctions().size(),
912+
extensionSignatures.urn());
838913
return collection;
839914
}
840-
841-
public static ExtensionCollection buildExtensionCollection(
842-
ExtensionSignatures extensionSignatures) {
843-
String urn = extensionSignatures.urn();
844-
return buildExtensionCollection(urn, extensionSignatures);
845-
}
846915
}

core/src/test/java/io/substrait/extension/TypeExtensionTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ public class TypeExtensionTest {
3232
final SimpleExtension.ExtensionCollection extensionCollection;
3333

3434
{
35-
InputStream inputStream =
36-
this.getClass().getResourceAsStream("/extensions/custom_extensions.yaml");
37-
extensionCollection = SimpleExtension.load(inputStream);
35+
String path = "/extensions/custom_extensions.yaml";
36+
InputStream inputStream = this.getClass().getResourceAsStream(path);
37+
extensionCollection = SimpleExtension.load(path, inputStream);
3838
}
3939

4040
final SubstraitBuilder b = new SubstraitBuilder(extensionCollection);

core/src/test/java/io/substrait/extension/UrnValidationTest.java

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,38 +7,43 @@ public class UrnValidationTest {
77

88
@Test
99
public void testMissingUrnThrowsException() {
10-
String yamlWithoutUrn = """
11-
%YAML 1.2
12-
---
13-
scalar_functions:
14-
- name: test
15-
""";
16-
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> SimpleExtension.load(yamlWithoutUrn));
10+
String yamlWithoutUrn = "%YAML 1.2\n" +
11+
"---\n" +
12+
"scalar_functions:\n" +
13+
" - name: test\n";
14+
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> SimpleExtension.load("test://uri", yamlWithoutUrn));
1715
assertTrue(exception.getMessage().contains("Extension YAML file must contain a 'urn' field"));
1816
}
1917

2018
@Test
2119
public void testInvalidUrnFormatThrowsException() {
22-
String yamlWithInvalidUrn = """
23-
%YAML 1.2
24-
---
25-
urn: invalid:format
26-
scalar_functions:
27-
- name: test
28-
""";
29-
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> SimpleExtension.load(yamlWithInvalidUrn));
20+
String yamlWithInvalidUrn = "%YAML 1.2\n" +
21+
"---\n" +
22+
"urn: invalid:format\n" +
23+
"scalar_functions:\n" +
24+
" - name: test\n";
25+
IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> SimpleExtension.load("test://uri", yamlWithInvalidUrn));
3026
assertTrue(exception.getMessage().contains("URN must follow format 'extension:<namespace>:<name>'"));
3127
}
3228

3329
@Test
3430
public void testValidUrnWorks() {
35-
String yamlWithValidUrn = """
36-
%YAML 1.2
37-
---
38-
urn: extension:test:valid
39-
scalar_functions:
40-
- name: test
41-
""";
42-
assertDoesNotThrow(() -> SimpleExtension.load(yamlWithValidUrn));
31+
String yamlWithValidUrn = "%YAML 1.2\n" +
32+
"---\n" +
33+
"urn: extension:test:valid\n" +
34+
"scalar_functions:\n" +
35+
" - name: test\n";
36+
assertDoesNotThrow(() -> SimpleExtension.load("test://uri", yamlWithValidUrn));
37+
}
38+
39+
@Test
40+
public void testUriUrnMapIsPopulated() {
41+
String yamlWithValidUrn = "%YAML 1.2\n" +
42+
"---\n" +
43+
"urn: extension:test:valid\n" +
44+
"scalar_functions:\n" +
45+
" - name: test\n";
46+
SimpleExtension.ExtensionCollection collection = SimpleExtension.load("test://uri", yamlWithValidUrn);
47+
assertEquals("extension:test:valid", collection.getUrn("test://uri"));
4348
}
4449
}

isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public class CustomFunctionTest extends PlanTestBase {
5656

5757
// Load custom extension into an ExtensionCollection
5858
static final SimpleExtension.ExtensionCollection extensionCollection =
59-
SimpleExtension.load(FUNCTIONS_CUSTOM);
59+
SimpleExtension.load("custom.yaml", FUNCTIONS_CUSTOM);
6060

6161
final SubstraitBuilder b = new SubstraitBuilder(extensionCollection);
6262

isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ public class RelCopyOnWriteVisitorTest extends PlanTestBase {
3131
SimpleExtension.FunctionAnchor.of(
3232
"extension:io.substrait:functions_aggregate_approx", "approx_count_distinct:any");
3333
public static SimpleExtension.FunctionAnchor COUNT =
34-
SimpleExtension.FunctionAnchor.of("extension:io.substrait:functions_aggregate_generic", "count:any");
34+
SimpleExtension.FunctionAnchor.of(
35+
"extension:io.substrait:functions_aggregate_generic", "count:any");
3536

3637
private static final String COUNT_DISTINCT_SUBBQUERY =
3738
"select\n"
3839
+ " count(distinct l.l_orderkey),\n"
39-
+ " count(distinct l.l_orderkey) + 1,\n"
40+
+ " count(distinct l.l_orderkey) + 1,\n"
4041
+ " sum(l.l_extendedprice * (1 - l.l_discount)) as revenue,\n"
4142
+ " o.o_orderdate,\n"
4243
+ " count(distinct o.o_shippriority)\n"

spark/src/main/scala/io/substrait/spark/SparkExtension.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ object SparkExtension {
2929
final val file = "/spark.yml"
3030

3131
private val SparkImpls: SimpleExtension.ExtensionCollection =
32-
SimpleExtension.load(getClass.getResourceAsStream(file))
32+
SimpleExtension.load(file, getClass.getResourceAsStream(file))
3333

3434
private val EXTENSION_COLLECTION: SimpleExtension.ExtensionCollection =
3535
SimpleExtension.loadDefaults()

0 commit comments

Comments
 (0)