pytorch:pokemon+resnet详细代码+数据集

pokemon数据集请戳:缦旋律的资源合集

文章目录

  • 一.定义一个Pokemon的类,用于获取图片以及对应的label
  • 二.搭建自己的Resnet
    • 1.构建resblock:
    • 2.搭建Resnet:
  • 三.设置一些超参数
  • 四.载入数据
  • 五.初始化模型,设置loss_function/optimizer/evaluation
  • 六.开始训练,并进行检验

import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from torch import optim
import os
import csv
from PIL import Image
import warnings
warnings.simplefilter('ignore')

一.定义一个Pokemon的类,用于获取图片以及对应的label

对于自定义数据集,并使用DataLoader划分batch不熟悉的,可以戳:
自定义数据集+DataLoader.

class Pokemon(Dataset):
    def __init__(self,root,resize,mode): #root是文件路径,resize是对原始图片进行裁剪,mode是选择模式(train、test、validation)
        super(Pokemon,self).__init__()
        self.root = root
        self.resize = resize
        self.name2label = {} #给每个种类分配一个数字,以该数字作为这一类别的label
        #name是宝可梦的种类,e.g:pikachu
        for name in sorted(os.listdir(os.path.join(self.root))): #listdir返回的顺序不固定,加上一个sorted使每一次的顺序都一样
            if not os.path.isdir(os.path.join(self.root,name)):#os.path.isdir()用于判断括号中的内容是否是一个未压缩的文件夹
                continue
            self.name2label[name] = len(self.name2label.keys())
        print(self.name2label)
        
        self.images,self.labels = self.load_csv('images&labels.csv')
        #将全部数据分成train、validation、test
        if mode == 'train': #前60%作为训练集
            self.images = self.images[:int(0.6*len(self.images))]
            self.labels = self.labels[:int(0.6*len(self.labels))]
        elif mode == 'val': #60%~80%作为validation
            self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
        else: #后20%作为test set
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]
        
        
    
    def load_csv(self,filename): 
    #载入原始图片的路径,并保存到指定的CSV文件中,然后从该CSV文件中再次读入所有图片的存储路径和label。
    #如果CSV文件已经存在,则直接读入该CSV文件的内容
    #为什么保存的是图片的路径而不是图片?因为直接保存图片可能会造成内存爆炸
        
        if not os.path.exists(os.path.join(self.root,filename)): #如果filename这个文件不存在,那么执行以下代码,创建file
            images = []
            for name in self.name2label.keys():
                #glob.glob()返回的是括号中的路径中的所有文件的路径 
                # += 是把glob.glob()返回的结果依次append到image中,而不是以一个整体append
                # 这里只用了png/jpg/jepg是因为本次实验的图片只有这三种格式,如果有其他格式请自行添加
                images += glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root,name,'*.jpg'))
                images += glob.glob(os.path.join(self.root,name,'*.jpeg'))
            print(len(images))
            random.shuffle(images) #把所有图片路径顺序打乱
            with open(os.path.join(self.root,filename),mode='w',newline='') as f: #将图片路径及其对应的数字标签写到指定文件中
                writer = csv.writer(f)
                for img in images: #img e.g:'./pokemon/pikachu\\00000001.png'
                    name = img.split(os.sep)[-2] #即取出‘pikachu’
                    label = self.name2label[name] #根据name找到对应的数字标签
                    writer.writerow([img,label]) #把每张图片的路径和它对应的数字标签写到指定的CSV文件中
                print('image paths and labels have been writen into csv file:',filename)
        
        
        #把数据读入(如果filename存在就直接执行这一步,如果不存在就先创建file再读入数据)
        images,labels = [],[]
        with open(os.path.join(self.root,filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img,label = row
                label = int(label)
                
                images.append(img)
                labels.append(label)
                
        assert len(images) == len(labels) #确保它们长度一致
        
        return images,labels
                
         
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self,idx):
        img,label = self.images[idx],self.labels[idx]#此时img还是路径字符串,要把它转化成tensor
        #将图片resize成224*224,并转化成tensor,这个tensor的size是3*224*224(3是因为有RGB3个通道)
        trans = transforms.Compose((
            lambda x: Image.open(x).convert('RGB'),
            transforms.Resize((self.resize,self.resize)), #必须要把长宽都一起写上啊!!!
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]) #这个数据是根据resnet中的图片统计得到的,直接拿来用就好
            )) 
        img = trans(img)
        label = torch.tensor(label)
        return img,label 

