pytorch中torch.cat()函数用法

torch.cat(seq,dim,out=None)

其中seq表示要连接的两个序列,以元组的形式给出,例如:seq=(a,b), a,b 为两个可以连接的序列
dim 表示以哪个维度连接,dim=0, 横向连接,dim=1,纵向连接
举例如下:

import torch
a = torch.ones([1, 2])
b = torch.ones([1, 2])
print(torch.cat([a, b], 1)) # dim=1纵向连接
print(torch.cat([a, b], 0)) # dim=0横向连接

输出结果:

tensor([[1., 1., 1., 1.]])  
tensor([[1., 1.],
        [1., 1.]])

纵向连接之后,维度变成1*4
横向连接之后,维度变成2*2

你可能感兴趣的:(pytorch,深度不学习)