查了好多资料,包括pytorch的官方文档(https://pytorch.org/docs/stable/generated/torch.mean.html#torch.mean)都没有说明mean中[a,b]到底是什么意思。写该文的目的是让更多的人能快速明白代码意思,并给自己做个备忘。
解析:torch.mean(x,[a,b],keepdim=True)中[a,b]的意思是,沿着将第a和第b维的维度变为1的方向做均值,其余维度不变。
直接上例子:
import torch
a = torch.tensor([
[[1, 2, 3],
[4, 5, 6]],
[[1, 2, 3],
[4, 5, 6]],
[[1, 2, 3],
[4, 5, 6]]]).float()
b = torch.mean(a, [0, 1], keepdim=True)
print(a.shape)
print(b.shape)
print(b)
结果输出:
torch.Size([3, 2, 3])
torch.Size([1, 1, 3])
tensor([[[2.5000, 3.5000, 4.5000]]])
保留a的第2维,沿着a的第0和第1维做均值,a的维度为[3, 2, 3],输出b的维度变为[1, 1, 3]。b中2.5是通过计算a中(1 + 4) * 3 / 6 得来的,即按列方向求均值,其余元素同理。
再看一个例子:
import torch
a = torch.tensor([
[[1, 2, 3],
[4, 5, 6]],
[[1, 2, 3],
[4, 5, 6]],
[[1, 2, 3],
[4, 5, 6]]]).float()
b = torch.mean(a, [0, 2], keepdim=True)
print(a.shape)
print(b.shape)
print(b)
结果输出:
torch.Size([3, 2, 3])
torch.Size([1, 2, 1])
tensor([[[2.],
[5.]]])
保留a的第1维,沿着a的第0和第2维做均值,a的维度为[3, 2, 3],输出b的维度变为[1, 2, 1]。b中2.是通过计算a中(1 + 2 + 3) * 3 / 9 得来的;b中5.是通过计算a中(4 + 5 + 6) * 3 / 9 得来的。即按行的方向求均值。
再看一个例子:
import torch
a = torch.tensor([
[[1, 2, 3],
[4, 5, 6]],
[[1, 2, 3],
[4, 5, 6]],
[[1, 2, 3],
[4, 5, 6]]]).float()
b = torch.mean(a, [1, 2], keepdim=True)
print(a.shape)
print(b.shape)
print(b)
结果输出:
torch.Size([3, 2, 3])
torch.Size([3, 1, 1])
tensor([[[3.5000]],
[[3.5000]],
[[3.5000]]])
保留a的第0维,沿着a的第1和第2维做均值,a的维度为[3, 2, 3],输出b的维度变为[3, 1, 1]。b中第一个3.5是通过计算a中小矩阵(1 + 2 + 3 + 4 + 5 + 6) / 6 得来的,其余元素同理。即,保留第0维的维度,计算a中每个小矩阵的均值。
第一次写文章,转载、引用请注明出处,谢谢!