ThansmittableThreadLocal

8/3/2022 多线程

# 前述

SpringSecurity微服务鉴权方式 (opens new window)一文中,我们提到过通过使用 ThansmittableThreadLocal 来实现从当前线程中获取登录用户的有关信息,今天我们来详细说一说。

GitHub地址 (opens new window)

# InheritableThreadLocal的问题

JDKInheritableThreadLocal类可以完成父线程到子线程的值传递。但对于使用线程池等会池化复用线程的执行组件的情况,线程由线程池创建好,并且线程是池化起来反复使用的;这时父子线程关系的ThreadLocal值传递已经没有意义,应用需要的实际上是把 任务提交给线程池时ThreadLocal值传递到 任务执行时

@SpringBootTest
class ProcurementApplicationTests {

    @Test
    void contextLoads() throws InterruptedException {
        //单一线程池
        ExecutorService executorService = Executors.newSingleThreadExecutor();
        //InheritableThreadLocal存储
        InheritableThreadLocal<String> username = new InheritableThreadLocal<>();
        for (int i = 0; i < 10; i++) {
            username.set("username—" + i);
            Thread.sleep(3000);
            CompletableFuture.runAsync(() -> System.out.println(username.get()), executorService);
        }
    }
}
// 结果输出为
// username—0
// username—0
// username—0
// username—0
// username—0
// username—0
// username—0
// username—0
// username—0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

# TransmittableThreadLocal功能

  • TransmittableThreadLocal(TTL):在使用线程池等会池化复用线程的执行组件情况下,提供ThreadLocal值的传递功能,解决异步执行时上下文传递的问题。

  • TransmittableThreadLocal类继承并加强InheritableThreadLocal类,解决上述的问题

  • 整个TransmittableThreadLocal库的核心功能(用户API与框架/中间件的集成API、线程池ExecutorService/ForkJoinPool/TimerTask及其线程工厂的Wrapper),只有 *~1000 SLOC代码行*,非常精小。

@SpringBootTest
class ProcurementApplicationTests {

    @Test
    void contextLoads() throws InterruptedException {
        //单一线程池
        ExecutorService executorService = Executors.newSingleThreadExecutor();
        //需要使用TtlExecutors对线程池包装一下
        executorService= TtlExecutors.getTtlExecutorService(executorService);
        //TransmittableThreadLocal创建
        TransmittableThreadLocal<String> username = new TransmittableThreadLocal<>();
        for (int i = 0; i < 10; i++) {
            username.set("username —" + i);
            Thread.sleep(3000);
            CompletableFuture.runAsync(() -> System.out.println(username.get()), executorService);
        }
    }
}

// 结果输出
// username —0
// username —1
// username —2
// username —3
// username —4
// username —5
// username —6
// username —7
// username —8
// username —9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30

# 原理

TransmittableThreadLocal继承InheritableThreadLocal,使用方式也类似。

public class TransmittableThreadLocal<T> extends InheritableThreadLocal<T> implements TtlCopier<T> {}
1

# 相比InheritableThreadLocal

主要添加了以下两个方法:

1、copy方法

用于定制 任务提交给线程池时ThreadLocal值传递到 任务执行时 的拷贝行为,缺省是简单的赋值传递

2、protectedbeforeExecute/afterExecute方法 执行任务(Runnable/Callable)的前/后的生命周期回调,缺省是空操作。

# 部分源码

1、TransimittableThreadLocal中添加holder属性。这个属性的作用就是被标记为具备线程传递资格的对象都会被添加到这个对象中

public class TransmittableThreadLocal<T> extends InheritableThreadLocal<T> implements TtlCopier<T> {

// 1. holder本身是一个InheritableThreadLocal对象
// 2. 这个holder对象的value是WeakHashMap<TransmittableThreadLocal<Object>, ?>
// 2.1 WeekHashMap的value总是null,且不可能被使用。
// 2.2 WeekHasshMap支持value=null
private static final InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
   new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
        @Override
           protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
                 return new WeakHashMap<>();
             }

                @Override
                protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
                    return new WeakHashMap<TransmittableThreadLocal<Object>, Object>(parentValue);
                }
            };

}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20

