diff --git a/src/main/java/org/redisson/RedissonCountDownLatch.java b/src/main/java/org/redisson/RedissonCountDownLatch.java index bfbdb94c0..58d41af69 100644 --- a/src/main/java/org/redisson/RedissonCountDownLatch.java +++ b/src/main/java/org/redisson/RedissonCountDownLatch.java @@ -21,6 +21,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import org.redisson.core.RCountDownLatch; +import org.redisson.misc.internal.ThreadLocalSemaphore; import com.lambdaworks.redis.RedisConnection; import com.lambdaworks.redis.pubsub.RedisPubSubAdapter; @@ -39,7 +40,7 @@ public class RedissonCountDownLatch implements RCountDownLatch { private final AtomicBoolean subscribeOnce = new AtomicBoolean(); - private final Semaphore msg = new Semaphore(1); + private final ThreadLocalSemaphore msg = new ThreadLocalSemaphore(); RedissonCountDownLatch(RedisPubSubConnection pubSubConnection, RedisConnection connection, String name) { this.connection = connection; @@ -60,7 +61,9 @@ public class RedissonCountDownLatch implements RCountDownLatch { @Override public void message(Object channel, Object message) { if (message.equals(unlockMessage)) { - msg.release(); + for (Semaphore s : msg.getAll()) { + s.release(); + } } } @@ -79,8 +82,9 @@ public class RedissonCountDownLatch implements RCountDownLatch { public void await() throws InterruptedException { while (getCount() > 0) { // waiting for message - msg.acquire(); + msg.get().acquire(); } + msg.remove(); } @@ -89,14 +93,17 @@ public class RedissonCountDownLatch implements RCountDownLatch { time = unit.toMillis(time); while (getCount() > 0) { if (time <= 0) { + msg.remove(); return false; } long current = System.currentTimeMillis(); // waiting for message - msg.tryAcquire(time, TimeUnit.MILLISECONDS); + msg.get().tryAcquire(time, TimeUnit.MILLISECONDS); long elapsed = System.currentTimeMillis() - current; time -= elapsed; } + + msg.remove(); return true; } @@ -104,8 +111,8 @@ public class RedissonCountDownLatch implements RCountDownLatch { public void countDown() { Long val = connection.decr(name); if (val == 0) { - connection.del(name); connection.publish(getChannelName(), unlockMessage); + connection.del(name); } } @@ -114,16 +121,16 @@ public class RedissonCountDownLatch implements RCountDownLatch { } @Override - public int getCount() { - Integer val = (Integer) connection.get(name); + public long getCount() { + Number val = (Number) connection.get(name); if (val == null) { return 0; } - return val; + return val.longValue(); } @Override - public boolean trySetCount(int count) { + public boolean trySetCount(long count) { return connection.setnx(name, count); } diff --git a/src/main/java/org/redisson/core/RCountDownLatch.java b/src/main/java/org/redisson/core/RCountDownLatch.java index c3805d748..3cc163f02 100644 --- a/src/main/java/org/redisson/core/RCountDownLatch.java +++ b/src/main/java/org/redisson/core/RCountDownLatch.java @@ -25,8 +25,8 @@ public interface RCountDownLatch { void countDown(); - int getCount(); + long getCount(); - boolean trySetCount(int count); + boolean trySetCount(long count); } diff --git a/src/main/java/org/redisson/misc/internal/ThreadLocalSemaphore.java b/src/main/java/org/redisson/misc/internal/ThreadLocalSemaphore.java new file mode 100644 index 000000000..b29aae338 --- /dev/null +++ b/src/main/java/org/redisson/misc/internal/ThreadLocalSemaphore.java @@ -0,0 +1,37 @@ +package org.redisson.misc.internal; + +import java.util.Collection; +import java.util.Collections; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Semaphore; + +public class ThreadLocalSemaphore { + + private final ThreadLocal semaphore; + private final Set allValues = Collections.newSetFromMap(new ConcurrentHashMap()); + + public ThreadLocalSemaphore() { + semaphore = new ThreadLocal() { + @Override protected Semaphore initialValue() { + Semaphore value = new Semaphore(1); + allValues.add(value); + return value; + } + }; + } + + public Semaphore get() { + return semaphore.get(); + } + + public void remove() { + allValues.remove(get()); + semaphore.remove(); + } + + public Collection getAll() { + return allValues; + } + +} diff --git a/src/test/java/org/redisson/RedissonCountDownLatchConcurrentTest.java b/src/test/java/org/redisson/RedissonCountDownLatchConcurrentTest.java new file mode 100644 index 000000000..26d133d41 --- /dev/null +++ b/src/test/java/org/redisson/RedissonCountDownLatchConcurrentTest.java @@ -0,0 +1,56 @@ +package org.redisson; + +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +import org.junit.Assert; +import org.junit.Test; +import org.redisson.core.RCountDownLatch; + +public class RedissonCountDownLatchConcurrentTest { + + @Test + public void testSingleCountDownAwait_SingleInstance() throws InterruptedException { + int iterations = Runtime.getRuntime().availableProcessors()*2; + + Redisson redisson = Redisson.create(); + final RCountDownLatch latch = redisson.getCountDownLatch("latch"); + latch.trySetCount(iterations); + + ScheduledExecutorService executor = Executors.newScheduledThreadPool(iterations); + for (int i = 0; i < iterations; i++) { + executor.execute(new Runnable() { + @Override + public void run() { + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Assert.fail(); + } + latch.countDown(); + } + }); + } + + executor = Executors.newScheduledThreadPool(iterations); + for (int i = 0; i < iterations; i++) { + executor.execute(new Runnable() { + @Override + public void run() { + try { + latch.await(); + } catch (InterruptedException e) { + Assert.fail(); + } + } + }); + } + + executor.shutdown(); + executor.awaitTermination(10, TimeUnit.SECONDS); + + redisson.shutdown(); + } + +} diff --git a/src/test/java/org/redisson/RedissonCountDownLatchTest.java b/src/test/java/org/redisson/RedissonCountDownLatchTest.java new file mode 100644 index 000000000..03b6425a2 --- /dev/null +++ b/src/test/java/org/redisson/RedissonCountDownLatchTest.java @@ -0,0 +1,42 @@ +package org.redisson; + +import org.junit.Test; +import org.redisson.core.RCountDownLatch; + +public class RedissonCountDownLatchTest { + + @Test + public void testCountDown() throws InterruptedException { + Redisson redisson = Redisson.create(); + RCountDownLatch latch = redisson.getCountDownLatch("latch"); + latch.trySetCount(1); + latch.countDown(); + latch.await(); + latch.countDown(); + latch.await(); + latch.countDown(); + latch.await(); + + RCountDownLatch latch1 = redisson.getCountDownLatch("latch1"); + latch1.trySetCount(1); + latch1.countDown(); + latch1.countDown(); + latch1.await(); + + RCountDownLatch latch2 = redisson.getCountDownLatch("latch2"); + latch2.trySetCount(1); + latch2.countDown(); + latch2.await(); + latch2.await(); + + RCountDownLatch latch3 = redisson.getCountDownLatch("latch3"); + latch3.await(); + + RCountDownLatch latch4 = redisson.getCountDownLatch("latch4"); + latch4.countDown(); + latch4.await(); + + redisson.shutdown(); + } + +}