diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go index 31a86b44df..dff1970665 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray.go @@ -79,10 +79,22 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC cfg := GetConfig() headReplicas := int32(1) headNodeRayStartParams := make(map[string]string) +<<<<<<< HEAD if rayJob.RayCluster.HeadGroupSpec != nil && rayJob.RayCluster.HeadGroupSpec.RayStartParams != nil { headNodeRayStartParams = rayJob.RayCluster.HeadGroupSpec.RayStartParams } else if headNode := cfg.Defaults.HeadNode; len(headNode.StartParameters) > 0 { headNodeRayStartParams = headNode.StartParameters +======= + headGroupResources := &v1.ResourceRequirements{} + if rayJob.RayCluster.HeadGroupSpec != nil{ + if rayJob.RayCluster.HeadGroupSpec.RayStartParams != nil { + headNodeRayStartParams = rayJob.RayCluster.HeadGroupSpec.RayStartParams + } + headGroupResources, err = flytek8s.ToK8sResourceRequirements(rayJob.RayCluster.HeadGroupSpec.Resources) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources[%v], Err: [%v]", headGroupResources, err.Error()) + } +>>>>>>> flyteplugins/add-ray-head-worker-resources } if _, exist := headNodeRayStartParams[IncludeDashboard]; !exist { @@ -101,11 +113,20 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC headNodeRayStartParams[DisableUsageStatsStartParameter] = "true" } + if rayJob.RayCluster.Namespace != "" { + objectMeta.Namespace = rayJob.RayCluster.Namespace + } + enableIngress := true rayClusterSpec := rayv1alpha1.RayClusterSpec{ HeadGroupSpec: rayv1alpha1.HeadGroupSpec{ +<<<<<<< HEAD Template: buildHeadPodTemplate(&container, podSpec, objectMeta, taskCtx), ServiceType: v1.ServiceType(cfg.ServiceType), +======= + Template: buildHeadPodTemplate(&container, podSpec, objectMeta, taskCtx, headGroupResources), + ServiceType: v1.ServiceType(GetConfig().ServiceType), +>>>>>>> flyteplugins/add-ray-head-worker-resources Replicas: &headReplicas, EnableIngress: &enableIngress, RayStartParams: headNodeRayStartParams, @@ -114,7 +135,12 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC } for _, spec := range rayJob.RayCluster.WorkerGroupSpec { - workerPodTemplate := buildWorkerPodTemplate(&container, podSpec, objectMeta, taskCtx) + workerGroupResources, err := flytek8s.ToK8sResourceRequirements(spec.Resources) + if err != nil { + return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification on Resources[%v], Err: [%v]", workerGroupResources, err.Error()) + } + + workerPodTemplate := buildWorkerPodTemplate(&container, podSpec, objectMeta, taskCtx, workerGroupResources) minReplicas := spec.Replicas maxReplicas := spec.Replicas @@ -153,7 +179,10 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC rayClusterSpec.WorkerGroupSpecs = append(rayClusterSpec.WorkerGroupSpecs, workerNodeSpec) } - serviceAccountName := flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) + serviceAccountName := rayJob.RayCluster.K8SServiceAccount + if serviceAccountName == "" { + serviceAccountName = flytek8s.GetServiceAccountNameFromTaskExecutionMetadata(taskCtx.TaskExecutionMetadata()) + } rayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName = serviceAccountName for index := range rayClusterSpec.WorkerGroupSpecs { @@ -180,12 +209,16 @@ func (rayJobResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsC return &rayJobObject, nil } -func buildHeadPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec { +func buildHeadPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMeta *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext, resources *v1.ResourceRequirements) v1.PodTemplateSpec { // Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L97 // They should always be the same, so we could hard code here. primaryContainer := container.DeepCopy() primaryContainer.Name = "ray-head" + if len(resources.Requests) >= 1 || len(resources.Limits) >= 1 { + primaryContainer.Resources = *resources + } + envs := []v1.EnvVar{ { Name: "MY_POD_IP", @@ -232,7 +265,7 @@ func buildHeadPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMe return podTemplateSpec } -func buildWorkerPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMetadata *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext) v1.PodTemplateSpec { +func buildWorkerPodTemplate(container *v1.Container, podSpec *v1.PodSpec, objectMetadata *metav1.ObjectMeta, taskCtx pluginsCore.TaskExecutionContext, resources *v1.ResourceRequirements) v1.PodTemplateSpec { // Some configs are copy from https://github.com/ray-project/kuberay/blob/b72e6bdcd9b8c77a9dc6b5da8560910f3a0c3ffd/apiserver/pkg/util/cluster.go#L185 // They should always be the same, so we could hard code here. initContainers := []v1.Container{ @@ -252,6 +285,10 @@ func buildWorkerPodTemplate(container *v1.Container, podSpec *v1.PodSpec, object primaryContainer.Args = []string{} + if len(resources.Requests) >= 1 || len(resources.Limits) >= 1 { + primaryContainer.Resources = *resources + } + envs := []v1.EnvVar{ { Name: "RAY_DISABLE_DOCKER_CPU_WARNING", diff --git a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go index 5e7b82d55c..7456328834 100644 --- a/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go @@ -33,6 +33,8 @@ import ( const testImage = "image://" const serviceAccount = "ray_sa" +const serviceAccountOverride = "ray_sa_override" +const namespaceOverride = "ray_namespace_override" var ( dummyEnvVars = []*core.KeyValuePair{ @@ -43,6 +45,52 @@ var ( "test-args", } + headResourceOverride = core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "1000m", + }, + { + Name: core.Resources_MEMORY, + Value: "2Gi", + }, + }, + Limits: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "2000m", + }, + { + Name: core.Resources_MEMORY, + Value: "4Gi", + }, + }, + } + + workerResourceOverride = core.Resources{ + Requests: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "5", + }, + { + Name: core.Resources_MEMORY, + Value: "10G", + }, + }, + Limits: []*core.Resources_ResourceEntry{ + { + Name: core.Resources_CPU, + Value: "10", + }, + { + Name: core.Resources_MEMORY, + Value: "20G", + }, + }, + } + resourceRequirements = &corev1.ResourceRequirements{ Limits: corev1.ResourceList{ corev1.ResourceCPU: resource.MustParse("1000m"), @@ -68,6 +116,17 @@ func dummyRayCustomObj() *plugins.RayJob { } } +func dummyRayCustomObjWithOverrides() *plugins.RayJob { + return &plugins.RayJob{ + RayCluster: &plugins.RayCluster{ + K8SServiceAccount: serviceAccountOverride, + Namespace: namespaceOverride, + HeadGroupSpec: &plugins.HeadGroupSpec{RayStartParams: map[string]string{"num-cpus": "1"}, Resources: &headResourceOverride}, + WorkerGroupSpec: []*plugins.WorkerGroupSpec{{GroupName: workerGroupName, Replicas: 3, Resources: &workerResourceOverride}}, + }, + } +} + func dummyRayTaskTemplate(id string, rayJobObj *plugins.RayJob) *core.TaskTemplate { ptObjJSON, err := utils.MarshalToString(rayJobObj) @@ -172,6 +231,7 @@ func TestBuildResourceRay(t *testing.T) { assert.True(t, ok) headReplica := int32(1) +<<<<<<< HEAD assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Replicas, &headReplica) assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName, serviceAccount) assert.Equal(t, ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams, @@ -192,6 +252,72 @@ func TestBuildResourceRay(t *testing.T) { assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations, map[string]string{"annotation-1": "val1"}) assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels, map[string]string{"label-1": "val1"}) assert.Equal(t, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations, toleration) +======= + assert.Equal(t, &headReplica, ray.Spec.RayClusterSpec.HeadGroupSpec.Replicas) + assert.Equal(t, serviceAccount, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName) + assert.Equal(t, map[string]string{"dashboard-host": "0.0.0.0", "include-dashboard": "true", "node-ip-address": "$MY_POD_IP", "num-cpus": "1"}, + ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams) + assert.Equal(t, map[string]string{"annotation-1": "val1"}, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Annotations) + assert.Equal(t, map[string]string{"label-1": "val1"}, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Labels) + assert.Equal(t, toleration, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Tolerations) + + workerReplica := int32(3) + assert.Equal(t, &workerReplica, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Replicas) + assert.Equal(t, &workerReplica, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MinReplicas) + assert.Equal(t, &workerReplica, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MaxReplicas) + assert.Equal(t, workerGroupName, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].GroupName) + assert.Equal(t, serviceAccount, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName) + assert.Equal(t, map[string]string{"node-ip-address": "$MY_POD_IP"}, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams) + assert.Equal(t, map[string]string{"annotation-1": "val1"}, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations) + assert.Equal(t, map[string]string{"label-1": "val1"}, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels) + assert.Equal(t, toleration, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations) +} + +func TestBuildResourceRayWithOverrides(t *testing.T) { + rayJobResourceHandler := rayJobResourceHandler{} + taskTemplate := dummyRayTaskTemplate("ray-id", dummyRayCustomObjWithOverrides()) + expectedHeadResources, _ := flytek8s.ToK8sResourceRequirements(&headResourceOverride) + expectedWorkerResources, _ := flytek8s.ToK8sResourceRequirements(&workerResourceOverride) + toleration := []corev1.Toleration{{ + Key: "storage", + Value: "dedicated", + Operator: corev1.TolerationOpExists, + Effect: corev1.TaintEffectNoSchedule, + }} + err := config.SetK8sPluginConfig(&config.K8sPluginConfig{DefaultTolerations: toleration}) + assert.Nil(t, err) + + RayResource, err := rayJobResourceHandler.BuildResource(context.TODO(), dummyRayTaskContext(taskTemplate)) + assert.Nil(t, err) + + assert.NotNil(t, RayResource) + ray, ok := RayResource.(*rayv1alpha1.RayJob) + assert.True(t, ok) + + headReplica := int32(1) + assert.Equal(t, namespaceOverride, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.ObjectMeta.Namespace) + assert.Equal(t, &headReplica, ray.Spec.RayClusterSpec.HeadGroupSpec.Replicas) + assert.Equal(t, serviceAccountOverride, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.ServiceAccountName) + assert.Equal(t, map[string]string{"dashboard-host": "0.0.0.0", "include-dashboard": "true", "node-ip-address": "$MY_POD_IP", "num-cpus": "1"}, + ray.Spec.RayClusterSpec.HeadGroupSpec.RayStartParams) + assert.Equal(t, map[string]string{"annotation-1": "val1"}, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Annotations) + assert.Equal(t, map[string]string{"label-1": "val1"}, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Labels) + assert.Equal(t, toleration, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Tolerations) + assert.Equal(t, *expectedHeadResources, ray.Spec.RayClusterSpec.HeadGroupSpec.Template.Spec.Containers[0].Resources) + + workerReplica := int32(3) + assert.Equal(t, namespaceOverride, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.ObjectMeta.Namespace) + assert.Equal(t, &workerReplica, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Replicas) + assert.Equal(t, &workerReplica, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MinReplicas) + assert.Equal(t, &workerReplica, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].MaxReplicas) + assert.Equal(t, workerGroupName, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].GroupName) + assert.Equal(t, serviceAccountOverride, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.ServiceAccountName) + assert.Equal(t, map[string]string{"node-ip-address": "$MY_POD_IP"}, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].RayStartParams) + assert.Equal(t, map[string]string{"annotation-1": "val1"}, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Annotations) + assert.Equal(t, map[string]string{"label-1": "val1"}, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Labels) + assert.Equal(t, toleration, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Tolerations) + assert.Equal(t, *expectedWorkerResources, ray.Spec.RayClusterSpec.WorkerGroupSpecs[0].Template.Spec.Containers[0].Resources) +>>>>>>> flyteplugins/add-ray-head-worker-resources } func TestDefaultStartParameters(t *testing.T) {