Skip to content

Commit

Permalink
Switch to cli
Browse files Browse the repository at this point in the history
  • Loading branch information
pwaller committed May 7, 2015
1 parent 8e46446 commit 5abde5d
Showing 1 changed file with 83 additions and 38 deletions.
121 changes: 83 additions & 38 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,67 +1,112 @@
package main

import (
"crypto/tls"
"fmt"
"io"
"log"
"net"
"net/http"
"net/url"
"os"
"regexp"

"github.com/gorilla/websocket"
"github.com/codegangsta/cli"
"golang.org/x/net/websocket"
)

func main() {
if len(os.Args) < 2 {
log.Fatal("usage: wscat <url>")
app := cli.NewApp()
app.Name = "wscat"
app.Usage = "cat, but for websockets"
app.Action = ActionMain

app.Flags = []cli.Flag{
cli.StringFlag{
Name: "origin",
Value: "http://localhost/",
Usage: "value to use for the origin header",
EnvVar: "WSCAT_ORIGIN",
},
cli.StringSliceFlag{
Name: "header, H",
Usage: "headers to pass to the remote",
Value: &cli.StringSlice{},
EnvVar: "WSCAT_HEADER",
},
}

u, err := url.Parse(os.Args[1])
if err != nil {
log.Fatal("Fail to parse:", err)
app.Run(os.Args)
}

var RegexParseHeader = regexp.MustCompile("^\\s*([^\\:]+)\\s*:\\s*(.*)$")

func MustParseHeader(header string) (string, string) {
if !RegexParseHeader.MatchString(header) {
log.Fatalf("Unable to parse header: %v (re: %v)", header,
RegexParseHeader.String())
return "", ""
}

useTLS := false

if _, _, err := net.SplitHostPort(u.Host); err != nil {
// no port specified
switch u.Scheme {
case "ws", "http", "":
u.Host += ":80"
case "wss", "https":
u.Host += ":443"
default:
log.Fatal("Unsupported URL scheme: %q", u.Scheme)
}
parts := RegexParseHeader.FindStringSubmatch(header)
return parts[1], parts[2]
}

func MustParseHeaders(c *cli.Context) http.Header {
headers := http.Header{}

for _, h := range c.StringSlice("header") {
key, value := MustParseHeader(h)
headers.Set(key, value)
}

conn, err := net.Dial("tcp", u.Host)
return headers
}

func MustParseURL(u string) *url.URL {
tgt, err := url.ParseRequestURI(u)
if err != nil {
log.Fatal("Failed to connect:", err)
log.Fatalf("Unable to parse URL: %v: %v", u, err)
}
defer conn.Close()
return tgt
}

func ActionMain(c *cli.Context) {

if useTLS {
conn = tls.Client(conn, nil)
defer conn.Close()
args := c.Args()

if len(args) < 1 {
log.Fatalf("usage: wscat <url>")
}

socket, _, err := websocket.NewClient(conn, u, http.Header{}, 1024, 1024)
url := args.First()

config := &websocket.Config{}
config.Location = MustParseURL(url)
config.Origin = MustParseURL(c.String("origin"))
config.Header = MustParseHeaders(c)
config.Version = websocket.ProtocolVersionHybi13

conn, err := websocket.DialConfig(config)
if err != nil {
log.Fatal("WS request failed:", err)
log.Fatalf("Error dialing: %v", err)
}
defer conn.Close()

errc := make(chan error)

for {
_, reader, err := socket.NextReader()
if err != nil {
log.Fatal("Error:", err)
go func() {
_, err := io.Copy(os.Stdout, conn)
if err != io.EOF && err != nil {
log.Printf("Error copying to stdout: %v", err)
}
_, err = io.Copy(os.Stdout, reader)
if err != nil {
log.Fatal("Error:", err)
errc <- err
}()

go func() {
_, err = io.Copy(conn, os.Stdin)
if err != io.EOF && err != nil {
log.Printf("Error copying from stdin: %v", err)
}
fmt.Fprintln(os.Stdout, "")
}
errc <- err
}()

<-errc
}

0 comments on commit 5abde5d

Please sign in to comment.