@@ -29,11 +29,14 @@ import (
2929 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
3030)
3131
32- func TestRayFinetuneLlmDeepspeedDemo (t * testing.T ) {
33- rayFinetuneLlmDeepspeed (t , 1 )
32+ func TestRayFinetuneLlmDeepspeedDemoLlama_2_7b (t * testing.T ) {
33+ rayFinetuneLlmDeepspeed (t , 1 , "zero_3_llama_2_7b.json" )
34+ }
35+ func TestRayFinetuneLlmDeepspeedDemoLlama_31_8b (t * testing.T ) {
36+ rayFinetuneLlmDeepspeed (t , 1 , "zero_3_offload_optim_param.json" )
3437}
3538
36- func rayFinetuneLlmDeepspeed (t * testing.T , numGpus int ) {
39+ func rayFinetuneLlmDeepspeed (t * testing.T , numGpus int , modelConfigFile string ) {
3740 test := With (t )
3841
3942 // Create a namespace
@@ -51,7 +54,7 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
5154 // list changes required in llm-deepspeed-finetune-demo.ipynb file and update those
5255 requiredChangesInNotebook := map [string ]string {
5356 "import os" : "import os,time,sys" ,
54- "import sys" : "!cp /opt/app-root/notebooks/* ./" ,
57+ "import sys" : "!cp /opt/app-root/notebooks/* ./\\ n \" , \n \t \" !ls " ,
5558 "from codeflare_sdk.cluster.auth import TokenAuthentication" : "from codeflare_sdk.cluster.auth import TokenAuthentication\\ n\" ,\n \t \" from codeflare_sdk.job import RayJobClient" ,
5659 "token = ''" : fmt .Sprintf ("token = '%s'" , userToken ),
5760 "server = ''" : fmt .Sprintf ("server = '%s'" , GetOpenShiftApiUrl (test )),
@@ -61,23 +64,26 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
6164 "num_workers=7" : "num_workers=1" ,
6265 "worker_cpu_requests=16" : "worker_cpu_requests=4" ,
6366 "worker_cpu_limits=16" : "worker_cpu_limits=4" ,
64- "worker_memory_requests=128" : "worker_memory_requests=60 " ,
65- "worker_memory_limits=256" : "worker_memory_limits=60 " ,
67+ "worker_memory_requests=128" : "worker_memory_requests=64 " ,
68+ "worker_memory_limits=256" : "worker_memory_limits=128 " ,
6669 "head_memory=128" : "head_memory=48" ,
6770 "client = cluster.job_client" : "ray_dashboard = cluster.cluster_dashboard_uri()\\ n\" ,\n \t \" header = {\\ \" Authorization\\ \" : \\ \" Bearer " + userToken + "\\ \" }\\ n\" ,\n \t \" client = RayJobClient(address=ray_dashboard, headers=header, verify=False)\\ n" ,
6871 "--num-devices=8" : fmt .Sprintf ("--num-devices=%d" , numGpus ),
6972 "--num-epochs=3" : fmt .Sprintf ("--num-epochs=%d" , 1 ),
70- "--ds-config=./deepspeed_configs/zero_3_llama_2_7b.json" : "--ds-config=./zero_3_llama_2_7b.json \\ \" \\ n\" ,\n \t \" \\ \" --lora-config=./lora.json \\ \" \\ n\" ,\n \t \" \\ \" --as-test" ,
71- "'pip': 'requirements.txt'" : "'pip': '/opt/app-root/src/requirements.txt'" ,
72- "'working_dir': './'" : "'working_dir': '/opt/app-root/src'" ,
73- "client.stop_job(submission_id)" : "finished = False\\ n\" ,\n \t \" while not finished:\\ n\" ,\n \t \" time.sleep(1)\\ n\" ,\n \t \" status = client.get_job_status(submission_id)\\ n\" ,\n \t \" finished = (status == \\ \" SUCCEEDED\\ \" )\\ n\" ,\n \t \" if finished:\\ n\" ,\n \t \" print(\\ \" Job completed Successfully !\\ \" )\\ n\" ,\n \t \" else:\\ n\" ,\n \t \" print(\\ \" Job failed !\\ \" )\\ n\" ,\n \t \" time.sleep(10)\\ n" ,
73+ "--ds-config=./deepspeed_configs/zero_3_offload_optim+param.json" : fmt .Sprintf ("--ds-config=./%s \\ \" \\ n\" ,\n \t \" \\ \" --lora-config=./lora.json \\ \" \\ n\" ,\n \t \" \\ \" --as-test" , modelConfigFile ),
74+ "--batch-size-per-device=32" : "--batch-size-per-device=6" ,
75+ "--eval-batch-size-per-device=32" : "--eval-batch-size-per-device=6" ,
76+ "'pip': 'requirements.txt'" : "'pip': '/opt/app-root/src/requirements.txt'" ,
77+ "'working_dir': './'" : "'working_dir': '/opt/app-root/src'" ,
78+ "client.stop_job(submission_id)" : "finished = False\\ n\" ,\n \t \" while not finished:\\ n\" ,\n \t \" time.sleep(1)\\ n\" ,\n \t \" status = client.get_job_status(submission_id)\\ n\" ,\n \t \" finished = (status == \\ \" SUCCEEDED\\ \" )\\ n\" ,\n \t \" if finished:\\ n\" ,\n \t \" print(\\ \" Job completed Successfully !\\ \" )\\ n\" ,\n \t \" else:\\ n\" ,\n \t \" print(\\ \" Job failed !\\ \" )\\ n\" ,\n \t \" time.sleep(10)\\ n" ,
7479 }
7580
7681 updatedNotebookContent := string (ReadFileExt (test , workingDirectory + "/../../examples/ray-finetune-llm-deepspeed/ray_finetune_llm_deepspeed.ipynb" ))
7782 for oldValue , newValue := range requiredChangesInNotebook {
7883 updatedNotebookContent = strings .Replace (updatedNotebookContent , oldValue , newValue , - 1 )
7984 }
8085 updatedNotebook := []byte (updatedNotebookContent )
86+ os .WriteFile ("demo.ipynb" , updatedNotebook , 0644 )
8187
8288 // Test configuration
8389 jupyterNotebookConfigMapFileName := "ray_finetune_llm_deepspeed.ipynb"
@@ -87,7 +93,7 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
8793 "requirements.txt" : ReadFileExt (test , workingDirectory + "/../../examples/ray-finetune-llm-deepspeed/requirements.txt" ),
8894 "create_dataset.py" : ReadFileExt (test , workingDirectory + "/../../examples/ray-finetune-llm-deepspeed/create_dataset.py" ),
8995 "lora.json" : ReadFileExt (test , workingDirectory + "/../../examples/ray-finetune-llm-deepspeed/lora_configs/lora.json" ),
90- "zero_3_llama_2_7b.json" : ReadFileExt (test , workingDirectory + "/../../examples/ray-finetune-llm-deepspeed/deepspeed_configs/zero_3_llama_2_7b.json" ),
96+ modelConfigFile : ReadFileExt (test , fmt . Sprintf ( workingDirectory + "/../../examples/ray-finetune-llm-deepspeed/deepspeed_configs/%s" , modelConfigFile ) ),
9197 "utils.py" : ReadFileExt (test , workingDirectory + "/../../examples/ray-finetune-llm-deepspeed/utils.py" ),
9298 }
9399
@@ -120,7 +126,7 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
120126
121127 // Initialise raycluster client to interact with raycluster to get rayjob details using REST-API
122128 dashboardUrl := GetDashboardUrl (test , namespace , rayCluster )
123- rayClusterClientConfig := RayClusterClientConfig {Address : dashboardUrl .String (), Client : nil , SkipTlsVerification : true }
129+ rayClusterClientConfig := RayClusterClientConfig {Address : dashboardUrl .String (), Client : nil , InsecureSkipVerify : true }
124130 rayClient , err := NewRayClusterClient (rayClusterClientConfig , test .Config ().BearerToken )
125131 if err != nil {
126132 test .T ().Errorf ("%s" , err )
0 commit comments