在修改模型结构时,本来想着简单替换主干网络,用轻量级结构的替换原来的复杂模型,但是过程没想象中的顺利;其中比较关键的一点是两个主干网络输出的特征图类型不一致。
主干网络A(轻量级),它输出特征图的类型是tuple,输出维度是[1, 3, 640, 640];
主干网络B(复杂的),它输出特征图的类型是torch.Tensor,输出维度也是[1, 3, 640, 640];
但是如果直接把主干网络B替换为主干网络A,后面接着原来的特征提取结构和任务头,会报错的。
把主干网络B替换为主干网络A后,加多一步操作,将输出特征图从tuple 转 torch.Tensor即可。
转换的基本思路是:使用 torch.cat( ) 把特征图进行拼接起来,通常是在维度 dim=0 进行拼接的。
import torch
# 假设模型输出的特征图为 feature_map, feature_map 是一个 tuple
# 获取特征图个数
num_maps = len(feature_map)
# 打印原来的特征图信息
print("type feature_raw:", type(outs))
for out in feature_map:
print(out.size())
print("len feature_raw:", num_maps)
# 按第 0 维度拼接特征图
feature_map = torch.cat([fm for fm in feature_map], dim=0)
# 检查特征图类型
print("type feature_map:", type(feature_map))
# 输出:
# 检查特征图维度
print("size feature_map:", feature_map.size())
示例输出:
type feature_raw:
torch.Size([8, 32, 640, 640])
len feature_raw: 1
type feature_map:
feature_map: torch.Size([8, 32, 640, 640])
如果主干网络输出的特征图类型为tuple,而且它包含多个特征图。我们想把它们变为一个torch.Tensor
,可以使用torch.cat
函数把它们拼接在一起。
import torch
# 假设模型输出的特征图为 feature_map, feature_map 是一个 tuple
# 获取特征图个数
num_maps = len(feature_map)
# 打印原来的特征图信息
print("type feature_raw:", type(outs))
for out in feature_map:
print(out.size())
print("len feature_raw:", num_maps)
# 按第 0 维度拼接特征图
feature_map = torch.cat([fm.unsqueeze(0) for fm in feature_map], dim=0)
# 检查特征图类型
print("type feature_map:", type(feature_map))
# 输出:
# 检查特征图维度
print("size feature_map:", feature_map.size())
这样就可以将输出的特征图类型由tuple
变为torch.Tensor
了。拼接时,通过unsqueeze(0)
把每个特征图在第0维度上增加一维,这样才能用torch.cat
进行拼接。
示例输出:
type feature_raw:
torch.Size([8, 32, 640, 640])torch.Size([8, 32, 640, 640])
torch.Size([8, 32, 640, 640])
torch.Size([8, 32, 640, 640])
len feature_raw: 1
type feature_map:
feature_map: torch.Size([4, 8, 32, 640, 640])
分享完成,欢迎交流~