之前项目中涉及到一个场景:
用户每天可以在app上签到领取金币,金币到达一定数量后可以换成人民币
场景很简单,其实仔细想下涉及到要考虑的细节还是很多,这里不一一列举,主要说一下,如果用户通过fiddler插件抓取到领取金币的接口请求,然后写一个脚本循环去调用,那么是不是就赚大发了。也许有的人会在心里想,谁会为了那点金币还是去安装各种抓包工具,然后又写脚本去循环调用接口。。。其实可能你没遇到过而已,这种情况真的是非常多。所以作为程序员,我们开发这个功能,如果不能保证和杜绝这种刷金币的情况,那么会给公司带来很大的损失,后果肯定也会非常严重。
怎么突然感觉到,程序员这条路也真的不容易走,一不小心一些细节没有考虑到,开发的功能模块有漏洞给公司造成损失,瞬间就玩完了。
下面把自己之前实践过的Redis分布式锁以注解的方式调用,非常小的侵入性,简单一个注解就可以搞定重复刷单请求。下面的代码经过生存环境的考验,而且项目上线后,没有出现过重复刷单请求。
一、首先是开发限制用户重复提交的注解,在其他需要进行限制的方法上,直接使用@RedisLimitLock注解即可
/**
* @Description Redis锁限制用户重复提交 Annotation
* @Version 1.0
**/
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RedisLimitLock {
String value() default "";
}
二、开发限制重复提交的切面
/**
* @Description redis锁限制重复提交 切面
**/
@Aspect
@Component
public class RedisLimitLockAspect {
private static final String TOKEN = "token"; //parameter token
private static final int TIMEOUT = 3; //超时时间 单位:秒
private static final Logger LOGGER = LoggerFactory.getLogger(RedisLimitLockAspect.class);
//记录当前线程标志 ,使用@Before @After时保存当前线程的标志
private static final ThreadLocal requestUUID = new ThreadLocal<>();
private static final ThreadLocal threadLockKey = new ThreadLocal<>();
/**
* 定义切点
*/
@Pointcut("@annotation(com.hstrivl.aspect.annotation.RedisLimitLock)")
public void controllerAspect(){}
/**
* 环绕通知,根据条件控制目标方法是否执行
* @param proceedingJoinPoint
*/
@Around("controllerAspect()")
public void doAround(ProceedingJoinPoint proceedingJoinPoint){
final Object[] parameterValues = proceedingJoinPoint.getArgs(); //切入方法参数值集合
CodeSignature codeSignature = (CodeSignature) proceedingJoinPoint.getStaticPart().getSignature();
String[] parameterNames = codeSignature.getParameterNames(); //切入方法参数名集合parameterName <--> parameterValue
String methodName = proceedingJoinPoint.getSignature().getName();//切入方法名称
String paramToken = "";
for(int i = 0, length = parameterNames.length; i < length; i++) {
if (TOKEN.equalsIgnoreCase(parameterNames[i])) {
paramToken = (String) parameterValues[i];
break;
}
}
//如果通过切点方式没有取到参数,通过request取
if (StringUtils.isEmpty(paramToken)) {
HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
paramToken = request.getParameter(TOKEN);
}
String requestId = UUID.randomUUID().toString();
String lockKey = paramToken + methodName; //同一个用户并发操作
LOGGER.info("lockKey:{}", lockKey);
//String lockKey = UUID.randomUUID().toString() + methodName; //模拟不同用户同时操作
if (!RedisUtil.tryGetDistributedLock(lockKey, requestId, TIMEOUT)) {
//没有获取到锁
LOGGER.info("get redis lock failed! paramerNames:{},parameterValues:{}", Arrays.toString(parameterNames),Arrays.toString(parameterValues));
} else {
LOGGER.info("get redis lock success");
try {
proceedingJoinPoint.proceed(); //执行目标方法
} catch (Throwable throwable) {
throwable.printStackTrace();
} finally {
//执行完之后释放锁
RedisUtil.releaseDistributedLock(lockKey, requestId);
}
}
}
/**
* 前置通知 目标方法执行前的一些操作**按需添加**
* @param joinPoint 切点
*/
//@Before("controllerAspect()")
public void doBefore(JoinPoint joinPoint) {
final Object[] parameterValues = joinPoint.getArgs(); //切入方法参数值集合
CodeSignature codeSignature = (CodeSignature) joinPoint.getStaticPart().getSignature();
String[] parameterNames = codeSignature.getParameterNames(); //切入方法参数名集合parameterName <--> parameterValue
String methodName = joinPoint.getSignature().getName();//切入方法名称
String paramToken = "";
for(int i = 0, length = parameterNames.length; i < length; i++) {
if (TOKEN.equalsIgnoreCase(parameterNames[i])) {
paramToken = (String) parameterValues[i];
break;
}
}
String requestId = UUID.randomUUID().toString();
requestUUID.set(requestId);
String lockKey = paramToken + methodName; //同一个用户并发操作
//String lockKey = UUID.randomUUID().toString() + methodName; //不同用户同时操作
threadLockKey.set(lockKey);
LOGGER.info("thread before-" + Thread.currentThread().getName());
if (!RedisUtil.tryGetDistributedLock(lockKey, requestId, TIMEOUT)) {
//没有获取到锁
LOGGER.info("get redis lock failed!");
return;
} else {
LOGGER.info("get redis lock success");
}
}
/**
* 后置通知 目标方法执行后的一些操作**按需添加**
* @param
*/
//@After("controllerAspect()")
public void doAfter(JoinPoint joinPoint) {
String requestId = null;
String lockKey = null;
try {
requestId = requestUUID.get();
lockKey = threadLockKey.get();
} finally {
requestUUID.remove();
threadLockKey.remove();
if (!RedisUtil.releaseDistributedLock(lockKey, requestId)) {
LOGGER.info("release redis lock failed!");
}else{
LOGGER.info("release redis success");
}
}
}
/**
* 异常通知 目标方法异常时的操作
* @param e
*/
@AfterThrowing(pointcut = "controllerAspect()", throwing = "e")
public void afterThrowing(JoinPoint point,Exception e) {
String requestId = null;
String lockKey = null;
try {
requestId = requestUUID.get();
lockKey = threadLockKey.get();
} finally {
requestUUID.remove();
threadLockKey.remove();
if (!StringUtils.isEmpty(lockKey) && !StringUtils.isEmpty(requestId)) {
RedisUtil.releaseDistributedLock(lockKey, requestId);//释放redislock
}
}
}
}
三、在对应Controller里method上添加限制重复提交的注解@RedisLimitLock
/**
* 每天签到
* @param response
* @param request
* @param token
* @throws IOException
*/
@RedisLimitLock
@RequestMapping(value = "/userNormalSign.html", method = RequestMethod.POST)
public void userNormalSign(HttpServletResponse response, HttpServletRequest request, String token) throws IOException {
// 业务逻辑代码
}
四、RedisUtil及JedisUtil工具类提供给大家参考:
/**
* RedisUtil
*/
public class RedisUtil {
/**
* @Description: 把一个对象存入redis
*/
public static void setObjectValue(String key, Object value) {
Jedis jedis = JedisUtil.getJedis();
SerializeUtil su = new SerializeUtil();
jedis.set(key.getBytes(), su.serialize(value));
JedisUtil.returnResource(jedis);
}
/**
* @Description: 把一个字符串存入redis
*/
public static void setStringValue(String key, String value) {
Jedis jedis = JedisUtil.getJedis();
jedis.set(key, value);
JedisUtil.returnResource(jedis);
}
/**
* @Description: 判断key是否存在
*/
public static boolean isHaveRedisKey(String key) {
Jedis jedis = JedisUtil.getJedis();
boolean flag = jedis.exists(key);
JedisUtil.returnResource(jedis);
return flag;
}
/**
* @Description: 把一个对象存入redis
*/
public static void setExpireObject(String key, int time, Object value) {
Jedis jedis = JedisUtil.getJedis();
SerializeUtil su = new SerializeUtil();
jedis.setex(key.getBytes(), time, su.serialize(value));
JedisUtil.returnResource(jedis);
}
public static void setExpireString(String key, int time, String value) {
Jedis jedis = JedisUtil.getJedis();
jedis.setex(key, time, value);
JedisUtil.returnResource(jedis);
}
/**
* @Description: 查询一个key从redis
*/
public static Object getObjectValue(String key) {
Jedis jedis = JedisUtil.getJedis();
SerializeUtil su = new SerializeUtil();
Object o = su.unserialize(jedis.get(key.getBytes()));
JedisUtil.returnResource(jedis);
return o;
}
/**
* @Description: 查询一个key从redis
*/
public static String getStringValue(String key) {
Jedis jedis = JedisUtil.getJedis();
String value = jedis.get(key);
JedisUtil.returnResource(jedis);
return value;
}
/**
* @Description: 删除对象
*/
public static void removeString(String key) {
Jedis jedis = JedisUtil.getJedis();
jedis.del(key);
JedisUtil.returnResource(jedis);
}
public static void removeObject(String key) {
Jedis jedis = JedisUtil.getJedis();
jedis.del(key.getBytes());
JedisUtil.returnResource(jedis);
}
/**
* @Description: 把一个对象存入队列
*/
public static void setQueueValue(String key, Object value) {
Jedis jedis = JedisUtil.getJedis();
SerializeUtil su = new SerializeUtil();
jedis.rpush(key.getBytes(), su.serialize(value));
JedisUtil.returnResource(jedis);
}
/**
* @Description: 把一个对象取出队列
*/
public static Object getQueueValue(String key) {
Jedis jedis = JedisUtil.getJedis();
SerializeUtil su = new SerializeUtil();
Object o = su.unserialize(jedis.lpop(key.getBytes()));
JedisUtil.returnResource(jedis);
return o;
}
/**
* @Description: 设置key的过期时间
*/
public static void setExpireKey(String key, int seconds) {
Jedis jedis = JedisUtil.getJedis();
jedis.expire(key, seconds);
JedisUtil.returnResource(jedis);
}
/**
* @Description: 把一个字符串存入redis
*/
public static void setAppendValue(String key, int time, String value) {
Jedis jedis = JedisUtil.getJedis();
if (jedis.get(key) != null) {
jedis.append(key, value);
} else {
jedis.setex(key, time, value);
}
JedisUtil.returnResource(jedis);
}
/**
* @Description: 查询key的过期时间
*/
public static int getKeyExpire(String key) {
Jedis jedis = JedisUtil.getJedis();
long time = jedis.pttl(key);
JedisUtil.returnResource(jedis);
return Integer.parseInt(time / 1000 + "");
}
/**
* @Description: 自动增加值
*/
public static void incrementKey(String key, Integer increment) {
Jedis jedis = JedisUtil.getJedis();
jedis.incrBy(key, increment);
JedisUtil.returnResource(jedis);
}
/**
* 批量添加String val
*/
public static void batchStringVal(Map map) {
Jedis jedis = JedisUtil.getJedis();
Pipeline pipeline = jedis.pipelined();
Set> entrySet = map.entrySet();
Iterator> it = entrySet.iterator();
while (it.hasNext()) {
Map.Entry entry = it.next();
pipeline.set(entry.getKey(), entry.getValue());
}
pipeline.sync();
JedisUtil.returnResource(jedis);
}
/**
* 批量添加Hash String
*/
public static void batchHashString(String key, Map map) {
Jedis jedis = JedisUtil.getJedis();
Pipeline pipeline = jedis.pipelined();
Set> entrySet = map.entrySet();
Iterator> it = entrySet.iterator();
while (it.hasNext()) {
Map.Entry entry = it.next();
pipeline.hset(key, entry.getKey(), entry.getValue());
}
pipeline.sync();
JedisUtil.returnResource(jedis);
}
/**
* 批量添加Hash Object
*/
public static void batchHashObject(String key, Map map) {
Jedis jedis = JedisUtil.getJedis();
Pipeline pipeline = jedis.pipelined();
Set> entrySet = map.entrySet();
SerializeUtil su = new SerializeUtil();
Iterator> it = entrySet.iterator();
while (it.hasNext()) {
Map.Entry entry = it.next();
pipeline.hset(key.getBytes(), entry.getKey().getBytes(), su.serialize(entry.getValue()));
}
pipeline.sync();
JedisUtil.returnResource(jedis);
}
/**
* Hash get
*/
public static String hGetString(String key, String field) {
Jedis jedis = JedisUtil.getJedis();
String val = jedis.hget(key, field);
JedisUtil.returnResource(jedis);
return val;
}
/**
* Hash get
*/
public static Object hGetObject(String key, String field) {
Jedis jedis = JedisUtil.getJedis();
SerializeUtil su = new SerializeUtil();
Object val = su.unserialize(jedis.hget(key.getBytes(), field.getBytes()));
JedisUtil.returnResource(jedis);
return val;
}
/**
* 批量查询
*/
public static List getQueryValues(List list) {
List values = new ArrayList<>();
for (String key : list) {
String o = getStringValue(key);
values.add(o);
}
return values;
}
/**Redis distribute lock*/
private static final String LOCK_SUCCESS = "OK"; //成功获取锁标识
private static final String SET_IF_NOT_EXIST = "NX"; //不存在时NX 存在时XX
private static final String SET_WITH_EXPIRE_TIME = "EX"; //超时设置 EX-秒 PX-毫秒
private static final Long RELEASE_SUCCESS = 1L; //锁释放成功标志
/**
* 尝试获取分布式锁
* @param lockKey 锁
* @param requestId 请求标识
* @param expireTime 超期时间
* @return 是否获取成功
*/
public static boolean tryGetDistributedLock(String lockKey, String requestId, int expireTime) {
Jedis jedis = JedisUtil.getJedis();
String result = jedis.set(lockKey, requestId, SET_IF_NOT_EXIST, SET_WITH_EXPIRE_TIME, expireTime);
JedisUtil.returnResource(jedis);
if (LOCK_SUCCESS.equals(result)) {
return true;
}
return false;
}
/**
* 释放分布式锁
* @param lockKey 锁
* @param requestId 请求标识
* @return 是否释放成功
*/
public static boolean releaseDistributedLock(String lockKey, String requestId) {
Jedis jedis = JedisUtil.getJedis();
String script = "if redis.call('get', KEYS[1]) == ARGV[1] then return redis.call('del', KEYS[1]) else return 0 end";
Object result = jedis.eval(script, Collections.singletonList(lockKey), Collections.singletonList(requestId));
JedisUtil.returnResource(jedis);
if (RELEASE_SUCCESS.equals(result)) {
return true;
}
return false;
}
/**
* 根据key 获取map
* @param key
* @return
*/
public static Map getMapByKey(String key) {
Jedis jedis = JedisUtil.getJedis();
Map result = jedis.hgetAll(key);
JedisUtil.returnResource(jedis);
return result == null ? new HashMap() : result;
}
/**
* userId:Map
* 自然日过期
* @param key
* @param map
*/
public static void setMapValueWithExpire(String key, Map map) {
Jedis jedis = JedisUtil.getJedis();
jedis.hmset(key, map);
jedis.expire(key,getExpireTimeOfDay());
JedisUtil.returnResource(jedis);
}
/**
* 获取一个自然日的过期时间(当天的23:59:59减去当前时间) 单位:秒
* @return
*/
private static int getExpireTimeOfDay(){
DateFormat format = new SimpleDateFormat("yyyy-MM-dd");
Date date = new Date();
String str = format.format(date);
Date date2 = null;
try {
date2 = format.parse(str);
} catch (ParseException e) {
e.printStackTrace();
}
long dayMis = 1000 * 60 * 60 * 24;//一天的毫秒
long curMillisecond = date2.getTime();
long resultMis = curMillisecond + (dayMis - 1); //当天最后一秒
long nowTimeMills = System.currentTimeMillis();
int result = (int) ((resultMis - nowTimeMills) / 1000);//转换成秒
return result;
}
}
JedisUtil工具类:
/**
* JedisUtil 工具类
*/
public class JedisUtil {
//Redis服务器IP
private static String server_ip = PropertiesHelper.getProperty("server_ip", "redis.properties");
//Redis的端口号
private static String port = PropertiesHelper.getProperty("port", "redis.properties");
//访问密码
private static String auth = PropertiesHelper.getProperty("auth", "redis.properties");
private static JedisPool pool = null;
/**
* 构建redis连接池
*
* @param ip
* @param port
* @return JedisPool
*/
public static JedisPool getPool() {
if (pool == null) {
JedisPoolConfig config = new JedisPoolConfig();
config.setTestOnBorrow(false);
config.setTestWhileIdle(true);
config.setMaxIdle(10);//the max number of free
config.setMaxTotal(10000);
pool = new JedisPool(config, server_ip, Integer.parseInt(port));
}
return pool;
}
/**
* 返还到连接池
*
* @param pool
* @param redis
*/
public static void returnResource(Jedis redis) {
if (redis != null) {
redis.close();
}
}
/**
* 获取数据
*/
public static Jedis getJedis() {
Jedis j = null;
try {
j = getPool().getResource();
if (!"".equals(auth) && auth != null) {
j.auth(auth);
}
} catch (Exception e) {
try {
JedisPoolConfig config = new JedisPoolConfig();
config.setTestOnBorrow(false);
config.setTestWhileIdle(true);
config.setMaxIdle(10);//the max number of free
config.setMaxTotal(10000);
pool = new JedisPool(config, server_ip, Integer.parseInt(port));
j = pool.getResource();
j.auth(auth);
} catch (Exception e1) {
e1.printStackTrace();
}
}
return j;
}
}
配置文件获取工具类:
/**
* properties工具类
*/
public class PropertiesHelper {
public static String getProperty(String name, String properties) {
String result = "";
Resource resource = new ClassPathResource(properties);
Properties props = null;
try {
props = PropertiesLoaderUtils.loadProperties(resource);
} catch (IOException e) {
System.out.println("读取配置文件" + properties + "失败,原因:配置文件不存在!");
}
for (String key : props.stringPropertyNames()) {
if (name.equals(key)) {
result = props.getProperty(key);
}
}
return result;
}
}