diff --git a/apiserver/cmd/main.go b/apiserver/cmd/main.go index fd543ca19a3..b479a45441c 100644 --- a/apiserver/cmd/main.go +++ b/apiserver/cmd/main.go @@ -173,8 +173,8 @@ func startHttpProxy() { KubernetesConfig: kubernetesConfig, Middleware: corsHandler, // Always set, even if it's a no-op } - - topMux, err = apiserversdk.NewMux(muxConfig) + clientManager := manager.NewClientManager() + topMux, err = apiserversdk.NewMux(muxConfig, &clientManager) if err != nil { klog.Fatalf("Failed to create API server mux: %v", err) } diff --git a/apiserver/pkg/http/client.go b/apiserver/pkg/http/client.go index ecdab649528..bd500b133b3 100644 --- a/apiserver/pkg/http/client.go +++ b/apiserver/pkg/http/client.go @@ -25,7 +25,7 @@ type KuberayAPIServerClient struct { // See https://github.com/ray-project/kuberay/pull/3334/files#r2041183495 for details. // // Store http request handling function for unit test purpose. - executeHttpRequest func(httpRequest *http.Request, URL string) ([]byte, *rpcStatus.Status, error) + ExecuteHttpRequest func(httpRequest *http.Request, URL string) ([]byte, *rpcStatus.Status, error) baseURL string retryCfg apiserversdkutil.RetryConfig } @@ -68,13 +68,13 @@ func NewKuberayAPIServerClient(baseURL string, httpClient *http.Client, retryCfg }, retryCfg: retryCfg, } - client.executeHttpRequest = client.executeRequest + client.ExecuteHttpRequest = client.executeRequest return client } // Setter function for setting executeHttpRequest method func (krc *KuberayAPIServerClient) SetExecuteHttpRequest(fn func(httpRequest *http.Request, URL string) ([]byte, *rpcStatus.Status, error)) { - krc.executeHttpRequest = fn + krc.ExecuteHttpRequest = fn } // CreateComputeTemplate creates a new compute template. @@ -94,7 +94,7 @@ func (krc *KuberayAPIServerClient) CreateComputeTemplate(request *api.CreateComp httpRequest.Header.Add("Accept", "application/json") httpRequest.Header.Add("Content-Type", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, createURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, createURL) if err != nil { return nil, status, err } @@ -122,7 +122,7 @@ func (krc *KuberayAPIServerClient) GetComputeTemplate(request *api.GetComputeTem httpRequest.Header.Add("Accept", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, getURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, getURL) if err != nil { return nil, status, err } @@ -143,7 +143,7 @@ func (krc *KuberayAPIServerClient) GetAllComputeTemplates() (*api.ListAllCompute httpRequest.Header.Add("Accept", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, getURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, getURL) if err != nil { return nil, status, err } @@ -164,7 +164,7 @@ func (krc *KuberayAPIServerClient) GetAllComputeTemplatesInNamespace(request *ap httpRequest.Header.Add("Accept", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, getURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, getURL) if err != nil { return nil, status, err } @@ -192,7 +192,7 @@ func (krc *KuberayAPIServerClient) CreateCluster(request *api.CreateClusterReque httpRequest.Header.Add("Accept", "application/json") httpRequest.Header.Add("Content-Type", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, createURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, createURL) if err != nil { return nil, status, err } @@ -219,7 +219,7 @@ func (krc *KuberayAPIServerClient) GetCluster(request *api.GetClusterRequest) (* httpRequest.Header.Add("Accept", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, getURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, getURL) if err != nil { return nil, status, err } @@ -245,7 +245,7 @@ func (krc *KuberayAPIServerClient) ListClusters(request *api.ListClustersRequest httpRequest.Header.Add("Accept", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, getURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, getURL) if err != nil { return nil, status, err } @@ -271,7 +271,7 @@ func (krc *KuberayAPIServerClient) ListAllClusters(request *api.ListAllClustersR httpRequest.Header.Add("Accept", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, getURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, getURL) if err != nil { return nil, status, err } @@ -298,7 +298,7 @@ func (krc *KuberayAPIServerClient) CreateRayJob(request *api.CreateRayJobRequest httpRequest.Header.Add("Accept", "application/json") httpRequest.Header.Add("Content-Type", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, createURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, createURL) if err != nil { return nil, status, err } @@ -319,7 +319,7 @@ func (krc *KuberayAPIServerClient) GetRayJob(request *api.GetRayJobRequest) (*ap httpRequest.Header.Add("Accept", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, getURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, getURL) if err != nil { return nil, status, err } @@ -345,7 +345,7 @@ func (krc *KuberayAPIServerClient) ListRayJobs(request *api.ListRayJobsRequest) httpRequest.Header.Add("Accept", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, getURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, getURL) if err != nil { return nil, status, err } @@ -370,7 +370,7 @@ func (krc *KuberayAPIServerClient) ListAllRayJobs(request *api.ListAllRayJobsReq httpRequest.URL.RawQuery = q.Encode() httpRequest.Header.Add("Accept", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, getURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, getURL) if err != nil { return nil, status, err } @@ -403,7 +403,7 @@ func (krc *KuberayAPIServerClient) CreateRayService(request *api.CreateRayServic httpRequest.Header.Add("Accept", "application/json") httpRequest.Header.Add("Content-Type", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, createURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, createURL) if err != nil { return nil, status, err } @@ -430,7 +430,7 @@ func (krc *KuberayAPIServerClient) UpdateRayService(request *api.UpdateRayServic httpRequest.Header.Add("Accept", "application/json") httpRequest.Header.Add("Content-Type", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, updateURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, updateURL) if err != nil { return nil, status, err } @@ -451,7 +451,7 @@ func (krc *KuberayAPIServerClient) GetRayService(request *api.GetRayServiceReque httpRequest.Header.Add("Accept", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, getURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, getURL) if err != nil { return nil, status, err } @@ -476,7 +476,7 @@ func (krc *KuberayAPIServerClient) ListRayServices(request *api.ListRayServicesR httpRequest.URL.RawQuery = q.Encode() httpRequest.Header.Add("Accept", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, getURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, getURL) if err != nil { return nil, status, err } @@ -502,7 +502,7 @@ func (krc *KuberayAPIServerClient) ListAllRayServices(request *api.ListAllRaySer httpRequest.URL.RawQuery = q.Encode() httpRequest.Header.Add("Accept", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, getURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, getURL) if err != nil { return nil, status, err } @@ -535,7 +535,7 @@ func (krc *KuberayAPIServerClient) SubmitRayJob(request *api.SubmitRayJobRequest httpRequest.Header.Add("Accept", "application/json") httpRequest.Header.Add("Content-Type", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, createURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, createURL) if err != nil { return nil, status, err } @@ -556,7 +556,7 @@ func (krc *KuberayAPIServerClient) GetRayJobDetails(request *api.GetJobDetailsRe httpRequest.Header.Add("Accept", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, getURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, getURL) if err != nil { return nil, status, err } @@ -577,7 +577,7 @@ func (krc *KuberayAPIServerClient) GetRayJobLog(request *api.GetJobLogRequest) ( httpRequest.Header.Add("Accept", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, getURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, getURL) if err != nil { return nil, status, err } @@ -598,7 +598,7 @@ func (krc *KuberayAPIServerClient) ListRayJobsCluster(request *api.ListJobDetail httpRequest.Header.Add("Accept", "application/json") - bodyBytes, status, err := krc.executeHttpRequest(httpRequest, getURL) + bodyBytes, status, err := krc.ExecuteHttpRequest(httpRequest, getURL) if err != nil { return nil, status, err } @@ -621,7 +621,7 @@ func (krc *KuberayAPIServerClient) StopRayJob(request *api.StopRayJobSubmissionR httpRequest.Header.Add("Accept", "application/json") httpRequest.Header.Add("Content-Type", "application/json") - _, status, err := krc.executeHttpRequest(httpRequest, createURL) + _, status, err := krc.ExecuteHttpRequest(httpRequest, createURL) if err != nil { return status, err } @@ -640,7 +640,7 @@ func (krc *KuberayAPIServerClient) doDelete(deleteURL string) (*rpcStatus.Status return nil, fmt.Errorf("failed to create http request for url '%s': %w", deleteURL, err) } httpRequest.Header.Add("Accept", "application/json") - _, status, err := krc.executeHttpRequest(httpRequest, deleteURL) + _, status, err := krc.ExecuteHttpRequest(httpRequest, deleteURL) return status, err } diff --git a/apiserver/pkg/http/client_test.go b/apiserver/pkg/http/client_test.go index c4c62828ee3..b90887559d0 100644 --- a/apiserver/pkg/http/client_test.go +++ b/apiserver/pkg/http/client_test.go @@ -62,7 +62,7 @@ func TestUnmarshalHttpResponseOK(t *testing.T) { } client := NewKuberayAPIServerClient("baseurl", nil /*httpClient*/, retryCfg) - client.executeHttpRequest = func(_ *http.Request, _ string) ([]byte, *rpcStatus.Status, error) { + client.ExecuteHttpRequest = func(_ *http.Request, _ string) ([]byte, *rpcStatus.Status, error) { resp := &api.ListClustersResponse{ Clusters: []*api.Cluster{ { @@ -98,7 +98,7 @@ func TestUnmarshalHttpResponseFails(t *testing.T) { } client := NewKuberayAPIServerClient("baseurl", nil /*httpClient*/, retryCfg) - client.executeHttpRequest = func(_ *http.Request, _ string) ([]byte, *rpcStatus.Status, error) { + client.ExecuteHttpRequest = func(_ *http.Request, _ string) ([]byte, *rpcStatus.Status, error) { // Intentionall returning a bad response. return []byte("helloworld"), nil, nil } diff --git a/apiserver/pkg/manager/resource_manager.go b/apiserver/pkg/manager/resource_manager.go index 9b8f96daa5a..bb4f038de7d 100644 --- a/apiserver/pkg/manager/resource_manager.go +++ b/apiserver/pkg/manager/resource_manager.go @@ -58,7 +58,7 @@ func (r *ResourceManager) getKubernetesNamespaceClient() clientv1.NamespaceInter // clusters func (r *ResourceManager) CreateCluster(ctx context.Context, apiCluster *api.Cluster) (*rayv1api.RayCluster, error) { // populate cluster map - computeTemplateDict, err := r.populateComputeTemplate(ctx, apiCluster.ClusterSpec, apiCluster.Namespace) + computeTemplateDict, err := r.PopulateComputeTemplate(ctx, apiCluster.ClusterSpec, apiCluster.Namespace) if err != nil { return nil, util.NewInternalServerError(err, "Failed to populate compute template for (%s/%s)", apiCluster.Namespace, apiCluster.Name) } @@ -82,13 +82,13 @@ func (r *ResourceManager) CreateCluster(ctx context.Context, apiCluster *api.Clu } // Compute template -func (r *ResourceManager) populateComputeTemplate(ctx context.Context, clusterSpec *api.ClusterSpec, nameSpace string) (map[string]*api.ComputeTemplate, error) { +func (r *ResourceManager) PopulateComputeTemplate(ctx context.Context, clusterSpec *api.ClusterSpec, nameSpace string) (map[string]*api.ComputeTemplate, error) { dict := map[string]*api.ComputeTemplate{} // populate head compute template name := clusterSpec.HeadGroupSpec.ComputeTemplate configMap, err := r.GetComputeTemplate(ctx, name, nameSpace) if err != nil { - return nil, err + return nil, fmt.Errorf("Cannot get compute template for name '%s' in namespace '%s', error: %w", name, nameSpace, err) } computeTemplate := model.FromKubeToAPIComputeTemplate(configMap) dict[name] = computeTemplate @@ -99,7 +99,7 @@ func (r *ResourceManager) populateComputeTemplate(ctx context.Context, clusterSp if _, exist := dict[name]; !exist { configMap, err := r.GetComputeTemplate(ctx, name, nameSpace) if err != nil { - return nil, err + return nil, fmt.Errorf("Cannot get compute template for name '%s' in namespace '%s', error: %w", name, nameSpace, err) } computeTemplate := model.FromKubeToAPIComputeTemplate(configMap) dict[name] = computeTemplate @@ -160,7 +160,7 @@ func (r *ResourceManager) CreateJob(ctx context.Context, apiJob *api.RayJob) (*r // populate cluster map if apiJob.ClusterSpec != nil { - computeTemplateMap, err = r.populateComputeTemplate(ctx, apiJob.ClusterSpec, apiJob.Namespace) + computeTemplateMap, err = r.PopulateComputeTemplate(ctx, apiJob.ClusterSpec, apiJob.Namespace) if err != nil { return nil, util.NewInternalServerError(err, "Failed to populate compute template for (%s/%s)", apiJob.Namespace, apiJob.JobId) } @@ -227,7 +227,7 @@ func (r *ResourceManager) DeleteJob(ctx context.Context, jobName string, namespa func (r *ResourceManager) CreateService(ctx context.Context, apiService *api.RayService) (*rayv1api.RayService, error) { // populate cluster map - computeTemplateDict, err := r.populateComputeTemplate(ctx, apiService.ClusterSpec, apiService.Namespace) + computeTemplateDict, err := r.PopulateComputeTemplate(ctx, apiService.ClusterSpec, apiService.Namespace) if err != nil { return nil, util.NewInternalServerError(err, "Failed to populate compute template for (%s/%s)", apiService.Namespace, apiService.Name) } @@ -254,7 +254,7 @@ func (r *ResourceManager) UpdateRayService(ctx context.Context, apiService *api. return nil, util.Wrap(err, fmt.Sprintf("Update service fail, no service named: %s ", name)) } // populate cluster map - computeTemplateDict, err := r.populateComputeTemplate(ctx, apiService.ClusterSpec, apiService.Namespace) + computeTemplateDict, err := r.PopulateComputeTemplate(ctx, apiService.ClusterSpec, apiService.Namespace) if err != nil { return nil, util.NewInternalServerError(err, "Failed to populate compute template for (%s/%s)", apiService.Namespace, apiService.Name) } diff --git a/apiserver/pkg/manager/resource_manager_test.go b/apiserver/pkg/manager/resource_manager_test.go index a988e18deb3..94bd74f9a4c 100644 --- a/apiserver/pkg/manager/resource_manager_test.go +++ b/apiserver/pkg/manager/resource_manager_test.go @@ -72,7 +72,7 @@ func TestPopulateComputeTemplate(t *testing.T) { // Run resourceManager := NewResourceManager(mockClientManager) - computeTemplates, err := resourceManager.populateComputeTemplate(ctx, clusterSpec, namespace) + computeTemplates, err := resourceManager.PopulateComputeTemplate(ctx, clusterSpec, namespace) // Assert require.NoError(t, err) diff --git a/apiserver/test/e2e/apiserversdk/compute_template_e2e_test.go b/apiserver/test/e2e/apiserversdk/compute_template_e2e_test.go new file mode 100644 index 00000000000..2040b5364d7 --- /dev/null +++ b/apiserver/test/e2e/apiserversdk/compute_template_e2e_test.go @@ -0,0 +1,300 @@ +package apiserversdk + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + + "github.com/ray-project/kuberay/apiserver/pkg/util" + api "github.com/ray-project/kuberay/proto/go_client" +) + +// TestComputeTemplateMiddleware tests that the middleware correctly applies compute template +// resources to RayCluster, RayJob, and RayService resources +func TestComputeTemplateMiddleware(t *testing.T) { + tCtx, err := NewEnd2EndTestingContext(t) + require.NoError(t, err, "No error expected when creating testing context") + + // Create a compute template with specific resources and tolerations + templateName := tCtx.GetNextName() + computeTemplate := &api.ComputeTemplate{ + Name: templateName, + Namespace: tCtx.GetNamespaceName(), + Cpu: 2, + Memory: 4, + Gpu: 1, + GpuAccelerator: "nvidia.com/gpu", + ExtendedResources: map[string]uint32{ + "custom.io/special-resource": 2, + }, + Tolerations: []*api.PodToleration{ + { + Key: "ray.io/node-type", + Operator: "Equal", + Value: "worker", + Effect: "NoSchedule", + }, + }, + } + + // Create compute template + _, _, err = tCtx.GetRayAPIServerClient().CreateComputeTemplate(&api.CreateComputeTemplateRequest{ + ComputeTemplate: computeTemplate, + Namespace: tCtx.GetNamespaceName(), + }) + require.NoError(t, err, "No error expected when creating compute template") + + t.Cleanup(func() { + tCtx.DeleteComputeTemplate(t, templateName) + }) + + t.Run("RayCluster with compute template", func(t *testing.T) { + testRayClusterWithComputeTemplate(t, tCtx, templateName, computeTemplate) + }) + + t.Run("RayJob with compute template", func(t *testing.T) { + testRayJobWithComputeTemplate(t, tCtx, templateName, computeTemplate) + }) + + t.Run("RayService with compute template", func(t *testing.T) { + testRayServiceWithComputeTemplate(t, tCtx, templateName, computeTemplate) + }) +} + +func testRayClusterWithComputeTemplate(t *testing.T, tCtx *End2EndTestingContext, templateName string, computeTemplate *api.ComputeTemplate) { + clusterName := tCtx.GetNextName() + + // Create RayCluster YAML with compute template references + rayClusterYAML := fmt.Sprintf(` +apiVersion: ray.io/v1 +kind: RayCluster +metadata: + name: %s + namespace: %s +spec: + headGroupSpec: + computeTemplate: %s + template: + spec: + containers: + - name: ray-head + image: %s + workerGroupSpecs: + - groupName: worker-group + computeTemplate: %s + replicas: 1 + minReplicas: 1 + maxReplicas: 3 + template: + spec: + containers: + - name: ray-worker + image: %s +`, clusterName, tCtx.GetNamespaceName(), templateName, tCtx.GetRayImage(), templateName, tCtx.GetRayImage()) + + // Send HTTP POST request to apiserver proxy to create RayCluster + _, err := tCtx.SendYAMLRequest("POST", fmt.Sprintf("/apis/ray.io/v1/namespaces/%s/rayclusters", tCtx.GetNamespaceName()), rayClusterYAML) + require.NoError(t, err, "No error expected when sending YAML request") + + t.Cleanup(func() { + tCtx.DeleteRayCluster(t, clusterName) + }) + + // Verify the actual RayCluster was created with correct resources applied by middleware + actualCluster, err := tCtx.GetRayClusterByName(clusterName) + require.NoError(t, err, "No error expected when getting ray cluster") + + // Verify head group has correct resources and annotations + verifyPodSpecResources(t, &actualCluster.Spec.HeadGroupSpec.Template.Spec, "head", computeTemplate) + verifyComputeTemplateAnnotation(t, actualCluster.Spec.HeadGroupSpec.Template.ObjectMeta, templateName) + + // Verify worker group has correct resources and annotations + require.Len(t, actualCluster.Spec.WorkerGroupSpecs, 1, "Expected one worker group") + verifyPodSpecResources(t, &actualCluster.Spec.WorkerGroupSpecs[0].Template.Spec, "worker", computeTemplate) + verifyComputeTemplateAnnotation(t, actualCluster.Spec.WorkerGroupSpecs[0].Template.ObjectMeta, templateName) +} + +func testRayJobWithComputeTemplate(t *testing.T, tCtx *End2EndTestingContext, templateName string, computeTemplate *api.ComputeTemplate) { + jobName := tCtx.GetNextName() + + // Create RayJob YAML with compute template references + rayJobYAML := fmt.Sprintf(` +apiVersion: ray.io/v1 +kind: RayJob +metadata: + name: %s + namespace: %s +spec: + entrypoint: "python -c \"import ray; ray.init(); print('Hello from Ray Job')\"" + rayClusterSpec: + headGroupSpec: + computeTemplate: %s + rayStartParams: + dashboard-host: "0.0.0.0" + template: + spec: + containers: + - name: ray-head + image: %s + workerGroupSpecs: + - groupName: worker-group + computeTemplate: %s + replicas: 1 + minReplicas: 1 + maxReplicas: 1 + rayStartParams: + node-ip-address: "$MY_POD_IP" + template: + spec: + containers: + - name: ray-worker + image: %s +`, jobName, tCtx.GetNamespaceName(), templateName, tCtx.GetRayImage(), templateName, tCtx.GetRayImage()) + + // Send HTTP POST request to apiserver proxy to create RayJob + _, err := tCtx.SendYAMLRequest("POST", fmt.Sprintf("/apis/ray.io/v1/namespaces/%s/rayjobs", tCtx.GetNamespaceName()), rayJobYAML) + require.NoError(t, err, "No error expected when sending YAML request") + + t.Cleanup(func() { + tCtx.DeleteRayJobByName(t, jobName) + }) + + // Verify the actual RayJob was created with correct resources applied by middleware + actualRayJob, err := tCtx.GetRayJobByName(jobName) + require.NoError(t, err, "No error expected when getting ray job") + + // Verify head group has correct resources and annotations + verifyPodSpecResources(t, &actualRayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec, "head", computeTemplate) + verifyComputeTemplateAnnotation(t, actualRayJob.Spec.RayClusterSpec.HeadGroupSpec.Template.ObjectMeta, templateName) + + // Verify worker group has correct resources and annotations + require.Len(t, actualRayJob.Spec.RayClusterSpec.WorkerGroupSpecs, 1, "Expected one worker group") + verifyPodSpecResources(t, &actualRayJob.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec, "worker", computeTemplate) + verifyComputeTemplateAnnotation(t, actualRayJob.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.ObjectMeta, templateName) +} + +func testRayServiceWithComputeTemplate(t *testing.T, tCtx *End2EndTestingContext, templateName string, computeTemplate *api.ComputeTemplate) { + serviceName := tCtx.GetNextName() + + // Create RayService YAML with compute template references + rayServiceYAML := fmt.Sprintf(` +apiVersion: ray.io/v1 +kind: RayService +metadata: + name: %s + namespace: %s +spec: + rayClusterConfig: + headGroupSpec: + computeTemplate: %s + rayStartParams: + dashboard-host: "0.0.0.0" + template: + spec: + containers: + - name: ray-head + image: %s + workerGroupSpecs: + - groupName: worker-group + computeTemplate: %s + replicas: 1 + minReplicas: 1 + maxReplicas: 1 + rayStartParams: + node-ip-address: "$MY_POD_IP" + template: + spec: + containers: + - name: ray-worker + image: %s +`, serviceName, tCtx.GetNamespaceName(), templateName, tCtx.GetRayImage(), templateName, tCtx.GetRayImage()) + + // Send HTTP POST request to apiserver proxy to create RayService + _, err := tCtx.SendYAMLRequest("POST", fmt.Sprintf("/apis/ray.io/v1/namespaces/%s/rayservices", tCtx.GetNamespaceName()), rayServiceYAML) + require.NoError(t, err, "No error expected when sending YAML request") + + t.Cleanup(func() { + tCtx.DeleteRayService(t, serviceName) + }) + + // Verify the actual RayService was created with correct resources applied by middleware + actualRayService, err := tCtx.GetRayServiceByName(serviceName) + require.NoError(t, err, "No error expected when getting ray service") + + // Verify head group has correct resources and annotations + verifyPodSpecResources(t, &actualRayService.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec, "head", computeTemplate) + verifyComputeTemplateAnnotation(t, actualRayService.Spec.RayClusterSpec.HeadGroupSpec.Template.ObjectMeta, templateName) + + // Verify worker group has correct resources and annotations + require.Len(t, actualRayService.Spec.RayClusterSpec.WorkerGroupSpecs, 1, "Expected one worker group") + verifyPodSpecResources(t, &actualRayService.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec, "worker", computeTemplate) + verifyComputeTemplateAnnotation(t, actualRayService.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.ObjectMeta, templateName) +} + +// verifyPodSpecResources verifies that the PodSpec has the expected resources from the compute template +func verifyPodSpecResources(t *testing.T, podSpec *corev1.PodSpec, groupType string, computeTemplate *api.ComputeTemplate) { + require.NotEmpty(t, podSpec.Containers, "Expected at least one container") + + // Find the ray container (ray-head or ray-worker) + expectedContainerName := fmt.Sprintf("ray-%s", groupType) + var rayContainer *corev1.Container + for i := range podSpec.Containers { + if podSpec.Containers[i].Name == expectedContainerName { + rayContainer = &podSpec.Containers[i] + break + } + } + + require.NotNil(t, rayContainer, "Expected to find ray container with name %s", expectedContainerName) + + // Verify CPU and memory resources + require.NotNil(t, rayContainer.Resources.Limits, "Expected resource limits") + require.NotNil(t, rayContainer.Resources.Requests, "Expected resource requests") + + cpuLimit := rayContainer.Resources.Limits[corev1.ResourceCPU] + cpuRequest := rayContainer.Resources.Requests[corev1.ResourceCPU] + require.Equal(t, fmt.Sprint(computeTemplate.GetCpu()), cpuLimit.String(), "CPU limit mismatch") + require.Equal(t, fmt.Sprint(computeTemplate.GetCpu()), cpuRequest.String(), "CPU request mismatch") + + // Check Memory + expectedMemory := fmt.Sprintf("%dGi", computeTemplate.GetMemory()) + memoryLimit := rayContainer.Resources.Limits[corev1.ResourceMemory] + memoryRequest := rayContainer.Resources.Requests[corev1.ResourceMemory] + require.Equal(t, expectedMemory, memoryLimit.String(), "Expected memory limit to be 4Gi") + require.Equal(t, expectedMemory, memoryRequest.String(), "Expected memory request to be 4Gi") + + // Check GPU + gpuLimit := rayContainer.Resources.Limits["nvidia.com/gpu"] + gpuRequest := rayContainer.Resources.Requests["nvidia.com/gpu"] + require.Equal(t, fmt.Sprint(computeTemplate.GetGpu()), gpuLimit.String(), "Expected GPU limit to be 1") + require.Equal(t, fmt.Sprint(computeTemplate.GetGpu()), gpuRequest.String(), "Expected GPU request to be 1") + + // Check extended resources + for name, val := range computeTemplate.ExtendedResources { + extResourceLimit := rayContainer.Resources.Limits[corev1.ResourceName(name)] + extResourceRequest := rayContainer.Resources.Requests[corev1.ResourceName(name)] + require.Equal(t, fmt.Sprint(val), extResourceLimit.String(), "Expected extended resource limit to be 2") + require.Equal(t, fmt.Sprint(val), extResourceRequest.String(), "Expected extended resource request to be 2") + } + + // Verify tolerations are applied to the pod spec + require.Len(t, podSpec.Tolerations, len(computeTemplate.Tolerations), "Toleration count mismatch") + for i, toleration := range podSpec.Tolerations { + expectedToleration := computeTemplate.Tolerations[i] + require.Equal(t, expectedToleration.Key, toleration.Key, "Expected toleration key to be ray.io/node-type") + require.Equal(t, corev1.TolerationOperator(expectedToleration.Operator), toleration.Operator, "Expected toleration operator to be Equal") + require.Equal(t, expectedToleration.Value, toleration.Value, "Expected toleration value to be worker") + require.Equal(t, corev1.TaintEffect(expectedToleration.Effect), toleration.Effect, "Expected toleration effect to be NoSchedule") + + } +} + +// verifyComputeTemplateAnnotation verifies that the compute template annotation is set correctly +func verifyComputeTemplateAnnotation(t *testing.T, objMeta metav1.ObjectMeta, expectedTemplateName string) { + require.NotNil(t, objMeta.Annotations, "Expected annotations to be set") + actualTemplateName := objMeta.Annotations[util.RayClusterComputeTemplateAnnotationKey] + require.Equal(t, expectedTemplateName, actualTemplateName, "Expected compute template annotation to be set correctly") +} diff --git a/apiserver/test/e2e/apiserversdk/event_e2e_test.go b/apiserver/test/e2e/apiserversdk/event_e2e_test.go index 6d9f6e01242..4ae08ddd2f3 100644 --- a/apiserver/test/e2e/apiserversdk/event_e2e_test.go +++ b/apiserver/test/e2e/apiserversdk/event_e2e_test.go @@ -37,7 +37,7 @@ func TestGetRayClusterEvent(t *testing.T) { _, err = rayClient.RayClusters(tCtx.GetNamespaceName()).Create(tCtx.GetCtx(), rayCluster, metav1.CreateOptions{}) require.NoError(t, err) - k8sClient := tCtx.GetK8sHttpClient() + k8sClient := tCtx.GetK8sClient() g := gomega.NewWithT(t) g.Eventually(func() bool { events, err := k8sClient.CoreV1().Events(tCtx.GetNamespaceName()).List(tCtx.GetCtx(), metav1.ListOptions{}) diff --git a/apiserver/test/e2e/apiserversdk/proxy_e2e_test.go b/apiserver/test/e2e/apiserversdk/proxy_e2e_test.go index d0033f2b68e..cb4835a9cae 100644 --- a/apiserver/test/e2e/apiserversdk/proxy_e2e_test.go +++ b/apiserver/test/e2e/apiserversdk/proxy_e2e_test.go @@ -41,7 +41,7 @@ func TestGetRayClusterProxy(t *testing.T) { // Wait for the Ray cluster's head pod to be ready so that we can access the dashboard waitForClusterConditions(t, tCtx, tCtx.GetRayClusterName(), []rayv1.RayClusterConditionType{rayv1.HeadPodReady}) - k8sClient := tCtx.GetK8sHttpClient() + k8sClient := tCtx.GetK8sClient() serviceName := tCtx.GetRayClusterName() + "-head-svc" r := k8sClient.CoreV1().Services(tCtx.GetNamespaceName()).ProxyGet("http", serviceName, "8265", "", nil) resp, err := r.DoRaw(tCtx.GetCtx()) diff --git a/apiserver/test/e2e/apiserversdk/types.go b/apiserver/test/e2e/apiserversdk/types.go index 20f03358ec4..c24a8e6906b 100644 --- a/apiserver/test/e2e/apiserversdk/types.go +++ b/apiserver/test/e2e/apiserversdk/types.go @@ -1,6 +1,7 @@ package apiserversdk import ( + "bytes" "context" "fmt" "net/http" @@ -18,8 +19,11 @@ import ( "k8s.io/client-go/kubernetes" "sigs.k8s.io/controller-runtime/pkg/client/config" - rayv1api "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" - rayv1 "github.com/ray-project/kuberay/ray-operator/pkg/client/clientset/versioned/typed/ray/v1" + kuberayHTTP "github.com/ray-project/kuberay/apiserver/pkg/http" + util "github.com/ray-project/kuberay/apiserversdk/util" + api "github.com/ray-project/kuberay/proto/go_client" + rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" + rayv1client "github.com/ray-project/kuberay/ray-operator/pkg/client/clientset/versioned/typed/ray/v1" ) // GenericEnd2EndTest struct allows for reuse in setting up and running tests @@ -32,14 +36,16 @@ type GenericEnd2EndTest[I proto.Message] struct { // End2EndTestingContext provides a common set of values and methods that // can be used in executing the tests type End2EndTestingContext struct { - ctx context.Context - rayHttpClient rayv1.RayV1Interface - k8sHttpClient *kubernetes.Clientset - k8client *kubernetes.Clientset - apiServerBaseURL string - rayImage string - namespaceName string - clusterName string + ctx context.Context + apiServerHttpClient *http.Client + kuberayAPIServerClient *kuberayHTTP.KuberayAPIServerClient + rayClient rayv1client.RayV1Interface + k8client *kubernetes.Clientset + apiServerBaseURL string + rayImage string + namespaceName string + clusterName string + currentName string } // contextOption is a functional option that allows for building out an instance @@ -54,10 +60,10 @@ func NewEnd2EndTestingContext(t *testing.T) (*End2EndTestingContext, error) { withRayImage(), withBaseURL(), withRayHttpClient(), - withK8sHttpClient(), withK8sClient(), withContext(), withNamespace(), + withAPIServerClient(), ) } @@ -84,20 +90,7 @@ func withRayHttpClient() contextOption { require.NoError(t, err) httpClient := &http.Client{Transport: rt} - testingContext.rayHttpClient, err = rayv1.NewForConfigAndClient(kubernetesConfig, httpClient) - if err != nil { - return err - } - return nil - } -} - -func withK8sHttpClient() contextOption { - return func(t *testing.T, testingContext *End2EndTestingContext) error { - kubernetesConfig, err := config.GetConfig() - require.NoError(t, err) - - testingContext.k8sHttpClient, err = kubernetes.NewForConfig(kubernetesConfig) + testingContext.rayClient, err = rayv1client.NewForConfigAndClient(kubernetesConfig, httpClient) if err != nil { return err } @@ -113,12 +106,12 @@ func withContext() contextOption { } func withBaseURL() contextOption { - return func(_ *testing.T, testingContext *End2EndTestingContext) error { - baseURL := os.Getenv("E2E_API_SERVER_URL") - if strings.TrimSpace(baseURL) == "" { - baseURL = "http://localhost:8888" - } - testingContext.apiServerBaseURL = baseURL + return func(t *testing.T, testingContext *End2EndTestingContext) error { + // Use ProxyRoundTripper with Kubernetes API server URL. The + // ProxyRoundTripper will route to the kuberay-apiserver service + kubernetesConfig, err := config.GetConfig() + require.NoError(t, err) + testingContext.apiServerBaseURL = kubernetesConfig.Host return nil } } @@ -178,16 +171,12 @@ func (e2etc *End2EndTestingContext) GetCtx() context.Context { return e2etc.ctx } -func (e2etc *End2EndTestingContext) GetK8sHttpClient() *kubernetes.Clientset { - return e2etc.k8sHttpClient -} - -func (e2etc *End2EndTestingContext) GetRayHttpClient() rayv1.RayV1Interface { - return e2etc.rayHttpClient +func (e2etc *End2EndTestingContext) GetRayHttpClient() rayv1client.RayV1Interface { + return e2etc.rayClient } -func (e2etc *End2EndTestingContext) GetRayClusterByName(clusterName string) (*rayv1api.RayCluster, error) { - return e2etc.rayHttpClient.RayClusters(e2etc.namespaceName).Get(e2etc.ctx, clusterName, metav1.GetOptions{}) +func (e2etc *End2EndTestingContext) GetRayClusterByName(clusterName string) (*rayv1.RayCluster, error) { + return e2etc.rayClient.RayClusters(e2etc.namespaceName).Get(e2etc.ctx, clusterName, metav1.GetOptions{}) } func (e2etc *End2EndTestingContext) GetRayClusterName() string { @@ -201,3 +190,87 @@ func (e2etc *End2EndTestingContext) GetNamespaceName() string { func (e2etc *End2EndTestingContext) GetRayImage() string { return e2etc.rayImage } + +func withAPIServerClient() contextOption { + return func(t *testing.T, testingContext *End2EndTestingContext) error { + kubernetesConfig, err := config.GetConfig() + require.NoError(t, err) + + rt, err := newProxyRoundTripper(kubernetesConfig) + require.NoError(t, err) + httpClient := &http.Client{Transport: rt, Timeout: time.Duration(10) * time.Second} + + testingContext.apiServerHttpClient = httpClient + + retryCfg := util.RetryConfig{ + MaxRetry: util.HTTPClientDefaultMaxRetry, + BackoffFactor: util.HTTPClientDefaultBackoffFactor, + InitBackoff: util.HTTPClientDefaultInitBackoff, + MaxBackoff: util.HTTPClientDefaultMaxBackoff, + OverallTimeout: util.HTTPClientDefaultOverallTimeout, + } + + testingContext.kuberayAPIServerClient = kuberayHTTP.NewKuberayAPIServerClient(testingContext.apiServerBaseURL, testingContext.apiServerHttpClient, retryCfg) + + return nil + } +} + +func (e2etc *End2EndTestingContext) GetRayAPIServerClient() *kuberayHTTP.KuberayAPIServerClient { + return e2etc.kuberayAPIServerClient +} + +func (e2etc *End2EndTestingContext) GetK8sClient() *kubernetes.Clientset { + return e2etc.k8client +} + +func (e2etc *End2EndTestingContext) GetNextName() string { + e2etc.currentName = petnames.Name() + return e2etc.currentName +} + +func (e2etc *End2EndTestingContext) DeleteComputeTemplate(t *testing.T, computeTemplateName string) { + deleteComputeTemplateRequest := &api.DeleteClusterRequest{ + Name: computeTemplateName, + Namespace: e2etc.namespaceName, + } + _, err := e2etc.kuberayAPIServerClient.DeleteComputeTemplate((*api.DeleteComputeTemplateRequest)(deleteComputeTemplateRequest)) + require.NoErrorf(t, err, "No error expected while deleting a compute template (%s, %s)", computeTemplateName, e2etc.namespaceName) +} + +func (e2etc *End2EndTestingContext) DeleteRayCluster(t *testing.T, clusterName string) { + err := e2etc.rayClient.RayClusters(e2etc.namespaceName).Delete(e2etc.ctx, clusterName, metav1.DeleteOptions{}) + require.NoError(t, err, "No error expected when deleting ray cluster") +} + +func (e2etc *End2EndTestingContext) DeleteRayJobByName(t *testing.T, jobName string) { + err := e2etc.rayClient.RayJobs(e2etc.namespaceName).Delete(e2etc.ctx, jobName, metav1.DeleteOptions{}) + require.NoError(t, err, "No error expected when deleting ray job") +} + +func (e2etc *End2EndTestingContext) DeleteRayService(t *testing.T, serviceName string) { + err := e2etc.rayClient.RayServices(e2etc.namespaceName).Delete(e2etc.ctx, serviceName, metav1.DeleteOptions{}) + require.NoError(t, err, "No error expected when deleting ray service") +} + +func (e2etc *End2EndTestingContext) GetRayJobByName(jobName string) (*rayv1.RayJob, error) { + return e2etc.rayClient.RayJobs(e2etc.namespaceName).Get(e2etc.ctx, jobName, metav1.GetOptions{}) +} + +func (e2etc *End2EndTestingContext) GetRayServiceByName(serviceName string) (*rayv1.RayService, error) { + return e2etc.rayClient.RayServices(e2etc.namespaceName).Get(e2etc.ctx, serviceName, metav1.GetOptions{}) +} + +// SendYAMLRequest sends a YAML request to the apiserver proxy with the specified method, path, and YAML content +func (e2etc *End2EndTestingContext) SendYAMLRequest(method, path, yamlContent string) ([]byte, error) { + url := e2etc.apiServerBaseURL + path + req, err := http.NewRequestWithContext(e2etc.ctx, method, url, bytes.NewBufferString(yamlContent)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/yaml") + + bodyBytes, _, err := e2etc.kuberayAPIServerClient.ExecuteHttpRequest(req, url) + + return bodyBytes, err +} diff --git a/apiserversdk/proxy.go b/apiserversdk/proxy.go index 49c276fd514..0d3d11349b8 100644 --- a/apiserversdk/proxy.go +++ b/apiserversdk/proxy.go @@ -15,6 +15,7 @@ import ( "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" + "github.com/ray-project/kuberay/apiserver/pkg/manager" apiserversdkutil "github.com/ray-project/kuberay/apiserversdk/util" rayutil "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" ) @@ -24,7 +25,7 @@ type MuxConfig struct { Middleware func(http.Handler) http.Handler } -func NewMux(config MuxConfig) (*http.ServeMux, error) { +func NewMux(config MuxConfig, clientManager manager.ClientManagerInterface) (*http.ServeMux, error) { u, err := url.Parse(config.KubernetesConfig.Host) // parse the K8s API server URL from the KubernetesConfig. if err != nil { return nil, fmt.Errorf("failed to parse url %s from config: %w", config.KubernetesConfig.Host, err) @@ -46,6 +47,16 @@ func NewMux(config MuxConfig) (*http.ServeMux, error) { mux.Handle("GET /api/v1/namespaces/{namespace}/events", withFieldSelector(handler, "involvedObject.apiVersion=ray.io/v1")) // allow querying KubeRay CR events. k8sClient := kubernetes.NewForConfigOrDie(config.KubernetesConfig) + + // Compute Template middleware + ctMiddleware := apiserversdkutil.NewComputeTemplateMiddleware(clientManager) + mux.Handle("POST /apis/ray.io/v1/namespaces/{namespace}/rayclusters", ctMiddleware(handler)) + mux.Handle("PUT /apis/ray.io/v1/namespaces/{namespace}/rayclusters/{name}", ctMiddleware(handler)) + mux.Handle("POST /apis/ray.io/v1/namespaces/{namespace}/rayjobs", ctMiddleware(handler)) + mux.Handle("PUT /apis/ray.io/v1/namespaces/{namespace}/rayjobs/{name}", ctMiddleware(handler)) + mux.Handle("POST /apis/ray.io/v1/namespaces/{namespace}/rayservices", ctMiddleware(handler)) + mux.Handle("PUT /apis/ray.io/v1/namespaces/{namespace}/rayservices/{name}", ctMiddleware(handler)) + requireKubeRayServiceHandler := requireKubeRayService(handler, k8sClient) // Allow accessing KubeRay dashboards and job submissions. // Note: We also register "/proxy" to avoid the trailing slash redirection diff --git a/apiserversdk/proxy_test.go b/apiserversdk/proxy_test.go index 6aa09904b76..7725057b917 100644 --- a/apiserversdk/proxy_test.go +++ b/apiserversdk/proxy_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/golang/mock/gomock" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" @@ -23,6 +24,7 @@ import ( "k8s.io/client-go/rest" "sigs.k8s.io/controller-runtime/pkg/envtest" + "github.com/ray-project/kuberay/apiserver/pkg/manager" apiserverutil "github.com/ray-project/kuberay/apiserversdk/util" rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" rayutil "github.com/ray-project/kuberay/ray-operator/controllers/ray/utils" @@ -56,6 +58,10 @@ var _ = BeforeSuite(func(_ SpecContext) { Expect(err).ToNot(HaveOccurred()) Expect(cfg).ToNot(BeNil()) + ctrl := gomock.Controller{} + // mock client manager + mockClientManager := manager.NewMockClientManagerInterface(&ctrl) + mux, err := NewMux(MuxConfig{ KubernetesConfig: cfg, Middleware: func(handler http.Handler) http.Handler { @@ -64,7 +70,7 @@ var _ = BeforeSuite(func(_ SpecContext) { handler.ServeHTTP(w, r) }) }, - }) + }, mockClientManager) Expect(err).ToNot(HaveOccurred()) Expect(mux).ToNot(BeNil()) diff --git a/apiserversdk/util/template.go b/apiserversdk/util/template.go new file mode 100644 index 00000000000..537d6881231 --- /dev/null +++ b/apiserversdk/util/template.go @@ -0,0 +1,275 @@ +package util + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "k8s.io/klog/v2" + "sigs.k8s.io/yaml" + + "github.com/ray-project/kuberay/apiserver/pkg/manager" + "github.com/ray-project/kuberay/apiserver/pkg/model" + "github.com/ray-project/kuberay/apiserver/pkg/util" + api "github.com/ray-project/kuberay/proto/go_client" +) + +// compute_template_middleware.go +func NewComputeTemplateMiddleware(clientManager manager.ClientManagerInterface) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + namespace := r.PathValue("namespace") + + // Read request body + bodyBytes, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request body", http.StatusBadRequest) + return + } + defer r.Body.Close() + + // Convert request body to Golang Map object + contentType := r.Header.Get("Content-Type") + requestMap, err := convertRequestBodyToMap(bodyBytes, contentType) + if err != nil { + klog.Errorf("Failed to convert request body to map: %v", err) + http.Error(w, "Failed to convert request body to Golang map object", http.StatusBadRequest) + return + } + spec, ok := requestMap["spec"].(map[string]any) + if !ok { + klog.Infof("ComputeTemplate middleware: spec is not a map, skipping compute template processing") + // Continue with original request body without compute template processing + r.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + next.ServeHTTP(w, r) + return + } + + // Process compute templates and apply them to the request + var headGroupMap map[string]any + var workerGroupMaps []any + if rayClusterSpec, ok := spec["rayClusterSpec"].(map[string]any); ok { + // For RayJob, get from spec.rayClusterSpec.headGroupSpec and spec.rayClusterSpec.workerGroupSpecs + if headGroup, ok := rayClusterSpec["headGroupSpec"].(map[string]any); ok { + headGroupMap = headGroup + } + if workerGroups, ok := rayClusterSpec["workerGroupSpecs"].([]any); ok { + workerGroupMaps = workerGroups + } + } else if rayClusterConfig, ok := spec["rayClusterConfig"].(map[string]any); ok { + // For RayService, get from spec.rayClusterConfig.headGroupSpec and spec.rayClusterConfig.workerGroupSpecs + if headGroup, ok := rayClusterConfig["headGroupSpec"].(map[string]any); ok { + headGroupMap = headGroup + } + if workerGroups, ok := rayClusterConfig["workerGroupSpecs"].([]any); ok { + workerGroupMaps = workerGroups + } + } else { + // For RayCluster, get from spec.headGroupSpec and spec.workerGroupSpecs + if headGroup, ok := spec["headGroupSpec"].(map[string]any); ok { + headGroupMap = headGroup + } + if workerGroups, ok := spec["workerGroupSpecs"].([]any); ok { + workerGroupMaps = workerGroups + } + } + + resourceManager := manager.NewResourceManager(clientManager) + + if headGroupMap != nil { + computeTemplate, err := getComputeTemplate(context.Background(), resourceManager, headGroupMap, namespace) + if err != nil { + klog.Errorf("ComputeTemplate middleware: Failed to get compute template for head group: %v", err) + http.Error(w, err.Error(), http.StatusUnprocessableEntity) + return + } + if computeTemplate != nil { + applyComputeTemplateToRequest(computeTemplate, &headGroupMap, "head") + } + } + + // Apply compute templates to worker groups + for i, workerGroupSpec := range workerGroupMaps { + if workerGroupMap, ok := workerGroupSpec.(map[string]any); ok { + computeTemplate, err := getComputeTemplate(context.Background(), resourceManager, workerGroupMap, namespace) + if err != nil { + klog.Errorf("ComputeTemplate middleware: Failed to get compute template for worker group %d: %v", i, err) + http.Error(w, err.Error(), http.StatusUnprocessableEntity) + return + } + if computeTemplate != nil { + klog.Infof("ComputeTemplate middleware: Applying compute template %s to worker group %d", computeTemplate.Name, i) + applyComputeTemplateToRequest(computeTemplate, &workerGroupMap, "worker") + } + } + } + + // Convert the modified requestMap to JSON since K8s API expects JSON format + jsonBytes, err := convertMapToJSON(requestMap) + if err != nil { + klog.Errorf("ComputeTemplate middleware: Failed to convert to JSON: %v", err) + http.Error(w, "Failed to process request", http.StatusInternalServerError) + return + } + + klog.Infof("ComputeTemplate middleware: Successfully processed request, sending to next handler") + // Update Content-Type to application/json and Content-Length header to match the new body size + r.Header.Set("Content-Type", "application/json") + r.ContentLength = int64(len(jsonBytes)) + r.Header.Set("Content-Length", fmt.Sprintf("%d", len(jsonBytes))) + r.Body = io.NopCloser(bytes.NewReader(jsonBytes)) + + next.ServeHTTP(w, r) + }) + } +} + +// Convert the request body to map, handling both JSON and YAML formats +func convertRequestBodyToMap(requestBody []byte, contentType string) (map[string]any, error) { + var requestMap map[string]any + + // Check content type to determine format + if strings.Contains(contentType, "application/json") { + if err := json.Unmarshal(requestBody, &requestMap); err != nil { + return nil, fmt.Errorf("failed to unmarshal JSON: %w", err) + } + } else if strings.Contains(contentType, "application/yaml") { + if err := yaml.Unmarshal(requestBody, &requestMap); err != nil { + return nil, fmt.Errorf("failed to unmarshal YAML: %w", err) + } + } else { + return nil, fmt.Errorf("Cannot unmarshal content type that's not JSON or YAML: %s", contentType) + } + + return requestMap, nil +} + +// Convert YAML request map to JSON bytes for K8s API +func convertMapToJSON(requestMap map[string]any) ([]byte, error) { + jsonBytes, err := json.Marshal(requestMap) + if err != nil { + return nil, fmt.Errorf("failed to marshal request map to JSON: %w", err) + } + return jsonBytes, nil +} + +// Get the compute template by extracting the name from request and query the compute template +func getComputeTemplate(ctx context.Context, resourceManager *manager.ResourceManager, clusterSpecMap map[string]any, nameSpace string) (*api.ComputeTemplate, error) { + name, ok := clusterSpecMap["computeTemplate"].(string) + if !ok { + // No compute template name found, directly return + klog.Infof("ComputeTemplate middleware: No computeTemplate field found in spec") + return nil, nil + } + + configMap, err := resourceManager.GetComputeTemplate(ctx, name, nameSpace) + if err != nil { + return nil, fmt.Errorf("Cannot get compute template for name '%s' in namespace '%s', error: %w", name, nameSpace, err) + } + computeTemplate := model.FromKubeToAPIComputeTemplate(configMap) + + return computeTemplate, nil +} + +// Apply the computeTemplate into the clusterSpec map. The clusterSpec map is the map representation +// for headGroupSpec or workerGroupSpec +func applyComputeTemplateToRequest(computeTemplate *api.ComputeTemplate, clusterSpecMap *map[string]any, group string) { + // calculate resources + cpu := fmt.Sprint(computeTemplate.GetCpu()) + memoryUnit := computeTemplate.GetMemoryUnit() + memory := fmt.Sprintf("%d%s", computeTemplate.GetMemory(), memoryUnit) + + if template, ok := (*clusterSpecMap)["template"].(map[string]any); ok { + // Add compute template name to annotation + + metadata, ok := template["metadata"].(map[string]any) + if !ok { + metadata = make(map[string]any) + template["metadata"] = metadata + } + annotations, ok := metadata["annotations"].(map[string]any) + if !ok { + annotations = make(map[string]any) + metadata["annotations"] = annotations + } + annotations[util.RayClusterComputeTemplateAnnotationKey] = computeTemplate.Name + + // apply resources to containers + if spec, ok := template["spec"].(map[string]any); ok { + if containers, ok := spec["containers"].([]any); ok { + for _, container := range containers { + if containerMap, ok := container.(map[string]any); ok { + // Get or create resources section for this container + resources, exists := containerMap["resources"].(map[string]any) + if !exists { + resources = make(map[string]any) + containerMap["resources"] = resources + } + + // Set limits + limits, exists := resources["limits"].(map[string]any) + if !exists { + limits = make(map[string]any) + resources["limits"] = limits + } + limits["cpu"] = cpu + limits["memory"] = memory + + // Set requests + requests, exists := resources["requests"].(map[string]any) + if !exists { + requests = make(map[string]any) + resources["requests"] = requests + } + requests["cpu"] = cpu + requests["memory"] = memory + + // Only apply followings if container name is "ray-head" for head group or "ray-worker" + // for worker group + if containerMap["name"] == fmt.Sprintf("ray-%s", group) { + if gpu := computeTemplate.GetGpu(); gpu != 0 { + accelerator := "nvidia.com/gpu" + if len(computeTemplate.GetGpuAccelerator()) != 0 { + accelerator = computeTemplate.GetGpuAccelerator() + } + limits[accelerator] = gpu + requests[accelerator] = gpu + } + + for k, v := range computeTemplate.GetExtendedResources() { + limits[k] = v + requests[k] = v + } + + } + } + } + } + + if computeTemplate.Tolerations != nil { + // Get existing tolerations + var tolerations []any + if existingTolerations, exists := spec["tolerations"].([]any); exists { + tolerations = existingTolerations + } + + // Add new tolerations from compute template + for _, t := range computeTemplate.Tolerations { + toleration := map[string]any{ + "key": t.Key, + "operator": t.Operator, + "value": t.Value, + "effect": t.Effect, + } + tolerations = append(tolerations, toleration) + } + + spec["tolerations"] = tolerations + } + } + } +}