理解ThreadLocal

原文链接:理解ThreadLocal

ThreadLocal是一个线程级别的局部变量,并非“本地线程”。ThreadLocal为每个使用该变量的线程提供了一个独立的变量副本,每个线程修改副本时不影响其它线程对象的副本。

下面是线程局部变量(ThreadLocal variables)的关键点:

1.一个线程局部变量(ThreadLocal variables)为每个线程方便地提供了一个单独的变量。
2.ThreadLocal实例通常作为静态的私有的(private static)字段出现在一个类中,这个类用来关联一个线程
3.当多个线程访问ThreadLocal实例时,每个线程维护ThreadLocal提供的独立的变量副本。

变量的作用域

class Foo {
    public static int sn;  // 全局变量
    int i;  // 成员变量, 对所有方法可见

    void bar() {
        int j;  // 局部变量,对当前方法可见
    }
}

以上代码中,sn, i, j有各自不同的作用域。可以满足不同的使用场景。

线程级别的局部变量

下面这段代码的功能是:若干个线程对一个整形变量进行累加, 检查累加结果是否为预期值(总的循环次数)。

public class UseThreadLocal {

    static int count;

    //synchronized
    static void increment() {
        count++;
    }

    /**
     * 若干个线程对一个整形变量进行累加, 检查累加结果是否为预期值(循环次数)
     * @param args
     * @throws InterruptedException
     */
    public static void main(String[] args) throws InterruptedException {

        int[] loops = { 1000000, 1000, 10000, 1000000, 1000000 };
        int num = loops.length;
        Thread[] ts = new Thread[num];

        long now = System.currentTimeMillis();
        for (int i = 0; i < ts.length; i++) {
            ts[i] = new Thread(new MyRunnable(loops[i]));
            ts[i].start();
        }

        for (Thread thread : ts) {
            thread.join();
        }
        long cost = System.currentTimeMillis() - now;

        int sumOfLoops = 0;
        for (int i = 0; i < loops.length; i++) {
            sumOfLoops += loops[i];
        }

        System.out.println("cost=" + cost + " ms");
        System.out.println("sumOfLoops=" + sumOfLoops + ",count="
                + UseThreadLocal.count);
        System.out.println("sumOfLoops==count? "
                + (sumOfLoops == UseThreadLocal.count));
    }

    static class MyRunnable implements Runnable {

        int loop;

        public MyRunnable(int loop) {
            super();
            this.loop = loop;
        }

        @Override
        public void run() {
            for (int i = 0; i < loop; i++) {
                UseThreadLocal.increment();
            }
        }

    }
}

请注意其中的static void increment()方法。显然当我们去掉该方法的synchronized时,很可能得不到正确的预期值。而加上synchronized后,保证可以得到正确的预期值。

去掉synchronized,输出如下

cost=9 ms
sumOfLoops=3011000,count=2788360
sumOfLoops==count? false

加上synchronized,输出如下

cost=406 ms
sumOfLoops=3011000,count=3011000
sumOfLoops==count? true

除了关心预期值是否正确外,我们还应注意到性能上的差异。再次执行时间分别是9ms和406ms,可见同步方法的代价是巨大的

假设这里有一种变量,是线程私有的、局部的,可以在当前线程上下文中自由操作(不需要同步),那我们就可以去掉synchronized并且保证程序结果是正确的。当然,可以直接在MyRunnable里面添加一个变量myCount,比如

static class MyRunnable implements Runnable {

    int loop;
    int myCount;

    public MyRunnable(int loop) {
        super();
        this.loop = loop;
    }

    @Override
    public void run() {
        for (int i = 0; i < loop; i++) {
            myCount++;
        }
    }
}

线程结束后再对每个线程的myCount求和。下面是用的另一种方式,即借助ThreadLocal来避免同步方法。

public class UseThreadLocal {

    /**
     * 计数器
     *
     */
    static class Counter {

        static int count;

        private static ThreadLocal<Integer> countLocal = new ThreadLocal<Integer>() {
            protected Integer initialValue() {
                return 0;
            }
        };

        void increment() {
            countLocal.set(countLocal.get() + 1);
        }

        int get() {
            return countLocal.get();
        }

        synchronized void addToCount() {
            count += get();
        }

        int getCount() {
            return count;
        }
    }

