自学内容网 自学内容网

多租户 TransmittableThreadLocal 线程安全问题

在一个多租户项目中,用户登录时,会在自定义请求头拦截器AsyncHandlerInterceptor将该用户的userId,cstNo等用户信息设置到TransmittableThreadLocal中,在后续代码中使用.代码如下:

HeaderInterceptor 请求头拦截器


public class HeaderInterceptor implements AsyncHandlerInterceptor
{
    @Override
    public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception
    {
        if (!(handler instanceof HandlerMethod))
        {
            return true;
        }

        SecurityContextHolder.setUserId(ServletUtils.getHeader(request, SecurityConstants.DETAILS_USER_ID));
        SecurityContextHolder.setUserName(ServletUtils.getHeader(request, SecurityConstants.DETAILS_USERNAME));
        SecurityContextHolder.setUserKey(ServletUtils.getHeader(request, SecurityConstants.USER_KEY));
        SecurityContextHolder.setCstNoKey(ServletUtils.getHeader(request, SecurityConstants.CST_NO));

        String token = SecurityUtils.getToken();
        if (StringUtils.isNotEmpty(token))
        {
            LoginUser loginUser = AuthUtil.getLoginUser(token);
            if (StringUtils.isNotNull(loginUser))
            {
                AuthUtil.verifyLoginUserExpire(loginUser);
                SecurityContextHolder.set(SecurityConstants.LOGIN_USER, loginUser);
            }
        }
        return true;
    }

    @Override
    public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex)
            throws Exception
    {
        SecurityContextHolder.remove();
    }
}

SecurityContextHolder


public class SecurityContextHolder
{
    private static final TransmittableThreadLocal<Map<String, Object>> THREAD_LOCAL = new TransmittableThreadLocal<>();

    public static void set(String key, Object value)
    {
        Map<String, Object> map = getLocalMap();
        map.put(key, value == null ? StringUtils.EMPTY : value);
    }

    public static String get(String key)
    {
        Map<String, Object> map = getLocalMap();
        return Convert.toStr(map.getOrDefault(key, StringUtils.EMPTY));
    }

    public static <T> T get(String key, Class<T> clazz)
    {
        Map<String, Object> map = getLocalMap();
        return StringUtils.cast(map.getOrDefault(key, null));
    }

    public static Map<String, Object> getLocalMap()
    {
        Map<String, Object> map = THREAD_LOCAL.get();
        if (map == null)
        {
            map = new ConcurrentHashMap<String, Object>();
            THREAD_LOCAL.set(map);
        }
        return map;
    }

    public static void setLocalMap(Map<String, Object> threadLocalMap)
    {
        THREAD_LOCAL.set(threadLocalMap);
    }

    public static void setCstNoKey(String cstNoKey)
    {
        set(SecurityConstants.CST_NO, cstNoKey);
    }
    public static String getCstNo(){
        return Convert.toStr(get(SecurityConstants.CST_NO), null);
    }

    public static Long getUserId()
    {
        return Convert.toLong(get(SecurityConstants.DETAILS_USER_ID), 0L);
    }

    public static void setUserId(String account)
    {
        set(SecurityConstants.DETAILS_USER_ID, account);
    }

    public static String getUserName()
    {
        return get(SecurityConstants.DETAILS_USERNAME);
    }

    public static void setUserName(String username)
    {
        set(SecurityConstants.DETAILS_USERNAME, username);
    }

    public static String getUserKey()
    {
        return get(SecurityConstants.USER_KEY);
    }

    public static void setUserKey(String userKey)
    {
        set(SecurityConstants.USER_KEY, userKey);
    }

    public static String getPermission()
    {
        return get(SecurityConstants.ROLE_PERMISSION);
    }

    public static void setPermission(String permissions)
    {
        set(SecurityConstants.ROLE_PERMISSION, permissions);
    }

    public static void remove()
    {
        THREAD_LOCAL.remove();
    }

}

在业务层通过调用多线程方法时,分别在主线程和子线程中输出客户编号,伪代码如下:

public void listMasterThread() {
 log.info(">>>>>>>>>>>主线程{}>>>>>", SecurityContextHolder.getCstNo());
CompletableFuture.supplyAsync(() -> this.savleThread(),threadPoolConfig.threadPool())

}

public void savleThread(){
 log.info(">>>>>>>>>>>子线程{}>>>>>", SecurityContextHolder.getCstNo());
}

问题复现:
客户编号为85的员工先登录(第一次登录)并调用该方法时,输出如下:

>>>>>>>>>>>主线程85 >>>>>
>>>>>>>>>>>子线程85 >>>>>

接着客户编号82的员工接着登录(第二次登录)并调用时,输出:

>>>>>>>>>>>主线程82 >>>>>
>>>>>>>>>>>子线程85 >>>>>

问题就来了,第二次登录时,主线程获取的客户编号正确,但是子线程获取的客户编号是上一次登录,并且多调用几次之后,子线程的输出也正常了.

下面是TransmittableThreadLocal.get()方法源码,最终调用的是ThreadLocal.get()
在这里插入图片描述
百度了一下TransmittableThreadLocal 用于解决 “在使用线程池会缓存线程的组件情况下传递ThreadLocal”问题的.

到这里就很疑惑,明明用的就是 TransmittableThreadLocal但是为何还会有此问题,经过折腾

解决方法添加TtlExecutors.getTtlExecutor()

// threadPoolConfig.threadPool() 参数为指定的线程池
   TtlExecutors.getTtlExecutor(threadPoolConfig.threadPool());

纠正后的伪代码如下,

public void listMasterThread() {
log.info(">>>>>>>>>>>主线程{}>>>>>", SecurityContextHolder.getCstNo());
TtlExecutors.getTtlExecutor(threadPoolConfig.threadPool());
CompletableFuture.supplyAsync(() -> this.savleThread(),threadPoolConfig.threadPool())

}

public void savleThread(){
 log.info(">>>>>>>>>>>子线程{}>>>>>", SecurityContextHolder.getCstNo());
}

具体解决原理不知!!!


原文地址:https://blog.csdn.net/qq_33776323/article/details/136412531

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!