SpringBoot bug在org.springframework.boot.autoconfigure.security.oauth2.resource.UserInfoTokenServices.java Line:144
private Map
this.logger.debug("Getting user info from: " + path);
}
try {
OAuth2RestOperations restTemplate = this.restTemplate;
if (restTemplate == null) {
BaseOAuth2ProtectedResourceDetails resource = new BaseOAuth2ProtectedResourceDetails();
resource.setClientId(this.clientId);
restTemplate = new OAuth2RestTemplate(resource);
}
OAuth2AccessToken existingToken = restTemplate.getOAuth2ClientContext()
.getAccessToken();
if (existingToken == null || !accessToken.equals(existingToken.getValue())) {
DefaultOAuth2AccessToken token = new DefaultOAuth2AccessToken(
accessToken);
token.setTokenType(this.tokenType);
// 此处当发生线程切换时,将导致token到user信息解析错乱
restTemplate.getOAuth2ClientContext().setAccessToken(token);
}
return restTemplate.getForEntity(path, Map.class).getBody();
}
catch (Exception ex) {
this.logger.warn("Could not fetch user details: " + ex.getClass() + ", "
+ ex.getMessage());
return Collections.
"Could not fetch user details");
}
}
如何稳定重现这个问题:
自定义UserInfoTokenServices.java代替spring默认提供的,类似如下:
public class CustomizedUserInfoTokenServices implements ResourceServerTokenServices {
protected final Log logger = LogFactory.getLog(getClass());
private final String userInfoEndpointUrl;
private final String clientId;
private OAuth2RestOperations restTemplate;
private String tokenType = DefaultOAuth2AccessToken.BEARER_TYPE;
private AuthoritiesExtractor authoritiesExtractor = new FixedAuthoritiesExtractor();
private PrincipalExtractor principalExtractor = new FixedPrincipalExtractor();
public CustomizedUserInfoTokenServices(String userInfoEndpointUrl, String clientId) {
this.userInfoEndpointUrl = userInfoEndpointUrl;
this.clientId = clientId;
}
public void setTokenType(String tokenType) {
this.tokenType = tokenType;
}
public void setRestTemplate(OAuth2RestOperations restTemplate) {
this.restTemplate = restTemplate;
}
public void setAuthoritiesExtractor(AuthoritiesExtractor authoritiesExtractor) {
Assert.notNull(authoritiesExtractor, "AuthoritiesExtractor must not be null");
this.authoritiesExtractor = authoritiesExtractor;
}
public void setPrincipalExtractor(PrincipalExtractor principalExtractor) {
Assert.notNull(principalExtractor, "PrincipalExtractor must not be null");
this.principalExtractor = principalExtractor;
}
@Override
public OAuth2Authentication loadAuthentication(String accessToken)
throws AuthenticationException, InvalidTokenException {
Map
if (map.containsKey("error")) {
if (this.logger.isDebugEnabled()) {
this.logger.debug("userinfo returned error: " + map.get("error"));
}
throw new InvalidTokenException(accessToken);
}
return extractAuthentication(map);
}
private OAuth2Authentication extractAuthentication(Map
Object principal = getPrincipal(map);
List
.extractAuthorities(map);
OAuth2Request request = new OAuth2Request(null, this.clientId, null, true, null,
null, null, null, null);
UsernamePasswordAuthenticationToken token = new UsernamePasswordAuthenticationToken(
principal, "N/A", authorities);
token.setDetails(map);
return new OAuth2Authentication(request, token);
}
/**
* Return the principal that should be used for the token. The default implementation
* delegates to the {@link PrincipalExtractor}.
*
* @param map the source map
* @return the principal or {@literal "unknown"}
*/
protected Object getPrincipal(Map
Object principal = this.principalExtractor.extractPrincipal(map);
return (principal != null) ? principal : "unknown";
}
@Override
public OAuth2AccessToken readAccessToken(String accessToken) {
throw new UnsupportedOperationException("Not supported: read access token");
}
@SuppressWarnings({"unchecked"})
private Map
if (this.logger.isDebugEnabled()) {
this.logger.debug("Getting user info from: " + path);
}
try {
OAuth2RestOperations restTemplate = this.restTemplate;
if (restTemplate == null) {
BaseOAuth2ProtectedResourceDetails resource = new BaseOAuth2ProtectedResourceDetails();
resource.setClientId(this.clientId);
restTemplate = new OAuth2RestTemplate(resource);
}
OAuth2AccessToken existingToken = restTemplate.getOAuth2ClientContext()
.getAccessToken();
if (existingToken == null || !accessToken.equals(existingToken.getValue())) {
DefaultOAuth2AccessToken token = new DefaultOAuth2AccessToken(
accessToken);
token.setTokenType(this.tokenType);
restTemplate.getOAuth2ClientContext().setAccessToken(token);
Thread.yield();
if (!StringUtils.equals(accessToken, restTemplate.getOAuth2ClientContext().getAccessToken().getValue())) {
logger.info("token before: " + accessToken + " token after: " + restTemplate.getOAuth2ClientContext().getAccessToken().getValue());
}
}
return restTemplate.getForEntity(path, Map.class).getBody();
} catch (Exception ex) {
this.logger.warn("Could not fetch user details: " + ex.getClass() + ", "
+ ex.getMessage());
return Collections.
"Could not fetch user details");
}
}
然后压测接口会发现上文绿色部分的log被打印出来,这会导致token错乱。
修复方法:自定义OAuth2ClientContext.java类,替换掉spring security默认提供的,使用ThreadLocal类型的 accessToken,该类内容如下:
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
import org.springframework.security.oauth2.client.OAuth2ClientContext;
import org.springframework.security.oauth2.client.token.AccessTokenRequest;
import org.springframework.security.oauth2.client.token.DefaultAccessTokenRequest;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
/**
* The OAuth 2 security context (for a specific user or client or combination thereof).
*
* @author Zhengfei Yan
*/
public class CustomizedOAuth2ClientContext implements OAuth2ClientContext, Serializable {
private static final long serialVersionUID = 3078781745905248724L;
// make accessToken thread local to avoid thread safe issue
private ThreadLocal
private AccessTokenRequest accessTokenRequest;
private Map
public CustomizedOAuth2ClientContext() {
this(new DefaultAccessTokenRequest());
}
public CustomizedOAuth2ClientContext(AccessTokenRequest accessTokenRequest) {
this.accessTokenRequest = accessTokenRequest;
}
public CustomizedOAuth2ClientContext(OAuth2AccessToken accessToken) {
this.accessToken.set(accessToken);
this.accessTokenRequest = new DefaultAccessTokenRequest();
}
public OAuth2AccessToken getAccessToken() {
return accessToken.get();
}
public void setAccessToken(OAuth2AccessToken accessToken) {
this.accessToken.set(accessToken);
this.accessTokenRequest.setExistingToken(accessToken);
}
public AccessTokenRequest getAccessTokenRequest() {
return accessTokenRequest;
}
public void setPreservedState(String stateKey, Object preservedState) {
state.put(stateKey, preservedState);
}
public Object removePreservedState(String stateKey) {
return state.remove(stateKey);
}
}
参考链接:https://github.com/spring-projects/spring-security-oauth/pull/1489/files