Pytorch数据加载(Dataset+DataLoader模块)【讲解+代码】

数据载入由什么组成?

数据载入由dataset和dataloader组成。
dataset:提供一种方式去获取数据及其label
dataloader: 为后面网络提供不同的数据形式

1. Dataset的功能

dataset主要为了实现两个功能
1.如何获取每个数据及其label
2.告诉我们总共有多少的数据

2. Dataset代码:

1) 查看官方文档解释

首先,在anaconda prompt中输入如下代码,打开jupyter环境

conda activate <pytorch环境名称> #激活pytorch环境 
jupyter notebook #打开jupyter

然后,在jupyter中创建新的文件,并输入以下指令既可以看到关于dataset官方文档解释。

from torch.utils.data import Dataset
help(Dataset)

2) pycharm中进行程序编写

(1)下载数据集
蚁蜜蜂分类数据集
https://download.pytorch.org/tutorial/hymenoptera_data.zip
(2)
建立dataset文件,并将数据集放入程序下
Pytorch数据加载(Dataset+DataLoader模块)【讲解+代码】_第1张图片
(3)编写数据集载入程序,实现dataset两个功能,第一,如何获取每个数据及其label
第二,告诉我们总共有多少的数据。



from torch.utils.data import Dataset
from PIL import Image
import os
from torchvision import transforms

class mydata(Dataset):

    #设置全局参数
    def __init__(self,root_dir,label_dir):
        self.root_dir=root_dir
        self.label_dir=label_dir
        # 获得图片的路径地址
        self.path=os.path.join(self.root_dir,self.label_dir) #os.path.join()函数用于路径拼接文件路径,可以传入多个路径。如果不存在以’/’开始的参数,则函数会自动加上
        # 获得图片的所有列表
        self.img_path=os.listdir(self.path)  #os.listdir() 方法用于返回指定的文件夹包含的文件或文件夹的名字的列表。

    #获取每一个图片
    def __getitem__(self, item):
        #获取单张图片名称
        img_name=self.img_path[item]
        #获取单张图片相对路径
        img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
        #读取单张图片
        img=Image.open(img_item_path)
        img=img.resize((256,256),Image.ANTIALIAS) #统一图片尺寸
        trans = transforms.ToTensor() #转换为tensor类型
        img_tensor = trans(img)#转换为tensor类型


        #获取lable
        label=self.label_dir
        return img_tensor, label

    #列表有多长
    def __len__(self):
        return len(self.img_path)

3. Dataloader的功能

dataloader的功能是为了实现从dataset中取数据,例如,每次取多少数据?,数据集是否打乱?,加载过程是单进程还是多进程?,如果最后剩余数据不足一次需要获取数据,剩余数据是否舍弃。

4. Dataloader代码

1) 查看官方文档解释

在jupyter中创建新的文件,并输入以下指令既可以看到关于dataset官方文档解释。

from torch.utils.data import DataLoader
help(DataLoader)

2) pycharm中进行程序编写

新建.py文件,写入以下内容

from dataset import mydata
from torch.utils.data import DataLoader
import torch

#准备测试数据集
root_dir= "dataset/val"
bees_label_dir="bees"
test_dataset=mydata(root_dir,bees_label_dir)#输入数据集路径

#数据载入
test_loader=DataLoader(dataset=test_dataset,batch_size=4,shuffle=True,num_workers=0,drop_last=False)

for data1 in test_loader:
    imgs,labs=data1
    print(imgs.shape)
    print(labs)

输出如下内容则表示数据载入成功
Pytorch数据加载(Dataset+DataLoader模块)【讲解+代码】_第2张图片

感谢: PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】
视频网址:https://www.bilibili.com/video/BV1hE411t7RN?p=15&vd_source=5b6e0605c1ed0f1db9c92503dd5994e0

你可能感兴趣的:(pytorch,pytorch,深度学习,python)