55from dataclasses import InitVar , dataclass
66from itertools import islice
77from typing import Any , Iterable , Literal
8+ from warnings import filterwarnings
89
910import torch
1011from datasets import (
12+ Array2D ,
1113 Array3D ,
12- ClassLabel ,
1314 DatasetDict ,
15+ DownloadMode ,
1416 Features ,
1517 Sequence ,
1618 SplitDict ,
2022)
2123from simple_parsing import Serializable , field
2224from torch import Tensor
23- from transformers import AutoConfig , AutoTokenizer , GPT2TokenizerFast
25+ from transformers import AutoConfig , PreTrainedModel
2426from transformers .modeling_outputs import Seq2SeqLMOutput
2527
2628from ..promptsource import DatasetTemplates
2729from ..utils import (
2830 assert_type ,
29- convert_span ,
3031 float32_to_int16 ,
32+ infer_label_column ,
33+ infer_num_classes ,
3134 instantiate_model ,
35+ instantiate_tokenizer ,
3236 is_autoregressive ,
3337 select_train_val_splits ,
3438 select_usable_devices ,
3539)
36- from .balanced_sampler import BalancedSampler
3740from .generator import _GeneratorBuilder
3841from .prompt_loading import PromptConfig , load_prompts
39- from ..rwkv_lm .rwkv_hf import RWKVConfig
4042
4143
4244@dataclass
@@ -58,6 +60,7 @@ class Extract(Serializable):
5860 layers : tuple [int , ...] = ()
5961 layer_stride : InitVar [int ] = 1
6062 token_loc : Literal ["first" , "last" , "mean" ] = "last"
63+ use_encoder_states : bool = False
6164
6265 def __post_init__ (self , layer_stride : int ):
6366 if self .layers and layer_stride > 1 :
@@ -85,7 +88,7 @@ def explode(self) -> list["Extract"]:
8588 return copies
8689
8790
88- @torch .no_grad ()
91+ @torch .inference_mode ()
8992def extract_hiddens (
9093 cfg : "Extract" ,
9194 * ,
@@ -99,135 +102,135 @@ def extract_hiddens(
99102
100103 # Silence datasets logging messages from all but the first process
101104 if rank != 0 :
105+ filterwarnings ("ignore" )
102106 logging .disable (logging .CRITICAL )
103107
104- ds_names = cfg .prompts .datasets
108+ p_cfg = cfg .prompts
109+ ds_names = p_cfg .datasets
105110 assert len (ds_names ) == 1 , "Can only extract hiddens from one dataset at a time."
106111
107- prompt_ds = load_prompts (
108- ds_names [0 ],
109- split_type = split_type ,
110- stream = cfg .prompts .stream ,
111- rank = rank ,
112- world_size = world_size ,
113- ) # this dataset is already sharded, but hasn't been truncated to max_examples
114-
115112 model = instantiate_model (
116113 cfg .model , torch_dtype = "auto" if device != "cpu" else torch .float32
117114 ).to (device )
118- tokenizer = None
115+ tokenizer = instantiate_tokenizer (
116+ cfg .model , truncation_side = "left" , verbose = rank == 0
117+ )
119118
120- if cfg .model .startswith ("RWKV" ):
121- tokenizer = GPT2TokenizerFast (tokenizer_file = '/home/kyle/repos/elk/elk/rwkv_lm/20B_tokenizer.json' )
122- else :
123- tokenizer = AutoTokenizer .from_pretrained (
124- cfg .model , truncation_side = "left" , verbose = False
125- )
119+ is_enc_dec = model .config .is_encoder_decoder
120+ if is_enc_dec and cfg .use_encoder_states :
121+ assert hasattr (model , "get_encoder" ) and callable (model .get_encoder )
122+ model = assert_type (PreTrainedModel , model .get_encoder ())
123+ is_enc_dec = False
126124
127- has_lm_preds = is_autoregressive (model .config )
125+ has_lm_preds = is_autoregressive (model .config , not cfg . use_encoder_states )
128126 if has_lm_preds and rank == 0 :
129127 print ("Model has language model head, will store predictions." )
130128
129+ prompt_ds = load_prompts (
130+ ds_names [0 ],
131+ label_column = p_cfg .label_columns [0 ] if p_cfg .label_columns else None ,
132+ num_classes = p_cfg .num_classes ,
133+ split_type = split_type ,
134+ stream = p_cfg .stream ,
135+ rank = rank ,
136+ world_size = world_size ,
137+ )
138+
131139 # Iterating over questions
132140 layer_indices = cfg .layers or tuple (range (model .config .num_hidden_layers ))
133141
134- global_max_examples = cfg . prompts .max_examples [0 if split_type == "train" else 1 ]
142+ global_max_examples = p_cfg .max_examples [0 if split_type == "train" else 1 ]
135143 # break `max_examples` among the processes roughly equally
136144 max_examples = global_max_examples // world_size
137145 # the last process gets the remainder (which is usually small)
138146 if rank == world_size - 1 :
139147 max_examples += global_max_examples % world_size
140148
141- for example in islice (BalancedSampler ( prompt_ds ) , max_examples ):
149+ for example in islice (prompt_ds , max_examples ):
142150 num_variants = len (example ["prompts" ])
151+ num_choices = len (example ["prompts" ][0 ])
152+
143153 hidden_dict = {
144154 f"hidden_{ layer_idx } " : torch .empty (
145155 num_variants ,
146- 2 , # contrast pair
156+ num_choices ,
147157 model .config .hidden_size ,
148158 device = device ,
149159 dtype = torch .int16 ,
150160 )
151161 for layer_idx in layer_indices
152162 }
153- lm_preds = torch .empty (
163+ lm_logits = torch .empty (
154164 num_variants ,
155- 2 , # contrast pair
165+ num_choices ,
156166 device = device ,
157167 dtype = torch .float32 ,
158168 )
159- text_inputs = []
169+ text_questions = []
160170
161171 # Iterate over variants
162172 for i , record in enumerate (example ["prompts" ]):
163- variant_inputs = []
173+ variant_questions = []
164174
165175 # Iterate over answers
166176 for j , choice in enumerate (record ):
167- text = choice ["text" ]
168-
169- # TODO: Do something smarter than "rindex" here. Really we want to
170- # get the span of the answer directly from Jinja, but that doesn't
171- # seem possible. This approach may fail for complex templates.
172- answer_start = text .rindex (choice ["answer" ])
177+ text = choice ["question" ]
173178
174179 # Only feed question, not the answer, to the encoder for enc-dec models
175- if model .config .is_encoder_decoder :
176- # TODO: Maybe make this more generic for complex templates?
177- text = text [:answer_start ].rstrip ()
178- target = choice ["answer" ]
179- else :
180- target = None
181-
182- # Record the EXACT string we fed to the model
183- variant_inputs .append (text )
184- # inputs = None
185- # if cfg.model.startswith("RWKV"):
186- # inputs = tokenizer(
187- # text,
188- # return_offsets_mapping=True,
189- # text_target=target, # type: ignore[arg-type]
190- # truncation=True,
191- # )
192- # else:
193- inputs = tokenizer (
180+ target = choice ["answer" ] if is_enc_dec else None
181+
182+ # Record the EXACT question we fed to the model
183+ variant_questions .append (text )
184+ encoding = tokenizer (
194185 text ,
195- return_offsets_mapping = True ,
186+ add_special_tokens = False ,
196187 return_tensors = "pt" ,
197188 text_target = target , # type: ignore[arg-type]
198189 truncation = True ,
199- )
190+ ).to (device )
191+ input_ids = assert_type (Tensor , encoding .input_ids )
192+
193+ if is_enc_dec :
194+ answer = assert_type (Tensor , encoding .labels )
195+ else :
196+ encoding2 = tokenizer (
197+ choice ["answer" ],
198+ add_special_tokens = False ,
199+ return_tensors = "pt" ,
200+ ).to (device )
201+ answer = assert_type (Tensor , encoding2 .input_ids )
200202
201- # The offset_mapping is a sorted list of (start, end) tuples. We locate
202- # the start of the answer in the tokenized sequence with binary search.
203- offsets = inputs .pop ("offset_mapping" ) if cfg .model .startswith ("RWKV" ) else inputs .pop ("offset_mapping" ).squeeze ().tolist ()
204- inputs = inputs if cfg .model .startswith ("RWKV" ) else inputs .to (device )
203+ input_ids = torch .cat ([input_ids , answer ], dim = - 1 )
204+ if max_len := tokenizer .model_max_length :
205+ input_ids = input_ids [..., - max_len :]
205206
206- # Run the forward pass
207- outputs = model (** inputs ) if cfg .model .startswith ("RWKV" ) else model (** inputs , output_hidden_states = True )
207+ # Make sure we only pass the arguments that the model expects
208+ inputs = dict (input_ids = input_ids )
209+ if is_enc_dec :
210+ inputs ["labels" ] = answer
211+
212+ with torch .autocast ("cuda" , enabled = torch .cuda .is_available ()):
213+ outputs = model (** inputs , output_hidden_states = True )
208214
209215 # Compute the log probability of the answer tokens if available
210216 if has_lm_preds :
211- start , end = convert_span (
212- offsets , (answer_start , answer_start + len (choice ["answer" ]))
213- )
214- log_p = outputs .logits [..., start - 1 : end - 1 , :].log_softmax (
215- dim = - 1
216- )
217- tokens = inputs .input_ids [..., start :end , None ]
218- lm_preds [i , j ] = log_p .gather (- 1 , tokens ).sum ()
217+ answer_len = answer .shape [- 1 ]
218+
219+ log_p = outputs .logits [..., - answer_len :, :].log_softmax (dim = - 1 )
220+ tokens = answer [..., None ]
221+ lm_logits [i , j ] = log_p .gather (- 1 , tokens ).sum ()
219222
220223 elif isinstance (outputs , Seq2SeqLMOutput ):
221224 # The cross entropy loss is averaged over tokens, so we need to
222225 # multiply by the length to get the total log probability.
223- length = inputs .labels .shape [- 1 ]
224- lm_preds [i , j ] = - assert_type (Tensor , outputs .loss ) * length
226+ length = encoding .labels .shape [- 1 ]
227+ lm_logits [i , j ] = - assert_type (Tensor , outputs .loss ) * length
225228
226- hiddens = outputs if cfg . model . startswith ( "RWKV" ) else (
229+ hiddens = (
227230 outputs .get ("decoder_hidden_states" ) or outputs ["hidden_states" ]
228231 )
229232 # First element of list is the input embeddings
230- hiddens = hiddens if cfg . model . startswith ( "RWKV" ) else hiddens [1 :]
233+ hiddens = hiddens [1 :]
231234
232235 # Throw out layers we don't care about
233236 hiddens = [hiddens [i ] for i in layer_indices ]
@@ -245,17 +248,16 @@ def extract_hiddens(
245248 for layer_idx , hidden in zip (layer_indices , hiddens ):
246249 hidden_dict [f"hidden_{ layer_idx } " ][i , j ] = float32_to_int16 (hidden )
247250
248- text_inputs .append (variant_inputs )
251+ text_questions .append (variant_questions )
249252
250253 out_record : dict [str , Any ] = dict (
251254 label = example ["label" ],
252255 variant_ids = example ["template_names" ],
253- text_inputs = text_inputs ,
256+ text_questions = text_questions ,
254257 ** hidden_dict ,
255258 )
256259 if has_lm_preds :
257- # We only need the probability of the positive example since this is binary
258- out_record ["model_preds" ] = lm_preds .softmax (dim = - 1 )[..., 1 ]
260+ out_record ["model_logits" ] = lm_logits
259261
260262 yield out_record
261263
@@ -266,7 +268,11 @@ def _extraction_worker(**kwargs):
266268
267269
268270def extract (
269- cfg : "Extract" , num_gpus : int = - 1 , min_gpu_mem : int | None = None
271+ cfg : "Extract" ,
272+ * ,
273+ disable_cache : bool = False ,
274+ num_gpus : int = - 1 ,
275+ min_gpu_mem : int | None = None ,
270276) -> DatasetDict :
271277 """Extract hidden states from a model and return a `DatasetDict` containing them."""
272278
@@ -292,15 +298,18 @@ def get_splits() -> SplitDict:
292298 dataset_name = available_splits .dataset_name ,
293299 )
294300
295- model_cfg = None
296- if cfg .model .startswith ("RWKV" ):
297- model_cfg = RWKVConfig ()
298- else :
299- model_cfg = AutoConfig .from_pretrained (cfg .model )
301+ model_cfg = AutoConfig .from_pretrained (cfg .model )
300302
301303 ds_name , _ , config_name = cfg .prompts .datasets [0 ].partition (" " )
302304 info = get_dataset_config_info (ds_name , config_name or None )
303305
306+ ds_features = assert_type (Features , info .features )
307+ label_col = (
308+ cfg .prompts .label_columns [0 ]
309+ if cfg .prompts .label_columns
310+ else infer_label_column (ds_features )
311+ )
312+ num_classes = cfg .prompts .num_classes or infer_num_classes (ds_features [label_col ])
304313 num_variants = cfg .prompts .num_variants
305314 if num_variants < 0 :
306315 prompter = DatasetTemplates (ds_name , config_name )
@@ -309,7 +318,7 @@ def get_splits() -> SplitDict:
309318 layer_cols = {
310319 f"hidden_{ layer } " : Array3D (
311320 dtype = "int16" ,
312- shape = (num_variants , 2 , model_cfg .hidden_size ),
321+ shape = (num_variants , num_classes , model_cfg .hidden_size ),
313322 )
314323 for layer in cfg .layers or range (model_cfg .num_hidden_layers )
315324 }
@@ -318,21 +327,20 @@ def get_splits() -> SplitDict:
318327 Value (dtype = "string" ),
319328 length = num_variants ,
320329 ),
321- "label" : ClassLabel ( names = [ "neg" , "pos" ] ),
322- "text_inputs " : Sequence (
330+ "label" : Value ( dtype = "int64" ),
331+ "text_questions " : Sequence (
323332 Sequence (
324333 Value (dtype = "string" ),
325- length = 2 ,
326334 ),
327335 length = num_variants ,
328336 ),
329337 }
330338
331- # Only add model_preds if the model is an autoregressive model
332- if is_autoregressive (model_cfg ):
333- other_cols ["model_preds " ] = Sequence (
334- Value ( dtype = "float32" ),
335- length = num_variants ,
339+ # Only add model_logits if the model is an autoregressive model
340+ if is_autoregressive (model_cfg , not cfg . use_encoder_states ):
341+ other_cols ["model_logits " ] = Array2D (
342+ shape = ( num_variants , num_classes ),
343+ dtype = "float32" ,
336344 )
337345
338346 devices = select_usable_devices (num_gpus , min_memory = min_gpu_mem )
@@ -361,7 +369,10 @@ def get_splits() -> SplitDict:
361369
362370 ds = dict ()
363371 for split , builder in builders .items ():
364- builder .download_and_prepare (num_proc = len (devices ))
372+ builder .download_and_prepare (
373+ download_mode = DownloadMode .FORCE_REDOWNLOAD if disable_cache else None ,
374+ num_proc = len (devices ),
375+ )
365376 ds [split ] = builder .as_dataset (split = split )
366377
367- return DatasetDict (ds )
378+ return DatasetDict (ds )
0 commit comments