JDK1.8源码分析:线程本地变量ThreadLocal的使用与实现原理

一、概述

  • 在Java多线程编程当中,对于被多个线程的共享变量,一般的方式是通过加锁,如使用synchronized关键字或者Java并发包的ReentrantLock加锁来实现线程安全,或者该变量在Java并发包存在线程安全的版本实现,如整数Integer对应的AtomicInteger,HashMap对应的ConcurrentHashMap等,则使用对应的线程安全版本的实现。
  • 除了以上两种方式之外,Java还提供了另外一种方式就是使用线程本地变量ThreadLocal来对共享变量进行包装。ThreadLocal是基于空间换时间的思路来设计的,即通过使用ThreadLocal对共享变量进行包装,使得每个线程都包含这个共享变量的一个副本,每个线程都对自己的共享变量副本进行操作,这样就实现了这个共享变量对每个线程的独立性,这样就不需要通过加锁来实现线程安全。
  • 不过由于每个线程都包含了这个共享变量的一个副本,所以会额外占用一定的内存空间,并且会随着线程数量的增加而增大,特别是如果这个共享变量会占用比较多空间,如用于存放数据的字典结构HashMap,则空间会增加更多。所以在选择是否使用ThreadLocal时,需要对该共享变量的空间占用进行一个衡量。

二、使用方法

  • 通过ThreadLocal来对共享变量进行包装来实现线程安全通常用在对类的静态变量或者被共享的对象的内部属性。

  • 如下示例为通过ThreadLocal来对类的静态变量来进行包装,例子的含义是:静态变量nextWorkId用于生成每个线程的操作序号,即每个线程每进行一次操作,递增产生一个序号来标识这次操作是当前线程的第几次操作。

    public class ThreadLocalDemo2 {
        // 每个线程的第几次操作
        private static ThreadLocal<Integer> nextWorkId = new ThreadLocal<Integer>() {
            @Override
            protected Integer initialValue() {
                return 1;
            }
        };
    
    
        public static void main(String[] args) {
            Thread thread1 = new Thread(new Runnable() {
                @Override
                public void run() {
                    for (int i = 0; i < 10; i++) {
                        // 打印并递增工作id
                        System.out.println("thread1 work id: " + nextWorkId.get());
                        Integer nextId = nextWorkId.get();
                        nextWorkId.set(++nextId);
                        // 每隔一秒
                        try {
                            Thread.sleep(1000);
                        } catch (Exception e) {
                            e.printStackTrace();
                        }
                    }
                }
            });
    
            Thread thread2 = new Thread(new Runnable() {
                @Override
                public void run() {
                    for (int i = 0; i < 10; i++) {
                        System.out.println("thread2 work id: " + nextWorkId.get());
                        Integer nextId = nextWorkId.get();
                        nextWorkId.set(++nextId);
    
                        try {
                            Thread.sleep(1000);
                        } catch (Exception e) {
                            e.printStackTrace();
                        }
                    }
                }
            });
    
            thread1.start();
            thread2.start();
        }
    }
    

    执行的打印如下:可以看到两个线程都是执行了10次,序号不会相互影响。

    thread1 work id: 1
    thread2 work id: 1
    thread2 work id: 2
    thread1 work id: 2
    thread2 work id: 3
    thread1 work id: 3
    thread1 work id: 4
    thread2 work id: 4
    thread2 work id: 5
    thread1 work id: 5
    thread2 work id: 6
    thread1 work id: 6
    thread2 work id: 7
    thread1 work id: 7
    thread1 work id: 8
    thread2 work id: 8
    thread1 work id: 9
    thread2 work id: 9
    thread1 work id: 10
    thread2 work id: 10
    
    Process finished with exit code 0
    

三、核心实现

  • 在实现层面,首先是在Thread类中包含一个字典类型的成员变量threadLocals,用于存放该Thread线程对象所包含的所有使用ThreadLocal包装的变量的集合,其中这个字典类型是在ThreadLocal内部定义的一个静态内部类ThreadLocalMap,该字典实现的key是ThreadLocal对象引用,值为该ThreadLocal对象所包装的具体值。由于是每个Thread线程对象都包含这样一个字典集合,所以实现了每个线程都包含对应变量的一份副本。

