目录
一、Dataset初识以及项目前期准备工作
二、MyData类
2.1 在python中定义类和方法
2.2 定义MyClass类
Dataset
2.3 获取图片
2.4 使用控制台调试对应信息
1. 获取ants集中第一章图片的绝对路径
2. 读取对应路径的图片
3. 显示图片:show方法
4. 获取图片信息列表
三、完善MyData类
3.1 初始化方法中需要的参数和方法
3.2 初始化init方法的书写
3.3 getitem方法的书写
3.4 生成实例
3.4 两个数据集的生成与相加操作
1. 生成蚂蚁和蜜蜂数据集
2. 数据集相加
四、完整代码
五、使用修改后数据集的代码练习
python文件、python控制台和jupter notebook的区别
遇到的问题:
1. jupyter notebook中配置pytorch
(71条消息) jupyter notebook中使用pytorch_一子慢的博客-CSDN博客_jupyternotebook使用pytorch
2. pycharm中matplotlib使用失败
(71条消息) Pycharm导入matplotlib失败的解决办法_c472769019的博客-CSDN博客_matplotlib导入失败
在notebook中使用help方法查看dataset类的功能以及操作:
前置操作:
1. 把数据集移动到项目所在的目录文件夹下
2. 右击想要查看路径的文件夹/图片:
可以复制需要的绝对路径/相对路径
- 在python中定义类的要求:class关键字定义类,后面跟着类的全名,括号(object)表示该类是从哪个类中继承下来的,如果没有合适的继承类,则使用object类,这是所有类都会继承的类。
- 在类里定义方法的要求:在类中定义方法时,第一个参数必须是self。
- 在类中定义方法的要求:self变量无需传递,其他参数正常传入。
例:
from torch.utils.data import Dataset
Dataset是一个抽象类,为了能够方便的读取,需要将要使用的数据包装为Dataset类。 自定义的Dataset需要继承它并且实现两个成员方法:
1. __getitem__()
该方法定义用索引(0
到 len(self)
)获取一条数据或一个样本,可以使用对象【item】进行访问
2. __len__()
该方法返回数据集的总长度
首先,重写init方法和getitem方法,后期重写len()方法
导入图片需要获取对应的图片image和对应的标签label,也需要获取图片所在的位置img_path
读取图片需要导入的模块
# 读取图片
from PIL import Image
控制台作用:可以显示定义的变量和相关属性
存入img_path变量中,复制后的路径需要再加一个双斜线进行转义。
使用Image中的open方法
可以看到右边出现了img变量的相关属性
如size值即为图片的大小,在控制台中可以对应输出
调用该方法后可以对应弹出显示图片的窗口
如图为img_path_list对象,可以看到集合了ants文件夹下所有图片的名称,共124张图片,因此列表大小为124
如果访问img_path_list列表的元素,如第一个元素,下标为0,则可以输出第一章图片的名称
获取到文件根目录和标签目录后,使用join方法进行地址的拼接,获取到对应图片文件夹的地址,然后使用listdir方法获取到该地址的图片列表
# 重写函数的初始化方法
def __init__(self,root_dir,label_dir):
# 初始化
self.root_dir=root_dir
self.label_dir=label_dir
# 获取图片文件夹的路径
self.path=os.path.join(self.root_dir,self.label_dir)
# 获取对应图片路径的图片名称列表
self.img_list=os.listdir(self.path)
作用:获取到图像列表中单个图片的对象以及其标签
idx:对应图片的索引值
使用拼接法:文件夹路径+图片名 可以获取到具体某一张图片的地址
open方法生成对应图片对象
python基础:如果有多个返回值,默认以元组形式打包,因此geitem方法返回的是(img,label)的元组
# 重写类的getitem方法
def __getitem__(self, idx):
# 获取单个图片名称
img_name=self.img_list[idx]
# 获取单个图片路径,使用拼接法
img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
# 生成对应图片对象
img = Image.open(img_item_path)
# 对应标签
label = self.label_dir
# 返回图像和标签,以元组格式返回
return img,label
root_dir="dataset/train"
label_dir="ants"
#实例化MyData类
ants_datasets=MyData(root_dir,label_dir)
在控制台中进行测试,可以看到生成的ants_datasets对象中有了我们在上面初始化方法中进行定义的所有属性,如list,path等等
ants_datasets数据集的第一项即为第一张图片对象以及其label标签
img,label=ants_datasets[1],使用img和label接受元组中的img和label,可以看到变量中img和lable有了对应的具体值
root_dir="dataset/train"
ants_label_dir="ants"
bees_label_dir="bees"
# 生成MyData类的实例对象
ants_datasets=MyData(root_dir,ants_label_dir)
bees_datasets=MyData(root_dir,bees_label_dir)
可以看到相加后train_datasets的长度是两个数据集的和
from torch.utils.data import Dataset
# 读取图片
from PIL import Image
# 关于系统的库
import os
class MyData(Dataset):
# 重写函数的初始化方法
def __init__(self,root_dir,label_dir):
# 初始化
self.root_dir=root_dir
self.label_dir=label_dir
# 获取图片文件夹的路径
self.path=os.path.join(self.root_dir,self.label_dir)
# 获取对应图片路径的图片名称列表
self.img_list=os.listdir(self.path)
# 重写类的getitem方法
def __getitem__(self, idx):
# 获取单个图片名称
img_name=self.img_list[idx]
# 获取单个图片路径,使用拼接法
img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
# 生成对应图片对象
img = Image.open(img_item_path)
# 对应标签
label = self.label_dir
# 返回图像和标签,元组
return img,label
def __len__(self):
return len(self.img_list)
root_dir="dataset/train"
ants_label_dir="ants"
bees_label_dir="bees"
# 生成MyData类的实例对象
ants_datasets=MyData(root_dir,ants_label_dir)
bees_datasets=MyData(root_dir,bees_label_dir)
# 两个数据集相加
train_datasets=ants_datasets+bees_datasets
修改后数据集结构如下图所示,图像和标签各有一个文件夹进行存储
标签文件夹下是各个图像的标签,为txt文件,文件名与图像名相同,并且文件内容仅有一行,即为标签内容ants
因此获取标签时需要使用file读取文件形式
from torch.utils.data import Dataset
from PIL import Image
import os
class MyDataset(Dataset):
def __init__(self,root_dir,img_dir,label_dir):
# 根文件路径
self.root_dir=root_dir
# 图片文件路径
self.img_dir=img_dir
#标签文件夹路径
self.label_dir=label_dir
# 获取图片文件夹路径并生成图片名称的列表
self.img_path=os.path.join(self.root_dir,self.img_dir)
self.img_list=os.listdir(self.img_path)
#获取标签文件夹路径并生成标签名称的列表
self.label_path=os.path.join(self.root_dir,self.label_dir)
self.label_list=os.listdir(self.label_path)
def __getitem__(self, item):
img_name=self.img_list[item]
img_item_path=os.path.join(self.img_path,img_name)
# 读取对应路径的图片内容,生成图片对象,存储在img中
img=Image.open(img_item_path)
label_name=self.label_list[item]
label_item_path=os.path.join(self.label_path,label_name)
# 打开对应路径的txt文件,读取对应内容,存储在label中
file1 = open(label_item_path,"r")
label= file1.readline()
return img,label
def __len__(self):
return len(self.img_list)
root_dir="datasets2/train"
ants_img_dir="ants_image"
ants_label_dir="ants_label"
bees_img_dir="bees_image"
bees_label_dir="bees_label"
ants_datasets=MyDataset(root_dir,ants_img_dir,ants_label_dir)
bees_datasets=MyDataset(root_dir,bees_img_dir,bees_label_dir)