Skip to content

Commit

Permalink
Optimize router logic
Browse files Browse the repository at this point in the history
  • Loading branch information
wweir committed Mar 23, 2020
1 parent ee6335f commit 1574623
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 102 deletions.
7 changes: 4 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@ func main() {
go proxy.StartHTTPProxy(conf.Conf.Client.HTTPProxy, conf.Conf.Client.Address,
[]byte(conf.Conf.Password), route.ShouldProxy)

if conf.Conf.Client.DNSServeIP != "" {
enableDNSSolution := conf.Conf.Client.DNSServeIP != ""
if enableDNSSolution {
go proxy.StartDNS(conf.Conf.Client.DNSServeIP, conf.Conf.Client.DNSUpstream,
route.ShouldProxy)
}

proxy.StartClient(conf.Conf.Client.Address, conf.Conf.Password,
conf.Conf.Client.DNSServeIP != "", conf.Conf.Client.PortForward)
proxy.StartClient(conf.Conf.Client.Address, conf.Conf.Password, enableDNSSolution,
conf.Conf.Client.PortForward, route.ShouldProxy)

default:
fmt.Println()
Expand Down
22 changes: 10 additions & 12 deletions proxy/http_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,23 @@ import (
"time"

"github.com/wweir/sower/transport"
"github.com/wweir/sower/util"
"github.com/wweir/utils/log"
)

// StartHTTPProxy start http reverse proxy.
// The httputil.ReverseProxy do not supply enough support for https request.
func StartHTTPProxy(httpProxyAddr, serverAddr string, password []byte, shouldProxy func(string) bool) {
func StartHTTPProxy(httpProxyAddr, serverAddr string, password []byte,
shouldProxy func(string) bool) {

proxy := httputil.ReverseProxy{
Director: func(r *http.Request) {},
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
addr, _ = withDefaultPort(addr, "80")
return transport.Dial(serverAddr, addr, password)
}},
addr, _ = util.WithDefaultPort(addr, "80")
return transport.Dial(serverAddr, addr, password, shouldProxy)
},
},
}

