pytorch:基于resnet18的transfer_learning(pokemon数据集实例分析+详细代码+完整数据集)

       本文大部分内容和pokemon+自己搭建resnet这篇一样,只有在模型部分(即第四部分)不太一样:本文用的是已经搭建好的resnet18的前17层做transfer_learning,而之前这篇是自己搭建的resnet。

       pokemon数据集请戳:缦旋律的资源合集

文章目录

  • 一.定义一个Pokemon的类,用于获取图片以及对应的label
  • 二.设置一些超参数
  • 三.载入数据
  • 四.初始化模型,设置loss_function/optimizer/evaluation
  • 五.开始训练并检验

import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from torch import optim
import os
import csv
from PIL import Image
import warnings
warnings.simplefilter('ignore')
from torchvision.models import resnet18

一.定义一个Pokemon的类,用于获取图片以及对应的label

class Pokemon(Dataset):
    def __init__(self,root,resize,mode): #root是文件路径,resize是对原始图片进行裁剪,mode是选择模式(train、test、validation)
        super(Pokemon,self).__init__()
        self.root = root
        self.resize = resize
        self.name2label = {} #给每个种类分配一个数字,以该数字作为这一类别的label
        #name是宝可梦的种类,e.g:pikachu
        for name in sorted(os.listdir(os.path.join(self.root))): #listdir返回的顺序不固定,加上一个sorted使每一次的顺序都一样
            if not os.path.isdir(os.path.join(self.root,name)):#os.path.isdir()用于判断括号中的内容是否是一个未压缩的文件夹
                continue
            self.name2label[name] = len(self.name2label.keys())
        print(self.name2label)
        
        self.images,self.labels = self.load_csv('images&labels.csv')
        #将全部数据分成train、validation、test
        if mode == 'train': #前60%作为训练集
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif mode == 'val': #60%~80%作为validation
            self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
        else: #后20%作为test set
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]
        
        
    
    def load_csv(self,filename): 
    #载入原始图片的路径,并保存到指定的CSV文件中,然后从该CSV文件中再次读入所有图片的存储路径和label。
    #如果CSV文件已经存在,则直接读入该CSV文件的内容
    #为什么保存的是图片的路径而不是图片?因为直接保存图片可能会造成内存爆炸
        
        if not os.path.exists(os.path.join(self.root,filename)): #如果filename这个文件不存在,那么执行以下代码,创建file
            images = []
            for name in self.name2label.keys():
                #glob.glob()返回的是括号中的路径中的所有文件的路径 
                # += 是把glob.glob()返回的结果依次append到image中,而不是以一个整体append
                # 这里只用了png/jpg/jepg是因为本次实验的图片只有这三种格式,如果有其他格式请自行添加
                images += glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root,name,'*.jpg'))
                images += glob.glob(os.path.join(self.root,name,'*.jpeg'))
            print(len(images))
            random.shuffle(images) #把所有图片路径顺序打乱
            with open(os.path.join(self.root,filename),mode='w',newline='') as f: #将图片路径及其对应的数字标签写到指定文件中
                writer = csv.writer(f)
                for img in images: #img e.g:'./pokemon/pikachu\\00000001.png'
                    name = img.split(os.sep)[-2] #即取出‘pikachu’
                    label = self.name2label[name] #根据name找到对应的数字标签
                    writer.writerow([img,label]) #把每张图片的路径和它对应的数字标签写到指定的CSV文件中
                print('image paths and labels have been writen into csv file:',filename)
        
        
        #把数据读入(如果filename存在就直接执行这一步,如果不存在就先创建file再读入数据)
        images,labels = [],[]
        with open(os.path.join(self.root,filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img,label = row
                label = int(label)
                
                images.append(img)
                labels.append(label)
                
        assert len(images) == len(labels) #确保它们长度一致
        
        return images,labels
                
         
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self,idx):
        img,label = self.images[idx],self.labels[idx]#此时img还是路径字符串,要把它转化成tensor
        #将图片resize成224*224,并转化成tensor,这个tensor的size是3*224*224(3是因为有RGB3个通道)
        trans = transforms.Compose((
            lambda x: Image.open(x).convert('RGB'),
            transforms.Resize((self.resize,self.resize)), #必须要把长宽都一起写上啊!!!
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]) #这个数据是根据resnet中的图片统计得到的,直接拿来用就好
            )) 
        img = trans(img)
        label = torch.tensor(label)
        return img,label 

