Skip to content

Commit d6993c0

Browse files
jjaderbergvnickolov
authored andcommitted
Make KNN metric parsin case insensitive
1 parent 6ccb41d commit d6993c0

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

algo/src/main/java/org/neo4j/gds/similarity/knn/KnnNodePropertySpecParser.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import java.util.Collections;
2929
import java.util.HashMap;
3030
import java.util.List;
31+
import java.util.Locale;
3132
import java.util.Map;
3233
import java.util.stream.Collectors;
3334

@@ -113,7 +114,7 @@ private static List<KnnNodePropertySpec> fromMap(Map<String, String> userInput)
113114
SimilarityMetric similarityMetric;
114115
if (value != null) {
115116
try {
116-
similarityMetric = SimilarityMetric.valueOf(value);
117+
similarityMetric = SimilarityMetric.valueOf(value.toUpperCase(Locale.ENGLISH));
117118
knnNodeProperties.add(new KnnNodePropertySpec(key, similarityMetric));
118119
} catch (IllegalArgumentException ex) {
119120
throw new IllegalArgumentException(

algo/src/test/java/org/neo4j/gds/similarity/knn/KnnNodePropertySpecParserTest.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@
2020
package org.neo4j.gds.similarity.knn;
2121

2222
import org.junit.jupiter.api.Test;
23+
import org.junit.jupiter.params.ParameterizedTest;
24+
import org.junit.jupiter.params.provider.ValueSource;
2325
import org.neo4j.gds.similarity.knn.metrics.SimilarityMetric;
2426

2527
import java.util.List;
28+
import java.util.Locale;
2629
import java.util.Map;
2730

2831
import static org.assertj.core.api.Assertions.assertThat;
@@ -103,6 +106,16 @@ void shouldParseMap() {
103106
);
104107
}
105108

109+
@ParameterizedTest
110+
@ValueSource(strings = {"cosine", "euCLIDean", "jACcaRd", "OVERLAP", "Pearson"})
111+
void shouldAcceptMetricsRegardlessOfCase(String metric) {
112+
var input = "property";
113+
assertThat(KnnNodePropertySpecParser.parse(Map.of(input, metric)))
114+
.singleElement()
115+
.extracting(spec -> spec.metric().name())
116+
.isEqualTo(metric.toUpperCase(Locale.ENGLISH));
117+
}
118+
106119
@Test
107120
void shouldRefuseToParseEmptyList() {
108121
var input = List.of();

0 commit comments

Comments
 (0)