# 修饰线程池

省去每次RunnableCallable传入线程池时的修饰,这个逻辑可以在线程池中完成。

通过工具类com.alibaba.ttl.threadpool.TtlExecutors完成,有下面的方法:

  • getTtlExecutor:修饰接口Executor
  • getTtlExecutorService:修饰接口ExecutorService
  • getTtlScheduledExecutorService:修饰接口ScheduledExecutorService
public final class TtlExecutors {

    @Nullable
    @Contract(value = "null -> null; !null -> !null", pure = true)
    public static ExecutorService getTtlExecutorService(@Nullable ExecutorService executorService) {
        if (TtlAgent.isTtlAgentLoaded() || executorService == null || executorService instanceof TtlEnhanced) {
            return executorService;
        }
        return new ExecutorServiceTtlWrapper(executorService, true);
    }
}
1
2
3
4
5
6
7
8
9
10
11

# ExecutorServiceTtlWrapper&ExecutorTtlWrapper

无论是ExecutorTtlWrapper方法还是ExecutorServiceTtlWrapper方法,最终都会将线程对象包装成TtlCallable或者TtlRunnable,用于在真正执行run方法前做一些业务逻辑。

class ExecutorServiceTtlWrapper extends ExecutorTtlWrapper implements ExecutorService, TtlEnhanced {

    /**
     * 在ExecutorServiceTtlWrapper实现submit方法
     */
	@NonNull
    @Override
    public <T> Future<T> submit(@NonNull Callable<T> task) {
        return executorService.submit(TtlCallable.get(task, false, idempotent));
    }
}

class ExecutorTtlWrapper implements Executor, TtlWrapper<Executor>, TtlEnhanced {

    /**
     * 在ExecutorTtlWrapper实现execute方法
     */
	@Override
    public void execute(@NonNull Runnable command) {
        executor.execute(TtlRunnable.get(command, false, idempotent));
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

# Transmitter类

public static class Transmitter {
      /**
    	* 捕获当前线程中的是所有TransimittableThreadLocal和注册ThreadLocal的值。
    	*/
        @NonNull
        public static Object capture() {
            return new Snapshot(captureTtlValues(), captureThreadLocalValues());
        }

     /**
    * 捕获TransimittableThreadLocal的值,将holder中的所有值都添加到HashMap后返回。
    */
        private static HashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
            HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new HashMap<>();
            for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
                ttl2Value.put(threadLocal, threadLocal.copyValue());
            }
            return ttl2Value;
        }

    /**
    * 捕获注册的ThreadLocal的值,也就是原本线程中的ThreadLocal,可以注册到TTL中,在
    * 进行线程池本地变量传递时也会被传递。
    */
        private static HashMap<ThreadLocal<Object>, Object> captureThreadLocalValues() {
            final HashMap<ThreadLocal<Object>, Object> threadLocal2Value = new HashMap<>();
            for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {
                final ThreadLocal<Object> threadLocal = entry.getKey();
                final TtlCopier<Object> copier = entry.getValue();

                threadLocal2Value.put(threadLocal, copier.copy(threadLocal.get()));
            }
            return threadLocal2Value;
        }

        /**
    * 将捕获到的本地变量进行替换子线程的本地变量,并且返回子线程现有的本地变量副本backup。
    * 用于在执行run/call方法之后,将本地变量副本恢复。
    */
        @NonNull
        public static Object replay(@NonNull Object captured) {
            final Snapshot capturedSnapshot = (Snapshot) captured;
            return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
        }

