1. 前言
最近在复现MCNN时发现一个问题,ShanghaiTech数据集图片的尺寸不一,转换为tensor后的shape形状不一致,无法直接进行多batch_size的数据加载。经过查找资料,有人提到可以定义dataloader的collate_fn函数,在加载时将数据裁剪为最小的图片尺寸,以便于堆叠成多个batch_size。
2. 代码
2.1 数据集的定义
dataset.py


import scipy.io as sio
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch
import os
import cv2
from PIL import Image
import torchvision
class myDatasets(Dataset):
def __init__(self,img_path, ann_path, down_sample=False,transform=None):
self.pre_img_path = img_path
self.pre_ann_path = ann_path
# 图像的文件名是 IMG_15.jpg 则 标签是 GT_IMG_15.mat
# 因此不需要listdir标签路径
self.img_names = os.listdir(img_path)
self.transform=transform
self.down_sample = down_sample
def __len__(self):
return len(self.img_names)
def __getitem__(self, index):
img_name = self.img_names[index]
mat_name = 'GT_' + img_name.replace('jpg','mat')
img = Image.open(os.path.join(self.pre_img_path,img_name)).convert('L')
img = np.array(img).astype(np.float32)
# print(F"{h=},{w=}")
if self.transform != None:
img=self.transform(img)
# img.permute(0,2,1) # totensor会自动进行维度的转换,所以这里是不必要的
h,w = img.shape[1],img.shape[2]
anno = sio.loadmat(self.pre_ann_path + mat_name)
xy = anno['image_info'][0][0][0][0][0] # N,2的坐标数组
density_map = self.get_density((h,w), xy).astype(np.float32) # 密度图
density_map = torch.from_numpy(density_map)
return img,density_map
def get_density(self,img_shape, points):
if self.down_sample:
h, w = img_shape[0]//4, img_shape[1]//4
else:
h, w = img_shape[0], img_shape[1]
# 进行下采样
# 密度图 初始化全0
labels = np.zeros(shape=(h,w))
for loc in points:
f_sz = 15 # 滤波器尺寸 预设为15 也是邻域的尺寸
sigma = 4.0 # sigma参数
H = self.fspecial(f_sz, f_sz , sigma) # 高斯核矩阵
if self.down_sample:
x = min(max(0,abs(int(loc[0]/4))),int(w)) # 头部坐标
y = min(max(0,abs(int(loc[1]/4))),int(h))
else:
x = min(max(0,abs(int(loc[0]))),int(w)) # 头部坐标
y = min(max(0,abs(int(loc[1]))),int(h))
if x > w or y > h:
continue
x1 = x - f_sz/2 ; y1 = y - f_sz/2
x2 = x + f_sz/2 ; y2 = y + f_sz/2
dfx1 = 0; dfy1 = 0; dfx2 = 0; dfy2 = 0
change_H = False
if x1 < 0:
dfx1 = abs(x1);x1 = 0 ;change_H = True
if y1 < 0:
dfy1 = abs(y1); y1 = 0 ; change_H = True
if x2 > w:
dfx2 = x2-w ; x2 =w-1 ; change_H =True
if y2 > h:
dfy2 = y2 -h ; y2 = h-1 ; change_H =True
x1h = 1 + dfx1 ; y1h = 1 + dfy1
x2h = f_sz - dfx2 ; y2h = f_sz - dfy2
if change_H :
H = self.fspecial(int(y2h-y1h+1), int(x2h-x1h+1),sigma)
labels[int(y1):int(y2), int(x1):int(x2)] = labels[int(y1):int(y2), int(x1):int(x2)] + H
return labels
def fspecial(self,ksize_x=5, ksize_y = 5, sigma=4):
kx = cv2.getGaussianKernel(ksize_x, sigma)
ky = cv2.getGaussianKernel(ksize_y, sigma)
return np.multiply(kx,np.transpose(ky))
2.2 使用
demo.py


from config import get_args
from model import MCNN
from dataset import myDatasets
import torchvision
from torch.utils.data import DataLoader
import torch
from torch import nn
import time
from utils import get_mse_mae,show
import os
import numpy as np
import matplotlib.pyplot as plt
from debug_utils import ModelVerbose
import random
import cv2
args = get_args()
if args.dataset == 'ShanghaiTechA':
if os.name == 'nt':
# for windows
train_imgs_path = args.dataset_path + r'\train_data\images\\'
train_labels_path = args.dataset_path+r'\train_data\ground-truth\\'
test_imgs_path = args.dataset_path+r'\test_data\images\\'
test_labels_path = args.dataset_path+r'\test_data\ground-truth\\'
else:
# for linux
train_imgs_path = os.path.join(args.dataset_path,'train_data/images/')
train_labels_path = os.path.join(args.dataset_path,'train_data/ground-truth/')
test_imgs_path = os.path.join(args.dataset_path,'test_data/images/')
test_labels_path = os.path.join(args.dataset_path,'test_data/ground-truth/')
# print(F"{train_imgs_path=}\n{train_labels_path=}\n{test_imgs_path=}\n{test_labels_path=}")
else:
raise Exception(F'Dataset {args.dataset} Not Implement')
# 数据集
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
# torchvision.transforms.Resize((768,1024)),
# torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def get_min_size(batch):
min_ht, min_wd = (float('inf'),float('inf'))
for img in batch:
c,h,w = img.shape
if h1:
break
2.3 配置
config.py


import argparse
def get_args():
parser = argparse.ArgumentParser(description='MCNN')
parser.add_argument('--dataset',type=str,default='ShanghaiTechA')
parser.add_argument('--dataset_path',type=str,default=r"C:\Users\ocean\Downloads\datasets\ShanghaiTech\part_A\\")
parser.add_argument('--save_path',type=str,default='./save_file/')
parser.add_argument('--print_freq',type=int,default=1)
parser.add_argument('--device',type=str,default='cuda')
parser.add_argument('--epochs',type=int,default=600)
parser.add_argument('--batch_size',type=int,default=4)
parser.add_argument('--lr',type=float,default=1e-5)
parser.add_argument('--optimizer',type=str,default='Adam')
args = parser.parse_args()
# for jupyer notbook
# args = parser.parse_know_args()[0]
return args
3. 总结
其中比较值得说道时collate_fn函数c_f(),它的代码如下所示
def c_f(batch):
transposed = list(zip(*batch))
imgs, dens = [transposed[0],transposed[1]]
error_msg = "batch must contain tensors; found {}"
if isinstance(imgs[0],torch.Tensor) and isinstance(dens[0],torch.Tensor):
min_h, min_w = get_min_size(imgs)
cropped_imgs = []
cropped_dens = []
for i in range(len(batch)):
_img,_dtmap = random_crop(imgs[i],dens[i],(min_h,min_w))
cropped_imgs.append(_img)
cropped_dens.append(_dtmap)
cropped_imgs = torch.stack(cropped_imgs)
cropped_dens = torch.stack(cropped_dens)
return [cropped_imgs,cropped_dens] # 这里不用列表包起来应该也行
raise TypeError((error_msg.format(type(batch[0]))))
这里传入的参数batch是一个list,其长度是batch_size。它的每一个元素代表了一个数据集单元,即自定义数据集类中__getitem__方法return的值。由于我们的__getitem__方法return了img和density_map两个数据,所以batch的每一个数据单元其实是一个元组(img, density_map)。
list(zip(*batch))所做的事情是把batch中的imgs和density_maps分别拿出来各自成为一个列表,方便下一步的处理。
在处理最后还要将列表中的元素堆叠成tensor返回