JPA AES 引发的 String 二进制数据 DATA LOSS 问题

引言

系统要求高安全性,试题的数据需要加密,防止在数据库层面进行数据泄漏。

但试题的加密还与密码的加密不同,用户密码的加密可以使用Hash算法,无法解密,不光在数据库中,即便是我们的程序也无法获悉用户的密码。

试题的加密要求数据库层面是加密数据,从数据库中查询数据后,再对数据进行解密,发送给前台。接口是经过精密的认证与鉴权的,保证安全。

最终对比各种方案,决定采用converter的实现方案。

实现

图解

image.png

如上图所示:当数据写入前,对数据进行加密;数据查询后,对数据进行解密。

所有的加密解密工作,均由自定义converter完成。

错误示例

照着Demo写了一个使用AES加解密的converter

/**
 * 加密解决转换器
 * 对象字段数据 与 数据表列之间转换
 */
public class EncryptConverter implements AttributeConverter {

    private static final Logger logger = LoggerFactory.getLogger(EncryptConverter.class);

    /**
     * AES 密钥
     */
    private static final byte[] VALUE = "XQRhrQnGNFJf1WaSGOOJEjNhDjRPMG5N".getBytes(StandardCharsets.UTF_8);

    /**
     * 加密/解密 算法
     */
    private static final String ALGORITHM = "AES";

    /**
     * 加密/解密 密钥
     */
    private static final Key KEY = new SecretKeySpec(VALUE, ALGORITHM);

    /**
     * 加密过程
     * 从对象字段数据 到 数据库列
     */
    @Override
    public String convertToDatabaseColumn(String data) {
        String result;

        try {
            logger.debug("获取算法");
            Cipher cipher = Cipher.getInstance(ALGORITHM);

            logger.debug("设置加密模式与加密密钥");
            cipher.init(Cipher.ENCRYPT_MODE, KEY);

            logger.debug("获取原始内容");
            byte[] rawData = data.getBytes(StandardCharsets.UTF_8);

            logger.debug("加密");
            byte[] encryptedData = cipher.doFinal(rawData);

            logger.debug("加密后的字节数组编码为字符串");
            result = new String(encryptedData);
        } catch (IllegalBlockSizeException | BadPaddingException | NoSuchPaddingException | NoSuchAlgorithmException | InvalidKeyException e) {
            throw new RuntimeException("encrypt error!", e);
        }

        return result;
    }

    @Override
    public String convertToEntityAttribute(String data) {
        String result;

        try {
            logger.debug("获取算法");
            Cipher cipher = Cipher.getInstance(ALGORITHM);

            logger.debug("设置解密模式与解密密钥");
            cipher.init(Cipher.DECRYPT_MODE, KEY);

            logger.debug("获取加密字节数组");
            byte[] encryptedData = data.getBytes(StandardCharsets.UTF_8);

            logger.debug("解密为原始内容");
            byte[] rawData = cipher.doFinal(encryptedData);

            logger.debug("解密");
            result = new String(rawData);
        } catch (IllegalBlockSizeException | BadPaddingException | NoSuchPaddingException | NoSuchAlgorithmException | InvalidKeyException e) {
            throw new RuntimeException("decrypt error!", e);
        }

        return result;
    }
}

再对要加密的字段添加@Convert注解,指明converter为之前编写的EncryptConverter.class

@Entity
public class Information {

    @Id
    @GeneratedValue(strategy = GenerationType.IDENTITY)
    private Long id;

    @Convert(converter = EncryptConverter.class)
    private String content;
}

跑一个单元测试看看是否生效。

@Test
void encrypt() {
    Information information = new Information();
    information.setContent("测试内容");
    informationRepository.save(information);

    Optional optional = informationRepository.findById(information.getId());
    System.out.println(optional);
}

出错了,突然发现事情并不简单。

image.png

因为异常处理得比较好,定位错误十分迅速:

