首先介绍如何用pytorch加载网络现有数据集,然后介绍如何制作自己的图像数据集并批量读取来训练自己的网络。
提示:以下是本篇文章正文内容,下面案例可供参考
使用Pytorch进行读取本地的MINIST数据集并进行装载
# 训练数据和测试数据的下载
trainDataset = torchvision.datasets.MNIST( # torchvision可以实现数据集的训练集和测试集的下载
root="./data", # 下载数据,并且存放在data文件夹中
train=True, # train用于指定在数据集下载完成后需要载入哪部分数据,如果设置为True,则说明载入的是该数据集的训练集部分;如果设置为False,则说明载入的是该数据集的测试集部分。
transform=transforms.ToTensor(), # 数据的标准化等操作都在transforms中,此处是转换
download=True
)
testDataset = torchvision.datasets.MNIST(
root="./data",
train=False,
transform=transforms.ToTensor(),
download=True
)
训练神经网络需要标准输入图像和它的真值标签。
在分类问题中,比如猫、狗、船、车等等,我们可以用数字代表不同的分类。可以制作一个txt文档用于存放输入图像的地址和它对应的标签数字。
我现在有个任务需要以图像作为输入,以另一张处理过后的图像作为它的真值,所以我在txt文本下面写的是它们的路径。在项目路径下新建了一个train文件夹用于放训练图片,并在train文件夹下新建一个训练的txt用于标注训练图像和标签图像
PyTorch读取图片,主要是通过Dataset类是Pytorch中所有数据集加载类中应该继承的父类。我们通过继承改写Dataset类来读取自己的图像数据集。其中以下三个函数必须改写:
__init__方法里面进行读取数据文件
__getitem__方法进行支持下标访问
__len__方法返回自定义数据集的大小,方便后期遍历
class OpticalSARDataset(Data.Dataset):
"""
定义自己的数据集、读取数据、初始化数据
"""
def __init__(self, data_dir, part):
# 所有图片的绝对路径
assert part in ["train", "val"]
self.image_dir = os.path.join(data_dir, part)
self.img_names = []
self.label_names = []
with open(os.path.join(data_dir, part, "label.txt")) as f:
while True:
il = f.readline(1500) # 如果样本数据名称大于1500,修改该值
if not il:
break
a = il.split(sep=' ')
self.img_names.append(a[0])
self.label_names.append(a[1][0:-1]) # remove '\n'
self.samples_num = len(self.img_names)
# print(self.samples_num)
self.transform = torchvision.transforms.Compose([
# 将 PIL 图片转换成位于[0.0, 1.0]的floatTensor, shape (C x H x W)
torchvision.transforms.ToTensor()])
def __len__(self):
# 返回图像的数量
return self.samples_num
def __getitem__(self, index):
tp_img = Image.open(os.path.join(self.image_dir, self.img_names[index])
).convert('RGB')
tp_label = Image.open(os.path.join(self.image_dir, self.label_names[index])
).convert('RGB')
# PIL.Image.open 读取的图片数据是RGB格式;
tp_img = cv2.cvtColor(np.asarray(tp_img), cv2.COLOR_RGB2BGR)
tp_label = cv2.cvtColor(np.asarray(tp_label), cv2.COLOR_RGB2BGR) # 转换为BGR便于cv2.imshow,跟下面imshow之前RGB2BGR只用一种方法,这里统一为cv2的BGR格式
img = self.transform(tp_img)
label = self.transform(tp_label)
sample = {
"label": label, # shape
"image": img # shape: (3, *image_size)
}
return sample
# 利用之前创建好的OpticalSARDataset类去创建数据对象
train_dataset = OpticalSARDataset(data_dir, 'train') # 训练数据集
之前所说的Dataset类是读入数据集数据并且对读入的数据进行了索引。
但是光有这个功能是不够用的,在实际的加载数据集的过程中,我们的数据量往往都很大,对此我们还需要一下几个功能:
可以分批次读取:batch-size
可以对数据进行随机读取,可以对数据进行洗牌操作(shuffling),打乱数据集内数据分布的顺序
可以并行加载数据(利用多核处理器加快载入数据的效率)
Dataloader这个类并不需要我们自己设计代码,我们只需要利用DataLoader类读取我们设计好的类。
# 利用dataloader读取我们的数据对象,并设定batch-size和工作现场
train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
shuffle=True, num_workers=0)
batch = iter(train_iter).next()
print(batch["image"].shape, batch["label"].shape)
print(batch["image"][0].shape)
参考博客:
定义自己的数据集
pytorch加载自己数据集
设计自己的数据
训练自己数据完整步骤
Dataset类