Pytorch计算模型的参数量

比如说我计算TransUNet(Transformer + UNet)的参数量
(计算之前需要安装thop库 pip install thop)

from thop import profile, clever_format
flops, params = profile(net, inputs=(inputs,))
macs, params = clever_format([flops, params], "%.3f") # 格式化输出
print('flops':, macs) # 计算量
print('params:',params) # 模型参数量

最后输出:
flops:128.677G
params:93.192M

可以发现在UNet中加入transformer之后,参数量还是增加不少的(DUNet的参数量是:19.219M)
Pytorch计算模型的参数量_第1张图片

你可能感兴趣的:(Python技巧,python,机器学习,人工智能,深度学习)