今天是尝试用 PyTorch 框架来跑 MNIST 手写数字数据集的第二天,主要学习加载 MNIST 数据集。本 blog 主要记录一个学习的路径以及学习资料的汇总。
注意:这是用 Python 2.7 版本写的代码
第一天(LeNet 网络的搭建):https://blog.csdn.net/qq_36627158/article/details/108098147
第二天(加载 MNIST 数据集):https://blog.csdn.net/qq_36627158/article/details/108119048
第三天(训练模型):https://blog.csdn.net/qq_36627158/article/details/108163693
第四天(单例测试):https://blog.csdn.net/qq_36627158/article/details/108183655
感谢 凯神 提供的代码与耐心指导!
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import glob
import os.path as osp
from PIL import Image
import matplotlib.pyplot as plt
TRAIN_BATCH_SIZE = 128
TEST_BATCH_SIZE = 1000
class MNIST(Dataset): # define a class named MNIST
# read all pictures' filename
def __init__(self, root, transform=None):
self.filenames = []
self.transform = transform
# read filenames
for i in range(10):
# 'root/0/all_png'
filenames = glob.glob(osp.join(root, str(i), '*.png'))
for fn in filenames:
# (filename, label)
self.filenames.append((fn, i))
self.len = len(self.filenames)
# Get a sample from the dataset
# Return an image and it's label
def __getitem__(self, index):
# open the image
image_fn, label = self.filenames[index]
image = Image.open(image_fn)
# May use transform function to transform samples
if self.transform is not None:
image = self.transform(image)
return image, label
# get the length of dataset
def __len__(self):
return self.len
# define the transformation
# PIL images -> torch tensors [0, 1]
transform = transforms.Compose([
transforms.ToTensor()
])
# 2. load the MNIST training dataset
trainset = MNIST(
root='/home/ubuntu/Downloads/C6/mnist_png/training',
transform=transform
)
# divide the dataset into batches
trainset_loader = DataLoader(
trainset,
batch_size=TRAIN_BATCH_SIZE,
shuffle=True,
num_workers=0
)
# 3. load the MNIST testing dataset
testset = MNIST(
root='/home/ubuntu/Downloads/C6/mnist_png/testing',
transform=transform
)
# divide the dataset into batches
testset_loader = DataLoader(
testset,
batch_size=TEST_BATCH_SIZE,
shuffle=False,
num_workers=0
)
https://pytorch-cn.readthedocs.io/zh/latest/package_references/data/
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
注意:__init__并不相当于C#中的构造函数,执行它的时候,实例已构造出来了。__init__作用是初始化已实例化后的对象。
图文均来自链接:https://www.cnblogs.com/insane-Mr-Li/p/9758776.html
__len__()
和 __getitem__() 函数。
前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。之前看代码,一直没有看到具体体现 __getitem__() 函数的使用地方。
后面查到了:只要
继承了 Dataset 这个类后,就可以通过类的实例化对象的索引来调用到 _getitem_() 了。如: data[0]
https://www.zhihu.com/question/383099903
(图也是链接里的)
将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。
seq = ['one', 'two', 'three']
for i, element in enumerate(seq):
print i, element
# 0 one
# 1 two
# 2 three
https://www.runoob.com/python/python-func-enumerate.html
Batch Size的理解:https://blog.csdn.net/qq_34886403/article/details/82558399
batch size 设置技巧:https://blog.csdn.net/kl1411/article/details/82983971
顺便找到了一个小白科普贴:深度学习中GPU和显存分析
num_worker
https://www.cnblogs.com/hesse-summer/p/11343870.html
https://blog.csdn.net/breeze210/article/details/99679048
迭代是Python最强大的功能之一,是访问集合元素的一种方式。
迭代器是一个可以记住遍历的位置的对象。
迭代器对象从集合的第一个元素开始访问,直到所有的元素被访问完结束。迭代器只能往前不会后退。
迭代器有两个基本的方法:iter() 和 next()。
https://www.runoob.com/python3/python3-iterator-generator.html
https://zhuanlan.zhihu.com/p/76893455
https://www.cnblogs.com/ranjiewen/p/10128046.html
https://www.cnblogs.com/luminousjj/p/9359543.html
https://www.cnblogs.com/luminousjj/p/9359543.html