二维可以想象成一张纸,
三维可以想象成多张纸叠在一块
四维可以想成多沓纸
求和时,如果没设定keepdim=True,则会消去相加的那一维度,否则则将维度变为1
A = torch.arange(20).reshape(5, 4)
A,A.shape, A.sum()
(tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15],
[16, 17, 18, 19]]),
torch.Size([5, 4]),
tensor(190))
A_sum_axis0 = A.sum(axis=0)
A_sum_axis0, A_sum_axis0.shape
(tensor([40, 45, 50, 55]), torch.Size([4]))
A_sum_axis1 = A.sum(axis=1)
A_sum_axis1, A_sum_axis1.shape
(tensor([ 6, 22, 38, 54, 70]), torch.Size([5]))
# 等价于A.SUM()
A.sum(axis=[0, 1]), A.sum(axis=[0, 1]).shape
(tensor(190), torch.Size([]))
# 三维 测试
SA = torch.arange(20 * 2).reshape(2, 5, 4)
SA, SA.shape
(tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15],
[16, 17, 18, 19]],
[[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31],
[32, 33, 34, 35],
[36, 37, 38, 39]]]),
torch.Size([2, 5, 4]))
SA_sum_axis0 = SA.sum(axis=0)
SA_sum_axis0, SA_sum_axis0.shape
(tensor([[20, 22, 24, 26],
[28, 30, 32, 34],
[36, 38, 40, 42],
[44, 46, 48, 50],
[52, 54, 56, 58]]),
torch.Size([5, 4]))
SA_sum_axis1 = SA.sum(axis=1)
SA_sum_axis1, SA_sum_axis1.shape
(tensor([[ 40, 45, 50, 55],
[140, 145, 150, 155]]),
torch.Size([2, 4]))
SA_sum_axis2 = SA.sum(axis=2)
SA_sum_axis2, SA_sum_axis2.shape
(tensor([[ 6, 22, 38, 54, 70],
[ 86, 102, 118, 134, 150]]),
torch.Size([2, 5]))
A = torch.arange(20*2*2).reshape((2,2,5,4))
A, A.shape
(tensor([[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15],
[16, 17, 18, 19]],
[[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31],
[32, 33, 34, 35],
[36, 37, 38, 39]]],
[[[40, 41, 42, 43],
[44, 45, 46, 47],
[48, 49, 50, 51],
[52, 53, 54, 55],
[56, 57, 58, 59]],
[[60, 61, 62, 63],
[64, 65, 66, 67],
[68, 69, 70, 71],
[72, 73, 74, 75],
[76, 77, 78, 79]]]]),
torch.Size([2, 2, 5, 4]))
A_sum_axis0 = A.sum(axis=0)
A_sum_axis0, A_sum_axis0.shape
(tensor([[[ 40, 42, 44, 46],
[ 48, 50, 52, 54],
[ 56, 58, 60, 62],
[ 64, 66, 68, 70],
[ 72, 74, 76, 78]],
[[ 80, 82, 84, 86],
[ 88, 90, 92, 94],
[ 96, 98, 100, 102],
[104, 106, 108, 110],
[112, 114, 116, 118]]]),
torch.Size([2, 5, 4]))
A_sum_axis1 = A.sum(axis=1)
A_sum_axis1, A_sum_axis1.shape
(tensor([[[ 20, 22, 24, 26],
[ 28, 30, 32, 34],
[ 36, 38, 40, 42],
[ 44, 46, 48, 50],
[ 52, 54, 56, 58]],
[[100, 102, 104, 106],
[108, 110, 112, 114],
[116, 118, 120, 122],
[124, 126, 128, 130],
[132, 134, 136, 138]]]),
torch.Size([2, 5, 4]))