为了介绍线程间传递 ThreadLocal 对象这个事情,请先耐心一些跟我一起来看看我是怎么遇到线程间传递 ThreadLocal 对象这个需求的。

一起看这么个场景,大致上是下面这样,是 clojure 的代码。但是请不要担心,它非常短也非常简单:

;; 定义一个 clojure 动态绑定的变量,实际就是个 Java ThreadLocal 对象
(defonce ^:dynamic *utc* false)
;; 处理某 Http 请求的函数, 参数 req 内是用户传来的 Key -> Value 格式的参数
(defn some-http-handler-function [req]
  ;; 首先需要从 req 内读取 API 版本信息 所有请求都会有这个参数
  ;; 如果 API 版本是 1.1 则认为使用的是 utc 时间,所以给 *utc* 变量绑定一个 true 值
  (binding [*utc* (= (:api-version req) "1.1")]
    ;; 绑定好 *utc* 后,开始处理 req 请求,process 返回值作为 Http 请求的结果
    (process req))

some-http-handler-function 是一个处理 Http 请求的函数,这个 Http 请求的 content-type 是 JSON。Server 收到请求后会将 body 内的 JSON 参数转换为 Key -> Value 的 map 当做 req 参数传入 some-http-handler-function 函数做处理。

我们的 Http 接口要求请求中必须带着一些时间参数,并且时间参数须是符合 iso8601 格式的字符串,形如 2017-09-23T12:15:42.972Z。我们的 API 分为很多版本,1.1 版本之前的 API 没有对用户提供的这个时间参数的时区做约定,所以我们默认以当前服务器所在时区来解析用户传来的时间参数。1.1 版本之后的 API,我们跟用户约定时间参数必须是 UTC 时间,服务器也就直接以 UTC 时间来解析时间参数,从而不再有时区差异问题。

因为相同的 some-http-handler-function 函数,要兼容处理 1.1 版本 API 请求和老版本 API 请求,所以要在请求中带着 API 版本信息,并在收到请求后,根据 req 中的 API 版本信息,判断是否使用 UTC 时间。因为处理请求的函数有多层嵌套,比如 (process req) 可能调用 (do-process req)(do-after-do-process req)等等,又不想让所有嵌套的函数调用都带着是否使用的 UTC 时间的参数,就简单的将是否使用 UTC 时间这个事情记录在一个动态绑定的 *utc* 参数上。

如果不了解 clojure,可以将 *utc* 简单理解为一个全局的 Java ThreadLocal 对象,同一个线程存入一个值后,比如将 *utc* 这个字符串绑定为 true 后,在该线程后续调用的所有函数、方法内,都能直接拿到 *utc* 参数的值,从而不用在所有该线程调用的函数、方法内都带着 *utc* 参数。让代码简单一些。

最初 (process req) 是个同步的调用,绑定完 *utc* 参数的主线程会完成 process 函数内所有逻辑,完成 Http 请求的处理并负责将结果发给用户。但后来为了隔离,将 process 函数内增加了线程池,会从线程池找一个空闲线程来实际处理 Http 请求,当前主线程会 Block 住等待 process 执行结果。由于 Java ThreadLocal 对象是不能在线程之间传递的,所以主线程虽然绑定了 *utc* 参数,但是 process 内的业务线程并不知道这个事情,于是出现即使在处理 1.1 版本 API 的请求时,所有时间参数也均采用服务器本地时区来解析,而不是 UTC 来解析,造成了 Bug。

这就是我遇到的线程之间传递 ThreadLocal 对象需求的来源。我们的解决办法很简单粗暴,就是将 *utc* 这个参数直接放在 process 函数的参数之中,等到 process 内的线程实际运行时,重新为线程池内实际执行 process 工作的业务线程绑定 *utc* 参数。

当时就在想如果有机制能在一些特定情况之下,让 ThreadLocal 对象绑定的 Value 在不同线程之间能共享,对上面这种场景处理就会比较方便。主线程将处理请求的任务交给业务线程之后,即使两个线程共享 *utc* 参数,但因为都不修改这个参数值,所以并不会引起问题。

下面我们看看这种 ThreadLocal 对象绑定值如何在不同线程之间传递。

父子进程之间传递 ThreadLocal 对象

这个实际上 JDK 自己有实现,是 InheritableThreadLocal 类。慢慢介绍一下它的原理。

Thread 类内实际有两个 ThreadLocal.ThreadLocalMap,一个是给 ThreadLocal 使用的,还一个是给 InheritableThreadLocal 使用的。给 InheritableThreadLocal 使用的 ThreadLocalMap 特殊一些,在一个线程 fork 一个新的子线程的时候,父线程会检查自己 Thread 内为 InheritableThreadLocal 使用的 ThreadLocalMap 是否为空,不为空则将其拷贝给子线程的为 InheritableThreadLocal 使用的 ThreadLocalMap。Thread 的 init 可以看到这块逻辑:

if (parent.inheritableThreadLocals != null)
    this.inheritableThreadLocals =
        ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);

ThreadLocal.createInheritedMap 会调用 ThreadLocalMap 的拷贝构造函数实现大致为:

private ThreadLocalMap(ThreadLocalMap parentMap) {
    Entry[] parentTable = parentMap.table;
  int len = parentTable.length;
  setThreshold(len);
  table = new Entry[len];
  for each Entry e in parentTable {
        if e != null { 
            ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
          if (key != null) {
            Object value = key.childValue(e.value);
                setKeyValueToThisThreadLocalMap(key, value);
          }
       }
   }
}

