pytorch学习(1) 数据集制作

(1) 数据集

数据集的制作

定义: 什么是数据集, 通俗来说就是包含一堆数据的集合, 是进行下一步训练的必备素材资源

连数据都没有, 你还分析啥啊

第一步: 引入Dataset模块

from torch.utils.data import Dataset	# 注意是大写的Data, 不是data \ Date

第二步: 创建一个数据集类

  • 该类需继承于Dataset父类, 并需要重写getitem, len魔法方法
  • 使用该类可以创建一个特殊的数据集, 即你需要的数据集, 可以自己定义具体内容
  • 最后将类实例化, 就得到了一个数据集

具体代码演示:

from PIL import Image	# 操作图片的模块
import os				# 操作文件的模块

class MyData(Dataset):	# 继承Dataset
    
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir
        self.label_dir = label_dir
        # 获取图片所在路径
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path_list = os.listdir(self.path)	# 得到所有图片的名字列表
        
    def __getitem__(self, idx):
        img_path = os.path.join(self.path, self.img_path_list[idx])
        img = Image.open(img_path)	# 获得所需的图片对象
        label = self.label_dir
        
        return img, label	# 返回图片, 标签元组
    
    def __len__(self):
        return len(self.img_path_list)		# 返回数据的数量
    

这样, 我们就把需要的数据集类定义完成了

第三步: 生成数据集

# 假如我的对应目录里有一组蚂蚁的图片
root_dir = r'datasets\train'
label_dir = 'ants'

ants_dataset = MyData(root_dir, label_dir)

Yahoo! 我们成功生成了一个数据集

你可能感兴趣的:(PyTorch入门,pytorch,深度学习,python)