class LaneDataset(Dataset):
def __init__(self, csv_file, transform=None):
super(LaneDataset, self).__init__()
self.data = pd.read_csv(os.path.join(os.getcwd(), "data_list", csv_file), header=None,
names=["image",
"label"])
self.images = self.data["image"].values
self.labels = self.data["label"].values
self.transform = transform
def __len__(self):
return self.labels.shape[0]
def __getitem__(self, idx):
ori_image = cv2.imread(self.images[idx])
ori_mask = cv2.imread(self.labels[idx], cv2.IMREAD_GRAYSCALE)
train_img, train_mask = crop_resize_data(ori_image, ori_mask)
train_mask = encode_labels(train_mask)
sample = [train_img.copy(), train_mask.copy()]
if self.transform:
sample = self.transform(sample)
return sample
然后将数据加载进来,
training_dataset = LaneDataset("train.csv", transform=transforms.Compose([ImageAug(), DeformAug(),
ScaleAug(), CutOut(32,0.5), ToTensor()]))
training_data_batch = DataLoader(training_dataset, batch_size=16,
shuffle=True, drop_last=True, **kwargs)
当调用数据的时候,出现:TypeError: ‘NoneType’ object is not subscriptable,表示在加载数据的时候有数据丢失了,检查后发现原因是用pandas存储csv文件后,csv自带了head,在读文件的时候把没有地址的head也读了进去。改进后代码为:
class LaneDataset(Dataset):
def __init__(self, csv_file, transform=None):
super(LaneDataset, self).__init__()
self.data = pd.read_csv(os.path.join(os.getcwd(), "data_list", csv_file), header=None,
names=["image",
"label"])
self.images = self.data["image"].values
self.labels = self.data["label"].values
self.images = np.delete(self.images,0,axis = 0)
self.labels = np.delete(self.labels,0,axis = 0)
self.transform = transform
def __len__(self):
return self.labels.shape[0]
def __getitem__(self, idx):
ori_image = cv2.imread(self.images[idx])
ori_mask = cv2.imread(self.labels[idx], cv2.IMREAD_GRAYSCALE)
train_img, train_mask = crop_resize_data(ori_image, ori_mask)
train_mask = encode_labels(train_mask)
sample = [train_img.copy(), train_mask.copy()]
if self.transform:
sample = self.transform(sample)
return sample
把数组的第一行删除就好了。