From 1d66f6a111793e66877d4ce8427c2e2e6b9e29df Mon Sep 17 00:00:00 2001 From: Ben Johnson Date: Tue, 21 Jan 2014 10:00:59 -0700 Subject: [PATCH] Snapshot test coverage. --- Makefile | 11 ++-- command.go | 1 - server.go | 121 ++++++++++++++-------------------- server_test.go | 6 -- snapshot.go | 47 +++++++------ snapshot_recovery_request.go | 6 -- snapshot_recovery_response.go | 13 +--- snapshot_request.go | 6 -- snapshot_response.go | 6 -- snapshot_test.go | 82 +++++++++++++++++++++++ statemachine.go | 6 -- statemachine_test.go | 19 ++++++ 12 files changed, 184 insertions(+), 140 deletions(-) create mode 100644 snapshot_test.go create mode 100644 statemachine_test.go diff --git a/Makefile b/Makefile index 95d6d57..ddb4e5c 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,11 @@ -all: test +COVERPROFILE=/tmp/c.out -coverage: - gocov test github.com/goraft/raft | gocov-html > coverage.html - open coverage.html +default: test + +cover: + go test -coverprofile=$(COVERPROFILE) . + go tool cover -html=$(COVERPROFILE) + rm $(COVERPROFILE) dependencies: go get -d . diff --git a/command.go b/command.go index 14341fa..5a92d6d 100644 --- a/command.go +++ b/command.go @@ -69,7 +69,6 @@ func RegisterCommand(command Command) { panic(fmt.Sprintf("raft: Cannot register nil")) } else if commandTypes[command.CommandName()] != nil { panic(fmt.Sprintf("raft: Duplicate registration: %s", command.CommandName())) - return } commandTypes[command.CommandName()] = command } diff --git a/server.go b/server.go index bb3efd6..8d787ed 100644 --- a/server.go +++ b/server.go @@ -1102,55 +1102,46 @@ func (s *server) RemovePeer(name string) error { //-------------------------------------- func (s *server) TakeSnapshot() error { - //TODO put a snapshot mutex + // TODO: put a snapshot mutex s.debugln("take Snapshot") + + // Exit if the server is currently creating a snapshot. if s.currentSnapshot != nil { return errors.New("handling snapshot") } + // Exit if there are no logs yet in the system. lastIndex, lastTerm := s.log.commitInfo() - + path := s.SnapshotPath(lastIndex, lastTerm) if lastIndex == 0 { return errors.New("No logs") } - path := s.SnapshotPath(lastIndex, lastTerm) - var state []byte var err error - if s.stateMachine != nil { state, err = s.stateMachine.Save() - if err != nil { return err } - } else { state = []byte{0} } - peers := make([]*Peer, len(s.peers)+1) - - i := 0 + // Clone the list of peers. + peers := make([]*Peer, 0, len(s.peers)+1) for _, peer := range s.peers { - peers[i] = peer.clone() - i++ - } - - peers[i] = &Peer{ - Name: s.Name(), - ConnectionString: s.connectionString, + peers = append(peers, peer.clone()) } + peers = append(peers, &Peer{Name: s.Name(), ConnectionString: s.connectionString}) + // Attach current snapshot and save it to disk. s.currentSnapshot = &Snapshot{lastIndex, lastTerm, peers, state, path} - s.saveSnapshot() - // We keep some log entries after the snapshot - // We do not want to send the whole snapshot - // to the slightly slow machines - if lastIndex-s.log.startIndex > NumberOfLogEntriesAfterSnapshot { + // We keep some log entries after the snapshot. + // We do not want to send the whole snapshot to the slightly slow machines + if lastIndex - s.log.startIndex > NumberOfLogEntriesAfterSnapshot { compactIndex := lastIndex - NumberOfLogEntriesAfterSnapshot compactTerm := s.log.getEntry(compactIndex).Term s.log.compact(compactIndex, compactTerm) @@ -1161,25 +1152,25 @@ func (s *server) TakeSnapshot() error { // Retrieves the log path for the server. func (s *server) saveSnapshot() error { - if s.currentSnapshot == nil { return errors.New("no snapshot to save") } - err := s.currentSnapshot.save() - - if err != nil { + // Write snapshot to disk. + if err := s.currentSnapshot.save(); err != nil { return err } + // Swap the current and last snapshots. tmp := s.lastSnapshot s.lastSnapshot = s.currentSnapshot - // delete the previous snapshot if there is any change + // Delete the previous snapshot if there is any change if tmp != nil && !(tmp.LastIndex == s.lastSnapshot.LastIndex && tmp.LastTerm == s.lastSnapshot.LastTerm) { tmp.remove() } s.currentSnapshot = nil + return nil } @@ -1195,18 +1186,15 @@ func (s *server) RequestSnapshot(req *SnapshotRequest) *SnapshotResponse { } func (s *server) processSnapshotRequest(req *SnapshotRequest) *SnapshotResponse { - // If the follower’s log contains an entry at the snapshot’s last index with a term - // that matches the snapshot’s last term - // Then the follower already has all the information found in the snapshot - // and can reply false - + // that matches the snapshot’s last term, then the follower already has all the + // information found in the snapshot and can reply false. entry := s.log.getEntry(req.LastIndex) - if entry != nil && entry.Term == req.LastTerm { return newSnapshotResponse(false) } + // Update state. s.setState(Snapshotting) return newSnapshotResponse(true) @@ -1219,29 +1207,26 @@ func (s *server) SnapshotRecoveryRequest(req *SnapshotRecoveryRequest) *Snapshot } func (s *server) processSnapshotRecoveryRequest(req *SnapshotRecoveryRequest) *SnapshotRecoveryResponse { + // Recover state sent from request. + if err := s.stateMachine.Recovery(req.State); err != nil { + return newSnapshotRecoveryResponse(req.LastTerm, false, req.LastIndex) + } - s.stateMachine.Recovery(req.State) - - // clear the peer map + // Recover the cluster configuration. s.peers = make(map[string]*Peer) - - // recovery the cluster configuration for _, peer := range req.Peers { s.AddPeer(peer.Name, peer.ConnectionString) } - //update term and index + // Update log state. s.currentTerm = req.LastTerm - s.log.updateCommitIndex(req.LastIndex) - snapshotPath := s.SnapshotPath(req.LastIndex, req.LastTerm) - - s.currentSnapshot = &Snapshot{req.LastIndex, req.LastTerm, req.Peers, req.State, snapshotPath} - + // Create local snapshot. + s.currentSnapshot = &Snapshot{req.LastIndex, req.LastTerm, req.Peers, req.State, s.SnapshotPath(req.LastIndex, req.LastTerm)} s.saveSnapshot() - // clear the previous log entries + // Clear the previous log entries. s.log.compact(req.LastIndex, req.LastTerm) return newSnapshotRecoveryResponse(req.LastTerm, true, req.LastIndex) @@ -1250,79 +1235,75 @@ func (s *server) processSnapshotRecoveryRequest(req *SnapshotRecoveryRequest) *S // Load a snapshot at restart func (s *server) LoadSnapshot() error { + // Open snapshot/ directory. dir, err := os.OpenFile(path.Join(s.path, "snapshot"), os.O_RDONLY, 0) if err != nil { - return err } + // Retrieve a list of all snapshots. filenames, err := dir.Readdirnames(-1) - if err != nil { dir.Close() panic(err) } - dir.Close() + if len(filenames) == 0 { return errors.New("no snapshot") } - // not sure how many snapshot we should keep + // Grab the latest snapshot. sort.Strings(filenames) snapshotPath := path.Join(s.path, "snapshot", filenames[len(filenames)-1]) - // should not fail + // Read snapshot data. file, err := os.OpenFile(snapshotPath, os.O_RDONLY, 0) - defer file.Close() if err != nil { - panic(err) + return err } + defer file.Close() - // TODO check checksum first - - var snapshotBytes []byte + // Check checksum. var checksum uint32 - n, err := fmt.Fscanf(file, "%08x\n", &checksum) - if err != nil { return err - } - - if n != 1 { + } else if n != 1 { return errors.New("Bad snapshot file") } - snapshotBytes, _ = ioutil.ReadAll(file) - s.debugln(string(snapshotBytes)) + // Load remaining snapshot contents. + b, err := ioutil.ReadAll(file) + if err != nil { + return err + } // Generate checksum. - byteChecksum := crc32.ChecksumIEEE(snapshotBytes) - + byteChecksum := crc32.ChecksumIEEE(b) if uint32(checksum) != byteChecksum { s.debugln(checksum, " ", byteChecksum) return errors.New("bad snapshot file") } - err = json.Unmarshal(snapshotBytes, &s.lastSnapshot) - - if err != nil { + // Decode snapshot. + if err = json.Unmarshal(b, &s.lastSnapshot); err != nil { s.debugln("unmarshal error: ", err) return err } - err = s.stateMachine.Recovery(s.lastSnapshot.State) - - if err != nil { + // Recover snapshot into state machine. + if err = s.stateMachine.Recovery(s.lastSnapshot.State); err != nil { s.debugln("recovery error: ", err) return err } + // Recover cluster configuration. for _, peer := range s.lastSnapshot.Peers { s.AddPeer(peer.Name, peer.ConnectionString) } + // Update log state. s.log.startTerm = s.lastSnapshot.LastTerm s.log.startIndex = s.lastSnapshot.LastIndex s.log.updateCommitIndex(s.lastSnapshot.LastIndex) diff --git a/server_test.go b/server_test.go index f8be91e..c350512 100644 --- a/server_test.go +++ b/server_test.go @@ -10,12 +10,6 @@ import ( "time" ) -//------------------------------------------------------------------------------ -// -// Tests -// -//------------------------------------------------------------------------------ - //-------------------------------------- // Request Vote //-------------------------------------- diff --git a/snapshot.go b/snapshot.go index 4f416f7..083a003 100644 --- a/snapshot.go +++ b/snapshot.go @@ -1,64 +1,61 @@ package raft import ( - //"bytes" "encoding/json" "fmt" "hash/crc32" "os" ) -//------------------------------------------------------------------------------ -// -// Typedefs -// -//------------------------------------------------------------------------------ - -// the in memory SnapShot struct -// TODO add cluster configuration +// Snapshot represents an in-memory representation of the current state of the system. type Snapshot struct { LastIndex uint64 `json:"lastIndex"` LastTerm uint64 `json:"lastTerm"` - // cluster configuration. + + // Cluster configuration. Peers []*Peer `json:"peers"` State []byte `json:"state"` Path string `json:"path"` } -// Save the snapshot to a file +// save writes the snapshot to file. func (ss *Snapshot) save() error { - // Write machine state to temporary buffer. - - // open file + // Open the file for writing. file, err := os.OpenFile(ss.Path, os.O_CREATE|os.O_WRONLY, 0600) - if err != nil { return err } - defer file.Close() + // Serialize to JSON. b, err := json.Marshal(ss) + if err != nil { + return err + } - // Generate checksum. + // Generate checksum and write it to disk. checksum := crc32.ChecksumIEEE(b) - - // Write snapshot with checksum. if _, err = fmt.Fprintf(file, "%08x\n", checksum); err != nil { return err } + // Write the snapshot to disk. if _, err = file.Write(b); err != nil { return err } - // force the change writting to disk - file.Sync() - return err + // Ensure that the snapshot has been flushed to disk before continuing. + if err := file.Sync(); err != nil { + return err + } + + return nil } -// remove the file of the snapshot +// remove deletes the snapshot file. func (ss *Snapshot) remove() error { - err := os.Remove(ss.Path) - return err + if err := os.Remove(ss.Path); err != nil { + return err + } + return nil } diff --git a/snapshot_recovery_request.go b/snapshot_recovery_request.go index 2c96957..406ea34 100644 --- a/snapshot_recovery_request.go +++ b/snapshot_recovery_request.go @@ -16,12 +16,6 @@ type SnapshotRecoveryRequest struct { State []byte } -//------------------------------------------------------------------------------ -// -// Constructors -// -//------------------------------------------------------------------------------ - // Creates a new Snapshot request. func newSnapshotRecoveryRequest(leaderName string, snapshot *Snapshot) *SnapshotRecoveryRequest { return &SnapshotRecoveryRequest{ diff --git a/snapshot_recovery_response.go b/snapshot_recovery_response.go index 7e09f86..e55deee 100644 --- a/snapshot_recovery_response.go +++ b/snapshot_recovery_response.go @@ -14,12 +14,6 @@ type SnapshotRecoveryResponse struct { CommitIndex uint64 } -//------------------------------------------------------------------------------ -// -// Constructors -// -//------------------------------------------------------------------------------ - // Creates a new Snapshot response. func newSnapshotRecoveryResponse(term uint64, success bool, commitIndex uint64) *SnapshotRecoveryResponse { return &SnapshotRecoveryResponse{ @@ -29,8 +23,8 @@ func newSnapshotRecoveryResponse(term uint64, success bool, commitIndex uint64) } } -// Encodes the SnapshotRecoveryResponse to a buffer. Returns the number of bytes -// written and any error that may have occurred. +// Encode writes the response to a writer. +// Returns the number of bytes written and any error that occurs. func (req *SnapshotRecoveryResponse) Encode(w io.Writer) (int, error) { pb := &protobuf.ProtoSnapshotRecoveryResponse{ Term: proto.Uint64(req.Term), @@ -45,8 +39,7 @@ func (req *SnapshotRecoveryResponse) Encode(w io.Writer) (int, error) { return w.Write(p) } -// Decodes the SnapshotRecoveryResponse from a buffer. Returns the number of bytes read and -// any error that occurs. +// Decodes the SnapshotRecoveryResponse from a buffer. func (req *SnapshotRecoveryResponse) Decode(r io.Reader) (int, error) { data, err := ioutil.ReadAll(r) diff --git a/snapshot_request.go b/snapshot_request.go index 8166ef4..7de2761 100644 --- a/snapshot_request.go +++ b/snapshot_request.go @@ -14,12 +14,6 @@ type SnapshotRequest struct { LastTerm uint64 } -//------------------------------------------------------------------------------ -// -// Constructors -// -//------------------------------------------------------------------------------ - // Creates a new Snapshot request. func newSnapshotRequest(leaderName string, snapshot *Snapshot) *SnapshotRequest { return &SnapshotRequest{ diff --git a/snapshot_response.go b/snapshot_response.go index a26244b..3e1d2c4 100644 --- a/snapshot_response.go +++ b/snapshot_response.go @@ -12,12 +12,6 @@ type SnapshotResponse struct { Success bool `json:"success"` } -//------------------------------------------------------------------------------ -// -// Constructors -// -//------------------------------------------------------------------------------ - // Creates a new Snapshot response. func newSnapshotResponse(success bool) *SnapshotResponse { return &SnapshotResponse{ diff --git a/snapshot_test.go b/snapshot_test.go new file mode 100644 index 0000000..99ca05a --- /dev/null +++ b/snapshot_test.go @@ -0,0 +1,82 @@ +package raft + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// Ensure that a snapshot occurs when there are existing logs. +func TestSnapshot(t *testing.T) { + runServerWithMockStateMachine(Leader, func (s Server, m *mock.Mock) { + m.On("Save").Return([]byte("foo"), nil) + m.On("Recovery", []byte("foo")).Return(nil) + + s.Do(&testCommand1{}) + err := s.TakeSnapshot() + assert.NoError(t, err) + assert.Equal(t, s.(*server).lastSnapshot.LastIndex, uint64(3)) + + // Repeat to make sure new snapshot gets created. + s.Do(&testCommand1{}) + err = s.TakeSnapshot() + assert.NoError(t, err) + assert.Equal(t, s.(*server).lastSnapshot.LastIndex, uint64(4)) + + // Restart server. + s.Stop() + s.Start() + + // Recover from snapshot. + err = s.LoadSnapshot() + assert.NoError(t, err) + }) +} + +// Ensure that snapshotting fails if there are no log entries yet. +func TestSnapshotWithNoLog(t *testing.T) { + runServerWithMockStateMachine(Leader, func (s Server, m *mock.Mock) { + err := s.TakeSnapshot() + assert.Equal(t, err, errors.New("No logs")) + }) +} + +// Ensure that a snapshot request can be sent and received. +func TestSnapshotRequest(t *testing.T) { + runServerWithMockStateMachine(Follower, func (s Server, m *mock.Mock) { + m.On("Recovery", []byte("bar")).Return(nil) + + // Send snapshot request. + resp := s.RequestSnapshot(&SnapshotRequest{LastIndex: 5, LastTerm: 1}) + assert.Equal(t, resp.Success, true) + assert.Equal(t, s.State(), Snapshotting) + + // Send recovery request. + resp2 := s.SnapshotRecoveryRequest(&SnapshotRecoveryRequest{ + LeaderName: "1", + LastIndex: 5, + LastTerm: 2, + Peers: make([]*Peer, 0), + State: []byte("bar"), + }) + assert.Equal(t, resp2.Success, true) + }) +} + +func runServerWithMockStateMachine(state string, fn func (s Server, m *mock.Mock)) { + var m mockStateMachine + s := newTestServer("1", &testTransporter{}) + s.(*server).stateMachine = &m + if err := s.Start(); err != nil { + panic("server start error: " + err.Error()) + } + if state == Leader { + if _, err := s.Do(&DefaultJoinCommand{Name: s.Name()}); err != nil { + panic("unable to join server to self: " + err.Error()) + } + } + defer s.Stop() + fn(s, &m.Mock) +} diff --git a/statemachine.go b/statemachine.go index a0a22e8..7d6ee79 100644 --- a/statemachine.go +++ b/statemachine.go @@ -1,11 +1,5 @@ package raft -//------------------------------------------------------------------------------ -// -// Typedefs -// -//------------------------------------------------------------------------------ - // StateMachine is the interface for allowing the host application to save and // recovery the state machine. This makes it possible to make snapshots // and compact the log. diff --git a/statemachine_test.go b/statemachine_test.go new file mode 100644 index 0000000..853d6cd --- /dev/null +++ b/statemachine_test.go @@ -0,0 +1,19 @@ +package raft + +import ( + "github.com/stretchr/testify/mock" +) + +type mockStateMachine struct { + mock.Mock +} + +func (m *mockStateMachine) Save() ([]byte, error) { + args := m.Called() + return args.Get(0).([]byte), args.Error(1) +} + +func (m *mockStateMachine) Recovery(b []byte) (error) { + args := m.Called(b) + return args.Error(0) +}