这么一来创建子线程之后,父线程的给 InheritableThreadLocal 使用的 ThreadLocalMap 就在子线程中有了一个副本。默认情况下父子线程的 ThreadLocalMap 内的 key 都指向同一个 InheritableThreadLocal 对象,Value 也指向同一个 Value。从上面能看到子线程存储 ThreadLocalMap 的 Value 实际上是 key.childValue(e.value) 就是说能在使用 InheritableThreadLocal 的时候覆盖 childValue 方法从而根据父线程的 Value 提供子线程的 Value。

但是对于开篇那个问题的场景,InheritableThreadLocal 无法使用。在那个场景中,相当于是一个 Thread A 发配一个 Runnable 或者 Callable 到一个线程池中,让线程池内的线程去执行 Runnable/Callable。Thread A 和线程池内的线程并没有父子关系,所以 Thread A 不能将绑定的 InheritableThreadLocal 值传递给线程池内负责实际执行 Runnable/Callable的线程。

用 Runnable/Callable 传递 ThreadLocal

Thread A 传递给线程池内线程的只有 Runnable/Callable,所以如果想要实现将 ThreadA 的 ThreadLocal 值传递给线程池内实际负责执行 Runnable/Callable 的 Thread,就一定是需要在 Runnalbe/Callable 做文章,传递 Runnable/Callable 的时候将 Thread A 的 ThreadLocal 值存到 Runnable/Callable 中,之后线程池线程,比如是 Thread 1,在运行 Runnable/Callable 时从中读出这些 ThreadLocal 值并存入Thread 1 的 ThreadLocalMap 中,实现 ThreadLocal 对象值的传递。这也是 GitHub - alibaba/transmittable-thread-local 库最基础的实现原理。

感觉很多东西都是这样,仔细想想原理感觉不难,但难的是能想到这个点子。之前遇到开篇问题的时候只是想绕过的笨办法,没想想怎么从 ThreadLocal 原理上来解决这个问题。下面记录一下这个库的部分实现,但是这里不准备直接记录 transmittable thread local 是怎么实现,而是看看能不能在现有了解东西的基础上,一步一步推测出来它怎么实现。

transmittable-thread-local

通过上面的描述,应该已对 transmittable thread local 的实现原理有个大致的猜想。首先是需要有个专门的 Runnable 或 Callable,用于读取原 Thread 的 ThreadLocal 对象及其值并存在 Runnable/Callable中,在执行 run 或者 call 的时候将存在 Runnable/Callable 中的 ThreadLocal 对象和值读出来,存入调用 run 或 call 的线程。

读取原 Thread 上所有的(或者说会发生这种线程间传递的) ThreadLocal 对象及其值比较麻烦,ThreadLocal 和 InheritableThreadLocal 都没有开放内部的 ThreadLocalMap,不能直接读取。所以要么自己完全实现一套 ThreadLocalMap 机制,像 Netty 的 FastThreadLocal 那样;要么就是自己实现 ThreadLocal 的子类,在每次调用 ThreadLocal 的 set/get/remove 等接口的时候,为 Thread 记录到底绑定了哪些需要发生线程间传递的 ThreadLocal 对象。后者更简单和更可靠一些,所以可能选择后者更稳妥,transmittable-thread-local 这个库也是这么做的。

通过 Runnable/Callable 传递 ThreadLocal 对象及其值的方法是有了,父子线程之间传递可以复用 InheritableThreadLocal 的实现,所以新的 TransmittableThreadLocal 对象需要继承 InheritableThreadLocal 从而获取它的父子线程间传递 ThreadLocal 对象及其值的能力。到目前为止,大致的类图如下:

transmittable-thread-local 实现

使用的时候必须使用使用 TransmittableThreadLocal,创建 Runnable 或 Callable 也必须使用 TtlCallable 或者 TtlRunnable。TransmittableThreadLocal 覆盖实现了 ThreadLocal 的 set、get、remove,实际存储 ThreadLocal 值的工作还是 ThreadLocal 父类完成,TransmittableThreadLocal 只是为每个使用它的 Thread 单独记录一份存储了哪些 TransmittableThreadLocal 对象。拿 set 来说就是这个样子:

public final void set(T value) {
  super.set(value);
  if (null == value) removeValue();
  else addValue();
}

addValue() 和 removeValue 都是将当前 TransmittableThreadLocal 对象存入 TransmittableThreadLocal 内一个 static 的并且是 InheritableThreadLocal 的 WeakHashMap 中。

private static InheritableThreadLocal<Map<TransmittableThreadLocal<?>, ?>> holder =
    new InheritableThreadLocal<Map<TransmittableThreadLocal<?>, ?>>() {
      @Override
      protected Map<TransmittableThreadLocal<?>, ?> initialValue() {
        return new WeakHashMap<TransmittableThreadLocal<?>, Object>();
      }
      @Override
      protected Map<TransmittableThreadLocal<?>, ?> childValue(Map<TransmittableThreadLocal<?>, ?> parentValue) {
        return new WeakHashMap<TransmittableThreadLocal<?>, Object>(parentValue);
      }
  };
private void addValue() {
    if (!holder.get().containsKey(this)) {
        holder.get().put(this, null); // WeakHashMap supports null value.
    }
}

相当于是和 Thread 绑定的所有 TransmittableThreadLocal 对象都保存在这个 holder 中,TransmittableThreadLocal 对应的 Value 则还是保存在 ThreadLocal 的 ThreadLocalMap 中。holder 只是为了记录当前 Thread 绑定了哪些 TransmittableThreadLocal 对象,好在 TtlRunnable 或 TtlCallable 构造的时候通过 holder 取出这些 TransmittableThreadLocal 存入 TtlRunnable 或 TtlCallable。

剩下的更多细节感觉就不用再多记录了,有兴趣的话可以下载这个库的代码看看。上面原理虽然看上去简单,但是这个库还在易用性上做了很多增强,还是值得看看值得使用一下的。