-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Artur Troian <[email protected]>
- Loading branch information
Showing
3 changed files
with
386 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
package pubsub | ||
|
||
import ( | ||
"errors" | ||
|
||
"github.com/boz/go-lifecycle" | ||
) | ||
|
||
// ErrNotRunning is the error with message "not running" | ||
var ErrNotRunning = errors.New("not running") | ||
|
||
// Event interface | ||
type Event interface{} | ||
|
||
type Publisher interface { | ||
Publish(Event) error | ||
} | ||
|
||
// Bus is an async event bus that allows subscriptions to behave as a bus themselves. | ||
// When an event is published, it is sent to all subscribers asynchronously - a subscriber | ||
// cannot block other subscribers. | ||
// | ||
// NOTE: this should probably be in util/event or something (not in provider/event) | ||
type Bus interface { | ||
Publisher | ||
Subscribe() (Subscriber, error) | ||
Close() | ||
Done() <-chan struct{} | ||
} | ||
|
||
// Subscriber emits events it sees on the channel returned by Events(). | ||
// A Clone() of a subscriber will emit all events that have not been emitted | ||
// from the cloned subscriber. This is important so that events are not missed | ||
// when adding subscribers for sub-components (see `provider/bidengine/{service,order}.go`) | ||
type Subscriber interface { | ||
Events() <-chan Event | ||
Clone() (Subscriber, error) | ||
Close() | ||
Done() <-chan struct{} | ||
} | ||
|
||
type bus struct { | ||
subscriptions map[*bus]bool | ||
|
||
evbuf []Event | ||
|
||
eventch chan Event | ||
parentch chan *bus | ||
|
||
pubch chan Event | ||
subch chan chan<- Subscriber | ||
unsubch chan *bus | ||
|
||
lc lifecycle.Lifecycle | ||
} | ||
|
||
// NewBus runs a new bus and returns bus details | ||
func NewBus() Bus { | ||
bus := &bus{ | ||
subscriptions: make(map[*bus]bool), | ||
pubch: make(chan Event), | ||
subch: make(chan chan<- Subscriber), | ||
unsubch: make(chan *bus), | ||
lc: lifecycle.New(), | ||
} | ||
|
||
go bus.run() | ||
|
||
return bus | ||
} | ||
|
||
func (b *bus) Publish(ev Event) error { | ||
select { | ||
case b.pubch <- ev: | ||
return nil | ||
case <-b.lc.ShuttingDown(): | ||
return ErrNotRunning | ||
} | ||
} | ||
|
||
func (b *bus) Subscribe() (Subscriber, error) { | ||
ch := make(chan Subscriber, 1) | ||
|
||
select { | ||
case b.subch <- ch: | ||
return <-ch, nil | ||
case <-b.lc.ShuttingDown(): | ||
return nil, ErrNotRunning | ||
} | ||
} | ||
|
||
func (b *bus) Clone() (Subscriber, error) { | ||
return b.Subscribe() | ||
} | ||
|
||
func (b *bus) Events() <-chan Event { | ||
return b.eventch | ||
} | ||
|
||
func (b *bus) Close() { | ||
b.lc.Shutdown(nil) | ||
} | ||
|
||
func (b *bus) Done() <-chan struct{} { | ||
return b.lc.Done() | ||
} | ||
|
||
func (b *bus) run() { | ||
defer b.lc.ShutdownCompleted() | ||
|
||
var outch chan<- Event | ||
var curev Event | ||
|
||
loop: | ||
for { | ||
|
||
if b.eventch != nil && len(b.evbuf) > 0 { | ||
// If we're emitting events (Subscriber mode) and there | ||
// are events to emit, set up the output channel and output | ||
// event accordingly. | ||
outch = b.eventch | ||
curev = b.evbuf[0] | ||
} else { | ||
// otherwise block the output (sending to a nil channel always blocks) | ||
outch = nil | ||
} | ||
|
||
select { | ||
case err := <-b.lc.ShutdownRequest(): | ||
b.lc.ShutdownInitiated(err) | ||
break loop | ||
|
||
case outch <- curev: | ||
// Event was emitted. Shrink current event buffer. | ||
b.evbuf = b.evbuf[1:] | ||
|
||
case ev := <-b.pubch: | ||
// publish event | ||
|
||
// Buffer event. | ||
if b.eventch != nil { | ||
b.evbuf = append(b.evbuf, ev) | ||
} | ||
|
||
// Publish to children. | ||
for sub := range b.subscriptions { | ||
if err := sub.Publish(ev); err != nil && !errors.Is(err, ErrNotRunning) { | ||
panic(err) | ||
} | ||
} | ||
|
||
case ch := <-b.subch: | ||
// new subscription | ||
|
||
sub := newSubscriber(b) | ||
b.subscriptions[sub] = true | ||
|
||
ch <- sub | ||
|
||
case sub := <-b.unsubch: | ||
// subscription closed | ||
delete(b.subscriptions, sub) | ||
} | ||
} | ||
|
||
for sub := range b.subscriptions { | ||
sub.lc.ShutdownAsync(nil) | ||
} | ||
|
||
for len(b.subscriptions) > 0 { | ||
sub := <-b.unsubch | ||
delete(b.subscriptions, sub) | ||
} | ||
|
||
if b.parentch != nil { | ||
b.parentch <- b | ||
} | ||
} | ||
|
||
func newSubscriber(parent *bus) *bus { | ||
// Re-use bus struct, but populate output channel (eventch) | ||
// to enable subscriber mode. | ||
|
||
evbuf := make([]Event, len(parent.evbuf)) | ||
copy(evbuf, parent.evbuf) | ||
|
||
sub := &bus{ | ||
eventch: make(chan Event), | ||
parentch: parent.unsubch, | ||
evbuf: evbuf, | ||
|
||
subscriptions: make(map[*bus]bool), | ||
pubch: make(chan Event), | ||
subch: make(chan chan<- Subscriber), | ||
unsubch: make(chan *bus), | ||
lc: lifecycle.New(), | ||
} | ||
|
||
go sub.run() | ||
|
||
return sub | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
package pubsub_test | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/cometbft/cometbft/crypto/ed25519" | ||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
|
||
"pkg.akt.dev/go/util/pubsub" | ||
) | ||
|
||
func TestBus(t *testing.T) { | ||
bus := pubsub.NewBus() | ||
defer bus.Close() | ||
|
||
did := ed25519.GenPrivKey().PubKey().Address() | ||
|
||
ev := newEvent(did) | ||
|
||
assert.NoError(t, bus.Publish(ev)) | ||
|
||
sub1, err := bus.Subscribe() | ||
require.NoError(t, err) | ||
|
||
sub2, err := bus.Subscribe() | ||
require.NoError(t, err) | ||
|
||
assert.NoError(t, bus.Publish(ev)) | ||
|
||
select { | ||
case newEv := <-sub1.Events(): | ||
assert.Equal(t, ev, newEv) | ||
case <-pubsub.AfterThreadStart(t): | ||
require.Fail(t, "time out") | ||
} | ||
|
||
select { | ||
case newEv := <-sub2.Events(): | ||
assert.Equal(t, ev, newEv) | ||
case <-pubsub.AfterThreadStart(t): | ||
require.Fail(t, "time out") | ||
} | ||
|
||
sub2.Close() | ||
|
||
select { | ||
case <-sub2.Done(): | ||
case <-pubsub.AfterThreadStart(t): | ||
require.Fail(t, "time out") | ||
} | ||
|
||
assert.NoError(t, bus.Publish(ev)) | ||
|
||
select { | ||
case newEv := <-sub1.Events(): | ||
assert.Equal(t, ev, newEv) | ||
case <-pubsub.AfterThreadStart(t): | ||
require.Fail(t, "time out") | ||
} | ||
|
||
select { | ||
case <-sub2.Events(): | ||
require.Fail(t, "spurious event") | ||
case <-pubsub.AfterThreadStart(t): | ||
} | ||
|
||
bus.Close() | ||
|
||
select { | ||
case <-sub1.Done(): | ||
case <-pubsub.AfterThreadStart(t): | ||
require.Fail(t, "time out") | ||
} | ||
|
||
assert.Equal(t, pubsub.ErrNotRunning, bus.Publish(ev)) | ||
|
||
} | ||
|
||
func TestClone(t *testing.T) { | ||
bus := pubsub.NewBus() | ||
defer bus.Close() | ||
|
||
did1 := ed25519.GenPrivKey().PubKey().Address() | ||
ev1 := newEvent(did1) | ||
|
||
did2 := ed25519.GenPrivKey().PubKey().Address() | ||
ev2 := newEvent(did2) | ||
|
||
assert.NoError(t, bus.Publish(ev1)) | ||
|
||
sub1, err := bus.Subscribe() | ||
require.NoError(t, err) | ||
|
||
select { | ||
case <-sub1.Events(): | ||
require.Fail(t, "spurious event") | ||
case <-pubsub.AfterThreadStart(t): | ||
} | ||
|
||
assert.NoError(t, bus.Publish(ev1)) | ||
assert.NoError(t, bus.Publish(ev2)) | ||
|
||
// allow event propagation | ||
pubsub.SleepForThreadStart(t) | ||
|
||
// clone subscription | ||
sub2, err := sub1.Clone() | ||
require.NoError(t, err) | ||
|
||
// both subscriptions should receive both events | ||
|
||
for i, pev := range []pubsub.Event{ev1, ev2} { | ||
select { | ||
case ev := <-sub1.Events(): | ||
assert.Equal(t, pev, ev, "sub1 event %v", i+1) | ||
case <-pubsub.AfterThreadStart(t): | ||
require.Fail(t, "timeout sub1 event %v", i+1) | ||
} | ||
|
||
select { | ||
case ev := <-sub2.Events(): | ||
assert.Equal(t, pev, ev, "sub2 event %v", i+1) | ||
case <-pubsub.AfterThreadStart(t): | ||
require.Fail(t, "timeout sub2 event %v", i+1) | ||
} | ||
} | ||
|
||
// sub1 should close sub2 | ||
sub1.Close() | ||
|
||
select { | ||
case <-sub2.Done(): | ||
case <-pubsub.AfterThreadStart(t): | ||
require.Fail(t, "time out closing sub2") | ||
} | ||
|
||
select { | ||
case <-sub1.Done(): | ||
case <-pubsub.AfterThreadStart(t): | ||
require.Fail(t, "time out closing sub1") | ||
} | ||
|
||
} | ||
|
||
type testEvent []byte | ||
|
||
func newEvent(addr []byte) testEvent { | ||
return testEvent(addr) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
package pubsub | ||
|
||
import ( | ||
"os" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
const ( | ||
defaultDelayThreadStart = time.Millisecond * 6 | ||
) | ||
|
||
// AfterThreadStart waits for the duration of delay thread start | ||
func AfterThreadStart(t *testing.T) <-chan time.Time { | ||
return time.After(delayThreadStart(t)) | ||
} | ||
|
||
// SleepForThreadStart pass go routine for the duration of delay thread start | ||
func SleepForThreadStart(t *testing.T) { | ||
time.Sleep(delayThreadStart(t)) | ||
} | ||
|
||
func delayThreadStart(t *testing.T) time.Duration { | ||
if val := os.Getenv("TEST_DELAY_THREAD_START"); val != "" { | ||
d, err := time.ParseDuration(val) | ||
require.NoError(t, err) | ||
|
||
return d | ||
} | ||
|
||
return defaultDelayThreadStart | ||
} |