【Pytorch基础】torch.stack()函数解析

目录

  • 1 函数作用
  • 2 例子
    • 2.1 沿dim=0拼接
    • 2.2 dim=1
    • 2.3 dim=2
  • 3 参考文献

1 函数作用

  官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状,注意与torch.cat的区别
  浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠。

2 例子

import torch
import numpy as np
# 创建3*3的矩阵,a、b
a=np.array([[1,2,3],[4,5,6],[7,8,9]])
b=np.array([[10,20,30],[40,50,60],[70,80,90]])
# 将矩阵转化为Tensor
a = torch.from_numpy(a)
b = torch.from_numpy(b)

tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]], dtype=torch.int32)
tensor([[10, 20, 30],
        [40, 50, 60],
        [70, 80, 90]], dtype=torch.int32)

2.1 沿dim=0拼接

d = torch.stack((a, b), dim=0)
print(d)
print(d.size())
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[10, 20, 30],
         [40, 50, 60],
         [70, 80, 90]]], dtype=torch.int32)
torch.Size([2, 3, 3])

  当dim = 0,原来的每一个矩阵也变成了一个维度,一个矩阵看做一个整体。有几个矩阵,新的维度就是几,第几个矩阵就是第几维。

2.2 dim=1

d = torch.stack((a, b), dim=1)
print(d)
print(d.size())
tensor([[[ 1,  2,  3],
         [10, 20, 30]],

        [[ 4,  5,  6],
         [40, 50, 60]],

        [[ 7,  8,  9],
         [70, 80, 90]]], dtype=torch.int32)
torch.Size([3, 2, 3])

  将每个矩阵的第一行组成第一维矩阵,依次下去,每个矩阵的第n行组成第n维矩阵。size=(n,i,y)

2.3 dim=2

d = torch.stack((a, b), dim=2)
print(d)
print(d.size())
tensor([[[ 1, 10],
         [ 2, 20],
         [ 3, 30]],

        [[ 4, 40],
         [ 5, 50],
         [ 6, 60]],

        [[ 7, 70],
         [ 8, 80],
         [ 9, 90]]], dtype=torch.int32)
torch.Size([3, 3, 2])

dim=2的理解可以参考文献【3】

3 参考文献

[1]【Pytorch】torch.stack()的使用
[2]看完秒懂torch.stack()
[3]初学torch.stack()对dim的个人理解

你可能感兴趣的:(Pytorch学习,pytorch,深度学习,python)