// JPA异常:使用AttributeConverter时发生错误。
org.springframework.orm.jpa.JpaSystemException: Error attempting to apply AttributeConverter; nested exception is javax.persistence.PersistenceException: Error attempting to apply AttributeConverter.
// 由这个错误引起:持久化错误:使用AttributeConverter时发生错误。
Caused by: javax.persistence.PersistenceException: Error attempting to apply AttributeConverter.
// 由这个错误引起:运行错误:加密错误!
Caused by: java.lang.RuntimeException: decrypt error!
// 由这个错误引起:非法的块大小异常,当解密时输入长度必须是16的倍数。
Caused by: javax.crypto.IllegalBlockSizeException: Input length must be multiple of 16 when decrypting with padded cipher.

同时,加密后的数据也有问题:

ag�U�)q\�|����

调试

经过调试与错误排查,发现问题出在字节数组与字符串的转换上。

logger.debug("加密");
byte[] encryptedData = cipher.doFinal(rawData);

logger.debug("加密后的字节数组编码为字符串");
result = new String(encryptedData);

通读String类源码,解决该问题。

String类内部实现中,字符串与字节数组之间的转换,是通过编码与解码实现的。

getBytes方法,将字符串转换为字节数组,即字符到二进制的编码。

public byte[] getBytes(Charset charset) {
    if (charset == null) throw new NullPointerException();
    return StringCoding.encode(charset, value, 0, value.length);
}

String类默认采用UTF-8编码,编码时调用UTF_8类内部的encode方法,将字符数组编码为字节数组。

encode源码,建议阅读。

public int encode(char[] sa, int sp, int len, byte[] da) {
    int sl = sp + len;
    int dp = 0;
    int dlASCII = dp + Math.min(len, da.length);

    // ASCII only optimized loop
    while (dp < dlASCII && sa[sp] < '\u0080')
        da[dp++] = (byte) sa[sp++];

    while (sp < sl) {
        char c = sa[sp++];
        if (c < 0x80) {
            // Have at most seven bits
            da[dp++] = (byte)c;
        } else if (c < 0x800) {
            // 2 bytes, 11 bits
            da[dp++] = (byte)(0xc0 | (c >> 6));
            da[dp++] = (byte)(0x80 | (c & 0x3f));
        } else if (Character.isSurrogate(c)) {
            if (sgp == null)
                sgp = new Surrogate.Parser();
            int uc = sgp.parse(c, sa, sp - 1, sl);
            if (uc < 0) {
                if (malformedInputAction() != CodingErrorAction.REPLACE)
                    return -1;
                da[dp++] = repl;
            } else {
                da[dp++] = (byte)(0xf0 | ((uc >> 18)));
                da[dp++] = (byte)(0x80 | ((uc >> 12) & 0x3f));
                da[dp++] = (byte)(0x80 | ((uc >>  6) & 0x3f));
                da[dp++] = (byte)(0x80 | (uc & 0x3f));
                sp++;  // 2 chars
            }
        } else {
            // 3 bytes, 16 bits
            da[dp++] = (byte)(0xe0 | ((c >> 12)));
            da[dp++] = (byte)(0x80 | ((c >>  6) & 0x3f));
            da[dp++] = (byte)(0x80 | (c & 0x3f));
        }
    }
    return dp;
}

String类中的bytes构造函数,将字节数组转换为字符串,即二进制到字符的解码。

public String(byte bytes[]) {
    this(bytes, 0, bytes.length);
}

与编码类似,解码时调用UTF_8类内部的decode方法,将字节数组解码为字符数组。

decode源码,建议阅读。

