一张图像在计算机中的表示通常为三维tensor(张量),即[channels,height,width] 。也就是一张彩色图片通常有三色通道(R,G,B)组成,高和宽也就是常说的照片大小,比如224x224
在图像处理的时候会增加一个变量batch_size,也就是把多少张图片作为一批进行处理。所以就变成了四维张量,即[batch_size,channels,heigth,width],也即是[批量大小,通道数,高,宽]
如何判断一个tensor是几维张量最简单的办法就是看中括号数。例如 [[[[1,2,3]]]],是四维张量。
torch.cat()函数,官方文档是这样写的torch.cat(tensors, dim=0, *, out=None),也就是有两个参数,一个是要合并的张量,一个是在哪个维度上进行合并。
废话少说开始演示。
import torch
a=torch.tensor([[[[1,1,1],[2,2,2]]]])
b=torch.tensor([[[[3,3,3],[4,4,4]]]])
print(a.shape,b.shape)
#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])
定义了两个四维张量。维度都为[1,1,2,3],即批量大小为1,通道为1,高为2,宽为3
import torch
a=torch.tensor([[[[1,1,1],[2,2,2]]]])
b=torch.tensor([[[[3,3,3],[4,4,4]]]])
print(a.shape,b.shape)
#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])
#在维度0上面进行合并
x=torch.cat((a,b),dim=0)
print(x.shape)
#torch.Size([2, 1, 2, 3])
在维度0上进行合并,然后输出维度为[2,1,2,3],所以得出结论 四维张量在0维合并的时候 其实是在批量大小维度上进行合并。
import torch
a=torch.tensor([[[[1,1,1],[2,2,2]]]])
b=torch.tensor([[[[3,3,3],[4,4,4]]]])
print(a.shape,b.shape)
#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])
#在维度0上面进行合并
x=torch.cat((a,b),dim=0)
print(x.shape)
#torch.Size([2, 1, 2, 3])
#在维度1上进行合并
x=torch.cat((a,b),dim=1)
print(x.shape)
#torch.Size([1, 2, 2, 3])
在1维度上进行合并,输出维度为[1,2,2,3],即在1维上合并是在通道维度上进行合并。
import torch
a=torch.tensor([[[[1,1,1],[2,2,2]]]])
b=torch.tensor([[[[3,3,3],[4,4,4]]]])
print(a.shape,b.shape)
#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])
#在维度0上面进行合并
x=torch.cat((a,b),dim=0)
print(x.shape)
#torch.Size([2, 1, 2, 3])
#在维度1上进行合并
x=torch.cat((a,b),dim=1)
print(x.shape)
#torch.Size([1, 2, 2, 3])
#在维度2上进行合并
x=torch.cat((a,b),dim=2)
print(x.shape)
#torch.Size([1, 1, 4, 3])
在维度2上进行合并,输出维度为[1,1,4,3]。即在2维上进行合并是在高上进行合并(也可以说是在行维度进行合并)
import torch
a=torch.tensor([[[[1,1,1],[2,2,2]]]])
b=torch.tensor([[[[3,3,3],[4,4,4]]]])
print(a.shape,b.shape)
#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])
#在维度0上面进行合并
x=torch.cat((a,b),dim=0)
print(x.shape)
#torch.Size([2, 1, 2, 3])
#在维度1上进行合并
x=torch.cat((a,b),dim=1)
print(x.shape)
#torch.Size([1, 2, 2, 3])
#在维度2上进行合并
x=torch.cat((a,b),dim=2)
print(x.shape)
#torch.Size([1, 1, 4, 3])
#在维度3上进行合并
x=torch.cat((a,b),dim=3)
print(x.shape)
#torch.Size([1, 1, 2, 6])
在维度3上进行合并,输出维度为[1,1,2,6],即在3维上进行合并是在宽维度进行合并(也可以说是列)
注:在拼接时 除了选择拼接的维度可以不同,其他维度要相同。什么意思?看代码
import torch
#定义两个变量[batch_size,channel,height,width]
a=torch.randn(size=(1,1,2,3))
b=torch.randn(size=(1,2,2,3))
#选择在1维度进行合并(也就是通道维度),注意a,b的通道维度不同,其他维度都相同。
x=torch.cat((a,b),dim=1)
print(x.shape)
#torch.Size([1, 3, 2, 3])
也就是选择合并的那个维度可以不同,其他维度要相同
如果不同,报错,如下。
import torch
#定义两个变量[batch_size,channel,height,width]
a=torch.randn(size=(1,1,2,3))
b=torch.randn(size=(2,2,2,3))
#选择在1维度进行合并(也就是通道维度),注意a,b的批量大小不同,维度不同,其他维度都相同。
x=torch.cat((a,b),dim=1)
print(x.shape)
#RuntimeError: Sizes of tensors must match except in dimension 1. Got 1 and 2 in dimension 0
可以看到当我们选择在通道维度合并时(通道数可以不同),但是其他的维度要相同(下面的a,b的批量大小也不同)。所以直接报错。
原文链接:https://blog.csdn.net/zwb619/article/details/127022873