巧谈数据接口安全-具体实现

巧谈数据接口安全-具体实现

上一节针对数据接口的安全性提出了解决方式, 这节通过具体的代码来实现

项目环境

项目接口采用SSM的框架, 基于SpringMVC拦截器来实现。 当然, 也可以采用Spring的AOP来实现, 这种方式更通用一些

上代码

本项目采用 SpringMVC拦截器的方式来实现

获取参数

目前接口最通用的提交方式是application/json, 这种方式在拦截器中如果要获取POST方式的参数, 需要就其进行处理。

工具类 对得到的参数进行处理

public class ParamUtils {
    private static final WeakHashMap PARAMS = new WeakHashMap<>();
    public static final String KEY = "params";
    private ParamUtils() {
    }

    public static ParamUtils getInstance() {
        return ParamUtils.Holder.INSTANCE;
    }

    static class Holder {
        private static final ParamUtils INSTANCE = new ParamUtils();
    }

    /**
     * 将从post中获取到的传递 存放到map中
     * @param params
     */
    public void set(String params) {
        PARAMS.put(KEY, params);
    }

    /**
     * 得到json字符串
     * @return
     */
    public String get() {
        return PARAMS.get(KEY);
    }

    /**
     * 将json字符串转换成map对象
     * @param json
     * @return
     */
    public Map json2Map(String json) {
        return (Map) JSON.parse(json);
    }

    /**
     * 签名验证
     * @param map
     * @return
     */
    public String getSign(Map map) {
        return getSign2Map(map);
    }

    private String getSign2Map(Map map) {
        StringBuffer sb = new StringBuffer();
        ArrayList list = new ArrayList(map.keySet());
        Collections.sort(list);

        for(String key:list) {
            Object value = map.get(key);
            if(!key.equalsIgnoreCase("sign"))
                sb.append(key).append("=").append(map.get(key)).append("&");
        }
        sb.deleteCharAt(sb.length() - 1);
        return DigestUtil.getInstance().md5(sb.toString());
    }

    public String getSign(String json){
        Map map = json2Map(json);
        return getSign2Map(map);
    }

    /**
     * 从get请求中获取到参数 封装成Map对象
     * @param request
     * @return
     */
    public Map getParam2Get(HttpServletRequest request) {
        Map map = new HashMap<>();
        Map parameterMap = request.getParameterMap();

        if(parameterMap != null && !parameterMap.isEmpty()) {
            for(Map.Entry entry : parameterMap.entrySet()) {
                String[] value = entry.getValue();
                if(value != null && value.length > 0) {
                    map.put(entry.getKey(), value[0]);
                }
            }
        }

        return map;
    }
}

Filter 获取非GET方式传递的数据

通过实现Filter的方式, 对非GET的request进行重写, 得到提交参数

public class SecurityFilter implements Filter {
    @Override
    public void init(FilterConfig filterConfig) throws ServletException {

    }

    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        SecurityHttpServletRequestWrapper wrap = null;
        if (request instanceof HttpServletRequest) {
            HttpServletRequest httpServletRequest = (HttpServletRequest) request;
            if (!"get".equals(httpServletRequest.getMethod().toUpperCase()) 
                && httpServletRequest.getHeader("Accept").contains("application/json")) {
                wrap = new SecurityHttpServletRequestWrapper(httpServletRequest);
                ParamUtils.getInstance().set(wrap.getJson());
            }
        }
        if (null != wrap) {
            chain.doFilter(wrap, response);
        } else {
            chain.doFilter(request, response);
        }
    }

    @Override
    public void destroy() {

    }
}


class SecurityHttpServletRequestWrapper extends HttpServletRequestWrapper {

    private String json;

    public SecurityHttpServletRequestWrapper(HttpServletRequest request) throws IOException {
        super(request);
        ServletInputStream stream = this.getRequest().getInputStream();
        json = IOUtils.toString(stream, "UTF-8");
    }