srv := &http.Server{
Expand All @@ -45,8 +49,6 @@ func StartHTTPProxy(httpProxyAddr, serverAddr string, password []byte, shouldPro
func httpsProxy(w http.ResponseWriter, r *http.Request,
serverAddr string, password []byte, shouldProxy func(string) bool) {

target, host := withDefaultPort(r.Host, "443")

conn, _, err := w.(http.Hijacker).Hijack()
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
Expand All @@ -60,12 +62,8 @@ func httpsProxy(w http.ResponseWriter, r *http.Request,
return
}

var rc net.Conn
if shouldProxy(host) {
rc, err = transport.Dial(serverAddr, target, password)
} else {
rc, err = net.Dial("tcp", target)
}
target, _ := util.WithDefaultPort(r.Host, "443")
rc, err := transport.Dial(serverAddr, target, password, shouldProxy)
if err != nil {
conn.Write([]byte("sower dial " + serverAddr + " fail: " + err.Error()))
conn.Close()
Expand Down
30 changes: 13 additions & 17 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,19 @@ import (
"crypto/tls"
"net"
"net/http"
"os"
"path/filepath"

"github.com/wweir/sower/transport"
"github.com/wweir/utils/log"
"golang.org/x/crypto/acme/autocert"
)

func StartClient(serverAddr, password string, enableDNS bool, forwards map[string]string) {
passwordData := []byte(password)
func StartClient(serverAddr, password string, enableDNS bool,
forwards map[string]string, shouldProxy func(string) bool) {

passwordData := []byte(password)
relayToRemote := func(lnAddr, target string,
parseFn func(net.Conn) (net.Conn, string, error)) {
parseFn func(net.Conn) (net.Conn, string, error),
shouldProxy func(string) bool) {

ln, err := net.Listen("tcp", lnAddr)
if err != nil {
Expand All @@ -40,7 +40,7 @@ func StartClient(serverAddr, password string, enableDNS bool, forwards map[strin
}
}

rc, err := transport.Dial(serverAddr, target, passwordData)
rc, err := transport.Dial(serverAddr, target, passwordData, shouldProxy)
if err != nil {
log.Warnw("dial", "addr", serverAddr, "err", err)
return
Expand All @@ -53,12 +53,12 @@ func StartClient(serverAddr, password string, enableDNS bool, forwards map[strin
}

for from, to := range forwards {
go relayToRemote(from, to, nil)
go relayToRemote(from, to, nil, func(string) bool { return true })
}

if enableDNS {
go relayToRemote(":80", "", ParseHTTP)
go relayToRemote(":443", "", ParseHTTPS)
go relayToRemote(":80", "", ParseHTTP, shouldProxy)
go relayToRemote(":443", "", ParseHTTPS, shouldProxy)
}

log.Infow("start sower client", "dns solution", enableDNS, "forwards", forwards)
Expand All @@ -67,14 +67,10 @@ func StartClient(serverAddr, password string, enableDNS bool, forwards map[strin
}

func StartServer(relayTarget, password, cacheDir, certFile, keyFile, email string) {
dir, _ := os.UserCacheDir()
dir = filepath.Join("/", dir, "sower")
log.Infow("certificate cache dir", "dir", dir)

certManager := autocert.Manager{
Prompt: autocert.AcceptTOS,
Email: email,
Cache: autocert.DirCache(dir),
Cache: autocert.DirCache(cacheDir),
}

tlsConf := &tls.Config{
Expand All @@ -94,22 +90,22 @@ func StartServer(relayTarget, password, cacheDir, certFile, keyFile, email strin
// Try to redirect 80 to 443
go http.ListenAndServe(":80", certManager.HTTPHandler(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
r.URL.Scheme = "https"
if host, _, err := net.SplitHostPort(r.Host); err != nil {
r.URL.Host = r.Host
} else {
r.URL.Host = host
}
r.URL.Scheme = "https"

http.Redirect(w, r, r.URL.String(), 301)
})))

log.Infow("start sower server", "relay_to", relayTarget)
ln, err := tls.Listen("tcp", ":443", tlsConf)
if err != nil {
log.Fatalw("tcp listen", "err", err)
}

log.Infow("start sower server", "relay_to", relayTarget)

passwordData := []byte(password)
for {
conn, err := ln.Accept()
Expand Down
16 changes: 3 additions & 13 deletions proxy/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@ func ParseHTTP(conn net.Conn) (net.Conn, string, error) {
return teeConn, "", err
}

if _, _, err := net.SplitHostPort(resp.Host); err != nil {
resp.Host = net.JoinHostPort(resp.Host, "80")
}

resp.Host, _ = util.WithDefaultPort(resp.Host, "80")
return teeConn, resp.Host, nil
}

Expand All @@ -41,15 +38,8 @@ func ParseHTTPS(conn net.Conn) (net.Conn, string, error) {
},
}).Handshake()

return teeConn, net.JoinHostPort(domain, "443"), nil
}

func withDefaultPort(addr string, port string) (address, host string) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return net.JoinHostPort(addr, port), addr
}
return addr, host
domain, _ = util.WithDefaultPort(domain, "443")
return teeConn, domain, nil
}

func relay(conn1, conn2 net.Conn) {
Expand Down
54 changes: 24 additions & 30 deletions router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ import (

// Route implement a router for each request
type Route struct {
port Port
once sync.Once
cache *mem.Cache
shouldProxy bool
port Port
once sync.Once
cache *mem.Cache

ProxyAddress string
ProxyPassword string
Expand Down Expand Up @@ -96,34 +97,19 @@ func (r *Route) Get(key interface{}) (err error) {
func (r *Route) detect(domain string) (http, https int) {
wg := sync.WaitGroup{}
httpScore, httpsScore := new(int32), new(int32)
for _, ping := range [...]*Route{{port: HTTP}, {port: HTTPS}} {
wg.Add(1)
go func(ping *Route) {
defer wg.Done()

if err := ping.port.Ping(domain, r.timeout); err != nil {
return
}

switch ping.port {
case HTTP:
if !atomic.CompareAndSwapInt32(httpScore, 0, -2) {
atomic.AddInt32(httpScore, -1)
}
case HTTPS:
if !atomic.CompareAndSwapInt32(httpsScore, 0, -2) {
atomic.AddInt32(httpScore, -1)
}
}
}(ping)
}
for _, ping := range [...]*Route{{port: HTTP}, {port: HTTPS}} {
for _, ping := range [...]*Route{
{shouldProxy: true, port: HTTP},
{shouldProxy: true, port: HTTPS},
{shouldProxy: false, port: HTTP},
{shouldProxy: false, port: HTTPS},
} {
wg.Add(1)
go func(ping *Route) {
defer wg.Done()

target := net.JoinHostPort(domain, ping.port.String())
conn, err := transport.Dial(r.ProxyAddress, target, r.password)
conn, err := transport.Dial(r.ProxyAddress, target, r.password,
func(string) bool { return r.shouldProxy })
if err != nil {
log.Errorw("sower dial", "addr", r.ProxyAddress, "err", err)
return
Expand All @@ -133,14 +119,22 @@ func (r *Route) detect(domain string) (http, https int) {
return
}

switch ping.port {
case HTTP:
switch {
case r.shouldProxy && ping.port == HTTP:
if !atomic.CompareAndSwapInt32(httpScore, 0, 2) {
atomic.AddInt32(httpScore, 1)
}
case HTTPS:
case r.shouldProxy && ping.port == HTTPS:
if !atomic.CompareAndSwapInt32(httpsScore, 0, 2) {
atomic.AddInt32(httpScore, 1)
atomic.AddInt32(httpsScore, 1)
}
case !r.shouldProxy && ping.port == HTTP:
if !atomic.CompareAndSwapInt32(httpScore, 0, -2) {
atomic.AddInt32(httpScore, -1)
}
case !r.shouldProxy && ping.port == HTTPS:
if !atomic.CompareAndSwapInt32(httpsScore, 0, -2) {
atomic.AddInt32(httpsScore, -1)
}
}
}(ping)
Expand Down
20 changes: 6 additions & 14 deletions transport/socks5.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"io"
"net"
"strconv"
"strings"
)

Expand All @@ -21,28 +20,19 @@ func IsSocks5Schema(addr string) (string, bool) {
return addr, false
}

func ToSocks5(c net.Conn, address string) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
p, err := strconv.Atoi(port)
if err != nil {
return nil, err
}

func ToSocks5(c net.Conn, host string, port uint16) (net.Conn, error) {
return &conn{
init: make(chan struct{}),
Conn: c,
domain: host,
port: []byte{byte(p >> 8), byte(p)},
port: port,
}, nil
}

type conn struct {
init chan struct{}
domain string
port []byte
port uint16
net.Conn
}

Expand Down Expand Up @@ -75,6 +65,8 @@ func (c *conn) Write(b []byte) (n int, err error) {
}
}
{
portBuf := make([]byte, 2)
binary.BigEndian.PutUint16(portBuf, c.port)
req := &request{
req: req{
VER: 5, // socks5
Expand All @@ -83,7 +75,7 @@ func (c *conn) Write(b []byte) (n int, err error) {
ATYP: 3, // DOMAINNAME
},
DST_ADDR: append([]byte{byte(len(c.domain))}, []byte(c.domain)...),
DST_PORT: c.port,
DST_PORT: portBuf,
}

if _, err := c.Conn.Write(req.Bytes()); err != nil {
Expand Down
33 changes: 20 additions & 13 deletions transport/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,38 @@ import (
"crypto/tls"
"net"
"strconv"

"github.com/wweir/sower/util"
)

func Dial(address, target string, password []byte) (net.Conn, error) {
func Dial(address, target string, password []byte, shouldProxy func(string) bool) (net.Conn, error) {
host, port, err := net.SplitHostPort(target)
if err != nil {
return nil, err
}

if !shouldProxy(host) {
return net.Dial("tcp", target)
}

p, err := strconv.Atoi(port)
if err != nil {
return nil, err
}

if addr, ok := IsSocks5Schema(address); ok {
conn, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}
if conn, err = ToSocks5(conn, target); err != nil {
if conn, err = ToSocks5(conn, host, uint16(p)); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}

host, port, err := net.SplitHostPort(target)
if err != nil {
port = "443"
}
p, err := strconv.Atoi(port)
if err != nil {
return nil, err
}

address, _ = util.WithDefaultPort(address, "443")
// tls.Config is same as golang http pkg default behavior
return DialTlsProxyConn(net.JoinHostPort(address, "443"),
host, uint16(p), &tls.Config{}, password)
return DialTlsProxyConn(address, host, uint16(p), &tls.Config{}, password)
}
11 changes: 11 additions & 0 deletions util/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package util

import "net"

func WithDefaultPort(addr string, port string) (address, host string) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return net.JoinHostPort(addr, port), addr
}
return addr, host
}

0 comments on commit 1574623

Please sign in to comment.