ThreadLocal源码分析

在理解Handler、Looper之前,先来说说ThreadLocal这个类,听名字好像是一个本地线程的意思,实际上它并不是一个Thread,而是提供一个与线程有关的局部变量功能,每个线程之间的数据互不影响。我们知道使用Handler的时候,每个线程都需要有一个looper对象,那么andorid中是怎么保存这个对象的呢,使用的就是ThreadLocal。

首先我们来看看主线程中looper是怎么初始化的。
在应用启动时,会线调用Looper.prepareMainLooper()方法,在这个方法里面会去初始化主线程需要用的looper对象

static final ThreadLocal sThreadLocal = new ThreadLocal();
public static void prepareMainLooper() {
    prepare(false);
    synchronized (Looper.class) {
        if (sMainLooper != null) {
            throw new IllegalStateException("The main Looper has already been prepared.");
        }
        sMainLooper = myLooper();
    }
}
private static void prepare(boolean quitAllowed) {
    if (sThreadLocal.get() != null) {
        throw new RuntimeException("Only one Looper may be created per thread");
    }
    //调用ThreadLocal的set()方法来保存一个looper对象
    sThreadLocal.set(new Looper(quitAllowed));
}

我们看到在Looper类中会在它被加载的时候将ThreadLocal对象创建出来,它是一个静态的变量。在我们初始化主线程的looper的时候,实际上就是直接new Looper()然后将其放在了ThreadLocal中的。

下面我们将从set()方法作为入口来具体分析ThreadLocal是怎么实现的。

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}
ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

ThreadLocal使用了泛型,该泛型就是需要存储的数据类型。在set()方法内部,首先取得了当前的线程,然后在线程对象中获取了一个threadLocals对象,在Thread类中有这么一个字段定义ThreadLocal.ThreadLocalMap threadLocals = null;,这个threadLocals对象是一个ThreadLocalMap类型的数据,默认在Thread线程中空的,看源码发现它是ThreadLocal的一个静态内部类。如果threadLocals不为null,那么久调用ThreadLocalMap.set(ThreadLocal key, Object value)方法,否则就初始化threadLocals,它的初始化很简单,直接在当前的线程对象中给threadLocals创建了一个ThreadLocal对象,同时将值保存进ThreadLocalMap中作为第一个值。

void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

到这里发现数据的具体存储还是在ThreadLocalMap这个类中。
在说ThreadLocalMap之前先看看ThreadLocal里面的一个小东西。

//该hash值可以唯一确定一个threadlocal对象,每创建一个threadlocal对象,该hash值都是唯一的
private final int threadLocalHashCode = nextHashCode();
//原子类,保证多线程下唯一
private static AtomicInteger nextHashCode =
    new AtomicInteger();
private static final int HASH_INCREMENT = 0x61c88647;
private static int nextHashCode() {
    return nextHashCode.getAndAdd(HASH_INCREMENT);
}

在ThreadLocal中还定义了如下的hash值,它在ThreadLocalMap中使用,可以唯一的确定一个threadlocal对象。

我们接着看看ThreadLocalMap中的构造方法和set方法。

//Entry是继承自WeakReference的软引用,ThreadLocal作为key对它软引用,
//同时也是一个key-value的键值对
static class Entry extends WeakReference> {
    Object value;
    Entry(ThreadLocal k, Object v) {
        super(k);
        value = v;
    }
}

private static final int INITIAL_CAPACITY = 16;
private Entry[] table;
private int size = 0;

ThreadLocalMap(ThreadLocal firstKey, Object firstValue) {
    //创建默认大小为16的Entry数组
    table = new Entry[INITIAL_CAPACITY];
    //通过hash计算出index
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    table[i] = new Entry(firstKey, firstValue);
    size = 1;
    setThreshold(INITIAL_CAPACITY);
}

firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1) 可能会让人迷惑,实际上它和firstKey.threadLocalHashCode%INITIAL_CAPACITY的计算结果是一样的,保证index的值在0到INITIAL_CAPACITY之间,不包含INITIAL_CAPACITY。但是这个前提是INITIAL_CAPACITY的值必须为2n...。
2n的二进制的表示为1000...,那么2n-1的二进制表示为0111...。是不是感觉好像发现了什么?和2n-1求与刚好是将余数部分给取出来,使用这种方式来计算index的速度要比直接使用%要快,但是是使用这个方式的前提就是INITIAL_CAPACITY的值必须为2n

private void set(ThreadLocal key, Object value) {
    ab = table;
    int len = tab.length;
    //不同的threadlocal对象,可能计算出来的index会一样
    int i = key.threadLocalHashCode & (len-1);

    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal k = e.get();
        //如果在index的地方找到了相同的key,就直接覆盖
        if (k == key) {
            e.value = value;
            return;
        }
        //如果发现有entry但是key被回收了,则覆盖
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    //如果找了一圈还是没有找到entry,那么就直接创建一个entry添加进去
    tab[i] = new Entry(key, value);
    int sz = ++size;
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

这里需要明白一个地方:如果在当前index中找到了一样的key,就直接覆盖,如果找到了entry但是key被回收了那么就替换数据,如果key不一样的话就在下一个index继续刚刚的判断。只要在某一个index中没有找到entry对象,则直接创建一个新的entry插入。

private void replaceStaleEntry(ThreadLocal key, Object value,int staleSlot) {
        Entry[] tab = table;
        int len = tab.length;
        Entry e;
        //staleSlot是一个旧数据,key被回收了,我们称它为旧数据吧
        //从staleSlot往前找到另外一个旧数据的index
        int slotToExpunge = staleSlot;
        for (int i = prevIndex(staleSlot, len);
             (e = tab[i]) != null;//这里只要为null就退出循环
             i = prevIndex(i, len))
            if (e.get() == null)
                slotToExpunge = i;
        
        //从staleSlot往后清除旧数据
        for (int i = nextIndex(staleSlot, len);
             (e = tab[i]) != null;
             i = nextIndex(i, len)) {
            ThreadLocal k = e.get();
            //如果找到了key相同的地方就替换数据
            if (k == key) {
                e.value = value;
                //交换staleSlot和i这两个位置的数据,此时tab[i]是一个旧数据
                tab[i] = tab[staleSlot];
                //staleSlot位置是一个新的数据
                tab[staleSlot] = e;

                //如果staleSlot前面不存在其他的旧数据,就记录下i这个就数据(它和staleSlot交换了数据,所以i这里变为了旧数据)
                if (slotToExpunge == staleSlot)
                    slotToExpunge = i;
                //清除旧数据
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                return;
            }
            //如果找到旧数据,并且staleSlot前面没有旧数据,记录当前i
            if (k == null && slotToExpunge == staleSlot)
                slotToExpunge = i;
        }
        //方便GC回收它
        tab[staleSlot].value = null;
        //使用一个新的entry替换掉staleSlot位置的旧数据
        tab[staleSlot] = new Entry(key, value);
        //清除其他位置的旧数据,staleSlot被新数据给替换了
        if (slotToExpunge != staleSlot)
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
    }

当key被回收的时候,会走到该方法中来。staleSlot是一个key被回收的数据,我们无法外面获取它,所以需要处理掉这些旧数据。这里替换数据分为两种情况:

1、如果在staleSlot后面找到了相同的key,则在找到的地方覆盖value同时和staleSlot交换位置
2、如果没有找到key,就在staleSlot重新创建新的entry覆盖旧数据

private int expungeStaleEntry(int staleSlot) {
        Entry[] tab = table;
        int len = tab.length;
        //清除旧数据
        tab[staleSlot].value = null;
        tab[staleSlot] = null;
        size--;

        Entry e;
        int i;
        for (i = nextIndex(staleSlot, len);
             (e = tab[i]) != null;
             i = nextIndex(i, len)) {
            ThreadLocal k = e.get();
            if (k == null) {
                e.value = null;
                tab[i] = null;
                size--;
            } else {
                int h = k.threadLocalHashCode & (len - 1);
                if (h != i) {
                    //重新放置entry数据
                    tab[i] = null;
                    while (tab[h] != null)
                        h = nextIndex(h, len);
                    tab[h] = e;
                }
            }
        }
        return i;
    }
//清除staleSlot后面的旧数据,每调用一次expungeStaleEntry(),从该方法返回值继续清除后面的旧数据
private boolean cleanSomeSlots(int i, int n) {
        boolean removed = false;
        Entry[] tab = table;
        int len = tab.length;
        do {
            i = nextIndex(i, len);
            Entry e = tab[i];
            if (e != null && e.get() == null) {
                n = len;
                removed = true;
                i = expungeStaleEntry(i);
            }
        } while ( (n >>>= 1) != 0);
        return removed;
    }

在前面构造方法中调用了setThreshold(INITIAL_CAPACITY)这个方法,设置一个数组大小的阈值,如果数组中的数据个数超过了它那么就调用rehash()行扩展处理。

private void rehash() {
        expungeStaleEntries();
        if (size >= threshold - threshold / 4)
            resize();
    }

//这里会便利所有的元素来进行旧数据的清除处理
private void expungeStaleEntries() {
    Entry[] tab = table;
    int len = tab.length;
    for (int j = 0; j < len; j++) {
        Entry e = tab[j];
        if (e != null && e.get() == null)
            expungeStaleEntry(j);
    }
}
//重新计算大小,扩容处理在这里面处理
private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    //在原来的基础上扩大2倍
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    int count = 0;

    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal k = e.get();
            if (k == null) {
                e.value = null; // GC回收
            } else {
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                //如果为null则将原数组中的数据添加进来
                newTab[h] = e;
                count++;
            }
        }
    }

    setThreshold(newLen);
    size = count;
    table = newTab;
}

