Intro: When I was processing the data set and practicing how to change the data by dimension, the “dim” confused me.
Start: Let’s read the codes. Congratulations, if you know why it’s happened.
import torch
a = torch.ones(
(2, 5, 4)
)
print(a.shape)
print("value of scalar: ",a.sum())
print("scalar: ",a.sum().shape)
print("a: ", a)
print("axis=1", a.sum(axis=1))
print("axis=1, keepdims=True: \n", a.sum(axis=1, keepdims=True))
print("axis=2", a.sum(axis=2))
print("axis=2, keepdims=True: \n", a.sum(axis=2, keepdims=True))
print("axis=0", a.sum(axis=0))
print("axis=0, keepdims=True: \n", a.sum(axis=0, keepdims=True))
print("axis=[0, 2]", a.sum(axis=[0, 2]))
print("axis=[0, 2], keepdims=True: \n", a.sum(axis=[0, 2], keepdims=True))
1. Let’s focus on the “.shape”.
The “.shape” has the “index”
2. Let’s focus on Computing.
when “axis=1” your eyes should focus on the [ ] whose index is 1. and you can find the number of [[ ]] is 2. So, what are the fundamental elements of [ ] of index 2? The [ ] of index 3! Perfect!!
Conclusion
The others are the same things. The key thinking is that if you wanna compute the tensor which is changing by the “dim” parameter. You should pay attention to the value of the “dim” and through the “.shape” index and the index of the tensor to get the true conclusion.