bert 模型参数统计

bert 模型参数量分析

使用huggingface transformers中的bert模型,分析统计模型的参数量

huggingface 模型导入

	import torch
    from transformers import BertTokenizer, BertModel
	bertModel = BertModel.from_pretrained("bert-base-chinese", output_hidden_states=True, output_attentions=True)
	total = sum(p.numel() for p in bertModel.parameters())
	print("total param:",total)

	输出如下:
	total param: 102267648

上述代码统计了模型的总参数量,输出为102267684

下面从三个方面统计分析bert 模型参数量

1、embedding 层

    bert中的embedding有三种,分别为word embedding、position embedding、sentence embedding。
在bert-base-chinese这个模型中,词汇数量为21128,embedding维度为768,每条数据长度L为512。

word embedding参数量:21128*768
position embedding参数量:512*768
sentence embedding参数量:2*768
在embedding层最后有Layer Norm 层,改层的参数量为768+768,LN公式中的 α \alpha α β \beta β

embedding层中的参数为
21128*768+512*768+2*768+768+768 =16622592

2、self-attention层

    self-attention 一共有12层,每层中有两部分组成,分别为multihead-Attention 和Layer Norm层

multihead-Attention 中有Q、K、V三个转化矩阵和一个拼接矩阵,Q、K、V的shape为:768*12*64 +768 第一个768为embedding维度,12为head数量,64为子head的维度,最后加的768为模型中的bias。经过Q、K、V变化后的数据需要concat起来,额外需要一个768*768+768的拼接矩阵。
Layer Norm参数量:768+768
self-attention一层中的参数为:
(768*12*64 +768)*3+768*768+768 +768+768=2363904
一共12层,2363904 *12 = 28366848

3、feedforward层

    feedforward 一共有12层,每层中有两部分组成,分别为feedforward和Layer Norm层

feedforward 网络结构为 W 2 ( W 1 X + b 1 ) + b 2 W_2(W_1X+b_1)+b_2 W2(W1X+b1)+b2,有两个线性变换层 W 1 W_1 W1是从768–>7684, W 2 W_2 W2是从7684–>768, W 1 W_1 W1参数量为7687684+7684, W 2 W_2 W2参数量为7684*768+768,
Layer Norm参数量:768+768
feedforward一层中的参数为:
(768*768*4 +768*4)+(768*4*768+768) + 768+768 =4723968
一共12层,4723968*12 = 56687616

参数总计

embedding:16622592
self-attention:28366848
feedforward:56687616
在feedforward层后还有一个pooler层,维度为768*768,参数量为(768*768+768 weights+bias),为获取训练数据中第一个特殊字符[CLS]的词向量,进一步计算bert中的NSP任务中的loss

total = 16622592 +28366848+56687616 + 768*768+768= 102267648
与pytorch统计结果相同。

上述有不明白的地方,可以看看bert模型中每层的参数

以下为模型中每一层的参数量:

for name,param in bertModel.named_parameters():
   print(name)
   print(param.shape)

# 输出如下:

