diff --git a/internal/scheduler/gpuresources/gpuresources.go b/internal/scheduler/gpuresources/gpuresources.go index c3759fad..53116a25 100644 --- a/internal/scheduler/gpuresources/gpuresources.go +++ b/internal/scheduler/gpuresources/gpuresources.go @@ -38,7 +38,7 @@ var _ framework.PreFilterPlugin = &GPUFit{} var _ framework.FilterPlugin = &GPUFit{} var _ framework.ScorePlugin = &GPUFit{} var _ framework.ReservePlugin = &GPUFit{} -var _ framework.PostBindPlugin = &GPUFit{} +var _ framework.PreBindPlugin = &GPUFit{} var _ framework.EnqueueExtensions = &GPUFit{} type GPUFit struct { @@ -462,35 +462,38 @@ func (s *GPUFit) Unreserve(ctx context.Context, state fwk.CycleState, pod *v1.Po }, schedulingResult.FinalGPUs, pod.ObjectMeta) } -func (s *GPUFit) PostBind(ctx context.Context, state fwk.CycleState, pod *v1.Pod, nodeName string) { +func (s *GPUFit) PreBindPreFlight(ctx context.Context, state fwk.CycleState, pod *v1.Pod, nodeName string) *fwk.Status { if !utils.IsTensorFusionWorker(pod) { - return + return fwk.NewStatus(fwk.Skip, "skip for non tensor-fusion worker") } + return fwk.NewStatus(fwk.Success, "") +} + +func (s *GPUFit) PreBind(ctx context.Context, state fwk.CycleState, pod *v1.Pod, nodeName string) *fwk.Status { - s.logger.Info("PostBinding pod for GPU resources", "pod", pod.Name, "node", nodeName) gpuSchedulingResult, err := state.Read(CycleStateGPUSchedulingResult) if err != nil { s.logger.Error(err, "failed to read gpu scheduling result", "pod", pod.Name) - return + return fwk.NewStatus(fwk.Error, "failed to read gpu scheduling result: "+err.Error()) } - // write the allocated GPU info to Pod in bindingCycle, before default binder changing the Pod nodeName info + gpuIDs := strings.Join(gpuSchedulingResult.(*GPUSchedulingStateData).FinalGPUs, ",") - s.logger.Info("PostBinding pod for GPU resources", "pod", pod.Name, "node", nodeName, "gpuIDs", gpuIDs) + s.logger.Info("PreBinding pod for GPU resources", "pod", pod.Name, "node", nodeName, "gpuIDs", gpuIDs) - // Patch GPU device IDs annotation patch := []byte(`[{ "op": "add", "path": "/metadata/annotations/` + utils.EscapeJSONPointer(constants.GPUDeviceIDsAnnotation) + `", "value": "` + gpuIDs + `"}]`) - err = s.client.Patch(s.ctx, pod, client.RawPatch(types.JSONPatchType, patch)) - if err != nil { + if err := s.client.Patch(s.ctx, pod, client.RawPatch(types.JSONPatchType, patch)); err != nil { s.logger.Error(err, "failed to patch gpu device ids", "pod", pod.Name) s.fh.EventRecorder().Eventf(pod, pod, v1.EventTypeWarning, "GPUDeviceAllocatedFailed", "Attach GPU device ID info failed", "Can not add GPU device IDs: "+gpuIDs) - } else { - s.fh.EventRecorder().Eventf(pod, pod, v1.EventTypeNormal, "GPUDeviceAllocated", - "Attach GPU device ID info", "Attach TensorFusion GPU device IDs to Pod: "+gpuIDs) + return fwk.NewStatus(fwk.Error, "failed to patch gpu device ids: "+err.Error()) } + + s.fh.EventRecorder().Eventf(pod, pod, v1.EventTypeNormal, "GPUDeviceAllocated", + "Attach GPU device ID info", "Attach TensorFusion GPU device IDs to Pod: "+gpuIDs) + return fwk.NewStatus(fwk.Success, "") } func (s *GPUFit) EventsToRegister(_ context.Context) ([]fwk.ClusterEventWithHint, error) { diff --git a/internal/scheduler/gpuresources/gpuresources_test.go b/internal/scheduler/gpuresources/gpuresources_test.go index 5707a640..1aadb336 100644 --- a/internal/scheduler/gpuresources/gpuresources_test.go +++ b/internal/scheduler/gpuresources/gpuresources_test.go @@ -511,8 +511,8 @@ func (s *GPUResourcesSuite) TestReserveAndUnreserve() { s.Len(gpu.Status.RunningApps, 1) } -func (s *GPUResourcesSuite) TestPostBind() { - log.FromContext(s.ctx).Info("Running TestPostBind") +func (s *GPUResourcesSuite) TestPreBind() { + log.FromContext(s.ctx).Info("Running TestPreBind") state := framework.NewCycleState() pod := s.makePod("p1", map[string]string{ @@ -528,7 +528,8 @@ func (s *GPUResourcesSuite) TestPostBind() { reserveStatus := s.plugin.Reserve(s.ctx, state, pod, "node-a") s.Require().True(reserveStatus.IsSuccess()) - s.plugin.PostBind(s.ctx, state, pod, "node-a") + preBindStatus := s.plugin.PreBind(s.ctx, state, pod, "node-a") + s.Require().True(preBindStatus.IsSuccess()) updatedPod := &v1.Pod{} s.NoError(s.client.Get(s.ctx, types.NamespacedName{Name: "p1", Namespace: "ns1"}, updatedPod)) @@ -669,8 +670,27 @@ func (s *GPUResourcesSuite) TestUnreserve_ErrorHandling() { }) } -func (s *GPUResourcesSuite) TestPostBind_ErrorHandling() { - log.FromContext(s.ctx).Info("Running TestPostBind_ErrorHandling") +func (s *GPUResourcesSuite) TestPreBindPreFlight() { + log.FromContext(s.ctx).Info("Running TestPreBindPreFlight") + state := framework.NewCycleState() + + // TensorFusion worker pod should return Success + tfPod := s.makePod("tf-pod", map[string]string{ + constants.GpuCountAnnotation: "1", + constants.TFLOPSRequestAnnotation: "100", + constants.VRAMRequestAnnotation: "10Gi", + }) + status := s.plugin.PreBindPreFlight(s.ctx, state, tfPod, "node-a") + s.Equal(fwk.Success, status.Code()) + + // Non-TensorFusion pod should return Skip + nonTFPod := s.makeNonTensorFusionPod("non-tf-pod", 1) + status = s.plugin.PreBindPreFlight(s.ctx, state, nonTFPod, "node-a") + s.Equal(fwk.Skip, status.Code()) +} + +func (s *GPUResourcesSuite) TestPreBind_ErrorHandling() { + log.FromContext(s.ctx).Info("Running TestPreBind_ErrorHandling") state := framework.NewCycleState() pod := s.makePod("p1", map[string]string{ @@ -680,7 +700,8 @@ func (s *GPUResourcesSuite) TestPostBind_ErrorHandling() { }) // No pre-filter call, so state is empty - s.plugin.PostBind(s.ctx, state, pod, "node-a") + status := s.plugin.PreBind(s.ctx, state, pod, "node-a") + s.Equal(fwk.Error, status.Code()) // Test with a pod that doesn't exist in the client _, preFilterStatus := s.plugin.PreFilter(s.ctx, state, pod, []fwk.NodeInfo{}) @@ -690,7 +711,8 @@ func (s *GPUResourcesSuite) TestPostBind_ErrorHandling() { nonExistentPod := pod.DeepCopy() nonExistentPod.Name = "p-non-existent" - s.plugin.PostBind(s.ctx, state, nonExistentPod, "node-a") + status = s.plugin.PreBind(s.ctx, state, nonExistentPod, "node-a") + s.Equal(fwk.Error, status.Code()) } func (s *GPUResourcesSuite) TestFilter_ErrorHandling() {