Java 中的 InheritableThreadLocal

Java juc 面试 大约 4300 字

使用场景

父子线程中传递数据的方式。

示例代码

main线程中设置了abc,之后再new一个child线程,在child线程中使用inheritableThreadLocal获取main线程设置的值。

public static void main(String[] args) {
    InheritableThreadLocal<String> inheritableThreadLocal = new InheritableThreadLocal<>();
    inheritableThreadLocal.set("abc");
    new Thread(() -> {
        System.out.println(inheritableThreadLocal.get());
    }, "child").start();
}

注意

如果在new线程之后再去赋值,则子线程初始化时不会拷贝父线程中的inheritableThreadLocal变量,导致子线程调用get()方法获取的都是null(孙线程同理)。

public static void main(String[] args) throws InterruptedException {
    InheritableThreadLocal<String> inheritableThreadLocal = new InheritableThreadLocal<>();

    new Thread(() -> {
        System.out.println(LocalDateTime.now() + " 子线程初始化完成");

        new Thread(() -> {
            System.out.println(LocalDateTime.now() + " 孙线程初始化完成");
            try {
                TimeUnit.MILLISECONDS.sleep(1500);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            System.out.println(LocalDateTime.now() + " 孙线程sleep完成");
            String s2 = inheritableThreadLocal.get();
            System.out.println(LocalDateTime.now() + " 孙线程获取的数据#" + s2);
        }).start();

        try {
            TimeUnit.SECONDS.sleep(3);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println(LocalDateTime.now() + " 子线程sleep完成");
        String s = inheritableThreadLocal.get();
        System.out.println(LocalDateTime.now() + " 子线程获取的数据#" + s);


    }).start();

    TimeUnit.MILLISECONDS.sleep(1000);
    System.out.println(LocalDateTime.now() + " 主线程设置value");

    inheritableThreadLocal.set("abc");

}

源码分析

在创建线程时,将从当前线程(new Thread所在的线程)的inheritableThreadLocals中复制一份ThreadLocalMap

更多ThreadLocal相关源码分析可查看之前文章:https://www.zhangbj.com/p/789.html

public class Thread implements Runnable {

    public Thread(Runnable target) {
        this(null, target, "Thread-" + nextThreadNum(), 0);
    }

    public Thread(ThreadGroup group, Runnable target, String name,
                  long stackSize) {
        this(group, target, name, stackSize, null, true);
    }

    private Thread(ThreadGroup g, Runnable target, String name,
                   long stackSize, AccessControlContext acc,
                   boolean inheritThreadLocals) {
        if (name == null) {
            throw new NullPointerException("name cannot be null");
        }

        this.name = name;

        Thread parent = currentThread();

        ...

        // 入参 inheritThreadLocals 为 true,且当前初始化该 Thread 的线程中的 inheritableThreadLocals 成员变量不为 null
        if (inheritThreadLocals && parent.inheritableThreadLocals != null)
            // 子线程的 inheritableThreadLocals 就根据父线程的 inheritableThreadLocals 变量,创建一个 ThreadLocalMap
            this.inheritableThreadLocals =
                ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
        ...
    }
}    

public class ThreadLocal<T> {
    static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
        return new ThreadLocalMap(parentMap);
    }

    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            map.set(this, value);
        } else {
            createMap(t, value);
        }
    }

    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }
}

public class InheritableThreadLocal<T> extends ThreadLocal<T> {

    protected T childValue(T parentValue) {
        return parentValue;
    }

    // 复写了 ThreadLocal 中的 getMap() 方法,返回线程中的 inheritableThreadLocals 变量
    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }

    // 复写了 ThreadLocal 中的 createMap() 方法,为线程的 inheritableThreadLocals 创建一个 ThreadLocalMap
    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}
阅读 477 · 发布于 2021-04-15

————        END        ————

扫描下方二维码关注公众号和小程序↓↓↓

扫描二维码关注我
昵称:
随便看看 换一批