pytorch学习3之torch.utils.data.Dataset

有时候我们需要自定义数据集。这时我们可以继承torch.utils.data.Dataset类,这是一个表示数据集的抽象类。当我们需要用到自定义的数据集时,可以去继承Dataset类并覆盖__len__()和__getitem__()方法,其中__len__()返回数据集的样本个数,getitem(index)返回训练集的第index个样本。
这次我使用的仍然是上一篇文章中的数据集,有需要的小伙伴可以去下载哦。pytorch学习2之数据加载与处理

from torch.utils.data import Dataset
import pandas as pd
import os
import numpy as np
import matplotlib.pyplot as plt
from skimage import io,transform

def show_landmarks(image,landmarks):
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    plt.pause(0.1)
class FaceLandsDataset(Dataset):
    #面部标记数据集
    def __init__(self,csv_file,root_dir,transform=None):
        #csv_file 带注释的csv文件的路径
        #root_dir 包含所有图像的目录
        #transform :一个样本上的可用的可选变换
        self.landmarks = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks)
    def __getitem__(self,index):
        img_name = os.path.join(self.root_dir,
        self.landmarks.iloc[index,0])
        image = io.imread(img_name)
        landmarks = self.landmarks.iloc[index,1:]
        landmarks = np.array([landmarks])
        landmarks = landmarks.astype('float').reshape(-1,2)
        sample = {'image':image,'landmarks':landmarks}
        if self.transform:
            sample = self.transform
        return sample

face_dataset = FaceLandsDataset(csv_file='data/faces/face_landmarks.csv',root_dir='data/faces/')
fig = plt.figure()
for i in range(len(face_dataset)):
    sample = face_dataset[i]
    print(i,sample['image'].shape,sample['landmarks'].shape)
    ax = plt.subplot(1,4,i+1)
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample)
    if(i == 3):
        plt.show()
        break

结果如下:
数据的维度
pytorch学习3之torch.utils.data.Dataset_第1张图片
以上就是这篇文章的全部内容,感谢观看,一起进步!(以上内容来自pytorch官方教程中文版)

你可能感兴趣的:(python,人工智能,pytorch)