x=torch.arange(15).view(5,3)
x_mean=torch.mean(x,dim=0,keepdim=True)
x_mean0=torch.mean(x,dim=1,keepdim=True)
print('before bn:')
print(x)
print('x_mean:')
print(x_mean)
print('x_mean0:')
print(x_mean0)
before bn:
0 1 2
3 4 5
6 7 8
9 10 11
12 13 14
[torch.FloatTensor of size 5x3]
x_mean:
6 7 8
[torch.FloatTensor of size 1x3]
x_mean0:
1
4
7
10
13
[torch.FloatTensor of size 5x1]