一、数据集解析
#数据读取与解析
import pickle
import numpy as np
import cv2
def unpickle(file):
with open(file,'rb') as fo:
dict = pickle.load(fo,encoding='bytes')
return dict
label_name = ["airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck"]
import glob
import numpy
import os
#解析训练集
im_train_list = glob.glob("G:\pycharm-work\CIFAR10\data_batch_*")
print(im_train_list)
save_path = "G:\pycharm-work\CIFAR10\TRAIN"
for l in im_train_list:
print(l)
l_dict = unpickle(l)
print(l_dict.keys())
for im_idx , im_data in enumerate(l_dict[b'data']):
# print(im_idx)
# print(im_data)
im_label = l_dict[b'labels'][im_idx]
im_name = l_dict[b'filenames'][im_idx]
# print(im_label,im_name)
im_label_name = label_name[im_label]
#对数据进行reshape 则需要转为numpy格式
im_data = np.reshape(im_data,[3,32,32])
im_data = np.transpose(im_data,(1,2,0))
# cv2.imshow("im_data",cv2.resize(im_data,(200,200)))
# cv2.waitKey(0)
if not os.path.exists("{}/{}".format(save_path,im_label_name)):
os.mkdir("{}/{}".format(save_path,im_label_name))
cv2.imwrite("{}/{}/{}".format(save_path,im_label_name,
im_name.decode("utf-8")),im_data)
#解析训练集
im_test_list = glob.glob("G:\pycharm-work\CIFAR10\test_batch_*")
print(im_test_list)
save_path = "G:\pycharm-work\CIFAR10\TEST"
for l in im_test_list:
print(l)
l_dict = unpickle(l)
print(l_dict.keys())
for im_idx , im_data in enumerate(l_dict[b'data']):
# print(im_idx)
# print(im_data)
im_label = l_dict[b'labels'][im_idx]
im_name = l_dict[b'filenames'][im_idx]
# print(im_label,im_name)
im_label_name = label_name[im_label]
#对数据进行reshape 则需要转为numpy格式
im_data = np.reshape(im_data,[3,32,32])
im_data = np.transpose(im_data,(1,2,0))
# cv2.imshow("im_data",cv2.resize(im_data,(200,200)))
# cv2.waitKey(0)
if not os.path.exists("{}/{}".format(save_path,im_label_name)):
os.mkdir("{}/{}".format(save_path,im_label_name))
cv2.imwrite("{}/{}/{}".format(save_path,im_label_name,
im_name.decode("utf-8")),im_data)
二、自定义数据集加载
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from torch.utils.tensorboard import SummaryWriter
import time
import os
from PIL import Image
import numpy as np
import cv2
label_name = ["airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck"]
label_dict = {}
for idx, name in enumerate(label_name):
label_dict[name] = idx
# print(label_dict)
def default_loader(path):
return Image.open(path).convert("RGB")
train_transform = transforms.Compose([
transforms.RandomResizedCrop((28,28)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(90),
transforms.RandomGrayscale(0.1),
transforms.ColorJitter(0.3,0.3,0.3,0.3),
transforms.ToTensor()
])
class MyDataset(Dataset):
def __init__(self,im_list,transform = None,loader = default_loader):
super(MyDataset,self).__init__()
imgs = []
for im_item in im_list:
# print(im_item)
im_label_name = im_item.split("\\")[-2]
imgs.append([im_item,label_dict[im_label_name]])
self.imgs =imgs
self.transform =transform
self.loader = loader
def __getitem__(self, index):
im_path, im_label = self.imgs[index]
im_data = self.loader(im_path)
if self.transform is not None:
im_data = self.transform(im_data)
return im_data,im_label
def __len__(self):
return len(self.imgs)
im_train_list = glob.glob("G:\pycharm-work\CIFAR10\TRAIN\*\*.png")
im_test_list = glob.glob("G:\pycharm-work\CIFAR10\TEST\*\*.png")
train_data_set = MyDataset(im_train_list,transform=train_transform)
test_data_set = MyDataset(im_test_list,transform = transforms.ToTensor())
# train_data_loader = DataLoader(dataset=train_data_set,batch_size=6,shuffle=True,num_workers=4)
# test_data_loader = DataLoader(dataset=test_data_set,batch_size=6,shuffle=False,num_workers=4)
train_data_loader =DataLoader(dataset=train_data_set,batch_size=128,shuffle=True)
test_data_loader = DataLoader(dataset=test_data_set,batch_size=128,shuffle=False)
print("num_of_train",len(train_data_set))
print("num_of_test",len(test_data_set))
三、搭建神经网络模型
class VGGBase(nn.Module):
def __init__(self):
super(VGGBase, self).__init__()
# 3 * 28 * 28
self.conv1 = nn.Sequential(
nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
#14*14
self.max_pooling1 = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
self.conv2_1 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
)
self.conv2_2 = nn.Sequential(
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU()
)
# 7*7
self.max_pooling2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.conv3 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.conv3_2 = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU()
)
self.max_pooling3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)# 获得4 * 4
self.conv4 = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.conv4_2 = nn.Sequential(
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU()
)
self.max_pooling4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc = nn.Linear(512 * 4,10)
def forward(self,x):
batchsize = x.size(0)
out = self.conv1(x)
out = self.max_pooling1(out)
out = self.conv2_1(out)
out = self.conv2_2(out)
out = self.max_pooling2(out)
out = self.conv3(out)
out = self.conv3_2(out)
out = self.max_pooling3(out)
out = self.conv4(out)
out = self.conv4_2(out)
out = self.max_pooling4(out)
out = out.view(batchsize, -1)
out = self.fc(out)
out = F.log_softmax(out,dim=1)
return out
def VGGNet():
return VGGBase()
# 定义训练设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net =VGGNet()
net = net.to(device)
# 损失函数
loss_fn =nn.CrossEntropyLoss()
loss_fn =loss_fn.to(device)
# 优化器
learning_rate =0.01
#weight_decay 正则项 momentum 动量
# optimizer =torch.optim.SGD(net.parameters(),lr = learning_rate,momentum=0.9,weight_decay=5e-4)
optimizer =torch.optim.Adam(net.parameters(),lr = learning_rate)
#学习率调整 指数衰减
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=1,gamma=0.9)
# 添加Tensorboard
if not os.path.exists("log"):
os.mkdir("log")
writer =SummaryWriter("log")
# 设置训练网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
epoch = 30
batch_size = 128
start_time = time.time()
step_n = 0
for epochidx in range(epoch):
print("-----第{}轮训练开始------".format(epochidx + 1))
# 训练步骤开始
net.train() #BN与dropout 更新
for i,data in enumerate(train_data_loader):
imgs,target =data
imgs = imgs.to(device)
target = target.to(device)
output = net(imgs)
loss = loss_fn(output, target)
#优化器优化模型
optimizer.zero_grad()
loss.backward()
optimizer.step()
# total_train_step +=1
# if total_train_step % 100 == 0:
# end_time =time.time()
# print(end_time-start_time)
# print("训练次数{}, Loss:{}".format(total_train_step, loss.item()))
# writer.add_scalar("train_loss",loss.item(),total_train_step)
_,pred = torch.max(output.data,dim= 1)
correct = pred.eq(target.data).cpu().sum()
print("train epoch is ",epochidx)
print("trainlr is ",optimizer.state_dict()["param_groups"][0]["lr"])
print("train step",i,"loss is:",loss.item(),
"mini_batch correct is:",100.0 * correct / batch_size)
#记录
writer.add_scalar("train loss",loss.item(),global_step=step_n)
writer.add_scalar("train correct", 100.0 * correct / batch_size,global_step=step_n)
im = torchvision.utils.make_grid(imgs)
writer.add_image("train img",im,global_step=step_n)
step_n +=1
if not os.path.exists("models"):
os.mkdir("models")
torch.save(net.state_dict(),"models/{}.pth".format(epochidx+1))
#每个epoch后更新学习率
scheduler.step()
#对模型进行测试
sum_loss = 0
sum_correct = 0
for j, data in enumerate(test_data_loader):
net.eval() # 测试
imgs, target = data
imgs = imgs.to(device)
target = target.to(device)
output = net(imgs)
loss = loss_fn(output, target)
_, pred = torch.max(output.data, dim=1)
correct = pred.eq(target.data).cpu().sum()
sum_loss += loss.item()
sum_correct +=correct.item()
im = torchvision.utils.make_grid(imgs)
writer.add_image("test img", im, global_step=step_n)
test_loss =sum_loss*1.0/len(test_data_loader)
test_correct = sum_correct * 100.0 / len(test_data_loader)/batch_size
# 记录
writer.add_scalar("test loss", test_loss, global_step=j+1)
writer.add_scalar("test correct", test_correct, global_step=j+1)
print("epoch is", j+1, "loss is:", test_loss,
"test correct is:", test_correct)
writer.close()