2019独角兽企业重金招聘Python工程师标准>>>
SpringMVC 实现自定义的session 共享(同步)机制
思路
这个问题是针对线上多台服务器(例如多个tomcat,集群)负载均衡而言,如果只有一个服务器运行(提供服务),则不存在这个问题,请直接略过.
为什么多个tomcat服务会存在session 不同步的问题
假设我们使用服务器端的session记录登录鉴权信息,(没有使用redis), 比如用户登录时,登录接口命中的是服务A,那么服务A中就会记录用户的登录信息, 接着用户修改资料(比如上传图片,修改昵称等),保存资料接口命中的是服务器B, 服务B中并没有记录登录信息,所以直接报错:未登录,会跳转到登录页面.
用户明明已经登录过了,可是莫名其妙地又让用户去登录. 这就是问题, 用户的登录信息 只存储到了一台服务器上, 而用户的各种操作(接口访问)可能负载到任意一台服务上. 而http session是内存级别的,各tomcat服务是不会共享的.
流程
如何让所有服务都能读取到用户的登录信息呢? 我们需要把登录信息存储到一个所有服务器都能访问的地方,这里我们使用redis, 使用其他分布式的缓存,Memcached ,zookeeper 也可以.
方案
实现 HttpServletRequest , 重写它的 getSession(boolean),getSession()方法.
具体方案
- 实现
javax.servlet.http.HttpServletRequestWrapper
,重写它的 getSession(boolean),getSession() - 实现
HttpSession
,重写HttpSession的三个核心方法: a. getAttribute; b. setAttribute; c. removeAttribute - 在这三个方法中,除了对原始的HttpSession 操作外,还会同时对redis进行操作.
看下setAttribute
的重写实现:
/**
* 需要重写
*
* @param s
* @param o
*/
@Override
public void setAttribute(String s, Object o) {
String sessionId = null;
if (null == this.httpSession) {
sessionId = this.JSESSIONID;
} else {
this.httpSession.setAttribute(s, o);
sessionId = this.httpSession.getId();
}
RedisCacheUtil.setSessionAttribute(sessionId + s, o);
}
注意
- 存储到redis 中的时候,redis 的key一定要有原始sessionId,这样才能区分是哪个会话;
- redis 中的value 实际都是String,所以在setAttribute 中存储到redis 时,要对存储的值进行序列化, 同理
getAttribute
中,对从redis中获取的value,要反序列化
代码
CustomSharedHttpSession 实现HttpSession
package oa.web.responsibility.impl.custom;
import com.common.util.RedisCacheUtil;
import javax.servlet.ServletContext;
import javax.servlet.http.HttpSession;
import javax.servlet.http.HttpSessionContext;
import java.util.Enumeration;
/***
* http session 同步和共享
* see oa/web/responsibility/impl/custom/HttpSessionSyncShareFilter.java
*/
public class CustomSharedHttpSession implements HttpSession {
protected HttpSession httpSession;
protected String JSESSIONID;
public CustomSharedHttpSession() {
super();
}
public CustomSharedHttpSession(HttpSession httpSession, String JSESSIONID) {
this.httpSession = httpSession;
this.JSESSIONID = JSESSIONID;
}
@Override
public long getCreationTime() {
return this.httpSession.getCreationTime();
}
@Override
public String getId() {
return this.httpSession.getId();
}
@Override
public long getLastAccessedTime() {
return this.httpSession.getLastAccessedTime();
}
@Override
public ServletContext getServletContext() {
return this.httpSession.getServletContext();
}
@Override
public void setMaxInactiveInterval(int i) {
this.httpSession.setMaxInactiveInterval(i);
}
@Override
public int getMaxInactiveInterval() {
return this.httpSession.getMaxInactiveInterval();
}
@Override
public HttpSessionContext getSessionContext() {
return this.httpSession.getSessionContext();
}
/***
* 需要重写 TODO
* @param s
* @return
*/
@Override
public Object getAttribute(String s) {
Object o1 = null;
if (null == this.getHttpSession()) {
o1 = RedisCacheUtil.getSessionAttribute(this.JSESSIONID + s);
/*if (null != o1) {
this.setAttribute(s,o1);
}*/
return o1;
}
Object o = this.httpSession.getAttribute(s);
if (o == null) {
String currentSessionId = this.httpSession.getId();
o = RedisCacheUtil.getSessionAttribute(currentSessionId + s);
if (null == o) {
if ((!currentSessionId.equals(this.JSESSIONID))) {
Object o2 = RedisCacheUtil.getSessionAttribute(this.JSESSIONID + s);
if (null != o2) {
this.httpSession.setAttribute(s, o2);
o = o2;
// RedisCacheUtil.setSessionAttribute(currentSessionId + s, o);
}
}
}
this.setAttribute(s, o);
}
return o;
}
@Override
public Object getValue(String s) {
return this.httpSession.getValue(s);
}
@Override
public Enumeration getAttributeNames() {
return this.httpSession.getAttributeNames();
}
@Override
public String[] getValueNames() {
return this.httpSession.getValueNames();
}
/**
* 需要重写
*
* @param s
* @param o
*/
@Override
public void setAttribute(String s, Object o) {
String sessionId = null;
if (null == this.httpSession) {
sessionId = this.JSESSIONID;
} else {
this.httpSession.setAttribute(s, o);
sessionId = this.httpSession.getId();
}
RedisCacheUtil.setSessionAttribute(sessionId + s, o);
}
@Override
public void putValue(String s, Object o) {
this.httpSession.putValue(s, o);
}
@Override
public void removeAttribute(String s) {
if (null != this.httpSession) {
this.httpSession.removeAttribute(s);
String sessionId = this.httpSession.getId();
RedisCacheUtil.setSessionAttribute(sessionId + s, null);
}
RedisCacheUtil.setSessionAttribute(this.JSESSIONID + s, null);
}
@Override
public void removeValue(String s) {
this.httpSession.removeValue(s);
}
@Override
public void invalidate() {
this.httpSession.invalidate();
}
@Override
public boolean isNew() {
return this.httpSession.isNew();
}
/***
* 自定义方法
* @return
*/
public HttpSession getHttpSession() {
return httpSession;
}
/***
* 自定义方法
* @param httpSession
*/
public void setHttpSession(HttpSession httpSession) {
this.httpSession = httpSession;
}
}
HttpPutFormContentRequestWrapper重写HttpServletRequest
package oa.web.request;
import com.common.util.RedisCacheUtil;
import com.common.util.RequestUtil;
import com.common.util.SystemHWUtil;
import com.common.web.filter.CustomFormHttpMessageConverter;
import com.file.hw.props.GenericReadPropsUtil;
import com.io.hw.json.HWJacksonUtils;
import com.string.widget.util.RegexUtil;
import com.string.widget.util.ValueWidget;
import oa.util.SpringMVCUtil;
import org.apache.log4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import javax.servlet.FilterChain;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
public class HttpPutFormContentRequestWrapper extends HttpServletRequestWrapper {
protected final static Logger logger = Logger.getLogger(HttpPutFormContentRequestWrapper.class);
public final static org.slf4j.Logger httpClientRestLogger = LoggerFactory.getLogger("rest_log");
protected MultiValueMap formParameters;
protected String requestBody;
private ResettableServletInputStream servletStream;
/***
* <实际不存在的接口路径A,真实的接口路径B>
* A映射到B
* 有两个来源:(1)config/pathMapping.json;(2)redis,通过方法 RedisCacheUtil.getServletPathMap()
*/
private static Map handlerMethodPathMap;
/***
* 缓存应答题,
*/
private Map responseReturnResultMap = new ConcurrentHashMap<>();
private FilterChain chain;
/***
* 是否需要改为成员变量
*/
protected static final CustomFormHttpMessageConverter formConverter = new CustomFormHttpMessageConverter();
private ThreadLocal is404NotFound = new ThreadLocal() {
@Override
protected Boolean initialValue() {
return Boolean.FALSE;
}
};
/***
* 解决 SpringMVC 进入接口慢的问题
* added 2018-06-28 中国标准时间 下午8:55:41
* see http://i.yhskyc.com/test/1384?testcase=SpringMVC%E8%BF%9B%E5%85%A5%E8%AF%B7%E6%B1%82%E5%B7%A8%E6%85%A2
*/
protected Map servletPathOriginAndTargetMap;
public void set404NotFound(boolean bool) {
this.is404NotFound.set(bool);
}
public boolean is404NotFound() {
return this.is404NotFound.get();
}
static {//因为每次请求都会new 一个HttpPutFormContentRequestWrapper,所以把initMapping 防止静态代码中,全局执行一次
initMapping();
}
public void put(String servletPath, Object response) {
if (null == servletPath) {
servletPath = "";
}
responseReturnResultMap.put(servletPath, response);
}
public String getResponseBodyBackup(String servletPath) {
return (String) this.responseReturnResultMap.get(servletPath);
}
public String getResponseBodyBackup() {
return this.getResponseBodyBackup(getServletPath());
}
public boolean hasContains(String servletPath) {
if (null == servletPath) {
return false;
}
return this.responseReturnResultMap.containsKey(servletPath);
}
/***
* servlet 路径映射,类似于nginx的转发功能
* see https://my.oschina.net/huangweiindex/blog/1789164
*/
public static void initMapping() {
handlerMethodPathMap = new ConcurrentHashMap<>();
//从本地文件"/config/pathMapping.json"中读取
handlerMethodPathMap.put("/agent/afterbuy/list/json", "/agent/afterbuy/listfilter/json");
ClassLoader classLoader = SpringMVCUtil.getApplication().getClassLoader();
String resourcePath = "/config/pathMapping.json";
String json = GenericReadPropsUtil.getConfigTxt(classLoader, resourcePath);
System.out.println("config/pathMapping.json :" + json);
if (!ValueWidget.isNullOrEmpty(json)) {
json = RegexUtil.sedDeleteComment(json);//删除第一行的注释
if (ValueWidget.isNullOrEmpty(json)) {
return;
}
handlerMethodPathMap.putAll(HWJacksonUtils.deSerializeMap(json, String.class));
}
}
public HttpPutFormContentRequestWrapper(HttpServletRequest request/*, MultiValueMap parameters, String requestBody*/) {
super(request);
servletStream = new ResettableServletInputStream();
MultiValueMap parameters = RequestUtil.readFormParameters(request, formConverter);
this.formParameters = (MultiValueMap) (parameters != null ? parameters : new LinkedMultiValueMap());
this.requestBody = formConverter.getRequestBody();
}
/***
* see https://my.oschina.net/huangweiindex/blog/1789164
* 里面有接口路径的映射:handlerMethodPathMap
* @return
*/
@Override
public String getServletPath() {
if (null != this.servletPathOriginAndTargetMap) {
// System.out.println("servletPath :" + servletPath);
String servletPath = super.getServletPath();
System.out.println("servletPath 2 :" + servletPath);
String targetPath = this.servletPathOriginAndTargetMap.get(servletPath);
if (null == targetPath) {
targetPath = servletPath;
}
return targetPath;
}
String servletPath = super.getServletPath();
//映射
String lookupPath = null;
if (!ValueWidget.isNullOrEmpty(handlerMethodPathMap)) {
//<实际不存在的接口路径A,真实的接口路径B>
if (handlerMethodPathMap.containsKey(servletPath)) {
lookupPath = handlerMethodPathMap.get(servletPath);
} else {//从 redis 获取,see PreServletPathMapController
Map servletPathMap = RedisCacheUtil.getServletPathMap();
if (!ValueWidget.isNullOrEmpty(servletPathMap)) {
lookupPath = (String) servletPathMap.get(servletPath);
handlerMethodPathMap.putAll(servletPathMap);
RedisCacheUtil.clearServletPathMap();
}
}
}
if (ValueWidget.isNullOrEmpty(lookupPath)) {
lookupPath = servletPath;
} else {
String msg = "SpringMVC 层实现 Path Mapping,old:" + servletPath + "\tnew:" + lookupPath + " 将被真正调用";
logger.warn(msg);
System.out.println(msg);
httpClientRestLogger.error(msg);
}
//解决 SpringMVC 进入接口慢的问题
servletPathOriginAndTargetMap = new HashMap<>();
servletPathOriginAndTargetMap.put(super.getServletPath(), lookupPath);
return lookupPath;
}
@Override
public String getParameter(String name) {
String queryStringValue = super.getParameter(name);
String formValue = (String) this.formParameters.getFirst(name);
return queryStringValue != null ? queryStringValue : formValue;
}
@Override
public Map getParameterMap() {
Map result = new LinkedHashMap();
Enumeration names = this.getParameterNames();
while (names.hasMoreElements()) {
String name = (String) names.nextElement();
result.put(name, this.getParameterValues(name));
}
return result;
}
@Override
public Enumeration getParameterNames() {
Set names = new LinkedHashSet();
names.addAll(Collections.list(super.getParameterNames()));
names.addAll(this.formParameters.keySet());
return Collections.enumeration(names);
}
@Override
public String[] getParameterValues(String name) {
String[] queryStringValues = super.getParameterValues(name);
List formValues = (List) this.formParameters.get(name);
if (formValues == null) {
return queryStringValues;
} else if (queryStringValues == null) {
return (String[]) formValues.toArray(new String[formValues.size()]);
} else {
List result = new ArrayList();
result.addAll(Arrays.asList(queryStringValues));
result.addAll(formValues);
return (String[]) result.toArray(new String[result.size()]);
}
}
@Override
public ServletInputStream getInputStream() throws IOException {
if (super.getInputStream().available() > 0) {
return super.getInputStream();
}
String requestCharset = getRequest().getCharacterEncoding();
if (ValueWidget.isNullOrEmpty(requestCharset)) {
requestCharset = SystemHWUtil.CHARSET_ISO88591;
}
servletStream.stream = new ByteArrayInputStream(this.requestBody.getBytes(requestCharset));
return servletStream;
}
public MultiValueMap getFormParameters() {
return formParameters;
}
private static class ResettableServletInputStream extends ServletInputStream {
private InputStream stream;
@Override
public int read() throws IOException {
return stream.read();
}
}
public FilterChain getChain() {
return chain;
}
public void setChain(FilterChain chain) {
this.chain = chain;
}
public static CustomFormHttpMessageConverter getFormConverter() {
return formConverter;
}
public void resetCustom() {
this.servletPathOriginAndTargetMap = null;
}
}
推荐
我的其他开源项目 用于服务器端API 的stub 测试
zookeeper的一个java 图形客户端