Skip to content

Commit

Permalink
🔥 transition from native implementation of sorting to sort.SortSlice()
Browse files Browse the repository at this point in the history
  • Loading branch information
egorgasay committed May 26, 2023
1 parent 106bb52 commit 4cef029
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 84 deletions.
1 change: 0 additions & 1 deletion cmd/static/default-config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ Headers = [ "X-Forwarded-Proto:{{HTTP_PROTO}}", "X-Forwarded-For:{{SOURCE_IP}}"
# "IPLocal:Test2",
# "Local:False"
# ]
# [Proxy.HeadersByIP]

# Use https requests to backend instead of http
HTTPSBackend = false
Expand Down
102 changes: 23 additions & 79 deletions internal/proxy/directors.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package proxy

import (
"bytes"
"fmt"
"net"
"net/http"
"net/url"
"sort"
"strconv"

"github.com/rekby/lets-proxy2/internal/contextlabel"
Expand Down Expand Up @@ -202,90 +204,32 @@ func NewDirectorSetHeadersByIP(m map[string]HTTPHeaders) (DirectorSetHeadersByIP
Headers: v,
})
}

return sortByIPNet(res), nil
sortByIPNet(res)
return res, nil
}

// sortByIPNet sorts by CIDR using quicksort algorithm.
func sortByIPNet(d DirectorSetHeadersByIP) DirectorSetHeadersByIP {
ipv4 := make(DirectorSetHeadersByIP, 0, len(d))
ipv6 := make(DirectorSetHeadersByIP, 0, len(d))
for _, item := range d {
if item.IPNet.IP.To4() != nil {
ipv4 = append(ipv4, item)
} else {
ipv6 = append(ipv6, item)
}
}

ipv4 = quickSortByIPNet(ipv4)
ipv6 = quickSortByIPNet(ipv6)

return append(ipv4, ipv6...)
}

// quickSortByIPNet sorts by CIDR using quicksort algorithm.
// The result is sorted by IPNet.
// example:
//
// IN -> DirectorSetHeadersByIP{
// {IPNet: net.IPNet{IP: net.ParseIP("192.168.88.0"), Mask: net.CIDRMask(24, 32)}},
// {IPNet: net.IPNet{IP: net.ParseIP("192.0.0.0"), Mask: net.CIDRMask(8, 32)}},
// {IPNet: net.IPNet{IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(16, 32)}},
// {IPNet: net.IPNet{IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)}},
// {IPNet: net.IPNet{IP: net.ParseIP("192.168.99.0"), Mask: net.CIDRMask(24, 32)}},
// {IPNet: net.IPNet{IP: net.ParseIP("172.0.0.0"), Mask: net.CIDRMask(8, 32)}},
// },
//
// OUT <- DirectorSetHeadersByIP{
// {IPNet: net.IPNet{IP: net.ParseIP("192.0.0.0"), Mask: net.CIDRMask(8, 32)}},
// {IPNet: net.IPNet{IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)}},
// {IPNet: net.IPNet{IP: net.ParseIP("192.168.88.0"), Mask: net.CIDRMask(24, 32)}},
// {IPNet: net.IPNet{IP: net.ParseIP("192.168.99.0"), Mask: net.CIDRMask(24, 32)}},
// {IPNet: net.IPNet{IP: net.ParseIP("172.0.0.0"), Mask: net.CIDRMask(8, 32)}},
// {IPNet: net.IPNet{IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(16, 32)}},
// },
func quickSortByIPNet(d DirectorSetHeadersByIP) DirectorSetHeadersByIP {
if len(d) <= 1 {
return d
}

mid := len(d) / 2
left := d[:mid]
right := d[mid:]
func sortByIPNet(d DirectorSetHeadersByIP) {
sort.Slice(d, func(i, j int) bool {
left, right := d[i], d[j]

left = quickSortByIPNet(left)
right = quickSortByIPNet(right)

return mergeByIPNet(left, right)
}
maskOnes := func(m net.IPMask) int {
ones, _ := m.Size()
return ones
}

// mergeByIPNet merges two sorted arrays with CIDRs.
// The result is sorted by IPNet.
func mergeByIPNet(left, right DirectorSetHeadersByIP) DirectorSetHeadersByIP {
res := make(DirectorSetHeadersByIP, 0, len(left)+len(right))
for len(left) > 0 || len(right) > 0 {
if len(left) > 0 && len(right) > 0 {
if left[0].IPNet.Contains(right[0].IPNet.IP) {
res = append(res, left[0])
left = left[1:]
} else if right[0].IPNet.Contains(left[0].IPNet.IP) {
res = append(res, right[0])
right = right[1:]
} else {
res = append(res, left[0], right[0])
left = left[1:]
right = right[1:]
}
} else if len(left) > 0 {
res = append(res, left[0])
left = left[1:]
} else if len(right) > 0 {
res = append(res, right[0])
right = right[1:]
switch {
case len(left.IPNet.IP) < len(right.IPNet.IP):
return true
case len(left.IPNet.IP) > len(right.IPNet.IP):
return false
case maskOnes(left.IPNet.Mask) < maskOnes(right.IPNet.Mask):
return true
case maskOnes(left.IPNet.Mask) > maskOnes(right.IPNet.Mask):
return false
default:
return bytes.Compare(left.IPNet.IP, right.IPNet.IP) < 0
}
}
return res
})
}

func (h DirectorSetHeadersByIP) Director(request *http.Request) error {
Expand Down
8 changes: 4 additions & 4 deletions internal/proxy/directors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -386,12 +386,12 @@ func Test_sortByIPNet(t *testing.T) {
},
},
want: DirectorSetHeadersByIP{
{IPNet: net.IPNet{IP: net.ParseIP("172.0.0.0"), Mask: net.CIDRMask(8, 32)}},
{IPNet: net.IPNet{IP: net.ParseIP("192.0.0.0"), Mask: net.CIDRMask(8, 32)}},
{IPNet: net.IPNet{IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(16, 32)}},
{IPNet: net.IPNet{IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)}},
{IPNet: net.IPNet{IP: net.ParseIP("192.168.88.0"), Mask: net.CIDRMask(24, 32)}},
{IPNet: net.IPNet{IP: net.ParseIP("192.168.99.0"), Mask: net.CIDRMask(24, 32)}},
{IPNet: net.IPNet{IP: net.ParseIP("172.0.0.0"), Mask: net.CIDRMask(8, 32)}},
{IPNet: net.IPNet{IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(16, 32)}},
},
},
{
Expand Down Expand Up @@ -432,8 +432,8 @@ func Test_sortByIPNet(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := sortByIPNet(tt.args.d)
if !td.CmpDeeply(got, tt.want) {
sortByIPNet(tt.args.d)
if !td.CmpDeeply(tt.args.d, tt.want) {
t.Errorf("sortByIPNet() = %v, want %v", tt.args.d, tt.want)
}
})
Expand Down

0 comments on commit 4cef029

Please sign in to comment.