|
34 | 34 | import java.util.Optional;
|
35 | 35 | import java.util.SplittableRandom;
|
36 | 36 | import java.util.concurrent.ExecutorService;
|
37 |
| -import java.util.concurrent.atomic.AtomicInteger; |
38 |
| - |
39 |
| -enum InputCondition { |
40 |
| - NORMAL(0), NAN(1), UNEQUALDIMENSION(2); |
41 |
| - public final int value; |
42 |
| - |
43 |
| - private InputCondition(int value) { |
44 |
| - this.value = value; |
45 |
| - } |
46 |
| - |
47 |
| -} |
48 | 37 |
|
49 | 38 | public class Kmeans extends Algorithm<KmeansResult> {
|
50 | 39 |
|
@@ -109,13 +98,8 @@ public static Kmeans createKmeans(Graph graph, KmeansBaseConfig config, KmeansCo
|
109 | 98 | public KmeansResult compute() {
|
110 | 99 | progressTracker.beginSubTask();
|
111 | 100 |
|
112 |
| - var inputCondition = checkInputValidity(); |
113 |
| - if (inputCondition == InputCondition.NAN) { |
114 |
| - throw new IllegalArgumentException("Input for K-Means should not contain any NaN values"); |
115 |
| - } else if (inputCondition == InputCondition.UNEQUALDIMENSION) { |
116 |
| - throw new IllegalStateException( |
117 |
| - "All property arrays for K-Means should have the same number of dimensions"); |
118 |
| - } |
| 101 | + checkInputValidity(); |
| 102 | + |
119 | 103 |
|
120 | 104 | if (k > graph.nodeCount()) {
|
121 | 105 | // Every node in its own community. Warn and return early.
|
@@ -197,43 +181,35 @@ private static SplittableRandom getSplittableRandom(Optional<Long> randomSeed) {
|
197 | 181 | return randomSeed.map(SplittableRandom::new).orElseGet(SplittableRandom::new);
|
198 | 182 | }
|
199 | 183 |
|
200 |
| - private InputCondition checkInputValidity() { |
201 |
| - AtomicInteger inputState = new AtomicInteger(InputCondition.NORMAL.value); |
| 184 | + private void checkInputValidity() { |
202 | 185 | ParallelUtil.parallelForEachNode(graph.nodeCount(), concurrency, nodeId -> {
|
203 |
| - if (inputState.get() == InputCondition.NORMAL.value) { |
204 |
| - if (nodePropertyValues.valueType() == ValueType.FLOAT_ARRAY) { |
205 |
| - var value = nodePropertyValues.floatArrayValue(nodeId); |
206 |
| - if (value.length != dimensions) { |
207 |
| - inputState.set(InputCondition.UNEQUALDIMENSION.value); |
208 |
| - } else { |
209 |
| - for (int dimension = 0; dimension < dimensions; ++dimension) { |
210 |
| - if (Float.isNaN(value[dimension])) { |
211 |
| - inputState.set(InputCondition.NAN.value); |
212 |
| - break; |
213 |
| - } |
| 186 | + if (nodePropertyValues.valueType() == ValueType.FLOAT_ARRAY) { |
| 187 | + var value = nodePropertyValues.floatArrayValue(nodeId); |
| 188 | + if (value.length != dimensions) { |
| 189 | + throw new IllegalStateException( |
| 190 | + "All property arrays for K-Means should have the same number of dimensions"); |
| 191 | + } else { |
| 192 | + for (int dimension = 0; dimension < dimensions; ++dimension) { |
| 193 | + if (Float.isNaN(value[dimension])) { |
| 194 | + throw new IllegalArgumentException("Input for K-Means should not contain any NaN values"); |
214 | 195 | }
|
215 | 196 | }
|
| 197 | + } |
| 198 | + } else { |
| 199 | + var value = nodePropertyValues.doubleArrayValue(nodeId); |
| 200 | + if (value.length != dimensions) { |
| 201 | + throw new IllegalStateException( |
| 202 | + "All property arrays for K-Means should have the same number of dimensions"); |
216 | 203 | } else {
|
217 |
| - var value = nodePropertyValues.doubleArrayValue(nodeId); |
218 |
| - if (value.length != dimensions) { |
219 |
| - inputState.set(InputCondition.UNEQUALDIMENSION.value); |
220 |
| - } else { |
221 |
| - for (int dimension = 0; dimension < dimensions; ++dimension) { |
222 |
| - if (Double.isNaN(value[dimension])) { |
223 |
| - inputState.set(InputCondition.NAN.value); |
224 |
| - break; |
225 |
| - } |
| 204 | + for (int dimension = 0; dimension < dimensions; ++dimension) { |
| 205 | + if (Double.isNaN(value[dimension])) { |
| 206 | + throw new IllegalArgumentException("Input for K-Means should not contain any NaN values"); |
| 207 | + |
226 | 208 | }
|
227 | 209 | }
|
228 | 210 | }
|
| 211 | + |
229 | 212 | }
|
230 | 213 | });
|
231 |
| - InputCondition inputCondition = InputCondition.NORMAL; |
232 |
| - if (inputState.get() == InputCondition.UNEQUALDIMENSION.value) { |
233 |
| - inputCondition = InputCondition.UNEQUALDIMENSION; |
234 |
| - } else if (inputState.get() == InputCondition.NAN.value) { |
235 |
| - inputCondition = InputCondition.NAN; |
236 |
| - } |
237 |
| - return inputCondition; |
238 | 214 | }
|
239 | 215 | }
|
0 commit comments