@@ -21,10 +21,12 @@ import (
21
21
"os"
22
22
"strings"
23
23
"testing"
24
+ "time"
24
25
25
26
. "github.com/onsi/gomega"
26
27
. "github.com/project-codeflare/codeflare-common/support"
27
28
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
29
+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
28
30
)
29
31
30
32
func TestRayFinetuneLlmDeepspeedDemo (t * testing.T ) {
@@ -55,19 +57,17 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
55
57
"server = ''" : fmt .Sprintf ("server = '%s'" , GetOpenShiftApiUrl (test )),
56
58
"namespace='ray-finetune-llm-deepspeed'" : fmt .Sprintf ("namespace='%s'" , namespace .Name ),
57
59
"head_cpus=16" : "head_cpus=2" ,
58
- "head_gpus =1" : "head_gpus =0" ,
60
+ "head_extended_resource_requests =1" : "head_extended_resource_requests =0" ,
59
61
"num_workers=7" : "num_workers=1" ,
60
- "min_cpus =16" : "min_cpus =4" ,
61
- "max_cpus =16" : "max_cpus =4" ,
62
- "min_memory =128" : "min_memory=48 " ,
63
- "max_memory =256" : "max_memory=48 " ,
62
+ "worker_cpu_requests =16" : "worker_cpu_requests =4" ,
63
+ "worker_cpu_limits =16" : "worker_cpu_limits =4" ,
64
+ "worker_memory_requests =128" : "worker_memory_requests=60 " ,
65
+ "worker_memory_limits =256" : "worker_memory_limits=60 " ,
64
66
"head_memory=128" : "head_memory=48" ,
65
- "num_gpus=1" : fmt .Sprintf ("worker_extended_resource_requests={'nvidia.com/gpu': %d},\\ n\" ,\n \t \" write_to_file=True,\\ n\" ,\n \t \" verify_tls=False" , numGpus ),
66
- "image='quay.io/rhoai/ray:2.23.0-py39-cu121'" : fmt .Sprintf ("image='%s'" , GetRayImage ()),
67
- "client = cluster.job_client" : "ray_dashboard = cluster.cluster_dashboard_uri()\\ n\" ,\n \t \" header = {\\ \" Authorization\\ \" : \\ \" Bearer " + userToken + "\\ \" }\\ n\" ,\n \t \" client = RayJobClient(address=ray_dashboard, headers=header, verify=False)\\ n" ,
68
- "--num-devices=8" : fmt .Sprintf ("--num-devices=%d" , numGpus ),
69
- "--num-epochs=3" : fmt .Sprintf ("--num-epochs=%d" , 1 ),
70
- "--ds-config=./deepspeed_configs/zero_3_llama_2_7b.json" : "--ds-config=./zero_3_llama_2_7b.json \\ \" \\ n\" ,\n \t \" \\ \" --lora-config=./lora.json \\ \" \\ n\" ,\n \t \" \\ \" --as-test" ,
67
+ "client = cluster.job_client" : "ray_dashboard = cluster.cluster_dashboard_uri()\\ n\" ,\n \t \" header = {\\ \" Authorization\\ \" : \\ \" Bearer " + userToken + "\\ \" }\\ n\" ,\n \t \" client = RayJobClient(address=ray_dashboard, headers=header, verify=False)\\ n" ,
68
+ "--num-devices=8" : fmt .Sprintf ("--num-devices=%d" , numGpus ),
69
+ "--num-epochs=3" : fmt .Sprintf ("--num-epochs=%d" , 1 ),
70
+ "--ds-config=./deepspeed_configs/zero_3_llama_2_7b.json" : "--ds-config=./zero_3_llama_2_7b.json \\ \" \\ n\" ,\n \t \" \\ \" --lora-config=./lora.json \\ \" \\ n\" ,\n \t \" \\ \" --as-test" ,
71
71
"'pip': 'requirements.txt'" : "'pip': '/opt/app-root/src/requirements.txt'" ,
72
72
"'working_dir': './'" : "'working_dir': '/opt/app-root/src'" ,
73
73
"client.stop_job(submission_id)" : "finished = False\\ n\" ,\n \t \" while not finished:\\ n\" ,\n \t \" time.sleep(1)\\ n\" ,\n \t \" status = client.get_job_status(submission_id)\\ n\" ,\n \t \" finished = (status == \\ \" SUCCEEDED\\ \" )\\ n\" ,\n \t \" if finished:\\ n\" ,\n \t \" print(\\ \" Job completed Successfully !\\ \" )\\ n\" ,\n \t \" else:\\ n\" ,\n \t \" print(\\ \" Job failed !\\ \" )\\ n\" ,\n \t \" time.sleep(10)\\ n" ,
@@ -111,6 +111,47 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
111
111
),
112
112
)
113
113
114
+ time .Sleep (30 * time .Second )
115
+
116
+ // Fetch created raycluster
117
+ rayClusterName := "ray"
118
+ rayCluster , err := test .Client ().Ray ().RayV1 ().RayClusters (namespace .Name ).Get (test .Ctx (), rayClusterName , metav1.GetOptions {})
119
+ test .Expect (err ).ToNot (HaveOccurred ())
120
+
121
+ // Initialise raycluster client to interact with raycluster to get rayjob details using REST-API
122
+ dashboardUrl := GetDashboardUrl (test , namespace , rayCluster )
123
+ rayClusterClientConfig := RayClusterClientConfig {Address : dashboardUrl .String (), Client : nil , SkipTlsVerification : true }
124
+ rayClient , err := NewRayClusterClient (rayClusterClientConfig , test .Config ().BearerToken )
125
+ if err != nil {
126
+ test .T ().Errorf ("%s" , err )
127
+ }
128
+
129
+ jobID := GetTestJobId (test , rayClient , dashboardUrl .Host )
130
+ test .Expect (jobID ).ToNot (Equal (nil ))
131
+
132
+ // Wait for the job to be succeeded or failed
133
+ var rayJobStatus string
134
+ fmt .Printf ("Waiting for job to be Succeeded...\n " )
135
+ test .Eventually (func () string {
136
+ resp , err := rayClient .GetJobDetails (jobID )
137
+ test .Expect (err ).ToNot (HaveOccurred ())
138
+ rayJobStatusVal := resp .Status
139
+ if rayJobStatusVal == "SUCCEEDED" || rayJobStatusVal == "FAILED" {
140
+ fmt .Printf ("JobStatus : %s\n " , rayJobStatusVal )
141
+ rayJobStatus = rayJobStatusVal
142
+ WriteRayJobAPILogs (test , rayClient , jobID )
143
+ return rayJobStatus
144
+ }
145
+ if rayJobStatus != rayJobStatusVal && rayJobStatusVal != "SUCCEEDED" {
146
+ fmt .Printf ("JobStatus : %s...\n " , rayJobStatusVal )
147
+ rayJobStatus = rayJobStatusVal
148
+ }
149
+ return rayJobStatus
150
+ }, TestTimeoutDouble , 3 * time .Second ).Should (Or (Equal ("SUCCEEDED" ), Equal ("FAILED" )), "Job did not complete within the expected time" )
151
+ // Store job logs in output directory
152
+ WriteRayJobAPILogs (test , rayClient , jobID )
153
+ test .Expect (rayJobStatus ).To (Equal ("SUCCEEDED" ), "RayJob failed !" )
154
+
114
155
// Make sure the RayCluster finishes and is deleted
115
156
test .Eventually (RayClusters (test , namespace .Name ), TestTimeoutMedium ).
116
157
Should (HaveLen (0 ))
0 commit comments