pytorch学习记录1-torch.max及torch.unsqueeze()的用法

1. torch.max(x,0):返回一列里的最大值及对应的索引。

2. torch.max(x,1):返回一行里的最大值及对应的索引。

3. torch.unsqueeze(input,dim):指定一个dim的位置插入一个维度为1的tensor

4. torch.cat((a,b),dim=0):将a,b按0维度(行)拼接,torch.cat((a,b),dim=1):将a,b按1维(列)度拼接,例子:

>>> import torch
>>> A=torch.ones(2,3) #2x3的张量(矩阵)                                     
>>> A
tensor([[ 1.,  1.,  1.],
        [ 1.,  1.,  1.]])
>>> B=2*torch.ones(4,3)#4x3的张量(矩阵)                                    
>>> B
tensor([[ 2.,  2.,  2.],
        [ 2.,  2.,  2.],
        [ 2.,  2.,  2.],
        [ 2.,  2.,  2.]])
>>> C=torch.cat((A,B),0)#按维数0(行)拼接
>>> C
tensor([[ 1.,  1.,  1.],
         [ 1.,  1.,  1.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.]])
>>> C.size()
torch.Size([6, 3])
>>> D=2*torch.ones(2,4) #2x4的张量(矩阵)
>>> C=torch.cat((A,D),1)#按维数1(列)拼接
>>> C
tensor([[ 1.,  1.,  1.,  2.,  2.,  2.,  2.],
        [ 1.,  1.,  1.,  2.,  2.,  2.,  2.]])
>>> C.size()
torch.Size([2, 7])

你可能感兴趣的:(python)