diff --git a/.errcheck.exclude b/.errcheck.exclude index b663935..aba5683 100644 --- a/.errcheck.exclude +++ b/.errcheck.exclude @@ -1,4 +1,5 @@ io.Copy +(*os.Process).Kill (net.Conn).Close fmt.Fprintf fmt.Fprintln diff --git a/Makefile b/Makefile index 857f14c..a193dc1 100644 --- a/Makefile +++ b/Makefile @@ -28,7 +28,7 @@ test-dep: dep go test -i -v ./... test: test-dep - go test -v ./... + go test -race -v ./... release: $(addsuffix .tar.gz,$(addprefix build/$(EXECUTABLE)-$(VERSION)_,$(subst /,_,$(BUILD_PLATFORMS)))) release: $(addsuffix .tar.gz.sha256,$(addprefix build/$(EXECUTABLE)-$(VERSION)_,$(subst /,_,$(BUILD_PLATFORMS)))) diff --git a/sshproxy.go b/sshproxy.go index a32890d..0b899ce 100644 --- a/sshproxy.go +++ b/sshproxy.go @@ -272,24 +272,21 @@ func (s *Server) handleRequests(reqs <-chan *ssh.Request, channel ssh.Channel, c go func() { done := make(chan error, 1) go func() { - done <- cmd.Wait() - }() - Loop: - for { - select { - case <-time.After(10 * time.Second): - if _, err := channel.SendRequest("ping", false, []byte{}); err != nil { - // Channel is dead, kill process - if err := cmd.Process.Kill(); err != nil { - s.handleError(err, nil) + for { + select { + case <-time.After(10 * time.Second): + if _, err := channel.SendRequest("ping", false, []byte{}); err != nil { + // Channel is dead, attempt to kill process + cmd.Process.Kill() + break } - break Loop + case <-done: + break } - case <-done: - break Loop } - } + }() + done <- cmd.Wait() exitStatusPayload := make([]byte, 4) exitStatus := uint32(1) if cmd.ProcessState != nil { diff --git a/sshproxy_test.go b/sshproxy_test.go new file mode 100644 index 0000000..85c28bb --- /dev/null +++ b/sshproxy_test.go @@ -0,0 +1,62 @@ +package sshproxy_test + +import ( + "net" + "testing" + "time" + + "github.com/balena-io/sshproxy" + "golang.org/x/crypto/ssh" +) + +func TestRace(t *testing.T) { + server, err := sshproxy.New( + "/tmp", + "/bin/bash", + false, + nil, + 3, + nil, + func(err error, tags map[string]string) { + t.Logf("uncaught error: %s", err) + }) + + if err != nil { + t.Fatalf("error calling sshproxy.New :( %s", err) + } + + go func() { + if err := server.Listen("12345"); err != nil { + t.Fatalf("Cannot start server! %s", err) + } + }() + + config := &ssh.ClientConfig{ + User: "user", + Auth: []ssh.AuthMethod{ + ssh.Password("password"), + }, + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + return nil + }, + } + for i := 0; i < 10; i++ { + client, err := ssh.Dial("tcp", "localhost:12345", config) + if err != nil { + t.Errorf("Cannot connect to server :( %s", err) + } + session, err := client.NewSession() + if err != nil { + t.Errorf("Cannot create session :( %s", err) + } + time.Sleep(time.Second) + _, err = session.SendRequest("exec", false, []byte{0, 0, 0, 4, 't', 'e', 's', 't'}) + if err != nil { + t.Errorf("Cannot send exec request :( %q", err) + } + time.Sleep(time.Duration(i*100) * time.Millisecond) + if err := client.Close(); err != nil { + t.Errorf("Error closing client - %s", err) + } + } +}