网络中所有操作对象都是Varoable对象,而Variable有两个参数可用于固定参数:requires_grad和volatile。
一:requires_grad参数设置
Method 1: 初始化时指定,如下语句所示:
x = Variable(torch.randn(2, 3), requires_grad=True)
y = Variable(torch.randn(2, 3), requires_grad=False)
注意:Variable中requires_grad的默认值为False,但是Module中的层在定义时,相关的Variable中的requires_grad默认都是True。在计算图中,如果有一个输入的requires_grad是True,那么输出的requires_grad也是True。所以为了更方便的进行参数固定,建议使用Method 2 。
Method 2:网络模型中设置。在训练中想要固定网络的底层,可以令这部分网络对应的子图的参数requires_grad为False。这样在反向传播的过程中就不会计算这些参数对应的梯度。需要在nn.Module中直接插入如下语句:
for p in self.parameters():
p.requires_grad=False
For example 1: 在加载预训练模型后,在原来的基础上添加一部分的网络,这样可以固定原来的参数,然后只训练添加的这部分网络,结束后再全部训练。
class RESNET_attention(nn.Module):
def __init__(self, model, pretrained):
super(RESNET_attetnion, self).__init__()
self.resnet = model(pretrained)
for p in self.parameters():
p.requires_grad = False
self.f = nn.Conv2d(2048, 512, 1)
self.g = nn.Conv2d(2048, 512, 1)
self.h = nn.Conv2d(2048, 2048, 1)
self.softmax = nn.Softmax(-1)
self.gamma = nn.Parameter(torch.FloatTensor([0.0]))
self.avgpool = nn.AvgPool2d(7, stride=1)
self.resnet.fc = nn.Linear(2048, 10)
这样就将for循环以上的参数固定,只训练下面的参数,f,g,h,gamma,fc等,注意需要在optimizer中添加以下语句:
filter(lambda p:p.requires_grad, model.parameters())
添加位置如下:
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.0001, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
For example 2: 自定义网络,添加for循环,固定for循环之前的参数。
class Net(nn.Module): def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
for p in self.parameters():
p.requires_grad=False
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
二:volatile参数设置
Variable的参数volatile=True和requires_grad=False的功能差不多,但是volatile的力量更大。当有一个输入的volatile=True时,那么输出的volatile=True。volatile=True推荐在模型的推理过程(测试)中使用,这时只需要令输入的voliate=True,保证用最小的内存来执行推理,不会保存任何中间状态。
For example:
>>> regular_input = Variable(torch.randn(5, 5))
>>> volatile_input = Variable(torch.randn(5, 5), volatile=True)
>>> model = torchvision.models.resnet18(pretrained=True)
>>> model(regular_input).requires_grad #输出的requires_grad应该是True,因为中间层的Variable的requires_grad默认是True
True
>>> model(volatile_input).requires_grad#输出的requires_grad是False,因为输出的volatile是True(等价于requires_grad是False)
False
>>> model(volatile_input).volatile
True
Ref:
https://blog.csdn.net/VictoriaW/article/details/72779407
https://www.jianshu.com/p/4ec8e310495d
https://blog.csdn.net/guotong1988/article/details/79739775