diff --git a/proxy/mount.go b/proxy/mount.go index 09b532f5..866b23f5 100644 --- a/proxy/mount.go +++ b/proxy/mount.go @@ -20,6 +20,7 @@ import ( ) const mountTimeout = time.Second * 5 +const ADJUST_PRIORITY_THRESHOLD = 100000 type Mount struct { logger zerolog.Logger @@ -166,6 +167,22 @@ func mostPriority(sources []*MountSourceClient) *MountSourceClient { }) } +// adjustPriority lowers the priority values in the sources list passed +// by subtracing the current minimum priority from all the other values +func adjustPriority(sources []*MountSourceClient) { + if len(sources) == 0 { + return + } + + slices.SortStableFunc(sources, func(a, b *MountSourceClient) int { + return cmp.Compare(a.Priority, b.Priority) + }) + + for i := range sources { + sources[i].Priority = uint(i) + } +} + // MountSourceClient is a SourceClient with extra fields for mount-specific // bookkeeping type MountSourceClient struct { @@ -229,6 +246,10 @@ func (m *Mount) AddSource(ctx context.Context, source *SourceClient) { m.Sources = append(m.Sources, msc) go m.RunMountSourceClient(ctx, msc) + if msc.Priority > ADJUST_PRIORITY_THRESHOLD { + adjustPriority(m.Sources) + } + // send an event that we connected m.events.eventSourceConnect(ctx, source) // check if this is our first source, if it is we can bump them diff --git a/proxy/mount_test.go b/proxy/mount_test.go index 7c05a6cb..7fcff695 100644 --- a/proxy/mount_test.go +++ b/proxy/mount_test.go @@ -258,3 +258,57 @@ func TestMountMetadataWriterSendMetadata(t *testing.T) { assert.True(t, called, "metadataFn should've been called after going live") assert.Equal(t, meta.Value, calledValue) } + +type adjustPriorityTestCase struct { + name string + sources []*MountSourceClient + expected []uint +} + +func TestMountAdjustPriority(t *testing.T) { + // helper functions to create MountSourceClient in the test cases + prio := func(p uint) *MountSourceClient { + return &MountSourceClient{ + Source: &SourceClient{ + ID: radio.SourceID{xid.New()}, + }, + Priority: p, + } + } + prioSlice := func(ps ...uint) []*MountSourceClient { + var sources = make([]*MountSourceClient, 0, len(ps)) + for _, p := range ps { + sources = append(sources, prio(p)) + } + return sources + } + + prioCase := func(name string, sources []*MountSourceClient, expected []uint) adjustPriorityTestCase { + return adjustPriorityTestCase{ + name: name, + sources: sources, + expected: expected, + } + } + + testCases := []adjustPriorityTestCase{ + {"empty", prioSlice(), nil}, + {"nil", nil, nil}, + prioCase("simple gaps", prioSlice(5, 10, 15, 20), []uint{0, 1, 2, 3}), + prioCase("simple sequential", prioSlice(0, 1, 2, 3, 4, 5), []uint{0, 1, 2, 3, 4, 5}), + prioCase("reversed", prioSlice(10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0), []uint{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), + prioCase("random", prioSlice(8, 2, 1, 7, 4, 5, 0, 10, 9, 3, 6), []uint{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}), + prioCase("large gaps", prioSlice(500, 1510, 11215, 122320), []uint{0, 1, 2, 3}), + } + for _, c := range testCases { + t.Run(c.name, func(t *testing.T) { + adjustPriority(c.sources) + + for i, s := range c.sources { + if !assert.Equal(t, c.expected[i], s.Priority) { + t.Log(c.expected[i], s.Priority) + } + } + }) + } +}