一、发现说明
使用torch.utils.data.DataLoader()函数对数据集进行按批分割处理,然后在训练网络时用enumerate()函数取出训练数据。发现不同Epoch,相同step(下文解释)情况下,enumerate取到的训练数据不一样。这样意味着DataLoader()联用enumerate()使用时候,每次取到的训练数据是动态的。
二、代码演示
import torch
import torch.utils.data as Data
import argparse
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms, datasets
import torch.backends.cudnn as cudnn
torch.set_printoptions(threshold=np.inf) #tensor输出不省略
device = 'cuda' if torch.cuda.is_available() else 'cpu'
parser = argparse.ArgumentParser()
parser.add_argument('--lr', default=0.00005, type=float, help='learning rate')
parser.add_argument('--dataset', default='CIFAR10', type=str, choices=['MNIST', 'FashionMNIST', 'CIFAR10'])
parser.add_argument('--batchsize', default=128, type=int)
parser.add_argument('--epoch', default=2, type=int)
args = parser.parse_args()
batch_size = args.batchsize
#定义数据loader函数
def get_data_loader():
if(args.dataset=='MNIST'):
data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.1307], std=[0.3081])
])
trainset = datasets.MNIST(root='.', train=True, transform=data_transform,
download=True)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
testset = datasets.MNIST(root='.', train=False, transform=data_transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=True, num_workers=2)
if (args.dataset == 'FashionMNIST'):
data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.1307], std=[0.3081])
])
trainset = datasets.FashionMNIST(root='.', train=True, transform=data_transform,
download=True)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
testset = datasets.FashionMNIST(root='.', train=False, transform=data_transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=True, num_workers=2)
if (args.dataset == 'CIFAR10'):
data_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
trainset = datasets.CIFAR10(root='.', train=True, transform=data_transform,
download=True)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
testset = datasets.CIFAR10(root='.', train=False, transform=data_transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=True, num_workers=2)
return train_loader, test_loader
if __name__ == '__main__':
Train_loader, Test_loader = get_data_loader()
for e in range(args.epoch):
print("Epoch: ", e + 1)
for step, (x, y) in enumerate(Train_loader, 0):
if device == 'cuda':
input, label = x.cuda(), y.cuda()
print(f'step:{step}')
print(f'样本训练集\t input:{input.shape}, label:{label.shape}')
print(input[0][0][0], label[0])
break #用break跳出,只对比每个epoch相同step下,enumerate取到的训练数据
print('\n')
我们先执行了数据装载代码,主观意识上,此时训练集中每个batch的顺序是静态的,
Train_loader, Test_loader = get_data_loader()
然后通过enumerate()去取个训练集中每个batch的数据,应该不同的Epoch在step相同时,取到的数据应该相同。
但是,从结果看,显然,第1,2个Epoch在step0相同时,取到的训练数据是不同的,也就意味着DataLoader()联用enumerate()的时候,取样本的过程是动态的。