springboot中@ControllerAdvice异常处理

代码如下

 

 

@ControllerAdvice
open class ExceptionHandler : ResponseBodyAdvice {

    private val LOGGER = LoggerFactory.getLogger(ExceptionHandler::class.java)

    @Autowired
    lateinit var sysAdminRequestLogService: SysAdminRequestLogService

    @Autowired
    private lateinit var serverProperties: ServerProperties

    @Autowired
    private lateinit var jsonConverter: MappingJackson2HttpMessageConverter

    private val EXCEPTION_TAG = "exception"
    private val ERROR_CODE_TAG = "error_code"
    private val OPERATOR_ID = "operator_id"
    private val TIME_ZONE = "user_timezone"

    override fun beforeBodyWrite(body: Any, returnType: MethodParameter?, selectedContentType: MediaType?,
                                 selectedConverterType: Class>?, request: ServerHttpRequest?, response: ServerHttpResponse?): Any? {
        if (!request!!.uri.toString().contains("swagger") && request.uri.toString().contains("api-docs")) {
            LOGGER.trace(" uri[${request.uri}], and result[$body]")
            return body
        }
        val timeZone = parseTimezone(request)
        jsonConverter.objectMapper.setTimeZone(timeZone)

        val result = jsonConverter.objectMapper.writeValueAsString(body)

        val httpLogMessage = constructHttpLogMessage((request as ServletServerHttpRequest).servletRequest, result)
        httpLogMessage?.let {
            val httpRequestLog = constructHttpRequestLog(it)
            sysAdminRequestLogService.save(httpRequestLog)
        }

        return JSONObject.parseObject(result, body.javaClass)

    }

    override fun supports(returnType: MethodParameter?, converterType: Class>?): Boolean {
        return true
    }

    @ExceptionHandler(MethodArgumentNotValidException::class)
    @ResponseBody
    @Throws(Exception::class)
    fun parameterInvalidException(e: MethodArgumentNotValidException, request: HttpServletRequest): BizResponse {
        LOGGER.error(" parameterInvalidException :", e)
        val response = BizResponse(ErrorCode.PARAM_ERROR.code, ErrorCode.PARAM_ERROR.msg)

        try {
            val propertyBindResult: BeanPropertyBindingResult = e.bindingResult as BeanPropertyBindingResult
            if (propertyBindResult.hasErrors()) {
                var sbf = StringBuffer()
                propertyBindResult.allErrors.forEach {
                    sbf.append(it.defaultMessage).append(";")
                }
                response.msg = sbf.toString()
            }
            return response
        } catch (e: Exception) {
            request.setAttribute(EXCEPTION_TAG, getStackTrace(e))
            request.setAttribute(ERROR_CODE_TAG, response.code)
            LOGGER.error(" parameterInvalidException url[${request.requestURL}]", e)
            return response
        }
    }

    @ExceptionHandler(Exception::class)
    @ResponseBody
    @Throws(Exception::class)
    fun sysException(e: Exception, request: HttpServletRequest): BizResponse {

        val response = BizResponse(ErrorCode.SYS_INTERNAL_ERROR.code, ErrorCode.SYS_INTERNAL_ERROR.msg)
        request.setAttribute(EXCEPTION_TAG, getStackTrace(e))
        request.setAttribute(ERROR_CODE_TAG, response.code)

        LOGGER.error("Exception requestUrl[${request.requestURL}] ", e)
        return response
    }

    @ExceptionHandler(BizException::class)
    @ResponseBody
    @Throws(RuntimeException::class)
    fun bizeException(e: BizException, request: HttpServletRequest): BizResponse {
        val response = BizResponse(e.code, e.msg)
        request.setAttribute(ERROR_CODE_TAG, response.code)
        request.setAttribute(EXCEPTION_TAG, getStackTrace(e))
        LOGGER.error("BizException requestUrl[${request.requestURL}] ${e.code}:${e.msg}")
        return response
    }

    /**
     * 缺失@RequestParam参数
     */
    @ExceptionHandler(MissingServletRequestParameterException::class)
    @ResponseBody
    @Throws(RuntimeException::class)
    fun bizeException(e: MissingServletRequestParameterException): BizResponse {
        val response = BizResponse(ErrorCode.PARAM_INCOMPLETE_ERROR.code, ErrorCode.PARAM_INCOMPLETE_ERROR.msg)
        return response
    }

