MNIST 数据集主要由一些手写数字的图片和相应的标签组成,图片一共有10 类,分别对应从0~9 ,共10 个阿拉伯数字。
数据以.CSV形式存储在MNIST_dataset文件夹下,训练数据集共有42000个数据,测试集有28000个数据。
1. 首先定义文件读取函数
def read_csv(fileRoad):
return np.loadtxt(fileRoad, dtype=str, skiprows=1)
由于读取到的数据为字符串数组,一行即为一个图片1 * 784,784源于图像大小 28 * 28=74采用以下代码将字符串变为浮点列表(便于进行转换为ndarray或者tensor进行后续处理)
data = [float(i) for i in datas[index].split(',')]
2. 定义图片索引查看和拼接全图查看函数
# 按索引查看图片
def img_show(fileRoad,index):
datas = read_csv(fileRoad)
data = [float(i) for i in datas[index].split(',')]
if len(data) == 785:
content = data[1:]
elif len(data) == 784:
content = data
img = np.array(content).reshape(28, 28)
pilImg = Image.fromarray(np.uint8(img))
pilImg.show()
# 对图片进行拼接,全图显示
def img_show(fileRoad):
datas = read_csv(fileRoad)
dataAllNum = len(datas)
columNum = int(np.sqrt(dataAllNum))
rowNum = int(dataAllNum/columNum) + 1
count = 0
img = []
for k in range(rowNum):
row = []
for n in range(columNum):
if count<dataAllNum:
data = [float(i) for i in datas[count].split(',')]
if len(data) == 785:
content = data[1:]
elif len(data) == 784:
content = data
else:
content = np.ones([1,784],dtype=np.float)
imgSingle = np.array(content).reshape(28, 28)
if len(row)==0:
row.append(imgSingle)
else:
row[0] = np.hstack((row[0],imgSingle))
count = count + 1
if len(img)==0:
img.append(row[0])
else:
img[0] = np.vstack((img[0],row[0]))
pilImg = Image.fromarray(np.uint8(img[0]))
pilImg.show()
3. 定义dataset和dataloader
class Mydata(Dataset):
def __init__(self, dir):
self.datas = read_csv(dir)
self.rootDir = dir
def __getitem__(self, index):
data = [float(i) for i in self.datas[index].split(',')]
if len(data)==785:
label = int(data[0])
sample = torch.tensor(data[1:])
return sample,label
elif len(data)==784:
sample = torch.tensor(data)
return sample
def __len__(self):
return len(self.datas)
trainLoader = DataLoader(trainData,batch_size=150,shuffle=True)
其中batch_size为超参数,在测试后选取batch为150
class Net(nn.Module):
def __init__(self,in_num,out_num):
super(Net,self).__init__()
self.network = nn.Sequential(
nn.Linear(in_num, 1000),
nn.ReLU(),
nn.Linear(1000, 800),
nn.Sigmoid(),
nn.Linear(800, 10),
nn.Softmax(dim=1)
)
self.optimizer = torch.optim.SGD(self.parameters(), lr=0.08)
self.loss_func = torch.nn.CrossEntropyLoss()
def forward(self,x):
return self.network(x)
设置10个epoch
trainData = Mydata(trainRoad)
trainLoader = DataLoader(trainData,batch_size=150,shuffle=True)
myNet = Net(784,10)
epochList = []
lossList = []
count = 0
print("can't find MINIST_model, start training...")
for epoch in range(1,11):
correctNum = 0
dataNum = 0
lossNum = 0
for inputs,label in trainLoader:
predY = myNet.forward(inputs)
loss = myNet.loss_func(predY,label)
count_accuracy(torch.max(predY, 1)[1].numpy(),label.numpy())
lossNum = lossNum + loss.item()
myNet.optimizer.zero_grad()
loss.backward()
myNet.optimizer.step()
lossList.append(float(lossNum) / dataNum)
epochList.append(epoch-1)
accuracy = float(correctNum*100)/dataNum
print('epoch:%d | accuracy = %.2f%%' % (epoch, accuracy))
print("training complete")
训练结束,对模型进行保存(准确率高于95%)
if accuracy>=95:
torch.save(myNet, 'MINIST_model.pkl')
print("model saved")
else:
print("accuracy is too low,didn't saved this model")
用 Matplotlib对10次epoch的loss绘制图像
plt.plot(epochList,lossList)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()
发现随着epoch增加loss下降越来越不明显,增加epoch对改进模型效果不大
编写准确率测试函数,测试读入模型准确率
def test_existModle(myNet):
trainData = Mydata(trainRoad)
trainLoader = DataLoader(trainData, batch_size=150, shuffle=True)
for input,label in trainLoader:
output = myNet(input)
count_accuracy(torch.max(output, 1)[1].numpy(), label.numpy())
print('accuracy = %.2f%%' % (float(correctNum * 100) / dataNum))
编写预测函数,通过读入已训练模型对test.csv进行预测,结果存为sample_submission_predict.csv
def predict(myNet):
try:
pd.read_csv(predictRoad, encoding='utf-8')
print("file sample_submission_predict.csv exist")
except:
sampleForm = pd.read_csv(sampleRoad, encoding='utf-8')
dataset =Mydata(testRoad)
testdata = DataLoader(dataset,batch_size=1,shuffle=False)
print("start to predict sample_submission.csv and save...")
for i,data in enumerate(testdata):
output = myNet(data)
result = torch.max(output, 1)[1].item()
sampleForm['Label'].loc[i] = result
sampleForm.to_csv(predictRoad, encoding='utf-8', index=False)
print("complete save sample_submission_predict.csv")
程序先尝试读取MINIST_model.pkl模型文件,若没有则进行训练
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,Dataset
import matplotlib.pyplot as plt
import pandas as pd
from PIL import Image
dataNum = 0
correctNum = 0
testRoad = "./MNIST_dataset/test.csv"
trainRoad = "./MNIST_dataset/train.csv"
sampleRoad = "./MNIST_dataset/sample_submission.csv"
predictRoad = "./MNIST_dataset/sample_submission_predict.csv"
def read_csv(fileRoad):
return np.loadtxt(fileRoad, dtype=str, skiprows=1)
def img_show(fileRoad,index):
datas = read_csv(fileRoad)
data = [float(i) for i in datas[index].split(',')]
if len(data) == 785:
content = data[1:]
elif len(data) == 784:
content = data
img = np.array(content).reshape(28, 28)
pilImg = Image.fromarray(np.uint8(img))
pilImg.show()
def img_show(fileRoad):
datas = read_csv(fileRoad)
dataAllNum = len(datas)
columNum = int(np.sqrt(dataAllNum))
rowNum = int(dataAllNum/columNum) + 1
count = 0
img = []
for k in range(rowNum):
row = []
for n in range(columNum):
if count<dataAllNum:
data = [float(i) for i in datas[count].split(',')]
if len(data) == 785:
content = data[1:]
elif len(data) == 784:
content = data
else:
content = np.ones([1,784],dtype=np.float)
imgSingle = np.array(content).reshape(28, 28)
if len(row)==0:
row.append(imgSingle)
else:
row[0] = np.hstack((row[0],imgSingle))
count = count + 1
if len(img)==0:
img.append(row[0])
else:
img[0] = np.vstack((img[0],row[0]))
pilImg = Image.fromarray(np.uint8(img[0]))
pilImg.show()
def count_accuracy(pred_y, label):
global dataNum, correctNum
for i in range(len(label)):
dataNum =dataNum +1
if pred_y[i] == label[i]:
correctNum = correctNum +1
def test_existModle(myNet):
trainData = Mydata(trainRoad)
trainLoader = DataLoader(trainData, batch_size=150, shuffle=True)
for input,label in trainLoader:
output = myNet(input)
count_accuracy(torch.max(output, 1)[1].numpy(), label.numpy())
print('accuracy = %.2f%%' % (float(correctNum * 100) / dataNum))
def predict(myNet):
try:
pd.read_csv(predictRoad, encoding='utf-8')
print("file sample_submission_predict.csv exist")
except:
sampleForm = pd.read_csv(sampleRoad, encoding='utf-8')
dataset =Mydata(testRoad)
testdata = DataLoader(dataset,batch_size=1,shuffle=False)
print("start to predict sample_submission.csv and save...")
for i,data in enumerate(testdata):
output = myNet(data)
result = torch.max(output, 1)[1].item()
sampleForm['Label'].loc[i] = result
sampleForm.to_csv(predictRoad, encoding='utf-8', index=False)
print("complete save sample_submission_predict.csv")
class Mydata(Dataset):
def __init__(self, dir):
self.datas = read_csv(dir)
self.rootDir = dir
def __getitem__(self, index):
data = [float(i) for i in self.datas[index].split(',')]
if len(data)==785:
label = int(data[0])
sample = torch.tensor(data[1:])
return sample,label
elif len(data)==784:
sample = torch.tensor(data)
return sample
def __len__(self):
return len(self.datas)
class Net(nn.Module):
def __init__(self,in_num,out_num):
super(Net,self).__init__()
self.network = nn.Sequential(
nn.Linear(in_num, 1000),
nn.ReLU(),
nn.Linear(1000, 800),
nn.Sigmoid(),
nn.Linear(800, 10),
nn.Softmax(dim=1)
)
self.optimizer = torch.optim.SGD(self.parameters(), lr=0.08)
self.loss_func = torch.nn.CrossEntropyLoss()
def forward(self,x):
return self.network(x)
if __name__ == '__main__':
try:
myNet = torch.load('MINIST_model.pkl')
print("MINIST_model exist and have been loaded")
print("testing accuracy...")
test_existModle(myNet)
predict(myNet)
except:
trainData = Mydata(trainRoad)
trainLoader = DataLoader(trainData,batch_size=150,shuffle=True)
myNet = Net(784,10)
epochList = []
lossList = []
count = 0
print("can't find MINIST_model, start training...")
for epoch in range(1,11):
correctNum = 0
dataNum = 0
lossNum = 0
for inputs,label in trainLoader:
predY = myNet.forward(inputs)
loss = myNet.loss_func(predY,label)
count_accuracy(torch.max(predY, 1)[1].numpy(),label.numpy())
lossNum = lossNum + loss.item()
myNet.optimizer.zero_grad()
loss.backward()
myNet.optimizer.step()
lossList.append(float(lossNum) / dataNum)
epochList.append(epoch-1)
accuracy = float(correctNum*100)/dataNum
print('epoch:%d | accuracy = %.2f%%' % (epoch, accuracy))
print("training complete")
if accuracy>=95:
torch.save(myNet, 'MINIST_model.pkl')
print("model saved")
else:
print("accuracy is too low,didn't saved this model")
plt.plot(epochList,lossList)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()