SM4加密解密的C++代码

SM4的加密,使用了ECB模式,PKCS7填充,实现了对十进制、字符串的加密解密,加密结果和解密参数支持十六进制字符串和base64两种方式。
本代码源于文章:https://blog.csdn.net/nicai_hualuo/article/details/121626931 中提供的算法,进行了完善而已。注意HexToBin和HexToDec,加入了对小写十六进制的支持。

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

using namespace std;

namespace sm4
{
	static const std::string base64_chars =
		"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
		"abcdefghijklmnopqrstuvwxyz"
		"0123456789+/";

	static inline bool is_base64(unsigned char c)
	{
		return (isalnum(c) || (c == '+') || (c == '/'));
	}

	std::string base64_encode(const string &text)
	{
		const char *bytes_to_encode = text.c_str();
		int in_len = text.size();

		std::string ret;
		int i = 0;
		int j = 0;
		unsigned char char_array_3[3];
		unsigned char char_array_4[4];

		while (in_len--)
		{
			char_array_3[i++] = *(bytes_to_encode++);
			if (i == 3)
			{
				char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
				char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
				char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
				char_array_4[3] = char_array_3[2] & 0x3f;

				for (i = 0; (i < 4); i++)
					ret += base64_chars[char_array_4[i]];
				i = 0;
			}
		}

		if (i)
		{
			for (j = i; j < 3; j++)
				char_array_3[j] = '\0';

			char_array_4[0] = (char_array_3[0] & 0xfc) >> 2;
			char_array_4[1] = ((char_array_3[0] & 0x03) << 4) + ((char_array_3[1] & 0xf0) >> 4);
			char_array_4[2] = ((char_array_3[1] & 0x0f) << 2) + ((char_array_3[2] & 0xc0) >> 6);
			char_array_4[3] = char_array_3[2] & 0x3f;

			for (j = 0; (j < i + 1); j++)
				ret += base64_chars[char_array_4[j]];

			while ((i++ < 3))
				ret += '=';
		}

		return ret;
	}

	std::string base64_decode(const std::string &encoded_string)
	{
		int in_len = encoded_string.size();
		int i = 0;
		int j = 0;
		int in_ = 0;
		unsigned char char_array_4[4], char_array_3[3];
		std::string ret;

		while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_]))
		{
			char_array_4[i++] = encoded_string[in_];
			in_++;
			if (i == 4)
			{
				for (i = 0; i < 4; i++)
					char_array_4[i] = base64_chars.find(char_array_4[i]);

				char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
				char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
				char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];

				for (i = 0; (i < 3); i++)
					ret += char_array_3[i];
				i = 0;
			}
		}

		if (i)
		{
			for (j = i; j < 4; j++)
				char_array_4[j] = 0;

			for (j = 0; j < 4; j++)
				char_array_4[j] = base64_chars.find(char_array_4[j]);

			char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
			char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
			char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];

			for (j = 0; (j < i - 1); j++)
				ret += char_array_3[j];
		}

		return ret;
	}

	std::string HexToStr(std::string str)
	{
		std::string hex = str;
		long len = hex.length();
		std::string newString;
		for (long i = 0; i < len; i += 2)
		{
			std::string byte = hex.substr(i, 2);
			char chr = (char)(int)strtol(byte.c_str(), NULL, 16);
			newString.push_back(chr);
		}
		return newString;
	}

	std::string StrToHex(std::string str)
	{
		unsigned char c;
		char buf[2];
		std::string result = "";
		std::stringstream ss;
		ss << str;
		while (ss.read((char *)(&c), sizeof(c)))
		{
			sprintf(buf, "%02X", c);
			result += buf;
		}
		return result;
	}

	string PKCS7(string str)
	{
		if (str.size() < 16)
		{
			char ch = 16 - str.size();
			str.append((size_t)ch, ch);
		}
		return str;
	}

	string BinToHex(string str)
	{ // 二进制转换为十六进制的函数实现
		string hex = "";
		int temp = 0;
		while (str.size() % 4 != 0)
		{
			str = "0" + str;
		}
		for (int i = 0; i < str.size(); i += 4)
		{
			temp = (str[i] - '0') * 8 + (str[i + 1] - '0') * 4 + (str[i + 2] - '0') * 2 + (str[i + 3] - '0') * 1;
			if (temp < 10)
			{
				hex += to_string(temp);
			}
			else
			{
				hex += 'A' + (temp - 10);
			}
		}
		return hex;
	}

	string HexToBin(string str)
	{ // 十六进制转换为二进制的函数实现
		string bin = "";
		string table[16] = {"0000", "0001", "0010", "0011", "0100", "0101", "0110", "0111", "1000", "1001", "1010", "1011", "1100", "1101", "1110", "1111"};
		for (int i = 0; i < str.size(); i++)
		{
			if (str[i] >= 'A' && str[i] <= 'F')
			{
				bin += table[str[i] - 'A' + 10];
			}
			else if (str[i] >= 'a' && str[i] <= 'f')
			{
				bin += table[str[i] - 'a' + 10];
			}
			else
			{
				bin += table[str[i] - '0'];
			}
		}
		return bin;
	}

	int HexToDec(char str)
	{ // 十六进制转换为十进制的函数实现
		int dec = 0;
		if (str >= 'A' && str <= 'F')
		{
			dec += (str - 'A' + 10);
		}
		else if (str >= 'a' && str <= 'f')
		{
			dec += (str - 'a' + 10);
		}
		else
		{
			dec += (str - '0');
		}
		return dec;
	}

	string LeftShift(string str, int len)
	{ // 循环左移len位函数实现
		string res = HexToBin(str);
		res = res.substr(len) + res.substr(0, len);
		return BinToHex(res);
	}

	string XOR(string str1, string str2)
	{ // 异或函数实现
		string res1 = HexToBin(str1);
		string res2 = HexToBin(str2);
		string res = "";
		for (int i = 0; i < res1.size(); i++)
		{
			if (res1[i] == res2[i])
			{
				res += "0";
			}
			else
			{
				res += "1";
			}
		}
		return BinToHex(res);
	}

	string NLTransform(string str)
	{ // 非线性变换t函数实现
		string Sbox[16][16] = {{"D6", "90", "E9", "FE", "CC", "E1", "3D", "B7", "16", "B6", "14", "C2", "28", "FB", "2C", "05"},
							   {"2B", "67", "9A", "76", "2A", "BE", "04", "C3", "AA", "44", "13", "26", "49", "86", "06", "99"},
							   {"9C", "42", "50", "F4", "91", "EF", "98", "7A", "33", "54", "0B", "43", "ED", "CF", "AC", "62"},
							   {"E4", "B3", "1C", "A9", "C9", "08", "E8", "95", "80", "DF", "94", "FA", "75", "8F", "3F", "A6"},
							   {"47", "07", "A7", "FC", "F3", "73", "17", "BA", "83", "59", "3C", "19", "E6", "85", "4F", "A8"},
							   {"68", "6B", "81", "B2", "71", "64", "DA", "8B", "F8", "EB", "0F", "4B", "70", "56", "9D", "35"},
							   {"1E", "24", "0E", "5E", "63", "58", "D1", "A2", "25", "22", "7C", "3B", "01", "21", "78", "87"},
							   {"D4", "00", "46", "57", "9F", "D3", "27", "52", "4C", "36", "02", "E7", "A0", "C4", "C8", "9E"},
							   {"EA", "BF", "8A", "D2", "40", "C7", "38", "B5", "A3", "F7", "F2", "CE", "F9", "61", "15", "A1"},
							   {"E0", "AE", "5D", "A4", "9B", "34", "1A", "55", "AD", "93", "32", "30", "F5", "8C", "B1", "E3"},
							   {"1D", "F6", "E2", "2E", "82", "66", "CA", "60", "C0", "29", "23", "AB", "0D", "53", "4E", "6F"},
							   {"D5", "DB", "37", "45", "DE", "FD", "8E", "2F", "03", "FF", "6A", "72", "6D", "6C", "5B", "51"},
							   {"8D", "1B", "AF", "92", "BB", "DD", "BC", "7F", "11", "D9", "5C", "41", "1F", "10", "5A", "D8"},
							   {"0A", "C1", "31", "88", "A5", "CD", "7B", "BD", "2D", "74", "D0", "12", "B8", "E5", "B4", "B0"},
							   {"89", "69", "97", "4A", "0C", "96", "77", "7E", "65", "B9", "F1", "09", "C5", "6E", "C6", "84"},
							   {"18", "F0", "7D", "EC", "3A", "DC", "4D", "20", "79", "EE", "5F", "3E", "D7", "CB", "39", "48"}};
		string res = "";
		for (int i = 0; i < 4; i++)
		{
			res = res + Sbox[HexToDec(str[2 * i])][HexToDec(str[2 * i + 1])];
		}
		return res;
	}

	string LTransform(string str)
	{ // 线性变换L函数实现
		return XOR(XOR(XOR(XOR(str, LeftShift(str, 2)), LeftShift(str, 10)), LeftShift(str, 18)), LeftShift(str, 24));
	}

	string L2Transform(string str)
	{ // 线性变换L'函数实现
		return XOR(XOR(str, LeftShift(str, 13)), LeftShift(str, 23));
	}

	string T(string str)
	{ // 用于加解密算法中的合成置换T函数实现
		return LTransform(NLTransform(str));
	}

	string T2(string str)
	{ // 用于密钥扩展算法中的合成置换T函数实现
		return L2Transform(NLTransform(str));
	}

	string KeyExtension(string MK)
	{ // 密钥扩展函数实现
		string FK[4] = {"A3B1BAC6", "56AA3350", "677D9197", "B27022DC"};
		string CK[32] = {"00070E15", "1C232A31", "383F464D", "545B6269",
						 "70777E85", "8C939AA1", "A8AFB6BD", "C4CBD2D9",
						 "E0E7EEF5", "FC030A11", "181F262D", "343B4249",
						 "50575E65", "6C737A81", "888F969D", "A4ABB2B9",
						 "C0C7CED5", "DCE3EAF1", "F8FF060D", "141B2229",
						 "30373E45", "4C535A61", "686F767D", "848B9299",
						 "A0A7AEB5", "BCC3CAD1", "D8DFE6ED", "F4FB0209",
						 "10171E25", "2C333A41", "484F565D", "646B7279"};
		string K[36] = {XOR(MK.substr(0, 8), FK[0]), XOR(MK.substr(8, 8), FK[1]), XOR(MK.substr(16, 8), FK[2]), XOR(MK.substr(24), FK[3])};
		string rks = "";
		for (int i = 0; i < 32; i++)
		{
			K[i + 4] = XOR(K[i], T2(XOR(XOR(XOR(K[i + 1], K[i + 2]), K[i + 3]), CK[i])));
			rks += K[i + 4];
		}
		return rks;
	}

	string sm4_encode(string &hex32, string &key)
	{ // 加密函数实现
		// cout << "轮密钥与每轮输出状态:" << endl;
		// cout << endl;
		string cipher[36] = {hex32.substr(0, 8), hex32.substr(8, 8), hex32.substr(16, 8), hex32.substr(24)};
		string rks = KeyExtension(key);
		for (int i = 0; i < 32; i++)
		{
			cipher[i + 4] = XOR(cipher[i], T(XOR(XOR(XOR(cipher[i + 1], cipher[i + 2]), cipher[i + 3]), rks.substr(8 * i, 8))));
			// cout << "rk[" + to_string(i) + "] = " + rks.substr(8 * i, 8) + "    X[" + to_string(i) + "] = " + cipher[i + 4] << endl;
		}
		// cout << endl;
		return cipher[35] + cipher[34] + cipher[33] + cipher[32];
	}

	string sm4_decode(string &hex32, string &key)
	{ // 解密函数实现
		// cout << "轮密钥与每轮输出状态:" << endl;
		// cout << endl;
		string plain[36] = {hex32.substr(0, 8), hex32.substr(8, 8), hex32.substr(16, 8), hex32.substr(24, 8)};
		string rks = KeyExtension(key);
		for (int i = 0; i < 32; i++)
		{
			plain[i + 4] = XOR(plain[i], T(XOR(XOR(XOR(plain[i + 1], plain[i + 2]), plain[i + 3]), rks.substr(8 * (31 - i), 8))));
			// cout << "rk[" + to_string(i) + "] = " + rks.substr(8 * (31 - i), 8) + "    X[" + to_string(i) + "] = " + plain[i + 4] << endl;
		}
		// cout << endl;
		return plain[35] + plain[34] + plain[33] + plain[32];
	}

	// 针对字符串加密成十六进制数据,使用ZeroPadding策略,不足16字节的补0
	string sm4encodestrhex(string text, string key)
	{
		int pos = 0;
		string all_cipher;
		while (pos < text.size())
		{
			int hasLen = text.size() - pos;
			string hex32 = hasLen > 16 ? text.substr(pos, 16) : text.substr(pos);
			if (hex32.size() < 16)
			{
				hex32 = PKCS7(hex32);
			}
			hex32 = StrToHex(hex32);
			all_cipher += sm4_encode(hex32, key);
			if (hasLen == 16)
			{
				// 正好16字节时,要补充16位
				hex32 = PKCS7("");
				hex32 = StrToHex(hex32);
				all_cipher += sm4_encode(hex32, key);
			}
			pos += 16;
		}
		return all_cipher;
	}

	// 针对字符串加密成十六进制后的解密,加密时使用了ZeroPadding策略,不足16字节的补0
	string sm4decodehexstr(string cipher, string key)
	{
		int pos = 0;
		string all_plain;
		while (pos < cipher.size())
		{
			int hasLen = cipher.size() - pos;
			if (hasLen % 32 != 0)
			{
				break; // 忽略无法解码的部分
			}
			string text32 = hasLen > 32 ? cipher.substr(pos, 32) : cipher.substr(pos);
			string hex32 = sm4_decode(text32, key);
			string one_plain = HexToStr(hex32);
			if (hasLen == 32)
			{
				// 去掉填充的数据(有可能整个数据都是填充的(16位整数倍的字符串加密时)
				int size = one_plain.at(one_plain.size() - 1);
				if (size >= 16)
				{
					one_plain.clear();
				}
				else
				{
					one_plain = one_plain.substr(0, 16 - size);
				}
			}
			all_plain.append(one_plain.c_str()); // 按字符串拼接,就可以去掉后面的0
			pos += 32;
		}
		return all_plain;
	}

	string sm4encodestrbase64(string text, string key)
	{
		string hex = sm4encodestrhex(text, key);
		string str = HexToStr(hex);
		return base64_encode(str);
	}

	string sm4decodestrbase64(string base64_text, string key)
	{
		string str = base64_decode(base64_text);
		string hex = StrToHex(str);
		return sm4decodehexstr(hex, key);
	}

}