public int decode(byte[] sa, int sp, int len, char[] da) {
    final int sl = sp + len;
    int dp = 0;
    int dlASCII = Math.min(len, da.length);
    ByteBuffer bb = null;  // only necessary if malformed

    // ASCII only optimized loop
    while (dp < dlASCII && sa[sp] >= 0)
        da[dp++] = (char) sa[sp++];

    while (sp < sl) {
        int b1 = sa[sp++];
        if (b1 >= 0) {
            // 1 byte, 7 bits: 0xxxxxxx
            da[dp++] = (char) b1;
        } else if ((b1 >> 5) == -2 && (b1 & 0x1e) != 0) {
            // 2 bytes, 11 bits: 110xxxxx 10xxxxxx
            if (sp < sl) {
                int b2 = sa[sp++];
                if (isNotContinuation(b2)) {
                    if (malformedInputAction() != CodingErrorAction.REPLACE)
                        return -1;
                    da[dp++] = replacement().charAt(0);
                    sp--;            // malformedN(bb, 2) always returns 1
                } else {
                    da[dp++] = (char) (((b1 << 6) ^ b2)^
                                   (((byte) 0xC0 << 6) ^
                                    ((byte) 0x80 << 0)));
                }
                continue;
            }
            if (malformedInputAction() != CodingErrorAction.REPLACE)
                return -1;
            da[dp++] = replacement().charAt(0);
            return dp;
        } else if ((b1 >> 4) == -2) {
            // 3 bytes, 16 bits: 1110xxxx 10xxxxxx 10xxxxxx
            if (sp + 1 < sl) {
                int b2 = sa[sp++];
                int b3 = sa[sp++];
                if (isMalformed3(b1, b2, b3)) {
                    if (malformedInputAction() != CodingErrorAction.REPLACE)
                        return -1;
                    da[dp++] = replacement().charAt(0);
                    sp -= 3;
                    bb = getByteBuffer(bb, sa, sp);
                    sp += malformedN(bb, 3).length();
                } else {
                    char c = (char)((b1 << 12) ^
                                      (b2 <<  6) ^
                                      (b3 ^
                                      (((byte) 0xE0 << 12) ^
                                      ((byte) 0x80 <<  6) ^
                                      ((byte) 0x80 <<  0))));
                    if (Character.isSurrogate(c)) {
                        if (malformedInputAction() != CodingErrorAction.REPLACE)
                            return -1;
                        da[dp++] = replacement().charAt(0);
                    } else {
                        da[dp++] = c;
                    }
                }
                continue;
            }
            if (malformedInputAction() != CodingErrorAction.REPLACE)
                return -1;
            if (sp  < sl && isMalformed3_2(b1, sa[sp])) {
                da[dp++] = replacement().charAt(0);
                continue;

            }
            da[dp++] = replacement().charAt(0);
            return dp;
        } else if ((b1 >> 3) == -2) {
            // 4 bytes, 21 bits: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
            if (sp + 2 < sl) {
                int b2 = sa[sp++];
                int b3 = sa[sp++];
                int b4 = sa[sp++];
                int uc = ((b1 << 18) ^
                          (b2 << 12) ^
                          (b3 <<  6) ^
                          (b4 ^
                           (((byte) 0xF0 << 18) ^
                           ((byte) 0x80 << 12) ^
                           ((byte) 0x80 <<  6) ^
                           ((byte) 0x80 <<  0))));
                if (isMalformed4(b2, b3, b4) ||
                    // shortest form check
                    !Character.isSupplementaryCodePoint(uc)) {
                    if (malformedInputAction() != CodingErrorAction.REPLACE)
                        return -1;
                    da[dp++] = replacement().charAt(0);
                    sp -= 4;
                    bb = getByteBuffer(bb, sa, sp);
                    sp += malformedN(bb, 4).length();
                } else {
                    da[dp++] = Character.highSurrogate(uc);
                    da[dp++] = Character.lowSurrogate(uc);
                }
                continue;
            }
            if (malformedInputAction() != CodingErrorAction.REPLACE)
                return -1;
            b1 &= 0xff;
            if (b1 > 0xf4 ||
                sp  < sl && isMalformed4_2(b1, sa[sp] & 0xff)) {
                da[dp++] = replacement().charAt(0);
                continue;
            }
            sp++;
            if (sp  < sl && isMalformed4_3(sa[sp])) {
                da[dp++] = replacement().charAt(0);
                continue;
            }
            da[dp++] = replacement().charAt(0);
            return dp;
        } else {
            if (malformedInputAction() != CodingErrorAction.REPLACE)
                return -1;
            da[dp++] = replacement().charAt(0);
        }
    }
    return dp;
}

DATA LOSS

如果阅读过源码之后,或者对编码比较熟悉,应该明白问题出在哪里了。

解码过程中可能出现解码错误的情况。

