pytorch实践(更改resnet50全连接层,重新训练,用于二分类)

背景交代(很杂乱,可跳过):

最开始的目的是完成和SAN论文里一样,用把四种风格图片当做四种标签,在原来的权重基础上,整体重新训练resnet50两轮(原文中是resnet101),下图为resnet网络结构

pytorch实践(更改resnet50全连接层,重新训练,用于二分类)_第1张图片

改变了最后的全连接层,从头开始重新训练,结果不太对,训练日志部分如下,一直不是等于100就是0

Training dir : D:\postgraduate\major_related\face_related\datasets\SAN-300W-test\300W-Convert
epoch : [0/50] lr=0.01
after 1 epoch, prec:57.14285714285714%,loss:1.261704683303833,input:torch.Size([7, 3, 224, 224]),output:torch.Size([7, 4])
after 1 epoch, prec:100.0%,loss:0.0,input:torch.Size([7, 3, 224, 224]),output:torch.Size([7, 4])
after 1 epoch, prec:0.0%,loss:23.00307846069336,input:torch.Size([7, 3, 224, 224]),output:torch.Size([7, 4])
after 1 epoch, prec:0.0%,loss:15.251273155212402,input:torch.Size([6, 3, 224, 224]),output:torch.Size([6, 4])
epoch : [1/50] lr=0.009978950082958632
after 2 epoch, prec:0.0%,loss:32.996482849121094,input:torch.Size([7, 3, 224, 224]),output:torch.Size([7, 4])
after 2 epoch, prec:0.0%,loss:20.09764289855957,input:torch.Size([7, 3, 224, 224]),output:torch.Size([7, 4])
after 2 epoch, prec:0.0%,loss:1.9084116220474243,input:torch.Size([7, 3, 224, 224]),output:torch.Size([7, 4])
after 2 epoch, prec:0.0%,loss:2.1563243865966797,input:torch.Size([6, 3, 224, 224]),output:torch.Size([6, 4])
epoch : [2/50] lr=0.00995794447581801
after 3 epoch, prec:0.0%,loss:6.454265594482422,input:torch.Size([7, 3, 224, 224]),output:torch.Size([7, 4])
after 3 epoch, prec:0.0%,loss:2.507550001144409,input:torch.Size([7, 3, 224, 224]),output:torch.Size([7, 4])
after 3 epoch, prec:0.0%,loss:2.3660380840301514,input:torch.Size([7, 3, 224, 224]),output:torch.Size([7, 4])
after 3 epoch, prec:0.0%,loss:2.6066017150878906,input:torch.Size([6, 3, 224, 224]),output:torch.Size([6, 4])
epoch : [3/50] lr=0.009936983085306158
after 4 epoch, prec:0.0%,loss:1.5964949131011963,input:torch.Size([7, 3, 224, 224]),output:torch.Size([7, 4])
after 4 epoch, prec:0.0%,loss:1.3409065008163452,input:torch.Size([7, 3, 224, 224]),output:torch.Size([7, 4])
after 4 epoch, prec:0.0%,loss:2.040916681289673,input:torch.Size([7, 3, 224, 224]),output:torch.Size([7, 4])
after 4 epoch, prec:0.0%,loss:1.7939444780349731,input:torch.Size([6, 3, 224, 224]),output:torch.Size([6, 4])
epoch : [4/50] lr=0.00991606581834744
after 5 epoch, prec:14.285714285714285%,loss:1.5786248445510864,input:torch.Size([7, 3, 224, 224]),output:torch.Size([7, 4])
after 5 epoch, prec:0.0%,loss:2.005967855453491,input:torch.Size([7, 3, 224, 224]),output:torch.Size([7, 4])
after 5 epoch, prec:0.0%,loss:1.800467848777771,input:torch.Size([7, 3, 224, 224]),output:torch.Size([7, 4])
after 5 epoch, prec:0.0%,loss:1.7157955169677734,input:torch.Size([6, 3, 224, 224]),output:torch.Size([6, 4])

改变全连接层的代码是这样哒,第一句就是获取原来fc层的输入通道数,把原来的全连接层改成了下面这样结构的模块(原来Resnet输出层是4096*1000,我这样改完就是4096*256 256*2 所以输出是2分类 中间还加了一层relu 一层dropout),其实也可以直接改成nn.Linear(channel_in, class_num)

channel_in = resnet.fc.in_features
    class_num = 2
    resnet.fc = nn.Sequential(
        nn.Linear(channel_in, 256),
        nn.ReLU(),
        nn.Dropout(0.4),
        nn.Linear(256, class_num),
        nn.LogSoftmax(dim=1)
    )

如果想要修改某一层,需要知道这一层的名字,可以通过输出网络结构来得知

 # 输出网络的结构
     for child in net.children():
        print(child)

因为把所有权重都设置成受梯度影响训练速度太慢,所以打算就把这个做一个修改经典网络来分类自己数据集的实验

如果想要整体的权重都梯度下降,而不是只有fc层梯度下降的话,把下面这一句的False改成Ture

for param in net.parameters():
        param.requires_grad = False

可以通过这种方法来冻结某些层

一开始还是用的SAN的数据集对四种风格分类

只训练了两轮 效果不好  训练过程中准确率一直在0和100之间大幅度跳动 考虑是学习率太大 修改成小一点的再加上指数衰减

训练过程中Loss和accuracy的曲线这样(横轴是每10个bactchsize输出一次)

pytorch实践(更改resnet50全连接层,重新训练,用于二分类)_第2张图片

训练完的模型对测试集分类准确率全为0 因为这个数据集太大 没办法做多轮试验

于是换了一个网上一个蜜蜂和蚂蚁分类的数据集(因为他比较小,可以做多轮试验)

数据集下载地址蜜蜂蚂蚁二分类数据集

学习率先设置0.001,衰减速率设置的0.5(记不清了)

训练了60多轮,这次准确率也是下降到80左右就不动了,后来甚至下降

终止训练,考虑是初始学习率太小了衰减的又太快,设置初始0.005,衰减0.6,加载看起来最好的第50轮的参数重新训练100轮

而且修改了输出图像的函数,之前我是每10个batchsize就输出一次,然后拿输出的这些作为画图的点,所以展示的不是整个集合的数值变化,这次我是每个batchsize的都添加到这轮的量里,然后每轮结束算所有batchsize的平均,拿每一轮的准确率和loss画图 ,下图是最后100轮的训练结果

pytorch实践(更改resnet50全连接层,重新训练,用于二分类)_第3张图片

在网上随意找了两个蜜蜂和蚂蚁的图,进行分类

这两张图都可以分类成功

pytorch实践(更改resnet50全连接层,重新训练,用于二分类)_第4张图片

pytorch实践(更改resnet50全连接层,重新训练,用于二分类)_第5张图片

下面代码里的test1.jpg就是上面蚂蚁这张图

如果想要对整个验证集测试的话,用resnet_eval()就可以

全部的代码

import torch
import torch.nn as nn
import torchvision
from torchvision import models
from torchvision import transforms
from PIL import Image
import torch.optim as optim
import torchvision.transforms as visiontransforms
from util.time_utils import time_for_file, print_log
from util.visualiztion import draw_loss_and_accuracy
import random
import numpy as np
import os.path as osp
import time

log_save_root_path = "./"
model_save_root_path = "./"
train_path = 'D:\postgraduate\major_related\\face_related\datasets\hymenoptera_data\\train'
test_path = 'D:\postgraduate\major_related\\face_related\datasets\hymenoptera_data\\val'


def loadTrainData():
    vision_normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    trainset = torchvision.datasets.ImageFolder(train_path, transform=transforms.Compose(
        [visiontransforms.RandomResizedCrop(224),
         visiontransforms.RandomHorizontalFlip(),
         visiontransforms.ToTensor(),
         vision_normalize]
    ))
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
    return trainloader


def loadTestData():
    vision_normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    testset = torchvision.datasets.ImageFolder(test_path, transform=transforms.Compose(
        [visiontransforms.Resize(256),
         visiontransforms.CenterCrop(224),
         visiontransforms.ToTensor(),
         vision_normalize]
    ))
    testloader = torch.utils.data.DataLoader(testset, batch_size=32)
    return testloader


def adjust_learning_rate(optimizer, epoch, train_epoch, learning_rate):
    lr = learning_rate * (0.6 ** (epoch / train_epoch))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr


def accuracy(output, target):
    output = np.array(output.numpy())
    target = np.array(target.numpy())
    prec = 0
    for i in range(output.shape[0]):
        pos = np.unravel_index(np.argmax(output[i]), output.shape)
        pre_label = pos[1]
        if pre_label == target[i]:
            prec += 1
    prec /= target.size
    prec *= 100
    return prec


