diff --git a/go.mod b/go.mod index 8a9f4b067f..b2953e40b1 100644 --- a/go.mod +++ b/go.mod @@ -20,7 +20,6 @@ require ( github.com/golang-jwt/jwt v3.2.2+incompatible github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.4.2 - github.com/kamilsk/breaker v1.2.1 github.com/mattn/go-sqlite3 v1.14.16 github.com/naoina/toml v0.1.2-0.20170918210437-9fafd6967416 github.com/pkg/errors v0.9.1 diff --git a/go.sum b/go.sum index a340958e70..2a2657a0c5 100644 --- a/go.sum +++ b/go.sum @@ -359,8 +359,6 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w= github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k= -github.com/kamilsk/breaker v1.2.1 h1:rOQ2AizoWUsNDg/0x2dtH/zjZjz8neFTBo+Y2IwpgO0= -github.com/kamilsk/breaker v1.2.1/go.mod h1:anrqSwLso3GOznuRshGash/NhQ7olWwTekQ42d4jO8g= github.com/kataras/golog v0.0.10/go.mod h1:yJ8YKCmyL+nWjERB90Qwn+bdyBZsaQwU3bTVFgkFIp8= github.com/kataras/iris/v12 v12.1.8/go.mod h1:LMYy4VlP67TQ3Zgriz8RE2h2kMZV2SgMYbq3UhfoFmE= github.com/kataras/neffos v0.0.14/go.mod h1:8lqADm8PnbeFfL7CLXh1WHw53dG27MC3pgi2R1rmoTE= diff --git a/go/common/stopcontrol/stop_control.go b/go/common/stopcontrol/stop_control.go index dc0988afb4..5a44536012 100644 --- a/go/common/stopcontrol/stop_control.go +++ b/go/common/stopcontrol/stop_control.go @@ -1,22 +1,35 @@ package stopcontrol -import "sync/atomic" +import ( + "sync" + "sync/atomic" +) // StopControl allows for any instance to thread-safely check if the status is stopping or not type StopControl struct { - stop *int32 + stop *int32 + stopChan chan interface{} + closer sync.Once } func New() *StopControl { return &StopControl{ - stop: new(int32), + stop: new(int32), + stopChan: make(chan interface{}), } } func (s *StopControl) Stop() { - atomic.StoreInt32(s.stop, 1) + s.closer.Do(func() { + atomic.StoreInt32(s.stop, 1) + close(s.stopChan) + }) } func (s *StopControl) IsStopping() bool { return atomic.LoadInt32(s.stop) == 1 } + +func (s *StopControl) Done() chan interface{} { + return s.stopChan +} diff --git a/go/common/stopcontrol/stop_control_test.go b/go/common/stopcontrol/stop_control_test.go new file mode 100644 index 0000000000..a94ea66124 --- /dev/null +++ b/go/common/stopcontrol/stop_control_test.go @@ -0,0 +1,60 @@ +package stopcontrol + +import ( + "testing" + "time" +) + +func TestStopControl_Stop(t *testing.T) { + sc := New() + sc.Stop() + + if !sc.IsStopping() { + t.Error("Expected IsStopping to return true after Stop, but got false") + } + + // Ensure it's safe to call Stop multiple times + func() { + defer func() { + if r := recover(); r != nil { + t.Error("Expected no panic when calling Stop multiple times") + } + }() + sc.Stop() + }() +} + +func TestStopControl_IsStopping(t *testing.T) { + sc := New() + + if sc.IsStopping() { + t.Error("Expected IsStopping to return false initially, but got true") + } + + sc.Stop() + + if !sc.IsStopping() { + t.Error("Expected IsStopping to return true after Stop, but got false") + } +} + +func TestStopControl_Done(t *testing.T) { + sc := New() + + select { + case <-sc.Done(): + t.Error("Expected Done channel to be blocking initially") + case <-time.After(50 * time.Millisecond): // Allow a small delay to check the non-blocking state + } + + sc.Stop() + + select { + case _, ok := <-sc.Done(): + if ok { + t.Error("Expected Done channel to be closed after Stop") + } + case <-time.After(50 * time.Millisecond): + t.Error("Expected Done channel to be closed immediately after Stop") + } +} diff --git a/go/host/enclave/guardian.go b/go/host/enclave/guardian.go index 5a2cb456c8..1617e650fd 100644 --- a/go/host/enclave/guardian.go +++ b/go/host/enclave/guardian.go @@ -8,12 +8,12 @@ import ( "sync/atomic" "time" + "github.com/obscuronet/go-obscuro/go/common/stopcontrol" + gethcommon "github.com/ethereum/go-ethereum/common" "github.com/obscuronet/go-obscuro/go/common/gethutil" - "github.com/kamilsk/breaker" - "github.com/ethereum/go-ethereum/core/types" gethlog "github.com/ethereum/go-ethereum/log" "github.com/obscuronet/go-obscuro/go/common" @@ -66,12 +66,12 @@ type Guardian struct { l1StartHash gethcommon.Hash running atomic.Bool - hostInterrupter breaker.Interface // host hostInterrupter so we can stop quickly + hostInterrupter *stopcontrol.StopControl // host hostInterrupter so we can stop quickly logger gethlog.Logger } -func NewGuardian(cfg *config.HostConfig, hostData host.Identity, serviceLocator guardianServiceLocator, enclaveClient common.Enclave, db *db.DB, interrupter breaker.Interface, logger gethlog.Logger) *Guardian { +func NewGuardian(cfg *config.HostConfig, hostData host.Identity, serviceLocator guardianServiceLocator, enclaveClient common.Enclave, db *db.DB, interrupter *stopcontrol.StopControl, logger gethlog.Logger) *Guardian { return &Guardian{ hostData: hostData, state: NewStateTracker(logger), diff --git a/go/host/host.go b/go/host/host.go index d91c51d662..0be21f3450 100644 --- a/go/host/host.go +++ b/go/host/host.go @@ -3,9 +3,6 @@ package host import ( "encoding/json" "fmt" - "os" - - "github.com/kamilsk/breaker" "github.com/obscuronet/go-obscuro/go/host/l2" @@ -50,7 +47,6 @@ type host struct { logger gethlog.Logger metricRegistry gethmetrics.Registry - interrupter breaker.Interface enclaveConfig *common.ObscuroEnclaveInfo } @@ -77,13 +73,8 @@ func NewHost(config *config.HostConfig, hostServices *ServicesRegistry, p2p P2PH stopControl: stopcontrol.New(), } - host.interrupter = breaker.Multiplex( - breaker.BreakBySignal( - os.Kill, - os.Interrupt, - ), - ) - enclGuardian := enclave.NewGuardian(config, hostIdentity, hostServices, enclaveClient, database, host.interrupter, logger) + + enclGuardian := enclave.NewGuardian(config, hostIdentity, hostServices, enclaveClient, database, host.stopControl, logger) enclService := enclave.NewService(hostIdentity, hostServices, enclGuardian, logger) l2Repo := l2.NewBatchRepository(config, hostServices, database, logger) subsService := events.NewLogEventManager(hostServices, logger) @@ -116,13 +107,6 @@ func (h *host) Start() error { return responses.ToInternalError(fmt.Errorf("requested Start with the host stopping")) } - h.interrupter = breaker.Multiplex( - breaker.BreakBySignal( - os.Kill, - os.Interrupt, - ), - ) - h.validateConfig() // start all registered services @@ -180,7 +164,6 @@ func (h *host) Stop() error { h.stopControl.Stop() h.logger.Info("Host received a stop command. Attempting shutdown...") - h.interrupter.Close() // stop all registered services for name, service := range h.services.All() {