使用pytorch读取、使用预训练模型进行finetune:以Resnet-101为例

在使用pytorch进行网络训练的时候,有时候不可避免的使用迁移学习(trainsfer learning),即使用已经训练好的模型(如resnet、inception等),固定其已经训练好的网络层参数,然后进行finetune。
以下代码是以resnet-101为例使用pytorch进行finetune的操作:

#导入必要模块
import torch
import torch.nn as nn
from torchvision import models

#读取pytorch自带的resnet-101模型,因为使用了预训练模型,所以会自动下载模型参数
model=models.resnet101(pretrained=True)

#对于模型的每个权重,使其不进行反向传播,即固定参数
for param in model.parameters():
    param.requires_grad = False
#但是参数全部固定了,也没法进行学习,所以我们不固定最后一层,即全连接层fc
for param in model.fc.parameters():
    param.requires_grad = True

如果想修改最后一层的话,可以这么修改:

class_num = 200 #假设要分类数目是200
channel_in = model.fc.in_features#获取fc层的输入通道数
#然后把resnet-101的fc层替换成300类别的fc层
model.fc = nn.Linear(channel_in,class_num)

也可以删除最后一层或者几层

#这里[:-1]代表删除最后一层
new_model = nn.Sequential(*list(model.children())[:-1])
#或删除最后两层
new_model = nn.Sequential(*list(model.children())[:-2])

当然删除了最后几层,可能还要添加,可以直接把以上代码写入新的网络里,然后再在forward代码块中添加,这个比较基础,就不详说了。

这个时候是如果按常规训练模型的方法直接使用optimizer的话会出错误的,如:

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

会报错:ValueError: optimizing a parameter that doesn't require gradients

Traceback (most recent call last):
  File "main.py", line 1, in 
    main()
  File "main.py", line 20, in main
    optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
  File "C:\Anaconda3\lib\site-packages\torch\optim\sgd.py", line 64, in __init__
    super(SGD, self).__init__(params, defaults)
  File "C:\Anaconda3\lib\site-packages\torch\optim\optimizer.py", line 43, in __init__
    self.add_param_group(param_group)
  File "C:\Anaconda3\lib\site-packages\torch\optim\optimizer.py", line 193, in add_param_group
    raise ValueError("optimizing a parameter that doesn't require gradients")
ValueError: optimizing a parameter that doesn't require gradients

这是因为optimizer的输入参数parameters必须都是可以修改、反向传播的,即requires_grad=True,但是我们刚才已经固定了除了最后一层的所有参数,所以会出错。
解决方法是optimizer中只输入需要反向传播的参数:

#filter()函数过滤掉parameters中requires_grad=Fasle的参数
optimizer = torch.optim.SGD(
                        filter(lambda p: p.requires_grad, model.parameters()),#重要的是这一句
                        lr=0.1)

这样就可以进行正常的训练了。

这里引申一下:接下来的代码是如何输出网络模型的卷积方式以及权重数值

for child in model.children():
    print(child)#打印网络模型的卷积方式
    for param in child.parameters():#打印权重数值
        print(param)

部分结果(第一层卷积层):

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
Parameter containing:
tensor([[[[ 2.0222e-02, -4.3927e-03, -1.8274e-02,  ..., -1.5180e-02,
           -1.5794e-03,  9.3115e-03],
          [-4.0752e-03,  3.2116e-03, -1.5956e-02,  ..., -8.4465e-02,
           -7.4997e-02, -4.0676e-02],
          [ 3.5039e-03,  2.6746e-02,  5.0813e-02,  ...,  3.3407e-02,
            1.3659e-02,  2.7821e-02],
          ...,
          [-3.6174e-02, -1.2986e-01, -3.0369e-01,  ..., -3.7412e-01,
           -1.3025e-01,  4.2633e-02],
          [ 1.5479e-02,  2.3444e-02,  6.5222e-03,  ..., -1.6439e-01,
           -1.8245e-01, -9.7434e-02],
          [-3.0444e-02, -1.1357e-02,  4.9984e-02,  ...,  1.6412e-01,
            1.0419e-01, -1.2681e-02]],

         [[ 8.7115e-03, -5.8911e-03, -1.2204e-02,  ..., -1.3515e-02,
            1.5212e-02,  1.9115e-02],
          [-6.8970e-03,  1.0470e-02, -7.7561e-03,  ..., -7.9215e-02,
           -5.9150e-02, -2.1380e-02],
          [-2.4955e-03,  3.2179e-02,  7.6542e-02,  ...,  8.9056e-02,
            5.1445e-02,  3.4868e-02],
          ...,
          [-6.3695e-02, -1.9335e-01, -4.2540e-01,  ..., -5.4060e-01,
           -2.1634e-01,  8.1145e-03],
          [-9.9660e-03,  9.2712e-04, -4.3920e-02,  ..., -2.9635e-01,
           -2.8675e-01, -1.7962e-01],
          [ 2.5959e-02,  7.8301e-02,  1.5091e-01,  ...,  2.7260e-01,
            2.0840e-01,  4.2282e-02]],

         [[ 4.0026e-03, -9.1936e-03, -2.0569e-02,  ...,  1.0229e-03,
            1.5326e-02,  9.7999e-04],
          [ 1.7419e-02,  3.0881e-02,  6.3634e-03,  ..., -1.2466e-02,
            7.4009e-03,  2.0044e-02],
          [-1.3397e-02, -5.1027e-03, -3.2198e-02,  ..., -5.9151e-02,
           -3.1463e-02, -9.3953e-03],
          ...,
          [-2.4370e-02, -8.1988e-02, -1.8747e-01,  ..., -2.2849e-01,
           -4.5930e-02,  5.7061e-02],
          [-4.6034e-03,  1.7281e-02,  2.0939e-02,  ..., -7.0717e-02,
           -1.0450e-01, -8.8036e-02],
          [-6.7201e-04,  1.1773e-02,  3.1304e-02,  ...,  1.1040e-01,
            9.6085e-02, -5.0325e-03]]],


        [[[-1.3283e-02, -6.3476e-03,  1.7371e-02,  ..., -2.6374e-02,
           -3.9297e-02,  3.6769e-02],
          [ 1.6402e-02, -9.2402e-03,  6.0237e-03,  ...,  4.5398e-02,
            6.1903e-02, -7.1906e-02],
          [ 1.0526e-02,  1.2171e-02,  9.3159e-02,  ..., -3.4628e-01,
            6.6404e-02,  7.0177e-02],
          ...,
          [-3.6881e-02,  7.5082e-02, -1.3295e-01,  ...,  5.8651e-01,
            1.0566e-01, -1.2248e-01],
          [ 2.1673e-02, -5.9233e-02,  2.3098e-01,  ..., -2.1465e-01,
            2.1108e-01, -1.8771e-02],
          [ 6.6413e-03, -5.1895e-02, -1.2541e-02,  ..., -8.8446e-02,
           -1.8269e-02,  3.2724e-02]],

         [[-5.7876e-03, -6.5713e-03,  9.7605e-03,  ..., -2.4045e-02,
           -4.0857e-02,  3.7269e-02],
          [ 1.7980e-02, -2.9323e-02,  9.4012e-03,  ...,  3.5879e-02,
            7.4789e-02, -3.7304e-02],
          [ 1.8119e-02,  2.5782e-03,  9.7156e-02,  ..., -4.4105e-01,
            1.4873e-02,  1.0465e-01],
          ...,
          [-4.4777e-02,  8.7919e-02, -2.3434e-01,  ...,  7.4114e-01,
            1.5556e-01, -1.4555e-01],
          [ 3.1682e-02, -1.3436e-02,  2.4753e-01,  ..., -2.5792e-01,
            2.6018e-01, -1.2641e-02],
          [ 2.1700e-02, -8.5579e-04,  2.5962e-02,  ..., -1.0650e-01,
           -2.9765e-02,  1.8971e-03]],

         [[-5.0954e-03,  1.0142e-02,  2.8154e-03,  ..., -2.2294e-03,
           -3.5401e-02,  1.8325e-02],
          [ 6.3929e-03, -1.6070e-02, -1.0550e-02,  ...,  9.2290e-02,
            6.2968e-02, -7.0256e-02],
          [-3.6175e-03, -6.6921e-03,  5.0055e-02,  ..., -2.8415e-01,
            1.3017e-01,  7.0130e-02],
          ...,
          [-3.1363e-02,  7.8047e-02, -7.0873e-02,  ...,  4.2964e-01,
            3.5380e-02, -8.7354e-02],
          [ 3.0397e-02, -5.8957e-02,  2.3035e-01,  ..., -2.0295e-01,
            1.3268e-01, -1.0068e-02],
          [ 8.1602e-03, -2.7920e-02, -3.8539e-02,  ..., -2.7914e-02,
           -2.1438e-02,  8.2988e-03]]],


        [[[ 5.9179e-03,  1.1450e-02, -6.1776e-02,  ..., -4.6299e-02,
            6.7527e-02,  9.8275e-03],
          [ 8.8669e-03,  7.3341e-02, -3.2901e-02,  ..., -1.0453e-01,
            1.7328e-01,  4.7307e-02],
          [-1.1470e-02,  1.2149e-01,  2.7013e-03,  ..., -1.5456e-01,
            2.4886e-01,  3.5192e-02],
          ...,
          [ 1.2814e-02,  1.1176e-01, -1.3131e-02,  ..., -8.0056e-02,
            2.6456e-01,  4.8966e-02],
          [ 9.1849e-03,  8.3514e-02, -3.2001e-02,  ..., -4.2758e-02,
            1.8086e-01,  2.9671e-02],
          [-3.6703e-03,  8.5900e-03, -7.5705e-02,  ..., -3.2137e-02,
            1.1875e-01,  1.3688e-02]],

         [[-3.1539e-03,  3.1529e-02, -2.7158e-02,  ..., -7.4867e-02,
            4.2428e-02,  1.5711e-02],
          [ 2.8099e-02,  1.2275e-01, -5.9815e-03,  ..., -1.7016e-01,
            1.5623e-01,  7.8585e-02],
          [ 2.5077e-02,  2.0189e-01,  3.0908e-02,  ..., -2.2413e-01,
            2.7603e-01,  9.4537e-02],
          ...,
          [ 3.5136e-02,  1.7748e-01, -1.4339e-02,  ..., -1.4229e-01,
            2.9657e-01,  1.0562e-01],
          [ 1.0093e-02,  9.9101e-02, -4.6115e-02,  ..., -8.1636e-02,
            1.8805e-01,  6.2061e-02],
          [ 8.4642e-03,  3.8329e-02, -5.6497e-02,  ..., -5.7700e-02,
            1.1272e-01,  4.3174e-02]],

         [[-1.4282e-04, -2.3389e-03, -1.9654e-02,  ..., -2.2358e-02,
            1.9954e-02, -3.2502e-02],
          [-2.6001e-03,  5.5482e-02,  1.5250e-02,  ..., -8.4497e-02,
            6.9498e-02, -1.2412e-02],
          [-2.2391e-02,  9.0480e-02,  6.0876e-02,  ..., -1.0753e-01,
            1.3897e-01, -4.1540e-02],
          ...,
          [-1.2348e-02,  8.0637e-02,  5.6340e-02,  ..., -4.5426e-02,
            1.3194e-01, -4.2526e-02],
          [-1.7389e-02,  4.7423e-02,  1.9719e-02,  ..., -2.9090e-02,
            8.3996e-02, -2.8975e-02],
          [-8.7350e-03,  2.0953e-02, -4.2252e-03,  ..., -2.5390e-02,
            4.4542e-02, -3.0590e-02]]],


        ...,


        [[[ 2.1164e-02,  1.6922e-02, -2.4637e-02,  ..., -3.5926e-03,
           -4.5143e-03, -4.7802e-03],
          [ 1.5557e-02,  4.1134e-02,  9.2594e-03,  ...,  6.8536e-02,
           -3.3796e-02, -1.2293e-01],
          [-2.9705e-02,  1.1398e-02, -1.8864e-02,  ...,  1.9850e-01,
            1.9440e-01, -1.2550e-01],
          ...,
          [ 1.4001e-02, -4.0922e-02,  2.5370e-01,  ..., -4.6508e-01,
           -1.2693e-01,  1.3241e-01],
          [-5.6044e-03, -1.0473e-01, -1.6365e-02,  ...,  4.1196e-02,
           -1.8936e-01,  2.9299e-03],
          [ 1.6253e-02, -5.6961e-02, -1.7359e-01,  ...,  1.4994e-01,
           -4.3681e-03,  3.1854e-02]],

         [[-9.8924e-03,  5.4293e-03, -2.6138e-02,  ...,  2.9590e-02,
            7.4812e-03, -2.1530e-02],
          [-1.5667e-02,  1.3143e-02, -1.7931e-02,  ...,  1.3352e-01,
            4.8710e-02, -5.9767e-02],
          [-4.3947e-02, -1.8217e-02, -9.7404e-02,  ...,  2.1595e-01,
            3.1335e-01, -1.8402e-03],
          ...,
          [ 6.3135e-02,  2.6133e-02,  3.3600e-01,  ..., -7.0122e-01,
           -2.6651e-01,  1.6588e-01],
          [ 7.3457e-03, -5.4677e-02,  8.6192e-02,  ..., -1.7732e-02,
           -3.2929e-01, -5.9486e-02],
          [ 3.2153e-03, -3.1568e-02, -9.5821e-02,  ...,  1.8495e-01,
           -7.3682e-02, -3.6894e-02]],

         [[-5.5842e-03, -2.7447e-03, -2.5824e-02,  ...,  9.0967e-04,
            1.8520e-02, -2.8618e-03],
          [-1.3200e-02,  5.1415e-03, -1.6833e-02,  ...,  3.9462e-02,
            1.5885e-03, -5.6381e-02],
          [-2.9015e-02, -5.7217e-03, -3.4202e-02,  ...,  1.6668e-01,
            1.6170e-01, -6.7241e-02],
          ...,
          [ 3.5389e-02, -2.1586e-02,  1.8604e-01,  ..., -4.1291e-01,
           -1.1603e-01,  1.0927e-01],
          [ 2.3013e-03, -4.1210e-02, -1.8644e-02,  ...,  2.2626e-02,
           -2.0080e-01, -3.4115e-02],
          [ 1.9892e-02, -1.5134e-03, -1.0327e-01,  ...,  1.1701e-01,
           -3.9063e-02, -1.6900e-02]]],


        [[[-5.5538e-08, -2.6325e-08, -6.9138e-10,  ...,  9.7495e-09,
            1.3223e-08,  1.8906e-08],
          [-2.6514e-08, -1.5337e-08, -1.8666e-08,  ..., -2.9065e-08,
            5.2469e-09,  1.9422e-08],
          [-5.8648e-09, -1.9407e-08, -1.7830e-08,  ..., -4.9517e-08,
           -4.3926e-08,  5.7348e-09],
          ...,
          [-7.6421e-08, -8.3873e-08, -6.7820e-08,  ..., -3.4331e-08,
           -2.4019e-08,  1.3504e-09],
          [-5.3446e-08, -6.0856e-08, -5.9789e-08,  ...,  7.6961e-09,
           -7.3536e-09,  9.9825e-09],
          [-7.5073e-08, -5.5320e-08, -4.5796e-08,  ...,  9.2459e-09,
           -3.1809e-09,  8.4687e-09]],

         [[ 1.1481e-08,  4.6150e-08,  7.6648e-08,  ...,  8.4643e-08,
            8.0418e-08,  7.5105e-08],
          [ 3.4774e-08,  5.7570e-08,  6.4259e-08,  ...,  5.5634e-08,
            8.3158e-08,  8.7862e-08],
          [ 5.0421e-08,  3.9131e-08,  4.5809e-08,  ...,  3.1533e-08,
            4.0924e-08,  7.2495e-08],
          ...,
          [-2.5365e-08, -2.9820e-08, -1.2614e-08,  ...,  4.6912e-08,
            5.6400e-08,  7.3814e-08],
          [-2.6964e-09, -3.6669e-10,  6.5934e-10,  ...,  8.3020e-08,
            6.4033e-08,  7.0015e-08],
          [-1.9816e-08,  4.0452e-09,  9.7432e-09,  ...,  7.3291e-08,
            4.6672e-08,  6.1159e-08]],

         [[-2.2916e-08,  9.8197e-09,  4.0514e-08,  ...,  4.4516e-08,
            3.9299e-08,  3.6953e-08],
          [ 1.9321e-08,  2.4614e-08,  2.8277e-08,  ...,  2.2826e-08,
            4.8924e-08,  4.8514e-08],
          [ 6.3043e-08,  3.3363e-08,  3.0949e-08,  ...,  9.2149e-09,
            1.1062e-08,  3.2589e-08],
          ...,
          [-1.7967e-08, -2.9080e-08, -7.1967e-10,  ...,  3.7657e-08,
            3.3544e-08,  2.4920e-08],
          [ 5.6783e-09,  3.1829e-09,  5.8844e-09,  ...,  6.8484e-08,
            2.9343e-08,  1.6371e-08],
          [-2.8144e-09,  1.3160e-08,  1.4706e-08,  ...,  4.7713e-08,
            1.7250e-08,  2.1211e-08]]],


        [[[-6.7577e-03, -2.2072e-02, -1.6120e-02,  ...,  1.9965e-02,
           -2.3618e-02, -4.0877e-02],
          [-1.0744e-02,  1.0183e-04,  4.3173e-03,  ...,  3.2774e-02,
            8.8146e-03, -1.4997e-02],
          [-1.8087e-02, -5.0862e-03,  1.9248e-02,  ...,  7.9772e-02,
            5.7868e-02,  4.2854e-02],
          ...,
          [ 1.0578e-02,  4.2848e-02,  7.4620e-02,  ...,  1.1236e-01,
            1.0815e-01,  1.0924e-01],
          [-3.2028e-02,  3.7799e-03,  5.1643e-02,  ...,  9.1239e-02,
            6.8991e-02,  4.9936e-02],
          [-5.3535e-02, -1.3198e-02,  1.6904e-02,  ...,  8.6864e-02,
            4.1772e-02,  3.9772e-02]],

         [[-2.0383e-02, -3.3167e-02, -2.6665e-02,  ...,  3.5744e-04,
           -2.6331e-02, -4.0418e-02],
          [-1.3169e-02,  1.9657e-03,  1.1509e-02,  ...,  3.5650e-02,
            1.7746e-02, -9.1710e-03],
          [-2.7688e-02,  5.5393e-03,  3.6324e-02,  ...,  7.0741e-02,
            5.0612e-02,  2.8343e-02],
          ...,
          [ 1.1707e-02,  4.1271e-02,  7.3833e-02,  ...,  9.3581e-02,
            1.0225e-01,  9.0626e-02],
          [-1.0045e-02,  1.7943e-02,  5.3709e-02,  ...,  8.3323e-02,
            7.8453e-02,  4.8212e-02],
          [-2.6048e-02, -1.3092e-04,  1.8727e-02,  ...,  6.5370e-02,
            2.8696e-02,  2.4652e-02]],

         [[ 3.3157e-02,  1.2283e-03,  1.3401e-03,  ...,  2.2791e-02,
           -2.8388e-02, -5.0342e-02],
          [ 2.1755e-02,  8.9925e-03,  6.1698e-03,  ...,  1.7188e-02,
           -1.5088e-02, -4.6674e-02],
          [-4.7259e-03, -3.3812e-03,  5.7801e-03,  ...,  2.7695e-02,
            1.3164e-02, -1.7174e-02],
          ...,
          [ 2.3359e-03, -1.1652e-02, -4.5875e-03,  ..., -8.4557e-03,
            1.5695e-02, -2.3720e-04],
          [-6.0332e-03, -2.9486e-02, -9.2786e-03,  ...,  1.7190e-02,
            7.1863e-03, -2.3858e-02],
          [-6.8612e-03, -3.4898e-02, -3.2312e-02,  ..., -1.0367e-02,
           -3.9500e-02, -3.4689e-02]]]])

你可能感兴趣的:(pytorch,深度学习,finetune,迁移学习,resnet,Python,深度学习,pytorch)