    @Override
    public ServletInputStream getInputStream() {
        byte[] buffer = null;
        try {
            buffer = json.toString().getBytes("UTF-8");
        } catch (UnsupportedEncodingException e) {
            e.printStackTrace();
        }
        final ByteArrayInputStream bais = new ByteArrayInputStream(buffer);
        ServletInputStream newStream = new ServletInputStream() {

            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return false;
            }

            @Override
            public void setReadListener(ReadListener readListener) {

            }

            @Override
            public int read() throws IOException {
                return bais.read();
            }
        };
        return newStream;
    }

    public String getJson() {
        return json;
    }

    public void setJson(String json) {
        this.json = json;
    }

}

web.xml中配置


    securityFilter
    com.sanq.product.security.filters.SecurityFilter


    securityFilter
    /*

自定义过滤注解

在项目中 并不是所有的接口都需要进行全部验证的, 比如 登录注册这种的就不需要验证Token, 这里就需要将验证过滤掉

@IgnoreSecurity 过滤掉Token验证

@Target(ElementType.METHOD) 
@Retention(RetentionPolicy.RUNTIME) 
@Documented
public @interface IgnoreSecurity {
}

@Security 过滤掉所有的验证 (这里名称起的不是很好, 大家随意替换)

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Security {
}

拦截匹配

到这里, 我们就应该在拦截器中开始我们的拦截验证了, 我们在这里需要验证一下几点:

  1. IP
  2. Token
  3. TimeStamp
  4. Sign
  5. Client 客户端

首先我们先将需要验证的参数名称写入一个枚举类中, 方便我们查看

SecurityFieldEnum

public enum SecurityFieldEnum {
  //需要验证的参数
  TOKEN("token"),
  TIMESTAMP("timestamp"),
  SIGN("sign"),
  CLIENT("client"),
  APP("APP");

  private String mName;

  SecurityFieldEnum(String name) {
      this.mName = name;
  }

  public String getName() {
      return mName;
  }
}

开始写拦截器的代码, 为了可以实现通用, 这里定义为abstract

SecurityInterceptor


public abstract class SecurityInterceptor implements HandlerInterceptor {

  @Override
  public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {


    if(handler instanceof HandlerMethod) {
        //验证ip是否在黑名单中
        if(checkIp(request, GlobalUtil.getIpAddr(request))) {
            return false;
        }
        HandlerMethod hm = (HandlerMethod) handler;
        Security security = hm.getMethodAnnotation(Security.class);

        if (security != null) {
            return true;
        }

        IgnoreSecurity s = hm.getMethodAnnotation(IgnoreSecurity.class);

        Map objectMap;
        if (request.getMethod().equalsIgnoreCase("get"))
            objectMap = ParamUtils.getInstance().getParam2Get(request);
        else
            objectMap = ParamUtils.getInstance().json2Map(ParamUtils.getInstance().get());


        if (objectMap != null && !objectMap.isEmpty()) {
            Object o = null;

            if (s == null) {
                o = objectMap.get(SecurityFieldEnum.TOKEN.getName());
                if (o == null)
                    throw new NoParamsException(String.format("参数%s不存在", SecurityFieldEnum.TOKEN.getName()));

                if (!checkToken(request, (String) o)) {
                    throw new TokenException(String.format("%s已过期,请重新登录", SecurityFieldEnum.TOKEN.getName()));
                }
            }

            o = objectMap.get(SecurityFieldEnum.TIMESTAMP.getName());
            if (o == null)
                throw new NoParamsException(String.format("参数%s不存在", SecurityFieldEnum.TIMESTAMP.getName()));

            Long timestamp = StringUtil.toLong(o);

            if (LocalDateUtils.nowTime().getTime() - timestamp >= 60 * 1000)
                throw new NoParamsException(String.format("参数%s已过期", SecurityFieldEnum.TIMESTAMP.getName()));


            o = objectMap.get(SecurityFieldEnum.CLIENT.getName());
            if (o == null || SecurityFieldEnum.APP.getName().equals((String) o)) {  //这里获取CLIENT是为了防止有些项目无法进行sign验证, 也可以去掉该验证
                o = objectMap.get(SecurityFieldEnum.SIGN.getName());
                if (o == null)
                    throw new NoParamsException(String.format("参数%s不存在", SecurityFieldEnum.SIGN.getName()));

                String sign = (String) o;

                String paramsSign = ParamUtils.getInstance().getSign(objectMap);
                if (!sign.equals(paramsSign)) {
                    throw new NoParamsException(String.format("参数%s验证不正确", SecurityFieldEnum.SIGN.getName()));
                }

            }
            return true;
        }
    }
    return false;
  }

  public abstract boolean checkToken(HttpServletRequest request, String token);
  public abstract boolean checkIp(HttpServletRequest request, String ip);

  @Override
  public void postHandle(HttpServletRequest request, HttpServletResponse response, Object handler, ModelAndView modelAndView) throws Exception {

  }

  @Override
  public void afterCompletion(HttpServletRequest request, HttpServletResponse response, Object handler, Exception ex) throws Exception {

  }
}

这一大串的判断, 谁来给我优化下啊 -_-~~

错误信息引导

在拦截器中我们throw了一些异常, 在这里我们需要包装错误返回给用户

这里使用@ControllerAdvice配合@ExceptionHandler来实现

ExceptionAspect

@ControllerAdvice
@ResponseBody
public class ExceptionAspect {

  /**
    * 500 - Token is invaild
    */
  @ResponseStatus(HttpStatus.INTERNAL_SERVER_ERROR)
  @ExceptionHandler(TokenException.class)
  public Response handleTokenException(TokenException e) {
      e.printStackTrace();
      return new Response().failure(e.getMsg(), ResultCode.NO_TOKEN);
  }

  /***
    * 参数有误
    */
  @ResponseStatus(HttpStatus.INTERNAL_SERVER_ERROR)
  @ExceptionHandler(NoParamsException.class)
  public Response handleTokenException(NoParamsException e) {
      e.printStackTrace();
      return new Response().failure(e.getMsg(), ResultCode.PARAM_ERROR);
  }
}

