diff --git a/mxresolv/mxresolv.go b/mxresolv/mxresolv.go index a061c02..680acad 100644 --- a/mxresolv/mxresolv.go +++ b/mxresolv/mxresolv.go @@ -108,10 +108,10 @@ func Lookup(ctx context.Context, hostname string) (retMxHosts []string, retImpli func shuffleMXRecords(mxRecords []*net.MX) []string { // Shuffle records within preference groups unless disabled in tests. if shuffle { - mxRecordCount := len(mxRecords) + mxRecordCount := len(mxRecords) - 1 groupBegin := 0 - for i := 1; i < mxRecordCount; i++ { - if mxRecords[i].Pref != mxRecords[groupBegin].Pref || i == mxRecordCount-1 { + for i := 1; i <= mxRecordCount; i++ { + if mxRecords[i].Pref != mxRecords[groupBegin].Pref || i == mxRecordCount { groupSlice := mxRecords[groupBegin:i] rand.Shuffle(len(groupSlice), func(i, j int) { groupSlice[i], groupSlice[j] = groupSlice[j], groupSlice[i] diff --git a/mxresolv/mxresolv_test.go b/mxresolv/mxresolv_test.go index 068f811..7a6f3bb 100644 --- a/mxresolv/mxresolv_test.go +++ b/mxresolv/mxresolv_test.go @@ -3,6 +3,7 @@ package mxresolv import ( "context" "math" + "net" "regexp" "sort" "testing" @@ -179,6 +180,41 @@ func TestLookupShuffle(t *testing.T) { assert.Equal(t, shuffle1[4:], shuffle2[4:]) } +func TestShuffle(t *testing.T) { + in := []*net.MX{ + {Host: "mxa.definbox.com", Pref: 1}, + {Host: "mxe.definbox.com", Pref: 1}, + {Host: "mxi.definbox.com", Pref: 1}, + {Host: "mxc.definbox.com", Pref: 2}, + {Host: "mxb.definbox.com", Pref: 3}, + {Host: "mxd.definbox.com", Pref: 3}, + {Host: "mxf.definbox.com", Pref: 3}, + {Host: "mxg.definbox.com", Pref: 3}, + {Host: "mxh.definbox.com", Pref: 3}, + } + out := shuffleMXRecords(in) + assert.Equal(t, 9, len(out)) + + // This is a regression test, previous implementation of shuffleMXRecords() would + // only return 1 MX record if there were 2 MX records with the same preference number. + in = []*net.MX{ + {Host: "mxa.definbox.com", Pref: 5}, + {Host: "mxe.definbox.com", Pref: 5}, + } + out = shuffleMXRecords(in) + assert.Equal(t, 2, len(out)) + + in = []*net.MX{ + {Host: "mxa.definbox.com", Pref: 5}, + } + out = shuffleMXRecords(in) + assert.Equal(t, 1, len(out)) + + // Should not panic + out = shuffleMXRecords([]*net.MX{}) + assert.Equal(t, 0, len(out)) +} + func TestDistribution(t *testing.T) { dist := make(map[string]int, 3) for i := 0; i < 1000; i++ {