    /**
     * 若干个线程对一个整形变量进行累加, 检查累加结果是否为预期值(循环次数)
     * @param args
     * @throws InterruptedException
     */
    public static void main(String[] args) throws InterruptedException {

        int[] loops = { 1000000, 1000, 10000, 1000000, 1000000 };
        int num = loops.length;
        Thread[] ts = new Thread[num];

        Counter wrapper = new Counter();

        long now = System.currentTimeMillis();
        for (int i = 0; i < ts.length; i++) {
            //ts[i] = new Thread(new MyRunnable(loops[i]));
            ts[i] = new Thread(new LocalRunnable(loops[i], wrapper));
            ts[i].start();
        }

        for (Thread thread : ts) {
            thread.join();
        }
        long cost = System.currentTimeMillis() - now;

        int sumOfLoops = 0;
        for (int i = 0; i < loops.length; i++) {
            sumOfLoops += loops[i];
        }

        System.out.println("cost=" + cost + " ms");
        System.out.println("sumOfLoops=" + sumOfLoops + ",count="
                + wrapper.getCount());
        System.out.println("sumOfLoops==count? "
                + (sumOfLoops == wrapper.getCount()));
    }

    static class LocalRunnable implements Runnable {
        int loop;
        Counter wrapper;

        public LocalRunnable(int loop, Counter wrapper) {
            super();
            this.loop = loop;
            this.wrapper = wrapper;
        }

        @Override
        public void run() {
            for (int i = 0; i < loop; i++) {
                wrapper.increment();
            }

            // 这里仍然需要注意同步问题
            wrapper.addToCount();
        }

    }
}

输出结果如下:

cost=81 ms
sumOfLoops=3011000,count=3011000
sumOfLoops==count? true

可以看到,不仅能保证结果正确,性能也比前面的同步方法有所提高。

ThreadLocal到底是什么

早在JDK 1.2的版本中就提供java.lang.ThreadLocal,ThreadLocal为解决多线程程序的并发问题提供了一种新的思路。使用这个工具类可以很简洁地编写出优美的多线程程序。

当使用ThreadLocal维护变量时,ThreadLocal为每个使用该变量的线程提供独立的变量副本,所以每一个线程都可以独立地改变自己的副本,而不会影响其它线程所对应的副本。

从线程的角度看,目标变量就象是线程的本地变量,这也是类名中“Local”所要表达的意思。

ThreadLocal的方法
该类的接口非常简单,只有4个方法。

  • void set(T value)
    设置当前线程的线程局部变量的值
  • T get()
    该方法返回当前线程所对应的线程局部变量
  • void remove()
    将当前线程局部变量的值删除
  • T initialValue()
    返回该线程局部变量的初始值。这个方法是一个延迟调用方法,在线程第1次调用get()或set(Object)时才执行,并且仅执行1次。ThreadLocal中的缺省实现直接返回一个null

如何实现ThreadLocal

下面是一个最简单的ThreadLocal。(JDK的实现也是类似思路)

public class SimpleThreadLocal {
    private Map valueMap = Collections.synchronizedMap(new HashMap());

    public void set(Object newValue) {
        valueMap.put(Thread.currentThread(), newValue);//①键为线程对象,值为本线程的变量副本
    }

    public Object get() {
        Thread currentThread = Thread.currentThread();
        Object o = valueMap.get(currentThread);//②返回本线程对应的变量
        if (o == null && !valueMap.containsKey(currentThread)) {//③如果在Map中不存在,放到Map中保存起来。
            o = initialValue();
            valueMap.put(currentThread, o);
        }
        return o;
    }

    public void remove() {
        valueMap.remove(Thread.currentThread());
    }

    public Object initialValue() {
        return null;
    }
}

说白了,其实ThreadLocal目的不过是将线程跟线程不安全的数据关联起来。前面提到,可以在MyRunnable里面添加一个变量myCount,这是一种直接建立关联的办法。可以认为ThreadLocal是一种间接的方式。对比如下图
理解ThreadLocal_第1张图片

在同步机制中,通过对象的锁机制保证同一时间只有一个线程访问变量。而ThreadLocal则从另一个角度来解决多线程的并发访问。ThreadLocal会为每一个线程提供一个独立的变量副本,从而隔离了多个线程对数据的访问冲突。因为每一个线程都拥有自己的变量副本,从而也就没有必要对该变量进行同步了。ThreadLocal提供了线程安全的共享对象,在编写多线程代码时,可以把不安全的变量封装进ThreadLocal。

