1+ #include < memory>
2+
13template <typename BeamSearchState, typename LiteBeamSearchState>
24class BeamSearchBase {
35 public:
46 BeamSearchBase () = default ;
57 BeamSearchBase (int beam_width) : beam_width_(beam_width) {}
68 void set_beam_width (int beam_width) { beam_width_ = beam_width; }
79 void register_state (const BeamSearchState &state) { states_.emplace_back (state); }
8- virtual void register_lite_state (LiteBeamSearchState &state) = 0;
9- virtual void search (
10- const function<void (BeamSearchState &)> &add_next_lite_states,
11- const function<BeamSearchState(LiteBeamSearchState &, BeamSearchState)> &to_next_state) = 0;
10+ virtual void register_lite_state (LiteBeamSearchState &lite_state, const BeamSearchState& state) = 0;
11+ virtual void reconstruct_states (const function<BeamSearchState(LiteBeamSearchState &, BeamSearchState)> &to_next_state) = 0;
12+ virtual void add_lite_states (const function<void (BeamSearchState &)> &add_next_lite_states) = 0;
1213 const BeamSearchState &get_best_state () const {
1314 if (states_.empty ()) {
1415 throw runtime_error (" No states are registered." );
@@ -32,26 +33,29 @@ template <class BeamSearchState, class LiteBeamSearchState>
3233class BeamSearch : public BeamSearchBase <BeamSearchState, LiteBeamSearchState> {
3334 public:
3435 using BeamSearchBase<BeamSearchState, LiteBeamSearchState>::BeamSearchBase;
35- void register_lite_state (LiteBeamSearchState &state) override {
36- state. state_idx = this -> current_state_idx_ ;
37- lite_states_.emplace_back (state );
36+ void register_lite_state (LiteBeamSearchState &lite_state, const BeamSearchState& state) override {
37+ lite_state. state_ref = make_shared<BeamSearchState>(state) ;
38+ lite_states_.emplace_back (lite_state );
3839 }
39- void search (
40- const function<void (BeamSearchState &)> &add_next_lite_states,
41- const function<BeamSearchState(LiteBeamSearchState &, BeamSearchState)> &to_next_state) override {
42- lite_states_.clear ();
43- for (this ->current_state_idx_ = 0 ; this ->current_state_idx_ < this ->states_ .size ();
44- this ->current_state_idx_ ++) {
45- add_next_lite_states (this ->states_ [this ->current_state_idx_ ]);
46- }
40+
41+ void reconstruct_states (const function<BeamSearchState(LiteBeamSearchState &, BeamSearchState)> &to_next_state) override {
4742 const int num_select = min ((int )lite_states_.size (), this ->beam_width_ );
4843 nth_element (lite_states_.begin (), lite_states_.begin () + num_select, lite_states_.end ());
4944 vector<BeamSearchState> next_states (num_select);
5045 for (int i = 0 ; i < num_select; i++) {
51- next_states[i] = to_next_state (lite_states_[i], this -> states_ [ lite_states_[i].state_idx ] );
46+ next_states[i] = to_next_state (lite_states_[i], * lite_states_[i].state_ref );
5247 }
5348 this ->states_ = move (next_states);
5449 }
50+
51+ void add_lite_states (const function<void (BeamSearchState &)> &add_next_lite_states) override {
52+ lite_states_.clear ();
53+ for (this ->current_state_idx_ = 0 ; this ->current_state_idx_ < this ->states_ .size ();
54+ this ->current_state_idx_ ++) {
55+ add_next_lite_states (this ->states_ [this ->current_state_idx_ ]);
56+ }
57+ }
58+
5559 int num_lite_states () const override { return lite_states_.size (); }
5660
5761 private:
@@ -62,25 +66,12 @@ template <class BeamSearchState, class LiteBeamSearchState>
6266class BeamSearchWithHash : public BeamSearchBase <BeamSearchState, LiteBeamSearchState> {
6367 public:
6468 using BeamSearchBase<BeamSearchState, LiteBeamSearchState>::BeamSearchBase;
65- void register_lite_state (LiteBeamSearchState &state) override {
66- state.state_idx = this ->current_state_idx_ ;
67- lite_states_.emplace (state);
68- // remove the worst (biggest) state.
69- while ((int )lite_states_.size () > this ->beam_width_ ) {
70- lite_states_.pop ();
71- }
69+ void register_lite_state (LiteBeamSearchState &lite_state, const BeamSearchState& state) override {
70+ lite_state.state_ref = make_shared<BeamSearchState>(state);
71+ lite_states_.emplace (lite_state);
7272 }
7373
74- void search (
75- const function<void (BeamSearchState &)> &add_next_lite_states,
76- const function<BeamSearchState(LiteBeamSearchState &, BeamSearchState)> &to_next_state) override {
77- // NOTE: alternative of clear() because STL doesn't have priority_queue::clear().
78- lite_states_ = priority_queue<LiteBeamSearchState>();
79- hash_values_.clear ();
80- for (this ->current_state_idx_ = 0 ; this ->current_state_idx_ < this ->states_ .size ();
81- this ->current_state_idx_ ++) {
82- add_next_lite_states (this ->states_ [this ->current_state_idx_ ]);
83- }
74+ void reconstruct_states (const function<BeamSearchState(LiteBeamSearchState &, BeamSearchState)> &to_next_state) override {
8475 const int num_select = min ((int )lite_states_.size (), this ->beam_width_ );
8576 vector<LiteBeamSearchState> lite_states_vec (lite_states_.size ());
8677 for (int i = (int )lite_states_vec.size () - 1 ; i >= 0 ; i--) {
@@ -89,37 +80,48 @@ class BeamSearchWithHash : public BeamSearchBase<BeamSearchState, LiteBeamSearch
8980 }
9081 vector<BeamSearchState> next_states (num_select);
9182 int num_next_states = 0 ;
92- for (int i = 0 ; i < ( int )lite_states_vec. size () ; i++) {
83+ for (int i = 0 ; i < num_select ; i++) {
9384 auto &lite_state = lite_states_vec[i];
9485 if (hash_values_.count (lite_state.hash_value )) {
9586 continue ;
9687 }
9788 hash_values_.insert (lite_state.hash_value );
98- next_states[num_next_states++] = to_next_state (lite_state, this -> states_ [ lite_state.state_idx ] );
89+ next_states[num_next_states++] = to_next_state (lite_state, * lite_state.state_ref );
9990 }
10091 next_states.resize (num_next_states);
10192 this ->states_ = move (next_states);
10293 }
94+
95+ void add_lite_states (const function<void (BeamSearchState &)> &add_next_lite_states) override {
96+ // NOTE: alternative of clear() because STL doesn't have priority_queue::clear().
97+ lite_states_ = priority_queue<LiteBeamSearchState>();
98+ hash_values_.clear ();
99+ for (this ->current_state_idx_ = 0 ; this ->current_state_idx_ < this ->states_ .size ();
100+ this ->current_state_idx_ ++) {
101+ add_next_lite_states (this ->states_ [this ->current_state_idx_ ]);
102+ }
103+ }
104+
103105 int num_lite_states () const override { return lite_states_.size (); }
104106
105107 private:
106108 priority_queue<LiteBeamSearchState> lite_states_;
107109 set<ll> hash_values_;
108110};
109111
110- template <typename score_t >
111- struct BeamSearchStateBase {
112- score_t score;
113- ll hash_value;
114- BeamSearchStateBase () : score(0 ) {}
115- virtual bool operator <( const BeamSearchStateBase &state) const = 0 ;
116- };
112+ // template <typename score_t>
113+ // struct BeamSearchState {
114+ // score_t score;
115+ // ll hash_value;
116+ // BeamSearchStateBase() : score(0), hash_value (0) {}
117+ // ~ BeamSearchStateBase() = default ;
118+ // };
117119
118- template <typename score_t >
119- struct LiteBeamSearchStateBase {
120- int state_idx ;
121- score_t score;
122- ll hash_value;
123- LiteBeamSearchStateBase () : state_idx(- 1 ), score(0 ), hash_value(0 ) {}
124- virtual bool operator <( const LiteBeamSearchStateBase &state) const = 0 ;
125- };
120+ // template <typename score_t>
121+ // struct LiteBeamSearchState {
122+ // shared_ptr<BeamSearchState<score_t>> state_ref ;
123+ // score_t score;
124+ // ll hash_value;
125+ // LiteBeamSearchState () : state_ref(nullptr ), score(0), hash_value(0) {}
126+ // ~LiteBeamSearchState() = default ;
127+ // };
0 commit comments