package com.tx.tcm.oauth.security.service;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
/**
* 自定义UserDetailsService接口类;实现多用户登录
**/
public interface UserDetailsServices extends UserDetailsService {
/**
*
* @param username 用户名
* @param type 用户类型
* @return org.springframework.security.core.userdetails.UserDetails
*/
UserDetails loadUserByUsername(String username, String type) throws UsernameNotFoundException;
}
package com.tx.tcm.oauth.security.service;
import com.tx.tcm.oauth.system.service.SysUserService;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.stereotype.Service;
import javax.annotation.Resource;
/**
* 自定义UserDetailsService接口实现类
**/
@Service
public class UserDetailsServicesImpl implements UserDetailsServices {
private static final String USER = "user";
private static final String USERCOPE = "userCope";
@Resource
SysUserService sysUserService;
@Override
public UserDetails loadUserByUsername(String username, String type) throws UsernameNotFoundException {
return sysUserService.selectSysUserByUsernameType(username, type);
}
@Override
public UserDetails loadUserByUsername(String s) throws UsernameNotFoundException {
return null;
}
}
从现在开始,所有需要用到userDetailsService的,全部都要替换成自定义CustomUserDetailsService
package com.tx.tcm.oauth.security.handler;
import com.tx.tcm.oauth.security.service.UserDetailsServices;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.InternalAuthenticationServiceException;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.authentication.dao.AbstractUserDetailsAuthenticationProvider;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsPasswordService;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.security.crypto.factory.PasswordEncoderFactories;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.util.Assert;
import javax.annotation.Resource;
import java.util.Map;
/**
* 自定义AuthenticationProvider类实现多用户登录
**/
public class CustomAuthenticationProvider extends AbstractUserDetailsAuthenticationProvider {
private static final String USER_NOT_FOUND_PASSWORD = "userNotFoundPassword";
private PasswordEncoder passwordEncoder;
private volatile String userNotFoundEncodedPassword;
@Resource
private UserDetailsServices userDetailsServices;
private UserDetailsPasswordService userDetailsPasswordService;
public CustomAuthenticationProvider() {
this.setPasswordEncoder(PasswordEncoderFactories.createDelegatingPasswordEncoder());
}
protected void additionalAuthenticationChecks(UserDetails userDetails, UsernamePasswordAuthenticationToken authentication) throws AuthenticationException {
if (authentication.getCredentials() == null) {
this.logger.debug("Authentication failed: no credentials provided");
throw new BadCredentialsException(this.messages.getMessage("AbstractUserDetailsAuthenticationProvider.badCredentials", "Bad credentials"));
} else {
String presentedPassword = authentication.getCredentials().toString();
if (!this.passwordEncoder.matches(presentedPassword, userDetails.getPassword())) {
this.logger.debug("Authentication failed: password does not match stored value");
throw new BadCredentialsException(this.messages.getMessage("AbstractUserDetailsAuthenticationProvider.badCredentials", "Bad credentials"));
}
}
}
protected void doAfterPropertiesSet() {
Assert.notNull(this.userDetailsServices, "A UserDetailsService must be set");
}
protected final UserDetails retrieveUser(String username, UsernamePasswordAuthenticationToken authentication) throws AuthenticationException {
this.prepareTimingAttackProtection();
// 自定义添加
Map<String,String> map = (Map<String, String>) authentication.getDetails();
try {
// 自定义添加 type必须和传参是的类型一致;否则会报错
String userType = map.get("type");
UserDetails loadedUser = getUserDetailsServices().loadUserByUsername(username, userType);
if (loadedUser == null) {
throw new InternalAuthenticationServiceException("UserDetailsService returned null, which is an interface contract violation");
} else {
return loadedUser;
}
} catch (UsernameNotFoundException var4) {
this.mitigateAgainstTimingAttack(authentication);
throw var4;
} catch (InternalAuthenticationServiceException var5) {
throw var5;
} catch (Exception var6) {
throw new InternalAuthenticationServiceException(var6.getMessage(), var6);
}
}
protected Authentication createSuccessAuthentication(Object principal, Authentication authentication, UserDetails user) {
boolean upgradeEncoding = this.userDetailsPasswordService != null && this.passwordEncoder.upgradeEncoding(user.getPassword());
if (upgradeEncoding) {
String presentedPassword = authentication.getCredentials().toString();
String newPassword = this.passwordEncoder.encode(presentedPassword);
user = this.userDetailsPasswordService.updatePassword(user, newPassword);
}
return super.createSuccessAuthentication(principal, authentication, user);
}
private void prepareTimingAttackProtection() {
if (this.userNotFoundEncodedPassword == null) {
this.userNotFoundEncodedPassword = this.passwordEncoder.encode("userNotFoundPassword");
}
}
private void mitigateAgainstTimingAttack(UsernamePasswordAuthenticationToken authentication) {
if (authentication.getCredentials() != null) {
String presentedPassword = authentication.getCredentials().toString();
this.passwordEncoder.matches(presentedPassword, this.userNotFoundEncodedPassword);
}
}
public void setPasswordEncoder(PasswordEncoder passwordEncoder) {
Assert.notNull(passwordEncoder, "passwordEncoder cannot be null");
this.passwordEncoder = passwordEncoder;
this.userNotFoundEncodedPassword = null;
}
protected PasswordEncoder getPasswordEncoder() {
return this.passwordEncoder;
}
public void setUserDetailsServices(UserDetailsServices userDetailsServices) {
this.userDetailsServices = userDetailsServices;
}
protected UserDetailsServices getUserDetailsServices() {
return this.userDetailsServices;
}
public void setUserDetailsPasswordService(UserDetailsPasswordService userDetailsPasswordService) {
this.userDetailsPasswordService = userDetailsPasswordService;
}
}
/**
* 配置自定义的CustomAuthenticationProvider
*/
@Bean
public AuthenticationProvider customAuthenticationProvider() {
CustomAuthenticationProvider customAuthenticationProvider= new CustomAuthenticationProvider();
customAuthenticationProvider.setUserDetailsServices(userDetailsServices);
customAuthenticationProvider.setHideUserNotFoundExceptions(false);
customAuthenticationProvider.setPasswordEncoder(bCryptPasswordEncoder());
return customAuthenticationProvider;
}
@Override
protected void configure(AuthenticationManagerBuilder auth) throws Exception {
auth.authenticationProvider(customAuthenticationProvider());
auth.userDetailsService(userDetailsServices).passwordEncoder(bCryptPasswordEncoder());
}
注意这个类中的 userDetailsService 也需要替换成自定义的CustomUserDetailsService
package com.tx.tcm.oauth.security.handler;
import com.tx.tcm.oauth.security.service.UserDetailsServices;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.userdetails.*;
import org.springframework.util.Assert;
import java.util.Map;
/**
* 刷新token
**/
public class CustomUserDetailsByNameServiceWrapper<T extends Authentication> implements AuthenticationUserDetailsService<T>, InitializingBean {
private UserDetailsServices userDetailsServices = null;
public CustomUserDetailsByNameServiceWrapper() {
}
public CustomUserDetailsByNameServiceWrapper(UserDetailsServices userDetailsServices) {
Assert.notNull(userDetailsServices, "userDetailsService cannot be null.");
this.userDetailsServices = userDetailsServices;
}
public void afterPropertiesSet() {
Assert.notNull(this.userDetailsServices, "UserDetailsService must be set");
}
public UserDetails loadUserDetails(T authentication) throws UsernameNotFoundException {
AbstractAuthenticationToken principal = (AbstractAuthenticationToken) authentication.getPrincipal();
Map<String,String> map = (Map<String, String>) principal.getDetails();
String userType = map.get("type");
// 使用自定义的userDetailsService
return this.userDetailsServices.loadUserByUsername(authentication.getName(), userType);
}
public void setUserDetailsService(UserDetailsServices aUserDetailsService) {
this.userDetailsServices = aUserDetailsService;
}
}
其实只改了if (this.authenticationManager != null && !authentication.isClientOnly()) 这里面的内容,其他的都没动(这个方法好长,容易造成太长不看的感觉…整个弄上来我其实是不情愿的=。=)
package com.tx.tcm.oauth.security.handler;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.common.*;
import org.springframework.security.oauth2.common.exceptions.InvalidGrantException;
import org.springframework.security.oauth2.common.exceptions.InvalidScopeException;
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
import org.springframework.security.oauth2.provider.*;
import org.springframework.security.oauth2.provider.token.*;
import org.springframework.security.web.authentication.preauth.PreAuthenticatedAuthenticationToken;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.Assert;
import java.util.Date;
import java.util.Set;
import java.util.UUID;
public class CustomTokenServices implements AuthorizationServerTokenServices, ResourceServerTokenServices, ConsumerTokenServices, InitializingBean {
private int refreshTokenValiditySeconds = 2592000;
private int accessTokenValiditySeconds = 43200;
private boolean supportRefreshToken = false;
private boolean reuseRefreshToken = true;
private TokenStore tokenStore;
private ClientDetailsService clientDetailsService;
private TokenEnhancer accessTokenEnhancer;
private AuthenticationManager authenticationManager;
public CustomTokenServices() {
}
public void afterPropertiesSet() throws Exception {
Assert.notNull(this.tokenStore, "tokenStore must be set");
}
@Transactional
public OAuth2AccessToken createAccessToken(OAuth2Authentication authentication) throws AuthenticationException {
OAuth2AccessToken existingAccessToken = this.tokenStore.getAccessToken(authentication);
OAuth2RefreshToken refreshToken = null;
if (existingAccessToken != null) {
if (!existingAccessToken.isExpired()) {
this.tokenStore.storeAccessToken(existingAccessToken, authentication);
return existingAccessToken;
}
if (existingAccessToken.getRefreshToken() != null) {
refreshToken = existingAccessToken.getRefreshToken();
this.tokenStore.removeRefreshToken(refreshToken);
}
this.tokenStore.removeAccessToken(existingAccessToken);
}
if (refreshToken == null) {
refreshToken = this.createRefreshToken(authentication);
} else if (refreshToken instanceof ExpiringOAuth2RefreshToken) {
ExpiringOAuth2RefreshToken expiring = (ExpiringOAuth2RefreshToken)refreshToken;
if (System.currentTimeMillis() > expiring.getExpiration().getTime()) {
refreshToken = this.createRefreshToken(authentication);
}
}
OAuth2AccessToken accessToken = this.createAccessToken(authentication, refreshToken);
this.tokenStore.storeAccessToken(accessToken, authentication);
refreshToken = accessToken.getRefreshToken();
if (refreshToken != null) {
this.tokenStore.storeRefreshToken(refreshToken, authentication);
}
return accessToken;
}
@Transactional(
noRollbackFor = {InvalidTokenException.class, InvalidGrantException.class}
)
public OAuth2AccessToken refreshAccessToken(String refreshTokenValue, TokenRequest tokenRequest) throws AuthenticationException {
if (!this.supportRefreshToken) {
throw new InvalidGrantException("Invalid refresh token: " + refreshTokenValue);
} else {
OAuth2RefreshToken refreshToken = this.tokenStore.readRefreshToken(refreshTokenValue);
if (refreshToken == null) {
throw new InvalidGrantException("Invalid refresh token: " + refreshTokenValue);
} else {
OAuth2Authentication authentication = this.tokenStore.readAuthenticationForRefreshToken(refreshToken);
if (this.authenticationManager != null && !authentication.isClientOnly()) {
// OAuth2Authentication 中的 Authentication userAuthentication 丢失了 Detail的信息,需要补上
// 1.从tokenRequest中获取请求的信息,并重新构造成 UsernamePasswordAuthenticationToken
// 2.设置好了Detail的信息再传入构造 PreAuthenticatedAuthenticationToken 交由后面的验证
tokenRequest.getRequestParameters();
Object details = tokenRequest.getRequestParameters();
UsernamePasswordAuthenticationToken userAuthentication = (UsernamePasswordAuthenticationToken) authentication.getUserAuthentication();
userAuthentication.setDetails(details);
// 去掉原来的,使用自己重新构造的 userAuthentication
// Authentication user = new PreAuthenticatedAuthenticationToken(authentication.getUserAuthentication(), "", authentication.getAuthorities());
Authentication user = new PreAuthenticatedAuthenticationToken(userAuthentication, "", authentication.getAuthorities());
user = this.authenticationManager.authenticate(user);
authentication = new OAuth2Authentication(authentication.getOAuth2Request(), user);
authentication.setDetails(details);
}
String clientId = authentication.getOAuth2Request().getClientId();
if (clientId != null && clientId.equals(tokenRequest.getClientId())) {
this.tokenStore.removeAccessTokenUsingRefreshToken(refreshToken);
if (this.isExpired(refreshToken)) {
this.tokenStore.removeRefreshToken(refreshToken);
throw new InvalidTokenException("Invalid refresh token (expired): " + refreshToken);
} else {
authentication = this.createRefreshedAuthentication(authentication, tokenRequest);
if (!this.reuseRefreshToken) {
this.tokenStore.removeRefreshToken(refreshToken);
refreshToken = this.createRefreshToken(authentication);
}
OAuth2AccessToken accessToken = this.createAccessToken(authentication, refreshToken);
this.tokenStore.storeAccessToken(accessToken, authentication);
if (!this.reuseRefreshToken) {
this.tokenStore.storeRefreshToken(accessToken.getRefreshToken(), authentication);
}
return accessToken;
}
} else {
throw new InvalidGrantException("Wrong client for this refresh token: " + refreshTokenValue);
}
}
}
}
public OAuth2AccessToken getAccessToken(OAuth2Authentication authentication) {
return this.tokenStore.getAccessToken(authentication);
}
private OAuth2Authentication createRefreshedAuthentication(OAuth2Authentication authentication, TokenRequest request) {
Set<String> scope = request.getScope();
OAuth2Request clientAuth = authentication.getOAuth2Request().refresh(request);
if (scope != null && !scope.isEmpty()) {
Set<String> originalScope = clientAuth.getScope();
if (originalScope == null || !originalScope.containsAll(scope)) {
throw new InvalidScopeException("Unable to narrow the scope of the client authentication to " + scope + ".", originalScope);
}
clientAuth = clientAuth.narrowScope(scope);
}
OAuth2Authentication narrowed = new OAuth2Authentication(clientAuth, authentication.getUserAuthentication());
return narrowed;
}
protected boolean isExpired(OAuth2RefreshToken refreshToken) {
if (!(refreshToken instanceof ExpiringOAuth2RefreshToken)) {
return false;
} else {
ExpiringOAuth2RefreshToken expiringToken = (ExpiringOAuth2RefreshToken)refreshToken;
return expiringToken.getExpiration() == null || System.currentTimeMillis() > expiringToken.getExpiration().getTime();
}
}
public OAuth2AccessToken readAccessToken(String accessToken) {
return this.tokenStore.readAccessToken(accessToken);
}
public OAuth2Authentication loadAuthentication(String accessTokenValue) throws AuthenticationException, InvalidTokenException {
OAuth2AccessToken accessToken = this.tokenStore.readAccessToken(accessTokenValue);
if (accessToken == null) {
throw new InvalidTokenException("Invalid access token: " + accessTokenValue);
} else if (accessToken.isExpired()) {
this.tokenStore.removeAccessToken(accessToken);
throw new InvalidTokenException("Access token expired: " + accessTokenValue);
} else {
OAuth2Authentication result = this.tokenStore.readAuthentication(accessToken);
if (result == null) {
throw new InvalidTokenException("Invalid access token: " + accessTokenValue);
} else {
if (this.clientDetailsService != null) {
String clientId = result.getOAuth2Request().getClientId();
try {
this.clientDetailsService.loadClientByClientId(clientId);
} catch (ClientRegistrationException var6) {
throw new InvalidTokenException("Client not valid: " + clientId, var6);
}
}
return result;
}
}
}
public String getClientId(String tokenValue) {
OAuth2Authentication authentication = this.tokenStore.readAuthentication(tokenValue);
if (authentication == null) {
throw new InvalidTokenException("Invalid access token: " + tokenValue);
} else {
OAuth2Request clientAuth = authentication.getOAuth2Request();
if (clientAuth == null) {
throw new InvalidTokenException("Invalid access token (no client id): " + tokenValue);
} else {
return clientAuth.getClientId();
}
}
}
public boolean revokeToken(String tokenValue) {
OAuth2AccessToken accessToken = this.tokenStore.readAccessToken(tokenValue);
if (accessToken == null) {
return false;
} else {
if (accessToken.getRefreshToken() != null) {
this.tokenStore.removeRefreshToken(accessToken.getRefreshToken());
}
this.tokenStore.removeAccessToken(accessToken);
return true;
}
}
private OAuth2RefreshToken createRefreshToken(OAuth2Authentication authentication) {
if (!this.isSupportRefreshToken(authentication.getOAuth2Request())) {
return null;
} else {
int validitySeconds = this.getRefreshTokenValiditySeconds(authentication.getOAuth2Request());
String value = UUID.randomUUID().toString();
return (OAuth2RefreshToken)(validitySeconds > 0 ? new DefaultExpiringOAuth2RefreshToken(value, new Date(System.currentTimeMillis() + (long)validitySeconds * 1000L)) : new DefaultOAuth2RefreshToken(value));
}
}
private OAuth2AccessToken createAccessToken(OAuth2Authentication authentication, OAuth2RefreshToken refreshToken) {
DefaultOAuth2AccessToken token = new DefaultOAuth2AccessToken(UUID.randomUUID().toString());
int validitySeconds = this.getAccessTokenValiditySeconds(authentication.getOAuth2Request());
if (validitySeconds > 0) {
token.setExpiration(new Date(System.currentTimeMillis() + (long)validitySeconds * 1000L));
}
token.setRefreshToken(refreshToken);
token.setScope(authentication.getOAuth2Request().getScope());
return (OAuth2AccessToken)(this.accessTokenEnhancer != null ? this.accessTokenEnhancer.enhance(token, authentication) : token);
}
protected int getAccessTokenValiditySeconds(OAuth2Request clientAuth) {
if (this.clientDetailsService != null) {
ClientDetails client = this.clientDetailsService.loadClientByClientId(clientAuth.getClientId());
Integer validity = client.getAccessTokenValiditySeconds();
if (validity != null) {
return validity;
}
}
return this.accessTokenValiditySeconds;
}
protected int getRefreshTokenValiditySeconds(OAuth2Request clientAuth) {
if (this.clientDetailsService != null) {
ClientDetails client = this.clientDetailsService.loadClientByClientId(clientAuth.getClientId());
Integer validity = client.getRefreshTokenValiditySeconds();
if (validity != null) {
return validity;
}
}
return this.refreshTokenValiditySeconds;
}
protected boolean isSupportRefreshToken(OAuth2Request clientAuth) {
if (this.clientDetailsService != null) {
ClientDetails client = this.clientDetailsService.loadClientByClientId(clientAuth.getClientId());
return client.getAuthorizedGrantTypes().contains("refresh_token");
} else {
return this.supportRefreshToken;
}
}
public void setTokenEnhancer(TokenEnhancer accessTokenEnhancer) {
this.accessTokenEnhancer = accessTokenEnhancer;
}
public void setRefreshTokenValiditySeconds(int refreshTokenValiditySeconds) {
this.refreshTokenValiditySeconds = refreshTokenValiditySeconds;
}
public void setAccessTokenValiditySeconds(int accessTokenValiditySeconds) {
this.accessTokenValiditySeconds = accessTokenValiditySeconds;
}
public void setSupportRefreshToken(boolean supportRefreshToken) {
this.supportRefreshToken = supportRefreshToken;
}
public void setReuseRefreshToken(boolean reuseRefreshToken) {
this.reuseRefreshToken = reuseRefreshToken;
}
public void setTokenStore(TokenStore tokenStore) {
this.tokenStore = tokenStore;
}
public void setAuthenticationManager(AuthenticationManager authenticationManager) {
this.authenticationManager = authenticationManager;
}
public void setClientDetailsService(ClientDetailsService clientDetailsService) {
this.clientDetailsService = clientDetailsService;
}
}
private ClientDetailsService clientDetailsService;
@Override
public void configure(AuthorizationServerEndpointsConfigurer endpoints) throws Exception {
//自定义TokenServices
endpoints.tokenServices(customTokenServices(endpoints));
}
private CustomTokenServices customTokenServices(AuthorizationServerEndpointsConfigurer endpoints) {
CustomTokenServices tokenServices = new CustomTokenServices();
tokenServices.setTokenStore(endpoints.getTokenStore());
tokenServices.setSupportRefreshToken(true);
tokenServices.setReuseRefreshToken(true);
tokenServices.setClientDetailsService(clientDetailsService);
tokenServices.setTokenEnhancer(endpoints.getTokenEnhancer());
// 设置自定义的CustomUserDetailsByNameServiceWrapper
if (userDetailsServices != null) {
PreAuthenticatedAuthenticationProvider provider = new PreAuthenticatedAuthenticationProvider();
provider.setPreAuthenticatedUserDetailsService(new CustomUserDetailsByNameServiceWrapper(userDetailsServices));
tokenServices.setAuthenticationManager(new ProviderManager(Arrays.asList(provider)));
}
return tokenServices;
}