pytorch dataset和dataloader使用实例(seq2seq)

场景:英译中,数据集包括训练集14533行,前面是英文后面是中文,中间用\t隔开
pytorch dataset和dataloader使用实例(seq2seq)_第1张图片
先明确我们的任务:
1 导入数据集,把所用句子加上‘BOS’和‘EOS’,中文和英文分开放在一个二维list中,里面嵌套的每个list表示一个句子,元素是单词。
2 建立词典,其中’unk’=0,'pad’=1
3 将单词根据词典编码,并按en中句子的长度排序
4 分成batch,记录每一个batch的行索引
5 记录每一个batch中的句子
6 将每一个batch中的句子填充成一样的长度,不足补0,并记录每个句子的原始长度
1.

def load_data(file):
    with open(file, 'r', encoding='utf-8') as f:
        # print(type(f))
        en = []
        cn = []
        for line in f:
            line = line.strip().split('\t')
            en.append(['BOS'] + [c for c in line[0].split()] + ['EOS'])
            cn.append(['BOS'] + [c for c in jieba.cut(line[1])] + ['EOS'])
        return en, cn
MAX_VOCAB_SIZE = 50000


def build_dict(text):
    vocab = Counter()
    for seq in text:
        for word in seq:
            vocab[word] += 1
    vocab_tuple = vocab.most_common(MAX_VOCAB_SIZE)  # [('BOS', 14533), ('EOS', 14533), ('I', 3221), ('the', 2976), ]
    word_to_idx = {word[0]: idx+2 for idx, word in enumerate(vocab_tuple)}
    word_to_idx[''] = 0
    word_to_idx[''] = 1
    return word_to_idx

后面几步全部放在dataloader中完成:

class EnToCn(torch.utils.data.Dataset):

    def __init__(self, en, cn, en_word_to_idx, cn_word_to_idx, batch_size):
        self.en_encoded = []
        for seq in en:
            self.en_encoded.append([en_word_to_idx.get(word, 0) for word in seq])
        self.en_encoded = [self.en_encoded[i] for i in sorted(range(len(self.en_encoded)), key=lambda x: len(self.en_encoded[x]))]
        # self.en_encoded = torch.Tensor(self.en_encoded).long()

        self.cn_encoded = []
        for seq in cn:
            self.cn_encoded.append([cn_word_to_idx.get(word, 0) for word in seq])
        self.cn_encoded = [self.cn_encoded[i] for i in sorted(range(len(self.en_encoded)), key=lambda x: len(self.en_encoded[x]))]
        # self.cn_encoded = torch.Tensor(self.cn_encoded).long()

        self.idx = np.arange(0, len(en), batch_size)
        print(len(self.idx))  # 228
        self.batch_idx = []
        for i in self.idx:
            self.batch_idx.append(np.arange(i, min((i+batch_size), len(en))))
        print(len(self.batch_idx))  # 228 = 14533/64
        print(len(self.batch_idx[0]))  # 64

    def __len__(self):
        return len(self.batch_idx)

    def __getitem__(self, idx):
        en_batch = [self.en_encoded[i] for i in self.batch_idx[idx]]
        print(len(en_batch))  # 64
        cn_batch = [self.cn_encoded[i] for i in self.batch_idx[idx]]
        en_batch_len = [len(seq) for seq in en_batch]
        cn_batch_len = [len(seq) for seq in cn_batch]
        en_batch_maxlen = max(en_batch_len)
        cn_batch_maxlen = max(cn_batch_len)
        num_seqs = len(en_batch)
        en_zero_pad = np.zeros((num_seqs, en_batch_maxlen)).astype('int32')
        cn_zero_pad = np.zeros((num_seqs, cn_batch_maxlen)).astype('int32')
        # print('type(en_zero_pad)', type(en_zero_pad))  # 
        # print(en_zero_pad)
        for idx, seq in enumerate(en_batch):
            en_zero_pad[idx, :len(seq)] = seq
        for idx, seq in enumerate(cn_batch):
            cn_zero_pad[idx, :len(seq)] = seq
        # print('type(en_zero_pad)', type(en_zero_pad))  # 
        # print(en_zero_pad)
        en_batch_len = np.array(en_batch_len).astype('int32')
        cn_batch_len = np.array(cn_batch_len).astype('int32')
        # print('type(en_batch_len)', type(en_batch_len))  # 

        return en_zero_pad, en_batch_len, cn_zero_pad, cn_batch_len

使用:

train_set = EnToCn(train_en, train_cn, en_word_to_idx, cn_word_to_idx, batch_size=64)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=True)

打印:

print(next(iter(train_loader))

结果是一个[tensor([[[三维]]]),tensor([[二维]]),tensor([[[三维]]]),tensor([[二维]])] 一个batch
问题在于我返回的明明是四个np.array类型的啊。。。只能说明是在使用代码的第二行:Dataloader里自己转换成tensor了,而且还自己加了一个维度。。
原来的格式是(array(二维),array(一维),array(二维),array(一维)) 一个batch

for it, (en_zero_pad, en_batch_len, cn_zero_pad, cn_batch_len) in enumerate(train_loader):
    print(it)
    print(type(en_zero_pad))  # 
    print(type(en_batch_len))  # 
    print(en_zero_pad.squeeze(0))

这样就可以把加上的一维删掉,正常使用了。。。
至于为什么加了一维我不晓得,求大佬解答。。

你可能感兴趣的:(pytorch)