diff --git a/.golangci.yml b/.golangci.yml index b4ec05e..dfa86c5 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -118,6 +118,9 @@ issues: - noctx - path: '_test\.go$' text: "unnamedResult:" + - path: '.*mxresolv.*' + linters: + - gosec run: diff --git a/mxresolv/dns_test.go b/mxresolv/dns_test.go new file mode 100644 index 0000000..a464ce5 --- /dev/null +++ b/mxresolv/dns_test.go @@ -0,0 +1,43 @@ +package mxresolv_test + +import ( + "net" + "sync" + + "github.com/foxcpp/go-mockdns" +) + +type MockDNS struct { + Server *mockdns.Server + mu sync.Mutex +} + +func SpawnMockDNS(zones map[string]mockdns.Zone) (*MockDNS, error) { + server, err := mockdns.NewServerWithLogger(zones, nullLogger{}, false) + if err != nil { + return nil, err + } + return &MockDNS{ + Server: server, + }, nil +} + +func (f *MockDNS) Stop() { + _ = f.Server.Close() +} + +func (f *MockDNS) Patch(r *net.Resolver) { + f.mu.Lock() + defer f.mu.Unlock() + f.Server.PatchNet(r) +} + +func (f *MockDNS) UnPatch(r *net.Resolver) { + f.mu.Lock() + defer f.mu.Unlock() + mockdns.UnpatchNet(r) +} + +type nullLogger struct{} + +func (l nullLogger) Printf(_ string, _ ...interface{}) {} diff --git a/mxresolv/mxresolv.go b/mxresolv/mxresolv.go index 680acad..cdb3256 100644 --- a/mxresolv/mxresolv.go +++ b/mxresolv/mxresolv.go @@ -6,6 +6,7 @@ import ( "net" "sort" "strings" + "time" "unicode" _ "unsafe" // For go:linkname @@ -23,21 +24,23 @@ const ( var ( errNullMXRecord = errors.New("domain accepts no mail") errNoValidMXHosts = errors.New("no valid MX hosts") - lookupResultCache *collections.LRUCache - // It is modified only in tests to make them deterministic. - shuffle = true + // defaultSeed allows the seed function to be patched in tests using SetDeterministic() + defaultRand = newRand - // DefaultResolver is exposed mainly to be patched in tests to access a - // mock DNS server github.com/foxcpp/go-mockdns. - DefaultResolver = net.DefaultResolver + // Resolver is exposed to be patched in tests + Resolver = net.DefaultResolver ) func init() { lookupResultCache = collections.NewLRUCache(cacheSize) } +func newRand() *rand.Rand { + return rand.New(rand.NewSource(time.Now().UnixNano())) +} + // Lookup performs a DNS lookup of MX records for the specified hostname. It // returns a prioritised list of MX hostnames, where hostnames with the same // priority are shuffled. If the second returned value is true, then the host @@ -57,7 +60,7 @@ func Lookup(ctx context.Context, hostname string) (retMxHosts []string, retImpli if err != nil { return nil, false, errors.Wrap(err, "invalid hostname") } - mxRecords, err := lookupMX(DefaultResolver, ctx, asciiHostname) + mxRecords, err := lookupMX(Resolver, ctx, asciiHostname) if err != nil { var timeouter interface{ Timeout() bool } if errors.As(err, &timeouter) && timeouter.Timeout() { @@ -65,7 +68,7 @@ func Lookup(ctx context.Context, hostname string) (retMxHosts []string, retImpli } var netDNSError *net.DNSError if errors.As(err, &netDNSError) && netDNSError.Err == "no such host" { - if _, err := DefaultResolver.LookupIPAddr(ctx, asciiHostname); err != nil { + if _, err := Resolver.LookupIPAddr(ctx, asciiHostname); err != nil { return cacheAndReturn(hostname, nil, nil, false, errors.WithStack(err)) } return cacheAndReturn(hostname, []string{asciiHostname}, nil, true, nil) @@ -105,21 +108,45 @@ func Lookup(ctx context.Context, hostname string) (retMxHosts []string, retImpli return cacheAndReturn(hostname, mxHosts, mxRecords, false, nil) } +// SetDeterministic sets rand to deterministic seed for testing, and is not Thread-Safe +func SetDeterministic() func() { + r := rand.New(rand.NewSource(1)) + defaultRand = func() *rand.Rand { return r } + return func() { + defaultRand = newRand + } +} + +// ResetCache clears the cache for use in tests, and is not Thread-Safe +func ResetCache() { + lookupResultCache = collections.NewLRUCache(1000) +} + func shuffleMXRecords(mxRecords []*net.MX) []string { - // Shuffle records within preference groups unless disabled in tests. - if shuffle { - mxRecordCount := len(mxRecords) - 1 - groupBegin := 0 - for i := 1; i <= mxRecordCount; i++ { - if mxRecords[i].Pref != mxRecords[groupBegin].Pref || i == mxRecordCount { - groupSlice := mxRecords[groupBegin:i] - rand.Shuffle(len(groupSlice), func(i, j int) { - groupSlice[i], groupSlice[j] = groupSlice[j], groupSlice[i] - }) - groupBegin = i - } + r := defaultRand() + + // Shuffle the hosts within the preference groups + begin := 0 + for i := 0; i <= len(mxRecords); i++ { + // If we are on the last record shuffle the last preference group + if i == len(mxRecords) { + group := mxRecords[begin:i] + r.Shuffle(len(group), func(i, j int) { + group[i], group[j] = group[j], group[i] + }) + break + } + + // After finding the end of a preference group, shuffle it + if mxRecords[begin].Pref != mxRecords[i].Pref { + group := mxRecords[begin:i] + r.Shuffle(len(group), func(i, j int) { + group[i], group[j] = group[j], group[i] + }) + begin = i } } + // Make a hostname list, but skip non-ASCII names, that cause issues. mxHosts := make([]string, 0, len(mxRecords)) for _, mxRecord := range mxRecords { diff --git a/mxresolv/mxresolv_test.go b/mxresolv/mxresolv_test.go index 7a6f3bb..c12a654 100644 --- a/mxresolv/mxresolv_test.go +++ b/mxresolv/mxresolv_test.go @@ -1,22 +1,116 @@ -package mxresolv +package mxresolv_test import ( "context" + "fmt" "math" + "math/rand" "net" + "os" "regexp" "sort" "testing" + "time" + "github.com/foxcpp/go-mockdns" "github.com/mailgun/holster/v4/clock" - "github.com/mailgun/holster/v4/collections" "github.com/mailgun/holster/v4/errors" + "github.com/mailgun/holster/v4/mxresolv" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestMain(m *testing.M) { + zones := map[string]mockdns.Zone{ + "test-a.definbox.com.": { + A: []string{"192.168.19.2"}, + }, + "test-cname.definbox.com.": { + CNAME: "definbox.com.", + }, + "definbox.com.": { + MX: []net.MX{ + {Host: "mxa.ninomail.com.", Pref: 10}, + {Host: "mxb.ninomail.com.", Pref: 10}, + }, + }, + "prefer.example.com.": { + MX: []net.MX{ + {Host: "mxa.example.com.", Pref: 20}, + {Host: "mxb.example.com.", Pref: 1}, + }, + }, + "prefer3.example.com.": { + MX: []net.MX{ + {Host: "mxa.example.com.", Pref: 1}, + {Host: "mxb.example.com.", Pref: 1}, + {Host: "mxc.example.com.", Pref: 2}, + }, + }, + "test-unicode.definbox.com.": { + MX: []net.MX{ + {Host: "mxa.definbox.com.", Pref: 1}, + {Host: "ex\\228mple.com.", Pref: 2}, + {Host: "mxb.definbox.com.", Pref: 3}, + }, + }, + "test-underscore.definbox.com.": { + MX: []net.MX{ + {Host: "foo_bar.definbox.com.", Pref: 1}, + }, + }, + "xn--test--xweh4bya7b6j.definbox.com.": { + MX: []net.MX{ + {Host: "xn--test---mofb0ab4b8camvcmn8gxd.definbox.com.", Pref: 10}, + }, + }, + "test-mx-ipv4.definbox.com.": { + MX: []net.MX{ + {Host: "34.150.176.225.", Pref: 10}, + }, + }, + "test-mx-ipv6.definbox.com.": { + MX: []net.MX{ + {Host: "::ffff:2296:b0e1.", Pref: 10}, + }, + }, + "example.com.": { + MX: []net.MX{ + {Host: ".", Pref: 0}, + }, + }, + "test-mx-zero.definbox.com.": { + MX: []net.MX{ + {Host: "0.0.0.0.", Pref: 0}, + }, + }, + "test-mx.definbox.com.": { + MX: []net.MX{ + {Host: "mxg.definbox.com.", Pref: 3}, + {Host: "mxa.definbox.com.", Pref: 1}, + {Host: "mxe.definbox.com.", Pref: 1}, + {Host: "mxi.definbox.com.", Pref: 1}, + {Host: "mxd.definbox.com.", Pref: 3}, + {Host: "mxc.definbox.com.", Pref: 2}, + {Host: "mxb.definbox.com.", Pref: 3}, + {Host: "mxf.definbox.com.", Pref: 3}, + {Host: "mxh.definbox.com.", Pref: 3}, + }, + }, + } + server, err := SpawnMockDNS(zones) + if err != nil { + panic(err) + } + + server.Patch(mxresolv.Resolver) + exitVal := m.Run() + server.UnPatch(mxresolv.Resolver) + server.Stop() + os.Exit(exitVal) +} + func TestLookup(t *testing.T) { - defer disableShuffle()() for _, tc := range []struct { desc string inDomainName string @@ -26,9 +120,10 @@ func TestLookup(t *testing.T) { desc: "MX record preference is respected", inDomainName: "test-mx.definbox.com", outMXHosts: []string{ - /* 1 */ "mxa.definbox.com", "mxe.definbox.com", "mxi.definbox.com", + /* 1 */ "mxa.definbox.com", "mxi.definbox.com", "mxe.definbox.com", /* 2 */ "mxc.definbox.com", - /* 3 */ "mxb.definbox.com", "mxd.definbox.com", "mxf.definbox.com", "mxg.definbox.com", "mxh.definbox.com"}, + /* 3 */ "mxb.definbox.com", "mxf.definbox.com", "mxh.definbox.com", "mxd.definbox.com", "mxg.definbox.com", + }, outImplicitMX: false, }, { inDomainName: "test-a.definbox.com", @@ -38,6 +133,10 @@ func TestLookup(t *testing.T) { inDomainName: "test-cname.definbox.com", outMXHosts: []string{"mxa.ninomail.com", "mxb.ninomail.com"}, outImplicitMX: false, + }, { + inDomainName: "definbox.com", + outMXHosts: []string{"mxa.ninomail.com", "mxb.ninomail.com"}, + outImplicitMX: false, }, { desc: "If an MX host returned by the resolver contains non ASCII " + "characters then it is silently dropped from the returned list", @@ -67,27 +166,78 @@ func TestLookup(t *testing.T) { outImplicitMX: false, }} { t.Run(tc.inDomainName, func(t *testing.T) { + defer mxresolv.SetDeterministic()() + // When ctx, cancel := context.WithTimeout(context.Background(), 3*clock.Second) - mxHosts, explictMX, err := Lookup(ctx, tc.inDomainName) - cancel() + defer cancel() + mxHosts, explictMX, err := mxresolv.Lookup(ctx, tc.inDomainName) // Then assert.NoError(t, err) assert.Equal(t, tc.outMXHosts, mxHosts) assert.Equal(t, tc.outImplicitMX, explictMX) - - // The second lookup returns the cached result, that only shows on the - // coverage report. - mxHosts, explictMX, err = Lookup(ctx, tc.inDomainName) - assert.NoError(t, err) - assert.Equal(t, tc.outMXHosts, mxHosts) - assert.Equal(t, tc.outImplicitMX, explictMX) }) } } +func TestLookupRegression(t *testing.T) { + defer mxresolv.SetDeterministic()() + mxresolv.ResetCache() + + // When + ctx, cancel := context.WithTimeout(context.Background(), 3*clock.Second) + defer cancel() + + mxHosts, explictMX, err := mxresolv.Lookup(ctx, "test-mx.definbox.com") + // Then + assert.NoError(t, err) + assert.Equal(t, []string{ + "mxa.definbox.com", "mxi.definbox.com", "mxe.definbox.com", "mxc.definbox.com", + "mxb.definbox.com", "mxf.definbox.com", "mxh.definbox.com", "mxd.definbox.com", + "mxg.definbox.com", + }, mxHosts) + assert.Equal(t, false, explictMX) + + // The second lookup returns the cached result, the cached result is shuffled. + mxHosts, explictMX, err = mxresolv.Lookup(ctx, "test-mx.definbox.com") + assert.NoError(t, err) + assert.Equal(t, []string{ + "mxi.definbox.com", "mxe.definbox.com", "mxa.definbox.com", "mxc.definbox.com", + "mxg.definbox.com", "mxh.definbox.com", "mxd.definbox.com", "mxf.definbox.com", + "mxb.definbox.com", + }, mxHosts) + assert.Equal(t, false, explictMX) + + mxHosts, _, err = mxresolv.Lookup(ctx, "definbox.com") + assert.NoError(t, err) + assert.Equal(t, []string{"mxb.ninomail.com", "mxa.ninomail.com"}, mxHosts) + + // Should always prefer mxb over mxa since mxb has a lower pref than mxa + for i := 0; i < 100; i++ { + mxHosts, _, err = mxresolv.Lookup(ctx, "prefer.example.com") + assert.NoError(t, err) + assert.Equal(t, []string{"mxb.example.com", "mxa.example.com"}, mxHosts) + } + + // Should randomly order mxa and mxb while mxc should always be last + mxHosts, _, err = mxresolv.Lookup(ctx, "prefer3.example.com") + assert.NoError(t, err) + assert.Equal(t, []string{"mxb.example.com", "mxa.example.com", "mxc.example.com"}, mxHosts) + + mxHosts, _, err = mxresolv.Lookup(ctx, "prefer3.example.com") + assert.NoError(t, err) + assert.Equal(t, []string{"mxa.example.com", "mxb.example.com", "mxc.example.com"}, mxHosts) + + // 'mxc.example.com' should always be last as it has a different priority than the other two. + for i := 0; i < 100; i++ { + mxHosts, _, err = mxresolv.Lookup(ctx, "prefer3.example.com") + assert.NoError(t, err) + assert.Equal(t, "mxc.example.com", mxHosts[2]) + } + +} + func TestLookupError(t *testing.T) { - defer disableShuffle()() for _, tc := range []struct { desc string inDomainName string @@ -101,11 +251,10 @@ func TestLookupError(t *testing.T) { inDomainName: "", outError: "lookup : no such host", }, - // TODO: fix https://github.com/mailgun/holster/issues/155: - // { - // inDomainName: "kaboom", - // outError: "lookup kaboom.*: no such host", - // }, + { + inDomainName: "kaboom", + outError: "lookup kaboom.*: no such host", + }, { // MX 0 . inDomainName: "example.com", @@ -120,8 +269,8 @@ func TestLookupError(t *testing.T) { t.Run(tc.inDomainName, func(t *testing.T) { // When ctx, cancel := context.WithTimeout(context.Background(), 3*clock.Second) - _, _, err := Lookup(ctx, tc.inDomainName) - cancel() + defer cancel() + _, _, err := mxresolv.Lookup(ctx, tc.inDomainName) // Then require.Error(t, err) @@ -136,7 +285,7 @@ func TestLookupError(t *testing.T) { // The second lookup returns the cached result, that only shows on the // coverage report. - _, _, err = Lookup(ctx, tc.inDomainName) + _, _, err = mxresolv.Lookup(ctx, tc.inDomainName) assert.Regexp(t, regexp.MustCompile(tc.outError), err.Error()) }) } @@ -149,17 +298,15 @@ func TestLookupError(t *testing.T) { // 1: mxa.definbox.com, mxe.definbox.com, mxi.definbox.com // 2: mxc.definbox.com // 3: mxb.definbox.com, mxd.definbox.com, mxf.definbox.com, mxg.definbox.com, mxh.definbox.com -// -// Warning: since the data set is pretty small subsequent shuffles can produce -// the same result causing the test to fail. func TestLookupShuffle(t *testing.T) { + defer mxresolv.SetDeterministic()() + // When ctx, cancel := context.WithTimeout(context.Background(), 3*clock.Second) defer cancel() - shuffle1, _, err := Lookup(ctx, "test-mx.definbox.com") + shuffle1, _, err := mxresolv.Lookup(ctx, "test-mx.definbox.com") assert.NoError(t, err) - resetCache() - shuffle2, _, err := Lookup(ctx, "test-mx.definbox.com") + shuffle2, _, err := mxresolv.Lookup(ctx, "test-mx.definbox.com") assert.NoError(t, err) // Then @@ -176,49 +323,31 @@ func TestLookupShuffle(t *testing.T) { sort.Strings(shuffle1[4:]) sort.Strings(shuffle2[4:]) - assert.Equal(t, []string{"mxb.definbox.com", "mxd.definbox.com", "mxf.definbox.com", "mxg.definbox.com", "mxh.definbox.com"}, shuffle1[4:]) + assert.Equal(t, []string{"mxb.definbox.com", "mxd.definbox.com", "mxf.definbox.com", + "mxg.definbox.com", "mxh.definbox.com"}, shuffle1[4:]) assert.Equal(t, shuffle1[4:], shuffle2[4:]) } -func TestShuffle(t *testing.T) { - in := []*net.MX{ - {Host: "mxa.definbox.com", Pref: 1}, - {Host: "mxe.definbox.com", Pref: 1}, - {Host: "mxi.definbox.com", Pref: 1}, - {Host: "mxc.definbox.com", Pref: 2}, - {Host: "mxb.definbox.com", Pref: 3}, - {Host: "mxd.definbox.com", Pref: 3}, - {Host: "mxf.definbox.com", Pref: 3}, - {Host: "mxg.definbox.com", Pref: 3}, - {Host: "mxh.definbox.com", Pref: 3}, - } - out := shuffleMXRecords(in) - assert.Equal(t, 9, len(out)) - - // This is a regression test, previous implementation of shuffleMXRecords() would - // only return 1 MX record if there were 2 MX records with the same preference number. - in = []*net.MX{ - {Host: "mxa.definbox.com", Pref: 5}, - {Host: "mxe.definbox.com", Pref: 5}, - } - out = shuffleMXRecords(in) - assert.Equal(t, 2, len(out)) +func TestDistribution(t *testing.T) { + mxresolv.ResetCache() - in = []*net.MX{ - {Host: "mxa.definbox.com", Pref: 5}, + // 2 host distribution should be uniform + dist := make(map[string]int, 2) + for i := 0; i < 1000; i++ { + s, _, _ := mxresolv.Lookup(context.Background(), "definbox.com") + _, ok := dist[s[0]] + if ok { + dist[s[0]] += 1 + } else { + dist[s[0]] = 0 + } } - out = shuffleMXRecords(in) - assert.Equal(t, 1, len(out)) - // Should not panic - out = shuffleMXRecords([]*net.MX{}) - assert.Equal(t, 0, len(out)) -} + assertDistribution(t, dist, 35.0) -func TestDistribution(t *testing.T) { - dist := make(map[string]int, 3) + dist = make(map[string]int, 3) for i := 0; i < 1000; i++ { - s, _, _ := Lookup(context.Background(), "test-mx.definbox.com") + s, _, _ := mxresolv.Lookup(context.Background(), "test-mx.definbox.com") _, ok := dist[s[0]] if ok { dist[s[0]] += 1 @@ -226,6 +355,50 @@ func TestDistribution(t *testing.T) { dist[s[0]] = 0 } } + assertDistribution(t, dist, 35.0) + + // This is what a standard distribution looks like when 3 hosts have the same MX priority + // spew.Dump(dist) + // (map[string]int) (len=3) { + // (string) (len=16) "mxa.definbox.com": (int) 324, + // (string) (len=16) "mxe.definbox.com": (int) 359, + // (string) (len=16) "mxi.definbox.com": (int) 314 + // } +} + +// Golang optimizes the allocation so there is no hit to performance or memory usage when calling +// `rand.New()` for each call to `shuffleNew()` over `rand.Shuffle()` which has a mutex. +// +// pkg: github.com/mailgun/holster/v4/mxresolv +// BenchmarkShuffleWithNew +// BenchmarkShuffleWithNew-10 61962 18434 ns/op 5376 B/op 1 allocs/op +// BenchmarkShuffleGlobal +// BenchmarkShuffleGlobal-10 65205 18480 ns/op 0 B/op 0 allocs/op +func BenchmarkShuffleWithNew(b *testing.B) { + for n := b.N; n > 0; n-- { + shuffleNew() + } + b.ReportAllocs() +} + +func BenchmarkShuffleGlobal(b *testing.B) { + for n := b.N; n > 0; n-- { + shuffleGlobal() + } + b.ReportAllocs() +} + +func shuffleNew() { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + r.Shuffle(52, func(i, j int) {}) +} + +func shuffleGlobal() { + rand.Shuffle(52, func(i, j int) {}) +} + +func assertDistribution(t *testing.T, dist map[string]int, expected float64) { + t.Helper() // Calculate the mean of the distribution var sum int @@ -245,25 +418,8 @@ func TestDistribution(t *testing.T) { variance := squaredDifferences / float64(len(dist)) stdDev := math.Sqrt(variance) - // The distribution of random hosts chosen should not exceed 30 - assert.False(t, stdDev > 30.0, "Standard deviation is greater than 30: %.2f", stdDev) - - // For example this is what a standard distribution looks like when 3 hosts have the same MX priority - // spew.Dump(dist) - // (map[string]int) (len=3) { - // (string) (len=16) "mxa.definbox.com": (int) 324, - // (string) (len=16) "mxe.definbox.com": (int) 359, - // (string) (len=16) "mxi.definbox.com": (int) 314 - // } -} - -func disableShuffle() func() { - shuffle = false - return func() { - shuffle = true - } -} + // The distribution of random hosts chosen should not exceed 35 + assert.False(t, stdDev > expected, + fmt.Sprintf("Standard deviation is greater than %f:", expected)+"%.2f", stdDev) -func resetCache() { - lookupResultCache = collections.NewLRUCache(1000) }