seq2seq logits和labels维度不一致问题

在decode阶段,训练结果的维度通常由trainingHelper中的sequence_length指定。

training_logits的shape为batch_size*target_sequence_length*decoder_embeding_size,

其中第二维度是每批target数据的固定长度,

embeding_size是转换后的解码层维度,

即decoder_embed_input向量的列维

  with tf.variable_scope("decode"):
        # 得到help对象
        training_helper = tf.contrib.seq2seq.TrainingHelper(inputs=decoder_embed_input,
                                                            sequence_length=target_sequence_length,
                                                            time_major=False)
        # 构造decoder
        training_decoder = tf.contrib.seq2seq.BasicDecoder(cell,
                                                           training_helper,
                                                           initial_state,
                                                           output_layer) 
        training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(training_decoder,
                                                                       impute_finished=True,
                                                                       maximum_iterations=max_target_sequence_length)
training_logits = tf.identity(training_decoder_output.rnn_output, 'logits')

training_logits的第二维度需要与targets的第二维度保持一致。

其中traing_logits的shape是[batch-size,target_sequence_length,embeding-size]

targets的shape 是[batch_size,target_sequence_length]

因此logits和labels维度不一致问题通常在于target_sequence_length不一致

  cost = tf.contrib.seq2seq.sequence_loss(
            training_logits,
            targets,
            masks)

在生成batch的函数中,注意保持targets-sequence_length一致即可

        # 记录每批记录的长度
        targets_lengths = [max_decoder_seq_length]*batch_size
#         targets_lengths=[]
#         for target in targets_batch:
#             targets_lengths.append(len(target))


你可能感兴趣的:(tensorflow)