Java ThreadLocal 源码解析

前言

ThreadLocal 是 Java 语言中的一个类,可以使用它为每个线程存储数据。这些数据只能被当前线程访问,而其他线程无法访问。这个类可以用于避免多次传递、线程间数据隔离、事务操作等场景。

本次源码分析基于 JDK 21.0.1。

ThreadLocal 使用简介

基本操作

使用 ThreadLocal 时,可以将数据存储在一个特殊的对象中,这个对象会被自动关联到当前线程。例如,可以使用以下代码创建一个 ThreadLocal 对象,其中存储了一个整数值:

ThreadLocal<Integer> threadLocalValue = new ThreadLocal<>();
threadLocalValue.set(1);
Integer result = threadLocalValue.get();

如果想要在创建 ThreadLocal 对象时就设置初始值,可以使用 withInitial() 方法,并通过 lambda 表达式传入一个 Supplier 对象,例如:

ThreadLocal<Integer> threadLocal = ThreadLocal.withInitial(() -> 1);

如果想要删除 ThreadLocal 中的值,可以调用 remove() 方法。例如:

threadLocal.remove();

多线程下 ThreadLocal 使用

以下代码演示了 ThreadLocal 的使用,代码首先创建了 NUM_THREADS 个线程,然后在每个线程内创建了 ThreadLocal。随后,每个线程分别对线程私有的 ThreadLocal 自增 NUM_THREADS 次,并对共享的 sharedValue 自增 NUM_THREADS 次。

import java.util.concurrent.atomic.AtomicInteger;

public class Main {
    private static final int NUM_THREADS = 3;
    private static final int NUM_INCREMENTS = 5;

    public static void main(String[] args) {
        AtomicInteger sharedValue = new AtomicInteger(0);

        for (int i = 0; i < NUM_THREADS; i++) {
            new Thread(() -> {
                ThreadLocal<Integer> threadLocalValue = ThreadLocal.withInitial(() -> 0);

                for (int j = 0; j < NUM_INCREMENTS; j++) {
                    int localValue = threadLocalValue.get();
                    localValue++;
                    threadLocalValue.set(localValue);

                    int currentValue = sharedValue.get();
                    currentValue++;
                    sharedValue.set(currentValue);
                }

                System.out.println("Thread " + Thread.currentThread().getId() + ": Thread-local value = " + threadLocalValue.get() + ", Shared value = " + sharedValue.get());
            }).start();
        }
    }
}

ThreadLocal 源码解析

初始化

使用无参构造器时仅创建一个空的 ThreadLocal 对象:

    public ThreadLocal() {
    }

使用 withInitial 设置 ThreadLocal 初值时,返回的是 SuppliedThreadLocal 类型:

	// supplier 为传入的 lambda 表达式
    public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
    	// 创建并返回了一个 SuppliedThreadLocal
        return new SuppliedThreadLocal<>(supplier);
    }

传入的 Supplier 定义如下:

@FunctionalInterface
public interface Supplier<T> {
    T get();
}

其中 SuppliedThreadLocal 是 ThreadLocal 的静态内部类,它继承了 ThreadLocal 并重写了 initialValue() 方法:

    static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {

        private final Supplier<? extends T> supplier;

		// 将赋初值的 lambda 表达式设置为 supplier 成员变量
        SuppliedThreadLocal(Supplier<? extends T> supplier) {
            this.supplier = Objects.requireNonNull(supplier);
        }

        @Override
        protected T initialValue() {
            return supplier.get();
        }
    }

后续第一次调用 get() 时,会调用 SuppliedThreadLocal 重写的 initialValue() 方法,该方法调用了传入的 Supplier 表达式返回 ThreadLocal 初值。

set()

set() 用于设置 ThreadLocal 的值,其实现如下。

    public void set(T value) {
    	// 为了设置 ThreadLocal 的值,传入了当前线程
        set(Thread.currentThread(), value);
        if (TRACE_VTHREAD_LOCALS) {
            dumpStackIfVirtualThread();
        }
    }

    private void set(Thread t, T value) {
    	// 获取和当前 ThreadLocal 关联的哈希表
        ThreadLocalMap map = getMap(t);
        if (map == ThreadLocalMap.NOT_SUPPORTED) {
            throw new UnsupportedOperationException();
        }
        if (map != null) {
        	// map 已经初始化,则直接设置值
            map.set(this, value);
        } else {
        	// lazy 初始化 ThreadLocalMap
            createMap(t, value);
        }
    }

首先看 getMap(t),它获取了和当前 ThreadLocal 关联的哈希表:

    ThreadLocalMap getMap(Thread t) {
    	// 从线程对象获取 ThreadLocalMap,由此可以看出每个对象一个 ThreadLocalMap
        return t.threadLocals;
    }

t.threadLocals 是 Thread 对象的成员,其类型为 ThreadLocal.ThreadLocalMap

public class Thread {
	...
	
	ThreadLocal.ThreadLocalMap threadLocals;
	
	...
}

ThreadLocalMap 是 ThreadLocal 类的内部类,它用于存储线程本地变量。 ThreadLocalMap 是 Thread 对象的成员变量,这说明每个线程都有一个 ThreadLocalMap 对象,而 ThreadLocalMap 保存了当前线程拥有的所有 ThreadLocal 对象和对应的变量副本。

回到set() 方法,由 set() 方法可以看出 ThreadLocalMap 是延迟到第一次使用的时候创建的。创建 ThreadLocalMap 的代码如下:

    void createMap(Thread t, T firstValue) {
    	// 创建 ThreadLocal 并将关联的线程和赋予的值传入
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

ThreadLocalMap 是一个专门保存 ThreadLocal 的哈希表,其构造器的实现如下:

        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
        	// 创建哈希表的底层数组
            table = new Entry[INITIAL_CAPACITY];
            // 哈希值取余定位
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            // 创建一个 entry 放到槽位中
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            // 设置哈希表扩容的大小门槛,为总容量的 2/3
            setThreshold(INITIAL_CAPACITY);
        }

        private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }

get()

get() 方法用于获取 ThreadLocal 的值,其实现如下。

    public T get() {
    	// 根据 Thread 获取值
        return get(Thread.currentThread());
    }

	// 1. 根据 Thread 获取 ThreadLocalMap
	// 2. 从 ThreadLocalMap 获取 entry,并将 entry 的 value 作为结果返回
	// 3. 如果 map 为 null,说明未初始化,调用 setInitialValue 进行初始化
    private T get(Thread t) {
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T) e.value;
                return result;
            }
        }
        return setInitialValue(t);
    }

    private T setInitialValue(Thread t) {
    	// 如果使用无参构造器,返回的是 null
    	// 如果使用了 ThreadLocal.withInitial 创建 ThreadLocal,返回的是 lambda 表达式的结果
        T value = initialValue();
        // 获取 ThreadLocalMap,如果是第一次访问则进行初始化
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            map.set(this, value);
        } else {
            createMap(t, value);
        }
        if (this instanceof TerminatingThreadLocal<?> ttl) {
            TerminatingThreadLocal.register(ttl);
        }
        if (TRACE_VTHREAD_LOCALS) {
            dumpStackIfVirtualThread();
        }
        return value;
    }

总结

ThreadLocal 可以用于保存线程私有的数据,其源码具有下关键点:

  • ThreadLocalMap 的创建是懒加载的;
  • ThreadLocal 的实现是通过将一个 ThreadLocalMap 作为 Thread 对象的成员实现的;
  • 各个线程的全部 ThreadLocal 都保存在 ThreadLocalMap 中。

你可能感兴趣的:(Java,java,python,开发语言)