Pytorch笔记:dataloader的collate_fn参数在加载数据集时的作用

1. 前言

最近在复现MCNN时发现一个问题,ShanghaiTech数据集图片的尺寸不一,转换为tensor后的shape形状不一致,无法直接进行多batch_size的数据加载。经过查找资料,有人提到可以定义dataloader的collate_fn函数,在加载时将数据裁剪为最小的图片尺寸,以便于堆叠成多个batch_size。

2. 代码

2.1 数据集的定义

dataset.py

import scipy.io as sio
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch
import os
import cv2
from PIL import Image
import torchvision

class myDatasets(Dataset):
    def __init__(self,img_path, ann_path, down_sample=False,transform=None):
        self.pre_img_path = img_path
        self.pre_ann_path = ann_path
        # 图像的文件名是 IMG_15.jpg 则 标签是 GT_IMG_15.mat
        # 因此不需要listdir标签路径
        self.img_names = os.listdir(img_path)
        self.transform=transform
        self.down_sample = down_sample

    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, index):
        img_name = self.img_names[index]
        mat_name = 'GT_' + img_name.replace('jpg','mat')

        img = Image.open(os.path.join(self.pre_img_path,img_name)).convert('L')
        img = np.array(img).astype(np.float32)
        
        # print(F"{h=},{w=}")
        if self.transform != None:
            img=self.transform(img)
        # img.permute(0,2,1) # totensor会自动进行维度的转换,所以这里是不必要的

        h,w = img.shape[1],img.shape[2]

        anno = sio.loadmat(self.pre_ann_path + mat_name)
        xy = anno['image_info'][0][0][0][0][0]  # N,2的坐标数组
        density_map = self.get_density((h,w), xy).astype(np.float32) # 密度图
        density_map = torch.from_numpy(density_map)

        return img,density_map


    def get_density(self,img_shape, points):
        if self.down_sample:
            h, w  = img_shape[0]//4, img_shape[1]//4
        else:
            h, w  = img_shape[0], img_shape[1]
        # 进行下采样
        # 密度图 初始化全0
        labels = np.zeros(shape=(h,w))
        for loc in points:
            f_sz = 15  # 滤波器尺寸 预设为15 也是邻域的尺寸
            sigma = 4.0  # sigma参数
            H = self.fspecial(f_sz, f_sz , sigma)  # 高斯核矩阵
            if self.down_sample:
                x = min(max(0,abs(int(loc[0]/4))),int(w))  # 头部坐标
                y = min(max(0,abs(int(loc[1]/4))),int(h))
            else:
                x = min(max(0,abs(int(loc[0]))),int(w))  # 头部坐标
                y = min(max(0,abs(int(loc[1]))),int(h))
            if x > w or y > h:
                continue
            x1 = x - f_sz/2 ; y1 = y - f_sz/2
            x2 = x + f_sz/2 ; y2 = y + f_sz/2
            dfx1 = 0; dfy1 = 0; dfx2 = 0; dfy2 = 0

            change_H = False
            if x1 < 0:
                dfx1 = abs(x1);x1 = 0 ;change_H = True
            if y1 < 0:
                dfy1 = abs(y1); y1 = 0 ; change_H = True
            if x2 > w:
                dfx2 = x2-w ; x2 =w-1 ; change_H =True
            if y2 > h:
                dfy2 = y2 -h ; y2 = h-1 ; change_H =True
            x1h =  1 + dfx1 ; y1h =  1 + dfy1
            x2h = f_sz - dfx2 ; y2h = f_sz - dfy2
            if change_H :
                H = self.fspecial(int(y2h-y1h+1), int(x2h-x1h+1),sigma)
            labels[int(y1):int(y2), int(x1):int(x2)] = labels[int(y1):int(y2), int(x1):int(x2)] + H
        return labels

    def fspecial(self,ksize_x=5, ksize_y = 5, sigma=4):
        kx = cv2.getGaussianKernel(ksize_x, sigma)
        ky = cv2.getGaussianKernel(ksize_y, sigma)
        return np.multiply(kx,np.transpose(ky))
