Skip to content

Commit 81de12f

Browse files
Updated ray-finetune-demo test with latest changes
1 parent 0eb5dc2 commit 81de12f

File tree

2 files changed

+60
-11
lines changed

2 files changed

+60
-11
lines changed

tests/odh/ray_finetune_llm_deepspeed_test.go

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ import (
2121
"os"
2222
"strings"
2323
"testing"
24+
"time"
2425

2526
. "github.com/onsi/gomega"
2627
. "github.com/project-codeflare/codeflare-common/support"
2728
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
29+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2830
)
2931

3032
func TestRayFinetuneLlmDeepspeedDemo(t *testing.T) {
@@ -55,19 +57,17 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
5557
"server = ''": fmt.Sprintf("server = '%s'", GetOpenShiftApiUrl(test)),
5658
"namespace='ray-finetune-llm-deepspeed'": fmt.Sprintf("namespace='%s'", namespace.Name),
5759
"head_cpus=16": "head_cpus=2",
58-
"head_gpus=1": "head_gpus=0",
60+
"head_extended_resource_requests=1": "head_extended_resource_requests=0",
5961
"num_workers=7": "num_workers=1",
60-
"min_cpus=16": "min_cpus=4",
61-
"max_cpus=16": "max_cpus=4",
62-
"min_memory=128": "min_memory=48",
63-
"max_memory=256": "max_memory=48",
62+
"worker_cpu_requests=16": "worker_cpu_requests=4",
63+
"worker_cpu_limits=16": "worker_cpu_limits=4",
64+
"worker_memory_requests=128": "worker_memory_requests=60",
65+
"worker_memory_limits=256": "worker_memory_limits=60",
6466
"head_memory=128": "head_memory=48",
65-
"num_gpus=1": fmt.Sprintf("worker_extended_resource_requests={'nvidia.com/gpu': %d},\\n\",\n\t\" write_to_file=True,\\n\",\n\t\" verify_tls=False", numGpus),
66-
"image='quay.io/rhoai/ray:2.23.0-py39-cu121'": fmt.Sprintf("image='%s'", GetRayImage()),
67-
"client = cluster.job_client": "ray_dashboard = cluster.cluster_dashboard_uri()\\n\",\n\t\"header = {\\\"Authorization\\\": \\\"Bearer " + userToken + "\\\"}\\n\",\n\t\"client = RayJobClient(address=ray_dashboard, headers=header, verify=False)\\n",
68-
"--num-devices=8": fmt.Sprintf("--num-devices=%d", numGpus),
69-
"--num-epochs=3": fmt.Sprintf("--num-epochs=%d", 1),
70-
"--ds-config=./deepspeed_configs/zero_3_llama_2_7b.json": "--ds-config=./zero_3_llama_2_7b.json \\\"\\n\",\n\t\" \\\"--lora-config=./lora.json \\\"\\n\",\n\t\" \\\"--as-test",
67+
"client = cluster.job_client": "ray_dashboard = cluster.cluster_dashboard_uri()\\n\",\n\t\"header = {\\\"Authorization\\\": \\\"Bearer " + userToken + "\\\"}\\n\",\n\t\"client = RayJobClient(address=ray_dashboard, headers=header, verify=False)\\n",
68+
"--num-devices=8": fmt.Sprintf("--num-devices=%d", numGpus),
69+
"--num-epochs=3": fmt.Sprintf("--num-epochs=%d", 1),
70+
"--ds-config=./deepspeed_configs/zero_3_llama_2_7b.json": "--ds-config=./zero_3_llama_2_7b.json \\\"\\n\",\n\t\" \\\"--lora-config=./lora.json \\\"\\n\",\n\t\" \\\"--as-test",
7171
"'pip': 'requirements.txt'": "'pip': '/opt/app-root/src/requirements.txt'",
7272
"'working_dir': './'": "'working_dir': '/opt/app-root/src'",
7373
"client.stop_job(submission_id)": "finished = False\\n\",\n\t\"while not finished:\\n\",\n\t\" time.sleep(1)\\n\",\n\t\" status = client.get_job_status(submission_id)\\n\",\n\t\" finished = (status == \\\"SUCCEEDED\\\")\\n\",\n\t\"if finished:\\n\",\n\t\" print(\\\"Job completed Successfully !\\\")\\n\",\n\t\"else:\\n\",\n\t\" print(\\\"Job failed !\\\")\\n\",\n\t\"time.sleep(10)\\n",
@@ -111,6 +111,47 @@ func rayFinetuneLlmDeepspeed(t *testing.T, numGpus int) {
111111
),
112112
)
113113

