Paper Link: https://arxiv.org/abs/2411.07404
First, clone the repo:
git clone [email protected]:kdu4108/context-vs-prior-finetuning.git
Create an environment and install dependencies via virtualenv/pip:
python3.11 -m venv env
source env/bin/activate
pip install -r requirements.txt
Here are the steps to regenerate the experiments in the paper. The key steps are (1) accessing or regenerating the dataset and (2) running the main entry point, main.py.
Run all cells in preprocessing/preprocess_fakepedia_train_and_dev.ipynb.
Run python preprocessing/generate_arithmetic.py.
The main entry point to run a single experiment is main.py. The most important arguments to this script are:
DATASET_NAME(positional argument, determines which dataset to run the experiment on. Must exactly match the name of a dataset defined inpreprocessing/datasets.py).--SUBSPLIT(the subsplit of the dataset to use)--MODEL_ID(the model name in huggingface)--CONTEXT_WEIGHT_FORMAT(the context weight format to use for training examples, e.g.,floatorinstruction.)--EVALS(the list of evals to run the model on. Must be a List of Dicts containing at leastdataset_name,subsplit,k_demonstrations, andcontext_weight_format. Optionally also includedo_steering.)--PROJECTION-PATH(path to a saved projection for training)
The remaining arguments (visible via python main.py --help) are either dataset-specific (e.g., specify the --QUERY_ID if running an experiment with DATASET_NAME="YagoECQ"), allow for control over other experiment details (e.g., which query types to use, the model's batch size for inference, how to sample entities, etc.), or steering-specific hyperparameters (e.g., steering values).
An example command to finetune a model and evaluate on the same dataset is:
python main.py BaseFakepedia -M meta-llama/Meta-Llama-3.1-8B-Instruct -S 3 -TS 2048 -TSS 1000 -P -CWF float -O -EV '{"dataset_name": "BaseFakepedia", "subsplit": "nodup_relpid", "k_demonstrations": 0, "context_weight_format": "float", "do_steering": False}'
Run the following command to submit all experiments for a given model (llama, mistral, gemma).
python generate_run_scripts.py --model-id llama --add-default --add-steering --add-oos-datasets --add-training && bash run_scripts.sh
The scripts for the interpretability analysis are in the analysis directory.
It is mainly based on the nnsight and nnpatch libraries.
Check the notebooks notebooks/analysis_llama.ipynb, notebooks/analysis_mistral.ipynb and notebooks/analysis_gemma.ipynb for the analysis in section 5 and 6 of the paper. To regenerate the plots, first generate the orthogonal projections using the notebooks mentioned, run all experiments and then run the analysis/plots_das.ipynb notebook.
You can also use the existing projections, which are hosted on huggingface.
- Meta-Llama-3.1-8B-Instruct-L16 Projection for the Meta-Llama-3.1 family of models. Layer 16. Recommended steering values:
prior=6,context=-6. - gemma-2-9b-it-L27 Projection for the gemma-2-9b-it family of models. Layer 27. Recommended steering values:
prior=-100,context=150. - Mistral-7B-Instruct-v0.3-L16 Projection for the Mistral family of models. Layer 16. Recommended steering values:
prior=5,context=-5.
Check the analysis/demo_steering.ipynb notebook for a demo of how to use the steering hook.
If you want the basis vector of the subspace, use the following snippet to get it:
proj = LowRankOrthogonalProjection.from_pretrained("jkminder/CTXPRIOR-Projection-Meta-Llama-3.1-8B-Instruct-L16")
u = proj.weight
u.shape # [4096, 1]