Skip to content

Commit 53ffaa6

Browse files
Add e2e Test for HyperParameter Optimisation with Ray Tune
1 parent 7bbcefb commit 53ffaa6

File tree

6 files changed

+534
-15
lines changed

6 files changed

+534
-15
lines changed

tests/odh/mnist_ray_test.go

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -189,18 +189,3 @@ func readMnistPy(test Test) []byte {
189189

190190
return ParseTemplate(test, template, props)
191191
}
192-
193-
// TODO: This belongs on codeflare-common/support/ray.go
194-
func rayClusters(t Test, namespace *corev1.Namespace) func(g Gomega) []*rayv1.RayCluster {
195-
return func(g Gomega) []*rayv1.RayCluster {
196-
rcs, err := t.Client().Ray().RayV1().RayClusters(namespace.Name).List(t.Ctx(), metav1.ListOptions{})
197-
g.Expect(err).NotTo(HaveOccurred())
198-
199-
rcsp := []*rayv1.RayCluster{}
200-
for _, v := range rcs.Items {
201-
rcsp = append(rcsp, &v)
202-
}
203-
204-
return rcsp
205-
}
206-
}
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
Copyright 2023.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package odh
18+
19+
import (
20+
"bytes"
21+
"fmt"
22+
"testing"
23+
24+
. "github.com/onsi/gomega"
25+
. "github.com/project-codeflare/codeflare-common/support"
26+
rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
27+
corev1 "k8s.io/api/core/v1"
28+
"k8s.io/apimachinery/pkg/api/resource"
29+
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
30+
"sigs.k8s.io/kueue/apis/kueue/v1beta1"
31+
)
32+
33+
func TestMnistRayTuneHpoCpu(t *testing.T) {
34+
mnistRayTuneHpo(t, 0)
35+
}
36+
37+
func TestMnistRayTuneHpoGpu(t *testing.T) {
38+
mnistRayTuneHpo(t, 1)
39+
}
40+
41+
func mnistRayTuneHpo(t *testing.T, numGpus int) {
42+
test := With(t)
43+
44+
// Creating a namespace
45+
namespace := test.NewTestNamespace()
46+
47+
// Create Kueue resources
48+
resourceFlavor := CreateKueueResourceFlavor(test, v1beta1.ResourceFlavorSpec{})
49+
defer test.Client().Kueue().KueueV1beta1().ResourceFlavors().Delete(test.Ctx(), resourceFlavor.Name, metav1.DeleteOptions{})
50+
cqSpec := v1beta1.ClusterQueueSpec{
51+
NamespaceSelector: &metav1.LabelSelector{},
52+
ResourceGroups: []v1beta1.ResourceGroup{
53+
{
54+
CoveredResources: []corev1.ResourceName{corev1.ResourceName("cpu"), corev1.ResourceName("memory"), corev1.ResourceName("nvidia.com/gpu")},
55+
Flavors: []v1beta1.FlavorQuotas{
56+
{
57+
Name: v1beta1.ResourceFlavorReference(resourceFlavor.Name),
58+
Resources: []v1beta1.ResourceQuota{
59+
{
60+
Name: corev1.ResourceCPU,
61+
NominalQuota: resource.MustParse("8"),
62+
},
63+
{
64+
Name: corev1.ResourceMemory,
65+
NominalQuota: resource.MustParse("12Gi"),
66+
},
67+
{
68+
Name: corev1.ResourceName("nvidia.com/gpu"),
69+
NominalQuota: resource.MustParse(fmt.Sprint(numGpus)),
70+
},
71+
},
72+
},
73+
},
74+
},
75+
},
76+
}
77+
clusterQueue := CreateKueueClusterQueue(test, cqSpec)
78+
defer test.Client().Kueue().KueueV1beta1().ClusterQueues().Delete(test.Ctx(), clusterQueue.Name, metav1.DeleteOptions{})
79+
localQueue := CreateKueueLocalQueue(test, namespace.Name, clusterQueue.Name)
80+
81+
// Test configuration
82+
jupyterNotebookConfigMapFileName := "mnist_hpo_raytune.ipynb"
83+
mnist_hpo := ReadFile(test, "resources/mnist_hpo.py")
84+
85+
if numGpus > 0 {
86+
mnist_hpo = bytes.Replace(mnist_hpo, []byte("gpu_value=\"has to be specified\""), []byte("gpu_value=\"1\""), 1)
87+
} else {
88+
mnist_hpo = bytes.Replace(mnist_hpo, []byte("gpu_value=\"has to be specified\""), []byte("gpu_value=\"0\""), 1)
89+
}
90+
91+
config := CreateConfigMap(test, namespace.Name, map[string][]byte{
92+
// MNIST Raytune HPO Notebook
93+
jupyterNotebookConfigMapFileName: ReadFile(test, "resources/mnist_hpo_raytune.ipynb"),
94+
"mnist_hpo.py": mnist_hpo,
95+
"hpo_raytune_requirements.txt": ReadFile(test, "resources/hpo_raytune_requirements.txt"),
96+
})
97+
98+
// Define the regular(non-admin) user
99+
userName := GetNotebookUserName(test)
100+
userToken := GetNotebookUserToken(test)
101+
102+
// Create role binding with Namespace specific admin cluster role
103+
CreateUserRoleBindingWithClusterRole(test, userName, namespace.Name, "admin")
104+
105+
// Create Notebook CR
106+
createNotebook(test, namespace, userToken, localQueue.Name, config.Name, jupyterNotebookConfigMapFileName, numGpus)
107+
108+
// Gracefully cleanup Notebook
109+
defer func() {
110+
deleteNotebook(test, namespace)
111+
test.Eventually(listNotebooks(test, namespace), TestTimeoutMedium).Should(HaveLen(0))
112+
}()
113+
114+
// Make sure the RayCluster is created and running
115+
test.Eventually(rayClusters(test, namespace), TestTimeoutLong).
116+
Should(
117+
And(
118+
HaveLen(1),
119+
ContainElement(WithTransform(RayClusterState, Equal(rayv1.Ready))),
120+
),
121+
)
122+
123+
// Make sure the Workload is created and running
124+
test.Eventually(GetKueueWorkloads(test, namespace.Name), TestTimeoutMedium).
125+
Should(
126+
And(
127+
HaveLen(1),
128+
ContainElement(WithTransform(KueueWorkloadAdmitted, BeTrueBecause("Workload failed to be admitted"))),
129+
),
130+
)
131+
132+
// Make sure the RayCluster finishes and is deleted
133+
test.Eventually(rayClusters(test, namespace), TestTimeoutLong).
134+
Should(HaveLen(0))
135+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
torchvision==0.18.0

tests/odh/resources/mnist_hpo.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import os
2+
import tempfile
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
import torch.optim as optim
8+
from filelock import FileLock
9+
from torchvision import datasets, transforms
10+
11+
import ray
12+
from ray import train, tune
13+
from ray.train import Checkpoint
14+
from ray.tune.schedulers import AsyncHyperBandScheduler
15+
16+
EPOCH_SIZE = 128
17+
TEST_SIZE = 64
18+
19+
20+
class ConvNet(nn.Module):
21+
def __init__(self):
22+
super(ConvNet, self).__init__()
23+
self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
24+
self.fc = nn.Linear(192, 10)
25+
26+
def forward(self, x):
27+
x = F.relu(F.max_pool2d(self.conv1(x), 3))
28+
x = x.view(-1, 192)
29+
x = self.fc(x)
30+
return F.log_softmax(x, dim=1)
31+
32+
33+
def train_func(model, optimizer, train_loader, device=None):
34+
device = device or torch.device("cpu")
35+
model.train()
36+
for batch_idx, (data, target) in enumerate(train_loader):
37+
if batch_idx * len(data) > EPOCH_SIZE:
38+
return
39+
data, target = data.to(device), target.to(device)
40+
optimizer.zero_grad()
41+
output = model(data)
42+
loss = F.nll_loss(output, target)
43+
loss.backward()
44+
optimizer.step()
45+
46+
47+
def test_func(model, data_loader, device=None):
48+
device = device or torch.device("cpu")
49+
model.eval()
50+
correct = 0
51+
total = 0
52+
with torch.no_grad():
53+
for batch_idx, (data, target) in enumerate(data_loader):
54+
if batch_idx * len(data) > TEST_SIZE:
55+
break
56+
data, target = data.to(device), target.to(device)
57+
outputs = model(data)
58+
_, predicted = torch.max(outputs.data, 1)
59+
total += target.size(0)
60+
correct += (predicted == target).sum().item()
61+
62+
return correct / total
63+
64+
65+
def get_data_loaders(batch_size=128):
66+
mnist_transforms = transforms.Compose(
67+
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
68+
)
69+
70+
# We add FileLock here because multiple workers will want to
71+
# download data, and this may cause overwrites since
72+
# DataLoader is not threadsafe.
73+
with FileLock(os.path.expanduser("~/data.lock")):
74+
train_loader = torch.utils.data.DataLoader(
75+
datasets.MNIST(
76+
"~/data", train=True, download=True, transform=mnist_transforms
77+
),
78+
batch_size=batch_size,
79+
shuffle=True,
80+
)
81+
test_loader = torch.utils.data.DataLoader(
82+
datasets.MNIST(
83+
"~/data", train=False, download=True, transform=mnist_transforms
84+
),
85+
batch_size=batch_size,
86+
shuffle=True,
87+
)
88+
return train_loader, test_loader
89+
90+
91+
def train_mnist(config):
92+
should_checkpoint = config.get("should_checkpoint", False)
93+
use_cuda = torch.cuda.is_available()
94+
device = torch.device("cuda" if use_cuda else "cpu")
95+
train_loader, test_loader = get_data_loaders()
96+
model = ConvNet().to(device)
97+
98+
optimizer = optim.SGD(
99+
model.parameters(), lr=config["lr"], momentum=config["momentum"]
100+
)
101+
102+
while True:
103+
train_func(model, optimizer, train_loader, device)
104+
acc = test_func(model, test_loader, device)
105+
metrics = {"mean_accuracy": acc}
106+
107+
# Report metrics (and possibly a checkpoint)
108+
if should_checkpoint:
109+
with tempfile.TemporaryDirectory() as tempdir:
110+
torch.save(model.state_dict(), os.path.join(tempdir, "model.pt"))
111+
train.report(metrics, checkpoint=Checkpoint.from_directory(tempdir))
112+
else:
113+
train.report(metrics)
114+
115+
116+
if __name__ == "__main__":
117+
# for early stopping
118+
sched = AsyncHyperBandScheduler()
119+
gpu_value="has to be specified"
120+
resources_per_trial = {"cpu": 1, "gpu": gpu_value}
121+
tuner = tune.Tuner(
122+
tune.with_resources(train_mnist, resources=resources_per_trial),
123+
tune_config=tune.TuneConfig(
124+
metric="mean_accuracy",
125+
mode="max",
126+
scheduler=sched,
127+
num_samples=5,
128+
),
129+
run_config=train.RunConfig(
130+
name="exp",
131+
stop={
132+
"mean_accuracy": 0.98,
133+
"training_iteration": 5,
134+
},
135+
),
136+
param_space={
137+
"lr": tune.loguniform(1e-4, 1e-2),
138+
"momentum": tune.uniform(0.1, 0.9),
139+
},
140+
)
141+
results = tuner.fit()
142+
143+
print("Best hyperparameters config is:", results.get_best_result().config)
144+
145+
assert not results.errors

0 commit comments

Comments
 (0)