Pytorch 统计网络参数个数

1 代码

def count_parameters(model):  # 传入的是模型实例对象
    params = [p.numel() for p in model.parameters() if p.requires_grad]
    for item in params:
        print(f'{item:>16}')   # 参数大于16的展示
    print(f'________\n{sum(params):>16}')  # 大于16的进行统计,可以自行修改
count_parameters(net)

2 输出

Pytorch 统计网络参数个数_第1张图片

你可能感兴趣的:(模型部件,pytorch,人工智能,python,神经网络,计算机视觉)