[ pytorch ] ——基本使用:(5) 模型权重(参数)学习

1. 查看/调用 模型的权重.

import torch
import torch.nn as nn
from torchvision import models

class MyModel(nn.Module):
    def __init__(self, ):  # input the dim of output fea-map of Resnet:
        super(MyModel, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.gap = nn.AdaptiveAvgPool1d(1)

        self.fc = nn.Linear(2048, 512)

    def forward(self, input):  # input is 2048!

        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.gap(x)
        x = self.fc(x)

        return x

##############################

# 模型准备
model = MyModel()

blank = ' '
print('-----------------------------------------------')
print('|   weight name   |        weight shape       |')
print('-----------------------------------------------')

for index, (key, w_variable) in enumerate(model.named_parameters()):
    if len(key)<=15: key = key + (15-len(key))*blank
    w_variable_blank = ''
    if len(w_variable.shape) == 1:
        if w_variable.shape[0] >= 100: w_variable_blank = 8*blank
        else: w_variable_blank = 9*blank
    elif len(w_variable.shape) == 2:
        if w_variable.shape[0] >= 100: w_variable_blank = 2*blank
        else: w_variable_blank = 3*blank

    print('| {} | {}{} |'.format(key, w_variable.shape, w_variable_blank))
    key = 0
print('-----------------------------------------------')

[结果]
-----------------------------------------------
|   weight name   |        weight shape       |
-----------------------------------------------
| conv1.weight    | torch.Size([64, 3, 7, 7]) |
| bn1.weight      | torch.Size([64])          |
| bn1.bias        | torch.Size([64])          |
| fc.weight       | torch.Size([512, 2048])   |
| fc.bias         | torch.Size([512])         |
-----------------------------------------------

可以看到, 通过打印各层的 [权重名称] 和 [权重tensor的形状], 可以清晰的看到各种网络的权重形状, 为进一步操作权重提供方便.

 

 

2. 打印模型参数量

转自:https://blog.csdn.net/guilutian0541/article/details/81977850

################
###   模型定义
# -------------
class MyModel(nn.Module):
    def __init__(self, feat_dim):   # input the dim of output fea-map of Resnet:
        super(MyModel, self).__init__()
        ...
    def forward(self, input):   # input is 2048!
        ...
        return x

net = MyModel()

######################################
params = list(net.parameters())
k = 0
for i in params:
    l = 1
    print("该层的结构:" + str(list(i.size())))
    for j in i.size():
        l *= j
    print("该层参数和:" + str(l))
    k = k + l
print("总参数数量和:" + str(k))
######################################

 

你可能感兴趣的:(Pytorch)