使用MNIST数据集训练第一个pytorch CNN手写数字识别神经网络

1、CNN卷积神经网络

2、torchvision.datasets

3、MINIST数据集

4、神经网络的训练

5、pytorch训练模型的保存

使用MNIST数据集训练第一个pytorch CNN手写数字识别神经网络_第1张图片

CNN

PyTorch 提供了许多预加载的数据集(例如 FashionMNIST),所有数据集都是torch.utils.data.Dataset 的子类,它们具有__getitem__和__len__实现的方法。因此,它们都可以传递给
torch.utils.data.DataLoader 也可以使用torch.multiprocessing并行加载多个样本的数据 。例如:

以下是如何从 TorchVision加载Fashion-MNIST数据集的示例。Fashion-MNIST由 60,000 个训练示例和 10,000 个测试示例组成。每个示例都包含一个 28×28 灰度图像和来自 10 个类别之一的相关标签。

MINIST数据

MINIST的数据分为2个部分:55000份训练数据(mnist.train)和10000份测试数据(mnist.test)。这个划分有重要的象征意义,他展示了在机器学习中如何使用数据。在训练的过程中,我们必须单独保留一份没有用于机器训练的数据作为验证的数据,这才能确保训练的结果的可行性。

前面已经提到,每一份MINIST数据都由图片以及标签组成。我们将图片命名为“x”,将标记数字的标签命名为“y”。训练数据集和测试数据集都是同样的结构,例如:训练的图片名为 mnist.train.images 而训练的标签名为 mnist.train.labels。

每一个图片均为28×28像素,我们可以将其理解为一个二维数组的结构:

使用MNIST数据集训练第一个pytorch CNN手写数字识别神经网络_第2张图片

MNIST

我们使用以下参数加载MNIST 数据集:

      • root ( string ) – 数据集所在MNIST/processed/training.pt 和 MNIST/processed/test.pt存在的根目录。
      • train ( bool , optional ) – 如果为 True,则从 中创建数据集training.pt,否则从test.pt.
      • download ( bool , optional ) – 如果为 true,则从 Internet 下载数据集并将其放在根目录中。如果数据集已经下载,则不会再次下载。
      • transform ( callable , optional ) – 一个函数/转换,它接收一个 PIL 图像并返回一个转换后的版本。例如,transforms.RandomCrop
      • target_transform ( callable , optional ) – 一个接收目标并对其进行转换的函数/转换。
torchvision.datasets.MNIST( root: str ,
                           train: bool = True , 
                           transform: Optional[Callable] = None , 
  target_transform: Optional[Callable] = None , 
    download: bool = False )

所有数据集都有几乎相似的 API。它们都有两个共同的参数: transform和 target_transform,本期文章,我们基于MNIST数据集

你可能感兴趣的:(pytorch,神经网络,cnn,人工智能,机器学习)