Skip to content

Commit 89fafae

Browse files
authored
Better Egor solver state handling (#168)
* Add solver final state in optim result * Rename random_seed in seed * Add Egor benchmark * Implement find best result using best_index state Instead of recomputing from the whole history dataset, we just compare new points with current best * Remove dead code (thanks clippy) * Make clippy happy * Remove brittle test
1 parent c87f86e commit 89fafae

File tree

12 files changed

+235
-78
lines changed

12 files changed

+235
-78
lines changed

ego/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,7 @@ criterion = "0.5"
6464
approx = "0.4"
6565
argmin_testfunctions = "0.2"
6666
serial_test = "3.1.0"
67+
68+
[[bench]]
69+
name = "ego"
70+
harness = false

ego/benches/ego.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
use criterion::{black_box, criterion_group, criterion_main, Criterion};
2+
use egobox_ego::{EgorBuilder, InfillStrategy};
3+
use egobox_moe::{CorrelationSpec, RegressionSpec};
4+
use ndarray::{array, Array2, ArrayView2, Zip};
5+
6+
/// Ackley test function: min f(x)=0 at x=(0, 0, 0)
7+
fn ackley(x: &ArrayView2<f64>) -> Array2<f64> {
8+
let mut y: Array2<f64> = Array2::zeros((x.nrows(), 1));
9+
Zip::from(y.rows_mut())
10+
.and(x.rows())
11+
.par_for_each(|mut yi, xi| yi.assign(&array![argmin_testfunctions::ackley(&xi.to_vec(),)]));
12+
y
13+
}
14+
15+
fn criterion_ego(c: &mut Criterion) {
16+
let xlimits = array![[-32.768, 32.768], [-32.768, 32.768], [-32.768, 32.768]];
17+
let mut group = c.benchmark_group("ego");
18+
group.bench_function("ego ackley", |b| {
19+
b.iter(|| {
20+
black_box(
21+
EgorBuilder::optimize(ackley)
22+
.configure(|config| {
23+
config
24+
.regression_spec(RegressionSpec::CONSTANT)
25+
.correlation_spec(CorrelationSpec::ABSOLUTEEXPONENTIAL)
26+
.infill_strategy(InfillStrategy::WB2S)
27+
.max_iters(10)
28+
.target(5e-1)
29+
.seed(42)
30+
})
31+
.min_within(&xlimits)
32+
.run()
33+
.expect("Minimize failure"),
34+
)
35+
});
36+
});
37+
38+
group.finish();
39+
}
40+
41+
criterion_group!(benches, criterion_ego);
42+
criterion_main!(benches);

ego/examples/g24.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ fn main() {
3434

3535
// We use Egor optimizer as a service
3636
let egor = EgorServiceBuilder::optimize()
37-
.configure(|config| config.n_cstr(2).random_seed(42))
37+
.configure(|config| config.n_cstr(2).seed(42))
3838
.min_within(&xlimits);
3939

4040
let mut y_doe = f_g24(&doe.view());

ego/src/egor.rs

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ impl<O: GroupFunc, SB: SurrogateBuilder> Egor<O, SB> {
197197
y_opt: result.state.get_full_best_cost().unwrap().to_owned(),
198198
x_hist: x_data,
199199
y_hist: y_data,
200+
state: result.state,
200201
}
201202
} else {
202203
let x_data = to_discrete_space(&xtypes, &x_data.view());
@@ -214,6 +215,7 @@ impl<O: GroupFunc, SB: SurrogateBuilder> Egor<O, SB> {
214215
y_opt: result.state.get_full_best_cost().unwrap().to_owned(),
215216
x_hist: x_data,
216217
y_hist: y_data,
218+
state: result.state,
217219
}
218220
};
219221
info!("Optim Result: min f(x)={} at x={}", res.y_opt, res.x_opt);
@@ -280,6 +282,7 @@ mod tests {
280282
.max_iters(20)
281283
.regression_spec(RegressionSpec::ALL)
282284
.correlation_spec(CorrelationSpec::ALL)
285+
.seed(1)
283286
})
284287
.min_within(&array![[0.0, 25.0]])
285288
.run()
@@ -321,27 +324,15 @@ mod tests {
321324
let xlimits = array![[0.0, 25.0]];
322325
let doe = Lhs::new(&xlimits).sample(10);
323326
let res = EgorBuilder::optimize(xsinx)
324-
.configure(|config| {
325-
config
326-
.max_iters(15)
327-
.doe(&doe)
328-
.outdir(outdir)
329-
.random_seed(42)
330-
})
327+
.configure(|config| config.max_iters(15).doe(&doe).outdir(outdir).seed(42))
331328
.min_within(&xlimits)
332329
.run()
333330
.expect("Minimize failure");
334331
let expected = array![18.9];
335332
assert_abs_diff_eq!(expected, res.x_opt, epsilon = 1e-1);
336333

337334
let res = EgorBuilder::optimize(xsinx)
338-
.configure(|config| {
339-
config
340-
.max_iters(5)
341-
.outdir(outdir)
342-
.hot_start(true)
343-
.random_seed(42)
344-
})
335+
.configure(|config| config.max_iters(5).outdir(outdir).hot_start(true).seed(42))
345336
.min_within(&xlimits)
346337
.run()
347338
.expect("Egor should minimize xsinx");
@@ -375,7 +366,7 @@ mod tests {
375366
.regression_spec(RegressionSpec::ALL)
376367
.correlation_spec(CorrelationSpec::ALL)
377368
.target(1e-2)
378-
.random_seed(42)
369+
.seed(42)
379370
})
380371
.min_within(&xlimits)
381372
.run()
@@ -395,7 +386,7 @@ mod tests {
395386
.with_rng(Xoshiro256Plus::seed_from_u64(42))
396387
.sample(10);
397388
let res = EgorBuilder::optimize(rosenb)
398-
.configure(|config| config.doe(&doe).max_iters(20).random_seed(42))
389+
.configure(|config| config.doe(&doe).max_iters(20).seed(42))
399390
.min_within(&xlimits)
400391
.run()
401392
.expect("Minimize failure");
@@ -445,7 +436,7 @@ mod tests {
445436
.doe(&doe)
446437
.max_iters(20)
447438
.cstr_tol(array![2e-6, 1e-6])
448-
.random_seed(42)
439+
.seed(42)
449440
})
450441
.min_within(&xlimits)
451442
.run()
@@ -474,7 +465,7 @@ mod tests {
474465
.doe(&doe)
475466
.target(-5.5030)
476467
.max_iters(30)
477-
.random_seed(42)
468+
.seed(42)
478469
})
479470
.min_within(&xlimits)
480471
.run()
@@ -508,7 +499,7 @@ mod tests {
508499
.max_iters(max_iters)
509500
.target(-15.1)
510501
.infill_strategy(InfillStrategy::EI)
511-
.random_seed(42)
502+
.seed(42)
512503
})
513504
.min_within_mixint_space(&xtypes)
514505
.run()
@@ -530,7 +521,7 @@ mod tests {
530521
.max_iters(max_iters)
531522
.target(-15.1)
532523
.infill_strategy(InfillStrategy::EI)
533-
.random_seed(42)
524+
.seed(42)
534525
})
535526
.min_within_mixint_space(&xtypes)
536527
.run()
@@ -550,7 +541,7 @@ mod tests {
550541
.regression_spec(egobox_moe::RegressionSpec::CONSTANT)
551542
.correlation_spec(egobox_moe::CorrelationSpec::SQUAREDEXPONENTIAL)
552543
.max_iters(max_iters)
553-
.random_seed(42)
544+
.seed(42)
554545
})
555546
.min_within_mixint_space(&xtypes)
556547
.run()
@@ -601,7 +592,7 @@ mod tests {
601592
.regression_spec(egobox_moe::RegressionSpec::CONSTANT)
602593
.correlation_spec(egobox_moe::CorrelationSpec::SQUAREDEXPONENTIAL)
603594
.max_iters(max_iters)
604-
.random_seed(42)
595+
.seed(42)
605596
})
606597
.min_within_mixint_space(&xtypes)
607598
.run()
@@ -632,7 +623,7 @@ mod tests {
632623
let xlimits = as_continuous_limits::<f64>(&xtypes);
633624

634625
EgorBuilder::optimize(mixobj)
635-
.configure(|config| config.outdir(outdir).max_iters(1).random_seed(42))
626+
.configure(|config| config.outdir(outdir).max_iters(1).seed(42))
636627
.min_within_mixint_space(&xtypes)
637628
.run()
638629
.unwrap();
@@ -644,13 +635,7 @@ mod tests {
644635
// Check that with no iteration, obj function is never called
645636
// as the DOE does not need to be evaluated!
646637
EgorBuilder::optimize(|_x| panic!("Should not call objective function!"))
647-
.configure(|config| {
648-
config
649-
.outdir(outdir)
650-
.hot_start(true)
651-
.max_iters(0)
652-
.random_seed(42)
653-
})
638+
.configure(|config| config.outdir(outdir).hot_start(true).max_iters(0).seed(42))
654639
.min_within_mixint_space(&xtypes)
655640
.run()
656641
.unwrap();

ego/src/egor_config.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ impl EgorConfig {
240240

241241
/// Allow to specify a seed for random number generator to allow
242242
/// reproducible runs.
243-
pub fn random_seed(mut self, seed: u64) -> Self {
243+
pub fn seed(mut self, seed: u64) -> Self {
244244
self.seed = Some(seed);
245245
self
246246
}

ego/src/egor_service.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
//! conf.regression_spec(RegressionSpec::ALL)
2121
//! .correlation_spec(CorrelationSpec::ALL)
2222
//! .infill_strategy(InfillStrategy::EI)
23-
//! .random_seed(42)
23+
//! .seed(42)
2424
//! })
2525
//! .min_within(&array![[0., 25.]]);
2626
//!
@@ -156,7 +156,7 @@ mod tests {
156156
conf.regression_spec(RegressionSpec::ALL)
157157
.correlation_spec(CorrelationSpec::ALL)
158158
.infill_strategy(InfillStrategy::EI)
159-
.random_seed(42)
159+
.seed(42)
160160
})
161161
.min_within(&array![[0., 25.]]);
162162

ego/src/egor_solver.rs

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ use crate::egor_config::EgorConfig;
109109
use crate::egor_state::{find_best_result_index, EgorState, MAX_POINT_ADDITION_RETRY};
110110
use crate::errors::{EgoError, Result};
111111

112-
use crate::mixint::*;
112+
use crate::{find_best_result_index_from, mixint::*};
113113

114114
use crate::optimizer::*;
115115

@@ -361,7 +361,7 @@ where
361361
let no_point_added_retries = MAX_POINT_ADDITION_RETRY;
362362

363363
let mut initial_state = state
364-
.data((x_data, y_data))
364+
.data((x_data, y_data.clone()))
365365
.clusterings(clusterings)
366366
.theta_inits(theta_inits)
367367
.sampling(sampling);
@@ -375,6 +375,10 @@ where
375375
.clone()
376376
.unwrap_or(Array1::from_elem(self.config.n_cstr, DEFAULT_CSTR_TOL));
377377
initial_state.target_cost = self.config.target;
378+
379+
let best_index = find_best_result_index(&y_data, &initial_state.cstr_tol);
380+
initial_state.best_index = Some(best_index);
381+
initial_state.last_best_iter = 0;
378382
debug!("INITIAL STATE = {:?}", initial_state);
379383
Ok((initial_state, None))
380384
}
@@ -437,7 +441,7 @@ where
437441

438442
let (x_dat, y_dat, infill_value) = self.next_points(
439443
init,
440-
new_state.get_iter(),
444+
state.get_iter(),
441445
recluster,
442446
&mut clusterings,
443447
&mut theta_inits,
@@ -532,23 +536,35 @@ where
532536
info!("Save doe shape {:?} in {:?}", doe.shape(), filepath);
533537
write_npy(filepath, &doe).expect("Write current doe");
534538
}
535-
let best_index = find_best_result_index(&y_data, &new_state.cstr_tol);
539+
540+
let best_index = find_best_result_index_from(
541+
state.best_index.unwrap(),
542+
y_data.nrows() - add_count as usize,
543+
&y_data,
544+
&new_state.cstr_tol,
545+
);
546+
// let best = find_best_result_index(&y_data, &new_state.cstr_tol);
547+
// assert!(best_index == best);
548+
new_state.best_index = Some(best_index);
536549
info!(
537550
"********* End iteration {}/{} in {:.3}s: Best fun(x)={} at x={}",
538-
new_state.get_iter() + 1,
539-
new_state.get_max_iters(),
551+
state.get_iter() + 1,
552+
state.get_max_iters(),
540553
now.elapsed().as_secs_f64(),
541554
y_data.row(best_index),
542555
x_data.row(best_index)
543556
);
544557
new_state = new_state.data((x_data, y_data.clone()));
558+
545559
Ok((new_state, None))
546560
}
547561

548562
fn terminate(&mut self, state: &EgorState<f64>) -> TerminationStatus {
549563
debug!(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> end iteration");
550564
debug!("Current Cost {:?}", state.get_cost());
551565
debug!("Best cost {:?}", state.get_best_cost());
566+
debug!("Best index {:?}", state.best_index);
567+
debug!("Data {:?}", state.data.as_ref().unwrap());
552568

553569
TerminationStatus::NotTerminated
554570
}

0 commit comments

Comments
 (0)