pytorch学习(二)—自定义数据集

#在深度学习中经常需要生成带标签的图片名称列表,xxxlist.txt文件,
#编写脚本语言,实现对文件中图片生成带标签的txt文件方法
import os 
def generate(dir,label):
    files = os.listdir(dir)
    files.sort()
    print("*****************")
    print("input = ",dir)
    print("Start...")
    listText = open(dir+'\\'+'train.txt','w')
    for file in files:
        fileType = os.path.split(file)
        if fileType[1] == '.jpg':
            continue
        name = '/cat' + '/' + file + ' ' +str(int(label))+'\n'
        listText.write(name)
    listText.close()
    print("down!")
    print("********************")
if __name__ == '__main__':
    generate('D:\\Spyder3Files\\data\\train\\cat',0)
import numpy as np
from skimage import io
from skimage import transform
import matplotlib.pyplot as plt
import os
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torchvision.utils import make_grid
#定义MyDataset类, 继承Dataset, 重写抽象方法:__len()__, __getitem()__
class MyDataset(Dataset):
    
    def __init__(self,root_dir,names_file,transform=None):
        self.root_dir = root_dir
        self.names_file = names_file
        self.transform = transform
        self.size = 0
        self.names_list = []
        
        if not os.path.isfile(self.names_file):
            print(self.names_file + "does not exist!")
        file = open(self.names_file)
        for f in file:
            self.names_list.append(f)
            self.size += 1
            
    def __len__(self):
        return self.size
    
    def __getitem__(self,idx):
        
        image_path = self.root_dir + self.names_list[idx].split(' ')[0]
        if not os.path.isfile(image_path):
            print(image_path + "does not exist!")
            return None
        image = io.imread(image_path)
        label = int (self.names_list[idx].split(' ')[1])
        
        sample = {'image': image,'label' : label}
        if self.transform:
            sample = self.transform(sample)
            
        return sample


train_dataset = MyDataset(root_dir = './data/train',
        names_file = './data/train/train.txt',
        transform = None)

print(train_dataset.size)
        
plt.figure()
for(cnt,i) in enumerate(train_dataset):
    image = i['image']
    label = i['label']
    
    ax = plt.subplot(4,5,cnt+1)
    ax.axis('off')
    ax.imshow(image)
    ax.set_title('label {}'.format(label))
    plt.pause(0.001)
    
    if cnt == 19:
        break

#  变换Resize    
class Resize(object):

    def __init__(self, output_size: tuple):
        self.output_size = output_size

    def __call__(self, sample):
        # 图像
        image = sample['image']
        # 使用skitimage.transform对图像进行缩放
        image_new = transform.resize(image, self.output_size)
        return {'image': image_new, 'label': sample['label']}

#  变换ToTensor
class ToTensor(object):
    
    def __call__(self,sample):
        image = sample['image']
        image_new = np.transpose(image,(2,0,1))
        return {'image': torch.from_numpy(image_new),
                'label': sample['label']}
        
# 对原始的训练数据集进行变换
transformed_trainset = MyDataset(root_dir='./data/train',
                          names_file='./data/train/train.txt',
                          transform=transforms.Compose(
                              [Resize((224,224)),
                               ToTensor()]
                          ))

# 使用DataLoader可以利用多线程,batch,shuffle等
trainset_dataloader = DataLoader(dataset=transformed_trainset,
                                 batch_size=4,
                                 shuffle=True,
                                 num_workers=0)     #注意改为主线程0

#  可视化
def show_images_batch(sample_batched):
    images_batch, labels_batch = \
    sample_batched['image'], sample_batched['label']
    grid = make_grid(images_batch)
    plt.imshow(grid.numpy().transpose(1, 2, 0))


# sample_batch:  Tensor , NxCxHxW
plt.figure()
for i_batch, sample_batch in enumerate(trainset_dataloader):
    show_images_batch(sample_batch)
    plt.axis('off')
    plt.ioff()
    plt.show()


plt.show()
#  使用更简便的方式——ImageFolder
#  如果每种类别的样本放在各自的文件夹中,则可以直接使用ImageFolder.
import torch 
from torch.utils.data import DataLoader
from torchvision import transforms,datasets
import matplotlib.pyplot as plt
import numpy as np

data_transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        ])

train_dataset = datasets.ImageFolder(root = './data/train',transform = data_transform)

train_dataloader = DataLoader(dataset = train_dataset,
                              batch_size = 4,
                              shuffle = True,
                              num_workers = 0)

def show_batch_images(sample_batch):
    images_batch = sample_batch[0]
    labels_batch = sample_batch[1] 
    for i in range(4):
        label_ = labels_batch[i].item()
        image_ = np.transpose(images_batch[i],(1,2,0))
        ax = plt.subplot(1,4,i + 1)      
        ax.imshow(image_)
        ax.set_title(str(label_)) 
        ax.axis('off')
        #plt.pause(0.001)   #不用多线程这里不用考虑

        
plt.figure()
for i_batch,sample_batch in enumerate(train_dataloader):
    show_batch_images(sample_batch)  
    
    plt.show()

 

你可能感兴趣的:(pytorch)