Pytorch固定部分参数(层)进行训练

网络中所有操作对象都是Varoable对象,而Variable有两个参数可用于固定参数:requires_grad和volatile。

一:requires_grad参数设置

Method 1: 初始化时指定,如下语句所示:

x = Variable(torch.randn(2, 3), requires_grad=True)
y = Variable(torch.randn(2, 3), requires_grad=False)

注意:Variable中requires_grad的默认值为False,但是Module中的层在定义时,相关的Variable中的requires_grad默认都是True。在计算图中,如果有一个输入的requires_grad是True,那么输出的requires_grad也是True。所以为了更方便的进行参数固定,建议使用Method 2 。

Method 2:网络模型中设置。在训练中想要固定网络的底层,可以令这部分网络对应的子图的参数requires_grad为False。这样在反向传播的过程中就不会计算这些参数对应的梯度。需要在nn.Module中直接插入如下语句:

for p in self.parameters():
    p.requires_grad=False

 

For example 1: 在加载预训练模型后,在原来的基础上添加一部分的网络,这样可以固定原来的参数,然后只训练添加的这部分网络,结束后再全部训练。

class RESNET_attention(nn.Module):
    def __init__(self, model, pretrained):
        super(RESNET_attetnion, self).__init__()
        self.resnet = model(pretrained)
        for p in self.parameters():
            p.requires_grad = False
        self.f = nn.Conv2d(2048, 512, 1)
        self.g = nn.Conv2d(2048, 512, 1)
        self.h = nn.Conv2d(2048, 2048, 1)
        self.softmax = nn.Softmax(-1)
        self.gamma = nn.Parameter(torch.FloatTensor([0.0]))
        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.resnet.fc = nn.Linear(2048, 10)

    这样就将for循环以上的参数固定,只训练下面的参数,f,g,h,gamma,fc等,注意需要在optimizer中添加以下语句:

         

filter(lambda p:p.requires_grad, model.parameters())

添加位置如下:

optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)

For example 2: 自定义网络,添加for循环,固定for循环之前的参数。

class Net(nn.Module): def __init__(self): 
    super(Net, self).__init__() 
    self.conv1 = nn.Conv2d(1, 6, 5) 
    self.conv2 = nn.Conv2d(6, 16, 5) 
    for p in self.parameters(): 
        p.requires_grad=False 
    self.fc1 = nn.Linear(16 * 5 * 5, 120) 
    self.fc2 = nn.Linear(120, 84) 
    self.fc3 = nn.Linear(84, 10)

二:volatile参数设置

     Variable的参数volatile=True和requires_grad=False的功能差不多,但是volatile的力量更大。当有一个输入的volatile=True时,那么输出的volatile=True。volatile=True推荐在模型的推理过程(测试)中使用,这时只需要令输入的voliate=True,保证用最小的内存来执行推理,不会保存任何中间状态。
For example:

>>> regular_input = Variable(torch.randn(5, 5))
>>> volatile_input = Variable(torch.randn(5, 5), volatile=True)
>>> model = torchvision.models.resnet18(pretrained=True)
>>> model(regular_input).requires_grad #输出的requires_grad应该是True,因为中间层的Variable的requires_grad默认是True
True
>>> model(volatile_input).requires_grad#输出的requires_grad是False,因为输出的volatile是True(等价于requires_grad是False)
False
>>> model(volatile_input).volatile
True

Ref:

https://blog.csdn.net/VictoriaW/article/details/72779407

https://www.jianshu.com/p/4ec8e310495d

https://blog.csdn.net/guotong1988/article/details/79739775

 

你可能感兴趣的:(pytorch)