Skip to content

Commit

Permalink
Merge pull request #111 from lesismal/http_conn_pool
Browse files Browse the repository at this point in the history
websocket client: write with mask
  • Loading branch information
lesismal authored Sep 28, 2021
2 parents b209b40 + 92dfb5a commit fe8165a
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 14 deletions.
78 changes: 78 additions & 0 deletions examples/websocket_proxy/app_client/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package main

import (
"fmt"
"log"
"time"

"github.com/gorilla/websocket"
)

var (
proxyServerAddr = "ws://localhost:8888/ws"
)

func main() {
c, _, err := websocket.DefaultDialer.Dial(proxyServerAddr, nil)
if err != nil {
log.Fatalf("Dial failed: %v, %v", proxyServerAddr, err)
}
defer c.Close()

for i := 0; i < 10; i++ {
{
request := fmt.Sprintf("hello %v", i)
err := c.WriteMessage(websocket.BinaryMessage, []byte(request))
if err != nil {
log.Fatalf("write: %v", err)
return
}

receiveType, response, err := c.ReadMessage()
if err != nil {
log.Println("ReadMessage failed:", err)
return
}
if receiveType != websocket.BinaryMessage {
log.Printf("received type(%d) != websocket.BinaryMessage(%d)\n", receiveType, websocket.BinaryMessage)
return

}

if string(response) != request {
log.Printf("'%v' != '%v'", len(response), len(request))
return
}

log.Printf("success echo: [websocket.BinaryMessage], %v", request)
}

{
request := fmt.Sprintf("hello %v", i)
err := c.WriteMessage(websocket.TextMessage, []byte(request))
if err != nil {
log.Fatalf("write: %v", err)
return
}

receiveType, response, err := c.ReadMessage()
if err != nil {
log.Println("ReadMessage failed:", err)
return
}
if receiveType != websocket.TextMessage {
log.Printf("received type(%d) != websocket.TextMessage(%d)\n", receiveType, websocket.TextMessage)
return

}

if string(response) != request {
log.Printf("'%v' != '%v'", len(response), len(request))
return
}

log.Printf("success echo: [websocket.TextMessage], %v", request)
}
time.Sleep(time.Second)
}
}
58 changes: 58 additions & 0 deletions examples/websocket_proxy/app_server/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package main

import (
"log"
"net/http"

"github.com/gorilla/websocket"
)

func msgType(messageType int) string {
switch messageType {
case websocket.BinaryMessage:
return "[websocket.BinaryMessage]"
case websocket.TextMessage:
return "[websocket.TextMessage]"
default:
}
return "[]"
}

func echo(w http.ResponseWriter, r *http.Request) {
upgrader := &websocket.Upgrader{}
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Print("upgrade:", err)
return
}
defer c.Close()
for {
messageType, message, err := c.ReadMessage()
if err != nil {
log.Println("read failed:", err)
break
}

log.Printf("app server onMessage: %v, %v", msgType(messageType), string(message))

err = c.WriteMessage(messageType, message)
if err != nil {
log.Println("write failed:", err)
break
}
}
}

var (
appServerAddr = "localhost:9999"
)

func main() {
mux := &http.ServeMux{}
mux.HandleFunc("/ws", echo)
server := http.Server{
Addr: appServerAddr,
Handler: mux,
}
log.Println("server exit:", server.ListenAndServe())
}
105 changes: 105 additions & 0 deletions examples/websocket_proxy/proxy_server/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package main

import (
"context"
"flag"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"time"

"github.com/lesismal/nbio/nbhttp"
"github.com/lesismal/nbio/nbhttp/websocket"
)

var (
proxyServerAddr = "localhost:8888"

appServerAddr = "ws://localhost:9999/ws"

proxyServer *nbhttp.Server
)

func msgType(messageType websocket.MessageType) string {
switch messageType {
case websocket.BinaryMessage:
return "[websocket.BinaryMessage]"
case websocket.TextMessage:
return "[websocket.TextMessage]"
default:
}
return "[]"
}

func newUpgrader() *websocket.Upgrader {
u := websocket.NewUpgrader()
u.OnMessage(func(c *websocket.Conn, messageType websocket.MessageType, data []byte) {
peer, ok := c.Session().(*websocket.Conn)
if ok {
log.Printf("proxy server onMessage [%v -> %v]: %v, %v", c.RemoteAddr(), peer.RemoteAddr(), msgType(messageType), string(data))
peer.WriteMessage(messageType, data)
}
})
u.OnClose(func(c *websocket.Conn, err error) {
peer, ok := c.Session().(*websocket.Conn)
if ok {
log.Printf("proxy server onClose [%v -> %v]", c.RemoteAddr(), peer.RemoteAddr())
peer.Close()
} else {
log.Printf("proxy server onClose [%v -> %v]", c.RemoteAddr(), "")
}
})
return u
}

func onWebsocket(w http.ResponseWriter, r *http.Request) {
upgrader := newUpgrader()
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
panic(err)
}
srcConn := conn.(*websocket.Conn)

dialer := &websocket.Dialer{
Engine: proxyServer.Engine,
Upgrader: newUpgrader(),
DialTimeout: time.Second * 3,
}
dstConn, _, err := dialer.Dial(appServerAddr, nil)
if err != nil {
log.Printf("Dial failed: %v, %v", appServerAddr, err)
srcConn.Close()
return
}

srcConn.SetSession(dstConn)
dstConn.SetSession(srcConn)
}

func main() {
flag.Parse()
mux := &http.ServeMux{}
mux.HandleFunc("/ws", onWebsocket)

proxyServer = nbhttp.NewServer(nbhttp.Config{
Network: "tcp",
Addrs: []string{proxyServerAddr},

SupportClient: true,
}, mux, nil)

err := proxyServer.Start()
if err != nil {
fmt.Printf("nbio.Start failed: %v\n", err)
return
}

interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt)
<-interrupt
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
proxyServer.Shutdown(ctx)
}
50 changes: 38 additions & 12 deletions nbhttp/websocket/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"bytes"
"encoding/binary"
"errors"
"math/rand"
"net"
"sync"

Expand Down Expand Up @@ -38,12 +39,18 @@ const (
PongMessage MessageType = 10
)

const (
maskBit = 1 << 7
)

// Conn .
type Conn struct {
net.Conn

mux sync.Mutex

isClient bool

onCloseCalled bool
remoteCompressionEnabled bool
enableWriteCompression bool
Expand Down Expand Up @@ -204,27 +211,46 @@ func (c *Conn) WriteFrame(messageType MessageType, sendOpcode, fin bool, data []
func (c *Conn) writeFrame(messageType MessageType, sendOpcode, fin bool, data []byte, compress bool) error {
var (
buf []byte
offset = 2
byte1 byte
maskLen int
headLen int
bodyLen = len(data)
)

if c.isClient {
byte1 |= maskBit
maskLen = 4
}

if bodyLen < 126 {
buf = mempool.Malloc(len(data) + 2)
headLen = 2 + maskLen
buf = mempool.Malloc(len(data) + headLen)
buf[0] = 0
buf[1] = byte(bodyLen)
buf[1] = (byte1 | byte(bodyLen))
} else if bodyLen <= 65535 {
buf = mempool.Malloc(len(data) + 4)
binary.LittleEndian.PutUint16(buf, 0)
buf[1] = 126
headLen = 4 + maskLen
buf = mempool.Malloc(len(data) + headLen)
buf[0] = 0
buf[1] = (byte1 | 126)
binary.BigEndian.PutUint16(buf[2:4], uint16(bodyLen))
offset = 4
} else {
buf = mempool.Malloc(len(data) + 10)
binary.LittleEndian.PutUint16(buf, 0)
buf[1] = 127
headLen = 10 + maskLen
buf = mempool.Malloc(len(data) + headLen)
buf[0] = 0
buf[1] = (byte1 | 127)
binary.BigEndian.PutUint64(buf[2:10], uint64(bodyLen))
offset = 10
}
copy(buf[offset:], data)

if c.isClient {
u32 := rand.Uint32()
maskKey := []byte{byte(u32), byte(u32 >> 8), byte(u32 >> 16), byte(u32 >> 24)}
copy(buf[headLen-4:headLen], maskKey)
for i := 0; i < len(data); i++ {
buf[headLen+i] = (data[i] ^ maskKey[i%4])
}
} else {
copy(buf[headLen:], data)
}

// opcode
if sendOpcode {
Expand Down
1 change: 1 addition & 0 deletions nbhttp/websocket/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}

wsConn = newConn(upgrader, conn, resp.Header.Get(secWebsocketProtoHeaderField), remoteCompressionEnabled)
wsConn.isClient = true
wsConn.Engine = d.Engine
wsConn.OnClose(upgrader.onClose)

Expand Down
4 changes: 2 additions & 2 deletions nbhttp/websocket/upgrader.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,9 +523,9 @@ func (u *Upgrader) nextFrame() (opcode MessageType, body []byte, ok, fin, res1,
if l >= total {
body = u.buffer[headLen:total]
if masked {
mask := u.buffer[headLen-4 : headLen]
maskKey := u.buffer[headLen-4 : headLen]
for i := 0; i < len(body); i++ {
body[i] ^= mask[i%4]
body[i] ^= maskKey[i%4]
}
}

Expand Down

0 comments on commit fe8165a

Please sign in to comment.