【Pytorch】pack_padded_sequence与pad_packed_sequence实战详解

一、问题背景

  在NLP的相关任务中,我们使用RNN或LSTM处理文本序列时,通常来说句子的长度是不一致的,我们常常采用的方法使用< PAD >(0)来补全至相同长度的序列。虽然这个时候序列的长度是一致的,但是序列中填充了许多无效值 0 ,这个时候喂给 RNN 进行 forward 计算,不仅1.浪费计算资源,最后得到的值2.可能还会存在误差
  因此,为了解决这样的问题,在将序列送给 RNN 进行处理之前,需要采用 nn.utils.rnn.pack_padded_sequence 进行压缩,压缩掉无效的填充值。序列经过 RNN 处理之后的输出仍然是压紧的序列,需要采用 pad_packed_sequence 把压紧的序列再填充回来,便于进行后续的处理。

二、使用方法

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
import torch.nn.functional as F

# 将数据转换到GPU上
def to_cuda(x, use_cuda=True):
    if use_cuda and torch.cuda.is_available():
        x = x.cuda()
    return x

class EncoderRNN(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, dropout=None, \
        bidirectional=False, shared_embed=None, init_word_embed=None, rnn_type='lstm', use_cuda=True):
        super(EncoderRNN, self).__init__()
        if not rnn_type in ('lstm', 'gru'):
            raise RuntimeError('rnn_type is expected to be lstm or gru, got {}'.format(rnn_type))
        if bidirectional:
            print('[ Using bidirectional {} encoder ]'.format(rnn_type))
        else:
            print('[ Using {} encoder ]'.format(rnn_type))
        if bidirectional and hidden_size % 2 != 0:
            raise RuntimeError('hidden_size is expected to be even in the bidirectional mode!')
        self.dropout = dropout
        self.rnn_type = rnn_type
        self.use_cuda = use_cuda
        self.hidden_size = hidden_size // 2 if bidirectional else hidden_size
        self.num_directions = 2 if bidirectional else 1
        self.embed = shared_embed if shared_embed is not None else nn.Embedding(vocab_size, embed_size, padding_idx=0)
        model = nn.LSTM if rnn_type == 'lstm' else nn.GRU
        self.model = model(embed_size, self.hidden_size, 1, batch_first=True, bidirectional=bidirectional)
        if shared_embed is None:
            self.init_weights(init_word_embed)

    def init_weights(self, init_word_embed):
        if init_word_embed is not None:
            print('[ Using pretrained word embeddings ]')
            self.embed.weight.data.copy_(torch.from_numpy(init_word_embed))
        else:
            self.embed.weight.data.uniform_(-0.08, 0.08)

    def forward(self, x, x_len):
        """x: [batch_size * max_length]
           x_len: [batch_size] 45423
        """
        x = self.embed(x)
        if self.dropout:
            x = F.dropout(x, p=self.dropout, training=self.training)
        print("x = ", x)
        sorted_x_len, indx = torch.sort(x_len, dim=-1, descending=True)
        # print(sorted_x_len)
        # sort_x_len 是数据的真实长度,由大到小(这里是因为pack_padded_sequence函数中enforce_sorted参数默认为True,则输入的长度序列必须是降序排列)
        x = pack_padded_sequence(x[indx], sorted_x_len.data.tolist(), batch_first=True)
        print(x)
        h0 = to_cuda(torch.zeros(self.num_directions, x_len.size(0), self.hidden_size), self.use_cuda)
        if self.rnn_type == 'lstm':
            c0 = to_cuda(torch.zeros(self.num_directions, x_len.size(0), self.hidden_size), self.use_cuda)
            packed_h, (packed_h_t, _) = self.model(x, (h0, c0))
            print("1:",packed_h)
            print("2:",packed_h_t)
            if self.num_directions == 2:
                packed_h_t = torch.cat([packed_h_t[i] for i in range(packed_h_t.size(0))], -1)
        else:
            packed_h, packed_h_t = self.model(x, h0)
            if self.num_directions == 2:
                packed_h_t = packed_h_t.transpose(0, 1).contiguous().view(query_lengths.size(0), -1)
        
        hh, out_len = pad_packed_sequence(packed_h, batch_first=True)
        print("hh =", hh, out_len)
        # restore the sorting,把压紧的序列再填充回来
        o, inverse_indx = torch.sort(indx, 0)
        print(o, " ", inverse_indx)
        restore_hh = hh[inverse_indx]
        # restore_packed_h_t = packed_h_t[inverse_indx]
        return restore_hh # , restore_packed_h_t

  假设我们的输入是:queries、query_lengths

