pytorch中指定或者冻结一层网络的权值(weight)

在实际开发中,我们有时需要指定一层神经网络的参数或者冻结他们。在torch中,我们都可以轻易实现。本文以简单的cnn为例,讲述指定一层网络的权值和冻结一层网络的权值。

一.指定权值

如果有这样一个神经网络

import torch.nn as nn
import torch.nn.functional as F
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5) # out_channels 6 filters
        self.conv2 = nn.Conv2d(in_channels=6,out_channels=12,kernel_size=5) # 可以用 out_features
        
        self.fc1 = nn.Linear(in_features= 12*4*4,out_features=120)
        self.fc2 = nn.Linear(in_features=120,out_features=60)
        self.out = nn.Linear(in_features=60,out_features=10)
        
        
    def forward(self,t):
        # input layer
        t = t 
        
        # conv layer 1
        t = self.conv1(t)
        t = F.relu(t)
        t = F.max_pool2d(t,kernel_size=2,stride=2)
        
        # conv layer 2
        t = self.conv2(t)
        t = F.relu(t)
        t = F.max_pool2d(t,kernel_size=2,stride=2)

        # linear layers
        t = t.reshape(-1,12*4*4)
         
        t = self.fc1(t)
        t = F.relu(t)
        
        t = self.fc2(t)
        t = F.relu(t)
        
        t = self.out(t)
        t = F.softmax(t, axis = 1)
        return t 

由于一层神经网络,如conv1,为Network类的一个属性(attribute),而conv1是nn.Conv2d的一个实例(instance)。在nn.Conv2d中转到定义往下翻,可以看到他的属性:

class Conv2d(_ConvNd):
# 定义里面往下翻
"""
Attributes:
        weight (Tensor): the learnable weights of the module of shape
            :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
            :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
            The values of these weights are sampled from
            :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
            :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
        bias (Tensor):   the learnable bias of the module of shape
            (out_channels). If :attr:`bias` is ``True``,
            then the values of these weights are
            sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
            :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
"""

我们先在notebook中看conv1

network.conv1

可以看到,他属于nn.Conv2d类

Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))

再引用属性

network.conv1.weight
Output exceeds the size limit. Open the full output data in a text editor
Parameter containing:
tensor([[[[-0.0017,  0.0833,  0.1076, -0.1180, -0.0225],
          [ 0.1380,  0.0366,  0.0529,  0.0363, -0.1543],
          [ 0.1068,  0.1220, -0.1522,  0.0381,  0.0512],
          [ 0.1212, -0.0746, -0.1923, -0.1805,  0.0160],
          [ 0.0523,  0.0680, -0.1754,  0.1557, -0.0219]]],

        [[[ 0.1577,  0.1780, -0.0353,  0.0785, -0.1649],
          [-0.0598, -0.1900, -0.1578,  0.0462,  0.1965],
          [ 0.0830, -0.0344,  0.0268, -0.0406, -0.0303],
...
        [[[ 0.1930, -0.1314,  0.1577,  0.0067, -0.1153],
          [ 0.1750,  0.1161, -0.0069, -0.0044,  0.1100],
          [-0.1456,  0.0942, -0.0637,  0.0621,  0.1154],
          [ 0.0873,  0.1486,  0.1285,  0.0642,  0.0339],
          [-0.0197, -0.0500,  0.0105,  0.0326, -0.0809]]]], requires_grad=True)

调用

network.conv1.weight.shape

输出 

torch.Size([6, 1, 5, 5])

有时候要拿到张量,这样写拿权重的张量

network.conv1.weight.data

理解了这些,我们就可以改参数了,为了我写的方便(doge),我们改个简单的全连接层。如果要改我们的卷积层,要先初始化一个torch.Size([6, 1, 5, 5])大小的张量,道理都一样。

in_feaure = torch.tensor([1,2,3,4],dtype = torch.float32)
weight_matrix = torch.tensor([
    [1,2,3,4],
    [2,3,4,5],
    [3,4,5,6]
],dtype = torch.float32)
fc = nn.Linear(in_features=4,out_features=3,bias=False)

我们直接使用weight_matrix换掉权值,注意改的是weight.data而不是weight。

# 冻结
fc.weight.data = nn.Parameter(weight_matrix, requires_grad = False)

# 不冻结
fc.weight.data = nn.Parameter(weight_matrix)

由于fc.weight属性要改的话只能是nn.Parameter或者none,所以要先转类型。

二.改冻结

由于fc.weight类型为nn.Parameter,转到nn.Parameter定义,看到Args,所以我们可以直接改requires_grad。

class Parameter(torch.Tensor):
#往下看
    """
    Args:
        data (Tensor): parameter tensor.
        requires_grad (bool, optional): if the parameter requires gradient. See
            :ref:`locally-disable-grad-doc` for more details. Default: `True`
    """
fc.weight.requires_grad = False

你可能感兴趣的:(python,torch,pytorch,深度学习,cnn)