diff --git a/redisson/src/main/java/org/redisson/connection/MasterSlaveEntry.java b/redisson/src/main/java/org/redisson/connection/MasterSlaveEntry.java index f8e5c86ce..0f79f01bf 100644 --- a/redisson/src/main/java/org/redisson/connection/MasterSlaveEntry.java +++ b/redisson/src/main/java/org/redisson/connection/MasterSlaveEntry.java @@ -594,7 +594,11 @@ public class MasterSlaveEntry { return slaveBalancer.getConnection(command, client); } - public CompletableFuture nextPubSubConnection() { + public CompletableFuture nextPubSubConnection(ClientConnectionsEntry entry) { + if (entry != null) { + return slaveBalancer.nextPubSubConnection(entry); + } + if (config.getSubscriptionMode() == SubscriptionMode.MASTER) { return pubSubConnectionPool.get(); } diff --git a/redisson/src/main/java/org/redisson/connection/balancer/LoadBalancerManager.java b/redisson/src/main/java/org/redisson/connection/balancer/LoadBalancerManager.java index 3302e33e0..6e2c6984c 100644 --- a/redisson/src/main/java/org/redisson/connection/balancer/LoadBalancerManager.java +++ b/redisson/src/main/java/org/redisson/connection/balancer/LoadBalancerManager.java @@ -253,6 +253,10 @@ public class LoadBalancerManager { return pubSubConnectionPool.get(); } + public CompletableFuture nextPubSubConnection(ClientConnectionsEntry entry) { + return pubSubConnectionPool.get(entry); + } + public boolean contains(InetSocketAddress addr) { return getEntry(addr) != null; } @@ -299,7 +303,7 @@ public class LoadBalancerManager { f.completeExceptionally(exception); return f; } - + public CompletableFuture getConnection(RedisCommand command, RedisClient client) { ClientConnectionsEntry entry = getEntry(client); if (entry != null) { diff --git a/redisson/src/main/java/org/redisson/connection/pool/PubSubConnectionPool.java b/redisson/src/main/java/org/redisson/connection/pool/PubSubConnectionPool.java index fc5776ca1..32f9b8ff6 100644 --- a/redisson/src/main/java/org/redisson/connection/pool/PubSubConnectionPool.java +++ b/redisson/src/main/java/org/redisson/connection/pool/PubSubConnectionPool.java @@ -41,7 +41,11 @@ public class PubSubConnectionPool extends ConnectionPool public CompletableFuture get() { return get(RedisCommands.SUBSCRIBE); } - + + public CompletableFuture get(ClientConnectionsEntry entry) { + return get(RedisCommands.SUBSCRIBE, entry); + } + @Override protected RedisPubSubConnection poll(ClientConnectionsEntry entry, RedisCommand command) { return entry.pollSubscribeConnection(); diff --git a/redisson/src/main/java/org/redisson/pubsub/PublishSubscribeService.java b/redisson/src/main/java/org/redisson/pubsub/PublishSubscribeService.java index f3a270efe..77501b3c0 100644 --- a/redisson/src/main/java/org/redisson/pubsub/PublishSubscribeService.java +++ b/redisson/src/main/java/org/redisson/pubsub/PublishSubscribeService.java @@ -21,6 +21,7 @@ import org.redisson.client.*; import org.redisson.client.codec.Codec; import org.redisson.client.protocol.pubsub.PubSubType; import org.redisson.config.MasterSlaveServersConfig; +import org.redisson.connection.ClientConnectionsEntry; import org.redisson.connection.ConnectionManager; import org.redisson.connection.MasterSlaveEntry; import org.redisson.misc.AsyncSemaphore; @@ -161,7 +162,7 @@ public class PublishSubscribeService { List> futures = new ArrayList<>(); for (MasterSlaveEntry entry : entrySet) { - CompletableFuture future = subscribe(PubSubType.PSUBSCRIBE, codec, channelName, entry, ls); + CompletableFuture future = subscribe(PubSubType.PSUBSCRIBE, codec, channelName, entry, null, ls); futures.add(future); } CompletableFuture future = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])); @@ -178,7 +179,7 @@ public class PublishSubscribeService { return promise; } - CompletableFuture f = subscribe(PubSubType.PSUBSCRIBE, codec, channelName, entry, listeners); + CompletableFuture f = subscribe(PubSubType.PSUBSCRIBE, codec, channelName, entry, null, listeners); return f.thenApply(res -> Collections.singletonList(res)); } @@ -188,6 +189,11 @@ public class PublishSubscribeService { || channelName.toString().startsWith("__keyevent@")); } + public CompletableFuture subscribe(MasterSlaveEntry entry, ClientConnectionsEntry clientEntry, + Codec codec, ChannelName channelName, RedisPubSubListener... listeners) { + return subscribe(PubSubType.SUBSCRIBE, codec, channelName, entry, clientEntry, listeners); + } + public CompletableFuture subscribe(Codec codec, ChannelName channelName, RedisPubSubListener... listeners) { MasterSlaveEntry entry = getEntry(channelName); if (entry == null) { @@ -196,7 +202,7 @@ public class PublishSubscribeService { promise.completeExceptionally(ex); return promise; } - return subscribe(PubSubType.SUBSCRIBE, codec, channelName, entry, listeners); + return subscribe(PubSubType.SUBSCRIBE, codec, channelName, entry, null, listeners); } public CompletableFuture ssubscribe(Codec codec, ChannelName channelName, RedisPubSubListener... listeners) { @@ -207,11 +213,11 @@ public class PublishSubscribeService { promise.completeExceptionally(ex); return promise; } - return subscribe(PubSubType.SSUBSCRIBE, codec, channelName, entry, listeners); + return subscribe(PubSubType.SSUBSCRIBE, codec, channelName, entry, null, listeners); } private CompletableFuture subscribe(PubSubType type, Codec codec, ChannelName channelName, - MasterSlaveEntry entry, RedisPubSubListener... listeners) { + MasterSlaveEntry entry, ClientConnectionsEntry clientEntry, RedisPubSubListener... listeners) { CompletableFuture promise = new CompletableFuture<>(); AsyncSemaphore lock = getSemaphore(channelName); int timeout = config.getTimeout() + config.getRetryInterval() * config.getRetryAttempts(); @@ -226,7 +232,7 @@ public class PublishSubscribeService { return; } - subscribeNoTimeout(codec, channelName, entry, promise, type, lock, new AtomicInteger(), listeners); + subscribeNoTimeout(codec, channelName, entry, clientEntry, promise, type, lock, new AtomicInteger(), listeners); timeout(promise); }); return promise; @@ -242,8 +248,8 @@ public class PublishSubscribeService { return promise; } - subscribeNoTimeout(codec, new ChannelName(channelName), entry, promise, - PubSubType.SUBSCRIBE, semaphore, new AtomicInteger(), listeners); + subscribeNoTimeout(codec, new ChannelName(channelName), entry, null, promise, + PubSubType.SUBSCRIBE, semaphore, new AtomicInteger(), listeners); return promise; } @@ -300,12 +306,12 @@ public class PublishSubscribeService { return; } - subscribeNoTimeout(codec, channelName, entry, promise, type, lock, attempts, listeners); + subscribeNoTimeout(codec, channelName, entry, null, promise, type, lock, attempts, listeners); } private void subscribeNoTimeout(Codec codec, ChannelName channelName, MasterSlaveEntry entry, - CompletableFuture promise, PubSubType type, - AsyncSemaphore lock, AtomicInteger attempts, RedisPubSubListener... listeners) { + ClientConnectionsEntry clientEntry, CompletableFuture promise, + PubSubType type, AsyncSemaphore lock, AtomicInteger attempts, RedisPubSubListener... listeners) { PubSubConnectionEntry connEntry = name2PubSubConnection.get(new PubSubKey(channelName, entry)); if (connEntry != null) { addListeners(channelName, promise, type, lock, connEntry, listeners); @@ -325,7 +331,8 @@ public class PublishSubscribeService { if (freeEntry == null) { freePubSubLock.release(); - CompletableFuture connectFuture = connect(codec, channelName, entry, promise, type, lock, listeners); + CompletableFuture connectFuture = connect(codec, channelName, entry, + clientEntry, promise, type, lock, listeners); connectionManager.getServiceManager().newTimeout(t -> { if (!connectFuture.cancel(false) && !connectFuture.isCompletedExceptionally()) { @@ -411,10 +418,11 @@ public class PublishSubscribeService { } private CompletableFuture connect(Codec codec, ChannelName channelName, - MasterSlaveEntry msEntry, CompletableFuture promise, + MasterSlaveEntry msEntry, ClientConnectionsEntry clientEntry, + CompletableFuture promise, PubSubType type, AsyncSemaphore lock, RedisPubSubListener... listeners) { - CompletableFuture connFuture = msEntry.nextPubSubConnection(); + CompletableFuture connFuture = msEntry.nextPubSubConnection(clientEntry); promise.whenComplete((res, e) -> { if (e != null) { connFuture.completeExceptionally(e); @@ -703,7 +711,7 @@ public class PublishSubscribeService { } CompletableFuture subscribeFuture = - subscribe(PubSubType.PSUBSCRIBE, subscribeCodec, channelName, entry, listeners.toArray(new RedisPubSubListener[0])); + subscribe(PubSubType.PSUBSCRIBE, subscribeCodec, channelName, entry, null, listeners.toArray(new RedisPubSubListener[0])); subscribeFuture.whenComplete((res, e) -> { if (e != null) { connectionManager.getServiceManager().newTimeout(task -> {