模型参数量(Params)/模型大小 & Pytorch统计模型参数量

模型参数量大小可以从保存的checkpoint文件直观看出来

total_params = sum(p.numel() for p in model.parameters())
total_params += sum(p.numel() for p in model.buffers())
print(f'{total_params:,} total parameters.')
print(f'{total_params/(1024*1024):.2f}M total parameters.')
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')
print(f'{total_trainable_params/(1024*1024):.2f}M training parameters.')

模型参数量(Params)/模型大小 & Pytorch统计模型参数量_第1张图片

有的地方算参数量/模型大小会乘以4,因为模型参数一般都是FP32存储的,FP32是单精度,占4个字节。要看具体的概念定义

如果统计各个部分的参数量

考虑一下是否需要统计buffer

_dict = {}
for _,param in enumerate(model.named_parameters()):
    # print(param[0])
    # print(param[1])
    total_params = param[1].numel()
    # print(f'{total_params:,} total parameters.')
    k = param[0].split('.')[0]
    if k in _dict.keys():
        _dict[k] += total_params
    else:
        _dict[k] = 0
        _dict[k] += total_params
    # print('----------------')
for k,v in _dict.items():
    print(k)
    print(v)
    print("%3.3fM parameters" %  (v / (1024*1024)))
    print('--------')

另一种方法

到时候把dict的item换成你所用模型的

def print_architecture(model):
    name = type(model).__name__
    result = '-------------------%s---------------------\n' % name
    total_num_params = 0
    for i, (name, child) in enumerate(model.named_children()):
        num_params = sum([p.numel() for p in child.parameters()])
        total_num_params += num_params
        for i, (name, grandchild) in enumerate(child.named_children()):
            num_params = sum([p.numel() for p in grandchild.parameters()])
    result += '[Network %s] Total number of parameters : %.3f M\n' % (name, total_num_params / (1024*1024))
    result += '-----------------------------------------------\n'
    print(result)



print(model)
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')
_dict = {}
_dict['encoder'] = 0
_dict['decoder'] = 0
_dict['stn_head'] = 0
for _,param in enumerate(model.named_parameters()):
    print(param[0])
    # print(param[1])
    total_params = param[1].numel()
    print(f'{total_params:,} total parameters.')
    k = param[0].split('.')[0]
    if k in _dict.keys():
        _dict[param[0].split('.')[0]] += total_params
    else:
        _dict[k] = 0
        _dict[param[0].split('.')[0]] += total_params
    print('----------------')
for k,v in _dict.items():
    print(k)
    print(v)
    print("%3.3fM parameters\n" %  (v / (1024*1024)))
    print('--------')
print_architecture(model)

常见模型的参数量

模型参数量(Params)/模型大小 & Pytorch统计模型参数量_第2张图片

 模型参数量(Params)/模型大小 & Pytorch统计模型参数量_第3张图片

你可能感兴趣的:(pytorch)