二.设置一些超参数

batch_size = 32
lr = 1e-3
device = torch.device('cuda')
torch.manual_seed(1234)

三.载入数据

train_db = Pokemon('./pokemon',224,'train') #将所有图片(顺序已打乱)的前60%作为train_set
val_db = Pokemon('./pokemon',224,'val')  #60%~80%作为validation_set
test_db = Pokemon('./pokemon',224,'test') #80%~100%作为test_set
train_loader = DataLoader(train_db,batch_size=batch_size,shuffle=True) #之后调用一次train_loader就会把train_db划分成很多batch
val_loader = DataLoader(val_db,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(test_db,batch_size=batch_size,shuffle=True)
{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}

四.初始化模型,设置loss_function/optimizer/evaluation

#首先定义一个Flatten类,用于后面的打平操作
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten,self).__init__()
    def forward(self,x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.reshape(-1,shape)

#初始化模型
trained_model = resnet18(pretrained = True) #拿到已经训练好的resnet18模型
model = nn.Sequential(*list(trained_model.children())[:-1], #拿出resnet18的前面17层,输出的size是b*512*1*1
                      Flatten(), #经过flatten之后的size是b*512
                      nn.Linear(512,5)).to(device)
print('模型需要训练的参数共有{}个'.format(sum(map(lambda p:p.numel(),model.parameters()))))

loss_fn = nn.CrossEntropyLoss() #选择loss_function

optimizer = optim.Adam(model.parameters(),lr=lr) #选择优化方式
模型需要训练的参数共有11179077个

flatten是进行打平操作,为什么不能像之前自己搭建resnet18那样,直接reshape就好?
因为这里是放到nn.Sequential()里面的,括号里面的必须是nn.Module里面的类,或者是继承了nn.Module的子类
所以我们这里得自己写一个Flatten的类,并让它以nn.Module为父类。

如果不放到nn.Module()中,那么就可以先让x经过前17层,得到一个输出(记为x_pro),然后x_pro.reshape(x_pro.size(0),-1),最后接一个linear就OK。

五.开始训练并检验

# 开始训练之前,先定义一个evaluate函数。evaluate用于检测模型的预测效果,validation_set和test_set是同样的evaluate方法
def evaluate(model,loader):
    correct_num = 0
    total_num = len(loader.dataset)
    for img,label in loader: #lodaer中包含了很多batch,每个batch有32张图片
        img,label = img.to(device),label.to(device)
        with torch.no_grad():
            logits = model(img)
            pre_label = logits.argmax(dim=1)
        correct_num += torch.eq(pre_label,label).sum().float().item()
    
    return correct_num/total_num 


#开始训练
best_epoch,best_acc = 0,0
for epoch in range(10): #时间关系,我们只训练10个epoch
    for batch_num,(img,label) in enumerate(train_loader):
        #img.size [b,3,224,224]  label.size [b]
        img,label = img.to(device),label.to(device)
        logits = model(img)
        loss = loss_fn(logits,label)
        if batch_num%5 == 0:
            print('这是第{}次迭代的第{}个batch,loss是{}'.format(epoch+1,batch_num+1,loss.item()))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        

    val_acc = evaluate(model,val_loader)
        #如果val_acc比之前的好,那么就把该epoch保存下来,并把此时模型的参数保存到指定txt文件里
    if val_acc>best_acc:
        print('验证集上的准确率是:{}'.format(val_acc))
        best_epoch = epoch
        best_acc = val_acc
        torch.save(model.state_dict(),'pokemon_ckp.txt')
    

print('best_acc:{},best_epoch:{}'.format(best_acc,best_epoch))
model.load_state_dict(torch.load('pokemon_ckp.txt'))

#开始检验
print('模型训练完毕,已将参数设置成训练过程中的最优值,现在开始测试test_set')
test_acc = evaluate(model,test_loader)
print('测试集上的准确率是:{}'.format(test_acc))
这是第1次迭代的第1个batch,loss是1.8974512815475464
这是第1次迭代的第6个batch,loss是0.3152352571487427
这是第1次迭代的第11个batch,loss是0.31969892978668213
这是第1次迭代的第16个batch,loss是0.827768862247467
这是第1次迭代的第21个batch,loss是0.06569187343120575
验证集上的准确率是:0.9399141630901288
这是第2次迭代的第1个batch,loss是0.27959883213043213
这是第2次迭代的第6个batch,loss是0.26758652925491333
这是第2次迭代的第11个batch,loss是0.5397248268127441
这是第2次迭代的第16个batch,loss是0.26908379793167114
这是第2次迭代的第21个batch,loss是0.2528558373451233
这是第3次迭代的第1个batch,loss是0.5769810080528259
这是第3次迭代的第6个batch,loss是0.17315296828746796
这是第3次迭代的第11个batch,loss是0.19980908930301666
这是第3次迭代的第16个batch,loss是0.1564580649137497
这是第3次迭代的第21个batch,loss是0.021813027560710907
这是第4次迭代的第1个batch,loss是0.20928955078125
这是第4次迭代的第6个batch,loss是0.09454512596130371
这是第4次迭代的第11个batch,loss是0.026858791708946228
这是第4次迭代的第16个batch,loss是0.09628774225711823
这是第4次迭代的第21个batch,loss是0.22692246735095978
验证集上的准确率是:0.9484978540772532
这是第5次迭代的第1个batch,loss是0.04763159155845642
这是第5次迭代的第6个batch,loss是0.026739276945590973
这是第5次迭代的第11个batch,loss是0.4837387800216675
这是第5次迭代的第16个batch,loss是0.0742536336183548
这是第5次迭代的第21个batch,loss是0.1805519163608551
这是第6次迭代的第1个batch,loss是0.26089876890182495
这是第6次迭代的第6个batch,loss是0.04913238435983658
这是第6次迭代的第11个batch,loss是0.23098143935203552
这是第6次迭代的第16个batch,loss是0.055031076073646545
这是第6次迭代的第21个batch,loss是0.2681158483028412
这是第7次迭代的第1个batch,loss是0.09300532191991806
这是第7次迭代的第6个batch,loss是0.20092912018299103
这是第7次迭代的第11个batch,loss是0.016669772565364838
这是第7次迭代的第16个batch,loss是0.019372448325157166
这是第7次迭代的第21个batch,loss是0.025167152285575867
这是第8次迭代的第1个batch,loss是0.16009360551834106
这是第8次迭代的第6个batch,loss是0.05369710177183151
这是第8次迭代的第11个batch,loss是0.02474011480808258
这是第8次迭代的第16个batch,loss是0.22973166406154633
这是第8次迭代的第21个batch,loss是0.0449075773358345
验证集上的准确率是:0.9699570815450643
这是第9次迭代的第1个batch,loss是0.015333056449890137
这是第9次迭代的第6个batch,loss是0.07510494440793991
这是第9次迭代的第11个batch,loss是0.04943542182445526
这是第9次迭代的第16个batch,loss是0.34347304701805115
这是第9次迭代的第21个batch,loss是0.11908939480781555
验证集上的准确率是:0.9785407725321889
这是第10次迭代的第1个batch,loss是0.02673729509115219
这是第10次迭代的第6个batch,loss是0.013404056429862976
这是第10次迭代的第11个batch,loss是0.09280069917440414
这是第10次迭代的第16个batch,loss是0.04911276698112488
这是第10次迭代的第21个batch,loss是0.10042841732501984
best_acc:0.9785407725321889,best_epoch:8
模型训练完毕,已将参数设置成训练过程中的最优值,现在开始测试test_set
测试集上的准确率是:0.9743589743589743

     可以看到,最优的epoch是8,即第九次训练时的模型,该模型在validation_set上的准确率是0.9785,在test_set上的准确率是0.9744.

     因此,和直接训练自己搭建的Resnet18相比(test_set上的准确率是0.902),通过transfer_learning学到的模型的效果有较为明显的提升。

pytorch:基于resnet18的transfer_learning(pokemon数据集实例分析+详细代码+完整数据集)_第1张图片

你可能感兴趣的:(pytorch)