View Code

2.2 使用

demo.py

from config import get_args
from model import MCNN
from dataset import myDatasets
import torchvision
from torch.utils.data import DataLoader
import torch
from torch import nn
import time
from utils import get_mse_mae,show
import os
import numpy as np
import matplotlib.pyplot as plt
from debug_utils import ModelVerbose
import random
import cv2

args = get_args()

if args.dataset == 'ShanghaiTechA':
    if os.name == 'nt':
        # for windows
        train_imgs_path = args.dataset_path + r'\train_data\images\\'
        train_labels_path = args.dataset_path+r'\train_data\ground-truth\\'
        test_imgs_path = args.dataset_path+r'\test_data\images\\'
        test_labels_path = args.dataset_path+r'\test_data\ground-truth\\'
    else:
        # for linux
        train_imgs_path = os.path.join(args.dataset_path,'train_data/images/')
        train_labels_path = os.path.join(args.dataset_path,'train_data/ground-truth/')
        test_imgs_path = os.path.join(args.dataset_path,'test_data/images/')
        test_labels_path = os.path.join(args.dataset_path,'test_data/ground-truth/')
    # print(F"{train_imgs_path=}\n{train_labels_path=}\n{test_imgs_path=}\n{test_labels_path=}")
else:
    raise Exception(F'Dataset {args.dataset} Not Implement')

# 数据集
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
    # torchvision.transforms.Resize((768,1024)),
    # torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def get_min_size(batch):
    min_ht, min_wd = (float('inf'),float('inf'))
    for img in batch:
        c,h,w = img.shape
        if h1:
        break
    
View Code

2.3 配置

config.py

import argparse


def get_args():
    parser = argparse.ArgumentParser(description='MCNN')

    parser.add_argument('--dataset',type=str,default='ShanghaiTechA')

    parser.add_argument('--dataset_path',type=str,default=r"C:\Users\ocean\Downloads\datasets\ShanghaiTech\part_A\\")

    parser.add_argument('--save_path',type=str,default='./save_file/')

    parser.add_argument('--print_freq',type=int,default=1)

    parser.add_argument('--device',type=str,default='cuda')

    parser.add_argument('--epochs',type=int,default=600)

    parser.add_argument('--batch_size',type=int,default=4)

    parser.add_argument('--lr',type=float,default=1e-5)

    parser.add_argument('--optimizer',type=str,default='Adam')

    args = parser.parse_args()
    # for jupyer notbook
    # args = parser.parse_know_args()[0]
    return args
View Code

 3. 总结

其中比较值得说道时collate_fn函数c_f(),它的代码如下所示

def c_f(batch):
    transposed = list(zip(*batch))
    imgs, dens = [transposed[0],transposed[1]]
    error_msg = "batch must contain tensors; found {}"
    if isinstance(imgs[0],torch.Tensor) and isinstance(dens[0],torch.Tensor):
        min_h, min_w = get_min_size(imgs)
        cropped_imgs = []
        cropped_dens = []
        for i in range(len(batch)):
            _img,_dtmap = random_crop(imgs[i],dens[i],(min_h,min_w))
            cropped_imgs.append(_img)
            cropped_dens.append(_dtmap)
        cropped_imgs = torch.stack(cropped_imgs)
        cropped_dens = torch.stack(cropped_dens)
        return [cropped_imgs,cropped_dens] # 这里不用列表包起来应该也行
    raise TypeError((error_msg.format(type(batch[0]))))

这里传入的参数batch是一个list,其长度是batch_size。它的每一个元素代表了一个数据集单元,即自定义数据集类中__getitem__方法return的值。由于我们的__getitem__方法return了img和density_map两个数据,所以batch的每一个数据单元其实是一个元组(img, density_map)。

list(zip(*batch))所做的事情是把batch中的imgs和density_maps分别拿出来各自成为一个列表,方便下一步的处理。

在处理最后还要将列表中的元素堆叠成tensor返回

你可能感兴趣的:(Pytorch笔记:dataloader的collate_fn参数在加载数据集时的作用)