torch中dim(0/1)维度表示

dim维度


dim=0代表是列,dim=1代表是行

import torch
a = [[1,3,5],
	 [2,4,6],
	 [7,8,9]]
	 
a = torch.tensor(a).float()
t = a.mean(dim=0)  #dim=0代表是列
print(t)

输出结果(列求均值):
在这里插入图片描述

t = a.mean(dim=1) # dim=1代表是行
print(t)

输出结果(行求均值):

在这里插入图片描述

你可能感兴趣的:(经验分享,pytorch)