diff --git a/actions/actions.go b/actions/actions.go index e4a69f48..1af65195 100644 --- a/actions/actions.go +++ b/actions/actions.go @@ -10,7 +10,6 @@ import ( "sync" "time" - "github.com/cenkalti/backoff/v4" "github.com/sirupsen/logrus" "k8s.io/client-go/dynamic" "k8s.io/client-go/kubernetes" @@ -18,6 +17,7 @@ import ( "github.com/castai/cluster-controller/castai" "github.com/castai/cluster-controller/health" "github.com/castai/cluster-controller/helm" + "github.com/castai/cluster-controller/waitext" ) const ( @@ -132,16 +132,19 @@ func (s *service) doWork(ctx context.Context) error { iteration int ) - b := backoff.WithContext(backoff.WithMaxRetries(backoff.NewConstantBackOff(5*time.Second), 3), ctx) - errR := backoff.Retry(func() error { + boff := waitext.NewConstantBackoff(5 * time.Second) + + errR := waitext.Retry(ctx, boff, 3, func(ctx context.Context) (bool, error) { iteration++ actions, err = s.castAIClient.GetActions(ctx, s.k8sVersion) if err != nil { - s.log.Errorf("polling actions: get action request failed: iteration: %v %v", iteration, err) - return err + return true, err } - return nil - }, b) + return false, nil + }, func(err error) { + s.log.Errorf("polling actions: get action request failed: iteration: %v %v", iteration, err) + }) + if errR != nil { return fmt.Errorf("polling actions: %w", err) } @@ -242,21 +245,16 @@ func (s *service) ackAction(ctx context.Context, action *castai.ClusterAction, h "type": actionType.String(), }).Info("ack action") - return backoff.RetryNotify(func() error { + boff := waitext.NewConstantBackoff(s.cfg.AckRetryWait) + + return waitext.Retry(ctx, boff, s.cfg.AckRetriesCount, func(ctx context.Context) (bool, error) { ctx, cancel := context.WithTimeout(ctx, s.cfg.AckTimeout) defer cancel() - return s.castAIClient.AckAction(ctx, action.ID, &castai.AckClusterActionRequest{ + return true, s.castAIClient.AckAction(ctx, action.ID, &castai.AckClusterActionRequest{ Error: getHandlerError(handleErr), }) - }, backoff.WithContext( - backoff.WithMaxRetries( - backoff.NewConstantBackOff(s.cfg.AckRetryWait), uint64(s.cfg.AckRetriesCount), - ), - ctx, - ), func(err error, duration time.Duration) { - if err != nil { - s.log.Debugf("ack failed, will retry: %v", err) - } + }, func(err error) { + s.log.Debugf("ack failed, will retry: %v", err) }) } diff --git a/actions/approve_csr_handler.go b/actions/approve_csr_handler.go index 72059920..eddeaf9b 100644 --- a/actions/approve_csr_handler.go +++ b/actions/approve_csr_handler.go @@ -7,12 +7,17 @@ import ( "reflect" "time" - "github.com/cenkalti/backoff/v4" "github.com/sirupsen/logrus" + "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/kubernetes" "github.com/castai/cluster-controller/castai" "github.com/castai/cluster-controller/csr" + "github.com/castai/cluster-controller/waitext" +) + +const ( + approveCSRTimeout = 4 * time.Minute ) func newApproveCSRHandler(log logrus.FieldLogger, clientset kubernetes.Interface) ActionHandler { @@ -53,17 +58,21 @@ func (h *approveCSRHandler) Handle(ctx context.Context, action *castai.ClusterAc return nil } - b := backoff.WithContext( - newApproveCSRExponentialBackoff(), + ctx, cancel := context.WithTimeout(ctx, approveCSRTimeout) + defer cancel() + + b := newApproveCSRExponentialBackoff() + return waitext.Retry( ctx, - ) - return backoff.RetryNotify(func() error { - return h.handle(ctx, log, cert) - }, b, func(err error, duration time.Duration) { - if err != nil { + b, + waitext.Forever, + func(ctx context.Context) (bool, error) { + return true, h.handle(ctx, log, cert) + }, + func(err error) { log.Warnf("csr approval failed, will retry: %v", err) - } - }) + }, + ) } func (h *approveCSRHandler) handle(ctx context.Context, log logrus.FieldLogger, cert *csr.Certificate) (reterr error) { @@ -122,25 +131,28 @@ func (h *approveCSRHandler) getInitialNodeCSR(ctx context.Context, log logrus.Fi var cert *csr.Certificate var err error - logRetry := func(err error, _ time.Duration) { - log.Warnf("getting initial csr, will retry: %v", err) - } - b := backoff.WithMaxRetries(backoff.NewExponentialBackOff(), 3) - err = backoff.RetryNotify(func() error { - cert, err = poll() - if errors.Is(err, context.DeadlineExceeded) { - return backoff.Permanent(err) - } - return err - }, b, logRetry) + b := waitext.DefaultExponentialBackoff() + err = waitext.Retry( + ctx, + b, + 3, + func(ctx context.Context) (bool, error) { + cert, err = poll() + if errors.Is(err, context.DeadlineExceeded) { + return false, err + } + return true, err + }, + func(err error) { + log.Warnf("getting initial csr, will retry: %v", err) + }, + ) return cert, err } -func newApproveCSRExponentialBackoff() *backoff.ExponentialBackOff { - b := backoff.NewExponentialBackOff() - b.Multiplier = 2 - b.MaxElapsedTime = 4 * time.Minute - b.Reset() +func newApproveCSRExponentialBackoff() wait.Backoff { + b := waitext.DefaultExponentialBackoff() + b.Factor = 2 return b } diff --git a/actions/approve_csr_handler_test.go b/actions/approve_csr_handler_test.go index 7a4ef30b..6d54daa5 100644 --- a/actions/approve_csr_handler_test.go +++ b/actions/approve_csr_handler_test.go @@ -4,11 +4,11 @@ import ( "context" "errors" "fmt" - "github.com/google/uuid" "sync/atomic" "testing" "time" + "github.com/google/uuid" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" certv1 "k8s.io/api/certificates/v1" @@ -231,10 +231,10 @@ func TestApproveCSRExponentialBackoff(t *testing.T) { b := newApproveCSRExponentialBackoff() var sum time.Duration for i := 0; i < 10; i++ { - tmp := b.NextBackOff() + tmp := b.Step() sum += tmp } - r.Truef(100 < sum.Seconds(), "actual elapsed seconds %s", sum.Seconds()) + r.Truef(100 < sum.Seconds(), "actual elapsed seconds %v", sum.Seconds()) } func getCSR() *certv1.CertificateSigningRequest { diff --git a/actions/check_node_deleted.go b/actions/check_node_deleted.go index 7eac072f..c12cb4e5 100644 --- a/actions/check_node_deleted.go +++ b/actions/check_node_deleted.go @@ -7,17 +7,17 @@ import ( "reflect" "time" - "github.com/cenkalti/backoff/v4" "github.com/sirupsen/logrus" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" "github.com/castai/cluster-controller/castai" + "github.com/castai/cluster-controller/waitext" ) type checkNodeDeletedConfig struct { - retries uint64 + retries int retryWait time.Duration } @@ -52,35 +52,44 @@ func (h *checkNodeDeletedHandler) Handle(ctx context.Context, action *castai.Clu }) log.Info("checking if node is deleted") - b := backoff.WithContext(backoff.WithMaxRetries(backoff.NewConstantBackOff(h.cfg.retryWait), h.cfg.retries), ctx) - return backoff.Retry(func() error { - n, err := h.clientset.CoreV1().Nodes().Get(ctx, req.NodeName, metav1.GetOptions{}) - if apierrors.IsNotFound(err) { - return nil - } + boff := waitext.NewConstantBackoff(h.cfg.retryWait) - if n == nil { - return nil - } + return waitext.Retry( + ctx, + boff, + h.cfg.retries, + func(ctx context.Context) (bool, error) { + n, err := h.clientset.CoreV1().Nodes().Get(ctx, req.NodeName, metav1.GetOptions{}) + if apierrors.IsNotFound(err) { + return false, nil + } + + if n == nil { + return false, nil + } - currentNodeID, ok := n.Labels[castai.LabelNodeID] - if !ok { - log.Info("node doesn't have castai node id label") - } - if currentNodeID != "" { - if currentNodeID != req.NodeID { - log.Info("node name was reused. Original node is deleted") - return nil + currentNodeID, ok := n.Labels[castai.LabelNodeID] + if !ok { + log.Info("node doesn't have castai node id label") } - if currentNodeID == req.NodeID { - return backoff.Permanent(errors.New("node is not deleted")) + if currentNodeID != "" { + if currentNodeID != req.NodeID { + log.Info("node name was reused. Original node is deleted") + return false, nil + } + if currentNodeID == req.NodeID { + return false, errors.New("node is not deleted") + } } - } - if n != nil { - return backoff.Permanent(errors.New("node is not deleted")) - } + if n != nil { + return false, errors.New("node is not deleted") + } - return err - }, b) + return true, err + }, + func(err error) { + log.Warnf("node deletion check failed, will retry: %v", err) + }, + ) } diff --git a/actions/check_node_status.go b/actions/check_node_status.go index 790062a7..3f24e731 100644 --- a/actions/check_node_status.go +++ b/actions/check_node_status.go @@ -7,7 +7,6 @@ import ( "reflect" "time" - "github.com/cenkalti/backoff/v4" "github.com/sirupsen/logrus" corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -16,6 +15,7 @@ import ( "k8s.io/client-go/kubernetes" "github.com/castai/cluster-controller/castai" + "github.com/castai/cluster-controller/waitext" ) func newCheckNodeStatusHandler(log logrus.FieldLogger, clientset kubernetes.Interface) ActionHandler { @@ -64,42 +64,51 @@ func (h *checkNodeStatusHandler) checkNodeDeleted(ctx context.Context, log *logr } ctx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) defer cancel() - b := backoff.WithContext(backoff.NewExponentialBackOff(), ctx) - return backoff.Retry(func() error { - n, err := h.clientset.CoreV1().Nodes().Get(ctx, req.NodeName, metav1.GetOptions{}) - if apierrors.IsNotFound(err) { - return nil - } - // If node is nil - deleted - // If label is present and doesn't match - node was reused - deleted - // If label is present and matches - node is not deleted - // If label is not present and node is not nil - node is not deleted (potentially corrupted state) + b := waitext.DefaultExponentialBackoff() + return waitext.Retry( + ctx, + b, + waitext.Forever, + func(ctx context.Context) (bool, error) { + n, err := h.clientset.CoreV1().Nodes().Get(ctx, req.NodeName, metav1.GetOptions{}) + if apierrors.IsNotFound(err) { + return false, nil + } - if n == nil { - return nil - } + // If node is nil - deleted + // If label is present and doesn't match - node was reused - deleted + // If label is present and matches - node is not deleted + // If label is not present and node is not nil - node is not deleted (potentially corrupted state) - currentNodeID, ok := n.Labels[castai.LabelNodeID] - if !ok { - log.Info("node doesn't have castai node id label") - } - if currentNodeID != "" { - if currentNodeID != req.NodeID { - log.Info("node name was reused. Original node is deleted") - return nil + if n == nil { + return false, nil } - if currentNodeID == req.NodeID { - return backoff.Permanent(errors.New("node is not deleted")) + + currentNodeID, ok := n.Labels[castai.LabelNodeID] + if !ok { + log.Info("node doesn't have castai node id label") + } + if currentNodeID != "" { + if currentNodeID != req.NodeID { + log.Info("node name was reused. Original node is deleted") + return false, nil + } + if currentNodeID == req.NodeID { + return false, errors.New("node is not deleted") + } } - } - if n != nil { - return backoff.Permanent(errors.New("node is not deleted")) - } + if n != nil { + return false, errors.New("node is not deleted") + } - return err - }, b) + return true, err + }, + func(err error) { + h.log.Warnf("check node %s status failed, will retry: %v", req.NodeName, err) + }, + ) } func (h *checkNodeStatusHandler) checkNodeReady(ctx context.Context, log *logrus.Entry, req *castai.ActionCheckNodeStatus) error { diff --git a/actions/delete_node_handler.go b/actions/delete_node_handler.go index 38e25840..7b299775 100644 --- a/actions/delete_node_handler.go +++ b/actions/delete_node_handler.go @@ -7,7 +7,6 @@ import ( "reflect" "time" - "github.com/cenkalti/backoff/v4" "github.com/sirupsen/logrus" v1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -16,10 +15,11 @@ import ( "k8s.io/client-go/kubernetes" "github.com/castai/cluster-controller/castai" + "github.com/castai/cluster-controller/waitext" ) type deleteNodeConfig struct { - deleteRetries uint64 + deleteRetries int deleteRetryWait time.Duration podsTerminationWait time.Duration } @@ -63,54 +63,67 @@ func (h *deleteNodeHandler) Handle(ctx context.Context, action *castai.ClusterAc }) log.Info("deleting kubernetes node") - b := backoff.WithContext(backoff.WithMaxRetries(backoff.NewConstantBackOff(h.cfg.deleteRetryWait), h.cfg.deleteRetries), ctx) - err := backoff.Retry(func() error { - current, err := h.clientset.CoreV1().Nodes().Get(ctx, req.NodeName, metav1.GetOptions{}) - if err != nil { - if apierrors.IsNotFound(err) { - log.Info("node not found, skipping delete") - return nil + b := waitext.NewConstantBackoff(h.cfg.deleteRetryWait) + err := waitext.Retry( + ctx, + b, + h.cfg.deleteRetries, + func(ctx context.Context) (bool, error) { + current, err := h.clientset.CoreV1().Nodes().Get(ctx, req.NodeName, metav1.GetOptions{}) + if err != nil { + if apierrors.IsNotFound(err) { + log.Info("node not found, skipping delete") + return false, nil + } + return true, fmt.Errorf("error getting node: %w", err) } - return fmt.Errorf("error getting node: %w", err) - } - if val, ok := current.Labels[castai.LabelNodeID]; ok { - if val != "" && val != req.NodeID { - log.Infof("node id mismatch, expected %q got %q. Skipping delete.", req.NodeID, val) - return errNodeMismatch + if val, ok := current.Labels[castai.LabelNodeID]; ok { + if val != "" && val != req.NodeID { + log.Infof("node id mismatch, expected %q got %q. Skipping delete.", req.NodeID, val) + return true, errNodeMismatch + } } - } - err = h.clientset.CoreV1().Nodes().Delete(ctx, current.Name, metav1.DeleteOptions{}) - if apierrors.IsNotFound(err) { - log.Info("node not found, skipping delete") - return nil - } - return err - }, b) + err = h.clientset.CoreV1().Nodes().Delete(ctx, current.Name, metav1.DeleteOptions{}) + if apierrors.IsNotFound(err) { + log.Info("node not found, skipping delete") + return false, nil + } + return true, err + }, + func(err error) { + h.log.Warnf("error deleting kubernetes node, will retry: %v", err) + }, + ) if errors.Is(err, errNodeMismatch) { return nil } - if err != nil { return fmt.Errorf("error removing node %w", err) } - podsListing := backoff.WithContext(backoff.WithMaxRetries(backoff.NewConstantBackOff(h.cfg.podsTerminationWait), h.cfg.deleteRetries), ctx) + podsListingBackoff := waitext.NewConstantBackoff(h.cfg.podsTerminationWait) var pods []v1.Pod - err = backoff.Retry(func() error { - podList, err := h.clientset.CoreV1().Pods(metav1.NamespaceAll).List(ctx, metav1.ListOptions{ - FieldSelector: fields.SelectorFromSet(fields.Set{"spec.nodeName": req.NodeName}).String(), - }) - if err != nil { - return err - } - pods = podList.Items - return nil - - }, podsListing) - + err = waitext.Retry( + ctx, + podsListingBackoff, + h.cfg.deleteRetries, + func(ctx context.Context) (bool, error) { + podList, err := h.clientset.CoreV1().Pods(metav1.NamespaceAll).List(ctx, metav1.ListOptions{ + FieldSelector: fields.SelectorFromSet(fields.Set{"spec.nodeName": req.NodeName}).String(), + }) + if err != nil { + return true, err + } + pods = podList.Items + return false, nil + }, + func(err error) { + h.log.Warnf("error listing pods, will retry: %v", err) + }, + ) if err != nil { return fmt.Errorf("listing node pods %w", err) } @@ -128,17 +141,25 @@ func (h *deleteNodeHandler) Handle(ctx context.Context, action *castai.ClusterAc } // Cleanup of pods for which node has been removed. It should take a few seconds but added retry in case of network errors. - podsWait := backoff.WithContext(backoff.WithMaxRetries(backoff.NewConstantBackOff(h.cfg.podsTerminationWait), h.cfg.deleteRetries), ctx) - return backoff.Retry(func() error { - pods, err := h.clientset.CoreV1().Pods(metav1.NamespaceAll).List(ctx, metav1.ListOptions{ - FieldSelector: fields.SelectorFromSet(fields.Set{"spec.nodeName": req.NodeName}).String(), - }) - if err != nil { - return fmt.Errorf("unable to list pods for node %q err: %w", req.NodeName, err) - } - if len(pods.Items) > 0 { - return fmt.Errorf("waiting for %d pods to be terminated on node %v", len(pods.Items), req.NodeName) - } - return nil - }, podsWait) + podsWaitBackoff := waitext.NewConstantBackoff(h.cfg.podsTerminationWait) + return waitext.Retry( + ctx, + podsWaitBackoff, + h.cfg.deleteRetries, + func(ctx context.Context) (bool, error) { + pods, err := h.clientset.CoreV1().Pods(metav1.NamespaceAll).List(ctx, metav1.ListOptions{ + FieldSelector: fields.SelectorFromSet(fields.Set{"spec.nodeName": req.NodeName}).String(), + }) + if err != nil { + return true, fmt.Errorf("unable to list pods for node %q err: %w", req.NodeName, err) + } + if len(pods.Items) > 0 { + return true, fmt.Errorf("waiting for %d pods to be terminated on node %v", len(pods.Items), req.NodeName) + } + return false, nil + }, + func(err error) { + h.log.Warnf("error waiting for pods termination, will retry: %v", err) + }, + ) } diff --git a/actions/drain_node_handler.go b/actions/drain_node_handler.go index d8cb4a34..17dc7820 100644 --- a/actions/drain_node_handler.go +++ b/actions/drain_node_handler.go @@ -8,7 +8,6 @@ import ( "strings" "time" - "github.com/cenkalti/backoff/v4" "github.com/samber/lo" "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" @@ -23,6 +22,7 @@ import ( "k8s.io/kubectl/pkg/drain" "github.com/castai/cluster-controller/castai" + "github.com/castai/cluster-controller/waitext" ) const ( @@ -32,7 +32,7 @@ const ( type drainNodeConfig struct { podsDeleteTimeout time.Duration - podDeleteRetries uint64 + podDeleteRetries int podDeleteRetryDelay time.Duration podEvictRetryDelay time.Duration podsTerminationWaitRetryDelay time.Duration @@ -147,7 +147,7 @@ func (h *drainNodeHandler) taintNode(ctx context.Context, node *v1.Node) error { return nil } - err := patchNode(ctx, h.clientset, node, func(n *v1.Node) { + err := patchNode(ctx, h.log, h.clientset, node, func(n *v1.Node) { n.Spec.Unschedulable = true }) if err != nil { @@ -226,16 +226,25 @@ func (h *drainNodeHandler) sendPodsRequests(ctx context.Context, pods []v1.Pod, func (h *drainNodeHandler) listNodePodsToEvict(ctx context.Context, log logrus.FieldLogger, node *v1.Node) ([]v1.Pod, error) { var pods *v1.PodList - if err := backoff.Retry(func() error { - p, err := h.clientset.CoreV1().Pods(metav1.NamespaceAll).List(ctx, metav1.ListOptions{ - FieldSelector: fields.SelectorFromSet(fields.Set{"spec.nodeName": node.Name}).String(), - }) - if err != nil { - return err - } - pods = p - return nil - }, defaultBackoff(ctx)); err != nil { + err := waitext.Retry( + ctx, + defaultBackoff(), + defaultMaxRetriesK8SOperation, + func(ctx context.Context) (bool, error) { + p, err := h.clientset.CoreV1().Pods(metav1.NamespaceAll).List(ctx, metav1.ListOptions{ + FieldSelector: fields.SelectorFromSet(fields.Set{"spec.nodeName": node.Name}).String(), + }) + if err != nil { + return true, err + } + pods = p + return false, nil + }, + func(err error) { + log.Warnf("listing pods on node %s: %v", node.Name, err) + }, + ) + if err != nil { return nil, fmt.Errorf("listing node %v pods: %w", node.Name, err) } @@ -269,23 +278,31 @@ func (h *drainNodeHandler) listNodePodsToEvict(ctx context.Context, log logrus.F } func (h *drainNodeHandler) waitNodePodsTerminated(ctx context.Context, log logrus.FieldLogger, node *v1.Node) error { - return backoff.Retry(func() error { - pods, err := h.listNodePodsToEvict(ctx, log, node) - if err != nil { - return fmt.Errorf("waiting for node %q pods to be terminated: %w", node.Name, err) - } - if len(pods) > 0 { - return fmt.Errorf("waiting for %d pods to be terminated on node %v", len(pods), node.Name) - } - return nil - }, backoff.WithContext(backoff.NewConstantBackOff(h.cfg.podsTerminationWaitRetryDelay), ctx)) + return waitext.Retry( + ctx, + waitext.NewConstantBackoff(h.cfg.podsTerminationWaitRetryDelay), + waitext.Forever, + func(ctx context.Context) (bool, error) { + pods, err := h.listNodePodsToEvict(ctx, log, node) + if err != nil { + return true, fmt.Errorf("listing %q pods to be terminated: %w", node.Name, err) + } + if len(pods) > 0 { + return true, fmt.Errorf("waiting for %d pods to be terminated on node %v", len(pods), node.Name) + } + return false, nil + }, + func(err error) { + h.log.Warnf("waiting for pod termination on node %v, will retry: %v", node.Name, err) + }, + ) } // evictPod from the k8s node. Error handling is based on eviction api documentation: // https://kubernetes.io/docs/tasks/administer-cluster/safely-drain-node/#the-eviction-api func (h *drainNodeHandler) evictPod(ctx context.Context, pod v1.Pod, groupVersion schema.GroupVersion) error { - b := backoff.WithContext(backoff.NewConstantBackOff(h.cfg.podEvictRetryDelay), ctx) // nolint:gomnd - action := func() error { + b := waitext.NewConstantBackoff(h.cfg.podEvictRetryDelay) + action := func(ctx context.Context) (bool, error) { var err error if groupVersion == policyv1.SchemeGroupVersion { err = h.clientset.PolicyV1().Evictions(pod.Namespace).Evict(ctx, &policyv1.Eviction{ @@ -310,44 +327,51 @@ func (h *drainNodeHandler) evictPod(ctx context.Context, pod v1.Pod, groupVersio if err != nil { // Pod is not found - ignore. if apierrors.IsNotFound(err) { - return nil + return false, nil } // Pod is misconfigured - stop retry. if apierrors.IsInternalError(err) { - return backoff.Permanent(err) + return false, err } } // Other errors - retry. - return err + return true, err } - if err := backoff.Retry(action, b); err != nil { + err := waitext.Retry(ctx, b, waitext.Forever, action, func(err error) { + h.log.Warnf("evict pod %s on node %s in namespace %s, will retry: %v", pod.Name, pod.Spec.NodeName, pod.Namespace, err) + }) + if err != nil { return fmt.Errorf("evicting pod %s in namespace %s: %w", pod.Name, pod.Namespace, err) + } return nil } func (h *drainNodeHandler) deletePod(ctx context.Context, options metav1.DeleteOptions, pod v1.Pod) error { - b := backoff.WithContext(backoff.WithMaxRetries(backoff.NewConstantBackOff(h.cfg.podDeleteRetryDelay), h.cfg.podDeleteRetries), ctx) // nolint:gomnd - action := func() error { + b := waitext.NewConstantBackoff(h.cfg.podDeleteRetryDelay) + action := func(ctx context.Context) (bool, error) { err := h.clientset.CoreV1().Pods(pod.Namespace).Delete(ctx, pod.Name, options) if err != nil { // Pod is not found - ignore. if apierrors.IsNotFound(err) { - return nil + return false, nil } // Pod is misconfigured - stop retry. if apierrors.IsInternalError(err) { - return backoff.Permanent(err) + return false, err } } // Other errors - retry. - return err + return true, err } - if err := backoff.Retry(action, b); err != nil { + err := waitext.Retry(ctx, b, h.cfg.podDeleteRetries, action, func(err error) { + h.log.Warnf("deleting pod %s on node %s in namespace %s, will retry: %v", pod.Name, pod.Spec.NodeName, pod.Namespace, err) + }) + if err != nil { return fmt.Errorf("deleting pod %s in namespace %s: %w", pod.Name, pod.Namespace, err) } return nil diff --git a/actions/kubernetes_helpers.go b/actions/kubernetes_helpers.go index 332e8333..075f40bc 100644 --- a/actions/kubernetes_helpers.go +++ b/actions/kubernetes_helpers.go @@ -6,17 +6,23 @@ import ( "fmt" "time" - "github.com/cenkalti/backoff/v4" "github.com/sirupsen/logrus" v1 "k8s.io/api/core/v1" k8serrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" apitypes "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/strategicpatch" + "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/kubernetes" + + "github.com/castai/cluster-controller/waitext" +) + +const ( + defaultMaxRetriesK8SOperation = 5 ) -func patchNode(ctx context.Context, clientset kubernetes.Interface, node *v1.Node, changeFn func(*v1.Node)) error { +func patchNode(ctx context.Context, log logrus.FieldLogger, clientset kubernetes.Interface, node *v1.Node, changeFn func(*v1.Node)) error { oldData, err := json.Marshal(node) if err != nil { return fmt.Errorf("marshaling old data: %w", err) @@ -34,10 +40,18 @@ func patchNode(ctx context.Context, clientset kubernetes.Interface, node *v1.Nod return fmt.Errorf("creating patch for node: %w", err) } - err = backoff.Retry(func() error { - _, err = clientset.CoreV1().Nodes().Patch(ctx, node.Name, apitypes.StrategicMergePatchType, patch, metav1.PatchOptions{}) - return err - }, defaultBackoff(ctx)) + err = waitext.Retry( + ctx, + defaultBackoff(), + defaultMaxRetriesK8SOperation, + func(ctx context.Context) (bool, error) { + _, err = clientset.CoreV1().Nodes().Patch(ctx, node.Name, apitypes.StrategicMergePatchType, patch, metav1.PatchOptions{}) + return true, err + }, + func(err error) { + log.Warnf("patch node, will retry: %v", err) + }, + ) if err != nil { return fmt.Errorf("patching node: %w", err) } @@ -46,15 +60,24 @@ func patchNode(ctx context.Context, clientset kubernetes.Interface, node *v1.Nod } func patchNodeStatus(ctx context.Context, log logrus.FieldLogger, clientset kubernetes.Interface, name string, patch []byte) error { - err := backoff.Retry(func() error { - _, err := clientset.CoreV1().Nodes().PatchStatus(ctx, name, patch) - if k8serrors.IsForbidden(err) { - // permissions might be of older version that can't patch node/status - log.WithField("node", name).WithError(err).Warn("skip patch node/status") - return nil - } - return err - }, defaultBackoff(ctx)) + err := waitext.Retry( + ctx, + defaultBackoff(), + defaultMaxRetriesK8SOperation, + func(ctx context.Context) (bool, error) { + _, err := clientset.CoreV1().Nodes().PatchStatus(ctx, name, patch) + if k8serrors.IsForbidden(err) { + // permissions might be of older version that can't patch node/status + log.WithField("node", name).WithError(err).Warn("skip patch node/status") + return false, nil + } + return true, err + }, + func(err error) { + log.Warnf("patch node status, will retry: %v", err) + }, + ) + if err != nil { return fmt.Errorf("patch status: %w", err) } @@ -65,27 +88,35 @@ func getNodeForPatching(ctx context.Context, log logrus.FieldLogger, clientset k // on GKE we noticed that sometimes the node is not found, even though it is in the cluster // as a result was returned from watch. But subsequent get request returns not found. // This is likely due to clientset's caching that's meant to alleviate API's load. - // So we give enough time for cache to sync. - logRetry := func(err error, _ time.Duration) { - log.Warnf("getting node, will retry: %v", err) - } + // So we give enough time for cache to sync - ~10s max. + var node *v1.Node - b := backoff.NewExponentialBackOff() - b.MaxElapsedTime = 10 * time.Second - err := backoff.RetryNotify(func() error { - var err error - node, err = clientset.CoreV1().Nodes().Get(ctx, nodeName, metav1.GetOptions{}) - if err != nil { - return err - } - return err - }, b, logRetry) + + boff := waitext.DefaultExponentialBackoff() + + err := waitext.Retry( + ctx, + boff, + 5, + func(ctx context.Context) (bool, error) { + var err error + node, err = clientset.CoreV1().Nodes().Get(ctx, nodeName, metav1.GetOptions{}) + if err != nil { + return true, err + } + return false, nil + }, + func(err error) { + log.Warnf("getting node, will retry: %v", err) + }, + ) if err != nil { return nil, err } return node, nil + } -func defaultBackoff(ctx context.Context) backoff.BackOffContext { - return backoff.WithContext(backoff.WithMaxRetries(backoff.NewConstantBackOff(500*time.Millisecond), 5), ctx) // nolint:gomnd +func defaultBackoff() wait.Backoff { + return waitext.NewConstantBackoff(500 * time.Millisecond) } diff --git a/actions/patch_node_handler.go b/actions/patch_node_handler.go index b8a1dd0e..9ef7c4ca 100644 --- a/actions/patch_node_handler.go +++ b/actions/patch_node_handler.go @@ -80,7 +80,7 @@ func (h *patchNodeHandler) Handle(ctx context.Context, action *castai.ClusterAct "capacity": req.Capacity, }).Infof("patching node, labels=%v, taints=%v, annotations=%v, unschedulable=%v", req.Labels, req.Taints, req.Annotations, unschedulable) - err = patchNode(ctx, h.clientset, node, func(n *v1.Node) { + err = patchNode(ctx, h.log, h.clientset, node, func(n *v1.Node) { n.Labels = patchNodeMapField(n.Labels, req.Labels) n.Annotations = patchNodeMapField(n.Annotations, req.Annotations) n.Spec.Taints = patchTaints(n.Spec.Taints, req.Taints) diff --git a/castai/logexporter.go b/castai/logexporter.go index ec91d094..e3ddc585 100644 --- a/castai/logexporter.go +++ b/castai/logexporter.go @@ -5,8 +5,9 @@ import ( "sync" "time" - "github.com/cenkalti/backoff/v4" "github.com/sirupsen/logrus" + + "github.com/castai/cluster-controller/waitext" ) const ( @@ -83,10 +84,12 @@ func (e *LogExporter) sendLogEvent(log *logrus.Entry) { Fields: log.Data, } - b := backoff.WithContext(backoff.WithMaxRetries(backoff.NewExponentialBackOff(), 3), ctx) - err := backoff.Retry(func() error { - return e.sender.SendLog(ctx, logEntry) - }, b) + b := waitext.DefaultExponentialBackoff() + err := waitext.Retry(ctx, b, 3, func(ctx context.Context) (bool, error) { + return true, e.sender.SendLog(ctx, logEntry) + }, func(err error) { + e.logger.Debugf("failed to send logs, will retry: %s", err) + }) if err != nil { e.logger.Debugf("sending logs: %v", err) diff --git a/go.mod b/go.mod index 2c082195..db1456a9 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.21 require ( github.com/bombsimon/logrusr/v4 v4.0.0 - github.com/cenkalti/backoff/v4 v4.2.1 github.com/evanphx/json-patch v5.7.0+incompatible github.com/go-resty/resty/v2 v2.5.0 github.com/golang/mock v1.6.0 diff --git a/go.sum b/go.sum index 1a04a559..d63568a0 100644 --- a/go.sum +++ b/go.sum @@ -94,8 +94,6 @@ github.com/bugsnag/osext v0.0.0-20130617224835-0dd3f918b21b h1:otBG+dV+YK+Soembj github.com/bugsnag/osext v0.0.0-20130617224835-0dd3f918b21b/go.mod h1:obH5gd0BsqsP2LwDJ9aOkm/6J86V6lyAXCoQWGw3K50= github.com/bugsnag/panicwrap v0.0.0-20151223152923-e2c28503fcd0 h1:nvj0OLI3YqYXer/kZD8Ri1aaunCxIEsOst1BVJswV0o= github.com/bugsnag/panicwrap v0.0.0-20151223152923-e2c28503fcd0/go.mod h1:D/8v3kj0zr8ZAKg1AQ6crr+5VwKN5eIywRkfhyM/+dE= -github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= -github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= diff --git a/helm/chart_loader.go b/helm/chart_loader.go index c156ac92..172f9d24 100644 --- a/helm/chart_loader.go +++ b/helm/chart_loader.go @@ -5,11 +5,12 @@ package helm import ( "context" "fmt" + "io" "net/http" "strings" "time" - "github.com/cenkalti/backoff/v4" + "github.com/sirupsen/logrus" "helm.sh/helm/v3/pkg/chart" "helm.sh/helm/v3/pkg/chart/loader" "helm.sh/helm/v3/pkg/cli" @@ -17,50 +18,70 @@ import ( "helm.sh/helm/v3/pkg/repo" "github.com/castai/cluster-controller/castai" + "github.com/castai/cluster-controller/waitext" +) + +const ( + defaultOperationRetries = 5 ) type ChartLoader interface { Load(ctx context.Context, c *castai.ChartSource) (*chart.Chart, error) } -func NewChartLoader() ChartLoader { - return &remoteChartLoader{} +func NewChartLoader(log logrus.FieldLogger) ChartLoader { + return &remoteChartLoader{log: log} } // remoteChartLoader fetches chart from remote source by given url. type remoteChartLoader struct { + log logrus.FieldLogger } func (cl *remoteChartLoader) Load(ctx context.Context, c *castai.ChartSource) (*chart.Chart, error) { var res *chart.Chart - err := backoff.Retry(func() error { - var archiveURL string - if strings.HasSuffix(c.RepoURL, ".tgz") { - archiveURL = c.RepoURL - } else { - index, err := cl.downloadHelmIndex(c.RepoURL) + + err := waitext.Retry( + ctx, + waitext.NewConstantBackoff(1*time.Second), + defaultOperationRetries, + func(ctx context.Context) (bool, error) { + var archiveURL string + if strings.HasSuffix(c.RepoURL, ".tgz") { + archiveURL = c.RepoURL + } else { + index, err := cl.downloadHelmIndex(c.RepoURL) + if err != nil { + return true, err + } + archiveURL, err = cl.chartURL(index, c.Name, c.Version) + if err != nil { + return true, err + } + } + + archiveResp, err := cl.fetchArchive(ctx, archiveURL) if err != nil { - return err + return true, err } - archiveURL, err = cl.chartURL(index, c.Name, c.Version) + defer func(Body io.ReadCloser) { + err := Body.Close() + if err != nil { + cl.log.Warnf("loading chart from archive - failed to close response body: %v", err) + } + }(archiveResp.Body) + + ch, err := loader.LoadArchive(archiveResp.Body) if err != nil { - return err + return true, fmt.Errorf("loading chart from archive: %w", err) } - } - - archiveResp, err := cl.fetchArchive(ctx, archiveURL) - if err != nil { - return err - } - defer archiveResp.Body.Close() - - ch, err := loader.LoadArchive(archiveResp.Body) - if err != nil { - return fmt.Errorf("loading chart from archive: %w", err) - } - res = ch - return nil - }, defaultBackoff(ctx)) + res = ch + return false, nil + }, + func(err error) { + cl.log.Warnf("error loading chart from archive, will retry: %v", err) + }, + ) if err != nil { return nil, err } @@ -86,10 +107,6 @@ func (cl *remoteChartLoader) fetchArchive(ctx context.Context, archiveURL string return archiveResp, nil } -func defaultBackoff(ctx context.Context) backoff.BackOffContext { - return backoff.WithContext(backoff.WithMaxRetries(backoff.NewConstantBackOff(1*time.Second), 5), ctx) -} - func (cl *remoteChartLoader) downloadHelmIndex(repoURL string) (*repo.IndexFile, error) { r, err := repo.NewChartRepository(&repo.Entry{URL: repoURL}, getter.All(&cli.EnvSettings{})) if err != nil { diff --git a/helm/chart_loader_test.go b/helm/chart_loader_test.go index 1f7a731b..b21e9b55 100644 --- a/helm/chart_loader_test.go +++ b/helm/chart_loader_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "github.com/castai/cluster-controller/castai" @@ -21,7 +22,7 @@ func TestIntegration_ChartLoader(t *testing.T) { Version: "0.4.3", } - loader := NewChartLoader() + loader := NewChartLoader(logrus.New()) c, err := loader.Load(ctx, chart) r.NoError(err) r.Equal(chart.Name, c.Name()) diff --git a/main.go b/main.go index d4a79909..70ccb5d1 100644 --- a/main.go +++ b/main.go @@ -12,7 +12,6 @@ import ( "time" "github.com/bombsimon/logrusr/v4" - "github.com/cenkalti/backoff/v4" "github.com/google/uuid" "github.com/sirupsen/logrus" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -34,6 +33,7 @@ import ( "github.com/castai/cluster-controller/health" "github.com/castai/cluster-controller/helm" "github.com/castai/cluster-controller/version" + "github.com/castai/cluster-controller/waitext" ) // These should be set via `go build` during a release. @@ -118,7 +118,7 @@ func run( restConfigLeader.RateLimiter = flowcontrol.NewTokenBucketRateLimiter(float32(cfg.KubeClient.QPS), cfg.KubeClient.Burst) restConfigDynamic.RateLimiter = flowcontrol.NewTokenBucketRateLimiter(float32(cfg.KubeClient.QPS), cfg.KubeClient.Burst) - helmClient := helm.NewClient(logger, helm.NewChartLoader(), restconfig) + helmClient := helm.NewClient(logger, helm.NewChartLoader(logger), restconfig) clientset, err := kubernetes.NewForConfig(restconfig) if err != nil { @@ -335,29 +335,29 @@ func retrieveKubeConfig(log logrus.FieldLogger) (*rest.Config, error) { type kubeRetryTransport struct { log logrus.FieldLogger next http.RoundTripper - maxRetries uint64 + maxRetries int retryInterval time.Duration } func (rt *kubeRetryTransport) RoundTrip(req *http.Request) (*http.Response, error) { var resp *http.Response - err := backoff.RetryNotify(func() error { + + boff := waitext.NewConstantBackoff(rt.retryInterval) + + err := waitext.Retry(context.Background(), boff, rt.maxRetries, func(_ context.Context) (bool, error) { var err error resp, err = rt.next.RoundTrip(req) if err != nil { // Previously client-go contained logic to retry connection refused errors. See https://github.com/kubernetes/kubernetes/pull/88267/files if net.IsConnectionRefused(err) { - return err + return true, err } - return backoff.Permanent(err) + return false, err } - return nil - }, backoff.WithMaxRetries(backoff.NewConstantBackOff(rt.retryInterval), rt.maxRetries), - func(err error, duration time.Duration) { - if err != nil { - rt.log.Warnf("kube api server connection refused, will retry: %v", err) - } - }) + return false, nil + }, func(err error) { + rt.log.Warnf("kube api server connection refused, will retry: %v", err) + }) return resp, err } diff --git a/waitext/doc.go b/waitext/doc.go new file mode 100644 index 00000000..6c75f9bb --- /dev/null +++ b/waitext/doc.go @@ -0,0 +1,2 @@ +// Package waitext implements behavior similar to https://github.com/cenkalti/backoff on top of k8s.io/apimachinery/pkg/util/wait. +package waitext diff --git a/waitext/extensions.go b/waitext/extensions.go new file mode 100644 index 00000000..861610c6 --- /dev/null +++ b/waitext/extensions.go @@ -0,0 +1,101 @@ +package waitext + +import ( + "context" + "fmt" + "math" + "time" + + "k8s.io/apimachinery/pkg/util/wait" +) + +const ( + defaultInitialInterval = 500 * time.Millisecond + defaultRandomizationFactor = 0.5 + defaultMultiplier = 1.5 + defaultMaxInterval = 60 * time.Second + + // Forever should be used to simulate infinite retries or backoff increase. + // Usually it's wise to have a context with timeout to avoid an infinite loop. + Forever = math.MaxInt32 +) + +// DefaultExponentialBackoff creates an exponential backoff with sensible default values. +// Defaults should match ExponentialBackoff in github.com/cenkalti/backoff +func DefaultExponentialBackoff() wait.Backoff { + return wait.Backoff{ + Duration: defaultInitialInterval, + Factor: defaultMultiplier, + Jitter: defaultRandomizationFactor, + Cap: defaultMaxInterval, + Steps: Forever, + } +} + +// NewConstantBackoff creates a backoff that steps at constant intervals. +// This backoff will run "forever", use WithMaxRetries or a context to put a hard cap. +// This works similar to ConstantBackOff in github.com/cenkalti/backoff +func NewConstantBackoff(interval time.Duration) wait.Backoff { + return wait.Backoff{ + Duration: interval, + Steps: Forever, + } +} + +// Retry executes an operation with retries following these semantics: +// +// - The operation is executed at least once (even if context is cancelled) +// +// - If operation returns nil error, assumption is that it succeeded +// +// - If operation returns non-nil error, then the first boolean return value decides whether to retry or not +// +// The operation will not be retried anymore if +// +// - retries reaches 0 +// +// - the context is cancelled +// +// The end result is: +// +// - nil if operation was successful at least once +// - last encountered error from operation if retries are exhausted +// - a multi-error if context is cancelled that contains - the ctx.Err(), context.Cause() and last encountered error from the operation +// +// If retryNotify is passed, it is called when making retries. +// Caveat: this function is similar to wait.ExponentialBackoff but has some important behavior differences like at-least-one execution and retryable errors +func Retry(ctx context.Context, backoff wait.Backoff, retries int, operation func(context.Context) (bool, error), retryNotify func(error)) error { + var lastErr error + var shouldRetry bool + + shouldRetry, lastErr = operation(ctx) + + // No retry needed + if lastErr == nil || !shouldRetry { + return lastErr + } + + for retries > 0 { + // Notify about expected retry + if retryNotify != nil { + retryNotify(lastErr) + } + + waitInterval := backoff.Step() + select { + case <-ctx.Done(): + return fmt.Errorf("context finished with err (%w); cause (%w); last encountered error from operation (%w)", ctx.Err(), context.Cause(ctx), lastErr) + case <-time.After(waitInterval): + } + + shouldRetry, lastErr = operation(ctx) + retries-- + + // We are done + if lastErr == nil || !shouldRetry { + break + } + } + + return lastErr +} diff --git a/waitext/extensions_test.go b/waitext/extensions_test.go new file mode 100644 index 00000000..24d24383 --- /dev/null +++ b/waitext/extensions_test.go @@ -0,0 +1,184 @@ +package waitext + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/util/wait" +) + +func TestNewConstantBackoff(t *testing.T) { + r := require.New(t) + expectedSleepDuration := 10 * time.Second + backoff := NewConstantBackoff(expectedSleepDuration) + + for i := 0; i < 10; i++ { + r.Equal(expectedSleepDuration, backoff.Step()) + } +} + +func TestDefaultExponentialBackoff(t *testing.T) { + r := require.New(t) + + val := DefaultExponentialBackoff() + + r.Equal(defaultInitialInterval, val.Duration) + r.Equal(defaultMultiplier, val.Factor) + r.Equal(defaultRandomizationFactor, val.Jitter) + r.Equal(defaultMaxInterval, val.Cap) +} + +func TestRetry(t *testing.T) { + r := require.New(t) + + t.Run("Retrying logic tests", func(t *testing.T) { + t.Run("Called at least once, even if retries or steps is 0", func(t *testing.T) { + called := false + err := Retry(context.Background(), wait.Backoff{Steps: 0}, 0, func(_ context.Context) (bool, error) { + called = true + return false, nil + }, nil) + + r.NoError(err) + r.True(called) + }) + + t.Run("Respects backoff and retry count", func(t *testing.T) { + retries := 4 + expectedTotalExecutions := 1 + retries + backoff := DefaultExponentialBackoff() + backoff.Duration = 10 * time.Millisecond + backoff.Factor = 2 + backoff.Jitter = 0 + + // There is no "initial" wait so 0 index simulates zero. + // The rest are calculated as interval * factor^(ix) without jitter for simplicity + expectedWaitTimes := []time.Duration{ + time.Millisecond, + 10 * time.Millisecond, + 20 * time.Millisecond, + 40 * time.Millisecond, + 80 * time.Millisecond, + } + indexWaitTimes := 0 + + actualExecutions := 0 + lastExec := time.Now() + err := Retry(context.Background(), backoff, retries, func(_ context.Context) (bool, error) { + actualExecutions++ + now := time.Now() + waitTime := now.Sub(lastExec) + lastExec = now + + t.Log("wait time", waitTime) + + // We give some tolerance as we can't be precise to the nanosecond here + r.InDelta(expectedWaitTimes[indexWaitTimes], waitTime, float64(2*time.Millisecond)) + indexWaitTimes++ + + return true, errors.New("dummy") + }, nil) + + r.Error(err) + r.Equal(expectedTotalExecutions, actualExecutions) + }) + + t.Run("Returns last encountered error", func(t *testing.T) { + timesCalled := 0 + expectedErrMessage := "boom 3" + + err := Retry(context.Background(), NewConstantBackoff(10*time.Millisecond), 2, + func(ctx context.Context) (bool, error) { + timesCalled++ + return true, fmt.Errorf("boom %d", timesCalled) + }, nil) + + r.Equal(expectedErrMessage, err.Error()) + }) + + t.Run("Does not retry if false is returned as first parameter", func(t *testing.T) { + expectedErr := errors.New("dummy") + called := false + err := Retry(context.Background(), NewConstantBackoff(10*time.Millisecond), 10, + func(ctx context.Context) (bool, error) { + r.False(called) + called = true + return false, expectedErr + }, nil) + + r.ErrorIs(err, expectedErr) + }) + }) + + t.Run("Notify callback tests", func(t *testing.T) { + t.Run("Notify is passed and called", func(t *testing.T) { + err := Retry( + context.Background(), + NewConstantBackoff(10*time.Millisecond), + 2, + func(_ context.Context) (bool, error) { + return true, errors.New("dummy") + }, + func(err error) { + r.Error(err) + }, + ) + r.Error(err) + }) + + t.Run("Notify is not passed, no panic", func(t *testing.T) { + err := Retry( + context.Background(), + NewConstantBackoff(10*time.Millisecond), + 2, + func(_ context.Context) (bool, error) { + return true, errors.New("dummy") + }, + nil, + ) + r.Error(err) + }) + }) + + t.Run("Context tests", func(t *testing.T) { + t.Run("On context cancel, stops", func(t *testing.T) { + ctx, cancel := context.WithCancelCause(context.Background()) + + innerError := errors.New("from operation") + cancelCause := errors.New("cancel cause err") + var overallReturnedErr error + + done := make(chan bool) + go func() { + overallReturnedErr = Retry(ctx, NewConstantBackoff(100*time.Millisecond), 1000, func(ctx context.Context) (bool, error) { + return true, innerError + }, nil) + done <- true + }() + + cancel(cancelCause) + <-done + r.ErrorIs(overallReturnedErr, context.Canceled, "Expected context cancelled to be propagated") + r.ErrorIs(overallReturnedErr, innerError, "Expected inner error by operation be propagated") + r.ErrorIs(overallReturnedErr, cancelCause, "Expected cancel cause error to be propagated") + }) + + t.Run("Operation is called at least once, even if context is cancelled", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + called := false + err := Retry(ctx, NewConstantBackoff(10*time.Millisecond), 1, func(ctx context.Context) (bool, error) { + called = true + return true, errors.New("dummy") + }, nil) + + r.ErrorIs(err, context.Canceled) + r.True(called) + }) + }) +}