@@ -115,26 +115,33 @@ def beam_search(
115
115
) -> torch .Tensor :
116
116
117
117
# Is this right?
118
- T = max_len
118
+ # T = max_len
119
119
N = vocab_size
120
120
121
+ emissions = model_kwargs ["encoder_outputs" ].get ("encoder_output" )
122
+
121
123
def update_func (emissions_ptr , N , T , prev_step_token_idxs , prev_step_model_states , timestep ):
122
- # `emissions_ptr` should always be the same (from encoder output)
123
- # N is not needed
124
- # T is not needed
124
+ # Currently, can only handle decoding one at a time
125
+ i = T # Access the current seq in inputs
126
+ new_model_kwargs = model_kwargs . copy ()
125
127
126
128
if timestep == 0 :
127
- prev_step_token_idxs = input_ids
129
+ prev_step_token_idxs = [ input_ids [ i ]]
128
130
prev_step_model_states = [
129
131
create_emitting_model_state (
130
132
Seq2SeqModelState (
131
133
timestep = 0 ,
132
134
hidden_states = None ,
133
- sequence = input_ids ,
135
+ sequence = input_ids [ i ]. unsqueeze ( 0 ) ,
134
136
lm_scores = None
135
137
)
136
138
)
137
139
]
140
+
141
+ new_model_kwargs ["encoder_outputs" ]["encoder_output" ] = emissions [i , :, :].unsqueeze (0 )
142
+
143
+ # import pdb
144
+ # pdb.set_trace()
138
145
139
146
out_probs , model_states = [], []
140
147
for idx , model_state_ptr in zip (prev_step_token_idxs , prev_step_model_states ):
@@ -145,10 +152,10 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
145
152
prev_model_state = get_obj_from_emitting_model_state (model_state_ptr )
146
153
147
154
# Create new decoder token ids
148
- new_input_ids = torch .cat ([prev_model_state .sequence [:, - 1 ], idx ], dim = - 1 )
155
+ new_input_ids = torch .cat ([prev_model_state .sequence , idx . unsqueeze ( 0 ) ], dim = - 1 )
149
156
150
157
# Forward pass
151
- model_inputs = self .model .prepare_inputs_for_generation (new_input_ids . unsqueeze ( dim = 0 ) , ** model_kwargs )
158
+ model_inputs = self .model .prepare_inputs_for_generation (new_input_ids , ** new_model_kwargs )
152
159
if self .is_huggingface_model :
153
160
model_inputs ["return_dict" ] = True
154
161
model_inputs ["output_hidden_states" ] = True
@@ -166,7 +173,7 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
166
173
Seq2SeqModelState (
167
174
timestep = timestep ,
168
175
hidden_states = outputs ["decoder_hidden_states" ],
169
- sequence = new_input_ids . unsqueeze ( dim = 0 ) ,
176
+ sequence = new_input_ids ,
170
177
lm_scores = lm_scores
171
178
)
172
179
)
@@ -179,7 +186,7 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
179
186
beam_size_token = self .model .config .vocab_size ,
180
187
beam_threshold = 50 ,
181
188
lm_weight = 0.0 ,
182
- eos_score = 1.0 ,
189
+ eos_score = 0.5 ,
183
190
log_add = True ,
184
191
)
185
192
@@ -191,21 +198,24 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
191
198
max_output_length = max_len
192
199
)
193
200
194
- emissions = model_kwargs ["encoder_outputs" ].get ("encoder_output" )
201
+ all_final_tokens = []
202
+ for i in range (len (input_ids )):
195
203
196
- decoder .decode_step (emissions .data_ptr (), T , N )
197
- hyps = decoder .get_all_final_hypothesis ()
204
+ decoder .decode_step (emissions .data_ptr (), i , N )
205
+ hyps = decoder .get_all_final_hypothesis ()
198
206
199
- token_scores = [(hyp .tokens , hyp .score ) for hyp in hyps ]
200
- max_tokens = max (token_scores , key = lambda x : x [1 ])
207
+ token_scores = [(hyp .tokens , hyp .score ) for hyp in hyps ]
208
+ max_tokens = max (token_scores , key = lambda x : x [1 ])
201
209
202
- filtered = list (filter (lambda x : x != - 1 , max_tokens [0 ]))
203
- final_tokens = [0 ] + filtered
204
-
205
- import pdb
206
- pdb .set_trace ()
210
+ filtered = list (filter (lambda x : x != - 1 , max_tokens [0 ]))
211
+ final_tokens = filtered
212
+
213
+ while len (final_tokens ) < max_len :
214
+ final_tokens += [0 ]
215
+
216
+ all_final_tokens .append (torch .Tensor (final_tokens ).to (torch .long ))
207
217
208
- return torch .Tensor ( final_tokens ). to ( torch . long )
218
+ return torch .stack ( all_final_tokens , dim = 0 )
209
219
210
220
def generate (
211
221
self ,
0 commit comments