基于Pytorch的cifar-10图像分类问题代码实现

之前在学习深度学习图片分类任务的时候,跟着老师的讲解实现了一个Cifair-10的图像分类任务。

数据集地址:网盘地址需要解码。
数据有50000张训练图片和10000张测试图片。
下载好数据后,在文件夹下新建两个文件夹一个为Train,一个为Test,用来保存解码后的图片。如图:
基于Pytorch的cifar-10图像分类问题代码实现_第1张图片
解码后的训练集:
基于Pytorch的cifar-10图像分类问题代码实现_第2张图片
基于Pytorch的cifar-10图像分类问题代码实现_第3张图片

文件的解码方式官方已经给出:

def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding="bytes")
    return dict

解码训练集的py文件的具体代码如下:

import os
import pickle
import glob #文件匹配
import cv2
import numpy as np

def unpickle(file): #图片的解码方式,数据集已经给出
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding="bytes")
    return dict #返回一个字典
lable_name = ["airplane",
              "automobile",
              "bird",
              "cat",
              "deer",
              "dog",
              "frog",
              "horse",
              "ship",
              "truck"]


train_list = glob.glob("D:\*\cifar-10-python\cifar-10-batches-py\data_batch_*") #文件的位置
print(train_list)
save_path = "D:\*\cifar-10-python\cifar-10-batches-py\Train"#解码成图片后文件的保存位置

for l in train_list:#读去每个文件,解码
    print(l)
    l_dict = unpickle(l)

    print(l_dict)
    print(l_dict.keys())

    for im_idx, im_data in enumerate(l_dict[b'data']):#enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中
        print(im_idx)
        print(im_data)

解码测试集的py文件的具体代码如下:

import os
import pickle

import cv2
import numpy as np

def unpickle(file): #图片的解码方式,数据集已经给出
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding="bytes")
    return dict #返回一个字典
lable_name = ["airplane",
              "automobile",
              "bird",
              "cat",
              "deer",
              "dog",
              "frog",
              "horse",
              "ship",
              "truck"]

import glob #文件匹配

train_list = glob.glob("D:\genglijia\cifar-10-python\cifar-10-batches-py\Test_batch")
print(train_list)
save_path = "D:\genglijia\cifar-10-python\cifar-10-batches-py\Test"

for l in train_list:#读去每个文件,解码
    print(l)
    l_dict = unpickle(l)

    print(l_dict)
    print(l_dict.keys())

    for im_idx, im_data in enumerate(l_dict[b'data']):#enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中
        print(im_idx)
        print(im_data)

        im_lable = l_dict[b'labels'][im_idx]
        im_name = l_dict[b'filenames'][im_idx]

        print(im_lable, im_name, im_data)

        im_lable_name = lable_name[im_lable]
        im_data = np.reshape(im_data, [3, 32, 32])#这里的图片的大小转换不是很清楚
        im_data = np.transpose(im_data, (1, 2, 0))#更改图片维度的位置,类似矩阵的转置


        if not os.path.exists("{}/{}".format(save_path,#判断path对应文件或目录是否存在,返回布尔类型
                                             im_lable_name)):
                    os.mkdir("{}/{}".format(save_path,
                                             im_lable_name))
        cv2.imwrite("{}/{}/{}".format(save_path,
                                   im_lable_name,
                                   im_name.decode("utf-8")),
                                   im_data)

加载本地数据集的py文件:

import glob
from torchvision import transforms #transforms用于后面数据增强
from torch.utils.data import DataLoader,Dataset#数据加载和数据读取相关的类
import os
from PIL import Image #用于图片数据的处理,有点类似opencv,数据会以numpy的形式进行存储
import numpy as np

lable_name = ["airplane","automobile","bird",
              "cat","deer","dog","frog",
              "horse","ship","truck"]

lable_dict = {}

#将字符串lable全部转换为数字
for idx, name in enumerate(lable_name):
    lable_dict[name] = idx
print(lable_dict)
def default_loader(path):
    return Image.open(path).convert("RGB")
    
#数据增强
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])


