Java无锁队列与栈的实现

参考:《 Implementing Lock-Free Queues》。

      尽管这篇文章讲的是无锁队列,但是引用《Java并发实践》中的一句话,并发环境下,首先应该保证正确性,其次才是性能。在没有证明现情况下性能确实需要提高,且锁机制无法满足的时候,才应该考虑无锁。确实,无锁实现的难度随着需求要求会迅速提高,相对于锁机制,难把控的多。

     无锁的基础是CAP(Compare And Swap)操作,这是由硬件提供的。Java中,原子类型对这种机制进行了包装。

     例如,i++,它由读取i值,i+1运算,写回新值三个操作组成,如果无法保证这三个操作整体的原子性,就可能出问题。使用CAP实现示例如下:

        AtomicInteger i = new AtomicInteger(0);
        int oldValue;
        int newValue;
        do{
             oldValue = i.get();
             newValue = oldValue + 1;
        }while (!i.compareAndSet(oldValue,newValue));

每一次循环都完成一次取值和加1操作,然后进行比较,如果当前旧值还没有改变,则更新,否者继续下一次尝试。在并发中,共享资源出现不一致的状态基本都是由于写依赖于读的操作,例如上面的i++,i的新值依赖于之前的旧值。而在不可更改的共享数据,或者更新不依赖于之前旧值的情况下是不会出现问题的,例如直接设值。要保证这种操作的正确性,就需要保证读与写整体的原子性,比如使用锁来保证。

而在无锁操作中,是无法阻止多个线程同时对共享资源进行更改,而采用的方式就是在写的时候进行验证。写操作采用CAP,该操作会比较当前值与提供的旧值,若当前值与旧值相等,这表示该线程上一次读之后没有其他线程对该资源进行更改,也就是该线程这一组运算之后的值是有效的,则将该值更新。

图1

例如,线程1需要对一个值进行+5操作,线程2需要对其进行+10操作,如图2。线程1先读取到值,并先写入值。如果没有采取任何措施,那么最终的结果将会是线程2的更新结果15,这显然不是我们想要的结果。采取CAP操作时,线程1比较当前值与旧值,都是5,则将值更新为10。线程2进行比较的时候,会发现值已经不是之前自己读取到的5,所以更新操作将失败,线程2将再一次回到起点,进行一次尝试,直到成功更新为止。


图2

以下是无锁队列的实现代码。

package cn.yuanye.concurrence.LockFreeCollection;

import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

