前言
经过这几天学习,我算是把数据集这一块给摸清楚了,前面分布分支的学习总是有点模棱两可,不清楚这步到底要干啥,在网上找资料学习时,总是拿的pytorch官网给的数据集,没有针对性和专一性。这里教大家如何使用咱们自己的数据集,当然,在做实验时数据集是通过爬虫来获取的,关于爬虫的相关知识可以留言私信,或者看我第一篇博客哦
一、MyData类的定义
在自建数据集时需要自己去定义一个dataset类来继承torch.utils.data.Dataset
来看代码
class MyData(Dataset):
def __init__(self, root_dir, label_dir, transform=None): # 初始化类,为class提供全局变量
self.transform = transform
self.root_dir = root_dir # 根文件位置
self.label_dir = label_dir # 子文件名
self.path = os.path.join(self.root_dir, self.label_dir) # 合并,即具体位置
self.img_path = os.listdir(self.path) # 转换成列表的形式
def __getitem__(self, idx): # 获取列表中每一个图片
img_name = self.img_path[idx] # idx表示下标,即对应位置
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 每一个图片的位置
img = Image.open(img_item_path) # 调用方法,拿到该图像
img = img.convert("RGB")
img = self.transform(img)
label = self.label_dir # 标签
return img, label # 返回img 图片 label 标签
def __len__(self): # 返回长度
return len(self.img_path)
这里要注意一下,跟第一篇学习笔记的不同在于第一篇没有定义transform导致返回的就是PIL类型,这边加入了几行代码,目的就是为了返回tensor类型,这里返回的是img,label两个对象
二、数据集的实际运用
上面的类是通法,在任何研究中都可以套用,下面来看在本次实验中的实际运用
tensor_trans = transforms.Compose([transforms.ToTensor(), transforms.Resize([512, 512])])
root_dir = 'D://情绪图片' # 根目录
happy_label_dir = '开心' # 子目录
happy_dataset = MyData(root_dir, happy_label_dir, transform=tensor_trans) # 开心数据集创建完成
这边是通过totensor转换为tensor类型,同时将图片尺寸变为512*512
也算so easy吧
三、DataLoader类
dataloader是用来load数据集,其中batch_size=4是为了每次抓取4张;shuffle是按需求来是否需要打乱,即在等于True的时候是打乱的,False的时候是不打乱的;drop_last是表示在最后一次抓取时不满4个是否需要保留,比如一共10张图片,每次抓取4个,最后一次不满4个可以选择保留或者舍弃
test_loader = DataLoader(dataset=happy_dataset, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
# img, label = happy_dataset[0]
step = 0
writer = SummaryWriter("dataloader")
for epoch in range(2):
for data in test_loader:
imgs, label = data
writer.add_images('{}'.format(epoch), imgs, step)
step = step + 1
writer.close()
四、源码
这边是最终的源码,大家可以按自己的要求选取哦
# -*- coding = utf-8 -*-
import torchvision
from torch.utils.data import Dataset
import cv2
from PIL import Image # 图像处理的库
import os
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torch.utils.data import DataLoader
class MyData(Dataset):
def __init__(self, root_dir, label_dir, transform=None): # 初始化类,为class提供全局变量
self.transform = transform
self.root_dir = root_dir # 根文件位置
self.label_dir = label_dir # 子文件名
self.path = os.path.join(self.root_dir, self.label_dir) # 合并,即具体位置
self.img_path = os.listdir(self.path) # 转换成列表的形式
def __getitem__(self, idx): # 获取列表中每一个图片
img_name = self.img_path[idx] # idx表示下标,即对应位置
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 每一个图片的位置
img = Image.open(img_item_path) # 调用方法,拿到该图像
img = img.convert("RGB")
img = self.transform(img)
label = self.label_dir # 标签
return img, label # 返回img 图片 label 标签
def __len__(self): # 返回长度
return len(self.img_path)
tensor_trans = transforms.Compose([transforms.ToTensor(), transforms.Resize([512, 512])])
root_dir = 'D://情绪图片' # 根目录
happy_label_dir = '开心' # 子目录
happy_dataset = MyData(root_dir, happy_label_dir, transform=tensor_trans) # 开心数据集创建完成
# img, label = happy_dataset[2] # 由上面可知返回的是两个值
# print(label) # 分别调用
# img.show()
# batch_size每次取dataset的四个数据集并打包, shuffle是是否打乱,drop_last为False即最后一步不满4个时不舍,反之舍
# print(happy_dataset[0])
# tensor_trans = transforms.Compose([transforms.ToTensor(), transforms.Resize([224, 224])])
# test_data = torchvision.datasets(datasets=happy_dataset, transforms=tensor_trans)
test_loader = DataLoader(dataset=happy_dataset, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
# img, label = happy_dataset[0]
step = 0
writer = SummaryWriter("dataloader")
for epoch in range(2):
for data in test_loader:
imgs, label = data
writer.add_images('{}'.format(epoch), imgs, step)
step = step + 1
writer.close()
下面来看结果
五、总结
这几天对数据集的学习对我这个初次接触pytorch的人来说也是挺头疼的,各种报错,包括loader的基本使用当时也是不太熟悉。现在大部分教学用的数据集都是通过torchvision.dataset.来获取pytorch自带的数据集,这篇也算是给大家提供另一种方法吧
最后有什么不明白或者报错不知道怎么解决的可以留言私信哦~说不定我曾经也经历过