pytorch中的Dataloader和dataset详细的collate_fn参数测试

DataLoder的参数

参考:https://blog.csdn.net/weixin_43794311/article/details/118091799

使用简单介绍

一、首先需要导入库,下面两种方式都行

from torch.utils.data import *
from torch.utils.data import DataLoader,Dataset

二、先建立自己的Dataset子类

class my_Dataset(Dataset):
	def __init__(self, 想要传入的参数):
		#参数一般是路径
		#对属性的赋值,一般得到所有的数据路径

    def __len__(self):
        return len(self.img_paths)#返回数据加载的数量

    def __getitem__(self, index):  # 对每个图片进行处理
    	#对每个加载的内容进行处理,最后返回需要使用的内容

三、定义DataLoader中参数collate_fn
这一步可以省略,但只能按照默认的格式输出,假设定义的DataSet的return中的返回两个对象

def collate_fn(batch):
	renew_out=[]
	for item in batch:#对一个batchsize的数据进行循环遍历后,控制输出
		el,el1 = item #将返回的两个对象进行重新处理格式
		renew_out+=[el,el1]

四、使用DataLoader加载
1、先实例化一个自己定义的Dataset对象,定义需要的数据
2、使用DataLoader加载生成需要的数据,其中设置了加载的线程,是否打乱,每个批数量等

data_set_object = my_Dataset(需要的参数)  # 先实例化一个
data_loader = DataLoader(data_set_object,batch_size,num_work,collate_fn,shuffle)

使用collate_fn和未使用自定义的不同

一、未使用collate_fn时

from imutils import paths
from torch.utils.data import *
import matplotlib.pyplot as plt
import cv2
import numpy as np

def collate_fn(batch):
	filenames=[]
	heights = []
	back_out = []
	for filename,height in batch:
		#print('file_name:',filename)
		#print('height',height)
		back_out+=[filename,height]
		filenames.append(filename)
	return back_out,filenames
	
class My_loader(Dataset):
    def __init__(self, img_dir):

        self.img_dir = img_dir
        self.img_paths = []
        self.img_paths += [el for el in paths.list_images(img_dir)]

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):  # 对每个图片进行处理
        filename = self.img_paths[index]
        # Image = cv2.imread(filename) # 原始读取
        plt_img = plt.imread(filename)
        # 为了正确显示plt和cv2图片矩阵格式修改
        Image= cv2.cvtColor(plt_img,cv2.COLOR_BGR2RGB) if plt_img.ndim>2 else plt_img

        height, width, _ = Image.shape  # h w c
        return filename,height  #自己定义的DataSet的返回

# # 若参数路径错误,可能出现
# ValueError: num_samples should be a positive integer value, but got num_samples=0
# 
if __name__ == '__main__':
	train_dataset = My_loader('my_test_imgs')
	print(len(train_dataset))  
	my_dataloder = DataLoader(train_dataset, batch_size = 2, shuffle=True, num_workers=2,)
	print(len(my_dataloder))
	for batch in my_dataloder:
		img_dir,height = batch #这里的DataSet对象的返回
		print(batch)
		print("!"*40)

pytorch中的Dataloader和dataset详细的collate_fn参数测试_第1张图片

二、使用collate_fn函数后

from imutils import paths
from torch.utils.data import *
import matplotlib.pyplot as plt
import cv2

import numpy as np

def collate_fn(batch):
	filenames=[]
	heights = []
	back_out = []
	for filename,height in batch:
		# print('file_name:',filename)
		# print('height',height)
		back_out+=[filename,height]  # 相加依然放在一个列表中,和后面的collate_fn比较
		filenames.append(filename)
	return back_out


class My_loader(Dataset):
    def __init__(self, img_dir):

        self.img_dir = img_dir
        self.img_paths = []
        self.img_paths += [el for el in paths.list_images(img_dir)]

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):  # 对每个图片进行处理
        filename = self.img_paths[index]
        # Image = cv2.imread(filename) # 原始读取
        plt_img = plt.imread(filename)
        # 为了正确显示plt和cv2图片矩阵格式修改
        Image= cv2.cvtColor(plt_img,cv2.COLOR_BGR2RGB) if plt_img.ndim>2 else plt_img

        height, width, _ = Image.shape  # h w c
        return filename,height

