@@ -29,11 +29,14 @@ import (
29
29
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
30
30
)
31
31
32
- func TestRayFinetuneLlmDeepspeedDemo (t * testing.T ) {
33
- rayFinetuneLlmDeepspeed (t , 1 )
32
+ func TestRayFinetuneLlmDeepspeedDemoLlama_2_7b (t * testing.T ) {
33
+ rayFinetuneLlmDeepspeed (t , 1 , "zero_3_llama_2_7b.json" )
34
+ }
35
+ func TestRayFinetuneLlmDeepspeedDemoLlama_31_8b (t * testing.T ) {
36
+ rayFinetuneLlmDeepspeed (t , 1 , "zero_3_offload_optim_param.json" )
34
37
}
35
38
36
- func rayFinetuneLlmDeepspeed (t * testing.T , numGpus int ) {
39
+ func rayFinetuneLlmDeepspeed (t * testing.T , numGpus int , modelConfigFile string ) {
37
40
test := With (t )
38
41
39
42
// Create a namespace
@@ -51,7 +54,7 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
51
54
// list changes required in llm-deepspeed-finetune-demo.ipynb file and update those
52
55
requiredChangesInNotebook := map [string ]string {
53
56
"import os" : "import os,time,sys" ,
54
- "import sys" : "!cp /opt/app-root/notebooks/* ./" ,
57
+ "import sys" : "!cp /opt/app-root/notebooks/* ./\\ n \" , \n \t \" !ls " ,
55
58
"from codeflare_sdk.cluster.auth import TokenAuthentication" : "from codeflare_sdk.cluster.auth import TokenAuthentication\\ n\" ,\n \t \" from codeflare_sdk.job import RayJobClient" ,
56
59
"token = ''" : fmt .Sprintf ("token = '%s'" , userToken ),
57
60
"server = ''" : fmt .Sprintf ("server = '%s'" , GetOpenShiftApiUrl (test )),
@@ -61,23 +64,26 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
61
64
"num_workers=7" : "num_workers=1" ,
62
65
"worker_cpu_requests=16" : "worker_cpu_requests=4" ,
63
66
"worker_cpu_limits=16" : "worker_cpu_limits=4" ,
64
- "worker_memory_requests=128" : "worker_memory_requests=60 " ,
65
- "worker_memory_limits=256" : "worker_memory_limits=60 " ,
67
+ "worker_memory_requests=128" : "worker_memory_requests=64 " ,
68
+ "worker_memory_limits=256" : "worker_memory_limits=128 " ,
66
69
"head_memory=128" : "head_memory=48" ,
67
70
"client = cluster.job_client" : "ray_dashboard = cluster.cluster_dashboard_uri()\\ n\" ,\n \t \" header = {\\ \" Authorization\\ \" : \\ \" Bearer " + userToken + "\\ \" }\\ n\" ,\n \t \" client = RayJobClient(address=ray_dashboard, headers=header, verify=False)\\ n" ,
68
71
"--num-devices=8" : fmt .Sprintf ("--num-devices=%d" , numGpus ),
69
72
"--num-epochs=3" : fmt .Sprintf ("--num-epochs=%d" , 1 ),
70
- "--ds-config=./deepspeed_configs/zero_3_llama_2_7b.json" : "--ds-config=./zero_3_llama_2_7b.json \\ \" \\ n\" ,\n \t \" \\ \" --lora-config=./lora.json \\ \" \\ n\" ,\n \t \" \\ \" --as-test" ,
71
- "'pip': 'requirements.txt'" : "'pip': '/opt/app-root/src/requirements.txt'" ,
72
- "'working_dir': './'" : "'working_dir': '/opt/app-root/src'" ,
73
- "client.stop_job(submission_id)" : "finished = False\\ n\" ,\n \t \" while not finished:\\ n\" ,\n \t \" time.sleep(1)\\ n\" ,\n \t \" status = client.get_job_status(submission_id)\\ n\" ,\n \t \" finished = (status == \\ \" SUCCEEDED\\ \" )\\ n\" ,\n \t \" if finished:\\ n\" ,\n \t \" print(\\ \" Job completed Successfully !\\ \" )\\ n\" ,\n \t \" else:\\ n\" ,\n \t \" print(\\ \" Job failed !\\ \" )\\ n\" ,\n \t \" time.sleep(10)\\ n" ,
73
+ "--ds-config=./deepspeed_configs/zero_3_offload_optim+param.json" : fmt .Sprintf ("--ds-config=./%s \\ \" \\ n\" ,\n \t \" \\ \" --lora-config=./lora.json \\ \" \\ n\" ,\n \t \" \\ \" --as-test" , modelConfigFile ),
74
+ "--batch-size-per-device=32" : "--batch-size-per-device=6" ,
75
+ "--eval-batch-size-per-device=32" : "--eval-batch-size-per-device=6" ,
76
+ "'pip': 'requirements.txt'" : "'pip': '/opt/app-root/src/requirements.txt'" ,
77
+ "'working_dir': './'" : "'working_dir': '/opt/app-root/src'" ,
78
+ "client.stop_job(submission_id)" : "finished = False\\ n\" ,\n \t \" while not finished:\\ n\" ,\n \t \" time.sleep(1)\\ n\" ,\n \t \" status = client.get_job_status(submission_id)\\ n\" ,\n \t \" finished = (status == \\ \" SUCCEEDED\\ \" )\\ n\" ,\n \t \" if finished:\\ n\" ,\n \t \" print(\\ \" Job completed Successfully !\\ \" )\\ n\" ,\n \t \" else:\\ n\" ,\n \t \" print(\\ \" Job failed !\\ \" )\\ n\" ,\n \t \" time.sleep(10)\\ n" ,
74
79
}
75
80
76
81
updatedNotebookContent := string (ReadFileExt (test , workingDirectory + "/../../examples/ray-finetune-llm-deepspeed/ray_finetune_llm_deepspeed.ipynb" ))
77
82
for oldValue , newValue := range requiredChangesInNotebook {
78
83
updatedNotebookContent = strings .Replace (updatedNotebookContent , oldValue , newValue , - 1 )
79
84
}
80
85
updatedNotebook := []byte (updatedNotebookContent )
86
+ os .WriteFile ("demo.ipynb" , updatedNotebook , 0644 )
81
87
82
88
// Test configuration
83
89
jupyterNotebookConfigMapFileName := "ray_finetune_llm_deepspeed.ipynb"
@@ -87,7 +93,7 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
87
93
"requirements.txt" : ReadFileExt (test , workingDirectory + "/../../examples/ray-finetune-llm-deepspeed/requirements.txt" ),
88
94
"create_dataset.py" : ReadFileExt (test , workingDirectory + "/../../examples/ray-finetune-llm-deepspeed/create_dataset.py" ),
89
95
"lora.json" : ReadFileExt (test , workingDirectory + "/../../examples/ray-finetune-llm-deepspeed/lora_configs/lora.json" ),
90
- "zero_3_llama_2_7b.json" : ReadFileExt (test , workingDirectory + "/../../examples/ray-finetune-llm-deepspeed/deepspeed_configs/zero_3_llama_2_7b.json" ),
96
+ modelConfigFile : ReadFileExt (test , fmt . Sprintf ( workingDirectory + "/../../examples/ray-finetune-llm-deepspeed/deepspeed_configs/%s" , modelConfigFile ) ),
91
97
"utils.py" : ReadFileExt (test , workingDirectory + "/../../examples/ray-finetune-llm-deepspeed/utils.py" ),
92
98
}
93
99
@@ -120,7 +126,7 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
120
126
121
127
// Initialise raycluster client to interact with raycluster to get rayjob details using REST-API
122
128
dashboardUrl := GetDashboardUrl (test , namespace , rayCluster )
123
- rayClusterClientConfig := RayClusterClientConfig {Address : dashboardUrl .String (), Client : nil , SkipTlsVerification : true }
129
+ rayClusterClientConfig := RayClusterClientConfig {Address : dashboardUrl .String (), Client : nil , InsecureSkipVerify : true }
124
130
rayClient , err := NewRayClusterClient (rayClusterClientConfig , test .Config ().BearerToken )
125
131
if err != nil {
126
132
test .T ().Errorf ("%s" , err )
0 commit comments