test_transform = transforms.Compose([
    transforms.CenterCrop((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])


class MyDataset(Dataset):
    #初始化函数
    def __init__(self, im_list, transform=None, loader=default_loader): #im_list拿到当前文件夹下所有文件的一个列表,也可以将文件的路径放到一个text中
        super(MyDataset, self).__init__()
        imgs = []
        for im_item in im_list:   #拿到的其实是“\genglijia\cifar-10-python\cifar-10-batches-py\Test\airplane\*.png”
            im_lable_name = im_item.split("\\")[-2]
            imgs.append([im_item, lable_dict[im_lable_name]])

        self.imgs = imgs#图片中的每一个元素
        self.transform = transform#两个方法 transform数据增强
        self.loader = loader

    #读取图片数据中元素的方法
    def __getitem__(self, index):
        im_path,im_lable = self.imgs[index]
        im_data = self.loader(im_path)
        if self.transform is not None:
            im_data = self.transform(im_data)
        return im_data, im_lable
    #计算样本的数量
    def __len__(self):
        return len(self.imgs)


im_train_list = glob.glob("D:\*\cifar-10-python\cifar-10-batches-py\Train\*\*.png")
im_test_list = glob.glob("D:\*\cifar-10-python\cifar-10-batches-py\Test\*\*.png")

train_dataset = MyDataset(im_train_list, transform=train_transform)
test_dataset = MyDataset(im_test_list, transform=transforms.ToTensor())

train_loader = DataLoader(dataset=train_dataset,
                            batch_size=64,
                            shuffle=True,
                            num_workers=0)

test_loader = DataLoader(dataset=test_dataset,
                            batch_size=64,
                            shuffle=False,
                            num_workers=0)#num_workers类似于多个进程同时对数据加载
print("num_of_train", len(train_dataset))
print("num_of_test", len(test_dataset))

定义网络结构的py文件(用的是经典resnet残差网络机构,也可以用其他的网络结构例如:vggnet、mobilenet等):

import torch
import torch.nn as nn
import torch.nn.functional as F

class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel, stride=1):
        super(ResBlock, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channel, out_channel,
                      kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(),
            nn.Conv2d(out_channel, out_channel,
                      kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channel),
        )
        self.shortcut = nn.Sequential()
        if in_channel != out_channel or stride > 1:

            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channel, out_channel,
                          kernel_size=3, stride=stride, padding=1),
                nn.BatchNorm2d(out_channel),
            )

    def forward(self, x):
        out1 = self.layer(x)
        out2 = self.shortcut(x)
        out = out1 + out2
        out = F.relu(out)
        return out



class ResNet(nn.Module):

    def make_layer(self, block, out_channel, stride, num_block):
        layers_list = []
        for i in range(num_block):
            if i == 0:
                in_stride = stride
            else:
                in_stride = 1
            layers_list.append(block(self.in_channel,out_channel, in_stride))
            self.in_channel = out_channel
        return nn.Sequential(*layers_list)

    def __init__(self, ResBlock):
        super(ResNet, self).__init__()
        self.in_channel = 32
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.layer1 = \
            self.make_layer(ResBlock, 64, 2, 2)

        self.layer2 = \
            self.make_layer(ResBlock, 128, 2, 2)

        self.layer3 = \
            self.make_layer(ResBlock, 256, 2, 2)

        self.layer4 = \
            self.make_layer(ResBlock, 512, 2, 2)

        self.fc = nn.Linear(512, 10)


    def forward(self,x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

def resnet():
    return ResNet(ResBlock)

训练和测试的py文件:

import torch
import torch.nn as nn
import torchvision
from resnet import resnet
from load_cifar10 import train_loader,test_loader
import os
#判断是否存在gpu
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

epoch_num = 1
lr = 0.01
batch_size = 128
net = resnet()#如果有GPU的话在后面加.to(device)

#loss
loss_func = nn.CrossEntropyLoss()

#optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
#optimizer = torch.optim.SGD(net.parameters(), lr=lr,
                           # momentum=0.9,weight_decay=5e-4)
                           
#更改学习率,采用指数的方式,也可以用固定的学习率
#step_size:每5次epoch更改学习率 gamma:更改为上次学习率的0.9倍
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

if __name__ == '__main__': #解决windows下的报错,因为使用了num_work,具体为啥我也不太清楚
    for epoch in range(epoch_num):
        net.train()#表明当前网络为训练的过程 train BN dropout
                   #如果在网络层定义了Batchnorm层则需要用net.train
                   #如果在网络层定义了dropout层则需要用net.eval()
        for i, data in enumerate(train_loader):
            inputs, labels = data
            # inputs, labels = inputs.to(device), labels.to(device) #转到GPU训练

            outputs = net(inputs)
            loss = loss_func(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            batch_size = inputs.size(0)
            _, pred = torch.max(outputs.data, dim=1)
            correct = pred.eq(labels.data).sum()
            print("train step", i, "loss is:", loss.item(), "mini-batch correct is:", 1.0 * correct / batch_size)

		#用来保存训练后的模型
        if not os.path.exists("models"):
            os.mkdir("models")
        torch.save(net.state_dict(),"models/{}.pth".format(epoch+1))
        scheduler.step()#更新学习率
        
        sum_loss = 0
        sum_correct = 0
        
        for i, data in enumerate(test_loader):
            net.eval()
            inputs, labels = data
            # inputs, labels = inputs.to(device), labels.to(device)

            outputs = net(inputs)
            loss = loss_func(outputs, labels)#测试集不再进行反向传播
            _, pred = torch.max(outputs.data, dim=1)
            correct = pred.eq(labels.data).sum()

            sum_loss += loss.item()
            sum_correct += correct.item()

            im = torchvision.utils.make_grid(inputs)

        test_loss = sum_loss*1.0/len(test_loader)
        test_correct = sum_correct *1.0/len(test_loader)/batch_size

        print("epoch", epoch+1, "loss is:", test_loss, "mini-batch correct is:", test_correct)


测试的结果:
基于Pytorch的cifar-10图像分类问题代码实现_第4张图片

这里面epoch为1,因为没有GPU训练的太慢就只训练了一次,正确率也有了百分之70左右,有条件的话多训练几次应该会达到更高。也可以改网络结构啥的,方法很多。

完成后可以在Test的文件夹中看到已经分类好的图片。

**有什么疑问可以在下面评论哦~**大家一起加油学习!!!

你可能感兴趣的:(深度学习,pytorch,图像处理)