Skip to content

Commit

Permalink
fix: skip processing started actions (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
anjmao authored Feb 11, 2022
1 parent 62d1753 commit 3ce89cb
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 22 deletions.
42 changes: 38 additions & 4 deletions actions/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"reflect"
"sync"
"time"

"github.com/cenkalti/backoff/v4"
Expand Down Expand Up @@ -40,9 +41,10 @@ func NewService(
helmClient helm.Client,
) Service {
return &service{
log: log,
cfg: cfg,
castaiClient: castaiClient,
log: log,
cfg: cfg,
castaiClient: castaiClient,
startedActions: map[string]struct{}{},
actionHandlers: map[reflect.Type]ActionHandler{
reflect.TypeOf(&castai.ActionDeleteNode{}): newDeleteNodeHandler(log, clientset),
reflect.TypeOf(&castai.ActionDrainNode{}): newDrainNodeHandler(log, clientset),
Expand All @@ -62,6 +64,10 @@ type service struct {
castaiClient castai.Client

actionHandlers map[reflect.Type]ActionHandler

startedActionsWg sync.WaitGroup
startedActions map[string]struct{}
startedActionsMu sync.Mutex
}

func (s *service) Run(ctx context.Context) error {
Expand Down Expand Up @@ -116,7 +122,14 @@ func (s *service) pollActions(ctx context.Context) ([]*castai.ClusterAction, err

func (s *service) handleActions(ctx context.Context, actions []*castai.ClusterAction) {
for _, action := range actions {
if !s.startProcessing(action.ID) {
s.log.Debugf("action is already processing, id=%s", action.ID)
continue
}

go func(action *castai.ClusterAction) {
defer s.finishProcessing(action.ID)

var err error
handleErr := s.handleAction(ctx, action)
ackErr := s.ackAction(ctx, action, handleErr)
Expand All @@ -133,6 +146,27 @@ func (s *service) handleActions(ctx context.Context, actions []*castai.ClusterAc
}
}

func (s *service) finishProcessing(actionID string) {
s.startedActionsMu.Lock()
defer s.startedActionsMu.Unlock()

s.startedActionsWg.Done()
delete(s.startedActions, actionID)
}

func (s *service) startProcessing(actionID string) bool {
s.startedActionsMu.Lock()
defer s.startedActionsMu.Unlock()

if _, ok := s.startedActions[actionID]; ok {
return false
}

s.startedActionsWg.Add(1)
s.startedActions[actionID] = struct{}{}
return true
}

func (s *service) handleAction(ctx context.Context, action *castai.ClusterAction) (err error) {
data := action.Data()
actionType := reflect.TypeOf(data)
Expand All @@ -143,7 +177,7 @@ func (s *service) handleAction(ctx context.Context, action *castai.ClusterAction
}

if err := handler.Handle(ctx, data); err != nil {
return err
return fmt.Errorf("handling action %v: %w", actionType, err)
}
return nil
}
Expand Down
34 changes: 23 additions & 11 deletions actions/actions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ func TestMain(m *testing.M) {
}

func TestActions(t *testing.T) {
r := require.New(t)

log := logrus.New()
log.SetLevel(logrus.DebugLevel)
cfg := Config{
Expand All @@ -34,9 +32,9 @@ func TestActions(t *testing.T) {
ClusterID: uuid.New().String(),
}

newTestService := func(handler ActionHandler, client castai.Client) Service {
svc := NewService(log, cfg, nil, client, nil)
handlers := svc.(*service).actionHandlers
newTestService := func(handler ActionHandler, client castai.Client) *service {
svc := NewService(log, cfg, nil, client, nil).(*service)
handlers := svc.actionHandlers
// Patch handlers with a mock one.
for k := range handlers {
handlers[k] = handler
Expand All @@ -45,6 +43,8 @@ func TestActions(t *testing.T) {
}

t.Run("poll, handle and ack", func(t *testing.T) {
r := require.New(t)

apiActions := []*castai.ClusterAction{
{
ID: "a1",
Expand All @@ -69,11 +69,13 @@ func TestActions(t *testing.T) {
},
}
client := mock.NewMockAPIClient(apiActions)
handler := &mockAgentActionHandler{}
handler := &mockAgentActionHandler{handleDelay: 2 * time.Millisecond}
svc := newTestService(handler, client)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer func() {
cancel()
svc.startedActionsWg.Wait()

r.Len(client.Acks, 3)
ids := make([]string, len(client.Acks))
for i, ack := range client.Acks {
Expand All @@ -84,23 +86,29 @@ func TestActions(t *testing.T) {
r.Equal("a2", ids[1])
r.Equal("a3", ids[2])
}()
svc.Run(ctx)
r.NoError(svc.Run(ctx))
})

t.Run("continue polling on api error", func(t *testing.T) {
r := require.New(t)

client := mock.NewMockAPIClient([]*castai.ClusterAction{})
client.GetActionsErr = errors.New("ups")
handler := &mockAgentActionHandler{err: errors.New("ups")}
svc := newTestService(handler, client)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer func() {
cancel()
svc.startedActionsWg.Wait()

r.Len(client.Acks, 0)
}()
svc.Run(ctx)
r.NoError(svc.Run(ctx))
})

t.Run("ack with error when action handler failed", func(t *testing.T) {
r := require.New(t)

apiActions := []*castai.ClusterAction{
{
ID: "a1",
Expand All @@ -116,19 +124,23 @@ func TestActions(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer func() {
cancel()
svc.startedActionsWg.Wait()

r.Empty(client.Actions)
r.Len(client.Acks, 1)
r.Equal("a1", client.Acks[0].ActionID)
r.Equal("ups", *client.Acks[0].Err)
r.Equal("handling action *castai.ActionPatchNode: ups", *client.Acks[0].Err)
}()
svc.Run(ctx)
r.NoError(svc.Run(ctx))
})
}

type mockAgentActionHandler struct {
err error
err error
handleDelay time.Duration
}

func (m *mockAgentActionHandler) Handle(ctx context.Context, data interface{}) error {
time.Sleep(m.handleDelay)
return m.err
}
10 changes: 3 additions & 7 deletions actions/delete_node_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,13 @@ func (h *deleteNodeHandler) Handle(ctx context.Context, data interface{}) error
log := h.log.WithField("node_name", req.NodeName)
log.Info("deleting kubernetes node")

node, err := h.clientset.CoreV1().Nodes().Get(ctx, req.NodeName, metav1.GetOptions{})
if err != nil {
b := backoff.WithContext(backoff.WithMaxRetries(backoff.NewConstantBackOff(h.cfg.deleteRetryWait), h.cfg.deleteRetries), ctx)
return backoff.Retry(func() error {
err := h.clientset.CoreV1().Nodes().Delete(ctx, req.NodeName, metav1.DeleteOptions{})
if apierrors.IsNotFound(err) {
log.Info("node not found, skipping delete")
return nil
}
return err
}

b := backoff.WithContext(backoff.WithMaxRetries(backoff.NewConstantBackOff(h.cfg.deleteRetryWait), h.cfg.deleteRetries), ctx)
return backoff.Retry(func() error {
return h.clientset.CoreV1().Nodes().Delete(ctx, node.Name, metav1.DeleteOptions{})
}, b)
}

0 comments on commit 3ce89cb

Please sign in to comment.