set()方法的分析到上面就结束了,有了上面的经验,get()方法分析起来就更简单了。简单说一下get()方法。具体的说明看注释

public T get() {
    Thread t = Thread.currentThread();
    //获取线程关联的LocalThreadMap
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            //如果entry不为null就返回value值
            T result = (T)e.value;
            return result;
        }
    }
    //返回初始化的值,如果不覆写initialValue()这里的返回值就是null
    return setInitialValue();
}

private T setInitialValue() {
    //ThreadLocal有一个默认的initialValue()方法返回null
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        //初始化threadlocalmap
        createMap(t, value);
    return value;
}

最终还是会调用ThreadLocalMap的getEntry()方法

private Entry getEntry(ThreadLocal key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    //匹配到可以就返回
    if (e != null && e.get() == key)
        return e;
    else
        //key被回收的话获取数据调用此方法
        return getEntryAfterMiss(key, i, e);
}

private Entry getEntryAfterMiss(ThreadLocal key, int i, Entry e) {
        Entry[] tab = table;
        int len = tab.length;
        
        //向右查找,找到就返回,找不到就返回null
        while (e != null) {
            ThreadLocal k = e.get();
            if (k == key)
                return e;
            if (k == null)
                expungeStaleEntry(i);
            else
                i = nextIndex(i, len);
            e = tab[i];
        }
        return null;
    }

内存泄漏

ThreadLocalMap中对key使用了软引用,当threadlocal对象在外面没有被使用的时候,gc就有可能会回收它,这样就导致了value值被ThreadLocalMap强引用无法释放调用造成内存泄漏,除非是ThreadLocalMap关联的thread线程被回收。只要线程还存活,就真的是内存泄漏了。尤其是在线程池中线程被重复使用,如果ThreadLocal使用不当就很容易造成内存泄漏了。
所以为了避免出现这种情况,在我们使用ThreadLocal的时候,如果不在需要使用threadlocal了,一定要先调用一次remove方法来清除数据。
当然在我们每次set和get的时候也会去处理一些旧数据,但是只要不去调用这些get和set方法,就不会触发去清理旧数据。

你可能感兴趣的:(ThreadLocal源码分析)