Skip to content

Commit 76a46e9

Browse files
Reused example finetune demo files and added provision for job status assert
1 parent 30510df commit 76a46e9

File tree

12 files changed

+115
-920
lines changed

12 files changed

+115
-920
lines changed

examples/ray-finetune-llm-deepspeed/create_dataset.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
import json
33
import os
44

5-
dataset = load_dataset("gsm8k", "main", cache_dir="../../datasets")
5+
datasets_dir="../../datasets"
6+
if os.path.exists(datasets_dir):
7+
dataset = load_dataset("gsm8k", "main", cache_dir=datasets_dir)
8+
else:
9+
dataset = load_dataset("gsm8k", "main")
610

711
dataset_splits = {"train": dataset["train"], "test": dataset["test"]}
812

examples/ray-finetune-llm-deepspeed/ray_finetune_llm_deepspeed.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,9 @@ def parse_args():
592592

593593
parser.add_argument("--lora", action="store_true", default=False,
594594
help="If passed, will enable parameter efficient fine-tuning with LoRA.")
595+
596+
parser.add_argument("--lora-config", type=str, default="./lora_configs/lora.json",
597+
help="Lora config json to use.")
595598

596599
parser.add_argument("--num-epochs", type=int, default=1,
597600
help="Number of epochs to train for.")
@@ -660,7 +663,7 @@ def main():
660663

661664
# Add LoRA config if needed
662665
if args.lora:
663-
with open("./lora_configs/lora.json", "r") as json_file:
666+
with open(args.lora_config, "r") as json_file:
664667
lora_config = json.load(json_file)
665668
config["lora_config"] = lora_config
666669