/**
 * 

A free lock queue,based on linked list

*/ public class LockFreeQueue { //the queue node private class Node { public V value = null; public AtomicReference> next = null; public Node(V value, Node next) { this.value = value; this.next = new AtomicReference>(next); } } private AtomicReference> head = null; //queue head private AtomicReference> tail = null; //queue tail private AtomicInteger queueSize = new AtomicInteger(0); //size of the queue public LockFreeQueue(){ Node dummy = new Node(null,null); //init an dummy node //init head and tail,reference to the same dummy node head = new AtomicReference>(dummy); tail = new AtomicReference>(dummy); } /** *

Add an value to the end of the queue

*

This method is based on CAP operation,and is thread safe.

*

It guarantee the value will eventually add into the queue

* @param value the value to be added into the queue */ public void enQueue(V value) { Node newNode = new Node(value,null); Node oldTail = null; while(true){ oldTail = tail.get(); AtomicReference> nextNode = oldTail.next; if(nextNode.compareAndSet(null,newNode)){ break; }else{ tail.compareAndSet(oldTail,oldTail.next.get()); } } queueSize.getAndIncrement(); tail.compareAndSet(oldTail,oldTail.next.get()); } /** *

Get an Value from the queue

*

This method is based on CAP operation,thread safe

*

It guarantees return an value or null if queue is empty eventually

* @return value on the head of the queue,or null when queue is empty */ public V deQueue() { while(true){ Node oldHead = head.get(); Node oldTail = tail.get(); AtomicReference> next = oldHead.next; if(next.get() == null){ return null; ///queue is empty } if(oldHead == tail.get()){ tail.compareAndSet(oldTail, oldTail.next.get()); //move the tail to last node continue; } if(head.compareAndSet(oldHead,oldHead.next.get())){ queueSize.getAndDecrement(); return oldHead.next.get().value; } } } /** *

Get the size of the stack

*

This method doesn't reflect timely state when used in concurrency environment

* @return size of the stack */ public int size() { return queueSize.get(); } /** *

Check if the stack is empty

*

This method doesn't reflect timely state when used in concurrency environment

* @return false unless stack is empty */ public boolean isEmpty() { return queueSize.get() == 0; } }

队列初始化时,头部和尾部指向了同一个节点,该节点不存储任何数据,仅仅是为了避免出栈和入栈操作作用于同一个节点上。

如果这两个操作同时作用在一个节点上,就会出现问题。比如,线程A进行出栈操作,线程B进行入栈操作,当链只有一个节点时,这两个线程会引用到同一个节点,当线程A从链上获取到该节点后,线程B却是不知道的。B比较该节点的Next,因为没有其他节点进行入栈操作,那么B将成功的将新节点链接到该Next引用上,此时B以为自己成功的将新节点加入了队列中。而实际上,B线程将新节点加入到了一个已经不再栈内的节点末尾。


如下是栈的实现代码,相比队列要简单不少。

package cn.yuanye.concurrence.LockFreeCollection;

import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

/**
 * 

A lock free thread safe stack,based on linked list

*/ public class LockFreeStack { private AtomicReference> head = new AtomicReference>(null); private AtomicInteger size = new AtomicInteger(0); /** *

stack node,consist of value and an link to the next node

* @param type of value */ private static class Node{ /** the value of this node,not null*/ public V value; /** link to next node,if it's the last ,next reference to null*/ AtomicReference> next ; public Node(V value,Node next){ this.value = value; this.next = new AtomicReference>(next); } } /** *

Pop an value from the stack.

*

This method is based on CAS operation, and is thread safe.

*

When used in concurrency environment ,only one thread will get the value on the top for once, * the rest thread will try to get next ones,until the stack is empty

* @return if stack is not empty,rerun an not null value,or null when the stack is null. */ public V pop(){ Node oldHead = null; Node next = null; do{ oldHead = head.get(); if(oldHead == null){ return null; //empty stack } next = oldHead.next.get(); }while (!head.compareAndSet(oldHead,next)); size.getAndDecrement(); return oldHead.value; } /** *

Push an value into the stack

*

This method is based on CAP operation,and is thread safe

*

When used in concurrency environment, only one thread will succeed once, * the rest will try again ,until succeed

* @param value value to put into the stack,not null * @exception NullPointerException throws when value is null */ public void push(V value){ if(value == null){ throw new NullPointerException("value is null"); } Node newNode = new Node(value,null); Node oldHead ; do{ oldHead = head.get(); newNode.next.set(oldHead); }while(!head.compareAndSet(oldHead,newNode)); size.getAndIncrement(); } /** *

Get the size of the stack

*

This method doesn't reflect timely state when used in concurrency environment

* @return size of the stack */ public int size(){ return size.get(); } /** *

Check is the stack is empty

*

This method doesn't reflect timely state when used in concurrency environment

* @return false unless stack is empty */ public boolean isEmpty(){ return head.get() == null; } }

附上简陋的测试代码。

队列的测试。

package test.cn.yuanye.concurrence.LockFreeCollection;

import cn.yuanye.concurrence.LockFreeCollection.LockFreeQueue;
import org.junit.Test;

import static org.junit.Assert.*;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * Created by Administrator on 14-1-12.
 */
public class LockFreeQueueTest {

    static class Producer implements Runnable {
        private static int count = 0;
        private int id = ++count;
        private CountDownLatch start ;
        private CountDownLatch end;
        private int  n = 0;
        private LockFreeQueue queue ;

        public Producer(CountDownLatch start,CountDownLatch end,int n,LockFreeQueue queue){
            this.start = start;
            this.end = end;
            this.n = n;
            this.queue = queue;
        }


        @Override
        public void run() {
            try {
                start.await();
            } catch (InterruptedException e) {
                return;
            }
            for (int i = 0; i < n; i++) {
                queue.enQueue(id +" : " + i);
            }
            end.countDown();
        }
    }

    static class Consumer implements Runnable {

        private CountDownLatch start ;
        private CountDownLatch end;
        private AtomicInteger count;
        private LockFreeQueue queue ;

        public Consumer(CountDownLatch start,CountDownLatch end,AtomicInteger count,LockFreeQueue queue){
            this.start = start;
            this.end = end;
            this.count = count;
            this.queue = queue;
        }
        @Override
        public void run() {
            try {
                start.await();
            } catch (InterruptedException e) {
                return;
            }

            while (queue.deQueue() != null) {
                count.getAndIncrement();
            }
            end.countDown();
        }
    }

    @Test
    public void deQueuetest() throws InterruptedException {
        final int testTimes = 1000;
        final int nThread = 100;
        final int nProduct = 5000;
        CountDownLatch start;
        CountDownLatch end;
        LockFreeQueue queue = new LockFreeQueue();
        AtomicInteger count = new AtomicInteger(0);

        for (int t = 0; t < testTimes;  t++) {
            //init the product
            for(int i = 0 ; i < nProduct ; i ++ ){
                queue.enQueue(i +"");
            }
            count.set(0);
            start = new CountDownLatch(1);
            end = new CountDownLatch(nThread);

            for(int i = 0; i < nThread ;i++){
                new Thread(new Consumer(start,end,count,queue)).start();
            }
            start.countDown();
            end.await();

            if(nProduct != count.get()){
                fail("should be " + nProduct + " actual is " + count.get());
            }
        }
    }

    @Test
    public void enQueuetest() {
        final int testTimes = 1000;
        final int nThread = 10;
        final int nP = 500;
        CountDownLatch start;
        CountDownLatch end;
        LockFreeQueue queue = new LockFreeQueue();

        for (int t = 0; t < testTimes;  t++) {

            while (queue.deQueue() != null){};  clear the queue
            start = new CountDownLatch(1);
            end = new CountDownLatch(nThread);

            for (int i = 0; i < nThread; i++) {
                new Thread(new Producer(start,end,nP,queue)).start();
            }
            start.countDown();
            try {
                end.await();
            } catch (InterruptedException e) {
                return;
            }

            if (queue.size() != nThread * nP) {
                fail("times " + t + " should be " + nThread * nP + " but actual is  " + queue.size());
            }
        }

    }

    @Test
    public void integratedtest(){
        final int nProducer = 20;
        final int nCosumer = 20;
        final int testTimes = 1000;
        final int productsPerProducer = 500;
        final AtomicInteger count = new AtomicInteger(0);
        LockFreeQueue queue = new LockFreeQueue();
        CountDownLatch start;
        CountDownLatch end;


        for(int i = 0 ; i < testTimes ; i++){
            count.set(0);
            while(queue.deQueue() != null){} //clear the queue
            start = new CountDownLatch(1);
            end = new CountDownLatch(nCosumer + nProducer);

            for(int j = 0 ; j < nCosumer ; j++){
                new Thread(new Consumer(start,end,count,queue)).start();
            }

            for(int j = 0 ; j < nProducer ; j++){
                new Thread(new Producer(start,end,productsPerProducer,queue)).start();
            }

            start.countDown();
            try {
                end.await();
            } catch (InterruptedException e) {  }

            if(count.get() != nProducer * productsPerProducer - queue.size()){
                fail("count.get()= " + count.get() + " queue.size()= " + queue.size());
            }

        }


    }
}


栈的测试。

package test.cn.yuanye.concurrence.LockFreeCollection;

import cn.yuanye.concurrence.LockFreeCollection.LockFreeStack;
import org.junit.Test;
import org.junit.Before;
import org.junit.After;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;

import static org.junit.Assert.fail;

/**
 * LockFreeStack Tester.
 *
 * @author 
 * @version 1.0
 * @since 
һ�� 14, 2014
*/ public class LockFreeStackTest { private LockFreeStack stack = new LockFreeStack(); private CountDownLatch start; private CountDownLatch end; static class Poper extends Thread { private LockFreeStack stack; CountDownLatch start; CountDownLatch end; AtomicInteger count; public Poper(LockFreeStack stack, AtomicInteger count, CountDownLatch start, CountDownLatch end) { this.start = start; this.end = end; this.count = count; this.stack = stack; } @Override public void run() { try { start.await(); } catch (InterruptedException e) { } while (stack.pop() != null) { count.getAndIncrement(); } end.countDown(); } } static class Pusher extends Thread { private LockFreeStack stack; private int nProduct; private CountDownLatch start; private CountDownLatch end; public Pusher(LockFreeStack stack, int n, CountDownLatch start, CountDownLatch end) { this.stack = stack; this.nProduct = n; this.start = start; this.end = end; } @Override public void run() { try { start.await(); } catch (InterruptedException e) { } for (int i = 0; i < nProduct; i++) { stack.push(i); } end.countDown(); } } @Before public void before() throws Exception { } @After public void after() throws Exception { } /** * Method: pop() */ @Test public void testPop() throws Exception { AtomicInteger count = new AtomicInteger(0); final int testTimes = 10000; final int stackSize = 10000; final int nThread = 10; for (int i = 0; i < testTimes; i++) { //init the stack int j = stack.size(); while (j < stackSize) { stack.push(j++); } start = new CountDownLatch(1); end = new CountDownLatch(nThread); count.set(0); for(int t = 0; t < nThread ; t ++){ new Poper(stack,count,start,end).start(); } start.countDown(); end.await(); if(stackSize != count.get()){ fail("times : " + i +" stackSize = " + stackSize +" pop count " + count.get()); } } } /** * Method: push(V value) */ @Test public void testPush() throws Exception { final int nThread = 20; final int testTime = 10000; final int nProducePerThread = 100; for(int i = 0; i < testTime ; i++){ start = new CountDownLatch(1); end = new CountDownLatch(nThread); while(stack.pop() != null); //clear ths stack for(int t = 0 ; t < nThread ; t++){ new Pusher(stack,nProducePerThread,start,end).start(); } start.countDown(); end.await(); if(stack.size() != nProducePerThread * nThread){ fail("stack.size = " + stack.size() + " should be " + nProducePerThread * nThread); } } } @Test public void testPopPush() throws Exception { final int testTimes = 10000; final int nPoper = 20; final int nPusher = 20; final int nProduct = 100; AtomicInteger count = new AtomicInteger(0); for(int i = 0 ; i < testTimes ; i++){ count.set(0); while (stack.pop() != null); //clear the stack start = new CountDownLatch(1); end = new CountDownLatch(nPoper + nPusher); for(int t = 0 ; t < nPusher ; t ++){ new Pusher(stack,nProduct,start,end).start(); } for(int t = 0 ; t < nPoper ; t ++){ new Poper(stack,count,start,end).start(); } start.countDown(); end.await(); if(count.get() + stack.size() != nProduct * nPusher){ fail("times " + i + " count " + count.get() +" stack.size " + stack.size() +" total should be " + nProduct * nPusher); } } } }




你可能感兴趣的:(基础)