Skip to content

Commit 03e4879

Browse files
Added test for llama-3.1-8b model finetuning demo
1 parent 81de12f commit 03e4879

File tree

5 files changed

+20
-16
lines changed

5 files changed

+20
-16
lines changed
File renamed without changes.

go.mod

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ require (
1616
sigs.k8s.io/kueue v0.6.2
1717
)
1818

19-
replace github.com/project-codeflare/codeflare-common => /home/abdhumal/abhidev/RedHatDev/codeflare-common
20-
2119
require (
2220
github.com/aymerick/douceur v0.2.0 // indirect
2321
github.com/beorn7/perks v1.0.1 // indirect

tests/odh/mnist_ray_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ func mnistRay(t *testing.T, numGpus int) {
138138

139139
// Initialise raycluster client to interact with raycluster to get rayjob details using REST-API
140140
dashboardUrl := GetDashboardUrl(test, namespace, rayCluster)
141-
rayClusterClientConfig := RayClusterClientConfig{Address: dashboardUrl.String(), Client: nil, SkipTlsVerification: true}
141+
rayClusterClientConfig := RayClusterClientConfig{Address: dashboardUrl.String(), Client: nil, InsecureSkipVerify: true}
142142
rayClient, err := NewRayClusterClient(rayClusterClientConfig, test.Config().BearerToken)
143143
if err != nil {
144144
test.T().Errorf("%s", err)

tests/odh/mnist_raytune_hpo_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ func mnistRayTuneHpo(t *testing.T, numGpus int) {
138138

139139
// Initialise raycluster client to interact with raycluster to get rayjob details using REST-API
140140
dashboardUrl := GetDashboardUrl(test, namespace, rayCluster)
141-
rayClusterClientConfig := RayClusterClientConfig{Address: dashboardUrl.String(), Client: nil, SkipTlsVerification: true}
141+
rayClusterClientConfig := RayClusterClientConfig{Address: dashboardUrl.String(), Client: nil, InsecureSkipVerify: true}
142142
rayClient, err := NewRayClusterClient(rayClusterClientConfig, test.Config().BearerToken)
143143
if err != nil {
144144
test.T().Errorf("%s", err)

tests/odh/ray_finetune_llm_deepspeed_test.go

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,14 @@ import (
2929
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
3030
)
3131

32-
func TestRayFinetuneLlmDeepspeedDemo(t *testing.T) {
33-
rayFinetuneLlmDeepspeed(t, 1)
32+
func TestRayFinetuneLlmDeepspeedDemoLlama_2_7b(t *testing.T) {
33+
rayFinetuneLlmDeepspeed(t, 1, "zero_3_llama_2_7b.json")
34+
}
35+
func TestRayFinetuneLlmDeepspeedDemoLlama_31_8b(t *testing.T) {
36+
rayFinetuneLlmDeepspeed(t, 1, "zero_3_offload_optim_param.json")
3437
}
3538

36-
func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
39+
func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int, modelConfigFile string) {
3740
test := With(t)
3841

3942
// Create a namespace
@@ -51,7 +54,7 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
5154
// list changes required in llm-deepspeed-finetune-demo.ipynb file and update those
5255
requiredChangesInNotebook := map[string]string{
5356
"import os": "import os,time,sys",
54-
"import sys": "!cp /opt/app-root/notebooks/* ./",
57+
"import sys": "!cp /opt/app-root/notebooks/* ./\\n\",\n\t\"!ls",
5558
"from codeflare_sdk.cluster.auth import TokenAuthentication": "from codeflare_sdk.cluster.auth import TokenAuthentication\\n\",\n\t\"from codeflare_sdk.job import RayJobClient",
5659
"token = ''": fmt.Sprintf("token = '%s'", userToken),
5760
"server = ''": fmt.Sprintf("server = '%s'", GetOpenShiftApiUrl(test)),
@@ -61,23 +64,26 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
6164
"num_workers=7": "num_workers=1",
6265
"worker_cpu_requests=16": "worker_cpu_requests=4",
6366
"worker_cpu_limits=16": "worker_cpu_limits=4",
64-
"worker_memory_requests=128": "worker_memory_requests=60",
65-
"worker_memory_limits=256": "worker_memory_limits=60",
67+
"worker_memory_requests=128": "worker_memory_requests=64",
68+
"worker_memory_limits=256": "worker_memory_limits=128",
6669
"head_memory=128": "head_memory=48",
6770
"client = cluster.job_client": "ray_dashboard = cluster.cluster_dashboard_uri()\\n\",\n\t\"header = {\\\"Authorization\\\": \\\"Bearer " + userToken + "\\\"}\\n\",\n\t\"client = RayJobClient(address=ray_dashboard, headers=header, verify=False)\\n",
6871
"--num-devices=8": fmt.Sprintf("--num-devices=%d", numGpus),
6972
"--num-epochs=3": fmt.Sprintf("--num-epochs=%d", 1),
70-
"--ds-config=./deepspeed_configs/zero_3_llama_2_7b.json": "--ds-config=./zero_3_llama_2_7b.json \\\"\\n\",\n\t\" \\\"--lora-config=./lora.json \\\"\\n\",\n\t\" \\\"--as-test",
71-
"'pip': 'requirements.txt'": "'pip': '/opt/app-root/src/requirements.txt'",
72-
"'working_dir': './'": "'working_dir': '/opt/app-root/src'",
73-
"client.stop_job(submission_id)": "finished = False\\n\",\n\t\"while not finished:\\n\",\n\t\" time.sleep(1)\\n\",\n\t\" status = client.get_job_status(submission_id)\\n\",\n\t\" finished = (status == \\\"SUCCEEDED\\\")\\n\",\n\t\"if finished:\\n\",\n\t\" print(\\\"Job completed Successfully !\\\")\\n\",\n\t\"else:\\n\",\n\t\" print(\\\"Job failed !\\\")\\n\",\n\t\"time.sleep(10)\\n",
73+
"--ds-config=./deepspeed_configs/zero_3_offload_optim+param.json": fmt.Sprintf("--ds-config=./%s \\\"\\n\",\n\t\" \\\"--lora-config=./lora.json \\\"\\n\",\n\t\" \\\"--as-test", modelConfigFile),
74+
"--batch-size-per-device=32": "--batch-size-per-device=6",
75+
"--eval-batch-size-per-device=32": "--eval-batch-size-per-device=6",
76+
"'pip': 'requirements.txt'": "'pip': '/opt/app-root/src/requirements.txt'",
77+
"'working_dir': './'": "'working_dir': '/opt/app-root/src'",
78+
"client.stop_job(submission_id)": "finished = False\\n\",\n\t\"while not finished:\\n\",\n\t\" time.sleep(1)\\n\",\n\t\" status = client.get_job_status(submission_id)\\n\",\n\t\" finished = (status == \\\"SUCCEEDED\\\")\\n\",\n\t\"if finished:\\n\",\n\t\" print(\\\"Job completed Successfully !\\\")\\n\",\n\t\"else:\\n\",\n\t\" print(\\\"Job failed !\\\")\\n\",\n\t\"time.sleep(10)\\n",
7479
}
7580

7681
updatedNotebookContent := string(ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/ray_finetune_llm_deepspeed.ipynb"))
7782
for oldValue, newValue := range requiredChangesInNotebook {
7883
updatedNotebookContent = strings.Replace(updatedNotebookContent, oldValue, newValue, -1)
7984
}
8085
updatedNotebook := []byte(updatedNotebookContent)
86+
os.WriteFile("demo.ipynb", updatedNotebook, 0644)
8187

8288
// Test configuration
8389
jupyterNotebookConfigMapFileName := "ray_finetune_llm_deepspeed.ipynb"
@@ -87,7 +93,7 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
8793
"requirements.txt": ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/requirements.txt"),
8894
"create_dataset.py": ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/create_dataset.py"),
8995
"lora.json": ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/lora_configs/lora.json"),
90-
"zero_3_llama_2_7b.json": ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/deepspeed_configs/zero_3_llama_2_7b.json"),
96+
modelConfigFile: ReadFileExt(test, fmt.Sprintf(workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/deepspeed_configs/%s", modelConfigFile)),
9197
"utils.py": ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/utils.py"),
9298
}
9399

@@ -120,7 +126,7 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
120126

121127
// Initialise raycluster client to interact with raycluster to get rayjob details using REST-API
122128
dashboardUrl := GetDashboardUrl(test, namespace, rayCluster)
123-
rayClusterClientConfig := RayClusterClientConfig{Address: dashboardUrl.String(), Client: nil, SkipTlsVerification: true}
129+
rayClusterClientConfig := RayClusterClientConfig{Address: dashboardUrl.String(), Client: nil, InsecureSkipVerify: true}
124130
rayClient, err := NewRayClusterClient(rayClusterClientConfig, test.Config().BearerToken)
125131
if err != nil {
126132
test.T().Errorf("%s", err)

0 commit comments

Comments
 (0)