代码地址:https://gitee.com/kkmy/kw-microservices.git
(又是一年1024,分享一下之前搭的OAuth2服务)
<dependency>
<groupId>org.springframework.cloudgroupId>
<artifactId>spring-cloud-starter-oauth2artifactId>
<version>2.2.5.RELEASEversion>
dependency>
这里使用了策略模式,根据传来的系统类型,调用对应系统服务的接口
package pers.kw.config.security;
import com.alibaba.fastjson.JSON;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
import org.springframework.stereotype.Component;
import pers.kw.common.spring.utils.SpringUtils;
import pers.kw.config.oauth.context.MyParamValue;
import pers.kw.config.oauth.context.MyParamValueThreadLocal;
import pers.kw.contants.AuthParamName;
import pers.kw.enums.AuthUserTypeEnum;
import pers.kw.service.UserDetailStrategy;
import java.util.ArrayList;
import java.util.List;
/**
* 自定义UserDetailService
*/
@Component
public class MyUserDetailsService implements UserDetailsService {
private static final Logger log = LoggerFactory.getLogger(MyUserDetailsService.class);
private static final List<GrantedAuthority> authorities = new ArrayList<>(2);
@Override
public UserDetails loadUserByUsername(String userName) throws UsernameNotFoundException {
log.info("自定义UserDetailsService处理start...");
MyParamValue paramValue = MyParamValueThreadLocal.getCurrent();
log.info("获取自定义参数信息:{}", JSON.toJSONString(paramValue));
String userType = paramValue.getAuthParameter(AuthParamName.USER_TYPE);
if (StringUtils.isBlank(userType)) {
throw new OAuth2Exception(AuthParamName.USER_TYPE + "不能为空");
}
if (!AuthUserTypeEnum.userTypeSet.contains(userType)) {
throw new OAuth2Exception(AuthParamName.USER_TYPE + "错误");
}
AuthUserTypeEnum userTypeEnum = AuthUserTypeEnum.getEnumObjByCode(userType);
if (userTypeEnum == null) {
log.info("oauth服务,用户认证策略配置错误,{}:{}", AuthParamName.USER_TYPE, userType);
throw new OAuth2Exception("认证系统异常");
}
try {
UserDetailStrategy userDetailStrategy = (UserDetailStrategy) SpringUtils.getBean(Class.forName(userTypeEnum.getUserStrategy()));
return userDetailStrategy.getUserInfoByMobile(userName,authorities);
} catch (ClassNotFoundException e) {
log.error("oauth服务,用户认证策略配置获取异常", e);
throw new OAuth2Exception("认证系统异常");
}
}
}
MyOauthAuthenticationFilter
package pers.kw.config.oauth;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.OrRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.GenericFilterBean;
import pers.kw.config.oauth.context.MyParamValue;
import pers.kw.config.oauth.context.MyParamValueThreadLocal;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
/**
* 通过添加自定义过滤器,实现对oauth标准接口增加自定义参数
*/
@Component
public class MyOauthAuthenticationFilter extends GenericFilterBean implements ApplicationContextAware {
private static final Logger log = LoggerFactory.getLogger(MyOauthAuthenticationFilter.class);
private ApplicationContext applicationContext;
private final RequestMatcher requestMatcher;
private static final String URL = "/oauth/token";
public MyOauthAuthenticationFilter() {
this.requestMatcher = new OrRequestMatcher(
new AntPathRequestMatcher(URL, "GET"),
new AntPathRequestMatcher(URL, "POST")
);
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) servletRequest;
HttpServletResponse response = (HttpServletResponse) servletResponse;
if (requestMatcher.matches(request)) {
//将自定义参数,保存到当前本地线程中
MyParamValue paramValue = new MyParamValue();
paramValue.setAuthParameters(request.getParameterMap());
MyParamValueThreadLocal.set(paramValue);
filterChain.doFilter(request, response);
//执行完成,清除线程本地变量
MyParamValueThreadLocal.remove();
} else {
filterChain.doFilter(request, response);
}
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.applicationContext = applicationContext;
}
}
这里的响应码一定要设置为200,若取oauth2返回的非200响应码,在微服务调用过程中,返回值无法被正常序列化
return new ResponseEntity<>(
ExceptionResponse.fail(status,
e.getMessage())
, headers,
HttpStatus.OK);
package pers.kw.config.oauth;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.security.access.AccessDeniedException;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.common.DefaultThrowableAnalyzer;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.InsufficientScopeException;
import org.springframework.security.oauth2.common.exceptions.InvalidGrantException;
import org.springframework.security.oauth2.common.exceptions.OAuth2Exception;
import org.springframework.security.oauth2.provider.error.WebResponseExceptionTranslator;
import org.springframework.security.web.util.ThrowableAnalyzer;
import org.springframework.web.HttpRequestMethodNotSupportedException;
import pers.kw.protocol.ExceptionResponse;
import java.io.IOException;
public class MyWebResponseExceptionTranslator implements WebResponseExceptionTranslator {
private static final Logger log = LoggerFactory.getLogger(MyWebResponseExceptionTranslator.class);
private ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();
@Override
public ResponseEntity<ExceptionResponse> translate(Exception e) throws Exception {
log.error("OAuth2异常处理:", e);
// Try to extract a SpringSecurityException from the stacktrace
Throwable[] causeChain = throwableAnalyzer.determineCauseChain(e);
Exception ase = (OAuth2Exception) throwableAnalyzer.getFirstThrowableOfType(OAuth2Exception.class, causeChain);
if (ase != null) {
if (ase instanceof InvalidGrantException) {
log.info("ase:{}", ase.getMessage());
return handleOAuth2Exception((OAuth2Exception) ase, "密码错误");
}
return handleOAuth2Exception((OAuth2Exception) ase);
}
ase = (AuthenticationException) throwableAnalyzer.getFirstThrowableOfType(AuthenticationException.class,
causeChain);
if (ase != null) {
return handleOAuth2Exception(new UnauthorizedException(e.getMessage(), e));
}
ase = (AccessDeniedException) throwableAnalyzer
.getFirstThrowableOfType(AccessDeniedException.class, causeChain);
if (ase != null) {
return handleOAuth2Exception(new ForbiddenException(ase.getMessage(), ase));
}
ase = (HttpRequestMethodNotSupportedException) throwableAnalyzer.getFirstThrowableOfType(
HttpRequestMethodNotSupportedException.class, causeChain);
if (ase != null) {
return handleOAuth2Exception(new MethodNotAllowed(ase.getMessage(), ase));
}
return handleOAuth2Exception(new ServerErrorException(HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase(), e));
}
private ResponseEntity<ExceptionResponse> handleOAuth2Exception(OAuth2Exception e, String msg) throws IOException {
int status = e.getHttpErrorCode();
HttpHeaders headers = new HttpHeaders();
headers.set("Cache-Control", "no-store");
headers.set("Pragma", "no-cache");
if (status == HttpStatus.UNAUTHORIZED.value() || (e instanceof InsufficientScopeException)) {
headers.set("WWW-Authenticate", String.format("%s %s", OAuth2AccessToken.BEARER_TYPE, e.getSummary()));
}
//HttpStatus.valueOf(status)
return new ResponseEntity<>(
ExceptionResponse.fail(status,
msg)
, headers, HttpStatus.OK
);
}
private ResponseEntity<ExceptionResponse> handleOAuth2Exception(OAuth2Exception e) throws IOException {
int status = e.getHttpErrorCode();
HttpHeaders headers = new HttpHeaders();
headers.set("Cache-Control", "no-store");
headers.set("Pragma", "no-cache");
if (status == HttpStatus.UNAUTHORIZED.value() || (e instanceof InsufficientScopeException)) {
headers.set("WWW-Authenticate", String.format("%s %s", OAuth2AccessToken.BEARER_TYPE, e.getSummary()));
}
return new ResponseEntity<>(
ExceptionResponse.fail(status,
e.getMessage())
, headers,
HttpStatus.OK);
}
public void setThrowableAnalyzer(ThrowableAnalyzer throwableAnalyzer) {
this.throwableAnalyzer = throwableAnalyzer;
}
private static class ForbiddenException extends OAuth2Exception {
public ForbiddenException(String msg, Throwable t) {
super(msg, t);
}
@Override
public String getOAuth2ErrorCode() {
return "access_denied";
}
@Override
public int getHttpErrorCode() {
return 403;
}
}
private static class ServerErrorException extends OAuth2Exception {
public ServerErrorException(String msg, Throwable t) {
super(msg, t);
}
@Override
public String getOAuth2ErrorCode() {
return "server_error";
}
@Override
public int getHttpErrorCode() {
return 500;
}
}
private static class UnauthorizedException extends OAuth2Exception {
public UnauthorizedException(String msg, Throwable t) {
super(msg, t);
}
@Override
public String getOAuth2ErrorCode() {
return "unauthorized";
}
@Override
public int getHttpErrorCode() {
return 401;
}
}
private static class MethodNotAllowed extends OAuth2Exception {
public MethodNotAllowed(String msg, Throwable t) {
super(msg, t);
}
@Override
public String getOAuth2ErrorCode() {
return "method_not_allowed";
}
@Override
public int getHttpErrorCode() {
return 405;
}
}
}