在Pytorch LeNet 2:手写体字符识别实现我们处理了Mnist只有一个通道了简单数据集,准确率达到了96%
本章节,我们尝试使用LeNet训练负责的彩色数据,来看下LeNet对负责数据的处理
CIFAR10数据集,它有如下的分类:“飞机”,“汽车”,“鸟”,“猫”,“鹿”,“狗”,“青蛙”,“马”,“船”,“卡车”等。在CIFAR-10里面的图片数据大小是3x32x32,即三通道彩色图,图片大小是32x32像素
class LeNetTrain:
#定义基本属性
dataType = ''#['Mnist'] or ['Cifar']
dataRoot = "../datas"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataImgChannel = 3
EPOCH = 500 #训练总轮数
#定义构造方法
def __init__(self,dataType,dataRoot,dataImgChannel):
self.dataType = dataType
self.dataRoot = dataRoot
self.dataImgChannel = dataImgChannel
#定义数据加载器
def getDataLoader(self):
#use gpu to load and train data
#define transform:
#1:resize MNIST data to 32x32, so adapter to LeNet struct
#2:transform PIL.Image to torch.FloatTensor
resize = 32
normalize = transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
if self.dataImgChannel==1:
normalize = transforms.Normalize((0.5), (0.5))
elif self.dataImgChannel==3:
normalize = transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
transform = transforms.Compose(
[transforms.Resize(size=(resize, resize)),
transforms.ToTensor(),#将图片转换成(C,H, W)的Tensor格式,且/255归一化到[0,1.0]之间
normalize])#对每个通道通过(image-mean)/std将数据转换到[-1,1]之间
if dataType=='Mnist':
train_data = torchvision.datasets.MNIST(root=self.dataRoot, #data dir
train=True, #it is train data
transform=transform, #use defined transform
download=True) #use local data
test_data = torchvision.datasets.MNIST(root=self.dataRoot,
train=False,
transform=transform,
download=True)
elif dataType=='Cifar':
train_data = torchvision.datasets.CIFAR10(root='../datas/cifar', train=True, download=True, transform=transform)
test_data = torchvision.datasets.CIFAR10(root='../datas/cifar', train=False, download=True, transform=transform)
else:
print('dataType not in[Mnist,Cifar]')
return
#define DataLoder,and shuffle data
train_loader = torch.utils.data.DataLoader(dataset = train_data,batch_size =320 ,shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = test_data,batch_size = 10000,shuffle = False)
return train_loader,test_loader
我们在LeNetTrain的基础上,让LeNet兼容3通道图片的训练。
#初始化网络
def getLenet(self):
#getdataload
#use already define Lenet
net = LeNet2.Net(self.dataImgChannel).to(self.device)
loss_fuc = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(),lr = 0.01,weight_decay = 0.005)
return net,loss_fuc,optimizer
LeNet会根据传入的图片通道数self.dataImgChannel,初始化不同的LeNet 的C1卷积层
#开始训练
def train_net(self,net,loss_fuc,optimizer):
#getdataloader
train_loader,test_loader= self.getDataLoader()
#Star train
adjust_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.1)#定义学习率衰减函数
print('start train with epoch:',self.EPOCH)
iteration = 0
for epoch in range(self.EPOCH):
sum_loss = 0
#数据读取
for i,data in enumerate(train_loader):
inputs,labels = data
#有GPU则将数据置入GPU加速
inputs, labels = inputs.to(self.device), labels.to(self.device)
# 梯度清零
optimizer.zero_grad()
# 传递损失 + 更新参数
output = net(inputs)
loss = loss_fuc(output,labels)
loss.backward()
optimizer.step()
iteration = iteration+1
# print loss every 100 iteration
sum_loss += loss.item()
if True:#每次epochs,测试一次测试数据
lr = optimizer.param_groups[0]["lr"]#get current lr
#print(lr)
print('###iteration[:%d],[Epoch:%d],[Lr:%.08f] train loss: %.03f' % (iteration,epoch + 1, lr, sum_loss / 100))
#用网络测试验证数据
self.test_net(self.device,test_loader,net)
#保存模型
sum_loss = 0.0
adjust_lr_scheduler.step()#更新学习率
#保存模型
self.save_model(net,optimizer,epoch+1)
训练结果为:
start train with epoch: 500
###iteration[:157],[Epoch:1],[Lr:0.01000000] train loss: 2.790
test data avg accuracy:42%
save_modelt:
save net: ../models/lenet_Cifar1.pth
###iteration[:314],[Epoch:2],[Lr:0.01000000] train loss: 2.300
test data avg accuracy:48%
save_modelt:
save net: ../models/lenet_Cifar2.pth
###iteration[:471],[Epoch:3],[Lr:0.01000000] train loss: 2.179
test data avg accuracy:49%
save_modelt:
...
...
###iteration[:76145],[Epoch:485],[Lr:0.00001000] train loss: 1.383
test data avg accuracy:67%
save_modelt:
save net: ../models/lenet_Cifar485.pth
###iteration[:76302],[Epoch:486],[Lr:0.00001000] train loss: 1.382
test data avg accuracy:67%
...
save_modelt:
save net: ../models/lenet_Cifar492.pth
###iteration[:77401],[Epoch:493],[Lr:0.00001000] train loss: 1.382
test data avg accuracy:67%
如上结果所示,LeNet对彩色复杂图片的训练和测试结果,并不理想,只有67%的准确率。如果想要进一步提高准确率,我们需要使用更复杂的网络
LeNet/LeNetTrain2.py