Skip to content

Commit

Permalink
Merge pull request lightningnetwork#9343 from ellemouton/contextGuard
Browse files Browse the repository at this point in the history
fn: expand the ContextGuard and add tests
  • Loading branch information
guggero authored Dec 13, 2024
2 parents 6298f76 + f99cabf commit ff14847
Show file tree
Hide file tree
Showing 3 changed files with 675 additions and 57 deletions.
246 changes: 189 additions & 57 deletions fn/context_guard.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package fn
import (
"context"
"sync"
"sync/atomic"
"time"
)

Expand All @@ -11,103 +12,234 @@ var (
DefaultTimeout = 30 * time.Second
)

// ContextGuard is an embeddable struct that provides a wait group and main quit
// channel that can be used to create guarded contexts.
// ContextGuard is a struct that provides a wait group and main quit channel
// that can be used to create guarded contexts.
type ContextGuard struct {
DefaultTimeout time.Duration
Wg sync.WaitGroup
Quit chan struct{}
mu sync.Mutex
wg sync.WaitGroup

quit chan struct{}
stopped sync.Once

// id is used to generate unique ids for each context that should be
// cancelled when the main quit signal is triggered.
id atomic.Uint32

// cancelFns is a map of cancel functions that can be used to cancel
// any context that should be cancelled when the main quit signal is
// triggered. The key is the id of the context. The mutex must be held
// when accessing this map.
cancelFns map[uint32]context.CancelFunc
}

// NewContextGuard constructs and returns a new instance of ContextGuard.
func NewContextGuard() *ContextGuard {
return &ContextGuard{
DefaultTimeout: DefaultTimeout,
Quit: make(chan struct{}),
quit: make(chan struct{}),
cancelFns: make(map[uint32]context.CancelFunc),
}
}

