Skip to content

Commit

Permalink
Add more tests (#6)
Browse files Browse the repository at this point in the history
* topic manager

- renamed to topic manager
- extracted from pubsub.go
- added tests

* msg router test

* test cfg defaults

+ fix pointer reciever

* makefile: fix test-cov

* test read config

* fix gen priv key + test
  • Loading branch information
amirylm authored Nov 27, 2023
1 parent 1d08658 commit 13b84f5
Show file tree
Hide file tree
Showing 12 changed files with 333 additions and 100 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ test-pkg:
@go test -v -race -timeout=${TEST_TIMEOUT} ${TEST_PKG}

test-cov:
@go test -v -race -timeout=${TEST_TIMEOUT} -coverprofile cover.out `go list ./... | grep -v -E "cmd|scripts|resources|examples|proto"`
@go test -timeout=${TEST_TIMEOUT} -coverprofile cover.out `go list ./... | grep -v -E "cmd|scripts|resources|examples|proto"`

test-open-cov:
@make test-cov
Expand Down
2 changes: 1 addition & 1 deletion commons/config_pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ type MsgValidationConfig struct {
Concurrency int `json:"concurrency,omitempty" yaml:"concurrency,omitempty"`
}

func (mvc MsgValidationConfig) Defaults(other *MsgValidationConfig) MsgValidationConfig {
func (mvc *MsgValidationConfig) Defaults(other *MsgValidationConfig) *MsgValidationConfig {
if mvc.Timeout.Milliseconds() == 0 {
mvc.Timeout = time.Second * 5
}
Expand Down
48 changes: 48 additions & 0 deletions commons/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package commons

import (
"testing"
"time"

"github.com/stretchr/testify/require"
)

func TestDefaults(t *testing.T) {
t.Run("Config", func(t *testing.T) {
cfg := &Config{}
cfg.Defaults()
require.Equal(t, 15*time.Second, cfg.DialTimeout)
require.Equal(t, 1024, cfg.MsgRouterQueueSize)
require.Equal(t, 10, cfg.MsgRouterWorkers)
require.Empty(t, cfg.ListenAddrs)
require.Empty(t, cfg.PSK)
})

t.Run("ConnManagerConfig", func(t *testing.T) {
cfg := &ConnManagerConfig{}
cfg.Defaults()
require.Equal(t, 5, cfg.LowWaterMark)
require.Equal(t, 25, cfg.HighWaterMark)
require.Equal(t, time.Minute, cfg.GracePeriod)
})

t.Run("DiscoveryConfig", func(t *testing.T) {
cfg := &DiscoveryConfig{}
cfg.Defaults()
require.Equal(t, ModeServer, cfg.Mode)
require.Equal(t, "p2pmq", cfg.ProtocolPrefix)
})

t.Run("MsgValidationConfig", func(t *testing.T) {
cfg := &MsgValidationConfig{}
cfg.Defaults(nil)
require.Equal(t, time.Second*5, cfg.Timeout)
require.Equal(t, 10, cfg.Concurrency)

cfg.Concurrency = 4
cfg2 := &MsgValidationConfig{}
cfg2.Defaults(cfg)
require.Equal(t, time.Second*5, cfg2.Timeout)
require.Equal(t, 4, cfg2.Concurrency)
})
}
24 changes: 24 additions & 0 deletions commons/io_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package commons

import (
"path"
"testing"

"github.com/stretchr/testify/require"
)

func TestReadConfig(t *testing.T) {
t.Run("happy path yaml", func(t *testing.T) {
cfgPath := "resources/config/default.p2pmq.yaml"
cfg, err := ReadConfig(path.Join("..", cfgPath))
require.NoError(t, err)
require.Len(t, cfg.ListenAddrs, 1)
require.Equal(t, "/ip4/0.0.0.0/tcp/5101", cfg.ListenAddrs[0])
})

t.Run("non existing json file", func(t *testing.T) {
cfgPath := "resources/config/non-existing.p2pmq.json"
_, err := ReadConfig(path.Join("..", cfgPath))
require.Error(t, err)
})
}
9 changes: 5 additions & 4 deletions commons/netkey.go → commons/net.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@ func GetOrGeneratePrivateKey(privKeyB64 string) (sk crypto.PrivKey, encodedB64 s
return nil, encodedB64, err
}
// TODO: pool base64 encoders
encodedB64 = base64.StdEncoding.EncodeToString(encoded)
encodedB64 = base64.RawStdEncoding.EncodeToString(encoded)
return sk, encodedB64, nil
}
encoded, err := base64.StdEncoding.DecodeString(encodedB64)
encodedB64 = privKeyB64
encoded, err := base64.RawStdEncoding.DecodeString(encodedB64)
if err != nil {
return nil, privKeyB64, err
return nil, encodedB64, fmt.Errorf("failed to decode private key with base64: %w", err)
}
sk, err = crypto.UnmarshalPrivateKey(encoded)
if err != nil {
return nil, privKeyB64, err
return nil, encodedB64, fmt.Errorf("failed to unmarshal private key: %w", err)
}
return sk, encodedB64, nil
}
26 changes: 26 additions & 0 deletions commons/net_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package commons

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestGetOrGeneratePrivateKey(t *testing.T) {
sk, skb64, err := GetOrGeneratePrivateKey("")
require.NoError(t, err)
require.NotNil(t, sk)
require.NotEmpty(t, skb64)

sk2, sk2b64, err := GetOrGeneratePrivateKey(skb64)
require.NoError(t, err)
require.NotNil(t, sk2)
require.Equal(t, skb64, sk2b64)
require.True(t, sk.Equals(sk2))

t.Run("bad input", func(t *testing.T) {
sk, _, err := GetOrGeneratePrivateKey("bad input")
require.Error(t, err)
require.Nil(t, sk)
})
}
7 changes: 4 additions & 3 deletions core/ctrl.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ type Controller struct {
mdnsSvc mdns.Service
pubsub *pubsub.PubSub

psManager pubsubManager
denylist pubsub.Blacklist
subFilter pubsub.SubscriptionFilter
topicManager *topicManager
denylist pubsub.Blacklist
subFilter pubsub.SubscriptionFilter

valRouter MsgRouter[pubsub.ValidationResult]
msgRouter MsgRouter[error]
Expand All @@ -54,6 +54,7 @@ func NewController(
cfg: cfg,
valRouter: valRouter,
msgRouter: msgRouter,
topicManager: newTopicManager(),
}
err := d.setup(ctx, cfg)

Expand Down
2 changes: 2 additions & 0 deletions core/ctrl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ func TestController_Sanity(t *testing.T) {

for topic, counter := range msgHitMap {
count := int(counter.Load()) / n // per node
// add 1 to account for the first message sent by the node
count += 1
require.GreaterOrEqual(t, count, rounds, "should get %d messages on topic %s", rounds, topic)
}
}
100 changes: 9 additions & 91 deletions core/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import (
"context"
"fmt"
"regexp"
"sync"
"sync/atomic"
"time"

"github.com/amirylm/p2pmq/commons"
Expand Down Expand Up @@ -71,7 +69,7 @@ func (c *Controller) setupPubsubRouter(ctx context.Context, cfg commons.Config)
}
c.pubsub = ps
c.denylist = denylist
c.psManager.topics = make(map[string]*topicWrapper)
c.topicManager.topics = make(map[string]*topicWrapper)

return nil
}
Expand All @@ -86,7 +84,7 @@ func (c *Controller) Publish(ctx context.Context, topicName string, data []byte)
}

