在使用pytorch进行网络训练的时候,有时候不可避免的使用迁移学习(trainsfer learning),即使用已经训练好的模型(如resnet、inception等),固定其已经训练好的网络层参数,然后进行finetune。
以下代码是以resnet-101为例使用pytorch进行finetune的操作:
#导入必要模块
import torch
import torch.nn as nn
from torchvision import models
#读取pytorch自带的resnet-101模型,因为使用了预训练模型,所以会自动下载模型参数
model=models.resnet101(pretrained=True)
#对于模型的每个权重,使其不进行反向传播,即固定参数
for param in model.parameters():
param.requires_grad = False
#但是参数全部固定了,也没法进行学习,所以我们不固定最后一层,即全连接层fc
for param in model.fc.parameters():
param.requires_grad = True
如果想修改最后一层的话,可以这么修改:
class_num = 200 #假设要分类数目是200
channel_in = model.fc.in_features#获取fc层的输入通道数
#然后把resnet-101的fc层替换成300类别的fc层
model.fc = nn.Linear(channel_in,class_num)
也可以删除最后一层或者几层
#这里[:-1]代表删除最后一层
new_model = nn.Sequential(*list(model.children())[:-1])
#或删除最后两层
new_model = nn.Sequential(*list(model.children())[:-2])
当然删除了最后几层,可能还要添加,可以直接把以上代码写入新的网络里,然后再在forward代码块中添加,这个比较基础,就不详说了。
这个时候是如果按常规训练模型的方法直接使用optimizer的话会出错误的,如:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
会报错:ValueError: optimizing a parameter that doesn't require gradients
Traceback (most recent call last):
File "main.py", line 1, in
main()
File "main.py", line 20, in main
optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
File "C:\Anaconda3\lib\site-packages\torch\optim\sgd.py", line 64, in __init__
super(SGD, self).__init__(params, defaults)
File "C:\Anaconda3\lib\site-packages\torch\optim\optimizer.py", line 43, in __init__
self.add_param_group(param_group)
File "C:\Anaconda3\lib\site-packages\torch\optim\optimizer.py", line 193, in add_param_group
raise ValueError("optimizing a parameter that doesn't require gradients")
ValueError: optimizing a parameter that doesn't require gradients
这是因为optimizer的输入参数parameters必须都是可以修改、反向传播的,即requires_grad=True
,但是我们刚才已经固定了除了最后一层的所有参数,所以会出错。
解决方法是optimizer中只输入需要反向传播的参数:
#filter()函数过滤掉parameters中requires_grad=Fasle的参数
optimizer = torch.optim.SGD(
filter(lambda p: p.requires_grad, model.parameters()),#重要的是这一句
lr=0.1)
这样就可以进行正常的训练了。
这里引申一下:接下来的代码是如何输出网络模型的卷积方式以及权重数值
for child in model.children():
print(child)#打印网络模型的卷积方式
for param in child.parameters():#打印权重数值
print(param)
部分结果(第一层卷积层):
Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
Parameter containing:
tensor([[[[ 2.0222e-02, -4.3927e-03, -1.8274e-02, ..., -1.5180e-02,
-1.5794e-03, 9.3115e-03],
[-4.0752e-03, 3.2116e-03, -1.5956e-02, ..., -8.4465e-02,
-7.4997e-02, -4.0676e-02],
[ 3.5039e-03, 2.6746e-02, 5.0813e-02, ..., 3.3407e-02,
1.3659e-02, 2.7821e-02],
...,
[-3.6174e-02, -1.2986e-01, -3.0369e-01, ..., -3.7412e-01,
-1.3025e-01, 4.2633e-02],
[ 1.5479e-02, 2.3444e-02, 6.5222e-03, ..., -1.6439e-01,
-1.8245e-01, -9.7434e-02],
[-3.0444e-02, -1.1357e-02, 4.9984e-02, ..., 1.6412e-01,
1.0419e-01, -1.2681e-02]],
[[ 8.7115e-03, -5.8911e-03, -1.2204e-02, ..., -1.3515e-02,
1.5212e-02, 1.9115e-02],
[-6.8970e-03, 1.0470e-02, -7.7561e-03, ..., -7.9215e-02,
-5.9150e-02, -2.1380e-02],
[-2.4955e-03, 3.2179e-02, 7.6542e-02, ..., 8.9056e-02,
5.1445e-02, 3.4868e-02],
...,
[-6.3695e-02, -1.9335e-01, -4.2540e-01, ..., -5.4060e-01,
-2.1634e-01, 8.1145e-03],
[-9.9660e-03, 9.2712e-04, -4.3920e-02, ..., -2.9635e-01,
-2.8675e-01, -1.7962e-01],
[ 2.5959e-02, 7.8301e-02, 1.5091e-01, ..., 2.7260e-01,
2.0840e-01, 4.2282e-02]],
[[ 4.0026e-03, -9.1936e-03, -2.0569e-02, ..., 1.0229e-03,
1.5326e-02, 9.7999e-04],
[ 1.7419e-02, 3.0881e-02, 6.3634e-03, ..., -1.2466e-02,
7.4009e-03, 2.0044e-02],
[-1.3397e-02, -5.1027e-03, -3.2198e-02, ..., -5.9151e-02,
-3.1463e-02, -9.3953e-03],
...,
[-2.4370e-02, -8.1988e-02, -1.8747e-01, ..., -2.2849e-01,
-4.5930e-02, 5.7061e-02],
[-4.6034e-03, 1.7281e-02, 2.0939e-02, ..., -7.0717e-02,
-1.0450e-01, -8.8036e-02],
[-6.7201e-04, 1.1773e-02, 3.1304e-02, ..., 1.1040e-01,
9.6085e-02, -5.0325e-03]]],
[[[-1.3283e-02, -6.3476e-03, 1.7371e-02, ..., -2.6374e-02,
-3.9297e-02, 3.6769e-02],
[ 1.6402e-02, -9.2402e-03, 6.0237e-03, ..., 4.5398e-02,
6.1903e-02, -7.1906e-02],
[ 1.0526e-02, 1.2171e-02, 9.3159e-02, ..., -3.4628e-01,
6.6404e-02, 7.0177e-02],
...,
[-3.6881e-02, 7.5082e-02, -1.3295e-01, ..., 5.8651e-01,
1.0566e-01, -1.2248e-01],
[ 2.1673e-02, -5.9233e-02, 2.3098e-01, ..., -2.1465e-01,
2.1108e-01, -1.8771e-02],
[ 6.6413e-03, -5.1895e-02, -1.2541e-02, ..., -8.8446e-02,
-1.8269e-02, 3.2724e-02]],
[[-5.7876e-03, -6.5713e-03, 9.7605e-03, ..., -2.4045e-02,
-4.0857e-02, 3.7269e-02],
[ 1.7980e-02, -2.9323e-02, 9.4012e-03, ..., 3.5879e-02,
7.4789e-02, -3.7304e-02],
[ 1.8119e-02, 2.5782e-03, 9.7156e-02, ..., -4.4105e-01,
1.4873e-02, 1.0465e-01],
...,
[-4.4777e-02, 8.7919e-02, -2.3434e-01, ..., 7.4114e-01,
1.5556e-01, -1.4555e-01],
[ 3.1682e-02, -1.3436e-02, 2.4753e-01, ..., -2.5792e-01,
2.6018e-01, -1.2641e-02],
[ 2.1700e-02, -8.5579e-04, 2.5962e-02, ..., -1.0650e-01,
-2.9765e-02, 1.8971e-03]],
[[-5.0954e-03, 1.0142e-02, 2.8154e-03, ..., -2.2294e-03,
-3.5401e-02, 1.8325e-02],
[ 6.3929e-03, -1.6070e-02, -1.0550e-02, ..., 9.2290e-02,
6.2968e-02, -7.0256e-02],
[-3.6175e-03, -6.6921e-03, 5.0055e-02, ..., -2.8415e-01,
1.3017e-01, 7.0130e-02],
...,
[-3.1363e-02, 7.8047e-02, -7.0873e-02, ..., 4.2964e-01,
3.5380e-02, -8.7354e-02],
[ 3.0397e-02, -5.8957e-02, 2.3035e-01, ..., -2.0295e-01,
1.3268e-01, -1.0068e-02],
[ 8.1602e-03, -2.7920e-02, -3.8539e-02, ..., -2.7914e-02,
-2.1438e-02, 8.2988e-03]]],
[[[ 5.9179e-03, 1.1450e-02, -6.1776e-02, ..., -4.6299e-02,
6.7527e-02, 9.8275e-03],
[ 8.8669e-03, 7.3341e-02, -3.2901e-02, ..., -1.0453e-01,
1.7328e-01, 4.7307e-02],
[-1.1470e-02, 1.2149e-01, 2.7013e-03, ..., -1.5456e-01,
2.4886e-01, 3.5192e-02],
...,
[ 1.2814e-02, 1.1176e-01, -1.3131e-02, ..., -8.0056e-02,
2.6456e-01, 4.8966e-02],
[ 9.1849e-03, 8.3514e-02, -3.2001e-02, ..., -4.2758e-02,
1.8086e-01, 2.9671e-02],
[-3.6703e-03, 8.5900e-03, -7.5705e-02, ..., -3.2137e-02,
1.1875e-01, 1.3688e-02]],
[[-3.1539e-03, 3.1529e-02, -2.7158e-02, ..., -7.4867e-02,
4.2428e-02, 1.5711e-02],
[ 2.8099e-02, 1.2275e-01, -5.9815e-03, ..., -1.7016e-01,
1.5623e-01, 7.8585e-02],
[ 2.5077e-02, 2.0189e-01, 3.0908e-02, ..., -2.2413e-01,
2.7603e-01, 9.4537e-02],
...,
[ 3.5136e-02, 1.7748e-01, -1.4339e-02, ..., -1.4229e-01,
2.9657e-01, 1.0562e-01],
[ 1.0093e-02, 9.9101e-02, -4.6115e-02, ..., -8.1636e-02,
1.8805e-01, 6.2061e-02],
[ 8.4642e-03, 3.8329e-02, -5.6497e-02, ..., -5.7700e-02,
1.1272e-01, 4.3174e-02]],
[[-1.4282e-04, -2.3389e-03, -1.9654e-02, ..., -2.2358e-02,
1.9954e-02, -3.2502e-02],
[-2.6001e-03, 5.5482e-02, 1.5250e-02, ..., -8.4497e-02,
6.9498e-02, -1.2412e-02],
[-2.2391e-02, 9.0480e-02, 6.0876e-02, ..., -1.0753e-01,
1.3897e-01, -4.1540e-02],
...,
[-1.2348e-02, 8.0637e-02, 5.6340e-02, ..., -4.5426e-02,
1.3194e-01, -4.2526e-02],
[-1.7389e-02, 4.7423e-02, 1.9719e-02, ..., -2.9090e-02,
8.3996e-02, -2.8975e-02],
[-8.7350e-03, 2.0953e-02, -4.2252e-03, ..., -2.5390e-02,
4.4542e-02, -3.0590e-02]]],
...,
[[[ 2.1164e-02, 1.6922e-02, -2.4637e-02, ..., -3.5926e-03,
-4.5143e-03, -4.7802e-03],
[ 1.5557e-02, 4.1134e-02, 9.2594e-03, ..., 6.8536e-02,
-3.3796e-02, -1.2293e-01],
[-2.9705e-02, 1.1398e-02, -1.8864e-02, ..., 1.9850e-01,
1.9440e-01, -1.2550e-01],
...,
[ 1.4001e-02, -4.0922e-02, 2.5370e-01, ..., -4.6508e-01,
-1.2693e-01, 1.3241e-01],
[-5.6044e-03, -1.0473e-01, -1.6365e-02, ..., 4.1196e-02,
-1.8936e-01, 2.9299e-03],
[ 1.6253e-02, -5.6961e-02, -1.7359e-01, ..., 1.4994e-01,
-4.3681e-03, 3.1854e-02]],
[[-9.8924e-03, 5.4293e-03, -2.6138e-02, ..., 2.9590e-02,
7.4812e-03, -2.1530e-02],
[-1.5667e-02, 1.3143e-02, -1.7931e-02, ..., 1.3352e-01,
4.8710e-02, -5.9767e-02],
[-4.3947e-02, -1.8217e-02, -9.7404e-02, ..., 2.1595e-01,
3.1335e-01, -1.8402e-03],
...,
[ 6.3135e-02, 2.6133e-02, 3.3600e-01, ..., -7.0122e-01,
-2.6651e-01, 1.6588e-01],
[ 7.3457e-03, -5.4677e-02, 8.6192e-02, ..., -1.7732e-02,
-3.2929e-01, -5.9486e-02],
[ 3.2153e-03, -3.1568e-02, -9.5821e-02, ..., 1.8495e-01,
-7.3682e-02, -3.6894e-02]],
[[-5.5842e-03, -2.7447e-03, -2.5824e-02, ..., 9.0967e-04,
1.8520e-02, -2.8618e-03],
[-1.3200e-02, 5.1415e-03, -1.6833e-02, ..., 3.9462e-02,
1.5885e-03, -5.6381e-02],
[-2.9015e-02, -5.7217e-03, -3.4202e-02, ..., 1.6668e-01,
1.6170e-01, -6.7241e-02],
...,
[ 3.5389e-02, -2.1586e-02, 1.8604e-01, ..., -4.1291e-01,
-1.1603e-01, 1.0927e-01],
[ 2.3013e-03, -4.1210e-02, -1.8644e-02, ..., 2.2626e-02,
-2.0080e-01, -3.4115e-02],
[ 1.9892e-02, -1.5134e-03, -1.0327e-01, ..., 1.1701e-01,
-3.9063e-02, -1.6900e-02]]],
[[[-5.5538e-08, -2.6325e-08, -6.9138e-10, ..., 9.7495e-09,
1.3223e-08, 1.8906e-08],
[-2.6514e-08, -1.5337e-08, -1.8666e-08, ..., -2.9065e-08,
5.2469e-09, 1.9422e-08],
[-5.8648e-09, -1.9407e-08, -1.7830e-08, ..., -4.9517e-08,
-4.3926e-08, 5.7348e-09],
...,
[-7.6421e-08, -8.3873e-08, -6.7820e-08, ..., -3.4331e-08,
-2.4019e-08, 1.3504e-09],
[-5.3446e-08, -6.0856e-08, -5.9789e-08, ..., 7.6961e-09,
-7.3536e-09, 9.9825e-09],
[-7.5073e-08, -5.5320e-08, -4.5796e-08, ..., 9.2459e-09,
-3.1809e-09, 8.4687e-09]],
[[ 1.1481e-08, 4.6150e-08, 7.6648e-08, ..., 8.4643e-08,
8.0418e-08, 7.5105e-08],
[ 3.4774e-08, 5.7570e-08, 6.4259e-08, ..., 5.5634e-08,
8.3158e-08, 8.7862e-08],
[ 5.0421e-08, 3.9131e-08, 4.5809e-08, ..., 3.1533e-08,
4.0924e-08, 7.2495e-08],
...,
[-2.5365e-08, -2.9820e-08, -1.2614e-08, ..., 4.6912e-08,
5.6400e-08, 7.3814e-08],
[-2.6964e-09, -3.6669e-10, 6.5934e-10, ..., 8.3020e-08,
6.4033e-08, 7.0015e-08],
[-1.9816e-08, 4.0452e-09, 9.7432e-09, ..., 7.3291e-08,
4.6672e-08, 6.1159e-08]],
[[-2.2916e-08, 9.8197e-09, 4.0514e-08, ..., 4.4516e-08,
3.9299e-08, 3.6953e-08],
[ 1.9321e-08, 2.4614e-08, 2.8277e-08, ..., 2.2826e-08,
4.8924e-08, 4.8514e-08],
[ 6.3043e-08, 3.3363e-08, 3.0949e-08, ..., 9.2149e-09,
1.1062e-08, 3.2589e-08],
...,
[-1.7967e-08, -2.9080e-08, -7.1967e-10, ..., 3.7657e-08,
3.3544e-08, 2.4920e-08],
[ 5.6783e-09, 3.1829e-09, 5.8844e-09, ..., 6.8484e-08,
2.9343e-08, 1.6371e-08],
[-2.8144e-09, 1.3160e-08, 1.4706e-08, ..., 4.7713e-08,
1.7250e-08, 2.1211e-08]]],
[[[-6.7577e-03, -2.2072e-02, -1.6120e-02, ..., 1.9965e-02,
-2.3618e-02, -4.0877e-02],
[-1.0744e-02, 1.0183e-04, 4.3173e-03, ..., 3.2774e-02,
8.8146e-03, -1.4997e-02],
[-1.8087e-02, -5.0862e-03, 1.9248e-02, ..., 7.9772e-02,
5.7868e-02, 4.2854e-02],
...,
[ 1.0578e-02, 4.2848e-02, 7.4620e-02, ..., 1.1236e-01,
1.0815e-01, 1.0924e-01],
[-3.2028e-02, 3.7799e-03, 5.1643e-02, ..., 9.1239e-02,
6.8991e-02, 4.9936e-02],
[-5.3535e-02, -1.3198e-02, 1.6904e-02, ..., 8.6864e-02,
4.1772e-02, 3.9772e-02]],
[[-2.0383e-02, -3.3167e-02, -2.6665e-02, ..., 3.5744e-04,
-2.6331e-02, -4.0418e-02],
[-1.3169e-02, 1.9657e-03, 1.1509e-02, ..., 3.5650e-02,
1.7746e-02, -9.1710e-03],
[-2.7688e-02, 5.5393e-03, 3.6324e-02, ..., 7.0741e-02,
5.0612e-02, 2.8343e-02],
...,
[ 1.1707e-02, 4.1271e-02, 7.3833e-02, ..., 9.3581e-02,
1.0225e-01, 9.0626e-02],
[-1.0045e-02, 1.7943e-02, 5.3709e-02, ..., 8.3323e-02,
7.8453e-02, 4.8212e-02],
[-2.6048e-02, -1.3092e-04, 1.8727e-02, ..., 6.5370e-02,
2.8696e-02, 2.4652e-02]],
[[ 3.3157e-02, 1.2283e-03, 1.3401e-03, ..., 2.2791e-02,
-2.8388e-02, -5.0342e-02],
[ 2.1755e-02, 8.9925e-03, 6.1698e-03, ..., 1.7188e-02,
-1.5088e-02, -4.6674e-02],
[-4.7259e-03, -3.3812e-03, 5.7801e-03, ..., 2.7695e-02,
1.3164e-02, -1.7174e-02],
...,
[ 2.3359e-03, -1.1652e-02, -4.5875e-03, ..., -8.4557e-03,
1.5695e-02, -2.3720e-04],
[-6.0332e-03, -2.9486e-02, -9.2786e-03, ..., 1.7190e-02,
7.1863e-03, -2.3858e-02],
[-6.8612e-03, -3.4898e-02, -3.2312e-02, ..., -1.0367e-02,
-3.9500e-02, -3.4689e-02]]]])