tests/odh/notebook.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,19 @@ type NotebookProps struct {
5757
func createNotebook(test Test, namespace *corev1.Namespace, notebookUserToken, jupyterNotebookConfigMapName, jupyterNotebookConfigMapFileName string, numGpus int) {
5858
// Create PVC for Notebook
5959
notebookPVC := CreatePersistentVolumeClaim(test, namespace.Name, "10Gi", corev1.ReadWriteOnce)
60-
s3BucketName, _ := GetStorageBucketName()
60+
s3BucketName, exists := GetStorageBucketName()
6161
s3AccessKeyId, _ := GetStorageBucketAccessKeyId()
6262
s3SecretAccessKey, _ := GetStorageBucketSecretKey()
6363
s3DefaultRegion, _ := GetStorageBucketDefaultRegion()
6464

65+
if !exists {
66+
println("Storage bucket doesn't exists!")
67+
s3BucketName = "\"\""
68+
s3AccessKeyId = "\"\""
69+
s3SecretAccessKey = "\"\""
70+
s3DefaultRegion = "\"\""
71+
}
72+
6573
// Read the Notebook CR from resources and perform replacements for custom values using go template
6674
notebookProps := NotebookProps{
6775
IngressDomain: GetOpenShiftIngressDomain(test),

tests/odh/ray_finetune_llm_deepspeed_test.go

Lines changed: 83 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,19 @@ limitations under the License.
1717
package odh
1818

1919
import (
20+
"crypto/tls"
21+
"encoding/json"
22+
"fmt"
23+
"io"
24+
"net/http"
25+
"os"
2026
"testing"
27+
"time"
2128

2229
. "github.com/onsi/gomega"
2330
. "github.com/project-codeflare/codeflare-common/support"
2431
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
32+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2533
)
2634

2735
func TestRayFinetuneDemo(t *testing.T) {
@@ -33,20 +41,19 @@ func mnistRayLlmFinetune(t *testing.T, numGpus int) {
3341

3442
// Create a namespace
3543
namespace := test.NewTestNamespace()
44+
var workingDirectory, _ = os.Getwd()
3645

3746
// Test configuration
3847
jupyterNotebookConfigMapFileName := "ray_finetune_llm_deepspeed.ipynb"
39-
40-
// Test configuration
4148
configMap := map[string][]byte{
4249
// MNIST Ray Notebook
4350
jupyterNotebookConfigMapFileName: ReadFile(test, "resources/ray_finetune_demo/ray_finetune_llm_deepspeed.ipynb"),
44-
"ray_finetune_llm_deepspeed.py": ReadFile(test, "resources/ray_finetune_demo/ray_finetune_llm_deepspeed.py"),
51+
"ray_finetune_llm_deepspeed.py": ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/ray_finetune_llm_deepspeed.py"),
4552
"ray_finetune_requirements.txt": ReadRayFinetuneRequirementsTxt(test),
46-
"create_dataset.py": ReadFile(test, "resources/ray_finetune_demo/create_dataset.py"),
47-
"lora.json": ReadFile(test, "resources/ray_finetune_demo/lora.json"),
48-
"zero_3_llama_2_7b.json": ReadFile(test, "resources/ray_finetune_demo/zero_3_llama_2_7b.json"),
49-
"utils.py": ReadFile(test, "resources/ray_finetune_demo/utils.py"),
53+
"create_dataset.py": ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/create_dataset.py"),
54+
"lora.json": ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/lora_configs/lora.json"),
55+
"zero_3_llama_2_7b.json": ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/deepspeed_configs/zero_3_llama_2_7b.json"),
56+
"utils.py": ReadFileExt(test, workingDirectory+"/../../examples/ray-finetune-llm-deepspeed/utils.py"),
5057
}
5158

5259
config := CreateConfigMap(test, namespace.Name, configMap)
@@ -75,9 +82,77 @@ func mnistRayLlmFinetune(t *testing.T, numGpus int) {
7582
ContainElement(WithTransform(RayClusterState, Equal(rayv1.Ready))),
7683
),
7784
)
85+
time.Sleep(30 * time.Second)
86+
87+
rayClusters, _ := test.Client().Ray().RayV1().RayClusters(namespace.Name).List(test.Ctx(), metav1.ListOptions{})
88+
test.Expect(len(rayClusters.Items)).To(BeNumerically(">", 0))
89+
dashboardName := "ray-dashboard-" + rayClusters.Items[0].Name
90+
fmt.Printf("Raycluster created : %s\n", rayClusters.Items[0].Name)
91+
route := GetRoute(test, namespace.Name, dashboardName)
92+
hostname := route.Status.Ingress[0].Host
93+
94+
// Wait for expected HTTP code
95+
fmt.Printf("Waiting for Route %s/%s to be available...\n", route.Namespace, route.Name)
96+
tr := &http.Transport{
97+
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
98+
Proxy: http.ProxyFromEnvironment,
99+
}
100+
client := &http.Client{Transport: tr}
101+
req, err := http.NewRequest("GET", "https://"+hostname+"/api/jobs/", nil)
102+
if err != nil {
103+
test.T().Fatal(err)
104+
}
105+
req.Header.Add("Authorization", "Bearer "+test.Config().BearerToken)
106+
107+
resp, err := client.Do(req)
108+
test.Expect(err).ToNot(HaveOccurred())
109+
test.Expect(resp.StatusCode).ToNot(Equal(503))
110+
defer resp.Body.Close()
111+
body, err := io.ReadAll(resp.Body)
112+
test.Expect(err).ToNot(HaveOccurred())
113+
114+
var resp_json []map[string]interface{}
115+
err = json.Unmarshal(body, &resp_json)
116+
test.Expect(err).ToNot(HaveOccurred())
117+
if len(resp_json) > 0 {
118+
fmt.Printf("Job is submitted in the raycluster!\nSubmission-ID : %s\n", resp_json[0]["submission_id"])
119+
}
120+
121+
var status string
122+
var prevStatus string
123+
fmt.Printf("Waiting for job to be Succeeded...\n")
124+
for status != "SUCCEEDED" {
125+
resp, err := client.Do(req)
126+
test.Expect(err).ToNot(HaveOccurred())
127+
body, err := io.ReadAll(resp.Body)
128+
test.Expect(err).ToNot(HaveOccurred())
129+
var result []map[string]interface{}
130+
if err := json.Unmarshal(body, &result); err != nil {
131+
time.Sleep(2 * time.Second)
132+
break
133+
}
134+
if status, ok := result[0]["status"].(string); ok {
135+
if prevStatus != status {
136+
fmt.Printf("JobStatus : %s...\n", status)
137+
prevStatus = status
138+
}
139+
if status == "SUCCEEDED" {
140+
prevStatus = status
141+
break
142+
}
143+
prevStatus = status
144+
} else {
145+
test.T().Logf("Status key not found or not a string")
146+
}
147+
time.Sleep(3 * time.Second)
148+
}
149+
if prevStatus != "SUCCEEDED" {
150+
fmt.Printf("Job failed!")
151+
}
152+
test.Expect(prevStatus).To(Equal("SUCCEEDED"))
78153

79154
// Make sure the RayCluster finishes and is deleted
80-
test.Eventually(RayClusters(test, namespace.Name), TestTimeoutGpuProvisioning).
155+
test.Eventually(RayClusters(test, namespace.Name), TestTimeoutMedium).
81156
Should(HaveLen(0))
82157
}
83158

tests/odh/resources/mnist_ray_mini.ipynb

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,5 @@
11
{
22
"cells": [
3-
{
4-
"cell_type": "code",
5-
"execution_count": null,
6-
"id": "df737457",
7-
"metadata": {},
8-
"outputs": [],
9-
"source": [
10-
"%pip install codeflare-sdk -U"
11-
]
12-
},
133
{
144
"cell_type": "code",
155
"execution_count": null,
@@ -53,8 +43,7 @@
5343
"namespace = \"default\"\n",
5444
"openshift_api_url = \"has to be specified\"\n",
5545
"kubernetes_user_bearer_token = \"has to be specified\"\n",
56-
"num_gpus = \"has to be specified\"\n",
57-
"print(\"*\"*8, namespace,openshift_api_url,kubernetes_user_bearer_token, num_gpus)"
46+
"num_gpus = \"has to be specified\""
5847
]
5948
},
6049
{

tests/odh/resources/ray_finetune_demo/create_dataset.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

tests/odh/resources/ray_finetune_demo/lora.json

Lines changed: 0 additions & 11 deletions
This file was deleted.

tests/odh/resources/ray_finetune_demo/ray_finetune_llm_deepspeed.ipynb

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@
135135
"outputs": [],
136136
"source": [
137137
"directory_path = os.path.expanduser(\"~/.codeflare/resources/\")\n",
138-
"outfile = os.path.join(directory_path, \"mnisttest.yaml\")\n",
138+
"outfile = os.path.join(directory_path, \"ray-finetune-test.yaml\")\n",
139139
"cluster_yaml = None\n",
140140
"with open(outfile) as f:\n",
141141
" cluster_yaml = yaml.load(f, yaml.FullLoader)\n",
@@ -201,12 +201,13 @@
201201
" entrypoint=\"python ray_finetune_llm_deepspeed.py \"\n",
202202
" \"--model-name=meta-llama/Llama-2-7b-chat-hf \"\n",
203203
" \"--lora \"\n",
204-
" f\"--num-devices=1 \"\n",
204+
" \"--num-devices=1 \"\n",
205205
" \"--num-epochs=1 \"\n",
206-
" f\"--ds-config=zero_3_llama_2_7b.json \"\n",
207206
" f\"--storage-path=s3://{s3_bucket_name}/ray-finetune-llm-deepspeed3/\"\n",
208207
" \"--batch-size-per-device=32 \"\n",
209208
" \"--eval-batch-size-per-device=32 \"\n",
209+
" \"--ds-config=./zero_3_llama_2_7b.json \"\n",
210+
" \"--lora-config=./lora.json \"\n",
210211
" \"--as-test \",\n",
211212
" runtime_env={\n",
212213
" \"env_vars\": {\n",
@@ -236,6 +237,7 @@
236237
" finished = (status == \"SUCCEEDED\")\n",
237238
"if finished:\n",
238239
" print(\"Job completed Successfully !\")\n",
240+
" sleep(10)\n",
239241
"else:\n",
240242
" print(\"Job failed !\")"
241243
]

0 commit comments

Comments
 (0)