diff --git a/.github/README.md b/.github/README.md
index 35833aa..54f1a20 100644
--- a/.github/README.md
+++ b/.github/README.md
@@ -166,6 +166,12 @@ We include a series of notebook+scripts for fine tuning the models.
* [Vision fine tuning Gemma 3 4B with Unsloth](/notebooks/Gemma3_(4B)-Vision.ipynb)
* [Conversational fine tuning Gemma 3 4B with Unsloth](/notebooks/Gemma3_(4B).ipynb)
+## RAG
+
+### Gemma 3n
+* [Retrieval-Augmented Generation with Gemma 3n](/notebooks/Gemma_RAG.ipynb)
+
+
Before fine-tuning the model, ensure all dependencies are installed:
```bash
diff --git a/notebooks/Gemma_RAG.ipynb b/notebooks/Gemma_RAG.ipynb
new file mode 100644
index 0000000..5cf3619
--- /dev/null
+++ b/notebooks/Gemma_RAG.ipynb
@@ -0,0 +1,487 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "884c076b-6f0b-4c42-b965-65dc046d29c1",
+ "metadata": {},
+ "source": [
+ "# 🧠 Gemma_RAG: Lightweight Retrieval-Augmented Generation with Gemma\n",
+ "\n",
+ "A minimal example of using Retrieval-Augmented Generation (RAG) with Gemma models, integrated with `sentence-transformers`, `FAISS`, and `Streamlit`. This notebook is runnable on **free Colab instances**."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "6f346d53-49db-44ea-9920-ad8ad16e0267",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install -Uq sentence-transformers transformers accelerate faiss-cpu timm"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9076a2b3-30ca-47a7-a242-8c8e82e08616",
+ "metadata": {},
+ "source": [
+ "### 📦 Importing Required Libraries\n",
+ "\n",
+ "This cell imports all the libraries needed for the project:\n",
+ "- `os` for accessing environment variables like HF tokens\n",
+ "- `torch` for deep learning with GPU support\n",
+ "- `transformers` for loading the Gemma language model\n",
+ "- `sentence-transformers` for creating semantic embeddings\n",
+ "- `faiss` for fast similarity search\n",
+ "- `numpy` for array manipulation and type casting"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "bab99e08-1edc-4bc3-9386-cb54a00b2342",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import torch\n",
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline\n",
+ "from sentence_transformers import SentenceTransformer\n",
+ "import faiss\n",
+ "import numpy as np"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "009ce4ae-f769-4b67-b6b3-a6dccbaf5868",
+ "metadata": {},
+ "source": [
+ "### 🧠 Model Setup and Token Loading\n",
+ "\n",
+ "This cell loads the **authentication token**, sets the **model ID**, and initializes both the **tokenizer** and the **language model** (`Gemma-3n-E4B-it`) from the Hugging Face Hub. These steps are essential to prepare the model for inference (i.e., generating text).\n",
+ "\n",
+ "- `token = os.environ.get(\"HF_TOKEN\")` \n",
+ " Retrieves your Hugging Face token from environment variables. This is used to authenticate access to gated models (like Gemma-3n) securely. By storing the token in the environment, you avoid hardcoding sensitive info in your notebook.\n",
+ "\n",
+ "- `model_id = \"google/gemma-3n-E4B-it\"` \n",
+ " Specifies the exact model you want to use from the Hugging Face Model Hub. In this case, you're using **Gemma-3n-E4B-it**, a 3-billion-parameter instruction-tuned language model developed by Google. This string acts as a reference for downloading both the tokenizer and model weights.\n",
+ "\n",
+ "- `tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)` \n",
+ " Loads the tokenizer that matches the specified Gemma model. The tokenizer transforms raw input text (e.g., `\"What happened?\"`) into token IDs that the model understands. Using `AutoTokenizer` ensures the right tokenizer is chosen automatically based on the model’s config file. The `token=token` part ensures access to the tokenizer files from a private/gated model if necessary.\n",
+ "\n",
+ "- `gemma_model = AutoModelForCausalLM.from_pretrained(model_id, token=token, torch_dtype=torch.bfloat16, device_map={\"\": 0})` \n",
+ " Loads the **Gemma-3n language model weights** for causal language modeling (i.e., left-to-right generation). \n",
+ " - `token=token`: Ensures authenticated access. \n",
+ " - `torch_dtype=torch.bfloat16`: Loads the model using Brain Float 16 precision, which is memory-efficient and optimized for newer GPUs like the A100. \n",
+ " - `device_map={\"\": 0}`: Places the full model on GPU 0 (i.e., `cuda:0`), preventing the runtime error you’d get if tensors are split across `cuda:0` and `cuda:1`.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "306ecee2-e6fe-4d08-8eb7-8c913d6a0297",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00, 1.39s/it]\n"
+ ]
+ }
+ ],
+ "source": [
+ "token = os.environ.get(\"HF_TOKEN\")\n",
+ "model_id = \"google/gemma-3n-E4B-it\"\n",
+ "tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)\n",
+ "gemma_model = AutoModelForCausalLM.from_pretrained(model_id,token=token,torch_dtype=torch.bfloat16,device_map={\"\":0})"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "64b42da5-dd57-47c2-8f9d-6418c5a18cc6",
+ "metadata": {},
+ "source": [
+ "### 🔁 Creating the Text Generation Pipeline\n",
+ "\n",
+ "This cell creates a **text generation pipeline** using Hugging Face’s `pipeline()` utility. The pipeline wraps the model and tokenizer together and handles the full process of generating natural language output from a prompt.\n",
+ "\n",
+ "- `generator = pipeline(\"text-generation\", ...)` \n",
+ " Initializes a high-level text generation pipeline for causal language models. This abstraction lets you input raw text and get full model-generated outputs without manually handling tokenization or decoding.\n",
+ "\n",
+ "- `model=gemma_model` \n",
+ " Sets the pretrained Gemma model as the core component that will perform text generation.\n",
+ "\n",
+ "- `tokenizer=tokenizer` \n",
+ " Supplies the tokenizer needed to convert input strings into token IDs that the model can understand.\n",
+ "\n",
+ "- `device_map=0` \n",
+ " Assigns the model and data to GPU 0 (`cuda:0`). This is important to avoid device mismatch errors when using multiple GPUs.\n",
+ "\n",
+ "- `torch_dtype=torch.bfloat16` \n",
+ " Sets the numerical precision for model weights and activations to bfloat16, which is memory-efficient and optimized for modern GPUs like the A100.\n",
+ "\n",
+ "- `max_new_tokens=256` \n",
+ " Limits how many tokens the model can generate in response to a prompt. A larger value allows for longer, more detailed outputs.\n",
+ "\n",
+ "> 💡 This pipeline simplifies the generation process so you can just call `generator(prompt)` and receive a coherent answer in return.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "2c228924-8cff-48c7-8594-e623b2ac6f77",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Device set to use cuda:0\n"
+ ]
+ }
+ ],
+ "source": [
+ "generator = pipeline(\n",
+ " \"text-generation\",\n",
+ " model=gemma_model,\n",
+ " tokenizer=tokenizer,\n",
+ " device_map=0,\n",
+ " torch_dtype=torch.bfloat16,\n",
+ " max_new_tokens=256, # Increased max tokens for more detailed responses\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "120262a3-aa21-4208-9339-e71d6cadc081",
+ "metadata": {},
+ "source": [
+ "### 📑 Step 2: Text Snippet Retrieval Setup\n",
+ "\n",
+ "In this cell, we define a list of short narrative passages or **context snippets** that describe key events, locations, and interactions between characters (Ethan and Fiona). These text entries will later serve as the **knowledge base** for answering questions using semantic search.\n",
+ "\n",
+ "- `text_snippets = [...]` \n",
+ " This is a Python list that contains multiple text strings. Each string represents a small piece of a story or description.\n",
+ "\n",
+ "These snippets will be:\n",
+ "- Embedded using a sentence transformer model.\n",
+ "- Indexed using FAISS for fast similarity search.\n",
+ "- Used as context when answering user questions via a large language model.\n",
+ "\n",
+ "> 📚 This is a crucial part of the RAG (Retrieval-Augmented Generation) setup, where relevant knowledge is retrieved from this list and passed as input to the language model for grounded, context-aware answers.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "1e16e5ce-a42a-4cc1-a990-a47d8947bcb6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# 2. Text Snippet Retrieval Setup\n",
+ "text_snippets = [\n",
+ " \"Fiona thanked Ethan for his unwavering support and promised to cherish their friendship.\",\n",
+ " \"As they ventured deeper into the forest, they encountered a wide array of obstacles.\",\n",
+ " \"Ethan and Fiona crossed treacherous ravines using rickety bridges, relying on each other's strength.\",\n",
+ " \"Overwhelmed with joy, Fiona thanked Ethan and disappeared into the embrace of her family.\",\n",
+ " \"Ethan returned to his cottage, heart full of memories and a smile brighter than ever before.\",\n",
+ " \"The forest was dark and mysterious, filled with ancient trees and hidden paths.\",\n",
+ " \"Ethan always carried a map and compass, ensuring they never lost their way.\",\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "084753ff-fea1-4c2f-8ac1-a1f32d9fd134",
+ "metadata": {},
+ "source": [
+ "### 🔍 Step 3: Enhanced Retrieval Mechanism — Semantic Search with FAISS\n",
+ "\n",
+ "This section sets up the **semantic embedding** and **vector search index** needed to perform efficient and meaningful retrieval of relevant text snippets based on a user's query.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "- `embedding_model = SentenceTransformer(\"all-MiniLM-L6-v2\")` \n",
+ " Loads a lightweight, high-performance sentence embedding model from the Sentence Transformers library. This model converts sentences into dense numerical vectors (embeddings) that capture their semantic meaning.\n",
+ "\n",
+ "- `embeddings_text_snippets = embedding_model.encode(text_snippets)` \n",
+ " Generates vector embeddings for each of the predefined text snippets. These embeddings will later be compared to the query embedding to find the most relevant snippet.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "### ⚙️ FAISS Index Creation\n",
+ "\n",
+ "- `dimension = embeddings_text_snippets.shape[1]` \n",
+ " Extracts the dimensionality of each embedding vector (e.g., 384), which is required to initialize the FAISS index correctly.\n",
+ "\n",
+ "- `index = faiss.IndexFlatL2(dimension)` \n",
+ " Initializes a **FAISS index** that uses L2 distance (Euclidean distance) to compare vectors. This allows for fast and efficient similarity search between embeddings.\n",
+ "\n",
+ "- `index.add(embeddings_text_snippets.astype(np.float32))` \n",
+ " Adds all the text snippet embeddings to the FAISS index after converting them to `float32`, which is the required input format for FAISS.\n",
+ "\n",
+ "> ⚡ This enables real-time semantic search, where a user’s question can be matched to the most semantically similar snippet — even if they use different words.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "3397f9bc-30d3-4bc3-88ed-dd18e2e918eb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# 3. Enhanced Retrieval Mechanism: Semantic Search with FAISS\n",
+ "embedding_model = SentenceTransformer(\"all-MiniLM-L6-v2\")\n",
+ "embeddings_text_snippets = embedding_model.encode(text_snippets)\n",
+ "\n",
+ "# FAISS Index Creation\n",
+ "dimension = embeddings_text_snippets.shape[1] # Embedding dimension\n",
+ "index = faiss.IndexFlatL2(dimension) # L2 distance (Euclidean)\n",
+ "index.add(embeddings_text_snippets.astype(np.float32)) # FAISS requires float32"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a9f852a5-94ca-449e-9a8c-e04516f6ce08",
+ "metadata": {},
+ "source": [
+ "### 🧠 Step 4: Retrieval Function (Semantic Search)\n",
+ "\n",
+ "This function takes a user query and returns the **most semantically similar snippet** from the previously indexed text corpus using **FAISS-based nearest neighbor search**.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "- `def retrieve_snippet(query, k=1):` \n",
+ " Defines a Python function that accepts a query string and retrieves `k` most similar snippets. By default, `k=1`, meaning it returns only the top match.\n",
+ "\n",
+ "- `query_embedded = embedding_model.encode([query]).astype(np.float32)` \n",
+ " Converts the query string into an embedding vector using the same sentence embedding model used for the snippets. FAISS requires all vectors to be in `float32`, so the type is cast accordingly.\n",
+ "\n",
+ "- `D, I = index.search(query_embedded, k)` \n",
+ " Searches the FAISS index to find the `k` most similar embeddings to the query. \n",
+ " - `D`: distances (lower = more similar) \n",
+ " - `I`: indices of the most similar snippets in the original list\n",
+ "\n",
+ "- `retrieved_indices = I[0]` \n",
+ " Extracts the list of top-k indices from the FAISS result. Since only one query is being processed, we access the first (and only) row of `I`.\n",
+ "\n",
+ "- `retrieved_texts = [text_snippets[i] for i in retrieved_indices]` \n",
+ " Uses the retrieved indices to extract the corresponding text snippets from the original list.\n",
+ "\n",
+ "- `return retrieved_texts[0]` \n",
+ " Returns only the **most relevant snippet**. This snippet will later be used as context for the language model during text generation.\n",
+ "\n",
+ "> 💡 This function powers the semantic retrieval part of RAG — ensuring the model responds using real context instead of hallucinating answers.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "7637263e-6d20-450a-b458-e9e2e66a608b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# 4. Retrieval Function (Semantic Search)\n",
+ "def retrieve_snippet(query, k=1): # k is the number of snippets to retrieve\n",
+ " query_embedded = embedding_model.encode([query]).astype(np.float32)\n",
+ " D, I = index.search(query_embedded, k) # D: distances, I: indices\n",
+ " retrieved_indices = I[0]\n",
+ " retrieved_texts = [text_snippets[i] for i in retrieved_indices]\n",
+ " return retrieved_texts[0] # Return only the top snippet\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "99147cf9-fff9-4379-b7aa-6888706d9e7b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# 5. Create a function to generate the answer based on the retrieved snippet and query\n",
+ "def ask_query(query):\n",
+ " retrieved_text = retrieve_snippet(query)\n",
+ "\n",
+ " # Step 1: Construct chat messages as a list of roles/content\n",
+ " chat = [\n",
+ " {\n",
+ " \"role\": \"system\",\n",
+ " \"content\": \"You are a helpful AI assistant. Answer the question based on the context provided.\",\n",
+ " },\n",
+ " {\n",
+ " \"role\": \"user\",\n",
+ " \"content\": f\"\"\"Context:\n",
+ "{retrieved_text}\n",
+ "\n",
+ "Question: {query}\"\"\",\n",
+ " },\n",
+ " ]\n",
+ "\n",
+ " # Step 2: Use tokenizer's chat template to format this\n",
+ " prompt_ids = tokenizer.apply_chat_template(\n",
+ " chat,\n",
+ " tokenize=True,\n",
+ " add_generation_prompt=True, # add assistant tag to begin model generation\n",
+ " return_tensors=\"pt\"\n",
+ " ).to(gemma_model.device)\n",
+ "\n",
+ " # Step 3: Generate using the raw model\n",
+ " output = gemma_model.generate(prompt_ids, max_new_tokens=128)\n",
+ " response = tokenizer.decode(output[0], skip_special_tokens=True)\n",
+ "\n",
+ " print(f\"Query: {query}\")\n",
+ " print(f\"Context: {retrieved_text}\")\n",
+ " print(f\"Answer: {response}\")\n",
+ " print(\"-\" * 40)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d4fdcda7-4f16-46c3-a22e-53d766625ea2",
+ "metadata": {},
+ "source": [
+ "### 🗣️ Step 6: Ask Questions\n",
+ "\n",
+ "This block runs a series of **user-defined natural language queries** through the full Retrieval-Augmented Generation (RAG) pipeline, using the `ask_query()` function. For each question, the pipeline:\n",
+ "\n",
+ "1. **Finds the most semantically similar snippet** using FAISS-based search.\n",
+ "2. **Constructs a prompt** that includes the retrieved snippet as context.\n",
+ "3. **Generates an answer** using the Gemma language model.\n",
+ "\n",
+ "---\n",
+ "\n",
+ "- `query1 = \"Why did Fiona thank Ethan?\"` \n",
+ " A straightforward question to test if the model can connect Fiona’s gratitude to Ethan’s support. \n",
+ " → Passed to `ask_query(query1)` to fetch the answer.\n",
+ "\n",
+ "- `query2 = \"What challenges did Ethan and Fiona face in the forest?\"` \n",
+ " A more complex question that probes the model’s understanding of events and obstacles. \n",
+ " → Answer will depend on the forest-related snippets.\n",
+ "\n",
+ "- `query3 = \"What tools did Ethan use to navigate?\"` \n",
+ " A factual retrieval question. The model should extract and summarize tools like a map or compass.\n",
+ "\n",
+ "- `query4 = \"Describe the forest.\"` \n",
+ " An open-ended descriptive query that should trigger a more vivid narrative response based on stored context.\n",
+ "\n",
+ "> 🧠 These queries showcase how the system can handle **factual, contextual, and descriptive questions** using real context — avoiding hallucinated answers.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "42395c04-f4f3-4eac-a8c2-b2b086016a73",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Query: Why did Fiona thank Ethan?\n",
+ "Context: Fiona thanked Ethan for his unwavering support and promised to cherish their friendship.\n",
+ "Answer: user\n",
+ "You are a helpful AI assistant. Answer the question based on the context provided.\n",
+ "\n",
+ "Context:\n",
+ "Fiona thanked Ethan for his unwavering support and promised to cherish their friendship.\n",
+ "\n",
+ "Question: Why did Fiona thank Ethan?\n",
+ "model\n",
+ "Fiona thanked Ethan for his unwavering support. \n",
+ "\n",
+ "----------------------------------------\n",
+ "Query: What challenges did Ethan and Fiona face in the forest?\n",
+ "Context: Ethan and Fiona crossed treacherous ravines using rickety bridges, relying on each other's strength.\n",
+ "Answer: user\n",
+ "You are a helpful AI assistant. Answer the question based on the context provided.\n",
+ "\n",
+ "Context:\n",
+ "Ethan and Fiona crossed treacherous ravines using rickety bridges, relying on each other's strength.\n",
+ "\n",
+ "Question: What challenges did Ethan and Fiona face in the forest?\n",
+ "model\n",
+ "Based on the context, Ethan and Fiona faced the challenge of crossing treacherous ravines using rickety bridges. This implies a physically dangerous obstacle and a need for careful coordination and reliance on each other. \n",
+ "\n",
+ "So the answer is: **They faced the challenge of crossing treacherous ravines using rickety bridges.**\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "----------------------------------------\n",
+ "Query: What tools did Ethan use to navigate?\n",
+ "Context: Ethan always carried a map and compass, ensuring they never lost their way.\n",
+ "Answer: user\n",
+ "You are a helpful AI assistant. Answer the question based on the context provided.\n",
+ "\n",
+ "Context:\n",
+ "Ethan always carried a map and compass, ensuring they never lost their way.\n",
+ "\n",
+ "Question: What tools did Ethan use to navigate?\n",
+ "model\n",
+ "Ethan used a map and compass to navigate. \n",
+ "\n",
+ "----------------------------------------\n",
+ "Query: Describe the forest.\n",
+ "Context: The forest was dark and mysterious, filled with ancient trees and hidden paths.\n",
+ "Answer: user\n",
+ "You are a helpful AI assistant. Answer the question based on the context provided.\n",
+ "\n",
+ "Context:\n",
+ "The forest was dark and mysterious, filled with ancient trees and hidden paths.\n",
+ "\n",
+ "Question: Describe the forest.\n",
+ "model\n",
+ "The forest is dark and mysterious, filled with ancient trees and hidden paths. \n",
+ "\n",
+ "----------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 6. Ask Questions\n",
+ "query1 = \"Why did Fiona thank Ethan?\"\n",
+ "ask_query(query1)\n",
+ "\n",
+ "query2 = \"What challenges did Ethan and Fiona face in the forest?\"\n",
+ "ask_query(query2)\n",
+ "\n",
+ "query3 = \"What tools did Ethan use to navigate?\"\n",
+ "ask_query(query3)\n",
+ "\n",
+ "query4 = \"Describe the forest.\"\n",
+ "ask_query(query4)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3ee915c8-a7c4-46cc-a49c-edc4de517a5e",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.13.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}