Skip to content

Commit

Permalink
Merge pull request #103177 from arkbriar/support_cancelable_exec_stream
Browse files Browse the repository at this point in the history
Support cancelable SPDY executor stream

Kubernetes-commit: 3cf75a2f760b8093f7c97f26b4b2b059f3777bec
  • Loading branch information
k8s-publishing-bot committed Nov 3, 2022
2 parents 19b2e89 + 0563dec commit bc6266d
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 30 deletions.
60 changes: 50 additions & 10 deletions tools/remotecommand/remotecommand.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package remotecommand

import (
"context"
"fmt"
"io"
"net/http"
Expand All @@ -27,7 +28,7 @@ import (
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/remotecommand"
restclient "k8s.io/client-go/rest"
spdy "k8s.io/client-go/transport/spdy"
"k8s.io/client-go/transport/spdy"
)

// StreamOptions holds information pertaining to the current streaming session:
Expand All @@ -43,11 +44,16 @@ type StreamOptions struct {

// Executor is an interface for transporting shell-style streams.
type Executor interface {
// Stream initiates the transport of the standard shell streams. It will transport any
// non-nil stream to a remote system, and return an error if a problem occurs. If tty
// is set, the stderr stream is not used (raw TTY manages stdout and stderr over the
// stdout stream).
// Deprecated: use StreamWithContext instead to avoid possible resource leaks.
// See https://github.com/kubernetes/kubernetes/pull/103177 for details.
Stream(options StreamOptions) error

// StreamWithContext initiates the transport of the standard shell streams. It will
// transport any non-nil stream to a remote system, and return an error if a problem
// occurs. If tty is set, the stderr stream is not used (raw TTY manages stdout and
// stderr over the stdout stream).
// The context controls the entire lifetime of stream execution.
StreamWithContext(ctx context.Context, options StreamOptions) error
}

type streamCreator interface {
Expand Down Expand Up @@ -106,9 +112,14 @@ func NewSPDYExecutorForProtocols(transport http.RoundTripper, upgrader spdy.Upgr
// Stream opens a protocol streamer to the server and streams until a client closes
// the connection or the server disconnects.
func (e *streamExecutor) Stream(options StreamOptions) error {
req, err := http.NewRequest(e.method, e.url.String(), nil)
return e.StreamWithContext(context.Background(), options)
}

// newConnectionAndStream creates a new SPDY connection and a stream protocol handler upon it.
func (e *streamExecutor) newConnectionAndStream(ctx context.Context, options StreamOptions) (httpstream.Connection, streamProtocolHandler, error) {
req, err := http.NewRequestWithContext(ctx, e.method, e.url.String(), nil)
if err != nil {
return fmt.Errorf("error creating request: %v", err)
return nil, nil, fmt.Errorf("error creating request: %v", err)
}

conn, protocol, err := spdy.Negotiate(
Expand All @@ -118,9 +129,8 @@ func (e *streamExecutor) Stream(options StreamOptions) error {
e.protocols...,
)
if err != nil {
return err
return nil, nil, err
}
defer conn.Close()

var streamer streamProtocolHandler

Expand All @@ -138,5 +148,35 @@ func (e *streamExecutor) Stream(options StreamOptions) error {
streamer = newStreamProtocolV1(options)
}

return streamer.stream(conn)
return conn, streamer, nil
}

// StreamWithContext opens a protocol streamer to the server and streams until a client closes
// the connection or the server disconnects or the context is done.
func (e *streamExecutor) StreamWithContext(ctx context.Context, options StreamOptions) error {
conn, streamer, err := e.newConnectionAndStream(ctx, options)
if err != nil {
return err
}
defer conn.Close()

panicChan := make(chan any, 1)
errorChan := make(chan error, 1)
go func() {
defer func() {
if p := recover(); p != nil {
panicChan <- p
}
}()
errorChan <- streamer.stream(conn)
}()

select {
case p := <-panicChan:
panic(p)
case err := <-errorChan:
return err
case <-ctx.Done():
return ctx.Err()
}
}
142 changes: 122 additions & 20 deletions tools/remotecommand/remotecommand_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,17 @@ limitations under the License.
package remotecommand

import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

v1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand All @@ -28,12 +36,6 @@ import (
remotecommandconsts "k8s.io/apimachinery/pkg/util/remotecommand"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/rest"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
)

type AttachFunc func(in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan TerminalSize) error
Expand All @@ -50,6 +52,17 @@ type streamAndReply struct {
replySent <-chan struct{}
}

type fakeEmptyDataPty struct {
}

func (s *fakeEmptyDataPty) Read(p []byte) (int, error) {
return len(p), nil
}

func (s *fakeEmptyDataPty) Write(p []byte) (int, error) {
return len(p), nil
}

type fakeMassiveDataPty struct{}

func (s *fakeMassiveDataPty) Read(p []byte) (int, error) {
Expand Down Expand Up @@ -107,6 +120,7 @@ func writeMassiveData(stdStream io.Writer) struct{} { // write to stdin or stdou

func TestSPDYExecutorStream(t *testing.T) {
tests := []struct {
timeout time.Duration
name string
options StreamOptions
expectError string
Expand All @@ -130,23 +144,40 @@ func TestSPDYExecutorStream(t *testing.T) {
expectError: "",
attacher: fakeMassiveDataAttacher,
},
{
timeout: 500 * time.Millisecond,
name: "timeoutTest",
options: StreamOptions{
Stdin: &fakeMassiveDataPty{},
Stderr: &fakeMassiveDataPty{},
},
expectError: context.DeadlineExceeded.Error(),
attacher: fakeMassiveDataAttacher,
},
}

for _, test := range tests {
server := newTestHTTPServer(test.attacher, &test.options)
t.Run(test.name, func(t *testing.T) {
server := newTestHTTPServer(test.attacher, &test.options)
defer server.Close()

err := attach2Server(server.URL, test.options)
gotError := ""
if err != nil {
gotError = err.Error()
}
if test.expectError != gotError {
t.Errorf("%s: expected [%v], got [%v]", test.name, test.expectError, gotError)
}
ctx, cancel := context.Background(), func() {}
if test.timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, test.timeout)
}
defer cancel()

server.Close()
}
err := attach2Server(ctx, server.URL, test.options)

gotError := ""
if err != nil {
gotError = err.Error()
}
if test.expectError != gotError {
t.Errorf("%s: expected [%v], got [%v]", test.name, test.expectError, gotError)
}
})
}
}

func newTestHTTPServer(f AttachFunc, options *StreamOptions) *httptest.Server {
Expand All @@ -170,16 +201,16 @@ func newTestHTTPServer(f AttachFunc, options *StreamOptions) *httptest.Server {
return server
}

func attach2Server(rawURL string, options StreamOptions) error {
func attach2Server(ctx context.Context, rawURL string, options StreamOptions) error {
uri, _ := url.Parse(rawURL)
exec, err := NewSPDYExecutor(&rest.Config{Host: uri.Host}, "POST", uri)
if err != nil {
return err
}

e := make(chan error)
e := make(chan error, 1)
go func(e chan error) {
e <- exec.Stream(options)
e <- exec.StreamWithContext(ctx, options)
}(e)
select {
case err := <-e:
Expand Down Expand Up @@ -263,3 +294,74 @@ func v4WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) err
return err
}
}

// writeDetector provides a helper method to block until the underlying writer written.
type writeDetector struct {
written chan bool
closed bool
io.Writer
}

func newWriterDetector(w io.Writer) *writeDetector {
return &writeDetector{
written: make(chan bool),
Writer: w,
}
}

func (w *writeDetector) BlockUntilWritten() {
<-w.written
}

func (w *writeDetector) Write(p []byte) (n int, err error) {
if !w.closed {
close(w.written)
w.closed = true
}
return w.Writer.Write(p)
}

// `Executor.StreamWithContext` starts a goroutine in the background to do the streaming
// and expects the deferred close of the connection leads to the exit of the goroutine on cancellation.
// This test verifies that works.
func TestStreamExitsAfterConnectionIsClosed(t *testing.T) {
writeDetector := newWriterDetector(&fakeEmptyDataPty{})
options := StreamOptions{
Stdin: &fakeEmptyDataPty{},
Stdout: writeDetector,
}
server := newTestHTTPServer(fakeMassiveDataAttacher, &options)

ctx, cancelFn := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancelFn()

uri, _ := url.Parse(server.URL)
exec, err := NewSPDYExecutor(&rest.Config{Host: uri.Host}, "POST", uri)
if err != nil {
t.Fatal(err)
}
streamExec := exec.(*streamExecutor)

conn, streamer, err := streamExec.newConnectionAndStream(ctx, options)
if err != nil {
t.Fatal(err)
}

errorChan := make(chan error)
go func() {
errorChan <- streamer.stream(conn)
}()

// Wait until stream goroutine starts.
writeDetector.BlockUntilWritten()

// Close the connection
conn.Close()

select {
case <-time.After(1 * time.Second):
t.Fatalf("expect stream to be closed after connection is closed.")
case <-errorChan:
return
}
}

0 comments on commit bc6266d

Please sign in to comment.