Skip to content

Commit 69194ee

Browse files
committed
Reset extraction
1 parent d61375c commit 69194ee

File tree

1 file changed

+107
-96
lines changed

1 file changed

+107
-96
lines changed

elk/extraction/extraction.py

Lines changed: 107 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
from dataclasses import InitVar, dataclass
66
from itertools import islice
77
from typing import Any, Iterable, Literal
8+
from warnings import filterwarnings
89

910
import torch
1011
from datasets import (
12+
Array2D,
1113
Array3D,
12-
ClassLabel,
1314
DatasetDict,
15+
DownloadMode,
1416
Features,
1517
Sequence,
1618
SplitDict,
@@ -20,23 +22,23 @@
2022
)
2123
from simple_parsing import Serializable, field
2224
from torch import Tensor
23-
from transformers import AutoConfig, AutoTokenizer, GPT2TokenizerFast
25+
from transformers import AutoConfig, PreTrainedModel
2426
from transformers.modeling_outputs import Seq2SeqLMOutput
2527

2628
from ..promptsource import DatasetTemplates
2729
from ..utils import (
2830
assert_type,
29-
convert_span,
3031
float32_to_int16,
32+
infer_label_column,
33+
infer_num_classes,
3134
instantiate_model,
35+
instantiate_tokenizer,
3236
is_autoregressive,
3337
select_train_val_splits,
3438
select_usable_devices,
3539
)
36-
from .balanced_sampler import BalancedSampler
3740
from .generator import _GeneratorBuilder
3841
from .prompt_loading import PromptConfig, load_prompts
39-
from ..rwkv_lm.rwkv_hf import RWKVConfig
4042

4143

4244
@dataclass
@@ -58,6 +60,7 @@ class Extract(Serializable):
5860
layers: tuple[int, ...] = ()
5961
layer_stride: InitVar[int] = 1
6062
token_loc: Literal["first", "last", "mean"] = "last"
63+
use_encoder_states: bool = False
6164

6265
def __post_init__(self, layer_stride: int):
6366
if self.layers and layer_stride > 1:
@@ -85,7 +88,7 @@ def explode(self) -> list["Extract"]:
8588
return copies
8689

8790

88-
@torch.no_grad()
91+
@torch.inference_mode()
8992
def extract_hiddens(
9093
cfg: "Extract",
9194
*,
@@ -99,135 +102,135 @@ def extract_hiddens(
99102

100103
# Silence datasets logging messages from all but the first process
101104
if rank != 0:
105+
filterwarnings("ignore")
102106
logging.disable(logging.CRITICAL)
103107

104-
ds_names = cfg.prompts.datasets
108+
p_cfg = cfg.prompts
109+
ds_names = p_cfg.datasets
105110
assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time."
106111

107-
prompt_ds = load_prompts(
108-
ds_names[0],
109-
split_type=split_type,
110-
stream=cfg.prompts.stream,
111-
rank=rank,
112-
world_size=world_size,
113-
) # this dataset is already sharded, but hasn't been truncated to max_examples
114-
115112
model = instantiate_model(
116113
cfg.model, torch_dtype="auto" if device != "cpu" else torch.float32
117114
).to(device)
118-
tokenizer = None
115+
tokenizer = instantiate_tokenizer(
116+
cfg.model, truncation_side="left", verbose=rank == 0
117+
)
119118

120-
if cfg.model.startswith("RWKV"):
121-
tokenizer = GPT2TokenizerFast(tokenizer_file='/home/kyle/repos/elk/elk/rwkv_lm/20B_tokenizer.json')
122-
else:
123-
tokenizer = AutoTokenizer.from_pretrained(
124-
cfg.model, truncation_side="left", verbose=False
125-
)
119+
is_enc_dec = model.config.is_encoder_decoder
120+
if is_enc_dec and cfg.use_encoder_states:
121+
assert hasattr(model, "get_encoder") and callable(model.get_encoder)
122+
model = assert_type(PreTrainedModel, model.get_encoder())
123+
is_enc_dec = False
126124

127-
has_lm_preds = is_autoregressive(model.config)
125+
has_lm_preds = is_autoregressive(model.config, not cfg.use_encoder_states)
128126
if has_lm_preds and rank == 0:
129127
print("Model has language model head, will store predictions.")
130128

129+
prompt_ds = load_prompts(
130+
ds_names[0],
131+
label_column=p_cfg.label_columns[0] if p_cfg.label_columns else None,
132+
num_classes=p_cfg.num_classes,
133+
split_type=split_type,
134+
stream=p_cfg.stream,
135+
rank=rank,
136+
world_size=world_size,
137+
)
138+
131139
# Iterating over questions
132140
layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers))
133141

134-
global_max_examples = cfg.prompts.max_examples[0 if split_type == "train" else 1]
142+
global_max_examples = p_cfg.max_examples[0 if split_type == "train" else 1]
135143
# break `max_examples` among the processes roughly equally
136144
max_examples = global_max_examples // world_size
137145
# the last process gets the remainder (which is usually small)
138146
if rank == world_size - 1:
139147
max_examples += global_max_examples % world_size
140148

