From 6b84b19dc828659d345e8342c5996e371c64b596 Mon Sep 17 00:00:00 2001 From: 0x5457 <0x5457@protonmail.com> Date: Tue, 9 Dec 2025 10:51:51 +0800 Subject: [PATCH] fix(scheduler): use PreBind instead of PostBind for GPU device IDs annotation PostBind doesn't return status, so patch failures cannot trigger scheduler retry. Switch to PreBind to ensure GPU device IDs annotation is applied before pod binding, enabling automatic retry when patch fails. This is critical because hypervisor component depends on gpu-ids annotation. --- .../scheduler/gpuresources/gpuresources.go | 29 ++++++++------- .../gpuresources/gpuresources_test.go | 36 +++++++++++++++---- 2 files changed, 45 insertions(+), 20 deletions(-) diff --git a/internal/scheduler/gpuresources/gpuresources.go b/internal/scheduler/gpuresources/gpuresources.go index c3759fad..53116a25 100644 --- a/internal/scheduler/gpuresources/gpuresources.go +++ b/internal/scheduler/gpuresources/gpuresources.go @@ -38,7 +38,7 @@ var _ framework.PreFilterPlugin = &GPUFit{} var _ framework.FilterPlugin = &GPUFit{} var _ framework.ScorePlugin = &GPUFit{} var _ framework.ReservePlugin = &GPUFit{} -var _ framework.PostBindPlugin = &GPUFit{} +var _ framework.PreBindPlugin = &GPUFit{} var _ framework.EnqueueExtensions = &GPUFit{} type GPUFit struct { @@ -462,35 +462,38 @@ func (s *GPUFit) Unreserve(ctx context.Context, state fwk.CycleState, pod *v1.Po }, schedulingResult.FinalGPUs, pod.ObjectMeta) } -func (s *GPUFit) PostBind(ctx context.Context, state fwk.CycleState, pod *v1.Pod, nodeName string) { +func (s *GPUFit) PreBindPreFlight(ctx context.Context, state fwk.CycleState, pod *v1.Pod, nodeName string) *fwk.Status { if !utils.IsTensorFusionWorker(pod) { - return + return fwk.NewStatus(fwk.Skip, "skip for non tensor-fusion worker") } + return fwk.NewStatus(fwk.Success, "") +} + +func (s *GPUFit) PreBind(ctx context.Context, state fwk.CycleState, pod *v1.Pod, nodeName string) *fwk.Status { - s.logger.Info("PostBinding pod for GPU resources", "pod", pod.Name, "node", nodeName) gpuSchedulingResult, err := state.Read(CycleStateGPUSchedulingResult) if err != nil { s.logger.Error(err, "failed to read gpu scheduling result", "pod", pod.Name) - return + return fwk.NewStatus(fwk.Error, "failed to read gpu scheduling result: "+err.Error()) } - // write the allocated GPU info to Pod in bindingCycle, before default binder changing the Pod nodeName info + gpuIDs := strings.Join(gpuSchedulingResult.(*GPUSchedulingStateData).FinalGPUs, ",") - s.logger.Info("PostBinding pod for GPU resources", "pod", pod.Name, "node", nodeName, "gpuIDs", gpuIDs) + s.logger.Info("PreBinding pod for GPU resources", "pod", pod.Name, "node", nodeName, "gpuIDs", gpuIDs) - // Patch GPU device IDs annotation patch := []byte(`[{ "op": "add", "path": "/metadata/annotations/` + utils.EscapeJSONPointer(constants.GPUDeviceIDsAnnotation) + `", "value": "` + gpuIDs + `"}]`) - err = s.client.Patch(s.ctx, pod, client.RawPatch(types.JSONPatchType, patch)) - if err != nil { + if err := s.client.Patch(s.ctx, pod, client.RawPatch(types.JSONPatchType, patch)); err != nil { s.logger.Error(err, "failed to patch gpu device ids", "pod", pod.Name) s.fh.EventRecorder().Eventf(pod, pod, v1.EventTypeWarning, "GPUDeviceAllocatedFailed", "Attach GPU device ID info failed", "Can not add GPU device IDs: "+gpuIDs) - } else { - s.fh.EventRecorder().Eventf(pod, pod, v1.EventTypeNormal, "GPUDeviceAllocated", - "Attach GPU device ID info", "Attach TensorFusion GPU device IDs to Pod: "+gpuIDs) + return fwk.NewStatus(fwk.Error, "failed to patch gpu device ids: "+err.Error()) } + + s.fh.EventRecorder().Eventf(pod, pod, v1.EventTypeNormal, "GPUDeviceAllocated", + "Attach GPU device ID info", "Attach TensorFusion GPU device IDs to Pod: "+gpuIDs) + return fwk.NewStatus(fwk.Success, "") } func (s *GPUFit) EventsToRegister(_ context.Context) ([]fwk.ClusterEventWithHint, error) { diff --git a/internal/scheduler/gpuresources/gpuresources_test.go b/internal/scheduler/gpuresources/gpuresources_test.go index 5707a640..1aadb336 100644 --- a/internal/scheduler/gpuresources/gpuresources_test.go +++ b/internal/scheduler/gpuresources/gpuresources_test.go @@ -511,8 +511,8 @@ func (s *GPUResourcesSuite) TestReserveAndUnreserve() { s.Len(gpu.Status.RunningApps, 1) } -func (s *GPUResourcesSuite) TestPostBind() { - log.FromContext(s.ctx).Info("Running TestPostBind") +func (s *GPUResourcesSuite) TestPreBind() { + log.FromContext(s.ctx).Info("Running TestPreBind") state := framework.NewCycleState() pod := s.makePod("p1", map[string]string{ @@ -528,7 +528,8 @@ func (s *GPUResourcesSuite) TestPostBind() { reserveStatus := s.plugin.Reserve(s.ctx, state, pod, "node-a") s.Require().True(reserveStatus.IsSuccess()) - s.plugin.PostBind(s.ctx, state, pod, "node-a") + preBindStatus := s.plugin.PreBind(s.ctx, state, pod, "node-a") + s.Require().True(preBindStatus.IsSuccess()) updatedPod := &v1.Pod{} s.NoError(s.client.Get(s.ctx, types.NamespacedName{Name: "p1", Namespace: "ns1"}, updatedPod)) @@ -669,8 +670,27 @@ func (s *GPUResourcesSuite) TestUnreserve_ErrorHandling() { }) } -func (s *GPUResourcesSuite) TestPostBind_ErrorHandling() { - log.FromContext(s.ctx).Info("Running TestPostBind_ErrorHandling") +func (s *GPUResourcesSuite) TestPreBindPreFlight() { + log.FromContext(s.ctx).Info("Running TestPreBindPreFlight") + state := framework.NewCycleState() + + // TensorFusion worker pod should return Success + tfPod := s.makePod("tf-pod", map[string]string{ + constants.GpuCountAnnotation: "1", + constants.TFLOPSRequestAnnotation: "100", + constants.VRAMRequestAnnotation: "10Gi", + }) + status := s.plugin.PreBindPreFlight(s.ctx, state, tfPod, "node-a") + s.Equal(fwk.Success, status.Code()) + + // Non-TensorFusion pod should return Skip + nonTFPod := s.makeNonTensorFusionPod("non-tf-pod", 1) + status = s.plugin.PreBindPreFlight(s.ctx, state, nonTFPod, "node-a") + s.Equal(fwk.Skip, status.Code()) +} + +func (s *GPUResourcesSuite) TestPreBind_ErrorHandling() { + log.FromContext(s.ctx).Info("Running TestPreBind_ErrorHandling") state := framework.NewCycleState() pod := s.makePod("p1", map[string]string{ @@ -680,7 +700,8 @@ func (s *GPUResourcesSuite) TestPostBind_ErrorHandling() { }) // No pre-filter call, so state is empty - s.plugin.PostBind(s.ctx, state, pod, "node-a") + status := s.plugin.PreBind(s.ctx, state, pod, "node-a") + s.Equal(fwk.Error, status.Code()) // Test with a pod that doesn't exist in the client _, preFilterStatus := s.plugin.PreFilter(s.ctx, state, pod, []fwk.NodeInfo{}) @@ -690,7 +711,8 @@ func (s *GPUResourcesSuite) TestPostBind_ErrorHandling() { nonExistentPod := pod.DeepCopy() nonExistentPod.Name = "p-non-existent" - s.plugin.PostBind(s.ctx, state, nonExistentPod, "node-a") + status = s.plugin.PreBind(s.ctx, state, nonExistentPod, "node-a") + s.Equal(fwk.Error, status.Code()) } func (s *GPUResourcesSuite) TestFilter_ErrorHandling() {