函数作用
移除input的指定维度,并将得到的多个张量拼接为一个tuple。
例如原维度为[2,3,4],那么移除dim=0后得到2个shape为[3,4]的张量,这2个张量拼接为一个tuple;如果移除dim=1,则得到3个shape为[2,4]张量,这3个张量拼接为一个tuple。
图解
移除dim=1,即沿着与维度1垂直的绿面和黄面进行切片,切片后得到3个[2,4]的张量。如果不是很明白,结合下面的代码实例以及运行结果来理解会好很多。
代码实例
input = torch.arange(24).reshape(2,3,4)
"""input为:
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
"""
print(torch.unbind(input,dim=0))
"""得到两个[3,4]
(tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]]),
tensor([[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]))
"""
print(torch.unbind(input,dim=1))
"""得到三个[2,4]
(tensor([[ 0, 1, 2, 3],
[12, 13, 14, 15]]),
tensor([[ 4, 5, 6, 7],
[16, 17, 18, 19]]),
tensor([[ 8, 9, 10, 11],
[20, 21, 22, 23]]))
"""
print(torch.unbind(input,dim=2))
"""得到四个[2,3]
(tensor([[ 0, 4, 8],
[12, 16, 20]]),
tensor([[ 1, 5, 9],
[13, 17, 21]]),
tensor([[ 2, 6, 10],
[14, 18, 22]]),
tensor([[ 3, 7, 11],
[15, 19, 23]]))
"""