141-
for example in islice(BalancedSampler(prompt_ds), max_examples):
149+
for example in islice(prompt_ds, max_examples):
142150
num_variants = len(example["prompts"])
151+
num_choices = len(example["prompts"][0])
152+
143153
hidden_dict = {
144154
f"hidden_{layer_idx}": torch.empty(
145155
num_variants,
146-
2, # contrast pair
156+
num_choices,
147157
model.config.hidden_size,
148158
device=device,
149159
dtype=torch.int16,
150160
)
151161
for layer_idx in layer_indices
152162
}
153-
lm_preds = torch.empty(
163+
lm_logits = torch.empty(
154164
num_variants,
155-
2, # contrast pair
165+
num_choices,
156166
device=device,
157167
dtype=torch.float32,
158168
)
159-
text_inputs = []
169+
text_questions = []
160170

161171
# Iterate over variants
162172
for i, record in enumerate(example["prompts"]):
163-
variant_inputs = []
173+
variant_questions = []
164174

165175
# Iterate over answers
166176
for j, choice in enumerate(record):
167-
text = choice["text"]
168-
169-
# TODO: Do something smarter than "rindex" here. Really we want to
170-
# get the span of the answer directly from Jinja, but that doesn't
171-
# seem possible. This approach may fail for complex templates.
172-
answer_start = text.rindex(choice["answer"])
177+
text = choice["question"]
173178

174179
# Only feed question, not the answer, to the encoder for enc-dec models
175-
if model.config.is_encoder_decoder:
176-
# TODO: Maybe make this more generic for complex templates?
177-
text = text[:answer_start].rstrip()
178-
target = choice["answer"]
179-
else:
180-
target = None
181-
182-
# Record the EXACT string we fed to the model
183-
variant_inputs.append(text)
184-
# inputs = None
185-
# if cfg.model.startswith("RWKV"):
186-
# inputs = tokenizer(
187-
# text,
188-
# return_offsets_mapping=True,
189-
# text_target=target, # type: ignore[arg-type]
190-
# truncation=True,
191-
# )
192-
# else:
193-
inputs = tokenizer(
180+
target = choice["answer"] if is_enc_dec else None
181+
182+
# Record the EXACT question we fed to the model
183+
variant_questions.append(text)
184+
encoding = tokenizer(
194185
text,
195-
return_offsets_mapping=True,
186+
add_special_tokens=False,
196187
return_tensors="pt",
197188
text_target=target, # type: ignore[arg-type]
198189
truncation=True,
199-
)
190+
).to(device)
191+
input_ids = assert_type(Tensor, encoding.input_ids)
192+
193+
if is_enc_dec:
194+
answer = assert_type(Tensor, encoding.labels)
195+
else:
196+
encoding2 = tokenizer(
197+
choice["answer"],
198+
add_special_tokens=False,
199+
return_tensors="pt",
200+
).to(device)
201+
answer = assert_type(Tensor, encoding2.input_ids)
200202

201-
# The offset_mapping is a sorted list of (start, end) tuples. We locate
202-
# the start of the answer in the tokenized sequence with binary search.
203-
offsets = inputs.pop("offset_mapping") if cfg.model.startswith("RWKV") else inputs.pop("offset_mapping").squeeze().tolist()
204-
inputs = inputs if cfg.model.startswith("RWKV") else inputs.to(device)
203+
input_ids = torch.cat([input_ids, answer], dim=-1)
204+
if max_len := tokenizer.model_max_length:
205+
input_ids = input_ids[..., -max_len:]
205206

206-
# Run the forward pass
207-
outputs = model(**inputs) if cfg.model.startswith("RWKV") else model(**inputs, output_hidden_states=True)
207+
# Make sure we only pass the arguments that the model expects
208+
inputs = dict(input_ids=input_ids)
209+
if is_enc_dec:
210+
inputs["labels"] = answer
211+
212+
with torch.autocast("cuda", enabled=torch.cuda.is_available()):
213+
outputs = model(**inputs, output_hidden_states=True)
208214

209215
# Compute the log probability of the answer tokens if available
210216
if has_lm_preds:
211-
start, end = convert_span(
212-
offsets, (answer_start, answer_start + len(choice["answer"]))
213-
)
214-
log_p = outputs.logits[..., start - 1 : end - 1, :].log_softmax(
215-
dim=-1
216-
)
217-
tokens = inputs.input_ids[..., start:end, None]
218-
lm_preds[i, j] = log_p.gather(-1, tokens).sum()
217+
answer_len = answer.shape[-1]
218+
219+
log_p = outputs.logits[..., -answer_len:, :].log_softmax(dim=-1)
220+
tokens = answer[..., None]
221+
lm_logits[i, j] = log_p.gather(-1, tokens).sum()
219222

220223
elif isinstance(outputs, Seq2SeqLMOutput):
221224
# The cross entropy loss is averaged over tokens, so we need to
222225
# multiply by the length to get the total log probability.
223-
length = inputs.labels.shape[-1]
224-
lm_preds[i, j] = -assert_type(Tensor, outputs.loss) * length
226+
length = encoding.labels.shape[-1]
227+
lm_logits[i, j] = -assert_type(Tensor, outputs.loss) * length
225228

226-
hiddens = outputs if cfg.model.startswith("RWKV") else (
229+
hiddens = (
227230
outputs.get("decoder_hidden_states") or outputs["hidden_states"]
228231
)
229232
# First element of list is the input embeddings
230-
hiddens = hiddens if cfg.model.startswith("RWKV") else hiddens[1:]
233+
hiddens = hiddens[1:]
231234

232235
# Throw out layers we don't care about
233236
hiddens = [hiddens[i] for i in layer_indices]
@@ -245,17 +248,16 @@ def extract_hiddens(
245248
for layer_idx, hidden in zip(layer_indices, hiddens):
246249
hidden_dict[f"hidden_{layer_idx}"][i, j] = float32_to_int16(hidden)
247250

248-
text_inputs.append(variant_inputs)
251+
text_questions.append(variant_questions)
249252

250253
out_record: dict[str, Any] = dict(
251254
label=example["label"],
252255
variant_ids=example["template_names"],
253-
text_inputs=text_inputs,
256+
text_questions=text_questions,
254257
**hidden_dict,
255258
)
256259
if has_lm_preds:
257-
# We only need the probability of the positive example since this is binary
258-
out_record["model_preds"] = lm_preds.softmax(dim=-1)[..., 1]
260+
out_record["model_logits"] = lm_logits
259261

260262
yield out_record
261263

@@ -266,7 +268,11 @@ def _extraction_worker(**kwargs):
266268

267269

268270
def extract(
269-
cfg: "Extract", num_gpus: int = -1, min_gpu_mem: int | None = None
271+
cfg: "Extract",
272+
*,
273+
disable_cache: bool = False,
274+
num_gpus: int = -1,
275+
min_gpu_mem: int | None = None,
270276
) -> DatasetDict:
271277
"""Extract hidden states from a model and return a `DatasetDict` containing them."""
272278

@@ -292,15 +298,18 @@ def get_splits() -> SplitDict:
292298
dataset_name=available_splits.dataset_name,
293299
)
294300

295-
model_cfg = None
296-
if cfg.model.startswith("RWKV"):
297-
model_cfg = RWKVConfig()
298-
else:
299-
model_cfg = AutoConfig.from_pretrained(cfg.model)
301+
model_cfg = AutoConfig.from_pretrained(cfg.model)
300302

301303
ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ")
302304
info = get_dataset_config_info(ds_name, config_name or None)
303305

306+
ds_features = assert_type(Features, info.features)
307+
label_col = (
308+
cfg.prompts.label_columns[0]
309+
if cfg.prompts.label_columns
310+
else infer_label_column(ds_features)
311+
)
312+
num_classes = cfg.prompts.num_classes or infer_num_classes(ds_features[label_col])
304313
num_variants = cfg.prompts.num_variants
305314
if num_variants < 0:
306315
prompter = DatasetTemplates(ds_name, config_name)
@@ -309,7 +318,7 @@ def get_splits() -> SplitDict:
309318
layer_cols = {
310319
f"hidden_{layer}": Array3D(
311320
dtype="int16",
312-
shape=(num_variants, 2, model_cfg.hidden_size),
321+
shape=(num_variants, num_classes, model_cfg.hidden_size),
313322
)
314323
for layer in cfg.layers or range(model_cfg.num_hidden_layers)
315324
}
@@ -318,21 +327,20 @@ def get_splits() -> SplitDict:
318327
Value(dtype="string"),
319328
length=num_variants,
320329
),
321-
"label": ClassLabel(names=["neg", "pos"]),
322-
"text_inputs": Sequence(
330+
"label": Value(dtype="int64"),
331+
"text_questions": Sequence(
323332
Sequence(
324333
Value(dtype="string"),
325-
length=2,
326334
),
327335
length=num_variants,
328336
),
329337
}
330338

331-
# Only add model_preds if the model is an autoregressive model
332-
if is_autoregressive(model_cfg):
333-
other_cols["model_preds"] = Sequence(
334-
Value(dtype="float32"),
335-
length=num_variants,
339+
# Only add model_logits if the model is an autoregressive model
340+
if is_autoregressive(model_cfg, not cfg.use_encoder_states):
341+
other_cols["model_logits"] = Array2D(
342+
shape=(num_variants, num_classes),
343+
dtype="float32",
336344
)
337345

338346
devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem)
@@ -361,7 +369,10 @@ def get_splits() -> SplitDict:
361369

362370
ds = dict()
363371
for split, builder in builders.items():
364-
builder.download_and_prepare(num_proc=len(devices))
372+
builder.download_and_prepare(
373+
download_mode=DownloadMode.FORCE_REDOWNLOAD if disable_cache else None,
374+
num_proc=len(devices),
375+
)
365376
ds[split] = builder.as_dataset(split=split)
366377

367-
return DatasetDict(ds)
378+
return DatasetDict(ds)

0 commit comments

Comments
 (0)