Skip to content

Commit

Permalink
feat(training):enable pytorch elastic training fashion based on torch…
Browse files Browse the repository at this point in the history
…-elastic
  • Loading branch information
wanziyu committed Aug 15, 2022
1 parent 171c0d7 commit 8b92522
Show file tree
Hide file tree
Showing 23 changed files with 1,038 additions and 6 deletions.
3 changes: 3 additions & 0 deletions apis/training/v1alpha1/pytorchjob_defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ func SetDefaults_PyTorchJob(job *PyTorchJob) {
// Set default replicas and restart policy.
if rType == PyTorchReplicaTypeWorker {
setDefaults_PyTorchJobWorkerReplicas(spec)
if job.Spec.EnableElastic && job.Spec.ElasticPolicy != nil {
setDefaults_PyTorchJobPort(&spec.Template.Spec)
}
}
if rType == PyTorchReplicaTypeMaster {
setDefaults_PyTorchJobMasterReplicas(spec)
Expand Down
23 changes: 23 additions & 0 deletions apis/training/v1alpha1/pytorchjob_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,29 @@ type PyTorchJobSpec struct {
// CacheBackend is used to configure the cache engine for job
// +optional
CacheBackend *cachev1alpha1.CacheBackendSpec `json:"cacheBackend"`

// EnableElastic decides whether torch elastic is enabled for job.
// +optional
EnableElastic bool `json:"enableElastic"`

// ElasticPolicy is used to configure the torch elastic-based elastic scaling support for distributed training job.
// +optional
ElasticPolicy *ElasticPolicy `json:"elasticPolicy,omitempty"`
}

type ElasticPolicy struct {
// minReplicas is the lower limit for the number of replicas to which the training job
// can scale down. It defaults to null.
MinReplicas *int32 `json:"minReplicas,omitempty"`

// upper limit for the number of pods that can be set by the autoscaler; cannot be smaller than MinReplicas, defaults to null.
MaxReplicas *int32 `json:"maxReplicas,omitempty"`

RDZVBackend string `json:"rdzvBackend,omitempty"`
RdzvEndpoint string `json:"rdzvEndpoint"`

// Number of workers per node; supported values: [auto, cpu, gpu, int].
NProcPerNode *int32 `json:"nProcPerNode,omitempty"`
}

// PyTorchJobStatus defines the observed state of PyTorchJob
Expand Down
35 changes: 35 additions & 0 deletions apis/training/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

18 changes: 18 additions & 0 deletions config/crd/bases/training.kubedl.io_elasticdljobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3111,6 +3111,24 @@ spec:
- type
type: object
type: array
elasticScaling:
additionalProperties:
properties:
continue:
type: boolean
currentReplicas:
format: int32
type: integer
lastReplicas:
format: int32
type: integer
message:
type: string
startTime:
format: date-time
type: string
type: object
type: object
lastReconcileTime:
format: date-time
type: string
Expand Down
18 changes: 18 additions & 0 deletions config/crd/bases/training.kubedl.io_marsjobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3133,6 +3133,24 @@ spec:
- type
type: object
type: array
elasticScaling:
additionalProperties:
properties:
continue:
type: boolean
currentReplicas:
format: int32
type: integer
lastReplicas:
format: int32
type: integer
message:
type: string
startTime:
format: date-time
type: string
type: object
type: object
lastReconcileTime:
format: date-time
type: string
Expand Down
18 changes: 18 additions & 0 deletions config/crd/bases/training.kubedl.io_mpijobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6156,6 +6156,24 @@ spec:
- type
type: object
type: array
elasticScaling:
additionalProperties:
properties:
continue:
type: boolean
currentReplicas:
format: int32
type: integer
lastReplicas:
format: int32
type: integer
message:
type: string
startTime:
format: date-time
type: string
type: object
type: object
lastReconcileTime:
format: date-time
type: string
Expand Down
38 changes: 38 additions & 0 deletions config/crd/bases/training.kubedl.io_pytorchjobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,26 @@ spec:
required:
- schedule
type: object
elasticPolicy:
properties:
maxReplicas:
format: int32
type: integer
minReplicas:
format: int32
type: integer
nProcPerNode:
format: int32
type: integer
rdzvBackend:
type: string
rdzvEndpoint:
type: string
required:
- rdzvEndpoint
type: object
enableElastic:
type: boolean
modelVersion:
properties:
createdBy:
Expand Down Expand Up @@ -3198,6 +3218,24 @@ spec:
- type
type: object
type: array
elasticScaling:
additionalProperties:
properties:
continue:
type: boolean
currentReplicas:
format: int32
type: integer
lastReplicas:
format: int32
type: integer
message:
type: string
startTime:
format: date-time
type: string
type: object
type: object
lastReconcileTime:
format: date-time
type: string
Expand Down
18 changes: 18 additions & 0 deletions config/crd/bases/training.kubedl.io_tfjobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3200,6 +3200,24 @@ spec:
- type
type: object
type: array
elasticScaling:
additionalProperties:
properties:
continue:
type: boolean
currentReplicas:
format: int32
type: integer
lastReplicas:
format: int32
type: integer
message:
type: string
startTime:
format: date-time
type: string
type: object
type: object
lastReconcileTime:
format: date-time
type: string
Expand Down
18 changes: 18 additions & 0 deletions config/crd/bases/training.kubedl.io_xdljobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3117,6 +3117,24 @@ spec:
- type
type: object
type: array
elasticScaling:
additionalProperties:
properties:
continue:
type: boolean
currentReplicas:
format: int32
type: integer
lastReplicas:
format: int32
type: integer
message:
type: string
startTime:
format: date-time
type: string
type: object
type: object
lastReconcileTime:
format: date-time
type: string
Expand Down
18 changes: 18 additions & 0 deletions config/crd/bases/training.kubedl.io_xgboostjobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3111,6 +3111,24 @@ spec:
- type
type: object
type: array
elasticScaling:
additionalProperties:
properties:
continue:
type: boolean
currentReplicas:
format: int32
type: integer
lastReplicas:
format: int32
type: integer
message:
type: string
startTime:
format: date-time
type: string
type: object
type: object
lastReconcileTime:
format: date-time
type: string
Expand Down
36 changes: 36 additions & 0 deletions controllers/pytorch/pytorchjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,40 @@ func (r *PytorchJobReconciler) SetClusterSpec(ctx context.Context, job interface
}
}

desiredReplicas, err := computeDesiredReplicas(pytorchJob)
if err != nil {
return err
}

// Set default value if minReplicas and maxReplicas are not set
var minReplicas, maxReplicas int32
if pytorchJob.Spec.ElasticPolicy.MinReplicas != nil {
minReplicas = *pytorchJob.Spec.ElasticPolicy.MinReplicas
} else {
minReplicas = desiredReplicas
}

if pytorchJob.Spec.ElasticPolicy.MaxReplicas != nil {
maxReplicas = *pytorchJob.Spec.ElasticPolicy.MaxReplicas
} else {
maxReplicas = desiredReplicas
}

var procPerNode int32
if pytorchJob.Spec.ElasticPolicy.NProcPerNode != nil {
procPerNode = *pytorchJob.Spec.ElasticPolicy.NProcPerNode
} else {
procPerNode = int32(1)
}

//Generate torch elastic env args.
launchElasticArgs := []string{
"--rdzv_backend=" + pytorchJob.Spec.ElasticPolicy.RDZVBackend,
"--rdzv_endpoint=" + pytorchJob.Spec.ElasticPolicy.RdzvEndpoint,
"--rdzv_id=" + pytorchJob.Name,
"--nproc_per_node=" + strconv.Itoa(int(procPerNode)),
"--nnodes=" + strconv.Itoa(int(minReplicas)) + ":" + strconv.Itoa(int(maxReplicas))}

for i := range podTemplate.Spec.Containers {
if len(podTemplate.Spec.Containers[i].Env) == 0 {
podTemplate.Spec.Containers[i].Env = make([]corev1.EnvVar, 0)
Expand All @@ -285,6 +319,8 @@ func (r *PytorchJobReconciler) SetClusterSpec(ctx context.Context, job interface
Name: "PYTHONUNBUFFERED",
Value: "0",
})
podTemplate.Spec.Containers[i].Args = append(launchElasticArgs, podTemplate.Spec.Containers[i].Args...)

if enableElasticScaling && rtype != "aimaster" {
// Job enables elastic scaling select value of AnnotationWorldSize as its
// WORLD_SIZE env value via field-path, the annotated value will be mutated
Expand Down
16 changes: 15 additions & 1 deletion controllers/pytorch/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,23 @@ limitations under the License.

package pytorch

import training "github.com/alibaba/kubedl/apis/training/v1alpha1"
import (
"fmt"
training "github.com/alibaba/kubedl/apis/training/v1alpha1"
v1 "github.com/alibaba/kubedl/pkg/job_controller/api/v1"
)

func ContainMasterSpec(job *training.PyTorchJob) bool {
_, ok := job.Spec.PyTorchReplicaSpecs[training.PyTorchReplicaTypeMaster]
return ok
}

// computeDesiredReplicas retrieve user's replica setting in specs
func computeDesiredReplicas(elasticJob *training.PyTorchJob) (int32, error) {
workerSpecs, exist := elasticJob.Spec.PyTorchReplicaSpecs[v1.ReplicaType(training.PyTorchReplicaTypeMaster)]
if !exist {
return 0, fmt.Errorf("elasticJob %v doesn't have %s", elasticJob, training.PyTorchReplicaTypeMaster)
}

return *workerSpecs.Replicas, nil
}
23 changes: 23 additions & 0 deletions controllers/torchelastic/elastic_controller.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package torchelastic

import (
"github.com/alibaba/kubedl/controllers/torchelastic/job"
ctrl "sigs.k8s.io/controller-runtime"
)

const (
controllerName = "ElasticScalingController"
)

func SetupWithManager(mgr ctrl.Manager) error {
// New torch elastic controller.
// period represents the time elastic scaling loop repeats.
// count represents the length of training metrics collection for each scale replicas.
torchElasticController := job.NewTorchElasticController(mgr, 30, 5)

if err := torchElasticController.SetupWithManager(mgr); err != nil {
return err
}
return nil

}
Loading

0 comments on commit 8b92522

Please sign in to comment.