pytorch中stack()和cat()的理解和区别图解

torch.cat() 和 torch.stack() 函数的作用都是将多个维度参数相同的张量连接成一个张量,不同之处在与 stock()相比于cat()多了一维。这里两个函数都有 dim 这个参数,但是指的意思却不一样。使用下图来解释,在这里将两个张量理解成树这种形式,希望可以帮助理解。

这里竖线代表一个维度,竖线上所有节点代表同一维度的所有元素,在下面所有图中,同颜色的元素都是按照从上往下按顺序排列的。

dim在cat()函数中表示索所要连接的维度,也就是连接 所要连接的多个张量 的这个维度上面的参数。

但是在stack()中,dim表示多出来的维度,这个维度被用来连接之后维度的参数。原来的维度则变成子节点了,例如dim=1,那么 原来张量的第一维度 就变成了 连接之后的张量 的第二维度

假设这里一个torch.randn(2, 3, 4)生成的两个张量,如下图
pytorch中stack()和cat()的理解和区别图解_第1张图片

红和蓝分别表示两个不同的张量,后面所有的图中左边的是使用stack()函数,右边是使用cat()函数,黄色的表示stack()函数生成的多的一维。
那么当 dim = 0时,如下图pytorch中stack()和cat()的理解和区别图解_第2张图片


dim = 1, 如下图

pytorch中stack()和cat()的理解和区别图解_第3张图片


dim = 2,如下图
pytorch中stack()和cat()的理解和区别图解_第4张图片


对于stack()函数生成的结果会多一个维度,所有在这个例子中会有3这个索引值所代表的第四维度,dim = 3是成立的,但是对于cat()函数则没有这个

pytorch中stack()和cat()的理解和区别图解_第5张图片

你可能感兴趣的:(python,pytorch,深度学习)