【代码学习】读取和训练cifar10

代码参考【从入门到进阶】《PyTorch深度学习实践》P61-73

目录

        • 一、处理数据
        • 二、训练
        • 三、网络模型
        • 四、测试

一、处理数据

# readcifar10.py
import pickle

# cifar数据库官网给出的数据处理函数
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

label_name = ["airplane",
              "automobile",
              "bird",
              "cat",
              "deer",
              "dog",
              "frog",
              "horse",
              "ship",
              "truck"]

import  glob
import numpy as np
import os
import cv2

# 获取名为data_batch_*的文件,将数据分别放进train文件夹和test文件夹
train_list = glob.glob("./dataset/cifar-10/data_batch_*")
test_list = glob.glob("./dataset/cifar-10/test_batch*")
train_save_path = "./dataset/cifar-10/train"
test_save_path = "./dataset/cifar-10/test"

# 提取训练集数据,保存成图片格式
for l in train_list:
    l_dict = unpickle(l)
    # print(l_dict.keys())
    # b用于指定 bytes 字符串
    for im_idx, im_data in enumerate(l_dict[b'data']):
        im_label = l_dict[b'labels'][im_idx]
        im_name = l_dict[b'filenames'][im_idx]
        # print(im_label, im_name, im_data)

        im_label_name = label_name[im_label]
        im_data = np.reshape(im_data, [3, 32, 32])
        im_data = np.transpose(im_data, (1, 2, 0))

        # cv2.imshow("im_data", cv2.resize(im_data,(200,200)))
        # cv2.waitKey(0)
        if not os.path.exists("{}/{}".format(train_save_path, 
                                            im_label_name)):
            print("not exit")
            os.mkdir("{}/{}".format(train_save_path, im_label_name))
        
        # 由于im_name是bytes类型,需要转换成utf-8字符串型
        cv2.imwrite("{}/{}/{}".format(train_save_path, 
                                        im_label_name, 
                                        im_name.decode("utf-8")), 
                                        im_data)

# 提取测试集数据
for l in test_list:
    l_dict = unpickle(l)
    # print(l_dict.keys())
    # b用于指定 bytes 字符串
    for im_idx, im_data in enumerate(l_dict[b'data']):
        im_label = l_dict[b'labels'][im_idx]
        im_name = l_dict[b'filenames'][im_idx]
        # print(im_label, im_name, im_data)

        im_label_name = label_name[im_label]
        im_data = np.reshape(im_data, [3, 32, 32])
        im_data = np.transpose(im_data, (1, 2, 0))

        # cv2.imshow("im_data", cv2.resize(im_data,(200,200)))
        # cv2.waitKey(0)

        if not os.path.exists("{}/{}".format(test_save_path, im_label_name)):
            os.mkdir("{}/{}".format(test_save_path, im_label_name))
        
        # 由于im_name是bytes类型,需要转换成utf-8字符串型
        cv2.imwrite("{}/{}/{}".format(test_save_path, im_label_name, im_name.decode("utf-8")), im_data)
# loadcifar10.py
from cProfile import label
from matplotlib.cbook import print_cycles
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import numpy as np
import glob

label_name = ["airplane",
              "automobile",
              "bird",
              "cat",
              "deer",
              "dog",
              "frog",
              "horse",
              "ship",
              "truck"]

label_dict = {}

for idx, name in enumerate(label_name):
    label_dict[name] = idx

def default_loader(path):
    return Image.open(path).convert("RGB")