    /**
    * 替换TransmittableThreadLocal
    */
        @NonNull
        private static HashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> captured) {
            // 创建副本backup
            HashMap<TransmittableThreadLocal<Object>, Object> backup = new HashMap<>();

            for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
                TransmittableThreadLocal<Object> threadLocal = iterator.next();

                // 对当前线程的本地变量进行副本拷贝
                backup.put(threadLocal, threadLocal.get());

                // 若出现调用线程中不存在某个线程变量,而线程池中线程有,则删除线程池中对应的本地变量
                if (!captured.containsKey(threadLocal)) {
                    iterator.remove();
                    threadLocal.superRemove();
                }
            }

            // 将捕获的TTL值打入线程池获取到的线程TTL中。
            setTtlValuesTo(captured);

            // 调用TTL的beforeExecute方法。默认实现为空
            doExecuteCallback(true);

            return backup;
        }

        private static HashMap<ThreadLocal<Object>, Object> replayThreadLocalValues(@NonNull HashMap<ThreadLocal<Object>, Object> captured) {
            final HashMap<ThreadLocal<Object>, Object> backup = new HashMap<>();

            for (Map.Entry<ThreadLocal<Object>, Object> entry : captured.entrySet()) {
                final ThreadLocal<Object> threadLocal = entry.getKey();
                backup.put(threadLocal, threadLocal.get());

                final Object value = entry.getValue();
                if (value == threadLocalClearMark) threadLocal.remove();
                else threadLocal.set(value);
            }

            return backup;
        }

        /**
    * 清除单线线程的所有TTL和TL,并返回清除之气的backup
    */
        @NonNull
        public static Object clear() {
            final HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new HashMap<>();

            final HashMap<ThreadLocal<Object>, Object> threadLocal2Value = new HashMap<>();
            for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {
                final ThreadLocal<Object> threadLocal = entry.getKey();
                threadLocal2Value.put(threadLocal, threadLocalClearMark);
            }

            return replay(new Snapshot(ttl2Value, threadLocal2Value));
        }

        /**
         * 还原
         */
        public static void restore(@NonNull Object backup) {
            final Snapshot backupSnapshot = (Snapshot) backup;
            restoreTtlValues(backupSnapshot.ttl2Value);
            restoreThreadLocalValues(backupSnapshot.threadLocal2Value);
        }

        private static void restoreTtlValues(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> backup) {
           // 扩展点,调用TTL的afterExecute
            doExecuteCallback(false);

            for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
                TransmittableThreadLocal<Object> threadLocal = iterator.next();

                // 清除备份中没有的TTL值
                // 恢复后避免额外的TTL值
                if (!backup.containsKey(threadLocal)) {
                    iterator.remove();
                    threadLocal.superRemove();
                }
            }

           // 将本地变量恢复成备份版本
            setTtlValuesTo(backup);
        }

        private static void setTtlValuesTo(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
            for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
                TransmittableThreadLocal<Object> threadLocal = entry.getKey();
                threadLocal.set(entry.getValue());
            }
        }

        private static void restoreThreadLocalValues(@NonNull HashMap<ThreadLocal<Object>, Object> backup) {
            for (Map.Entry<ThreadLocal<Object>, Object> entry : backup.entrySet()) {
                final ThreadLocal<Object> threadLocal = entry.getKey();
                threadLocal.set(entry.getValue());
            }
        }

    /**
   * 快照类,保存TTL和TL
   */
        private static class Snapshot {
            final HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value;
            final HashMap<ThreadLocal<Object>, Object> threadLocal2Value;

            private Snapshot(HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value, HashMap<ThreadLocal<Object>, Object> threadLocal2Value) {
                this.ttl2Value = ttl2Value;
                this.threadLocal2Value = threadLocal2Value;
            }
        }
    }
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161

# TtlCallable方法

  • TtlRunnable方法类似
public final class TtlCallable<V> implements Callable<V>, TtlWrapper<Callable<V>>, TtlEnhanced, TtlAttachments 
{

	@Override
    @SuppressFBWarnings("THROWS_METHOD_THROWS_CLAUSE_BASIC_EXCEPTION")
    public V call() throws Exception {
        final Object captured = capturedRef.get();
        if (captured == null || releaseTtlValueReferenceAfterCall && !capturedRef.compareAndSet(captured, null)) {
            throw new IllegalStateException("TTL value reference is released after call!");
        }
		// 调用replay方法将捕获到的当前线程的本地变量,传递给线程池线程的本地变量,
 		// 并且获取到线程池线程覆盖之前的本地变量副本。
        final Object backup = replay(captured);
        try {
            // 线程方法调用
            return callable.call();
        } finally {
            // 使用副本进行恢复。
            restore(backup);
        }
    }
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

# 时序图

ThansmittableThreadLocal