Tensorflow代码转pytorch代码 函数的转换

tensoflow函数和pytorch函数之间的转换

tensorflow pytroch
tf.reshape(input, shape) input.view()
tf.expand_dims(input, dim) input.unsqueeze(dim) / input.view()
tf.squeeze(input, dim) torch.squeeze(dim)/ input.view()
tf.gather(input1, input2) input1[input2]
tf.tile(input, shape) input.repeat(shape)
tf.boolean_mask(input, mask) input[mask] #注意,mask是bool值,不是0,1的数值
tf.concat(input1, input2) torch.cat(input1, input2)
tf.matmul() torch.matmul()
tf.minium(input, min) torch.clamp(input, max=min)
tf.equal(input1, input2) torch.eq(input1, input2)/ input1 == input2
tf.logical_and(input1, input2) input1 & input2
tf.logical_not(input) ~ input
tf.reduce_logsumexp(input, [dim]) torch.logsumexp(input, dim=dim)
tf.reduce_any(input, dim) input.any(dim)
tf.reduce_mean(input) torch.mean(input)
tf.reduce_sum(input) input.sum()
tf.transpose(input) input.t()
tf.softmax_cross_entroy_with_logits(logits, labels) torch.nn.CrossEntropyLoss(logits, labels)

你可能感兴趣的:(pytorch,tensorflow)