diff --git a/redisson/src/main/java/org/redisson/connection/balancer/BaseLoadBalancer.java b/redisson/src/main/java/org/redisson/connection/balancer/BaseLoadBalancer.java index 8b793f9ba..4d03b6f9a 100644 --- a/redisson/src/main/java/org/redisson/connection/balancer/BaseLoadBalancer.java +++ b/redisson/src/main/java/org/redisson/connection/balancer/BaseLoadBalancer.java @@ -40,6 +40,10 @@ public abstract class BaseLoadBalancer implements LoadBalancer { } protected List filter(List entries) { + return filter(entries, pattern); + } + + protected final List filter(List entries, Pattern pattern) { if (pattern == null) { return entries; } diff --git a/redisson/src/main/java/org/redisson/connection/balancer/CommandsLoadBalancer.java b/redisson/src/main/java/org/redisson/connection/balancer/CommandsLoadBalancer.java index 71aa0427d..5a65c1dd9 100644 --- a/redisson/src/main/java/org/redisson/connection/balancer/CommandsLoadBalancer.java +++ b/redisson/src/main/java/org/redisson/connection/balancer/CommandsLoadBalancer.java @@ -18,10 +18,11 @@ package org.redisson.connection.balancer; import org.redisson.client.protocol.RedisCommand; import org.redisson.connection.ClientConnectionsEntry; import org.redisson.misc.RedisURI; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; -import java.util.List; -import java.util.Locale; -import java.util.Set; +import java.util.*; +import java.util.regex.Pattern; import java.util.stream.Collectors; /** @@ -32,12 +33,19 @@ import java.util.stream.Collectors; */ public class CommandsLoadBalancer extends RoundRobinLoadBalancer implements LoadBalancer { + private static final Logger log = LoggerFactory.getLogger(CommandsLoadBalancer.class); + + private final Map> commandsMap = new HashMap<>(); + private Set commands; private RedisURI address; @Override public ClientConnectionsEntry getEntry(List clientsCopy, RedisCommand redisCommand) { - if (commands.contains(redisCommand.getName().toLowerCase(Locale.ENGLISH))) { + String name = redisCommand.getName().toLowerCase(Locale.ENGLISH); + + if (commands != null + && commands.contains(name)) { return clientsCopy.stream() .filter(c -> address.equals(c.getClient().getAddr())) .findAny() @@ -45,6 +53,16 @@ public class CommandsLoadBalancer extends RoundRobinLoadBalancer implements Load return getEntry(clientsCopy); }); } + + for (Map.Entry> e : commandsMap.entrySet()) { + if (e.getValue().contains(name)) { + List s = filter(clientsCopy, e.getKey()); + if (!s.isEmpty()) { + return getEntry(s); + } + } + } + return getEntry(clientsCopy); } @@ -53,7 +71,9 @@ public class CommandsLoadBalancer extends RoundRobinLoadBalancer implements Load * * @param address Redis node address */ + @Deprecated public void setAddress(String address) { + log.warn("address setting is deprecated. Use commandsMap setting instead."); this.address = new RedisURI(address); } @@ -63,9 +83,36 @@ public class CommandsLoadBalancer extends RoundRobinLoadBalancer implements Load * * @param commands commands list */ + @Deprecated public void setCommands(List commands) { + log.warn("commands setting is deprecated. Use commandsMap setting instead."); this.commands = commands.stream() .map(c -> c.toLowerCase(Locale.ENGLISH)) .collect(Collectors.toSet()); } + + /** + * Defines command names mapped per host name regular expression. + *

+ * YAML definition example: + *

+     *      loadBalancer: !<org.redisson.connection.balancer.CommandsLoadBalancer>
+     *       commandsMap:
+     *           "slavehost1.*" : ["get", "hget"]
+     *           "slavehost2.*" : ["mget", "publish"]
+     * 
+ * + * @param value a map where the key is a host name regular expression, + * and the value is an array of command names + * that should be executed. + */ + public void setCommandsMap(Map> value) { + for (Map.Entry> e : value.entrySet()) { + Set cc = e.getValue().stream() + .map(c -> c.toLowerCase(Locale.ENGLISH)) + .collect(Collectors.toSet()); + this.commandsMap.put(Pattern.compile(e.getKey()), cc); + } + } + }