Skip to content

Commit

Permalink
refactor: reduce memory usage
Browse files Browse the repository at this point in the history
  • Loading branch information
wweir committed Nov 15, 2021
1 parent 8c9a0bc commit 0f029f8
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 49 deletions.
26 changes: 14 additions & 12 deletions cmd/sower/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"net/url"
"os"
"runtime"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -96,8 +97,10 @@ func init() {
func main() {
proxtDial := GenProxyDial(conf.Remote.Type, conf.Remote.Addr, conf.Remote.Password)
r := router.NewRouter(conf.DNS.Serve, conf.DNS.Fallback, conf.Router.Country.MMDB, proxtDial)
r.SetRules(conf.Router.Block.Rules, conf.Router.Direct.Rules, conf.Router.Proxy.Rules,
conf.Router.Country.Rules)
r.SetBlockRules(conf.Router.Block.Rules)
r.SetDirectRules(conf.Router.Direct.Rules)
r.SetProxyRules(conf.Router.Proxy.Rules)
r.SetCountryCIDRs(conf.Router.Country.Rules)

go func() {
if conf.DNS.Disable {
Expand Down Expand Up @@ -140,16 +143,14 @@ func main() {
}()

start := time.Now()
conf.Router.Block.Rules = append(conf.Router.Block.Rules,
loadRules(proxtDial, conf.Router.Block.File, conf.Router.Block.FilePrefix)...)
conf.Router.Direct.Rules = append(conf.Router.Direct.Rules,
loadRules(proxtDial, conf.Router.Direct.File, conf.Router.Direct.FilePrefix)...)
conf.Router.Proxy.Rules = append(conf.Router.Proxy.Rules,
loadRules(proxtDial, conf.Router.Proxy.File, conf.Router.Proxy.FilePrefix)...)
conf.Router.Country.Rules = append(conf.Router.Country.Rules,
loadRules(proxtDial, conf.Router.Country.File, conf.Router.Country.FilePrefix)...)
r.SetRules(conf.Router.Block.Rules, conf.Router.Direct.Rules, conf.Router.Proxy.Rules,
conf.Router.Country.Rules)
r.SetBlockRules(append(conf.Router.Block.Rules,
loadRules(proxtDial, conf.Router.Block.File, conf.Router.Block.FilePrefix)...))
r.SetDirectRules(append(conf.Router.Direct.Rules,
loadRules(proxtDial, conf.Router.Direct.File, conf.Router.Direct.FilePrefix)...))
r.SetProxyRules(append(conf.Router.Proxy.Rules,
loadRules(proxtDial, conf.Router.Proxy.File, conf.Router.Proxy.FilePrefix)...))
r.SetCountryCIDRs(append(conf.Router.Country.Rules,
loadRules(proxtDial, conf.Router.Country.File, conf.Router.Country.FilePrefix)...))

log.Info().
Dur("spend", time.Since(start)).
Expand All @@ -158,6 +159,7 @@ func main() {
Int("proxyRule", len(conf.Router.Proxy.Rules)).
Int("countryRule", len(conf.Router.Country.Rules)).
Msg("Loaded rules, proxy started")
runtime.GC()
select {}
}

Expand Down
10 changes: 7 additions & 3 deletions router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,16 @@ func NewRouter(serveIP, fallbackDNS, mmdbFile string, proxyDial ProxyDialFn) *Ro
return &r
}

func (r *Router) SetRules(blockList, directList, proxyList, directCIDRs []string) {

func (r *Router) SetBlockRules(blockList []string) {
r.blockRule = util.NewNodeFromRules(blockList...)
}
func (r *Router) SetDirectRules(directList []string) {
r.directRule = util.NewNodeFromRules(directList...)
}
func (r *Router) SetProxyRules(proxyList []string) {
r.proxyRule = util.NewNodeFromRules(proxyList...)

}
func (r *Router) SetCountryCIDRs(directCIDRs []string) {
r.country.cidrs = make([]*net.IPNet, 0, len(directCIDRs))
for _, cidr := range directCIDRs {
_, ipnet, err := net.ParseCIDR(cidr)
Expand Down
69 changes: 35 additions & 34 deletions util/suffix_tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ func NewNodeFromRules(rules ...string) *Node {
}

func (n *node) lite() *node {
if n == nil {
return nil
}

lite := &node{
secs: make([]string, 0, len(n.secs)),
subNodes: make([]*node, 0, len(n.subNodes)),
Expand All @@ -41,6 +45,9 @@ func (n *Node) String() string {
return n.string("", " ")
}
func (n *node) string(prefix, indent string) (out string) {
if n == nil {
return
}
for key, val := range n.subNodes {
out += prefix + n.secs[key] + "\n" + val.string(prefix+indent, indent)
}
Expand All @@ -60,34 +67,39 @@ func (n *node) add(secs []string) {
case 0:
case 1:
sec := secs[length-1]
subNode := &node{secs: []string{""}, subNodes: []*node{{}}}
switch sec {
case "", "*", "**":
n.secs = append([]string{sec}, n.secs...)
n.subNodes = append([]*node{subNode}, n.subNodes...)
n.subNodes = append([]*node{nil}, n.subNodes...)
default:
n.secs = append(n.secs, sec)
n.subNodes = append(n.subNodes, subNode)
n.subNodes = append(n.subNodes, nil)
}
default:
sec := secs[length-1]
if sec == "**" {
if sec == "**" { // ** is only allowed in the last sec
sec = "*"
}

subNode, ok := n.find(sec)
if !ok {
subNode = &node{}
idx := n.index(sec)
if idx == -1 {
switch sec {
case "", "*", "**":
idx = 0
n.secs = append([]string{sec}, n.secs...)
n.subNodes = append([]*node{subNode}, n.subNodes...)
n.subNodes = append([]*node{{}}, n.subNodes...)
default:
idx = len(n.secs)
n.secs = append(n.secs, sec)
n.subNodes = append(n.subNodes, subNode)
n.subNodes = append(n.subNodes, &node{})
}

} else if n.subNodes[idx] == nil {
n.subNodes[idx] = &node{}
n.subNodes[idx].add([]string{""})
}
subNode.add(secs[:length-1])

n.subNodes[idx].add(secs[:length-1])
}
}

Expand All @@ -102,45 +114,34 @@ func (n *Node) Match(item string) bool {
func (n *node) matchSecs(secs []string, fuzzNode bool) bool {
length := len(secs)
if length == 0 {
if len(n.secs) == 0 {
return true
}
if _, ok := n.find(""); ok {
return true
}
if _, ok := n.find("**"); ok {
if n == nil {
return true
}
if _, ok := n.find("*"); ok {
return !fuzzNode
}
return false
return n.index("") != -1
}

if n, ok := n.find(secs[length-1]); ok {
if n.matchSecs(secs[:length-1], false) {
if idx := n.index(secs[length-1]); idx >= 0 {
if n.subNodes[idx].matchSecs(secs[:length-1], false) {
return true
}
}
if n, ok := n.find("*"); ok {
if n.matchSecs(secs[:length-1], true) {
if idx := n.index("*"); idx >= 0 {
if n.subNodes[idx].matchSecs(secs[:length-1], true) {
return true
}
}
if _, ok := n.find("**"); ok {
return true
}

return false
return n.index("**") >= 0
}
func (n *node) find(sec string) (*node, bool) {

// index return the sec index in node, or -1 if not found
func (n *node) index(sec string) int {
if n == nil {
return nil, false
return -1
}
for s := range n.secs {
if n.secs[s] == sec {
return n.subNodes[s], true
return s
}
}
return nil, false
return -1
}

0 comments on commit 0f029f8

Please sign in to comment.