关于torch.max()、torch.mean()、torch.cat()的理解

torch.max()的理解

torch.max()一共有两种形式,如下:

第一种:torch.max(input)

这种形式直接求出input中所有数的最大值,输出是一个数,且output.dim()=0,即无论input是几维,输出都为一个0维的数,注意input需为张量。

示例代码如下:

>>>x1 = torch.tensor([1, 2, 3])
>>>torch.max(x1)
tensor(3)

>>>x2 = torch.tensor([[[1, 2],[3, 4]],
                  [[5, 6],[7, 8]]])
>>>torch.max(x2)
tensor(8)
第二种:torch.max(input, dim, keepdim=False, *, out=None)

这种形式可以根据需要求第几维度上的最大值,且可以选择输出的维度是否改变,返回最大值和第几维度上最大值索引。
dim,求第dim维度的最大值,例如dim=0,求第0维上的最大值,dim=1,求第1维上的最大值;
keepdim,当keepdim=False时,输出维度input.dim()改变,否则不变。
示例代码如下:

>>>a = torch.tensor([[[[ 1.,  2.]],
                   [[ 3.,  7.]],
                   [[ 5.,  6.]]],

                  [[[ 7.,  8.]],
                   [[ 13., 10.]],
                   [[11., 12.]]]])
>>>a.shape
torch.Size([2, 3, 1, 2])

>>>max1_a = torch.max(a, 1)
>>>max_a
torch.return_types.max(
values=tensor([[[ 5.,  7.]],

        [[13., 12.]]]),
indices=tensor([[[2, 1]],

        [[1, 2]]]))
        
>>>max_a[0]
tensor([[[ 5.,  7.]],

        [[13., 12.]]])
        
>>>max_a[0].shape
torch.Size([2, 1, 2])

>>>max2_a = torch.max(a, 1, keepdim=True)
>>>max2_a[0].shape
torch.Size([2, 1, 1, 2])

torch.mean()的理解

torch.mean()同 torch.max()一样也是有两种用法,一个是求输入种所有数的平均值,一个是求输入种第几维度的平均值。
使用时同 torch.max()一样,见上torch.max()的理解,只需将max换成mean就行,这里就不举例说明了。

torch.cat()的理解

torch.cat(tensors, dim=0, *, out=None)

torch.cat(),就是在第几维度上连接张量,输入为张量元组,例如dim=0,在第0维连接,dim=1,在第1维上连接。
代码示例如下:

>>>m = torch.tensor([[[[ 1.,  2.]],
                   [[ 3.,  4.]],
                   [[ 5.,  6.]]],

                  [[[ 7.,  8.]],
                   [[ 9., 10.]],
                   [[11., 12.]]]])      
>>>n = torch.tensor([[[[ 1.,  2.]],
                   [[ 1.,  2.]],
                   [[ 2.,  1.]]],

                  [[[ 4.,  8.]],
                   [[ 5.,  8.]],
                   [[ 6.,  7.]]]])
>>>n.shape
torch.Size([2, 3, 1, 2])     
                
>>>torch.cat((m, n), dim=0)
tensor([[[[ 1.,  2.]],

         [[ 3.,  4.]],

         [[ 5.,  6.]]],


        [[[ 7.,  8.]],

         [[ 9., 10.]],

         [[11., 12.]]],


        [[[ 1.,  2.]],

         [[ 1.,  2.]],

         [[ 2.,  1.]]],


        [[[ 4.,  8.]],

         [[ 5.,  8.]],

         [[ 6.,  7.]]]])

>>>torch.cat((m, n), dim=0).shape
torch.Size([4, 3, 1, 2])

>>>torch.cat((m, n), dim=1)
tensor([[[[ 1.,  2.]],

         [[ 3.,  4.]],

         [[ 5.,  6.]],

         [[ 1.,  2.]],

         [[ 1.,  2.]],

         [[ 2.,  1.]]],


        [[[ 7.,  8.]],

         [[ 9., 10.]],

         [[11., 12.]],

         [[ 4.,  8.]],

         [[ 5.,  8.]],

         [[ 6.,  7.]]]])
         
>>>torch.cat((m, n), dim=1).shape
torch.Size([2, 6, 1, 2])

你可能感兴趣的:(深度学习,python,pytorch,深度学习)