对pytorch中的dataset和dataloader的一些理解

dataset与dataloader

  • dataset
  • dataloader

dataset

pytorch为我们提供了一个torch.utils.data.Dataset的抽象类,在构建自己的数据集的时候,都必须继承这个类。
并且都要重写__len__和__getitem__这两两个方法。
__len__就是一共有多少条数据。
__getitem__就是在对自己的数据进行一些处理,比如说读取图片,对图片的大小,或者是通道进行调换(cv2读取的图片是RBG格式,而pytorch是RGB格式,因此要在这里进行一个转换)等操作。
然后还有一个是transform参数,一般这个参数就是对你的数据进行一个怎样的转化,比如说有一个transform.ToTensor的方法,就是将你读取到的图片等信息,变成一个Tensor形式。

dataloader

由于dataset在读取数据的时候,只能一条一条的读取,这就为我们进行mini-batch的训练造成了麻烦,而dataloader恰好就帮我们完成了这件事。
pytorch中在torch.utils.data.DataLoader提供了这个类,在初始化的时候,一般常用的参数有dataset,batch_size,shuffle,num_workers,这几个参数。
其中dataset就是自己创建的dataset类。在使用时,datalaoder可以直接用for循环遍历,也可以用next()来遍历。不过一般都用for循环。
其实遍历dataloader的时候,也是从dataset里面取数据,只不过改了一下规则。比如说打乱顺序之后,再从里面取数据。所以也会调用自己定义的dataset类中的__getitem__函数。

你可能感兴趣的:(pytorch,深度学习,python)