深度学习—Pytorch—手写数字识别 :将MINIST数据集中的手写字体图片按照标签分类,分别保存到本地,并记录数量

注意:这里使用的是Pytorch,没有安装将无法使用本方法

废话不多说,上代码:


import os
from torchvision import datasets

#  输入文件地址
id1 = input("请输入MNIST文件存放的地址,如果您还没有下载该文件,请输入任意文件夹名称,将会自动下载MNIST:")
id2 = input("请输入存放图片文件夹名称,该文件夹默认存放于本文件同级目录下:")

#  判断输入的文件地址是否存在,如果没有,则会创建
if not os.path.exists(id1):
    os.makedirs(id1)

# 取出训练集与测试集总的数据,这里取出的数据形式表现为【列表中的元组】,即图片与标签的组合
traindata = [i for i in datasets.MNIST(root="{}/".format(id1),train=True,download=True)]
testdata = [i for i in datasets.MNIST(root="{}/".format(id1),train=False,download=True)]

# 循环保存数字 0 到 9的图片
for x in range(10):

    # 将元组中的图片按照循环次数分别取出,存入列表(这里i为图片,j为标签)
    data1,data2 = [ i for i,j in testdata if j == x ],[ i for i,j in traindata if j == x ]  
    # 定义存储路径
    path1,path2 = "{0}/test_image_{1}".format(id2,x),"{0}/train_image_{1}".format(id2,x)  

    print("正在存储【测试集】与【训练集】数字【{0}】图片,各有{3} ;{4}张,存放于{1} ;{2}"
            .format(x,path1,path2,len(data1),len(data2)))  # 这里分别为存储的:标签号、路径、元素个数
    if not os.path.exists(path1):  # 判断是否需要创建文件夹,存在则跳过
        os.makedirs(path1)
    if not os.path.exists(path2):
        os.makedirs(path2)
    for i, j in enumerate(data1):  # 存储测试集图片:将取出的所有图片迭代出,按照对应索引存入文件夹
        data1[i].save(path1+"\{}.png".format(i))
    for i, j in enumerate(data2):  # 存储训练集图片
        data2[i].save(path2+"\{}.png".format(i))

print("保存完毕,共存储图片{}张,感谢使用!".format(len(traindata)+len(testdata)))

由于设备差异,图片存储速度也将有所差别。
每一行代码都加了详细的注释,所以这里不加以额外说明,请浏览代码即可。
需要注意的是,MINIST中图片本身以PIL.Image.Image的格式与标签成对存放。
欢迎指出bug,以及更优化的方法。

你可能感兴趣的:(深度学习—Pytorch—手写数字识别 :将MINIST数据集中的手写字体图片按照标签分类,分别保存到本地,并记录数量)