diff --git a/end-to-end-use-cases/coding/text2sql/README.md b/end-to-end-use-cases/coding/text2sql/README.md index 1198a2d88..25da0a741 100644 --- a/end-to-end-use-cases/coding/text2sql/README.md +++ b/end-to-end-use-cases/coding/text2sql/README.md @@ -1,30 +1,39 @@ -## Text2SQL: Natural Language to SQL Interface +# Improving Llama Text2SQL performance with CoT Fine-tuning -This project provides a set of scripts to convert natural language queries into SQL statements using Meta's Llama model. The goal is to enable users to interact with databases using natural language inputs, making it easier for non-technical users to access and analyze data. +This recipe is step by step guide to improve Llama performance on Text2SQL measured with the popular [BIRD](https://bird-bench.github.io) benchmark. We generate a synthetic Chain of Thought(CoT) dataset and fine-tune Llama models on it. -For detailed instructions on setting up the environment, creating a database, and executing natural language queries using the Text2SQL interface, please refer to the quickstart.ipynb notebook. +Results: -### Structure: +| Fine-tuning Combination | Accuracy | +|-----------------------------|-------------------------------| +| baseline | 39.47% | +| CoT, PEFT | 43.35% | +| CoT, FFT | 42.44% (3 epochs) | +| CoT, FFT | 43.87% (10 epochs) | -- quickstart.ipynb: A Quick Demo of Text2SQL Using Llama 3.3. This Jupyter Notebook includes examples of how to use the interface to execute natural language queries on the sample data. It uses Llama 3.3 to answer questions about a SQLite database using LangChain and the Llama cloud provider Together.ai. -- nba.txt: A text file containing NBA roster information, which is used as sample data for demonstration purposes. -- txt2csv.py: A script that converts text data into a CSV format. This script is used to preprocess the input data before it is fed into csv2db.py. -- csv2db.py: A script that imports data from a CSV file into a SQLite database. This script is used to populate the database with sample data. -- nba_roster.db: A SQLite database file created from the nba.txt data, used to test the Text2SQL interface. +The complete steps are: -### Detailed steps on running the notebook: +1. Pre-processing the [BIRD](https://bird-bench.github.io) TRAIN datset by converting text, schema, external knowledge, and SQL statements into the conversation format. -- Before getting started, please make sure to setup Together.ai and get an API key from [here](https://www.together.ai/). +2. Using Llama-3.3-70B to add CoT to the conversation format dataset. -- First, please install the requirements from [here](https://github.com/meta-llama/llama-cookbook/blob/main/end-to-end-use-cases/coding/text2sql/requirements.txt) by running inside the folder: +3. Fine-tuning Llama-3.1-8B on the CoT dataset from step 2. -``` -git clone https://github.com/meta-llama/llama-cookbook.git -cd llama-cookbook/end-to-end-use-cases/coding/text2sql/ -pip install -r requirements.txt -``` +4. Running the BIRD DEV eval benchmark on the fine-tuned models and compare it with out of the model. -### Contributing -Contributions are welcome! If you'd like to add new features or improve existing ones, please submit a pull request. We encourage contributions in the following areas: -- Adding support for additional databases -- Developing new interfaces or applications that use the Text2SQL interface \ No newline at end of file +## Folder Structure + +- quickstart folder: contains a notebook to ask Llama 3.3 to convert natural language queries into SQL queries. +- data folder: contains scripts to download the BIRD TRAIN and DEV datasets; +- fine-tune folder: contains scripts to generate CoT dataset based on the BIRD TRAIN set and to supervised fine-tune Llama models using the dataset, with different SFT options (quantization or not, full fine-tuning or parameter-efficient fine-tuning); +- eval folder: contains scripts to evaluate Llama models (original and fine-tuned) on the BIRD dataset. + +We also experimented with supervised fine-tuning (SFT) without CoT which resulted in slightly lower accuracy. + +## Next Steps + +1. Hyper-parameter tuning of the current SFT scripts. +2. Try GRPO reinforcement learning to further improve the accuracy. +3. Fine-tune Llama 3.3 70B and Llama 4 models. +4. Try agentic workflow. +5. Expand the eval to support other enterprise databases. diff --git a/end-to-end-use-cases/coding/text2sql/data/download_dev_unzip.sh b/end-to-end-use-cases/coding/text2sql/data/download_dev_unzip.sh new file mode 100644 index 000000000..606696d89 --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/data/download_dev_unzip.sh @@ -0,0 +1,9 @@ +wget https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip +unzip dev.zip +rm dev.zip +rm -rf __MACOSX +cd dev_20240627 +unzip dev_databases.zip +rm dev_databases.zip +rm -rf __MACOSX +cd .. diff --git a/end-to-end-use-cases/coding/text2sql/data/download_train_unzip.sh b/end-to-end-use-cases/coding/text2sql/data/download_train_unzip.sh new file mode 100644 index 000000000..ea785ca50 --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/data/download_train_unzip.sh @@ -0,0 +1,9 @@ +wget https://bird-bench.oss-cn-beijing.aliyuncs.com/train.zip +UNZIP_DISABLE_ZIPBOMB_DETECTION=TRUE unzip train.zip +rm train.zip +rm -rf __MACOSX +cd train +unzip train_databases.zip +rm train_databases.zip +rm -rf __MACOSX +cd .. diff --git a/end-to-end-use-cases/coding/text2sql/eval/README.md b/end-to-end-use-cases/coding/text2sql/eval/README.md new file mode 100644 index 000000000..be213ceb6 --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/eval/README.md @@ -0,0 +1,85 @@ +# Llama Text2SQL Evaluation + +We have updated and simplified the original eval scripts from the BIRD [repo](https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/bird) to 3 simple steps for Llama 3 & 4 models hosted via Meta's [Llama API](https://llama.developer.meta.com), as well as Llama 3.1 8B on Hugging Face and its fine-tuned models. + +## Evaluation Results + +Below are the results of the Llama models we have evaluated on the BIRD DEV dataset: + +| Model | Llama API Accuracy | +|------------------------|--------------------| +| Llama 3.1 8b | 39.47% (*) | +| Llama 3.3 70b | 54.11% | +| Llama 4 Scout | 44.39% | +| Llama 4 Maverick | 44.00% | + +- Since Llama API does not have Llama 3.1 8b model, we use Hugging Face weights and vllm to run locally. + +## Quick Start with Llama Models via Llama API + +Follow the steps below to evaluate Llama 3 & 4 models on Text2SQL using the BIRD benchmark: + +1. Run the commands below to create a new Conda environment and install all the required packages for Text2SQL evaluation: + +``` +conda create -n llama-text2sql python=3.10 +conda activate llama-text2sql +git clone https://github.com/meta-llama/llama-cookbook +git checkout text2sql # to be removed after the PR merge +cd llama-cookbook/end-to-end-use-cases/coding/text2sql/eval +pip install -r requirements.txt +``` + +2. Get the DEV dataset: +``` +cd ../data +sh download_dev_unzip.sh +cd ../eval +``` + +3. Open `llama_eval.sh` and set `YOUR_API_KEY` to your [Llama API](https://llama.developer.meta.com/) key then uncomment a line that starts with `model=` to specify the Llama model to use for the text2sql eval. + +4. Run the evaluation script `sh llama_eval.sh`, which will use the BIRD DEV dataset (1534 examples in total) with external knowledge turned on to run the Llama model on each text question and compare the generated SQL with the gold SQL. + +If your API key or model name is incorrect, the script will exit with an authentication or model not supported error. + +After the script completes, you'll see the accuracy of the Llama model on the BIRD DEV text2sql. For example, the total accuracy is about 54.24% with `YOUR_API_KEY` set to your Llama API key and `model='Llama-3.3-70B-Instruct'` + +To compare your evaluated accuracy of your selected Llama model with other results in the BIRD Dev leaderboard, click [here](https://bird-bench.github.io/). + +## Evaluation with Llama Models on Hugging Face or Fine-tuned + +We use vllm OpenAI compatible server to run Llama 3.1 8B on Hugging Face (steps below) or its fine-tuned models (steps [here](../fine-tuning/#evaluating-the-fine-tuned-model) for eval: + +1. Uncomment the last two lines in requirements.txt then run `pip install -r requirements.txt`: +``` +# vllm==0.9.2 +# openai==1.90.0 +``` + +2. Uncomment in `llama_eval.sh`: +``` +YOUR_API_KEY='huggingface' +model='meta-llama/Llama-3.1-8B-Instruct' +``` + +3. Start the vllm server: +``` +vllm serve meta-llama/Llama-3.1-8B-Instruct --tensor-parallel-size 1 --max-num-batched-tokens 8192 --max-num-seqs 64 +``` +or if you want to speed up the inference and eval and have multiple GPUs, you can set `--tensor-parallel-size` to the number of your available GPUs, e.g.: +``` +vllm serve meta-llama/Llama-3.1-8B-Instruct --tensor-parallel-size 8 --max-num-batched-tokens 8192 --max-num-seqs 64 +``` + +then run `sh llama_eval.sh`. + +## Evaluation Process + +1. **SQL Generation**: `llama_text2sql.py` sends natural language questions to the specified Llama model and collects the generated SQL queries. + +2. **SQL Execution**: `text2sql_eval.py` executes both the generated SQL and ground truth SQL against the corresponding databases, then continues with steps 3 and 4 below. + +3. **Result Comparison**: The results from executing the generated SQL are compared ([source code](text2sql_eval.py#L29)) with the results from the ground truth SQL to determine correctness. + +4. **Accuracy Calculation**: Accuracy scores are calculated overall and broken down by difficulty levels (simple, moderate, challenging). diff --git a/end-to-end-use-cases/coding/text2sql/eval/create_bird_eval_dataset.py b/end-to-end-use-cases/coding/text2sql/eval/create_bird_eval_dataset.py new file mode 100644 index 000000000..892ed4620 --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/eval/create_bird_eval_dataset.py @@ -0,0 +1,161 @@ +import argparse +import json +import os +import sqlite3 + +import pandas as pd + +# from datasets import Dataset +from tqdm import tqdm + + +def new_directory(path): + if not os.path.exists(path): + os.makedirs(path) + + +def nice_look_table(column_names: list, values: list): + rows = [] + # Determine the maximum width of each column + widths = [ + max(len(str(value[i])) for value in values + [column_names]) + for i in range(len(column_names)) + ] + + # Print the column names + header = "".join( + f"{column.rjust(width)} " for column, width in zip(column_names, widths) + ) + # print(header) + # Print the values + for value in values: + row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths)) + rows.append(row) + rows = "\n".join(rows) + final_output = header + "\n" + rows + return final_output + + +def generate_schema_prompt(db_path, num_rows=None): + # extract create ddls + """ + :param root_place: + :param db_name: + :return: + """ + full_schema_prompt_list = [] + conn = sqlite3.connect(db_path) + # Create a cursor object + cursor = conn.cursor() + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = cursor.fetchall() + schemas = {} + for table in tables: + if table == "sqlite_sequence": + continue + cursor.execute( + "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format( + table[0] + ) + ) + create_prompt = cursor.fetchone()[0] + schemas[table[0]] = create_prompt + if num_rows: + cur_table = table[0] + if cur_table in ["order", "by", "group"]: + cur_table = "`{}`".format(cur_table) + + cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows)) + column_names = [description[0] for description in cursor.description] + values = cursor.fetchall() + rows_prompt = nice_look_table(column_names=column_names, values=values) + verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format( + num_rows, cur_table, num_rows, rows_prompt + ) + schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt) + + for k, v in schemas.items(): + full_schema_prompt_list.append(v) + + schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list) + + return schema_prompt + + +def generate_comment_prompt(question, knowledge=None): + knowledge_prompt = "-- External Knowledge: {}".format(knowledge) + question_prompt = "-- Question: {}".format(question) + + result_prompt = knowledge_prompt + "\n\n" + question_prompt + + return result_prompt + + +def generate_combined_prompts_one(db_path, question, knowledge=None): + schema_prompt = generate_schema_prompt(db_path, num_rows=None) + comment_prompt = generate_comment_prompt(question, knowledge) + + combined_prompts = schema_prompt + "\n\n" + comment_prompt + + return combined_prompts + + +def create_conversation(sample): + return { + "messages": [ + {"role": "system", "content": sample["messages"][0]["content"]}, + {"role": "user", "content": sample["messages"][1]["content"]}, + {"role": "assistant", "content": sample["messages"][2]["content"]}, + ] + } + + +def create_bird_eval_dataset(input_json, db_root_path): + SYSTEM_PROMPT = ( + "You are a text to SQL query translator. Using the SQLite DB Schema and the " + "External Knowledge, translate the following text question into a SQLite SQL " + "select statement." + ) + data = [] + + for i, item in tqdm(enumerate(input_json)): + print(f"processing #{i+1}") + db_id = item["db_id"] + question = item["question"] + external_knowledge = item["evidence"] + SQL = item["SQL"] + db_path = db_root_path + "/" + db_id + "/" + db_id + ".sqlite" + print(f"{db_path=}") + prompt = generate_combined_prompts_one( + db_path, + question, + knowledge=external_knowledge, + ) + + data.append( + { + "prompt": SYSTEM_PROMPT + "\n\n" + prompt, + "gold_sql": SQL, + "db_id": db_id, + } + ) + + df = pd.DataFrame(data) + df.to_csv("bird_dev_set_eval.csv", index=False) + print(f"Dataset saved as bird_dev_set_eval.csv with {len(df)} rows") + + +if __name__ == "__main__": + args_parser = argparse.ArgumentParser() + args_parser.add_argument("--input_json", type=str, required=True) + args_parser.add_argument("--db_root_path", type=str, required=True) + args = args_parser.parse_args() + + input_json = json.load(open(args.input_json, "r")) + db_root_path = args.db_root_path + + create_bird_eval_dataset(input_json, db_root_path) + +# follow steps 1 and 2 here https://github.com/meta-llama/llama-cookbook/tree/text2sql/end-to-end-use-cases/coding/text2sql/eval#quick-start-with-llama-models-via-llama-api +# then run: +# python3 create_bird_eval_dataset.py --input_json ../data/dev_20240627/dev.json --db_root_path ../data/dev_20240627/dev_databases diff --git a/end-to-end-use-cases/coding/text2sql/eval/llama_eval.sh b/end-to-end-use-cases/coding/text2sql/eval/llama_eval.sh new file mode 100644 index 000000000..0a4130f1f --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/eval/llama_eval.sh @@ -0,0 +1,49 @@ +# Set to "true" to enable debug mode with detailed prints +DEBUG_MODE="false" + +eval_path='../data/dev_20240627/dev.json' +db_root_path='../data/dev_20240627/dev_databases/' +ground_truth_path='../data/' + +# Llama models on Llama API +# YOUR_API_KEY='YOUR_LLAMA_API_KEY' +# model='Llama-3.3-8B-Instruct' +#model='Llama-3.3-70B-Instruct' +#model='Llama-4-Maverick-17B-128E-Instruct-FP8' +#model='Llama-4-Scout-17B-16E-Instruct-FP8' + +# Llama model on Hugging Face Hub https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct +# YOUR_API_KEY='huggingface' +# model='meta-llama/Llama-3.1-8B-Instruct' + +# Fine-tuned Llama models locally +YOUR_API_KEY='finetuned' +model='../fine-tuning/llama31-8b-text2sql-fft-nonquantized-cot-epochs-3' + +data_output_path="./output/$model/" + +echo "Text2SQL using $model" +python3 -u llama_text2sql.py --db_root_path ${db_root_path} --api_key ${YOUR_API_KEY} \ +--model ${model} --eval_path ${eval_path} --data_output_path ${data_output_path} + +# Check if llama_text2sql.py exited successfully +if [ $? -eq 0 ]; then + echo "llama_text2sql.py completed successfully. Proceeding with evaluation..." + + # Add --debug flag if DEBUG_MODE is true + if [ "$DEBUG_MODE" = "true" ]; then + python3 -u text2sql_eval.py --db_root_path ${db_root_path} --predicted_sql_path ${data_output_path} \ + --ground_truth_path ${ground_truth_path} \ + --diff_json_path ${eval_path} --debug + else + python3 -u text2sql_eval.py --db_root_path ${db_root_path} --predicted_sql_path ${data_output_path} \ + --ground_truth_path ${ground_truth_path} \ + --diff_json_path ${eval_path} + fi + + echo "Done evaluating $model." + +else + echo "Error: llama_text2sql.py failed with exit code $?. Skipping evaluation." + exit 1 +fi diff --git a/end-to-end-use-cases/coding/text2sql/eval/llama_text2sql.py b/end-to-end-use-cases/coding/text2sql/eval/llama_text2sql.py new file mode 100644 index 000000000..a20d4e5d9 --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/eval/llama_text2sql.py @@ -0,0 +1,458 @@ +import argparse +import concurrent.futures +import json +import os +import re +import sqlite3 +from typing import Dict + +from llama_api_client import LlamaAPIClient +from tqdm import tqdm + +MAX_NEW_TOKENS = 10240 # If API has max tokens (vs max new tokens), we calculate it +TIMEOUT = 60 # Timeout in seconds for each API call + + +def local_llama(client, api_key, prompts, model, max_workers=8): + """ + Process multiple prompts in parallel using the vllm server. + + Args: + client: OpenAI client + prompts: List of prompts to process + model: Model name + max_workers: Maximum number of parallel workers + + Returns: + List of results in the same order as prompts + """ + + SYSTEM_PROMPT = ( + ( + "You are a text to SQL query translator. Using the SQLite DB Schema " + "and the External Knowledge, translate the following text question " + "into a SQLite SQL select statement." + ) + if api_key == "huggingface" + else ( + "You are a text to SQL query translator. Using the SQLite DB Schema " + "and the External Knowledge, generate the step-by-step reasoning and " + "then the final SQLite SQL select statement from the text question." + ) + ) + + def process_single_prompt(prompt): + messages = [ + {"content": SYSTEM_PROMPT, "role": "system"}, + {"role": "user", "content": prompt}, + ] + try: + chat_response = client.chat.completions.create( + model=model, + messages=messages, + timeout=TIMEOUT, + temperature=0, + ) + answer = chat_response.choices[0].message.content.strip() + + pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL) + matches = pattern.findall(answer) + if not matches: + result = answer + else: + result = matches[0] + + return result + except Exception as e: + print(f"Error processing prompt: {e}") + return f"error:{e}" + + print( + f"local_llama: Processing {len(prompts)} prompts with {model=} " + f"using {max_workers} workers" + ) + results = [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all tasks and create a map of futures to their indices + future_to_index = { + executor.submit(process_single_prompt, prompt): i + for i, prompt in enumerate(prompts) + } + + # Initialize results list with None values + results = [None] * len(prompts) + + # Process completed futures as they complete + for future in tqdm( + concurrent.futures.as_completed(future_to_index), + total=len(prompts), + desc="Processing prompts", + ): + index = future_to_index[future] + try: + results[index] = future.result() + except Exception as e: + print(f"Error processing prompt at index {index}: {e}") + results[index] = f"error:{e}" + + return results + + +def new_directory(path): + if not os.path.exists(path): + os.makedirs(path) + + +def get_db_schemas(bench_root: str, db_name: str) -> Dict[str, str]: + """ + Read an sqlite file, and return the CREATE commands for each of the tables in the database. + """ + asdf = "database" if bench_root == "spider" else "databases" + with sqlite3.connect( + f"file:{bench_root}/{asdf}/{db_name}/{db_name}.sqlite?mode=ro", uri=True + ) as conn: + # conn.text_factory = bytes + cursor = conn.cursor() + cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") + tables = cursor.fetchall() + schemas = {} + for table in tables: + cursor.execute( + "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format( + table[0] + ) + ) + schemas[table[0]] = cursor.fetchone()[0] + + return schemas + + +def nice_look_table(column_names: list, values: list): + rows = [] + # Determine the maximum width of each column + widths = [ + max(len(str(value[i])) for value in values + [column_names]) + for i in range(len(column_names)) + ] + + header = "".join( + f"{column.rjust(width)} " for column, width in zip(column_names, widths) + ) + for value in values: + row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths)) + rows.append(row) + rows = "\n".join(rows) + final_output = header + "\n" + rows + return final_output + + +def generate_schema_prompt(db_path, num_rows=None): + # extract create ddls + """ + :param root_place: + :param db_name: + :return: + """ + full_schema_prompt_list = [] + conn = sqlite3.connect(db_path) + # Create a cursor object + cursor = conn.cursor() + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = cursor.fetchall() + schemas = {} + for table in tables: + if table == "sqlite_sequence": + continue + cursor.execute( + "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format( + table[0] + ) + ) + create_prompt = cursor.fetchone()[0] + schemas[table[0]] = create_prompt + if num_rows: + cur_table = table[0] + if cur_table in ["order", "by", "group"]: + cur_table = "`{}`".format(cur_table) + + cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows)) + column_names = [description[0] for description in cursor.description] + values = cursor.fetchall() + rows_prompt = nice_look_table(column_names=column_names, values=values) + verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format( + num_rows, cur_table, num_rows, rows_prompt + ) + schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt) + + for k, v in schemas.items(): + full_schema_prompt_list.append(v) + + schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list) + + return schema_prompt + + +def generate_comment_prompt(question, knowledge=None): + knowledge_prompt = "-- External Knowledge: {}".format(knowledge) + question_prompt = "-- Question: {}".format(question) + + result_prompt = knowledge_prompt + "\n\n" + question_prompt + + return result_prompt + + +def generate_combined_prompts_one(db_path, question, knowledge=None): + schema_prompt = generate_schema_prompt(db_path, num_rows=None) + comment_prompt = generate_comment_prompt(question, knowledge) + + combined_prompts = schema_prompt + "\n\n" + comment_prompt + + return combined_prompts + + +def cloud_llama(client, api_key, model, prompts): + """ + Process multiple prompts sequentially using the cloud API, showing progress with tqdm. + + Args: + client: LlamaAPIClient + api_key: API key + model: Model name + prompts: List of prompts to process (or a single prompt as string) + + Returns: + List of results if prompts is a list, or a single result if prompts is a string + """ + SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement." + + # Handle the case where a single prompt is passed + single_prompt = False + if isinstance(prompts, str): + prompts = [prompts] + single_prompt = True + + results = [] + + # Process each prompt sequentially with tqdm progress bar + for prompt in tqdm(prompts, desc="Processing prompts", unit="prompt"): + try: + messages = [ + {"content": SYSTEM_PROMPT, "role": "system"}, + {"role": "user", "content": prompt}, + ] + final_max_tokens = len(messages) + MAX_NEW_TOKENS + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=0, + max_completion_tokens=final_max_tokens, + ) + answer = response.completion_message.content.text + + pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL) + matches = pattern.findall(answer) + if matches != []: + result = matches[0] + else: + result = answer + except Exception as e: + result = "error:{}".format(e) + print(f"{result=}") + + results.append(result) + + # Return a single result if input was a single prompt + if single_prompt: + return results[0] + return results + + +def batch_collect_response_from_llama( + db_path_list, question_list, api_key, model, knowledge_list=None, batch_size=8 +): + """ + Process multiple questions in parallel using the vllm server. + + Args: + db_path_list: List of database paths + question_list: List of questions + api_key: API key + model: Model name + knowledge_list: List of knowledge strings (optional) + batch_size: Number of parallel requests + + Returns: + List of SQL responses + """ + if api_key in ["huggingface", "finetuned"]: + from openai import OpenAI + + openai_api_key = "EMPTY" + openai_api_base = "http://localhost:8000/v1" + + client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, + ) + else: + client = LlamaAPIClient() + + # Generate all prompts first + prompts = [] + for i, question in enumerate(question_list): + if knowledge_list: + cur_prompt = generate_combined_prompts_one( + db_path=db_path_list[i], question=question, knowledge=knowledge_list[i] + ) + else: + cur_prompt = generate_combined_prompts_one( + db_path=db_path_list[i], question=question + ) + prompts.append(cur_prompt) + + print(f"Generated {len(prompts)} prompts for Llama processing") + + if api_key in [ + "huggingface", + "finetuned", + ]: + # Process prompts in parallel; running vllm on multiple GPUs for best eval performance + results = local_llama( + client=client, + api_key=api_key, + prompts=prompts, + model=model, + max_workers=batch_size, + ) + else: + results = cloud_llama( + client=client, + api_key=api_key, + model=model, + prompts=prompts, + ) + + # Format results + response_list = [] + for i, result in enumerate(results): + if isinstance(result, str): + sql = result + else: + sql = "SELECT" + result["choices"][0]["text"] + + db_id = db_path_list[i].split("/")[-1].split(".sqlite")[0] + sql = ( + sql + "\t----- bird -----\t" + db_id + ) # to avoid unpredicted \t appearing in codex results + response_list.append(sql) + + return response_list + + +def question_package(data_json, knowledge=False): + question_list = [] + for data in data_json: + question_list.append(data["question"]) + + return question_list + + +def knowledge_package(data_json, knowledge=False): + knowledge_list = [] + for data in data_json: + knowledge_list.append(data["evidence"]) + + return knowledge_list + + +def decouple_question_schema(datasets, db_root_path): + question_list = [] + db_path_list = [] + knowledge_list = [] + for i, data in enumerate(datasets): + question_list.append(data["question"]) + cur_db_path = db_root_path + data["db_id"] + "/" + data["db_id"] + ".sqlite" + db_path_list.append(cur_db_path) + knowledge_list.append(data["evidence"]) + + return question_list, db_path_list, knowledge_list + + +def generate_sql_file(sql_lst, output_path=None): + result = {} + for i, sql in enumerate(sql_lst): + result[i] = sql + + if output_path: + directory_path = os.path.dirname(output_path) + new_directory(directory_path) + json.dump(result, open(output_path, "w"), indent=4) + + return result + + +if __name__ == "__main__": + args_parser = argparse.ArgumentParser() + args_parser.add_argument("--eval_path", type=str, default="") + args_parser.add_argument("--mode", type=str, default="dev") + args_parser.add_argument("--test_path", type=str, default="") + args_parser.add_argument("--use_knowledge", type=str, default="True") + args_parser.add_argument("--db_root_path", type=str, default="") + args_parser.add_argument("--api_key", type=str, required=True) + args_parser.add_argument("--model", type=str, required=True) + args_parser.add_argument("--data_output_path", type=str) + args_parser.add_argument( + "--batch_size", + type=int, + default=8, + help="Number of parallel requests for batch processing", + ) + args = args_parser.parse_args() + + if args.api_key not in ["huggingface", "finetuned"]: + os.environ["LLAMA_API_KEY"] = args.api_key + + try: + # test if the Llama API key is valid + client = LlamaAPIClient() + client.chat.completions.create( + model=args.model, + messages=[{"role": "user", "content": "125*125 is?"}], + temperature=0, + ) + except Exception as exception: + print(f"{exception=}") + exit(1) + + eval_data = json.load(open(args.eval_path, "r")) + + question_list, db_path_list, knowledge_list = decouple_question_schema( + datasets=eval_data, db_root_path=args.db_root_path + ) + assert len(question_list) == len(db_path_list) == len(knowledge_list) + + if args.use_knowledge == "True": + responses = batch_collect_response_from_llama( + db_path_list=db_path_list, + question_list=question_list, + api_key=args.api_key, + model=args.model, + knowledge_list=knowledge_list, + batch_size=args.batch_size, + ) + else: + responses = batch_collect_response_from_llama( + db_path_list=db_path_list, + question_list=question_list, + api_key=args.api_key, + model=args.model, + knowledge_list=None, + batch_size=args.batch_size, + ) + + output_name = args.data_output_path + "predict_" + args.mode + ".json" + + generate_sql_file(sql_lst=responses, output_path=output_name) + + print("successfully collect results from {}".format(args.model)) diff --git a/end-to-end-use-cases/coding/text2sql/eval/requirements.txt b/end-to-end-use-cases/coding/text2sql/eval/requirements.txt new file mode 100644 index 000000000..cb22a1b7a --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/eval/requirements.txt @@ -0,0 +1,7 @@ +llama_api_client==0.1.2 +func_timeout==4.3.5 +tqdm==4.67.1 + +# uncomment to run vllm for eval with Llama 3.1 8B on HF and its fine-tuned models +# vllm==0.9.2 +# openai==1.90.0 diff --git a/end-to-end-use-cases/coding/text2sql/eval/text2sql_eval.py b/end-to-end-use-cases/coding/text2sql/eval/text2sql_eval.py new file mode 100644 index 000000000..22c30f314 --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/eval/text2sql_eval.py @@ -0,0 +1,243 @@ +import argparse +import json +import multiprocessing as mp +import sqlite3 +import sys + +from func_timeout import func_timeout, FunctionTimedOut +from tqdm import tqdm + + +def load_json(dir): + with open(dir, "r") as j: + contents = json.loads(j.read()) + return contents + + +def result_callback(result): + exec_result.append(result) + + +def execute_sql(predicted_sql, ground_truth, db_path, debug=False): + conn = sqlite3.connect(db_path) + # Connect to the database + cursor = conn.cursor() + cursor.execute(predicted_sql) + predicted_res = cursor.fetchall() + cursor.execute(ground_truth) + ground_truth_res = cursor.fetchall() + res = 0 + if set(predicted_res) == set(ground_truth_res): + res = 1 + elif debug: + print( + f"\n\n==== INCORRECT SQL GENERATED ====\n{predicted_sql=}\n{predicted_res=}\n{ground_truth=}\n{ground_truth_res=}\n======\n\n" + ) + + return res + + +def execute_model( + predicted_sql, ground_truth, db_place, idx, meta_time_out, debug=False +): + try: + res = func_timeout( + meta_time_out, + execute_sql, + args=(predicted_sql, ground_truth, db_place, debug), + ) + except KeyboardInterrupt: + sys.exit(0) + except FunctionTimedOut: + result = [(f"timeout",)] + res = 0 + except Exception as e: + result = [(f"{e}",)] # possibly len(query) > 512 or not executable + res = 0 + result = {"sql_idx": idx, "res": res} + return result + + +def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"): + clean_sqls = [] + db_path_list = [] + if mode == "gpt": + sql_data = json.load(open(sql_path + "predict_" + data_mode + ".json", "r")) + for idx, sql_str in sql_data.items(): + if type(sql_str) == str: + sql, db_name = sql_str.split("\t----- bird -----\t") + else: + sql, db_name = " ", "financial" + clean_sqls.append(sql) + + db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite") + + elif mode == "gt": # ground truth + items = json.load(open(db_root_path + "/../dev.json")) + + for item in items: + sql = item["SQL"] + db_name = item["db_id"] + clean_sqls.append(sql) + db_path_list.append(db_root_path + db_name + "/" + db_name + ".sqlite") + + return clean_sqls, db_path_list + + +def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0, debug=False): + pool = mp.Pool(processes=num_cpus) + + # Create a progress bar if not in debug mode + if not debug: + pbar = tqdm(total=len(sqls), desc="Evaluating SQL queries") + + for i, sql_pair in enumerate(sqls): + predicted_sql, ground_truth = sql_pair + pool.apply_async( + execute_model, + args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out, debug), + callback=lambda result: result_callback_with_progress( + result, not debug, pbar + ), + ) + pool.close() + pool.join() + + # Close the progress bar if not in debug mode + if not debug: + pbar.close() + + +def result_callback_with_progress(result, use_progress, pbar=None): + exec_result.append(result) + if use_progress and pbar: + pbar.update(1) + + +def sort_results(list_of_dicts): + return sorted(list_of_dicts, key=lambda x: x["sql_idx"]) + + +def compute_acc_by_diff(exec_results, diff_json_path): + num_queries = len(exec_results) + results = [res["res"] for res in exec_results] + contents = load_json(diff_json_path) + + simple_results, moderate_results, challenging_results = [], [], [] + + for i, content in enumerate(contents): + if content["difficulty"] == "simple": + simple_results.append(exec_results[i]) + + if content["difficulty"] == "moderate": + moderate_results.append(exec_results[i]) + + if content["difficulty"] == "challenging": + challenging_results.append(exec_results[i]) + + simple_acc = sum([res["res"] for res in simple_results]) / len(simple_results) + moderate_acc = sum([res["res"] for res in moderate_results]) / len(moderate_results) + challenging_acc = ( + 0 + if len(challenging_results) == 0 + else sum([res["res"] for res in challenging_results]) / len(challenging_results) + ) + all_acc = sum(results) / num_queries + count_lists = [ + len(simple_results), + len(moderate_results), + len(challenging_results), + num_queries, + ] + return ( + simple_acc * 100, + moderate_acc * 100, + challenging_acc * 100, + all_acc * 100, + count_lists, + ) + + +def print_data(score_lists, count_lists, debug=False): + levels = ["simple", "moderate", "challenging", "total"] + + if debug: + print("{:20} {:20} {:20} {:20} {:20}".format("", *levels)) + print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists)) + print( + "====================================== ACCURACY =====================================" + ) + else: + print("\nEvaluation Results:") + print("-" * 40) + + print( + "{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format("accuracy", *score_lists) + ) + + +if __name__ == "__main__": + args_parser = argparse.ArgumentParser() + args_parser.add_argument( + "--predicted_sql_path", type=str, required=True, default="" + ) + args_parser.add_argument("--ground_truth_path", type=str, required=True, default="") + args_parser.add_argument("--data_mode", type=str, default="dev") + args_parser.add_argument("--db_root_path", type=str, required=True, default="") + args_parser.add_argument("--num_cpus", type=int, default=1) + args_parser.add_argument("--meta_time_out", type=float, default=30.0) + args_parser.add_argument("--mode_gt", type=str, default="gt") + args_parser.add_argument("--mode_predict", type=str, default="gpt") + args_parser.add_argument("--difficulty", type=str, default="simple") + args_parser.add_argument("--diff_json_path", type=str, default="") + args_parser.add_argument( + "--debug", action="store_true", help="Enable debug mode with detailed prints" + ) + args = args_parser.parse_args() + exec_result = [] + + if args.debug: + print("Debug mode enabled - showing detailed output") + + # Show loading progress if not in debug mode + if not args.debug: + print("Loading SQL queries and database paths...") + + pred_queries, db_paths = package_sqls( + args.predicted_sql_path, + args.db_root_path, + mode=args.mode_predict, + data_mode=args.data_mode, + ) + # generate gt sqls: + gt_queries, db_paths_gt = package_sqls( + args.ground_truth_path, args.db_root_path, mode="gt", data_mode=args.data_mode + ) + + query_pairs = list(zip(pred_queries, gt_queries)) + + if args.debug: + print(f"Executing {len(query_pairs)} SQL query pairs...") + + run_sqls_parallel( + query_pairs, + db_places=db_paths, + num_cpus=args.num_cpus, + meta_time_out=args.meta_time_out, + debug=args.debug, + ) + exec_result = sort_results(exec_result) + + if args.debug: + print("Evaluating statistics...") + + simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff( + exec_result, args.diff_json_path + ) + score_lists = [simple_acc, moderate_acc, challenging_acc, acc] + print_data(score_lists, count_lists, debug=args.debug) + + if args.debug: + print( + "===========================================================================================" + ) diff --git a/end-to-end-use-cases/coding/text2sql/fine-tuning/README.md b/end-to-end-use-cases/coding/text2sql/fine-tuning/README.md new file mode 100644 index 000000000..a6d02796a --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/fine-tuning/README.md @@ -0,0 +1,219 @@ +# Enhancing Text-to-SQL with CoT: A Fine-Tuning Approach with Llama + +This folder contains scripts to: + +* generate a dataset from the BIRD TRAIN set (with no CoT info) for supervised fine-tuning (SFT); +* generate a dataset from the BIRD TRAIN set (with CoT info by Llama 3.3 70B) for SFT; +* SFT the Llama 3.1 8B model with the generated datasets with different fine-tuning combinations: with or without CoT, using quantization or not, full fine-tuning (FFT) or parameter-efficient fine-tuning (PEFT). + +**Note:** CoT stands for Chain of Thought and we will use "CoT" and "reasoning" interchangeably here, although generally, reasoning encompasses a broader concept than CoT. + +## Eval Results of the Fine-tuned Models + +The eval results of SFT Llama 3.1 8B with different options (epochs is 3, with an additional 10 for the two FFT models) are summarized below: + +| Fine-tuning Combination | Accuracy | +|-----------------------------|-------------------------------| +| baseline | 39.47% | +| CoT, PEFT | 43.35% | +| CoT, FFT | 42.44% (3 epochs) | +| CoT, FFT | 43.87% (10 epochs) | + + +Using Quantization+PEFT on CoT dataset only dropped the accuracy from 43.35% to 42.89%. + +## Quick Start with Fine-tuning Llama 3.1 8B + +1. If you have already run the eval folder's Quick Start Step 1's commands [here](../eval/README.md#quick-start-with-llama-models-via-llama-api) to "create a new Conda environment and install all the required packages for Text2SQL evaluation", just run: + +``` +cd llama-cookbook/end-to-end-use-cases/coding/text2sql/fine-tuning +pip install -r requirements.txt +``` + +Otherwise, run the commands below to create a new Conda environment and install all the required packages for Text2SQL evaluation and fine-tuning: + +``` +conda create -n llama-text2sql python=3.10 +conda activate llama-text2sql +git clone https://github.com/meta-llama/llama-cookbook +git checkout text2sql # to be removed after the PR merge +cd llama-cookbook/end-to-end-use-cases/coding/text2sql/fine-tuning +pip install -r requirements.txt +``` + +2. Get the TRAIN dataset: + +``` +cd ../data +sh download_train_unzip.sh +cd ../fine-tuning +``` + +3. Create a CoT reasoning dataset from the TRAIN dataset: + +``` +python create_reasoning_dataset.py --input_json ../data/train/train.json --db_root_path ../data/train/train_databases +``` + +See the section "About Creating the CoT Dataset" below for more details. + +4. Run one of the commands below to fine-tune the Llama 3.1 8B model with the generated dataset (about 50-70GB GPU memory required): + +``` +python trl_sft.py --quantized false --peft false --cot true +python trl_sft.py --quantized false --peft true --cot true +python trl_sft.py --quantized true --peft true --cot true +``` + +See the section "About fine-tuning" below for more details. + +## Evaluating the fine-tuned model + +1. Set the `model` value in `llama_eval.sh` to be one of the fine-tuned model folders above, e.g. + +``` +YOUR_API_KEY='finetuned' +model='fine_tuning/llama31-8b-text2sql-fft-nonquantized-cot' +``` + +2. Start the vllm server by running +``` +vllm serve fine_tuning/llama31-8b-text2sql-fft-nonquantized-cot --tensor-parallel-size 1 --max-num-batched-tokens 8192 --max-num-seqs 64 +``` +or if you want to speed up the inference and eval and have multiple GPUs, you can set `--tensor-parallel-size` to the number of your available GPUs, e.g.: + +``` +vllm serve fine_tuning/llama31-8b-text2sql-fft-nonquantized-cot --tensor-parallel-size 8 --max-num-batched-tokens 8192 --max-num-seqs 64 +``` + +3. If you haven't downloaded the DEV dataset, download it and unzip it first: + +``` +cd ../data +sh download_dev_unzip.sh +cd ../eval +``` + +Then run `sh llama_eval.sh`. + +**Note:** If your fine-tuned model is PEFT based, you may need to run `python merge_peft.py` after modifying its `peft_model_path` and `output_dir` and set the merged folder path after `vllm serve`. + +## About Creating the CoT Dataset + +We use the BIRD TRAIN dataset to prepare for supervised fine-tuning with reasoning info in the dataset. The goal is to see if we can improve the accuracy of the fine-tuned model by adding the reasoning info in the dataset. + +The script `create_reasoning_dataset.py` is used to create a reasoning dataset from the TRAIN dataset by asking Llama 3.3 70B to generate the reasoning for each text question and its corresponding gold SQL. The intent is to use the reasoning dataset to fine-tune the Llama model to improve the accuracy of the generated SQL. + +To run the script, use the following command: +``` +python create_reasoning_dataset.py --input_json ../data/train/train.json --db_root_path ../data/train/train_databases +``` + +This will create a `text2sql_cot_dataset` dataset and `train_text2sql_cot_dataset.json` in the conversation format ready for fine-tuning. Each example in the dataset is generated from the code snippet below: + +``` +prompt = f""" +-- DB Schema: {db_schema} +-- External Knowledge: {external_knowledge} +-- Text Question: {question} +""" +cot = { + "messages": [ + { + "role": "system", + "content": "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question.", + }, + {"role": "user", "content": prompt}, + {"role": "assistant", "content": reasoning}, + ] +} +``` + +The prompt for Llama 3.3 70B to generate the `reasoning` above is: +``` +You are a text to SQL query translator. Based on the DB Schema and External Knowledge, given the Text Question Input and its Gold SQL Output below, generate the step-by-step reasoning to infer the Gold SQL Output from the Text Question Input. + +-- DB Schema: {db_schema} +-- External Knowledge: {external_knowledge} +-- Text Question Input: {question} +-- Gold SQL Output: {gold_SQL} + +Your response should be as follows:\n\n +Let me think through this step by step:\n\n1. First, I need to consider...\n2. Then...\n3. Next...\n...\n\nFinally, the SQL statement for the text question is: +```sql ...```\n + +""" +``` + +## About fine-tuning + +Run one of the commands below: + +``` +python trl_sft.py --quantized false --peft false --cot true +python trl_sft.py --quantized false --peft true --cot true +python trl_sft.py --quantized true --peft true --cot true +``` + +After the fine-tuning completes, you'll see the fine-tuned model saved in one of the following folders, as specified in `output_dir` of `SFTConfig` in `trl_sft.py`: + +``` +llama31-8b-text2sql-fft-nonquantized-cot +llama31-8b-text2sql-peft-nonquantized-cot +llama31-8b-text2sql-peft-quantized-cot +``` + +The train loss chart should look like this: +![](train_loss_cot.png) + + +## Fine-tuning with Llama 3.3 70B + +If you have 8xH100 GPUs, you can use [torchtune](https://github.com/pytorch/torchtune) to fine-tune Llama 3.3 70B and then evaluate the fine-tuned model. Note that "active development on torchtune" has been stopped ([detail](https://github.com/pytorch/torchtune/issues/2883)), but "Torchtune will continue to receive critical bug fixes and security patches during 2025", so here we just show torchtune as a method to fine-tune the larger Llama 3.3 70B on multiple GPUs. + +``` +pip install torch torchvision torchao +pip install torchtune +tune download meta-llama/Llama-3.3-70B-Instruct --ignore-patterns "original/consolidated*" --output-dir /tmp/Llama-3.3-70B-Instruct +git clone https://github.com/pytorch/torchtune +cd torchtune/tree/main/recipes/configs +``` + +Modify `llama3_3/70B_lora.yaml` as follows: + +``` +output_dir: /tmp/torchtune/llama3_3_70B/lora + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.chat_dataset + source: json + conversation_column: messages + conversation_style: openai + data_files: train_text2sql_cot_dataset_array.json + #split: train +seed: null +shuffle: True + +# Validation +run_val_every_n_steps: null # Change to an integer to enable validation every N steps +dataset_val: + _component_: torchtune.datasets.chat_dataset + source: json + conversation_column: messages + conversation_style: openai + data_files: test_text2sql_cot_dataset_array.json + #split: validation +batch_size_val: ${batch_size} +``` + +Then run: + +``` +tune run --nproc_per_node 8 lora_finetune_distributed --config llama3_3/70B_lora +``` + +After the fine-tuning is done, cd to `text2sql/fine-tuning` folder, set `peft_model_path` as `/tmp/torchtune/llama3_3_70B/lora` and `output_dir` as `llama3_3_70B/lora`, then run `vllm serve llama3_3_70B/lora --tensor-parallel-size 8 --max-num-batched-tokens 8192 --max-num-seqs 64`. + +Finally, in the `eval/llama_eval.sh`, set `model='llama3_3_70B/lora'`, and run `sh llama_eval.sh`. The accuracy of the fine-tuned Llama 3.3 70B should be around 57.24%, compared with the original 54.11% for off-the-shelf Llama 3.3 70B as shown in the [eval README](../eval#evaluation-results). diff --git a/end-to-end-use-cases/coding/text2sql/fine-tuning/create_reasoning_dataset.py b/end-to-end-use-cases/coding/text2sql/fine-tuning/create_reasoning_dataset.py new file mode 100644 index 000000000..33724afee --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/fine-tuning/create_reasoning_dataset.py @@ -0,0 +1,238 @@ +import argparse +import json +import os +import re +import sqlite3 + +from datasets import Dataset, load_from_disk +from langchain_together import ChatTogether +from llama_api_client import LlamaAPIClient + +if ( + os.environ.get("LLAMA_API_KEY", "") == "" + and os.environ.get("TOGETHER_API_KEY", "") == "" +): + print( + "Please set the environment variable LLAMA_API_KEY or TOGETHER_API_KEY to your API key." + ) + exit(1) + + +if os.environ.get("LLAMA_API_KEY", "") != "": # Llama model on Llama API + try: + client = LlamaAPIClient(api_key=os.environ["LLAMA_API_KEY"]) + + response = client.chat.completions.create( + model="Llama-3.3-70B-Instruct", + messages=[{"role": "user", "content": "125*125 is?"}], + temperature=0, + ) + answer = response.completion_message.content.text + except Exception as exception: + print(f"Invalid LLAMA_API_KEY {exception=}") + +if os.environ.get("TOGETHER_API_KEY", "") != "": # Llama model on together + llm = ChatTogether( + model="meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + temperature=0, + ) + try: + answer = llm.invoke("125*125 is?").content + except Exception as exception: + print(f"Invalid TOGETHER_API_KEY - {exception=}") + exit(1) + + +def llama(prompt, model="Llama-3.3-70B-Instruct"): + + if os.environ["LLAMA_API_KEY"] != "": + client = LlamaAPIClient(api_key=os.environ["LLAMA_API_KEY"]) + response = client.chat.completions.create( + model=model, messages=[{"role": "user", "content": prompt}], temperature=0 + ) + return response.completion_message.content.text + else: + llm = ChatTogether( + model="meta-llama/Llama-3.3-70B-Instruct-Turbo", + temperature=0, + ) + answer = llm.invoke(prompt).content + return answer + + +def new_directory(path): + if not os.path.exists(path): + os.makedirs(path) + + +def nice_look_table(column_names: list, values: list): + rows = [] + # Determine the maximum width of each column + widths = [ + max(len(str(value[i])) for value in values + [column_names]) + for i in range(len(column_names)) + ] + + # Print the column names + header = "".join( + f"{column.rjust(width)} " for column, width in zip(column_names, widths) + ) + # Print the values + for value in values: + row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths)) + rows.append(row) + rows = "\n".join(rows) + final_output = header + "\n" + rows + return final_output + + +def generate_schema_prompt(db_path, num_rows=None): + # extract create ddls + """ + :param root_place: + :param db_name: + :return: + """ + full_schema_prompt_list = [] + conn = sqlite3.connect(db_path) + # Create a cursor object + cursor = conn.cursor() + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = cursor.fetchall() + schemas = {} + for table in tables: + if table == "sqlite_sequence": + continue + cursor.execute( + "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format( + table[0] + ) + ) + create_prompt = cursor.fetchone()[0] + schemas[table[0]] = create_prompt + if num_rows: + cur_table = table[0] + if cur_table in ["order", "by", "group"]: + cur_table = "`{}`".format(cur_table) + + cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows)) + column_names = [description[0] for description in cursor.description] + values = cursor.fetchall() + rows_prompt = nice_look_table(column_names=column_names, values=values) + verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format( + num_rows, cur_table, num_rows, rows_prompt + ) + schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt) + + for k, v in schemas.items(): + full_schema_prompt_list.append(v) + + schema_prompt = "\n\n".join(full_schema_prompt_list) + + return schema_prompt + + +def create_conversation(sample): + return { + "messages": [ + {"role": "system", "content": sample["messages"][0]["content"]}, + {"role": "user", "content": sample["messages"][1]["content"]}, + {"role": "assistant", "content": sample["messages"][2]["content"]}, + ] + } + + +def create_cot_dataset(input_json, db_root_path): + cot_list = [] + diff = 0 + for i, item in enumerate(input_json): + print(f"processing #{i+1}") + + db_id = item["db_id"] + question = item["question"] + external_knowledge = item["evidence"] + gold_SQL = item["SQL"].strip() + db_path = db_root_path + "/" + db_id + "/" + db_id + ".sqlite" + # print(f"{db_path=}") + db_schema = generate_schema_prompt(db_path) + + prompt_to_generate_reasoning = """ + You are a text to SQL query translator. Based on the DB Schema and External Knowledge, given the Text Question Input and its Gold SQL Output below, generate the step-by-step reasoning to infer the Gold SQL Output from the Text Question Input. + + -- DB Schema: {db_schema} + -- External Knowledge: {external_knowledge} + -- Text Question Input: {question} + -- Gold SQL Output: {gold_SQL} + + Your response should be as follows:\n\n + Let me think through this step by step:\n\n1. First, I need to consider...\n2. Then...\n3. Next...\n...\n\nFinally, the SQL statement for the text question is: + ```sql ...```\n + + """ + + prompt_to_generate_reasoning = ( + prompt_to_generate_reasoning.replace("{db_schema}", db_schema) + .replace("{external_knowledge}", external_knowledge) + .replace("{question}", question) + .replace("{gold_SQL}", gold_SQL) + ) + reasoning = llama(prompt_to_generate_reasoning) + + pattern = re.compile(r"```sql\n*(.*?)```", re.DOTALL) + matches = pattern.findall(reasoning) + if matches != []: + gene_SQL = matches[0].replace("\n", "").strip() + gene_SQL = re.sub(r"\s{2,}", " ", gene_SQL) + else: + gene_SQL = reasoning + + print(f"{diff=}\n{gold_SQL=}\n{gene_SQL=}") + if gold_SQL != gene_SQL: + diff += 1 + continue + + # use the reasoning generated above to generate an example for the reasoning dataset used for fine-tuning + prompt = f""" + -- DB Schema: {db_schema} + -- External Knowledge: {external_knowledge} + -- Text Question: {question} +""" + cot = { + "messages": [ + { + "role": "system", + "content": "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question.", + }, + {"role": "user", "content": prompt}, + {"role": "assistant", "content": reasoning}, + ] + } + cot_list.append(cot) + + print(f"{diff=}, total: {len(input_json)}") + dataset_dict = {key: [d[key] for d in cot_list] for key in cot_list[0]} + hf_dataset = Dataset.from_dict(dataset_dict) + hf_dataset.save_to_disk("text2sql_cot_dataset") + + dataset = load_from_disk("text2sql_cot_dataset") + dataset = dataset.map( + create_conversation, remove_columns=dataset.features, batched=False + ) + dataset = dataset.train_test_split(test_size=0.3) + + dataset["train"].to_json("train_text2sql_cot_dataset.json", orient="records") + dataset["test"].to_json("test_text2sql_cot_dataset.json", orient="records") + + +if __name__ == "__main__": + args_parser = argparse.ArgumentParser() + args_parser.add_argument("--input_json", type=str, required=True) + args_parser.add_argument("--db_root_path", type=str, required=True) + args = args_parser.parse_args() + + input_json = json.load(open(args.input_json, "r")) + db_root_path = args.db_root_path + + create_cot_dataset(input_json, db_root_path) + +# python create_reasoning_dataset.py --input_json ../data/train/train.json --db_root_path ../data/train/train_databases diff --git a/end-to-end-use-cases/coding/text2sql/fine-tuning/create_sft_dataset.py b/end-to-end-use-cases/coding/text2sql/fine-tuning/create_sft_dataset.py new file mode 100644 index 000000000..b2597af87 --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/fine-tuning/create_sft_dataset.py @@ -0,0 +1,163 @@ +import argparse +import json +import os +import sqlite3 + +from datasets import Dataset +from tqdm import tqdm + + +def new_directory(path): + if not os.path.exists(path): + os.makedirs(path) + + +def nice_look_table(column_names: list, values: list): + rows = [] + # Determine the maximum width of each column + widths = [ + max(len(str(value[i])) for value in values + [column_names]) + for i in range(len(column_names)) + ] + + # Print the column names + header = "".join( + f"{column.rjust(width)} " for column, width in zip(column_names, widths) + ) + # print(header) + # Print the values + for value in values: + row = "".join(f"{str(v).rjust(width)} " for v, width in zip(value, widths)) + rows.append(row) + rows = "\n".join(rows) + final_output = header + "\n" + rows + return final_output + + +def generate_schema_prompt(db_path, num_rows=None): + # extract create ddls + """ + :param root_place: + :param db_name: + :return: + """ + full_schema_prompt_list = [] + conn = sqlite3.connect(db_path) + # Create a cursor object + cursor = conn.cursor() + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = cursor.fetchall() + schemas = {} + for table in tables: + if table == "sqlite_sequence": + continue + cursor.execute( + "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format( + table[0] + ) + ) + create_prompt = cursor.fetchone()[0] + schemas[table[0]] = create_prompt + if num_rows: + cur_table = table[0] + if cur_table in ["order", "by", "group"]: + cur_table = "`{}`".format(cur_table) + + cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows)) + column_names = [description[0] for description in cursor.description] + values = cursor.fetchall() + rows_prompt = nice_look_table(column_names=column_names, values=values) + verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format( + num_rows, cur_table, num_rows, rows_prompt + ) + schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt) + + for k, v in schemas.items(): + full_schema_prompt_list.append(v) + + schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list) + + return schema_prompt + + +def generate_comment_prompt(question, knowledge=None): + knowledge_prompt = "-- External Knowledge: {}".format(knowledge) + question_prompt = "-- Question: {}".format(question) + + result_prompt = knowledge_prompt + "\n\n" + question_prompt + + return result_prompt + + +def generate_combined_prompts_one(db_path, question, knowledge=None): + schema_prompt = generate_schema_prompt(db_path, num_rows=None) + comment_prompt = generate_comment_prompt(question, knowledge) + + combined_prompts = schema_prompt + "\n\n" + comment_prompt + + return combined_prompts + + +def create_conversation(sample): + return { + "messages": [ + {"role": "system", "content": sample["messages"][0]["content"]}, + {"role": "user", "content": sample["messages"][1]["content"]}, + {"role": "assistant", "content": sample["messages"][2]["content"]}, + ] + } + + +def create_sft_dataset(input_json, db_root_path): + ds = [] + SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement." + + for i, item in tqdm(enumerate(input_json)): + print(f"processing #{i+1}") + db_id = item["db_id"] + question = item["question"] + external_knowledge = item["evidence"] + SQL = item["SQL"] + db_path = db_root_path + "/" + db_id + "/" + db_id + ".sqlite" + print(f"{db_path=}") + prompt = generate_combined_prompts_one( + db_path, + question, + knowledge=external_knowledge, + ) + + example = { + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": prompt}, + {"role": "assistant", "content": SQL}, + ] + } + + ds.append(example) + + dataset_dict = {key: [d[key] for d in ds] for key in ds[0]} + dataset = Dataset.from_dict(dataset_dict) + # dataset.save_to_disk(f"text2sql_sft_dataset") + + dataset = dataset.map( + create_conversation, remove_columns=dataset.features, batched=False + ) + dataset = dataset.train_test_split(test_size=0.3) + + dataset["train"].to_json("train_text2sql_sft_dataset.json", orient="records") + dataset["test"].to_json("test_text2sql_sft_dataset.json", orient="records") + + +if __name__ == "__main__": + args_parser = argparse.ArgumentParser() + args_parser.add_argument("--input_json", type=str, required=True) + args_parser.add_argument("--db_root_path", type=str, required=True) + args = args_parser.parse_args() + + input_json = json.load(open(args.input_json, "r")) + db_root_path = args.db_root_path + + create_sft_dataset(input_json, db_root_path) + +# python create_sft_dataset.py --input_json ../data/train/train.json --db_root_path ../data/train/train_databases diff --git a/end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/README.md b/end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/README.md new file mode 100644 index 000000000..d0a7fd7cf --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/README.md @@ -0,0 +1,41 @@ +# GRPO Fine-tuning for Text2SQL + +This folder contains scripts to reinforcemen fine-tuning Llama models for the Text2SQL task using GRPO. + +## Quick start + +1. Download the BIRD train and dev datasets, if you haven't already: + +``` +git clone https://github.com/meta-llama/llama-cookbook +git checkout text2sql +cd llama-cookbook/end-to-end-use-cases/coding/text2sql/data +sh download_dev_unzip.sh +sh download_train_unzip.sh +``` + +2. (Optional) Set the following environment variable, so the reward of using LLM as a judge (via Llama 3.3 70b hosted on Together.ai) can be calculated: + +``` +pip install together +export TOGETHER_API_KEY= +``` + +If you don't want to use the using LLM as a judge reward, you can comment out this [line](https://github.com/meta-llama/llama-cookbook/blob/text2sql/end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/grpo_text2sql.py#L594) when calling GRPOTrainer and change the reward weights [here](https://github.com/meta-llama/llama-cookbook/blob/text2sql/end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/grpo-llama323b-text2sql.yaml#L32) to [1.0, 3.0, 1.0] + +3. Install the required libraries in a conda or virtual environment: + +``` +cd ../fine-tuning/grpo +pip install -r requirements.txt +``` + +4. Run the training script, assuming you have 6 GPUs to use for the training (if not, modify the `--num_processes` and `--gpu_ids`): + +``` +accelerate launch --num_processes 6 --gpu_ids 2,3,4,5,6,7 --config_file deepspeed_zero3.yaml grpo_text2sql.py --config grpo-llama323b-text2sql.yaml +``` + +You can modify the grpo-llama323b-text2sql.yaml file and tune `num_generations`, `learning_rate`, `reward_weights` and other parameters. + +5. To evaluate a saved checkpoint, follow the steps [here](https://github.com/meta-llama/llama-cookbook/tree/text2sql/end-to-end-use-cases/coding/text2sql/eval#evaluation-with-llama-models-on-hugging-face-or-fine-tuned). diff --git a/end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/deepspeed_zero3.yaml b/end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/deepspeed_zero3.yaml new file mode 100644 index 000000000..b5a1201f8 --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/deepspeed_zero3.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/grpo-llama323b-text2sql.yaml b/end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/grpo-llama323b-text2sql.yaml new file mode 100644 index 000000000..e8eace639 --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/grpo-llama323b-text2sql.yaml @@ -0,0 +1,72 @@ +# Model arguments +model_name_or_path: meta-llama/Llama-3.2-3B-Instruct +model_revision: main +torch_dtype: bfloat16 +attn_implementation: flash_attention_2 +bf16: true +tf32: true +output_dir: runs/llama-3.2-3b-grpo-text2sql-4rewards-6gpu + +# Lora Arguments +# No LoRA is used here + +# Training arguments +max_steps: 750 # 1000 #500 +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +learning_rate: 5.0e-7 # 1.0e-6 # 5.0e-7 # 1.0e-6 as in the deepseek math paper 5-e7 from https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05 +lr_scheduler_type: cosine +warmup_ratio: 0.03 +# GRPO specific parameters +beta: 0.001 # 0.04 as in the deepseek math paper 0.001 from https://hijkzzz.notion.site/unraveling-rlhf-and-its-variants-engineering-insights#147d9a33ecc9806090f3d5c749d31f05 +max_prompt_length: 512 # 256 +max_completion_length: 1024 +num_generations: 8 # 6 # 8 +use_vllm: true + +# Reward function weights +# Order: [format_reward_func, execution_reward_func, ensemble_n_gram_reward_func] +reward_weights: [1.0, 3.0, 1.0, 1.0] +# **Recommended Weight Strategy** +# Current Setting: `[1.0, 3.0, 1.0]`** +# * **Format reward (1.0)**: Standard weight since format correctness is binary but essential +# * **Execution reward (3.0)**: **Highest weight** - SQL execution correctness is most important for text2sql +# * **N-gram similarity (1.0)**: Standard weight for syntactic similarity + +# **Alternative Weight Strategies** +# **Conservative approach: `[2.0, 4.0, 1.0]`** +# * Emphasizes both format and execution correctness +# * Lower weight on similarity metrics +# **Balanced approach: `[1.5, 2.0, 1.5]`** +# * More balanced across all three metrics +# * Good for early training stages +# **Similarity-focused: `[1.0, 2.0, 2.0]`** +# * Higher weight on N-gram similarity +# * Useful if execution often fails initially +# final_reward = format_reward*1.0 + execution_reward*3.0 + ngram_reward*1.0 + +vllm_device: "cuda:0" # use vLLM for generation and DeepSpeed for distributed training. +# Set the num_processes to the number of GPUs you have - +# the last one will be used with vLLM for Generation. +# if you have 6 GPUs, set vllm_device to "cuda:5" (or 5?) and +# num_processes to 5 (or 6? in which case, 6th GPU will be used +# for both generation and training + +vllm_gpu_memory_utilization: 0.5 + +# Logging arguments +logging_strategy: steps +logging_steps: 2 +report_to: +- tensorboard +save_strategy: "steps" +save_steps: 50 +seed: 42 + +# Hugging Face Hub +push_to_hub: false + # hub_model_id: llama-3-1-8b-math-orca-qlora-10k-ep1 # if not defined same as output_dir +hub_strategy: every_save diff --git a/end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/grpo_text2sql.py b/end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/grpo_text2sql.py new file mode 100644 index 000000000..f09f8ec9c --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/grpo_text2sql.py @@ -0,0 +1,660 @@ +import json +import logging + +import os +import random +import re +import sqlite3 +import sys +from dataclasses import dataclass +from datetime import datetime +from typing import List + +from datasets import Dataset +from func_timeout import func_timeout, FunctionTimedOut +from together import Together +from tqdm import tqdm +from transformers import AutoTokenizer +from transformers.trainer_utils import get_last_checkpoint +from trl import get_peft_config, GRPOConfig, GRPOTrainer, ModelConfig, TrlParser + +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +TRAIN_JSON = "../../data/train/train.json" +DB_ROOT_PATH = "../../data/train/train_databases/" +LOG_REWARD_FILE_NAME = "text2sql_grpo_rewards5.log" +COMPLETION_SAMPLE_TXT_FILE_NAME = "completion_samples5.txt" + + +def load_json(dir): + with open(dir, "r") as j: + contents = json.loads(j.read()) + return contents + + +def execute_sql(predicted_sql, ground_truth_dbid): + ground_truth, db_name = ground_truth_dbid.split("\t----- bird -----\t") + + # print(f"\n==== execute_sql ====\n{predicted_sql=}\n{ground_truth=}") + + db_path = DB_ROOT_PATH + db_name + "/" + db_name + ".sqlite" + conn = sqlite3.connect(db_path) + # Connect to the database + cursor = conn.cursor() + cursor.execute(predicted_sql) + predicted_res = cursor.fetchall() + cursor.execute(ground_truth) + ground_truth_res = cursor.fetchall() + res = 0 + + if set(predicted_res) == set(ground_truth_res): + res = 1 + print("execution result same") + else: + print("execution result different") + conn.close() + + return res + + +@dataclass +class ScriptArguments: + tokenizer_name_or_path: str = None + + +######################## +# Setup logging +######################## +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +handler = logging.StreamHandler() +handler.setFormatter( + logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") +) +logger.addHandler(handler) + +######################## +# Helper functions +######################## + + +def log_reward(reason, completion, gt): + import os + + os.makedirs("logs", exist_ok=True) + log_file = os.path.join("logs", LOG_REWARD_FILE_NAME) + with open(log_file, "a") as f: + f.write("\n\n==============\n") + f.write(f">>>{reason=}\n>>>{completion=}\n>>>{gt=}\n") + + +def extract_answer(text): + """ + Extracts the final SQL statement answer from the raw text. + """ + try: + match = re.search(r"#### (\-?[\d\.,$]+)", text) + if match: + matched_string = match.group(1) + + # Remove any characters that would cause a ValueError, + # such as dollar signs ($) and commas (,) + cleaned_string = re.sub(r"[$,]", "", matched_string) + + return float(cleaned_string) + + match = re.search( + r"(?:The final answer is|The answer is):?\s*(\-?[\d\.,$]+)", + text, + re.IGNORECASE, + ) + if match: + matched_string = match.group(1) + cleaned_string = re.sub(r"[$,]", "", matched_string) + return float(cleaned_string) + + except (ValueError, AttributeError): + print(f"Error extracting answer from text: {match.group(1)}") + pass + return None + + +def format_reward_func(completions, answer, **kwargs): + """ + Format: ...... + Args: + completions (list[str]): Generated outputs + answer (list[str]): Expected answers + + Returns: + list[float]: Reward scores + """ + rewards = [] + + for completion, gt in zip(completions, answer): + + try: + if random.random() < 0.1: # 1% chance to write samples into a file + os.makedirs("completion_samples", exist_ok=True) + log_file = os.path.join( + "completion_samples", COMPLETION_SAMPLE_TXT_FILE_NAME + ) + with open(log_file, "a") as f: + f.write(f"\n\n==============\n") + f.write(completion) + + # Check if the format is correct + regex = r"([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\s*([\s\S]*?)<\/answer>$" + + match = re.search(regex, completion, re.DOTALL) + # if the format is not correct, reward is 0 + if match is None or len(match.groups()) != 2: + rewards.append(0.0) + log_reward("format_reward 0", completion, gt) + else: + rewards.append(1.0) + log_reward("format_reward 1", completion, gt) + except Exception as e: + rewards.append(0.0) + log_reward(f"format_reward 0 - exception {e=}", completion, gt) + return rewards + + +def execution_reward_func(completions, answer, **kwargs): + """ + Evaluates completions based on SQL statement execution result + + Args: + completions (list[str]): Generated outputs + answer (list[str]): Ground truth answers (from content with "role":"assistant" in train_text2sql_sft_dataset.json) + + Returns: + list[float]: Reward scores + """ + rewards = [] + for completion, gt in zip(completions, answer): + try: + # gt = extract_answer(gt) + match = re.search(r"(.*?)<\/answer>", completion) + if match is None: + rewards.append(0.0) + log_reward("execution_reward 0 - no answer tag found", completion, gt) + continue + # Extract the "answer" part from the completion + predicted_sql = match.group(1).strip() + + reason = "execution result different" + # execute the sql_generated and gt and compare the results + try: + res = func_timeout( + 30.0, + execute_sql, + args=(predicted_sql, gt), + ) + except KeyboardInterrupt: + sys.exit(0) + except FunctionTimedOut: + print("FunctionTimedOut") + reason = "execution timeout" + res = 0 + except Exception as e: + print("Exception", e) + reason = f"execution exception {e}" + res = 0 + + if res == 1: + # reason = "execution result same" + rewards.append(1.0) + log_reward("execution_reward 1", completion, gt) + else: + rewards.append(0.0) + log_reward( + f"execution_reward 0 {reason=}, {predicted_sql=}", + completion, + gt, + ) + + except Exception as e: + # If evaluation fails, reward is 0 + rewards.append(0.0) + log_reward(f"execution_reward 0 - exception {e=}", completion, gt) + + return rewards + + +def get_ngrams(tokens: List[str], n: int) -> set: + """Generates a set of n-grams from a list of tokens.""" + # Ensure there are enough tokens to create at least one n-gram + if len(tokens) < n: + return set() + return {tuple(tokens[i : i + n]) for i in range(len(tokens) - n + 1)} + + +def n_gram_jaccard_similarity(candidate_query: str, gold_query: str, n: int) -> float: + """Calculates the n-gram Jaccard similarity for a single n.""" + # Tokenize the SQL queries. Using .lower() for case-insensitivity. + candidate_tokens = candidate_query.lower().split() + gold_tokens = gold_query.lower().split() + + # Get the n-grams for both sets of tokens. + candidate_ngrams = get_ngrams(candidate_tokens, n) + gold_ngrams = get_ngrams(gold_tokens, n) + + # Handle the edge case where one or both sets are empty. + if not candidate_ngrams and not gold_ngrams: + return 1.0 + if not candidate_ngrams or not gold_ngrams: + return 0.0 + + # Calculate Jaccard similarity. + intersection = len(candidate_ngrams.intersection(gold_ngrams)) + union = len(candidate_ngrams.union(gold_ngrams)) + + return intersection / union + + +def ensemble_n_gram_reward_func(completions, answer, **kwargs): + """ + Calculates the averaged ensemble n-gram Jaccard similarity reward. + This function computes the Jaccard similarity for n=1, 2, and 3 + and returns the average score for each sample. + + Args: + completions (list[str]): Generated outputs + answer (list[str]): Ground truth answers (from content with "role":"assistant" in train_text2sql_sft_dataset.json) + + Returns: + list[float]: Reward scores + """ + + rewards = [] + questions = kwargs.get("question") + evidences = kwargs.get("evidence") + + for completion, gt, question, evidence in zip( + completions, answer, questions, evidences + ): + try: + match = re.search(r"(.*?)<\/answer>", completion) + if match is None: + rewards.append(0.0) + log_reward("n_gram_reward 0 - no answer tag found", completion, gt) + continue + # Extract the "answer" part from the completion + predicted_sql = match.group(1).strip() + + # Calculate Jaccard similarity for n=1, 2, and 3 + jaccard_1 = n_gram_jaccard_similarity(predicted_sql, gt, n=1) + jaccard_2 = n_gram_jaccard_similarity(predicted_sql, gt, n=2) + jaccard_3 = n_gram_jaccard_similarity(predicted_sql, gt, n=3) + + # Average the scores to get the final ensemble reward + average_jaccard = (jaccard_1 + jaccard_2 + jaccard_3) / 3.0 + print(f"{average_jaccard=}") + rewards.append(average_jaccard) + except Exception as e: + rewards.append(0.0) + log_reward(f"n_gram_reward 0 - exception {e=}", completion, gt) + + return rewards + + +def llm_as_a_judge_reward_func(completions, answer, **kwargs): + """ + Use Llama 3.3 70b as a judge to evaluate the quality of the generated SQL statements by comparing them to the ground truth answers. + + Args: + completions (list[str]): Generated outputs + answer (list[str]): Ground truth answers (from content with "role":"assistant" in train_text2sql_sft_dataset.json) + + Returns: + list[float]: Reward scores + """ + + rewards = [] + + client = Together() + PROMPT_TEMPLATE = """ +You are an experienced database expert. Your task is to evaluate a generated SQL query by comparing it +to the ground truth (gold) query and then assign a score between 0.0 and 2.0. A higher score indicates +the predicted query is more correct, while a score of 0.0 means it is completely incorrect. + +Follow these evaluation rules strictly: + +1. SELECT Clause: +• Only select columns that are mentioned in the user’s question. +• Do not include unnecessary columns or values. + +2. Aggregation (MAX/MIN): +• Always perform JOINs before applying MAX() or MIN(). + +3. ORDER BY with Distinct Values: +• Use a GROUP BY before an ORDER BY ASC|DESC to ensure +distinct values. + +4. Handling NULLs: +• If a column may contain NULL values (indicated by "None" in value examples +or explicitly mentioned), include a JOIN or a WHERE IS NOT NULL +clause. + +5. FROM/JOIN Clauses: +• Only include the tables essential for answering the question. + +6. Strictly Follow Hints: +• Adhere to all hints provided with the question. + +7. Thorough Question Analysis: +• Ensure all conditions and requirements mentioned in the question are ad- +dressed. + +8. DISTINCT Keyword: +• Use SELECT DISTINCTwhen the question requires unique values (e.g., IDs, URLs) +or when column statistics (Value Statics) indicate its necessity. + +9. Column Selection: +• Carefully analyze column descriptions and hints to choose the correct column +when similar columns exist across tables. + +10. String Concatenation: +• Do not use any string concatenation methods (e.g., || ’ ’ ||) in the SELECT +clause. + +11. JOIN Preference: +• Prefer using INNER JOINover nested SELECT statements. + +12. Date Processing: +• Use STRFTIME()for any date manipulations (e.g., STRFTIME(’%Y’, SOMETIME)to +extract the year). + +You are provided with the following inputs: +• Question: {QUESTION} +• Hint: {HINT} +• Gold Query: {GOLD_QUERY} +• Predicted Query: {PREDICTED_QUERY} + +Based on the above, return a single numeric score between 0.0 and 2.0 that reflects how +correct the predicted query is compared to the gold query. Respond with only the score and +no additional explanation. +""" + + questions = kwargs.get("question") + evidences = kwargs.get("evidence") + for completion, gt, question, evidence in zip( + completions, answer, questions, evidences + ): + try: + match = re.search(r"(.*?)<\/answer>", completion) + if match is None: + rewards.append(0.0) + log_reward( + "llm_as_a_judge_reward_func 0 - no answer tag found", completion, gt + ) + continue + # Extract the "answer" part from the completion + predicted_sql = match.group(1).strip() + prompt = PROMPT_TEMPLATE.format( + QUESTION=question, + HINT=evidence, + GOLD_QUERY=gt, + PREDICTED_QUERY=predicted_sql, + ) + response = client.chat.completions.create( + model="meta-llama/Llama-3.3-70B-Instruct-Turbo", + messages=[{"role": "user", "content": prompt}], + temperature=0, + ) + reward = float(response.choices[0].message.content) + print(f"llm_as_a_judge_reward_func>>> {reward=}") + rewards.append(reward) + except Exception as e: + rewards.append(0.0) + log_reward(f"llm_as_a_judge_reward_func 0 - exception {e=}", completion, gt) + + return rewards + + +def get_checkpoint(training_args: GRPOConfig): + last_checkpoint = None + if os.path.isdir(training_args.output_dir): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + return last_checkpoint + + +def generate_schema_prompt(db_path, num_rows=None): + # extract create ddls + """ + :param root_place: + :param db_name: + :return: + """ + full_schema_prompt_list = [] + conn = sqlite3.connect(db_path) + # Create a cursor object + cursor = conn.cursor() + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = cursor.fetchall() + schemas = {} + for table in tables: + if table == "sqlite_sequence": + continue + cursor.execute( + "SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format( + table[0] + ) + ) + create_prompt = cursor.fetchone()[0] + schemas[table[0]] = create_prompt + if num_rows: + cur_table = table[0] + if cur_table in ["order", "by", "group"]: + cur_table = "`{}`".format(cur_table) + + cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows)) + column_names = [description[0] for description in cursor.description] + values = cursor.fetchall() + # Format the rows as a simple table representation + rows_prompt = "\n".join( + "\t".join(str(val) for val in row) for row in values + ) + verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format( + num_rows, cur_table, num_rows, rows_prompt + ) + schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt) + + for k, v in schemas.items(): + full_schema_prompt_list.append(v) + + schema_prompt = "-- DB Schema: " + "\n\n".join(full_schema_prompt_list) + + return schema_prompt + + +def generate_comment_prompt(question, knowledge=None): + knowledge_prompt = "-- External Knowledge: {}".format(knowledge) + question_prompt = "-- Question: {}".format(question) + + result_prompt = knowledge_prompt + "\n\n" + question_prompt + + return result_prompt + + +def generate_combined_prompts_one(db_path, question, knowledge=None): + schema_prompt = generate_schema_prompt(db_path, num_rows=None) + comment_prompt = generate_comment_prompt(question, knowledge) + + combined_prompts = schema_prompt + "\n\n" + comment_prompt + + return combined_prompts + + +def grpo_function( + model_args: ModelConfig, script_args: ScriptArguments, training_args: GRPOConfig +): + + logger.info(f"Model parameters {model_args}") + logger.info(f"Training/evaluation parameters {training_args}") + + tokenizer = AutoTokenizer.from_pretrained( + ( + script_args.tokenizer_name_or_path + if script_args.tokenizer_name_or_path + else model_args.model_name_or_path + ), + revision=model_args.model_revision, + trust_remote_code=model_args.trust_remote_code, + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + ds = [] + SYSTEM_PROMPT = "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement." + + input_json = json.load(open(TRAIN_JSON, "r")) + + for i, item in tqdm(enumerate(input_json)): + print(f"processing #{i+1}") + db_id = item["db_id"] + question = item["question"] + external_knowledge = item["evidence"] + SQL = item["SQL"] + db_path = DB_ROOT_PATH + "/" + db_id + "/" + db_id + ".sqlite" + prompt = generate_combined_prompts_one( + db_path, + question, + knowledge=external_knowledge, + ) + + example = { + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": prompt}, + {"role": "assistant", "content": SQL + "\t----- bird -----\t" + db_id}, + ], + "question": question, + "evidence": external_knowledge, + } + + ds.append(example) + + dataset_dict = {key: [d[key] for d in ds] for key in ds[0]} + dataset = Dataset.from_dict(dataset_dict) + + def generate_r1_prompt( + system_prompt, user_prompt, ground_truth, question, evidence + ): + r1_prefix = [ + { + "role": "system", + "content": """You are great at reasoning and translating natural language question to SQLite SQL query. Given DB Schema, External Knowledge, and Question, your task is to first generate step-by-step reasoning, then apply the resoning to generate the SQLite select statement as the accurate translation of the Question. Enclose the step-by-step reasoning within the tags, and the final SQL statement within the tags, i.e. reasoning steps final SQL .""", + }, + {"role": "user", "content": user_prompt}, + ] + + return { + "prompt": tokenizer.apply_chat_template( + r1_prefix, tokenize=False, continue_final_message=True + ), + "answer": ground_truth, + "question": question, + "evidence": evidence, + } + + # convert our dataset to the r1 prompt + dataset = dataset.map( + lambda x: generate_r1_prompt( + x["messages"][0]["content"], + x["messages"][1]["content"], + x["messages"][2]["content"], + x["question"], + x["evidence"], + ), + remove_columns=dataset.column_names, # Remove original columns to avoid conflicts with apply_chat_template + ) + + # split the dataset into train and test + train_test_split = dataset.train_test_split(test_size=0.3) + + train_dataset = train_test_split["train"] + eval_dataset = train_test_split["test"] + print("len(train_dataset)", len(train_dataset)) + print(train_dataset[0]) + print("len(eval_dataset)", len(eval_dataset)) + print(eval_dataset[0]) + + ######################### + # Instantiate DPO trainer + ######################### + + trainer = GRPOTrainer( + model=model_args.model_name_or_path, + reward_funcs=[ + format_reward_func, + execution_reward_func, + ensemble_n_gram_reward_func, + llm_as_a_judge_reward_func, + ], + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=get_peft_config(model_args), + ) + + trainer.tokenizer = tokenizer + + ############### + # Training loop + ############### + # Check for last checkpoint + last_checkpoint = get_checkpoint(training_args) + # by default training_args.resume_from_checkpoint is None + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info(f"Checkpoint detected, resuming training at {last_checkpoint}.") + + # Train the model + logger.info( + f'*** Starting training {datetime.now().strftime("%Y-%m-%d %H:%M:%S")} for {training_args.num_train_epochs} epochs***' + ) + train_result = trainer.train(resume_from_checkpoint=last_checkpoint) + # Log and save metrics + metrics = train_result.metrics + metrics["train_samples"] = len(train_dataset) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + logger.info("*** Training complete ***") + + ################################## + # Save model and create model card + ################################## + + logger.info("*** Save model ***") + trainer.model.config.use_cache = True + trainer.save_model(training_args.output_dir) + logger.info(f"Model saved to {training_args.output_dir}") + training_args.distributed_state.wait_for_everyone() # wait for all processes to load + + tokenizer.save_pretrained(training_args.output_dir) + logger.info(f"Tokenizer saved to {training_args.output_dir}") + + # Save everything else on main process + # if trainer.accelerator.is_main_process: + # trainer.create_model_card({"tags": ["rl", "grpo", "tutorial", "philschmid"]}) + # push to hub if needed + # if training_args.push_to_hub is True: + # logger.info("Pushing to hub...") + # trainer.push_to_hub() + + logger.info("*** Training complete! ***") + + +def main(): + parser = TrlParser((ModelConfig, ScriptArguments, GRPOConfig)) + model_args, script_args, training_args = parser.parse_args_and_config() + + # Run the main training loop + grpo_function(model_args, script_args, training_args) + + +if __name__ == "__main__": + main() diff --git a/end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/requirements.txt b/end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/requirements.txt new file mode 100644 index 000000000..91127851d --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/fine-tuning/grpo/requirements.txt @@ -0,0 +1,14 @@ +torch==2.5.1 +tensorboard==2.19.0 +setuptools==70.3.0 +flash-attn==2.7.4.post1 +transformers==4.48.1 +datasets==3.1.0 +accelerate==1.3.0 +hf-transfer==0.1.9 +deepspeed==0.15.4 +trl==0.14.0 +peft==0.15.2 +vllm==0.7.0 +func_timeout==4.3.5 +together==1.5.26 diff --git a/end-to-end-use-cases/coding/text2sql/fine-tuning/merge_peft.py b/end-to-end-use-cases/coding/text2sql/fine-tuning/merge_peft.py new file mode 100644 index 000000000..5764bf0e3 --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/fine-tuning/merge_peft.py @@ -0,0 +1,45 @@ +import torch +from peft import PeftModel +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig + +peft_model_path = "../fine-tuning/final_test/llama31-8b-text2sql-peft-nonquantized-cot" +output_dir = ( + "../fine-tuning/final_test/llama31-8b-text2sql-peft-nonquantized-cot_merged" +) +# === Load Base Model and Tokenizer === +print("Loading base model and tokenizer...") +base_model_id = "meta-llama/Llama-3.1-8B-Instruct" +tokenizer = AutoTokenizer.from_pretrained(base_model_id) + +# Configure quantization if needed +quantization_config = None +use_quantized = False +if use_quantized: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ) + +# Load model +base_model = AutoModelForCausalLM.from_pretrained( + base_model_id, + device_map="auto", + torch_dtype=torch.bfloat16, + quantization_config=quantization_config, +) +base_model.resize_token_embeddings(128257) + +# === Load PEFT Adapter and Merge === +print("Loading PEFT adapter and merging...") +# peft_config = PeftConfig.from_pretrained(peft_model_path) +model = PeftModel.from_pretrained(base_model, peft_model_path) +model = model.merge_and_unload() # This merges the adapter weights into the base model + +# === Save the Merged Model === +print(f"Saving merged model to {output_dir} ...") +model.save_pretrained(output_dir) +tokenizer.save_pretrained(output_dir) + +print("Done! The merged model is ready for vLLM serving.") diff --git a/end-to-end-use-cases/coding/text2sql/fine-tuning/requirements.txt b/end-to-end-use-cases/coding/text2sql/fine-tuning/requirements.txt new file mode 100644 index 000000000..d08af229b --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/fine-tuning/requirements.txt @@ -0,0 +1,19 @@ +llama_api_client==0.1.2 +func_timeout==4.3.5 +tqdm==4.67.1 +vllm==0.9.2 +openai==1.90.0 +langchain-together==0.3.0 +sqlparse==0.5.3 +tensorboard==2.19.0 +liger_kernel==0.6.1 +setuptools==78.1.1 +deepspeed==0.17.3 +transformers==4.54.0 +datasets==4.0.0 +accelerate==1.9.0 +bitsandbytes==0.46.1 +trl==0.19.1 +peft==0.16.0 +lighteval==0.10.0 +hf_transfer==0.1.9 diff --git a/end-to-end-use-cases/coding/text2sql/fine-tuning/train_loss.png b/end-to-end-use-cases/coding/text2sql/fine-tuning/train_loss.png new file mode 100644 index 000000000..a0fbadf0b Binary files /dev/null and b/end-to-end-use-cases/coding/text2sql/fine-tuning/train_loss.png differ diff --git a/end-to-end-use-cases/coding/text2sql/fine-tuning/train_loss_cot.png b/end-to-end-use-cases/coding/text2sql/fine-tuning/train_loss_cot.png new file mode 100644 index 000000000..854d40bf8 Binary files /dev/null and b/end-to-end-use-cases/coding/text2sql/fine-tuning/train_loss_cot.png differ diff --git a/end-to-end-use-cases/coding/text2sql/fine-tuning/trl_sft.py b/end-to-end-use-cases/coding/text2sql/fine-tuning/trl_sft.py new file mode 100644 index 000000000..81ddf64c2 --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/fine-tuning/trl_sft.py @@ -0,0 +1,221 @@ +# Unified script supporting multiple fine-tuning configurations + +import argparse +import sys + +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import SFTConfig, SFTTrainer + +# Parse command line arguments +parser = argparse.ArgumentParser( + description="Unified fine-tuning script for Llama 3.1 8B" +) +parser.add_argument( + "--quantized", + type=str, + choices=["true", "false"], + required=True, + help="Whether to use quantization (true/false)", +) +parser.add_argument( + "--peft", + type=str, + choices=["true", "false"], + required=True, + help="Whether to use PEFT (true/false)", +) +parser.add_argument( + "--cot", + type=str, + choices=["true", "false"], + required=True, + help="Whether to use Chain-of-Thought dataset (true/false)", +) +args = parser.parse_args() + +# Convert string arguments to boolean +use_quantized = args.quantized.lower() == "true" +use_peft = args.peft.lower() == "true" +use_cot = args.cot.lower() == "true" + +# Check for unsupported combination +if not use_peft and use_quantized: + print( + "ERROR: Full Fine-Tuning (peft=false) with Quantization (quantized=true) is NOT RECOMMENDED!" + ) + print("This combination can lead to:") + print("- Gradient precision loss due to quantization") + print("- Training instability") + print("- Suboptimal convergence") + print("\nRecommended combinations:") + print( + "1. --peft=true --quantized=true (PEFT + Quantized - Most memory efficient)" + ) + print("2. --peft=true --quantized=false (PEFT + Non-quantized - Good balance)") + print( + "3. --peft=false --quantized=false (FFT + Non-quantized - Maximum performance)" + ) + sys.exit(1) + +print(f"Configuration: PEFT={use_peft}, Quantized={use_quantized}, CoT={use_cot}") + +# Import additional modules based on configuration +if use_quantized: + from transformers import BitsAndBytesConfig +if use_peft: + from peft import LoraConfig + +# Dataset configuration based on CoT parameter +if use_cot: + FT_DATASET = "train_text2sql_cot_dataset.json" + print("Using Chain-of-Thought reasoning dataset") +else: + FT_DATASET = "train_text2sql_sft_dataset.json" + print("Using standard SFT dataset") + +dataset = load_dataset("json", data_files=FT_DATASET, split="train") + +model_id = "meta-llama/Llama-3.1-8B-Instruct" + +# Configure quantization if needed +quantization_config = None +if use_quantized: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + ) + +# Load model +model = AutoModelForCausalLM.from_pretrained( + model_id, + device_map="auto", + torch_dtype=torch.bfloat16, + quantization_config=quantization_config, +) + +# Load tokenizer +tokenizer = AutoTokenizer.from_pretrained(model_id) +tokenizer.padding_side = "right" + +if tokenizer.pad_token is None: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + model.resize_token_embeddings(len(tokenizer)) + +# Configure PEFT if needed +peft_config = None +if use_peft: + peft_config = LoraConfig( + lora_alpha=128, + lora_dropout=0.05, + r=256, + bias="none", + target_modules="all-linear", + task_type="CAUSAL_LM", + ) + +# Configure training arguments based on combination using newer TRL API +cot_suffix = "cot" if use_cot else "nocot" + +if use_peft and use_quantized: + # PEFT + Quantized: Use SFTConfig (newer API) + print("Using PEFT + Quantized configuration") + args = SFTConfig( + output_dir=f"llama31-8b-text2sql-peft-quantized-{cot_suffix}", + num_train_epochs=3, + per_device_train_batch_size=3, + gradient_accumulation_steps=2, + gradient_checkpointing=True, + optim="adamw_torch_fused", + logging_steps=10, + save_strategy="epoch", + learning_rate=2e-4, + bf16=True, + tf32=True, + max_grad_norm=0.3, + warmup_ratio=0.03, + lr_scheduler_type="constant", + push_to_hub=True, + report_to="tensorboard", + max_seq_length=4096, + packing=True, + ) + +elif use_peft and not use_quantized: + # PEFT + Non-quantized: Use SFTConfig (newer API) + print("Using PEFT + Non-quantized configuration") + args = SFTConfig( + output_dir=f"llama31-8b-text2sql-peft-nonquantized-{cot_suffix}", + num_train_epochs=3, + per_device_train_batch_size=2, # Slightly reduced for non-quantized + gradient_accumulation_steps=4, # Increased to maintain effective batch size + gradient_checkpointing=True, + optim="adamw_torch_fused", + logging_steps=10, + save_strategy="epoch", + learning_rate=2e-4, + bf16=True, + tf32=True, + max_grad_norm=0.3, + warmup_ratio=0.03, + lr_scheduler_type="constant", + push_to_hub=True, + report_to="tensorboard", + max_seq_length=4096, + packing=True, + ) + +else: # not use_peft and not use_quantized + # FFT + Non-quantized: Use SFTConfig (newer API) + print("Using Full Fine-Tuning + Non-quantized configuration") + args = SFTConfig( + output_dir=f"llama31-8b-text2sql-fft-nonquantized-{cot_suffix}", + num_train_epochs=1, # Reduced epochs for full fine-tuning + per_device_train_batch_size=1, # Reduced batch size for full model training + gradient_accumulation_steps=8, # Increased to maintain effective batch size + gradient_checkpointing=True, + optim="adamw_torch_fused", + logging_steps=10, + save_strategy="epoch", + learning_rate=5e-6, # Lower learning rate for full fine-tuning + bf16=True, + tf32=True, + max_grad_norm=1.0, # Standard gradient clipping for full fine-tuning + warmup_ratio=0.1, # Warmup ratio for full fine-tuning + lr_scheduler_type="cosine", # Cosine scheduler for full fine-tuning + push_to_hub=True, + report_to="tensorboard", + dataloader_pin_memory=False, # Disable pin memory to save GPU memory + remove_unused_columns=False, # Keep all columns + max_seq_length=4096, + packing=True, + ) + +# Create trainer with consistent newer API +trainer = SFTTrainer( + model=model, + args=args, + train_dataset=dataset, + processing_class=tokenizer, + peft_config=peft_config, +) + +# Print memory requirements estimate +print("\nEstimated GPU Memory Requirements:") +if use_peft and use_quantized: + print("- PEFT + Quantized: ~12-16 GB") +elif use_peft and not use_quantized: + print("- PEFT + Non-quantized: ~20-25 GB") +else: # FFT + Non-quantized + print("- Full Fine-Tuning + Non-quantized: ~70-90 GB") + +print("\nStarting training...") +trainer.train() + +print("Training completed. Saving model...") +trainer.save_model() + +print("Model saved successfully!") diff --git a/end-to-end-use-cases/coding/text2sql/quickstart.ipynb b/end-to-end-use-cases/coding/text2sql/quickstart.ipynb deleted file mode 100644 index 39c7e75cd..000000000 --- a/end-to-end-use-cases/coding/text2sql/quickstart.ipynb +++ /dev/null @@ -1,334 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "e8cba0b6", - "metadata": {}, - "source": [ - "\"Open \n", - "\n", - "## Quick Demo of Text2SQL Using Llama 3.3\n", - "\n", - "This demo shows how to use Llama 3.3 to answer questions about a SQLite DB. \n", - "\n", - "We'll use LangChain and the Llama cloud provider [Together.ai](https://api.together.ai/), where you can easily get a free API key (or you can use any other Llama cloud provider or even Ollama running Llama locally)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "33fb3190-59fb-4edd-82dd-f20f6eab3e47", - "metadata": {}, - "outputs": [], - "source": [ - "!pip install --upgrade -r requirements.txt" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "fa4562d3", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "from langchain_together import ChatTogether\n", - "\n", - "os.environ['TOGETHER_API_KEY'] = 'your_api_key'\n", - "\n", - "llm = ChatTogether(\n", - " model=\"meta-llama/Llama-3.3-70B-Instruct-Turbo\",\n", - " temperature=0,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "6d421ae7", - "metadata": {}, - "source": [ - "To recreate the `nba_roster.db` file, run the two commands below:\n", - "- `python txt2csv.py` to convert the `nba.txt` file to `nba_roster.csv`. The `nba.txt` file was created by scraping the NBA roster info from the web.\n", - "- `python csv2db.py` to convert `nba_roster.csv` to `nba_roster.db`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "56f0360e-fca3-49a8-9a70-0416f84e15fc", - "metadata": {}, - "outputs": [], - "source": [ - "# uncomment if you don't want to create the db yourself\n", - "#! wget https://github.com/meta-llama/llama-recipes/raw/3649841b426999fdc61c30a9fc8721106bec769b/recipes/use_cases/coding/text2sql/nba_roster.db" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "3bb99f39-cd7a-4db6-91dd-02f3bf80347c", - "metadata": {}, - "outputs": [], - "source": [ - "from langchain_community.utilities import SQLDatabase\n", - "\n", - "# Note: to run in Colab, you need to upload the nba_roster.db file in the repo to the Colab folder first.\n", - "db = SQLDatabase.from_uri(\"sqlite:///nba_roster.db\", sample_rows_in_table_info=0)\n", - "\n", - "def get_schema():\n", - " return db.get_table_info()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "8d793ce7-324b-4861-926c-54973d7c9b43", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Based on the table schema below, write a SQL query that would answer the user's question; just return the SQL query and nothing else.\n", - "\n", - "Scheme:\n", - "\n", - "CREATE TABLE nba_roster (\n", - "\t\"Team\" TEXT, \n", - "\t\"NAME\" TEXT, \n", - "\t\"Jersey\" TEXT, \n", - "\t\"POS\" TEXT, \n", - "\t\"AGE\" INTEGER, \n", - "\t\"HT\" TEXT, \n", - "\t\"WT\" TEXT, \n", - "\t\"COLLEGE\" TEXT, \n", - "\t\"SALARY\" TEXT\n", - ")\n", - "\n", - "Question: What team is Stephen Curry on?\n", - "\n", - "SQL Query:\n" - ] - } - ], - "source": [ - "question = \"What team is Stephen Curry on?\"\n", - "prompt = f\"\"\"Based on the table schema below, write a SQL query that would answer the user's question; just return the SQL query and nothing else.\n", - "\n", - "Scheme:\n", - "{get_schema()}\n", - "\n", - "Question: {question}\n", - "\n", - "SQL Query:\"\"\"\n", - "\n", - "print(prompt)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "70776558", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "SELECT Team FROM nba_roster WHERE NAME = 'Stephen Curry'\n" - ] - } - ], - "source": [ - "answer = llm.invoke(prompt).content\n", - "print(answer)" - ] - }, - { - "cell_type": "markdown", - "id": "afcf423a", - "metadata": {}, - "source": [ - "***Note:*** If you don't have the \"just return the SQL query and nothing else\" in the prompt above, you'll likely get more text other than the SQL query back in the answer, making some extra post-processing necessary before running the db query below." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "62472ce6-794b-4a61-b88c-a1e031e28e4e", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\"[('Golden State Warriors',)]\"" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# note this is a dangerous operation and for demo purpose only; in production app you'll need to safe-guard any DB operation\n", - "result = db.run(answer)\n", - "result" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "39ed4bc3", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "I don't have enough information to determine whose salary you are referring to. Could you please provide more context or specify the person you are asking about?\n" - ] - } - ], - "source": [ - "# how about a follow up question\n", - "follow_up = \"What's his salary?\"\n", - "print(llm.invoke(follow_up).content)" - ] - }, - { - "cell_type": "markdown", - "id": "98b2c523", - "metadata": {}, - "source": [ - "Since we did not pass any context along with the follow-up to Llama, it doesn't know the answer. Let's try to fix it by adding context to the follow-up prompt." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "0c305278-29d2-4e88-9b3d-ad67c94ce0f2", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Based on the table schema, question, SQL query, and SQL response below, write a new SQL response; be concise, just output the SQL response.\n", - "\n", - "Scheme:\n", - "\n", - "CREATE TABLE nba_roster (\n", - "\t\"Team\" TEXT, \n", - "\t\"NAME\" TEXT, \n", - "\t\"Jersey\" TEXT, \n", - "\t\"POS\" TEXT, \n", - "\t\"AGE\" INTEGER, \n", - "\t\"HT\" TEXT, \n", - "\t\"WT\" TEXT, \n", - "\t\"COLLEGE\" TEXT, \n", - "\t\"SALARY\" TEXT\n", - ")\n", - "\n", - "Question: What's his salary?\n", - "SQL Query: What team is Stephen Curry on?\n", - "SQL Result: [('Golden State Warriors',)]\n", - "\n", - "New SQL Response:\n", - "\n" - ] - } - ], - "source": [ - "prompt = f\"\"\"Based on the table schema, question, SQL query, and SQL response below, write a new SQL response; be concise, just output the SQL response.\n", - "\n", - "Scheme:\n", - "{get_schema()}\n", - "\n", - "Question: {follow_up}\n", - "SQL Query: {question}\n", - "SQL Result: {result}\n", - "\n", - "New SQL Response:\n", - "\"\"\"\n", - "print(prompt)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "03739b96-e607-4fa9-bc5c-df118198dc7f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "SELECT SALARY FROM nba_roster WHERE NAME = \"Stephen Curry\"\n" - ] - } - ], - "source": [ - "new_answer = llm.invoke(prompt).content\n", - "print(new_answer)" - ] - }, - { - "cell_type": "markdown", - "id": "c782abb6-3b44-45be-8694-70fc29b82523", - "metadata": {}, - "source": [ - "Because we have \"be concise, just output the SQL response\", Llama 3 is able to just generate the SQL statement; otherwise output parsing will be needed." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "6ecfca53-be7e-4668-bad1-5ca7571817d7", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\"[('$51,915,615',)]\"" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "db.run(new_answer)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9d79bbb1-e91d-4b56-b6ef-98c94ff414d0", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/end-to-end-use-cases/coding/text2sql/quickstart/README.md b/end-to-end-use-cases/coding/text2sql/quickstart/README.md new file mode 100644 index 000000000..feb7db10a --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/quickstart/README.md @@ -0,0 +1,13 @@ +# Quickstart with Text2SQL + +The scripts and notebook in this folder let you get familiar with how to interact with a database using natural language inputs by asking Llama to convert natural language queries into SQL queries. + +For detailed instructions on setting up the environment, creating a database, and executing natural language queries using the Text2SQL interface, please refer to the [quickstart.ipynb](quickstart.ipynb) notebook. + +## Structure: + +- quickstart.ipynb: A Quick Demo of Text2SQL Using Llama 3.3. This Jupyter Notebook includes examples of how to use the interface to execute natural language queries on the sample data. It uses Llama 3.3 to answer questions about a SQLite database using LangChain and the Llama cloud provider Together.ai. +- nba.txt: A text file containing NBA roster information, which is used as sample data for demonstration purposes. +- txt2csv.py: A script that converts text data into a CSV format. This script is used to preprocess the input data before it is fed into csv2db.py. +- csv2db.py: A script that imports data from a CSV file into a SQLite database. This script is used to populate the database with sample data. +- nba_roster.db: A SQLite database file created from the nba.txt data, used to test the Text2SQL interface. diff --git a/end-to-end-use-cases/coding/text2sql/csv2db.py b/end-to-end-use-cases/coding/text2sql/quickstart/csv2db.py similarity index 100% rename from end-to-end-use-cases/coding/text2sql/csv2db.py rename to end-to-end-use-cases/coding/text2sql/quickstart/csv2db.py diff --git a/end-to-end-use-cases/coding/text2sql/nba.txt b/end-to-end-use-cases/coding/text2sql/quickstart/nba.txt similarity index 100% rename from end-to-end-use-cases/coding/text2sql/nba.txt rename to end-to-end-use-cases/coding/text2sql/quickstart/nba.txt diff --git a/end-to-end-use-cases/coding/text2sql/nba_roster.db b/end-to-end-use-cases/coding/text2sql/quickstart/nba_roster.db similarity index 100% rename from end-to-end-use-cases/coding/text2sql/nba_roster.db rename to end-to-end-use-cases/coding/text2sql/quickstart/nba_roster.db diff --git a/end-to-end-use-cases/coding/text2sql/quickstart/quickstart.ipynb b/end-to-end-use-cases/coding/text2sql/quickstart/quickstart.ipynb new file mode 100644 index 000000000..e9a0a1c97 --- /dev/null +++ b/end-to-end-use-cases/coding/text2sql/quickstart/quickstart.ipynb @@ -0,0 +1,226 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e8cba0b6", + "metadata": {}, + "source": [ + "## Quick Demo of Text2SQL Using Llama 3.3\n", + "\n", + "This demo shows how to use Llama 3.3 to answer questions about a SQLite DB. \n", + "\n", + "We'll use LangChain and the Llama cloud provider [Together.ai](https://api.together.ai/), where you can easily get a free API key (or you can use any other Llama cloud provider or even Ollama running Llama locally)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "33fb3190-59fb-4edd-82dd-f20f6eab3e47", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install --upgrade -r requirements.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa4562d3", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from langchain_together import ChatTogether\n", + "\n", + "os.environ['TOGETHER_API_KEY'] = 'your_api_key'\n", + "\n", + "llm = ChatTogether(\n", + " model=\"meta-llama/Llama-3.3-70B-Instruct-Turbo\",\n", + " temperature=0,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "6d421ae7", + "metadata": {}, + "source": [ + "To recreate the `nba_roster.db` file, run the two commands below:\n", + "- `python txt2csv.py` to convert the `nba.txt` file to `nba_roster.csv`. The `nba.txt` file was created by scraping the NBA roster info from the web.\n", + "- `python csv2db.py` to convert `nba_roster.csv` to `nba_roster.db`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56f0360e-fca3-49a8-9a70-0416f84e15fc", + "metadata": {}, + "outputs": [], + "source": [ + "# uncomment if you don't want to create the db yourself\n", + "#! wget https://github.com/meta-llama/llama-recipes/raw/3649841b426999fdc61c30a9fc8721106bec769b/recipes/use_cases/coding/text2sql/nba_roster.db" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3bb99f39-cd7a-4db6-91dd-02f3bf80347c", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.utilities import SQLDatabase\n", + "\n", + "# Note: to run in Colab, you need to upload the nba_roster.db file in the repo to the Colab folder first.\n", + "db = SQLDatabase.from_uri(\"sqlite:///nba_roster.db\", sample_rows_in_table_info=0)\n", + "\n", + "def get_schema():\n", + " return db.get_table_info()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d793ce7-324b-4861-926c-54973d7c9b43", + "metadata": {}, + "outputs": [], + "source": [ + "question = \"What team is Stephen Curry on?\"\n", + "prompt = f\"\"\"Based on the table schema below, write a SQL query that would answer the user's question; just return the SQL query and nothing else.\n", + "\n", + "Scheme:\n", + "{get_schema()}\n", + "\n", + "Question: {question}\n", + "\n", + "SQL Query:\"\"\"\n", + "\n", + "print(prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70776558", + "metadata": {}, + "outputs": [], + "source": [ + "answer = llm.invoke(prompt).content\n", + "print(answer)" + ] + }, + { + "cell_type": "markdown", + "id": "afcf423a", + "metadata": {}, + "source": [ + "***Note:*** If you don't have the \"just return the SQL query and nothing else\" in the prompt above, you'll likely get more text other than the SQL query back in the answer, making some extra post-processing necessary before running the db query below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62472ce6-794b-4a61-b88c-a1e031e28e4e", + "metadata": {}, + "outputs": [], + "source": [ + "# note this is a dangerous operation and for demo purpose only; in production app you'll need to safe-guard any DB operation\n", + "result = db.run(answer)\n", + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39ed4bc3", + "metadata": {}, + "outputs": [], + "source": [ + "# how about a follow up question\n", + "follow_up = \"What's his salary?\"\n", + "print(llm.invoke(follow_up).content)" + ] + }, + { + "cell_type": "markdown", + "id": "98b2c523", + "metadata": {}, + "source": [ + "Since we did not pass any context along with the follow-up to Llama, it doesn't know the answer. Let's try to fix it by adding context to the follow-up prompt." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c305278-29d2-4e88-9b3d-ad67c94ce0f2", + "metadata": {}, + "outputs": [], + "source": [ + "prompt = f\"\"\"Based on the table schema, question, SQL query, and SQL response below, write a new SQL response; be concise, just output the SQL response.\n", + "\n", + "Scheme:\n", + "{get_schema()}\n", + "\n", + "Question: {follow_up}\n", + "SQL Query: {question}\n", + "SQL Result: {result}\n", + "\n", + "New SQL Response:\n", + "\"\"\"\n", + "print(prompt)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "03739b96-e607-4fa9-bc5c-df118198dc7f", + "metadata": {}, + "outputs": [], + "source": [ + "new_answer = llm.invoke(prompt).content\n", + "print(new_answer)" + ] + }, + { + "cell_type": "markdown", + "id": "c782abb6-3b44-45be-8694-70fc29b82523", + "metadata": {}, + "source": [ + "Because we have \"be concise, just output the SQL response\", Llama 3 is able to just generate the SQL statement; otherwise output parsing will be needed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ecfca53-be7e-4668-bad1-5ca7571817d7", + "metadata": {}, + "outputs": [], + "source": [ + "db.run(new_answer)" + ] + } + ], + "metadata": { + "fileHeader": "", + "fileUid": "559664fd-826e-43fa-86b4-78edb9bbf80e", + "isAdHoc": false, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/end-to-end-use-cases/coding/text2sql/requirements.txt b/end-to-end-use-cases/coding/text2sql/quickstart/requirements.txt similarity index 100% rename from end-to-end-use-cases/coding/text2sql/requirements.txt rename to end-to-end-use-cases/coding/text2sql/quickstart/requirements.txt diff --git a/end-to-end-use-cases/coding/text2sql/txt2csv.py b/end-to-end-use-cases/coding/text2sql/quickstart/txt2csv.py similarity index 100% rename from end-to-end-use-cases/coding/text2sql/txt2csv.py rename to end-to-end-use-cases/coding/text2sql/quickstart/txt2csv.py