// WithCtxQuit is used to create a cancellable context that will be cancelled
// if the main quit signal is triggered or after the default timeout occurred.
func (g *ContextGuard) WithCtxQuit() (context.Context, func()) {
return g.WithCtxQuitCustomTimeout(g.DefaultTimeout)
}
// Quit is used to signal the main quit channel, which will cancel all
// non-blocking contexts derived from the ContextGuard.
func (g *ContextGuard) Quit() {
g.stopped.Do(func() {
g.mu.Lock()
defer g.mu.Unlock()

// WithCtxQuitCustomTimeout is used to create a cancellable context that will be
// cancelled if the main quit signal is triggered or after the given timeout
// occurred.
func (g *ContextGuard) WithCtxQuitCustomTimeout(
timeout time.Duration) (context.Context, func()) {
for _, cancel := range g.cancelFns {
cancel()
}

timeoutTimer := time.NewTimer(timeout)
ctx, cancel := context.WithCancel(context.Background())
close(g.quit)
})
}

g.Wg.Add(1)
go func() {
defer timeoutTimer.Stop()
defer cancel()
defer g.Wg.Done()
// Done returns a channel that will be closed when the main quit signal is
// triggered.
func (g *ContextGuard) Done() <-chan struct{} {
return g.quit
}

select {
case <-g.Quit:
// WgAdd is used to add delta to the internal wait group of the ContextGuard.
func (g *ContextGuard) WgAdd(delta int) {
g.wg.Add(delta)
}

case <-timeoutTimer.C:
// WgDone is used to decrement the internal wait group of the ContextGuard.
func (g *ContextGuard) WgDone() {
g.wg.Done()
}

case <-ctx.Done():
}
}()
// WgWait is used to block until the internal wait group of the ContextGuard is
// empty.
func (g *ContextGuard) WgWait() {
g.wg.Wait()
}

return ctx, cancel
// ctxGuardOptions is used to configure the behaviour of the context derived
// via the WithCtx method of the ContextGuard.
type ctxGuardOptions struct {
blocking bool
withTimeout bool
timeout time.Duration
}

// CtxBlocking is used to create a cancellable context that will NOT be
// ContextGuardOption defines the signature of a functional option that can be
// used to configure the behaviour of the context derived via the WithCtx method
// of the ContextGuard.
type ContextGuardOption func(*ctxGuardOptions)

// WithBlockingCG is used to create a cancellable context that will NOT be
// cancelled if the main quit signal is triggered, to block shutdown of
// important tasks. The context will be cancelled if the timeout is reached.
func (g *ContextGuard) CtxBlocking() (context.Context, func()) {
return g.CtxBlockingCustomTimeout(g.DefaultTimeout)
// important tasks.
func WithBlockingCG() ContextGuardOption {
return func(o *ctxGuardOptions) {
o.blocking = true
}
}

// WithCustomTimeoutCG is used to create a cancellable context with a custom
// timeout. Such a context will be cancelled if either the parent context is
// cancelled, the timeout is reached or, if the Blocking option is not provided,
// the main quit signal is triggered.
func WithCustomTimeoutCG(timeout time.Duration) ContextGuardOption {
return func(o *ctxGuardOptions) {
o.withTimeout = true
o.timeout = timeout
}
}

// WithTimeoutCG is used to create a cancellable context with a default timeout.
// Such a context will be cancelled if either the parent context is cancelled,
// the timeout is reached or, if the Blocking option is not provided, the main
// quit signal is triggered.
func WithTimeoutCG() ContextGuardOption {
return func(o *ctxGuardOptions) {
o.withTimeout = true
o.timeout = DefaultTimeout
}
}

// CtxBlockingCustomTimeout is used to create a cancellable context with a
// custom timeout that will NOT be cancelled if the main quit signal is
// triggered, to block shutdown of important tasks. The context will be
// cancelled if the timeout is reached.
func (g *ContextGuard) CtxBlockingCustomTimeout(
timeout time.Duration) (context.Context, func()) {
// Create is used to derive a cancellable context from the parent. Various
// options can be provided to configure the behaviour of the derived context.
func (g *ContextGuard) Create(ctx context.Context,
options ...ContextGuardOption) (context.Context, context.CancelFunc) {

timeoutTimer := time.NewTimer(timeout)
ctx, cancel := context.WithCancel(context.Background())
// Exit early if the parent context has already been cancelled.
select {
case <-ctx.Done():
return ctx, func() {}
default:
}

var opts ctxGuardOptions
for _, o := range options {
o(&opts)
}

g.Wg.Add(1)
g.mu.Lock()
defer g.mu.Unlock()

var cancel context.CancelFunc
if opts.withTimeout {
ctx, cancel = context.WithTimeout(ctx, opts.timeout)
} else {
ctx, cancel = context.WithCancel(ctx)
}

if opts.blocking {
g.ctxBlocking(ctx, cancel)

return ctx, cancel
}

// If the call is non-blocking, then we can exit early if the main quit
// signal has been triggered.
select {
case <-g.quit:
cancel()

return ctx, cancel
default:
}

cancel = g.ctxQuitUnsafe(ctx, cancel)

return ctx, cancel
}

// ctxQuitUnsafe spins off a goroutine that will block until the passed context
// is cancelled or until the quit channel has been signaled after which it will
// call the passed cancel function and decrement the wait group.
//
// NOTE: the caller must hold the ContextGuard's mutex before calling this
// function.
func (g *ContextGuard) ctxQuitUnsafe(ctx context.Context,
cancel context.CancelFunc) context.CancelFunc {

cancel = g.addCancelFnUnsafe(cancel)

g.wg.Add(1)
go func() {
defer timeoutTimer.Stop()
defer cancel()
defer g.Wg.Done()
defer g.wg.Done()

select {
case <-timeoutTimer.C:
case <-g.quit:

case <-ctx.Done():
}
}()

return ctx, cancel
return cancel
}

// WithCtxQuitNoTimeout is used to create a cancellable context that will be
// cancelled if the main quit signal is triggered.
func (g *ContextGuard) WithCtxQuitNoTimeout() (context.Context, func()) {
ctx, cancel := context.WithCancel(context.Background())
// ctxBlocking spins off a goroutine that will block until the passed context
// is cancelled after which it will call the passed cancel function and
// decrement the wait group.
func (g *ContextGuard) ctxBlocking(ctx context.Context,
cancel context.CancelFunc) {

g.Wg.Add(1)
g.wg.Add(1)
go func() {
defer cancel()
defer g.Wg.Done()
defer g.wg.Done()

select {
case <-g.Quit:

case <-ctx.Done():
}
}()
}

return ctx, cancel
// addCancelFnUnsafe adds a context cancel function to the manager and returns a
// call-back which can safely be used to cancel the context.
//
// NOTE: the caller must hold the ContextGuard's mutex before calling this
// function.
func (g *ContextGuard) addCancelFnUnsafe(
cancel context.CancelFunc) context.CancelFunc {

id := g.id.Add(1)
g.cancelFns[id] = cancel

return g.cancelCtxFn(id)
}

// cancelCtxFn returns a call-back that can be used to cancel the context
// associated with the passed id.
func (g *ContextGuard) cancelCtxFn(id uint32) context.CancelFunc {
return func() {
g.mu.Lock()

fn, ok := g.cancelFns[id]
if !ok {
g.mu.Unlock()
return
}
delete(g.cancelFns, id)
g.mu.Unlock()

fn()
}
}
Loading

0 comments on commit ff14847

Please sign in to comment.