Skip to content

Commit e1cb7e4

Browse files
committed
Refined qwen colab
Signed-off-by: Vladimir Suvorov <[email protected]>
1 parent 2cea782 commit e1cb7e4

File tree

1 file changed

+257
-0
lines changed

1 file changed

+257
-0
lines changed
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"Run SFT on Qwen3-0.6B model\n",
8+
"\n",
9+
"This collab can run on the public TPU 5e-1\n",
10+
"\n",
11+
"This notebook demonstrates how to perform Supervised Fine-Tuning (SFT) on Qwen3-0.6B using the Hugging Face ultrachat_200k dataset with Tunix integration for efficient training.\n",
12+
"\n",
13+
"Dataset Overview\n",
14+
"https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k\n",
15+
"\n",
16+
"Dataset Information:\n",
17+
"\n",
18+
"Name: HuggingFaceH4/ultrachat_200k\n",
19+
"Type: Supervised Fine-Tuning dataset\n",
20+
"Size: ~200k conversations\n",
21+
"Format: Chat conversations with human-AI pairs\n",
22+
"Splits: train_sft, test_sft\n",
23+
"Data columns: ['messages']\n",
24+
"Dataset Structure: Each example contains a 'messages' field with:\n",
25+
"\n",
26+
"role: 'user' or 'assistant'\n",
27+
"content: The actual message text\n",
28+
"Example data format:\n",
29+
"\n",
30+
"{\n",
31+
" \"messages\": [\n",
32+
" {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n",
33+
" {\"role\": \"assistant\", \"content\": \"The capital of France is Paris.\"}\n",
34+
" ]\n",
35+
"}\n",
36+
"\n",
37+
"Prerequisites\n",
38+
"HuggingFace access token for dataset download\n",
39+
"Sufficient compute resources (TPU/GPU)"
40+
]
41+
},
42+
{
43+
"cell_type": "code",
44+
"execution_count": null,
45+
"metadata": {
46+
"id": "Wr4OOETu8elP"
47+
},
48+
"outputs": [],
49+
"source": [
50+
"### (Optional) Run this if you just have this file and nothing else\n",
51+
"\n",
52+
"# 1. Clone the MaxText repository (from AI‑Hypercomputer)\n",
53+
"!git clone https://github.com/AI-Hypercomputer/maxtext.git\n",
54+
"\n",
55+
"# 2. Navigate into the cloned directory\n",
56+
"%cd maxtext"
57+
]
58+
},
59+
{
60+
"cell_type": "code",
61+
"execution_count": null,
62+
"metadata": {
63+
"id": "5KPyOE8e9WbO"
64+
},
65+
"outputs": [],
66+
"source": [
67+
"### (Optional) Do not run this if you already installed the dependencies\n",
68+
"\n",
69+
"# 3. Ensure setup.sh is executable\n",
70+
"!chmod +x setup.sh\n",
71+
"\n",
72+
"# 4. Execute the setup script\n",
73+
"!./setup.sh\n",
74+
"\n",
75+
"# force numpy version\n",
76+
"!pip install --force-reinstall numpy==2.1.2\n",
77+
"#install nest_asyncio\n",
78+
"!pip install nest_asyncio\n",
79+
"\n",
80+
"import nest_asyncio\n",
81+
"nest_asyncio.apply()\n",
82+
"# To fix \"This event loop is already running\" error in Colab\n"
83+
]
84+
},
85+
{
86+
"cell_type": "code",
87+
"execution_count": null,
88+
"metadata": {
89+
"id": "CJnhPxUq_G6a"
90+
},
91+
"outputs": [],
92+
"source": [
93+
"import os\n",
94+
"import sys\n",
95+
"# Set home directory. Change this to your home directory where maxtext is cloned\n",
96+
"MAXTEXT_HOME = os.path.join(\"/content\", \"maxtext\")\n",
97+
"print(f\"Home directory (from Python): {MAXTEXT_HOME}\")\n",
98+
"#MODEL_CHECKPOINT_PATH = \"path/to/scanned/checkpoint\""
99+
]
100+
},
101+
{
102+
"cell_type": "code",
103+
"execution_count": null,
104+
"metadata": {
105+
"id": "CxzKMBQd_U5-"
106+
},
107+
"outputs": [],
108+
"source": [
109+
"from pathlib import Path\n",
110+
"from typing import Optional, Dict, Any\n",
111+
"\n",
112+
"# Find MaxText directory and change working directory to it\n",
113+
"current_dir = Path.cwd()\n",
114+
"if current_dir.name == 'examples':\n",
115+
" # We're in the examples folder, go up one level\n",
116+
" maxtext_path = current_dir.parent.parent\n",
117+
"else:\n",
118+
" # We're in the root, MaxText is a subfolder\n",
119+
" maxtext_path = Path(f'{MAXTEXT_HOME}') / 'src' / 'MaxText'\n",
120+
"\n",
121+
"# Change working directory to MaxText project root\n",
122+
"os.chdir(maxtext_path)\n",
123+
"sys.path.insert(0, str(maxtext_path))\n",
124+
"\n",
125+
"print(f\"✓ Changed working directory to: {os.getcwd()}\")\n",
126+
"print(f\"✓ MaxText project root: {maxtext_path}\")\n",
127+
"print(f\"✓ Added to Python path: {maxtext_path}\")\n",
128+
"import jax\n",
129+
"if not jax.distributed.is_initialized():\n",
130+
" jax.distributed.initialize()\n",
131+
"print(f\"JAX version: {jax.__version__}\")\n",
132+
"print(f\"JAX devices: {jax.devices()}\")\n"
133+
]
134+
},
135+
{
136+
"cell_type": "code",
137+
"execution_count": null,
138+
"metadata": {
139+
"id": "rKS8nVYgAbwE"
140+
},
141+
"outputs": [],
142+
"source": [
143+
"# Hugging Face Authentication Setup\n",
144+
"from huggingface_hub import login\n",
145+
"\n",
146+
"# Set your Hugging Face token here\n",
147+
"HF_TOKEN = \"your_actual_token_here\" # Replace with your actual token\n",
148+
"login(token=HF_TOKEN)\n"
149+
]
150+
},
151+
{
152+
"cell_type": "code",
153+
"execution_count": null,
154+
"metadata": {
155+
"id": "aR0zTWkxAs4t"
156+
},
157+
"outputs": [],
158+
"source": [
159+
"# MaxText imports\n",
160+
"try:\n",
161+
" from MaxText import pyconfig\n",
162+
" from MaxText.sft.sft_trainer import train as sft_train\n",
163+
"\n",
164+
" MAXTEXT_AVAILABLE = True\n",
165+
" print(\"✓ MaxText imports successful\")\n",
166+
"except ImportError as e:\n",
167+
" print(f\"⚠️ MaxText not available: {e}\")\n",
168+
" MAXTEXT_AVAILABLE = False"
169+
]
170+
},
171+
{
172+
"cell_type": "code",
173+
"execution_count": null,
174+
"metadata": {
175+
"id": "In-jdp1AAwrL"
176+
},
177+
"outputs": [],
178+
"source": [
179+
"# Fixed configuration setup for Qwen-0.6B on small TPU\n",
180+
"if MAXTEXT_AVAILABLE:\n",
181+
" config_argv = [\n",
182+
" \"\",\n",
183+
" f\"{MAXTEXT_HOME}/src/MaxText/configs/sft.yml\", # base SFT config\n",
184+
" \"model_name=qwen3-0.6b\",\n",
185+
" \"steps=20\", # very short run for testing\n",
186+
" \"per_device_batch_size=1\", # minimal to avoid OOM\n",
187+
" \"max_target_length=512\", # shorter context to fit memory\n",
188+
" \"learning_rate=2.0e-5\", # safe small LR\n",
189+
" \"eval_steps=5\",\n",
190+
" \"weight_dtype=bfloat16\",\n",
191+
" \"dtype=bfloat16\",\n",
192+
" \"hf_path=HuggingFaceH4/ultrachat_200k\", # HuggingFace dataset/model if needed\n",
193+
" f\"hf_access_token={HF_TOKEN}\",\n",
194+
" \"base_output_directory=/tmp/maxtext_qwen06\",\n",
195+
" \"run_name=sft_qwen0.6b_test\",\n",
196+
" \"tokenizer_path=Qwen/Qwen3-0.6B\", # Qwen tokenizer\n",
197+
" \"eval_interval=10\",\n",
198+
" \"steps=100\",\n",
199+
" \"profiler=xplane\",\n",
200+
" ]\n",
201+
"\n",
202+
" # Initialize configuration using MaxText's pyconfig\n",
203+
" config = pyconfig.initialize(config_argv)\n",
204+
"\n",
205+
" print(\"✓ Fixed configuration loaded:\")\n",
206+
" print(f\" - Model: {config.model_name}\")\n",
207+
" print(f\" - Dataset: {config.hf_path}\")\n",
208+
" print(f\" - Steps: {config.steps}\")\n",
209+
" print(f\" - Use SFT: {config.use_sft}\")\n",
210+
" print(f\" - Learning Rate: {config.learning_rate}\")\n",
211+
"else:\n",
212+
" print(\"MaxText not available - cannot load configuration\")"
213+
]
214+
},
215+
{
216+
"cell_type": "markdown",
217+
"metadata": {
218+
"id": "EJE1ookSAzz-"
219+
},
220+
"source": []
221+
},
222+
{
223+
"cell_type": "code",
224+
"execution_count": null,
225+
"metadata": {
226+
"id": "mgwpNgQYCJEd"
227+
},
228+
"outputs": [],
229+
"source": [
230+
"# Execute the training using MaxText SFT trainer's train() function\n",
231+
"if MAXTEXT_AVAILABLE:\n",
232+
" print(\"=\"*60)\n",
233+
" print(\"EXECUTING ACTUAL TRAINING\")\n",
234+
" print(\"=\"*60)\n",
235+
"\n",
236+
"\n",
237+
" sft_train(config)\n"
238+
]
239+
}
240+
],
241+
"metadata": {
242+
"accelerator": "TPU",
243+
"colab": {
244+
"gpuType": "V5E1",
245+
"provenance": []
246+
},
247+
"kernelspec": {
248+
"display_name": "Python 3",
249+
"name": "python3"
250+
},
251+
"language_info": {
252+
"name": "python"
253+
}
254+
},
255+
"nbformat": 4,
256+
"nbformat_minor": 0
257+
}

0 commit comments

Comments
 (0)