diff --git a/pkg/webhook/server/generic/operationjob/operationjob_validating_handler.go b/pkg/webhook/server/generic/operationjob/operationjob_validating_handler.go index 81676ac3..8c7e0d44 100644 --- a/pkg/webhook/server/generic/operationjob/operationjob_validating_handler.go +++ b/pkg/webhook/server/generic/operationjob/operationjob_validating_handler.go @@ -26,6 +26,7 @@ import ( "k8s.io/apimachinery/pkg/api/equality" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/runtime/inject" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" @@ -127,29 +128,14 @@ func (h *ValidatingHandler) validateOpsTarget(instance, old *appsv1alpha1.Operat func (h *ValidatingHandler) validatePartition(instance, old *appsv1alpha1.OperationJob, fldPath *field.Path) field.ErrorList { var allErrors field.ErrorList - var currPartition, oldPartition int32 - var currTotalReplicas, oldTotalReplicas int32 - - currTotalReplicas = int32(len(instance.Spec.Targets)) - if instance.Spec.Partition == nil { - currPartition = currTotalReplicas - } else { - currPartition = minInt32(*instance.Spec.Partition, currTotalReplicas) - } - - oldTotalReplicas = int32(len(old.Spec.Targets)) - if old.Spec.Partition == nil { - oldPartition = oldTotalReplicas - } else { - oldPartition = minInt32(*old.Spec.Partition, oldTotalReplicas) - } + oldPartition := ptr.Deref(old.Spec.Partition, 0) + curPartition := ptr.Deref(instance.Spec.Partition, 0) - if currPartition < 0 { - allErrors = append(allErrors, field.Invalid(fldPath.Child("partition"), currPartition, "should not be negative")) - } else if currPartition < oldPartition { - allErrors = append(allErrors, field.Invalid(fldPath.Child("partition"), currPartition, fmt.Sprintf("should not be decreased (from %d to %d)", oldPartition, currPartition))) + if curPartition < 0 { + allErrors = append(allErrors, field.Invalid(fldPath, curPartition, "should not be negative")) + } else if oldPartition > curPartition { + allErrors = append(allErrors, field.Invalid(fldPath, curPartition, fmt.Sprintf("should not be decreased. (from %d to %d)", oldPartition, curPartition))) } - return allErrors } @@ -165,10 +151,3 @@ func (h *ValidatingHandler) validateTTLAndActiveDeadline(instance *appsv1alpha1. } return allErrors } - -func minInt32(a, b int32) int32 { - if a < b { - return a - } - return b -}