From 7059ce13f13a5fe613582feb650a7189aea4568b Mon Sep 17 00:00:00 2001 From: Peter Waller Date: Thu, 7 May 2015 19:44:51 +0100 Subject: [PATCH] Switch to gorilla/websocket --- main.go | 72 ++++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 58 insertions(+), 14 deletions(-) diff --git a/main.go b/main.go index 66e8b1b..606dba2 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,8 @@ package main import ( + "encoding/base64" + "fmt" "io" "log" "net/http" @@ -9,7 +11,7 @@ import ( "regexp" "github.com/codegangsta/cli" - "golang.org/x/net/websocket" + "github.com/gorilla/websocket" ) func main() { @@ -21,8 +23,8 @@ func main() { app.Flags = []cli.Flag{ cli.StringFlag{ Name: "origin", - Value: "http://localhost/", - Usage: "value to use for the origin header", + Value: "samehost", + Usage: "URL to use for the origin header ('samehost' is special)", EnvVar: "WSCAT_ORIGIN", }, cli.StringSliceFlag{ @@ -76,15 +78,25 @@ func ActionMain(c *cli.Context) { log.Fatalf("usage: wscat ") } - url := args.First() + urlString := args.First() - config := &websocket.Config{} - config.Location = MustParseURL(url) - config.Origin = MustParseURL(c.String("origin")) - config.Header = MustParseHeaders(c) - config.Version = websocket.ProtocolVersionHybi13 + u := MustParseURL(urlString) - conn, err := websocket.DialConfig(config) + headers := MustParseHeaders(c) + origin := c.String("origin") + if origin == "samehost" { + origin = "//" + u.Host + } + headers.Set("Origin", origin) + + if u.User != nil { + userPassBytes := []byte(u.User.String()) + token := base64.StdEncoding.EncodeToString(userPassBytes) + headers.Set("Authorization", fmt.Sprintf("Basic %v", token)) + u.User = nil + } + + conn, _, err := websocket.DefaultDialer.Dial(u.String(), headers) if err != nil { log.Fatalf("Error dialing: %v", err) } @@ -93,18 +105,50 @@ func ActionMain(c *cli.Context) { errc := make(chan error) go func() { - _, err := io.Copy(os.Stdout, conn) - if err != io.EOF && err != nil { + // _, err := io.Copy(os.Stdout, conn) + var ( + err error + r io.Reader + ) + for { + _, r, err = conn.NextReader() + if err != nil { + break + } + _, err = io.Copy(os.Stdout, r) + if err != nil { + break + } + } + if err != io.EOF { log.Printf("Error copying to stdout: %v", err) } errc <- err }() go func() { - _, err = io.Copy(conn, os.Stdin) - if err != io.EOF && err != nil { + var ( + err error + w io.Writer + ) + + for { + w, err = conn.NextWriter(websocket.BinaryMessage) + if err != nil { + break + } + _, err = io.Copy(w, os.Stdin) + if err != nil { + break + } + + break + } + + if err != nil && err != io.EOF { log.Printf("Error copying from stdin: %v", err) } + errc <- err }()