diff --git a/balancer/catabalancer/catalyst_balancer.go b/balancer/catabalancer/catalyst_balancer.go index 6e5c0c2c..a107832c 100644 --- a/balancer/catabalancer/catalyst_balancer.go +++ b/balancer/catabalancer/catalyst_balancer.go @@ -9,15 +9,18 @@ import ( "sort" "strconv" "strings" + "sync" "time" + _ "github.com/lib/pq" "github.com/livepeer/catalyst-api/cluster" "github.com/livepeer/catalyst-api/log" "github.com/patrickmn/go-cache" ) const ( - stateCacheKey = "stateCacheKey" + stateCacheKey = "stateCacheKey" + dbQueryTimeout = 10 * time.Second ) type CataBalancer struct { @@ -27,6 +30,7 @@ type CataBalancer struct { ingestStreamTimeout time.Duration nodeStatsDB *sql.DB nodeStatsCache *cache.Cache + cacheMutex sync.Mutex } type stats struct { @@ -136,7 +140,7 @@ func (c *CataBalancer) UpdateMembers(ctx context.Context, members []cluster.Memb } func (c *CataBalancer) GetBestNode(ctx context.Context, redirectPrefixes []string, playbackID, lat, lon, fallbackPrefix string, isStudioReq bool) (string, string, error) { - s, err := c.refreshNodes() + s, err := c.refreshNodes(ctx) if err != nil { return "", "", fmt.Errorf("error refreshing nodes: %w", err) } @@ -291,10 +295,28 @@ func truncateReturned(scoredNodes []ScoredNode, numNodes int) []ScoredNode { return scoredNodes[:numNodes] } -func (c *CataBalancer) refreshNodes() (stats, error) { +func (c *CataBalancer) getCachedStats() (stats, bool) { cachedState, found := c.nodeStatsCache.Get(stateCacheKey) if found { - return *cachedState.(*stats), nil + return *cachedState.(*stats), true + } + return stats{}, false +} + +func (c *CataBalancer) refreshNodes(ctx context.Context) (stats, error) { + cachedState, found := c.getCachedStats() + if found { + return cachedState, nil + } + + c.cacheMutex.Lock() + defer c.cacheMutex.Unlock() + + // check cache again since multiple requests can get an initial cache miss, the first one will populate + // the cache while the requests waiting behind it (with the cacheMutex) can use the new cached data + cachedState, found = c.getCachedStats() + if found { + return cachedState, nil } s := stats{ @@ -307,8 +329,11 @@ func (c *CataBalancer) refreshNodes() (stats, error) { return s, fmt.Errorf("node stats DB was nil") } + queryContext, cancel := context.WithTimeout(ctx, dbQueryTimeout) + defer cancel() + query := "SELECT stats FROM node_stats" - rows, err := c.nodeStatsDB.Query(query) + rows, err := c.nodeStatsDB.QueryContext(queryContext, query) if err != nil { return s, fmt.Errorf("failed to query node stats: %w", err) } @@ -366,7 +391,7 @@ func getPlaybackID(streamID string) string { } func (c *CataBalancer) MistUtilLoadSource(ctx context.Context, streamID, lat, lon string) (string, error) { - s, err := c.refreshNodes() + s, err := c.refreshNodes(ctx) if err != nil { return "", fmt.Errorf("error refreshing nodes: %w", err) } diff --git a/balancer/catabalancer/catalyst_balancer_test.go b/balancer/catabalancer/catalyst_balancer_test.go index e9862d2c..4f5f2a06 100644 --- a/balancer/catabalancer/catalyst_balancer_test.go +++ b/balancer/catabalancer/catalyst_balancer_test.go @@ -391,7 +391,7 @@ func TestStreamTimeout(t *testing.T) { nodeStats := NodeUpdateEvent{NodeID: "node", NodeMetrics: NodeMetrics{Timestamp: time.Now()}} nodeStats.SetStreams([]string{"video+stream"}, []string{"video+ingest"}) setNodeMetrics(t, mock, []NodeUpdateEvent{nodeStats}) - s, err := c.refreshNodes() + s, err := c.refreshNodes(context.Background()) require.NoError(t, err) setNodeMetrics(t, mock, []NodeUpdateEvent{nodeStats})