在学习李沐在B站发布的《动手学深度学习》PyTorch版本教学视频中发现在操作使用PyTorch方面有许多地方看不懂,往往只是“动手”了,没有动脑。所以打算趁着寒假的时间好好恶补、整理一下PyTorch的操作,以便跟上课程。
学习资源:
Dataset
以及 Dataloader
是Pytorch中读取数据需要用到的两个重要的类
Dataset
:提供一种方式去获取数据及其lable,需要我们自己去写。Dataloader
:为后面的网络提供不同的数据形式常见的图片数据集有两种形式:
功能:
我们可以运行 from torch.utils.data import Dataset
(其中 utils
有实用工具的意思,理解为工具区)来导入Dataset这个类,同时也可以使用 Dataset??
和 help(Dataset)
来看如何使用Dataset
import torch
from torch.utils.data import Dataset
Dataset??
help(Dataset)
Help on class Dataset in module torch.utils.data.dataset:
class Dataset(typing.Generic)
| An abstract class representing a :class:`Dataset`.
|
| All datasets that represent a map from keys to data samples should subclass
| it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
| data sample for a given key. Subclasses could also optionally overwrite
| :meth:`__len__`, which is expected to return the size of the dataset by many
| :class:`~torch.utils.data.Sampler` implementations and the default options
| of :class:`~torch.utils.data.DataLoader`.
|
| .. note::
| :class:`~torch.utils.data.DataLoader` by default constructs a index
| sampler that yields integral indices. To make it work with a map-style
| dataset with non-integral indices/keys, a custom sampler must be provided.
|
| Method resolution order:
| Dataset
| typing.Generic
| builtins.object
|
| Methods defined here:
|
| __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]'
|
| __getitem__(self, index) -> +T_co
|
| ----------------------------------------------------------------------
| Data descriptors defined here:
|
| __dict__
| dictionary for instance variables (if defined)
|
| __weakref__
| list of weak references to the object (if defined)
|
| ----------------------------------------------------------------------
| Data and other attributes defined here:
|
| __orig_bases__ = (typing.Generic[+T_co],)
|
| __parameters__ = (+T_co,)
|
| ----------------------------------------------------------------------
| Class methods inherited from typing.Generic:
|
| __class_getitem__(params) from builtins.type
|
| __init_subclass__(*args, **kwargs) from builtins.type
| This method is called when a class is subclassed.
|
| The default implementation does nothing. It may be
| overridden to extend subclasses.
简言之:Dataset是个抽象类,所有的子类都要重写__getitem__
方法获取label,也可以重写__len__
方法获取长度
其中用到的两个库:
PIL
中的Image
:
PIL(Python Imaging Library):是Python的图像处理库
img = Image.open(image_path)
读取对应路径的图片为一个变量img.show()
使用系统默认图片打开方式打开此图片os
:operating system
os.path.join(root_dir, label_dir)
:将两个路径连起来,这个函数会根据操作系统自动调整路径的语法;其中dir = directory
目录os.listdir(dir_path)
:顾名思义,就是把括号里面的目录路径中的“所有文件的路径”生成一个列表,可以用类似a[0]
的语句取出 对应图片的路径from torch.utils.data import Dataset
from PIL import Image
import os
图片数据集格式为:label直接标在文件夹上
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
# 初始化函数,为后面的getitem和next方法提供所需要的量
self.root_dir = root_dir # 根目录的路径;self可以理解为当前类内部的一个全局变量
self.label_dir = label_dir # label目录的路径,因为下一行要合起来,并且label名是文件夹名,所以这里的label_dir可以直接取对应的label,如:"ants"
self.path = os.path.join(root_dir, label_dir) # 将两个路径连起来
self.img_path = os.listdir(self.path) # 获取该路径下所有文件的路径列表
def __getitem__(self, idx):
"""获取数据集中的每一个图片,输入索引,得到对应的图片"""
# idx是index索引的缩写
img_name = self.img_path[idx]
img_item_path = os.path.join(self.path, img_name)
image = Image.open(img_item_path)
label = self.label_dir
return image, label
def __len__(self):
return len(self.img_path)
对应的蜜蜂蚂蚁图片识别数据集见小土堆B站视频简介
root_dir = r"F:\Data and code\data\蚂蚁蜜蜂数据\hymenoptera_data\hymenoptera_data\train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
取一个试试看到底能不能取出来单独数据
img, label = ants_dataset[0]
# img.show()
img
"""下面的代码是将本节中所提到的格式转化成使用txt文档存放label的格式"""
"""这种格式就是新建一个ant_label文件夹,其中放的都是.txt文件。每一个文件的名字都是对应图片的名字,文件的内容则是对应的label"""
'''说实话还没看'''
root_dir = r"F:\Data and code\data\蚂蚁蜜蜂数据\hymenoptera_data\hymenoptera_data\train"
target_dir = "ants_image"
img_path_list = os.listdir(os.path.join(root_dir, target_dir))
label = target_dir.split('_')[0]
out_dir = "ants_label"
for i in img_path_list:
file_name = i.split('.jpg')[0]
with open(os.path.join(root_dir, out_dir, "{}.txt".format(file_name)), 'w') as f:
f.write(label)