@@ -21,10 +21,12 @@ import (
2121 "os"
2222 "strings"
2323 "testing"
24+ "time"
2425
2526 . "github.com/onsi/gomega"
2627 . "github.com/project-codeflare/codeflare-common/support"
2728 rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
29+ metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2830)
2931
3032func TestRayFinetuneLlmDeepspeedDemo (t * testing.T ) {
@@ -55,19 +57,17 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
5557 "server = ''" : fmt .Sprintf ("server = '%s'" , GetOpenShiftApiUrl (test )),
5658 "namespace='ray-finetune-llm-deepspeed'" : fmt .Sprintf ("namespace='%s'" , namespace .Name ),
5759 "head_cpus=16" : "head_cpus=2" ,
58- "head_gpus =1" : "head_gpus =0" ,
60+ "head_extended_resource_requests =1" : "head_extended_resource_requests =0" ,
5961 "num_workers=7" : "num_workers=1" ,
60- "min_cpus =16" : "min_cpus =4" ,
61- "max_cpus =16" : "max_cpus =4" ,
62- "min_memory =128" : "min_memory=48 " ,
63- "max_memory =256" : "max_memory=48 " ,
62+ "worker_cpu_requests =16" : "worker_cpu_requests =4" ,
63+ "worker_cpu_limits =16" : "worker_cpu_limits =4" ,
64+ "worker_memory_requests =128" : "worker_memory_requests=60 " ,
65+ "worker_memory_limits =256" : "worker_memory_limits=60 " ,
6466 "head_memory=128" : "head_memory=48" ,
65- "num_gpus=1" : fmt .Sprintf ("worker_extended_resource_requests={'nvidia.com/gpu': %d},\\ n\" ,\n \t \" write_to_file=True,\\ n\" ,\n \t \" verify_tls=False" , numGpus ),
66- "image='quay.io/rhoai/ray:2.23.0-py39-cu121'" : fmt .Sprintf ("image='%s'" , GetRayImage ()),
67- "client = cluster.job_client" : "ray_dashboard = cluster.cluster_dashboard_uri()\\ n\" ,\n \t \" header = {\\ \" Authorization\\ \" : \\ \" Bearer " + userToken + "\\ \" }\\ n\" ,\n \t \" client = RayJobClient(address=ray_dashboard, headers=header, verify=False)\\ n" ,
68- "--num-devices=8" : fmt .Sprintf ("--num-devices=%d" , numGpus ),
69- "--num-epochs=3" : fmt .Sprintf ("--num-epochs=%d" , 1 ),
70- "--ds-config=./deepspeed_configs/zero_3_llama_2_7b.json" : "--ds-config=./zero_3_llama_2_7b.json \\ \" \\ n\" ,\n \t \" \\ \" --lora-config=./lora.json \\ \" \\ n\" ,\n \t \" \\ \" --as-test" ,
67+ "client = cluster.job_client" : "ray_dashboard = cluster.cluster_dashboard_uri()\\ n\" ,\n \t \" header = {\\ \" Authorization\\ \" : \\ \" Bearer " + userToken + "\\ \" }\\ n\" ,\n \t \" client = RayJobClient(address=ray_dashboard, headers=header, verify=False)\\ n" ,
68+ "--num-devices=8" : fmt .Sprintf ("--num-devices=%d" , numGpus ),
69+ "--num-epochs=3" : fmt .Sprintf ("--num-epochs=%d" , 1 ),
70+ "--ds-config=./deepspeed_configs/zero_3_llama_2_7b.json" : "--ds-config=./zero_3_llama_2_7b.json \\ \" \\ n\" ,\n \t \" \\ \" --lora-config=./lora.json \\ \" \\ n\" ,\n \t \" \\ \" --as-test" ,
7171 "'pip': 'requirements.txt'" : "'pip': '/opt/app-root/src/requirements.txt'" ,
7272 "'working_dir': './'" : "'working_dir': '/opt/app-root/src'" ,
7373 "client.stop_job(submission_id)" : "finished = False\\ n\" ,\n \t \" while not finished:\\ n\" ,\n \t \" time.sleep(1)\\ n\" ,\n \t \" status = client.get_job_status(submission_id)\\ n\" ,\n \t \" finished = (status == \\ \" SUCCEEDED\\ \" )\\ n\" ,\n \t \" if finished:\\ n\" ,\n \t \" print(\\ \" Job completed Successfully !\\ \" )\\ n\" ,\n \t \" else:\\ n\" ,\n \t \" print(\\ \" Job failed !\\ \" )\\ n\" ,\n \t \" time.sleep(10)\\ n" ,
@@ -111,6 +111,47 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
111111 ),
112112 )
113113
114+ time .Sleep (30 * time .Second )
115+
116+ // Fetch created raycluster
117+ rayClusterName := "ray"
118+ rayCluster , err := test .Client ().Ray ().RayV1 ().RayClusters (namespace .Name ).Get (test .Ctx (), rayClusterName , metav1.GetOptions {})
119+ test .Expect (err ).ToNot (HaveOccurred ())
120+
121+ // Initialise raycluster client to interact with raycluster to get rayjob details using REST-API
122+ dashboardUrl := GetDashboardUrl (test , namespace , rayCluster )
123+ rayClusterClientConfig := RayClusterClientConfig {Address : dashboardUrl .String (), Client : nil , SkipTlsVerification : true }
124+ rayClient , err := NewRayClusterClient (rayClusterClientConfig , test .Config ().BearerToken )
125+ if err != nil {
126+ test .T ().Errorf ("%s" , err )
127+ }
128+
129+ jobID := GetTestJobId (test , rayClient , dashboardUrl .Host )
130+ test .Expect (jobID ).ToNot (Equal (nil ))
131+
132+ // Wait for the job to be succeeded or failed
133+ var rayJobStatus string
134+ fmt .Printf ("Waiting for job to be Succeeded...\n " )
135+ test .Eventually (func () string {
136+ resp , err := rayClient .GetJobDetails (jobID )
137+ test .Expect (err ).ToNot (HaveOccurred ())
138+ rayJobStatusVal := resp .Status
139+ if rayJobStatusVal == "SUCCEEDED" || rayJobStatusVal == "FAILED" {
140+ fmt .Printf ("JobStatus : %s\n " , rayJobStatusVal )
141+ rayJobStatus = rayJobStatusVal
142+ WriteRayJobAPILogs (test , rayClient , jobID )
143+ return rayJobStatus
144+ }
145+ if rayJobStatus != rayJobStatusVal && rayJobStatusVal != "SUCCEEDED" {
146+ fmt .Printf ("JobStatus : %s...\n " , rayJobStatusVal )
147+ rayJobStatus = rayJobStatusVal
148+ }
149+ return rayJobStatus
150+ }, TestTimeoutDouble , 3 * time .Second ).Should (Or (Equal ("SUCCEEDED" ), Equal ("FAILED" )), "Job did not complete within the expected time" )
151+ // Store job logs in output directory
152+ WriteRayJobAPILogs (test , rayClient , jobID )
153+ test .Expect (rayJobStatus ).To (Equal ("SUCCEEDED" ), "RayJob failed !" )
154+
114155 // Make sure the RayCluster finishes and is deleted
115156 test .Eventually (RayClusters (test , namespace .Name ), TestTimeoutMedium ).
116157 Should (HaveLen (0 ))
0 commit comments