diff --git a/tools/walletextension/container/walletextension_container.go b/tools/walletextension/container/walletextension_container.go index 1dacaefa11..99bc8b6f63 100644 --- a/tools/walletextension/container/walletextension_container.go +++ b/tools/walletextension/container/walletextension_container.go @@ -159,7 +159,14 @@ func (w *WalletExtensionContainer) Start() error { for { select { case err := <-httpErrChan: - if !errors.Is(err, http.ErrServerClosed) { + if errors.Is(err, http.ErrServerClosed) { + err = w.Stop() // Stop the container when the HTTP server is closed + if err != nil { + fmt.Printf("failed to stop gracefully - %s\n", err) + os.Exit(1) + } + } else { + // for other errors, we just log them w.logger.Error("HTTP server error: %v", err) } case <-w.stopControl.Done(): @@ -173,8 +180,15 @@ func (w *WalletExtensionContainer) Start() error { for { select { case err := <-wsErrChan: - if !errors.Is(err, http.ErrServerClosed) { - w.logger.Error("WebSocket server error: %v", err) + if errors.Is(err, http.ErrServerClosed) { + err = w.Stop() // Stop the container when the WS server is closed + if err != nil { + fmt.Printf("failed to stop gracefully - %s\n", err) + os.Exit(1) + } + } else { + // for other errors, we just log them + w.logger.Error("HTTP server error: %v", err) } case <-w.stopControl.Done(): return // Exit the goroutine when stop signal is received diff --git a/tools/walletextension/main/main.go b/tools/walletextension/main/main.go index b9e9ed73ab..b909f43368 100644 --- a/tools/walletextension/main/main.go +++ b/tools/walletextension/main/main.go @@ -59,19 +59,12 @@ func main() { logger := log.New(log.WalletExtCmp, int(logLvl), config.LogPath) walletExtContainer := container.NewWalletExtensionContainerFromConfig(config, logger) - defer func() { - err := walletExtContainer.Start() - if err != nil { - fmt.Printf("error stopping WE - %s", err) - } - }() - go func() { - err := walletExtContainer.Start() - if err != nil { - fmt.Printf("error in WE - %s", err) - } - }() + // Start the wallet extension. + err := walletExtContainer.Start() + if err != nil { + fmt.Printf("error in WE - %s", err) + } walletExtensionAddr := fmt.Sprintf("%s:%d", common.Localhost, config.WalletExtensionPortHTTP) fmt.Printf("💡 Wallet extension started \n") // Some tests rely on seeing this message. Removed in next PR.