diff --git a/redisson/src/main/java/org/redisson/RedissonRateLimiter.java b/redisson/src/main/java/org/redisson/RedissonRateLimiter.java index b7953560c..33a6246ae 100644 --- a/redisson/src/main/java/org/redisson/RedissonRateLimiter.java +++ b/redisson/src/main/java/org/redisson/RedissonRateLimiter.java @@ -253,12 +253,18 @@ public class RedissonRateLimiter extends RedissonExpirable implements RRateLimit @Override public RFuture setRateAsync(RateType type, long rate, long rateInterval, RateIntervalUnit unit) { - return commandExecutor.evalWriteAsync(getRawName(), LongCodec.INSTANCE, RedisCommands.EVAL_BOOLEAN, - "redis.call('hset', KEYS[1], 'rate', ARGV[1]);" + return commandExecutor.evalWriteAsync(getRawName(), LongCodec.INSTANCE, RedisCommands.EVAL_BOOLEAN, + "local valueName = KEYS[2];" + + "local permitsName = KEYS[4];" + + "if ARGV[3] == '1' then " + + " valueName = KEYS[3];" + + " permitsName = KEYS[5];" + + "end " + +"redis.call('hset', KEYS[1], 'rate', ARGV[1]);" + "redis.call('hset', KEYS[1], 'interval', ARGV[2]);" + "redis.call('hset', KEYS[1], 'type', ARGV[3]);" - + "redis.call('del', KEYS[2], KEYS[3]);", - Arrays.asList(getRawName(), getValueName(), getPermitsName()), rate, unit.toMillis(rateInterval), type.ordinal()); + + "redis.call('del', valueName, permitsName);", + Arrays.asList(getRawName(), getValueName(), getClientValueName(), getPermitsName(), getClientPermitsName()), rate, unit.toMillis(rateInterval), type.ordinal()); } private static final RedisCommand HGETALL = new RedisCommand("HGETALL", new MapEntriesDecoder(new MultiDecoder() { diff --git a/redisson/src/test/java/org/redisson/RedissonRateLimiterTest.java b/redisson/src/test/java/org/redisson/RedissonRateLimiterTest.java index e2b8431cf..bd0913bd0 100644 --- a/redisson/src/test/java/org/redisson/RedissonRateLimiterTest.java +++ b/redisson/src/test/java/org/redisson/RedissonRateLimiterTest.java @@ -2,10 +2,7 @@ package org.redisson; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import org.redisson.api.RRateLimiter; -import org.redisson.api.RScoredSortedSet; -import org.redisson.api.RateIntervalUnit; -import org.redisson.api.RateType; +import org.redisson.api.*; import java.time.Duration; import java.util.ArrayList; @@ -314,5 +311,45 @@ public class RedissonRateLimiterTest extends BaseTest { count++; } } + + @Test + public void testChangeRate() { + /* Test case -- PRE_CLIENT */ + RRateLimiter rr = redisson.getRateLimiter("test_change_rate"); + rr.setRate(RateType.PER_CLIENT, 10, 1, RateIntervalUnit.SECONDS); + assertThat(rr.getConfig().getRate()).isEqualTo(10); + //check value in Redis + rr.acquire(1); + String valueKey = redisson.getKeys().getKeysStream().filter(k -> k.contains("value:")).findAny().get(); + Long value = redisson.getAtomicLong(valueKey).get(); + assertThat(value).isEqualTo(9); + + //change to 20/s + rr.setRate(RateType.PER_CLIENT, 20, 1, RateIntervalUnit.SECONDS); + assertThat(rr.getConfig().getRate()).isEqualTo(20); + //check value in Redis + rr.acquire(1); + value = redisson.getAtomicLong(valueKey).get(); + assertThat(value).isEqualTo(19); + + /* Test case -- OVERALL */ + rr.setRate(RateType.OVERALL, 10, 1, RateIntervalUnit.SECONDS); + assertThat(rr.getConfig().getRate()).isEqualTo(10); + //check value in Redis + rr.acquire(1); + valueKey = redisson.getKeys().getKeysStream().filter(k -> k.endsWith("value")).findAny().get(); + value = redisson.getAtomicLong(valueKey).get(); + assertThat(value).isEqualTo(9); + + rr.setRate(RateType.OVERALL, 20, 1, RateIntervalUnit.SECONDS); + assertThat(rr.getConfig().getRate()).isEqualTo(20); + //check value in Redis + rr.acquire(1); + value = redisson.getAtomicLong(valueKey).get(); + assertThat(value).isEqualTo(19); + + //clean all keys in test + redisson.getKeys().deleteByPattern("*test_change_rate*"); + } }