114+
time.Sleep(30 * time.Second)
115+
116+
// Fetch created raycluster
117+
rayClusterName := "ray"
118+
rayCluster, err := test.Client().Ray().RayV1().RayClusters(namespace.Name).Get(test.Ctx(), rayClusterName, metav1.GetOptions{})
119+
test.Expect(err).ToNot(HaveOccurred())
120+
121+
// Initialise raycluster client to interact with raycluster to get rayjob details using REST-API
122+
dashboardUrl := GetDashboardUrl(test, namespace, rayCluster)
123+
rayClusterClientConfig := RayClusterClientConfig{Address: dashboardUrl.String(), Client: nil, SkipTlsVerification: true}
124+
rayClient, err := NewRayClusterClient(rayClusterClientConfig, test.Config().BearerToken)
125+
if err != nil {
126+
test.T().Errorf("%s", err)
127+
}
128+
129+
jobID := GetTestJobId(test, rayClient, dashboardUrl.Host)
130+
test.Expect(jobID).ToNot(Equal(nil))
131+
132+
// Wait for the job to be succeeded or failed
133+
var rayJobStatus string
134+
fmt.Printf("Waiting for job to be Succeeded...\n")
135+
test.Eventually(func() string {
136+
resp, err := rayClient.GetJobDetails(jobID)
137+
test.Expect(err).ToNot(HaveOccurred())
138+
rayJobStatusVal := resp.Status
139+
if rayJobStatusVal == "SUCCEEDED" || rayJobStatusVal == "FAILED" {
140+
fmt.Printf("JobStatus : %s\n", rayJobStatusVal)
141+
rayJobStatus = rayJobStatusVal
142+
WriteRayJobAPILogs(test, rayClient, jobID)
143+
return rayJobStatus
144+
}
145+
if rayJobStatus != rayJobStatusVal && rayJobStatusVal != "SUCCEEDED" {
146+
fmt.Printf("JobStatus : %s...\n", rayJobStatusVal)
147+
rayJobStatus = rayJobStatusVal
148+
}
149+
return rayJobStatus
150+
}, TestTimeoutDouble, 3*time.Second).Should(Or(Equal("SUCCEEDED"), Equal("FAILED")), "Job did not complete within the expected time")
151+
// Store job logs in output directory
152+
WriteRayJobAPILogs(test, rayClient, jobID)
153+
test.Expect(rayJobStatus).To(Equal("SUCCEEDED"), "RayJob failed !")
154+
114155
// Make sure the RayCluster finishes and is deleted
115156
test.Eventually(RayClusters(test, namespace.Name), TestTimeoutMedium).
116157
Should(HaveLen(0))

tests/odh/support.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"embed"
2121
"net/http"
2222
"net/url"
23+
"os"
2324

2425
. "github.com/onsi/gomega"
2526
gomega "github.com/onsi/gomega"
@@ -39,6 +40,13 @@ func ReadFile(t support.Test, fileName string) []byte {
3940
return file
4041
}
4142

43+
func ReadFileExt(t support.Test, fileName string) []byte {
44+
t.T().Helper()
45+
file, err := os.ReadFile(fileName)
46+
t.Expect(err).NotTo(gomega.HaveOccurred())
47+
return file
48+
}
49+
4250
func GetDashboardUrl(test support.Test, namespace *v1.Namespace, rayCluster *rayv1.RayCluster) *url.URL {
4351
dashboardName := "ray-dashboard-" + rayCluster.Name
4452
test.T().Logf("Raycluster created : %s\n", rayCluster.Name)

0 commit comments

Comments
 (0)