最近在使用pytorch复现PointNet分割网络的过程中,在读入数据时遇到了一些问题,需要重写DataLoader中的sampler和collate_fn
sampler的作用是按照指定的顺序向batch里面读入数据,自定义的sampler可以根据我们的需要返回索引,DataLoader会根据我们返回的索引值提取数据,生成batch
注意:
重写sampler需要重写__len__()和__iter__()方法,其中__len__()返回你读入数据的总长度,iter()返回一个迭代器
例如,我们需要sampler根据样本点数返回索引
class sampler(data.Sampler):
"""
由于每个batch的点数可能不一致
例如 len(b[0])=10220, len(b[1])=23300, len(b[2])=24000 , ...
该sampler是为了将每个batch内的点数统一
首先将batch里的样本按照点数从小到大排列
返回排序之后的索引值
"""
def __init__(self, data_source):
super(sampler, self).__init__(data_source)
self.x = data_source
# y = data_source[1]
self.lst = []
for i in range(len(self.x)):
self.lst.append(self.x[i].shape[0])
self.idx = np.argsort(self.lst) # 排序之后的索引
def __iter__(self):
return iter(self.idx) # 这里的idx最后会返回给DataSets中的__getitem__方法
def __len__(self):
return len(self.x[0]) # 这里的__len__需要返回总长度
在data.DataLoader中找到了下面注释:
sampler (Sampler or Iterable, optional): defines the strategy to draw
samples from the dataset. Can be anyIterable
with__len__
implemented. If specified, :attr:shuffle
must not be specified.
意思是自己重写了sampler之后,shuffle关键字不能指定。
很重要的一点:
最后在传参的时候,传入的是sampler的示例
例如:
p_loader = data.DataLoader(p, batch_size, drop_last=True, sampler=sampler(sourcedata)) # p是DateSet的子类
collate_fn方法的作用是对于还未被连结batch进行操作,因为是没有被连结所以这里我说的batch是list类型
pytorch规定每一个batch中样本的点数必须相同,所以重写collate_fn方法,将每个batch中样本下采样到相同的数目
这里的函数的下采很简单,就是单纯的取得batch中的最小样本点数,将其他样本中的点shuffle之后取前最小个点数
def collate_fn(batch: list):
"""
DataLoader中的最后一步
对即将输出的batch进行操作
这里将每一个batch的样本点数降到最少点数即min_num_pts
:param batch: B*N*3 不同B中的N并不相同
:return:
"""
batch_size = len(batch)
ret_cls = []
for elem in batch:
ret_cls.append(elem[2])
ret_cls = torch.tensor(ret_cls)
num_pts_lst = []
for elem in batch:
X, y = elem[0], elem[1]
num_pts_lst.append(X.shape[0])
sorted_lst = np.argsort(num_pts_lst)
min_mun_pts = len(batch[sorted_lst[0]][0])
temp_points = []
temp_target = []
for i in range(batch_size):
X, y = batch[i][0], batch[i][1]
torch.manual_seed(2021)
X = X[torch.randperm(min_mun_pts)]
torch.manual_seed(2021)
y = y[torch.randperm(min_mun_pts)]
temp_points.append(X)
temp_target.append(y)
ret_points = temp_points[0].unsqueeze(0)
ret_target = temp_target[0].unsqueeze(0)
for i in range(1, batch_size):
cat_X = temp_points[i].unsqueeze(0)
cat_y = temp_target[i].unsqueeze(0)
ret_points = torch.cat([ret_points, cat_X], dim=0)
ret_target = torch.cat([ret_target, cat_y], dim=0)
return ret_points, ret_target, ret_cls
最后传入参数时,传入的是方法名
p_loader = data.DataLoader(p, batch_size, drop_last=True, collate_fn=collate_fn)
参考:
知乎大佬的sampler源码解读,很有用