RedissonCountDownLatch few bugs fixed, tests added

pull/6/head
Nikita 11 years ago
parent 207d89b34b
commit 6486428f91

@ -21,6 +21,7 @@ import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.redisson.core.RCountDownLatch;
import org.redisson.misc.internal.ThreadLocalSemaphore;
import com.lambdaworks.redis.RedisConnection;
import com.lambdaworks.redis.pubsub.RedisPubSubAdapter;
@ -39,7 +40,7 @@ public class RedissonCountDownLatch implements RCountDownLatch {
private final AtomicBoolean subscribeOnce = new AtomicBoolean();
private final Semaphore msg = new Semaphore(1);
private final ThreadLocalSemaphore msg = new ThreadLocalSemaphore();
RedissonCountDownLatch(RedisPubSubConnection<Object, Object> pubSubConnection, RedisConnection<Object, Object> connection, String name) {
this.connection = connection;
@ -60,7 +61,9 @@ public class RedissonCountDownLatch implements RCountDownLatch {
@Override
public void message(Object channel, Object message) {
if (message.equals(unlockMessage)) {
msg.release();
for (Semaphore s : msg.getAll()) {
s.release();
}
}
}
@ -79,8 +82,9 @@ public class RedissonCountDownLatch implements RCountDownLatch {
public void await() throws InterruptedException {
while (getCount() > 0) {
// waiting for message
msg.acquire();
msg.get().acquire();
}
msg.remove();
}
@ -89,14 +93,17 @@ public class RedissonCountDownLatch implements RCountDownLatch {
time = unit.toMillis(time);
while (getCount() > 0) {
if (time <= 0) {
msg.remove();
return false;
}
long current = System.currentTimeMillis();
// waiting for message
msg.tryAcquire(time, TimeUnit.MILLISECONDS);
msg.get().tryAcquire(time, TimeUnit.MILLISECONDS);
long elapsed = System.currentTimeMillis() - current;
time -= elapsed;
}
msg.remove();
return true;
}
@ -104,8 +111,8 @@ public class RedissonCountDownLatch implements RCountDownLatch {
public void countDown() {
Long val = connection.decr(name);
if (val == 0) {
connection.del(name);
connection.publish(getChannelName(), unlockMessage);
connection.del(name);
}
}
@ -114,16 +121,16 @@ public class RedissonCountDownLatch implements RCountDownLatch {
}
@Override
public int getCount() {
Integer val = (Integer) connection.get(name);
public long getCount() {
Number val = (Number) connection.get(name);
if (val == null) {
return 0;
}
return val;
return val.longValue();
}
@Override
public boolean trySetCount(int count) {
public boolean trySetCount(long count) {
return connection.setnx(name, count);
}

@ -25,8 +25,8 @@ public interface RCountDownLatch {
void countDown();
int getCount();
long getCount();
boolean trySetCount(int count);
boolean trySetCount(long count);
}

@ -0,0 +1,37 @@
package org.redisson.misc.internal;
import java.util.Collection;
import java.util.Collections;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Semaphore;
public class ThreadLocalSemaphore {
private final ThreadLocal<Semaphore> semaphore;
private final Set<Semaphore> allValues = Collections.newSetFromMap(new ConcurrentHashMap<Semaphore, Boolean>());
public ThreadLocalSemaphore() {
semaphore = new ThreadLocal<Semaphore>() {
@Override protected Semaphore initialValue() {
Semaphore value = new Semaphore(1);
allValues.add(value);
return value;
}
};
}
public Semaphore get() {
return semaphore.get();
}
public void remove() {
allValues.remove(get());
semaphore.remove();
}
public Collection<Semaphore> getAll() {
return allValues;
}
}

@ -0,0 +1,56 @@
package org.redisson;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import org.junit.Assert;
import org.junit.Test;
import org.redisson.core.RCountDownLatch;
public class RedissonCountDownLatchConcurrentTest {
@Test
public void testSingleCountDownAwait_SingleInstance() throws InterruptedException {
int iterations = Runtime.getRuntime().availableProcessors()*2;
Redisson redisson = Redisson.create();
final RCountDownLatch latch = redisson.getCountDownLatch("latch");
latch.trySetCount(iterations);
ScheduledExecutorService executor = Executors.newScheduledThreadPool(iterations);
for (int i = 0; i < iterations; i++) {
executor.execute(new Runnable() {
@Override
public void run() {
try {
Thread.sleep(100);
} catch (InterruptedException e) {
Assert.fail();
}
latch.countDown();
}
});
}
executor = Executors.newScheduledThreadPool(iterations);
for (int i = 0; i < iterations; i++) {
executor.execute(new Runnable() {
@Override
public void run() {
try {
latch.await();
} catch (InterruptedException e) {
Assert.fail();
}
}
});
}
executor.shutdown();
executor.awaitTermination(10, TimeUnit.SECONDS);
redisson.shutdown();
}
}

@ -0,0 +1,42 @@
package org.redisson;
import org.junit.Test;
import org.redisson.core.RCountDownLatch;
public class RedissonCountDownLatchTest {
@Test
public void testCountDown() throws InterruptedException {
Redisson redisson = Redisson.create();
RCountDownLatch latch = redisson.getCountDownLatch("latch");
latch.trySetCount(1);
latch.countDown();
latch.await();
latch.countDown();
latch.await();
latch.countDown();
latch.await();
RCountDownLatch latch1 = redisson.getCountDownLatch("latch1");
latch1.trySetCount(1);
latch1.countDown();
latch1.countDown();
latch1.await();
RCountDownLatch latch2 = redisson.getCountDownLatch("latch2");
latch2.trySetCount(1);
latch2.countDown();
latch2.await();
latch2.await();
RCountDownLatch latch3 = redisson.getCountDownLatch("latch3");
latch3.await();
RCountDownLatch latch4 = redisson.getCountDownLatch("latch4");
latch4.countDown();
latch4.await();
redisson.shutdown();
}
}
Loading…
Cancel
Save