image.png

UTF-8中,一个字符肯定对应着一个编码,但是一个编码,不一定对应着一个字符。

logger.debug("加密");
byte[] encryptedData = cipher.doFinal(rawData);

logger.debug("加密后的字节数组编码为字符串");
result = new String(encryptedData);

所以,加密后的字节数组不符合UTF-8编码规则,解码错误。

然后,数据库中加密后的数据就这样了:

ag�U�)q\�|����

解密的时候,查询这个异常的字符串,再调用getBytes方法将其编码为字节数组。

因为解码的错误,所以这里的编码结果也不正确,所以解密中获取的字节数组与原加密后的字节数组不同。

image.png

如上图所示:左侧为编码时的字节数组各下标内容,右侧为解码时的字节数组各下标内容。

两者不等价,这在Java术语中称为DATA LOSS

所以会有这句建议:

String is not a good container for binary data.

解决

将原需要进行编码的new String(bytes)修改为对二进制数据无损的方法调用即可。

我这里采用Base64方式对字节数组进行编码解码,保证数据无损。

完整代码:

/**
 * 加密解决转换器
 * 对象字段数据 与 数据表列之间转换
 */
public class EncryptConverter implements AttributeConverter {

    private static final Logger logger = LoggerFactory.getLogger(EncryptConverter.class);

    /**
     * AES 密钥
     */
    private static final byte[] VALUE = "XQRhrQnGNFJf1WaSGOOJEjNhDjRPMG5N".getBytes(StandardCharsets.UTF_8);

    /**
     * 加密/解密 算法
     */
    private static final String ALGORITHM = "AES";

    /**
     * 加密/解密 密钥
     */
    private static final Key KEY = new SecretKeySpec(VALUE, ALGORITHM);

    /**
     * 加密过程
     * 从对象字段数据 到 数据库列
     */
    @Override
    public String convertToDatabaseColumn(String data) {
        String result;

        try {
            logger.debug("获取算法");
            Cipher cipher = Cipher.getInstance(ALGORITHM);

            logger.debug("设置加密模式与加密密钥");
            cipher.init(Cipher.ENCRYPT_MODE, KEY);

            logger.debug("获取原始内容");
            byte[] rawData = data.getBytes(StandardCharsets.UTF_8);

            logger.debug("加密");
            byte[] encryptedData = cipher.doFinal(rawData);

            logger.debug("BASE64 将加密后的字节数组编码为字符串");
            result = Base64.getEncoder().encodeToString(encryptedData);
        } catch (IllegalBlockSizeException | BadPaddingException | NoSuchPaddingException | NoSuchAlgorithmException | InvalidKeyException e) {
            throw new RuntimeException("encrypt error!", e);
        }

        return result;
    }

    @Override
    public String convertToEntityAttribute(String data) {
        String result;

        try {
            logger.debug("获取算法");
            Cipher cipher = Cipher.getInstance(ALGORITHM);

            logger.debug("设置解密模式与解密密钥");
            cipher.init(Cipher.DECRYPT_MODE, KEY);

            logger.debug("BASE64 将编码后的字符串解码为加密字节数组");
            byte[] encryptedData = Base64.getDecoder().decode(data);

            logger.debug("解密为原始内容");
            byte[] rawData = cipher.doFinal(encryptedData);

            logger.debug("解密");
            result = new String(rawData);
        } catch (IllegalBlockSizeException | BadPaddingException | NoSuchPaddingException | NoSuchAlgorithmException | InvalidKeyException e) {
            throw new RuntimeException("decrypt error!", e);
        }

        return result;
    }
}

测试

再次运行该单元测试。

@Test
void encrypt() {
    Information information = new Information();
    information.setContent("测试内容");
    informationRepository.save(information);

    Optional optional = informationRepository.findById(information.getId());
    System.out.println(optional);
}

数据表中数据完成加密:

image.png

数据查询后自动解密:

image.png

总结

源码,因为面试而变了味道。

程序,因为资本而变得功利。

你,还记得写代码的初心吗?

你可能感兴趣的:(jpa,aes,加密,解密,string)