ThreadLocal在常见框架中的使用

简介

ThreadLocal能够为当前线程提供存储和读取变量的能力,提供一个静态方法,从而能够让若干模块解耦;也为多线程并发提供一个思路,在ThreadLocal中为当前储存变量,只为当前线程所用,让多线程之间不互相干扰。

本文简单介绍ThreadLocal,列举一些常见框架中的使用场景,从而对它有更好的认识。

ThreadLocal API

ThreadLocal常用的有三个方法set, get, remove,下面用一小段代码来看看这三个方法的使用。

public class ThreadLocalTest {

    final static ThreadLocal threadLocal = new ThreadLocal<>();

    @Test
    public void testThreadLocal() {

        final String valueBeforeSet = threadLocal.get();
        log.info("value before set: {}", valueBeforeSet);

        threadLocal.set("test");
        final String valueAfterSet = threadLocal.get();
        log.info("value after set: {}", valueAfterSet);


        threadLocal.remove();
        final String valueAfterRemove = threadLocal.get();
        log.info("value after remove: {}", valueAfterRemove);

    }
}

输出:

2018-05-21 20:43:41,390{GMT} INFO  value before set: null
2018-05-21 20:43:41,396{GMT} INFO  value after set: test
2018-05-21 20:43:41,396{GMT} INFO  value after remove: null

上面的代码中可以清楚看到getset方法的使用,同时remove也是非常重要的,因为线程池的原因,如果不执行remove操作,这个线程在下次被重复使用的时候,存储在ThreadLocal中的值仍可使用。

在框架中的使用

Spring Security中的使用

如果你使用过Spring Security,那么对SecurityContextHolder肯定不陌生:

SecurityContext securityContext = SecurityContextHolder.getContext();
Authentication authentication = securityContext.getAuthentication();

上面这段代码经常用来获取当前认证过的用户相关信息,而这个方法之所以能够工作,其中之一就用的ThreadLocal
详细代码可以可见github,这里摘取一些通过ThreadLocal来存储认证信息片段。

public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
            throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest) req;
        HttpServletResponse response = (HttpServletResponse) res;

        if (request.getAttribute(FILTER_APPLIED) != null) {
            // ensure that filter is only applied once per request
            chain.doFilter(request, response);
            return;
        }

        final boolean debug = logger.isDebugEnabled();

        request.setAttribute(FILTER_APPLIED, Boolean.TRUE);

        if (forceEagerSessionCreation) {
            HttpSession session = request.getSession();

            if (debug && session.isNew()) {
                logger.debug("Eagerly created session: " + session.getId());
            }
        }

        HttpRequestResponseHolder holder = new HttpRequestResponseHolder(request, response);
        // 下面摘取HttpSessionSecurityContextRepository代码来看看这个从哪里获取
        SecurityContext contextBeforeChainExecution = repo.loadContext(holder);

        try {
            // 调用set方法存储
            SecurityContextHolder.setContext(contextBeforeChainExecution);
            // 继续调用其他filter
            chain.doFilter(holder.getRequest(), holder.getResponse());

        } finally {
            SecurityContext contextAfterChainExecution = SecurityContextHolder.getContext();
            // Crucial removal of SecurityContextHolder contents - do this before anything else.
            // finally方法中一定要remove,防止线程被重复使用,变量仍在
            SecurityContextHolder.clearContext();
            repo.saveContext(contextAfterChainExecution, holder.getRequest(), holder.getResponse());
            request.removeAttribute(FILTER_APPLIED);

            if (debug) {
                logger.debug("SecurityContextHolder now cleared, as request processing completed");
            }
        }
    }

下面在看看上面如何获取SecurityContext,这里摘取HttpSessionSecurityContextRepository源码片段,

public SecurityContext loadContext(HttpRequestResponseHolder requestResponseHolder) {
        // 从requestResponseHolder中获取request和response
        HttpServletRequest request = requestResponseHolder.getRequest();
        HttpServletResponse response = requestResponseHolder.getResponse();
        HttpSession httpSession = request.getSession(false);

        // 下面private方法看看如何从session里面获取context
        SecurityContext context = readSecurityContextFromSession(httpSession);

        if (context == null) {
            if (logger.isDebugEnabled()) {
                logger.debug("No SecurityContext was available from the HttpSession: " + httpSession +". " +
                        "A new one will be created.");
            }
            context = generateNewContext();

        }

        SaveToSessionResponseWrapper wrappedResponse = new SaveToSessionResponseWrapper(response, request, httpSession != null, context);
        requestResponseHolder.setResponse(wrappedResponse);

        if(isServlet3) {
            requestResponseHolder.setRequest(new Servlet3SaveToSessionRequestWrapper(request, wrappedResponse));
        }

        return context;
    }

    /**
     *
     * @param httpSession the session obtained from the request.
     */
