Skip to content

Commit

Permalink
feat: optimize code with 1.18 generic
Browse files Browse the repository at this point in the history
  • Loading branch information
wweir committed May 1, 2022
1 parent 78d36b0 commit 6fc97a3
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 83 deletions.
106 changes: 58 additions & 48 deletions cmd/sower/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package main

import (
"bufio"
"compress/gzip"
"io"
"net"
"net/http"
Expand All @@ -19,6 +20,7 @@ import (
"github.com/miekg/dns"
"github.com/pkg/errors"
"github.com/sower-proxy/deferlog/log"
"github.com/wweir/sower/pkg/suffixtree"
"github.com/wweir/sower/router"
)

Expand Down Expand Up @@ -95,12 +97,12 @@ func init() {
}

func main() {
proxtDial := GenProxyDial(conf.Remote.Type, conf.Remote.Addr, conf.Remote.Password)
r := router.NewRouter(conf.DNS.Serve, conf.DNS.Fallback, conf.Router.Country.MMDB, proxtDial)
r.SetBlockRules(conf.Router.Block.Rules)
r.SetDirectRules(conf.Router.Direct.Rules)
r.SetProxyRules(conf.Router.Proxy.Rules)
r.SetCountryCIDRs(conf.Router.Country.Rules)
proxyDial := GenProxyDial(conf.Remote.Type, conf.Remote.Addr, conf.Remote.Password)
r := router.NewRouter(conf.DNS.Serve, conf.DNS.Fallback, conf.Router.Country.MMDB, proxyDial)
r.BlockRule = suffixtree.NewNodeFromRules(conf.Router.Block.Rules...)
r.DirectRule = suffixtree.NewNodeFromRules(conf.Router.Direct.Rules...)
r.ProxyRule = suffixtree.NewNodeFromRules(conf.Router.Proxy.Rules...)
r.AddCountryCIDRs(conf.Router.Country.Rules...)

go func() {
if conf.DNS.Disable {
Expand Down Expand Up @@ -143,27 +145,30 @@ func main() {
}()

start := time.Now()
r.SetBlockRules(append(conf.Router.Block.Rules,
loadRules(proxtDial, conf.Router.Block.File, conf.Router.Block.FilePrefix)...))
r.SetDirectRules(append(conf.Router.Direct.Rules,
loadRules(proxtDial, conf.Router.Direct.File, conf.Router.Direct.FilePrefix)...))
r.SetProxyRules(append(conf.Router.Proxy.Rules,
loadRules(proxtDial, conf.Router.Proxy.File, conf.Router.Proxy.FilePrefix)...))
r.SetCountryCIDRs(append(conf.Router.Country.Rules,
loadRules(proxtDial, conf.Router.Country.File, conf.Router.Country.FilePrefix)...))
loadRule(r.BlockRule, proxyDial, conf.Router.Block.File, conf.Router.Block.FilePrefix)
loadRule(r.DirectRule, proxyDial, conf.Router.Direct.File, conf.Router.Direct.FilePrefix)
loadRule(r.ProxyRule, proxyDial, conf.Router.Proxy.File, conf.Router.Proxy.FilePrefix)
for line := range featchRuleFile(proxyDial, conf.Router.Country.File) {
r.AddCountryCIDRs(line)
}

log.Info().
Dur("spend", time.Since(start)).
Int("blockRule", len(conf.Router.Block.Rules)).
Int("directRule", len(conf.Router.Direct.Rules)).
Int("proxyRule", len(conf.Router.Proxy.Rules)).
Int("countryRule", len(conf.Router.Country.Rules)).
Dur("took", time.Since(start)).
Uint64("blockRule", r.BlockRule.Count).
Uint64("directRule", r.DirectRule.Count).
Uint64("proxyRule", r.ProxyRule.Count).
Msg("Loaded rules, proxy started")
runtime.GC()
select {}
}

func loadRules(proxyDial router.ProxyDialFn, file, linePrefix string) []string {
func loadRule(rule *suffixtree.Node, proxyDial router.ProxyDialFn, file, linePrefix string) {
for line := range featchRuleFile(proxyDial, file) {
rule.Add(linePrefix + line)
}
rule.GC()
}
func featchRuleFile(proxyDial router.ProxyDialFn, file string) <-chan string {
var loadFn func() (io.ReadCloser, error)
if _, err := url.Parse(file); err == nil {
// load rule file from remote by HTTP
Expand All @@ -178,7 +183,9 @@ func loadRules(proxyDial router.ProxyDialFn, file, linePrefix string) []string {
}

loadFn = func() (io.ReadCloser, error) {
resp, err := client.Get(file)
req, _ := http.NewRequest(http.MethodGet, file, nil)
req.Header.Add("Accept-Encoding", "gzip")
resp, err := client.Do(req)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -209,34 +216,37 @@ func loadRules(proxyDial router.ProxyDialFn, file, linePrefix string) []string {
time.Sleep(i * i * 100 * time.Millisecond)
rc, err = loadFn()
}
if err != nil {
log.Fatal().Err(err).
Str("file", file).
Msg("load config file")
}
defer rc.Close()

// parse rule file into rule tree
var lines []string
br := bufio.NewReader(rc)
for {
line, _, err := br.ReadLine()
if err == io.EOF {
break
} else if err != nil {
log.Error().Err(err).
Str("file", file).
Msg("read line")
return nil
}
log.InfoFatal(err).
Str("file", file).
Msg("fetch rule file")

if strings.TrimSpace(string(line)) == "" {
continue
}
ch := make(chan string, 100)
go func() {
defer rc.Close()
defer close(ch)
gr, err := gzip.NewReader(rc)
log.DebugFatal(err).Msg("gzip reader")
defer gr.Close()

br := bufio.NewReader(gr)
for {
line, _, err := br.ReadLine()
if err == io.EOF {
return
} else if err != nil {
log.Error().Err(err).
Str("file", file).
Msg("read line")
return
}

// use line content as suffix
lines = append(lines, linePrefix+string(line))
}
if strings.TrimSpace(string(line)) == "" {
continue
}

ch <- string(line)
}
}()

return lines
return ch
}
24 changes: 10 additions & 14 deletions pkg/suffixtree/suffix_tree.go
Original file line number Diff line number Diff line change
@@ -1,44 +1,39 @@
package suffixtree

import (
"runtime"
"strings"
)

type Node struct {
*node
sep string
sep string
Count uint64
}
type node struct {
secs []string
subNodes []*node
}

func NewNodeFromRules(rules ...string) *Node {
n := &Node{&node{}, "."}
n := &Node{&node{}, ".", 0}
for i := range rules {
n.Add(rules[i])
}

n.node = n.node.lite()
runtime.GC()
n.GC()
return n
}

func (n *node) lite() *node {
func (n *node) GC() {
if n == nil {
return nil
return
}

lite := &node{
secs: make([]string, 0, len(n.secs)),
subNodes: make([]*node, 0, len(n.subNodes)),
}
lite.secs = append(lite.secs, n.secs...)
n.secs = GCSlice(n.secs)
n.subNodes = GCSlice(n.subNodes)
for i := range n.subNodes {
lite.subNodes = append(lite.subNodes, n.subNodes[i].lite())
n.subNodes[i].GC()
}
return lite
}

func (n *Node) String() string {
Expand All @@ -59,6 +54,7 @@ func (n *Node) trim(item string) string {
}

func (n *Node) Add(item string) {
n.Count++
n.add(strings.Split(n.trim(item), n.sep))
}
func (n *node) add(secs []string) {
Expand Down
6 changes: 6 additions & 0 deletions pkg/suffixtree/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package suffixtree

func GCSlice[T any](arr []T) []T {
old := arr
return append(make([]T, 0, len(old)), old...)
}
6 changes: 3 additions & 3 deletions router/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ func (r *Router) ServeDNS(w dns.ResponseWriter, req *dns.Msg) {
domain := req.Question[0].Name
// 1. rule_based( block > direct > proxy )
switch {
case r.blockRule.Match(domain):
case r.BlockRule.Match(domain):
_ = w.WriteMsg(r.dnsFail(req, dns.RcodeNameError))
return

case r.directRule.Match(domain):
case r.DirectRule.Match(domain):

case r.proxyRule.Match(domain):
case r.ProxyRule.Match(domain):
_ = w.WriteMsg(r.dnsProxyA(domain, r.dns.serveIP, req))
return
}
Expand Down
27 changes: 9 additions & 18 deletions router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ import (

type ProxyDialFn func(network, host string, port uint16) (net.Conn, error)
type Router struct {
blockRule *suffixtree.Node
directRule *suffixtree.Node
proxyRule *suffixtree.Node
BlockRule *suffixtree.Node
DirectRule *suffixtree.Node
ProxyRule *suffixtree.Node
ProxyDial ProxyDialFn

dns struct {
Expand Down Expand Up @@ -52,24 +52,15 @@ func NewRouter(serveIP, fallbackDNS, mmdbFile string, proxyDial ProxyDialFn) *Ro
return &r
}

func (r *Router) SetBlockRules(blockList []string) {
r.blockRule = suffixtree.NewNodeFromRules(blockList...)
}
func (r *Router) SetDirectRules(directList []string) {
r.directRule = suffixtree.NewNodeFromRules(directList...)
}
func (r *Router) SetProxyRules(proxyList []string) {
r.proxyRule = suffixtree.NewNodeFromRules(proxyList...)
}
func (r *Router) SetCountryCIDRs(directCIDRs []string) {
r.country.cidrs = make([]*net.IPNet, 0, len(directCIDRs))
for _, cidr := range directCIDRs {
func (r *Router) AddCountryCIDRs(cidrs ...string) {
for _, cidr := range cidrs {
_, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
log.Error().Err(err).Msg("Failed to parse CIDR")
}
r.country.cidrs = append(r.country.cidrs, ipnet)
}
r.country.cidrs = suffixtree.GCSlice(r.country.cidrs)
}

func (r *Router) dialDNSConn() {
Expand Down Expand Up @@ -111,13 +102,13 @@ func (r *Router) RouteHandle(conn net.Conn, domain string, port uint16) (err err
// 2. detect_based( CN IP || access site )
// 3. fallback( proxy )
switch {
case r.blockRule.Match(domain):
case r.BlockRule.Match(domain):
return nil

case r.directRule.Match(domain):
case r.DirectRule.Match(domain):
return r.DirectHandle(conn, addr)

case r.proxyRule.Match(domain):
case r.ProxyRule.Match(domain):
return r.ProxyHandle(conn, domain, port)

case r.localSite(domain), r.isAccess(domain, port):
Expand Down

0 comments on commit 6fc97a3

Please sign in to comment.