在训练CTPN的时候,数据集处理的 cv2.dnn.blobFromImage 之后的reshape报的这个错。原因是有一张图像它的通道数乘以宽和高等于571428,不等于3 * 351 * 407,因此不能reshape到(3,351,407)。
算了一下 571428 = 4 * 351 * 407 ,说明这个图莫名其妙地是个4通道的图像。然后断点找了一下这张图,发现它竟然是一个png格式的后缀名jpg的假jpg,多了一个通道。因为png是RGBA,jpg是RGB。
解决:进行图像转换,把它转换为RGB的三通道图。在这
def __getitem__(self, idx, scale=600, maxScale=900):
img_name = self.img_names[idx]
img_path = os.path.join(self.datadir, img_name)
img = Image.open(img_path)
img, rate = resize_img(img, scale, maxScale=maxScale)
rescale_fac = 1/rate
h, w = img.shape[:2]
img = cv2.dnn.blobFromImage(
img, scalefactor=1.0, size=(w, h), swapRB=False, crop=False) # <Mat>
img = img.reshape(3, h, w)
# = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
img = torch.tensor(img)
# img = Variable(img).type(Tensor)
xml_path = os.path.join(self.labelsdir, img_name.split('.')[0]+'.xml')
gtbox = self.generate_gtboxes(xml_path, rescale_fac)
feature_size = (int(np.ceil(h/16)), int(np.ceil(w/16)))
[cls, regr] = cal_rpn((h, w), feature_size, 16, gtbox)
# regr = [ancher_nums, 3]
regr = np.hstack([cls.reshape(cls.shape[0], 1), regr])
regr = torch.from_numpy(regr).float()
# cls = [1, ancher_nums]
cls = np.expand_dims(cls, axis=0)
cls = torch.from_numpy(cls).float()
return img, cls, regr