Skip to content

Commit 27537ca

Browse files
authored
Merge pull request #6046 from FlorentinD/maxtrials-21
Always run all concrete configs
2 parents 677510f + b177345 commit 27537ca

File tree

13 files changed

+154
-36
lines changed

13 files changed

+154
-36
lines changed

doc/modules/ROOT/pages/machine-learning/linkprediction-pipelines/config.adoc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -550,17 +550,16 @@ include::partial$/machine-learning/linkprediction-pipeline/pipelineInfoResult.ad
550550
[source, cypher, role=noplay]
551551
----
552552
CALL gds.alpha.pipeline.linkPrediction.configureAutoTuning('pipe', {
553-
maxTrials: 5
554-
})
555-
YIELD autoTuningConfig
553+
maxTrials: 2
554+
}) YIELD autoTuningConfig
556555
----
557556

558557
.Results
559558
[opts="header",cols="1"]
560559
|===
561560
| autoTuningConfig
562-
| {maxTrials=5}
561+
| {maxTrials=2}
563562
|===
564563

565-
We now reconfigured the auto-tuning to try out at most 5 model candidates during xref:machine-learning/linkprediction-pipelines/training.adoc[training].
564+
We now reconfigured the auto-tuning to try out at most 2 model candidates during xref:machine-learning/linkprediction-pipelines/training.adoc[training].
566565
--

doc/modules/ROOT/pages/machine-learning/linkprediction-pipelines/training.adoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ RETURN
262262
[opts="header", cols="6, 2, 2, 2, 6"]
263263
|===
264264
| winningModel | avgTrainScore | outerTrainScore | testScore | validationScores
265-
| {maxDepth=2147483647, minLeafSize=1, criterion=GINI, minSplitSize=2, numberOfDecisionTrees=10, methodName=RandomForest, numberOfSamplesRatio=1.0} | 0.779365079365079 | 0.788888888888889 | 0.766666666666667 | [0.3333333333333333, 0.6388888888888888, 0.3333333333333333, 0.3333333333333333, 0.3333333333333333]
265+
| {maxDepth=2147483647, minLeafSize=1, criterion=GINI, minSplitSize=2, numberOfDecisionTrees=10, methodName=RandomForest, numberOfSamplesRatio=1.0} | 0.779365079365079 | 0.788888888888889 | 0.766666666666667 | [0.3333333333333333, 0.6388888888888888, 0.3333333333333333, 0.3333333333333333]
266266
|===
267267

268268
We can see the RandomForest model configuration with `numberOfDecisionTrees = 10` (and defaults filled for remaining parameters) was selected, and has a score of `0.77` on the test set.

doc/modules/ROOT/pages/machine-learning/node-property-prediction/nodeclassification-pipelines/config.adoc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -468,16 +468,15 @@ include::partial$/machine-learning/node-property-prediction/pipelineInfoResult.a
468468
[source, cypher, role=noplay]
469469
----
470470
CALL gds.alpha.pipeline.nodeClassification.configureAutoTuning('pipe', {
471-
maxTrials: 5
472-
})
473-
YIELD autoTuningConfig
471+
maxTrials: 2
472+
}) YIELD autoTuningConfig
474473
----
475474

476475
.Results
477476
[opts="header",cols="1"]
478477
|===
479478
| autoTuningConfig
480-
| {maxTrials=5}
479+
| {maxTrials=2}
481480
|===
482481

483482
We now reconfigured the auto-tuning to try out at most 100 model candidates during xref::machine-learning/node-property-prediction/nodeclassification-pipelines/training.adoc[training].

doc/modules/ROOT/pages/machine-learning/node-property-prediction/nodeclassification-pipelines/training.adoc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ YIELD requiredMemory
245245
[opts="header"]
246246
|===
247247
| requiredMemory
248-
| +"[1264 KiB ... 1338 KiB]"+
248+
| +"[1186 KiB ... 1260 KiB]"+
249249
|===
250250
--
251251

@@ -283,7 +283,7 @@ RETURN
283283
[opts="header", cols="8, 2, 2, 2, 8"]
284284
|===
285285
| winningModel | avgTrainScore | outerTrainScore | testScore | validationScores
286-
| {maxEpochs=100, minEpochs=1, penalty=0.0, patience=1, methodName=LogisticRegression, batchSize=100, tolerance=0.001, learningRate=0.001} | 0.999999989939394 | 0.9999999912121211 | 0.999999985 | [0.4909090835454547, 0.07272727163636365, 0.4909090835454547, 0.4909090835454547, 0.4909090835454547]
286+
| {maxEpochs=100, minEpochs=1, penalty=0.0, patience=1, methodName=LogisticRegression, batchSize=100, tolerance=0.001, learningRate=0.001} | 0.999999989939394 | 0.9999999912121211 | 0.999999985 | [0.4909090835454547, 0.07272727163636365, 0.4909090835454547, 0.4909090835454547]
287287
|===
288288

