1- /* *
2- * @brief ビームサーチ
3- * @docs docs/marathon/beam_search.md
4- */
5-
61template <typename BeamSearchState, typename LiteBeamSearchState>
72class BeamSearchBase {
83 public:
@@ -15,7 +10,7 @@ class BeamSearchBase {
1510 const function<void (BeamSearchState &)> &add_next_lite_states,
1611 const function<BeamSearchState(LiteBeamSearchState &, BeamSearchState)> &to_next_state) = 0;
1712 const BeamSearchState &get_best_state () const {
18- if (states_.empty ()) {
13+ if (states_.empty ()) {
1914 throw runtime_error (" No states are registered." );
2015 }
2116 return *min_element (states_.begin (), states_.end ());
@@ -45,14 +40,14 @@ class BeamSearch : public BeamSearchBase<BeamSearchState, LiteBeamSearchState> {
4540 const function<void (BeamSearchState &)> &add_next_lite_states,
4641 const function<BeamSearchState(LiteBeamSearchState &, BeamSearchState)> &to_next_state) override {
4742 lite_states_.clear ();
48- for (this ->current_state_idx_ = 0 ; this ->current_state_idx_ < this ->states_ .size ();
49- this ->current_state_idx_ ++) {
43+ for (this ->current_state_idx_ = 0 ; this ->current_state_idx_ < this ->states_ .size ();
44+ this ->current_state_idx_ ++) {
5045 add_next_lite_states (this ->states_ [this ->current_state_idx_ ]);
5146 }
5247 const int num_select = min ((int )lite_states_.size (), this ->beam_width_ );
5348 nth_element (lite_states_.begin (), lite_states_.begin () + num_select, lite_states_.end ());
5449 vector<BeamSearchState> next_states (num_select);
55- for (int i = 0 ; i < num_select; i++) {
50+ for (int i = 0 ; i < num_select; i++) {
5651 next_states[i] = to_next_state (lite_states_[i], this ->states_ [lite_states_[i].state_idx ]);
5752 }
5853 this ->states_ = move (next_states);
@@ -71,7 +66,7 @@ class BeamSearchWithHash : public BeamSearchBase<BeamSearchState, LiteBeamSearch
7166 state.state_idx = this ->current_state_idx_ ;
7267 lite_states_.emplace (state);
7368 // remove the worst (biggest) state.
74- while ((int )lite_states_.size () > this ->beam_width_ ) {
69+ while ((int )lite_states_.size () > this ->beam_width_ ) {
7570 lite_states_.pop ();
7671 }
7772 }
@@ -82,17 +77,21 @@ class BeamSearchWithHash : public BeamSearchBase<BeamSearchState, LiteBeamSearch
8277 // NOTE: alternative of clear() because STL doesn't have priority_queue::clear().
8378 lite_states_ = priority_queue<LiteBeamSearchState>();
8479 hash_values_.clear ();
85- for (this ->current_state_idx_ = 0 ; this ->current_state_idx_ < this ->states_ .size ();
86- this ->current_state_idx_ ++) {
80+ for (this ->current_state_idx_ = 0 ; this ->current_state_idx_ < this ->states_ .size ();
81+ this ->current_state_idx_ ++) {
8782 add_next_lite_states (this ->states_ [this ->current_state_idx_ ]);
8883 }
8984 const int num_select = min ((int )lite_states_.size (), this ->beam_width_ );
85+ vector<LiteBeamSearchState> lite_states_vec (lite_states_.size ());
86+ for (int i = (int )lite_states_vec.size () - 1 ; i >= 0 ; i--) {
87+ lite_states_vec[i] = lite_states_.top ();
88+ lite_states_.pop ();
89+ }
9090 vector<BeamSearchState> next_states (num_select);
9191 int num_next_states = 0 ;
92- while (lite_states_.size ()) {
93- auto lite_state = lite_states_.top ();
94- lite_states_.pop ();
95- if (hash_values_.count (lite_state.hash_value )) {
92+ for (int i = 0 ; i < (int )lite_states_vec.size (); i++) {
93+ auto &lite_state = lite_states_vec[i];
94+ if (hash_values_.count (lite_state.hash_value )) {
9695 continue ;
9796 }
9897 hash_values_.insert (lite_state.hash_value );
@@ -111,6 +110,7 @@ class BeamSearchWithHash : public BeamSearchBase<BeamSearchState, LiteBeamSearch
111110template <typename score_t >
112111struct BeamSearchStateBase {
113112 score_t score;
113+ ll hash_value;
114114 BeamSearchStateBase () : score(0 ) {}
115115 virtual bool operator <(const BeamSearchStateBase &state) const = 0 ;
116116};
0 commit comments