Skip to content

Commit

Permalink
Merge pull request #186 from mfreeman451/fix/node_recovery_bug
Browse files Browse the repository at this point in the history
Fix/node recovery bug
  • Loading branch information
mfreeman451 authored Feb 6, 2025
2 parents 4346581 + d72cce5 commit 5175cc3
Show file tree
Hide file tree
Showing 8 changed files with 769 additions and 330 deletions.
119 changes: 119 additions & 0 deletions pkg/cloud/node_recovery_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
package cloud

import (
"context"
"testing"
"time"

"github.com/mfreeman451/serviceradar/pkg/cloud/alerts"
"github.com/mfreeman451/serviceradar/pkg/db"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
)

func TestNodeRecoveryManager_ProcessRecovery(t *testing.T) {
tests := []struct {
name string
nodeID string
currentStatus *db.NodeStatus
dbError error
expectAlert bool
expectedError string
}{
{
name: "successful_recovery",
nodeID: "test-node",
currentStatus: &db.NodeStatus{
NodeID: "test-node",
IsHealthy: false,
LastSeen: time.Now().Add(-time.Hour),
},
expectAlert: true,
},
{
name: "already_healthy",
nodeID: "test-node",
currentStatus: &db.NodeStatus{
NodeID: "test-node",
IsHealthy: true,
LastSeen: time.Now(),
},
expectAlert: false,
},
{
name: "db_error",
nodeID: "test-node",
dbError: db.ErrDatabaseError,
expectedError: "get node status",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockDB := db.NewMockService(ctrl)
mockAlerter := alerts.NewMockAlertService(ctrl)
mockTx := db.NewMockTransaction(ctrl)

// Mock Begin() call
mockDB.EXPECT().Begin().Return(mockTx, nil)

// Mock Rollback() as it's in a defer
mockTx.EXPECT().Rollback().Return(nil).AnyTimes()

// Mock GetNodeStatus
mockDB.EXPECT().GetNodeStatus(tt.nodeID).Return(tt.currentStatus, tt.dbError)

if tt.currentStatus != nil && !tt.currentStatus.IsHealthy {
mockDB.EXPECT().UpdateNodeStatus(gomock.Any()).Return(nil)

if tt.expectAlert {
mockAlerter.EXPECT().Alert(gomock.Any(), gomock.Any()).Return(nil)
// Mock the successful commit
mockTx.EXPECT().Commit().Return(nil)
}
}

mgr := &NodeRecoveryManager{
db: mockDB,
alerter: mockAlerter,
getHostname: func() string { return "test-host" },
}

err := mgr.processRecovery(context.Background(), tt.nodeID, time.Now())

if tt.expectedError != "" {
assert.Contains(t, err.Error(), tt.expectedError)
} else {
assert.NoError(t, err)
}
})
}
}

