pytorch:torch.utils.data.DataLoader的使用

大佬文章
中文的(感觉有的地方翻译和使用的有问题)
在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化。

代码:

train_dataloader = torch.utils.data.DataLoader(train_data, args.batch_size, collate_fn=collate_fn, shuffle=True)

第一个参数:训练集(或者是其他集)

第二个参数:batch_size(),每一次喂入模型的数据量。

第三个参数:collate_fn 不详。

第四个参数:在每次迭代训练时是否将数据洗牌,默认设置是False。将输入数据的顺序打乱,是为了使数据更有独立性,但如果数据是有序列特征的,就不要设置成True了。

你可能感兴趣的:(pytorch)