embeddings.word_embeddings.weight
torch.Size([21128, 768])
embeddings.position_embeddings.weight
torch.Size([512, 768])
embeddings.token_type_embeddings.weight
torch.Size([2, 768])
embeddings.LayerNorm.weight
torch.Size([768])
embeddings.LayerNorm.bias
torch.Size([768])
encoder.layer.0.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.0.attention.self.query.bias
torch.Size([768])
encoder.layer.0.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.0.attention.self.key.bias
torch.Size([768])
encoder.layer.0.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.0.attention.self.value.bias
torch.Size([768])
encoder.layer.0.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.0.attention.output.dense.bias
torch.Size([768])
encoder.layer.0.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.0.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.0.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.0.intermediate.dense.bias
torch.Size([3072])
encoder.layer.0.output.dense.weight
torch.Size([768, 3072])
encoder.layer.0.output.dense.bias
torch.Size([768])
encoder.layer.0.output.LayerNorm.weight
torch.Size([768])
encoder.layer.0.output.LayerNorm.bias
torch.Size([768])
encoder.layer.1.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.1.attention.self.query.bias
torch.Size([768])
encoder.layer.1.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.1.attention.self.key.bias
torch.Size([768])
encoder.layer.1.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.1.attention.self.value.bias
torch.Size([768])
encoder.layer.1.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.1.attention.output.dense.bias
torch.Size([768])
encoder.layer.1.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.1.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.1.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.1.intermediate.dense.bias
torch.Size([3072])
encoder.layer.1.output.dense.weight
torch.Size([768, 3072])
encoder.layer.1.output.dense.bias
torch.Size([768])
encoder.layer.1.output.LayerNorm.weight
torch.Size([768])
encoder.layer.1.output.LayerNorm.bias
torch.Size([768])
encoder.layer.2.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.2.attention.self.query.bias
torch.Size([768])
encoder.layer.2.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.2.attention.self.key.bias
torch.Size([768])
encoder.layer.2.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.2.attention.self.value.bias
torch.Size([768])
encoder.layer.2.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.2.attention.output.dense.bias
torch.Size([768])
encoder.layer.2.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.2.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.2.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.2.intermediate.dense.bias
torch.Size([3072])
encoder.layer.2.output.dense.weight
torch.Size([768, 3072])
encoder.layer.2.output.dense.bias
torch.Size([768])
encoder.layer.2.output.LayerNorm.weight
torch.Size([768])
encoder.layer.2.output.LayerNorm.bias
torch.Size([768])
encoder.layer.3.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.3.attention.self.query.bias
torch.Size([768])
encoder.layer.3.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.3.attention.self.key.bias
torch.Size([768])
encoder.layer.3.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.3.attention.self.value.bias
torch.Size([768])
encoder.layer.3.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.3.attention.output.dense.bias
torch.Size([768])
encoder.layer.3.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.3.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.3.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.3.intermediate.dense.bias
torch.Size([3072])
encoder.layer.3.output.dense.weight
torch.Size([768, 3072])
encoder.layer.3.output.dense.bias
torch.Size([768])
encoder.layer.3.output.LayerNorm.weight
torch.Size([768])
encoder.layer.3.output.LayerNorm.bias
torch.Size([768])
encoder.layer.4.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.4.attention.self.query.bias
torch.Size([768])
encoder.layer.4.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.4.attention.self.key.bias
torch.Size([768])
encoder.layer.4.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.4.attention.self.value.bias
torch.Size([768])
encoder.layer.4.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.4.attention.output.dense.bias
torch.Size([768])
encoder.layer.4.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.4.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.4.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.4.intermediate.dense.bias
torch.Size([3072])
encoder.layer.4.output.dense.weight
torch.Size([768, 3072])
encoder.layer.4.output.dense.bias
torch.Size([768])
encoder.layer.4.output.LayerNorm.weight
torch.Size([768])
encoder.layer.4.output.LayerNorm.bias
torch.Size([768])
encoder.layer.5.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.5.attention.self.query.bias
torch.Size([768])
encoder.layer.5.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.5.attention.self.key.bias
torch.Size([768])
encoder.layer.5.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.5.attention.self.value.bias
torch.Size([768])
encoder.layer.5.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.5.attention.output.dense.bias
torch.Size([768])
encoder.layer.5.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.5.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.5.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.5.intermediate.dense.bias
torch.Size([3072])
encoder.layer.5.output.dense.weight
torch.Size([768, 3072])
encoder.layer.5.output.dense.bias
torch.Size([768])
encoder.layer.5.output.LayerNorm.weight
torch.Size([768])
encoder.layer.5.output.LayerNorm.bias
torch.Size([768])
encoder.layer.6.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.6.attention.self.query.bias
torch.Size([768])
encoder.layer.6.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.6.attention.self.key.bias
torch.Size([768])
encoder.layer.6.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.6.attention.self.value.bias
torch.Size([768])
encoder.layer.6.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.6.attention.output.dense.bias
torch.Size([768])
encoder.layer.6.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.6.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.6.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.6.intermediate.dense.bias
torch.Size([3072])
encoder.layer.6.output.dense.weight
torch.Size([768, 3072])
encoder.layer.6.output.dense.bias
torch.Size([768])
encoder.layer.6.output.LayerNorm.weight
torch.Size([768])
encoder.layer.6.output.LayerNorm.bias
torch.Size([768])
encoder.layer.7.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.7.attention.self.query.bias
torch.Size([768])
encoder.layer.7.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.7.attention.self.key.bias
torch.Size([768])
encoder.layer.7.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.7.attention.self.value.bias
torch.Size([768])
encoder.layer.7.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.7.attention.output.dense.bias
torch.Size([768])
encoder.layer.7.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.7.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.7.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.7.intermediate.dense.bias
torch.Size([3072])
encoder.layer.7.output.dense.weight
torch.Size([768, 3072])
encoder.layer.7.output.dense.bias
torch.Size([768])
encoder.layer.7.output.LayerNorm.weight
torch.Size([768])
encoder.layer.7.output.LayerNorm.bias
torch.Size([768])
encoder.layer.8.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.8.attention.self.query.bias
torch.Size([768])
encoder.layer.8.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.8.attention.self.key.bias
torch.Size([768])
encoder.layer.8.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.8.attention.self.value.bias
torch.Size([768])
encoder.layer.8.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.8.attention.output.dense.bias
torch.Size([768])
encoder.layer.8.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.8.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.8.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.8.intermediate.dense.bias
torch.Size([3072])
encoder.layer.8.output.dense.weight
torch.Size([768, 3072])
encoder.layer.8.output.dense.bias
torch.Size([768])
encoder.layer.8.output.LayerNorm.weight
torch.Size([768])
encoder.layer.8.output.LayerNorm.bias
torch.Size([768])
encoder.layer.9.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.9.attention.self.query.bias
torch.Size([768])
encoder.layer.9.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.9.attention.self.key.bias
torch.Size([768])
encoder.layer.9.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.9.attention.self.value.bias
torch.Size([768])
encoder.layer.9.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.9.attention.output.dense.bias
torch.Size([768])
encoder.layer.9.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.9.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.9.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.9.intermediate.dense.bias
torch.Size([3072])
encoder.layer.9.output.dense.weight
torch.Size([768, 3072])
encoder.layer.9.output.dense.bias
torch.Size([768])
encoder.layer.9.output.LayerNorm.weight
torch.Size([768])
encoder.layer.9.output.LayerNorm.bias
torch.Size([768])
encoder.layer.10.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.10.attention.self.query.bias
torch.Size([768])
encoder.layer.10.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.10.attention.self.key.bias
torch.Size([768])
encoder.layer.10.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.10.attention.self.value.bias
torch.Size([768])
encoder.layer.10.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.10.attention.output.dense.bias
torch.Size([768])
encoder.layer.10.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.10.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.10.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.10.intermediate.dense.bias
torch.Size([3072])
encoder.layer.10.output.dense.weight
torch.Size([768, 3072])
encoder.layer.10.output.dense.bias
torch.Size([768])
encoder.layer.10.output.LayerNorm.weight
torch.Size([768])
encoder.layer.10.output.LayerNorm.bias
torch.Size([768])
encoder.layer.11.attention.self.query.weight
torch.Size([768, 768])
encoder.layer.11.attention.self.query.bias
torch.Size([768])
encoder.layer.11.attention.self.key.weight
torch.Size([768, 768])
encoder.layer.11.attention.self.key.bias
torch.Size([768])
encoder.layer.11.attention.self.value.weight
torch.Size([768, 768])
encoder.layer.11.attention.self.value.bias
torch.Size([768])
encoder.layer.11.attention.output.dense.weight
torch.Size([768, 768])
encoder.layer.11.attention.output.dense.bias
torch.Size([768])
encoder.layer.11.attention.output.LayerNorm.weight
torch.Size([768])
encoder.layer.11.attention.output.LayerNorm.bias
torch.Size([768])
encoder.layer.11.intermediate.dense.weight
torch.Size([3072, 768])
encoder.layer.11.intermediate.dense.bias
torch.Size([3072])
encoder.layer.11.output.dense.weight
torch.Size([768, 3072])
encoder.layer.11.output.dense.bias
torch.Size([768])
encoder.layer.11.output.LayerNorm.weight
torch.Size([768])
encoder.layer.11.output.LayerNorm.bias
torch.Size([768])
pooler.dense.weight
torch.Size([768, 768])
pooler.dense.bias
torch.Size([768])

你可能感兴趣的:(bert,自然语言处理,pytorch)