From 59d5c388f45602f8564803ff115210abdc842c19 Mon Sep 17 00:00:00 2001 From: Ljubisa Gacevic Date: Wed, 8 Jan 2025 19:29:44 +0100 Subject: [PATCH] feat: add graceful shutdown --- cmd/beekeeper/cmd/stamper.go | 6 +++++- go.mod | 1 + go.sum | 2 ++ pkg/scheduler/scheduler.go | 23 +++++++++++++++++------ pkg/stamper/stamper.go | 27 ++++++++++++++++++--------- 5 files changed, 43 insertions(+), 16 deletions(-) diff --git a/cmd/beekeeper/cmd/stamper.go b/cmd/beekeeper/cmd/stamper.go index ee183e03..47ee3852 100644 --- a/cmd/beekeeper/cmd/stamper.go +++ b/cmd/beekeeper/cmd/stamper.go @@ -109,7 +109,11 @@ func (c *command) initStamperDilute() *cobra.Command { diluteExecutor.Start(ctx, func(ctx context.Context) error { return c.stamper.Dilute(ctx, c.globalConfig.GetFloat64(optionUsageThreshold), c.globalConfig.GetUint16(optionDiutionDepth)) }) - defer diluteExecutor.Stop() + defer func() { + if err := diluteExecutor.Close(); err != nil { + c.log.Errorf("failed to close dilution periodic executor: %v", err) + } + }() <-ctx.Done() diff --git a/go.mod b/go.mod index d17c3c89..57b8b925 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( k8s.io/api v0.30.3 k8s.io/apimachinery v0.30.3 k8s.io/client-go v0.30.3 + resenje.org/x v0.6.0 ) require ( diff --git a/go.sum b/go.sum index 72b06b42..b7790119 100644 --- a/go.sum +++ b/go.sum @@ -521,6 +521,8 @@ k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 h1:pUdcCO1Lk/tbT5ztQWOBi5HBgbBP1 k8s.io/utils v0.0.0-20240711033017-18e509b52bc8/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= lukechampine.com/blake3 v1.2.1 h1:YuqqRuaqsGV71BV/nm9xlI0MKUv4QC54jQnBChWbGnI= lukechampine.com/blake3 v1.2.1/go.mod h1:0OFRp7fBtAylGVCO40o87sbupkyIGgbpv1+M1k1LM6k= +resenje.org/x v0.6.0 h1:afn9E4XhglF4y9Kq0VH5tdSyjnsVKxiYgB6HFj7ebss= +resenje.org/x v0.6.0/go.mod h1:qgwe4MCzh57EkkMDurg24ug7HHfZtAjtBkmCihNmOpM= rsc.io/tmplfunc v0.0.3 h1:53XFQh69AfOa8Tw0Jm7t+GV7KZhOi6jzsCzTtKbMvzU= rsc.io/tmplfunc v0.0.3/go.mod h1:AG3sTPzElb1Io3Yg4voV9AGZJuleGAwaVRxL9M49PhA= sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo= diff --git a/pkg/scheduler/scheduler.go b/pkg/scheduler/scheduler.go index f1a086e7..15ae12f8 100644 --- a/pkg/scheduler/scheduler.go +++ b/pkg/scheduler/scheduler.go @@ -5,13 +5,14 @@ import ( "time" "github.com/ethersphere/beekeeper/pkg/logging" + "resenje.org/x/shutdown" ) type PeriodicExecutor struct { ticker *time.Ticker interval time.Duration log logging.Logger - stopChan chan struct{} + shutdown *shutdown.Graceful } func NewPeriodicExecutor(interval time.Duration, log logging.Logger) *PeriodicExecutor { @@ -19,20 +20,28 @@ func NewPeriodicExecutor(interval time.Duration, log logging.Logger) *PeriodicEx ticker: time.NewTicker(interval), interval: interval, log: log, - stopChan: make(chan struct{}), + shutdown: shutdown.NewGraceful(), } } func (pe *PeriodicExecutor) Start(ctx context.Context, task func(ctx context.Context) error) { + pe.shutdown.Add(1) go func() { + defer pe.shutdown.Done() + ctx = pe.shutdown.Context(ctx) + + if err := task(ctx); err != nil { + pe.log.Errorf("Task execution failed: %v", err) + } + for { select { case <-pe.ticker.C: - pe.log.Tracef("Executing task") + pe.log.Tracef("Executing task after %s interval", pe.interval) if err := task(ctx); err != nil { pe.log.Errorf("Task execution failed: %v", err) } - case <-pe.stopChan: + case <-pe.shutdown.Quit(): return case <-ctx.Done(): return @@ -41,7 +50,9 @@ func (pe *PeriodicExecutor) Start(ctx context.Context, task func(ctx context.Con }() } -func (pe *PeriodicExecutor) Stop() { +func (pe *PeriodicExecutor) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() pe.ticker.Stop() - close(pe.stopChan) + return pe.shutdown.Shutdown(ctx) } diff --git a/pkg/stamper/stamper.go b/pkg/stamper/stamper.go index 7c7bd37a..d5b109dc 100644 --- a/pkg/stamper/stamper.go +++ b/pkg/stamper/stamper.go @@ -30,8 +30,12 @@ type ClientConfig struct { } type StamperClient struct { - *ClientConfig - httpClient http.Client + log logging.Logger + namespace string + k8sClient *k8s.Client + labelSelector string + inCluster bool + httpClient http.Client } func NewStamperClient(cfg *ClientConfig) *StamperClient { @@ -50,8 +54,12 @@ func NewStamperClient(cfg *ClientConfig) *StamperClient { } return &StamperClient{ - httpClient: *httpClient, - ClientConfig: cfg, + httpClient: *httpClient, + log: cfg.Log, + namespace: cfg.Namespace, + k8sClient: cfg.K8sClient, + labelSelector: cfg.LabelSelector, + inCluster: cfg.InCluster, } } @@ -62,6 +70,7 @@ func (s *StamperClient) Create(ctx context.Context, amount uint64, depth uint8) // Dilute implements Client. func (s *StamperClient) Dilute(ctx context.Context, usageThreshold float64, dilutionDepth uint16) error { + s.log.WithFields(map[string]interface{}{"usageThreshold": usageThreshold, "dilutionDepth": dilutionDepth}).Infof("diluting namespace %s", s.namespace) nodes, err := s.getNamespaceNodes(ctx) if err != nil { return fmt.Errorf("get namespace nodes: %w", err) @@ -97,11 +106,11 @@ func (s *StamperClient) Topup(ctx context.Context, ttlThreshold time.Duration, t } func (sc *StamperClient) getNamespaceNodes(ctx context.Context) (nodes []Node, err error) { - if sc.Namespace == "" { + if sc.namespace == "" { return nil, fmt.Errorf("namespace not provided") } - if sc.InCluster { + if sc.inCluster { nodes, err = sc.getServiceNodes(ctx) } else { nodes, err = sc.getIngressNodes(ctx) @@ -115,7 +124,7 @@ func (sc *StamperClient) getNamespaceNodes(ctx context.Context) (nodes []Node, e } func (sc *StamperClient) getServiceNodes(ctx context.Context) ([]Node, error) { - svcNodes, err := sc.K8sClient.Service.GetNodes(ctx, sc.Namespace, sc.LabelSelector) + svcNodes, err := sc.k8sClient.Service.GetNodes(ctx, sc.namespace, sc.labelSelector) if err != nil { return nil, fmt.Errorf("list api services: %w", err) } @@ -138,12 +147,12 @@ func (sc *StamperClient) getServiceNodes(ctx context.Context) ([]Node, error) { } func (sc *StamperClient) getIngressNodes(ctx context.Context) ([]Node, error) { - ingressNodes, err := sc.K8sClient.Ingress.GetNodes(ctx, sc.Namespace, sc.LabelSelector) + ingressNodes, err := sc.k8sClient.Ingress.GetNodes(ctx, sc.namespace, sc.labelSelector) if err != nil { return nil, fmt.Errorf("list ingress api nodes hosts: %w", err) } - ingressRouteNodes, err := sc.K8sClient.IngressRoute.GetNodes(ctx, sc.Namespace, sc.LabelSelector) + ingressRouteNodes, err := sc.k8sClient.IngressRoute.GetNodes(ctx, sc.namespace, sc.labelSelector) if err != nil { return nil, fmt.Errorf("list ingress route api nodes hosts: %w", err) }