结合我的小说项目运行 查看效果

正式编码已经完成 我们来看下效果

配置

maven引入


    com.sanq.product.x_utils
    util_security
    1.0-SNAPSHOT

web.xml配置


    securityFilter
    com.sanq.product.security.filters.SecurityFilter


    securityFilter
    /*

spring-mvc.xml配置

// 加载扫描错误信息的包


//配置拦截器

    


SecurityInterceptor


public class SecurityInterceptor extends com.sanq.product.security.interceptors.SecurityInterceptor {

  @Resource
  private JedisPoolService jedisPoolService;
  private static final int MAX = 60; 

  @Override
  public boolean checkToken(HttpServletRequest request, String token) {
      return jedisPoolService.exists(Redis.ReplaceKey.getTokenUser(token));
  }

  @Override
  public boolean checkIp(HttpServletRequest request, String ip) {
      //统计ip访问次数的key
      String ipKey = Redis.ReplaceKey.getCheckIpKey(ip);

      //BLOCK_IP_SET: 黑名单的key
      if (jedisPoolService.zrank(Redis.RedisKey.BLOCK_IP_SET, ip)) {
        LogUtil.getInstance(SecurityInterceptor.class).i("ip进入了黑名单");
        return true;
      }

      String ipCountTmp = jedisPoolService.get(ipKey);
      int ipCount = StringUtil.toInteger(ipCountTmp != null ? ipCountTmp : 0);

      if (ipCount > MAX) {
          jedisPoolService.putSet(Redis.RedisKey.BLOCK_IP_SET, 1, ip);
          jedisPoolService.delete(ipKey);
          return true;
      }

      jedisPoolService.incrAtTime(ipKey, MAX);

      return false;
  }
}

完结

到此我们关于接口安全的拦截就已经全部实现完成。 也欢迎大胆的尝试更改。

希望大家都可以写出更加优雅,健壮的程序

扩展

关于js端MD5的加密, 大家可以使用js-md5, 亲测和后端加密后得到的数据是一直的

md5("引入js,调用md5()方法");

你可能感兴趣的:(巧谈数据接口安全-具体实现)