diff --git a/tsshd/main.go b/tsshd/main.go index e90f30a..f4ade88 100644 --- a/tsshd/main.go +++ b/tsshd/main.go @@ -29,6 +29,8 @@ import ( "io" "os" "os/exec" + "os/signal" + "syscall" "time" "github.com/trzsz/go-arg" @@ -99,6 +101,9 @@ func TsshdMain() int { // cleanup on exit defer cleanupOnExit() + // handle exit signals + handleExitSignals() + kcpListener, quicListener, err := initServer(&args) if err != nil { fmt.Println(err) @@ -127,3 +132,18 @@ func TsshdMain() int { return <-exitChan } + +func handleExitSignals() { + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, + syscall.SIGTERM, // Default signal for the kill command + syscall.SIGINT, // Ctrl+C signal + syscall.SIGHUP, // Terminal closed (System reboot/shutdown) + ) + + go func() { + <-sigChan + trySendErrorMessage("tsshd has been terminated") + closeAllSessions() + }() +} diff --git a/tsshd/session.go b/tsshd/session.go index 52b71ee..ad94e21 100644 --- a/tsshd/session.go +++ b/tsshd/session.go @@ -48,6 +48,7 @@ type sessionContext struct { stdout io.ReadCloser stderr io.ReadCloser started bool + closed bool } type stderrStream struct { @@ -130,6 +131,10 @@ func (c *sessionContext) Wait() { } func (c *sessionContext) Close() { + if c.closed { + return + } + c.closed = true if err := sendBusMessage("exit", ExitMessage{ ID: c.id, ExitCode: c.cmd.ProcessState.ExitCode(), @@ -447,3 +452,16 @@ func handleChannelAccept(listener net.Listener, channelType string) { }(conn) } } + +func closeAllSessions() { + sessionMutex.Lock() + var sessions []*sessionContext + for _, session := range sessionMap { + sessions = append(sessions, session) + } + sessionMutex.Unlock() + + for _, session := range sessions { + session.Close() + } +}