diff --git a/ring.go b/ring.go index 0ff3f75b1..e8d0897c2 100644 --- a/ring.go +++ b/ring.go @@ -22,6 +22,12 @@ import ( var errRingShardsDown = errors.New("redis: all ring shards are down") +// defaultHeartbeatFn is the default function used to check the shard liveness +var defaultHeartbeatFn = func(ctx context.Context, client *Client) bool { + err := client.Ping(ctx).Err() + return err == nil || err == pool.ErrPoolTimeout +} + //------------------------------------------------------------------------------ type ConsistentHash interface { @@ -54,10 +60,14 @@ type RingOptions struct { // ClientName will execute the `CLIENT SETNAME ClientName` command for each conn. ClientName string - // Frequency of PING commands sent to check shards availability. + // Frequency of executing HeartbeatFn to check shards availability. // Shard is considered down after 3 subsequent failed checks. HeartbeatFrequency time.Duration + // A function used to check the shard liveness + // if not set, defaults to defaultHeartbeatFn + HeartbeatFn func(ctx context.Context, client *Client) bool + // NewConsistentHash returns a consistent hash that is used // to distribute keys across the shards. // @@ -124,6 +134,10 @@ func (opt *RingOptions) init() { opt.HeartbeatFrequency = 500 * time.Millisecond } + if opt.HeartbeatFn == nil { + opt.HeartbeatFn = defaultHeartbeatFn + } + if opt.NewConsistentHash == nil { opt.NewConsistentHash = newRendezvous } @@ -423,8 +437,7 @@ func (c *ringSharding) Heartbeat(ctx context.Context, frequency time.Duration) { var rebalance bool for _, shard := range c.List() { - err := shard.Client.Ping(ctx).Err() - isUp := err == nil || err == pool.ErrPoolTimeout + isUp := c.opt.HeartbeatFn(ctx, shard.Client) if shard.Vote(isUp) { internal.Logger.Printf(ctx, "ring shard state changed: %s", shard) rebalance = true