Pytorch 构建数据集dataset

额,这里我们在网上找了10类花朵的数据,将数据进行分类,放在各个文件夹,文件名是花朵的标签,然后对图片大小统一为256*256。

将数据集分成训练集(train)、验证集(validation)、测试集(test)

分别为训练集800张,验证集100张,测试集100张,训练集和验证集的需要进行灰度处理,测试集不需要。

 

1.准备数据集好后,将文件路径和标签保存在txt文件中

from torchvision import transforms
import os
path = './validation/'
a = []
data_tf = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize([0.5], [0.5])])

#图片名和label信息的文本
def photo_text(path, name):
    table = os.listdir(path)
    for i in table:
        if i =='train.txt' or i == 'test.txt' or i == 'validation.txt':
            continue
        else:
            a.append(path+i +'  '+i.split('_')[0])
    try:
        with open(path + name,'x', encoding='utf-8') as f:
            for j in a:
                f.write(j+'\n')
    except:
        os.remove(path+name)
        with open(path + name,'x', encoding='utf-8') as f:
            for j in a:
                f.write(j+'\n')
if __name__ =='__main__':
    photo_text(path,'validation.txt')

Pytorch 构建数据集dataset_第1张图片

2.创建数据集,可以得到

from PIL import Image
from torchvision import transforms
import torch
# 创建数据集
class MyDataset(torch.utils.data.Dataset):  # 创建自己的类:MyDataset,这个类是继承的torch.utils.data.Dataset
    def __init__(self, root, datatxt, transform=None, target_transform=None):  # 初始化一些需要传入的参数
        super(MyDataset, self).__init__()
        fh = open(root + datatxt, 'r', encoding='utf-8')  # 按照传入的路径和txt文本参数,打开这个文本,并读取内容
        imgs = []  # 创建一个名为img的空列表,一会儿用来装东西
        for line in fh:  # 按行循环txt文本中的内容
            line = line.rstrip()  # 删除 本行string 字符串末尾的指定字符,这个方法的详细介绍自己查询python
            words = line.split()  # 通过指定分隔符对字符串进行切片,默认为所有的空字符,包括空格、换行、制表符等
            imgs.append((words[0], int(words[1])))  # 把txt里的内容读入imgs列表保存,具体是words几要看txt内容而定

        # 很显然,刚才截图所示txt的内容,words[0]是图片信息,words[1]是lable
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        # 这个方法是必须要有的,用于按照索引读取每个元素的具体内容
        fn, label = self.imgs[index]  # fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息
        img = Image.open(fn).convert('RGB')  # 按照path读入图片from PIL import Image # 按照路径读取图片
        if self.transform is not None:
            img = self.transform(img)  # 是否进行transform
        return img, label  # return很关键,return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容

    def __len__(self):  # 这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
        return len(self.imgs)
path_train = './train/'
path_validation = './validation/'


# 数据预处理。transforms.ToTensor()将图片转换成PyTorch中处理的对象Tensor,并且进行标准化(数据在0~1之间)
# transforms.Normalize()做归一化。它进行了减均值,再除以标准差。两个参数分别是均值和标准差
# transforms.Compose()函数则是将各种预处理的操作组合到了一起
data_tf = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize([0.5], [0.5])])

#个人数据装载
train_data = MyDataset(path_train ,'train.txt', transform=data_tf)
test_data = MyDataset(path_validation ,'validation.txt', transform=data_tf)

#数据装载
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory = True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True)

 

你可能感兴趣的:(人工智能,python,深度学习,pytorch)