二.搭建自己的Resnet

1.构建resblock:

对于resnet的构建不熟悉的,可以戳:
cifar-10+resnet 详细代码+解释.

class resblock(nn.Module):
    def __init__(self,ch_in,ch_out,stride=1):
        super(resblock,self).__init__()
        self.conv_1 = nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)
        self.bn_1 = nn.BatchNorm2d(ch_out)
        self.conv_2 = nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)
        self.bn_2 = nn.BatchNorm2d(ch_out)
        self.ch_trans = nn.Sequential()
        if ch_in != ch_out:
            self.ch_trans = nn.Sequential(nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),nn.BatchNorm2d(ch_out))
        #ch_trans表示通道数转变。因为要做short_cut,所以x_pro和x_ch的size应该完全一致
        
    def  forward(self,x):
        x_pro = F.relu(self.bn_1(self.conv_1(x)))
        x_pro = self.bn_2(self.conv_2(x_pro))
        
        #short_cut:
        x_ch = self.ch_trans(x)
        out = x_pro + x_ch
        out = F.relu(out)
        return out    

2.搭建Resnet:

class Resnet18(nn.Module):
    def __init__(self,num_class):
        super(Resnet18,self).__init__()
        self.conv_1 = nn.Sequential(
        nn.Conv2d(3,16,kernel_size=3,stride=3,padding=0),
        nn.BatchNorm2d(16))
        self.block1 = resblock(16,32,3) 
        self.block2 = resblock(32,64,3) 
        self.block3 = resblock(64,128,2)
        self.block4 = resblock(128,256,2)
        self.outlayer = nn.Linear(256*3*3,num_class)#这个256*3*3是根据forward中x经过4个resblock之后来决定的
        
    def forward(self,x):
        x = F.relu(self.conv_1(x))
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = x.reshape(x.size(0),-1)
        result = self.outlayer(x)
        return result      

三.设置一些超参数

batch_size = 32
lr = 1e-3
device = torch.device('cuda')
torch.manual_seed(1234)

四.载入数据

