pytorch创建自定义数据集

pytorch为我们提供了Dataset类来提供所用数据集的创建任务。
数据集有两种情况:
1.pytorch中写好的数据集,如CIFAR10,我们在使用该数据集时只需要以下代码:data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)
datasets.CIFAR10就是Datasets的一个子类,data是这个类的一个实例。
2.利用Dataset自定义数据集:
模板为

class MovingMNISTdataset(Dataset):#需要继承Dataset类
    ##dataset class for moving MNIST data
    ##Initialize
    def __init__(self, path):
        self.path = path
        self.data = MNISTdataLoader(path)

    def __len__(self):
        return len(self.data[:, 0, 0, 0])

    def __getitem__(self, indx):
        ##getitem method
        self.trainsample_ = self.data[indx, ...]
        self.sample_ = self.trainsample_/255.0

        self.sample = torch.from_numpy(np.expand_dims(self.sample_, axis = 1)).float()
        return self.sample

其中 getitem(self, index), len(self) 两个内建方法,用来表示从索引到样本的映射(Map).

你可能感兴趣的:(pytorch,深度学习,机器学习)