手写数字代码识别(pytorch)实现

数据预览: 

import pandas
df=pandas.read_csv('C:\\Users\\HP\\Desktop\\mnist_train.csv',header=None)
df.head()

手写数字代码识别(pytorch)实现_第1张图片

 

MNIST的每一行数据包含785个值。第一个值是图像所表示的数字,其余的784个值是图像(尺寸为28像素× 28像素)的像素值。¶

我们可以使用info()函数查看DataFrame的概况

df.info()

RangeIndex: 60000 entries, 0 to 59999
Columns: 785 entries, 0 to 784
dtypes: int64(785)
memory usage: 359.3 MB

 

以上结果告诉我们,该DataFrame有60 000行。这对应60 000幅训练图像。同时,我们也可以确认每行有785个值。¶

让我们将一行像素值转换成实际图像来直观地查看一下。

我们使用通用的matplotlib库来显示图像。在下面的代码中,我们导入matplotlib库的pyplot包

完整代码: 

#### 导入库
import torch
import torch.nn as nn
import pandas
import matplotlib.pyplot as plt
from torch.utils.data import Dataset #是pytorch加载和导入数据的方式
'''------------构建神经网络类------------'''
class Classifier(nn.Module):#nn.Module是所有类的父类
    """分类器"""
    def __init__(self):
        #初始化pytorch父类
        super().__init__()
        #定义神经网络层
        self.model=nn.Sequential(
        nn.Linear(784,200),
        nn.Sigmoid(),
        nn.Linear(200,10),
        nn.Sigmoid()
        )
        #创建损失函数(均方误差)
        self.loss_function=nn.MSELoss()
        #创建优化器,使用简单梯度下降
        self.optimiser=torch.optim.SGD(self.parameters(),lr=0.01)
        '''可视化'''
        #记录训练进展的计数器和列表
        self.counter=0
        self.progress=[]
    def forward(self,inputs):
        #直接运行模型
        return self.model(inputs)
    
    """训练器"""
    def train(self,inputs,targets):
        #计算网络的输出值
        outputs=self.forward(inputs)
        #计算损失值
        loss=self.loss_function(outputs,targets)
        """下一步是使用损失来更新网络的链接权值"""
        #梯度归零,反向传播,并更新权重
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
        """可视化"""
        #每隔10个训练样本增加一次计数器的值,并将损失值添加进列表的末尾
        self.counter+=1
        if(self.counter%10==0):
            self.progress.append(loss.item())#这里使用item()的作用只是为了方便展开一个单值张量,获取里面的数字
            pass
        #每10000次训练后打印计数器的值,这样可以了解训练进展的快慢
        if(self.counter%10000==0):
            print("counter=",self.counter)
            pass
    """将损失值绘成图"""
    def plot_progress(self):
        df=pandas.DataFrame(self.progress,columns=['loss'])
        df.plot(ylim=(0,1.0),figsize=(16,8),alpha=0.1,marker='.',grid=True,yticks=(0,0.25,0.5))
        pass
'''-------------创建MnistDataset类--------------'''
class MnistDataset(Dataset):
    def __init__(self,csv_file):
        self.data_df=pandas.read_csv(csv_file,header=None)
        pass
    def __len__(self):
        return len(self.data_df)
    def __getitem__(self,index):
        #目标图像(标签)
        label=self.data_df.iloc[index,0]
        target=torch.zeros((10))
        target[label]=1.0
        #图像数据,取值范围是0~255,标准化为0~1
        image_values=torch.FloatTensor(self.data_df.iloc[index,1:].values)/255.0
        #返回标签,图像数据张量以及目标张量
        return label, image_values, target
    pass
    """可视化"""
    def plot_image(self,index):
        arr=self.data_df.iloc[index,1:].values.reshape(28,28)
        plt.title("label = " + str(self.data_df.iloc[index,0]))
        plt.imshow(arr,interpolation='none',cmap='Blues')
        pass
#检查一下到目前为止是否一切正常
mnist_dataset=MnistDataset('C:\\Users\\HP\\Desktop\\mnist_train.csv')
#mnist_dataset.plot_image(9)
"""训练分类器"""
#创建神经网络

C=Classifier()
#在MNIST数据集训练神经网络
for label, image_data_tensor, target_tensor in mnist_dataset:
    C.train(image_data_tensor,target_tensor)
    pass
# 绘制分类器损失值 
C.plot_progress()
# 加载MNIST测试数据 
mnist_test_dataset = MnistDataset('C:\\Users\\HP\\Desktop\\mnist_test.csv')
counter= 10000
counter= 20000
counter= 30000
counter= 40000
counter= 50000
counter= 60000
# 绘制分类器损失值
C.plot_progress()

手写数字代码识别(pytorch)实现_第2张图片

现在我们有了一个训练后的网络,可以进行图像分类了。我们将切换到包含10 000幅图像的MNIST测试数据集。这些是我们的神经网络从来没看到过的图像。让我们用一个新的Dataset对象加载数据集 

# 加载MNIST测试数据
mnist_test_dataset = MnistDataset('C:\\Users\\HP\\Desktop\\mnist_test.csv')
# 挑选一幅图像
record = 19
# 绘制图像和标签
mnist_test_dataset.plot_image(record)

手写数字代码识别(pytorch)实现_第3张图片

让我们看看训练过的神经网络是如何判断这幅图像的。下面的代码继续使用第20幅图像并提取像素值作为image_data。我们使用forward()函数将图像传递并通过神经网络 

image_data = mnist_test_dataset[record][1]
# 调用训练后的神经网络
output = C.forward(image_data)
# 绘制输出张量
pandas.DataFrame(output.detach().numpy()).plot(kind='bar',
legend=False, ylim=(0,1))

手写数字代码识别(pytorch)实现_第4张图片

 

你可能感兴趣的:(机器学习,机器学习,人工智能,pytorch,python)