【Pytorch API笔记3】用torch.numel()来统计网络的参数量

如何统计网络的大小,可以试一试torch.numel()函数
torch.numel()函数,可以计算出单个tensor元素的个数

一、对单个tensor使用,求tensor元素的个数

x = torch.randn((1, 3, 5, 7))
x.numel()
torch.numel()

输出105

二、求整个网络的参数

  n_p = sum(x.numel() for x in model.parameters())  # number parameters

如下示意图,可以计算网络的参数量
一个线性层,输入维度为1,输出维度为100
这个网络有200个参数,可以用x.numel() 巧妙计算出整个网络所需要的参数量

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(1, 100) # 输入1、输出的维度都是100
    def forward(self, x):
        out = self.linear(x)
        return out
    
net = LinearModel()
n_p = sum(x.numel() for x in net.parameters())  # number parameters
print(n_p)  ##  ------>输出为200 

你可能感兴趣的:(深度框架,Pytorch,pytorch,深度学习,python)