train_transform = transforms.Compose([
    transforms.RandomResizedCrop((28, 28)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(90),
    transforms.RandomGrayscale(0.1),
    transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
    transforms.ToTensor()
])

test_transform = transforms.Compose([
    transforms.Resize((28,28)),
    transforms.ToTensor()
])

class MyDataset(Dataset):
    def __init__(self, im_list, transform=None,
                 loader=default_loader):
        super(MyDataset, self).__init__()
        imgs = []

        for im_item in im_list:
            im_label_name = im_item.split("/")[-2]
            imgs.append([im_item, label_dict[im_label_name]])

        self.imgs = imgs
        self.transform = transform
        self.loader = loader

    def __getitem__(self, index):
        im_path, im_label = self.imgs[index]
        im_data = self.loader(im_path)
        if self.transform is not None:
            im_data = self.transform(im_data)
        return im_data, im_label

    def __len__(self):
        return len(self.imgs)

im_train_list = glob.glob("./dataset/cifar-10/train/*/*.png")
im_test_list = glob.glob("./dataset/cifar-10/test/*/*.png")

train_dataset = MyDataset(im_train_list, transform=train_transform)
test_dataset = MyDataset(im_test_list, transform=test_transform)

train_data_loader = DataLoader(dataset=train_dataset,
                               batch_size=6,
                               shuffle=True,
                               num_workers=4)

test_data_loader = DataLoader(dataset=test_dataset,
                               batch_size=6,
                               shuffle=False,
                               num_workers=4)

# print("num of train", len(train_dataset))
# print("num of test", len(test_dataset))

二、训练

# train.py
from cProfile import label
import torch
import torch.nn as nn
import torch.nn.functional as F
from zmq import device
from vggnet import VGGNet
from resnet import resnet
from mobilenetv1 import MobileNetv1_small
from inceptionModule import InceptionNet_small
from loadcifar10 import train_data_loader, test_data_loader
import os
import tensorboardX
import torchvision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

epoch_num = 2
lr = 0.01
batch_size = train_data_loader.batch_size

# net = VGGNet().to(device)
# net = resnet().to(device)
# net = MobileNetv1_small().to(device)
net = InceptionNet_small().to(device)

# loss
loss_func = nn.CrossEntropyLoss()

# optimizer
optimizer = torch.optim.Adam(net.parameters(), lr = lr)
# optimizer = torch.optim.SGD(net.parameters(), lr = lr,
#                             momentum=0.9, weight_decay=5e-4)
# 学习率指数衰减,每进行5个epoch,学习率会变成上一个的0.9倍
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

if not os.path.exists("Logs"):
    os.mkdir("Logs")
writer = tensorboardX.SummaryWriter("Logs")
step_n = 0

for epoch in range(epoch_num):
    net.train() # train BN dropout
    for i, data in enumerate(train_data_loader):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = net(inputs)
        loss = loss_func(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        _, pred = torch.max(outputs.data, dim=1)

        correct = pred.eq(labels.data).cpu().sum()
        # print("epoch is ", epoch)
        # print("lr is ", optimizer.state_dict()["param_groups"][0]["lr"])
        # print("train_ step ", i, "——loss is:", loss.item(),
        #         "——mini-batch correct is:", 100.0 * correct / batch_size)
        
        writer.add_scalar("train loss", loss.item(), global_step= step_n)
        writer.add_scalar("train correct", 100.0 * correct.item() / batch_size, global_step= step_n)
        im = torchvision.utils.make_grid(inputs)
        writer.add_image("train image", im, global_step=step_n)
        
        step_n += 1

    if not os.path.exists("models"):
        os.mkdir("models")
    torch.save(net.state_dict(), "models/{}.pth".format(epoch + 1))
    scheduler.step()

    sum_loss = 0
    sum_correct = 0
    for i, data in enumerate(test_data_loader):
        net.eval()
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = net(inputs)
        loss = loss_func(outputs, labels)

        _, pred = torch.max(outputs.data, dim=1)

        correct = pred.eq(labels.data).cpu().sum()

        sum_loss += loss.item()
        sum_correct += correct.item()
        im = torchvision.utils.make_grid(inputs)
        writer.add_image("test image", im, global_step=step_n)
        # writer.add_scalar("test loss", loss, global_step= step_n)
        # writer.add_scalar("test correct", 100.0 * correct / batch_size, global_step= step_n)
        
    test_loss = sum_loss * 1.0 / len(test_data_loader)
    test_correct = sum_correct * 100.0 / len(test_data_loader) / batch_size
    writer.add_scalar("test loss", loss, global_step= epoch + 1)
    writer.add_scalar("test correct", 100.0 * correct / batch_size, global_step= epoch + 1)
    print("test epoch is ", epoch + 1, "——loss is:", test_loss,
             "——mini-batch correct is:",test_correct)
        
writer.close()  

三、网络模型

# vggnet.py
from turtle import forward
import torch
import torch.nn as nn
import torch.nn.functional as F
from tensorboardX import SummaryWriter

class VGGbase(nn.Module):
    def __init__(self):
        super(VGGbase, self).__init__()

        # 3 * 28 * 28 
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        # -> 64 * 28 * 28
        self.max_pooling1 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 64 * 14 * 14
        self.conv2_1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.conv2_2 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        # -> 128 * 14 * 14
        self.max_pooling2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 128 * 7 * 7
        self.conv3_1 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.conv3_2 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        # 256 * 7 * 7
        self.max_pooling3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)

        # 256 * 4 * 4
        self.conv4_1 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.conv4_2 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        # -> 512 * 4 * 4
        self.max_pooling4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # batchsize * 512 * 2 * 2 -> batchsize * (512 * 4)
        self.fc = nn.Linear(512 * 4, 10)

    def forward(self, x):
        batchsize = x.size(0)
        out = self.conv1(x)
        out = self.max_pooling1(out)

        out = self.conv2_1(out)
        out = self.conv2_2(out)
        out = self.max_pooling2(out)

        out = self.conv3_1(out)
        out = self.conv3_2(out)
        out = self.max_pooling3(out)

        out = self.conv4_1(out)
        out = self.conv4_2(out)
        out = self.max_pooling4(out) # torch.Size([1, 512, 2, 2])
        out = out.view(batchsize, -1) # torch.Size([1, 2048])
        
        # batchsize * c * h * w -> batchsize * n
        out = self.fc(out)
        out = F.log_softmax(out, dim=1)

        return out
    
def VGGNet():
    return VGGbase()

# x = torch.rand(1, 3, 28, 28)
# net =VGGbase()
# with SummaryWriter(comment='VGGNet') as w:
#     w.add_graph(net, x)
# tensorboard --logdir ./runs/Jul29_16-59-43_ivilab1VGGNet
# http://localhost:6006/
# resnet.py
from turtle import forward
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel, stride=1):
        super(ResBlock, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channel, out_channel,
                      kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(),
            nn.Conv2d(out_channel, out_channel,
                      kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channel)
        )
        self.shortcut = nn.Sequential()
        if in_channel != out_channel or stride > 1:
            # shortcut的大小要与layer相同,故layer做了几次下采样,shortcut也同样
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channel, out_channel,
                          kernel_size=3, stride=stride, padding=1),
                nn.BatchNorm2d(out_channel)
            )

    def forward(self, x):
        out1 = self.layer(x)
        out2 = self.shortcut(x)
        out = out1 + out2
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def make_layer(self, block, out_channel, stride, num_block):
        layer_list = []
        for i in range(num_block):
            if i == 0:
                in_stride = stride
            else:
                in_stride = 1
            layer_list.append(block(self.in_channel,
                                    out_channel,
                                    in_stride))
            self.in_channel = out_channel
        return nn.Sequential(*layer_list)

    def __init__(self):
        super(ResNet, self).__init__()
        self.in_channel = 32
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32,
                      kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.layer1 = self.make_layer(ResBlock, 64, 2, 2)
        self.layer2 = self.make_layer(ResBlock, 128, 2, 2)
        self.layer3 = self.make_layer(ResBlock, 256, 2, 2)
        self.layer4 = self.make_layer(ResBlock, 512, 2, 2)
        
        self.fc = nn.Linear(512, 10)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

