2021-03-08 Speech-Transformer项目模型参数统计

读入已训练好并保存的模型

pthfile = r'/home/user1/Project/Speech-Transformer/egs/aishell/exp/train_m1_n6_in80_elayer6_head8_k64_v64_model512_inner2048_drop0.1_pe5000_emb512_dlayer6_share1_ls0.1_epoch150_shuffle1_bs16_bf30000_mli800_mlo150_k0.2_warm4000/final.pth.tar'  # 
net = torch.load(pthfile, map_location=torch.device('cpu'))  # 由于模型原本是用GPU保存的,如果电脑上没有GPU,可以转化到CPU上

print(type(net))  # 类型是 dict
print(len(net))   # 长度为 22,即存在22个 key-value 键值对
执行结果

列出所有键值

for k in net.keys():
    print(k) 
# 22个key
# LFR_m  LFR_n    d_input    n_layers_enc    n_head    d_k    d_v
# d_model    d_inner    dropout    pe_maxlen    sos_id    eos_id
# vocab_size    d_word_vec    n_layers_dec    tgt_emb_prj_weight_sharing
# state_dict    optim_dict    epoch    tr_loss    cv_loss

打印出当前保存的epoch数

print(net["epoch"])
>>> 118

统计模型参数量方法(一)

统计state_dict内包含的参数数量

psum = 0
for key, value in net["state_dict"].items():
    print(key)
    print(value.size())
    print(value.numel())
    psum += value.numel()
    # print(key, value.size(), sep=" ")
print(psum)

输出结果:

encoder.linear_in.weight
torch.Size([512, 80])
40960
encoder.linear_in.bias
torch.Size([512])
512
encoder.layer_norm_in.weight
torch.Size([512])
512
encoder.layer_norm_in.bias
torch.Size([512])
512
encoder.positional_encoding.pe
torch.Size([1, 5000, 512])
2560000
encoder.layer_stack.0.slf_attn.w_qs.weight
torch.Size([512, 512])
262144
encoder.layer_stack.0.slf_attn.w_qs.bias
torch.Size([512])
512
encoder.layer_stack.0.slf_attn.w_ks.weight
torch.Size([512, 512])
262144
encoder.layer_stack.0.slf_attn.w_ks.bias
torch.Size([512])
512
encoder.layer_stack.0.slf_attn.w_vs.weight
torch.Size([512, 512])
262144
encoder.layer_stack.0.slf_attn.w_vs.bias
torch.Size([512])
512
encoder.layer_stack.0.slf_attn.layer_norm.weight
torch.Size([512])
512
encoder.layer_stack.0.slf_attn.layer_norm.bias
torch.Size([512])
512
encoder.layer_stack.0.slf_attn.fc.weight
torch.Size([512, 512])
262144
encoder.layer_stack.0.slf_attn.fc.bias
torch.Size([512])
512
encoder.layer_stack.0.pos_ffn.w_1.weight
torch.Size([2048, 512])
1048576
encoder.layer_stack.0.pos_ffn.w_1.bias
torch.Size([2048])
2048
encoder.layer_stack.0.pos_ffn.w_2.weight
torch.Size([512, 2048])
1048576
encoder.layer_stack.0.pos_ffn.w_2.bias
torch.Size([512])
512
encoder.layer_stack.0.pos_ffn.layer_norm.weight
torch.Size([512])
512
encoder.layer_stack.0.pos_ffn.layer_norm.bias
torch.Size([512])
512
encoder.layer_stack.1.slf_attn.w_qs.weight
torch.Size([512, 512])
262144
encoder.layer_stack.1.slf_attn.w_qs.bias
torch.Size([512])
512
encoder.layer_stack.1.slf_attn.w_ks.weight
torch.Size([512, 512])
262144
encoder.layer_stack.1.slf_attn.w_ks.bias
torch.Size([512])
512
encoder.layer_stack.1.slf_attn.w_vs.weight
torch.Size([512, 512])
262144
encoder.layer_stack.1.slf_attn.w_vs.bias
torch.Size([512])
512
encoder.layer_stack.1.slf_attn.layer_norm.weight
torch.Size([512])
512
encoder.layer_stack.1.slf_attn.layer_norm.bias
torch.Size([512])
512
encoder.layer_stack.1.slf_attn.fc.weight
torch.Size([512, 512])
262144
encoder.layer_stack.1.slf_attn.fc.bias
torch.Size([512])
512
encoder.layer_stack.1.pos_ffn.w_1.weight
torch.Size([2048, 512])
1048576
encoder.layer_stack.1.pos_ffn.w_1.bias
torch.Size([2048])
2048
encoder.layer_stack.1.pos_ffn.w_2.weight
torch.Size([512, 2048])
1048576
encoder.layer_stack.1.pos_ffn.w_2.bias
torch.Size([512])
512
encoder.layer_stack.1.pos_ffn.layer_norm.weight
torch.Size([512])
512
encoder.layer_stack.1.pos_ffn.layer_norm.bias
torch.Size([512])
512
encoder.layer_stack.2.slf_attn.w_qs.weight
torch.Size([512, 512])
262144
encoder.layer_stack.2.slf_attn.w_qs.bias
torch.Size([512])
512
encoder.layer_stack.2.slf_attn.w_ks.weight
torch.Size([512, 512])
262144
encoder.layer_stack.2.slf_attn.w_ks.bias
torch.Size([512])
512
encoder.layer_stack.2.slf_attn.w_vs.weight
torch.Size([512, 512])
262144
encoder.layer_stack.2.slf_attn.w_vs.bias
torch.Size([512])
512
encoder.layer_stack.2.slf_attn.layer_norm.weight
torch.Size([512])
512
encoder.layer_stack.2.slf_attn.layer_norm.bias
torch.Size([512])
512
encoder.layer_stack.2.slf_attn.fc.weight
torch.Size([512, 512])
262144
encoder.layer_stack.2.slf_attn.fc.bias
torch.Size([512])
512
encoder.layer_stack.2.pos_ffn.w_1.weight
torch.Size([2048, 512])
1048576
encoder.layer_stack.2.pos_ffn.w_1.bias
torch.Size([2048])
2048
encoder.layer_stack.2.pos_ffn.w_2.weight
torch.Size([512, 2048])
1048576
encoder.layer_stack.2.pos_ffn.w_2.bias
torch.Size([512])
512
encoder.layer_stack.2.pos_ffn.layer_norm.weight
torch.Size([512])
512
encoder.layer_stack.2.pos_ffn.layer_norm.bias
torch.Size([512])
512
encoder.layer_stack.3.slf_attn.w_qs.weight
torch.Size([512, 512])
262144
encoder.layer_stack.3.slf_attn.w_qs.bias
torch.Size([512])
512
encoder.layer_stack.3.slf_attn.w_ks.weight
torch.Size([512, 512])
262144
encoder.layer_stack.3.slf_attn.w_ks.bias
torch.Size([512])
512
encoder.layer_stack.3.slf_attn.w_vs.weight
torch.Size([512, 512])
262144
encoder.layer_stack.3.slf_attn.w_vs.bias
torch.Size([512])
512
encoder.layer_stack.3.slf_attn.layer_norm.weight
torch.Size([512])
512
encoder.layer_stack.3.slf_attn.layer_norm.bias
torch.Size([512])
512
encoder.layer_stack.3.slf_attn.fc.weight
torch.Size([512, 512])
262144
encoder.layer_stack.3.slf_attn.fc.bias
torch.Size([512])
512
encoder.layer_stack.3.pos_ffn.w_1.weight
torch.Size([2048, 512])
1048576
encoder.layer_stack.3.pos_ffn.w_1.bias
torch.Size([2048])
2048
encoder.layer_stack.3.pos_ffn.w_2.weight
torch.Size([512, 2048])
1048576
encoder.layer_stack.3.pos_ffn.w_2.bias
torch.Size([512])
512
encoder.layer_stack.3.pos_ffn.layer_norm.weight
torch.Size([512])
512
encoder.layer_stack.3.pos_ffn.layer_norm.bias
torch.Size([512])
512
encoder.layer_stack.4.slf_attn.w_qs.weight
torch.Size([512, 512])
262144
encoder.layer_stack.4.slf_attn.w_qs.bias
torch.Size([512])
512
encoder.layer_stack.4.slf_attn.w_ks.weight
torch.Size([512, 512])
262144
encoder.layer_stack.4.slf_attn.w_ks.bias
torch.Size([512])
512
encoder.layer_stack.4.slf_attn.w_vs.weight
torch.Size([512, 512])
262144
encoder.layer_stack.4.slf_attn.w_vs.bias
torch.Size([512])
512
encoder.layer_stack.4.slf_attn.layer_norm.weight
torch.Size([512])
512
encoder.layer_stack.4.slf_attn.layer_norm.bias
torch.Size([512])
512
encoder.layer_stack.4.slf_attn.fc.weight
torch.Size([512, 512])
262144
encoder.layer_stack.4.slf_attn.fc.bias
torch.Size([512])
512
encoder.layer_stack.4.pos_ffn.w_1.weight
torch.Size([2048, 512])
1048576
encoder.layer_stack.4.pos_ffn.w_1.bias
torch.Size([2048])
2048
encoder.layer_stack.4.pos_ffn.w_2.weight
torch.Size([512, 2048])
1048576
encoder.layer_stack.4.pos_ffn.w_2.bias
torch.Size([512])
512
encoder.layer_stack.4.pos_ffn.layer_norm.weight
torch.Size([512])
512
encoder.layer_stack.4.pos_ffn.layer_norm.bias
torch.Size([512])
512
encoder.layer_stack.5.slf_attn.w_qs.weight
torch.Size([512, 512])
262144
encoder.layer_stack.5.slf_attn.w_qs.bias
torch.Size([512])
512
encoder.layer_stack.5.slf_attn.w_ks.weight
torch.Size([512, 512])
262144
encoder.layer_stack.5.slf_attn.w_ks.bias
torch.Size([512])
512
encoder.layer_stack.5.slf_attn.w_vs.weight
torch.Size([512, 512])
262144
encoder.layer_stack.5.slf_attn.w_vs.bias
torch.Size([512])
512
encoder.layer_stack.5.slf_attn.layer_norm.weight
torch.Size([512])
512
encoder.layer_stack.5.slf_attn.layer_norm.bias
torch.Size([512])
512
encoder.layer_stack.5.slf_attn.fc.weight
torch.Size([512, 512])
262144
encoder.layer_stack.5.slf_attn.fc.bias
torch.Size([512])
512
encoder.layer_stack.5.pos_ffn.w_1.weight
torch.Size([2048, 512])
1048576
encoder.layer_stack.5.pos_ffn.w_1.bias
torch.Size([2048])
2048
encoder.layer_stack.5.pos_ffn.w_2.weight
torch.Size([512, 2048])
1048576
encoder.layer_stack.5.pos_ffn.w_2.bias
torch.Size([512])
512
encoder.layer_stack.5.pos_ffn.layer_norm.weight
torch.Size([512])
512
encoder.layer_stack.5.pos_ffn.layer_norm.bias
torch.Size([512])
512
decoder.tgt_word_emb.weight
torch.Size([4233, 512])
2167296
decoder.positional_encoding.pe
torch.Size([1, 5000, 512])
2560000
decoder.layer_stack.0.slf_attn.w_qs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.0.slf_attn.w_qs.bias
torch.Size([512])
512
decoder.layer_stack.0.slf_attn.w_ks.weight
torch.Size([512, 512])
262144
decoder.layer_stack.0.slf_attn.w_ks.bias
torch.Size([512])
512
decoder.layer_stack.0.slf_attn.w_vs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.0.slf_attn.w_vs.bias
torch.Size([512])
512
decoder.layer_stack.0.slf_attn.layer_norm.weight
torch.Size([512])
512
decoder.layer_stack.0.slf_attn.layer_norm.bias
torch.Size([512])
512
decoder.layer_stack.0.slf_attn.fc.weight
torch.Size([512, 512])
262144
decoder.layer_stack.0.slf_attn.fc.bias
torch.Size([512])
512
decoder.layer_stack.0.enc_attn.w_qs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.0.enc_attn.w_qs.bias
torch.Size([512])
512
decoder.layer_stack.0.enc_attn.w_ks.weight
torch.Size([512, 512])
262144
decoder.layer_stack.0.enc_attn.w_ks.bias
torch.Size([512])
512
decoder.layer_stack.0.enc_attn.w_vs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.0.enc_attn.w_vs.bias
torch.Size([512])
512
decoder.layer_stack.0.enc_attn.layer_norm.weight
torch.Size([512])
512
decoder.layer_stack.0.enc_attn.layer_norm.bias
torch.Size([512])
512
decoder.layer_stack.0.enc_attn.fc.weight
torch.Size([512, 512])
262144
decoder.layer_stack.0.enc_attn.fc.bias
torch.Size([512])
512
decoder.layer_stack.0.pos_ffn.w_1.weight
torch.Size([2048, 512])
1048576
decoder.layer_stack.0.pos_ffn.w_1.bias
torch.Size([2048])
2048
decoder.layer_stack.0.pos_ffn.w_2.weight
torch.Size([512, 2048])
1048576
decoder.layer_stack.0.pos_ffn.w_2.bias
torch.Size([512])
512
decoder.layer_stack.0.pos_ffn.layer_norm.weight
torch.Size([512])
512
decoder.layer_stack.0.pos_ffn.layer_norm.bias
torch.Size([512])
512
decoder.layer_stack.1.slf_attn.w_qs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.1.slf_attn.w_qs.bias
torch.Size([512])
512
decoder.layer_stack.1.slf_attn.w_ks.weight
torch.Size([512, 512])
262144
decoder.layer_stack.1.slf_attn.w_ks.bias
torch.Size([512])
512
decoder.layer_stack.1.slf_attn.w_vs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.1.slf_attn.w_vs.bias
torch.Size([512])
512
decoder.layer_stack.1.slf_attn.layer_norm.weight
torch.Size([512])
512
decoder.layer_stack.1.slf_attn.layer_norm.bias
torch.Size([512])
512
decoder.layer_stack.1.slf_attn.fc.weight
torch.Size([512, 512])
262144
decoder.layer_stack.1.slf_attn.fc.bias
torch.Size([512])
512
decoder.layer_stack.1.enc_attn.w_qs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.1.enc_attn.w_qs.bias
torch.Size([512])
512
decoder.layer_stack.1.enc_attn.w_ks.weight
torch.Size([512, 512])
262144
decoder.layer_stack.1.enc_attn.w_ks.bias
torch.Size([512])
512
decoder.layer_stack.1.enc_attn.w_vs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.1.enc_attn.w_vs.bias
torch.Size([512])
512
decoder.layer_stack.1.enc_attn.layer_norm.weight
torch.Size([512])
512
decoder.layer_stack.1.enc_attn.layer_norm.bias
torch.Size([512])
512
decoder.layer_stack.1.enc_attn.fc.weight
torch.Size([512, 512])
262144
decoder.layer_stack.1.enc_attn.fc.bias
torch.Size([512])
512
decoder.layer_stack.1.pos_ffn.w_1.weight
torch.Size([2048, 512])
1048576
decoder.layer_stack.1.pos_ffn.w_1.bias
torch.Size([2048])
2048
decoder.layer_stack.1.pos_ffn.w_2.weight
torch.Size([512, 2048])
1048576
decoder.layer_stack.1.pos_ffn.w_2.bias
torch.Size([512])
512
decoder.layer_stack.1.pos_ffn.layer_norm.weight
torch.Size([512])
512
decoder.layer_stack.1.pos_ffn.layer_norm.bias
torch.Size([512])
512
decoder.layer_stack.2.slf_attn.w_qs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.2.slf_attn.w_qs.bias
torch.Size([512])
512
decoder.layer_stack.2.slf_attn.w_ks.weight
torch.Size([512, 512])
262144
decoder.layer_stack.2.slf_attn.w_ks.bias
torch.Size([512])
512
decoder.layer_stack.2.slf_attn.w_vs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.2.slf_attn.w_vs.bias
torch.Size([512])
512
decoder.layer_stack.2.slf_attn.layer_norm.weight
torch.Size([512])
512
decoder.layer_stack.2.slf_attn.layer_norm.bias
torch.Size([512])
512
decoder.layer_stack.2.slf_attn.fc.weight
torch.Size([512, 512])
262144
decoder.layer_stack.2.slf_attn.fc.bias
torch.Size([512])
512
decoder.layer_stack.2.enc_attn.w_qs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.2.enc_attn.w_qs.bias
torch.Size([512])
512
decoder.layer_stack.2.enc_attn.w_ks.weight
torch.Size([512, 512])
262144
decoder.layer_stack.2.enc_attn.w_ks.bias
torch.Size([512])
512
decoder.layer_stack.2.enc_attn.w_vs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.2.enc_attn.w_vs.bias
torch.Size([512])
512
decoder.layer_stack.2.enc_attn.layer_norm.weight
torch.Size([512])
512
decoder.layer_stack.2.enc_attn.layer_norm.bias
torch.Size([512])
512
decoder.layer_stack.2.enc_attn.fc.weight
torch.Size([512, 512])
262144
decoder.layer_stack.2.enc_attn.fc.bias
torch.Size([512])
512
decoder.layer_stack.2.pos_ffn.w_1.weight
torch.Size([2048, 512])
1048576
decoder.layer_stack.2.pos_ffn.w_1.bias
torch.Size([2048])
2048
decoder.layer_stack.2.pos_ffn.w_2.weight
torch.Size([512, 2048])
1048576
decoder.layer_stack.2.pos_ffn.w_2.bias
torch.Size([512])
512
decoder.layer_stack.2.pos_ffn.layer_norm.weight
torch.Size([512])
512
decoder.layer_stack.2.pos_ffn.layer_norm.bias
torch.Size([512])
512
decoder.layer_stack.3.slf_attn.w_qs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.3.slf_attn.w_qs.bias
torch.Size([512])
512
decoder.layer_stack.3.slf_attn.w_ks.weight
torch.Size([512, 512])
262144
decoder.layer_stack.3.slf_attn.w_ks.bias
torch.Size([512])
512
decoder.layer_stack.3.slf_attn.w_vs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.3.slf_attn.w_vs.bias
torch.Size([512])
512
decoder.layer_stack.3.slf_attn.layer_norm.weight
torch.Size([512])
512
decoder.layer_stack.3.slf_attn.layer_norm.bias
torch.Size([512])
512
decoder.layer_stack.3.slf_attn.fc.weight
torch.Size([512, 512])
262144
decoder.layer_stack.3.slf_attn.fc.bias
torch.Size([512])
512
decoder.layer_stack.3.enc_attn.w_qs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.3.enc_attn.w_qs.bias
torch.Size([512])
512
decoder.layer_stack.3.enc_attn.w_ks.weight
torch.Size([512, 512])
262144
decoder.layer_stack.3.enc_attn.w_ks.bias
torch.Size([512])
512
decoder.layer_stack.3.enc_attn.w_vs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.3.enc_attn.w_vs.bias
torch.Size([512])
512
decoder.layer_stack.3.enc_attn.layer_norm.weight
torch.Size([512])
512
decoder.layer_stack.3.enc_attn.layer_norm.bias
torch.Size([512])
512
decoder.layer_stack.3.enc_attn.fc.weight
torch.Size([512, 512])
262144
decoder.layer_stack.3.enc_attn.fc.bias
torch.Size([512])
512
decoder.layer_stack.3.pos_ffn.w_1.weight
torch.Size([2048, 512])
1048576
decoder.layer_stack.3.pos_ffn.w_1.bias
torch.Size([2048])
2048
decoder.layer_stack.3.pos_ffn.w_2.weight
torch.Size([512, 2048])
1048576
decoder.layer_stack.3.pos_ffn.w_2.bias
torch.Size([512])
512
decoder.layer_stack.3.pos_ffn.layer_norm.weight
torch.Size([512])
512
decoder.layer_stack.3.pos_ffn.layer_norm.bias
torch.Size([512])
512
decoder.layer_stack.4.slf_attn.w_qs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.4.slf_attn.w_qs.bias
torch.Size([512])
512
decoder.layer_stack.4.slf_attn.w_ks.weight
torch.Size([512, 512])
262144
decoder.layer_stack.4.slf_attn.w_ks.bias
torch.Size([512])
512
decoder.layer_stack.4.slf_attn.w_vs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.4.slf_attn.w_vs.bias
torch.Size([512])
512
decoder.layer_stack.4.slf_attn.layer_norm.weight
torch.Size([512])
512
decoder.layer_stack.4.slf_attn.layer_norm.bias
torch.Size([512])
512
decoder.layer_stack.4.slf_attn.fc.weight
torch.Size([512, 512])
262144
decoder.layer_stack.4.slf_attn.fc.bias
torch.Size([512])
512
decoder.layer_stack.4.enc_attn.w_qs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.4.enc_attn.w_qs.bias
torch.Size([512])
512
decoder.layer_stack.4.enc_attn.w_ks.weight
torch.Size([512, 512])
262144
decoder.layer_stack.4.enc_attn.w_ks.bias
torch.Size([512])
512
decoder.layer_stack.4.enc_attn.w_vs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.4.enc_attn.w_vs.bias
torch.Size([512])
512
decoder.layer_stack.4.enc_attn.layer_norm.weight
torch.Size([512])
512
decoder.layer_stack.4.enc_attn.layer_norm.bias
torch.Size([512])
512
decoder.layer_stack.4.enc_attn.fc.weight
torch.Size([512, 512])
262144
decoder.layer_stack.4.enc_attn.fc.bias
torch.Size([512])
512
decoder.layer_stack.4.pos_ffn.w_1.weight
torch.Size([2048, 512])
1048576
decoder.layer_stack.4.pos_ffn.w_1.bias
torch.Size([2048])
2048
decoder.layer_stack.4.pos_ffn.w_2.weight
torch.Size([512, 2048])
1048576
decoder.layer_stack.4.pos_ffn.w_2.bias
torch.Size([512])
512
decoder.layer_stack.4.pos_ffn.layer_norm.weight
torch.Size([512])
512
decoder.layer_stack.4.pos_ffn.layer_norm.bias
torch.Size([512])
512
decoder.layer_stack.5.slf_attn.w_qs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.5.slf_attn.w_qs.bias
torch.Size([512])
512
decoder.layer_stack.5.slf_attn.w_ks.weight
torch.Size([512, 512])
262144
decoder.layer_stack.5.slf_attn.w_ks.bias
torch.Size([512])
512
decoder.layer_stack.5.slf_attn.w_vs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.5.slf_attn.w_vs.bias
torch.Size([512])
512
decoder.layer_stack.5.slf_attn.layer_norm.weight
torch.Size([512])
512
decoder.layer_stack.5.slf_attn.layer_norm.bias
torch.Size([512])
512
decoder.layer_stack.5.slf_attn.fc.weight
torch.Size([512, 512])
262144
decoder.layer_stack.5.slf_attn.fc.bias
torch.Size([512])
512
decoder.layer_stack.5.enc_attn.w_qs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.5.enc_attn.w_qs.bias
torch.Size([512])
512
decoder.layer_stack.5.enc_attn.w_ks.weight
torch.Size([512, 512])
262144
decoder.layer_stack.5.enc_attn.w_ks.bias
torch.Size([512])
512
decoder.layer_stack.5.enc_attn.w_vs.weight
torch.Size([512, 512])
262144
decoder.layer_stack.5.enc_attn.w_vs.bias
torch.Size([512])
512
decoder.layer_stack.5.enc_attn.layer_norm.weight
torch.Size([512])
512
decoder.layer_stack.5.enc_attn.layer_norm.bias
torch.Size([512])
512
decoder.layer_stack.5.enc_attn.fc.weight
torch.Size([512, 512])
262144
decoder.layer_stack.5.enc_attn.fc.bias
torch.Size([512])
512
decoder.layer_stack.5.pos_ffn.w_1.weight
torch.Size([2048, 512])
1048576
decoder.layer_stack.5.pos_ffn.w_1.bias
torch.Size([2048])
2048
decoder.layer_stack.5.pos_ffn.w_2.weight
torch.Size([512, 2048])
1048576
decoder.layer_stack.5.pos_ffn.w_2.bias
torch.Size([512])
512
decoder.layer_stack.5.pos_ffn.layer_norm.weight
torch.Size([512])
512
decoder.layer_stack.5.pos_ffn.layer_norm.bias
torch.Size([512])
512
decoder.tgt_word_prj.weight
torch.Size([4233, 512])
2167296
53635584

