diff --git a/tools/walletextension/ratelimiter/rate_limiter.go b/tools/walletextension/ratelimiter/rate_limiter.go index 05103b46f6..f264108e5b 100644 --- a/tools/walletextension/ratelimiter/rate_limiter.go +++ b/tools/walletextension/ratelimiter/rate_limiter.go @@ -21,6 +21,7 @@ type RequestInterval struct { // RateLimitUser represents a user with a map of current requests. type RateLimitUser struct { CurrentRequests map[uuid.UUID]RequestInterval + mu sync.RWMutex } // zeroUUID is a zero UUID returned when no new request is added. @@ -33,17 +34,21 @@ func (rl *RateLimiter) AddRequest(userID common.Address, interval RequestInterva return zeroUUID } rl.mu.Lock() - defer rl.mu.Unlock() - user, exists := rl.users[userID] if !exists { user = &RateLimitUser{ CurrentRequests: make(map[uuid.UUID]RequestInterval), + mu: sync.RWMutex{}, } rl.users[userID] = user } + rl.mu.Unlock() + id := uuid.New() + user.mu.Lock() user.CurrentRequests[id] = interval + user.mu.Unlock() + return id } @@ -54,18 +59,22 @@ func (rl *RateLimiter) SetRequestEnd(userID common.Address, id uuid.UUID) { return } - if user, userExists := rl.users[userID]; userExists { + rl.mu.RLock() + user, userExists := rl.users[userID] + rl.mu.RUnlock() + + if userExists { + user.mu.Lock() if request, requestExists := user.CurrentRequests[id]; requestExists { - rl.mu.Lock() - defer rl.mu.Unlock() now := time.Now() request.End = &now user.CurrentRequests[id] = request } else { - rl.logger.Info("Request with ID %s not found for user %s.", id, userID.Hex()) + rl.logger.Info("Request with ID not found for user.", "id", id, "user", userID) } + user.mu.Unlock() } else { - rl.logger.Info("User %s not found while trying to update the request.", userID.Hex()) + rl.logger.Info("User not found while trying to update the request.", "user", userID) } } @@ -76,11 +85,13 @@ func (rl *RateLimiter) CountOpenRequests(userID common.Address) int { var count int if user, exists := rl.users[userID]; exists { + user.mu.RLock() for _, interval := range user.CurrentRequests { if interval.End == nil { count++ } } + user.mu.RUnlock() } return count } @@ -202,6 +213,7 @@ func (rl *RateLimiter) PruneRequests() { // delete all the requests that have cutoff := time.Now().Add(-rl.window) for userID, user := range rl.users { + user.mu.Lock() for id, interval := range user.CurrentRequests { if interval.End != nil && interval.End.Before(cutoff) { delete(user.CurrentRequests, id) @@ -210,6 +222,7 @@ func (rl *RateLimiter) PruneRequests() { if len(user.CurrentRequests) == 0 { delete(rl.users, userID) } + user.mu.Unlock() } timeTaken := time.Since(startTime) if timeTaken > 1*time.Second {