train_db = Pokemon('./pokemon',224,'train') #将所有图片(顺序已打乱)的前60%作为train_set
val_db = Pokemon('./pokemon',224,'val')  #60%~80%作为validation_set
test_db = Pokemon('./pokemon',224,'test') #80%~100%作为test_set
train_loader = DataLoader(train_db,batch_size=batch_size,shuffle=True) #之后调用一次train_loader就会把train_db划分成很多batch
val_loader = DataLoader(val_db,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(test_db,batch_size=batch_size,shuffle=True)
{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}

五.初始化模型,设置loss_function/optimizer/evaluation

model = Resnet18(5).to(device) #模型初始化,5代表一共有5种类别
print('模型需要训练的参数共有{}个'.format(sum(map(lambda p:p.numel(),model.parameters()))))
loss_fn = nn.CrossEntropyLoss() #选择loss_function
optimizer = optim.Adam(model.parameters(),lr=lr) #选择优化方式

六.开始训练,并进行检验

# 开始训练前,先定义一个evaluate函数。evaluate用于检测模型的预测效果,validation_set和test_set是同样的evaluate方法
def evaluate(model,loader):
    correct_num = 0
    total_num = len(loader.dataset)
    for img,label in loader: #lodaer中包含了很多batch,每个batch有32张图片
        img,label = img.to(device),label.to(device)
        with torch.no_grad():
            logits = model(img)
            pre_label = logits.argmax(dim=1)
        correct_num += torch.eq(pre_label,label).sum().float().item()
    
    return correct_num/total_num 

#开始训练:
best_epoch,best_acc = 0,0
for epoch in range(10): #时间关系,我们只训练10个epoch
    for batch_num,(img,label) in enumerate(train_loader):
        #img.size [b,3,224,224]  label.size [b]
        img,label = img.to(device),label.to(device)
        logits = model(img)
        loss = loss_fn(logits,label)
        if batch_num%5 == 0:
            print('这是第{}次迭代的第{}个batch,loss是{}'.format(epoch+1,batch_num+1,loss.item()))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    if epoch%2==0: #这里设置的是每训练两次epoch就进行一次validation
        val_acc = evaluate(model,val_loader)
        #如果val_acc比之前的好,那么就把该epoch保存下来,并把此时模型的参数保存到指定txt文件里
        if val_acc>best_acc:
            print('验证集上的准确率是:{}'.format(val_acc))
            best_epoch = epoch
            best_acc = val_acc
            torch.save(model.state_dict(),'pokemon_ckp.txt')
    

print('best_acc:{},best_epoch:{}'.format(best_acc,best_epoch))
model.load_state_dict(torch.load('pokemon_ckp.txt'))

#开始检验:
print('模型训练完毕,已将参数设置成训练过程中的最优值,现在开始测试test_set')
test_acc = evaluate(model,test_loader)
print('测试集上的准确率是:{}'.format(test_acc))
这是第1次迭代的第1个batch,loss是0.09284534305334091
这是第1次迭代的第6个batch,loss是0.048977259546518326
这是第1次迭代的第11个batch,loss是0.012712553143501282
这是第1次迭代的第16个batch,loss是0.29784664511680603
这是第1次迭代的第21个batch,loss是0.04697053134441376
验证集上的准确率是:0.9055793991416309
这是第2次迭代的第1个batch,loss是0.1201871708035469
这是第2次迭代的第6个batch,loss是0.17532214522361755
这是第2次迭代的第11个batch,loss是0.4216356873512268
这是第2次迭代的第16个batch,loss是0.25439736247062683
这是第2次迭代的第21个batch,loss是0.08713945746421814
这是第3次迭代的第1个batch,loss是0.29146236181259155
这是第3次迭代的第6个batch,loss是0.054789893329143524
这是第3次迭代的第11个batch,loss是0.4522630572319031
这是第3次迭代的第16个batch,loss是0.08970004320144653
这是第3次迭代的第21个batch,loss是0.3429204821586609
这是第4次迭代的第1个batch,loss是0.06742320954799652
这是第4次迭代的第6个batch,loss是0.12285976856946945
这是第4次迭代的第11个batch,loss是0.15735840797424316
这是第4次迭代的第16个batch,loss是0.07834229618310928
这是第4次迭代的第21个batch,loss是0.20532763004302979
这是第5次迭代的第1个batch,loss是0.00593993067741394
这是第5次迭代的第6个batch,loss是0.03216344118118286
这是第5次迭代的第11个batch,loss是0.03481002524495125
这是第5次迭代的第16个batch,loss是0.15314869582653046
这是第5次迭代的第21个batch,loss是0.08527624607086182
这是第6次迭代的第1个batch,loss是0.05515890568494797
这是第6次迭代的第6个batch,loss是0.036611974239349365
这是第6次迭代的第11个batch,loss是0.007195517420768738
这是第6次迭代的第16个batch,loss是0.05695120990276337
这是第6次迭代的第21个batch,loss是0.15042126178741455
这是第7次迭代的第1个batch,loss是0.1088687926530838
这是第7次迭代的第6个batch,loss是0.002063468098640442
这是第7次迭代的第11个batch,loss是0.01613890379667282
这是第7次迭代的第16个batch,loss是0.012490876019001007
这是第7次迭代的第21个batch,loss是0.48446154594421387
验证集上的准确率是:0.9141630901287554
这是第8次迭代的第1个batch,loss是0.10298655182123184
这是第8次迭代的第6个batch,loss是0.05644068121910095
这是第8次迭代的第11个batch,loss是0.0563386008143425
这是第8次迭代的第16个batch,loss是0.00903283804655075
这是第8次迭代的第21个batch,loss是0.08256962895393372
这是第9次迭代的第1个batch,loss是0.014249928295612335
这是第9次迭代的第6个batch,loss是0.013826802372932434
这是第9次迭代的第11个batch,loss是0.0016943514347076416
这是第9次迭代的第16个batch,loss是0.1954154521226883
这是第9次迭代的第21个batch,loss是0.056067951023578644
这是第10次迭代的第1个batch,loss是0.014393903315067291
这是第10次迭代的第6个batch,loss是0.26919856667518616
这是第10次迭代的第11个batch,loss是0.03811478987336159
这是第10次迭代的第16个batch,loss是0.18677780032157898
这是第10次迭代的第21个batch,loss是0.018675178289413452
best_acc:0.9141630901287554,best_epoch:6
模型训练完毕,已将参数设置成训练过程中的最优值,现在开始测试test_set
测试集上的准确率是:0.9017094017094017

可以看到,最优的epoch是6,即第7次迭代时候的模型,该模型在validation_set上的正确率是0.914,在test_set上的准确率是0.902


pytorch:pokemon+resnet详细代码+数据集_第1张图片

你可能感兴趣的:(pytorch)