打印optim_dict内部值

for key, value in net["optim_dict"].items():
    print(key)
    print(type(value))
结果显示

打印optim_dict内部param_groups

groups = net["optim_dict"]["param_groups"]
print(groups)
print(len(groups))

执行结果:

[{'lr': 2.9553222446278096e-05, 'betas': (0.9, 0.98), 'eps': 1e-09, 'weight_decay': 0, 'amsgrad': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256]}]
1

打印optim_dict内部state

state = net["optim_dict"]["state"]
print(len(state)) 
for key, value in state.items():
    print(key, type(value), sep=" ")

执行结果:

257
0 
1 
2 
3 
4 
5 
6 
7 
8 
9 
10 
11 
12 
13 
14 
15 
16 
17 
18 
19 
20 
21 
22 
23 
24 
25 
26 
27 
28 
29 
30 
31 
32 
33 
34 
35 
36 
37 
38 
39 
40 
41 
42 
43 
44 
45 
46 
47 
48 
49 
50 
51 
52 
53 
54 
55 
56 
57 
58 
59 
60 
61 
62 
63 
64 
65 
66 
67 
68 
69 
70 
71 
72 
73 
74 
75 
76 
77 
78 
79 
80 
81 
82 
83 
84 
85 
86 
87 
88 
89 
90 
91 
92 
93 
94 
95 
96 
97 
98 
99 
100 
101 
102 
103 
104 
105 
106 
107 
108 
109 
110 
111 
112 
113 
114 
115 
116 
117 
118 
119 
120 
121 
122 
123 
124 
125 
126 
127 
128 
129 
130 
131 
132 
133 
134 
135 
136 
137 
138 
139 
140 
141 
142 
143 
144 
145 
146 
147 
148 
149 
150 
151 
152 
153 
154 
155 
156 
157 
158 
159 
160 
161 
162 
163 
164 
165 
166 
167 
168 
169 
170 
171 
172 
173 
174 
175 
176 
177 
178 
179 
180 
181 
182 
183 
184 
185 
186 
187 
188 
189 
190 
191 
192 
193 
194 
195 
196 
197 
198 
199 
200 
201 
202 
203 
204 
205 
206 
207 
208 
209 
210 
211 
212 
213 
214 
215 
216 
217 
218 
219 
220 
221 
222 
223 
224 
225 
226 
227 
228 
229 
230 
231 
232 
233 
234 
235 
236 
237 
238 
239 
240 
241 
242 
243 
244 
245 
246 
247 
248 
249 
250 
251 
252 
253 
254 
255 
256 

统计模型参数量方法(二)

读入模型中state_dict

model.load_state_dict(net['state_dict'])

统计参数数量

num_params = 0
for param in model.parameters():
    num_params += param.numel()
print(num_params / 1e6)
print(num_params)

执行结果:

46.348288
46348288

总结:

方法(二)计算的是准确的模型参数量,方法(一)可能包含了些其他参数。
方法一 53635584 个参数
方法二 46348288 个参数
相差 7287296 个参数
经过对比,刚好相差了如下几个参数:

encoder.positional_encoding.pe
torch.Size([1, 5000, 512])
2560000
decoder.tgt_word_emb.weight
torch.Size([4233, 512])
2167296
decoder.positional_encoding.pe
torch.Size([1, 5000, 512])
2560000

你可能感兴趣的:(2021-03-08 Speech-Transformer项目模型参数统计)