springboot1.5.21的bug导致高并发场景下一个用户token解析成另一个用户

SpringBoot bug在org.springframework.boot.autoconfigure.security.oauth2.resource.UserInfoTokenServices.java Line:144

private Map getMap(String path, String accessToken) {
  
if (this.logger.isDebugEnabled()) {
     
this.logger.debug("Getting user info from: " + path);
   }
  
try {
      OAuth2RestOperations restTemplate =
this.restTemplate;
     
if (restTemplate == null) {
         BaseOAuth2ProtectedResourceDetails resource =
new BaseOAuth2ProtectedResourceDetails();
         resource.setClientId(
this.clientId);
         restTemplate =
new OAuth2RestTemplate(resource);
      }
      OAuth2AccessToken existingToken = restTemplate.getOAuth2ClientContext()
            .getAccessToken();
     
if (existingToken == null || !accessToken.equals(existingToken.getValue())) {
         DefaultOAuth2AccessToken token =
new DefaultOAuth2AccessToken(
               accessToken);
         token.setTokenType(
this.tokenType);

         // 此处当发生线程切换时,将导致token到user信息解析错乱
         restTemplate.getOAuth2ClientContext().setAccessToken(token);

      }
     
return restTemplate.getForEntity(path, Map.class).getBody();
   }
  
catch (Exception ex) {
     
this.logger.warn("Could not fetch user details: " + ex.getClass() + ", "
           
+ ex.getMessage());
     
return Collections.singletonMap("error",
           
"Could not fetch user details");
   }
}

如何稳定重现这个问题:

自定义UserInfoTokenServices.java代替spring默认提供的,类似如下:

public class CustomizedUserInfoTokenServices implements ResourceServerTokenServices {

    protected final Log logger = LogFactory.getLog(getClass());

    private final String userInfoEndpointUrl;

    private final String clientId;

    private OAuth2RestOperations restTemplate;

    private String tokenType = DefaultOAuth2AccessToken.BEARER_TYPE;

    private AuthoritiesExtractor authoritiesExtractor = new FixedAuthoritiesExtractor();

    private PrincipalExtractor principalExtractor = new FixedPrincipalExtractor();

    public CustomizedUserInfoTokenServices(String userInfoEndpointUrl, String clientId) {
        this.userInfoEndpointUrl = userInfoEndpointUrl;
        this.clientId = clientId;
    }

    public void setTokenType(String tokenType) {
        this.tokenType = tokenType;
    }

    public void setRestTemplate(OAuth2RestOperations restTemplate) {
        this.restTemplate = restTemplate;
    }

    public void setAuthoritiesExtractor(AuthoritiesExtractor authoritiesExtractor) {
        Assert.notNull(authoritiesExtractor, "AuthoritiesExtractor must not be null");
        this.authoritiesExtractor = authoritiesExtractor;
    }

    public void setPrincipalExtractor(PrincipalExtractor principalExtractor) {
        Assert.notNull(principalExtractor, "PrincipalExtractor must not be null");
        this.principalExtractor = principalExtractor;
    }

