这两节课的题目分别是:
P6:PyTorch加载数据初认识
P7:Dataset类代码实战
首先放上课程代码:
"""
@Author : 时礼
@Contact : [email protected]
@Software : pycharm
@File : MyData.py
@Desc : 小土堆课程P6的笔记
"""
from torch.utils.data import Dataset
from PIL import Image
#python中关于系统的一个库
import os
class MyData():
#注意:这里定义的初始化函数是init不是int
def __init__(self,root_dir,label_dir):
#使用self相当于创建了一个类中的全局变量
#存储根目录地址
self.root_dir=root_dir
#存储存图片的文件夹的名称
self.label_dir=label_dir
#可以将两个路径合并
#得到图片的实际
self.path=os.path.join(self.root_dir,self.label_dir)
#获得指定路径下的文件列表(文件及其下标)
self.img_path=os.listdir(self.path)
def __getitem__(self, idx):
img_name=self.img_path[idx]
img_item_path=os.path.join(self.path,img_name)
img=Image.open(img_item_path)
label=self.label_dir
#所以这个return是返回了一个图片和存储图片的文件夹的名称?
return img,label
def __len__(self):
return len(self.img_path)
#注意,在服务器里复制到的路径还需要在前面加上一个/
root_dir = "/mnt/pycharmWorkspace1/pyCharmProject1/dataSet/train"
ants_label_dir = "ants_image"
#antsDataset = MyData(root_dir, ants_label_dir)
"""
root_dir="/mnt/pycharmWorkspace1/pyCharmProject1/dataSet/train"
ants_label_dir="ants_image"
/mnt/pycharmWorkspace1/pyCharmProject1/dataSet/train/ants_image
antsDataset=MyData(root_dir,ants_label_dir)
bees_laber_dir = "bees"
#可以将几个小的数据集进行拼接
trainDataset=antsDataset+beesDataset
#pycharm中整行复制的快捷键:ctrl+d
beesDataset=MyData(root_dir,ants_label_dir)
#查询数据长度的两种方式(起码在console里可以使用)
len(antsDataset)
antsDataset.__len__()位置
"""
博主没有学过python,但毕竟是面向对象的语言,本节课的主要内容就相当于是定义了一个class,然后实例化(创建对象)。
注意:
1、上文中没有注释的是定义的class,注释的是利用这个class来进行数据读取的代码,其中用到的数据集为:https://download.pytorch.org/tutorial/hymenoptera_data.zip
2、小土堆的在定义class的时候用的是:
class MyData(Dataset):
这样写的话在python console里可以运行,但如果直接写到main.py然后运行的话就会报错:
MyData is not found 我是上网上找了一个简单的不会报错的test.py文件,然后一点一点的将MyData里的代码放到这个test文件中,终于发现了如果将class的定义改成:
class MyData():
就不会再报错了。
3、这里讲一下我的实现方法,我将这个class代码单独放到了MyData.py中,然后通过main.py import MyData。这里又发生了一些问题(其实本质问题还是没有系统学过python,对python的一些操作不熟悉)。引用的时候报了TypeError: ‘module’ object is not callable的错。
为了解决这个问题,我又尝试了添加如下两行代码:
import sys
sys.path.append('./')
依然报错。而后看了一个文章说可能是因为class名称和python文件名一样,所以造成了报错。更改后还是报错。
最后发现是引用的格式不对,正确格式为:
from dataTest_test import dataTest
总结一下第三点遇到的情况:
添加的那两行代码应用的情况可能是当编写class的python文件不在工程文件中。并且我发现pycharm有一个比较方便的功能,那就是如果在main.py中引用成功了一个python文件中的class文件,如果将python文件重命名则main中的引用会自动改名,如上一个代码段中的代码更改完python文件名后的代码自动变为:
from dataTest import dataTest
4、创建class时需要注意的事:
(1)在定义init函数的时候注意不要打成int(我刚创建的时候报错了,半天没找到原因)
(2)在创建方法时使用的self相当于是一个整个类中的全局变量,如果在一个方法中使用self.变量名的方法给变量赋值,则这个变量可以在别的方法中读取。
5、当需要读取工程文件的地址时,可以使用FTP工具,不过在读取之后需要在前面加一个/表示从根目录开始。
6、使用到的技巧(如数据拼接,一会删)
(1)当需要将两个地址拼接时可以使用
self.path=os.path.join(self.root_dir,self.label_dir)
注意:当使用这个方法时需要引入os,代码如下:
import os
(2)如果使用了MyData方法创建了两个dataset,可以直接使用+将两个dataset拼接到一起
antsDataset=MyData(root_dir,ants_label_dir)
antsDataset=MyData(root_dir,bees_label_dir)
trainDataset=antsDataset+beesDataset
(3)查询一个dataset的数据长度的两种方式
1)
len(antsDataset)
2)
antsDataset.__len__()
7、本次使用到的pycharm快捷键(mac版)
(1)将一行代码整行复制:command + d
(2)将一行代码整行上下移动:command + shift + 上下键
接下来放下一段代码,用来创建image对应的label:
"""
@Author : rosyForever
@Contact : [email protected]
@Software : pycharm
@File : labelMake.py
@Desc : 用来生成图片对应的标签文件
"""
import os
root_dir = "/mnt/pyCharmProject1/dataSet/train"
target_dir = "bees_image"
#获得图片的文件列表
img_path = os.listdir(os.path.join(root_dir,target_dir))
label = target_dir.split('_')[0]
out_dir = "bees_label"
for i in img_path:
file_name=i.split('.jpg')[0]
with open(os.path.join(root_dir,out_dir,"{}.txt".format(file_name)),'w')as f:
f.write(label)
注意:
刚看到这段代码的时候不知道
label = target_dir.split('_')[0]
为什么最后要加一个[],于是写了一个test代码专门来测试这个方法
"""
@Author : rosyForever
@Contact : [email protected]
@Software : pycharm
@File : labelMake.py
@Time : 2022/10/19 21:22
@Desc : 用来测试split函数的使用以及为什么其后面要加一个[0]。
得出结论:slip是将字符串以输入到slipt函数中的字符串为分隔符进行分隔,生成
一个['','']格式的数组,使用[num]可以得到数组其中的一个值。注:如果分隔符
出现在最后一个,则会在数组最后一个值为空。
"""
file_name1="test1.jpg"
file_name2=file_name1.split('.jpg')
file_name3=file_name1.split('.jpg')[0]
print(file_name2)
print(file_name3)
"""
如果不使用[]则得到的是['test1', ''],如果使用[]则得到的时test1
"""
file_name4=file_name1.split('.')[0]
file_name5=file_name1.split('.')[1]
file_name6=file_name1.split('.')
print(file_name4)
print(file_name5)
print(file_name6)
"""
file_name4的内容是test1
file_name5的内容是jpg
file_name6的内容是['test1', 'jpg']
"""
得出结论:slip是将字符串以输入到slipt函数中的字符串为分隔符进行分隔,生成 一个['','']格式的数组,使用[num]可以得到数组其中的一个值。注:如果分隔符 出现在最后一个,则会在数组最后一个值为空。