【小土堆学习笔记】Pytorch_p6p7加载数据初认识&实战

加载数据

Dataset

Dataset:提供一种方式去获取数据及其label
实现功能:
1.如何获取每一个数据及其label
2.告诉我们总共有多少个数据

Dataloader

Dataloader:为后面的网络提供不同的数据形式
对Dataset数据进行打包/压缩(batchsize)后,将数据送进网络

# use of class Dataset
from torch.utils.data import Dataset
help(Dataset)

help(Dataset) 能看到Dataset的用法解释:

Help on class Dataset in module torch.utils.data.dataset:

class Dataset(typing.Generic)
 |  Dataset(*args, **kwds)
 |  
 |  An abstract class representing a :class:`Dataset`.
 |  
 |  All datasets that represent a map from keys to data samples should subclass
 |  it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
 |  data sample for a given key. Subclasses could also optionally overwrite
 |  :meth:`__len__`, which is expected to return the size of the dataset by many
 |  :class:`~torch.utils.data.Sampler` implementations and the default options
 |  of :class:`~torch.utils.data.DataLoader`.
 |  
 |  .. note::
 |    :class:`~torch.utils.data.DataLoader` by default constructs a index
 |    sampler that yields integral indices.  To make it work with a map-style
 |    dataset with non-integral indices/keys, a custom sampler must be provided.
 |  
 |  Method resolution order:
 |      Dataset
 |      typing.Generic
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]'
 |  
 |  __getattr__(self, attribute_name)
 |  
 |  __getitem__(self, index) -> +T_co
 |  
 |  ----------------------------------------------------------------------
 |  Class methods defined here:
 |  
 |  register_datapipe_as_function(function_name, cls_to_register, enable_df_api_tracing=False) from builtins.type
 |  
 |  register_function(function_name, function) from builtins.type
 |  
 |  ----------------------------------------------------------------------
 |  Data descriptors defined here:
 |  
 |  __dict__
 |      dictionary for instance variables (if defined)
 |  
 |  __weakref__
 |      list of weak references to the object (if defined)
 |  
 |  ----------------------------------------------------------------------
 |  Data and other attributes defined here:
 |  
 |  __annotations__ = {'functions': typing.Dict[str, typing.Callable]}
 |  
 |  __orig_bases__ = (typing.Generic[+T_co],)
 |  
 |  __parameters__ = (+T_co,)
 |  
 |  functions = {'concat': functools.partial(

实战:

from torch.utils.data import Dataset
from PIL import Image
import os

# help(Dataset)
class MyData(Dataset):
    def __init__(self, root_dir, label_dir):  # 初始化类。根据该类创建一个特例/实例时,需要运行的一个函数,为整个class提供一个**全局变量**,为后面的函数提供变量
         self.root_dir = root_dir  # global
         self.label_dir = label_dir
         self.path = os.path.join(self.root_dir, self.label_dir)
         self.img_path = os.listdir(self.path)

    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return label, img

    def __len__(self):

        return len(self.img_path)


root_dir = 'dataset/train'
box_label_dir = 'box'
box2_label_dir = 'box2'
box_dataset = MyData(root_dir, box_label_dir)
box2_dataset = MyData(root_dir, box2_label_dir)
train_dataset = box_dataset + box2_dataset  # concat dataset

也可以再pycharm下方中的Python Console进行调试:
【小土堆学习笔记】Pytorch_p6p7加载数据初认识&实战_第1张图片
本节主要函数总结:

# concat path 路径拼接
os.path.join(path_a, path_b)
# path list 遍历路径下文件(名),并将该路径下所有文件暂存到一个列表path_1
path_1 = os.listdir(path)
# concat dataset  数据集拼接
train_dataset = box_dataset + box2_dataset   # dataset 可直接 + 来实现拼接
# 读取img
img = Image.open(img_path)  # 所得到的img变量里面有多个img相关的参数,里面的参数可以使用img.xx实现调用

你可能感兴趣的:(pytorch,pytorch,学习,深度学习)