diff --git a/conf/conf.go b/conf/conf.go index 5e939ef..af26ee6 100644 --- a/conf/conf.go +++ b/conf/conf.go @@ -53,7 +53,7 @@ func Init() (*Client, *Server, string) { flag.StringVar(&conf.Server.KeyFile, "s_key", "", "tls key file, gen cert from letsencrypt if empty") flag.StringVar(&conf.Client.Address, "c", "", "remote server domain, eg: aa.bb.cc, socks5h://127.0.0.1:1080") flag.StringVar(&conf.Client.HTTPProxy, "http_proxy", ":8080", "http proxy, empty to disable") - flag.IntVar(&conf.Client.Router.DetectLevel, "level", 2, "dynamic rule detect level: 0~4") + flag.IntVar(&conf.Client.Router.DetectLevel, "level", 0, "dynamic rule detect level: -4~4") if !flag.Parsed() { flag.Parse() diff --git a/conf/sower.toml b/conf/sower.toml index e8660c1..0a85f0f 100644 --- a/conf/sower.toml +++ b/conf/sower.toml @@ -10,7 +10,7 @@ password="" # sower password # eg: ":2222"="aa.bb.cc:22" [client.router] - detect_level = 1 # 0~4, the bigger the harder to add + detect_level = 0 # [-4, 4], the bigger the harder to add direct_list = [ "**.in-addr.arpa", "imap.*.*", diff --git a/main.go b/main.go index a9699b7..225c841 100644 --- a/main.go +++ b/main.go @@ -30,20 +30,18 @@ func main() { if client.HTTPProxy != "" { go proxy.StartHTTPProxy(client.HTTPProxy, client.Address, - []byte(password), route.ShouldProxy) + []byte(password), route.GenProxyCheck(true)) } enableDNSSolution := client.DNSServeIP != "" if enableDNSSolution { - if client.DNSUpstream != "" { - transport.SetDNS(client.DNSUpstream) - } + transport.SetDNS(nil, client.DNSUpstream) go proxy.StartDNS(client.DNSServeIP, client.DNSUpstream, - route.ShouldProxy) + route.GenProxyCheck(false)) } proxy.StartClient(client.Address, password, enableDNSSolution, - client.PortForward, route.ShouldProxy) + client.PortForward, route.GenProxyCheck(true)) default: fmt.Println() diff --git a/proxy/proxy.go b/proxy/proxy.go index 37fbbdf..b64e692 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -40,14 +40,15 @@ func StartClient(serverAddr, password string, enableDNS bool, } } - rc, err := transport.Dial(target, func(host string) (string, []byte) { - if shouldProxy(host) { + rc, err := transport.Dial(target, func(domain string) (string, []byte) { + if shouldProxy(domain) { return serverAddr, passwordData } return "", nil }) if err != nil { - log.Warnw("dial", "addr", serverAddr, "err", err) + host, _, _ := net.SplitHostPort(target) + log.Warnw("dial", "addr", target, "proxy", shouldProxy(host), "err", err) return } defer rc.Close() diff --git a/router/http_ping.go b/router/http_ping.go index 58f1512..01ad5d8 100644 --- a/router/http_ping.go +++ b/router/http_ping.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/tls" "encoding/binary" + "io" "net" "strconv" "time" @@ -18,19 +19,14 @@ const ( ) // Ping try connect to a http(s) server with domain though the http addr -func (p Port) Ping(domain string, timeout time.Duration) error { - conn, err := net.DialTimeout("tcp", net.JoinHostPort(domain, p.String()), timeout) +func (p Port) Ping(domain string, dial func(string) (net.Conn, error)) error { + conn, err := dial(net.JoinHostPort(domain, p.String())) if err != nil { return err } defer conn.Close() - return p.PingWithConn(domain, conn, timeout) -} - -// PingWithConn try connect to a http(s) server with domain though the http addr -func (p Port) PingWithConn(domain string, conn net.Conn, timeout time.Duration) error { - conn.SetDeadline(time.Now().Add(timeout)) + conn.SetDeadline(time.Now().Add(5 * time.Second)) if _, err := conn.Write(p.PingMsg(domain)); err != nil { return err } @@ -38,7 +34,9 @@ func (p Port) PingWithConn(domain string, conn net.Conn, timeout time.Duration) // err -> nil: read something succ // err -> io.EOF: no such domain or connection refused // err -> timeout: tcp package has been dropped - _, err := conn.Read(make([]byte, 10)) + if _, err = conn.Read(make([]byte, 10)); err == io.EOF { + err = nil + } return err } diff --git a/router/http_ping_test.go b/router/http_ping_test.go index 9eb1881..9f1ae31 100644 --- a/router/http_ping_test.go +++ b/router/http_ping_test.go @@ -4,7 +4,9 @@ package router_test import ( + "crypto/tls" "net" + "strconv" "testing" "time" @@ -12,80 +14,34 @@ import ( "github.com/wweir/sower/transport" ) -func TestPort_Ping(t *testing.T) { - type args struct { - domain string - timeout time.Duration - } - tests := []struct { - name string - p router.Port - args args - wantErr bool - }{ - { - "baidu_80", - router.HTTP, - args{"baidu.com", time.Second}, - false, - }, - { - "baidu_443", - router.HTTPS, - args{"baidu.com", time.Second}, - false, - }, - { - "google_80", - router.HTTP, - args{"google.com", time.Second}, - true, - }, - { - "google_443", - router.HTTPS, - args{"google.com", time.Second}, - true, - }, - { - "mail_80", - router.HTTP, - args{"smtp.163.com", time.Second}, - true, - }, - { - "mail_443", - router.HTTPS, - args{"smtp.163.com", time.Second}, - true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if err := tt.p.Ping(tt.args.domain, tt.args.timeout); (err != nil) != tt.wantErr { - t.Errorf("Port.Ping() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - var ( - proxyAddr string = "" - password []byte = []byte("") + proxyAddr string = "socks5://127.0.0.1:1080" + password []byte = nil ) -func TestPort_PingWithConn(t *testing.T) { - conn, err := transport.Dial(":443", func(string) (string, []byte) { - return proxyAddr, password - }) - if err != nil { - t.Errorf("dial remote %s", err) +func TestPort_Ping(t *testing.T) { + direct := func(addr string) (net.Conn, error) { + return net.DialTimeout("tcp", addr, 5*time.Second) + } + proxy := func(addr string) (net.Conn, error) { + conn, err := tls.Dial("tcp", proxyAddr, &tls.Config{}) + if err != nil { + return nil, err + } + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + p, err := strconv.Atoi(port) + if err != nil { + return nil, err + } + return transport.ToProxyConn(conn, host, uint16(p), password) } type args struct { - domain string - conn net.Conn - timeout time.Duration + domain string + dial func(string) (net.Conn, error) } tests := []struct { name string @@ -93,35 +49,23 @@ func TestPort_PingWithConn(t *testing.T) { args args wantErr bool }{ - { - "google_80", - router.HTTP, - args{"google.com", conn, 3 * time.Second}, - false, - }, - { - "google_443", - router.HTTPS, - args{"baidu.com", conn, 3 * time.Second}, - false, - }, - { - "mail_80", - router.HTTP, - args{"smtp.163.com", conn, 3 * time.Second}, - true, - }, - { - "mail_443", - router.HTTPS, - args{"smtp.163.com", conn, 3 * time.Second}, - true, - }, + {"", router.HTTP, args{"baidu.com", direct}, false}, + {"", router.HTTP, args{"baidu.com", proxy}, false}, + {"", router.HTTPS, args{"baidu.com", direct}, false}, + {"", router.HTTPS, args{"baidu.com", proxy}, false}, + {"", router.HTTP, args{"google.com", direct}, true}, + {"", router.HTTP, args{"google.com", proxy}, false}, + {"", router.HTTPS, args{"google.com", direct}, true}, + {"", router.HTTPS, args{"google.com", proxy}, false}, + {"", router.HTTP, args{"smtp.163.com", direct}, true}, + {"", router.HTTP, args{"smtp.163.com", proxy}, true}, + {"", router.HTTPS, args{"smtp.163.com", direct}, true}, + {"", router.HTTPS, args{"smtp.163.com", proxy}, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := tt.p.PingWithConn(tt.args.domain, tt.args.conn, tt.args.timeout); (err != nil) != tt.wantErr { - t.Errorf("Port.PingWithConn() error = %v, wantErr %v", err, tt.wantErr) + if err := tt.p.Ping(tt.args.domain, tt.args.dial); (err != nil) != tt.wantErr { + t.Errorf("Port.Ping() error = %v, wantErr %v", err, tt.wantErr) } }) } diff --git a/router/router.go b/router/router.go index 244649f..755c808 100644 --- a/router/router.go +++ b/router/router.go @@ -31,7 +31,7 @@ type Route struct { } // ShouldProxy check if the domain shoule request though proxy -func (r *Route) ShouldProxy(domain string) bool { +func (r *Route) GenProxyCheck(sync bool) func(string) bool { r.once.Do(func() { r.cache = mem.New(4 * time.Hour) r.password = []byte(r.ProxyPassword) @@ -39,25 +39,36 @@ func (r *Route) ShouldProxy(domain string) bool { r.proxyRule = util.NewNodeFromRules(r.ProxyList...) }) - // break deadlook, for wildcard - if strings.Count(domain, ".") > 4 { - return false - } - domain = strings.TrimSuffix(domain, ".") - - if domain == r.ProxyAddress { - return false - } - - if r.proxyRule.Match(domain) { + detect := func(domain string) bool { + go r.cache.Remember(r, domain) return true } - if r.directRule.Match(domain) { - return false + if sync { + detect = func(domain string) bool { + r.cache.Remember(r, domain) + if r.proxyRule.Match(domain) { + return true + } + return !r.directRule.Match(domain) + } } - go r.cache.Remember(r, domain) - return true + return func(domain string) bool { + domain = strings.TrimSuffix(domain, ".") + // break deadlook, for wildcard + if sepCount := strings.Count(domain, "."); sepCount == 0 || sepCount >= 5 { + return false + } + + if r.proxyRule.Match(domain) { + return true + } + if r.directRule.Match(domain) { + return false + } + + return detect(domain) + } } // Get implement for cache @@ -93,19 +104,16 @@ func (r *Route) detect(domain string) (http, https int) { go func(shouldProxy bool, port Port) { defer wg.Done() - target := net.JoinHostPort(domain, port.String()) - conn, err := transport.Dial(target, func(string) (string, []byte) { - if shouldProxy { - return r.ProxyAddress, r.password - } - return "", nil - }) - if err != nil { - log.Warnw("sower dial", "proxy", shouldProxy, "address", target, "err", err) - return - } - - if err := port.PingWithConn(domain, conn, 5*time.Second); err != nil { + if err := port.Ping(domain, func(domain string) (net.Conn, error) { + return transport.Dial(domain, + func(domain string) (proxyAddr string, password []byte) { + if shouldProxy { + return r.ProxyAddress, r.password + } + return "", nil + }) + }); err != nil { + log.Warnw("sower dial", "proxy", shouldProxy, "host", domain, "port", port, "err", err) return } diff --git a/transport/proxy_conn.go b/transport/proxy_conn.go index 2b06a06..9848300 100644 --- a/transport/proxy_conn.go +++ b/transport/proxy_conn.go @@ -2,7 +2,6 @@ package transport import ( "crypto/md5" - "crypto/tls" "encoding/binary" "io" "net" @@ -12,7 +11,7 @@ import ( ) // checksum(>=0x80) + port + target_length + target + data -// data(HTTP/HTTPS, first byte < 0x7F) +// data(HTTP, first byte < 0x7F) type head struct { Checksum byte Port uint16 @@ -41,7 +40,7 @@ func ParseProxyConn(conn net.Conn, password []byte) (net.Conn, string) { return teeConn, net.JoinHostPort(string(buf), strconv.Itoa(int(h.Port))) } -func ToProxyConn(conn net.Conn, tgtHost string, tgtPort uint16, tlsCfg *tls.Config, password []byte) (net.Conn, error) { +func ToProxyConn(conn net.Conn, tgtHost string, tgtPort uint16, password []byte) (net.Conn, error) { h := &head{ Checksum: sumChecksum([]byte(tgtHost), password), Port: tgtPort, diff --git a/transport/util.go b/transport/util.go index a2309e5..eac0e9d 100644 --- a/transport/util.go +++ b/transport/util.go @@ -7,17 +7,39 @@ import ( "strconv" "github.com/wweir/sower/dhcp" - "github.com/wweir/sower/util" + "github.com/wweir/utils/log" ) var ( - dnsAddr string - preSet bool + persistDNS string + dnsAddr string + resolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + return (&net.Dialer{}).DialContext(ctx, network, dnsAddr) + }, + } ) -func SetDNS(dnsIP string) { - dnsAddr = net.JoinHostPort(dnsIP, "53") - preSet = true +func SetDNS(err error, dnsIP string) { + if dnsIP != "" { + persistDNS = dnsIP + dnsAddr = net.JoinHostPort(dnsIP, "53") + return + } else if persistDNS != "" { + return + } + + if e, ok := err.(*net.DNSError); !ok /*nil*/ || !e.IsNotFound { + if dnsIP, err = dhcp.GetDefaultDNSServer(); err != nil { + dnsIP, err = dhcp.GetDefaultDNSServer() // retry + } + if err != nil { + log.Errorw("get dns via dhcp", "err", err, "current_dns", dnsAddr) + } else { + dnsAddr = net.JoinHostPort(dnsIP, "53") + } + } } // Dial dial targetAddr with possiable proxy address @@ -29,21 +51,12 @@ func Dial(targetAddr string, dialAddr func(domain string) (proxyAddr string, pas address, password := dialAddr(host) if address == "" { - ips, err := (&net.Resolver{ - PreferGo: true, - Dial: func(ctx context.Context, network, address string) (net.Conn, error) { - return (&net.Dialer{}).DialContext(ctx, network, dnsAddr) - }, - }).LookupIPAddr(context.Background(), host) + ips, err := resolver.LookupIPAddr(context.Background(), host) + if err != nil { //retry + ips, err = resolver.LookupIPAddr(context.Background(), host) + } if err != nil { - if !preSet { - if e, ok := err.(*net.DNSError); !ok || !e.IsNotFound { - if ip, err := dhcp.GetDefaultDNSServer(); err == nil { - dnsAddr = net.JoinHostPort(ip, "53") - } - } - } - + SetDNS(err, "") return nil, err } @@ -67,11 +80,10 @@ func Dial(targetAddr string, dialAddr func(domain string) (proxyAddr string, pas return conn, nil } - address, _ = util.WithDefaultPort(address, "443") - // tls.Config is same as golang http pkg default behavior - conn, err := tls.Dial("tcp", address, &tls.Config{}) + conn, err := tls.DialWithDialer(&net.Dialer{Resolver: resolver}, + "tcp", net.JoinHostPort(address, "443"), &tls.Config{}) if err != nil { return nil, err } - return ToProxyConn(conn, host, uint16(p), &tls.Config{}, password) + return ToProxyConn(conn, host, uint16(p), password) }