From 7059ce13f13a5fe613582feb650a7189aea4568b Mon Sep 17 00:00:00 2001
From: Peter Waller
Date: Thu, 7 May 2015 19:44:51 +0100
Subject: [PATCH] Switch to gorilla/websocket
---
main.go | 72 ++++++++++++++++++++++++++++++++++++++++++++++-----------
1 file changed, 58 insertions(+), 14 deletions(-)
diff --git a/main.go b/main.go
index 66e8b1b..606dba2 100644
--- a/main.go
+++ b/main.go
@@ -1,6 +1,8 @@
package main
import (
+ "encoding/base64"
+ "fmt"
"io"
"log"
"net/http"
@@ -9,7 +11,7 @@ import (
"regexp"
"github.com/codegangsta/cli"
- "golang.org/x/net/websocket"
+ "github.com/gorilla/websocket"
)
func main() {
@@ -21,8 +23,8 @@ func main() {
app.Flags = []cli.Flag{
cli.StringFlag{
Name: "origin",
- Value: "http://localhost/",
- Usage: "value to use for the origin header",
+ Value: "samehost",
+ Usage: "URL to use for the origin header ('samehost' is special)",
EnvVar: "WSCAT_ORIGIN",
},
cli.StringSliceFlag{
@@ -76,15 +78,25 @@ func ActionMain(c *cli.Context) {
log.Fatalf("usage: wscat ")
}
- url := args.First()
+ urlString := args.First()
- config := &websocket.Config{}
- config.Location = MustParseURL(url)
- config.Origin = MustParseURL(c.String("origin"))
- config.Header = MustParseHeaders(c)
- config.Version = websocket.ProtocolVersionHybi13
+ u := MustParseURL(urlString)
- conn, err := websocket.DialConfig(config)
+ headers := MustParseHeaders(c)
+ origin := c.String("origin")
+ if origin == "samehost" {
+ origin = "//" + u.Host
+ }
+ headers.Set("Origin", origin)
+
+ if u.User != nil {
+ userPassBytes := []byte(u.User.String())
+ token := base64.StdEncoding.EncodeToString(userPassBytes)
+ headers.Set("Authorization", fmt.Sprintf("Basic %v", token))
+ u.User = nil
+ }
+
+ conn, _, err := websocket.DefaultDialer.Dial(u.String(), headers)
if err != nil {
log.Fatalf("Error dialing: %v", err)
}
@@ -93,18 +105,50 @@ func ActionMain(c *cli.Context) {
errc := make(chan error)
go func() {
- _, err := io.Copy(os.Stdout, conn)
- if err != io.EOF && err != nil {
+ // _, err := io.Copy(os.Stdout, conn)
+ var (
+ err error
+ r io.Reader
+ )
+ for {
+ _, r, err = conn.NextReader()
+ if err != nil {
+ break
+ }
+ _, err = io.Copy(os.Stdout, r)
+ if err != nil {
+ break
+ }
+ }
+ if err != io.EOF {
log.Printf("Error copying to stdout: %v", err)
}
errc <- err
}()
go func() {
- _, err = io.Copy(conn, os.Stdin)
- if err != io.EOF && err != nil {
+ var (
+ err error
+ w io.Writer
+ )
+
+ for {
+ w, err = conn.NextWriter(websocket.BinaryMessage)
+ if err != nil {
+ break
+ }
+ _, err = io.Copy(w, os.Stdin)
+ if err != nil {
+ break
+ }
+
+ break
+ }
+
+ if err != nil && err != io.EOF {
log.Printf("Error copying from stdin: %v", err)
}
+
errc <- err
}()