首先继承自torch.utils.data.Dataset
重写len与getitem
train就用train数据集,test就用test数据集。
import torch
import os,glob
import random,csv
import torchvision.datasets
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
from PIL import Image
from torchvision.transforms import InterpolationMode
class Pokemon(Dataset):
def __init__(self,root,resize,mode):
super(Pokemon, self).__init__()
self.root=root
self.resize=resize
self.name2label={}#字典表达映射关系label‘‘sq..’’:0
for name in sorted(os.listdir(os.path.join(root))):#遍历根目录下所有文件假
if not os.path.isdir(os.path.join(root,name)):#判断是否是文件夹
continue
self.name2label[name]=len(self.name2label.keys())
print(self.name2label)
self.images,self.labels=self.load_csv('images.csv')#得到的是images的路径,和对应的数字标签
if mode=='train':
self.images=self.images[:int(0.6*len(self.images))]
self.labels=self.labels[:int(0.6*len(self.labels))]
elif mode=='val':#20
self.images=self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
self.labels=self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
else:
self.images=self.images[int(0.8*len(self.images)):]
self.labels=self.labels[int(0.8*len(self.labels)):]
#创建数据对:path+label
def load_csv(self,filename):
if not os.path.exists(os.path.join(self.root,filename)):#如果已经有了,不需要再创建
images=[]
for name in self.name2label.keys():#key:value
images+=glob.glob(os.path.join(self.root,name,'*.png'))#glob方法获取目录下所有满足的文件
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
#1165,pokeman/bulbasaur/00001.png
#对应关系保存到csv
random.shuffle(images)
with open(os.path.join(self.root,filename),mode='w',newline='') as f:
writer=csv.writer(f)
for img in images:#pokeman/bulbasaur/00001.png
name=img.split(os.sep)[-2]
label=self.name2label[name]#字典根据key找value存入label
writer.writerow([img,label])#pokeman/bulbasaur/00001.png,0
print('writen into csvfile:',filename)
#read from csv
images,labels=[],[]
with open(os.path.join(self.root,filename)) as f:
reader=csv.reader(f)
for row in reader:
img,label=row
label=int(label)
images.append(img)
labels.append(label)
assert len(images)==len(labels)
return images,labels
def __len__(self):
return len(self.images)
def denormalize(self,x_hat):#逆归一化已回复图片视觉效果
mean = [0.845, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
#x_hat=(x-mean)/std
#x=x_hat*std+mean
#x:[c,h,w]
#mean:[3]=>[3,1,1]
mean=torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
print(mean.shape,std.shape)
x=x_hat*std+mean
return x
def __getitem__(self, idx):
#self.images,self.labels
#idx-[0-len(images)]
img,label=self.images[idx],self.labels[idx]#从csv获得的图片路径与label
tf=transforms.Compose([
lambda x:Image.open(x).convert('RGB'),#string path=>image data
transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
transforms.RandomRotation(15),
transforms.CenterCrop(self.resize),
transforms.ToTensor(),
transforms.Normalize(mean=[0.845,0.456,0.406],
std=[0.229,0.224,0.225])
])
img=tf(img)
label=torch.tensor(label)
return img,label
if __name__=='__main__':
import visdom
import time
viz=visdom.Visdom()
# tf = transforms.Compose([
# transforms.Resize((64,64)),
# transforms.ToTensor(),
# ])
# db=torchvision.datasets.ImageFolder(root='/home/lizheng/Study/yolov5-5.0/pytorch1/pokemon/pokeman',transform=tf)
# loader=DataLoader(db,batch_size=32,shuffle=True)#一行代码完成数据集加载工作
# print(db.class_to_idx)
# for x,y in loader:
# viz.images(x,nrow=8,win='batch',opts=dict(title='batch'))
# viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
#
# time.sleep(10)
db=Pokemon('pokemon/pokeman',64,'train')
x,y=next(iter(db))#利用迭代器输入路径获得具体图像,得到第一个样本,调用时自动使用getitem函数,此时x是图像
print('sample:',x.shape,y.shape,y)
viz.image(db.denormalize(x),win='sample_x',opts=dict(title='sample_x'))
loader=DataLoader(db,batch_size=32,shuffle=True,num_workers=8)#不想一个一个取,想一个batch一个batch取
for x,y in loader:
viz.images(db.denormalize(x),nrow=8,win='batch',opts=dict(title='batch'))
viz.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
time.sleep(10)
import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
def __init__(self,ch_in,ch_out,stride=1):
'''
:param ch_in:
:param ch_out:
'''
super(ResBlk,self).__init__()#super方法避免父类的init函数被替换
self.conv1=nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)
self.bn1=nn.BatchNorm2d(ch_out)
self.conv2=nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1)
self.bn2=nn.BatchNorm2d(ch_out)
self.extra=nn.Sequential()
if ch_out!=ch_in:
self.extra=nn.Sequential(
nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),#Sequential里面加入的都是类,因此可以自己写,然后加入
nn.BatchNorm2d(ch_out)#这些类在调用时会自动调用forward函数,记得要写return
)
def forward(self,x):
'''
:param x:[b,ch,h,w]
:return:
'''
out=F.relu(self.bn1(self.conv1(x)))
out=self.bn2(self.conv2(out))
#short cut
#extra module:[b,ch_in,h,w] with [b,ch_out,h,w]
#element-wise add:\
out=self.extra(x)+out
return out
class ResNet18(nn.Module):
def __init__(self,num_class):
super(ResNet18,self).__init__()
self.conv1=nn.Sequential(
nn.Conv2d(3,16,kernel_size=3,stride=3,padding=0),
nn.BatchNorm2d(16)
)
#followed 4 blocks
#[b,16,h,w]=>[b,32,h,w]
self.blk1=ResBlk(16,32,stride=3)#增多通道,减少长宽,避免数据量过大
#[b,32,h,w]=>[b,64,h,w]
self.blk2 = ResBlk(32, 64,stride=3)
#[b,64,h,w]=>[b,128,h,w]
self.blk3 = ResBlk(64,128,stride=2)
# [b,128,h,w]=>[b,256,h,w]
self.blk4 = ResBlk(128,256,stride=2)
#[b,256,7,7]
self.outlayer=nn.Linear(256*3*3,num_class)#输入512通道,输出10通道
def forward(self,x):
'''
:param x:
:return:
'''
x=F.relu(self.conv1(x))
#[b,64,h,w]=>[b,1024,h,w]
x=self.blk1(x)
x=self.blk2(x)
x=self.blk3(x)
x=self.blk4(x)
# print('after conv:',x.shape)#[b,512,2,2]
# # [b,512,h,w]=>[b,512,2,2]
# x=F.adaptive_avg_pool2d(x,[1,1])
# # print('after pool:',x.shape)
x=x.view(x.size(0),-1)
x=self.outlayer(x)
return x
if __name__=='__main__':
blk=ResBlk(64,128)
tmp=torch.randn(2,64,224,224)
out=blk(tmp)
print('block',out.shape)
model=ResNet18(5)#5分类
tmp=torch.randn(2,3,224,224)
out=model(tmp)
print('resnet:',out.shape)
p=sum(map(lambda p:p.numel(),model.parameters()))
print('parameters size:',p)
import torch
from torch import optim,nn
import visdom
import torchvision
from torch.utils.data import DataLoader
from pytorch1.pt3 import Pokemon
from resnet import ResNet18
batchsz=32
lr=1e-3
epochs=10
device=torch.device('cuda')
torch.manual_seed(1234)
train_db=Pokemon('pokemon/pokeman',224,mode='train')#initial函数初始化训练集,
val_db=Pokemon('pokemon/pokeman',224,mode='val')
test_db=Pokemon('pokemon/pokeman',224,mode='test')
train_loader=DataLoader(train_db,batch_size=batchsz,shuffle=True,num_workers=4)#loader里获得的都是真正是图片
val_loader=DataLoader(val_db,batch_size=batchsz,num_workers=4)
test_loader=DataLoader(test_db,batch_size=batchsz,num_workers=4)
viz = visdom.Visdom()
def evaluate(model,loader):
correct=0
total=len(loader.dataset)
for x,y in loader:
x,y=x.to(device),y.to(device)
with torch.no_grad():
logits=model(x)
pred=logits.argmax(dim=1)
correct+=torch.eq(pred,y).sum().float().item()
return correct/total
if __name__=='__main__':
model=ResNet18(5).to(device)
optimizer=optim.Adam(model.parameters(),lr=lr)
criteon=nn.CrossEntropyLoss()
best_acc,best_epoch=0,0
global_step=0
viz.line([0],[-1],win='loss',opts=dict(title='loss'))
viz.line([0],[-1],win='val_acc',opts=dict(title='val_acc'))#清空操作
for epoch in range(epochs):
for step,(x,y) in enumerate(train_loader):
#x:[b,3,224,224],y:[b]
x,y=x.to(device),y.to(device)
logits=model(x)
loss=criteon(logits,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([loss.item()], [global_step], win='loss', update='append')
global_step+=1
if epoch%2==0:
val_acc=evaluate(model,val_loader)
if val_acc>best_acc:
best_epoch=epoch
best_acc=val_acc
torch.save(model.state_dict(),'best.mdl')
viz.line([val_acc], [global_step], win='val_acc',update='append')
print('best acc:',best_acc,'best_epoch:',best_epoch)
model.load_state_dict(torch.load('best.mdl'))
print('loaded from ckpt!')
test_acc=evaluate(model,test_loader)
print('test acc:',test_acc)