Pytorch函数

  • max()
    torch.max(input, dim)
    dim参数指出删去哪一维度,0-行,1-列;输出两个tensor,第一个得到最大值结果,第二个给出相对位置(0-index)
>>> a = torch.randn(4, 4)
>>> a
tensor([[-1.2360, -0.2942, -0.1222,  0.8475],
        [ 1.1949, -1.1127, -2.2379, -0.6702],
        [ 1.5717, -0.9207,  0.1297, -1.8768],
        [-0.6172,  1.0036, -0.6060, -0.2432]])
>>> torch.max(a, 1)
(tensor([ 0.8475,  1.1949,  1.5717,  1.0036]), tensor([ 3,  0,  0,  1]))

dim=1,删除列的维度,只有1列,每一行为该行最大值,第二个tensor给出该最大值所在的列数
等同于a.max(1)
例:在训练网络时

output = net(img)
 _, predicted = output.max(1)

output为对img的预测输出,batch行label列,每行是一个图片的输出,每次输出batch组。所以预测结果需要看每行的最大值,找每行最大值的位置。output.max(1)找到每行最大值,有两个tensor输出,第一个为最大值,第二个为最大值所在位置,所关注的是位置,所以第一个下划线_舍弃掉最大值。

  • item()
    把tensor转换成数

  • torch.nn.Sequential语法
    nn.Sequential(a, b, c)
    括号,逗号

  • torchvision.transforms.Composed语法
    transforms.Composed([a, b, c])
    括号,方括号,逗号

你可能感兴趣的:(Pytorch函数)