diff --git a/server.go b/server.go index 26419831d..bb71de677 100644 --- a/server.go +++ b/server.go @@ -74,9 +74,18 @@ func (s *Server) RegisterService(name string, desc *ServiceDesc) { } func (s *Server) Serve(ctx context.Context, l net.Listener) error { - s.addListener(l) + s.mu.Lock() + s.addListenerLocked(l) defer s.closeListener(l) + select { + case <-s.done: + s.mu.Unlock() + return ErrServerClosed + default: + } + s.mu.Unlock() + var ( backoff time.Duration handshaker = s.config.handshaker @@ -188,9 +197,7 @@ func (s *Server) Close() error { return err } -func (s *Server) addListener(l net.Listener) { - s.mu.Lock() - defer s.mu.Unlock() +func (s *Server) addListenerLocked(l net.Listener) { s.listeners[l] = struct{}{} } diff --git a/server_test.go b/server_test.go index cf34986d6..4a1561df8 100644 --- a/server_test.go +++ b/server_test.go @@ -298,6 +298,36 @@ func TestServerClose(t *testing.T) { checkServerShutdown(t, server) } +func TestImmediateServerShutdown(t *testing.T) { + var ( + ctx = context.Background() + server = mustServer(t)(NewServer()) + addr, listener = newTestListener(t) + errs = make(chan error, 1) + _, cleanup = newTestClient(t, addr) + ) + defer cleanup() + defer listener.Close() + go func() { + time.Sleep(1 * time.Millisecond) + errs <- server.Serve(ctx, listener) + }() + + registerTestingService(server, &testingServer{}) + + if err := server.Shutdown(ctx); err != nil { + t.Fatal(err) + } + select { + case err := <-errs: + if err != ErrServerClosed { + t.Fatal(err) + } + case <-time.After(2 * time.Second): + t.Fatal("retreiving error from server.Shutdown() timed out") + } +} + func TestOversizeCall(t *testing.T) { var ( ctx = context.Background()