DataLoader 加载数据并显示数据

# -*- coding: utf-8 -*-
import argparse
import os
import copy

import torch
from torch import nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data.dataloader import DataLoader
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from tqdm import tqdm
import numpy as np

from models import FSRCNN, Discriminator
from datasets import TrainDataset, EvalDataset
from utils import AverageMeter, calc_psnr
import PIL.Image as Image
import matplotlib.pyplot as plt

# print('pid:{}   GPU:{}'.format(os.getpid(), os.environ['CUDA_VISIBLE_DEVICES']))
class MyDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        rmb面额分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform,数据预处理
        """
        self.label_name = {"1": 0, "100": 1}   # 初始化部分
        self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.transform = transform

    def __getitem__(self, index):  # 函数功能是根据index索引去返回图片img以及标签label
        path_img = self.data_info[index]
        img = Image.open(path_img).convert('RGB')     # 0~255

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img

    def __len__(self):   # 函数功能是用来查看数据的长度,也就是样本的数量
        return len(self.data_info)

    @staticmethod
    def get_img_info(data_dir):   # 函数功能是用来获取数据的路径以及标签
        data_info = list()
        print(data_dir)
        for root, dirs, files in os.walk(data_dir):
            print(files)
            # 遍历类别
            for file in files:
                if file.endswith('.png') or file.endswith('.PNG'):
                    img_names = os.path.join(root, file)
                    data_info.append(img_names)

        return data_info    # 有了data_info,就可以返回上面的__getitem__()函数中的self.data_info[index],根据index索取图片和标签


if __name__ == '__main__':
    
    norm_mean = [0.485, 0.456, 0.406]
    norm_std = [0.229, 0.224, 0.225]
    BATCH_SIZE = 12
    
    split_dir = os.path.join("./DIV2K_train_HR")
    train_dir = os.path.join(split_dir, "train")
    print(train_dir)
    
    # train_transform = transforms.Compose([
    # transforms.RandomCrop(32, padding=4),
    # transforms.ToTensor(),
    # transforms.Normalize(norm_mean, norm_std),
    # ])   # Resize的功能是缩放,RandomCrop的功能是裁剪,ToTensor的功能是把图片变为张量
    train_transform = transforms.Compose([
    transforms.RandomCrop(128, padding=0),
    transforms.ToTensor(),
    ])  
    valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
    ])
    train_data = MyDataset(data_dir=train_dir, transform=train_transform)  # data_dir是数据的路径,transform是数据预处理
    # 构建DataLoder
    train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)  # shuffle=True,每一个epoch中样本都是乱序的

    to_pil_image = transforms.ToPILImage()
    for data in train_loader:
        inputs = data
        # print('inputs.size():',inputs.size())
        
        #显示图片的第一种方法
        # 方法1:Image.show()
        # transforms.ToPILImage()中有一句
        # npimg = np.transpose(pic.numpy(), (1, 2, 0))
        # 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维
        inputs = to_pil_image(inputs[0]) #去掉第一维度
        inputs.show()
        
        # 方法2:plt.imshow(ndarray)
        img = inputs[0] # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
        img = img.numpy() # FloatTensor转为ndarray
        img = np.transpose(img, (1,2,0)) # 把channel那一维放到最后
        # 显示图片
        plt.imshow(img)
        plt.show()



#https://blog.csdn.net/qq_37388085/article/details/102663166?utm_medium=distribute.pc_relevant.none-task-blog-baidujs-2

 

你可能感兴趣的:(pytorch学习,DataLoader显示数据)