diff --git a/layer4/listener.go b/layer4/listener.go index d2e4ae5..85d5193 100644 --- a/layer4/listener.go +++ b/layer4/listener.go @@ -66,6 +66,7 @@ func (lw *ListenerWrapper) WrapListener(l net.Listener) net.Listener { Listener: l, logger: lw.logger, compiledRoute: lw.compiledRoute, + done: make(chan struct{}), connChan: connChan, wg: new(sync.WaitGroup), } @@ -116,9 +117,9 @@ type listener struct { compiledRoute Handler closed atomic.Bool + done chan struct{} // closed when there is a non-recoverable error and all handle goroutines are done connChan chan net.Conn - err error // count running handles wg *sync.WaitGroup @@ -135,7 +136,6 @@ func (l *listener) loop() { conn, err := l.Listener.Accept() // listener closed if l.closed.Load() { - l.err = net.ErrClosed break } @@ -145,7 +145,6 @@ func (l *listener) loop() { continue } if err != nil { - l.err = err break } @@ -158,6 +157,7 @@ func (l *listener) loop() { l.wg.Wait() close(l.connChan) }() + close(l.done) for conn := range l.connChan { _ = conn.Close() } @@ -198,11 +198,15 @@ func (l *listener) handle(conn net.Conn) { } func (l *listener) Accept() (net.Conn, error) { - for conn := range l.connChan { - return conn, nil + select { + case conn, ok := <-l.connChan: + if ok { + return conn, nil + } + return nil, net.ErrClosed + case <-l.done: + return nil, net.ErrClosed } - return nil, l.err - } func (l *listener) pipeConnection(conn *Connection) error {