diff --git a/tracker/http.go b/tracker/http.go index df7b9694..4158dc14 100644 --- a/tracker/http.go +++ b/tracker/http.go @@ -11,14 +11,16 @@ import ( "github.com/rs/zerolog/hlog" ) +const ( + ICECAST_AUTH_HEADER = "icecast-auth-user" + ICECAST_CLIENTID_FIELD_NAME = "client" +) + func ListenerAdd(ctx context.Context, recorder *Recorder) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("icecast-auth-user", "1") - w.WriteHeader(http.StatusOK) - _ = r.ParseForm() - id := r.FormValue("client") + id := r.FormValue(ICECAST_CLIENTID_FIELD_NAME) if id == "" { // icecast send us no client id somehow, this is broken and // we can't record this listener @@ -33,17 +35,23 @@ func ListenerAdd(ctx context.Context, recorder *Recorder) http.HandlerFunc { return } + // only return OK if we got the required ID from icecast + w.Header().Set(ICECAST_AUTH_HEADER, "1") + w.WriteHeader(http.StatusOK) + go recorder.ListenerAdd(ctx, cid, r) } } func ListenerRemove(ctx context.Context, recorder *Recorder) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { + // always return OK because it doesn't really matter if the + // rest of the request is broken w.WriteHeader(http.StatusOK) _ = r.ParseForm() - id := r.FormValue("client") + id := r.FormValue(ICECAST_CLIENTID_FIELD_NAME) if id == "" { // icecast send us no client id somehow, this is broken and // we can't record this listener diff --git a/tracker/http_test.go b/tracker/http_test.go new file mode 100644 index 00000000..2a0d1b9d --- /dev/null +++ b/tracker/http_test.go @@ -0,0 +1,99 @@ +package tracker + +import ( + "context" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestListenerAddAndRemove(t *testing.T) { + ctx := context.Background() + + recorder := NewRecorder() + dummy := NewServer(ctx, "", recorder) + + srv := httptest.NewServer(dummy.Handler) + defer srv.Close() + client := srv.Client() + + t.Run("join then leave", func(t *testing.T) { + id := ClientID(500) + + // ======================== + // Do a normal join request + resp, err := client.PostForm(srv.URL+"/listener_joined", url.Values{ + ICECAST_CLIENTID_FIELD_NAME: []string{id.String()}, + }) + require.NoError(t, err) + require.NotNil(t, resp) + resp.Body.Close() + + // status should be OK + require.Equal(t, http.StatusOK, resp.StatusCode) + // and we should have the OK header that icecast needs + require.Equal(t, "1", resp.Header.Get(ICECAST_AUTH_HEADER)) + + // we also should have a listener in the recorder + require.Eventually(t, func() bool { + return assert.Equal(t, int64(1), recorder.ListenerAmount()) + }, eventuallyDelay, eventuallyTick) + testListenerLength(t, recorder, 1) + + // ========================= + // Do a normal leave request + resp, err = client.PostForm(srv.URL+"/listener_left", url.Values{ + ICECAST_CLIENTID_FIELD_NAME: []string{id.String()}, + }) + require.NoError(t, err) + require.NotNil(t, resp) + resp.Body.Close() + + // status should be OK again + require.Equal(t, http.StatusOK, resp.StatusCode) + // and the listener should now be gone + require.Eventually(t, func() bool { + return assert.Equal(t, int64(0), recorder.ListenerAmount()) + }, eventuallyDelay, eventuallyTick) + + testListenerLength(t, recorder, 0) + }) + + for _, uri := range []string{"/listener_joined", "/listener_left"} { + t.Run("empty client"+uri, func(t *testing.T) { + // ======================================== + // Try an empty client ID, this should fail + resp, err := client.PostForm(srv.URL+uri, url.Values{ + ICECAST_CLIENTID_FIELD_NAME: []string{}, + }) + require.NoError(t, err) + require.NotNil(t, resp) + resp.Body.Close() + + // status should still be OK + require.Equal(t, http.StatusOK, resp.StatusCode) + // but it should not have the OK header + require.Zero(t, resp.Header.Get(ICECAST_AUTH_HEADER)) + }) + + t.Run("non-integer client"+uri, func(t *testing.T) { + // ======================================== + // Try a non-integer client ID, this should fail + resp, err := client.PostForm(srv.URL+uri, url.Values{ + ICECAST_CLIENTID_FIELD_NAME: []string{"not an integer"}, + }) + require.NoError(t, err) + require.NotNil(t, resp) + resp.Body.Close() + + // status should still be OK + require.Equal(t, http.StatusOK, resp.StatusCode) + // but it should not have the OK header + require.Zero(t, resp.Header.Get(ICECAST_AUTH_HEADER)) + }) + } +} diff --git a/tracker/main.go b/tracker/main.go index 27eb93d0..3ddc6ec1 100644 --- a/tracker/main.go +++ b/tracker/main.go @@ -4,28 +4,19 @@ import ( "context" "time" + radio "github.com/R-a-dio/valkyrie" "github.com/R-a-dio/valkyrie/config" "github.com/rs/zerolog" ) +var UpdateListenersTickrate = time.Second * 10 + func Execute(ctx context.Context, cfg config.Config) error { manager := cfg.Conf().Manager.Client() var recorder = NewRecorder() - go func() { - ticker := time.NewTicker(time.Second * 10) - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - err := manager.UpdateListeners(ctx, recorder.ListenerAmount()) - if err != nil { - zerolog.Ctx(ctx).Error().Err(err).Msg("failed update listeners") - } - } - } - }() + + go PeriodicallyUpdateListeners(ctx, manager, recorder) srv := NewServer(ctx, ":9999", recorder) @@ -41,3 +32,20 @@ func Execute(ctx context.Context, cfg config.Config) error { return err } } + +func PeriodicallyUpdateListeners(ctx context.Context, manager radio.ManagerService, recorder *Recorder) { + ticker := time.NewTicker(UpdateListenersTickrate) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + err := manager.UpdateListeners(ctx, recorder.ListenerAmount()) + if err != nil { + zerolog.Ctx(ctx).Error().Err(err).Msg("failed update listeners") + } + } + } +} diff --git a/tracker/main_test.go b/tracker/main_test.go new file mode 100644 index 00000000..1917054f --- /dev/null +++ b/tracker/main_test.go @@ -0,0 +1,53 @@ +package tracker + +import ( + "context" + "fmt" + "math/rand/v2" + "sync/atomic" + "testing" + "time" + + "github.com/R-a-dio/valkyrie/mocks" + "github.com/stretchr/testify/assert" +) + +func TestPeriodicallyUpdateListeners(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan struct{}) + + recorder := NewRecorder() + var last atomic.Int64 + var count int + + manager := &mocks.ManagerServiceMock{ + UpdateListenersFunc: func(contextMoqParam context.Context, new int64) error { + // we're done after 10 updates + if count++; count > 10 { + close(done) + } + // every 5 updates return an error + if count%5 == 0 { + return fmt.Errorf("that's an error") + } + + // otherwise our new value should equal what we set it to previously + if !assert.Equal(t, last.Load(), new) { + close(done) + } + + adjustment := rand.Int64() + recorder.listenerAmount.Store(adjustment) + last.Store(adjustment) + + return nil + }, + } + + // set the tickrate a bit higher for testing purposes + UpdateListenersTickrate = time.Millisecond * 10 + go PeriodicallyUpdateListeners(ctx, manager, recorder) + + <-done +} diff --git a/tracker/recorder.go b/tracker/recorder.go index c2c71b00..52a3d4a8 100644 --- a/tracker/recorder.go +++ b/tracker/recorder.go @@ -23,6 +23,10 @@ func ParseClientID(s string) (ClientID, error) { return ClientID(id), err } +func (c ClientID) String() string { + return strconv.FormatUint(uint64(c), 10) +} + type Listener struct { span trace.Span diff --git a/tracker/recorder_test.go b/tracker/recorder_test.go index 36efb1f1..67344b5f 100644 --- a/tracker/recorder_test.go +++ b/tracker/recorder_test.go @@ -10,6 +10,11 @@ import ( "github.com/stretchr/testify/assert" ) +const ( + eventuallyTick = time.Millisecond * 150 + eventuallyDelay = time.Second * 5 +) + func TestListenerAddAndRemoval(t *testing.T) { r := NewRecorder() ctx := context.Background() @@ -37,9 +42,6 @@ func TestListenerAddAndRemoval(t *testing.T) { } func TestListenerAddAndRemovalOutOfOrder(t *testing.T) { - eventuallyTick := time.Millisecond * 150 - eventuallyDelay := time.Second * 5 - r := NewRecorder() ctx := context.Background() req := httptest.NewRequest(http.MethodGet, "/test", nil) @@ -56,7 +58,7 @@ func TestListenerAddAndRemovalOutOfOrder(t *testing.T) { assert.Eventually(t, func() bool { // half should have been added normally return assert.Equal(t, count/2, r.ListenerAmount()) && - assert.Len(t, r.listeners, int(count/2)) + testListenerLength(t, r, int(count/2)) }, eventuallyDelay, eventuallyTick) assert.Eventually(t, func() bool { // half should have been removed early @@ -75,15 +77,23 @@ func TestListenerAddAndRemovalOutOfOrder(t *testing.T) { } assert.Eventually(t, func() bool { - r.mu.Lock() - defer r.mu.Unlock() return assert.Zero(t, r.ListenerAmount()) && - assert.Len(t, r.listeners, 0) + testListenerLength(t, r, 0) }, eventuallyDelay, eventuallyTick) assert.Eventually(t, func() bool { - r.mu.Lock() - defer r.mu.Unlock() - return assert.Len(t, r.pendingRemoval, 0) + return testPendingLength(t, r, 0) }, eventuallyDelay, eventuallyTick) } + +func testListenerLength(t *testing.T, r *Recorder, expected int) bool { + r.mu.Lock() + defer r.mu.Unlock() + return assert.Len(t, r.listeners, expected) +} + +func testPendingLength(t *testing.T, r *Recorder, expected int) bool { + r.mu.Lock() + defer r.mu.Unlock() + return assert.Len(t, r.pendingRemoval, expected) +}