本文大部分内容和pokemon+自己搭建resnet这篇一样,只有在模型部分(即第四部分)不太一样:本文用的是已经搭建好的resnet18的前17层做transfer_learning,而之前这篇是自己搭建的resnet。
pokemon数据集请戳:缦旋律的资源合集
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from torch import optim
import os
import csv
from PIL import Image
import warnings
warnings.simplefilter('ignore')
from torchvision.models import resnet18
class Pokemon(Dataset):
def __init__(self,root,resize,mode): #root是文件路径,resize是对原始图片进行裁剪,mode是选择模式(train、test、validation)
super(Pokemon,self).__init__()
self.root = root
self.resize = resize
self.name2label = {} #给每个种类分配一个数字,以该数字作为这一类别的label
#name是宝可梦的种类,e.g:pikachu
for name in sorted(os.listdir(os.path.join(self.root))): #listdir返回的顺序不固定,加上一个sorted使每一次的顺序都一样
if not os.path.isdir(os.path.join(self.root,name)):#os.path.isdir()用于判断括号中的内容是否是一个未压缩的文件夹
continue
self.name2label[name] = len(self.name2label.keys())
print(self.name2label)
self.images,self.labels = self.load_csv('images&labels.csv')
#将全部数据分成train、validation、test
if mode == 'train': #前60%作为训练集
self.images = self.images[:int(0.6*len(self.images))]
self.labels = self.labels[:int(0.6*len(self.labels))]
elif mode == 'val': #60%~80%作为validation
self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
else: #后20%作为test set
self.images = self.images[int(0.8*len(self.images)):]
self.labels = self.labels[int(0.8*len(self.labels)):]
def load_csv(self,filename):
#载入原始图片的路径,并保存到指定的CSV文件中,然后从该CSV文件中再次读入所有图片的存储路径和label。
#如果CSV文件已经存在,则直接读入该CSV文件的内容
#为什么保存的是图片的路径而不是图片?因为直接保存图片可能会造成内存爆炸
if not os.path.exists(os.path.join(self.root,filename)): #如果filename这个文件不存在,那么执行以下代码,创建file
images = []
for name in self.name2label.keys():
#glob.glob()返回的是括号中的路径中的所有文件的路径
# += 是把glob.glob()返回的结果依次append到image中,而不是以一个整体append
# 这里只用了png/jpg/jepg是因为本次实验的图片只有这三种格式,如果有其他格式请自行添加
images += glob.glob(os.path.join(self.root,name,'*.png'))
images += glob.glob(os.path.join(self.root,name,'*.jpg'))
images += glob.glob(os.path.join(self.root,name,'*.jpeg'))
print(len(images))
random.shuffle(images) #把所有图片路径顺序打乱
with open(os.path.join(self.root,filename),mode='w',newline='') as f: #将图片路径及其对应的数字标签写到指定文件中
writer = csv.writer(f)
for img in images: #img e.g:'./pokemon/pikachu\\00000001.png'
name = img.split(os.sep)[-2] #即取出‘pikachu’
label = self.name2label[name] #根据name找到对应的数字标签
writer.writerow([img,label]) #把每张图片的路径和它对应的数字标签写到指定的CSV文件中
print('image paths and labels have been writen into csv file:',filename)
#把数据读入(如果filename存在就直接执行这一步,如果不存在就先创建file再读入数据)
images,labels = [],[]
with open(os.path.join(self.root,filename)) as f:
reader = csv.reader(f)
for row in reader:
img,label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels) #确保它们长度一致
return images,labels
def __len__(self):
return len(self.images)
def __getitem__(self,idx):
img,label = self.images[idx],self.labels[idx]#此时img还是路径字符串,要把它转化成tensor
#将图片resize成224*224,并转化成tensor,这个tensor的size是3*224*224(3是因为有RGB3个通道)
trans = transforms.Compose((
lambda x: Image.open(x).convert('RGB'),
transforms.Resize((self.resize,self.resize)), #必须要把长宽都一起写上啊!!!
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225]) #这个数据是根据resnet中的图片统计得到的,直接拿来用就好
))
img = trans(img)
label = torch.tensor(label)
return img,label
batch_size = 32
lr = 1e-3
device = torch.device('cuda')
torch.manual_seed(1234)
train_db = Pokemon('./pokemon',224,'train') #将所有图片(顺序已打乱)的前60%作为train_set
val_db = Pokemon('./pokemon',224,'val') #60%~80%作为validation_set
test_db = Pokemon('./pokemon',224,'test') #80%~100%作为test_set
train_loader = DataLoader(train_db,batch_size=batch_size,shuffle=True) #之后调用一次train_loader就会把train_db划分成很多batch
val_loader = DataLoader(val_db,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(test_db,batch_size=batch_size,shuffle=True)
{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
#首先定义一个Flatten类,用于后面的打平操作
class Flatten(nn.Module):
def __init__(self):
super(Flatten,self).__init__()
def forward(self,x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.reshape(-1,shape)
#初始化模型
trained_model = resnet18(pretrained = True) #拿到已经训练好的resnet18模型
model = nn.Sequential(*list(trained_model.children())[:-1], #拿出resnet18的前面17层,输出的size是b*512*1*1
Flatten(), #经过flatten之后的size是b*512
nn.Linear(512,5)).to(device)
print('模型需要训练的参数共有{}个'.format(sum(map(lambda p:p.numel(),model.parameters()))))
loss_fn = nn.CrossEntropyLoss() #选择loss_function
optimizer = optim.Adam(model.parameters(),lr=lr) #选择优化方式
模型需要训练的参数共有11179077个
flatten是进行打平操作,为什么不能像之前自己搭建resnet18那样,直接reshape就好?
因为这里是放到nn.Sequential()里面的,括号里面的必须是nn.Module里面的类,或者是继承了nn.Module的子类
所以我们这里得自己写一个Flatten的类,并让它以nn.Module为父类。
如果不放到nn.Module()中,那么就可以先让x经过前17层,得到一个输出(记为x_pro),然后x_pro.reshape(x_pro.size(0),-1),最后接一个linear就OK。
# 开始训练之前,先定义一个evaluate函数。evaluate用于检测模型的预测效果,validation_set和test_set是同样的evaluate方法
def evaluate(model,loader):
correct_num = 0
total_num = len(loader.dataset)
for img,label in loader: #lodaer中包含了很多batch,每个batch有32张图片
img,label = img.to(device),label.to(device)
with torch.no_grad():
logits = model(img)
pre_label = logits.argmax(dim=1)
correct_num += torch.eq(pre_label,label).sum().float().item()
return correct_num/total_num
#开始训练
best_epoch,best_acc = 0,0
for epoch in range(10): #时间关系,我们只训练10个epoch
for batch_num,(img,label) in enumerate(train_loader):
#img.size [b,3,224,224] label.size [b]
img,label = img.to(device),label.to(device)
logits = model(img)
loss = loss_fn(logits,label)
if batch_num%5 == 0:
print('这是第{}次迭代的第{}个batch,loss是{}'.format(epoch+1,batch_num+1,loss.item()))
optimizer.zero_grad()
loss.backward()
optimizer.step()
val_acc = evaluate(model,val_loader)
#如果val_acc比之前的好,那么就把该epoch保存下来,并把此时模型的参数保存到指定txt文件里
if val_acc>best_acc:
print('验证集上的准确率是:{}'.format(val_acc))
best_epoch = epoch
best_acc = val_acc
torch.save(model.state_dict(),'pokemon_ckp.txt')
print('best_acc:{},best_epoch:{}'.format(best_acc,best_epoch))
model.load_state_dict(torch.load('pokemon_ckp.txt'))
#开始检验
print('模型训练完毕,已将参数设置成训练过程中的最优值,现在开始测试test_set')
test_acc = evaluate(model,test_loader)
print('测试集上的准确率是:{}'.format(test_acc))
这是第1次迭代的第1个batch,loss是1.8974512815475464
这是第1次迭代的第6个batch,loss是0.3152352571487427
这是第1次迭代的第11个batch,loss是0.31969892978668213
这是第1次迭代的第16个batch,loss是0.827768862247467
这是第1次迭代的第21个batch,loss是0.06569187343120575
验证集上的准确率是:0.9399141630901288
这是第2次迭代的第1个batch,loss是0.27959883213043213
这是第2次迭代的第6个batch,loss是0.26758652925491333
这是第2次迭代的第11个batch,loss是0.5397248268127441
这是第2次迭代的第16个batch,loss是0.26908379793167114
这是第2次迭代的第21个batch,loss是0.2528558373451233
这是第3次迭代的第1个batch,loss是0.5769810080528259
这是第3次迭代的第6个batch,loss是0.17315296828746796
这是第3次迭代的第11个batch,loss是0.19980908930301666
这是第3次迭代的第16个batch,loss是0.1564580649137497
这是第3次迭代的第21个batch,loss是0.021813027560710907
这是第4次迭代的第1个batch,loss是0.20928955078125
这是第4次迭代的第6个batch,loss是0.09454512596130371
这是第4次迭代的第11个batch,loss是0.026858791708946228
这是第4次迭代的第16个batch,loss是0.09628774225711823
这是第4次迭代的第21个batch,loss是0.22692246735095978
验证集上的准确率是:0.9484978540772532
这是第5次迭代的第1个batch,loss是0.04763159155845642
这是第5次迭代的第6个batch,loss是0.026739276945590973
这是第5次迭代的第11个batch,loss是0.4837387800216675
这是第5次迭代的第16个batch,loss是0.0742536336183548
这是第5次迭代的第21个batch,loss是0.1805519163608551
这是第6次迭代的第1个batch,loss是0.26089876890182495
这是第6次迭代的第6个batch,loss是0.04913238435983658
这是第6次迭代的第11个batch,loss是0.23098143935203552
这是第6次迭代的第16个batch,loss是0.055031076073646545
这是第6次迭代的第21个batch,loss是0.2681158483028412
这是第7次迭代的第1个batch,loss是0.09300532191991806
这是第7次迭代的第6个batch,loss是0.20092912018299103
这是第7次迭代的第11个batch,loss是0.016669772565364838
这是第7次迭代的第16个batch,loss是0.019372448325157166
这是第7次迭代的第21个batch,loss是0.025167152285575867
这是第8次迭代的第1个batch,loss是0.16009360551834106
这是第8次迭代的第6个batch,loss是0.05369710177183151
这是第8次迭代的第11个batch,loss是0.02474011480808258
这是第8次迭代的第16个batch,loss是0.22973166406154633
这是第8次迭代的第21个batch,loss是0.0449075773358345
验证集上的准确率是:0.9699570815450643
这是第9次迭代的第1个batch,loss是0.015333056449890137
这是第9次迭代的第6个batch,loss是0.07510494440793991
这是第9次迭代的第11个batch,loss是0.04943542182445526
这是第9次迭代的第16个batch,loss是0.34347304701805115
这是第9次迭代的第21个batch,loss是0.11908939480781555
验证集上的准确率是:0.9785407725321889
这是第10次迭代的第1个batch,loss是0.02673729509115219
这是第10次迭代的第6个batch,loss是0.013404056429862976
这是第10次迭代的第11个batch,loss是0.09280069917440414
这是第10次迭代的第16个batch,loss是0.04911276698112488
这是第10次迭代的第21个batch,loss是0.10042841732501984
best_acc:0.9785407725321889,best_epoch:8
模型训练完毕,已将参数设置成训练过程中的最优值,现在开始测试test_set
测试集上的准确率是:0.9743589743589743
可以看到,最优的epoch是8,即第九次训练时的模型,该模型在validation_set上的准确率是0.9785,在test_set上的准确率是0.9744.
因此,和直接训练自己搭建的Resnet18相比(test_set上的准确率是0.902),通过transfer_learning学到的模型的效果有较为明显的提升。