289289
Here we can observe that the model candidate with penalty `0.0625` performed the best in the training phase, with an `F1_WEIGHTED` score nearing 1 over the train graph as well as on the test graph.

ml/ml-algo/src/main/java/org/neo4j/gds/ml/models/automl/RandomSearch.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,18 @@ public class RandomSearch implements HyperParameterOptimizer {
3535
private final List<TunableTrainerConfig> concreteConfigs;
3636
private final List<TunableTrainerConfig> tunableConfigs;
3737
private final int totalNumberOfTrials;
38+
39+
private final int numberOfConcreteTrials;
3840
private final SplittableRandom random;
3941
private int numberOfFinishedTrials;
4042

41-
public RandomSearch(Map<TrainingMethod, List<TunableTrainerConfig>> parameterSpace, int totalNumberOfTrials, long randomSeed) {
42-
this(parameterSpace, totalNumberOfTrials, Optional.of(randomSeed));
43+
public RandomSearch(Map<TrainingMethod, List<TunableTrainerConfig>> parameterSpace, int maxTrials, long randomSeed) {
44+
this(parameterSpace, maxTrials, Optional.of(randomSeed));
4345
}
4446

4547
public RandomSearch(
4648
Map<TrainingMethod, List<TunableTrainerConfig>> parameterSpace,
47-
int totalNumberOfTrials,
49+
int maxTrials,
4850
Optional<Long> randomSeed
4951
) {
5052
this.concreteConfigs = parameterSpace.values().stream()
@@ -55,15 +57,17 @@ public RandomSearch(
5557
.flatMap(List::stream)
5658
.filter(tunableTrainerConfig -> !tunableTrainerConfig.isConcrete())
5759
.collect(Collectors.toList());
58-
this.totalNumberOfTrials = totalNumberOfTrials;
60+
this.numberOfConcreteTrials = this.concreteConfigs.size();
61+
this.totalNumberOfTrials = maxTrials + numberOfConcreteTrials;
5962
this.random = randomSeed.map(SplittableRandom::new).orElseGet(SplittableRandom::new);
6063
this.numberOfFinishedTrials = 0;
6164
}
6265

6366

6467
@Override
6568
public boolean hasNext() {
66-
return numberOfFinishedTrials < totalNumberOfTrials;
69+
//There's a next trial to run if 1.there are more concrete trials or 2.there are actually tunable configs, and we haven't reached total number of allowed trials
70+
return (numberOfFinishedTrials < numberOfConcreteTrials) || (numberOfFinishedTrials < totalNumberOfTrials && !tunableConfigs.isEmpty());
6771
}
6872

6973
@Override

ml/ml-algo/src/test/java/org/neo4j/gds/ml/models/automl/RandomSearchTest.java

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,95 @@ void shouldProduceConcreteConfigsFirst() {
156156
);
157157
assertThat(randomSearch.hasNext()).isTrue();
158158
assertThat(randomSearch.next()).isInstanceOf(RandomForestTrainerConfig.class);
159+
assertThat(randomSearch.hasNext()).isTrue();
160+
assertThat(randomSearch.next()).isInstanceOf(LogisticRegressionTrainConfig.class);
161+
assertThat(randomSearch.hasNext()).isFalse();
162+
}
163+
164+
@Test
165+
void runsAllConcreteConfigsRegardlessOfMaxTrials() {
166+
var maxTrials = 2;
167+
var randomSearch = new RandomSearch(
168+
Map.of(
169+
TrainingMethod.LogisticRegression,
170+
List.of(
171+
TunableTrainerConfig.of(
172+
Map.of(
173+
"penalty", Map.of("range", List.of(1e-4, 1e4))
174+
),
175+
TrainingMethod.LogisticRegression
176+
),
177+
TunableTrainerConfig.of(
178+
Map.of(),
179+
TrainingMethod.RandomForestClassification
180+
),
181+
TunableTrainerConfig.of(
182+
Map.of(),
183+
TrainingMethod.RandomForestClassification
184+
),
185+
TunableTrainerConfig.of(
186+
Map.of(),
187+
TrainingMethod.RandomForestClassification
188+
),
189+
TunableTrainerConfig.of(
190+
Map.of(),
191+
TrainingMethod.RandomForestClassification
192+
)
193+
)
194+
),
195+
maxTrials,
196+
System.currentTimeMillis()
197+
);
198+
// first all the concrete configs
199+
for (int i = 0; i < 4; i++) {
200+
assertThat(randomSearch.hasNext()).isTrue();
201+
assertThat(randomSearch.next()).isInstanceOf(RandomForestTrainerConfig.class);
202+
}
203+
// then maxTrials (2) auto tuning configs
204+
assertThat(randomSearch.hasNext()).isTrue();
205+
assertThat(randomSearch.next()).isInstanceOf(LogisticRegressionTrainConfig.class);
206+
assertThat(randomSearch.hasNext()).isTrue();
207+
assertThat(randomSearch.next()).isInstanceOf(LogisticRegressionTrainConfig.class);
208+
// then no more
209+
assertThat(randomSearch.hasNext()).isFalse();
210+
}
211+
212+
@Test
213+
void foo() {
214+
var maxTrials = 5;
215+
var randomSearch = new RandomSearch(
216+
Map.of(
217+
TrainingMethod.LogisticRegression,
218+
List.of(
219+
TunableTrainerConfig.of(
220+
Map.of(
221+
"penalty", Map.of("range", List.of(1e-4, 1e4))
222+
),
223+
TrainingMethod.LogisticRegression
224+
),
225+
TunableTrainerConfig.of(
226+
Map.of(),
227+
TrainingMethod.RandomForestClassification
228+
),
229+
TunableTrainerConfig.of(
230+
Map.of(),
231+
TrainingMethod.RandomForestClassification
232+
),
233+
TunableTrainerConfig.of(
234+
Map.of(),
235+
TrainingMethod.RandomForestClassification
236+
)
237+
)
238+
),
239+
maxTrials,
240+
System.currentTimeMillis()
241+
);
242+
// first all the concrete configs
243+
for (int i = 0; i < 8; i++) {
244+
assertThat(randomSearch.hasNext()).isTrue();
245+
randomSearch.next();
246+
}
247+
// then no more
159248
assertThat(randomSearch.hasNext()).isFalse();
160249
}
161250
}

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/TrainingPipeline.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,21 @@ public Map<TrainingMethod, List<TunableTrainerConfig>> trainingParameterSpace()
114114
return trainingParameterSpace;
115115
}
116116