Thread类的线程本地变量字典threadLocals

  • Thread类的threadLocals定义如下:可以看出类型是ThreadLocal.ThreadLocalMap。

    ThreadLocal.ThreadLocalMap threadLocals = null;
    

ThreadLocal类定义

  • 由以上分析可知,Thread类的线程本地变量字典threadLocals的类型ThreadLocalMap是在ThreadLocal中定义的,ThreadLocalMap的核心定义如下:可以看出与常用的字典结构HashMap类似,也是基于链式哈希表实现的。

    // 每个线程自身独立的,用于存放线程本地变量值的哈希字典表map
    static class ThreadLocalMap {
    
        // 链式哈希表的链表节点定义
        static class Entry extends WeakReference<ThreadLocal<?>> {
            // 实际的值
            Object value;
    
            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }
        
        // 链式哈希表对应的数组的初始容量为16
        private static final int INITIAL_CAPACITY = 16;
    
        
        // 链式哈希表的数组实现
        private Entry[] table;
    
        // 元素个数
        private int size = 0;
    
        // 拓容的阀值
        private int threshold;
        
        // 省略其他代码
        
    }
    
1. ThreadLocal的值初始化
  • 当使用ThreadLocal对某个变量进行包装时,一般首先需要对这个变量进行初始化,不过也可以通过调用set方法在之后使用时再设值。对ThreadLocal变量进行初始化主要是通过其initialValue方法来实现的,如下:默认实现为返回null,该方法是protected方法,故可以在创建ThreadLocal对象时,重写这个方法来自定义初始化逻辑。

    protected T initialValue() {
            return null;
        }
    
  • 重写initialValue方法来自定义初始化逻辑:如下初始化Integer类型的nextWorkId的值为1

    // 每个线程的第几次操作
    private static ThreadLocal<Integer> nextWorkId = new ThreadLocal<Integer>() {
        @Override
        protected Integer initialValue() {
            return 1;
        }
    };
    
2. ThreadLocal的get方法:获取线程绑定的值
  • 初始化值或者调用set方法写值之后,则在使用时,一般会通过ThreadLocal的get方法来获取该ThreadLocal所包装的变量对应的值,由于每个线程都是获取到与该线程绑定的值,即从该Thread线程对象所关联的线程本地变量集合threadLocals中获取,所以在get方法的内部实现当中,首先需要获取当前调用这个get方法的线程的对象引用Thread,即通过调用Thread.currentThread()方法获取,然后使用当前的ThreadLocal对象引用作为key,从该Thread线程对象的成员变量threadLocals获取对应的值,具体实现如下:

    // 获取值
    public T get() {
        // 获取当前调用这个get方法的线程引用
        Thread t = Thread.currentThread();
        // 每个线程thread都包含一个类型为ThreadLocalMap的threadLocals,
        // ThreadLocalMap用于存放这个线程所包含的所有ThreadLocal类的对象实例,
        // 即ThreadLocal对象作为key,值为每个线程独立的业务值value
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            // 将当前的ThreadLocal对象引用this作为key,从当前线程的ThreadLocalMap中获取值
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }
    
3. ThreadLocal的set方法:设置线程绑定的值
  • set方法主要是往Thread线程对象的threadLocals集合中设置该ThreadLocal对应的值,与get方法实现类似,也是先拿到当前调用这个set方法的线程的对象引用Thread,然后在往该Thread对象引用的threadLocals集合中设置值,其中key为当前的ThreadLocal对象引用,值为通过方法参数传递进来的实际的值,具体实现如下:

    // 设置值
    public void set(T value) {
        // 当前调用该方法的线程
        Thread t = Thread.currentThread();
        // 获取这个线程所绑定的ThreadLocalMap
        ThreadLocalMap map = getMap(t);
        if (map != null)
            // 将当前的ThreadLocal对象引用作为key,实际的值作为value
            map.set(this, value);
        else
            // 如果当前线程还没有填充过ThreadLocal类型的数据,则首先创建threadLocals集合,然后在写值进去
            createMap(t, value);
    }
    

你可能感兴趣的:(Java)