Transformer代码实例中各张量的维度是多少

一下是一个Transformer代码实例:

def sample(self, batch_size, max_length=140, con_token_list= ['is_JNK3', 'is_GSK3', 'high_QED', 'good_SA']):
        """
               Sample a batch of sequences

               Args:
                   batch_size : Number of sequences to sample
                   max_length:  Maximum length of the sequences

               Outputs:
               seqs: (batch_size, seq_length) The sampled sequences.
               log_probs : (batch_size) Log likelihood for each sequence.
               entropy: (batch_size) The entropies for the sequences. Not
                                       currently used.
       """

        # conditional token
        con_token_list = Variable(self.voc.encode(con_token_list))

        con_tokens = Variable(torch.zeros(batch_size, len(con_token_list)).long()) #形状为 (batch_size, len(con_token_list)),表示条件标记的张量。

        for ind, token in enumerate(con_token_list):
            con_tokens[:, ind] = token

        start_token = Variable(torch.zeros(batch_size, 1).long())  #形状为 (batch_size, 1),表示序列开始标记的张量。
        start_token[:] = self.voc.vocab['GO']
        input_vector = start_token   # 在循环中更新的张量,它的形状与 sequences 相同。
        # print(batch_size)

        sequences = start_token
        log_probs = Variable(torch.zeros(batch_size))
        # log_probs1 = Variable(torch.zeros(batch_size))

        finished = torch.zeros(batch_size).byte()

        finished = finished.to(self.device)

        for step in range(max_length):
            logits = sample_forward_model(self.decodertf, input_vector, con_tokens) #形状为 (batch_size, max_length, vocab_size)。

            logits_step = logits[:, step, :]  #是从 logits 中选择当前时间步的张量,形状为 (batch_size, vocab_size)。

            prob = F.softmax(logits_step, dim=1)
            log_prob = F.log_softmax(logits_step, dim=1)

            input_vector = torch.multinomial(prob, 1)

            # need to concat prior words as the sequences and input 记录下每一步采样
            sequences = torch.cat((sequences, input_vector), 1)  #形状为 (batch_size, seq_length),表示生成的序列。


            log_probs += self._nll_loss(log_prob, input_vector.view(-1))  #形状为 (batch_size),表示每个生成序列的对数似然。
            # log_probs1 += NLLLoss(log_prob, input_vector.view(-1))
            # print(log_probs1==-log_probs)




            EOS_sampled = (input_vector.view(-1) == self.voc.vocab['EOS']).data
            finished = torch.ge(finished + EOS_sampled, 1)  #形状为 (batch_size),是一个二进制张量,表示每个序列是否已经结束。

            if torch.prod(finished) == 1:
                # print('End')
                break

            # because there are no hidden layer in transformer, so we need to append generated word in every step as the input_vector
            input_vector = sequences

        return sequences[:, 1:].data, log_probs

你可能感兴趣的:(transformer,深度学习,人工智能)