def resnet():
    return ResNet()
# mobilenet.py
from turtle import forward
import torch
import torch.nn as nn
import torch.nn.functional as F

class mobilenet(nn.Module):
    def conv_dw(self, in_channel, out_inchannel, stride):
        return nn.Sequential(
            nn.Conv2d(in_channel, in_channel,
                      kernel_size=3, stride=stride, padding=1,
                      groups=in_channel, bias=False),
            nn.BatchNorm2d(in_channel),
            nn.ReLU(),

            nn.Conv2d(in_channel, out_inchannel,
                      kernel_size=1, stride=1, padding=1,
                      bias=False),
            nn.BatchNorm2d(out_inchannel),
            nn.ReLU()
        )

    def __init__(self):
        super(mobilenet, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.conv2_dw2 = self.conv_dw(32, 32, 1)
        self.conv2_dw3 = self.conv_dw(32, 64, 2)
        self.conv2_dw4 = self.conv_dw(64, 64, 1)
        self.conv2_dw5 = self.conv_dw(64, 128, 2)
        self.conv2_dw6 = self.conv_dw(128, 128, 1)
        self.conv2_dw7 = self.conv_dw(128, 256, 2)
        self.conv2_dw8 = self.conv_dw(256, 256, 1)
        self.conv2_dw9 = self.conv_dw(256, 512, 2)

        self.fc = nn.Linear(512, 10)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2_dw2(out)
        out = self.conv2_dw3(out)
        out = self.conv2_dw4(out)
        out = self.conv2_dw5(out)
        out = self.conv2_dw6(out)
        out = self.conv2_dw7(out)
        out = self.conv2_dw8(out)
        out = self.conv2_dw9(out)
        out = F.avg_pool2d(out, 2)
        out = out.view(-1, 512)
        out = self.fc(out)

        return out 

def MobileNetv1_small():
    return mobilenet()
# inception.py
from turtle import forward
from numpy import pad
import torch
import torch.nn as nn
import torch.nn.functional as F

'''
input: A
inception: 
B1 = f1(A)
B2 = f2(A)
B3 = f3(A)
concat([B1, B2, B3])

resnet: B = g(A) + f(A)
'''

def ConvBNRelu(in_channel, out_channel, kernel_size):
    return nn.Sequential(
        nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size,
                  stride=1, 
                  padding=kernel_size//2),
        nn.BatchNorm2d(out_channel),
        nn.ReLU()
    )

class BaseInception(nn.Module):
    def __init__(self, in_channel, out_channel_list, reduce_channel_list):
        super(BaseInception, self).__init__()
        self.branch1_conv = ConvBNRelu(in_channel, 
                                       out_channel_list[0],
                                       1)
        self.branch2_conv1 = ConvBNRelu(in_channel, 
                                       reduce_channel_list[0],
                                       1)
        self.branch2_conv2 = ConvBNRelu(reduce_channel_list[0], 
                                       out_channel_list[1],
                                       3)
        self.branch3_conv1 = ConvBNRelu(in_channel, 
                                       reduce_channel_list[1],
                                       1)
        self.branch3_conv2 = ConvBNRelu(reduce_channel_list[1], 
                                       out_channel_list[2],
                                       5)   
        self.branch4_pool = nn.MaxPool2d(kernel_size=3, 
                                        stride=1, 
                                        padding=1)
        self.branch4_conv = ConvBNRelu(in_channel, 
                                       out_channel_list[3],
                                       3)                                                                                                                          

    def forward(self, x):
        out1 = self.branch1_conv(x)

        out2 = self.branch2_conv1(x)
        out2 = self.branch2_conv2(out2)

        out3 = self.branch3_conv1(x)
        out3 = self.branch3_conv2(out3)

        out4 = self.branch4_pool(x)
        out4 = self.branch4_conv(out4)

        out = torch.concat([out1, out2, out3, out4], dim = 1)

        return out

class InceptionNet(nn.Module):
    def __init__(self):
        super(InceptionNet, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64,
                        kernel_size=7,
                        stride=2,
                        padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )

        self.block2 = nn.Sequential(
            nn.Conv2d(64, 128,
                        kernel_size=3,
                        stride=2,
                        padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )

        self.block3 = nn.Sequential(
            BaseInception(in_channel=128,
                          out_channel_list=[64,64,64,64],
                          reduce_channel_list=[16,16]),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        self.block4 = nn.Sequential(
            BaseInception(in_channel=256,
                          out_channel_list=[96,96,96,96],
                          reduce_channel_list=[32,32]),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        self.fc = nn.Linear(384, 10)

    def forward(self, x):
        out = self.block1(x)
        out = self.block2(out)
        out = self.block3(out)
        out = self.block4(out)
        out = F.avg_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

def InceptionNet_small():
    return InceptionNet()
# pre_resnet.py 从torchvision.models中调用网络模型
from turtle import forward
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class resnet18(nn.Module):
    def __init__(self):
        super(resnet18, self).__init__()
        self.model = models.resnet18(pretrained=True)
        self.num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(self.num_features, 10)

    def forward(self, x):
        out = self.model(x)
        return out

def pytorch_resnet18():
    return resnet18()

四、测试

import torch
import glob
import cv2
from PIL import Image
from torchvision import transforms
import numpy as np
from pre_resnet import resnet18

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

net = resnet18().to(device)

net.load_state_dict(torch.load("1.pth"))

im_list = glob.glob("./dataset/cifar-10/test/*/*.png")

label_name = ["airplane",
              "automobile",
              "bird",
              "cat",
              "deer",
              "dog",
              "frog",
              "horse",
              "ship",
              "truck"]

test_transform = transforms.Compose([
    transforms.Resize((28,28)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010))
])

for im_path in im_list:
    net.eval()
    im_data = Image.open(im_path)

    inputs = test_transform(im_data)
    inputs = torch.unsqueeze(inputs, dim=0)
    inputs = inputs.to(device)
    outputs = net.forward(inputs)

    _, pred = torch.max(outputs.data, dim=1)

    print(label_name[pred.cpu().numpu()[0]])

    img = np.asarray(im_data)
    img = img[:, :, [1, 2, 0]]
    img = cv2.resize(img, (300, 300))
    cv2.imshow("im", img)
    cv2.waitKey()

你可能感兴趣的:(学习,python,深度学习)