Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 183b80c

Browse files
committed
chpt3
1 parent b5c3210 commit 183b80c

File tree

2 files changed

+36
-26
lines changed

2 files changed

+36
-26
lines changed

test/torchtext_unittest/prototype/test_generate.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ def setUp(self) -> None:
1616
self.inputs = self.transform(
1717
[
1818
"summarize: studies have shown that owning a dog is good for you",
19-
# "translate English to German: That is good.",
20-
# "cola sentence: The course is jumping well.",
21-
# "stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.",
22-
# "summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi...",
19+
"translate English to German: That is good.",
20+
"cola sentence: The course is jumping well.",
21+
"stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.",
22+
"summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi...",
2323
]
2424
)
2525
torch.manual_seed(0)
@@ -55,7 +55,7 @@ def test_warns_when_no_max_len_provided(self, mock) -> None:
5555
def test_beam_search(self) -> None:
5656
generation_model = GenerationUtil(self.model)
5757

58-
tokens = generation_model.generate(self.inputs, num_beams=3, max_len=100)
58+
tokens = generation_model.generate(self.inputs, num_beams=3, max_len=30)
5959

6060
generated_text = self.transform.decode(tokens.tolist())
6161

torchtext/prototype/generate.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -115,26 +115,33 @@ def beam_search(
115115
) -> torch.Tensor:
116116

117117
# Is this right?
118-
T = max_len
118+
# T = max_len
119119
N = vocab_size
120120

121+
emissions = model_kwargs["encoder_outputs"].get("encoder_output")
122+
121123
def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_states, timestep):
122-
# `emissions_ptr` should always be the same (from encoder output)
123-
# N is not needed
124-
# T is not needed
124+
# Currently, can only handle decoding one at a time
125+
i = T # Access the current seq in inputs
126+
new_model_kwargs = model_kwargs.copy()
125127

126128
if timestep == 0:
127-
prev_step_token_idxs = input_ids
129+
prev_step_token_idxs = [input_ids[i]]
128130
prev_step_model_states = [
129131
create_emitting_model_state(
130132
Seq2SeqModelState(
131133
timestep=0,
132134
hidden_states=None,
133-
sequence=input_ids,
135+
sequence=input_ids[i].unsqueeze(0),
134136
lm_scores=None
135137
)
136138
)
137139
]
140+
141+
new_model_kwargs["encoder_outputs"]["encoder_output"] = emissions[i, :, :].unsqueeze(0)
142+
143+
# import pdb
144+
# pdb.set_trace()
138145

139146
out_probs, model_states = [], []
140147
for idx, model_state_ptr in zip(prev_step_token_idxs, prev_step_model_states):
@@ -145,10 +152,10 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
145152
prev_model_state = get_obj_from_emitting_model_state(model_state_ptr)
146153

147154
# Create new decoder token ids
148-
new_input_ids = torch.cat([prev_model_state.sequence[:, -1], idx], dim=-1)
155+
new_input_ids = torch.cat([prev_model_state.sequence, idx.unsqueeze(0)], dim=-1)
149156

150157
# Forward pass
151-
model_inputs = self.model.prepare_inputs_for_generation(new_input_ids.unsqueeze(dim=0), **model_kwargs)
158+
model_inputs = self.model.prepare_inputs_for_generation(new_input_ids, **new_model_kwargs)
152159
if self.is_huggingface_model:
153160
model_inputs["return_dict"] = True
154161
model_inputs["output_hidden_states"] = True
@@ -166,7 +173,7 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
166173
Seq2SeqModelState(
167174
timestep=timestep,
168175
hidden_states=outputs["decoder_hidden_states"],
169-
sequence=new_input_ids.unsqueeze(dim=0),
176+
sequence=new_input_ids,
170177
lm_scores=lm_scores
171178
)
172179
)
@@ -179,7 +186,7 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
179186
beam_size_token=self.model.config.vocab_size,
180187
beam_threshold=50,
181188
lm_weight=0.0,
182-
eos_score=1.0,
189+
eos_score=0.5,
183190
log_add=True,
184191
)
185192

@@ -191,21 +198,24 @@ def update_func(emissions_ptr, N, T, prev_step_token_idxs, prev_step_model_state
191198
max_output_length=max_len
192199
)
193200

194-
emissions = model_kwargs["encoder_outputs"].get("encoder_output")
201+
all_final_tokens = []
202+
for i in range(len(input_ids)):
195203

196-
decoder.decode_step(emissions.data_ptr(), T, N)
197-
hyps = decoder.get_all_final_hypothesis()
204+
decoder.decode_step(emissions.data_ptr(), i, N)
205+
hyps = decoder.get_all_final_hypothesis()
198206

199-
token_scores = [(hyp.tokens, hyp.score) for hyp in hyps]
200-
max_tokens = max(token_scores, key=lambda x: x[1])
207+
token_scores = [(hyp.tokens, hyp.score) for hyp in hyps]
208+
max_tokens = max(token_scores, key=lambda x: x[1])
201209

202-
filtered = list(filter(lambda x: x != -1, max_tokens[0]))
203-
final_tokens = [0] + filtered
204-
205-
import pdb
206-
pdb.set_trace()
210+
filtered = list(filter(lambda x: x != -1, max_tokens[0]))
211+
final_tokens = filtered
212+
213+
while len(final_tokens) < max_len:
214+
final_tokens += [0]
215+
216+
all_final_tokens.append(torch.Tensor(final_tokens).to(torch.long))
207217

208-
return torch.Tensor(final_tokens).to(torch.long)
218+
return torch.stack(all_final_tokens, dim=0)
209219

210220
def generate(
211221
self,

0 commit comments

Comments
 (0)