本篇作为后期文章“特征融合”的基础。
特征融合分早融合和晚融合,早融合里的重要手段是concat和add
torch.cat()可以将多个张量合并为一个张量,我们接下来从简单到复杂一点点来盘这个函数
我们首先随机生成两个形状一致的张量:
import torch
A =torch.rand(3,2) #单通道,高为3.宽为2的张量
B=torch.rand(3,3) #单通道,高为2.宽为3的张量
print(A)
print(B)
让这个张量在第0维度进行拼接,也就是在高这个维度进行拼接:
C=torch.cat((A,B),dim=0)
print(C)
print(C.shape)
让这个张量在第1维度进行拼接,也就是在宽这个维度进行拼接:
C=torch.cat((A,B),dim=1)
print(C)
print(C.shape)
在第0维度拼接时,高可以不一样,但是宽需要一致,不然会报错:
import torch
A =torch.rand(3,3) #单通道,高为3.宽为2的张量
B=torch.rand(4,3) #单通道,高为2.宽为3的张量
print(A)
print(B)
C=torch.cat((A,B),dim=0)
print(C)
print(C.shape)
import torch
A =torch.rand(3,3) #单通道,高为3.宽为2的张量
B=torch.rand(3,5) #单通道,高为2.宽为3的张量
print(A)
print(B)
C=torch.cat((A,B),dim=0)
print(C)
print(C.shape)
直接报错:
在第1维度拼接时,高必须一致,宽可以不一样,不然会报错:
import torch
A =torch.rand(3,3) #单通道,高为3.宽为2的张量
B=torch.rand(3,5) #单通道,高为2.宽为3的张量
print(A)
print(B)
C=torch.cat((A,B),dim=1)
print(C)
print(C.shape)
import torch
A =torch.rand(3,3) #单通道,高为3.宽为2的张量
B=torch.rand(4,3) #单通道,高为2.宽为3的张量
print(A)
print(B)
C=torch.cat((A,B),dim=1)
print(C)
print(C.shape)
我们随机生成两个3通道的2X2图像
import torch
A =torch.rand(3,2,2) #单通道,高为3.宽为2的张量
B=torch.rand(3,2,2) #单通道,高为2.宽为3的张量
print(A)
print(B)
让他们在第0维度进行拼接(通道维度拼接):
相当于通道数堆叠了,变成了六个通道
让他们在第1维度进行拼接(高维度拼接):
让他们在第2维度进行拼接(宽维度拼接):
这两个堆叠结果就和之前的方法一样了
import torch
A =torch.rand(3,2,2) #单通道,高为3.宽为2的张量
B=torch.rand(3,2,2) #单通道,高为2.宽为3的张量
print(A)
print(B)
C=torch.add(A,B)
print(C)
print(C.shape)