并发队列实现练习

代码:

package conSet;

import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

/**
 * 并发单向队列简单实现
 * 
 * @author dingchd
 * 
 * @param <T>
 */
public class NoBlockQueue<T> {
	private Node<T> header;
	private AtomicReference<Node<T>> tail;

	private AtomicInteger size;

	public NoBlockQueue() {
		header = new Node<T>();
		tail = new AtomicReference<Node<T>>(header);
		size = new AtomicInteger(0);
	}

	/**
	 * 存元素的过程分两步骤:原子更新尾节点的next、原子更新尾节点 如果第二部更新失败 则原子还原尾节点的next
	 * 
	 * @return
	 */
	public void add(T t) {
		// 创建一个节点
		Node<T> node = new Node<T>();
		node.value = t;

		Node<T> curTail = null;
		for (;;) {
			curTail = tail.get();

			if (curTail.next.get() == null) {
				if (casNext(curTail, null, node)) {
					if (casTail(curTail, node)) {
						size.incrementAndGet();
						return;
					} else {
						curTail.next.getAndSet(null);
					}
				}
			}
		}
	}

	/**
	 * 取元素分两部:原子更新header的next、第一个元素为尾节点,则将尾节点原子更新到header 如果第二部失败,则原则还原第一步
	 * 
	 * @return
	 */
	public T poll() {
		Node<T> first = null;
		T value = null;
		for (;;) {
			first = header.next.get();
			Node<T> curTail = tail.get();

			// 队列空
			if (curTail == header && first == null) {
				break;
			}

			// 中间状态
			if ((first != null && curTail == header)
					|| (first == null && curTail != header)) {
				continue;
			}

			if (first != null) {
				// 如果tail指向第一个元素,则取队首后将tail更新至header
				if (curTail == first) {
					if (casHeaderNext(first, null)) {
						if (casTail(curTail, header)) {
							value = first.value;
							break;
						} else {
							header.next.getAndSet(first);
						}
					}
				} else {
					Node<T> second = first.next.get();

					// 如果second为null,则说明当前获得的first已经被其他线程取走
					if (second != null) {
						if (casHeaderNext(first, second)) {
							value = first.value;
							break;
						}
					}
				}
			}
		}

		if (value != null) {
			size.decrementAndGet();
		}

		return value;
	}

	public boolean isEmpty() {
		return tail.get().value == null;
	}

	public T top() {
		Node<T> first = header.next.get();
		return first == null ? null : first.value;
	}

	public int size() {
		return size.get();
	}

	private final boolean casHeaderNext(Node<T> before, Node<T> after) {
		return header.next.compareAndSet(before, after);
	}

	private final boolean casTail(Node<T> before, Node<T> after) {
		return tail.compareAndSet(before, after);
	}

	private final boolean casNext(Node<T> node, Node<T> before, Node<T> after) {
		return node.next.compareAndSet(before, after);
	}

	static class Node<T> {
		T value;
		AtomicReference<Node<T>> next = new AtomicReference<Node<T>>();
	}
}

 
测试代码:

package conSet;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Queue;
import java.util.UUID;
import java.util.concurrent.ConcurrentLinkedQueue;

public class NoBlockQueueTest2 {
	public static int SIZE = 10000;
	public static int C_NUM = 10;

	/**
	 * @param args
	 */
	public static void main(String[] args) {
		for (int i = 0; i < 10000; i++) {
			test();
		}
	}

	public static void test() {
		NoBlockQueue<String> queue = new NoBlockQueue<String>();

		Queue<String> input = new ConcurrentLinkedQueue<String>();
		Queue<String> output = new ConcurrentLinkedQueue<String>();

		for (int i = 0; i < C_NUM; i++) {
			Runnable mp = new MP(queue, input);
			new Thread(mp).start();
		}

		List<Thread> list = new ArrayList<Thread>();
		for (int i = 0; i < C_NUM; i++) {
			Runnable mc = new MC(queue, output);
			Thread t = new Thread(mc);
			t.start();
			list.add(t);
		}

		for (Thread t : list) {
			try {
				t.join();
			} catch (InterruptedException e) {
				e.printStackTrace();
			}
		}

		ArrayList<String> sort1 = new ArrayList<String>();
		ArrayList<String> sort2 = new ArrayList<String>();
		while (!input.isEmpty()) {
			sort1.add(input.poll());
		}
		while (!output.isEmpty()) {
			sort2.add(output.poll());
		}

		Collections.sort(sort1);
		Collections.sort(sort2);

		if (sort1.size() != sort2.size()) {
			throw new RuntimeException("test error,size not equal");
		}

		for (int i = 0; i < sort1.size(); i++) {
			String left = sort1.get(i);
			String right = sort2.get(i);
			if (!left.equals(right)) {
				throw new RuntimeException("test error,data wrong");
			}
		}

		System.out.println("test ok size=" + queue.size());
	}

	static class MP implements Runnable {
		NoBlockQueue<String> queue;
		Queue<String> input;

		public MP(NoBlockQueue<String> queue, Queue<String> input) {
			super();
			this.queue = queue;
			this.input = input;
		}

		public void run() {
			for (int i = 0; i < NoBlockQueueTest2.SIZE; i++) {
				String s = UUID.randomUUID().toString();
				input.add(s);
				queue.add(s);
			}
		}
	}

	static class MC implements Runnable {
		NoBlockQueue<String> queue;
		Queue<String> output;

		public MC(NoBlockQueue<String> queue, Queue<String> output) {
			super();
			this.queue = queue;
			this.output = output;
		}

		public void run() {
			final int count = NoBlockQueueTest2.C_NUM * NoBlockQueueTest2.SIZE;
			for (;;) {
				String s = queue.poll();
				if (s != null) {
					output.add(s);
				} else {
					if (output.size() == count) {
						break;
					}
				}
			}
		}
	}
}

 
因为没有实现remove和itr功能,因此复杂度甚微,经过10000次的不断测试,尚未发现测试失败

 

你可能感兴趣的:(java,并发集合)