ThreadLocal 是 Java 语言中的一个类,可以使用它为每个线程存储数据。这些数据只能被当前线程访问,而其他线程无法访问。这个类可以用于避免多次传递、线程间数据隔离、事务操作等场景。
本次源码分析基于 JDK 21.0.1。
使用 ThreadLocal 时,可以将数据存储在一个特殊的对象中,这个对象会被自动关联到当前线程。例如,可以使用以下代码创建一个 ThreadLocal 对象,其中存储了一个整数值:
ThreadLocal<Integer> threadLocalValue = new ThreadLocal<>();
threadLocalValue.set(1);
Integer result = threadLocalValue.get();
如果想要在创建 ThreadLocal 对象时就设置初始值,可以使用 withInitial()
方法,并通过 lambda 表达式传入一个 Supplier 对象,例如:
ThreadLocal<Integer> threadLocal = ThreadLocal.withInitial(() -> 1);
如果想要删除 ThreadLocal 中的值,可以调用 remove()
方法。例如:
threadLocal.remove();
以下代码演示了 ThreadLocal 的使用,代码首先创建了 NUM_THREADS
个线程,然后在每个线程内创建了 ThreadLocal。随后,每个线程分别对线程私有的 ThreadLocal 自增 NUM_THREADS
次,并对共享的 sharedValue
自增 NUM_THREADS
次。
import java.util.concurrent.atomic.AtomicInteger;
public class Main {
private static final int NUM_THREADS = 3;
private static final int NUM_INCREMENTS = 5;
public static void main(String[] args) {
AtomicInteger sharedValue = new AtomicInteger(0);
for (int i = 0; i < NUM_THREADS; i++) {
new Thread(() -> {
ThreadLocal<Integer> threadLocalValue = ThreadLocal.withInitial(() -> 0);
for (int j = 0; j < NUM_INCREMENTS; j++) {
int localValue = threadLocalValue.get();
localValue++;
threadLocalValue.set(localValue);
int currentValue = sharedValue.get();
currentValue++;
sharedValue.set(currentValue);
}
System.out.println("Thread " + Thread.currentThread().getId() + ": Thread-local value = " + threadLocalValue.get() + ", Shared value = " + sharedValue.get());
}).start();
}
}
}
使用无参构造器时仅创建一个空的 ThreadLocal 对象:
public ThreadLocal() {
}
使用 withInitial
设置 ThreadLocal 初值时,返回的是 SuppliedThreadLocal 类型:
// supplier 为传入的 lambda 表达式
public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
// 创建并返回了一个 SuppliedThreadLocal
return new SuppliedThreadLocal<>(supplier);
}
传入的 Supplier 定义如下:
@FunctionalInterface
public interface Supplier<T> {
T get();
}
其中 SuppliedThreadLocal 是 ThreadLocal 的静态内部类,它继承了 ThreadLocal 并重写了 initialValue()
方法:
static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {
private final Supplier<? extends T> supplier;
// 将赋初值的 lambda 表达式设置为 supplier 成员变量
SuppliedThreadLocal(Supplier<? extends T> supplier) {
this.supplier = Objects.requireNonNull(supplier);
}
@Override
protected T initialValue() {
return supplier.get();
}
}
后续第一次调用 get()
时,会调用 SuppliedThreadLocal 重写的 initialValue()
方法,该方法调用了传入的 Supplier 表达式返回 ThreadLocal 初值。
set()
用于设置 ThreadLocal 的值,其实现如下。
public void set(T value) {
// 为了设置 ThreadLocal 的值,传入了当前线程
set(Thread.currentThread(), value);
if (TRACE_VTHREAD_LOCALS) {
dumpStackIfVirtualThread();
}
}
private void set(Thread t, T value) {
// 获取和当前 ThreadLocal 关联的哈希表
ThreadLocalMap map = getMap(t);
if (map == ThreadLocalMap.NOT_SUPPORTED) {
throw new UnsupportedOperationException();
}
if (map != null) {
// map 已经初始化,则直接设置值
map.set(this, value);
} else {
// lazy 初始化 ThreadLocalMap
createMap(t, value);
}
}
首先看 getMap(t)
,它获取了和当前 ThreadLocal 关联的哈希表:
ThreadLocalMap getMap(Thread t) {
// 从线程对象获取 ThreadLocalMap,由此可以看出每个对象一个 ThreadLocalMap
return t.threadLocals;
}
t.threadLocals
是 Thread 对象的成员,其类型为 ThreadLocal.ThreadLocalMap
:
public class Thread {
...
ThreadLocal.ThreadLocalMap threadLocals;
...
}
ThreadLocalMap 是 ThreadLocal 类的内部类,它用于存储线程本地变量。 ThreadLocalMap 是 Thread 对象的成员变量,这说明每个线程都有一个 ThreadLocalMap 对象,而 ThreadLocalMap 保存了当前线程拥有的所有 ThreadLocal 对象和对应的变量副本。
回到set()
方法,由 set()
方法可以看出 ThreadLocalMap 是延迟到第一次使用的时候创建的。创建 ThreadLocalMap 的代码如下:
void createMap(Thread t, T firstValue) {
// 创建 ThreadLocal 并将关联的线程和赋予的值传入
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
ThreadLocalMap 是一个专门保存 ThreadLocal 的哈希表,其构造器的实现如下:
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
// 创建哈希表的底层数组
table = new Entry[INITIAL_CAPACITY];
// 哈希值取余定位
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
// 创建一个 entry 放到槽位中
table[i] = new Entry(firstKey, firstValue);
size = 1;
// 设置哈希表扩容的大小门槛,为总容量的 2/3
setThreshold(INITIAL_CAPACITY);
}
private void setThreshold(int len) {
threshold = len * 2 / 3;
}
get()
方法用于获取 ThreadLocal 的值,其实现如下。
public T get() {
// 根据 Thread 获取值
return get(Thread.currentThread());
}
// 1. 根据 Thread 获取 ThreadLocalMap
// 2. 从 ThreadLocalMap 获取 entry,并将 entry 的 value 作为结果返回
// 3. 如果 map 为 null,说明未初始化,调用 setInitialValue 进行初始化
private T get(Thread t) {
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T) e.value;
return result;
}
}
return setInitialValue(t);
}
private T setInitialValue(Thread t) {
// 如果使用无参构造器,返回的是 null
// 如果使用了 ThreadLocal.withInitial 创建 ThreadLocal,返回的是 lambda 表达式的结果
T value = initialValue();
// 获取 ThreadLocalMap,如果是第一次访问则进行初始化
ThreadLocalMap map = getMap(t);
if (map != null) {
map.set(this, value);
} else {
createMap(t, value);
}
if (this instanceof TerminatingThreadLocal<?> ttl) {
TerminatingThreadLocal.register(ttl);
}
if (TRACE_VTHREAD_LOCALS) {
dumpStackIfVirtualThread();
}
return value;
}
ThreadLocal 可以用于保存线程私有的数据,其源码具有下关键点: