本文主要聚焦以下几个问题:
shuffle=True
的时候究竟打乱的是什么shuffle=True
在非时序预测中,在自定义dataset
的时候根据index直接返回对应索引的一个样本就行了,此时不管dataloader
是否设置shuffle=True
,都可以正常返回数据,只是shuffle=False
的时候可能会稍微影响性能。
但是在时序数据里,如果我们不提前把数据处理成:一个样本就是一个序列。的形式,那么就需要在dataset中自己手动取一个序列,此时需要考虑是否打乱了,打乱会不会让时序关系彻底乱了。
什么意思呢,再解释一下,正常来说我们的数据是这样的:每一个样本就是一个一维向量。下面就是每一个样本有7个特征,一共100个样本。
data = [[i] * 7 for i in range(100)]
data = np.array(data)
如果没有时序关系,每次就取一行拿来训练就完了。此时Dataset就可以这样定义:
class aiDataset(Dataset):
def __init__(self, features, labels):
self.features = features
self.labels = labels
def __getitem__(self, index):
feature = self.features[index] # .values
if self.labels is not None:
label = self.labels[index]
return feature, label
else:
return feature
def __len__(self):
return len(self.features)
如果这100个数据是按照时间的,他们有时序关系,而我们又要使用GRU或者LSTM来做,那么我们的输入就不是一个时刻的数据了,而是当前时刻(含)之前若干个时刻组成的序列。当然,我们可以提前写一个for循环,把所有的序列组成一个新的矩阵,比如:
X = []
y = []
window_size = 2
for i in range(len(data) - window_size):
X.append(data[i:i+window_size, :])
y.append(labels[i+window_size-1])
X = np.array(X)
y = np.array(y)
print(X.shape) # [N-window_size, window_size, features_num]
print(y.shape) # [N-window_size]
但是,当数据量很大的时候,可能这个循环就要好久,好不如直接写在dataset里。
class aiDataset(Dataset):
def __init__(self, features, labels, window_size=2):
self.window_size = window_size
self.features = features
self.labels = labels
def __getitem__(self, index):
feature = self.features[index:index+self.window_size, :] # .values
if self.labels is not None:
label = self.labels[index+self.window_size-1]
return feature, label
else:
return feature
def __len__(self):
return len(self.features) - self.window_size
那么这就是时序数据在定义dataset的时候和非时序数据的区别了。
shuffle=True
当我们定义好dataset之后,不出意外的话,就可以直接定义dataloader了。
data = [[i] * 7 for i in range(100)]
data = np.array(data)
labels = [i*2 for i in range(100)]
train_dataset = aiDataset(data, None, 4) # window_size设置为4
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=2)
上面的代码我们设置了shuffle=True
,我们来看看每一个batch的东西是什么。
可以看到,即便设置了shuffle=True
,每一个batch中的每一个样本依然是按照原来的时序关系的,所以不需要担心shuffle会打乱原来的时序数据。
Dataloader
中的shuffle到底shuffle了什么。这个其实很简单,具体可以看dataloader的源码,比较简单,结论就是打乱的是index
,所以我们看到的上面图中同一个batch内的样本是随机的,但是一个样本内是有序的。