From 2c8b0df1b2ee88b24a201c84bfc517d50674938f Mon Sep 17 00:00:00 2001 From: tok-kkk Date: Tue, 3 Aug 2021 13:40:05 +1000 Subject: [PATCH] fix the mutex in dht --- dht/table.go | 18 ++++++------------ dht/table_test.go | 40 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/dht/table.go b/dht/table.go index 0c8862e..d152339 100644 --- a/dht/table.go +++ b/dht/table.go @@ -67,7 +67,7 @@ type Table interface { type InMemTable struct { self id.Signatory - sortedMu *sync.Mutex + sortedMu *sync.RWMutex sorted []id.Signatory addrsBySignatoryMu *sync.Mutex @@ -86,7 +86,7 @@ func NewInMemTable(self id.Signatory) *InMemTable { return &InMemTable{ self: self, - sortedMu: new(sync.Mutex), + sortedMu: new(sync.RWMutex), sorted: []id.Signatory{}, addrsBySignatoryMu: new(sync.Mutex), @@ -166,8 +166,8 @@ func (table *InMemTable) PeerAddress(peerID id.Signatory) (wire.Address, bool) { // Peers returns the n closest peer IDs. func (table *InMemTable) Peers(n int) []id.Signatory { - table.sortedMu.Lock() - defer table.sortedMu.Unlock() + table.sortedMu.RLock() + defer table.sortedMu.RUnlock() if n <= 0 { // For values of n that are less than, or equal to, zero, return an @@ -183,9 +183,9 @@ func (table *InMemTable) Peers(n int) []id.Signatory { // RandomPeers returns n random peer IDs func (table *InMemTable) RandomPeers(n int) []id.Signatory { - table.sortedMu.Lock() + table.sortedMu.RLock() + defer table.sortedMu.RUnlock() m := len(table.sorted) - table.sortedMu.Unlock() if n <= 0 { // For values of n that are less than, or equal to, zero, return an @@ -195,8 +195,6 @@ func (table *InMemTable) RandomPeers(n int) []id.Signatory { } if n >= m { sigs := make([]id.Signatory, m) - table.sortedMu.Lock() - defer table.sortedMu.Unlock() copy(sigs, table.sorted) return sigs } @@ -208,8 +206,6 @@ func (table *InMemTable) RandomPeers(n int) []id.Signatory { if m <= 10000 || n >= m/50.0 { shuffled := make([]id.Signatory, n) indexPerm := rand.Perm(m) - table.sortedMu.Lock() - defer table.sortedMu.Unlock() for i := 0; i < n; i++ { shuffled[i] = table.sorted[indexPerm[i]] } @@ -219,8 +215,6 @@ func (table *InMemTable) RandomPeers(n int) []id.Signatory { // Otherwise, use Floyd's sampling algorithm to select n random elements set := make(map[int]struct{}, n) randomSelection := make([]id.Signatory, 0, n) - table.sortedMu.Lock() - defer table.sortedMu.Unlock() for i := m - n; i < m; i++ { index := table.randObj.Intn(i) if _, ok := set[index]; !ok { diff --git a/dht/table_test.go b/dht/table_test.go index 3347111..40e2919 100644 --- a/dht/table_test.go +++ b/dht/table_test.go @@ -2,7 +2,7 @@ package dht_test import ( "fmt" - "github.com/renproject/aw/wire" + "log" "math/rand" "strconv" "testing/quick" @@ -10,6 +10,7 @@ import ( "github.com/renproject/aw/dht" "github.com/renproject/aw/dht/dhtutil" + "github.com/renproject/aw/wire" "github.com/renproject/id" . "github.com/onsi/ginkgo" @@ -204,6 +205,43 @@ var _ = Describe("DHT", func() { } } }) + + It("should work while deleting peers from the table", func() { + table, _ := initDHT() + numAddrs := rand.Intn(100) + numRandAddrs := rand.Intn(numAddrs) + + // Insert `numAddrs` random addresses into the store. + deletedPeers := make([]id.Signatory, 0, 50) + for i := 0; i < numAddrs; i++ { + privKey := id.NewPrivKey() + sig := privKey.Signatory() + addr := wire.NewUnsignedAddress(wire.TCP, "172.16.254.1:3000", uint64(time.Now().UnixNano())) + table.AddPeer(sig, addr) + if i < numAddrs/2 { + deletedPeers = append(deletedPeers, sig) + } + } + + done := make(chan struct{}, 1) + go func() { + defer close(done) + + for i := range deletedPeers{ + table.DeletePeer(deletedPeers[i]) + } + }() + + total := time.Duration(0) + for i := 0; i <50 ; i ++ { + start := time.Now() + table.RandomPeers(numRandAddrs) + duration := time.Now().Sub(start) + total += duration + } + log.Printf("RandomPeers takes %v on average", total/50) + <- done + }) }) Context("when querying the number of addresses", func() {