private SecurityContext readSecurityContextFromSession(HttpSession httpSession) {
        final boolean debug = logger.isDebugEnabled();

        if (httpSession == null) {
            if (debug) {
                logger.debug("No HttpSession currently exists");
            }

            return null;
        }

        // Session exists, so try to obtain a context from it.

        // 原来session里面存储着一个特殊的属性key为SPRING_SECURITY_CONTEXT,这个就是context
        Object contextFromSession = httpSession.getAttribute(springSecurityContextKey);

        if (contextFromSession == null) {
            if (debug) {
                logger.debug("HttpSession returned null object for SPRING_SECURITY_CONTEXT");
            }

            return null;
        }

        // We now have the security context object from the session.
        if (!(contextFromSession instanceof SecurityContext)) {
            if (logger.isWarnEnabled()) {
                logger.warn(springSecurityContextKey + " did not contain a SecurityContext but contained: '"
                        + contextFromSession + "'; are you improperly modifying the HttpSession directly "
                        + "(you should always use SecurityContextHolder) or using the HttpSession attribute "
                        + "reserved for this class?");
            }

            return null;
        }

        if (debug) {
            logger.debug("Obtained a valid SecurityContext from " + springSecurityContextKey + ": '" + contextFromSession + "'");
        }

        // Everything OK. The only non-null return from this method.

        return (SecurityContext) contextFromSession;
    }

上面代码我们就看出了ThreadLocalSecurityContextHolderStrategy是如何工作的了。

Spring MVC中的使用

如果你在Spring MVC中启用过过org.springframework.web.context.request.RequestContextListener,那么对RequestContextHolder肯定不陌生,通过RequestContextHolder提供的静态方法,可以获取当前request对象。

public static HttpServletRequest getRequest() {
        ServletRequestAttributes attr = (ServletRequestAttributes) RequestContextHolder.currentRequestAttributes();
        return  attr.getRequest();
    }
public static HttpSession getSession() {
        ServletRequestAttributes attr = (ServletRequestAttributes) RequestContextHolder.currentRequestAttributes();
        return attr.getRequest().getSession(false); // true == allow create
    }

上述方法能够生效只要看看RequestContextListener源码就很好理解了

public void requestInitialized(ServletRequestEvent requestEvent) {
        if (!(requestEvent.getServletRequest() instanceof HttpServletRequest)) {
            throw new IllegalArgumentException(
                    "Request is not an HttpServletRequest: " + requestEvent.getServletRequest());
        }
        HttpServletRequest request = (HttpServletRequest) requestEvent.getServletRequest();
        ServletRequestAttributes attributes = new ServletRequestAttributes(request);
        request.setAttribute(REQUEST_ATTRIBUTES_ATTRIBUTE, attributes);
        LocaleContextHolder.setLocale(request.getLocale());
        //存储
        RequestContextHolder.setRequestAttributes(attributes);
    }

public void requestDestroyed(ServletRequestEvent requestEvent) {
        //销毁
        ServletRequestAttributes attributes =
                (ServletRequestAttributes) requestEvent.getServletRequest().getAttribute(REQUEST_ATTRIBUTES_ATTRIBUTE);
        ServletRequestAttributes threadAttributes =
                (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        if (threadAttributes != null) {
            // We're assumably within the original request thread...
            if (attributes == null) {
                attributes = threadAttributes;
            }
            LocaleContextHolder.resetLocaleContext();
            RequestContextHolder.resetRequestAttributes();
        }
        if (attributes != null) {
            attributes.requestCompleted();
        }
    }

多数据源切换

在开发中可能会遇到有多个数据源的情况,例如两个Datasource之间的切换,配合SpringAOPAbstractRoutingDataSource,在加上改变ThreadLocal中的值来完成数据源的切换。
实现AbstractRoutingDataSource时,实现determineCurrentLookupKey方法来返回ThreadLocal中的值,通过其他方法
去改变ThreadLocal中的值,例如对某些方法的切面, 自定义注解等等。

log4j中的使用

如果想在·log4j·中打印一些公共的变量。
例如TrackingId表示每一个请求,可以在filter中先通过org.apache.log4j.MDC.put(String key, Object o)插入一个trackingIdkey(例如TACKING_ID)和value,然后配置log4j的格式%X{TACKING_ID}即可在日志中实现插入一个值,即使为null也没有问题。

你可能感兴趣的:(ThreadLocal在常见框架中的使用)