    @Override
    public OAuth2Authentication loadAuthentication(String accessToken)
            throws AuthenticationException, InvalidTokenException {
        Map map = getMap(this.userInfoEndpointUrl, accessToken);
        if (map.containsKey("error")) {
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("userinfo returned error: " + map.get("error"));
            }
            throw new InvalidTokenException(accessToken);
        }
        return extractAuthentication(map);
    }

    private OAuth2Authentication extractAuthentication(Map map) {
        Object principal = getPrincipal(map);
        List authorities = this.authoritiesExtractor
                .extractAuthorities(map);
        OAuth2Request request = new OAuth2Request(null, this.clientId, null, true, null,
                null, null, null, null);
        UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(
                principal, "N/A", authorities);
        token.setDetails(map);
        return new OAuth2Authentication(request, token);
    }

    /**
     * Return the principal that should be used for the token. The default implementation
     * delegates to the {@link PrincipalExtractor}.
     *
     * @param map the source map
     * @return the principal or {@literal "unknown"}
     */
    protected Object getPrincipal(Map map) {
        Object principal = this.principalExtractor.extractPrincipal(map);
        return (principal != null) ? principal : "unknown";
    }

    @Override
    public OAuth2AccessToken readAccessToken(String accessToken) {
        throw new UnsupportedOperationException("Not supported: read access token");
    }

    @SuppressWarnings({"unchecked"})
    private Map getMap(String path, String accessToken) {
        if (this.logger.isDebugEnabled()) {
            this.logger.debug("Getting user info from: " + path);
        }
        try {
            OAuth2RestOperations restTemplate = this.restTemplate;
            if (restTemplate == null) {
                BaseOAuth2ProtectedResourceDetails resource = new BaseOAuth2ProtectedResourceDetails();
                resource.setClientId(this.clientId);
                restTemplate = new OAuth2RestTemplate(resource);
            }
            OAuth2AccessToken existingToken = restTemplate.getOAuth2ClientContext()
                    .getAccessToken();
            if (existingToken == null || !accessToken.equals(existingToken.getValue())) {
                DefaultOAuth2AccessToken token = new DefaultOAuth2AccessToken(
                        accessToken);
                token.setTokenType(this.tokenType);
                restTemplate.getOAuth2ClientContext().setAccessToken(token);
                Thread.yield();
                if (!StringUtils.equals(accessToken, restTemplate.getOAuth2ClientContext().getAccessToken().getValue())) {
                    logger.info("token before: " + accessToken + " token after: " + restTemplate.getOAuth2ClientContext().getAccessToken().getValue());
                }

            }
            return restTemplate.getForEntity(path, Map.class).getBody();
        } catch (Exception ex) {
            this.logger.warn("Could not fetch user details: " + ex.getClass() + ", "
                    + ex.getMessage());
            return Collections.singletonMap("error",
                    "Could not fetch user details");
        }
    }
 

然后压测接口会发现上文绿色部分的log被打印出来,这会导致token错乱。

 

修复方法:自定义OAuth2ClientContext.java类,替换掉spring security默认提供的,使用ThreadLocal类型的 accessToken,该类内容如下:

import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;

import org.springframework.security.oauth2.client.OAuth2ClientContext;
import org.springframework.security.oauth2.client.token.AccessTokenRequest;
import org.springframework.security.oauth2.client.token.DefaultAccessTokenRequest;
import org.springframework.security.oauth2.common.OAuth2AccessToken;

/**
 * The OAuth 2 security context (for a specific user or client or combination thereof).
 *
 * @author Zhengfei Yan
 */
public class CustomizedOAuth2ClientContext implements OAuth2ClientContext, Serializable {

    private static final long serialVersionUID = 3078781745905248724L;

    // make accessToken thread local to avoid thread safe issue
    private ThreadLocal accessToken = new ThreadLocal<>();

    private AccessTokenRequest accessTokenRequest;

    private Map state = new HashMap();

    public CustomizedOAuth2ClientContext() {
        this(new DefaultAccessTokenRequest());
    }

    public CustomizedOAuth2ClientContext(AccessTokenRequest accessTokenRequest) {
        this.accessTokenRequest = accessTokenRequest;
    }

    public CustomizedOAuth2ClientContext(OAuth2AccessToken accessToken) {
        this.accessToken.set(accessToken);
        this.accessTokenRequest = new DefaultAccessTokenRequest();
    }

    public OAuth2AccessToken getAccessToken() {
        return accessToken.get();
    }

    public void setAccessToken(OAuth2AccessToken accessToken) {
        this.accessToken.set(accessToken);
        this.accessTokenRequest.setExistingToken(accessToken);
    }

    public AccessTokenRequest getAccessTokenRequest() {
        return accessTokenRequest;
    }

    public void setPreservedState(String stateKey, Object preservedState) {
        state.put(stateKey, preservedState);
    }

    public Object removePreservedState(String stateKey) {
        return state.remove(stateKey);
    }

}

参考链接:https://github.com/spring-projects/spring-security-oauth/pull/1489/files

你可能感兴趣的:(J2EE,java,spring)