Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Commit e7f3888

Browse files
author
byhsu
committed
Allow 0 worker in pytorch plugins & Add objectMeta to PyTorchJob
1 parent 76a80ec commit e7f3888

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package pytorch
22

33
import (
44
"context"
5-
"fmt"
65
"time"
76

87
"github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/plugins"
@@ -69,9 +68,6 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx
6968
common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.PytorchJobDefaultContainerName)
7069

7170
workers := pytorchTaskExtraArgs.GetWorkers()
72-
if workers == 0 {
73-
return nil, fmt.Errorf("number of worker should be more then 0")
74-
}
7571

7672
var jobSpec kubeflowv1.PyTorchJobSpec
7773

@@ -115,23 +111,27 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx
115111
},
116112
RestartPolicy: commonOp.RestartPolicyNever,
117113
},
118-
kubeflowv1.PyTorchJobReplicaTypeWorker: {
119-
Replicas: &workers,
120-
Template: v1.PodTemplateSpec{
121-
ObjectMeta: *objectMeta,
122-
Spec: *podSpec,
123-
},
124-
RestartPolicy: commonOp.RestartPolicyNever,
125-
},
126114
},
127115
}
116+
117+
if workers > 0 {
118+
jobSpec.PyTorchReplicaSpecs[kubeflowv1.MPIJobReplicaTypeWorker] = &commonOp.ReplicaSpec{
119+
Replicas: &workers,
120+
Template: v1.PodTemplateSpec{
121+
ObjectMeta: *objectMeta,
122+
Spec: *podSpec,
123+
},
124+
RestartPolicy: commonOp.RestartPolicyNever,
125+
}
126+
}
128127
}
129128
job := &kubeflowv1.PyTorchJob{
130129
TypeMeta: metav1.TypeMeta{
131130
Kind: kubeflowv1.PytorchJobKind,
132131
APIVersion: kubeflowv1.SchemeGroupVersion.String(),
133132
},
134-
Spec: jobSpec,
133+
Spec: jobSpec,
134+
ObjectMeta: *objectMeta,
135135
}
136136

137137
return job, nil

go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ func TestReplicaCounts(t *testing.T) {
425425
contains []commonOp.ReplicaType
426426
notContains []commonOp.ReplicaType
427427
}{
428-
{"NoWorkers", 0, true, nil, nil},
428+
{"NoWorkers", 0, false, []commonOp.ReplicaType{kubeflowv1.PyTorchJobReplicaTypeMaster}, nil},
429429
{"Works", 1, false, []commonOp.ReplicaType{kubeflowv1.PyTorchJobReplicaTypeMaster, kubeflowv1.PyTorchJobReplicaTypeWorker}, []commonOp.ReplicaType{}},
430430
} {
431431
t.Run(test.name, func(t *testing.T) {

0 commit comments

Comments
 (0)