From 1213a2bfc4b32e19f6262b5fdfad5816a29213b2 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Mon, 8 Jul 2024 12:51:41 +0200 Subject: [PATCH] Wait for cleanup on certexchange server shutdown --- certexchange/server.go | 40 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/certexchange/server.go b/certexchange/server.go index 23c24d4d..edc2b40e 100644 --- a/certexchange/server.go +++ b/certexchange/server.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "runtime/debug" + "sync" "time" "github.com/filecoin-project/go-f3/certstore" @@ -28,7 +29,8 @@ type Server struct { Host host.Host Store *certstore.Store - cancel context.CancelFunc + runningLk sync.RWMutex + stopFunc context.CancelFunc } func (s *Server) withDeadline(ctx context.Context) (context.Context, context.CancelFunc) { @@ -111,9 +113,25 @@ func (s *Server) handleRequest(ctx context.Context, stream network.Stream) (_err // Start the server. func (s *Server) Start() error { + s.runningLk.Lock() + defer s.runningLk.Unlock() + if s.stopFunc != nil { + return fmt.Errorf("certificate exchange already running") + } + ctx, cancel := context.WithCancel(context.Background()) - s.cancel = cancel + s.stopFunc = cancel s.Host.SetStreamHandler(FetchProtocolName(s.NetworkName), func(stream network.Stream) { + s.runningLk.RLock() + defer s.runningLk.RUnlock() + if s.stopFunc == nil { + _ = stream.Reset() + return + } + + // Kill the stream if/when we shutdown the server. + defer context.AfterFunc(ctx, func() { _ = stream.Reset() })() + ctx, cancel := s.withDeadline(ctx) defer cancel() @@ -129,7 +147,23 @@ func (s *Server) Start() error { // Stop the server. func (s *Server) Stop() error { + // Ask the handlers to cancel/stop. + s.runningLk.RLock() + cancel := s.stopFunc + s.runningLk.RUnlock() + if cancel == nil { + return nil + } + cancel() + + // Wait and finish shutdown. + s.runningLk.Lock() + defer s.runningLk.Unlock() + if s.stopFunc == nil { + return nil + } + s.stopFunc = nil s.Host.RemoveStreamHandler(FetchProtocolName(s.NetworkName)) - s.cancel() + return nil }