Skip to content

Commit a742005

Browse files
committed
update search
1 parent c0da818 commit a742005

File tree

3 files changed

+75
-55
lines changed

3 files changed

+75
-55
lines changed

docs/marathon/beam_search.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,15 @@ documentation_of: ./marathon/beam_search.cpp
99
- 状態 `BeamSearchState` と、軽量化した状態 `LiteBeamSearchState` の両方を定義して使う
1010
- 軽量化した状態は、参照する `BeamSearchState` のインデックス +α の情報を持たせる
1111
- 両方ともスコア計算の実装が必須
12-
- 以下の 2 つの関数を定義してビームサーチを実行する
12+
- 以下の関数を定義する
1313
- `add_next_lite_states`: 状態を受けとり、次の状態を軽量化したものを返す関数
1414
- `to_next_state`: 軽量化した状態を受けとり、それを元の (軽量化していない) 状態に変換する関数
15+
- 概ね次のようにしてビームサーチをする
16+
- 初期状態を登録
17+
- ターン数だけ次を繰り返す
18+
- 最初のターン以外: `reconstruct_states` で、軽量化した状態を元の状態に変換
19+
- `add_lite_states` で、軽量化した状態を登録
20+
- どのターンに対しても登録可能
1521

16-
AHC040 で使用実績あり。
17-
[https://atcoder.jp/contests/ahc040/submissions/61300845](https://atcoder.jp/contests/ahc040/submissions/61300845)
22+
AHC049 で使用実績あり。
23+
[https://atcoder.jp/contests/ahc049/submissions/67053066](https://atcoder.jp/contests/ahc049/submissions/67053066)

marathon/beam_search.cpp

Lines changed: 52 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1+
#include <memory>
2+
13
template <typename BeamSearchState, typename LiteBeamSearchState>
24
class BeamSearchBase {
35
public:
46
BeamSearchBase() = default;
57
BeamSearchBase(int beam_width) : beam_width_(beam_width) {}
68
void set_beam_width(int beam_width) { beam_width_ = beam_width; }
79
void register_state(const BeamSearchState &state) { states_.emplace_back(state); }
8-
virtual void register_lite_state(LiteBeamSearchState &state) = 0;
9-
virtual void search(
10-
const function<void(BeamSearchState &)> &add_next_lite_states,
11-
const function<BeamSearchState(LiteBeamSearchState &, BeamSearchState)> &to_next_state) = 0;
10+
virtual void register_lite_state(LiteBeamSearchState &lite_state, const BeamSearchState& state) = 0;
11+
virtual void reconstruct_states(const function<BeamSearchState(LiteBeamSearchState &, BeamSearchState)> &to_next_state) = 0;
12+
virtual void add_lite_states(const function<void(BeamSearchState &)> &add_next_lite_states) = 0;
1213
const BeamSearchState &get_best_state() const {
1314
if (states_.empty()) {
1415
throw runtime_error("No states are registered.");
@@ -32,26 +33,29 @@ template <class BeamSearchState, class LiteBeamSearchState>
3233
class BeamSearch : public BeamSearchBase<BeamSearchState, LiteBeamSearchState> {
3334
public:
3435
using BeamSearchBase<BeamSearchState, LiteBeamSearchState>::BeamSearchBase;
35-
void register_lite_state(LiteBeamSearchState &state) override {
36-
state.state_idx = this->current_state_idx_;
37-
lite_states_.emplace_back(state);
36+
void register_lite_state(LiteBeamSearchState &lite_state, const BeamSearchState& state) override {
37+
lite_state.state_ref = make_shared<BeamSearchState>(state);
38+
lite_states_.emplace_back(lite_state);
3839
}
39-
void search(
40-
const function<void(BeamSearchState &)> &add_next_lite_states,
41-
const function<BeamSearchState(LiteBeamSearchState &, BeamSearchState)> &to_next_state) override {
42-
lite_states_.clear();
43-
for (this->current_state_idx_ = 0; this->current_state_idx_ < this->states_.size();
44-
this->current_state_idx_++) {
45-
add_next_lite_states(this->states_[this->current_state_idx_]);
46-
}
40+
41+
void reconstruct_states(const function<BeamSearchState(LiteBeamSearchState &, BeamSearchState)> &to_next_state) override {
4742
const int num_select = min((int)lite_states_.size(), this->beam_width_);
4843
nth_element(lite_states_.begin(), lite_states_.begin() + num_select, lite_states_.end());
4944
vector<BeamSearchState> next_states(num_select);
5045
for (int i = 0; i < num_select; i++) {
51-
next_states[i] = to_next_state(lite_states_[i], this->states_[lite_states_[i].state_idx]);
46+
next_states[i] = to_next_state(lite_states_[i], *lite_states_[i].state_ref);
5247
}
5348
this->states_ = move(next_states);
5449
}
50+
51+
void add_lite_states(const function<void(BeamSearchState &)> &add_next_lite_states) override {
52+
lite_states_.clear();
53+
for (this->current_state_idx_ = 0; this->current_state_idx_ < this->states_.size();
54+
this->current_state_idx_++) {
55+
add_next_lite_states(this->states_[this->current_state_idx_]);
56+
}
57+
}
58+
5559
int num_lite_states() const override { return lite_states_.size(); }
5660

5761
private:
@@ -62,25 +66,12 @@ template <class BeamSearchState, class LiteBeamSearchState>
6266
class BeamSearchWithHash : public BeamSearchBase<BeamSearchState, LiteBeamSearchState> {
6367
public:
6468
using BeamSearchBase<BeamSearchState, LiteBeamSearchState>::BeamSearchBase;
65-
void register_lite_state(LiteBeamSearchState &state) override {
66-
state.state_idx = this->current_state_idx_;
67-
lite_states_.emplace(state);
68-
// remove the worst (biggest) state.
69-
while ((int)lite_states_.size() > this->beam_width_) {
70-
lite_states_.pop();
71-
}
69+
void register_lite_state(LiteBeamSearchState &lite_state, const BeamSearchState& state) override {
70+
lite_state.state_ref = make_shared<BeamSearchState>(state);
71+
lite_states_.emplace(lite_state);
7272
}
7373

74-
void search(
75-
const function<void(BeamSearchState &)> &add_next_lite_states,
76-
const function<BeamSearchState(LiteBeamSearchState &, BeamSearchState)> &to_next_state) override {
77-
// NOTE: alternative of clear() because STL doesn't have priority_queue::clear().
78-
lite_states_ = priority_queue<LiteBeamSearchState>();
79-
hash_values_.clear();
80-
for (this->current_state_idx_ = 0; this->current_state_idx_ < this->states_.size();
81-
this->current_state_idx_++) {
82-
add_next_lite_states(this->states_[this->current_state_idx_]);
83-
}
74+
void reconstruct_states(const function<BeamSearchState(LiteBeamSearchState &, BeamSearchState)> &to_next_state) override {
8475
const int num_select = min((int)lite_states_.size(), this->beam_width_);
8576
vector<LiteBeamSearchState> lite_states_vec(lite_states_.size());
8677
for (int i = (int)lite_states_vec.size() - 1; i >= 0; i--) {
@@ -89,37 +80,48 @@ class BeamSearchWithHash : public BeamSearchBase<BeamSearchState, LiteBeamSearch
8980
}
9081
vector<BeamSearchState> next_states(num_select);
9182
int num_next_states = 0;
92-
for (int i = 0; i < (int)lite_states_vec.size(); i++) {
83+
for (int i = 0; i < num_select; i++) {
9384
auto &lite_state = lite_states_vec[i];
9485
if (hash_values_.count(lite_state.hash_value)) {
9586
continue;
9687
}
9788
hash_values_.insert(lite_state.hash_value);
98-
next_states[num_next_states++] = to_next_state(lite_state, this->states_[lite_state.state_idx]);
89+
next_states[num_next_states++] = to_next_state(lite_state, *lite_state.state_ref);
9990
}
10091
next_states.resize(num_next_states);
10192
this->states_ = move(next_states);
10293
}
94+
95+
void add_lite_states(const function<void(BeamSearchState &)> &add_next_lite_states) override {
96+
// NOTE: alternative of clear() because STL doesn't have priority_queue::clear().
97+
lite_states_ = priority_queue<LiteBeamSearchState>();
98+
hash_values_.clear();
99+
for (this->current_state_idx_ = 0; this->current_state_idx_ < this->states_.size();
100+
this->current_state_idx_++) {
101+
add_next_lite_states(this->states_[this->current_state_idx_]);
102+
}
103+
}
104+
103105
int num_lite_states() const override { return lite_states_.size(); }
104106

105107
private:
106108
priority_queue<LiteBeamSearchState> lite_states_;
107109
set<ll> hash_values_;
108110
};
109111

110-
template <typename score_t>
111-
struct BeamSearchStateBase {
112-
score_t score;
113-
ll hash_value;
114-
BeamSearchStateBase() : score(0) {}
115-
virtual bool operator<(const BeamSearchStateBase &state) const = 0;
116-
};
112+
// template <typename score_t>
113+
// struct BeamSearchState {
114+
// score_t score;
115+
// ll hash_value;
116+
// BeamSearchStateBase() : score(0), hash_value(0) {}
117+
// ~BeamSearchStateBase() = default;
118+
// };
117119

118-
template <typename score_t>
119-
struct LiteBeamSearchStateBase {
120-
int state_idx;
121-
score_t score;
122-
ll hash_value;
123-
LiteBeamSearchStateBase() : state_idx(-1), score(0), hash_value(0) {}
124-
virtual bool operator<(const LiteBeamSearchStateBase &state) const = 0;
125-
};
120+
// template <typename score_t>
121+
// struct LiteBeamSearchState {
122+
// shared_ptr<BeamSearchState<score_t>> state_ref;
123+
// score_t score;
124+
// ll hash_value;
125+
// LiteBeamSearchState() : state_ref(nullptr), score(0), hash_value(0) {}
126+
// ~LiteBeamSearchState() = default;
127+
// };

marathon/chokudai_search.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class ChokudaiSearch {
99
search_time_ = search_time;
1010
num_iter_ = 0;
1111
time_check_iter_ = time_check_iter;
12-
max_num_states_.resize(max_turns, 1 << 30);
12+
max_num_states_.resize(max_turns + 1, 1 << 30);
1313
}
1414
void set_max_num_states(int max_num_state) {
1515
fill(max_num_states_.begin(), max_num_states_.end(), max_num_state);
@@ -24,6 +24,13 @@ class ChokudaiSearch {
2424
states_[turns].pop_max();
2525
}
2626
}
27+
size_t size(int turns) const {
28+
return states_[turns].size();
29+
}
30+
const State& worst_state(int turns) const {
31+
assert(states_[turns].size() > 0);
32+
return states_[turns].top_max();
33+
}
2734
void search(Timer &timer, const auto &add_next_states) {
2835
assert(num_iter_ >= 0);
2936
const double start_time = timer.getTime();
@@ -41,7 +48,12 @@ class ChokudaiSearch {
4148
TIME_OVER:;
4249
fprintf(stderr, "chokudai search: num_iter = %d\n", num_iter_);
4350
}
44-
const State &get_best_state() const { return states_.back().top_min(); }
51+
const State &get_best_state() const {
52+
if (states_.back().empty()) {
53+
throw runtime_error("No states are registered.");
54+
}
55+
return states_.back().top_min();
56+
}
4557

4658
private:
4759
double search_time_;

0 commit comments

Comments
 (0)