diff --git a/finetuning/training_configs/few_shot/svamp.yaml b/finetuning/training_configs/few_shot/svamp.yaml new file mode 100755 index 00000000..5cd0a3f5 --- /dev/null +++ b/finetuning/training_configs/few_shot/svamp.yaml @@ -0,0 +1,60 @@ +seed_everything: 333 +trainer: + default_root_dir: &exp_name results/debug-tmp + # progress_bar_refresh_rate: 1 + num_sanity_val_steps: 0 + log_every_n_steps: 1 + logger+: + - class_path: finetuning.lightning_modules.patches.patched_loggers.PatchedWandbLogger + init_args: + entity: yale-lily + project: unified-codegen + save_dir: *exp_name + name: *exp_name + log_model: False + save_code: True + offline: False + # offline: True + callbacks+: + - class_path: pytorch_lightning.callbacks.progress.TQDMProgressBar + init_args: + refresh_rate: 1 + + accelerator: gpu + devices: 2 + # strategy: deepspeed_stage_2 + strategy: ddp_find_unused_parameters_false + precision: 16 + +model: + class_path: lightning_modules.models.seq2seq_model.Seq2SeqModel + init_args: + transformer_model_name: default-will-cause-error + executor_cls: execution.executors.MathExecutor + max_gen_len: 256 + sampling_temp: 0.001 + # sampling_temp_at_k: 0.8 + # pass_at_k: 50 + # max_generation_batches: 5 + gradient_ckpt: false + save_raw_generation_results: true + # print_eval_every_n_batches: 1 + +data: + class_path: lightning_modules.datasets.base_datamodule.FewShotNL2CodeDataModule + init_args: + transformer_model_name: default-will-cause-error + dataset_cls: FewShotMathQADataset + batch_size: 1 + val_batch_size: 4 + ## prompting settings + prompting_init_args: + exemplar_file_path: prompt_files/svamp-idiomatic_code-annotated-4_exemplars.jsonl + num_exemplars: 4 + fixed_exemplars: true + exemplar_selection_method: first + add_instruction: true + use_chat_format: false + # val_max_instances: 64 + val_set_init_args: + file_path: data/svamp/svamp_test.jsonl \ No newline at end of file diff --git a/preprocessing/preprocess_svamp.py b/preprocessing/preprocess_svamp.py new file mode 100644 index 00000000..329fda8d --- /dev/null +++ b/preprocessing/preprocess_svamp.py @@ -0,0 +1,80 @@ +"""Preprocessing script for SVAMP. + +A typical example of SVAMP look like this: +{ + "ID": "chal-1", + "Body": "Each pack of dvds costs 76 dollars. If there is a discount of 25 dollars on each pack", + "Question": "How much do you have to pay to buy each pack?", + "Equation": "( 76.0 - 25.0 )", + "Answer": 51.0, + "Type": "Subtraction" +}, + +And after preprocessing, we want it to look like this: +{ + "question": "Each pack of dvds costs 76 dollars. If there is a discount of 25 dollars on each pack. How much do you have to pay to buy each pack?", + "answer": 51.0, + "annotated_code": , + "metadata": { + "ID": "chal-1", + "Body": "Each pack of dvds costs 76 dollars. If there is a discount of 25 dollars on each pack", + "Question": "How much do you have to pay to buy each pack?", + "Equation": "( 76.0 - 25.0 )", + "Answer": 51.0, + "Type": "Subtraction" + }, +} + +""" + +import json + +from typing import Dict, List, Any + + +ANNOTATION_DICT = { + "chal-10": "n_customers_left = 9\nn_customers_now = 12\nn_customers_start = n_customers_now + n_customers_left\nanswer = n_customers_start", + "chal-11": "n_birds = 3\nn_storks = 6\nn_more_bird = 2\nn_more_stork_than_bird = n_storks - (n_birds + n_more_bird)\nanswer = n_more_stork_than_bird", + "chal-12": "n_tables = 11\nn_chairs_per_table = 13\nn_chairs = n_tables * n_chairs_per_table\nanswer = n_chairs", + "chal-23": "group_size = 18\nn_total_bananas = 180\nn_groups = n_total_bananas / group_size\nanswer = n_groups", +} + +def preprocess_svamp_instance(example: Dict[str, Any]) -> Dict[str, Any]: + # preprocess based on the example + preprocessed_example = {} + preprocessed_example["question"] = example["Body"] + (" " if example["Body"].endswith(".") else ". ") \ + + example["Question"] + preprocessed_example["answer"] = example["Answer"] + preprocessed_example["metadata"] = example + + return preprocessed_example + +def main(): + with open("data/svamp/SVAMP.json", "r") as f: + examples = json.load(f) + + print(f"loaded {len(examples)} examples") + + # preprocess the examples + processed_examples = [preprocess_svamp_instance(example) for example in examples] + + # split the examples to prompt and test sets + prompt_examples = list(filter(lambda x: x["metadata"]["ID"] in ANNOTATION_DICT, processed_examples)) + test_examples = list(filter(lambda x: x["metadata"]["ID"] not in ANNOTATION_DICT, processed_examples)) + + # save the program annotations to the prompt examples + for example in prompt_examples: + example["annotated_code"] = ANNOTATION_DICT[example["metadata"]["ID"]] + + # save the prompt and test sets + print(f"Saving {len(prompt_examples)} prompt examples and {len(test_examples)} test examples") + with open("prompt_files/svamp-idiomatic_code-annotated-4_exemplars.jsonl", "w+") as f: + for example in prompt_examples: + f.write(json.dumps(example) + "\n") + + with open("data/svamp/svamp_test.jsonl", "w+") as f: + for example in test_examples: + f.write(json.dumps(example) + "\n") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/prompt_files/svamp-idiomatic_code-annotated-4_exemplars.jsonl b/prompt_files/svamp-idiomatic_code-annotated-4_exemplars.jsonl new file mode 100644 index 00000000..7b6dbfe6 --- /dev/null +++ b/prompt_files/svamp-idiomatic_code-annotated-4_exemplars.jsonl @@ -0,0 +1,4 @@ +{"question": "A waiter had some customers. After 9 customers left he still had 12 customers. How many customers did he have at the start?", "answer": 21.0, "metadata": {"ID": "chal-10", "Body": "A waiter had some customers. After 9 customers left he still had 12 customers.", "Question": "How many customers did he have at the start?", "Equation": "( 9.0 + 12.0 )", "Answer": 21.0, "Type": "Addition"}, "annotated_code": "n_customers_left = 9\nn_customers_now = 12\nn_customers_start = n_customers_now + n_customers_left\nanswer = n_customers_start"} +{"question": "3 birds were sitting on the fence. 6 more storks and 2 more birds came to join them. How many more storks than birds are sitting on the fence?", "answer": 1.0, "metadata": {"ID": "chal-11", "Body": "3 birds were sitting on the fence. 6 more storks and 2 more birds came to join them.", "Question": "How many more storks than birds are sitting on the fence?", "Equation": "( 6.0 - ( 3.0 + 2.0 ) )", "Answer": 1.0, "Type": "Subtraction"}, "annotated_code": "n_birds = 3\nn_storks = 6\nn_more_bird = 2\nn_more_stork_than_bird = n_storks - (n_birds + n_more_bird)\nanswer = n_more_stork_than_bird"} +{"question": "They decided to hold the party in their backyard. If they have 11 sets of tables and each set has 13 chairs. How many chairs do they have in the backyard?", "answer": 143.0, "metadata": {"ID": "chal-12", "Body": "They decided to hold the party in their backyard. If they have 11 sets of tables and each set has 13 chairs", "Question": "How many chairs do they have in the backyard?", "Equation": "( 11.0 * 13.0 )", "Answer": 143.0, "Type": "Multiplication"}, "annotated_code": "n_tables = 11\nn_chairs_per_table = 13\nn_chairs = n_tables * n_chairs_per_table\nanswer = n_chairs"} +{"question": "The bananas in Philip's collection are organized into groups of size 18. If there are a total of 180 bananas in Philip's banana collection. How many groups are there?", "answer": 10.0, "metadata": {"ID": "chal-23", "Body": "The bananas in Philip's collection are organized into groups of size 18. If there are a total of 180 bananas in Philip's banana collection", "Question": "How many groups are there?", "Equation": "( 180.0 / 18.0 )", "Answer": 10.0, "Type": "Common-Division"}, "annotated_code": "group_size = 18\nn_total_bananas = 180\nn_groups = n_total_bananas / group_size\nanswer = n_groups"} diff --git a/svamp-idiomatic_code-annotated-4_exemplars.jsonl b/svamp-idiomatic_code-annotated-4_exemplars.jsonl new file mode 100644 index 00000000..7b6dbfe6 --- /dev/null +++ b/svamp-idiomatic_code-annotated-4_exemplars.jsonl @@ -0,0 +1,4 @@ +{"question": "A waiter had some customers. After 9 customers left he still had 12 customers. How many customers did he have at the start?", "answer": 21.0, "metadata": {"ID": "chal-10", "Body": "A waiter had some customers. After 9 customers left he still had 12 customers.", "Question": "How many customers did he have at the start?", "Equation": "( 9.0 + 12.0 )", "Answer": 21.0, "Type": "Addition"}, "annotated_code": "n_customers_left = 9\nn_customers_now = 12\nn_customers_start = n_customers_now + n_customers_left\nanswer = n_customers_start"} +{"question": "3 birds were sitting on the fence. 6 more storks and 2 more birds came to join them. How many more storks than birds are sitting on the fence?", "answer": 1.0, "metadata": {"ID": "chal-11", "Body": "3 birds were sitting on the fence. 6 more storks and 2 more birds came to join them.", "Question": "How many more storks than birds are sitting on the fence?", "Equation": "( 6.0 - ( 3.0 + 2.0 ) )", "Answer": 1.0, "Type": "Subtraction"}, "annotated_code": "n_birds = 3\nn_storks = 6\nn_more_bird = 2\nn_more_stork_than_bird = n_storks - (n_birds + n_more_bird)\nanswer = n_more_stork_than_bird"} +{"question": "They decided to hold the party in their backyard. If they have 11 sets of tables and each set has 13 chairs. How many chairs do they have in the backyard?", "answer": 143.0, "metadata": {"ID": "chal-12", "Body": "They decided to hold the party in their backyard. If they have 11 sets of tables and each set has 13 chairs", "Question": "How many chairs do they have in the backyard?", "Equation": "( 11.0 * 13.0 )", "Answer": 143.0, "Type": "Multiplication"}, "annotated_code": "n_tables = 11\nn_chairs_per_table = 13\nn_chairs = n_tables * n_chairs_per_table\nanswer = n_chairs"} +{"question": "The bananas in Philip's collection are organized into groups of size 18. If there are a total of 180 bananas in Philip's banana collection. How many groups are there?", "answer": 10.0, "metadata": {"ID": "chal-23", "Body": "The bananas in Philip's collection are organized into groups of size 18. If there are a total of 180 bananas in Philip's banana collection", "Question": "How many groups are there?", "Equation": "( 180.0 / 18.0 )", "Answer": 10.0, "Type": "Common-Division"}, "annotated_code": "group_size = 18\nn_total_bananas = 180\nn_groups = n_total_bananas / group_size\nanswer = n_groups"}