torch.max()一共有两种形式,如下:
第一种:torch.max(input)
这种形式直接求出input中所有数的最大值,输出是一个数,且output.dim()=0,即无论input是几维,输出都为一个0维的数,注意input需为张量。
示例代码如下:
>>>x1 = torch.tensor([1, 2, 3])
>>>torch.max(x1)
tensor(3)
>>>x2 = torch.tensor([[[1, 2],[3, 4]],
[[5, 6],[7, 8]]])
>>>torch.max(x2)
tensor(8)
第二种:torch.max(input, dim, keepdim=False, *, out=None)
这种形式可以根据需要求第几维度上的最大值,且可以选择输出的维度是否改变,返回最大值和第几维度上最大值索引。
dim,求第dim维度的最大值,例如dim=0,求第0维上的最大值,dim=1,求第1维上的最大值;
keepdim,当keepdim=False时,输出维度input.dim()改变,否则不变。
示例代码如下:
>>>a = torch.tensor([[[[ 1., 2.]],
[[ 3., 7.]],
[[ 5., 6.]]],
[[[ 7., 8.]],
[[ 13., 10.]],
[[11., 12.]]]])
>>>a.shape
torch.Size([2, 3, 1, 2])
>>>max1_a = torch.max(a, 1)
>>>max_a
torch.return_types.max(
values=tensor([[[ 5., 7.]],
[[13., 12.]]]),
indices=tensor([[[2, 1]],
[[1, 2]]]))
>>>max_a[0]
tensor([[[ 5., 7.]],
[[13., 12.]]])
>>>max_a[0].shape
torch.Size([2, 1, 2])
>>>max2_a = torch.max(a, 1, keepdim=True)
>>>max2_a[0].shape
torch.Size([2, 1, 1, 2])
torch.mean()同 torch.max()一样也是有两种用法,一个是求输入种所有数的平均值,一个是求输入种第几维度的平均值。
使用时同 torch.max()一样,见上torch.max()的理解,只需将max换成mean就行,这里就不举例说明了。
torch.cat(tensors, dim=0, *, out=None)
torch.cat(),就是在第几维度上连接张量,输入为张量元组,例如dim=0,在第0维连接,dim=1,在第1维上连接。
代码示例如下:
>>>m = torch.tensor([[[[ 1., 2.]],
[[ 3., 4.]],
[[ 5., 6.]]],
[[[ 7., 8.]],
[[ 9., 10.]],
[[11., 12.]]]])
>>>n = torch.tensor([[[[ 1., 2.]],
[[ 1., 2.]],
[[ 2., 1.]]],
[[[ 4., 8.]],
[[ 5., 8.]],
[[ 6., 7.]]]])
>>>n.shape
torch.Size([2, 3, 1, 2])
>>>torch.cat((m, n), dim=0)
tensor([[[[ 1., 2.]],
[[ 3., 4.]],
[[ 5., 6.]]],
[[[ 7., 8.]],
[[ 9., 10.]],
[[11., 12.]]],
[[[ 1., 2.]],
[[ 1., 2.]],
[[ 2., 1.]]],
[[[ 4., 8.]],
[[ 5., 8.]],
[[ 6., 7.]]]])
>>>torch.cat((m, n), dim=0).shape
torch.Size([4, 3, 1, 2])
>>>torch.cat((m, n), dim=1)
tensor([[[[ 1., 2.]],
[[ 3., 4.]],
[[ 5., 6.]],
[[ 1., 2.]],
[[ 1., 2.]],
[[ 2., 1.]]],
[[[ 7., 8.]],
[[ 9., 10.]],
[[11., 12.]],
[[ 4., 8.]],
[[ 5., 8.]],
[[ 6., 7.]]]])
>>>torch.cat((m, n), dim=1).shape
torch.Size([2, 6, 1, 2])