CNN使用Transformer

目录

  • 1、torch.nn.Transformer()
  • 2 、CNN使用Transformer代码

1、torch.nn.Transformer()

API解释参考:《pytorch中的transformer》

2 、CNN使用Transformer代码

注意:
[1]nhead必须能被d_model整除(序列被几个头注意)
[2]CNN特征图通道512被当成序列,放到第一个维度,批次放到第二个维度
[3]Transformer必须有src和tgt两个向量,CNN是自相关性解算,都放入特征图向量

特别提示
Transformer计算量大,他提取的是全局信息。CNN更多偏向于局部信息,使用时可以利用深层的特征图部分通道做Transformer后,变换回去和其他未做Transformer的特征图cat继续卷积

这样即引入Transformer提供了全局语义信息,又不至于计算量过大

 net_2 = nn.Transformer(d_model=81, nhead=9, num_encoder_layers=12)                                                                                                        
 11 img = torch.rand((2, 512, 9, 9))
 12 print("img_shape:", img.shape)
 13 
 14 img_trans = img.permute((1, 0, 2, 3))
 15 print("img_trans:", img_trans.shape)
 16 
 17 img_view = img_trans.view((img_trans.size(0), img_trans.size(1), -1))
 18 print("img_view:", img_view.shape)
 19 
 20 out = net_2(img_view, img_view)
 21 print("out:", out.shape)

你可能感兴趣的:(深度学习,计算机视觉,pytorch)