diff --git a/modules/core/src/main/scala/dev/profunktor/redis4cats/config.scala b/modules/core/src/main/scala/dev/profunktor/redis4cats/config.scala index 6b52b995..6816810b 100644 --- a/modules/core/src/main/scala/dev/profunktor/redis4cats/config.scala +++ b/modules/core/src/main/scala/dev/profunktor/redis4cats/config.scala @@ -18,24 +18,32 @@ package dev.profunktor.redis4cats import scala.concurrent.duration._ +import io.lettuce.core.cluster.models.partitions.RedisClusterNode +import io.lettuce.core.cluster.ClusterClientOptions + object config { // Builder-style abstract class instead of case class to allow for bincompat-friendly extension in future versions. sealed abstract class Redis4CatsConfig { val shutdown: ShutdownConfig val topologyViewRefreshStrategy: TopologyViewRefreshStrategy + val nodeFilter: RedisClusterNode => Boolean def withShutdown(shutdown: ShutdownConfig): Redis4CatsConfig def withTopologyViewRefreshStrategy(strategy: TopologyViewRefreshStrategy): Redis4CatsConfig + def withNodeFilter(nodeFilter: RedisClusterNode => Boolean): Redis4CatsConfig } object Redis4CatsConfig { private case class Redis4CatsConfigImpl( shutdown: ShutdownConfig, - topologyViewRefreshStrategy: TopologyViewRefreshStrategy = NoRefresh + topologyViewRefreshStrategy: TopologyViewRefreshStrategy = NoRefresh, + nodeFilter: RedisClusterNode => Boolean = ClusterClientOptions.DEFAULT_NODE_FILTER.test ) extends Redis4CatsConfig { override def withShutdown(_shutdown: ShutdownConfig): Redis4CatsConfig = copy(shutdown = _shutdown) override def withTopologyViewRefreshStrategy(strategy: TopologyViewRefreshStrategy): Redis4CatsConfig = copy(topologyViewRefreshStrategy = strategy) + override def withNodeFilter(_nodeFilter: RedisClusterNode => Boolean): Redis4CatsConfig = + copy(nodeFilter = _nodeFilter) } def apply(): Redis4CatsConfig = Redis4CatsConfigImpl(ShutdownConfig()) } diff --git a/modules/core/src/main/scala/dev/profunktor/redis4cats/connection/RedisClusterClient.scala b/modules/core/src/main/scala/dev/profunktor/redis4cats/connection/RedisClusterClient.scala index a9ccf681..85015074 100644 --- a/modules/core/src/main/scala/dev/profunktor/redis4cats/connection/RedisClusterClient.scala +++ b/modules/core/src/main/scala/dev/profunktor/redis4cats/connection/RedisClusterClient.scala @@ -26,7 +26,7 @@ import dev.profunktor.redis4cats.JavaConversions._ import dev.profunktor.redis4cats.config._ import dev.profunktor.redis4cats.data.NodeId import dev.profunktor.redis4cats.effect._ -import io.lettuce.core.cluster.models.partitions.{ Partitions => JPartitions } +import io.lettuce.core.cluster.models.partitions.{ RedisClusterNode, Partitions => JPartitions } import io.lettuce.core.cluster.{ ClusterClientOptions, ClusterTopologyRefreshOptions, @@ -47,7 +47,7 @@ object RedisClusterClient { Log[F].info(s"Acquire Redis Cluster client") *> RedisExecutor[F] .lift(JClusterClient.create(uri.map(_.underlying).asJava)) - .flatTap(initializeClusterTopology[F](_, config.topologyViewRefreshStrategy)) + .flatTap(initializeClusterTopology[F](_, config.topologyViewRefreshStrategy, config.nodeFilter)) .map(new RedisClusterClient(_) {}) val release: RedisClusterClient => F[Unit] = client => @@ -69,11 +69,18 @@ object RedisClusterClient { private[redis4cats] def initializeClusterTopology[F[_]: Functor: RedisExecutor]( client: JClusterClient, - topologyViewRefreshStrategy: TopologyViewRefreshStrategy + topologyViewRefreshStrategy: TopologyViewRefreshStrategy, + nodeFilter: RedisClusterNode => Boolean ): F[Unit] = RedisExecutor[F].lift { topologyViewRefreshStrategy match { case NoRefresh => + client.setOptions( + ClusterClientOptions + .builder() + .nodeFilter(nodeFilter(_)) + .build() + ) client.getPartitions case Periodic(interval) => client.setOptions( @@ -86,6 +93,7 @@ object RedisClusterClient { .enablePeriodicRefresh(Duration.ofMillis(interval.toMillis)) .build() ) + .nodeFilter(nodeFilter(_)) .build() ) client.getPartitions @@ -101,6 +109,7 @@ object RedisClusterClient { .adaptiveRefreshTriggersTimeout(Duration.ofMillis(timeout.toMillis)) .build() ) + .nodeFilter(nodeFilter(_)) .build() ) client.getPartitions