def resnet_finute(train_epoch, print_freq, learning_rate_start):
    log = open(osp.join(log_save_root_path, 'cluster_seed_{}_{}.txt'.format(random.randint(1, 10000), time_for_file())),
               'w')
    net = models.resnet50(pretrained=True)
    channel_in = net.fc.in_features
    class_num = 2
    net.fc = nn.Sequential(
        nn.Linear(channel_in, 256),
        nn.ReLU(),
        nn.Dropout(0.4),
        nn.Linear(256, class_num),
        nn.LogSoftmax(dim=1)
    )
    for param in net.parameters():
        param.requires_grad = False

    for param in net.fc.parameters():
        param.requires_grad = True

    # # 输出网络的结构
    # for child in net.children():
    #     print(child)

    net.load_state_dict(
        torch.load(osp.join(model_save_root_path, 'resnet50_50_2020-04-09_22-15-11.pth')))
     # 后来训练100轮的模型是在我之前训练完50轮的基础上训练的,如果想从头开始训练,可以把加载参数这一句注释掉,但是可能需要重新调整学习率和衰减率

    # 用于可视化Loss和Accuracy的列表
    Loss_list = []
    Accuracy_list = []

    trainloader = loadTrainData()
    optimizer = optim.SGD(net.parameters(), lr=learning_rate_start, momentum=0.9)
    criterion = nn.NLLLoss()
    print_log('Training dir : {:}'.format(train_path), log)
    for epoch in range(train_epoch):
        epoch_accuracy = 0
        epoch_loss = 0
        learning_rate = adjust_learning_rate(optimizer, epoch, train_epoch, learning_rate_start)
        print_log('epoch : [{}/{}] lr={}'.format(epoch, train_epoch, learning_rate), log)
        net.train()
        for i, (inputs, target) in enumerate(trainloader):
            output = net(inputs)
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            prec = accuracy(output.data, target)
            epoch_accuracy += prec
            epoch_loss += loss
            if i % print_freq == 0 or i + 1 == len(trainloader):
                print_log(
                    'after {} epoch, {}th batchsize, prec:{}%,loss:{},input:{},output:{}'.format(epoch + 1, i + 1, prec,
                                                                                                 loss, inputs.size(),
                                                                                                 output.size()), log)
        epoch_loss /= len(trainloader)
        epoch_accuracy /= len(trainloader)
        Loss_list.append(epoch_loss)
        Accuracy_list.append(epoch_accuracy)
        torch.save(net.state_dict(), osp.join(model_save_root_path, 'resnet50_{}_{}.pth'.format(epoch + 1,
                                                                                                time.strftime(
                                                                                                    "%Y-%m-%d_%H-%M-%S"))))
    draw_loss_and_accuracy(Loss_list, Accuracy_list, train_epoch)


def resnet_eval(single_image = False,img_path = None):
    resnet = models.resnet50(pretrained=True)
    channel_in = resnet.fc.in_features
    class_num = 2
    resnet.fc = nn.Sequential(
        nn.Linear(channel_in, 256),
        nn.ReLU(),
        nn.Dropout(0.4),
        nn.Linear(256, class_num),
        nn.LogSoftmax(dim=1)
    )
    resnet.load_state_dict(
        torch.load(osp.join(model_save_root_path, 'resnet50_100_2020-04-10_01-07-34.pth')))  # 这里填最新的模型的名字
    if single_image == False:
        resnet.eval()
        val_loader = loadTestData()
        criterion = torch.nn.CrossEntropyLoss()
        sum_accuracy = 0
        for i, (inputs, target) in enumerate(val_loader):
            with torch.no_grad():
                output = resnet(inputs)
                loss = criterion(output, target)
                prec = accuracy(output.data, target)
                sum_accuracy += prec
                print('for {}th batchsize, Eval:Accuracy:{}%,loss:{},input:{},output:{}'.format(i + 1, prec, loss,
                                                                                                inputs.size(),
                                                                                   output.size()))
        sum_accuracy /= len(val_loader)
        print('sum of accuracy = {}'.format(sum_accuracy))
    else:
        transform = transforms.Compose(
            [visiontransforms.Resize(256),
             visiontransforms.CenterCrop(224),
             visiontransforms.ToTensor()])
        image_PIL = Image.open(img_path)
        img_tensor = transform(image_PIL)
        img_tensor.unsqueeze_(0)
        result = resnet(img_tensor)
        result = result.detach().numpy()
        result = np.array(result)
        pos = np.unravel_index(np.argmax(result[0]), result.shape)
        pre_label = pos[1]
        if pre_label == 0:
            pre_label = '蚂蚁'
        else:
            pre_label = '蜜蜂'
        print('predicted label is {}'.format(pre_label))


if __name__ == '__main__':
    # train_epoch = 100
    # print_freq = 5
    # learning_rate_start = 0.005
    # resnet_finute(train_epoch, print_freq, learning_rate_start)
    resnet_eval(True, 'test1.jpg')

上面的代码引入了一些自己写的函数,我把他们放在自己建的同级的util文件夹下

一个是time_utils.py,内容为

import time


def print_log(print_string, log):
    print("{}".format(print_string))
    if log is not None:
        log.write('{}\n'.format(print_string))
        log.flush()


def time_for_file():
    ISOTIMEFORMAT = '%d-%h-at-%H-%M-%S'
    return '{}'.format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time())))

另一个是visualiztion.py(这个单词命名时候拼写错了哈哈),内容是

import matplotlib.pyplot as plt

loss_img_save = './accuracy_loss.jpg'


def draw_loss_and_accuracy(Loss_list, Accuracy_list, train_epoch):

    x1 = range(0, len(Loss_list))
    x2 = range(0, len(Accuracy_list))
    y1 = Accuracy_list
    y2 = Loss_list
    plt.subplot(2, 1, 1)
    plt.plot(x1, y1, 'o-')
    plt.title('Test accuracy vs. epoches')
    plt.ylabel('Test accuracy')
    plt.subplot(2, 1, 2)
    plt.plot(x2, y2, '.-')
    plt.xlabel('Test loss vs. epoches')
    plt.ylabel('Test loss')
    plt.show()
    plt.savefig(loss_img_save)

有问题可以联系我哦~第一篇csdn,接下来想把印象笔记的内容搬过来!ヾ(◍°∇°◍)ノ゙加油!

你可能感兴趣的:(深度学习,深度学习)