diff --git a/cmd/sower/main.go b/cmd/sower/main.go index 06ef7a1..4fbcd3a 100644 --- a/cmd/sower/main.go +++ b/cmd/sower/main.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "compress/gzip" "io" "net" "net/http" @@ -19,6 +20,7 @@ import ( "github.com/miekg/dns" "github.com/pkg/errors" "github.com/sower-proxy/deferlog/log" + "github.com/wweir/sower/pkg/suffixtree" "github.com/wweir/sower/router" ) @@ -95,12 +97,12 @@ func init() { } func main() { - proxtDial := GenProxyDial(conf.Remote.Type, conf.Remote.Addr, conf.Remote.Password) - r := router.NewRouter(conf.DNS.Serve, conf.DNS.Fallback, conf.Router.Country.MMDB, proxtDial) - r.SetBlockRules(conf.Router.Block.Rules) - r.SetDirectRules(conf.Router.Direct.Rules) - r.SetProxyRules(conf.Router.Proxy.Rules) - r.SetCountryCIDRs(conf.Router.Country.Rules) + proxyDial := GenProxyDial(conf.Remote.Type, conf.Remote.Addr, conf.Remote.Password) + r := router.NewRouter(conf.DNS.Serve, conf.DNS.Fallback, conf.Router.Country.MMDB, proxyDial) + r.BlockRule = suffixtree.NewNodeFromRules(conf.Router.Block.Rules...) + r.DirectRule = suffixtree.NewNodeFromRules(conf.Router.Direct.Rules...) + r.ProxyRule = suffixtree.NewNodeFromRules(conf.Router.Proxy.Rules...) + r.AddCountryCIDRs(conf.Router.Country.Rules...) go func() { if conf.DNS.Disable { @@ -143,27 +145,30 @@ func main() { }() start := time.Now() - r.SetBlockRules(append(conf.Router.Block.Rules, - loadRules(proxtDial, conf.Router.Block.File, conf.Router.Block.FilePrefix)...)) - r.SetDirectRules(append(conf.Router.Direct.Rules, - loadRules(proxtDial, conf.Router.Direct.File, conf.Router.Direct.FilePrefix)...)) - r.SetProxyRules(append(conf.Router.Proxy.Rules, - loadRules(proxtDial, conf.Router.Proxy.File, conf.Router.Proxy.FilePrefix)...)) - r.SetCountryCIDRs(append(conf.Router.Country.Rules, - loadRules(proxtDial, conf.Router.Country.File, conf.Router.Country.FilePrefix)...)) + loadRule(r.BlockRule, proxyDial, conf.Router.Block.File, conf.Router.Block.FilePrefix) + loadRule(r.DirectRule, proxyDial, conf.Router.Direct.File, conf.Router.Direct.FilePrefix) + loadRule(r.ProxyRule, proxyDial, conf.Router.Proxy.File, conf.Router.Proxy.FilePrefix) + for line := range featchRuleFile(proxyDial, conf.Router.Country.File) { + r.AddCountryCIDRs(line) + } log.Info(). - Dur("spend", time.Since(start)). - Int("blockRule", len(conf.Router.Block.Rules)). - Int("directRule", len(conf.Router.Direct.Rules)). - Int("proxyRule", len(conf.Router.Proxy.Rules)). - Int("countryRule", len(conf.Router.Country.Rules)). + Dur("took", time.Since(start)). + Uint64("blockRule", r.BlockRule.Count). + Uint64("directRule", r.DirectRule.Count). + Uint64("proxyRule", r.ProxyRule.Count). Msg("Loaded rules, proxy started") runtime.GC() select {} } -func loadRules(proxyDial router.ProxyDialFn, file, linePrefix string) []string { +func loadRule(rule *suffixtree.Node, proxyDial router.ProxyDialFn, file, linePrefix string) { + for line := range featchRuleFile(proxyDial, file) { + rule.Add(linePrefix + line) + } + rule.GC() +} +func featchRuleFile(proxyDial router.ProxyDialFn, file string) <-chan string { var loadFn func() (io.ReadCloser, error) if _, err := url.Parse(file); err == nil { // load rule file from remote by HTTP @@ -178,7 +183,9 @@ func loadRules(proxyDial router.ProxyDialFn, file, linePrefix string) []string { } loadFn = func() (io.ReadCloser, error) { - resp, err := client.Get(file) + req, _ := http.NewRequest(http.MethodGet, file, nil) + req.Header.Add("Accept-Encoding", "gzip") + resp, err := client.Do(req) if err != nil { return nil, err } @@ -209,34 +216,37 @@ func loadRules(proxyDial router.ProxyDialFn, file, linePrefix string) []string { time.Sleep(i * i * 100 * time.Millisecond) rc, err = loadFn() } - if err != nil { - log.Fatal().Err(err). - Str("file", file). - Msg("load config file") - } - defer rc.Close() - - // parse rule file into rule tree - var lines []string - br := bufio.NewReader(rc) - for { - line, _, err := br.ReadLine() - if err == io.EOF { - break - } else if err != nil { - log.Error().Err(err). - Str("file", file). - Msg("read line") - return nil - } + log.InfoFatal(err). + Str("file", file). + Msg("fetch rule file") - if strings.TrimSpace(string(line)) == "" { - continue - } + ch := make(chan string, 100) + go func() { + defer rc.Close() + defer close(ch) + gr, err := gzip.NewReader(rc) + log.DebugFatal(err).Msg("gzip reader") + defer gr.Close() + + br := bufio.NewReader(gr) + for { + line, _, err := br.ReadLine() + if err == io.EOF { + return + } else if err != nil { + log.Error().Err(err). + Str("file", file). + Msg("read line") + return + } - // use line content as suffix - lines = append(lines, linePrefix+string(line)) - } + if strings.TrimSpace(string(line)) == "" { + continue + } + + ch <- string(line) + } + }() - return lines + return ch } diff --git a/pkg/suffixtree/suffix_tree.go b/pkg/suffixtree/suffix_tree.go index aa02f65..b01912c 100644 --- a/pkg/suffixtree/suffix_tree.go +++ b/pkg/suffixtree/suffix_tree.go @@ -1,13 +1,13 @@ package suffixtree import ( - "runtime" "strings" ) type Node struct { *node - sep string + sep string + Count uint64 } type node struct { secs []string @@ -15,30 +15,25 @@ type node struct { } func NewNodeFromRules(rules ...string) *Node { - n := &Node{&node{}, "."} + n := &Node{&node{}, ".", 0} for i := range rules { n.Add(rules[i]) } - n.node = n.node.lite() - runtime.GC() + n.GC() return n } -func (n *node) lite() *node { +func (n *node) GC() { if n == nil { - return nil + return } - lite := &node{ - secs: make([]string, 0, len(n.secs)), - subNodes: make([]*node, 0, len(n.subNodes)), - } - lite.secs = append(lite.secs, n.secs...) + n.secs = GCSlice(n.secs) + n.subNodes = GCSlice(n.subNodes) for i := range n.subNodes { - lite.subNodes = append(lite.subNodes, n.subNodes[i].lite()) + n.subNodes[i].GC() } - return lite } func (n *Node) String() string { @@ -59,6 +54,7 @@ func (n *Node) trim(item string) string { } func (n *Node) Add(item string) { + n.Count++ n.add(strings.Split(n.trim(item), n.sep)) } func (n *node) add(secs []string) { diff --git a/pkg/suffixtree/util.go b/pkg/suffixtree/util.go new file mode 100644 index 0000000..903066e --- /dev/null +++ b/pkg/suffixtree/util.go @@ -0,0 +1,6 @@ +package suffixtree + +func GCSlice[T any](arr []T) []T { + old := arr + return append(make([]T, 0, len(old)), old...) +} diff --git a/router/dns.go b/router/dns.go index 8650cea..7bd12a9 100644 --- a/router/dns.go +++ b/router/dns.go @@ -17,13 +17,13 @@ func (r *Router) ServeDNS(w dns.ResponseWriter, req *dns.Msg) { domain := req.Question[0].Name // 1. rule_based( block > direct > proxy ) switch { - case r.blockRule.Match(domain): + case r.BlockRule.Match(domain): _ = w.WriteMsg(r.dnsFail(req, dns.RcodeNameError)) return - case r.directRule.Match(domain): + case r.DirectRule.Match(domain): - case r.proxyRule.Match(domain): + case r.ProxyRule.Match(domain): _ = w.WriteMsg(r.dnsProxyA(domain, r.dns.serveIP, req)) return } diff --git a/router/router.go b/router/router.go index 5a52fd7..539abb1 100644 --- a/router/router.go +++ b/router/router.go @@ -17,9 +17,9 @@ import ( type ProxyDialFn func(network, host string, port uint16) (net.Conn, error) type Router struct { - blockRule *suffixtree.Node - directRule *suffixtree.Node - proxyRule *suffixtree.Node + BlockRule *suffixtree.Node + DirectRule *suffixtree.Node + ProxyRule *suffixtree.Node ProxyDial ProxyDialFn dns struct { @@ -52,24 +52,15 @@ func NewRouter(serveIP, fallbackDNS, mmdbFile string, proxyDial ProxyDialFn) *Ro return &r } -func (r *Router) SetBlockRules(blockList []string) { - r.blockRule = suffixtree.NewNodeFromRules(blockList...) -} -func (r *Router) SetDirectRules(directList []string) { - r.directRule = suffixtree.NewNodeFromRules(directList...) -} -func (r *Router) SetProxyRules(proxyList []string) { - r.proxyRule = suffixtree.NewNodeFromRules(proxyList...) -} -func (r *Router) SetCountryCIDRs(directCIDRs []string) { - r.country.cidrs = make([]*net.IPNet, 0, len(directCIDRs)) - for _, cidr := range directCIDRs { +func (r *Router) AddCountryCIDRs(cidrs ...string) { + for _, cidr := range cidrs { _, ipnet, err := net.ParseCIDR(cidr) if err != nil { log.Error().Err(err).Msg("Failed to parse CIDR") } r.country.cidrs = append(r.country.cidrs, ipnet) } + r.country.cidrs = suffixtree.GCSlice(r.country.cidrs) } func (r *Router) dialDNSConn() { @@ -111,13 +102,13 @@ func (r *Router) RouteHandle(conn net.Conn, domain string, port uint16) (err err // 2. detect_based( CN IP || access site ) // 3. fallback( proxy ) switch { - case r.blockRule.Match(domain): + case r.BlockRule.Match(domain): return nil - case r.directRule.Match(domain): + case r.DirectRule.Match(domain): return r.DirectHandle(conn, addr) - case r.proxyRule.Match(domain): + case r.ProxyRule.Match(domain): return r.ProxyHandle(conn, domain, port) case r.localSite(domain), r.isAccess(domain, port):