Skip to content

Commit b577920

Browse files
Expand label result with some stats
Co-authored-by: Veselin Nikolov <[email protected]>
1 parent 3052d1b commit b577920

File tree

5 files changed

+62
-17
lines changed

5 files changed

+62
-17
lines changed

algo/src/main/java/org/neo4j/gds/hdbscan/HDBScan.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,14 @@
2222
import org.neo4j.gds.Algorithm;
2323
import org.neo4j.gds.api.IdMap;
2424
import org.neo4j.gds.api.properties.nodes.NodePropertyValues;
25-
import org.neo4j.gds.collections.ha.HugeLongArray;
2625
import org.neo4j.gds.collections.ha.HugeObjectArray;
2726
import org.neo4j.gds.core.concurrency.Concurrency;
2827
import org.neo4j.gds.core.concurrency.ParallelUtil;
2928
import org.neo4j.gds.core.utils.paged.HugeSerialObjectMergeSort;
3029
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;
3130
import org.neo4j.gds.termination.TerminationFlag;
3231

33-
public class HDBScan extends Algorithm<HugeLongArray> {
32+
public class HDBScan extends Algorithm<Labels> {
3433

3534
private final IdMap nodes;
3635
private final NodePropertyValues nodePropertyValues;
@@ -61,7 +60,7 @@ public HDBScan(
6160
}
6261

6362
@Override
64-
public HugeLongArray compute() {
63+
public Labels compute() {
6564
progressTracker.beginSubTask();
6665
var kdTree = buildKDTree();
6766

algo/src/main/java/org/neo4j/gds/hdbscan/LabellingStep.java

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class LabellingStep {
2929
private final CondensedTree condensedTree;
3030
private final long nodeCount;
3131
private final ProgressTracker progressTracker;
32+
private static final long NOISE = -1;
3233

3334
LabellingStep(CondensedTree condensedTree, long nodeCount, ProgressTracker progressTracker) {
3435
this.condensedTree = condensedTree;
@@ -97,36 +98,42 @@ BitSet selectedClusters(HugeDoubleArray stabilities) {
9798
return selectedClusters;
9899
}
99100

100-
HugeLongArray computeLabels(BitSet selectedClusters) {
101+
Labels computeLabels(BitSet selectedClusters) {
101102
progressTracker.beginSubTask();
103+
var treeLabels = HugeLongArray.newArray(nodeCount);
102104
var labels = HugeLongArray.newArray(nodeCount);
103-
labels.fill(-1L);
104-
var nodeCountLabels = HugeLongArray.newArray(nodeCount);
105+
treeLabels.fill(NOISE);
106+
long clusters=0;
105107
var root = condensedTree.root();
106108
var maximumClusterId = condensedTree.maximumClusterId();
107109
for (var p = root; p <= maximumClusterId; p++) {
108110
var adaptedIndex = p - nodeCount;
109111
var parent = condensedTree.parent(p);
110-
long parentLabel = p == root ? -1L : labels.get(parent - nodeCount);
111-
if (parentLabel != -1L) {
112-
labels.set(adaptedIndex, parentLabel);
112+
long parentLabel = p == root ? NOISE : treeLabels.get(parent - nodeCount);
113+
if (parentLabel != NOISE) {
114+
treeLabels.set(adaptedIndex, parentLabel);
113115
} else if (selectedClusters.get(adaptedIndex)) {
114-
labels.set(adaptedIndex, adaptedIndex);
116+
clusters++;
117+
treeLabels.set(adaptedIndex, adaptedIndex);
115118
}
116119
progressTracker.logProgress();
117120
}
118121

122+
long noisePoints=0;
119123
for (var n = 0; n < nodeCount; n++) {
120-
nodeCountLabels.set(n, labels.get(condensedTree.fellOutOf(n) - nodeCount));
124+
long label = treeLabels.get(condensedTree.fellOutOf(n) - nodeCount);
125+
if (label == NOISE) {
126+
noisePoints++;
127+
}
128+
labels.set(n, label);
121129
progressTracker.logProgress();
122130
}
123131
progressTracker.endSubTask();
124132

125-
126-
return nodeCountLabels;
133+
return new Labels(labels,noisePoints,clusters);
127134
}
128135

129-
HugeLongArray labels() {
136+
Labels labels() {
130137
progressTracker.beginSubTask();
131138
var stabilities = computeStabilities();
132139
var selectedClusters = selectedClusters(stabilities);
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Copyright (c) "Neo4j"
3+
* Neo4j Sweden AB [http://neo4j.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Neo4j is free software: you can redistribute it and/or modify
8+
* it under the terms of the GNU General Public License as published by
9+
* the Free Software Foundation, either version 3 of the License, or
10+
* (at your option) any later version.
11+
*
12+
* This program is distributed in the hope that it will be useful,
13+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
14+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15+
* GNU General Public License for more details.
16+
*
17+
* You should have received a copy of the GNU General Public License
18+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
19+
*/
20+
package org.neo4j.gds.hdbscan;
21+
22+
import org.neo4j.gds.collections.ha.HugeLongArray;
23+
24+
public record Labels(HugeLongArray labels, long numberOfNoisePoints, long numberOfClusters) {}

algo/src/test/java/org/neo4j/gds/hdbscan/HDBScanE2ETest.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ void hdbscan() {
7373
TerminationFlag.RUNNING_TRUE
7474
);
7575

76-
var labelsWithOffset = hdbScan.compute();
76+
var result =hdbScan.compute();
77+
78+
var labelsWithOffset = result.labels();
7779

7880
var labels = new long[10];
7981
for (char letter='a'; letter<='j';++letter){
@@ -83,6 +85,9 @@ void hdbscan() {
8385

8486
var expectedLabels = new long[] {2, 2, 1, 2, 2, 2, 1, 1, 1, 1};
8587

88+
assertThat(result.numberOfClusters()).isEqualTo(2L);
89+
assertThat(result.numberOfNoisePoints()).isEqualTo(0);
90+
8691
assertThat(labels).containsExactly(expectedLabels);
8792
}
8893

algo/src/test/java/org/neo4j/gds/hdbscan/LabellingTest.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,9 @@ void labelling() {
176176

177177
var stabilityStep = new LabellingStep(condensedTree, nodeCount,ProgressTracker.NULL_TRACKER);
178178

179-
var labels = stabilityStep.computeLabels(selectedClusters);
179+
var labelsResult = stabilityStep.computeLabels(selectedClusters);
180180

181+
var labels=labelsResult.labels();
181182
assertThat(labels.size()).isEqualTo(nodeCount);
182183

183184
assertThat(labels.get(0)).isEqualTo(1L);
@@ -190,6 +191,10 @@ void labelling() {
190191
assertThat(labels.get(5)).isEqualTo(4L);
191192
assertThat(labels.get(6)).isEqualTo(4L);
192193

194+
assertThat(labelsResult.numberOfNoisePoints()).isEqualTo(2L);
195+
assertThat(labelsResult.numberOfClusters()).isEqualTo(2L);
196+
197+
193198
}
194199

195200
@Test
@@ -207,10 +212,15 @@ void labellingWhenAllClustersAreSelected() {
207212

208213
var stabilityStep = new LabellingStep(condensedTree, nodeCount,ProgressTracker.NULL_TRACKER);
209214

210-
var labels = stabilityStep.computeLabels(selectedClusters);
215+
var labelsResult = stabilityStep.computeLabels(selectedClusters);
211216

217+
var labels= labelsResult.labels();
212218
assertThat(labels.size()).isEqualTo(nodeCount);
213219
assertThat(labels.toArray()).containsOnly(0L);
220+
221+
assertThat(labelsResult.numberOfClusters()).isEqualTo(1L);
222+
assertThat(labelsResult.numberOfNoisePoints()).isEqualTo(0L);
223+
214224
}
215225

216226
@Test

0 commit comments

Comments
 (0)