概括起来说,对于多线程资源共享的问题,同步机制采用了“以时间换空间”的方式,而ThreadLocal采用了“以空间换时间”的方式。前者仅提供一份变量,让不同的线程排队访问,而后者为每一个线程都提供了一份变量,因此可以同时访问而互不影响。

使用场景

例1 使用线程相关的Connection

public class TopicDao {
    private Connection conn;//①一个非线程安全的变量

    public void addTopic() {
        Statement stat = conn.createStatement();//②引用非线程安全变量
    }
}

由于①处的conn是成员变量,因为addTopic()方法是非线程安全的,必须在使用时创建一个新TopicDao实例。 使用ThreadLocal对conn这个非线程安全的“状态”进行改造:

public class TopicDao {
    //①使用ThreadLocal保存Connection变量
    private static ThreadLocal connThreadLocal = new ThreadLocal();

    public static Connection getConnection() {
        //②如果connThreadLocal没有本线程对应的Connection创建一个新的Connection,
        //并将其保存到线程本地变量中。
        if (connThreadLocal.get() == null) {
            Connection conn = ConnectionManager.getConnection();
            connThreadLocal.set(conn);
            return conn;
        } else {
            return connThreadLocal.get();//③直接返回线程本地变量
        }
    }

    public void addTopic() {
        //④从ThreadLocal中获取线程对应的Connection
        Statement stat = getConnection().createStatement();
    }
}

不同的线程在使用TopicDao时,先判断connThreadLocal.get()是否是null,如果是null,则说明当前线程还没有对应的Connection对象,这时创建一个Connection对象并添加到本地线程变量中;如果不为null,则说明当前的线程已经拥有了Connection对象,直接使用就可以了。这样,就保证了不同的线程使用线程相关的Connection,而不会使用其它线程的Connection。

例2 Handler所在线程必须有一个Looper

Android中一个线程中如果想要使用Handler,必须保证当前线程有一个关联的Looper。Handler的构造方法会调用Looper.myLooper()得到当前线程的Looper,如果为null则抛出RuntimeException。

public Handler(Callback callback, boolean async) {
    if (FIND_POTENTIAL_LEAKS) {
        final Class<? extends Handler> klass = getClass();
        if ((klass.isAnonymousClass() || klass.isMemberClass() || klass.isLocalClass()) &&
                (klass.getModifiers() & Modifier.STATIC) == 0) {
            Log.w(TAG, "The following Handler class should be static or leaks might occur: " +
                klass.getCanonicalName());
        }
    }

    mLooper = Looper.myLooper();
    if (mLooper == null) {
        throw new RuntimeException(
            "Can't create handler inside thread that has not called Looper.prepare()");
    }
    mQueue = mLooper.mQueue;
    mCallback = callback;
    mAsynchronous = async;
}

Looper代码如下。Looper.myLooper()方法中用到了ThreadLocal类。

public class Looper {
    // sThreadLocal.get() will return null unless you've called prepare().
    static final ThreadLocal<Looper> sThreadLocal = new ThreadLocal<Looper>();


     /** Initialize the current thread as a looper.
      * This gives you a chance to create handlers that then reference
      * this looper, before actually starting the loop. Be sure to call
      * {@link #loop()} after calling this method, and end it by calling
      * {@link #quit()}.
      */
    public static void prepare() {
        prepare(true);
    }

    private static void prepare(boolean quitAllowed) {
        if (sThreadLocal.get() != null) {
            throw new RuntimeException("Only one Looper may be created per thread");
        }
        sThreadLocal.set(new Looper(quitAllowed));
    }

    /**
     * Run the message queue in this thread. Be sure to call
     * {@link #quit()} to end the loop.
     */
    public static void loop() {
        final Looper me = myLooper();
        if (me == null) {
            throw new RuntimeException("No Looper; Looper.prepare() wasn't called on this thread.");
        }
        final MessageQueue queue = me.mQueue;

        for (;;) {
            Message msg = queue.next(); // might block
            if (msg == null) {
                // No message indicates that the message queue is quitting.
                return;
            }

            msg.target.dispatchMessage(msg);
            msg.recycle();
        }
    }

    /**
     * Return the Looper object associated with the current thread.  Returns
     * null if the calling thread is not associated with a Looper.
     */
    public static Looper myLooper() {
        return sThreadLocal.get();
    }
}

你可能感兴趣的:(java,threadLocal,handler)