pytorch计算参数量

简易代码

def print_model_parm_nums():
    model = models.alexnet()
    total = sum([param.nelement() for param in model.parameters()])
    print('  + Number of params: %.2fM' % (total / 1e6))

你可能感兴趣的:(pytorch计算参数量)