国密算法 SM2 公钥加密 数字签名 密钥交换 全网最高效的开源python代码

此前发布过SM2、SM3、SM4、ZUC等文章,以及开源的完整python代码。近些天看到一篇电子科大兰同学的硕士毕业论文(兰修文. ECC计算算法的优化及其在SM2实现中的运用[D]. 成都: 电子科技大学, 2019),文中采用预计算加速SM2椭圆曲线基点点乘,将这个思路用python代码实现后,实测比起原来的SM2又有4-5倍的提升。现把全网最快(也是功能实现最全)的SM2完整python代码分享出来(小弟口出狂言,若班门弄斧,还请大佬勿怪O(∩_∩)O)。愿大家同心协力推动国密算法普及,为国家网络安全添砖加瓦!

介绍其他国密算法的链接如下:

上一篇SM2:国密算法 SM2 公钥加密 非对称加密 数字签名 密钥协商 python实现完整代码_qq_43339242的博客-CSDN博客_python sm2

SM3:国密算法 SM3 消息摘要 杂凑算法 哈希函数 散列函数 python实现完整代码_qq_43339242的博客-CSDN博客_国密sm3

SM4:国密算法 SM4 对称加密 分组密码 python实现完整代码_qq_43339242的博客-CSDN博客_python国密算法库

ZUC:国密算法 ZUC流密码 祖冲之密码 python代码完整实现_qq_43339242的博客-CSDN博客_国密算法代码

对上述几个算法和实现不了解的,建议点进去看看。下面这篇文章是对上述的汇总:

国密算法 SM2公钥密码 SM3杂凑算法 SM4分组密码 python代码完整实现_qq_43339242的博客-CSDN博客_python sm2

所有代码托管在码云:hggm - 国密算法 SM2 SM3 SM4 python实现完整代码: 国密算法 SM2公钥密码 SM3杂凑算法 SM4分组密码 python代码完整实现 效率高于所有公开的python国密算法库 (gitee.com)

 SM2代码如下:

import random
import math
from hggm.SM3 import digest as sm3
from Crypto.PublicKey import ECC
from Crypto.Math.Numbers import Integer
from Crypto.Util._raw_api import VoidPointer, SmartPointer


# 转换为bytes,第二参数为字节数(可不填)
def to_byte(x, size=None):
    if isinstance(x, int):
        if size is None:  # 计算合适的字节数
            size = 0
            tmp = x >> 64
            while tmp:
                size += 8
                tmp >>= 64
            tmp = x >> (size << 3)
            while tmp:
                size += 1
                tmp >>= 8
        elif x >> (size << 3):  # 指定的字节数不够则截取低位
            x &= (1 << (size << 3)) - 1
        return x.to_bytes(size, byteorder='big')
    elif isinstance(x, str):
        x = x.encode()
        if size is not None and len(x) > size:  # 超过指定长度
            x = x[:size]  # 截取左侧字符
        return x
    elif isinstance(x, bytes):
        if size is not None and len(x) > size:  # 超过指定长度
            x = x[:size]  # 截取左侧字节
        return x
    elif isinstance(x, tuple):  # 坐标形式(x, y)
        if size is None:
            size = PARA_SIZE
        return to_byte(x[0], size) + to_byte(x[1], size)
    elif isinstance(x, Integer):
        return to_byte(int(x), size)
    elif isinstance(x, ECC.EccPoint):
        x, y = x.xy
        return to_byte(int(x), PARA_SIZE) + to_byte(int(y), PARA_SIZE)
    return bytes(x)


SM2_p = 0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF
SM2_a = 0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC
SM2_b = 0x28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93
SM2_n = 0xFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123
SM2_Gx = 0x32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7
SM2_Gy = 0xBC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0
PARA_SIZE = 32  # 参数长度(字节)
HASH_SIZE = 32  # sm3输出256位(32字节)
KEY_LEN = 128  # 默认密钥位数
SM2_w_l_1 = math.ceil(math.ceil(math.log(SM2_n, 2)) / 2) - 1  # w * 2


# 使用Crypto.PublicKey.ECC库生成SM2曲线
def gen_ECC_curve(p=SM2_p, b=SM2_b, n=SM2_n, Gx=SM2_Gx, Gy=SM2_Gy):
    # 如果未采用SM2标准规定的参数,则需要用pre_kG函数重新预计算kG,将输出矩阵作为全局变量kG_tuple,否则无法正确计算
    if 'sm2' not in ECC._curves or ECC._curves['sm2'].p != p:
        ec_context = VoidPointer()
        res = ECC._ec_lib.ec_ws_new_context(ec_context.address_of(), to_byte(p, PARA_SIZE), to_byte(b, PARA_SIZE),
                                            to_byte(n, PARA_SIZE), PARA_SIZE, random.getrandbits(64))
        if res:
            raise ImportError("Error %d initializing SM2 context" % res)
        context = SmartPointer(ec_context.get(), ECC._ec_lib.ec_free_context)
        ECC._curves['sm2'] = ECC._Curve(p, b, n, Gx, Gy, None, 256, p, context, '', '')


# 创建SM2点
def SM2_Point(xy):
    return ECC.EccPoint(xy[0], xy[1], 'sm2')


# 将字节转换为int
def to_int(byte):
    return int.from_bytes(byte, byteorder='big')


# 将列表元素转换为bytes并连接
def join_bytes(*data_list):
    return b''.join(to_byte(i) for i in data_list)


# 求最大公约数
def gcd(a, b):
    return a if b == 0 else gcd(b, a % b)


# 求乘法逆元过程中的辅助递归函数
def get_(a, b):
    if b == 0:
        return 1, 0
    x1, y1 = get_(b, a % b)
    x, y = y1, x1 - a // b * y1
    return x, y


# 求乘法逆元
def get_inverse(a, p):
    # return pow(a, p-2, p) # 效率较低、n倍点的时候两种计算方法结果会有不同
    if gcd(a, p) == 1:
        x, y = get_(a, p)
        return x % p
    return 1


# 密钥派生函数(从一个共享的秘密比特串中派生出密钥数据)
# SM2第3部分 5.4.3
# Z为bytes类型
# klen表示要获得的密钥数据的比特长度(8的倍数),int类型
# 输出为bytes类型
def KDF(Z, klen):
    ksize = klen >> 3
    K = bytearray()
    for ct in range(1, math.ceil(ksize / HASH_SIZE) + 1):
        K.extend(sm3(Z + to_byte(ct, 4)))
    return K[:ksize]


# 计算比特位数
def get_bit_num(x):
    if isinstance(x, int):
        num = 0
        tmp = x >> 64
        while tmp:
            num += 64
            tmp >>= 64
        tmp = x >> num >> 8
        while tmp:
            num += 8
            tmp >>= 8
        x >>= num
        while x:
            num += 1
            x >>= 1
        return num
    elif isinstance(x, str):
        return len(x.encode()) << 3
    elif isinstance(x, bytes):
        return len(x) << 3
    return 0


# 预计算kG(输出txt文本,内容为32行255列的椭圆曲线点矩阵)
def pre_kG():
    mySM2 = SM2()
    kG_point = [k * mySM2.G for k in range(1, 256)]
    with open('SM2_kG.txt', 'w') as f:
        print('[[' + ','.join(map(lambda p: '(0x%x,0x%x)' % p.xy, kG_point)) + '],', file=f)
        for i in range(31):
            for p in kG_point:
                p *= 256
            print('[' + ','.join(map(lambda p: '(0x%x,0x%x)' % p.xy, kG_point)) + '],\n', end='', file=f)
        print(']', file=f)


gen_ECC_curve()
# 计算好的kG点坐标(来自pre_kG函数输出的文本内容)
kG_tuple = []  # 这里本来不是空着的!PS:说我字数太多,所以[]里的大量内容就删掉了,可以在mian函数里运行pre_kG(),并把生成的文本内容复制进[]里面。当然,更建议直接去gitee下载完整源代码(●'◡'●)
# 将计算好的kG点坐标转换成SM2点
kG_points = [list(map(SM2_Point, plist)) for plist in kG_tuple]


# 采用预计算好的数据快速计算kG
def kG(k):
    P, i = None, 0
    while k:
        last_byte = k & 0xff
        if last_byte:
            if P is None:
                P = SM2_Point(kG_tuple[i][last_byte - 1])
            else:
                P += kG_points[i][last_byte - 1]
        k >>= 8
        i += 1
    return P


# SM2类继承ECC
class SM2:
    # 默认使用SM2推荐曲线参数
    def __init__(self, p=SM2_p, a=SM2_a, b=SM2_b, n=SM2_n, G=(SM2_Gx, SM2_Gy), h=1,  # 余因子h默认为1
                 ID=None, sk=None, pk=None, genkeypair=True):  # genkeypair表示是否自动生成公私钥对
        gen_ECC_curve(p, b, n, G[0], G[1])
        self.p, self.a, self.b, self.n, self.Gx, self.Gy, self.G, self.h = p, a, b, n, G[0], G[1], SM2_Point(G), h
        # 除曲线外的其他参数
        self.ID = ID if type(ID) in (int, str) else ''  # 身份ID(数字或字符串)
        if sk and pk:  # 已提供公私钥对
            try:  # 验证该公私钥对
                if kG(sk) == SM2_Point(pk):  # 通过验证,即使genkeypair=True也不会重新生成
                    self.sk, self.pk = sk, pk  # 私钥(int [1,n-2]),公钥(x, y)
                else:  # 不合格则生成
                    self.sk, self.pk = self.gen_keypair()
            except ValueError:  # 不在曲线上会报错,重新生成
                self.sk, self.pk = self.gen_keypair()
        elif genkeypair:  # 自动生成合格的公私钥对
            self.sk, self.pk = self.gen_keypair()

        # 预先计算用到的常数
        self.Z_tmp = to_byte(a) + to_byte(b) + to_byte(G[0]) + to_byte(G[1])  # Z值的中间部分
        if hasattr(self, 'sk'):  # 签名时
            self.d_1 = get_inverse(1 + self.sk, n)

    # 判断是否在椭圆曲线上
    def on_curve(self, p):
        try:
            return SM2_Point(p)  # 不报错则返回SM2_Point对象
        except ValueError:  # 报错说明不在曲线上
            return False

    # 生成密钥对
    # 返回值:d为私钥,P为公钥
    # SM2第1部分 6.1
    def gen_keypair(self, toTuple=True):
        d = random.randint(1, self.n - 2)
        P = kG(d)
        return d, tuple(map(int, P.xy)) if toTuple else P

    # 计算Z
    # SM2第2部分 5.5
    # ID为数字或字符串,P为公钥(不提供参数时返回自身Z值)
    def get_Z(self, ID=None, P=None):
        save = False
        if P is None:  # 不提供参数
            if hasattr(self, 'Z'):  # 再次计算,返回曾计算好的自身Z值
                return self.Z
            else:  # 首次计算自身Z值
                ID = self.ID
                P = self.pk
                save = True
        entlen = get_bit_num(ID)
        ENTL = to_byte(entlen, 2)
        Z = sm3(join_bytes(ENTL, ID, self.Z_tmp, P))
        if save:  # 保存自身Z值
            self.Z = Z
        return Z

    # 数字签名
    # SM2第2部分 6.1
    # 输入:待签名的消息M、随机数k(不填则自动生成)、输出类型(默认bytes)、对M是否hash(默认是)
    # 输出:r, s(int类型)或拼接后的bytes
    def sign(self, M, k=None, outbytes=True, dohash=True):
        if dohash:
            M_ = join_bytes(self.get_Z(), M)
            e = to_int(sm3(M_))
        else:
            e = to_int(to_byte(M))
        while True:
            if not k:
                k = random.randint(1, self.n - 1)
            x1 = int(kG(k).x)
            r = (e + x1) % self.n
            if r == 0 or r + k == self.n:
                k = 0
                continue
            s = self.d_1 * (k - r * self.sk) % self.n
            if s:
                break
            k = 0
        return to_byte((r, s), PARA_SIZE) if outbytes else (r, s)

    # 数字签名验证
    # SM2第2部分 7.1
    # 输入:收到的消息M′及其数字签名(r′, s′)、签名者的身份标识IDA及公钥PA、对M是否hash(默认是)
    # 输出:True or False
    def verify(self, M, sig, IDA, PA, dohash=True):
        PA_bytes = to_byte(PA)
        PA = self.on_curve(PA)
        if not PA:
            return False  # 对方公钥不在椭圆曲线上
        if isinstance(sig, bytes):
            r = to_int(sig[:PARA_SIZE])
            s = to_int(sig[PARA_SIZE:])
        else:
            r, s = sig
        if not 1 <= r <= self.n - 1 or not 1 <= s <= self.n - 1:
            return False
        if dohash:
            M_ = join_bytes(self.get_Z(IDA, PA_bytes), M)
            e = to_int(sm3(M_))
        else:
            e = to_int(to_byte(M))
        t = (r + s) % self.n
        if t == 0:
            return False
        PA *= t
        PA += kG(s)
        x1 = int(PA.x)
        # x1 = int((kG(s) + t * PA).x)
        R = (e + x1) % self.n
        return R == r

    # A 发起协商
    # SM2第3部分 6.1 A1-A3
    # 返回rA、RA
    def agreement_initiate(self):
        return self.gen_keypair()

    # B 响应协商(option=True时计算选项部分)
    # SM2第3部分 6.1 B1-B9
    def agreement_response(self, RA, PA, IDA, option=False, rB=None, RB=None, klen=None):
        # 参数准备
        PA_bytes = to_byte(PA)
        PA = self.on_curve(PA)
        if not PA:
            return False, '对方公钥不在椭圆曲线上'
        x1 = RA[0]
        RA = self.on_curve(RA)
        if not RA:
            return False, 'RA不在椭圆曲线上'

        if not hasattr(self, 'sk'):
            self.sk, self.pk = self.gen_keypair()
        ZA = self.get_Z(IDA, PA_bytes)
        ZB = self.get_Z()
        # B1-B7
        if not rB:
            rB, RB = self.gen_keypair()
        x2 = RB[0]
        x_2 = SM2_w_l_1 + (x2 & SM2_w_l_1 - 1)
        tB = (self.sk + x_2 * rB) % self.n
        x_1 = SM2_w_l_1 + (x1 & SM2_w_l_1 - 1)
        RA_bytes = to_byte(RA)
        RA *= x_1
        RA += PA
        RA *= self.h * tB
        xV, yV = RA.xy
        # V = (self.h * tB) * (x_1 * RA + PA)
        if (xV, yV) == (0, 0):
            return False, 'V是无穷远点'
        if not klen:
            klen = KEY_LEN
        KB = KDF(join_bytes((xV, yV), ZA, ZB), klen)
        if not option:
            return True, (RB, KB)
        # B8、B10(可选部分)
        tmp = join_bytes(yV, sm3(join_bytes(xV, ZA, ZB, RA_bytes, RB)))
        SB = sm3(join_bytes(2, tmp))
        S2 = sm3(join_bytes(3, tmp))
        return True, (RB, KB, SB, S2)

    # A 协商确认
    # SM2第3部分 6.1 A4-A10
    def agreement_confirm(self, rA, RA, RB, PB, IDB, SB=None, option=False, klen=None):
        # 参数准备
        PB_bytes = to_byte(PB)
        PB = self.on_curve(PB)
        if not PB:
            return False, '对方公钥不在椭圆曲线上'
        x1, x2 = RA[0], RB[0]
        RB = self.on_curve(RB)
        if not RB:
            return False, 'RB不在椭圆曲线上'
        if not hasattr(self, 'sk'):
            self.sk, self.pk = self.gen_keypair()
        ZA = self.get_Z()
        ZB = self.get_Z(IDB, PB_bytes)
        # A4-A8
        x_1 = SM2_w_l_1 + (x1 & SM2_w_l_1 - 1)
        tA = (self.sk + x_1 * rA) % self.n
        x_2 = SM2_w_l_1 + (x2 & SM2_w_l_1 - 1)
        RB_bytes = to_byte(RB)
        RB *= x_2
        RB += PB
        RB *= self.h * tA
        xU, yU = RB.xy
        # U = (self.h * tA) * (x_2 * RB + PB)
        if (xU, yU) == (0, 0):
            return False, 'U是无穷远点'
        if not klen:
            klen = KEY_LEN
        KA = KDF(join_bytes((xU, yU), ZA, ZB), klen)
        if not option or not SB:
            return True, KA
        # A9-A10(可选部分)
        tmp = join_bytes(yU, sm3(join_bytes(xU, ZA, ZB, RA, RB_bytes)))
        S1 = sm3(join_bytes(2, tmp))
        if S1 != SB:
            return False, 'S1 != SB'
        SA = sm3(join_bytes(3, tmp))
        return True, (KA, SA)

    # B 协商确认(可选部分)
    # SM2第3部分 6.1 B10
    def agreement_confirm2(self, S2, SA):
        if S2 != SA:
            return False, 'S2 != SA'
        return True, ''

    # 加密
    # SM2第4部分 6.1
    # 输入:待加密的消息M(bytes或str类型)、对方的公钥PB、随机数k(不填则自动生成)
    # 输出(True, bytes类型密文)或(False, 错误信息)
    def encrypt(self, M, PB, k=None):
        PB = self.on_curve(PB)
        if not PB:
            return False, '对方公钥不在椭圆曲线上'
        M = to_byte(M)
        klen = get_bit_num(M)
        while True:
            if not k:
                k = random.randint(1, self.n - 1)
            PB *= k
            x2, y2 = PB.xy
            # x2, y2 = (k * PB).xy
            t = to_int(KDF(to_byte((x2, y2)), klen))
            if t:
                break
            k = 0  # 若t为全0比特串则继续循环
        C1 = to_byte(kG(k), PARA_SIZE)  # (x1, y1)
        C2 = to_byte(to_int(M) ^ t, klen >> 3)
        C3 = sm3(join_bytes(x2, M, y2))
        return True, join_bytes(C1, C2, C3)

    # 解密
    # SM2第4部分 7.1
    # 输入:密文C(bytes类型)
    # 输出(True, bytes类型明文)或(False, 错误信息)
    def decrypt(self, C):
        x1 = to_int(C[:PARA_SIZE])
        y1 = to_int(C[PARA_SIZE:PARA_SIZE << 1])
        C1 = self.on_curve((x1, y1))
        if not C1:
            return False, 'C1不满足椭圆曲线方程'
        C1 *= self.sk
        x2, y2 = C1.xy
        # x2, y2 = (self.sk * C1).xy
        klen = len(C) - (PARA_SIZE << 1) - HASH_SIZE << 3
        t = to_int(KDF(to_byte((x2, y2)), klen))
        if t == 0:
            return False, 't为全0比特串'
        C2 = C[PARA_SIZE << 1:-HASH_SIZE]
        M = to_byte(to_int(C2) ^ t, klen >> 3)
        u = sm3(join_bytes(x2, M, y2))
        C3 = C[-HASH_SIZE:]
        if u != C3:
            return False, 'u != C3'
        return True, M
注意:pycryptodome老版本没有int * EccPoint实现,新版本(3.14.1)缺SHA256和ARC4的链接库,我用的3.10.1版本是没问题的(pip install pycryptodome==3.10.1)

下面来看性能测试结果:

国密算法 SM2 公钥加密 数字签名 密钥交换 全网最高效的开源python代码_第1张图片

机器配置不高,处理器i3-10110U,不过差不了太多,相对速度还是有参考意义的。点乘测试都是针对椭圆曲线基点G进行,算法一、二、三是本实现的未加速版本(纯python实现,参考上一篇写SM2的文章),“加速后”指调用PyCryptodome库进行SM2椭圆曲线k·G运算。PyCryptodome库针对ECC的NIST P-256参数进行了数学优化,所以ECC计算k·G很快。但采用预计算以后,SM2的k·G已经追上ECC,且SM2的签名、验证比ECC签名、验证(即ECDSA算法)更快,可惜的是PyCryptodome库目前仍未实现ECC加密和解密。

用法请参考测试代码,需要的请下载完整代码,再次奉上链接:hggm - 国密算法 SM2 SM3 SM4 python实现完整代码: 国密算法 SM2公钥密码 SM3杂凑算法 SM4分组密码 python代码完整实现 效率高于所有公开的python国密算法库 (gitee.com)

你可能感兴趣的:(国密算法,密码,Python,python)