    private fun constructHttpRequestLog(message: HttpLogMessage): SysAdminRequestLog {

        val sysAdminRequestLog = SysAdminRequestLog()

        sysAdminRequestLog.operatorId = message.operatorId
        sysAdminRequestLog.errorCode = message.errorCode
        sysAdminRequestLog.method = message.method
        sysAdminRequestLog.requestIp = message.requestIp
        sysAdminRequestLog.module = message.module
        sysAdminRequestLog.requestParmas = message.requestParmas
        sysAdminRequestLog.requestUri = message.requestUri
        sysAdminRequestLog.requestUrl = message.requestUrl
        sysAdminRequestLog.requestUuid = message.requestUuid
        sysAdminRequestLog.response = message.response
        sysAdminRequestLog.stackTrace = message.stackTrace
        sysAdminRequestLog.timestamp = message.timestamp
        sysAdminRequestLog.responseTime = (System.currentTimeMillis() - message.timestamp!!)

        return sysAdminRequestLog
    }

    private fun constructHttpLogMessage(request: HttpServletRequest, response: String?): HttpLogMessage? {

        val requestURI = request.requestURL.toString()
        LOGGER.trace(" requestURI:[$requestURI]")

        val httpLogMessage = HttpLogMessage()
        httpLogMessage.operatorId = getOperatorId(request)
        httpLogMessage.errorCode = getErrorCode(request)
        httpLogMessage.method = request.method
        httpLogMessage.requestIp = getRealRemoteIpAddress(request)
        httpLogMessage.module = serverProperties.displayName
        httpLogMessage.requestParmas = getRequestBody(request)
        httpLogMessage.requestUri = request.requestURI
        httpLogMessage.requestUrl = request.requestURL.toString()
        httpLogMessage.requestUuid = getRequestUuid(request)
        httpLogMessage.response = response
        httpLogMessage.stackTrace = getStackTrace(request)
        httpLogMessage.timestamp = System.currentTimeMillis()
        httpLogMessage.responseTime = (System.currentTimeMillis() - httpLogMessage.timestamp!!)

        return httpLogMessage
    }

    private fun getStackTrace(request: HttpServletRequest): String? {
        return request.getAttribute(EXCEPTION_TAG)?.toString()
    }

    private fun getStackTrace(e: Exception): String {

        val stringBuffer = StringBuffer(e.toString() + "\n")
        val messages = e.stackTrace
        messages.forEach {
            stringBuffer.append("\t $it \n")
        }

        return stringBuffer.toString()
    }

    private fun getRequestUuid(request: HttpServletRequest): String? {
        return request.getHeader(RequestType.REQUEST_UUID_TAG)
    }

    private fun getRequestBody(request: HttpServletRequest): String {

        val stringBuffer = StringBuffer()

        try {
            val buf: ByteArray = kotlin.ByteArray(1024 * 100)
            val inputStream = request.inputStream
            var len = 0
            while (inputStream.read(buf).apply { len = this } != -1) {
                stringBuffer.append(String(buf, 0, len))
            }

            if (stringBuffer.isEmpty() || stringBuffer.toString() == "") {
                return ""
            }

            val json = JSONObject.parseObject(stringBuffer.toString())
            if (json.containsKey("password")) {
                json["password"] = "*******"
            }

            return json.toString()

        } catch (e: Exception) {

            LOGGER.error(" getRequestBody exception", e)
        }

        return stringBuffer.toString()
    }

    private fun getRealRemoteIpAddress(request: HttpServletRequest): String {

        var clientIp = request.getHeader("X-Forwarded-For")
        if (null != clientIp) {
            val arr = clientIp.replace("[", "").replace("]", "").split(",")
            if (!arr.isEmpty()) {
                clientIp = arr.first()
            }
            return clientIp
        } else {
            clientIp = request.getHeader("X-Real-IP")
        }

        if (null == clientIp) {
            clientIp = request.remoteAddr
        }

        return clientIp
    }

    private fun getErrorCode(request: HttpServletRequest): Int {
        val errorCode = request.getAttribute(ERROR_CODE_TAG)?.toString()
        return errorCode?.toInt() ?: 0
    }

    private fun getOperatorId(request: HttpServletRequest): Long? {
        var operatorIdStr = request.getAttribute(OPERATOR_ID)?.toString()
        if (null == operatorIdStr || "" == operatorIdStr) {
            return null
        }
        return operatorIdStr.toLong()
    }

    private fun parseTimezone(request: ServerHttpRequest): TimeZone {

        var zone = (request as ServletServerHttpRequest).servletRequest.getHeader(TIME_ZONE)
        var timeZone = TimeZone.getDefault()

        try {
            val hour = zone.toInt()
            if (hour >= 0) {
                zone = "GMT+$hour"
            } else {
                zone = "GMT$hour"
            }

            timeZone = StringUtils.parseTimeZoneString(zone)
        } catch (e: Exception) {
        }

        return timeZone
    }

}

你可能感兴趣的:(springboot)