我们看个小栗子
Command存有ThreadLocal(name),InheritableThreadLocal(key),TransmittableThreadLocal(codec)三个变量。
TtlTest 用于测试,main线程设置这三个变量后,能否被其他线程继承。
public class Command {
static ThreadLocal<String> name = new ThreadLocal<>();
static InheritableThreadLocal<String> key = new InheritableThreadLocal<>();
static TransmittableThreadLocal<String> codec = new TransmittableThreadLocal<>();
public static String getName() {
return name.get();
}
public static String getKey() {
return key.get();
}
public static String getCodec() {
return codec.get();
}
public static void setName(String name) {
Command.name.set(name);
}
public static void setKey(String key) {
Command.key.set(key);
}
public static void setCodec(String codec) {
Command.codec.set(codec);
}
}
public class TtlTest {
static ExecutorService executorService = Executors.newFixedThreadPool(1);
static ExecutorService asyncExecutorService = Executors.newFixedThreadPool(1);
static ExecutorService ttlExecutorService;
static CountDownLatch asyncCtl = new CountDownLatch(1);
static {
new Thread(() -> {
asyncExecutorService.execute(() -> {
System.out.println("初次创建异步线程池");
});
ttlExecutorService = TtlExecutors.getTtlExecutorService(asyncExecutorService);
ttlExecutorService.execute(()->{
System.out.println("-------------");
});
asyncCtl.countDown();
}).start();
}
public static void main(String[] args) throws InterruptedException {
try {
asyncCtl.await();
} catch (InterruptedException e) {
}
Command.setName("main-name");
Command.setKey("main-key");
Command.setCodec("main-codec");
System.out.println("main线程获取变量值");
System.out.println(Command.getName() + "," + Command.getKey() + "," + Command.getCodec());
CountDownLatch countDownLatch = new CountDownLatch(1);
Thread thread1 = new Thread(() -> {
System.out.println("新线程获取变量值");
System.out.println(Command.getName() + "," + Command.getKey() + "," + Command.getCodec());
countDownLatch.countDown();
});
thread1.start();
countDownLatch.await();
CountDownLatch countDownLatch2 = new CountDownLatch(1);
executorService.execute(() -> {
System.out.println("同步创建-线程池获取变量值");
System.out.println(Command.getName() + "," + Command.getKey() + "," + Command.getCodec());
countDownLatch2.countDown();
});
countDownLatch2.await();
CountDownLatch countDownLatch3 = new CountDownLatch(1);
asyncExecutorService.execute(() -> {
System.out.println("异步创建-线程池获取变量值");
System.out.println(Command.getName() + "," + Command.getKey() + "," + Command.getCodec());
countDownLatch3.countDown();
});
countDownLatch3.await();
CountDownLatch countDownLatch4 = new CountDownLatch(1);
ttlExecutorService.execute(() -> {
System.out.println("ttl-线程池获取变量值");
System.out.println(Command.getName() + "," + Command.getKey() + "," + Command.getCodec());
countDownLatch4.countDown();
});
countDownLatch4.await();
}
运行结果
main线程获取变量值
main-name,main-key,main-codec
新线程获取变量值
null,main-key,main-codec
同步创建-线程池获取变量值
null,main-key,main-codec
异步创建-线程池获取变量值
null,null,null
ttl-线程池获取变量值
null,null,main-codec
根据结果我们发现
可见,TransmittableThreadLocal与普通Thread搭配使用,无法发挥其跨线程的能力。
Thread内部有一个threadLocals map属性,专门存储所有该线程持有的ThreadLocal变量。每次调用ThreadLocal变量的get方法时,其实就是从这个threadLocals map里面获取到变量值。
public T get() {
Thread t = Thread.currentThread();
//获取threadLocals
ThreadLocalMap map = getMap(t);
if (map != null) {
//this就是threadlocal变量
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}
和ThreadLocal变量类似,线程有一个inheritableThreadLocals属性专门存储相关的inheritableThreadLocal 变量。与ThreadLocal不同的是,inheritableThreadLocals具有被继承功能。
在创建一个线程的时候,最终会调用init方法,init内部就会继承父线程(当前线程)的inheritableThreadLocals
private void init(ThreadGroup g, Runnable target, String name,
long stackSize, AccessControlContext acc,
boolean inheritThreadLocals) {
......
Thread parent = currentThread();
......
if (inheritThreadLocals && parent.inheritableThreadLocals != null)
this.inheritableThreadLocals =
ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
......
}
private ThreadLocalMap(ThreadLocalMap parentMap) {
Entry[] parentTable = parentMap.table;
int len = parentTable.length;
setThreshold(len);
table = new Entry[len];
for (int j = 0; j < len; j++) {
Entry e = parentTable[j];
if (e != null) {
@SuppressWarnings("unchecked")
ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
if (key != null) {
//用于继承父线程的InheritableThreadLocal
//后续可以向Ttl的holder存入继承的值
Object value = key.childValue(e.value);
Entry c = new Entry(key, value);
int h = key.threadLocalHashCode & (len - 1);
while (table[h] != null)
h = nextIndex(h, len);
table[h] = c;
size++;
}
}
}
}
TransmittableThreadLocal继承了InheritableThreadLocal,所以天然会有被继承的能力,在上面提到的
Object value = key.childValue(e.value);
子线程会向TransmittableThreadLocal的holder变量存入父线程的ttl值。
TransmittableThreadLocal内部有一个holder,用于存储某个线程的TransmittableThreadLocal变量。holder是InheritableThreadLocal类型,这是为了可以传递给子线程。内部的Map是为了去重Ttl变量。
private static InheritableThreadLocal<Map<TransmittableThreadLocal<?>, ?>> holder =
new InheritableThreadLocal<Map<TransmittableThreadLocal<?>, ?>>() {
@Override
protected Map<TransmittableThreadLocal<?>, ?> initialValue() {
return new WeakHashMap<TransmittableThreadLocal<?>, Object>();
}
@Override
protected Map<TransmittableThreadLocal<?>, ?> childValue(Map<TransmittableThreadLocal<?>, ?> parentValue) {
//子线程继承时,会想holder存入parentValue
return new WeakHashMap<TransmittableThreadLocal<?>, Object>(parentValue);
}
};
所以,如果是结合普通线程使用,TransmittableThreadLocal与InheritableThreadLocal都是在线程被创建的时候被继承,一旦发生跨线程创建与使用,就无法被继承。
TtlRunnable就是用于包装Runnable,在包装的时候,会通过capture方法获取当前线程的TransmittableThreadLocal变量,在run之前先装载好构造方法里面存好的——capture的TransmittableThreadLocals,在run结束时归还run之前的TransmittableThreadLocals。这样就能达到跨线程继承TransmittableThreadLocal。(每次都会new一个TtlRunnable来缓存(capture)当前线程(不一定是父线程)的holder(inheritableThreadLocal),等待run的时候就调用replay来备份,且运用异步存入的holder变量)
private TtlRunnable(@Nonnull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
this.capturedRef = new AtomicReference<Object>(capture());
this.runnable = runnable;
this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
}
@Override
public void run() {
Object captured = capturedRef.get();
if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
throw new IllegalStateException("TTL value reference is released after run!");
}
Object backup = replay(captured);
try {
runnable.run();
} finally {
restore(backup);
}
}
/**
* Capture all {@link TransmittableThreadLocal} values in current thread.
*
* @return the captured {@link TransmittableThreadLocal} values
* @since 2.3.0
*/
@Nonnull
public static Object capture() {
Map<TransmittableThreadLocal<?>, Object> captured = new HashMap<TransmittableThreadLocal<?>, Object>();
for (TransmittableThreadLocal<?> threadLocal : holder.get().keySet()) {
captured.put(threadLocal, threadLocal.copyValue());
}
return captured;
}
/**
* Replay the captured {@link TransmittableThreadLocal} values from {@link #capture()},
* and return the backup {@link TransmittableThreadLocal} values in current thread before replay.
*
* @param captured captured {@link TransmittableThreadLocal} values from other thread from {@link #capture()}
* @return the backup {@link TransmittableThreadLocal} values before replay
* @see #capture()
* @since 2.3.0
*/
@Nonnull
public static Object replay(@Nonnull Object captured) {
@SuppressWarnings("unchecked")
Map<TransmittableThreadLocal<?>, Object> capturedMap = (Map<TransmittableThreadLocal<?>, Object>) captured;
Map<TransmittableThreadLocal<?>, Object> backup = new HashMap<TransmittableThreadLocal<?>, Object>();
for (Iterator<? extends Map.Entry<TransmittableThreadLocal<?>, ?>> iterator = holder.get().entrySet().iterator();
iterator.hasNext(); ) {
Map.Entry<TransmittableThreadLocal<?>, ?> next = iterator.next();
TransmittableThreadLocal<?> threadLocal = next.getKey();
// backup
backup.put(threadLocal, threadLocal.get());
// clear the TTL values that is not in captured
// avoid the extra TTL values after replay when run task
if (!capturedMap.containsKey(threadLocal)) {
iterator.remove();
threadLocal.superRemove();
}
}
// set values to captured TTL
setTtlValuesTo(capturedMap);
// call beforeExecute callback
doExecuteCallback(true);
return backup;
}
/**
* Clear all {@link TransmittableThreadLocal} values in current thread,
* and return the backup {@link TransmittableThreadLocal} values in current thread before clear.
*
* @return the backup {@link TransmittableThreadLocal} values before clear
* @since 2.9.0
*/
@Nonnull
public static Object clear() {
return replay(Collections.emptyMap());
}
/**
* Restore the backup {@link TransmittableThreadLocal} values from {@link #replay(Object)}/{@link #clear()}.
*
* @param backup the backup {@link TransmittableThreadLocal} values from {@link #replay(Object)}/{@link #clear()}
* @see #replay(Object)
* @see #clear()
* @since 2.3.0
*/
public static void restore(@Nonnull Object backup) {
@SuppressWarnings("unchecked")
Map<TransmittableThreadLocal<?>, Object> backupMap = (Map<TransmittableThreadLocal<?>, Object>) backup;
// call afterExecute callback
doExecuteCallback(false);
for (Iterator<? extends Map.Entry<TransmittableThreadLocal<?>, ?>> iterator = holder.get().entrySet().iterator();
iterator.hasNext(); ) {
Map.Entry<TransmittableThreadLocal<?>, ?> next = iterator.next();
TransmittableThreadLocal<?> threadLocal = next.getKey();
// clear the TTL values that is not in backup
// avoid the extra TTL values after restore
if (!backupMap.containsKey(threadLocal)) {
iterator.remove();
threadLocal.superRemove();
}
}
// restore TTL values
setTtlValuesTo(backupMap);
}
private static void setTtlValuesTo(@Nonnull Map<TransmittableThreadLocal<?>, Object> ttlValues) {
for (Map.Entry<TransmittableThreadLocal<?>, Object> entry : ttlValues.entrySet()) {
@SuppressWarnings("unchecked")
TransmittableThreadLocal<Object> threadLocal = (TransmittableThreadLocal<Object>) entry.getKey();
threadLocal.set(entry.getValue());
}
}
根据上面原理的阐述,我们可以得知: