Skip to content

Commit

Permalink
fallback and server support proxy too
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx committed Oct 24, 2022
1 parent 0382140 commit b9f9765
Show file tree
Hide file tree
Showing 9 changed files with 197 additions and 160 deletions.
8 changes: 5 additions & 3 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ jobs:
name: Build
runs-on: ubuntu-latest
steps:
- uses: actions/setup-go@v1
- uses: actions/checkout@v3
- uses: actions/setup-go@v3
with:
go-version: 1.19
go-version: '1.19'
check-latest: true
cache: true
- run: go install github.com/mitchellh/gox@latest
- uses: actions/checkout@v2
- run: PATH=$HOME/go/bin:$PATH ./crossbuild.sh
- uses: actions/upload-artifact@v1
with:
Expand Down
8 changes: 5 additions & 3 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ jobs:
- name: Get the version
id: get_version
run: echo ::set-output name=VERSION::${GITHUB_REF/refs\/tags\//}
- uses: actions/setup-go@v1
- uses: actions/checkout@v3
- uses: actions/setup-go@v3
with:
go-version: 1.19
go-version: '1.19'
check-latest: true
cache: true
- run: go install github.com/mitchellh/gox@latest
- uses: actions/checkout@v2
- run: PATH=$HOME/go/bin:$PATH ./crossbuild.sh
- uses: svenstaro/[email protected]
with:
Expand Down
120 changes: 82 additions & 38 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ package client

import (
"crypto/tls"
"io"
"log"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"

"github.com/wwqgtxx/wstunnel/common"
Expand All @@ -23,7 +23,7 @@ import (
type client struct {
common.ClientImpl
serverWSPath string
listenerConfig config.ListenerConfig
listenerConfig listener.Config
}

func (c *client) Start() {
Expand Down Expand Up @@ -98,10 +98,10 @@ func (c *wsClientImpl) Handle(tcp net.Conn) {
return
}
defer conn.Close()
c.Tunnel(tcp, conn)
conn.TunnelTcp(tcp)
}

func (c *wsClientImpl) Dial(edBuf []byte, inHeader http.Header) (io.Closer, error) {
func (c *wsClientImpl) Dial(edBuf []byte, inHeader http.Header) (common.ClientConn, error) {
header := c.header
if len(inHeader) > 0 {
// copy from inHeader
Expand All @@ -124,24 +124,50 @@ func (c *wsClientImpl) Dial(edBuf []byte, inHeader http.Header) (io.Closer, erro

// force use inHeader's `Sec-WebSocket-Protocol` for Xray's 0rtt ws
if secProtocol := inHeader.Get("Sec-WebSocket-Protocol"); len(secProtocol) > 0 {
header.Set("Sec-WebSocket-Protocol", secProtocol)
if c.ed > 0 {
header.Set("Sec-WebSocket-Protocol", secProtocol)
edBuf = nil
} else {
edBuf, _ = utils.DecodeEd(secProtocol)
}
}
}
if c.ed > 0 && len(edBuf) > 0 {
header.Set("Sec-WebSocket-Protocol", utils.EncodeEd(edBuf))
edBuf = nil
}
log.Println("Dial to", c.Target(), c.Proxy(), "with", header)
ws, resp, err := c.wsDialer.Dial(c.Target(), header)
if resp != nil {
log.Println("Dial", c.Target(), c.Proxy(), "get response", resp.Header)
}
return ws, err
if len(edBuf) > 0 {
err = ws.WriteMessage(websocket.BinaryMessage, edBuf)
if err != nil {
return nil, err
}
}
return &wsClientConn{ws: ws}, err
}

type wsClientConn struct {
ws *websocket.Conn
close sync.Once
}

func (c *wsClientImpl) ToRawConn(conn io.Closer) net.Conn {
ws := conn.(*websocket.Conn)
return ws.UnderlyingConn()
func (c *wsClientConn) Close() {
c.close.Do(func() {
_ = c.ws.Close()
})
}

func (c *wsClientImpl) Tunnel(tcp net.Conn, conn io.Closer) {
tunnel.TunnelTcpWs(tcp, conn.(*websocket.Conn))
func (c *wsClientConn) TunnelTcp(tcp net.Conn) {
tunnel.TcpWs(tcp, c.ws)
}

func (c *wsClientConn) TunnelWs(ws *websocket.Conn) {
// fastpath for direct tunnel underlying ws connection
tunnel.TcpTcp(ws.UnderlyingConn(), c.ws.UnderlyingConn())
}

func (c *tcpClientImpl) Target() string {
Expand All @@ -161,23 +187,37 @@ func (c *tcpClientImpl) Handle(tcp net.Conn) {
return
}
defer conn.Close()
c.Tunnel(tcp, conn)
conn.TunnelTcp(tcp)
}

func (c *tcpClientImpl) Dial(edBuf []byte, inHeader http.Header) (io.Closer, error) {
func (c *tcpClientImpl) Dial(edBuf []byte, inHeader http.Header) (common.ClientConn, error) {
tcp, err := c.netDial("tcp", c.Target())
if err == nil && len(edBuf) > 0 {
_, err = tcp.Write(edBuf)
if err != nil {
return nil, err
}
}
return tcp, err
return &tcpClientConn{tcp: tcp}, err
}

type tcpClientConn struct {
tcp net.Conn
close sync.Once
}

func (c *tcpClientConn) Close() {
c.close.Do(func() {
_ = c.tcp.Close()
})
}

func (c *tcpClientImpl) ToRawConn(conn io.Closer) net.Conn {
return conn.(net.Conn)
func (c *tcpClientConn) TunnelTcp(tcp net.Conn) {
tunnel.TcpTcp(tcp, c.tcp)
}

func (c *tcpClientImpl) Tunnel(tcp net.Conn, conn io.Closer) {
tunnel.TunnelTcpTcp(tcp, conn.(net.Conn))
func (c *tcpClientConn) TunnelWs(ws *websocket.Conn) {
tunnel.TcpWs(c.tcp, ws)
}

func BuildClient(clientConfig config.ClientConfig) {
Expand All @@ -189,9 +229,13 @@ func BuildClient(clientConfig config.ClientConfig) {
serverWSPath := strings.ReplaceAll(clientConfig.ServerWSPath, "{port}", port)

c := &client{
ClientImpl: NewClientImpl(clientConfig),
serverWSPath: serverWSPath,
listenerConfig: clientConfig.ListenerConfig,
ClientImpl: NewClientImpl(clientConfig),
serverWSPath: serverWSPath,
listenerConfig: listener.Config{
ListenerConfig: clientConfig.ListenerConfig,
ProxyConfig: clientConfig.ProxyConfig,
IsWebSocketListener: len(clientConfig.TargetAddress) > 0,
},
}

common.PortToClient[port] = c
Expand All @@ -212,16 +256,16 @@ func parseProxy(proxyString string) (proxyUrl *url.URL, proxyStr string) {
return
}

func NewClientImpl(config config.ClientConfig) common.ClientImpl {
if len(config.TargetAddress) > 0 {
return NewTcpClientImpl(config)
func NewClientImpl(clientConfig config.ClientConfig) common.ClientImpl {
if len(clientConfig.TargetAddress) > 0 {
return NewTcpClientImpl(clientConfig)
} else {
return NewWsClientImpl(config)
return NewWsClientImpl(clientConfig)
}
}

func NewTcpClientImpl(config config.ClientConfig) common.ClientImpl {
proxyUrl, proxyStr := parseProxy(config.Proxy)
func NewTcpClientImpl(clientConfig config.ClientConfig) common.ClientImpl {
proxyUrl, proxyStr := parseProxy(clientConfig.Proxy)

var netDial NetDialerFunc
tcpDialer := &net.Dialer{
Expand All @@ -243,23 +287,23 @@ func NewTcpClientImpl(config config.ClientConfig) common.ClientImpl {
}

return &tcpClientImpl{
targetAddress: config.TargetAddress,
targetAddress: clientConfig.TargetAddress,
netDial: netDial,
proxy: proxyStr,
}
}

func NewWsClientImpl(config config.ClientConfig) common.ClientImpl {
proxyUrl, proxyStr := parseProxy(config.Proxy)
func NewWsClientImpl(clientConfig config.ClientConfig) common.ClientImpl {
proxyUrl, proxyStr := parseProxy(clientConfig.Proxy)

proxy := http.ProxyFromEnvironment
if proxyUrl != nil {
proxy = http.ProxyURL(proxyUrl)
}

header := http.Header{}
if len(config.WSHeaders) != 0 {
for key, value := range config.WSHeaders {
if len(clientConfig.WSHeaders) != 0 {
for key, value := range clientConfig.WSHeaders {
header.Add(key, value)
}
}
Expand All @@ -271,22 +315,22 @@ func NewWsClientImpl(config config.ClientConfig) common.ClientImpl {
WriteBufferPool: tunnel.WriteBufferPool,
}
wsDialer.TLSClientConfig = &tls.Config{
ServerName: config.ServerName,
InsecureSkipVerify: config.SkipCertVerify,
ServerName: clientConfig.ServerName,
InsecureSkipVerify: clientConfig.SkipCertVerify,
}
var ed uint32
if u, err := url.Parse(config.WSUrl); err == nil {
if u, err := url.Parse(clientConfig.WSUrl); err == nil {
if q := u.Query(); q.Get("ed") != "" {
Ed, _ := strconv.Atoi(q.Get("ed"))
ed = uint32(Ed)
q.Del("ed")
u.RawQuery = q.Encode()
config.WSUrl = u.String()
clientConfig.WSUrl = u.String()
}
}
return &wsClientImpl{
header: header,
wsUrl: config.WSUrl,
wsUrl: clientConfig.WSUrl,
wsDialer: wsDialer,
ed: ed,
proxy: proxyStr,
Expand Down Expand Up @@ -329,5 +373,5 @@ func StartClients() {
}

func init() {
common.NewClientImpl = NewClientImpl
listener.NewClientImpl = NewClientImpl
}
13 changes: 7 additions & 6 deletions common/common.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
package common

import (
"io"
"net"
"net/http"

"github.com/wwqgtxx/wstunnel/config"
"github.com/gorilla/websocket"
)

var PortToServer = make(map[string]Server)
Expand All @@ -31,9 +30,11 @@ type ClientImpl interface {
Target() string
Proxy() string
Handle(tcp net.Conn)
Dial(edBuf []byte, inHeader http.Header) (io.Closer, error)
ToRawConn(conn io.Closer) net.Conn
Tunnel(tcp net.Conn, conn io.Closer)
Dial(edBuf []byte, inHeader http.Header) (ClientConn, error)
}

var NewClientImpl func(config config.ClientConfig) ClientImpl
type ClientConn interface {
Close()
TunnelTcp(tcp net.Conn)
TunnelWs(ws *websocket.Conn)
}
11 changes: 9 additions & 2 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,36 @@ import (

type ClientConfig struct {
ListenerConfig `yaml:",inline"`
ProxyConfig `yaml:",inline"`
TargetAddress string `yaml:"target-address"`
WSUrl string `yaml:"ws-url"`
WSHeaders map[string]string `yaml:"ws-headers"`
SkipCertVerify bool `yaml:"skip-cert-verify"`
ServerName string `yaml:"servername"`
Proxy string `yaml:"proxy"`
ServerWSPath string `yaml:"server-ws-path"`
}

type ServerConfig struct {
ListenerConfig `yaml:",inline"`
ProxyConfig `yaml:",inline"`
Target []ServerTargetConfig `yaml:"target"`
}

type ListenerConfig struct {
BindAddress string `yaml:"bind-address"`
TLSFallbackAddress string `yaml:"tls-fallback-address"`
SshFallbackAddress string `yaml:"ssh-fallback-address"`
SshFallbackTimeout int `yaml:"ssh-fallback-timeout"`
TLSFallbackAddress string `yaml:"tls-fallback-address"`
WSFallbackAddress string `yaml:"ws-fallback-address"`
UnknownFallbackAddress string `yaml:"unknown-fallback-address"`
}

type ProxyConfig struct {
Proxy string `yaml:"proxy"`
}

type ServerTargetConfig struct {
*ProxyConfig `yaml:",inline"`
TargetAddress string `yaml:"target-address"`
WSPath string `yaml:"ws-path"`
}
Expand Down
Loading

0 comments on commit b9f9765

Please sign in to comment.