diff --git a/.github/README.md b/.github/README.md index 35833aa..54f1a20 100644 --- a/.github/README.md +++ b/.github/README.md @@ -166,6 +166,12 @@ We include a series of notebook+scripts for fine tuning the models. * [Vision fine tuning Gemma 3 4B with Unsloth](/notebooks/Gemma3_(4B)-Vision.ipynb) Open In Colab * [Conversational fine tuning Gemma 3 4B with Unsloth](/notebooks/Gemma3_(4B).ipynb) Open In Colab +## RAG + +### Gemma 3n +* [Retrieval-Augmented Generation with Gemma 3n](/notebooks/Gemma_RAG.ipynb) + + Before fine-tuning the model, ensure all dependencies are installed: ```bash diff --git a/notebooks/Gemma_RAG.ipynb b/notebooks/Gemma_RAG.ipynb new file mode 100644 index 0000000..5cf3619 --- /dev/null +++ b/notebooks/Gemma_RAG.ipynb @@ -0,0 +1,487 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "884c076b-6f0b-4c42-b965-65dc046d29c1", + "metadata": {}, + "source": [ + "# 🧠 Gemma_RAG: Lightweight Retrieval-Augmented Generation with Gemma\n", + "\n", + "A minimal example of using Retrieval-Augmented Generation (RAG) with Gemma models, integrated with `sentence-transformers`, `FAISS`, and `Streamlit`. This notebook is runnable on **free Colab instances**." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "6f346d53-49db-44ea-9920-ad8ad16e0267", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -Uq sentence-transformers transformers accelerate faiss-cpu timm" + ] + }, + { + "cell_type": "markdown", + "id": "9076a2b3-30ca-47a7-a242-8c8e82e08616", + "metadata": {}, + "source": [ + "### 📦 Importing Required Libraries\n", + "\n", + "This cell imports all the libraries needed for the project:\n", + "- `os` for accessing environment variables like HF tokens\n", + "- `torch` for deep learning with GPU support\n", + "- `transformers` for loading the Gemma language model\n", + "- `sentence-transformers` for creating semantic embeddings\n", + "- `faiss` for fast similarity search\n", + "- `numpy` for array manipulation and type casting" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "bab99e08-1edc-4bc3-9386-cb54a00b2342", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import torch\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n", + "from sentence_transformers import SentenceTransformer\n", + "import faiss\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "id": "009ce4ae-f769-4b67-b6b3-a6dccbaf5868", + "metadata": {}, + "source": [ + "### 🧠 Model Setup and Token Loading\n", + "\n", + "This cell loads the **authentication token**, sets the **model ID**, and initializes both the **tokenizer** and the **language model** (`Gemma-3n-E4B-it`) from the Hugging Face Hub. These steps are essential to prepare the model for inference (i.e., generating text).\n", + "\n", + "- `token = os.environ.get(\"HF_TOKEN\")` \n", + " Retrieves your Hugging Face token from environment variables. This is used to authenticate access to gated models (like Gemma-3n) securely. By storing the token in the environment, you avoid hardcoding sensitive info in your notebook.\n", + "\n", + "- `model_id = \"google/gemma-3n-E4B-it\"` \n", + " Specifies the exact model you want to use from the Hugging Face Model Hub. In this case, you're using **Gemma-3n-E4B-it**, a 3-billion-parameter instruction-tuned language model developed by Google. This string acts as a reference for downloading both the tokenizer and model weights.\n", + "\n", + "- `tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)` \n", + " Loads the tokenizer that matches the specified Gemma model. The tokenizer transforms raw input text (e.g., `\"What happened?\"`) into token IDs that the model understands. Using `AutoTokenizer` ensures the right tokenizer is chosen automatically based on the model’s config file. The `token=token` part ensures access to the tokenizer files from a private/gated model if necessary.\n", + "\n", + "- `gemma_model = AutoModelForCausalLM.from_pretrained(model_id, token=token, torch_dtype=torch.bfloat16, device_map={\"\": 0})` \n", + " Loads the **Gemma-3n language model weights** for causal language modeling (i.e., left-to-right generation). \n", + " - `token=token`: Ensures authenticated access. \n", + " - `torch_dtype=torch.bfloat16`: Loads the model using Brain Float 16 precision, which is memory-efficient and optimized for newer GPUs like the A100. \n", + " - `device_map={\"\": 0}`: Places the full model on GPU 0 (i.e., `cuda:0`), preventing the runtime error you’d get if tensors are split across `cuda:0` and `cuda:1`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "306ecee2-e6fe-4d08-8eb7-8c913d6a0297", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00, 1.39s/it]\n" + ] + } + ], + "source": [ + "token = os.environ.get(\"HF_TOKEN\")\n", + "model_id = \"google/gemma-3n-E4B-it\"\n", + "tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)\n", + "gemma_model = AutoModelForCausalLM.from_pretrained(model_id,token=token,torch_dtype=torch.bfloat16,device_map={\"\":0})" + ] + }, + { + "cell_type": "markdown", + "id": "64b42da5-dd57-47c2-8f9d-6418c5a18cc6", + "metadata": {}, + "source": [ + "### 🔁 Creating the Text Generation Pipeline\n", + "\n", + "This cell creates a **text generation pipeline** using Hugging Face’s `pipeline()` utility. The pipeline wraps the model and tokenizer together and handles the full process of generating natural language output from a prompt.\n", + "\n", + "- `generator = pipeline(\"text-generation\", ...)` \n", + " Initializes a high-level text generation pipeline for causal language models. This abstraction lets you input raw text and get full model-generated outputs without manually handling tokenization or decoding.\n", + "\n", + "- `model=gemma_model` \n", + " Sets the pretrained Gemma model as the core component that will perform text generation.\n", + "\n", + "- `tokenizer=tokenizer` \n", + " Supplies the tokenizer needed to convert input strings into token IDs that the model can understand.\n", + "\n", + "- `device_map=0` \n", + " Assigns the model and data to GPU 0 (`cuda:0`). This is important to avoid device mismatch errors when using multiple GPUs.\n", + "\n", + "- `torch_dtype=torch.bfloat16` \n", + " Sets the numerical precision for model weights and activations to bfloat16, which is memory-efficient and optimized for modern GPUs like the A100.\n", + "\n", + "- `max_new_tokens=256` \n", + " Limits how many tokens the model can generate in response to a prompt. A larger value allows for longer, more detailed outputs.\n", + "\n", + "> 💡 This pipeline simplifies the generation process so you can just call `generator(prompt)` and receive a coherent answer in return.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "2c228924-8cff-48c7-8594-e623b2ac6f77", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Device set to use cuda:0\n" + ] + } + ], + "source": [ + "generator = pipeline(\n", + " \"text-generation\",\n", + " model=gemma_model,\n", + " tokenizer=tokenizer,\n", + " device_map=0,\n", + " torch_dtype=torch.bfloat16,\n", + " max_new_tokens=256, # Increased max tokens for more detailed responses\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "120262a3-aa21-4208-9339-e71d6cadc081", + "metadata": {}, + "source": [ + "### 📑 Step 2: Text Snippet Retrieval Setup\n", + "\n", + "In this cell, we define a list of short narrative passages or **context snippets** that describe key events, locations, and interactions between characters (Ethan and Fiona). These text entries will later serve as the **knowledge base** for answering questions using semantic search.\n", + "\n", + "- `text_snippets = [...]` \n", + " This is a Python list that contains multiple text strings. Each string represents a small piece of a story or description.\n", + "\n", + "These snippets will be:\n", + "- Embedded using a sentence transformer model.\n", + "- Indexed using FAISS for fast similarity search.\n", + "- Used as context when answering user questions via a large language model.\n", + "\n", + "> 📚 This is a crucial part of the RAG (Retrieval-Augmented Generation) setup, where relevant knowledge is retrieved from this list and passed as input to the language model for grounded, context-aware answers.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "1e16e5ce-a42a-4cc1-a990-a47d8947bcb6", + "metadata": {}, + "outputs": [], + "source": [ + "# 2. Text Snippet Retrieval Setup\n", + "text_snippets = [\n", + " \"Fiona thanked Ethan for his unwavering support and promised to cherish their friendship.\",\n", + " \"As they ventured deeper into the forest, they encountered a wide array of obstacles.\",\n", + " \"Ethan and Fiona crossed treacherous ravines using rickety bridges, relying on each other's strength.\",\n", + " \"Overwhelmed with joy, Fiona thanked Ethan and disappeared into the embrace of her family.\",\n", + " \"Ethan returned to his cottage, heart full of memories and a smile brighter than ever before.\",\n", + " \"The forest was dark and mysterious, filled with ancient trees and hidden paths.\",\n", + " \"Ethan always carried a map and compass, ensuring they never lost their way.\",\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "084753ff-fea1-4c2f-8ac1-a1f32d9fd134", + "metadata": {}, + "source": [ + "### 🔍 Step 3: Enhanced Retrieval Mechanism — Semantic Search with FAISS\n", + "\n", + "This section sets up the **semantic embedding** and **vector search index** needed to perform efficient and meaningful retrieval of relevant text snippets based on a user's query.\n", + "\n", + "---\n", + "\n", + "- `embedding_model = SentenceTransformer(\"all-MiniLM-L6-v2\")` \n", + " Loads a lightweight, high-performance sentence embedding model from the Sentence Transformers library. This model converts sentences into dense numerical vectors (embeddings) that capture their semantic meaning.\n", + "\n", + "- `embeddings_text_snippets = embedding_model.encode(text_snippets)` \n", + " Generates vector embeddings for each of the predefined text snippets. These embeddings will later be compared to the query embedding to find the most relevant snippet.\n", + "\n", + "---\n", + "\n", + "### ⚙️ FAISS Index Creation\n", + "\n", + "- `dimension = embeddings_text_snippets.shape[1]` \n", + " Extracts the dimensionality of each embedding vector (e.g., 384), which is required to initialize the FAISS index correctly.\n", + "\n", + "- `index = faiss.IndexFlatL2(dimension)` \n", + " Initializes a **FAISS index** that uses L2 distance (Euclidean distance) to compare vectors. This allows for fast and efficient similarity search between embeddings.\n", + "\n", + "- `index.add(embeddings_text_snippets.astype(np.float32))` \n", + " Adds all the text snippet embeddings to the FAISS index after converting them to `float32`, which is the required input format for FAISS.\n", + "\n", + "> ⚡ This enables real-time semantic search, where a user’s question can be matched to the most semantically similar snippet — even if they use different words.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "3397f9bc-30d3-4bc3-88ed-dd18e2e918eb", + "metadata": {}, + "outputs": [], + "source": [ + "# 3. Enhanced Retrieval Mechanism: Semantic Search with FAISS\n", + "embedding_model = SentenceTransformer(\"all-MiniLM-L6-v2\")\n", + "embeddings_text_snippets = embedding_model.encode(text_snippets)\n", + "\n", + "# FAISS Index Creation\n", + "dimension = embeddings_text_snippets.shape[1] # Embedding dimension\n", + "index = faiss.IndexFlatL2(dimension) # L2 distance (Euclidean)\n", + "index.add(embeddings_text_snippets.astype(np.float32)) # FAISS requires float32" + ] + }, + { + "cell_type": "markdown", + "id": "a9f852a5-94ca-449e-9a8c-e04516f6ce08", + "metadata": {}, + "source": [ + "### 🧠 Step 4: Retrieval Function (Semantic Search)\n", + "\n", + "This function takes a user query and returns the **most semantically similar snippet** from the previously indexed text corpus using **FAISS-based nearest neighbor search**.\n", + "\n", + "---\n", + "\n", + "- `def retrieve_snippet(query, k=1):` \n", + " Defines a Python function that accepts a query string and retrieves `k` most similar snippets. By default, `k=1`, meaning it returns only the top match.\n", + "\n", + "- `query_embedded = embedding_model.encode([query]).astype(np.float32)` \n", + " Converts the query string into an embedding vector using the same sentence embedding model used for the snippets. FAISS requires all vectors to be in `float32`, so the type is cast accordingly.\n", + "\n", + "- `D, I = index.search(query_embedded, k)` \n", + " Searches the FAISS index to find the `k` most similar embeddings to the query. \n", + " - `D`: distances (lower = more similar) \n", + " - `I`: indices of the most similar snippets in the original list\n", + "\n", + "- `retrieved_indices = I[0]` \n", + " Extracts the list of top-k indices from the FAISS result. Since only one query is being processed, we access the first (and only) row of `I`.\n", + "\n", + "- `retrieved_texts = [text_snippets[i] for i in retrieved_indices]` \n", + " Uses the retrieved indices to extract the corresponding text snippets from the original list.\n", + "\n", + "- `return retrieved_texts[0]` \n", + " Returns only the **most relevant snippet**. This snippet will later be used as context for the language model during text generation.\n", + "\n", + "> 💡 This function powers the semantic retrieval part of RAG — ensuring the model responds using real context instead of hallucinating answers.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "7637263e-6d20-450a-b458-e9e2e66a608b", + "metadata": {}, + "outputs": [], + "source": [ + "# 4. Retrieval Function (Semantic Search)\n", + "def retrieve_snippet(query, k=1): # k is the number of snippets to retrieve\n", + " query_embedded = embedding_model.encode([query]).astype(np.float32)\n", + " D, I = index.search(query_embedded, k) # D: distances, I: indices\n", + " retrieved_indices = I[0]\n", + " retrieved_texts = [text_snippets[i] for i in retrieved_indices]\n", + " return retrieved_texts[0] # Return only the top snippet\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "99147cf9-fff9-4379-b7aa-6888706d9e7b", + "metadata": {}, + "outputs": [], + "source": [ + "# 5. Create a function to generate the answer based on the retrieved snippet and query\n", + "def ask_query(query):\n", + " retrieved_text = retrieve_snippet(query)\n", + "\n", + " # Step 1: Construct chat messages as a list of roles/content\n", + " chat = [\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"You are a helpful AI assistant. Answer the question based on the context provided.\",\n", + " },\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": f\"\"\"Context:\n", + "{retrieved_text}\n", + "\n", + "Question: {query}\"\"\",\n", + " },\n", + " ]\n", + "\n", + " # Step 2: Use tokenizer's chat template to format this\n", + " prompt_ids = tokenizer.apply_chat_template(\n", + " chat,\n", + " tokenize=True,\n", + " add_generation_prompt=True, # add assistant tag to begin model generation\n", + " return_tensors=\"pt\"\n", + " ).to(gemma_model.device)\n", + "\n", + " # Step 3: Generate using the raw model\n", + " output = gemma_model.generate(prompt_ids, max_new_tokens=128)\n", + " response = tokenizer.decode(output[0], skip_special_tokens=True)\n", + "\n", + " print(f\"Query: {query}\")\n", + " print(f\"Context: {retrieved_text}\")\n", + " print(f\"Answer: {response}\")\n", + " print(\"-\" * 40)\n" + ] + }, + { + "cell_type": "markdown", + "id": "d4fdcda7-4f16-46c3-a22e-53d766625ea2", + "metadata": {}, + "source": [ + "### 🗣️ Step 6: Ask Questions\n", + "\n", + "This block runs a series of **user-defined natural language queries** through the full Retrieval-Augmented Generation (RAG) pipeline, using the `ask_query()` function. For each question, the pipeline:\n", + "\n", + "1. **Finds the most semantically similar snippet** using FAISS-based search.\n", + "2. **Constructs a prompt** that includes the retrieved snippet as context.\n", + "3. **Generates an answer** using the Gemma language model.\n", + "\n", + "---\n", + "\n", + "- `query1 = \"Why did Fiona thank Ethan?\"` \n", + " A straightforward question to test if the model can connect Fiona’s gratitude to Ethan’s support. \n", + " → Passed to `ask_query(query1)` to fetch the answer.\n", + "\n", + "- `query2 = \"What challenges did Ethan and Fiona face in the forest?\"` \n", + " A more complex question that probes the model’s understanding of events and obstacles. \n", + " → Answer will depend on the forest-related snippets.\n", + "\n", + "- `query3 = \"What tools did Ethan use to navigate?\"` \n", + " A factual retrieval question. The model should extract and summarize tools like a map or compass.\n", + "\n", + "- `query4 = \"Describe the forest.\"` \n", + " An open-ended descriptive query that should trigger a more vivid narrative response based on stored context.\n", + "\n", + "> 🧠 These queries showcase how the system can handle **factual, contextual, and descriptive questions** using real context — avoiding hallucinated answers.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "42395c04-f4f3-4eac-a8c2-b2b086016a73", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Query: Why did Fiona thank Ethan?\n", + "Context: Fiona thanked Ethan for his unwavering support and promised to cherish their friendship.\n", + "Answer: user\n", + "You are a helpful AI assistant. Answer the question based on the context provided.\n", + "\n", + "Context:\n", + "Fiona thanked Ethan for his unwavering support and promised to cherish their friendship.\n", + "\n", + "Question: Why did Fiona thank Ethan?\n", + "model\n", + "Fiona thanked Ethan for his unwavering support. \n", + "\n", + "----------------------------------------\n", + "Query: What challenges did Ethan and Fiona face in the forest?\n", + "Context: Ethan and Fiona crossed treacherous ravines using rickety bridges, relying on each other's strength.\n", + "Answer: user\n", + "You are a helpful AI assistant. Answer the question based on the context provided.\n", + "\n", + "Context:\n", + "Ethan and Fiona crossed treacherous ravines using rickety bridges, relying on each other's strength.\n", + "\n", + "Question: What challenges did Ethan and Fiona face in the forest?\n", + "model\n", + "Based on the context, Ethan and Fiona faced the challenge of crossing treacherous ravines using rickety bridges. This implies a physically dangerous obstacle and a need for careful coordination and reliance on each other. \n", + "\n", + "So the answer is: **They faced the challenge of crossing treacherous ravines using rickety bridges.**\n", + "\n", + "\n", + "\n", + "\n", + "----------------------------------------\n", + "Query: What tools did Ethan use to navigate?\n", + "Context: Ethan always carried a map and compass, ensuring they never lost their way.\n", + "Answer: user\n", + "You are a helpful AI assistant. Answer the question based on the context provided.\n", + "\n", + "Context:\n", + "Ethan always carried a map and compass, ensuring they never lost their way.\n", + "\n", + "Question: What tools did Ethan use to navigate?\n", + "model\n", + "Ethan used a map and compass to navigate. \n", + "\n", + "----------------------------------------\n", + "Query: Describe the forest.\n", + "Context: The forest was dark and mysterious, filled with ancient trees and hidden paths.\n", + "Answer: user\n", + "You are a helpful AI assistant. Answer the question based on the context provided.\n", + "\n", + "Context:\n", + "The forest was dark and mysterious, filled with ancient trees and hidden paths.\n", + "\n", + "Question: Describe the forest.\n", + "model\n", + "The forest is dark and mysterious, filled with ancient trees and hidden paths. \n", + "\n", + "----------------------------------------\n" + ] + } + ], + "source": [ + "# 6. Ask Questions\n", + "query1 = \"Why did Fiona thank Ethan?\"\n", + "ask_query(query1)\n", + "\n", + "query2 = \"What challenges did Ethan and Fiona face in the forest?\"\n", + "ask_query(query2)\n", + "\n", + "query3 = \"What tools did Ethan use to navigate?\"\n", + "ask_query(query3)\n", + "\n", + "query4 = \"Describe the forest.\"\n", + "ask_query(query4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3ee915c8-a7c4-46cc-a49c-edc4de517a5e", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}