diff --git a/pkg/ipc/restful/restful.go b/pkg/ipc/restful/restful.go index b268b2a..0de2efa 100644 --- a/pkg/ipc/restful/restful.go +++ b/pkg/ipc/restful/restful.go @@ -162,23 +162,30 @@ func (s *Restful) mux() *http.ServeMux { } outCh := infinity.NewChannel[string]() - errCh := make(chan string) - doneCh := make(chan struct{}) + errCh := make(chan string, 1) + doneCh := make(chan struct{}, 1) go func() { - if err := s.exec(body.Command, outCh, errCh); err != nil { + if err := s.exec(r.Context(), body.Command, outCh, errCh); err != nil { s.log.Warnf("Failed to execute command: %v", err) } - _, _ = fmt.Fprintf(w, "event: done\n") - _, _ = fmt.Fprintf(w, "data: done\n\n") - w.(http.Flusher).Flush() - doneCh <- struct{}{} outCh.Close() close(errCh) }() + defer func() { + select { + case <-r.Context().Done(): + // pass + default: + _, _ = fmt.Fprintf(w, "event: done\n") + _, _ = fmt.Fprintf(w, "data: done\n\n") + w.(http.Flusher).Flush() + } + }() + for { select { case <-doneCh: @@ -295,7 +302,7 @@ func (s *Restful) powerSaveMode(enable bool) { s.opt.PowerSaveMode = enable } -func (s *Restful) exec(command string, outCh *infinity.Channel[string], errCh chan string) error { +func (s *Restful) exec(ctx context.Context, command string, outCh *infinity.Channel[string], errCh chan string) error { s.log.Info("request /exec") conf := &ssh.ClientConfig{ @@ -313,6 +320,10 @@ func (s *Restful) exec(command string, outCh *infinity.Channel[string], errCh ch } defer conn.Close() + context.AfterFunc(ctx, func() { + _ = conn.Close() + }) + session, err := conn.NewSession() if err != nil { errCh <- fmt.Sprintf("new ssh session error: %v", err)