// int main()
// {
// 	string text, cipher, plain, base64text;

// 	string key = "Myu0T1#vD5VqUEgt";
// 	key = sm4::StrToHex(key); // 转换十六进制使用
// 	cout << "密钥:" << key << endl;

// 	text = "0123456789ABCDEFFEDCBA9876543210"; // 十六进制
// 	cipher = sm4::sm4_encode(text, key);
// 	plain = sm4::sm4_decode(cipher, key);
// 	cout << "十六进制 " << text << " 加密后 " << cipher << " 解密后 " << plain << endl;

// 	text = "012345678912345";
// 	cipher = sm4::sm4encodestrhex(text, key);
// 	plain = sm4::sm4decodehexstr(cipher, key);
// 	cout << "字符串 \"" << text << "\" HEX加密后 " << cipher << " 解密后 \"" << plain << "\"" << endl;

// 	base64text = sm4::sm4encodestrbase64(text, key);
// 	plain = sm4::sm4decodestrbase64(base64text, key);
// 	cout << base64text << endl;
// 	cout << "字符串 \"" << text << "\" BASE64加密后 " << base64text << " 解密后 \"" << plain << "\"" << endl;

// 	text = "0123456789123456";
// 	cipher = sm4::sm4encodestrhex(text, key);
// 	plain = sm4::sm4decodehexstr(cipher, key);
// 	cout << "字符串 \"" << text << "\" HEX加密后 " << cipher << " 解密后 \"" << plain << "\"" << endl;

// 	base64text = sm4::sm4encodestrbase64(text, key);
// 	plain = sm4::sm4decodestrbase64(base64text, key);
// 	cout << base64text << endl;
// 	cout << "字符串 \"" << text << "\" BASE64加密后 " << base64text << " 解密后 \"" << plain << "\"" << endl;

// 	return 0;
// }

你可能感兴趣的:(C/C++,c++,开发语言)