pytorch 基本运算函数

生成随机数:

torch.rand()
torch.randn()

tensor乘法:

torch.mm(tensor1, tensor2)

返回tensor最大值, 前K个最大值:

value, index = torch.max(input, dim)
value, index = torch.topk(input, k, dim)

tensor拼接:

torch.cat(list, index)

tensor转置:

torch.Tensor.permute(0,1,2,3,4)

unsqueeze:添加维度 [3,4] -> [1,3,4]

torch.Tensor.unsqueeze(dim)
torch.unsqueeze(tensor, dim)

squeeze:去掉维数为1的维度 [1,3,4] -> [3,4]

torch.Tensor.squeeze()
torch.squeeze(tensor)

 

 

 

 

 

 

 

你可能感兴趣的:(python)