Skip to content

Commit

Permalink
Merge pull request #1 from pwaller/gorilla-websocket
Browse files Browse the repository at this point in the history
Switch to gorilla/websocket
  • Loading branch information
pwaller committed May 7, 2015
2 parents 5abde5d + 7059ce1 commit 23da23b
Showing 1 changed file with 58 additions and 14 deletions.
72 changes: 58 additions & 14 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package main

import (
"encoding/base64"
"fmt"
"io"
"log"
"net/http"
Expand All @@ -9,7 +11,7 @@ import (
"regexp"

"github.com/codegangsta/cli"
"golang.org/x/net/websocket"
"github.com/gorilla/websocket"
)

func main() {
Expand All @@ -21,8 +23,8 @@ func main() {
app.Flags = []cli.Flag{
cli.StringFlag{
Name: "origin",
Value: "http://localhost/",
Usage: "value to use for the origin header",
Value: "samehost",
Usage: "URL to use for the origin header ('samehost' is special)",
EnvVar: "WSCAT_ORIGIN",
},
cli.StringSliceFlag{
Expand Down Expand Up @@ -76,15 +78,25 @@ func ActionMain(c *cli.Context) {
log.Fatalf("usage: wscat <url>")
}

url := args.First()
urlString := args.First()

config := &websocket.Config{}
config.Location = MustParseURL(url)
config.Origin = MustParseURL(c.String("origin"))
config.Header = MustParseHeaders(c)
config.Version = websocket.ProtocolVersionHybi13
u := MustParseURL(urlString)

conn, err := websocket.DialConfig(config)
headers := MustParseHeaders(c)
origin := c.String("origin")
if origin == "samehost" {
origin = "//" + u.Host
}
headers.Set("Origin", origin)

if u.User != nil {
userPassBytes := []byte(u.User.String())
token := base64.StdEncoding.EncodeToString(userPassBytes)
headers.Set("Authorization", fmt.Sprintf("Basic %v", token))
u.User = nil
}

conn, _, err := websocket.DefaultDialer.Dial(u.String(), headers)
if err != nil {
log.Fatalf("Error dialing: %v", err)
}
Expand All @@ -93,18 +105,50 @@ func ActionMain(c *cli.Context) {
errc := make(chan error)

go func() {
_, err := io.Copy(os.Stdout, conn)
if err != io.EOF && err != nil {
// _, err := io.Copy(os.Stdout, conn)
var (
err error
r io.Reader
)
for {
_, r, err = conn.NextReader()
if err != nil {
break
}
_, err = io.Copy(os.Stdout, r)
if err != nil {
break
}
}
if err != io.EOF {
log.Printf("Error copying to stdout: %v", err)
}
errc <- err
}()

go func() {
_, err = io.Copy(conn, os.Stdin)
if err != io.EOF && err != nil {
var (
err error
w io.Writer
)

for {
w, err = conn.NextWriter(websocket.BinaryMessage)
if err != nil {
break
}
_, err = io.Copy(w, os.Stdin)
if err != nil {
break
}

break
}

if err != nil && err != io.EOF {
log.Printf("Error copying from stdin: %v", err)
}

errc <- err
}()

Expand Down

0 comments on commit 23da23b

Please sign in to comment.