queries = torch.tensor([[1,2,3,4,0],[2,3,4,5,6],[4,5,6,7,0],[5,6,0,0,0],[6,7,8,0,0]])
query_lengths = torch.tensor([4,5,4,2,3])

Que_encoder = EncoderRNN(vocab_size=10, embed_size=4, hidden_size=4, \
                        bidirectional=False, \
                        rnn_type='lstm', \
                        use_cuda=False)

Q_r = Que_encoder(queries, query_lengths)
print("编码后的Que为:"Q_r)

'''
其中的  x = pack_padded_sequence(x[indx], sorted_x_len.data.tolist(), batch_first=True), 输出的x为:
PackedSequence(data=tensor([[-0.0783, -0.0091, -0.0775,  0.0731],
        [ 0.0100, -0.0372, -0.0139, -0.0669],
        [ 0.0167,  0.0425, -0.0596, -0.0765],
        [ 0.0652, -0.0491,  0.0274, -0.0430],
        [ 0.0217,  0.0014,  0.0430, -0.0057],
        [-0.0014, -0.0411, -0.0739, -0.0768],
        [-0.0783, -0.0091, -0.0775,  0.0731],
        [ 0.0217,  0.0014,  0.0430, -0.0057],
        [ 0.0518,  0.0061,  0.0161,  0.0411],
        [ 0.0652, -0.0491,  0.0274, -0.0430],
        [ 0.0167,  0.0425, -0.0596, -0.0765],
        [-0.0014, -0.0411, -0.0739, -0.0768],
        [ 0.0652, -0.0491,  0.0274, -0.0430],
        [-0.0562,  0.0797, -0.0044, -0.0591],
        [ 0.0217,  0.0014,  0.0430, -0.0057],
        [ 0.0167,  0.0425, -0.0596, -0.0765],
        [ 0.0518,  0.0061,  0.0161,  0.0411],
        [ 0.0652, -0.0491,  0.0274, -0.0430]],
       grad_fn=), batch_sizes=tensor([5, 5, 4, 3, 1]), sorted_indices=None, unsorted_indices=None)
       
	   注:这样输入LSTM中就不包含0,每一个时间步中输入到lstm的batch大小分别为[5, 5, 4, 3, 1]。
	   最终的Que的编码结果就是:[bts,seq_max,hidden_size] 
	   tensor([[[-0.1572, -0.0628,  0.1359,  0.0133],
         [-0.2072, -0.0917,  0.1704, -0.0067],
         [-0.2320, -0.1170,  0.1780, -0.0071],
         [-0.2383, -0.1252,  0.1847, -0.0125],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.1553, -0.0554,  0.1370, -0.0058],
         [-0.2124, -0.0962,  0.1708, -0.0013],
         [-0.2307, -0.1140,  0.1833, -0.0063],
         [-0.2367, -0.1202,  0.1781, -0.0072],
         [-0.2451, -0.1270,  0.1763, -0.0020]],

        [[-0.1570, -0.0598,  0.1424,  0.0083],
         [-0.2096, -0.0908,  0.1686,  0.0090],
         [-0.2342, -0.1123,  0.1744,  0.0117],
         [-0.2393, -0.1167,  0.1786,  0.0018],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.1567, -0.0560,  0.1341,  0.0124],
         [-0.2144, -0.0937,  0.1665,  0.0188],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.1595, -0.0614,  0.1344,  0.0208],
         [-0.2122, -0.0900,  0.1692,  0.0147],
         [-0.2282, -0.1101,  0.1799, -0.0051],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]]], grad_fn=)

'''

三、pack_padded_sequence原理详解

pad_packed_sequence函数实际上是 pack_padded_sequence 函数的逆向操作。就是把压紧的序列再填充回来。

#详细的解释过程图如下:
【Pytorch】pack_padded_sequence与pad_packed_sequence实战详解_第1张图片

你可能感兴趣的:(Python学习,深度学习,1024程序员节,pytorch,pack_padded_seq)