定义: 什么是数据集, 通俗来说就是包含一堆数据的集合, 是进行下一步训练的必备素材资源
连数据都没有, 你还分析啥啊
from torch.utils.data import Dataset # 注意是大写的Data, 不是data \ Date
具体代码演示:
from PIL import Image # 操作图片的模块
import os # 操作文件的模块
class MyData(Dataset): # 继承Dataset
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
# 获取图片所在路径
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path_list = os.listdir(self.path) # 得到所有图片的名字列表
def __getitem__(self, idx):
img_path = os.path.join(self.path, self.img_path_list[idx])
img = Image.open(img_path) # 获得所需的图片对象
label = self.label_dir
return img, label # 返回图片, 标签元组
def __len__(self):
return len(self.img_path_list) # 返回数据的数量
这样, 我们就把需要的数据集类定义完成了
# 假如我的对应目录里有一组蚂蚁的图片
root_dir = r'datasets\train'
label_dir = 'ants'
ants_dataset = MyData(root_dir, label_dir)
Yahoo! 我们成功生成了一个数据集