Skip to content

Commit

Permalink
feat(training):update elastic training control flow and APIs
Browse files Browse the repository at this point in the history
Signed-off-by: wanziyu <[email protected]>
  • Loading branch information
wanziyu committed Aug 23, 2022
1 parent 8b92522 commit a7e5fbc
Show file tree
Hide file tree
Showing 20 changed files with 936 additions and 504 deletions.
4 changes: 2 additions & 2 deletions apis/training/v1alpha1/pytorchjob_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ type PyTorchJobSpec struct {

// EnableElastic decides whether torch elastic is enabled for job.
// +optional
EnableElastic bool `json:"enableElastic"`
EnableElastic bool `json:"enableElastic,omitempty"`

// ElasticPolicy is used to configure the torch elastic-based elastic scaling support for distributed training job.
// +optional
Expand All @@ -64,7 +64,7 @@ type ElasticPolicy struct {
// upper limit for the number of pods that can be set by the autoscaler; cannot be smaller than MinReplicas, defaults to null.
MaxReplicas *int32 `json:"maxReplicas,omitempty"`

RDZVBackend string `json:"rdzvBackend,omitempty"`
RDZVBackend string `json:"rdzvBackend"`
RdzvEndpoint string `json:"rdzvEndpoint"`

// Number of workers per node; supported values: [auto, cpu, gpu, int].
Expand Down
2 changes: 2 additions & 0 deletions config/crd/bases/training.kubedl.io_elasticdljobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3119,6 +3119,8 @@ spec:
currentReplicas:
format: int32
type: integer
elasticCondition:
type: string
lastReplicas:
format: int32
type: integer
Expand Down
2 changes: 2 additions & 0 deletions config/crd/bases/training.kubedl.io_marsjobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3141,6 +3141,8 @@ spec:
currentReplicas:
format: int32
type: integer
elasticCondition:
type: string
lastReplicas:
format: int32
type: integer
Expand Down
2 changes: 2 additions & 0 deletions config/crd/bases/training.kubedl.io_mpijobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6164,6 +6164,8 @@ spec:
currentReplicas:
format: int32
type: integer
elasticCondition:
type: string
lastReplicas:
format: int32
type: integer
Expand Down
3 changes: 3 additions & 0 deletions config/crd/bases/training.kubedl.io_pytorchjobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ spec:
rdzvEndpoint:
type: string
required:
- rdzvBackend
- rdzvEndpoint
type: object
enableElastic:
Expand Down Expand Up @@ -3226,6 +3227,8 @@ spec:
currentReplicas:
format: int32
type: integer
elasticCondition:
type: string
lastReplicas:
format: int32
type: integer
Expand Down
2 changes: 2 additions & 0 deletions config/crd/bases/training.kubedl.io_tfjobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3208,6 +3208,8 @@ spec:
currentReplicas:
format: int32
type: integer
elasticCondition:
type: string
lastReplicas:
format: int32
type: integer
Expand Down
2 changes: 2 additions & 0 deletions config/crd/bases/training.kubedl.io_xdljobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3125,6 +3125,8 @@ spec:
currentReplicas:
format: int32
type: integer
elasticCondition:
type: string
lastReplicas:
format: int32
type: integer
Expand Down
2 changes: 2 additions & 0 deletions config/crd/bases/training.kubedl.io_xgboostjobs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3119,6 +3119,8 @@ spec:
currentReplicas:
format: int32
type: integer
elasticCondition:
type: string
lastReplicas:
format: int32
type: integer
Expand Down
1 change: 1 addition & 0 deletions controllers/pytorch/elastic_scale.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ const (
AnnotationCheckpointRequestedVersion = v1.KubeDLPrefix + "/ckpt-requested-version"
AnnotationCheckpointCompletedVersion = v1.KubeDLPrefix + "/ckpt-completed-version"
AnnotationReadyToStartWorker = v1.KubeDLPrefix + "/ready-to-start-worker"
AnnotationReadyToRestartWorker = v1.KubeDLPrefix + "/ready-to-restart-worker"
AnnotationImmediatelyStartWorker = v1.KubeDLPrefix + "/immediately-start-worker"
AnnotationWorldSize = v1.KubeDLPrefix + "/world-size"
)
Expand Down
5 changes: 4 additions & 1 deletion controllers/pytorch/pytorchjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,10 @@ func (r *PytorchJobReconciler) SetClusterSpec(ctx context.Context, job interface
Name: "PYTHONUNBUFFERED",
Value: "0",
})
podTemplate.Spec.Containers[i].Args = append(launchElasticArgs, podTemplate.Spec.Containers[i].Args...)

if pytorchJob.Spec.EnableElastic && pytorchJob.Spec.ElasticPolicy != nil {
podTemplate.Spec.Containers[i].Args = append(launchElasticArgs, podTemplate.Spec.Containers[i].Args...)
}

if enableElasticScaling && rtype != "aimaster" {
// Job enables elastic scaling select value of AnnotationWorldSize as its
Expand Down
223 changes: 223 additions & 0 deletions controllers/torchelastic/elastic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
/*
Copyright 2022 The Alibaba Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package torchelastic

import (
"context"
training "github.com/alibaba/kubedl/apis/training/v1alpha1"
apiv1 "github.com/alibaba/kubedl/pkg/job_controller/api/v1"
logger "github.com/sirupsen/logrus"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/types"
"reflect"
)

func (ts *TorchElasticController) start(ctx context.Context, cancel context.CancelFunc, name, namespace string) {
sharedPytorchJob := &training.PyTorchJob{}
jobName := name
jobNamespace := namespace

// Create metrics for each torch elastic job.
ts.locker.Lock()
if _, ok := ts.metrics[jobName]; !ok {
ts.metrics[jobName] = make(map[int32][]MetricObservation)
}
ts.locker.Unlock()

err := ts.Client.Get(ctx, types.NamespacedName{Namespace: jobNamespace, Name: jobName}, sharedPytorchJob)
if err != nil {
logger.Infof("try to get job %s from namespace %s but it has been deleted", jobName, jobNamespace)
// cancel the elastic scaling process context of the deleted job.
defer cancel()
return
}

pytorchJob := sharedPytorchJob.DeepCopy()
if pytorchJob.Spec.ElasticPolicy.MaxReplicas == nil || pytorchJob.Spec.ElasticPolicy.MinReplicas == nil {
logger.Infof("pytorch job %s does not configure the max or min replicas", pytorchJob.Name)
defer cancel()
delete(ts.torchElasticJobs, makeElasticJobName(pytorchJob.Name, pytorchJob.Namespace))
return
}

if pytorchJob.Status.ElasticStatus == nil {
initializeElasticStatuses(pytorchJob, training.PyTorchReplicaTypeWorker)
if err := ts.UpdateJobStatusInApiServer(pytorchJob, &pytorchJob.Status); err != nil {
if errors.IsConflict(err) {
// retry later when update operation violates with etcd concurrency control.
log.Info("fail to update pytorch job")
}
}
return
}

jobStatus := pytorchJob.Status.DeepCopy()
oldStatus := jobStatus.DeepCopy()
if pytorchJob.Status.CompletionTime != nil || pytorchJob.DeletionTimestamp != nil {
logger.Infof("job %s has been completed or deleted and does not need to do elastic scaling", pytorchJob.Name)
defer cancel()
delete(ts.torchElasticJobs, makeElasticJobName(pytorchJob.Name, pytorchJob.Namespace))
delete(ts.metrics, makeElasticJobName(pytorchJob.Name, pytorchJob.Namespace))
return
}

currentReplicas := *pytorchJob.Spec.PyTorchReplicaSpecs[training.PyTorchReplicaTypeWorker].Replicas

// Wait for all pods running and judge whether there exists pending or failed pods.
hasPendingPod, hasFailedPod := ts.waitForAllPodsRunning(pytorchJob)

// If job has pending pods and current replicas are more than min replicas, return to the last replicas.
if hasPendingPod && currentReplicas > *pytorchJob.Spec.ElasticPolicy.MinReplicas {
lastReplicas := jobStatus.ElasticStatus[training.PyTorchReplicaTypeWorker].LastReplicas
*pytorchJob.Spec.PyTorchReplicaSpecs[training.PyTorchReplicaTypeWorker].Replicas = lastReplicas
// Return to the last replicas.
if err := ts.Client.Update(ctx, pytorchJob); err != nil {
log.Info("fail to update replicas of pytorch job")
}

updateElasticStatusForPendingJob(pytorchJob, lastReplicas, training.PyTorchReplicaTypeWorker)
if err := ts.UpdateJobStatusInApiServer(pytorchJob, &pytorchJob.Status); err != nil {
if errors.IsConflict(err) {
// retry later when update operation violates with etcd concurrency control.
log.Info("fail to update pytorch job")
}
}
return

// If job has pending pods and current replicas equals to the min replicas, cancel the elastic scaling process context.
} else if (hasPendingPod && currentReplicas == *pytorchJob.Spec.ElasticPolicy.MinReplicas) || hasFailedPod {
defer cancel()
logger.Info("pods did not reach the running state at min replicas or job is failed, so the elastic scaling controller shutdown")
delete(ts.torchElasticJobs, makeElasticJobName(pytorchJob.Name, pytorchJob.Namespace))
return
}

if !hasPendingPod && jobStatus.ElasticStatus != nil && jobStatus.ElasticStatus[training.PyTorchReplicaTypeWorker].Continue == false {
// If job metrics have reached the max, restart stale pods.
if jobStatus.ElasticStatus[training.PyTorchReplicaTypeWorker].ElasticCondition == apiv1.ElasticMaxMetric {
pods, err := ts.GetPodsForJob(pytorchJob)
if err != nil {
logger.Warnf("Get Pods For Job error %v", err)
}
// Restart stale torch elastic pods.
complete := ts.restartStalePytorchPods(pods, pytorchJob)
if !complete {
logger.Info("restart pods does not complete")
return
}
logger.Info("restart pods has completed")
jobStatus.ElasticStatus[training.PyTorchReplicaTypeWorker].ElasticCondition = apiv1.ElasticStop
if err = ts.UpdateJobStatusInApiServer(pytorchJob, jobStatus); err != nil {
if errors.IsConflict(err) {
// retry later when update operation violates with etcd concurrency control.
logger.Info("fail to update pytorch job status")
return
}
}
return
// If current replicas reach the defined max replicas or elastic condition is stopped, return directly.
} else if jobStatus.ElasticStatus[training.PyTorchReplicaTypeWorker].ElasticCondition == apiv1.ElasticStop || jobStatus.ElasticStatus[training.PyTorchReplicaTypeWorker].ElasticCondition == apiv1.ElasticMaxReplica {
log.Info("Pytorch job does not need to be scaled")
return
}
}

// Read training logs from pytorch pods and save the observation.
observation, err := read(ts.client, jobNamespace, GetDefaultWorkerName(jobName))
if err != nil {
logger.Infof("fail to read training logs: %v", err)
return
}

ts.locker.Lock()
defer ts.locker.Unlock()

// Create metrics for current replicas.
if _, ok := ts.metrics[jobName][currentReplicas]; !ok {
ts.metrics[jobName][currentReplicas] = make([]MetricObservation, 0)
}
ts.metrics[jobName][currentReplicas] = append(ts.metrics[jobName][currentReplicas], observation)
currentLength := len(ts.metrics[jobName][currentReplicas])
logger.Infof("Current metric length: %d", currentLength)

// If current metrics have reached the metric count, judge the next scaling replicas.
if currentLength >= ts.metricCount {
if currentReplicas > *pytorchJob.Spec.ElasticPolicy.MinReplicas && currentReplicas <= *pytorchJob.Spec.ElasticPolicy.MaxReplicas {
lastReplicas := jobStatus.ElasticStatus[training.PyTorchReplicaTypeWorker].LastReplicas

if ts.IsSatisfyElasticContinue(jobName, currentReplicas, lastReplicas) {
if currentReplicas == *pytorchJob.Spec.ElasticPolicy.MaxReplicas {
updateElasticStatusForMaxReplicaJob(pytorchJob, training.PyTorchReplicaTypeWorker)
ts.metrics[jobName][currentReplicas] = make([]MetricObservation, 0)
} else {
newReplicas := computeNewReplicas(currentReplicas)
*pytorchJob.Spec.PyTorchReplicaSpecs[training.PyTorchReplicaTypeWorker].Replicas = newReplicas
if err := ts.Client.Update(ctx, pytorchJob); err != nil {
log.Info("fail to update pytorch job")
}

updateElasticStatusForContinueJob(pytorchJob, currentReplicas, newReplicas, training.PyTorchReplicaTypeWorker)
if _, ok := ts.metrics[jobName][newReplicas]; !ok {
ts.metrics[jobName][newReplicas] = make([]MetricObservation, 0)
}
}

} else {
*pytorchJob.Spec.PyTorchReplicaSpecs[training.PyTorchReplicaTypeWorker].Replicas = lastReplicas
if err := ts.Client.Update(ctx, pytorchJob); err != nil {
log.Info("fail to update pytorch job")
}

updateElasticStatusForMaxMetricJob(pytorchJob, currentReplicas, lastReplicas, training.PyTorchReplicaTypeWorker)
ts.metrics[jobName][lastReplicas] = make([]MetricObservation, 0)
ts.metrics[jobName][currentReplicas] = make([]MetricObservation, 0)
}

} else if currentReplicas == *pytorchJob.Spec.ElasticPolicy.MinReplicas && currentReplicas < *pytorchJob.Spec.ElasticPolicy.MaxReplicas {
newReplicas := computeNewReplicas(currentReplicas)
*pytorchJob.Spec.PyTorchReplicaSpecs[training.PyTorchReplicaTypeWorker].Replicas = newReplicas
if err := ts.Client.Update(ctx, pytorchJob); err != nil {
log.Info("fail to update pytorch job")
}

updateElasticStatusForContinueJob(pytorchJob, currentReplicas, newReplicas, training.PyTorchReplicaTypeWorker)
if _, ok := ts.metrics[jobName][newReplicas]; !ok {
ts.metrics[jobName][newReplicas] = make([]MetricObservation, 0)
}

} else if currentReplicas == *pytorchJob.Spec.ElasticPolicy.MaxReplicas {
updateElasticStatusForMaxReplicaJob(pytorchJob, training.PyTorchReplicaTypeWorker)
if _, ok := ts.metrics[jobName][currentReplicas]; ok {
ts.metrics[jobName][currentReplicas] = make([]MetricObservation, 0)
}

}
}

// No need to update the job status if the status hasn't changed since last time.
if !reflect.DeepEqual(*oldStatus, pytorchJob.Status) {
if err = ts.UpdateJobStatusInApiServer(pytorchJob, &pytorchJob.Status); err != nil {
if errors.IsConflict(err) {
// retry later when update operation violates with etcd concurrency control.
logger.Info("fail to update pytorch job status")
return
}
}
}

return
}
25 changes: 18 additions & 7 deletions controllers/torchelastic/elastic_controller.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
/*
Copyright 2022 The Alibaba Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package torchelastic

import (
"github.com/alibaba/kubedl/controllers/torchelastic/job"
ctrl "sigs.k8s.io/controller-runtime"
)

const (
controllerName = "ElasticScalingController"
)

func SetupWithManager(mgr ctrl.Manager) error {
// New torch elastic controller.
// period represents the time elastic scaling loop repeats.
// count represents the length of training metrics collection for each scale replicas.
torchElasticController := job.NewTorchElasticController(mgr, 30, 5)
// count represents the length of training metrics collection for each replica.
torchElasticController := NewTorchElasticController(mgr, 30, 5)

if err := torchElasticController.SetupWithManager(mgr); err != nil {
return err
Expand Down
Loading

0 comments on commit a7e5fbc

Please sign in to comment.