@@ -141,6 +141,86 @@ func TestRayClusterAutoscalerWithFakeGPU(t *testing.T) {
141141 }
142142}
143143
144+ func TestRayClusterAutoscalerWithFakeSingleHostTPU (t * testing.T ) {
145+ for _ , tc := range tests {
146+ t .Run (tc .name , func (t * testing.T ) {
147+ test := With (t )
148+ g := gomega .NewWithT (t )
149+
150+ // Create a namespace
151+ namespace := test .NewTestNamespace ()
152+
153+ // Scripts for creating and terminating detached actors to trigger autoscaling
154+ scriptsAC := newConfigMap (namespace .Name , files (test , "create_detached_actor.py" , "terminate_detached_actor.py" ))
155+ scripts , err := test .Client ().Core ().CoreV1 ().ConfigMaps (namespace .Name ).Apply (test .Ctx (), scriptsAC , TestApplyOptions )
156+ g .Expect (err ).NotTo (gomega .HaveOccurred ())
157+ LogWithTimestamp (test .T (), "Created ConfigMap %s/%s successfully" , scripts .Namespace , scripts .Name )
158+
159+ groupName := "tpu-group"
160+ workerPodTemplate := tc .WorkerPodTemplateGetter ()
161+ // Set required TPU-specific Pod fields.
162+ workerPodTemplate .Spec .NodeSelector = map [string ]string {
163+ "cloud.google.com/gke-tpu-accelerator" : "tpu-v6e-slice" ,
164+ "cloud.google.com/gke-tpu-topology" : "1x1" ,
165+ }
166+
167+ rayClusterSpecAC := rayv1ac .RayClusterSpec ().
168+ WithEnableInTreeAutoscaling (true ).
169+ WithRayVersion (GetRayVersion ()).
170+ WithHeadGroupSpec (rayv1ac .HeadGroupSpec ().
171+ WithRayStartParams (map [string ]string {"num-cpus" : "0" }).
172+ WithTemplate (tc .HeadPodTemplateGetter ())).
173+ WithWorkerGroupSpecs (rayv1ac .WorkerGroupSpec ().
174+ WithReplicas (0 ).
175+ WithMinReplicas (0 ).
176+ WithMaxReplicas (3 ).
177+ WithNumOfHosts (1 ).
178+ WithGroupName (groupName ).
179+ WithRayStartParams (map [string ]string {"num-cpus" : "1" , "resources" : `'{"TPU":4}'` }).
180+ WithTemplate (workerPodTemplate ))
181+ rayClusterAC := rayv1ac .RayCluster ("ray-cluster" , namespace .Name ).
182+ WithSpec (apply (rayClusterSpecAC , mountConfigMap [rayv1ac.RayClusterSpecApplyConfiguration ](scripts , "/home/ray/test_scripts" )))
183+
184+ rayCluster , err := test .Client ().Ray ().RayV1 ().RayClusters (namespace .Name ).Apply (test .Ctx (), rayClusterAC , TestApplyOptions )
185+ g .Expect (err ).NotTo (gomega .HaveOccurred ())
186+ LogWithTimestamp (test .T (), "Created RayCluster %s/%s successfully" , rayCluster .Namespace , rayCluster .Name )
187+
188+ // Wait for RayCluster to become ready and verify the number of available worker replicas.
189+ g .Eventually (RayCluster (test , rayCluster .Namespace , rayCluster .Name ), TestTimeoutMedium ).
190+ Should (gomega .WithTransform (RayClusterState , gomega .Equal (rayv1 .Ready )))
191+ g .Expect (GetRayCluster (test , rayCluster .Namespace , rayCluster .Name )).To (gomega .WithTransform (RayClusterDesiredWorkerReplicas , gomega .Equal (int32 (0 ))))
192+
193+ headPod , err := GetHeadPod (test , rayCluster )
194+ g .Expect (err ).NotTo (gomega .HaveOccurred ())
195+ LogWithTimestamp (test .T (), "Found head pod %s/%s" , headPod .Namespace , headPod .Name )
196+
197+ // Create a detached TPU actor, and a TPU worker Pod should be created.
198+ ExecPodCmd (test , headPod , common .RayHeadContainer , []string {"python" , "/home/ray/test_scripts/create_detached_actor.py" , "tpu_actor" , "--custom-resources=TPU=4" })
199+ g .Eventually (RayCluster (test , rayCluster .Namespace , rayCluster .Name ), TestTimeoutMedium ).
200+ Should (gomega .WithTransform (RayClusterDesiredWorkerReplicas , gomega .Equal (int32 (1 ))))
201+ // We don't use real TPU resources of Kubernetes here, therefore we can't test the RayClusterDesiredTPU.
202+ // We check the TPU worker group's number of Pods instead.
203+ g .Expect (GetGroupPods (test , rayCluster , groupName )).To (gomega .HaveLen (1 ))
204+ LogWithTimestamp (test .T (), "Created TPU worker of group %s" , groupName )
205+
206+ // Terminate the TPU actor to remove the allocated resource request.
207+ ExecPodCmd (test , headPod , common .RayHeadContainer , []string {"python" , "/home/ray/test_scripts/terminate_detached_actor.py" , "tpu_actor" })
208+
209+ // Set maxReplicas of the TPU worker group replica to 0 to force scale-down.
210+ // It's impossible to wait on idle timeout since the required TPU nodeSelectors prevent scheduling.
211+ rayCluster , err = test .Client ().Ray ().RayV1 ().RayClusters (namespace .Name ).Get (test .Ctx (), rayCluster .Name , metav1.GetOptions {})
212+ g .Expect (err ).NotTo (gomega .HaveOccurred ())
213+ rayCluster .Spec .WorkerGroupSpecs [0 ].MaxReplicas = ptr .To (int32 (0 ))
214+ rayCluster , err = test .Client ().Ray ().RayV1 ().RayClusters (namespace .Name ).Update (test .Ctx (), rayCluster , metav1.UpdateOptions {})
215+ g .Expect (err ).NotTo (gomega .HaveOccurred ())
216+ LogWithTimestamp (test .T (), "Updated RayCluster %s/%s successfully" , rayCluster .Namespace , rayCluster .Name )
217+
218+ // Validate that the TPU slice is scaled down.
219+ g .Eventually (WorkerPods (test , rayCluster ), TestTimeoutMedium ).Should (gomega .BeEmpty ())
220+ })
221+ }
222+ }
223+
144224func TestRayClusterAutoscalerWithCustomResource (t * testing.T ) {
145225 for _ , tc := range tests {
146226 t .Run (tc .name , func (t * testing.T ) {
@@ -187,7 +267,7 @@ func TestRayClusterAutoscalerWithCustomResource(t *testing.T) {
187267 LogWithTimestamp (test .T (), "Found head pod %s/%s" , headPod .Namespace , headPod .Name )
188268
189269 // Create a detached custom resource actor, and a worker in the "custom-resource-group" should be created.
190- ExecPodCmd (test , headPod , common .RayHeadContainer , []string {"python" , "/home/ray/test_scripts/create_detached_actor.py" , "custom_resource_actor" , "--num- custom-resources=1" })
270+ ExecPodCmd (test , headPod , common .RayHeadContainer , []string {"python" , "/home/ray/test_scripts/create_detached_actor.py" , "custom_resource_actor" , "--custom-resources=CustomResource =1" })
191271 g .Eventually (RayCluster (test , rayCluster .Namespace , rayCluster .Name ), TestTimeoutMedium ).
192272 Should (gomega .WithTransform (RayClusterDesiredWorkerReplicas , gomega .Equal (int32 (1 ))))
193273 g .Expect (GetGroupPods (test , rayCluster , groupName )).To (gomega .HaveLen (1 ))
0 commit comments