这篇博文不讲原理,只讲一些使用方法和技巧。所有提供的信息仅供参考,不要当作金科玉律。
首先给出讲述的时候使用的基本程序框架。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
class My_Dataset(Dataset):
def __init__(self, list1, array2):
self.len = len(list1)
self.x_data = list1 # something support indexing, like a list, length = 16
self.y_data = array2 # something support indexing, like torch.Tensor, shape = (16, 4, 5)
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
# padding unequal length sequences
def collate_fn(batch_data):
return batch_data
# train dataloader & test dataloader
list1 = [chr(ord('a') + i) for i in range(16)] # 'a'~'p'
array2 = torch.randn((16, 4, 5))
my_dataset = My_Dataset(list1, array2)
my_dataloader = DataLoader(dataset = my_dataset,
batch_size = 4,
collate_fn = collate_fn)
注意这个函数:
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
这代表,如果你用下标索引i从dataloader中取出值,返回值将会是一个长度为2的元组,下标为0的是list1[i]
(即第i+1个字母),下标为1的是array2[i]
(即一个size = (4, 5)
的tensor)。暂且称这种形式的数据为data[i]
。
此时如果你运行如下指令:
for batch_data in enumerate(my_dataloader):
# show batch_data
batch_data是一个长度为2的元组,下标为0的是这个batch的序号(在以上的程序里面是0~3),下标为1的是一个长度为4(batch_size)的support indexing的对象,这个对象的每个元素就是对应batch中应该包含的几个data[i]
,比如第0个batch的这个列表中的元素就分别是data[0],..data[3]
。至于data[i]
则是刚才说的由两项数据所构成的元组。
在这里,下标为1的对象是一个列表。而如果数据本身就是一个tensor的话,这里会给一个第一维维度为batch_size,其他维维度数对应的tensor.
此时如果你运行如下指令:
for batch_index, batch_data in enumerate(train_loader):
# show data
这里的batch_index对应元组的下标为0的元素,即这个batch的序号(在以上的程序里面是0~3);batch_data对应上面的列表(support indexing的对象)。显然这种更细致的处理是更常用的。
对于以上讲的两点,读者可以直接跑一下附录1所示的程序来获得直观感受。
在从dataloader中读取数据时,可以通过collate_fn
做处理,使读取的数据符合要求。
让我们审视这个函数:
def collate_fn(batch_data):
return batch_data
这里输入的batch_data就是上一节那个以batch_size为长度,以对应位置的data[i]
为元素的列表。如果要取得元素之后进行特定处理,可以在这个函数里面操作;这个函数的返回值会代替原来那个列表的位置。可以运行附录2的代码获得直观感受。
在自然语言处理中,可能要把不等长的tensor padding 成等长,这个步骤可以在collate_fn里面做。举个例子,下面的这个函数从不等长Tensor的列表生成一个padding成等长的高维tensor.
def collate_fn(data):
# self.data: list of tensors of different length
# data:[x[0], x[1], ..], x[0].shape = (20, 128), x[1].shape = (30, 128)
# x[2].shape = (28, 128), x[3].shape = (25, 128)
data.sort(key=lambda data: len(data[0]), reverse=True) # 按照序列长度降序排列
seq_len_list = [elem.shape[0] for elem in data]
data = pad_sequence(data, batch_first=True, padding_value=0)
seq_len_list = torch.Tensor(seq_len_list)
return data_batch, seq_len_list
# data_batch.shape = [4, 30, 128], seq_len_list = [20, 30, 28, 25]
函数的返回值包括合并的高维tensor和每个小tensor的实际长度,方便后续处理使用。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
torch.manual_seed(314)
class My_Dataset(Dataset):
def __init__(self, list1, array2):
self.len = len(list1)
self.x_data = list1 # something support indexing, like a list, length = 16
self.y_data = array2 # something support indexing, like torch.Tensor, shape = (16, 4, 5)
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
# padding unequal length sequences
def collate_fn(batch_data):
return batch_data
# train dataloader & test dataloader
list1 = [chr(ord('a') + i) for i in range(16)] # 'a'~'p'
array2 = torch.randn((16, 4, 5))
my_dataset = My_Dataset(list1, array2)
my_dataloader = DataLoader(dataset = my_dataset,
batch_size = 4,
collate_fn = collate_fn)
for batch_data in enumerate(my_dataloader):
# show batch_data
print("New Batch")
print(type(batch_data), len(batch_data), batch_data[0], type(batch_data[1]))
print(len(batch_data[1]), type(batch_data[1][0]))
print(batch_data[1][0][0], type(batch_data[1][0][1]), batch_data[1][0][1].shape)
for batch_index, batch_data in enumerate(my_dataloader):
# show batch_data
print("Batch", batch_index)
for i in range(len(batch_data)):
print(type(batch_data[i]), len(batch_data[i]))
print(batch_data[i][0], type(batch_data[i][1]), batch_data[i][1].shape)
...
my_dataset = My_Dataset(list1, array2)
my_dataloader = DataLoader(dataset = my_dataset,
batch_size = 4,
collate_fn = collate_fn)
for batch_index, batch_data in enumerate(my_dataloader):
# show batch_data
print("Batch", batch_index)
print(batch_data)