参考:https://blog.csdn.net/weixin_43794311/article/details/118091799
一、首先需要导入库,下面两种方式都行
from torch.utils.data import *
from torch.utils.data import DataLoader,Dataset
二、先建立自己的Dataset子类
class my_Dataset(Dataset):
def __init__(self, 想要传入的参数):
#参数一般是路径
#对属性的赋值,一般得到所有的数据路径
def __len__(self):
return len(self.img_paths)#返回数据加载的数量
def __getitem__(self, index): # 对每个图片进行处理
#对每个加载的内容进行处理,最后返回需要使用的内容
三、定义DataLoader中参数collate_fn
这一步可以省略,但只能按照默认的格式输出,假设定义的DataSet的return中的返回两个对象
def collate_fn(batch):
renew_out=[]
for item in batch:#对一个batchsize的数据进行循环遍历后,控制输出
el,el1 = item #将返回的两个对象进行重新处理格式
renew_out+=[el,el1]
四、使用DataLoader加载
1、先实例化一个自己定义的Dataset对象,定义需要的数据
2、使用DataLoader加载生成需要的数据,其中设置了加载的线程,是否打乱,每个批数量等
data_set_object = my_Dataset(需要的参数) # 先实例化一个
data_loader = DataLoader(data_set_object,batch_size,num_work,collate_fn,shuffle)
from imutils import paths
from torch.utils.data import *
import matplotlib.pyplot as plt
import cv2
import numpy as np
def collate_fn(batch):
filenames=[]
heights = []
back_out = []
for filename,height in batch:
#print('file_name:',filename)
#print('height',height)
back_out+=[filename,height]
filenames.append(filename)
return back_out,filenames
class My_loader(Dataset):
def __init__(self, img_dir):
self.img_dir = img_dir
self.img_paths = []
self.img_paths += [el for el in paths.list_images(img_dir)]
def __len__(self):
return len(self.img_paths)
def __getitem__(self, index): # 对每个图片进行处理
filename = self.img_paths[index]
# Image = cv2.imread(filename) # 原始读取
plt_img = plt.imread(filename)
# 为了正确显示plt和cv2图片矩阵格式修改
Image= cv2.cvtColor(plt_img,cv2.COLOR_BGR2RGB) if plt_img.ndim>2 else plt_img
height, width, _ = Image.shape # h w c
return filename,height #自己定义的DataSet的返回
# # 若参数路径错误,可能出现
# ValueError: num_samples should be a positive integer value, but got num_samples=0
#
if __name__ == '__main__':
train_dataset = My_loader('my_test_imgs')
print(len(train_dataset))
my_dataloder = DataLoader(train_dataset, batch_size = 2, shuffle=True, num_workers=2,)
print(len(my_dataloder))
for batch in my_dataloder:
img_dir,height = batch #这里的DataSet对象的返回
print(batch)
print("!"*40)
from imutils import paths
from torch.utils.data import *
import matplotlib.pyplot as plt
import cv2
import numpy as np
def collate_fn(batch):
filenames=[]
heights = []
back_out = []
for filename,height in batch:
# print('file_name:',filename)
# print('height',height)
back_out+=[filename,height] # 相加依然放在一个列表中,和后面的collate_fn比较
filenames.append(filename)
return back_out
class My_loader(Dataset):
def __init__(self, img_dir):
self.img_dir = img_dir
self.img_paths = []
self.img_paths += [el for el in paths.list_images(img_dir)]
def __len__(self):
return len(self.img_paths)
def __getitem__(self, index): # 对每个图片进行处理
filename = self.img_paths[index]
# Image = cv2.imread(filename) # 原始读取
plt_img = plt.imread(filename)
# 为了正确显示plt和cv2图片矩阵格式修改
Image= cv2.cvtColor(plt_img,cv2.COLOR_BGR2RGB) if plt_img.ndim>2 else plt_img
height, width, _ = Image.shape # h w c
return filename,height
# # 若参数路径错误,可能出现
# ValueError: num_samples should be a positive integer value, but got num_samples=0
#
if __name__ == '__main__':
train_dataset = My_loader('my_test_imgs')
print(len(train_dataset))
my_dataloder = DataLoader(train_dataset, batch_size = 2, shuffle=True, num_workers=2,collate_fn=collate_fn)
print(len(my_dataloder))
for batch in my_dataloder:
# img_dir,height = batch
print(batch)
print("!"*40)
如果修改collate_fn,结果就会改变,修改back_out.append([filename,height])
,
def collate_fn(batch):
filenames=[]
heights = []
back_out = []
for filename,height in batch:
# print('file_name:',filename)
# print('height',height)
back_out.append([filename,height]) # 将一个列表作为最小的单位扩展放入空表中
return back_out
1.先执行main中的print
2.轮流执行DataSet和collate_fn函数中的print,执行DataSet的次数是batchsize的次数
from imutils import paths
from torch.utils.data import *
import matplotlib.pyplot as plt
import cv2
import numpy as np
def collate_fn(batch):
filenames=[]
heights = []
back_out = []
for filename,height in batch:
# print('file_name:',filename)
# print('height',height)
back_out.append([filename,height])
print('这是collatet_fn中的运行')
return back_out
class My_loader(Dataset):
def __init__(self, img_dir):
self.img_dir = img_dir
self.img_paths = []
self.img_paths += [el for el in paths.list_images(img_dir)]
def __len__(self):
return len(self.img_paths)
def __getitem__(self, index): # 对每个图片进行处理
filename = self.img_paths[index]
# Image = cv2.imread(filename) # 原始读取
plt_img = plt.imread(filename)
# 为了正确显示plt和cv2图片矩阵格式修改
Image= cv2.cvtColor(plt_img,cv2.COLOR_BGR2RGB) if plt_img.ndim>2 else plt_img
height, width, _ = Image.shape # h w c
print("这是DataSet中的内容")
return filename,height
# # 若参数路径错误,可能出现
# ValueError: num_samples should be a positive integer value, but got num_samples=0
#
if __name__ == '__main__':
train_dataset = My_loader('my_test_imgs')
print(len(train_dataset))
my_dataloder = DataLoader(train_dataset, batch_size = 2, shuffle=True, num_workers=2,collate_fn=collate_fn)
print(len(my_dataloder))
for batch in my_dataloder:
# img_dir,height = batch
print(batch)
print("这是main中")
print("!"*40)
结果分析:先执行了main中,然后自定义的DataSet中__getitem__()和调整输出函数collate_fn中的内容交替执行。
这个函数只在实例化的时候执行一次,
发现问题的过程,只是将程序复制一份后出现下面问题;又对旧文件及环境进行了测试,依然正常;
最后解决:发现唯一不同的是pycharm中的解释器环境版本不同,修改了新的环境和旧的一样后正常显示
分析可能的原因:问题中是因为循环Dataloader中导致的线程出现问题,旧和新的torch是同一个版本,但导致不同结果,可能是其他库版本的问题
RuntimeError: Caught RuntimeError in DataLoader worker process 0
RuntimeError: Could not infer dtype of numpy.float32