pytorch torch.cat()

#1. 输入两个二维张量(dim=0):dim=0对行进行拼接 (行数增加)
    a = torch.randn(2, 3)
    b = torch.randn(3, 3)
    c = torch.cat((a, b), dim=0)
    print(a,'\n',b,'\n',c)
    print(a.shape,b.shape,c.shape)  # torch.Size([2, 3]) torch.Size([3, 3]) torch.Size([5, 3])

    # 2. 输入两个二维张量(dim=1): dim=1对列进行拼接(列数增加)
    a = torch.randn(2, 3)
    b = torch.randn(2, 4)
    c = torch.cat((a, b), dim=1)
    print(a.shape, b.shape, c.shape)  #torch.Size([2, 3]) torch.Size([2, 4]) torch.Size([2, 7])

    # 3. 输入两个三维张量:dim=0 对通道进行拼接(第一个维度数增加)
    a = torch.randn(2, 3, 4)
    b = torch.randn(1, 3, 4)
    c = torch.cat((a, b), dim=0)
    print(a.shape, b.shape, c.shape)  # torch.Size([2, 3, 4]) torch.Size([1, 3, 4]) torch.Size([3, 3, 4])

    # 4. 输入两个三维张量:dim=-1对行进行拼接
    a = torch.randn(16,85,768)
    b = torch.randn(16,85,768)
    c = torch.cat((a, b), dim=-1)  # dim=-1:指的是按最后的一个维度拼接,相当于dim=2
    print(a.shape, b.shape, c.shape)  # torch.Size([16, 85, 768]) torch.Size([16, 85, 768]) torch.Size([16, 85, 1536])

你可能感兴趣的:(My_python_code,pytorch,深度学习,python)