pytorch 常用函数 max ,eq

max找出tensor 的行或者列最大的值:

找出每行的最大值:

import torch

outputs=torch.FloatTensor([[1],[2],[3]])

print(torch.max(outputs.data,1))

输出:
(tensor([ 1., 2., 3.]), tensor([ 0, 0, 0]))

找出每列的最大值:

import torch

outputs=torch.FloatTensor([[1],[2],[3]])

print(torch.max(outputs.data,0))

输出结果:
(tensor([ 3.]), tensor([ 2]))

Tensor比较eq相等:

import torch

outputs=torch.FloatTensor([[1],[2],[3]])
targets=torch.FloatTensor([[0],[2],[3]])
print(targets.eq(outputs.data))

输出结果:
tensor([[ 0],
[ 1],
[ 1]], dtype=torch.uint8)

使用sum() 统计相等的个数:

import torch

outputs=torch.FloatTensor([[1],[2],[3]])
targets=torch.FloatTensor([[0],[2],[3]])
print(targets.eq(outputs.data).cpu().sum())

输出结果:
tensor(2)

你可能感兴趣的:(pytorch)