117-
private boolean hasOnlyConcreteTrainerConfigs() {
118-
return trainingParameterSpace().values().stream().flatMap(List::stream).allMatch(TunableTrainerConfig::isConcrete);
117+
private int concreteTrainerConfigsCount() {
118+
return (int) trainingParameterSpace()
119+
.values()
120+
.stream()
121+
.flatMap(List::stream)
122+
.filter(TunableTrainerConfig::isConcrete)
123+
.count();
119124
}
120125

121126
public int numberOfModelSelectionTrials() {
122-
return hasOnlyConcreteTrainerConfigs()
127+
int concreteTrainerConfigsCount = concreteTrainerConfigsCount();
128+
129+
return concreteTrainerConfigsCount == numberOfTrainerConfigs()
123130
? numberOfTrainerConfigs()
124-
: autoTuningConfig().maxTrials();
131+
: autoTuningConfig().maxTrials() + concreteTrainerConfigsCount;
125132
}
126133

127134
public void addTrainerConfig(TunableTrainerConfig trainingConfig) {

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/linkPipeline/train/LinkPredictionTrain.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ private void findBestModelCandidate(
185185
) {
186186
var modelCandidates = new RandomSearch(
187187
pipeline.trainingParameterSpace(),
188-
pipeline.numberOfModelSelectionTrials(),
188+
pipeline.autoTuningConfig().maxTrials(),
189189
config.randomSeed()
190190
);
191191

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/nodePipeline/classification/train/NodeClassificationTrain.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ private void findBestModelCandidate(ReadOnlyHugeLongArray trainNodeIds, Training
316316

317317
var modelCandidates = new RandomSearch(
318318
pipeline.trainingParameterSpace(),
319-
pipeline.numberOfModelSelectionTrials(),
319+
pipeline.autoTuningConfig().maxTrials(),
320320
trainConfig.randomSeed()
321321
);
322322

pipeline/src/main/java/org/neo4j/gds/ml/pipeline/nodePipeline/regression/NodeRegressionTrain.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ private void findBestModelCandidate(
164164

165165
var modelCandidates = new RandomSearch(
166166
pipeline.trainingParameterSpace(),
167-
pipeline.numberOfModelSelectionTrials(),
167+
pipeline.autoTuningConfig().maxTrials(),
168168
trainConfig.randomSeed()
169169
);
170170

0 commit comments

Comments
 (0)