Skip to content

Commit 2419ec0

Browse files
Added test for llama-3.1-8b model finetuning demo
1 parent 81de12f commit 2419ec0

File tree

6 files changed

+23
-19
lines changed

6 files changed

+23
-19
lines changed

go.mod

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ toolchain go1.21.5
77
require (
88
github.com/kubeflow/training-operator v1.7.0
99
github.com/onsi/gomega v1.31.1
10-
github.com/project-codeflare/codeflare-common v0.0.0-20240809123324-d44e319ba556
10+
github.com/project-codeflare/codeflare-common v0.0.0-20240827080155-9234d23ff47d
1111
github.com/prometheus/client_golang v1.18.0
1212
github.com/prometheus/common v0.45.0
1313
github.com/ray-project/kuberay/ray-operator v1.1.0-alpha.0
@@ -16,8 +16,6 @@ require (
1616
sigs.k8s.io/kueue v0.6.2
1717
)
1818

19-
replace github.com/project-codeflare/codeflare-common => /home/abdhumal/abhidev/RedHatDev/codeflare-common
20-
2119
require (
2220
github.com/aymerick/douceur v0.2.0 // indirect
2321
github.com/beorn7/perks v1.0.1 // indirect

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,8 +363,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
363363
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
364364
github.com/project-codeflare/appwrapper v0.8.0 h1:vWHNtXUtHutN2EzYb6rryLdESnb8iDXsCokXOuNYXvg=
365365
github.com/project-codeflare/appwrapper v0.8.0/go.mod h1:FMQ2lI3fz6LakUVXgN1FTdpsc3BBkNIZZgtMmM9J5UM=
366-
github.com/project-codeflare/codeflare-common v0.0.0-20240809123324-d44e319ba556 h1:4SI3d63CNZ+7sKQ1JEqLmNzGSgVXqz3aT3+aDXRgo18=
367-
github.com/project-codeflare/codeflare-common v0.0.0-20240809123324-d44e319ba556/go.mod h1:unKTw+XoMANTES3WieG016im7rxZ7IR2/ph++L5Vp1Y=
366+
github.com/project-codeflare/codeflare-common v0.0.0-20240827080155-9234d23ff47d h1:hbfF20rw/NHvXNXYLuxPjCnBS5Lotvt6rU0S9DLs0HU=
367+
github.com/project-codeflare/codeflare-common v0.0.0-20240827080155-9234d23ff47d/go.mod h1:unKTw+XoMANTES3WieG016im7rxZ7IR2/ph++L5Vp1Y=
368368
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
369369
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
370370
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=

tests/odh/mnist_ray_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ func mnistRay(t *testing.T, numGpus int) {
138138

139139
// Initialise raycluster client to interact with raycluster to get rayjob details using REST-API
140140
dashboardUrl := GetDashboardUrl(test, namespace, rayCluster)
141-
rayClusterClientConfig := RayClusterClientConfig{Address: dashboardUrl.String(), Client: nil, SkipTlsVerification: true}
141+
rayClusterClientConfig := RayClusterClientConfig{Address: dashboardUrl.String(), Client: nil, InsecureSkipVerify: true}
142142
rayClient, err := NewRayClusterClient(rayClusterClientConfig, test.Config().BearerToken)
143143
if err != nil {
144144
test.T().Errorf("%s", err)

tests/odh/mnist_raytune_hpo_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ func mnistRayTuneHpo(t *testing.T, numGpus int) {
138138

139139
// Initialise raycluster client to interact with raycluster to get rayjob details using REST-API
140140
dashboardUrl := GetDashboardUrl(test, namespace, rayCluster)
141-
rayClusterClientConfig := RayClusterClientConfig{Address: dashboardUrl.String(), Client: nil, SkipTlsVerification: true}
141+
rayClusterClientConfig := RayClusterClientConfig{Address: dashboardUrl.String(), Client: nil, InsecureSkipVerify: true}
142142
rayClient, err := NewRayClusterClient(rayClusterClientConfig, test.Config().BearerToken)
143143
if err != nil {
144144
test.T().Errorf("%s", err)

tests/odh/ray_finetune_llm_deepspeed_test.go

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,14 @@ import (
2929
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
3030
)
3131

32-
func TestRayFinetuneLlmDeepspeedDemo(t *testing.T) {
33-
rayFinetuneLlmDeepspeed(t, 1)
32+
func TestRayFinetuneLlmDeepspeedDemoLlama_2_7b(t *testing.T) {
33+
rayFinetuneLlmDeepspeed(t, 1, "zero_3_llama_2_7b.json")
34+
}
35+
func TestRayFinetuneLlmDeepspeedDemoLlama_31_8b(t *testing.T) {
36+
rayFinetuneLlmDeepspeed(t, 1, "zero_3_offload_optim_param.json")
3437
}
3538

36-
func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
39+
func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int, modelConfigFile string) {
3740
test := With(t)
3841

3942
// Create a namespace
@@ -51,7 +54,7 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
5154
// list changes required in llm-deepspeed-finetune-demo.ipynb file and update those
5255
requiredChangesInNotebook := map[string]string{
5356
"import os": "import os,time,sys",
54-
"import sys": "!cp /opt/app-root/notebooks/* ./",
57+
"import sys": "!cp /opt/app-root/notebooks/* ./\\n\",\n\t\"!ls",
5558
"from codeflare_sdk.cluster.auth import TokenAuthentication": "from codeflare_sdk.cluster.auth import TokenAuthentication\\n\",\n\t\"from codeflare_sdk.job import RayJobClient",
5659
"token = ''": fmt.Sprintf("token = '%s'", userToken),
5760
"server = ''": fmt.Sprintf("server = '%s'", GetOpenShiftApiUrl(test)),
@@ -61,23 +64,26 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
6164
"num_workers=7": "num_workers=1",
6265
"worker_cpu_requests=16": "worker_cpu_requests=4",
6366
"worker_cpu_limits=16": "worker_cpu_limits=4",
64-
"worker_memory_requests=128": "worker_memory_requests=60",
65-
"worker_memory_limits=256": "worker_memory_limits=60",
67+
"worker_memory_requests=128": "worker_memory_requests=64",
68+
"worker_memory_limits=256": "worker_memory_limits=128",
6669
"head_memory=128": "head_memory=48",
6770
"client = cluster.job_client": "ray_dashboard = cluster.cluster_dashboard_uri()\\n\",\n\t\"header = {\\\"Authorization\\\": \\\"Bearer " + userToken + "\\\"}\\n\",\n\t\"client = RayJobClient(address=ray_dashboard, headers=header, verify=False)\\n",
6871
"--num-devices=8": fmt.Sprintf("--num-devices=%d", numGpus),
6972
"--num-epochs=3": fmt.Sprintf("--num-epochs=%d", 1),
70-
"--ds-config=./deepspeed_configs/zero_3_llama_2_7b.json": "--ds-config=./zero_3_llama_2_7b.json \\\"\\n\",\n\t\" \\\"--lora-config=./lora.json \\\"\\n\",\n\t\" \\\"--as-test",
71-
"'pip': 'requirements.txt'": "'pip': '/opt/app-root/src/requirements.txt'",
72-
"'working_dir': './'": "'working_dir': '/opt/app-root/src'",
73-
"client.stop_job(submission_id)": "finished = False\\n\",\n\t\"while not finished:\\n\",\n\t\" time.sleep(1)\\n\",\n\t\" status = client.get_job_status(submission_id)\\n\",\n\t\" finished = (status == \\\"SUCCEEDED\\\")\\n\",\n\t\"if finished:\\n\",\n\t\" print(\\\"Job completed Successfully !\\\")\\n\",\n\t\"else:\\n\",\n\t\" print(\\\"Job failed !\\\")\\n\",\n\t\"time.sleep(10)\\n",
73+
"--ds-config=./deepspeed_configs/zero_3_offload_optim+param.json": fmt.Sprintf("--ds-config=./%s \\\"\\n\",\n\t\" \\\"--lora-config=./lora.json \\\"\\n\",\n\t\" \\\"--as-test", modelConfigFile),
74+
"--batch-size-per-device=32": "--batch-size-per-device=6",
75+
"--eval-batch-size-per-device=32": "--eval-batch-size-per-device=6",
76+
"'pip': 'requirements.txt'": "'pip': '/opt/app-root/src/requirements.txt'",
77+
"'working_dir': './'": "'working_dir': '/opt/app-root/src'",
78+
"client.stop_job(submission_id)": "finished = False\\n\",\n\t\"while not finished:\\n\",\n\t\" time.sleep(1)\\n\",\n\t\" status = client.get_job_status(submission_id)\\n\",\n\t\" finished = (status == \\\"SUCCEEDED\\\")\\n\",\n\t\"if finished:\\n\",\n\t\" print(\\\"Job completed Successfully !\\\")\\n\",\n\t\"else:\\n\",\n\t\" print(\\\"Job failed !\\\")\\n\",\n\t\"time.sleep(10)\\n",
7479
}
7580

7681
updatedNotebookContent := string(ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/ray_finetune_llm_deepspeed.ipynb"))
7782
for oldValue, newValue := range requiredChangesInNotebook {
7883
updatedNotebookContent = strings.Replace(updatedNotebookContent, oldValue, newValue, -1)
7984
}
8085
updatedNotebook := []byte(updatedNotebookContent)
86+
os.WriteFile("demo.ipynb", updatedNotebook, 0644)
8187

8288
// Test configuration
8389
jupyterNotebookConfigMapFileName := "ray_finetune_llm_deepspeed.ipynb"
@@ -87,7 +93,7 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
8793
"requirements.txt": ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/requirements.txt"),
8894
"create_dataset.py": ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/create_dataset.py"),
8995
"lora.json": ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/lora_configs/lora.json"),
90-
"zero_3_llama_2_7b.json": ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/deepspeed_configs/zero_3_llama_2_7b.json"),
96+
modelConfigFile: ReadFileExt(test, fmt.Sprintf(workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/deepspeed_configs/%s", modelConfigFile)),
9197
"utils.py": ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/utils.py"),
9298
}
9399

@@ -120,7 +126,7 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
120126

121127
// Initialise raycluster client to interact with raycluster to get rayjob details using REST-API
122128
dashboardUrl := GetDashboardUrl(test, namespace, rayCluster)
123-
rayClusterClientConfig := RayClusterClientConfig{Address: dashboardUrl.String(), Client: nil, SkipTlsVerification: true}
129+
rayClusterClientConfig := RayClusterClientConfig{Address: dashboardUrl.String(), Client: nil, InsecureSkipVerify: true}
124130
rayClient, err := NewRayClusterClient(rayClusterClientConfig, test.Config().BearerToken)
125131
if err != nil {
126132
test.T().Errorf("%s", err)

0 commit comments

Comments
 (0)