Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions internal/scheduler/gpuresources/gpuresources.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ var _ framework.PreFilterPlugin = &GPUFit{}
var _ framework.FilterPlugin = &GPUFit{}
var _ framework.ScorePlugin = &GPUFit{}
var _ framework.ReservePlugin = &GPUFit{}
var _ framework.PostBindPlugin = &GPUFit{}
var _ framework.PreBindPlugin = &GPUFit{}
var _ framework.EnqueueExtensions = &GPUFit{}

type GPUFit struct {
Expand Down Expand Up @@ -462,35 +462,38 @@ func (s *GPUFit) Unreserve(ctx context.Context, state fwk.CycleState, pod *v1.Po
}, schedulingResult.FinalGPUs, pod.ObjectMeta)
}

func (s *GPUFit) PostBind(ctx context.Context, state fwk.CycleState, pod *v1.Pod, nodeName string) {
func (s *GPUFit) PreBindPreFlight(ctx context.Context, state fwk.CycleState, pod *v1.Pod, nodeName string) *fwk.Status {
if !utils.IsTensorFusionWorker(pod) {
return
return fwk.NewStatus(fwk.Skip, "skip for non tensor-fusion worker")
}
return fwk.NewStatus(fwk.Success, "")
}

func (s *GPUFit) PreBind(ctx context.Context, state fwk.CycleState, pod *v1.Pod, nodeName string) *fwk.Status {

s.logger.Info("PostBinding pod for GPU resources", "pod", pod.Name, "node", nodeName)
gpuSchedulingResult, err := state.Read(CycleStateGPUSchedulingResult)
if err != nil {
s.logger.Error(err, "failed to read gpu scheduling result", "pod", pod.Name)
return
return fwk.NewStatus(fwk.Error, "failed to read gpu scheduling result: "+err.Error())
}
// write the allocated GPU info to Pod in bindingCycle, before default binder changing the Pod nodeName info

gpuIDs := strings.Join(gpuSchedulingResult.(*GPUSchedulingStateData).FinalGPUs, ",")
s.logger.Info("PostBinding pod for GPU resources", "pod", pod.Name, "node", nodeName, "gpuIDs", gpuIDs)
s.logger.Info("PreBinding pod for GPU resources", "pod", pod.Name, "node", nodeName, "gpuIDs", gpuIDs)

// Patch GPU device IDs annotation
patch := []byte(`[{
"op": "add",
"path": "/metadata/annotations/` + utils.EscapeJSONPointer(constants.GPUDeviceIDsAnnotation) + `",
"value": "` + gpuIDs + `"}]`)
err = s.client.Patch(s.ctx, pod, client.RawPatch(types.JSONPatchType, patch))
if err != nil {
if err := s.client.Patch(s.ctx, pod, client.RawPatch(types.JSONPatchType, patch)); err != nil {
s.logger.Error(err, "failed to patch gpu device ids", "pod", pod.Name)
s.fh.EventRecorder().Eventf(pod, pod, v1.EventTypeWarning, "GPUDeviceAllocatedFailed",
"Attach GPU device ID info failed", "Can not add GPU device IDs: "+gpuIDs)
} else {
s.fh.EventRecorder().Eventf(pod, pod, v1.EventTypeNormal, "GPUDeviceAllocated",
"Attach GPU device ID info", "Attach TensorFusion GPU device IDs to Pod: "+gpuIDs)
return fwk.NewStatus(fwk.Error, "failed to patch gpu device ids: "+err.Error())
}

s.fh.EventRecorder().Eventf(pod, pod, v1.EventTypeNormal, "GPUDeviceAllocated",
"Attach GPU device ID info", "Attach TensorFusion GPU device IDs to Pod: "+gpuIDs)
return fwk.NewStatus(fwk.Success, "")
}

func (s *GPUFit) EventsToRegister(_ context.Context) ([]fwk.ClusterEventWithHint, error) {
Expand Down
36 changes: 29 additions & 7 deletions internal/scheduler/gpuresources/gpuresources_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -511,8 +511,8 @@ func (s *GPUResourcesSuite) TestReserveAndUnreserve() {
s.Len(gpu.Status.RunningApps, 1)
}

func (s *GPUResourcesSuite) TestPostBind() {
log.FromContext(s.ctx).Info("Running TestPostBind")
func (s *GPUResourcesSuite) TestPreBind() {
log.FromContext(s.ctx).Info("Running TestPreBind")
state := framework.NewCycleState()
pod := s.makePod("p1",
map[string]string{
Expand All @@ -528,7 +528,8 @@ func (s *GPUResourcesSuite) TestPostBind() {
reserveStatus := s.plugin.Reserve(s.ctx, state, pod, "node-a")
s.Require().True(reserveStatus.IsSuccess())

s.plugin.PostBind(s.ctx, state, pod, "node-a")
preBindStatus := s.plugin.PreBind(s.ctx, state, pod, "node-a")
s.Require().True(preBindStatus.IsSuccess())

updatedPod := &v1.Pod{}
s.NoError(s.client.Get(s.ctx, types.NamespacedName{Name: "p1", Namespace: "ns1"}, updatedPod))
Expand Down Expand Up @@ -669,8 +670,27 @@ func (s *GPUResourcesSuite) TestUnreserve_ErrorHandling() {
})
}

func (s *GPUResourcesSuite) TestPostBind_ErrorHandling() {
log.FromContext(s.ctx).Info("Running TestPostBind_ErrorHandling")
func (s *GPUResourcesSuite) TestPreBindPreFlight() {
log.FromContext(s.ctx).Info("Running TestPreBindPreFlight")
state := framework.NewCycleState()

// TensorFusion worker pod should return Success
tfPod := s.makePod("tf-pod", map[string]string{
constants.GpuCountAnnotation: "1",
constants.TFLOPSRequestAnnotation: "100",
constants.VRAMRequestAnnotation: "10Gi",
})
status := s.plugin.PreBindPreFlight(s.ctx, state, tfPod, "node-a")
s.Equal(fwk.Success, status.Code())

// Non-TensorFusion pod should return Skip
nonTFPod := s.makeNonTensorFusionPod("non-tf-pod", 1)
status = s.plugin.PreBindPreFlight(s.ctx, state, nonTFPod, "node-a")
s.Equal(fwk.Skip, status.Code())
}

func (s *GPUResourcesSuite) TestPreBind_ErrorHandling() {
log.FromContext(s.ctx).Info("Running TestPreBind_ErrorHandling")
state := framework.NewCycleState()
pod := s.makePod("p1",
map[string]string{
Expand All @@ -680,7 +700,8 @@ func (s *GPUResourcesSuite) TestPostBind_ErrorHandling() {
})

// No pre-filter call, so state is empty
s.plugin.PostBind(s.ctx, state, pod, "node-a")
status := s.plugin.PreBind(s.ctx, state, pod, "node-a")
s.Equal(fwk.Error, status.Code())

// Test with a pod that doesn't exist in the client
_, preFilterStatus := s.plugin.PreFilter(s.ctx, state, pod, []fwk.NodeInfo{})
Expand All @@ -690,7 +711,8 @@ func (s *GPUResourcesSuite) TestPostBind_ErrorHandling() {

nonExistentPod := pod.DeepCopy()
nonExistentPod.Name = "p-non-existent"
s.plugin.PostBind(s.ctx, state, nonExistentPod, "node-a")
status = s.plugin.PreBind(s.ctx, state, nonExistentPod, "node-a")
s.Equal(fwk.Error, status.Code())
}

func (s *GPUResourcesSuite) TestFilter_ErrorHandling() {
Expand Down