pytorch-读数据

# -*- coding: utf-8 -*-

import os
import random
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import cv2



# random.seed(1)
#
#
# rmb_label = {"1": 0, "100": 1}
#
# train_transform = transforms.Compose([
#     transforms.Resize((32, 32)),
#     # transforms.RandomCrop(32, padding=4),
#     transforms.ToTensor(),
#     # transforms.Normalize(norm_mean, norm_std),
# ])

class RMBDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.label_name = {"1": 0, "100": 1}
        self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.transform = transform

    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')     # 0~255

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等
        return img, label

    def __len__(self):
        return len(self.data_info)

    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        rmb_label = {"1": 0, "100": 1}
        for root, dirs, _ in os.walk(data_dir): #
            # 遍历类别
            for sub_dir in dirs:
                img_names = os.listdir(os.path.join(root, sub_dir))
                img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))

                # 遍历图片
                for i in range(len(img_names)):
                    img_name = img_names[i]
                    path_img = os.path.join(root, sub_dir, img_name)
                    label = rmb_label[sub_dir]
                    data_info.append((path_img, int(label)))

        return data_info




# train_dir='./RMB_data/rmb_split/train/'
#
# train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
# # 构建DataLoder
# train_loader = DataLoader(dataset=train_data, batch_size=16)
#
#
# for epoch in range(10):
#     for i, data in enumerate(train_loader):
#         pass






# dir = "./data/"  #100张图片
#
# def 挤牙膏操作(dir,batch_size):
#     读取文件夹(dir)
#     image ,label=按照batch_size进行数据的拆分(batch_size)
#     image=数据预处理(image)
#     return image ,label
#
# for i in range(5): #5个epoch
#     for j in range(10):
#         image, label = 挤牙膏操作(dir, batch_size=10)


#         pre_label = 神经网络(image)
#         损失函数(pre_label,label)
#




#
# class MYDataset(Dataset):
#     def __init__(self, data_dir, transform=None):
#        ..
#
#     def __getitem__(self, index):
#         ..
#         return  ..,..
#
#     def __len__(self):
#         return len(..)
#
#
# train_dir='./'
#
# train_data = MYDataset(data_dir=train_dir)
# # 构建DataLoder
# train_loader = DataLoader(dataset=train_data, batch_size=16)
#
#
# for epoch in range(10):
#     for i, data in enumerate(train_loader):
#         pass

# ---------enumerate函数的解释,是一个迭代器,对列表进行不断的迭代。。。
# list_train = [[1,2,3,4],[5,6,7,8],[9,1,2,3]]
# for i, data in enumerate(list_train):
#     print(i, data)
    # 0[1, 2, 3, 4]
    # 1[5, 6, 7, 8]
    # 2[9, 1, 2, 3]













你可能感兴趣的:(Pytorch相关,pytorch,深度学习,神经网络)