diff --git a/README.md b/README.md index 5485645..9d4e872 100644 --- a/README.md +++ b/README.md @@ -50,14 +50,16 @@ There are 3 kinds of proxy solutions, they are HTTP(S)_PROXY / DNS-based proxy / An HTTP(S)_PROXY listening on `:8080` is set by default if you run sower as client mode. ### DNS-based proxy -You can set the `serve_ip` field in the `dns` section in the configuration file to start the DNS-based proxy. You should also set the value of `serve_ip` as your default DNS in OS. +**DNS-based** solution is not recommanded now, for incompatible with TLS 1.3 esni extention. + +You can set the `dns_serve_ip` field in the configuration file to start the DNS-based proxy. You should also set the value of `dns_upstream` as your default DNS in OS. If you want to enjoy the full experience provided by the sower, you can take sower as your private DNS on a long-running server and set it as your default DNS in your router. ### port-forward -The port-forward can be only setted in configuration file, you can set it in section `client.router.port_mapping`, eg: +The port-forward can be only setted in configuration file, you can set it in section `client.port_mapping`, eg: ``` toml -[client.router.port_mapping] +[client.port_mapping] ":2222"="aa.bb.cc:22" ``` diff --git a/conf/conf.go b/conf/conf.go index 9ac2d77..e4f52ae 100644 --- a/conf/conf.go +++ b/conf/conf.go @@ -3,37 +3,25 @@ package conf import ( "flag" "os" - "time" toml "github.com/pelletier/go-toml" - "github.com/wweir/sower/util" "github.com/wweir/utils/log" ) type client struct { - Address string `toml:"address"` - - HTTPProxy struct { - Address string `toml:"address"` - } `toml:"http_proxy"` - - DNS struct { - ServeIP string `toml:"serve_ip"` - Upstream string `toml:"upstream"` - FlushCmd string `toml:"flush_cmd"` - } `toml:"dns"` + Address string `toml:"address"` + HTTPProxy string `toml:"http_proxy"` + DNSServeIP string `toml:"dns_serve_ip"` + DNSUpstream string `toml:"dns_upstream"` + PortForward map[string]string `toml:"port_forward"` Router struct { - PortMapping map[string]string `toml:"port_mapping"` - DetectLevel int `toml:"detect_level"` - DetectTimeout string `toml:"detect_timeout"` - - ProxyList []string `toml:"proxy_list"` - DirectList []string `toml:"direct_list"` - DynamicList map[string]int `toml:"dynamic_list"` // toml has a bug with dot - directRules *util.Node - proxyRules *util.Node - dynamicRules *util.Node + DetectLevel int `toml:"detect_level"` + DetectTimeout string `toml:"detect_timeout"` + + ProxyList []string `toml:"proxy_list"` + DirectList []string `toml:"direct_list"` + DynamicList []string `toml:"dynamic_list"` } `toml:"router"` } type server struct { @@ -48,26 +36,24 @@ var ( Server = server{} Client = client{} - conf = struct { - file string - Server *server `toml:"server"` - Client *client `toml:"client"` - }{"", &Server, &Client} + Conf = struct { + file string + Password string `toml:"password"` + Server *server `toml:"server"` + Client *client `toml:"client"` + }{"", "", &Server, &Client} flushCh = make(chan struct{}) - Password string installCmd string uninstallFlag bool ) func init() { - flag.StringVar(&Password, "password", "", "password") + flag.StringVar(&Conf.Password, "password", "", "password") flag.StringVar(&Server.Upstream, "s", "", "upstream http service, eg: 127.0.0.1:8080") flag.StringVar(&Server.CertFile, "s_cert", "", "tls cert file, gen cert from letsencrypt if empty") flag.StringVar(&Server.KeyFile, "s_key", "", "tls key file, gen cert from letsencrypt if empty") flag.StringVar(&Client.Address, "c", "", "remote server domain, eg: aa.bb.cc, socks5h://127.0.0.1:1080") - flag.StringVar(&Client.HTTPProxy.Address, "http_proxy", ":8080", "http proxy, empty to disable") - flag.StringVar(&Client.DNS.ServeIP, "dns_ip", "", "upstream dns, eg: 127.0.0.1, disable dns proxy if empty") - flag.StringVar(&Client.DNS.Upstream, "dns_upstream", "", "dns relay server ip, dynamic detect if empty") + flag.StringVar(&Client.HTTPProxy, "http_proxy", ":8080", "http proxy, empty to disable") flag.IntVar(&Client.Router.DetectLevel, "level", 2, "dynamic rule detect level: 0~4") flag.StringVar(&Client.Router.DetectTimeout, "timeout", "300ms", "dynamic rule detect timeout") flag.BoolVar(&uninstallFlag, "uninstall", false, "uninstall service") @@ -85,23 +71,14 @@ func init() { os.Exit(0) } - var err error - defer func() { - if timeout, err = time.ParseDuration(Client.Router.DetectTimeout); err != nil { - log.Fatalw("parse dynamic detect timeout", "val", Client.Router.DetectTimeout, "err", err) - } - - log.Infow("start", "version", version, "date", date, "conf", &conf) - passwordData = []byte(Password) - }() - - if conf.file == "" { + defer log.Infow("starting", "config", &Conf) + if Conf.file == "" { return } for i := range loadConfigFns { - if err = loadConfigFns[i].fn(); err != nil { - log.Fatalw("load config", "config", conf.file, "step", loadConfigFns[i].step, "err", err) + if err := loadConfigFns[i].fn(); err != nil { + log.Fatalw("load config", "config", Conf.file, "step", loadConfigFns[i].step, "err", err) } } @@ -113,45 +90,40 @@ var loadConfigFns = []struct { step string fn func() error }{{"load_config", func() error { - f, err := os.OpenFile(conf.file, os.O_RDONLY, 0644) + f, err := os.OpenFile(Conf.file, os.O_RDONLY, 0644) if err != nil { return err } defer f.Close() - Client.Router.DynamicList = map[string]int{} - return toml.NewDecoder(f).Decode(&conf) - -}}, {"load_rules", func() error { - Client.Router.directRules = util.NewNodeFromRules(Client.Router.DirectList...) - Client.Router.proxyRules = util.NewNodeFromRules(Client.Router.ProxyList...) - return nil - -}}, {"flush_dns", func() error { - if Client.DNS.FlushCmd != "" { - return execute(Client.DNS.FlushCmd) - } - return nil + return toml.NewDecoder(f).Decode(&Conf) }}} +func PersistRule(domain string) { + Client.Router.DynamicList = append(Client.Router.DynamicList, domain) + select { + case flushCh <- struct{}{}: + default: + } +} func flushConfDaemon() { for range flushCh { // safe write file - if conf.file != "" { - f, err := os.OpenFile(conf.file+"~", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if Conf.file != "" { + f, err := os.OpenFile(Conf.file+"~", os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) if err != nil { log.Errorw("flush config", "step", "flush", "err", err) continue } - if err := toml.NewEncoder(f).ArraysWithOneElementPerLine(true).Encode(conf); err != nil { + if err := toml.NewEncoder(f).ArraysWithOneElementPerLine(true).Encode(&Conf); err != nil { log.Errorw("flush config", "step", "flush", "err", err) f.Close() continue } f.Close() - if err = os.Rename(conf.file+"~", conf.file); err != nil { + if err = os.Rename(Conf.file+"~", Conf.file); err != nil { log.Errorw("flush config", "step", "flush", "err", err) continue } diff --git a/conf/conf_darwin.go b/conf/conf_darwin.go index 2deeaa1..f82f99f 100644 --- a/conf/conf_darwin.go +++ b/conf/conf_darwin.go @@ -36,9 +36,8 @@ const svcFile = ` ` func Init() { - flag.StringVar(&conf.file, "f", "", "config file, rewrite all other parameters if set") - flag.StringVar(&Client.DNS.FlushCmd, "flush_dns", "pkill mDNSResponder || true", "flush dns command") - flag.StringVar(&installCmd, "install", "", "install service with cmd, eg: '-f /etc/sower/sower.toml'") + flag.StringVar(&Conf.file, "f", "", "config file, rewrite all other parameters if set") + flag.StringVar(&installCmd, "install", "", "install service with cmd, eg: '-f /etc/sower.toml'") } func install() { diff --git a/conf/conf_linux.go b/conf/conf_linux.go index ba5b1a1..a8de9b3 100644 --- a/conf/conf_linux.go +++ b/conf/conf_linux.go @@ -32,9 +32,8 @@ RestartSec=3 Restart=on-failure` func Init() { - flag.StringVar(&conf.file, "f", "", "config file, rewrite all other parameters if set") - flag.StringVar(&Client.DNS.FlushCmd, "flush_dns", "", "flush dns command") - flag.StringVar(&installCmd, "install", "", "install service with cmd, eg: '-f /etc/sower/sower.toml'") + flag.StringVar(&Conf.file, "f", "", "config file, rewrite all other parameters if set") + flag.StringVar(&installCmd, "install", "", "install service with cmd, eg: '-f /etc/sower.toml'") } func install() { execFile, err := filepath.Abs(os.Args[0]) diff --git a/conf/conf_windows.go b/conf/conf_windows.go index 72539f6..a16b808 100644 --- a/conf/conf_windows.go +++ b/conf/conf_windows.go @@ -27,9 +27,8 @@ var execFile, _ = filepath.Abs(os.Args[0]) var execDir, _ = filepath.Abs(filepath.Dir(execFile)) func Init() { - flag.StringVar(&conf.file, "f", filepath.Join(execDir, "sower.toml"), "config file, rewrite all other parameters if set") + flag.StringVar(&Conf.file, "f", filepath.Join(execDir, "sower.toml"), "config file, rewrite all other parameters if set") flag.StringVar(&installCmd, "install", "", "install service with cmd") - flag.StringVar(&Client.DNS.FlushCmd, "flush_dns", "ipconfig /flushdnss", "flush dns command") flag.Parse() switch { @@ -151,7 +150,7 @@ func execute(cmd string) error { defer cancel() var cmds []string - for _, cmd := range strings.Split(Client.DNS.FlushCmd, " ") { + for _, cmd := range strings.Split(cmd, " ") { if cmd == "" { continue } @@ -168,7 +167,7 @@ func execute(cmd string) error { command := exec.CommandContext(ctx, cmds[0], cmds[1:]...) command.SysProcAttr = &syscall.SysProcAttr{HideWindow: true} if out, err := command.CombinedOutput(); err != nil { - return fmt.Errorf("cmd: %s, output: %s, err: %w", Client.DNS.FlushCmd, out, err) + return fmt.Errorf("cmd: %s, output: %s, err: %w", cmd, out, err) } return nil } diff --git a/conf/dynamic_rule.go b/conf/dynamic_rule.go deleted file mode 100644 index e799c33..0000000 --- a/conf/dynamic_rule.go +++ /dev/null @@ -1,166 +0,0 @@ -package conf - -import ( - "crypto/tls" - "net" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/wweir/sower/internal/http" - "github.com/wweir/sower/internal/socks5" - "github.com/wweir/utils/log" - "github.com/wweir/utils/mem" -) - -type dynamic struct { - port http.Port -} - -var cache = mem.New(4 * time.Hour) -var detect = &dynamic{} -var passwordData []byte -var timeout time.Duration -var dynamicCache sync.Map -var dynamicMu = sync.Mutex{} - -// ShouldProxy check if the domain shoule request though proxy -func ShouldProxy(domain string) bool { - // break deadlook, for wildcard - if strings.Count(domain, ".") > 4 { - return false - } - domain = strings.TrimSuffix(domain, ".") - - if domain == Client.Address { - return false - } - if Client.Router.directRules.Match(domain) { - return false - } - if Client.Router.proxyRules.Match(domain) { - return true - } - - cache.Remember(detect, domain) - val, _ := dynamicCache.Load(domain) - return val.(int) >= Client.Router.DetectLevel -} - -func (d *dynamic) Get(key interface{}) (err error) { - domain := key.(string) - domainUnderscore := strings.ReplaceAll(domain, ".", "_") - var score int - - defer func() { - dynamicCache.Store(domain, score) - - if score < conf.Client.Router.DetectLevel { - delete(Client.Router.DynamicList, domainUnderscore) - } else { - Client.Router.DynamicList[domainUnderscore] = score - - // persist when add new domain - select { - case flushCh <- struct{}{}: - default: - } - log.Infow("persist rule", "domain", domain, "score", score) - } - }() - - if val, ok := dynamicCache.Load(domain); ok { - score = val.(int) - } else { - dynamicMu.Lock() - score = Client.Router.DynamicList[domainUnderscore] - dynamicMu.Unlock() - } - - // detect range: [0,conf.Client.Router.DetectLevel) - switch { - case score < -1: - score++ - case score == -1: - score++ - score += d.detect(domain) - case score > conf.Client.Router.DetectLevel: - score-- - case score == conf.Client.Router.DetectLevel: - score-- - score += d.detect(domain) - } - - return nil -} - -// detect and caculate direct connection and proxy connection score -func (d *dynamic) detect(domain string) int { - wg := sync.WaitGroup{} - httpScore, httpsScore := new(int32), new(int32) - for _, ping := range [...]dynamic{{port: http.HTTP}, {port: http.HTTPS}} { - wg.Add(1) - go func(ping dynamic) { - defer wg.Done() - - if err := ping.port.Ping(domain, timeout); err != nil { - return - } - - switch ping.port { - case http.HTTP: - if !atomic.CompareAndSwapInt32(httpScore, 0, -2) { - atomic.AddInt32(httpScore, -1) - } - case http.HTTPS: - if !atomic.CompareAndSwapInt32(httpsScore, 0, -2) { - atomic.AddInt32(httpScore, -1) - } - } - }(ping) - } - for _, ping := range [...]dynamic{{port: http.HTTP}, {port: http.HTTPS}} { - wg.Add(1) - go func(ping dynamic) { - defer wg.Done() - - var conn net.Conn - var err error - if addr, ok := socks5.IsSocks5Schema(Client.Address); ok { - conn, err = net.Dial("tcp", addr) - conn = socks5.ToSocks5(conn, domain, uint16(ping.port)) - - } else { - conn, err = tls.Dial("tcp", net.JoinHostPort(Client.Address, "443"), &tls.Config{}) - if ping.port == http.HTTP { - conn = http.NewTgtConn(conn, passwordData, http.TGT_HTTP, "", 80) - } else { - conn = http.NewTgtConn(conn, passwordData, http.TGT_HTTPS, "", 443) - } - } - if err != nil { - log.Errorw("sower dial", "addr", Client.Address, "err", err) - return - } - - if err := ping.port.PingWithConn(domain, conn, timeout); err != nil { - return - } - - switch ping.port { - case http.HTTP: - if !atomic.CompareAndSwapInt32(httpScore, 0, 2) { - atomic.AddInt32(httpScore, 1) - } - case http.HTTPS: - if !atomic.CompareAndSwapInt32(httpsScore, 0, 2) { - atomic.AddInt32(httpScore, 1) - } - } - }(ping) - } - - wg.Wait() - return int(*httpScore + *httpsScore) -} diff --git a/conf/sower.toml b/conf/sower.toml index b74106a..746e991 100644 --- a/conf/sower.toml +++ b/conf/sower.toml @@ -1,13 +1,13 @@ +password="" # sower password + [client] address = "" # aa.bb.cc, socks5h://127.0.0.1:1080 + http_proxy=":8080" + dns_serve_ip="" # eg: 127.0.0.1. keep empty to disable DNS solution + dns_upstream="" # keep empty to set via dhcp, not effective in any environment - [client.dns] - flush_cmd="" # macOS: pkill mDNSResponder || true, Windows: ipconfig /flushdnss - serve_ip = "127.0.0.1" - upstream = "" # empty to dynamic detect - - [client.http_proxy] - address = ":8080" # empty to disable http_proxy + [client.port_forward] + # eg: ":2222"="aa.bb.cc:22" [client.router] detect_level = 2 # 0~4, the bigger the harder to add @@ -47,8 +47,6 @@ "*.github.*", ] - [client.router.port_mapping] - # ":2222"="aa.bb.cc:22" [server] cert_email = "" # eg: user@aa.bb.cc diff --git a/internal/http/tgt_parser.go b/internal/http/tgt_parser.go deleted file mode 100644 index 4e0eb86..0000000 --- a/internal/http/tgt_parser.go +++ /dev/null @@ -1,147 +0,0 @@ -package http - -import ( - "bufio" - "crypto/md5" - "encoding/binary" - "errors" - "io" - "net" - "net/http" - "strconv" - "strings" - - "github.com/wweir/sower/util" -) - -const ( - TGT_OTHER byte = iota - TGT_HTTP - TGT_HTTPS -) - -// Write Addr -type conn struct { - typ byte - password []byte - domain []byte - port uint16 - init bool - net.Conn -} - -func NewTgtConn(c net.Conn, password []byte, tgtType byte, domain string, port uint16) net.Conn { - return &conn{ - typ: tgtType, - password: password, - domain: []byte(domain), - port: port, - init: true, - Conn: c, - } -} - -// other => type + checksum + port + domain_length ++ domain + data -// http => type + checksum ++ data -// https => type + checksum + port ++ data -type header struct { - Type byte - Checksum byte - Port uint16 - DomainLength uint8 -} - -func (c *conn) Write(b []byte) (n int, err error) { - if c.init { - c.init = false - domainLength := byte(len(c.domain)) - if err := binary.Write(c.Conn, binary.BigEndian, &header{ - Type: c.typ, - Checksum: checksum(c.password, c.port, domainLength), - Port: c.port, - DomainLength: domainLength, - }); err != nil { - return 0, err - } - - switch c.typ { - case TGT_OTHER: - n, err = c.Conn.Write(append(c.domain, b...)) - default: - n, err = c.Conn.Write(b) - } - return n - len(c.domain), err - } - - return c.Conn.Write(b) -} - -// ParseAddr parse target addr from net.Conn -func ParseAddr(conn net.Conn, password []byte) (_ net.Conn, domain string, port uint16, err error) { - teeConn := &util.TeeConn{Conn: conn} - teeConn.StartOrReset() - defer teeConn.Stop() - - head := new(header) - if err = binary.Read(conn, binary.BigEndian, head); err != nil { - return teeConn, "", 0, nil - } - if head.Checksum != checksum(password, head.Port, head.DomainLength) { - return teeConn, "", 0, nil - } - - switch head.Type { - case TGT_OTHER: - buf := make([]byte, int(head.DomainLength)) - if _, err = io.ReadFull(conn, buf); err != nil { - return teeConn, "", 0, err - } - - return teeConn, string(buf), head.Port, nil - - case TGT_HTTP: - teeConn.DropAndRestart() - return ParseHTTP(teeConn) - - case TGT_HTTPS: - teeConn.DropAndRestart() - conn, domain, err = ParseHTTPS(teeConn) - return conn, domain, head.Port, err - - default: - return teeConn, "", 0, errors.New("invalid request") - } -} -func ParseHTTP(teeConn net.Conn) (_ net.Conn, domain string, port uint16, err error) { - resp, err := http.ReadRequest(bufio.NewReader(teeConn)) - if err != nil { - return teeConn, "", 0, err - } - - idx := strings.LastIndex(resp.Host, ":") - if idx == -1 { - return teeConn, resp.Host, 80, nil - } - - p, err := strconv.ParseUint(resp.Host[idx+1:], 10, 16) - if err != nil { - return teeConn, "", 0, err - } - return teeConn, resp.Host[:idx], uint16(p), nil -} -func ParseHTTPS(teeConn net.Conn) (_ net.Conn, domain string, err error) { - if domain, _, err = extractSNI(teeConn); err != nil { - return nil, "", err - } - return teeConn, domain, nil -} - -var errChecksum = errors.New("invalid checksum") - -func checksum(password []byte, port uint16, length uint8) (val byte) { - nums := md5.Sum(append(password, byte(port), length)) - for _, b := range nums { - val += b - } - return val -} diff --git a/internal/http/tgt_parser_test.go b/internal/http/tgt_parser_test.go deleted file mode 100644 index 8e97c18..0000000 --- a/internal/http/tgt_parser_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package http - -import ( - "bufio" - "bytes" - "io/ioutil" - "net" - "net/http" - "testing" -) - -func TestParseAddr1(t *testing.T) { - c1, c2 := net.Pipe() - - go func() { - c1 = NewTgtConn(c1, nil, TGT_HTTP, "", 0) - req, _ := http.NewRequest("GET", "http://wweir.cc", bytes.NewReader([]byte{1, 2, 3})) - req.Write(c1) - }() - - c2, host, port, err := ParseAddr(c2, nil) - - if err != nil || host != "wweir.cc" || port != 80 { - t.Error(err, host, port) - } - - req, err := http.ReadRequest(bufio.NewReader(c2)) - if err != nil { - t.Error(err) - } - - data, err := ioutil.ReadAll(req.Body) - if err != nil || len(data) != 3 || data[0] != 1 { - t.Error(err, data) - } -} - -func TestParseAddr2(t *testing.T) { - c1, c2 := net.Pipe() - - go func() { - c1 = NewTgtConn(c1, nil, TGT_HTTPS, "", 443) - c1.Write(HTTPS.PingMsg("wweir.cc")) - }() - - _, host, port, err := ParseAddr(c2, nil) - - if err != nil || host != "wweir.cc" || port != 443 { - t.Error(err, host, port) - } -} - -func TestParseAddr3(t *testing.T) { - c1, c2 := net.Pipe() - - go func() { - c1 = NewTgtConn(c1, nil, TGT_OTHER, "wweir.cc", 1080) - c1.Write(HTTPS.PingMsg("wweir.cc")) - }() - - _, host, port, err := ParseAddr(c2, nil) - - if err != nil || host != "wweir.cc" || port != 1080 { - t.Error(err, host, port) - } -} diff --git a/internal/http/tgt_sni.go b/internal/http/tgt_sni.go deleted file mode 100644 index 55b10c5..0000000 --- a/internal/http/tgt_sni.go +++ /dev/null @@ -1,232 +0,0 @@ -// Copyright 2016 Google Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package http - -import ( - "encoding/binary" - "errors" - "fmt" - "io" -) - -func extractSNI(r io.Reader) (string, int, error) { - handshake, tlsver, err := handshakeRecord(r) - if err != nil { - return "", 0, fmt.Errorf("reading TLS record: %s", err) - } - - sni, err := parseHello(handshake) - if err != nil { - return "", 0, fmt.Errorf("reading ClientHello: %s", err) - } - if len(sni) == 0 { - // ClientHello did not present an SNI extension. Valid packet, - // no hostname. - return "", tlsver, nil - } - - hostname, err := parseSNI(sni) - if err != nil { - return "", 0, fmt.Errorf("parsing SNI extension: %s", err) - } - return hostname, tlsver, nil -} - -// Extract the indicated hostname, if any, from the given SNI -// extension bytes. -func parseSNI(b []byte) (string, error) { - b, _, err := vector(b, 2) - if err != nil { - return "", err - } - - var ret []byte - for len(b) >= 3 { - typ := b[0] - ret, b, err = vector(b[1:], 2) - if err != nil { - return "", fmt.Errorf("truncated SNI extension") - } - - if typ == sniHostnameID { - return string(ret), nil - } - } - - if len(b) != 0 { - return "", fmt.Errorf("trailing garbage at end of SNI extension") - } - - // No DNS-based SNI present. - return "", nil -} - -const sniExtensionID = 0 -const sniHostnameID = 0 - -// Parse a TLS handshake record as a ClientHello message and extract -// the SNI extension bytes, if any. -func parseHello(b []byte) ([]byte, error) { - if len(b) == 0 { - return nil, errors.New("zero length handshake record") - } - if b[0] != 1 { - return nil, fmt.Errorf("non-ClientHello handshake record type %d", b[0]) - } - - // We're expecting a stricter TLS parser to run after we've - // proxied, so we ignore any trailing bytes that might be present - // (e.g. another handshake message). - b, _, err := vector(b[1:], 3) - if err != nil { - return nil, fmt.Errorf("reading ClientHello: %s", err) - } - - // ClientHello must be at least 34 bytes to reach the first vector - // length byte. The actual minimal size is larger than that, but - // vector() will correctly handle truncated packets. - if len(b) < 34 { - return nil, errors.New("ClientHello packet too short") - } - - if b[0] != 3 { - return nil, fmt.Errorf("ClientHello has unsupported version %d.%d", b[0], b[1]) - } - switch b[1] { - case 1, 2, 3: - // TLS 1.0, TLS 1.1, TLS 1.2 - default: - return nil, fmt.Errorf("TLS record has unsupported version %d.%d", b[0], b[1]) - } - - // Skip over version and random struct - b = b[34:] - - // We don't technically care about SessionID, but we care that the - // framing is well-formed all the way up to the SNI field, so that - // we are sure that we're pulling the same SNI bytes as the - // eventual TLS implementation. - vec, b, err := vector(b, 1) - if err != nil { - return nil, fmt.Errorf("reading ClientHello SessionID: %s", err) - } - if len(vec) > 32 { - return nil, fmt.Errorf("ClientHello SessionID too long (%db)", len(vec)) - } - - // Likewise, we're just checking the bare minimum of framing. - vec, b, err = vector(b, 2) - if err != nil { - return nil, fmt.Errorf("reading ClientHello CipherSuites: %s", err) - } - if len(vec) < 2 || len(vec)%2 != 0 { - return nil, fmt.Errorf("ClientHello CipherSuites invalid length %d", len(vec)) - } - - vec, b, err = vector(b, 1) - if err != nil { - return nil, fmt.Errorf("reading ClientHello CompressionMethods: %s", err) - } - if len(vec) < 1 { - return nil, fmt.Errorf("ClientHello CompressionMethods invalid length %d", len(vec)) - } - - // Finally, we reach the extensions. - if len(b) == 0 { - // No extensions. This is not an error, it just means we have - // no SNI payload. - return nil, nil - } - b, vec, err = vector(b, 2) - if err != nil { - return nil, fmt.Errorf("reading ClientHello extensions: %s", err) - } - if len(vec) != 0 { - return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(vec)) - } - - for len(b) >= 4 { - typ := binary.BigEndian.Uint16(b[:2]) - vec, b, err = vector(b[2:], 2) - if err != nil { - return nil, fmt.Errorf("reading ClientHello extension %d: %s", typ, err) - } - if typ == sniExtensionID { - // Found the SNI extension, return its payload. We don't - // care about anything in the packet beyond this point. - return vec, nil - } - } - - if len(b) != 0 { - return nil, fmt.Errorf("%d bytes of trailing garbage in ClientHello", len(b)) - } - - // Successfully parsed all extensions, but there was no SNI. - return nil, nil -} - -const maxTLSRecordLength = 16384 - -// Read one TLS record, which must be for the handshake protocol, from r. -func handshakeRecord(r io.Reader) ([]byte, int, error) { - var hdr struct { - Type uint8 - Major, Minor uint8 - Length uint16 - } - if err := binary.Read(r, binary.BigEndian, &hdr); err != nil { - return nil, 0, fmt.Errorf("reading TLS record header: %s", err) - } - - if hdr.Type != 22 { - return nil, 0, fmt.Errorf("TLS record is not a handshake") - } - - if hdr.Major != 3 { - return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor) - } - switch hdr.Minor { - case 1, 2, 3: - // TLS 1.0, TLS 1.1, TLS 1.2 - default: - return nil, 0, fmt.Errorf("TLS record has unsupported version %d.%d", hdr.Major, hdr.Minor) - } - - if hdr.Length > maxTLSRecordLength { - return nil, 0, fmt.Errorf("TLS record length is greater than %d", maxTLSRecordLength) - } - - ret := make([]byte, hdr.Length) - if _, err := io.ReadFull(r, ret); err != nil { - return nil, 0, err - } - - return ret, int(hdr.Minor), nil -} - -func vector(b []byte, lenBytes int) ([]byte, []byte, error) { - if len(b) < lenBytes { - return nil, nil, errors.New("not enough space in packet for vector") - } - var l int - for _, b := range b[:lenBytes] { - l = (l << 8) + int(b) - } - if len(b) < l+lenBytes { - return nil, nil, errors.New("not enough space in packet for vector") - } - return b[lenBytes : l+lenBytes], b[l+lenBytes:], nil -} diff --git a/main.go b/main.go index 7ae316c..f63a53a 100644 --- a/main.go +++ b/main.go @@ -6,25 +6,41 @@ import ( "github.com/wweir/sower/conf" "github.com/wweir/sower/proxy" + "github.com/wweir/sower/router" ) func main() { - if conf.Server.Upstream != "" { - proxy.StartServer(conf.Server.Upstream, conf.Password, + switch { + case conf.Server.Upstream != "": + proxy.StartServer(conf.Server.Upstream, conf.Conf.Password, conf.Server.CertFile, conf.Server.KeyFile, conf.Server.CertEmail) - } - if conf.Client.Address != "" { - if conf.Client.DNS.ServeIP != "" { - go proxy.StartDNS(conf.Client.DNS.ServeIP, conf.Client.DNS.Upstream) + case conf.Client.Address != "": + route := &router.Route{ + ProxyAddress: conf.Client.Address, + ProxyPassword: conf.Conf.Password, + DetectLevel: conf.Client.Router.DetectLevel, + DetectTimeout: conf.Client.Router.DetectTimeout, + DirectList: conf.Client.Router.DirectList, + ProxyList: conf.Client.Router.ProxyList, + DynamicList: conf.Client.Router.DynamicList, } - proxy.StartClient(conf.Password, conf.Client.Address, conf.Client.HTTPProxy.Address, - conf.Client.DNS.ServeIP, conf.Client.Router.PortMapping) - } + go proxy.StartHTTPProxy(conf.Client.HTTPProxy, conf.Client.Address, + []byte(conf.Conf.Password), route.ShouldProxy) + + if conf.Client.DNSServeIP != "" { + go proxy.StartDNS(conf.Client.DNSServeIP, conf.Client.DNSUpstream, + route.ShouldProxy) + } + + proxy.StartClient(conf.Client.Address, conf.Conf.Password, + conf.Client.DNSServeIP != "", conf.Client.PortForward) - if conf.Server.Upstream == "" && conf.Client.Address == "" { - fmt.Println() - flag.Usage() + default: + if conf.Server.Upstream == "" && conf.Client.Address == "" { + fmt.Println() + flag.Usage() + } } } diff --git a/internal/net/dhcp.go b/proxy/dhcp/dhcp.go similarity index 99% rename from internal/net/dhcp.go rename to proxy/dhcp/dhcp.go index f2131aa..6edc6c0 100644 --- a/internal/net/dhcp.go +++ b/proxy/dhcp/dhcp.go @@ -1,4 +1,4 @@ -package net +package dhcp import ( "math/rand" diff --git a/internal/net/dhcp_test.go b/proxy/dhcp/dhcp_test.go similarity index 50% rename from internal/net/dhcp_test.go rename to proxy/dhcp/dhcp_test.go index 25fbcc2..ede8fbc 100644 --- a/internal/net/dhcp_test.go +++ b/proxy/dhcp/dhcp_test.go @@ -1,13 +1,13 @@ -package net_test +package dhcp_test import ( "fmt" - "github.com/wweir/sower/internal/net" + "github.com/wweir/sower/proxy/dhcp" ) func Example_dns() { - got, err := net.GetDefaultDNSServer() + got, err := dhcp.GetDefaultDNSServer() if err != nil { panic(err) } diff --git a/internal/net/pick_iface_other.go b/proxy/dhcp/pick_iface_other.go similarity index 97% rename from internal/net/pick_iface_other.go rename to proxy/dhcp/pick_iface_other.go index 8b56e3d..89cd606 100644 --- a/internal/net/pick_iface_other.go +++ b/proxy/dhcp/pick_iface_other.go @@ -1,6 +1,6 @@ // +build !windows -package net +package dhcp import ( "errors" diff --git a/internal/net/pick_iface_test.go b/proxy/dhcp/pick_iface_test.go similarity index 50% rename from internal/net/pick_iface_test.go rename to proxy/dhcp/pick_iface_test.go index 9d6c96c..7495bc7 100644 --- a/internal/net/pick_iface_test.go +++ b/proxy/dhcp/pick_iface_test.go @@ -1,13 +1,13 @@ -package net_test +package dhcp_test import ( "fmt" - "github.com/wweir/sower/internal/net" + "github.com/wweir/sower/proxy/dhcp" ) func Example_iface() { - got, err := net.PickInternetInterface() + got, err := dhcp.PickInternetInterface() if err != nil { panic(err) } diff --git a/internal/net/pick_iface_windows.go b/proxy/dhcp/pick_iface_windows.go similarity index 98% rename from internal/net/pick_iface_windows.go rename to proxy/dhcp/pick_iface_windows.go index 2bcf8ea..fa7d6e7 100644 --- a/internal/net/pick_iface_windows.go +++ b/proxy/dhcp/pick_iface_windows.go @@ -1,6 +1,6 @@ // +build windows -package net +package dhcp import ( "bytes" diff --git a/internal/net/util.go b/proxy/dhcp/util.go similarity index 88% rename from internal/net/util.go rename to proxy/dhcp/util.go index d2fdd8d..b64686c 100644 --- a/internal/net/util.go +++ b/proxy/dhcp/util.go @@ -1,4 +1,4 @@ -package net +package dhcp import "net" diff --git a/proxy/dns.go b/proxy/dns.go index df721fb..37e1190 100644 --- a/proxy/dns.go +++ b/proxy/dns.go @@ -6,12 +6,11 @@ import ( "time" "github.com/miekg/dns" - "github.com/wweir/sower/conf" - _net "github.com/wweir/sower/internal/net" + "github.com/wweir/sower/proxy/dhcp" "github.com/wweir/utils/log" ) -func StartDNS(redirectIP, relayServer string) { +func StartDNS(redirectIP, relayServer string, shouldProxy func(string) bool) { serveIP := net.ParseIP(redirectIP) if redirectIP == "" || serveIP.String() != redirectIP { log.Fatalw("invalid listen ip", "ip", redirectIP) @@ -40,7 +39,7 @@ func StartDNS(redirectIP, relayServer string) { domain = domain[:idx] // trim port } - if conf.ShouldProxy(domain) { + if shouldProxy(domain) { w.WriteMsg(localA(r, domain, serveIP)) } else if msg, err := dns.Exchange(r, relayServer); err != nil || msg == nil { @@ -64,7 +63,7 @@ func StartDNS(redirectIP, relayServer string) { func pickRelayAddr(relayServer string) (_ string, err error) { if relayServer == "" { - if relayServer, err = _net.GetDefaultDNSServer(); err != nil { + if relayServer, err = dhcp.GetDefaultDNSServer(); err != nil { return "", err } } diff --git a/proxy/http_proxy.go b/proxy/http_proxy.go index 67f552b..f50e954 100644 --- a/proxy/http_proxy.go +++ b/proxy/http_proxy.go @@ -6,23 +6,20 @@ import ( "io" "net" "net/http" - "strconv" "time" - "github.com/wweir/sower/conf" - _http "github.com/wweir/sower/internal/http" - "github.com/wweir/sower/util" + "github.com/wweir/sower/transport" "github.com/wweir/utils/log" ) -func startHTTPProxy(httpProxyAddr, serverAddr string, password []byte) { +func StartHTTPProxy(httpProxyAddr, serverAddr string, password []byte, shouldProxy func(string) bool) { srv := &http.Server{ Addr: httpProxyAddr, Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodConnect { - httpsProxy(w, r, serverAddr, password) + httpsProxy(w, r, serverAddr, password, shouldProxy) } else { - httpProxy(w, r, serverAddr, password) + httpProxy(w, r, serverAddr, password, shouldProxy) } }), // Disable HTTP/2. @@ -33,13 +30,15 @@ func startHTTPProxy(httpProxyAddr, serverAddr string, password []byte) { go log.Fatalw("serve http proxy", "addr", httpProxyAddr, "err", srv.ListenAndServe()) } -func httpProxy(w http.ResponseWriter, r *http.Request, serverAddr string, password []byte) { - host, port := util.ParseHostPort(r.Host, 80) +func httpProxy(w http.ResponseWriter, r *http.Request, + serverAddr string, password []byte, shouldProxy func(string) bool) { + + target, host := withDefaultPort(r.Host, "80") roundTripper := &http.Transport{} - if conf.ShouldProxy(host) { + if shouldProxy(host) { roundTripper.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - return dial(serverAddr, password, _http.TGT_HTTP, host, port) + return transport.Dial(serverAddr, target, password) } } @@ -59,8 +58,10 @@ func httpProxy(w http.ResponseWriter, r *http.Request, serverAddr string, passwo io.Copy(w, resp.Body) } -func httpsProxy(w http.ResponseWriter, r *http.Request, serverAddr string, password []byte) { - host, port := util.ParseHostPort(r.Host, 443) +func httpsProxy(w http.ResponseWriter, r *http.Request, + serverAddr string, password []byte, shouldProxy func(string) bool) { + + target, host := withDefaultPort(r.Host, "443") conn, _, err := w.(http.Hijacker).Hijack() if err != nil { @@ -76,10 +77,10 @@ func httpsProxy(w http.ResponseWriter, r *http.Request, serverAddr string, passw } var rc net.Conn - if conf.ShouldProxy(host) { - rc, err = dial(serverAddr, password, _http.TGT_HTTPS, host, port) + if shouldProxy(host) { + rc, err = transport.Dial(serverAddr, target, password) } else { - rc, err = net.Dial("tcp", net.JoinHostPort(host, strconv.Itoa(int(port)))) + rc, err = net.Dial("tcp", target) } if err != nil { conn.Write([]byte("sower dial " + serverAddr + " fail: " + err.Error())) diff --git a/proxy/parse_tgt.go b/proxy/parse_tgt.go new file mode 100644 index 0000000..3913c12 --- /dev/null +++ b/proxy/parse_tgt.go @@ -0,0 +1,41 @@ +package proxy + +import ( + "bufio" + "crypto/tls" + "net" + "net/http" + + "github.com/wweir/sower/util" +) + +func ParseHTTP(conn net.Conn) (net.Conn, string, error) { + teeConn := &util.TeeConn{Conn: conn} + defer teeConn.Stop() + + resp, err := http.ReadRequest(bufio.NewReader(teeConn)) + if err != nil { + return teeConn, "", err + } + + if _, _, err := net.SplitHostPort(resp.Host); err != nil { + resp.Host = net.JoinHostPort(resp.Host, "80") + } + + return teeConn, resp.Host, nil +} + +func ParseHTTPS(conn net.Conn) (net.Conn, string, error) { + teeConn := &util.TeeConn{Conn: conn} + defer teeConn.Stop() + + var domain string + tls.Server(teeConn, &tls.Config{ + GetConfigForClient: func(hello *tls.ClientHelloInfo) (*tls.Config, error) { + domain = hello.ServerName + return nil, nil + }, + }).Handshake() + + return teeConn, net.JoinHostPort(domain, "443"), nil +} diff --git a/proxy/proxy.go b/proxy/proxy.go index 5ee680f..852a024 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -4,31 +4,20 @@ import ( "crypto/tls" "net" "net/http" - "strconv" + "os" + "path/filepath" - _http "github.com/wweir/sower/internal/http" - "github.com/wweir/sower/internal/socks5" - "github.com/wweir/sower/util" + "github.com/wweir/sower/transport" "github.com/wweir/utils/log" "golang.org/x/crypto/acme/autocert" ) -const configDir = "/etc/sower" - -type head struct { - checksum byte - length byte -} - -func StartClient(password, serverAddr, httpProxy, dnsServeIP string, forwards map[string]string) { +func StartClient(serverAddr, password string, enableDNS bool, forwards map[string]string) { passwordData := []byte(password) - _, isSocks5 := socks5.IsSocks5Schema(serverAddr) - if httpProxy != "" { - go startHTTPProxy(httpProxy, serverAddr, passwordData) - } + relayToRemote := func(lnAddr, target string, + parseFn func(net.Conn) (net.Conn, string, error)) { - relayToRemote := func(tgtType byte, lnAddr string, host string, port uint16) { ln, err := net.Listen("tcp", lnAddr) if err != nil { log.Fatalw("tcp listen", "port", lnAddr, "err", err) @@ -44,26 +33,16 @@ func StartClient(password, serverAddr, httpProxy, dnsServeIP string, forwards ma go func(conn net.Conn) { defer conn.Close() - if isSocks5 { - teeConn := &util.TeeConn{Conn: conn} - teeConn.StartOrReset() - - switch tgtType { - case _http.TGT_HTTP: - conn, host, port, err = _http.ParseHTTP(teeConn) - case _http.TGT_HTTPS: - conn, host, err = _http.ParseHTTPS(teeConn) - } - if err != nil { - log.Errorw("parse socks5 target", "err", err) + if parseFn != nil { + if conn, target, err = parseFn(conn); err != nil { + log.Warnw("parse target", "err", err) return } - teeConn.Stop() } - rc, err := dial(serverAddr, passwordData, tgtType, host, port) + rc, err := transport.Dial(serverAddr, target, passwordData) if err != nil { - log.Errorw("dial", "addr", serverAddr, "err", err) + log.Warnw("dial", "addr", serverAddr, "err", err) return } defer rc.Close() @@ -73,30 +52,33 @@ func StartClient(password, serverAddr, httpProxy, dnsServeIP string, forwards ma } } - if dnsServeIP != "" { - go relayToRemote(_http.TGT_HTTP, dnsServeIP+":http", "", 80) - go relayToRemote(_http.TGT_HTTPS, dnsServeIP+":https", "", 443) + for from, to := range forwards { + go relayToRemote(from, to, nil) } - for from, to := range forwards { - go func(from, to string) { - host, port := util.ParseHostPort(to, 0) - relayToRemote(_http.TGT_OTHER, from, host, port) - }(from, to) + if enableDNS { + go relayToRemote(":80", "", ParseHTTP) + go relayToRemote(":443", "", ParseHTTPS) } select {} } func StartServer(relayTarget, password, certFile, keyFile, email string) { + dir, _ := os.UserCacheDir() + dir = filepath.Join("/", dir, "sower") + log.Infow("certificate cache dir", "dir", dir) + certManager := autocert.Manager{ Prompt: autocert.AcceptTOS, - Cache: autocert.DirCache(configDir), //folder for storing certificates Email: email, + Cache: autocert.DirCache(dir), } + tlsConf := &tls.Config{ GetCertificate: certManager.GetCertificate, MinVersion: tls.VersionTLS12, + NextProtos: []string{"http/1.1", "h2"}, } if certFile != "" && keyFile != "" { if cert, err := tls.LoadX509KeyPair(certFile, keyFile); err != nil { @@ -108,7 +90,7 @@ func StartServer(relayTarget, password, certFile, keyFile, email string) { } // Try to redirect 80 to 443 - go http.ListenAndServe(":http", certManager.HTTPHandler(http.HandlerFunc( + go http.ListenAndServe(":80", certManager.HTTPHandler(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { if host, _, err := net.SplitHostPort(r.Host); err != nil { r.URL.Host = r.Host @@ -119,7 +101,7 @@ func StartServer(relayTarget, password, certFile, keyFile, email string) { http.Redirect(w, r, r.URL.String(), 301) }))) - ln, err := tls.Listen("tcp", ":https", tlsConf) + ln, err := tls.Listen("tcp", ":443", tlsConf) if err != nil { log.Fatalw("tcp listen", "err", err) } @@ -135,20 +117,14 @@ func StartServer(relayTarget, password, certFile, keyFile, email string) { go func(conn net.Conn) { defer conn.Close() - conn, domain, port, err := _http.ParseAddr(conn, passwordData) - if err != nil { - log.Errorw("parse relay target", "err", err) - return - } - - addr := relayTarget - if domain != "" { - addr = net.JoinHostPort(domain, strconv.Itoa(int(port))) + conn, target := transport.ToProxyConn(conn, passwordData) + if target == "" { + target = relayTarget } - rc, err := net.Dial("tcp", addr) + rc, err := net.Dial("tcp", target) if err != nil { - log.Errorw("tcp dial", "addr", addr, "err", err) + log.Errorw("tcp dial", "addr", target, "err", err) return } defer rc.Close() diff --git a/proxy/util.go b/proxy/util.go index 570e7d8..d0b8fd1 100644 --- a/proxy/util.go +++ b/proxy/util.go @@ -1,31 +1,19 @@ package proxy import ( - "crypto/tls" "io" "net" "sync" "sync/atomic" "time" - - "github.com/wweir/sower/internal/http" - "github.com/wweir/sower/internal/socks5" ) -func dial(serverAddr string, password []byte, tgtType byte, domain string, port uint16) (net.Conn, error) { - if addr, ok := socks5.IsSocks5Schema(serverAddr); ok { - conn, err := net.Dial("tcp", addr) - if err != nil { - return nil, err - } - return socks5.ToSocks5(conn, domain, port), nil - } - - conn, err := tls.Dial("tcp", net.JoinHostPort(serverAddr, "443"), &tls.Config{}) +func withDefaultPort(addr string, port string) (address, host string) { + host, _, err := net.SplitHostPort(addr) if err != nil { - return nil, err + return net.JoinHostPort(addr, port), addr } - return http.NewTgtConn(conn, password, tgtType, domain, port), nil + return addr, host } func relay(conn1, conn2 net.Conn) { @@ -36,7 +24,6 @@ func relay(conn1, conn2 net.Conn) { redirect(conn1, conn2, wg, exitFlag) wg.Wait() } - func redirect(dst, src net.Conn, wg *sync.WaitGroup, exitFlag *int32) { io.Copy(dst, src) diff --git a/internal/http/http_ping.go b/router/http_ping.go similarity index 96% rename from internal/http/http_ping.go rename to router/http_ping.go index 7184a21..d134ea7 100644 --- a/internal/http/http_ping.go +++ b/router/http_ping.go @@ -1,4 +1,4 @@ -package http +package router import ( "bytes" @@ -13,8 +13,10 @@ import ( // Port ========================== type Port uint16 -const HTTP Port = 80 -const HTTPS Port = 443 +const ( + HTTP Port = 80 + HTTPS Port = 443 +) // Ping try connect to a http(s) server with domain though the http addr func (p Port) Ping(domain string, timeout time.Duration) error { @@ -96,7 +98,7 @@ func NewClientHelloSNIMsg(domain string) []byte { length := uint16(len(domain)) msg := &clientHelloSNI{ ContentType: 0x16, // Content Type: Handshake (22) - Version: 0x0301, // Version: TLS 1.0 (0x0301) + Version: 0x0301, // Version: TLS 1.0 Length: length + 56, handshakeProtocol: handshakeProtocol{ HandshakeType: 0x01, // Handshake Type: Client Hello (1) diff --git a/router/router.go b/router/router.go new file mode 100644 index 0000000..d4b2a5c --- /dev/null +++ b/router/router.go @@ -0,0 +1,142 @@ +package router + +import ( + "net" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/wweir/sower/transport" + "github.com/wweir/utils/log" + "github.com/wweir/utils/mem" +) + +type Route struct { + port Port + once sync.Once + cache *mem.Cache + + ProxyAddress string + ProxyPassword string + password []byte + + DetectLevel int // dynamic detect proxy level + DetectTimeout string // dynamic detect timeout + timeout time.Duration + + DirectList []string + directRule *Node + ProxyList []string + proxyRule *Node + DynamicList []string + dynamicRule *Node + PersistFn func(string) +} + +// ShouldProxy check if the domain shoule request though proxy +func (r *Route) ShouldProxy(domain string) bool { + r.once.Do(func() { + r.cache = mem.New(4 * time.Hour) + r.password = []byte(r.ProxyPassword) + r.directRule = NewNodeFromRules(r.DirectList...) + r.proxyRule = NewNodeFromRules(r.ProxyList...) + r.dynamicRule = NewNodeFromRules(r.DynamicList...) + + if timeout, err := time.ParseDuration(r.DetectTimeout); err != nil { + r.timeout = 200 * time.Millisecond + log.Warnw("parse detect timeout", "err", err, "default", "t.timeout") + } else { + r.timeout = timeout + } + }) + + // break deadlook, for wildcard + if strings.Count(domain, ".") > 4 { + return false + } + domain = strings.TrimSuffix(domain, ".") + + if domain == r.ProxyAddress { + return false + } + if r.directRule.Match(domain) { + return false + } + if r.proxyRule.Match(domain) { + return true + } + + r.cache.Remember(r, domain) + return r.dynamicRule.Match(domain) +} + +func (r *Route) Get(key interface{}) (err error) { + domain := key.(string) + + if r.detect(domain) > r.DetectLevel { + r.dynamicRule.Add(domain) + if r.PersistFn != nil { + r.PersistFn(domain) + } + } + return nil +} + +// detect and caculate direct connection and proxy connection score +func (r *Route) detect(domain string) int { + wg := sync.WaitGroup{} + httpScore, httpsScore := new(int32), new(int32) + for _, ping := range [...]*Route{{port: HTTP}, {port: HTTPS}} { + wg.Add(1) + go func(ping *Route) { + defer wg.Done() + + if err := ping.port.Ping(domain, r.timeout); err != nil { + return + } + + switch ping.port { + case HTTP: + if !atomic.CompareAndSwapInt32(httpScore, 0, -2) { + atomic.AddInt32(httpScore, -1) + } + case HTTPS: + if !atomic.CompareAndSwapInt32(httpsScore, 0, -2) { + atomic.AddInt32(httpScore, -1) + } + } + }(ping) + } + for _, ping := range [...]*Route{{port: HTTP}, {port: HTTPS}} { + wg.Add(1) + go func(ping *Route) { + defer wg.Done() + + target := net.JoinHostPort(domain, ping.port.String()) + conn, err := transport.Dial(r.ProxyAddress, target, r.password) + if err != nil { + log.Errorw("sower dial", "addr", r.ProxyAddress, "err", err) + return + } + + if err := ping.port.PingWithConn(domain, conn, r.timeout); err != nil { + return + } + + switch ping.port { + case HTTP: + if !atomic.CompareAndSwapInt32(httpScore, 0, 2) { + atomic.AddInt32(httpScore, 1) + } + case HTTPS: + if !atomic.CompareAndSwapInt32(httpsScore, 0, 2) { + atomic.AddInt32(httpScore, 1) + } + } + }(ping) + } + + wg.Wait() + return int(*httpScore + *httpsScore) +} diff --git a/util/suffix_tree.go b/router/suffix_tree.go similarity index 97% rename from util/suffix_tree.go rename to router/suffix_tree.go index ad13835..05fb7af 100644 --- a/util/suffix_tree.go +++ b/router/suffix_tree.go @@ -1,4 +1,4 @@ -package util +package router import ( "strings" @@ -69,6 +69,9 @@ func (n *Node) Match(item string) bool { if n == nil { return false } + + n.RLock() + defer n.RUnlock() return n.matchSecs(strings.Split(n.trim(item), n.sep), false) } diff --git a/util/suffix_tree_test.go b/router/suffix_tree_test.go similarity index 74% rename from util/suffix_tree_test.go rename to router/suffix_tree_test.go index b338b1c..b56e008 100644 --- a/util/suffix_tree_test.go +++ b/router/suffix_tree_test.go @@ -1,7 +1,9 @@ -package util +package router_test import ( "testing" + + "github.com/wweir/sower/router" ) func TestNode_Match(t *testing.T) { @@ -11,18 +13,18 @@ func TestNode_Match(t *testing.T) { } tests := []struct { name string - node *Node + node *router.Node tests []test }{{ "simple", - NewNodeFromRules("a.wweir.cc", "b.wweir.cc"), + router.NewNodeFromRules("a.wweir.cc", "b.wweir.cc"), []test{ {"a.wweir.cc", true}, {"b.wweir.cc", true}, }, }, { "parent", - NewNodeFromRules("wweir.cc", "a.wweir.cc"), + router.NewNodeFromRules("wweir.cc", "a.wweir.cc"), []test{ {"wweir.cc", true}, {"a.wweir.cc", true}, @@ -30,7 +32,7 @@ func TestNode_Match(t *testing.T) { }, }, { "fuzz1", - NewNodeFromRules("wweir.cc", "a.wweir.cc", "*.wweir.cc"), + router.NewNodeFromRules("wweir.cc", "a.wweir.cc", "*.wweir.cc"), []test{ {"wweir.cc", true}, {"a.wweir.cc", true}, @@ -39,7 +41,7 @@ func TestNode_Match(t *testing.T) { }, }, { "fuzz2", - NewNodeFromRules("a.*.cc", "c.wweir.*"), + router.NewNodeFromRules("a.*.cc", "c.wweir.*"), []test{ {"wweir.cc", false}, {"a.wweir.cc", true}, @@ -48,7 +50,7 @@ func TestNode_Match(t *testing.T) { }, }, { "fuzz3", - NewNodeFromRules("*.*.cc", "iamp.*.*"), + router.NewNodeFromRules("*.*.cc", "iamp.*.*"), []test{ {"wweir.cc", false}, {"a.wweir.cc", true}, @@ -57,7 +59,7 @@ func TestNode_Match(t *testing.T) { }, }, { "fuzz4", - NewNodeFromRules("**.cc", "a.**.com", "**.wweir.*"), + router.NewNodeFromRules("**.cc", "a.**.com", "**.wweir.*"), []test{ {"wweir.cc", true}, {"a.wweir.cc", true}, diff --git a/transport/proxy_conn.go b/transport/proxy_conn.go new file mode 100644 index 0000000..5bcd974 --- /dev/null +++ b/transport/proxy_conn.go @@ -0,0 +1,73 @@ +package transport + +import ( + "crypto/md5" + "crypto/tls" + "encoding/binary" + "io" + "net" + "strconv" + + "github.com/wweir/sower/util" +) + +// checksum(>=0x80) + port + target_length + target + data +// data(HTTP/HTTPS, first byte < 0x7F) +type head struct { + Checksum byte + Port uint16 + AddrLen uint8 +} + +func ToProxyConn(conn net.Conn, password []byte) (net.Conn, string) { + teeConn := &util.TeeConn{Conn: conn} + defer teeConn.Stop() + + h := &head{} + if err := binary.Read(teeConn, binary.BigEndian, h); err != nil || h.Checksum < 0x80 { + return teeConn, "" + } + + buf := make([]byte, int(h.AddrLen)) + if _, err := io.ReadFull(teeConn, buf); err != nil { + return teeConn, "" + } + + if h.Checksum != sumChecksum(buf, password) { + return teeConn, "" + } + + teeConn.Reset() + return teeConn, net.JoinHostPort(string(buf), strconv.Itoa(int(h.Port))) +} + +func DialTlsProxyConn(address, tgtHost string, tgtPort uint16, tlsCfg *tls.Config, password []byte) (net.Conn, error) { + conn, err := tls.Dial("tcp", address, tlsCfg) + if err != nil { + return nil, err + } + + h := &head{ + Checksum: sumChecksum([]byte(tgtHost), password), + Port: tgtPort, + AddrLen: uint8(len(tgtHost)), + } + if err := binary.Write(conn, binary.BigEndian, h); err != nil { + conn.Close() + return nil, err + } + + data := []byte(tgtHost) + for n, nn := 0, 0; nn < len(data); nn += n { + if n, err = conn.Write(data[nn:]); err != nil { + conn.Close() + return nil, err + } + } + + return conn, nil +} + +func sumChecksum(target, password []byte) byte { + return md5.Sum(append(target, password...))[0] | 0x80 +} diff --git a/internal/socks5/socks5.go b/transport/socks5.go similarity index 88% rename from internal/socks5/socks5.go rename to transport/socks5.go index a3a1aed..a8c3895 100644 --- a/internal/socks5/socks5.go +++ b/transport/socks5.go @@ -1,10 +1,11 @@ -package socks5 +package transport import ( "encoding/binary" "fmt" "io" "net" + "strconv" "strings" ) @@ -20,13 +21,22 @@ func IsSocks5Schema(addr string) (string, bool) { return addr, false } -func ToSocks5(c net.Conn, domain string, port uint16) net.Conn { +func ToSocks5(c net.Conn, address string) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + p, err := strconv.Atoi(port) + if err != nil { + return nil, err + } + return &conn{ init: make(chan struct{}), Conn: c, - domain: domain, - port: []byte{byte(port >> 8), byte(port)}, - } + domain: host, + port: []byte{byte(p >> 8), byte(p)}, + }, nil } type conn struct { diff --git a/internal/socks5/rfc_def.go b/transport/socks5_rfc.go similarity index 97% rename from internal/socks5/rfc_def.go rename to transport/socks5_rfc.go index 313bfee..6ae46f0 100644 --- a/internal/socks5/rfc_def.go +++ b/transport/socks5_rfc.go @@ -1,4 +1,4 @@ -package socks5 +package transport // https://tools.ietf.org/html/rfc1928 diff --git a/transport/util.go b/transport/util.go new file mode 100644 index 0000000..9be4179 --- /dev/null +++ b/transport/util.go @@ -0,0 +1,36 @@ +package transport + +import ( + "crypto/tls" + "net" + "strconv" +) + +func Dial(address, target string, password []byte) (net.Conn, error) { + if addr, ok := IsSocks5Schema(address); ok { + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, err + } + if conn, err = ToSocks5(conn, target); err != nil { + conn.Close() + return nil, err + } + return conn, nil + } + + host, port, err := net.SplitHostPort(target) + if err != nil { + port = "443" + } + p, err := strconv.Atoi(port) + if err != nil { + return nil, err + } + + return DialTlsProxyConn(net.JoinHostPort(address, "443"), host, uint16(p), &tls.Config{ + ServerName: address, + MinVersion: tls.VersionTLS12, + NextProtos: []string{"http/1.1", "h2"}, + }, password) +} diff --git a/util/tee_conn.go b/util/tee_conn.go index 89a2485..8220929 100644 --- a/util/tee_conn.go +++ b/util/tee_conn.go @@ -1,27 +1,28 @@ package util import ( + "io" "net" ) type TeeConn struct { net.Conn - buf []byte - offset int - tee bool // read + buf []byte + offset int + stop bool // read + EnableWrite bool } -func (t *TeeConn) StartOrReset() { +func (t *TeeConn) Reread() { t.offset = 0 - t.tee = true } -func (t *TeeConn) DropAndRestart() { +func (t *TeeConn) Reset() { t.buf = []byte{} - t.tee = true + t.offset = 0 } func (t *TeeConn) Stop() { t.offset = 0 - t.tee = false + t.stop = true } func (t *TeeConn) Read(b []byte) (n int, err error) { @@ -33,9 +34,17 @@ func (t *TeeConn) Read(b []byte) (n int, err error) { } n, err = t.Conn.Read(b) - if t.tee { + if !t.stop { t.buf = append(t.buf, b[:n]...) t.offset += n } return n, err } + +func (t *TeeConn) Write(b []byte) (n int, err error) { + if t.stop || t.EnableWrite { + return t.Conn.Write(b) + } + + return 0, io.EOF +} diff --git a/util/util.go b/util/util.go deleted file mode 100644 index 5e6607a..0000000 --- a/util/util.go +++ /dev/null @@ -1,19 +0,0 @@ -package util - -import ( - "net" - "strconv" -) - -func ParseHostPort(addr string, defaultPort uint16) (string, uint16) { - h, p, err := net.SplitHostPort(addr) - if err != nil { - if defaultPort == 0 { - panic("parse port fail with no default, addr: " + addr) - } - return addr, defaultPort - } - - pNum, _ := strconv.ParseUint(p, 10, 16) - return h, uint16(pNum) -}