package com.ysl.zkclient.queue; import com.ysl.zkclient.ZKClient;
import com.ysl.zkclient.exception.ZKNoNodeException;
import com.ysl.zkclient.utils.ExceptionUtil;
import org.apache.zookeeper.CreateMode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import java.io.Serializable;
import java.util.List; /**
* 一种分布式队列的实现
* @param <T>
*/
public class ZKDistributedQueue<T extends Serializable> { private static final Logger LOG = LoggerFactory.getLogger(ZKDistributedQueue.class); private static final String ELEMENT_NAME = "node"; private ZKClient client;
private String rootPath; /**
* 创建分布式队列
* @param client zk客户端
* @param rootPath 队列的跟路径
*/
public ZKDistributedQueue(ZKClient client, String rootPath) {
this.client = client;
this.rootPath = rootPath;
if(!client.exists(rootPath)){
throw new ZKNoNodeException("the root path is not exists, please create path first ["+rootPath+"]");
}
} /**
* 添加一个元素
* @param node
* @return
*/
public boolean offer(T node){
try{
client.create(rootPath+"/"+ELEMENT_NAME + "-",node, CreateMode.PERSISTENT_SEQUENTIAL);
}catch (Exception e){
throw ExceptionUtil.convertToRuntimeException(e);
}
return true;
} /**
* 删除并返回顶部元素
* @return
*/
public T pool(){
while(true){
Node node = getFirstNode();
if(node == null){
return null;
} try{
boolean flag = client.delete(node.getName());
if(flag){
return (T)node.getData();
}else{
//删除失败,说明数据已经被其他的线程获取,重新获取底部元素
}
}catch (Exception e){
throw ExceptionUtil.convertToRuntimeException(e);
}
}
} /**
* 获取队列顶部元素
* @return
*/
private Node<T> getFirstNode() {
try{
while(true){
List<String> children = client.getChild(rootPath,true);
if(children == null || children.isEmpty()){
return null;
} String nodeName = getNodeName(children);
try{
return new Node<T>(rootPath+"/"+nodeName,(T)client.getData(rootPath+"/"+nodeName));
}catch (ZKNoNodeException e){
//如果抛出此异常,证明该节点已被其他线程获取
}
}
}catch (Exception e){
throw ExceptionUtil.convertToRuntimeException(e);
}
} /**
* 获取编号最小的节点
* @param children
* @return
*/
private String getNodeName(List<String> children) {
String child= children.get(0);
for(String path : children){
if(path.compareTo(child) < 0){
child = path;
}
}
return child;
} public boolean isEmpty(){
return client.getChild(rootPath,true).size() == 0;
} public T peek(){
Node<T> node = getFirstNode();
if(node == null){
return null;
}
return node.getData();
} private class Node<T>{ private String name;
private T data; public Node(String name, T data) {
this.name = name;
this.data = data;
} public String getName() {
return name;
} public T getData() {
return data;
}
}
}
测试代码
/**
* 测试分布式队列
* @throws Exception
* @return void
*/
@Test
public void testDistributedQueue() throws Exception{
final String rootPath = "/zk/queue";
//创建rootPath
zkClient.createRecursive(rootPath, null, CreateMode.PERSISTENT); final List<String> list1 = new ArrayList<String>();
final List<String> list2 = new ArrayList<String>();
for(int i=0;i<21;i++){
Thread thread1 = new Thread(new Runnable() {
public void run() {
ZKDistributedQueue<String> queue = new ZKDistributedQueue(zkClient, rootPath);
queue.offer(Thread.currentThread().getName());
list1.add(Thread.currentThread().getName());
}
});
thread1.start();
} //等待事件到达
int size1 = TestUtil.waitUntil(21, new Callable<Integer>() {
@Override
public Integer call() throws Exception {
return list1.size();
} }, TimeUnit.SECONDS, 100);
System.out.println(zkClient.getChildren(rootPath)); for(int i=0;i<20;i++){
Thread thread = new Thread(new Runnable() {
public void run() {
ZKDistributedQueue<String> queue = new ZKDistributedQueue(zkClient, rootPath);
list2.add(queue.poll());
}
});
thread.start();
}
//等待事件到达
int size2 = TestUtil.waitUntil(20, new Callable<Integer>() {
@Override
public Integer call() throws Exception {
return list2.size();
} }, TimeUnit.SECONDS, 100);
assertThat(size2).isEqualTo(20);
boolean flag = true;
for(int i =0;i<20;i++){
if(!list1.get(i).equals(list2.get(i))){
flag = false;
break;
}
}
assertThat(flag).isTrue(); ZKDistributedQueue<String> queue = new ZKDistributedQueue(zkClient, rootPath);
assertThat(queue.peek()).isEqualTo(queue.poll());
}