diff --git a/tools/remotecommand/remotecommand.go b/tools/remotecommand/remotecommand.go index cb39faf7f1..662a3cb4ac 100644 --- a/tools/remotecommand/remotecommand.go +++ b/tools/remotecommand/remotecommand.go @@ -17,6 +17,7 @@ limitations under the License. package remotecommand import ( + "context" "fmt" "io" "net/http" @@ -27,7 +28,7 @@ import ( "k8s.io/apimachinery/pkg/util/httpstream" "k8s.io/apimachinery/pkg/util/remotecommand" restclient "k8s.io/client-go/rest" - spdy "k8s.io/client-go/transport/spdy" + "k8s.io/client-go/transport/spdy" ) // StreamOptions holds information pertaining to the current streaming session: @@ -43,11 +44,16 @@ type StreamOptions struct { // Executor is an interface for transporting shell-style streams. type Executor interface { - // Stream initiates the transport of the standard shell streams. It will transport any - // non-nil stream to a remote system, and return an error if a problem occurs. If tty - // is set, the stderr stream is not used (raw TTY manages stdout and stderr over the - // stdout stream). + // Deprecated: use StreamWithContext instead to avoid possible resource leaks. + // See https://github.com/kubernetes/kubernetes/pull/103177 for details. Stream(options StreamOptions) error + + // StreamWithContext initiates the transport of the standard shell streams. It will + // transport any non-nil stream to a remote system, and return an error if a problem + // occurs. If tty is set, the stderr stream is not used (raw TTY manages stdout and + // stderr over the stdout stream). + // The context controls the entire lifetime of stream execution. + StreamWithContext(ctx context.Context, options StreamOptions) error } type streamCreator interface { @@ -106,9 +112,14 @@ func NewSPDYExecutorForProtocols(transport http.RoundTripper, upgrader spdy.Upgr // Stream opens a protocol streamer to the server and streams until a client closes // the connection or the server disconnects. func (e *streamExecutor) Stream(options StreamOptions) error { - req, err := http.NewRequest(e.method, e.url.String(), nil) + return e.StreamWithContext(context.Background(), options) +} + +// newConnectionAndStream creates a new SPDY connection and a stream protocol handler upon it. +func (e *streamExecutor) newConnectionAndStream(ctx context.Context, options StreamOptions) (httpstream.Connection, streamProtocolHandler, error) { + req, err := http.NewRequestWithContext(ctx, e.method, e.url.String(), nil) if err != nil { - return fmt.Errorf("error creating request: %v", err) + return nil, nil, fmt.Errorf("error creating request: %v", err) } conn, protocol, err := spdy.Negotiate( @@ -118,9 +129,8 @@ func (e *streamExecutor) Stream(options StreamOptions) error { e.protocols..., ) if err != nil { - return err + return nil, nil, err } - defer conn.Close() var streamer streamProtocolHandler @@ -138,5 +148,35 @@ func (e *streamExecutor) Stream(options StreamOptions) error { streamer = newStreamProtocolV1(options) } - return streamer.stream(conn) + return conn, streamer, nil +} + +// StreamWithContext opens a protocol streamer to the server and streams until a client closes +// the connection or the server disconnects or the context is done. +func (e *streamExecutor) StreamWithContext(ctx context.Context, options StreamOptions) error { + conn, streamer, err := e.newConnectionAndStream(ctx, options) + if err != nil { + return err + } + defer conn.Close() + + panicChan := make(chan any, 1) + errorChan := make(chan error, 1) + go func() { + defer func() { + if p := recover(); p != nil { + panicChan <- p + } + }() + errorChan <- streamer.stream(conn) + }() + + select { + case p := <-panicChan: + panic(p) + case err := <-errorChan: + return err + case <-ctx.Done(): + return ctx.Err() + } } diff --git a/tools/remotecommand/remotecommand_test.go b/tools/remotecommand/remotecommand_test.go index 7eec4565ed..9144a14526 100644 --- a/tools/remotecommand/remotecommand_test.go +++ b/tools/remotecommand/remotecommand_test.go @@ -17,9 +17,17 @@ limitations under the License. package remotecommand import ( + "context" "encoding/json" "errors" "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + v1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -28,12 +36,6 @@ import ( remotecommandconsts "k8s.io/apimachinery/pkg/util/remotecommand" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/rest" - "net/http" - "net/http/httptest" - "net/url" - "strings" - "testing" - "time" ) type AttachFunc func(in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan TerminalSize) error @@ -50,6 +52,17 @@ type streamAndReply struct { replySent <-chan struct{} } +type fakeEmptyDataPty struct { +} + +func (s *fakeEmptyDataPty) Read(p []byte) (int, error) { + return len(p), nil +} + +func (s *fakeEmptyDataPty) Write(p []byte) (int, error) { + return len(p), nil +} + type fakeMassiveDataPty struct{} func (s *fakeMassiveDataPty) Read(p []byte) (int, error) { @@ -107,6 +120,7 @@ func writeMassiveData(stdStream io.Writer) struct{} { // write to stdin or stdou func TestSPDYExecutorStream(t *testing.T) { tests := []struct { + timeout time.Duration name string options StreamOptions expectError string @@ -130,23 +144,40 @@ func TestSPDYExecutorStream(t *testing.T) { expectError: "", attacher: fakeMassiveDataAttacher, }, + { + timeout: 500 * time.Millisecond, + name: "timeoutTest", + options: StreamOptions{ + Stdin: &fakeMassiveDataPty{}, + Stderr: &fakeMassiveDataPty{}, + }, + expectError: context.DeadlineExceeded.Error(), + attacher: fakeMassiveDataAttacher, + }, } for _, test := range tests { - server := newTestHTTPServer(test.attacher, &test.options) + t.Run(test.name, func(t *testing.T) { + server := newTestHTTPServer(test.attacher, &test.options) + defer server.Close() - err := attach2Server(server.URL, test.options) - gotError := "" - if err != nil { - gotError = err.Error() - } - if test.expectError != gotError { - t.Errorf("%s: expected [%v], got [%v]", test.name, test.expectError, gotError) - } + ctx, cancel := context.Background(), func() {} + if test.timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, test.timeout) + } + defer cancel() - server.Close() - } + err := attach2Server(ctx, server.URL, test.options) + gotError := "" + if err != nil { + gotError = err.Error() + } + if test.expectError != gotError { + t.Errorf("%s: expected [%v], got [%v]", test.name, test.expectError, gotError) + } + }) + } } func newTestHTTPServer(f AttachFunc, options *StreamOptions) *httptest.Server { @@ -170,16 +201,16 @@ func newTestHTTPServer(f AttachFunc, options *StreamOptions) *httptest.Server { return server } -func attach2Server(rawURL string, options StreamOptions) error { +func attach2Server(ctx context.Context, rawURL string, options StreamOptions) error { uri, _ := url.Parse(rawURL) exec, err := NewSPDYExecutor(&rest.Config{Host: uri.Host}, "POST", uri) if err != nil { return err } - e := make(chan error) + e := make(chan error, 1) go func(e chan error) { - e <- exec.Stream(options) + e <- exec.StreamWithContext(ctx, options) }(e) select { case err := <-e: @@ -263,3 +294,74 @@ func v4WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) err return err } } + +// writeDetector provides a helper method to block until the underlying writer written. +type writeDetector struct { + written chan bool + closed bool + io.Writer +} + +func newWriterDetector(w io.Writer) *writeDetector { + return &writeDetector{ + written: make(chan bool), + Writer: w, + } +} + +func (w *writeDetector) BlockUntilWritten() { + <-w.written +} + +func (w *writeDetector) Write(p []byte) (n int, err error) { + if !w.closed { + close(w.written) + w.closed = true + } + return w.Writer.Write(p) +} + +// `Executor.StreamWithContext` starts a goroutine in the background to do the streaming +// and expects the deferred close of the connection leads to the exit of the goroutine on cancellation. +// This test verifies that works. +func TestStreamExitsAfterConnectionIsClosed(t *testing.T) { + writeDetector := newWriterDetector(&fakeEmptyDataPty{}) + options := StreamOptions{ + Stdin: &fakeEmptyDataPty{}, + Stdout: writeDetector, + } + server := newTestHTTPServer(fakeMassiveDataAttacher, &options) + + ctx, cancelFn := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancelFn() + + uri, _ := url.Parse(server.URL) + exec, err := NewSPDYExecutor(&rest.Config{Host: uri.Host}, "POST", uri) + if err != nil { + t.Fatal(err) + } + streamExec := exec.(*streamExecutor) + + conn, streamer, err := streamExec.newConnectionAndStream(ctx, options) + if err != nil { + t.Fatal(err) + } + + errorChan := make(chan error) + go func() { + errorChan <- streamer.stream(conn) + }() + + // Wait until stream goroutine starts. + writeDetector.BlockUntilWritten() + + // Close the connection + conn.Close() + + select { + case <-time.After(1 * time.Second): + t.Fatalf("expect stream to be closed after connection is closed.") + case <-errorChan: + return + } +}