使用pytorch训练网络之前的数据准备部分都要有两个流程:Dataset和DataLoader。前者用来定义自己的数据集类型(内部实现最主要的是__getitem__()
方法);而后者则是负责真正在运行的时侯给网络递送数据。
继承Dataset类,自定义数据处理类。必须重载实现len()、getitem()这两个方法。
其中__len__
返回数据集样本的数量,而__getitem__
应该编写支持数据集索引的函数,返回数据和对应label,例如:通过dataset[i]可以得到数据集中的第i+1个数据。
DataLoader完整的参数表如下:
class torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None)
几个关键的参数意思:
dataset
:PyTorch已有的数据读取接口或自定义数据接口的输出
batch_size
:根据具体情况设置
shuffle
:设置为True的时候,每个迭代都会打乱数据集,一般在训练数据中会采用
num_workers
:这个参数必须大于等于0,0表示数据导入在主进程中进行,大于0表示通过多个进程来导入数据,可以加快数据导入速度。
collate_fn
:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能
drop_last
:告诉如何处理数据集长度除于batch_size余下的数据。True抛弃,否则保留
通常说来,我们在编写完Dataset之后,其内部的__getitem__
会弹出一个[data, label]的一条数据,DataLoader中的collate_fn
函数将这些一条一条的数据组织成一个batch。
注意:
通常的,默认的collate_fn函数是要求一个batch中的图片都具备相同size,当一个batch中的图片大小都不一样时(或者想要定值batch的输出形式),使用自定义的collate_fn函数。
通过collate_fn函数可以对这些样本做进一步的处理(任何你想要的处理),原则上返回值应当是一个有结构的batch。而DataLoader每次迭代的返回值就是collate_fn的返回值。
可以使用collate_fn的同时,结合使用默认的default_collate。
from torch.utils.data.dataloader import default_collate # 导入这个函数
def my_collate_fn(batch):
"""
params:
batch :是一个列表,列表的长度是 batch_size,list中的每个元素都是__getitem__得到的结果。
列表的每一个元素是 (x,y) 这样的元组tuple,元祖的两个元素分别是x,y
大致的格式如下 [(x1,y1),(x2,y2),(x3,y3)...(xn,yn)]
returns:
整理之后的新的batch
"""
# 这一部分是对 batch 进行重新 “校对、整理”的代码
return default_collate(batch) #返回校对之后的batch,一般就直接推荐使用default_collate进行包装,因为它里面有很多功能,比如将numpy转化成tensor等操作,这是必须的。
然后调用时使用:
trainset = DataLoader(dataset=train_dataset,
batch_size=24,
shuffle=True,
collate_fn=my_collate_fn,
图神经网络时,将多张图合并为一张大图。
"""
Combine multiple graphs into one large graph.
"""
#- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能
# collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果。
def collate_fn(batch):
nodes_list = [b[0] for b in batch] #b[0]=p.array(nodes)
nodes = np.concatenate(nodes_list, axis=0) #所有节点拼到一起,不扩展维度,拼成一个array
#map 对于node_list每一组,计算shape(即节点个数)。也就是一个可迭代对象。返回array(每个图的节点个数)
nodes_lens = np.fromiter(map(lambda l: l.shape[0], nodes_list), dtype=np.int64)
nodes_inds = np.cumsum(nodes_lens) #计算一个数组各行的累加值
nodes_num = nodes_inds[-1] #最后一个值,即总节点个数
nodes_inds = np.insert(nodes_inds, 0, 0) #在第一个位置插入0这个值
nodes_inds = np.delete(nodes_inds, -1) #按行展开后 删除最后一个元素
edges_list = [b[1] for b in batch] #np.array(edges)
edges_list = [e + i for e, i in zip(edges_list, nodes_inds)] #e是边的连接(每个图都从0开始) i是总节点个数
edges = np.concatenate(edges_list, axis=0)
m = edges_to_matrix(nodes_num, edges) #将每个batch拼接成一个邻接矩阵
labels = [b[2] for b in batch] #np.array([float(label)])
labels = np.concatenate(labels, axis=0)
# batch中第i个图的节点个数为k batch_mask数组分别为[0,..,0] [1,..,1] 其中...为节点个数
batch_mask = [np.array([i] * k, dtype=np.int32) for i, k in zip(range(len(batch)), nodes_lens)]
batch_mask = np.concatenate(batch_mask, axis=0)
#返回节点类型 邻接矩阵 预测值 batch_mask
return torch.from_numpy(nodes), torch.from_numpy(m).float(), torch.from_numpy(labels), torch.from_numpy(batch_mask)