func TestNodeRecoveryManager_SendRecoveryAlert(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockAlerter := alerts.NewMockAlertService(ctrl)
mgr := &NodeRecoveryManager{
alerter: mockAlerter,
getHostname: func() string { return "test-host" },
}

mockAlerter.EXPECT().
Alert(gomock.Any(), gomock.Any()).
DoAndReturn(func(_ context.Context, alert *alerts.WebhookAlert) error {
assert.Equal(t, alerts.Info, alert.Level)
assert.Equal(t, "Node Recovered", alert.Title)
assert.Equal(t, "test-node", alert.NodeID)
assert.Equal(t, "test-host", alert.Details["hostname"])

return nil
})

err := mgr.sendRecoveryAlert(context.Background(), "test-node", time.Now())
assert.NoError(t, err)
}
167 changes: 87 additions & 80 deletions pkg/cloud/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ const (
nodeHistoryLimit = 1000
nodeDiscoveryTimeout = 30 * time.Second
nodeNeverReportedTimeout = 30 * time.Second
pollerTimeout = 30 * time.Second
defaultDBPath = "/var/lib/serviceradar/serviceradar.db"
statusUnknown = "unknown"
sweepService = "sweep"
Expand Down Expand Up @@ -418,19 +417,13 @@ func (s *Server) checkInitialStates(ctx context.Context) {
// Add ORDER BY clause
query += "ORDER BY last_seen DESC"

//nolint:rowserrcheck // rows.Close() is deferred
rows, err := s.db.Query(query, args...)
if err != nil {
log.Printf("Error querying nodes: %v", err)

return
}
defer func(rows *sql.Rows) {
err := rows.Close()
if err != nil {
log.Printf("Error closing rows: %v", err)
}
}(rows)
defer db.CloseRows(rows)

for rows.Next() {
var nodeID string
Expand Down Expand Up @@ -678,23 +671,21 @@ func (s *Server) updateNodeDownStatus(nodeID string, lastSeen time.Time) error {
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}

// Create a flag to track if we need to rollback
needsRollback := true
defer func() {
if needsRollback {
if rbErr := tx.Rollback(); rbErr != nil && !errors.Is(rbErr, sql.ErrTxDone) {
log.Printf("Error rolling back transaction: %v", rbErr)
}
defer func(tx db.Transaction) {
err = tx.Rollback()
if err != nil {
log.Printf("Error rolling back transaction: %v", err)
}
}()
}(tx)

if err := s.performNodeUpdate(tx, nodeID, lastSeen); err != nil {
return err
sqlTx, err := db.ToTx(tx)
if err != nil {
return fmt.Errorf("invalid transaction: %w", err)
}

// Mark that we don't need rollback before committing
needsRollback = false
if err := s.performNodeUpdate(sqlTx, nodeID, lastSeen); err != nil {
return err
}

return tx.Commit()
}
Expand Down Expand Up @@ -798,12 +789,7 @@ func (s *Server) checkNeverReportedNodes(ctx context.Context) error {
if err != nil {
return fmt.Errorf("error querying unreported nodes: %w", err)
}
defer func(rows *sql.Rows) {
err := rows.Close()
if err != nil {
log.Printf("Error closing rows: %v", err)
}
}(rows)
defer db.CloseRows(rows)

var unreportedNodes []string

Expand Down Expand Up @@ -856,7 +842,6 @@ func (s *Server) checkNeverReportedPollers(ctx context.Context) {

var unreportedNodes []string

//nolint:rowserrcheck // rows.Close() is deferred
rows, err := s.db.Query(`
SELECT node_id
FROM nodes
Expand All @@ -866,13 +851,7 @@ func (s *Server) checkNeverReportedPollers(ctx context.Context) {
log.Printf("Error querying unreported nodes: %v", err)
return
}

defer func(rows *sql.Rows) {
err := rows.Close()
if err != nil {
log.Printf("Error closing rows: %v", err)
}
}(rows)
defer db.CloseRows(rows)

for rows.Next() {
var nodeID string
Expand Down Expand Up @@ -1031,12 +1010,7 @@ func (s *Server) checkNodeStates(ctx context.Context) error {
if err != nil {
return fmt.Errorf("failed to query nodes: %w", err)
}
defer func(rows *sql.Rows) {
err := rows.Close()
if err != nil {
log.Printf("Error closing rows: %v", err)
}
}(rows)
defer db.CloseRows(rows)

threshold := time.Now().Add(-s.alertThreshold)

Expand Down Expand Up @@ -1086,56 +1060,89 @@ func (s *Server) evaluateNodeHealth(ctx context.Context, nodeID string, lastSeen
return nil
}

func (s *Server) getNodeStatus(nodeID string) (*db.NodeStatus, error) {
var status db.NodeStatus
// NodeRecoveryManager handles node recovery state transitions.
type NodeRecoveryManager struct {
db db.Service
alerter alerts.AlertService
getHostname func() string
}

err := s.db.QueryRow(`
SELECT node_id, is_healthy, first_seen, last_seen
FROM nodes
WHERE node_id = ?`, nodeID).Scan(&status.NodeID, &status.IsHealthy, &status.FirstSeen, &status.LastSeen)
if err != nil {
return nil, fmt.Errorf("failed to get node status: %w", err)
func newNodeRecoveryManager(d db.Service, alerter alerts.AlertService) *NodeRecoveryManager {
return &NodeRecoveryManager{
db: d,
alerter: alerter,
getHostname: func() string {
hostname, err := os.Hostname()
if err != nil {
return statusUnknown
}
return hostname
},
}

return &status, nil
}

// handlePotentialRecovery simplified to coordinate the recovery process.
func (s *Server) handlePotentialRecovery(ctx context.Context, nodeID string, lastSeen time.Time) error {
// First check if the node is actually healthy via status
currentStatus, err := s.getNodeStatus(nodeID)
mgr := newNodeRecoveryManager(s.db, s.webhooks[0])
return mgr.processRecovery(ctx, nodeID, lastSeen)
}

// processRecovery handles the recovery state transition.
func (m *NodeRecoveryManager) processRecovery(ctx context.Context, nodeID string, lastSeen time.Time) error {
tx, err := m.db.Begin()
if err != nil {
return fmt.Errorf("begin transaction: %w", err)
}
defer func(tx db.Transaction) {
err = tx.Rollback()
if err != nil {
log.Printf("Error rolling back transaction: %v", err)
}
}(tx)

status, err := m.db.GetNodeStatus(nodeID)
if err != nil {
return fmt.Errorf("failed to get node status: %w", err)
return fmt.Errorf("get node status: %w", err)
}

// If node is already marked healthy, no need to do anything
if currentStatus.IsHealthy {
return nil
if status.IsHealthy {
return nil // Node is already healthy
}

// Update node status to healthy
if err := s.updateNodeStatus(nodeID, true, lastSeen); err != nil {
return fmt.Errorf("failed to update node status: %w", err)
// Update node status
status.IsHealthy = true
status.LastSeen = lastSeen

if err := m.db.UpdateNodeStatus(status); err != nil {
return fmt.Errorf("update node status: %w", err)
}

// Create recovery alert
if err := m.sendRecoveryAlert(ctx, nodeID, lastSeen); err != nil {
return fmt.Errorf("send recovery alert: %w", err)
}

if err := tx.Commit(); err != nil {
return fmt.Errorf("commit transaction: %w", err)
}

return nil
}

// sendRecoveryAlert handles alert creation and sending.
func (m *NodeRecoveryManager) sendRecoveryAlert(ctx context.Context, nodeID string, lastSeen time.Time) error {
alert := &alerts.WebhookAlert{
Level: alerts.Info,
Title: "Node Recovered",
Message: fmt.Sprintf("Node '%s' is back online", nodeID),
NodeID: nodeID,
Timestamp: lastSeen.UTC().Format(time.RFC3339),
Details: map[string]any{
"hostname": getHostname(),
"hostname": m.getHostname(),
"recovery_time": lastSeen.Format(time.RFC3339),
},
}

// Send the alert
if err := s.sendAlert(ctx, alert); err != nil {
log.Printf("Failed to send recovery alert: %v", err)
}

return nil
return m.alerter.Alert(ctx, alert)
}

func (s *Server) handleNodeDown(ctx context.Context, nodeID string, lastSeen time.Time) error {
Expand Down Expand Up @@ -1176,18 +1183,20 @@ func (s *Server) updateNodeStatus(nodeID string, isHealthy bool, timestamp time.
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}

needsRollback := true
defer func() {
if needsRollback {
if rbErr := tx.Rollback(); rbErr != nil && !errors.Is(rbErr, sql.ErrTxDone) {
log.Printf("Error rolling back transaction: %v", rbErr)
}
defer func(tx db.Transaction) {
err = tx.Rollback()
if err != nil {
log.Printf("Error rolling back transaction: %v", err)
}
}()
}(tx)

sqlTx, err := db.ToTx(tx)
if err != nil {
return fmt.Errorf("invalid transaction: %w", err)
}

// Update node status
if err := s.updateNodeInTx(tx, nodeID, isHealthy, timestamp); err != nil {
if err := s.updateNodeInTx(sqlTx, nodeID, isHealthy, timestamp); err != nil {
return err
}

Expand All @@ -1199,8 +1208,6 @@ func (s *Server) updateNodeStatus(nodeID string, isHealthy bool, timestamp time.
return fmt.Errorf("failed to insert history: %w", err)
}

needsRollback = false

return tx.Commit()
}

Expand Down Expand Up @@ -1363,7 +1370,7 @@ func (s *Server) getLastDowntime(nodeID string) time.Time {
func getHostname() string {
hostname, err := os.Hostname()
if err != nil {
return "unknown"
return statusUnknown
}

return hostname
Expand Down
Loading

0 comments on commit 5175cc3

Please sign in to comment.