pytorch 实现MNIST数据集建立及训练

文章主要包含:官方数据集导入、自定义数据集,自定义网络结构,训练,训练后的模型使用

头文件导入

import torch
import torchvision
import torchsummary
import os
import numpy as np
import matplotlib.pyplot as plt

常量定义

BATCH_SIZE = 64   
#图像行列像素数量
IMAGE_ROW = 28   
IMAGE_COL = 28
#数据根路径
DATA_SOURCE_DIR = "../datasets/MNIST/raw/"
TRANSFORM = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.,),(1.,))
    ])
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(DEVICE)

数据集导入

数据集说明: 
参考网址 https://www.cnblogs.com/xianhan/p/9145966.html
数据集网址 http://yann.lecun.com/exdb/mnist/

train-labels-idx1-ubyte
[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000801(2049) magic number (MSB first)
0004     32 bit integer  60000            number of items
0008     unsigned byte   ??               label
0009     unsigned byte   ??               label
........
xxxx     unsigned byte   ??               label

train-images-idx3-ubyte
[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000803(2051) magic number
0004     32 bit integer  60000            number of images
0008     32 bit integer  28               number of rows
0012     32 bit integer  28               number of columns
0016     unsigned byte   ??               pixel
0017     unsigned byte   ??               pixel
........
xxxx     unsigned byte   ??               pixel

t10k-labels-idx1-ubyte
[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000801(2049) magic number (MSB first)
0004     32 bit integer  10000            number of items
0008     unsigned byte   ??               label
0009     unsigned byte   ??               label
........
xxxx     unsigned byte   ??               label

t10k-images-idx3-ubyte
[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000803(2051) magic number
0004     32 bit integer  10000            number of images
0008     32 bit integer  28               number of rows
0012     32 bit integer  28               number of columns
0016     unsigned byte   ??               pixel
0017     unsigned byte   ??               pixel
........
xxxx     unsigned byte   ??               pixel

官方数据集导入

TRAIN_DATASETS = torchvision.datasets.MNIST(root="../datasets",train=True,download=True,transform=TRANSFORM)
TRAIN_LOADER = torch.utils.data.DataLoader(TRAIN_DATASETS,shuffle=True,batch_size=BATCH_SIZE)
TEST_DATASETS = torchvision.datasets.MNIST(root="../datasets",train=False,download=True,transform=TRANSFORM)
TEST_LOADER = torch.utils.data.DataLoader(TRAIN_DATASETS,shuffle=False,batch_size=BATCH_SIZE)

 查看图片

img,label = TRAIN_DATASETS[0]
img = img.numpy()
plt.title(label)
plt.imshow(img[0])

自定义数据集

torch官方解释文档(纯英文) https://pytorch.org/docs/stable/data.html
torch.utils.data.Dataset源码   https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#Dataset

自定义DataSet基本结构
calss DataSets(torch.utils.data.Dataset):
    def __init__(self):
        super(DataSets,self).__init__()
        pass
    def __getitem__(self,idx):
        pass
    def __len__(self):
        pass

struct.unpack_from(fmt, buf,offset)
fmt: 内容解析格式 '>'or'<' + str(number) + 'B'or'b'or'I'or'i'
buf: 文件缓存
offset:指针偏移量

import struct
def decode_idx1_ubyte(idx1_ubyte_file):
    with open(idx1_ubyte_file, 'rb') as fp:
        bin_data = fp.read()
        #解析头文件        
        fmt = ">II"
        magic_number,label_number = struct.unpack_from(fmt, bin_data, 0)
        offset = 8 #指针偏移量
        print("magic number:0x{:0>8x}({})\tlabel number:{}".format(magic_number,magic_number,label_number))
        fmt=">B"
        label=[]
        for idx in range(label_number):
            label.append(struct.unpack_from(fmt,bin_data,offset+idx))
    return label
def decode_idx3_ubyte(idx3_ubyte_file):
    with open(idx3_ubyte_file, 'rb') as fp:
        bin_data = fp.read()
        #解析头文件        
        fmt = ">IIII"
        magic_number,image_number,rows,cols = struct.unpack_from(fmt, bin_data,  0)
        offset = 16 #指针偏移量
        print("magic number:0x{:0>8x}({})\t image number:{}".format(magic_number, magic_number, image_number))
        print("rows:{}\t columns:{}".format(rows, cols))
        fmt='>'+str(rows*cols)+'B'
        image=[]
        for idx in range(image_number):
            data = struct.unpack_from(fmt, bin_data, offset+idx*rows*cols)
            data = np.array(data,dtype=np.uint8).reshape((rows, cols))
            image.append(data)
    image = np.array(image)
    return image
class MyMNISTDataSets(torch.utils.data.Dataset):
    def __init__(self,root=DATA_SOURCE_DIR,train=True,transform=None):
        super(MyMNISTDataSets,self).__init__()
        self.root = root
        self.transform = transform
        self.train = train
        if self.train:
            image_path = "train"
            label_path = "train"
        else:
            image_path = "t10k"
            label_path = "t10k"
        image_path = image_path+"-images-idx3-ubyte"
        label_path = label_path+"-labels-idx1-ubyte"
        image_path = os.path.join(self.root,image_path)
        label_path = os.path.join(self.root,label_path)
        self.data, self.targets = decode_idx3_ubyte(image_path),decode_idx1_ubyte(label_path)
        
    def __getitem__(self,idx):
        data,label = self.data[idx], self.targets[idx]
        label = torch.as_tensor(label,dtype=torch.int64)
        if self.transform is not None:
            data = self.transform(data)
        data = data.type(torch.FloatTensor)
        return data,label
    
    def __len__(self):
        return len(self.data)
TRAIN_DATASETS = MyMNISTDataSets(train=True,transform=TRANSFORM)
TRAIN_LOADER = torch.utils.data.DataLoader(TRAIN_DATASETS,shuffle=True,batch_size=BATCH_SIZE)
TEST_DATASETS = MyMNISTDataSets(train=False,transform=TRANSFORM)
TEST_LOADER = torch.utils.data.DataLoader(TRAIN_DATASETS,shuffle=False,batch_size=BATCH_SIZE)

 查看图片

img,label = TRAIN_DATASETS[0]
img = img.numpy()
plt.title(label)
plt.imshow(img[0])

网络定义

自定义线性网络

class LinearNet(torch.nn.Module):
    def __init__(self):
        super(LinearNet,self).__init__()
        self.l1 = torch.nn.Linear(28*28,512)
        self.l2 = torch.nn.Linear(512,256)
        self.l3 = torch.nn.Linear(256,128)
        self.l4 = torch.nn.Linear(128,64)
        self.l5 = torch.nn.Linear(64,10)
    def forward(self,x):
        x = x.view(-1,IMAGE_ROW*IMAGE_COL)
        x = torch.nn.functional.relu(self.l1(x))
        x = torch.nn.functional.relu(self.l2(x))
        x = torch.nn.functional.relu(self.l3(x))
        x = torch.nn.functional.relu(self.l4(x))
        y = self.l5(x)
        return y
model = LinearNet()
model.to(DEVICE)
torchsummary.summary(model,(1,28,28))

自定义FCNN

class CNNNet(torch.nn.Module):
    def __init__(self):
        super(CNNNet,self).__init__()
        self.conv1 = torch.nn.Conv2d(in_channels = 1,out_channels = 10,kernel_size=5)
        self.conv2 = torch.nn.Conv2d(in_channels = 10,out_channels = 20,kernel_size=5)
        self.pooling = torch.nn.MaxPool2d(2)
        self.fc = torch.nn.Linear(in_features = 320,out_features = 10)
        self.relu = torch.nn.ReLU()
    def forward(self,x):
        batch_size = x.size(0)
        x = self.conv1(x)
        x = self.pooling(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.pooling(x)
        x = self.relu(x)
        x = x.view(batch_size,-1)
        x = self.fc(x)
        return x
model = CNNNet()
model.to(DEVICE)
torchsummary.summary(model,(1,28,28))

模型训练

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.5)
import sys

for epoch in range(2):
    model.train()
    running_loss = 0.0
    for batch_idx,data in enumerate(TRAIN_LOADER):
        inputs,target = data
        inputs,target = inputs.to(DEVICE),target.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(inputs)
        target = target.squeeze()
        loss = criterion(outputs,target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if batch_idx % 50 == 49:
            sys.stdout.write("epoch:{:2d}\t {}\t:{:.2%}\t loss:{:.2f}\t\r".format(epoch,"train",(batch_idx+1)/len(TRAIN_LOADER),running_loss/(batch_idx+1)))
            sys.stdout.flush()
    sys.stdout.write('\n')
    sys.stdout.flush()
    
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx,data in enumerate(TEST_LOADER):
            inputs,target = data
            inputs,target = inputs.to(DEVICE),target.to(DEVICE)
            outputs = model(inputs)
            target = target.squeeze()
            _,predict = torch.max(outputs.data,dim=1)
            
            
            total += target.size(0)
            correct += (predict == target).sum().item()
            if batch_idx % 50 == 49:
                sys.stdout.write("epoch:{:2d}\t {}\t:{:.2%}\t accuracy:{:.2%}\t\r".format(epoch,"test",(batch_idx+1)/len(TEST_LOADER),correct/total))
                sys.stdout.flush()
        sys.stdout.write('\n')
        sys.stdout.flush()

结果测试

with torch.no_grad():
    choice = np.random.randint(0,len(TEST_DATASETS))
    inputs,target = TEST_DATASETS[choice]
    inputs = torch.as_tensor( inputs.numpy().reshape((1,1,28,28)))
    inputs,target = inputs.to(DEVICE),target.to(DEVICE)
    outputs = model(inputs)
    print(outputs)
    _,predict = torch.max(outputs.data,dim=1)
    plt.title(predict)
    plt.imshow(inputs.to("cpu").numpy()[0,0])

你可能感兴趣的:(Pytorch例程,pytorch,python)