func (c *Controller) Leave(topicName string) error {
tw := c.psManager.getTopicWrapper(topicName)
tw := c.topicManager.getTopicWrapper(topicName)
state := tw.state.Load()
switch state {
case topicStateJoined, topicStateErr:
Expand All @@ -102,7 +100,7 @@ func (c *Controller) Leave(topicName string) error {
}

func (c *Controller) Unsubscribe(topicName string) error {
tw := c.psManager.getTopicWrapper(topicName)
tw := c.topicManager.getTopicWrapper(topicName)
if tw.state.Load() == topicStateUnknown {
return nil // TODO: topic not found?
}
Expand Down Expand Up @@ -160,14 +158,14 @@ func (c *Controller) listenSubscription(ctx context.Context, sub *pubsub.Subscri
}

func (c *Controller) tryJoin(topicName string) (*pubsub.Topic, error) {
topicW := c.psManager.getTopicWrapper(topicName)
topicW := c.topicManager.getTopicWrapper(topicName)
if topicW != nil {
if topicW.state.Load() == topicStateJoining {
return nil, fmt.Errorf("already tring to join topic %s", topicName)
}
return topicW.topic, nil
}
c.psManager.joiningTopic(topicName)
c.topicManager.joiningTopic(topicName)
opts := []pubsub.TopicOpt{}
cfg, ok := c.cfg.Pubsub.GetTopicConfig(topicName)
if ok {
Expand All @@ -180,10 +178,10 @@ func (c *Controller) tryJoin(topicName string) (*pubsub.Topic, error) {
if err != nil {
return nil, err
}
c.psManager.upgradeTopic(topicName, topic)
c.topicManager.upgradeTopic(topicName, topic)

if cfg.MsgValidator != nil || c.cfg.Pubsub.MsgValidator != nil {
msgValConfig := commons.MsgValidationConfig{}.Defaults(c.cfg.Pubsub.MsgValidator)
msgValConfig := (&commons.MsgValidationConfig{}).Defaults(c.cfg.Pubsub.MsgValidator)
if cfg.MsgValidator != nil {
msgValConfig = msgValConfig.Defaults(cfg.MsgValidator)
}
Expand All @@ -202,7 +200,7 @@ func (c *Controller) tryJoin(topicName string) (*pubsub.Topic, error) {

func (c *Controller) trySubscribe(topic *pubsub.Topic) (sub *pubsub.Subscription, err error) {
topicName := topic.String()
sub = c.psManager.getSub(topicName)
sub = c.topicManager.getSub(topicName)
if sub != nil {
return nil, nil
}
Expand All @@ -217,7 +215,7 @@ func (c *Controller) trySubscribe(topic *pubsub.Topic) (sub *pubsub.Subscription
if err != nil {
return nil, err
}
c.psManager.addSub(topicName, sub)
c.topicManager.addSub(topicName, sub)
return sub, nil
}

Expand All @@ -240,86 +238,6 @@ func (c *Controller) inspectPeerScores(map[peer.ID]*pubsub.PeerScoreSnapshot) {
// TODO
}

type pubsubManager struct {
lock sync.RWMutex
topics map[string]*topicWrapper
}

type topicWrapper struct {
state atomic.Int32
topic *pubsub.Topic
sub *pubsub.Subscription
}

const (
topicStateUnknown = int32(0)
topicStateJoining = int32(1)
topicStateJoined = int32(2)
topicStateErr = int32(10)
)

func (pm *pubsubManager) joiningTopic(name string) {
pm.lock.Lock()
defer pm.lock.Unlock()

tw := &topicWrapper{}
tw.state.Store(topicStateJoining)
pm.topics[name] = tw
}

func (pm *pubsubManager) upgradeTopic(name string, topic *pubsub.Topic) bool {
pm.lock.Lock()
defer pm.lock.Unlock()

tw, ok := pm.topics[name]
if !ok {
return false
}
if !tw.state.CompareAndSwap(topicStateJoining, topicStateJoined) {
return false
}
tw.topic = topic
pm.topics[name] = tw

return true
}

func (pm *pubsubManager) getTopicWrapper(topic string) *topicWrapper {
pm.lock.RLock()
defer pm.lock.RUnlock()

t, ok := pm.topics[topic]
if !ok {
return nil
}
return t
}

func (pm *pubsubManager) addSub(name string, sub *pubsub.Subscription) bool {
pm.lock.Lock()
defer pm.lock.Unlock()

tw, ok := pm.topics[name]
if ok {
// TODO: enable multiple subscriptions per topic
return false
}
tw.sub = sub

return true
}

func (pm *pubsubManager) getSub(topic string) *pubsub.Subscription {
pm.lock.RLock()
defer pm.lock.RUnlock()

tw, ok := pm.topics[topic]
if !ok {
return nil
}
return tw.sub
}

// psTracer helps to trace pubsub events, implements pubsublibp2p.EventTracer
type psTracer struct {
lggr *zap.SugaredLogger
Expand Down
Loading

0 comments on commit 13b84f5

Please sign in to comment.