Pytorch LeNet 4_Cifar数据集训练验证

在Pytorch LeNet 2:手写体字符识别实现我们处理了Mnist只有一个通道了简单数据集,准确率达到了96%

本章节,我们尝试使用LeNet训练负责的彩色数据,来看下LeNet对负责数据的处理

Cifar-10数据集

CIFAR10数据集,它有如下的分类:“飞机”,“汽车”,“鸟”,“猫”,“鹿”,“狗”,“青蛙”,“马”,“船”,“卡车”等。在CIFAR-10里面的图片数据大小是3x32x32,即三通道彩色图,图片大小是32x32像素

Pytorch LeNet 4_Cifar数据集训练验证_第1张图片

加载和标准化数据

class LeNetTrain: 
    #定义基本属性 
    dataType = ''#['Mnist'] or ['Cifar']
    dataRoot = "../datas"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataImgChannel = 3
    EPOCH = 500   #训练总轮数
    #定义构造方法 
    def __init__(self,dataType,dataRoot,dataImgChannel): 
        self.dataType = dataType
        self.dataRoot = dataRoot
        self.dataImgChannel = dataImgChannel

    #定义数据加载器
    def getDataLoader(self):
        #use gpu to load and train data

        #define transform:
        #1:resize MNIST data to 32x32, so adapter to LeNet struct
        #2:transform PIL.Image to  torch.FloatTensor
        resize = 32
        normalize = transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
        if self.dataImgChannel==1:
            normalize = transforms.Normalize((0.5), (0.5))
        elif self.dataImgChannel==3:
            normalize = transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))


        transform = transforms.Compose(
            [transforms.Resize(size=(resize, resize)),
             transforms.ToTensor(),#将图片转换成(C,H, W)的Tensor格式,且/255归一化到[0,1.0]之间
             normalize])#对每个通道通过(image-mean)/std将数据转换到[-1,1]之间
        if dataType=='Mnist':
            train_data = torchvision.datasets.MNIST(root=self.dataRoot,    #data dir 
                                                    train=True,               #it is train data
                                                    transform=transform,      #use defined transform
                                                    download=True)            #use local data
            test_data = torchvision.datasets.MNIST(root=self.dataRoot,
                                                    train=False,
                                                    transform=transform,
                                                    download=True)
        elif dataType=='Cifar':
             train_data = torchvision.datasets.CIFAR10(root='../datas/cifar', train=True, download=True, transform=transform)
             test_data = torchvision.datasets.CIFAR10(root='../datas/cifar', train=False, download=True, transform=transform)
        else:
            print('dataType not in[Mnist,Cifar]')
            return

        #define DataLoder,and shuffle data
        train_loader = torch.utils.data.DataLoader(dataset = train_data,batch_size =320 ,shuffle = True)
        test_loader = torch.utils.data.DataLoader(dataset = test_data,batch_size = 10000,shuffle = False)


        return train_loader,test_loader

我们在LeNetTrain的基础上,让LeNet兼容3通道图片的训练。

初始化网络,定义损失函数和优化器

    #初始化网络
    def getLenet(self):

        #getdataload
        #use already define Lenet
        net = LeNet2.Net(self.dataImgChannel).to(self.device)
        
        loss_fuc = nn.CrossEntropyLoss() 
        optimizer = optim.Adam(net.parameters(),lr = 0.01,weight_decay = 0.005) 
        return net,loss_fuc,optimizer

LeNet会根据传入的图片通道数self.dataImgChannel,初始化不同的LeNet 的C1卷积层

训练网络

    #开始训练
    def train_net(self,net,loss_fuc,optimizer):
        #getdataloader
        train_loader,test_loader= self.getDataLoader()
        #Star train
        
        adjust_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.1)#定义学习率衰减函数
        print('start train with epoch:',self.EPOCH)
        iteration = 0
        for epoch in range(self.EPOCH):
            sum_loss = 0
            #数据读取
            for i,data in enumerate(train_loader):
                inputs,labels = data
                #有GPU则将数据置入GPU加速
                inputs, labels = inputs.to(self.device), labels.to(self.device)   
         
                # 梯度清零
                optimizer.zero_grad()
         
                # 传递损失 + 更新参数
                output = net(inputs)
                loss = loss_fuc(output,labels)
                loss.backward()
                optimizer.step()
                
                iteration = iteration+1
         
                # print loss every 100 iteration
                sum_loss += loss.item()
            if True:#每次epochs,测试一次测试数据
                lr = optimizer.param_groups[0]["lr"]#get current lr
                #print(lr)
                print('###iteration[:%d],[Epoch:%d],[Lr:%.08f] train loss: %.03f' % (iteration,epoch + 1, lr, sum_loss / 100))
                #用网络测试验证数据
                self.test_net(self.device,test_loader,net)
                #保存模型

                sum_loss = 0.0
                adjust_lr_scheduler.step()#更新学习率
            #保存模型
            self.save_model(net,optimizer,epoch+1)

训练结果为:


start train with epoch: 500
###iteration[:157],[Epoch:1],[Lr:0.01000000] train loss: 2.790
test data avg accuracy:42%
save_modelt:
save net: ../models/lenet_Cifar1.pth
###iteration[:314],[Epoch:2],[Lr:0.01000000] train loss: 2.300
test data avg accuracy:48%
save_modelt:
save net: ../models/lenet_Cifar2.pth
###iteration[:471],[Epoch:3],[Lr:0.01000000] train loss: 2.179
test data avg accuracy:49%
save_modelt:
...
...
###iteration[:76145],[Epoch:485],[Lr:0.00001000] train loss: 1.383
test data avg accuracy:67%
save_modelt:
save net: ../models/lenet_Cifar485.pth
###iteration[:76302],[Epoch:486],[Lr:0.00001000] train loss: 1.382
test data avg accuracy:67%
...
save_modelt:
save net: ../models/lenet_Cifar492.pth
###iteration[:77401],[Epoch:493],[Lr:0.00001000] train loss: 1.382
test data avg accuracy:67%

如上结果所示,LeNet对彩色复杂图片的训练和测试结果,并不理想,只有67%的准确率。如果想要进一步提高准确率,我们需要使用更复杂的网络

代码实现

LeNet/LeNetTrain2.py

你可能感兴趣的:(pytorch实践,Pytorch,LeNet,LeNet,Cifar数据集训练验证)