diff --git a/apis/training/v1alpha1/pytorchjob_defaults.go b/apis/training/v1alpha1/pytorchjob_defaults.go index 652badb4..405c89a8 100644 --- a/apis/training/v1alpha1/pytorchjob_defaults.go +++ b/apis/training/v1alpha1/pytorchjob_defaults.go @@ -125,6 +125,9 @@ func SetDefaults_PyTorchJob(job *PyTorchJob) { // Set default replicas and restart policy. if rType == PyTorchReplicaTypeWorker { setDefaults_PyTorchJobWorkerReplicas(spec) + if job.Spec.EnableElastic && job.Spec.ElasticPolicy != nil { + setDefaults_PyTorchJobPort(&spec.Template.Spec) + } } if rType == PyTorchReplicaTypeMaster { setDefaults_PyTorchJobMasterReplicas(spec) diff --git a/apis/training/v1alpha1/pytorchjob_types.go b/apis/training/v1alpha1/pytorchjob_types.go index 4b3f2787..c0bececc 100644 --- a/apis/training/v1alpha1/pytorchjob_types.go +++ b/apis/training/v1alpha1/pytorchjob_types.go @@ -46,6 +46,29 @@ type PyTorchJobSpec struct { // CacheBackend is used to configure the cache engine for job // +optional CacheBackend *cachev1alpha1.CacheBackendSpec `json:"cacheBackend"` + + // EnableElastic decides whether torch elastic is enabled for job. + // +optional + EnableElastic bool `json:"enableElastic,omitempty"` + + // ElasticPolicy is used to configure the torch elastic-based elastic scaling support for distributed training job. + // +optional + ElasticPolicy *ElasticPolicy `json:"elasticPolicy,omitempty"` +} + +type ElasticPolicy struct { + // minReplicas is the lower limit for the number of replicas to which the training job + // can scale down. It defaults to null. + MinReplicas *int32 `json:"minReplicas,omitempty"` + + // upper limit for the number of pods that can be set by the autoscaler; cannot be smaller than MinReplicas, defaults to null. + MaxReplicas *int32 `json:"maxReplicas,omitempty"` + + RDZVBackend string `json:"rdzvBackend"` + RdzvEndpoint string `json:"rdzvEndpoint"` + + // Number of workers per node; supported values: [auto, cpu, gpu, int]. + NProcPerNode *int32 `json:"nProcPerNode,omitempty"` } // PyTorchJobStatus defines the observed state of PyTorchJob diff --git a/apis/training/v1alpha1/zz_generated.deepcopy.go b/apis/training/v1alpha1/zz_generated.deepcopy.go index 5785441b..5b2cb55e 100644 --- a/apis/training/v1alpha1/zz_generated.deepcopy.go +++ b/apis/training/v1alpha1/zz_generated.deepcopy.go @@ -118,6 +118,36 @@ func (in *ElasticDLJobSpec) DeepCopy() *ElasticDLJobSpec { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ElasticPolicy) DeepCopyInto(out *ElasticPolicy) { + *out = *in + if in.MinReplicas != nil { + in, out := &in.MinReplicas, &out.MinReplicas + *out = new(int32) + **out = **in + } + if in.MaxReplicas != nil { + in, out := &in.MaxReplicas, &out.MaxReplicas + *out = new(int32) + **out = **in + } + if in.NProcPerNode != nil { + in, out := &in.NProcPerNode, &out.NProcPerNode + *out = new(int32) + **out = **in + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ElasticPolicy. +func (in *ElasticPolicy) DeepCopy() *ElasticPolicy { + if in == nil { + return nil + } + out := new(ElasticPolicy) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *LegacyV1Alpha1) DeepCopyInto(out *LegacyV1Alpha1) { *out = *in @@ -558,6 +588,11 @@ func (in *PyTorchJobSpec) DeepCopyInto(out *PyTorchJobSpec) { *out = new(cachev1alpha1.CacheBackendSpec) (*in).DeepCopyInto(*out) } + if in.ElasticPolicy != nil { + in, out := &in.ElasticPolicy, &out.ElasticPolicy + *out = new(ElasticPolicy) + (*in).DeepCopyInto(*out) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new PyTorchJobSpec. diff --git a/config/crd/bases/training.kubedl.io_elasticdljobs.yaml b/config/crd/bases/training.kubedl.io_elasticdljobs.yaml index effd16fb..4ba30e4e 100644 --- a/config/crd/bases/training.kubedl.io_elasticdljobs.yaml +++ b/config/crd/bases/training.kubedl.io_elasticdljobs.yaml @@ -3111,6 +3111,26 @@ spec: - type type: object type: array + elasticScaling: + additionalProperties: + properties: + continue: + type: boolean + currentReplicas: + format: int32 + type: integer + elasticCondition: + type: string + lastReplicas: + format: int32 + type: integer + message: + type: string + startTime: + format: date-time + type: string + type: object + type: object lastReconcileTime: format: date-time type: string diff --git a/config/crd/bases/training.kubedl.io_marsjobs.yaml b/config/crd/bases/training.kubedl.io_marsjobs.yaml index 73fbf03c..79fadc23 100644 --- a/config/crd/bases/training.kubedl.io_marsjobs.yaml +++ b/config/crd/bases/training.kubedl.io_marsjobs.yaml @@ -3133,6 +3133,26 @@ spec: - type type: object type: array + elasticScaling: + additionalProperties: + properties: + continue: + type: boolean + currentReplicas: + format: int32 + type: integer + elasticCondition: + type: string + lastReplicas: + format: int32 + type: integer + message: + type: string + startTime: + format: date-time + type: string + type: object + type: object lastReconcileTime: format: date-time type: string diff --git a/config/crd/bases/training.kubedl.io_mpijobs.yaml b/config/crd/bases/training.kubedl.io_mpijobs.yaml index 878a6b6c..e283f7dc 100644 --- a/config/crd/bases/training.kubedl.io_mpijobs.yaml +++ b/config/crd/bases/training.kubedl.io_mpijobs.yaml @@ -6156,6 +6156,26 @@ spec: - type type: object type: array + elasticScaling: + additionalProperties: + properties: + continue: + type: boolean + currentReplicas: + format: int32 + type: integer + elasticCondition: + type: string + lastReplicas: + format: int32 + type: integer + message: + type: string + startTime: + format: date-time + type: string + type: object + type: object lastReconcileTime: format: date-time type: string diff --git a/config/crd/bases/training.kubedl.io_pytorchjobs.yaml b/config/crd/bases/training.kubedl.io_pytorchjobs.yaml index 45b1c8dc..3238995d 100644 --- a/config/crd/bases/training.kubedl.io_pytorchjobs.yaml +++ b/config/crd/bases/training.kubedl.io_pytorchjobs.yaml @@ -112,6 +112,27 @@ spec: required: - schedule type: object + elasticPolicy: + properties: + maxReplicas: + format: int32 + type: integer + minReplicas: + format: int32 + type: integer + nProcPerNode: + format: int32 + type: integer + rdzvBackend: + type: string + rdzvEndpoint: + type: string + required: + - rdzvBackend + - rdzvEndpoint + type: object + enableElastic: + type: boolean modelVersion: properties: createdBy: @@ -3198,6 +3219,26 @@ spec: - type type: object type: array + elasticScaling: + additionalProperties: + properties: + continue: + type: boolean + currentReplicas: + format: int32 + type: integer + elasticCondition: + type: string + lastReplicas: + format: int32 + type: integer + message: + type: string + startTime: + format: date-time + type: string + type: object + type: object lastReconcileTime: format: date-time type: string diff --git a/config/crd/bases/training.kubedl.io_tfjobs.yaml b/config/crd/bases/training.kubedl.io_tfjobs.yaml index fa9a5f1b..5bd61575 100644 --- a/config/crd/bases/training.kubedl.io_tfjobs.yaml +++ b/config/crd/bases/training.kubedl.io_tfjobs.yaml @@ -3200,6 +3200,26 @@ spec: - type type: object type: array + elasticScaling: + additionalProperties: + properties: + continue: + type: boolean + currentReplicas: + format: int32 + type: integer + elasticCondition: + type: string + lastReplicas: + format: int32 + type: integer + message: + type: string + startTime: + format: date-time + type: string + type: object + type: object lastReconcileTime: format: date-time type: string diff --git a/config/crd/bases/training.kubedl.io_xdljobs.yaml b/config/crd/bases/training.kubedl.io_xdljobs.yaml index ff3a405b..a9bb6cd0 100644 --- a/config/crd/bases/training.kubedl.io_xdljobs.yaml +++ b/config/crd/bases/training.kubedl.io_xdljobs.yaml @@ -3117,6 +3117,26 @@ spec: - type type: object type: array + elasticScaling: + additionalProperties: + properties: + continue: + type: boolean + currentReplicas: + format: int32 + type: integer + elasticCondition: + type: string + lastReplicas: + format: int32 + type: integer + message: + type: string + startTime: + format: date-time + type: string + type: object + type: object lastReconcileTime: format: date-time type: string diff --git a/config/crd/bases/training.kubedl.io_xgboostjobs.yaml b/config/crd/bases/training.kubedl.io_xgboostjobs.yaml index b248367b..c74e55e5 100644 --- a/config/crd/bases/training.kubedl.io_xgboostjobs.yaml +++ b/config/crd/bases/training.kubedl.io_xgboostjobs.yaml @@ -3111,6 +3111,26 @@ spec: - type type: object type: array + elasticScaling: + additionalProperties: + properties: + continue: + type: boolean + currentReplicas: + format: int32 + type: integer + elasticCondition: + type: string + lastReplicas: + format: int32 + type: integer + message: + type: string + startTime: + format: date-time + type: string + type: object + type: object lastReconcileTime: format: date-time type: string diff --git a/controllers/pytorch/elastic_scale.go b/controllers/pytorch/elastic_scale.go index 06364e23..bce51efc 100644 --- a/controllers/pytorch/elastic_scale.go +++ b/controllers/pytorch/elastic_scale.go @@ -35,6 +35,7 @@ const ( AnnotationCheckpointRequestedVersion = v1.KubeDLPrefix + "/ckpt-requested-version" AnnotationCheckpointCompletedVersion = v1.KubeDLPrefix + "/ckpt-completed-version" AnnotationReadyToStartWorker = v1.KubeDLPrefix + "/ready-to-start-worker" + AnnotationReadyToRestartWorker = v1.KubeDLPrefix + "/ready-to-restart-worker" AnnotationImmediatelyStartWorker = v1.KubeDLPrefix + "/immediately-start-worker" AnnotationWorldSize = v1.KubeDLPrefix + "/world-size" ) diff --git a/controllers/pytorch/pytorchjob_controller.go b/controllers/pytorch/pytorchjob_controller.go index f11bb9ce..ac51a030 100644 --- a/controllers/pytorch/pytorchjob_controller.go +++ b/controllers/pytorch/pytorchjob_controller.go @@ -265,6 +265,40 @@ func (r *PytorchJobReconciler) SetClusterSpec(ctx context.Context, job interface } } + desiredReplicas, err := computeDesiredReplicas(pytorchJob) + if err != nil { + return err + } + + // Set default value if minReplicas and maxReplicas are not set + var minReplicas, maxReplicas int32 + if pytorchJob.Spec.ElasticPolicy.MinReplicas != nil { + minReplicas = *pytorchJob.Spec.ElasticPolicy.MinReplicas + } else { + minReplicas = desiredReplicas + } + + if pytorchJob.Spec.ElasticPolicy.MaxReplicas != nil { + maxReplicas = *pytorchJob.Spec.ElasticPolicy.MaxReplicas + } else { + maxReplicas = desiredReplicas + } + + var procPerNode int32 + if pytorchJob.Spec.ElasticPolicy.NProcPerNode != nil { + procPerNode = *pytorchJob.Spec.ElasticPolicy.NProcPerNode + } else { + procPerNode = int32(1) + } + + //Generate torch elastic env args. + launchElasticArgs := []string{ + "--rdzv_backend=" + pytorchJob.Spec.ElasticPolicy.RDZVBackend, + "--rdzv_endpoint=" + pytorchJob.Spec.ElasticPolicy.RdzvEndpoint, + "--rdzv_id=" + pytorchJob.Name, + "--nproc_per_node=" + strconv.Itoa(int(procPerNode)), + "--nnodes=" + strconv.Itoa(int(minReplicas)) + ":" + strconv.Itoa(int(maxReplicas))} + for i := range podTemplate.Spec.Containers { if len(podTemplate.Spec.Containers[i].Env) == 0 { podTemplate.Spec.Containers[i].Env = make([]corev1.EnvVar, 0) @@ -285,6 +319,11 @@ func (r *PytorchJobReconciler) SetClusterSpec(ctx context.Context, job interface Name: "PYTHONUNBUFFERED", Value: "0", }) + + if pytorchJob.Spec.EnableElastic && pytorchJob.Spec.ElasticPolicy != nil { + podTemplate.Spec.Containers[i].Args = append(launchElasticArgs, podTemplate.Spec.Containers[i].Args...) + } + if enableElasticScaling && rtype != "aimaster" { // Job enables elastic scaling select value of AnnotationWorldSize as its // WORLD_SIZE env value via field-path, the annotated value will be mutated diff --git a/controllers/pytorch/util.go b/controllers/pytorch/util.go index 6880fb19..12de2221 100644 --- a/controllers/pytorch/util.go +++ b/controllers/pytorch/util.go @@ -16,9 +16,23 @@ limitations under the License. package pytorch -import training "github.com/alibaba/kubedl/apis/training/v1alpha1" +import ( + "fmt" + training "github.com/alibaba/kubedl/apis/training/v1alpha1" + v1 "github.com/alibaba/kubedl/pkg/job_controller/api/v1" +) func ContainMasterSpec(job *training.PyTorchJob) bool { _, ok := job.Spec.PyTorchReplicaSpecs[training.PyTorchReplicaTypeMaster] return ok } + +// computeDesiredReplicas retrieve user's replica setting in specs +func computeDesiredReplicas(elasticJob *training.PyTorchJob) (int32, error) { + workerSpecs, exist := elasticJob.Spec.PyTorchReplicaSpecs[v1.ReplicaType(training.PyTorchReplicaTypeMaster)] + if !exist { + return 0, fmt.Errorf("elasticJob %v doesn't have %s", elasticJob, training.PyTorchReplicaTypeMaster) + } + + return *workerSpecs.Replicas, nil +} diff --git a/controllers/torchelastic/elastic.go b/controllers/torchelastic/elastic.go new file mode 100644 index 00000000..d1c7d90f --- /dev/null +++ b/controllers/torchelastic/elastic.go @@ -0,0 +1,223 @@ +/* +Copyright 2022 The Alibaba Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package torchelastic + +import ( + "context" + training "github.com/alibaba/kubedl/apis/training/v1alpha1" + apiv1 "github.com/alibaba/kubedl/pkg/job_controller/api/v1" + logger "github.com/sirupsen/logrus" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/types" + "reflect" +) + +func (ts *TorchElasticController) start(ctx context.Context, cancel context.CancelFunc, name, namespace string) { + sharedPytorchJob := &training.PyTorchJob{} + jobName := name + jobNamespace := namespace + + // Create metrics for each torch elastic job. + ts.locker.Lock() + if _, ok := ts.metrics[jobName]; !ok { + ts.metrics[jobName] = make(map[int32][]MetricObservation) + } + ts.locker.Unlock() + + err := ts.Client.Get(ctx, types.NamespacedName{Namespace: jobNamespace, Name: jobName}, sharedPytorchJob) + if err != nil { + logger.Infof("try to get job %s from namespace %s but it has been deleted", jobName, jobNamespace) + // cancel the elastic scaling process context of the deleted job. + defer cancel() + return + } + + pytorchJob := sharedPytorchJob.DeepCopy() + if pytorchJob.Spec.ElasticPolicy.MaxReplicas == nil || pytorchJob.Spec.ElasticPolicy.MinReplicas == nil { + logger.Infof("pytorch job %s does not configure the max or min replicas", pytorchJob.Name) + defer cancel() + delete(ts.torchElasticJobs, makeElasticJobName(pytorchJob.Name, pytorchJob.Namespace)) + return + } + + if pytorchJob.Status.ElasticStatus == nil { + initializeElasticStatuses(pytorchJob, training.PyTorchReplicaTypeWorker) + if err := ts.UpdateJobStatusInApiServer(pytorchJob, &pytorchJob.Status); err != nil { + if errors.IsConflict(err) { + // retry later when update operation violates with etcd concurrency control. + log.Info("fail to update pytorch job") + } + } + return + } + + jobStatus := pytorchJob.Status.DeepCopy() + oldStatus := jobStatus.DeepCopy() + if pytorchJob.Status.CompletionTime != nil || pytorchJob.DeletionTimestamp != nil { + logger.Infof("job %s has been completed or deleted and does not need to do elastic scaling", pytorchJob.Name) + defer cancel() + delete(ts.torchElasticJobs, makeElasticJobName(pytorchJob.Name, pytorchJob.Namespace)) + delete(ts.metrics, makeElasticJobName(pytorchJob.Name, pytorchJob.Namespace)) + return + } + + currentReplicas := *pytorchJob.Spec.PyTorchReplicaSpecs[training.PyTorchReplicaTypeWorker].Replicas + + // Wait for all pods running and judge whether there exists pending or failed pods. + hasPendingPod, hasFailedPod := ts.waitForAllPodsRunning(pytorchJob) + + // If job has pending pods and current replicas are more than min replicas, return to the last replicas. + if hasPendingPod && currentReplicas > *pytorchJob.Spec.ElasticPolicy.MinReplicas { + lastReplicas := jobStatus.ElasticStatus[training.PyTorchReplicaTypeWorker].LastReplicas + *pytorchJob.Spec.PyTorchReplicaSpecs[training.PyTorchReplicaTypeWorker].Replicas = lastReplicas + // Return to the last replicas. + if err := ts.Client.Update(ctx, pytorchJob); err != nil { + log.Info("fail to update replicas of pytorch job") + } + + updateElasticStatusForPendingJob(pytorchJob, lastReplicas, training.PyTorchReplicaTypeWorker) + if err := ts.UpdateJobStatusInApiServer(pytorchJob, &pytorchJob.Status); err != nil { + if errors.IsConflict(err) { + // retry later when update operation violates with etcd concurrency control. + log.Info("fail to update pytorch job") + } + } + return + + // If job has pending pods and current replicas equals to the min replicas, cancel the elastic scaling process context. + } else if (hasPendingPod && currentReplicas == *pytorchJob.Spec.ElasticPolicy.MinReplicas) || hasFailedPod { + defer cancel() + logger.Info("pods did not reach the running state at min replicas or job is failed, so the elastic scaling controller shutdown") + delete(ts.torchElasticJobs, makeElasticJobName(pytorchJob.Name, pytorchJob.Namespace)) + return + } + + if !hasPendingPod && jobStatus.ElasticStatus != nil && jobStatus.ElasticStatus[training.PyTorchReplicaTypeWorker].Continue == false { + // If job metrics have reached the max, restart stale pods. + if jobStatus.ElasticStatus[training.PyTorchReplicaTypeWorker].ElasticCondition == apiv1.ElasticMaxMetric { + pods, err := ts.GetPodsForJob(pytorchJob) + if err != nil { + logger.Warnf("Get Pods For Job error %v", err) + } + // Restart stale torch elastic pods. + complete := ts.restartStalePytorchPods(pods, pytorchJob) + if !complete { + logger.Info("restart pods does not complete") + return + } + logger.Info("restart pods has completed") + jobStatus.ElasticStatus[training.PyTorchReplicaTypeWorker].ElasticCondition = apiv1.ElasticStop + if err = ts.UpdateJobStatusInApiServer(pytorchJob, jobStatus); err != nil { + if errors.IsConflict(err) { + // retry later when update operation violates with etcd concurrency control. + logger.Info("fail to update pytorch job status") + return + } + } + return + // If current replicas reach the defined max replicas or elastic condition is stopped, return directly. + } else if jobStatus.ElasticStatus[training.PyTorchReplicaTypeWorker].ElasticCondition == apiv1.ElasticStop || jobStatus.ElasticStatus[training.PyTorchReplicaTypeWorker].ElasticCondition == apiv1.ElasticMaxReplica { + log.Info("Pytorch job does not need to be scaled") + return + } + } + + // Read training logs from pytorch pods and save the observation. + observation, err := read(ts.client, jobNamespace, GetDefaultWorkerName(jobName)) + if err != nil { + logger.Infof("fail to read training logs: %v", err) + return + } + + ts.locker.Lock() + defer ts.locker.Unlock() + + // Create metrics for current replicas. + if _, ok := ts.metrics[jobName][currentReplicas]; !ok { + ts.metrics[jobName][currentReplicas] = make([]MetricObservation, 0) + } + ts.metrics[jobName][currentReplicas] = append(ts.metrics[jobName][currentReplicas], observation) + currentLength := len(ts.metrics[jobName][currentReplicas]) + logger.Infof("Current metric length: %d", currentLength) + + // If current metrics have reached the metric count, judge the next scaling replicas. + if currentLength >= ts.metricCount { + if currentReplicas > *pytorchJob.Spec.ElasticPolicy.MinReplicas && currentReplicas <= *pytorchJob.Spec.ElasticPolicy.MaxReplicas { + lastReplicas := jobStatus.ElasticStatus[training.PyTorchReplicaTypeWorker].LastReplicas + + if ts.IsSatisfyElasticContinue(jobName, currentReplicas, lastReplicas) { + if currentReplicas == *pytorchJob.Spec.ElasticPolicy.MaxReplicas { + updateElasticStatusForMaxReplicaJob(pytorchJob, training.PyTorchReplicaTypeWorker) + ts.metrics[jobName][currentReplicas] = make([]MetricObservation, 0) + } else { + newReplicas := computeNewReplicas(currentReplicas) + *pytorchJob.Spec.PyTorchReplicaSpecs[training.PyTorchReplicaTypeWorker].Replicas = newReplicas + if err := ts.Client.Update(ctx, pytorchJob); err != nil { + log.Info("fail to update pytorch job") + } + + updateElasticStatusForContinueJob(pytorchJob, currentReplicas, newReplicas, training.PyTorchReplicaTypeWorker) + if _, ok := ts.metrics[jobName][newReplicas]; !ok { + ts.metrics[jobName][newReplicas] = make([]MetricObservation, 0) + } + } + + } else { + *pytorchJob.Spec.PyTorchReplicaSpecs[training.PyTorchReplicaTypeWorker].Replicas = lastReplicas + if err := ts.Client.Update(ctx, pytorchJob); err != nil { + log.Info("fail to update pytorch job") + } + + updateElasticStatusForMaxMetricJob(pytorchJob, currentReplicas, lastReplicas, training.PyTorchReplicaTypeWorker) + ts.metrics[jobName][lastReplicas] = make([]MetricObservation, 0) + ts.metrics[jobName][currentReplicas] = make([]MetricObservation, 0) + } + + } else if currentReplicas == *pytorchJob.Spec.ElasticPolicy.MinReplicas && currentReplicas < *pytorchJob.Spec.ElasticPolicy.MaxReplicas { + newReplicas := computeNewReplicas(currentReplicas) + *pytorchJob.Spec.PyTorchReplicaSpecs[training.PyTorchReplicaTypeWorker].Replicas = newReplicas + if err := ts.Client.Update(ctx, pytorchJob); err != nil { + log.Info("fail to update pytorch job") + } + + updateElasticStatusForContinueJob(pytorchJob, currentReplicas, newReplicas, training.PyTorchReplicaTypeWorker) + if _, ok := ts.metrics[jobName][newReplicas]; !ok { + ts.metrics[jobName][newReplicas] = make([]MetricObservation, 0) + } + + } else if currentReplicas == *pytorchJob.Spec.ElasticPolicy.MaxReplicas { + updateElasticStatusForMaxReplicaJob(pytorchJob, training.PyTorchReplicaTypeWorker) + if _, ok := ts.metrics[jobName][currentReplicas]; ok { + ts.metrics[jobName][currentReplicas] = make([]MetricObservation, 0) + } + + } + } + + // No need to update the job status if the status hasn't changed since last time. + if !reflect.DeepEqual(*oldStatus, pytorchJob.Status) { + if err = ts.UpdateJobStatusInApiServer(pytorchJob, &pytorchJob.Status); err != nil { + if errors.IsConflict(err) { + // retry later when update operation violates with etcd concurrency control. + logger.Info("fail to update pytorch job status") + return + } + } + } + + return +} diff --git a/controllers/torchelastic/elastic_controller.go b/controllers/torchelastic/elastic_controller.go new file mode 100644 index 00000000..c65204b5 --- /dev/null +++ b/controllers/torchelastic/elastic_controller.go @@ -0,0 +1,34 @@ +/* +Copyright 2022 The Alibaba Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package torchelastic + +import ( + ctrl "sigs.k8s.io/controller-runtime" +) + +func SetupWithManager(mgr ctrl.Manager) error { + // New torch elastic controller. + // period represents the time elastic scaling loop repeats. + // count represents the length of training metrics collection for each replica. + torchElasticController := NewTorchElasticController(mgr, 30, 5) + + if err := torchElasticController.SetupWithManager(mgr); err != nil { + return err + } + return nil + +} diff --git a/controllers/torchelastic/job.go b/controllers/torchelastic/job.go new file mode 100644 index 00000000..3817a7a9 --- /dev/null +++ b/controllers/torchelastic/job.go @@ -0,0 +1,126 @@ +/* +Copyright 2022 The Alibaba Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package torchelastic + +import ( + "context" + "fmt" + training "github.com/alibaba/kubedl/apis/training/v1alpha1" + apiv1 "github.com/alibaba/kubedl/pkg/job_controller/api/v1" + commonutil "github.com/alibaba/kubedl/pkg/util" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +func makeElasticJobName(name, namespace string) string { + return name + "-" + namespace +} + +// UpdateJobStatusInApiServer updates the job status in API server +func (ts *TorchElasticController) UpdateJobStatusInApiServer(job interface{}, jobStatus *apiv1.JobStatus) error { + torchElasticJob, ok := job.(*training.PyTorchJob) + if !ok { + return fmt.Errorf("%+v is not a type of PytorchJob", torchElasticJob) + } + var jobCpy *training.PyTorchJob + // Job status passed in differs with status in job, update in basis of the passed in one. + jobCpy = torchElasticJob.DeepCopy() + jobCpy.Status = *jobStatus.DeepCopy() + return ts.Status().Update(context.Background(), jobCpy) +} + +// initializeReplicaStatuses initializes the ElasticStatuses for replica. +func initializeElasticStatuses(pytorchJob *training.PyTorchJob, rtype apiv1.ReplicaType) { + jobStatus := &pytorchJob.Status + if jobStatus.ElasticStatus == nil { + jobStatus.ElasticStatus = make(map[apiv1.ReplicaType]*apiv1.ElasticScalingStatus) + } + + jobStatus.ElasticStatus[rtype] = &apiv1.ElasticScalingStatus{ElasticCondition: apiv1.ElasticStart} + jobStatus.ElasticStatus[rtype].CurrentReplicas = *pytorchJob.Spec.PyTorchReplicaSpecs[rtype].Replicas + jobStatus.ElasticStatus[rtype].Continue = true + now := metav1.Now() + jobStatus.ElasticStatus[rtype].LastUpdateTime = &now +} + +func updateElasticStatusForPendingJob(pytorchJob *training.PyTorchJob, lastReplicas int32, rtype apiv1.ReplicaType) { + jobStatus := &pytorchJob.Status + jobStatus.ElasticStatus[rtype].Continue = false + jobStatus.ElasticStatus[rtype].LastReplicas = jobStatus.ElasticStatus[training.PyTorchReplicaTypeWorker].CurrentReplicas + jobStatus.ElasticStatus[rtype].CurrentReplicas = lastReplicas + jobStatus.ElasticStatus[rtype].Message = "There exists pending pods, return to the last replicas" + now := metav1.Now() + jobStatus.ElasticStatus[rtype].LastUpdateTime = &now + jobStatus.ElasticStatus[rtype].ElasticCondition = apiv1.ElasticStop +} + +func updateElasticStatusForContinueJob(pytorchJob *training.PyTorchJob, currentReplicas, newReplicas int32, rtype apiv1.ReplicaType) { + jobStatus := &pytorchJob.Status + jobStatus.ElasticStatus[rtype].LastReplicas = currentReplicas + jobStatus.ElasticStatus[rtype].CurrentReplicas = newReplicas + jobStatus.ElasticStatus[rtype].Message = "Pytorch job continues to be scaled" + now := metav1.Now() + jobStatus.ElasticStatus[rtype].LastUpdateTime = &now + jobStatus.ElasticStatus[rtype].Continue = true + jobStatus.ElasticStatus[rtype].ElasticCondition = apiv1.ElasticContinue +} + +func updateElasticStatusForMaxReplicaJob(pytorchJob *training.PyTorchJob, rtype apiv1.ReplicaType) { + jobStatus := &pytorchJob.Status + jobStatus.ElasticStatus[rtype].Message = "Pytorch job has reached the max replicas" + jobStatus.ElasticStatus[rtype].Continue = false + jobStatus.ElasticStatus[rtype].ElasticCondition = apiv1.ElasticMaxReplica +} + +func updateElasticStatusForMaxMetricJob(pytorchJob *training.PyTorchJob, currentReplicas, lastReplicas int32, rtype apiv1.ReplicaType) { + jobStatus := &pytorchJob.Status + jobStatus.ElasticStatus[rtype].CurrentReplicas = lastReplicas + jobStatus.ElasticStatus[rtype].LastReplicas = currentReplicas + jobStatus.ElasticStatus[rtype].Message = "Pytorch job has reached the max metrics" + now := metav1.Now() + jobStatus.ElasticStatus[rtype].LastUpdateTime = &now + jobStatus.ElasticStatus[rtype].Continue = false + jobStatus.ElasticStatus[rtype].ElasticCondition = apiv1.ElasticMaxMetric +} + +func (ts *TorchElasticController) IsSatisfyElasticContinue(jobName string, currentReplicas, lastReplicas int32) bool { + currentLength := ts.metricCount + currentLatency := ts.metrics[jobName][currentReplicas][currentLength-1].Latency + lastReplicaLatency := ts.metrics[jobName][lastReplicas][currentLength-1].Latency + //Decide whether the elastic scaling can continue by the ratio of batch training latency and replicas. + return lastReplicaLatency/float64(lastReplicas) > currentLatency/float64(currentReplicas) +} + +func computeNewReplicas(currentReplicas int32) int32 { + // Double the replicas in the next elastic scaling loop. + return currentReplicas * 2 +} + +func (ts *TorchElasticController) GetPodsForJob(job *training.PyTorchJob) ([]*v1.Pod, error) { + selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{ + MatchLabels: ts.GenLabels(job.Name), + }) + // List all pods to include those that don't match the selector anymore + // but have a ControllerRef pointing to this controller. + podList := &v1.PodList{} + err = ts.Client.List(context.Background(), podList, client.MatchingLabelsSelector{Selector: selector}) + if err != nil { + return nil, err + } + return commonutil.ToPodPointerList(podList.Items), nil +} diff --git a/controllers/torchelastic/job_elastic_controller.go b/controllers/torchelastic/job_elastic_controller.go new file mode 100644 index 00000000..2a88db2b --- /dev/null +++ b/controllers/torchelastic/job_elastic_controller.go @@ -0,0 +1,37 @@ +/* +Copyright 2022 The Alibaba Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package torchelastic + +import ( + "k8s.io/apimachinery/pkg/runtime" + ctrl "sigs.k8s.io/controller-runtime" + logf "sigs.k8s.io/controller-runtime/pkg/log" +) + +// ElasticController implementations +type ElasticController interface { + SetupWithManager(mgr ctrl.Manager) error +} + +var _ ElasticController = &TorchElasticController{} + +type newJobElasticController func(mgr ctrl.Manager, period, count int) ElasticController + +var ( + log = logf.Log.WithName("job-elastic-controller") + jobElasticCtrlMap = make(map[runtime.Object]newJobElasticController) +) diff --git a/controllers/torchelastic/log_util.go b/controllers/torchelastic/log_util.go new file mode 100644 index 00000000..37c86899 --- /dev/null +++ b/controllers/torchelastic/log_util.go @@ -0,0 +1,105 @@ +package torchelastic + +import ( + "bufio" + "context" + "fmt" + logger "github.com/sirupsen/logrus" + "io" + v1 "k8s.io/api/core/v1" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/kubernetes/scheme" + + "regexp" + "strconv" + "strings" +) + +// Construct a request for getting the logs for a pod and retrieves the logs. +func readRawLogs(client kubernetes.Interface, namespace, podID string, logOptions *v1.PodLogOptions) (string, error) { + readCloser, err := openStream(client, namespace, podID, logOptions) + if err != nil { + return err.Error(), nil + } + + defer func(readCloser io.ReadCloser) { + err := readCloser.Close() + if err != nil { + return + } + }(readCloser) + + reader := bufio.NewReader(readCloser) + line, err := reader.ReadString('\n') + if err != nil { + + return err.Error(), nil + } + + return line, nil +} + +func openStream(client kubernetes.Interface, namespace, podID string, logOptions *v1.PodLogOptions) (io.ReadCloser, error) { + return client.CoreV1().RESTClient().Get(). + Namespace(namespace). + Name(podID). + Resource("pods"). + SubResource("log"). + VersionedParams(logOptions, scheme.ParameterCodec).Stream(context.TODO()) +} + +func podRunning(pod *v1.Pod) bool { + return pod.Status.Phase == v1.PodRunning +} + +func GetDefaultWorkerName(pytorchJobName string) string { + return pytorchJobName + "-" + "worker" + "-" + "0" +} + +func read(client *kubernetes.Clientset, namespace, name string) (MetricObservation, error) { + lines := int64(1) + opts := &v1.PodLogOptions{ + TailLines: &lines, + Follow: true, + } + + //Read raw pod log. + detail, err := readRawLogs(client, namespace, name, opts) + if err != nil { + return MetricObservation{}, err + } + //Extract training metrics from raw log. + rawLog := strings.Split(detail, "\t") + epochRule := regexp.MustCompile(`[0-9]{1,2}`) + batchRule := regexp.MustCompile(`[0-9]{2,4}`) + trainRule := regexp.MustCompile(`[0-9]{1,2}.[0-9]{3}`) + accRule := regexp.MustCompile(`[0-9]{1,2}.[0-9]{1,2}`) + matchTrain, err := regexp.MatchString(`Epoch`, rawLog[0]) + + if err != nil { + return MetricObservation{}, err + } + // If current log is a training log. + if matchTrain { + epochNum, _ := strconv.Atoi(epochRule.FindStringSubmatch(rawLog[0])[0]) + batchNum, _ := strconv.Atoi(batchRule.FindStringSubmatch(rawLog[0])[0]) + trainTime, _ := strconv.ParseFloat(trainRule.FindStringSubmatch(rawLog[1])[0], 64) + accuracy, _ := strconv.ParseFloat(accRule.FindStringSubmatch(rawLog[5])[0], 64) + + observation := MetricObservation{ + Accuracy: accuracy, + Epoch: int32(epochNum), + Latency: trainTime, + Batch: int32(batchNum), + } + // drop the inaccurate train data + if trainTime > 1 { + return MetricObservation{}, fmt.Errorf("drop the inaccurate train data") + } + + logger.Infof("epoch: %d batch: %d train_time: %f accuracy: %f", epochNum, batchNum, trainTime, accuracy) + return observation, nil + } + + return MetricObservation{}, fmt.Errorf("current log is not a training log") +} diff --git a/controllers/torchelastic/pod.go b/controllers/torchelastic/pod.go new file mode 100644 index 00000000..70d7141d --- /dev/null +++ b/controllers/torchelastic/pod.go @@ -0,0 +1,269 @@ +/* +Copyright 2022 The Alibaba Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package torchelastic + +import ( + "context" + trainingv1alpha1 "github.com/alibaba/kubedl/apis/training/v1alpha1" + "github.com/alibaba/kubedl/controllers/pytorch" + v1 "github.com/alibaba/kubedl/pkg/job_controller/api/v1" + "github.com/alibaba/kubedl/pkg/util/concurrent" + "github.com/alibaba/kubedl/pkg/util/k8sutil" + patchutil "github.com/alibaba/kubedl/pkg/util/patch" + kruisev1alpha1 "github.com/openkruise/kruise/apis/apps/v1alpha1" + logger "github.com/sirupsen/logrus" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/utils/pointer" + "strconv" + "strings" +) + +func (ts *TorchElasticController) recreatePodContainers(job *trainingv1alpha1.PyTorchJob, pod *corev1.Pod, generation string) error { + crr := kruisev1alpha1.ContainerRecreateRequest{ + ObjectMeta: metav1.ObjectMeta{ + Name: pod.Name, + Namespace: pod.Namespace, + Labels: map[string]string{ + v1.LabelGeneration: generation, + }, + OwnerReferences: []metav1.OwnerReference{ + { + APIVersion: "v1", + Kind: "Pod", + Name: pod.Name, + UID: pod.UID, + Controller: pointer.BoolPtr(false), + BlockOwnerDeletion: pointer.BoolPtr(true), + }, + { + APIVersion: job.APIVersion, + Kind: job.Kind, + Name: job.Name, + UID: job.UID, + Controller: pointer.BoolPtr(false), + BlockOwnerDeletion: pointer.BoolPtr(true), + }, + }, + }, + Spec: kruisev1alpha1.ContainerRecreateRequestSpec{ + PodName: pod.Name, + Strategy: &kruisev1alpha1.ContainerRecreateRequestStrategy{OrderedRecreate: false}, + }, + } + + for ci := range pod.Spec.Containers { + container := &pod.Spec.Containers[ci] + crr.Spec.Containers = append(crr.Spec.Containers, kruisev1alpha1.ContainerRecreateRequestContainer{Name: container.Name}) + } + return ts.Client.Create(context.Background(), &crr) +} + +func (ts *TorchElasticController) restartStaleWorker(job *trainingv1alpha1.PyTorchJob, pod *corev1.Pod, worldSize, generation int64) (completed bool, err error) { + expectedWorldSize := strconv.FormatInt(worldSize, 10) + expectedGeneration := strconv.FormatInt(generation, 10) + podKey := pod.Namespace + "/" + pod.Name + + if job.Annotations[pytorch.AnnotationReadyToRestartWorker] == "true" && !k8sutil.IsPodActive(pod) { + err = ts.Client.Delete(context.Background(), pod) + return err == nil, err + } + + if pod.Labels[v1.LabelGeneration] == expectedGeneration { + return true, nil + } + + log.Info("refresh stale pod to latest generation", "pod", podKey, "generation", generation) + + completed, err = ts.restartWorkerInKruiseProtocol(job, pod, expectedWorldSize, expectedGeneration) + if !completed { + return false, err + } + + // Finally, incremental generation for current worker and mark refreshment done. + patch := patchutil.NewStrategicPatch() + patch.InsertLabel(v1.LabelGeneration, expectedGeneration) + err = ts.Client.Patch(context.Background(), pod, patch) + if err != nil { + return false, err + } + logger.Infof("succeed to refresh pod to generation: %v", generation) + return true, nil +} + +func (ts *TorchElasticController) restartWorkerInKruiseProtocol(job *trainingv1alpha1.PyTorchJob, pod *corev1.Pod, expectedWorldSize, expectedGeneration string) (completed bool, err error) { + podKey := pod.Namespace + "/" + pod.Name + crr := kruisev1alpha1.ContainerRecreateRequest{} + if curWorldSize, ok := pod.Annotations[pytorch.AnnotationWorldSize]; !ok || curWorldSize != expectedWorldSize { + log.Info("update latest world size of pytorch", + "key", podKey, "current world size", curWorldSize, "target world size", expectedWorldSize) + patch := patchutil.NewStrategicPatch() + patch.InsertAnnotation(pytorch.AnnotationWorldSize, expectedWorldSize) + if err = ts.Client.Patch(context.Background(), pod, patch); err != nil { + log.Error(err, "failed to refresh world-size of stale worker", "pod", podKey, "world size", expectedWorldSize) + return false, err + } + return false, nil + } + + if err = ts.Client.Get(context.Background(), types.NamespacedName{Name: pod.Name, Namespace: pod.Namespace}, &crr); err != nil { + if errors.IsNotFound(err) { + logger.Info("Not found ContainerRecreateRequest") + return false, ts.recreatePodContainers(job, pod, expectedGeneration) + } + + log.Error(err, "failed to get latest container-recreate-request for stale worker", + "pod", podKey) + return false, err + } + // crr created in previous round, clean it. + if crr.Labels[v1.LabelGeneration] != expectedGeneration { + if err = ts.Client.Delete(context.Background(), &crr); err != nil { + return false, err + } + return false, ts.recreatePodContainers(job, pod, expectedGeneration) + } + + if crr.Status.Phase == kruisev1alpha1.ContainerRecreateRequestFailed { + logger.Infof("failed to restart containers of pod %s/%s, fallback to recreate pod", pod.Namespace, pod.Name) + err = ts.Client.Delete(context.Background(), pod) + return err == nil, err + } + + recreateDone := crr.Status.Phase == kruisev1alpha1.ContainerRecreateRequestCompleted || crr.Status.Phase == kruisev1alpha1.ContainerRecreateRequestSucceeded + if !recreateDone { + logger.Error("container recreate request has not completed yet", "pod", podKey) + return false, nil + } + + // Finalize container-recreate-request object once it completes, because elastic scaling is repeatable + // and 'crr' request will be re-initiated. + defer ts.Client.Delete(context.Background(), &crr) + + logger.Info("ContainerRecreateSucceed", "succeed to recreate containers in stale worker: %s", podKey) + return true, nil +} + +func FilterRunningPods(pods []*corev1.Pod) []*corev1.Pod { + var result []*corev1.Pod + for _, p := range pods { + if podRunning(p) { + result = append(result, p) + } else { + deletionTimeStamp := "N/A" + if p.DeletionTimestamp != nil { + deletionTimeStamp = p.DeletionTimestamp.String() + } + logger.Infof("Ignoring inactive pod %v/%v in state %v, deletion time %s", + p.Namespace, p.Name, p.Status.Phase, deletionTimeStamp) + } + } + return result +} + +func (ts *TorchElasticController) restartStalePytorchPods(pods []*corev1.Pod, pytorchJob *trainingv1alpha1.PyTorchJob) (completed bool) { + + runningPods := FilterRunningPods(pods) + _, stalePods := k8sutil.FilterStalePodsByReplicaType(runningPods, pytorchJob.Generation, strings.ToLower(string(v1.JobReplicaTypeAIMaster))) + staleWorkers := stalePods[strings.ToLower(string(trainingv1alpha1.PyTorchReplicaTypeWorker))] + totalReplicas := len(stalePods) + workerNums := len(staleWorkers) + logger.Infof("worker nums: %d", workerNums) + + if pytorchJob.Annotations[pytorch.AnnotationReadyToRestartWorker] == "false" { + log.Info("PytorchJob does not need to restart workers") + return false + } + + tickets := 100 // max semaphore tickets limited. + if len(staleWorkers) < 100 { + tickets = len(staleWorkers) + } + sema := concurrent.NewSemaphore(tickets) + for _, pod := range staleWorkers { + sema.Acquire() + + go func(worker *corev1.Pod) { + defer sema.Release() + if completed, err := ts.restartStaleWorker(pytorchJob, worker, int64(totalReplicas), pytorchJob.Generation); err != nil { + logger.Warnf("Restart worker %s failed becasue error %v", worker.Name, err) + } else if completed { + workerNums-- + } + }(pod) + } + // block until all semaphore is released. + sema.Wait() + if workerNums != 0 { + log.Info("refresh stale workers has not completed yet", "key", pytorchJob.Namespace+"/"+pytorchJob.Name) + return false + } + + if len(stalePods) == 0 || workerNums == 0 { + log.Info("all pods are in latest generation, mark ready-to-start-worker as false") + patch := patchutil.NewMergePatch() + patch.InsertAnnotation(pytorch.AnnotationReadyToRestartWorker, "false") + + if err := ts.Client.Patch(context.Background(), pytorchJob, patch); err != nil { + logger.Infof("fail to patch pytorchJob: %v", err) + return false + } + logger.Infof("pytorch job %s/%s elastic scaling successfully finished, total replicas: %v", pytorchJob.Namespace, pytorchJob.Name, totalReplicas) + completed = true + } + + return completed + +} + +func (ts *TorchElasticController) waitForAllPodsRunning(pytorchJob *trainingv1alpha1.PyTorchJob) (hasPendingPod, hasFailedPod bool) { + pods, err := ts.GetPodsForJob(pytorchJob) + if err != nil { + logger.Warnf("Get Pods For Job error %v", err) + } + + // Wait for all pods running with timeout seconds. + waitErr := wait.PollImmediate(interval, podReadyTimeout, func() (bool, error) { + for _, pod := range pods { + if isRunning := podRunning(pod); !isRunning { + return false, nil + } + } + return true, nil + }) + + if waitErr != nil { + logger.Info("pods did not reach the running state") + } + + for _, pod := range pods { + if pod.Status.Phase == corev1.PodPending { + hasPendingPod = true + break + } + } + for _, pod := range pods { + if pod.Status.Phase == corev1.PodFailed { + hasFailedPod = true + break + } + } + return +} diff --git a/controllers/torchelastic/torchelastic_controller.go b/controllers/torchelastic/torchelastic_controller.go new file mode 100644 index 00000000..d1c52070 --- /dev/null +++ b/controllers/torchelastic/torchelastic_controller.go @@ -0,0 +1,216 @@ +/* +Copyright 2022 The Alibaba Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package torchelastic + +import ( + "context" + training "github.com/alibaba/kubedl/apis/training/v1alpha1" + apiv1 "github.com/alibaba/kubedl/pkg/job_controller/api/v1" + "github.com/alibaba/kubedl/pkg/util/concurrent" + logger "github.com/sirupsen/logrus" + "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/wait" + "k8s.io/client-go/kubernetes" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/controller" + "sigs.k8s.io/controller-runtime/pkg/event" + "sigs.k8s.io/controller-runtime/pkg/handler" + "sigs.k8s.io/controller-runtime/pkg/predicate" + "sigs.k8s.io/controller-runtime/pkg/source" + "strings" + "sync" + "time" +) + +const ( + controllerName = "TorchElasticController" + interval = 5 * time.Second + podReadyTimeout = 1 * time.Minute +) + +type name string +type namespace string + +func init() { + jobElasticCtrlMap[&training.PyTorchJob{}] = NewTorchElasticController +} + +func NewTorchElasticController(mgr ctrl.Manager, period, count int) ElasticController { + metrics := make(map[string]map[int32][]MetricObservation) + torchJobs := make(map[string]TorchElasticJob) + return &TorchElasticController{ + period: period, + metricCount: count, + client: kubernetes.NewForConfigOrDie(mgr.GetConfig()), + metrics: metrics, + torchElasticJobs: torchJobs, + Client: mgr.GetClient(), + } +} + +type TorchElasticController struct { + period int + metricCount int + client *kubernetes.Clientset + client.Client + locker sync.Mutex + // metrics stores observations collected from running pods + metrics map[string]map[int32][]MetricObservation + // torchElasticJobs stores torch-elastic jobs infos. + torchElasticJobs map[string]TorchElasticJob +} + +// TorchElasticJob represents one elastic job. +type TorchElasticJob struct { + Name string + Namespace string + ctx context.Context + cancelFunc context.CancelFunc +} + +// MetricObservation represents one metric set collected from training pods. +type MetricObservation struct { + Epoch int32 `json:"epoch,omitempty"` + Batch int32 `json:"batch,omitempty"` + Accuracy float64 `json:"accuracy,omitempty"` + Latency float64 `json:"latency,omitempty"` +} + +func (ts *TorchElasticController) Reconcile(_ context.Context, req ctrl.Request) (ctrl.Result, error) { + pytorchJob := training.PyTorchJob{} + err := ts.Client.Get(context.Background(), types.NamespacedName{ + Namespace: req.Namespace, + Name: req.Name, + }, &pytorchJob) + + if err != nil { + if errors.IsNotFound(err) { + log.Info("try to fetch pytorch job but it has been deleted.", "key", req.String()) + return ctrl.Result{}, nil + } + return ctrl.Result{}, err + } + return ctrl.Result{}, nil +} + +func (ts *TorchElasticController) SetupWithManager(mgr ctrl.Manager) error { + c, err := controller.New(controllerName, mgr, controller.Options{Reconciler: ts}) + if err != nil { + return err + } + // Watch events with pod events-handler. + if err = c.Watch(&source.Kind{Type: &training.PyTorchJob{}}, &handler.EnqueueRequestForObject{}, predicate.Funcs{ + CreateFunc: onOwnerCreateFunc(ts), + DeleteFunc: onOwnerDeleteFunc(ts), + }); err != nil { + return err + } + + ctx := context.Background() + go wait.UntilWithContext(ctx, ts.startElasticForAllJobs, time.Duration(ts.period)*(time.Second)) + log.Info("Start Elastic Scaling Controller Loop") + + ctx.Done() + log.Info("Shutting down Elastic Scaling Controller Loop") + return nil +} + +func onOwnerCreateFunc(ts *TorchElasticController) func(e event.CreateEvent) bool { + return func(e event.CreateEvent) bool { + pytorchJob, ok := e.Object.(*training.PyTorchJob) + if !ok { + return true + } + if !pytorchJob.Spec.EnableElastic && pytorchJob.Spec.ElasticPolicy == nil { + return true + } + ctx, cancel := context.WithCancel(context.Background()) + ctx = context.WithValue(ctx, name("job"), pytorchJob.Name) + ctx = context.WithValue(ctx, namespace("namespace"), pytorchJob.Namespace) + logger.Info("Create torch elastic job: ", pytorchJob.Name, " in namespace: ", pytorchJob.Namespace) + ts.torchElasticJobs[makeElasticJobName(pytorchJob.Name, pytorchJob.Namespace)] = TorchElasticJob{ + Name: pytorchJob.Name, + Namespace: pytorchJob.Namespace, + ctx: ctx, + cancelFunc: cancel, + } + return true + } +} + +func onOwnerDeleteFunc(ts *TorchElasticController) func(e event.DeleteEvent) bool { + return func(e event.DeleteEvent) bool { + pytorchJob, ok := e.Object.(*training.PyTorchJob) + if !ok { + return true + } + if !pytorchJob.Spec.EnableElastic && pytorchJob.Spec.ElasticPolicy == nil { + return true + } + + logger.Infof("Deleting elastic scaling for pytorch job %s from namespace %s", pytorchJob.Name, pytorchJob.Namespace) + // Delete job infos saved in Torch Elastic controller. + if _, ok := ts.torchElasticJobs[makeElasticJobName(pytorchJob.Name, pytorchJob.Namespace)]; ok { + cancel := ts.torchElasticJobs[makeElasticJobName(pytorchJob.Name, pytorchJob.Namespace)].cancelFunc + defer cancel() + delete(ts.torchElasticJobs, makeElasticJobName(pytorchJob.Name, pytorchJob.Namespace)) + delete(ts.metrics, makeElasticJobName(pytorchJob.Name, pytorchJob.Namespace)) + } + + return true + } +} + +// Start elastic scaling loop for all torch elastic jobs. +func (ts *TorchElasticController) startElasticForAllJobs(ctx context.Context) { + tickets := 100 // max semaphore tickets limited. + if len(ts.torchElasticJobs) < 100 { + tickets = len(ts.torchElasticJobs) + } + sema := concurrent.NewSemaphore(tickets) + for _, torchJob := range ts.torchElasticJobs { + sema.Acquire() + + go func(job TorchElasticJob) { + defer sema.Release() + //Start elastic scaling for each torch elastic job. + ts.start(job.ctx, job.cancelFunc, job.Name, job.Namespace) + }(torchJob) + } + // block until all semaphore is released. + sema.Wait() +} + +func (ts *TorchElasticController) GenLabels(jobName string) map[string]string { + labelGroupName := apiv1.GroupNameLabel + labelJobName := apiv1.JobNameLabel + groupName := ts.GetGroupNameLabelValue() + return map[string]string{ + labelGroupName: groupName, + labelJobName: strings.Replace(jobName, "/", "-", -1), + } +} + +func (ts *TorchElasticController) GetGroupNameLabelValue() string { + return training.SchemeGroupVersion.Group +} + +func (ts *TorchElasticController) ControllerName() string { + return controllerName +} diff --git a/docs/tutorial/torchelasticjob.md b/docs/tutorial/torchelasticjob.md new file mode 100644 index 00000000..9ba967b2 --- /dev/null +++ b/docs/tutorial/torchelasticjob.md @@ -0,0 +1,322 @@ +# Run a Torch Elastic Job with KubeDL Operator + +This tutorial walks you through an example to run a Torch Elastic Job. [Torch Elastic](https://pytorch.org/elastic/0.2.0/) enables distributed PyTorch training jobs to be executed in a fault-tolerant and elastic manner. +Torch Elastic can be used in the following cases: +- Fault Tolerance: jobs that run on infrastructure where nodes get replaced frequently, either due to flaky hardware or by design. Or mission critical production grade jobs that need to be run with resilience to failures. +- Dynamic Capacity Management: jobs that run on leased capacity that can be taken away at any time (e.g. AWS spot instances) or shared pools where the pool size can change dynamically based on demand. + +## Requirements + +#### 1. Deploy KubeDL +Follow the [installation tutorial](https://github.com/alibaba/kubedl#getting-started) in README and deploy `kubedl` operator to cluster. + +#### 2. Apply Pytorch Job CRD + +`PytorchJob` CRD(CustomResourceDefinition) manifest file describes the structure of a pytorch job spec. Run the following to apply the CRD: + +```bash +kubectl apply -f https://raw.githubusercontent.com/alibaba/kubedl/master/config/crd/bases/training.kubedl.io_pytorchjobs.yaml +``` + +## Run Torch Elastic Job on Kubernetes with KubeDL + +Run `torch elastic` job on kubernetes natively. +### 1. Deploy an etcd server. +Create a new namespace `elastic-job`. +```bash +kubectl create ns elastic-job +``` +Create an etcd server and a service `etcd-service` with port 2379. +```yaml +apiVersion: v1 +kind: Service +metadata: + name: etcd-service + namespace: elastic-job +spec: + ports: + - name: etcd-client-port + port: 2379 + protocol: TCP + targetPort: 2379 + selector: + app: etcd + +--- +apiVersion: v1 +kind: Pod +metadata: + labels: + app: etcd + name: etcd + namespace: elastic-job +spec: + containers: + - command: + - /usr/local/bin/etcd + - --data-dir + - /var/lib/etcd + - --enable-v2 + - --listen-client-urls + - http://0.0.0.0:2379 + - --advertise-client-urls + - http://0.0.0.0:2379 + - --initial-cluster-state + - new + image: k8s.gcr.io/etcd:3.5.1-0 + name: etcd + ports: + - containerPort: 2379 + name: client + protocol: TCP + - containerPort: 2380 + name: server + protocol: TCP + restartPolicy: Always + +``` + +```bash +kubectl apply -f etcd.yaml +``` +Get the etcd server endpoint: +```bash +$ kubectl get svc -n elastic-job + +NAME TYPE CLUSTER-IP EXTERNAL-IP PORT(S) AGE +etcd-service ClusterIP 10.96.170.111 2379/TCP 3h15m +``` + +### 2. Create a Torch Elastic Job and checkpoint persistent volume +Create a PV and PVC YAML spec that describes the storage dictionary of checkpoint model. +```yaml +apiVersion: v1 +kind: PersistentVolume +metadata: + name: pv-torch-checkpoint + namespace: elastic-job +spec: + capacity: + storage: 5Gi + volumeMode: Filesystem + accessModes: + - ReadWriteMany + persistentVolumeReclaimPolicy: Retain + ... + +--- +kind: PersistentVolumeClaim +apiVersion: v1 +metadata: + name: pvc-torch-checkpoint + namespace: elastic-job +spec: + accessModes: + - ReadWriteMany + resources: + requests: + storage: 5Gi + volumeName: pv-torch-checkpoint + ... +``` +```bash +kubectl create -f pv.yaml +``` + + +Create a YAML spec that describes the specifications of a Torch Elastic Job such as the ElasticPolicy, master, worker and volumes like below + +```yaml +apiVersion: training.kubedl.io/v1alpha1 +kind: "PyTorchJob" +metadata: + name: "resnet" + namespace: elastic-job +spec: + enableElastic: true + elasticPolicy: + rdzvBackend: etcd + rdzvEndpoint: "etcd-service:2379" + minReplicas: 1 + maxReplicas: 3 + nProcPerNode: 1 + pytorchReplicaSpecs: + Master: + replicas: 1 + restartPolicy: ExitCode + template: + spec: + containers: + - name: pytorch + image: kubedl/pytorch-dist-example + imagePullPolicy: Always + Worker: + replicas: 1 + restartPolicy: OnFailure + template: + spec: + volumes: + - name: checkpoint + persistentVolumeClaim: + claimName: pvc-torch-checkpoint + containers: + - name: pytorch + image: wanziyu/imagenet:1.1 + imagePullPolicy: Always + args: + - "/workspace/examples/imagenet/main.py" + - "--arch=resnet50" + - "--epochs=20" + - "--batch-size=64" + - "--print-freq=50" + # number of data loader workers (NOT trainers) + # zero means load the data on the same process as the trainer + # this is set so that the container does not OOM since + # pytorch data loaders use shm + - "--workers=0" + - "/workspace/data/tiny-imagenet-200" + - "--checkpoint-file=/mnt/blob/data/checkpoint.pth.tar" + resources: + limits: + nvidia.com/gpu: 1 + volumeMounts: + - name: checkpoint + mountPath: "/mnt/blob/data" +``` + +The `spec.enableElastic` field describes whether user enables the KubeDL torch elastic controller or not. When `enableElastic` field is true and `elasticPolicy` is not empty, the elastic scaling process for this job will be started. + +The `spec.elasticPolicy` field specifies the elastic policy including rdzv_backend, rdzv_endpoint, minimum replicas and maximum replicas. +The `rdzvEndpoint` can be set to the etcd service. The `minReplicas` and `maxReplicas` should be set to the desired min and max num nodes (max should not exceed your cluster capacity). + +### 3. Submit the Torch Elastic Training Job +```bash +kubectl create -f example/torchelastic/torchelastic-resnet.yaml +``` +Check the initial torch elastic training pod status: +```bash +$ kubectl get pod -n elastic-job + +NAME READY STATUS RESTARTS AGE +etcd 1/1 Running 0 3h56m +resnet-master-0 1/1 Running 0 12s +resnet-worker-0 1/1 Running 0 7s +``` + +Check the initial service status: +```bash +$ kubectl get svc -n elastic-job + +NAME TYPE CLUSTER-IP EXTERNAL-IP PORT(S) AGE +etcd-service ClusterIP 10.96.170.111 2379/TCP 3h57m +resnet-master-0 ClusterIP None 23456/TCP 79s +resnet-worker-0 ClusterIP None 23456/TCP 74s +``` +Check the pytorchJob status: +```bash +$ kubectl describe pytorchjob resnet -n elastic-job + +... +Status: + Conditions: + Last Transition Time: 2022-09-11T12:08:29Z + Last Update Time: 2022-09-11T12:08:29Z + Message: PyTorchJob resnet is running. + Reason: JobRunning + Status: True + Type: Running + Elastic Scaling: + Worker: + Continue: true + Current Replicas: 1 + Elastic Condition: Start + Start Time: 2022-09-11T12:08:54Z + Replica Statuses: + Master: + Active: 1 + Worker: + Active: 1 +``` +The `Elastic Scaling` field describes the current elastic scaling status of torch elastic training job. The `Elastic Condition` field indicates whether the elastic scaling workflow continues. + +### 3. Watch the elastic scaling process for torch elastic job +The elastic scaling controller continuously collects the real-time training metrics and decides whether job replicas(ranging from min to max replicas) can be further increased. +If the following conditions satisfy, job will return to the last replicas and the controller will stop the further scaling. + +**1. There exists pending pods when scaling in new pods.** +```bash +$ kubectl get pod -n elastic-job + +etcd 1/1 Running 0 6h +resnet-master-0 1/1 Running 0 6m53s +resnet-worker-0 1/1 Running 0 6m47s +resnet-worker-1 1/1 Running 0 3m44s +resnet-worker-2 0/1 Pending 0 14s +``` +The elastic scaling status is: +```bash + ... + Elastic Scaling: + Worker: + Continue: false + Current Replicas: 2 + Elastic Condition: Stop + Last Replicas: 3 + Message: There exists pending pods, return to the last replicas + Start Time: 2022-09-11T14:13:55Z +``` +The `Message` shows the reason that the elastic scaling process stops. + +**2. The training metrics have reached the best values and the job does not need to be further scaled.** + +```bash + ... + Elastic Scaling: + Worker: + Continue: false + Current Replicas: 2 + Elastic Condition: ReachMaxMetric + Last Replicas: 4 + Message: Pytorch job has reached the max metrics + Start Time: 2022-09-11T12:15:24Z +``` +Meanwhile, the elastic scaling process will be stopped. + + + +If the elastic scaling process can continue and job replicas can be further increased, the elastic status is as below. + +```bash +$ kubectl get pod -n elastic-job + +etcd 1/1 Running 0 5h57m +resnet-master-0 1/1 Running 0 3m25s +resnet-worker-0 1/1 Running 0 3m19s +resnet-worker-1 1/1 Running 0 16s +``` +```bash + ... + Elastic Scaling: + Worker: + Continue: true + Current Replicas: 2 + Elastic Condition: Continue + Last Replicas: 1 + Message: Pytorch job continues to be scaled + Start Time: 2022-09-11T12:11:54Z +``` +Currently, the scaling algorithm is based on the real-time batch training latency collected from running pod logs. The logs of distributed training pods are like as follows. +```bash +Epoch: [17][ 0/1563] Time 5.969 ( 5.969) Data 0.214 ( 0.214) Loss 2.1565e+00 (2.1565e+00) Acc@1 53.12 ( 53.12) Acc@5 75.00 ( 75.00) +Epoch: [17][ 50/1563] Time 0.258 ( 0.385) Data 0.155 ( 0.170) Loss 2.5284e+00 (2.5905e+00) Acc@1 42.19 ( 39.71) Acc@5 64.06 ( 65.81) +Epoch: [17][ 100/1563] Time 0.260 ( 0.323) Data 0.158 ( 0.164) Loss 2.4015e+00 (2.6419e+00) Acc@1 45.31 ( 38.58) Acc@5 70.31 ( 64.96) +Epoch: [17][ 150/1563] Time 0.256 ( 0.302) Data 0.153 ( 0.161) Loss 2.9381e+00 (2.6560e+00) Acc@1 34.38 ( 38.15) Acc@5 64.06 ( 64.82) +Epoch: [17][ 200/1563] Time 0.296 ( 0.295) Data 0.189 ( 0.163) Loss 2.5786e+00 (2.6778e+00) Acc@1 35.94 ( 37.52) Acc@5 68.75 ( 64.30) +Epoch: [17][ 250/1563] Time 0.313 ( 0.291) Data 0.202 ( 0.165) Loss 2.6223e+00 (2.6837e+00) Acc@1 39.06 ( 37.52) Acc@5 62.50 ( 64.20) +Epoch: [17][ 300/1563] Time 0.263 ( 0.286) Data 0.159 ( 0.164) Loss 2.7830e+00 (2.7005e+00) Acc@1 40.62 ( 37.18) Acc@5 57.81 ( 63.83) +Epoch: [17][ 350/1563] Time 0.267 ( 0.284) Data 0.163 ( 0.164) Loss 2.8693e+00 (2.7060e+00) Acc@1 39.06 ( 37.34) Acc@5 57.81 ( 63.62) +Epoch: [17][ 400/1563] Time 0.259 ( 0.281) Data 0.155 ( 0.163) Loss 3.0643e+00 (2.7000e+00) Acc@1 28.12 ( 37.36) Acc@5 50.00 ( 63.68) +Epoch: [17][ 450/1563] Time 0.294 ( 0.280) Data 0.189 ( 0.164) Loss 2.4482e+00 (2.7056e+00) Acc@1 43.75 ( 37.21) Acc@5 70.31 ( 63.57) +``` +If you want to change the +training log format in your Python codes, you can revise the regular expression search formula defined in `log_util.go` to extract the training metrics you specify. \ No newline at end of file diff --git a/example/pytorch/torchelastic/etcd.yaml b/example/pytorch/torchelastic/etcd.yaml new file mode 100644 index 00000000..bd787ece --- /dev/null +++ b/example/pytorch/torchelastic/etcd.yaml @@ -0,0 +1,45 @@ +apiVersion: v1 +kind: Service +metadata: + name: etcd-service + namespace: elastic-job +spec: + ports: + - name: etcd-client-port + port: 2379 + protocol: TCP + targetPort: 2379 + selector: + app: etcd + +--- +apiVersion: v1 +kind: Pod +metadata: + labels: + app: etcd + name: etcd + namespace: elastic-job +spec: + containers: + - command: + - /usr/local/bin/etcd + - --data-dir + - /var/lib/etcd + - --enable-v2 + - --listen-client-urls + - http://0.0.0.0:2379 + - --advertise-client-urls + - http://0.0.0.0:2379 + - --initial-cluster-state + - new + image: k8s.gcr.io/etcd:3.5.1-0 + name: etcd + ports: + - containerPort: 2379 + name: client + protocol: TCP + - containerPort: 2380 + name: server + protocol: TCP + restartPolicy: Always diff --git a/example/pytorch/torchelastic/torchelastic.yaml b/example/pytorch/torchelastic/torchelastic.yaml new file mode 100644 index 00000000..9a3c5b18 --- /dev/null +++ b/example/pytorch/torchelastic/torchelastic.yaml @@ -0,0 +1,55 @@ +apiVersion: training.kubedl.io/v1alpha1 +kind: "PyTorchJob" +metadata: + name: "resnet" + namespace: elastic-job +spec: + enableElastic: true + elasticPolicy: + rdzvBackend: etcd + rdzvEndpoint: "etcd-service:2379" + minReplicas: 1 + maxReplicas: 3 + nProcPerNode: 1 + pytorchReplicaSpecs: + Master: + replicas: 1 + restartPolicy: ExitCode + template: + spec: + containers: + - name: pytorch + image: kubedl/pytorch-dist-example + imagePullPolicy: Always + Worker: + replicas: 1 + restartPolicy: OnFailure + template: + spec: + volumes: + - name: checkpoint + persistentVolumeClaim: + claimName: pvc-torch-checkpoint + containers: + - name: pytorch + image: wanziyu/imagenet:1.1 + imagePullPolicy: Always + args: + - "/workspace/examples/imagenet/main.py" + - "--arch=resnet50" + - "--epochs=20" + - "--batch-size=64" + - "--print-freq=50" + # number of data loader workers (NOT trainers) + # zero means load the data on the same process as the trainer + # this is set so that the container does not OOM since + # pytorch data loaders use shm + - "--workers=0" + - "/workspace/data/tiny-imagenet-200" + - "--checkpoint-file=/mnt/blob/data/checkpoint.pth.tar" + resources: + limits: + nvidia.com/gpu: 1 + volumeMounts: + - name: checkpoint + mountPath: "/mnt/blob/data" \ No newline at end of file diff --git a/main.go b/main.go index 80472b38..d059ae94 100644 --- a/main.go +++ b/main.go @@ -17,6 +17,7 @@ limitations under the License. package main import ( + "github.com/alibaba/kubedl/controllers/torchelastic" "os" "k8s.io/apimachinery/pkg/util/net" @@ -117,6 +118,11 @@ func main() { os.Exit(1) } + if err = torchelastic.SetupWithManager(mgr); err != nil { + setupLog.Error(err, "unable to setup elastic scaling controllers") + os.Exit(1) + } + // Start monitoring for default registry. metrics.StartMonitoringForDefaultRegistry(metricsAddr) diff --git a/pkg/job_controller/api/v1/types.go b/pkg/job_controller/api/v1/types.go index 2ac1acb7..cf920dc1 100644 --- a/pkg/job_controller/api/v1/types.go +++ b/pkg/job_controller/api/v1/types.go @@ -41,6 +41,10 @@ type JobStatus struct { // It is represented in RFC3339 form and is in UTC. CompletionTime *metav1.Time `json:"completionTime,omitempty"` + // Represents the elastic scaling status for training jobs, + // specifies the status of current elastic scaling. + ElasticStatus map[ReplicaType]*ElasticScalingStatus `json:"elasticScaling,omitempty"` + // Represents last time when the job was reconciled. It is not guaranteed to // be set in happens-before order across separate operations. // It is represented in RFC3339 form and is in UTC. @@ -57,6 +61,28 @@ type JobStatus struct { // own set of ReplicaTypes. type ReplicaType string +// ElasticScalingStatus represents the current elastic scaling status of the training job. +// +k8s:deepcopy-gen=true +type ElasticScalingStatus struct { + // Type of elastic scaling condition. + ElasticCondition ElasticConditionType `json:"elasticCondition,omitempty"` + + // Continue represents whether the job needs to continue scaling. + Continue bool `json:"continue,omitempty"` + + // The number of current scaling pod replicas. + CurrentReplicas int32 `json:"currentReplicas,omitempty"` + + // The number of last scaling pod replicas. + LastReplicas int32 `json:"lastReplicas,omitempty"` + + // The time this elastic scaling loop was started. + LastUpdateTime *metav1.Time `json:"startTime,omitempty"` + + // A human readable message indicating details about the transition. + Message string `json:"message,omitempty"` +} + // ReplicaStatus represents the current observed state of the replica. type ReplicaStatus struct { // The number of actively running pods. @@ -131,6 +157,23 @@ type JobCondition struct { LastTransitionTime metav1.Time `json:"lastTransitionTime,omitempty"` } +// ElasticConditionType defines all kinds of elastic scaling conditions. +type ElasticConditionType string + +const ( + ElasticJobPending ElasticConditionType = "HasPendingPod" + // ElasticStart means the elastic scaling has been started. + ElasticStart ElasticConditionType = "Start" + // ElasticStop means the elastic scaling has been stopped. + ElasticStop ElasticConditionType = "Stop" + // ElasticContinue means the elastic scaling continues. + ElasticContinue ElasticConditionType = "Continue" + // ElasticMaxMetric means the training metrics have reached the max. + ElasticMaxMetric ElasticConditionType = "ReachMaxMetric" + // ElasticMaxReplica means the replicas have reached the maxReplicas. + ElasticMaxReplica ElasticConditionType = "ReachMaxReplicas" +) + // JobConditionType defines all kinds of types of JobStatus. type JobConditionType string diff --git a/pkg/job_controller/job.go b/pkg/job_controller/job.go index c6319b66..da7e2093 100644 --- a/pkg/job_controller/job.go +++ b/pkg/job_controller/job.go @@ -317,10 +317,15 @@ func (jc *JobController) ReconcileJobs(job interface{}, replicas map[apiv1.Repli continue } - // Service is in need only for Master - if jc.Controller.GetAPIGroupVersionKind().Kind == training.PyTorchJobKind && - rtype != training.PyTorchReplicaTypeMaster { - continue + if jc.Controller.GetAPIGroupVersionKind().Kind == training.PyTorchJobKind { + pytorchJob, ok := job.(*training.PyTorchJob) + if !ok { + log.Warnf("Job is not a type of PytorchJob %v", err) + } + // Service is in need only for pytorch Master + if !pytorchJob.Spec.EnableElastic && rtype != training.PyTorchReplicaTypeMaster { + continue + } } err = jc.ReconcileServices(ctx, metaObject, services, rtype, spec) diff --git a/pkg/job_controller/service.go b/pkg/job_controller/service.go index f5d652b2..7ac7af6d 100644 --- a/pkg/job_controller/service.go +++ b/pkg/job_controller/service.go @@ -162,11 +162,28 @@ func (jc *JobController) FilterServicesForReplicaType(services []*v1.Service, re return result, nil } +// calculateServiceSliceSize compare max pod index with desired replicas and return larger size +func calculateServiceSliceSize(services []*v1.Service, replicas int) int { + size := 0 + for _, svc := range services { + if _, ok := svc.Labels[apiv1.ReplicaIndexLabel]; !ok { + continue + } + index, err := strconv.Atoi(svc.Labels[apiv1.ReplicaIndexLabel]) + if err != nil { + continue + } + size = maxInt(size, index) + } + // size comes from index, need to +1 to indicate real size + return maxInt(size+1, replicas) +} + // GetServiceSlices returns a slice, which element is the slice of service. // Assume the return object is serviceSlices, then serviceSlices[i] is an // array of pointers to services corresponding to Services for replica i. func (jc *JobController) GetServiceSlices(services []*v1.Service, replicas int, logger *log.Entry) [][]*v1.Service { - serviceSlices := make([][]*v1.Service, replicas) + serviceSlices := make([][]*v1.Service, calculateServiceSliceSize(services, replicas)) for _, service := range services { if _, ok := service.Labels[apiv1.ReplicaIndexLabel]; !ok { logger.Warning("The service do not have the index label.") diff --git a/pkg/job_controller/util.go b/pkg/job_controller/util.go index 61229548..239fd31f 100644 --- a/pkg/job_controller/util.go +++ b/pkg/job_controller/util.go @@ -75,3 +75,10 @@ func ReplicaTypes(specs map[v1.ReplicaType]*v1.ReplicaSpec) []v1.ReplicaType { } return replicas } + +func maxInt(x, y int) int { + if x < y { + return y + } + return x +}