3131import org .neo4j .procedure .UserFunction ;
3232import org .neo4j .values .storable .Values ;
3333
34- import java .util .Collections ;
3534import java .util .Comparator ;
3635import java .util .HashSet ;
3736import java .util .List ;
3837import java .util .Map ;
38+ import java .util .function .Predicate ;
3939
4040import static org .neo4j .graphalgo .impl .similarity .SimilarityVectorAggregator .CATEGORY_KEY ;
4141import static org .neo4j .graphalgo .impl .similarity .SimilarityVectorAggregator .WEIGHT_KEY ;
4242import static org .neo4j .graphalgo .impl .utils .NumberUtils .getDoubleValue ;
4343
4444public class SimilaritiesFunc {
4545
46+ public static final Predicate <Number > IS_NULL = Predicate .isEqual (null );
47+ public static final Comparator <Number > NUMBER_COMPARATOR = new NumberComparator ();
48+
4649 @ UserFunction ("gds.alpha.similarity.jaccard" )
4750 @ Description ("RETURN gds.alpha.similarity.jaccard(vector1, vector2) - Given two collection vectors, calculate Jaccard similarity" )
4851 public double jaccardSimilarity (@ Name ("vector1" ) List <Number > vector1 , @ Name ("vector2" ) List <Number > vector2 ) {
@@ -190,12 +193,10 @@ public double overlapSimilarity(@Name("vector1") List<Number> vector1, @Name("ve
190193 * @return The jaccard score, the intersection divided by the union of the input lists.
191194 */
192195 private double jaccard (List <Number > vector1 , List <Number > vector2 ) {
193- Comparator <Number > numberComparator = new NumberComparator ();
194- List <Number > nullList = Collections .singletonList (null );
195- vector1 .removeAll (nullList );
196- vector2 .removeAll (nullList );
197- vector1 .sort (numberComparator );
198- vector2 .sort (numberComparator );
196+ vector1 .removeIf (IS_NULL );
197+ vector2 .removeIf (IS_NULL );
198+ vector1 .sort (NUMBER_COMPARATOR );
199+ vector2 .sort (NUMBER_COMPARATOR );
199200
200201 int index1 = 0 ;
201202 int index2 = 0 ;
@@ -206,7 +207,7 @@ private double jaccard(List<Number> vector1, List<Number> vector2) {
206207 while (index1 < vector1 .size () && index2 < vector2 .size ()) {
207208 Number val1 = vector1 .get (index1 );
208209 Number val2 = vector2 .get (index2 );
209- int compare = numberComparator .compare (val1 , val2 );
210+ int compare = NUMBER_COMPARATOR .compare (val1 , val2 );
210211
211212 if (compare == 0 ) {
212213 intersection ++;
@@ -225,7 +226,7 @@ private double jaccard(List<Number> vector1, List<Number> vector2) {
225226 // the remainder, if any, is never shared so add to the union
226227 union += (vector1 .size () - index1 ) + (vector2 .size () - index2 );
227228
228- return intersection / union ;
229+ return union == 0 ? 1 : intersection / union ;
229230 }
230231
231232 static class NumberComparator implements Comparator <Number > {
0 commit comments