Pytorch torch.cat与torch.chunk

部分内容摘自pytorch中文文档

torch.cat

torch.cat(inputs, dimension=0) → Tensor ,在给定维度上对输入的张量序列seq 进行连接操作。
torch.cat()可以看做 torch.split()torch.chunk()的反操作, cat() 函数可以通过下面例子更好的理解。
参数:

  • inputs (sequence of Tensors) – 可以是任意相同Tensor 类型的python 序列
  • dimension (int, optional) – 沿着此维连接张量序列。
>>> x = torch.randn(2, 3)
>>> x

 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 2x3]

>>> torch.cat((x, x, x), 0)

 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
 0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 6x3]

>>> torch.cat((x, x, x), 1)

 0.5983 -0.0341  2.4918  0.5983 -0.0341  2.4918  0.5983 -0.0341  2.4918
 1.5981 -0.5265 -0.8735  1.5981 -0.5265 -0.8735  1.5981 -0.5265 -0.8735
[torch.FloatTensor of size 2x9]

torch.chunk

torch.chunk(tensor, chunks, dim=0) 在给定维度(轴)上将输入张量进行分块儿。
参数:

  • tensor (Tensor) – 待分块的输入张量
  • chunks (int) – 分块的个数
  • dim (int) – 沿着此维度进行分块
# -*- coding: utf-8 -*-
import torch

a = torch.ones([4, 8])
b = torch.zeros([4, 8])
c = torch.cat([a, b], 0)  # 第0个维度stack
d1, d2 = torch.chunk(c, 2, 0)
print("1--------------------------------")
print('a.size: ',end='');print(a.size())
print(a)
print('b.size: ',end='');print(b.size())
print(b)
print('c.size: ',end='');print(c.size())
print(c)
print('d1.size: ',end='');print(d1.size())
print(d1)
print('d2.size: ',end='');print(d2.size())
print(d2)

print('2----------------------------------')
a2 = a.view(1, 1, a.size(0), a.size(1))
print('a2.size: ',end='');print(a2.size())
print(a2)
b2 = b.view(1,1,b.size(0),b.size(1))
print('b2.size: ',end='');print(b2.size())
print(b2)
c2 = torch.cat([a2, b2], 2)  # 第0个维度stack
print('c2.size: ',end='');print(c2.size())
print(c2)
dd1,dd2 = torch.chunk(c2,2,2)

print('3------------------------------------')
print('dd1.size: ',end='');print(dd1.size())
print(dd1)
print('dd2.size: ',end='');print(dd2.size())
print(dd2)
c22 = torch.cat([dd1,dd2],2)
print('c22.size: ',end='');print(c22.size())
print(c22)

输出结果

1--------------------------------
a.size: torch.Size([4, 8])
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
b.size: torch.Size([4, 8])
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
c.size: torch.Size([8, 8])
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
d1.size: torch.Size([4, 8])
tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
d2.size: torch.Size([4, 8])
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
2----------------------------------
a2.size: torch.Size([1, 1, 4, 8])
tensor([[[[1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.]]]])
b2.size: torch.Size([1, 1, 4, 8])
tensor([[[[0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.]]]])
c2.size: torch.Size([1, 1, 8, 8])
tensor([[[[1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.]]]])
3------------------------------------
dd1.size: torch.Size([1, 1, 4, 8])
tensor([[[[1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.]]]])
dd2.size: torch.Size([1, 1, 4, 8])
tensor([[[[0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.]]]])
c22.size: torch.Size([1, 1, 8, 8])
tensor([[[[1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [1., 1., 1., 1., 1., 1., 1., 1.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0.]]]])

你可能感兴趣的:(Python)