# # 若参数路径错误,可能出现
# ValueError: num_samples should be a positive integer value, but got num_samples=0
# 
if __name__ == '__main__':
	train_dataset = My_loader('my_test_imgs')
	print(len(train_dataset))  
	my_dataloder = DataLoader(train_dataset, batch_size = 2, shuffle=True, num_workers=2,collate_fn=collate_fn)
	print(len(my_dataloder))
	for batch in my_dataloder:
		# img_dir,height = batch
		print(batch)
		print("!"*40)

pytorch中的Dataloader和dataset详细的collate_fn参数测试_第2张图片
如果修改collate_fn,结果就会改变,修改back_out.append([filename,height])

def collate_fn(batch):
	filenames=[]
	heights = []
	back_out = []
	for filename,height in batch:
		# print('file_name:',filename)
		# print('height',height)
		back_out.append([filename,height]) # 将一个列表作为最小的单位扩展放入空表中
	return back_out

pytorch中的Dataloader和dataset详细的collate_fn参数测试_第3张图片

DataLoader中运行情况

1.先执行main中的print
2.轮流执行DataSet和collate_fn函数中的print,执行DataSet的次数是batchsize的次数
from imutils import paths
from torch.utils.data import *
import matplotlib.pyplot as plt
import cv2

import numpy as np

def collate_fn(batch):
	filenames=[]
	heights = []
	back_out = []
	for filename,height in batch:
		# print('file_name:',filename)
		# print('height',height)
		back_out.append([filename,height])
	print('这是collatet_fn中的运行')
	return back_out

class My_loader(Dataset):
    def __init__(self, img_dir):

        self.img_dir = img_dir
        self.img_paths = []
        self.img_paths += [el for el in paths.list_images(img_dir)]
    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):  # 对每个图片进行处理
        filename = self.img_paths[index]
        # Image = cv2.imread(filename) # 原始读取
        plt_img = plt.imread(filename)
        # 为了正确显示plt和cv2图片矩阵格式修改
        Image= cv2.cvtColor(plt_img,cv2.COLOR_BGR2RGB) if plt_img.ndim>2 else plt_img

        height, width, _ = Image.shape  # h w c
        print("这是DataSet中的内容")
        return filename,height

# # 若参数路径错误,可能出现
# ValueError: num_samples should be a positive integer value, but got num_samples=0
# 
if __name__ == '__main__':
	train_dataset = My_loader('my_test_imgs')
	print(len(train_dataset))  
	my_dataloder = DataLoader(train_dataset, batch_size = 2, shuffle=True, num_workers=2,collate_fn=collate_fn)
	print(len(my_dataloder))
	for batch in my_dataloder:
		# img_dir,height = batch
		print(batch)
		print("这是main中")
		print("!"*40)

结果分析:先执行了main中,然后自定义的DataSet中__getitem__()和调整输出函数collate_fn中的内容交替执行。
pytorch中的Dataloader和dataset详细的collate_fn参数测试_第4张图片

注意:自定义的DataSet中的__init__()

这个函数只在实例化的时候执行一次

在对DataLoader的对象进行循环访问时出现问题

发现问题的过程,只是将程序复制一份后出现下面问题;又对旧文件及环境进行了测试,依然正常;
最后解决:发现唯一不同的是pycharm中的解释器环境版本不同,修改了新的环境和旧的一样后正常显示
分析可能的原因:问题中是因为循环Dataloader中导致的线程出现问题,旧和新的torch是同一个版本,但导致不同结果,可能是其他库版本的问题

RuntimeError: Caught RuntimeError in DataLoader worker process 0
RuntimeError: Could not infer dtype of numpy.float32

你可能感兴趣的:(笔记,一些自己的小用法,pytorch,深度学习)