pytorch Dataset与DataLoader

在模型训练或预测时,需要加载数据集,对数据集进行预处理,提取特征,并分批读取,在minibatch内对数据进行Padding。

训练时用到的数据处理和预测时用到的数据的处理可以在同一个Dataset中,这样可以复用一些数据处理的函数。

from torch.utils.data import Dataset
class MyDataset(Dataset):
	def __init__(self,is_train):
		pass
	def __getitem__(self, idx):
		pass
	def __len__(self):
		pass

DataLoader的collate_fn:

from torch.nn.utils.rnn import pad_sequence

def my_collate_fn(data, is_train, padding_value):
	if is_train:
		token_ids = pad_seq(data[0],batch_first=True,padding_value=padding_value)
		seq_len = torch.stack(data[1],dim=0)
		labels = torch.stack(data[2],dim=0)
		return token_ids, seq_len, labels
	else:
		token_ids = pad_seq(data[0],batch_first=True,padding_value=padding_value)
		seq_len = torch.stack(data[1],dim=0)
		return token_ids, seq_len

当使用DataLoader时:

from torch.utils.data import DataLoader
from functionals import partial

# Load data use Dataset and DataLoader
dataset = MyDataset()
collate_fn = partial(my_collate_fn, is_train=False, padding_value=0)
my_loader = DataLoader(dataset, collate_fn=collate_fn)

# do train or evalate
for data in my_loader:
	loss, predicts = self.model(*data)
	optimizer.zero_grad()
	loss.backward()
	optimizer.step()

你可能感兴趣的:(pytorch,NLP)