tf.reduce_sum与torch.sum

  • dim 参数对应的维数消失
this_platf_target_outputs = torch.reshape(this_platf_target_outputs, shape=(self.batch_size, self.seq_len, self.num_nodes, self.units))  # (64, 6, 30, 16)

this_platf_target_outputs = torch.sum(this_platf_target_outputs, dim=2)  # shape=(batch_size, seq_len, units)
this_platf_diff_outputs = tf.reshape(this_platf_diff_outputs, shape=(self.batch_size, self.seq_len, self.num_nodes, self.units))
this_platf_diff_outputs = tf.reduce_sum(this_platf_diff_outputs, axis=2) #shape=(batch_size, seq_len, units)

你可能感兴趣的:(tensorflow)