对象
:给定的序列化张量,即Tensor
型。
功能
:实现两个张量在指定维度上的拼接。
输出
:拼接后的张量。
函数以及参数
:torch.cat(tensor, dim),官方给出的有四个参数,但是我们平时只会用到前两个参数即可。
tensor
:有相同形状的张量序列,所有的张量需要有相同的形状才能够拼接,除非是在拼接维度上两个张量可以有不同的尺寸,或者两个张量都是空的。
dim
:两个张量或者多个张量拼接的维度。
应用实例1:两个张量形状相同
代码
:
import torch
x = torch.randn(2,4)
y = torch.randn(2,4)
print(f'x={x}','\n',f'y={y}')
print(f'z={torch.cat((x,y), 0)}')
输出
:
x=tensor([[-1.2870, -0.7040, 0.3016, -0.2970],
[-0.8151, -0.5236, -1.7680, 0.7675]])
y=tensor([[-1.4207, -0.2694, 0.2521, -0.7187],
[ 0.8776, -0.0352, -0.5094, 0.0602]])
z=tensor([[-1.2870, -0.7040, 0.3016, -0.2970],
[-0.8151, -0.5236, -1.7680, 0.7675],
[-1.4207, -0.2694, 0.2521, -0.7187],
[ 0.8776, -0.0352, -0.5094, 0.0602]])
应用实例2:多个张量形状相同
代码
:
import torch
x = torch.randn(2,4)
y = torch.randn(2,4)
print(f'x={x}','\n',f'y={y}')
print(f'z={torch.cat((x,y,x,y), 0)}')
输出
:
x=tensor([[ 0.4697, -0.4881, -2.0199, -0.8661],
[ 0.4911, -0.1259, 1.1939, 0.7730]])
y=tensor([[ 0.8633, 0.4438, -0.6975, 0.5440],
[ 0.1554, -1.6358, -1.2234, -0.6597]])
z=tensor([[ 0.4697, -0.4881, -2.0199, -0.8661],
[ 0.4911, -0.1259, 1.1939, 0.7730],
[ 0.8633, 0.4438, -0.6975, 0.5440],
[ 0.1554, -1.6358, -1.2234, -0.6597],
[ 0.4697, -0.4881, -2.0199, -0.8661],
[ 0.4911, -0.1259, 1.1939, 0.7730],
[ 0.8633, 0.4438, -0.6975, 0.5440],
[ 0.1554, -1.6358, -1.2234, -0.6597]])
应用实例3:两个张量形状不同,但只在拼接维度上
代码
:
import torch
x = torch.randn(3,4)
y = torch.randn(2,4)
print(f'x={x}','\n',f'y={y}')
print(f'z={torch.cat((x,y), 0)}')
x_1 = torch.randn(2,3)
y_1 = torch.randn(2,4)
print(f'x_1=\n{x_1}','\n',f'y_1=\n{y_1}')
print(f'z_1=\n{torch.cat((x_1,y_1), 1)}')
输出
:
x=tensor([[-0.1966, -0.9648, 1.2787, -1.4578],
[-1.2216, 0.1663, 0.5380, -0.0376],
[-1.7365, -0.4151, -1.0336, -0.6732]])
y=tensor([[ 1.4477, 0.3616, -0.1504, 0.4662],
[-1.1334, 1.3100, 0.1624, 0.8206]])
z=tensor([[-0.1966, -0.9648, 1.2787, -1.4578],
[-1.2216, 0.1663, 0.5380, -0.0376],
[-1.7365, -0.4151, -1.0336, -0.6732],
[ 1.4477, 0.3616, -0.1504, 0.4662],
[-1.1334, 1.3100, 0.1624, 0.8206]])
x_1=
tensor([[ 1.1418, 0.0774, 0.2047],
[-0.0673, -1.5794, 0.0131]])
y_1=
tensor([[ 1.4149, -1.9538, 0.1660, 1.1142],
[-1.6455, 0.5595, -0.1162, 0.8628]])
z_1=
tensor([[ 1.1418, 0.0774, 0.2047, 1.4149, -1.9538, 0.1660, 1.1142],
[-0.0673, -1.5794, 0.0131, -1.6455, 0.5595, -0.1162, 0.8628]])
Process finished with exit code 0
对象
:给定的张量,即Tensor
型。
功能
:在指定的维度上对张量进行重复扩充,也可以用来增加维度。
输出
:升维或扩充后的张量。
函数以及参数
:torch.tensor.repeat(size),size所在的索引表示扩充的维度的索引。
size
:表示张量在这个索引维度下的扩充倍数。
注意事项
:函数的参数量必须大于等于tensor的维度,如a.shape=(2,3),那么如果我们想扩充2倍a的第0个维度时,应该这么写a.repeat(2,1),对于不扩充的维度则写1。
应用实例1:一维张量扩充
代码
:
import torch
x = torch.randn(3)
print(f'x={x}')
print(f'x_1={x.repeat(2)}')
输出
:
x=tensor([-0.1485, 1.8445, 1.4257])
x_1=tensor([-0.1485, 1.8445, 1.4257, -0.1485, 1.8445, 1.4257])
应用实例2:多维张量扩充
代码
:
import torch
x = torch.randn(3, 4, 3)
print(f'x={x}')
#在第2个维度上扩充两倍,其他维度保持不变
print(f'x_1={x.repeat(1,1,2)}')
输出
:
x=tensor([[[-0.0294, 1.2902, 0.9825],
[-0.3032, 1.6733, 0.9163],
[ 0.3079, -0.0159, 0.2626],
[-0.2934, -0.6076, 0.1593]],
[[ 1.7661, -1.0698, 0.4074],
[-0.3660, -0.3219, 0.3732],
[-1.3314, -0.8263, -1.0793],
[ 1.2589, 0.1886, 0.5453]],
[[ 0.2520, -0.5695, -0.6685],
[ 0.5554, 0.0119, -0.5650],
[ 0.9733, -0.3812, 0.1963],
[-1.1284, 0.2561, 0.4507]]])
x_1=tensor([[[-0.0294, 1.2902, 0.9825, -0.0294, 1.2902, 0.9825],
[-0.3032, 1.6733, 0.9163, -0.3032, 1.6733, 0.9163],
[ 0.3079, -0.0159, 0.2626, 0.3079, -0.0159, 0.2626],
[-0.2934, -0.6076, 0.1593, -0.2934, -0.6076, 0.1593]],
[[ 1.7661, -1.0698, 0.4074, 1.7661, -1.0698, 0.4074],
[-0.3660, -0.3219, 0.3732, -0.3660, -0.3219, 0.3732],
[-1.3314, -0.8263, -1.0793, -1.3314, -0.8263, -1.0793],
[ 1.2589, 0.1886, 0.5453, 1.2589, 0.1886, 0.5453]],
[[ 0.2520, -0.5695, -0.6685, 0.2520, -0.5695, -0.6685],
[ 0.5554, 0.0119, -0.5650, 0.5554, 0.0119, -0.5650],
[ 0.9733, -0.3812, 0.1963, 0.9733, -0.3812, 0.1963],
[-1.1284, 0.2561, 0.4507, -1.1284, 0.2561, 0.4507]]])
应用实例3:张量维度扩充
代码
:
import torch
x = torch.randn(1,2)
print(f'x={x}')
#将a多扩充一个维度,这个维度扩充的倍数需要写在最前面,如此案例的3
print(f'x_1={x.repeat(3,1,1)}')
输出
:
x=tensor([[-0.2581, -0.8387]])
x_1=tensor([[[-0.2581, -0.8387]],
[[-0.2581, -0.8387]],
[[-0.2581, -0.8387]]])