PyTorch 图像分类识别(一)定义及加载自己的数据集并可视化

文章目录

  • 前言
  • 一、Dataset、DataLoader是什么?
  • 二、如何定义Dataset?
    • 1.定义 Dataset
  • 三、如何使用DataLoader?
    • 1. 使用Dataloader加载数据集
  • 四、可视化源数据
  • 五、完整代码
  • 参考


前言

深度学习初入门小白,技艺不精,写下笔记记录自己的学习过程。欢迎评论区交流提问,力所能及之问题,定当毫无保留之相授。


一、Dataset、DataLoader是什么?

Dataset:是一个包装类,用来将数据包装为Dataset类,然后传入DataLoader中。
Dataloader:通过DataLoader这个函数,我们在加载数据集时候,批次读取数据及多线程并行处理,这样可以加快我们读取数据集的速度。

二、如何定义Dataset?

Dataset类是Pytorch中数据集加载类中应该继承的父类。通常包括这三部分:

1.*def __init__(self)*
2.*def __getitem__(self, index):*
3.*def __len__(self):*

其中父类中的两个私有成员函数,__len__和__getitem__必须被重载!

1.定义 Dataset

#root1和root2分别为训练集,验证集存放图片路径及标签的txt路径
root1 = r"C:\Users\asus\Desktop\mstar_classification\mstar\train.txt"
root2 = r"C:\Users\asus\Desktop\mstar_classification\mstar\val.txt"

# 1、构建数据集类
class Mydata(Dataset):

    # __init__
    # 该函数可以包含多个参数,如数据的读取路径和对数据的处理设置等一系列设定
    # txt:存放着图片数据的路径和标签信息,words[0]为图片的路径,words[1]为图片的标签,如下图所示。(txt需要事先生成,如何生成先挖个坑)
    # imgs:按行读取txt,并依次存放到列表中
    # transform为:图片数据增强,下文中会讲
    def __init__(self, txt, transform=None, target_transform=None):
        super(Mydata, self).__init__()
        imgs = []
        fh = open(txt, 'r')
        for line in fh:
            line = line.strip('\n')
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0], int(words[1])))  # imgs中包含有图像路径和标签
        self.txt = txt
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform

    # __getitem__
    # 接收一个index,然后返回图片路径和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。
    # 在本代码中,这个list为imgs[]
    # 图片打开方式为Image.open,三通道RGB格式。若数据集图片为单通道,可在transform中添加transforms.Grayscale(1)函数。
    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = Image.open(os.path.join(self.txt[:-4], fn))#self.txt[:-4],下文加载txt时,路径中不需要有后缀,所以去掉.txt四个字符
        if self.transform is not None:
            img = self.transform(img)
        return img, label

    #__len__ 
   #返回样本的总数量, 该方法提供了dataset的大小
    def __len__(self):
        return len(self.imgs)
        
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ColorJitter(),
                                      transforms.Grayscale(1), transforms.ToTensor(),
                                      transforms.Normalize([0.5], [0.5])])
test_transform = transforms.Compose([transforms.Grayscale(1), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])

train_data = Mydata(txt=root1, transform=train_transform)
test_data = Mydata(txt=root2, transform=test_transform)

txt中存放着图片的路径及标签
PyTorch 图像分类识别(一)定义及加载自己的数据集并可视化_第1张图片

三、如何使用DataLoader?

该函数的作用是将数据整理成一个batch,即根据batch_size的大小一次性在数据集中取出batch_size个数据。例如数据集中有100条数据,batch_size的值为20,则每次在100条数据中取出20条数据。

torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)
# dataset: 加载torch.utils.data.Dataset对象数据,即为上文中的train_data和test_data
# batch_size: 每个batch的大小
# shuffle:是否对数据进行打乱
# drop_last:是否对无法整除的最后一个datasize进行丢弃
# um_workers:表示加载的时候子进程数,一般GPU使用

1. 使用Dataloader加载数据集

train_loader = DataLoader(dataset=train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)

四、可视化源数据

examples = enumerate(train_loader)
batch_idx, (examples_data, examples_targets) = next(examples)

fig = plt.figure()
for i in range(6):
    plt.subplot(2, 3, i+1)
    plt.tight_layout()#自动调整子图参数,使之填充满整个图像区域
    plt.imshow(examples_data[i][0], interpolation='none')
    plt.title("Category:{}".format(examples_targets[i]))
    plt.xticks([])
    plt.yticks([])
plt.show()

五、完整代码

注意:
1.数据集的路径需要改成自己的
2.前提需要生成相应的txt文件

import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import os

root1 = r"C:\Users\asus\Desktop\mstar_classification\mstar\train.txt"
root2 = r"C:\Users\asus\Desktop\mstar_classification\mstar\val.txt"


# 1、构建数据集
class Mydata(Dataset):
    def __init__(self, txt, transform=None, target_transform=None):
        super(Mydata, self).__init__()
        self.txt = txt
        fh = open(txt, 'r')
        imgs = []
        for line in fh:
            line = line.strip('\n')
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0], int(words[1])))  # imgs中包含有图像路径和标签
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = Image.open(os.path.join(self.txt[:-4], fn))
        if self.transform is not None:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.imgs)


# 2.数据增强、加载数据
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ColorJitter(),
                                      transforms.Grayscale(1), transforms.ToTensor(),
                                      transforms.Normalize([0.5], [0.5])])
test_transform = transforms.Compose(
    [transforms.Grayscale(1), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
# 是被封装进DataLoader里,实现该方法封装自己的数据和标签
train_data = Mydata(txt=root1, transform=train_transform)
test_data = Mydata(txt=root2, transform=test_transform)
# DataLoader被封装入DataLoader里,实现该方法达到数据的划分
# train_data 和test_data包含多有的训练与测试数据,调用DataLoader批量加载
train_loader = DataLoader(dataset=train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)


# 3.可视化源数据
examples = enumerate(train_loader)
batch_idx, (examples_data, examples_targets) = next(examples)

fig = plt.figure()
for i in range(6):
    plt.subplot(2, 3, i + 1)
    plt.tight_layout()  # 自动调整子图参数,使之填充满整个图像区域
    plt.imshow(examples_data[i][0], interpolation='none')
    plt.title("Category:{}".format(examples_targets[i]))
    plt.xticks([])
    plt.yticks([])
plt.show()

参考

https://blog.csdn.net/sinat_42239797/article/details/90641659
https://blog.csdn.net/ChaoFeiLi/article/details/109764566
https://blog.csdn.net/l8947943/article/details/103733473
https://blog.csdn.net/kahuifu/article/details/108654421
https://blog.csdn.net/wangkaidehao/article/details/104209685

你可能感兴趣的:(pytorch,python,深度学习)