ObjectFormer复现记录

以下所有的内容,是笔者在阅读完ObjectFormer论文后,自己给出的Pytorch复现及过程记录,如有问题请不吝赐教。

一.ObjectFormer 复现代码

import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
import torch

feature_dim = 4


class BCIM(nn.Module):
    def __init__(self, height, width, h_patches, w_patches):
        super().__init__()
        self.height = height
        self.width = width
        self.h_patches = h_patches
        self.w_patches = w_patches
        self.patch_height = int(height / h_patches)
        self.patch_width = int(width / w_patches)
        self.window_size = 3
        self.weight = nn.Parameter(
            torch.zeros(size=(feature_dim * self.patch_height * self.patch_width * 2,
                              feature_dim * self.patch_height * self.patch_width * 2,
                              self.window_size,
                              self.window_size)),
            requires_grad=False)
        for i in range(feature_dim * self.patch_height * self.patch_width * 2):
            self.weight[i, i, :, :] = 1
        self.bias = nn.Parameter(torch.zeros(feature_dim * self.patch_height * self.patch_width * 2),
                                 requires_grad=False)

    def forward(self, p_vector):
        batch_size = p_vector.size()[0]
        # p_vector 10, 128*8*8*2, 28, 38
        norm2_p_vector = torch.norm(p_vector, p=2, dim=1, keepdim=True)  # p_vector 10, 1, 28, 38
        unit_p_vector = p_vector / norm2_p_vector  # p_vector 10, 128*8*8*2, 28, 38
        window_sum_unit_p_vector = F.conv2d(unit_p_vector, weight=self.weight, bias=self.bias,
                                            padding=int((self.window_size-1)/2)) / self.window_size / self.window_size
        # p_vector 10, 128*8*8*2, 28, 38
        Sim = torch.mul(unit_p_vector, window_sum_unit_p_vector).sum(dim=1, keepdim=True)  # 10, 1, 28, 38
        p_vector = torch.mul(p_vector, Sim)  # p_vector 10, 128*8*8*2, 28, 38
        p_vector = p_vector.permute(0, 2, 3, 1).contiguous()  # 10, 28, 38, 128*8*8*2
        p_vector = p_vector.view(batch_size, self.h_patches, self.w_patches,
                                 feature_dim * self.patch_height * self.patch_width, 2)
        # 10, 28, 38, 128*8*8, 2
        p_vector = p_vector.permute(0, 4, 1, 2, 3).contiguous()     # 10, 2, 28, 38, 128*8*8
        p_vector = p_vector.view(batch_size, 2 * self.h_patches * self.w_patches,
                                 feature_dim * self.patch_height * self.patch_width)
        # 10, 2*28*38, 128*8*8
        return p_vector


class ObjectDecoder(nn.Module):
    def __init__(self, height, width, h_patches, w_patches):
        super().__init__()
        self.height = height
        self.width = width
        self.h_patches = h_patches
        self.w_patches = w_patches
        patch_height = int(height / h_patches)
        patch_width = int(width / w_patches)
        self.patch_height = patch_height
        self.patch_width = patch_width
        self.p_layernorm = nn.LayerNorm(feature_dim * patch_height * patch_width)
        self.object_layernorm = nn.LayerNorm(feature_dim * patch_height * patch_width)
        self.query_embedding_Matrix = nn.Linear(in_features=feature_dim * patch_height * patch_width,
                                                out_features=feature_dim * patch_height * patch_width)
        self.key_embedding_Matrix = nn.Linear(in_features=feature_dim * patch_height * patch_width,
                                              out_features=feature_dim * patch_height * patch_width)
        self.value_embedding_Matrix = nn.Linear(in_features=feature_dim * patch_height * patch_width,
                                                out_features=feature_dim * patch_height * patch_width)
        self.softmax_layer = nn.Softmax(dim=2)
        self.projection_layer = nn.Sequential(nn.Linear(in_features=feature_dim * patch_height * patch_width,
                                                        out_features=feature_dim * patch_height * patch_width),
                                              nn.LayerNorm(feature_dim * patch_height * patch_width),
                                              nn.GELU(),
                                              nn.Linear(in_features=feature_dim * patch_height * patch_width,
                                                        out_features=feature_dim * patch_height * patch_width),
                                              nn.LayerNorm(feature_dim * patch_height * patch_width),
                                              nn.GELU())
        self.BCIM = BCIM(height, width, h_patches, w_patches)

    def forward(self, p_vector, object_vector):
        batch_size = p_vector.size()[0]
        p_vector = self.p_layernorm(p_vector)  # 10, 2*28*38, 128*8*8
        object_vector = self.object_layernorm(object_vector)  # N, 128*8*8
        query = self.query_embedding_Matrix(p_vector)  # 10, 2*28*38, 128*8*8
        key = self.key_embedding_Matrix(object_vector)  # N, 128*8*8
        value = self.value_embedding_Matrix(object_vector)  # N, 128*8*8

        key = key.permute(1, 0).contiguous()  # 128*8*8, N
        A = torch.matmul(query, key)  # 10, 2*28*38, N
        A = self.softmax_layer(A)  # 10, 2*28*38, N
        res_p_vector = torch.matmul(A, value)  # 10, 2*28*38, 128*8*8
        p_vector = p_vector + res_p_vector  # 10, 2*28*38, 128*8*8

        projection_p_vector = self.projection_layer(p_vector)
        p_vector = p_vector + projection_p_vector  # 10, 2*28*38, 128*8*8

        p_vector = p_vector.permute(0, 2, 1).contiguous()  # 10, 128*8*8, 2*28*38
        p_vector = p_vector.view(batch_size, feature_dim * self.patch_height * self.patch_width * 2,
                                 self.h_patches, self.w_patches)
        #   10, 128*8*8*2, 28, 38

        p_vector = self.BCIM(p_vector)  # 10, 2*28*38, 128*8*8
        return p_vector


class ObjectEncoder(nn.Module):
    def __init__(self, height, width, h_patches, w_patches):
        super().__init__()
        patch_height = int(height / h_patches)
        patch_width = int(width / w_patches)
        self.object_layernorm = nn.LayerNorm(feature_dim * patch_height * patch_width)
        self.p_layernorm = nn.LayerNorm(feature_dim * patch_height * patch_width)
        self.object_embedding_Matrix = nn.Sequential(nn.Linear(in_features=feature_dim * patch_width * patch_height,
                                                               out_features=feature_dim * patch_width * patch_height),
                                                     nn.LayerNorm(feature_dim * patch_width * patch_height))
        self.key_embedding_Matrix = nn.Sequential(nn.Linear(in_features=feature_dim * patch_width * patch_height,
                                                            out_features=feature_dim * patch_width * patch_height),
                                                  nn.LayerNorm(feature_dim * patch_width * patch_height))
        self.value_embedding_Matrix = nn.Sequential(nn.Linear(in_features=feature_dim * patch_width * patch_height,
                                                              out_features=feature_dim * patch_width * patch_height),
                                                    nn.LayerNorm(feature_dim * patch_width * patch_height))
        self.softmax_layer = nn.Softmax(dim=2)
        self.interaction_Matrix = nn.Sequential(nn.Linear(in_features=feature_dim * patch_height * patch_width,
                                                          out_features=feature_dim * patch_height * patch_width))
        self.linear_projection_layer = nn.Sequential(nn.Linear(in_features=feature_dim * patch_height * patch_width,
                                                               out_features=feature_dim * patch_height * patch_width),
                                                     nn.GELU(),
                                                     nn.Linear(in_features=feature_dim * patch_height * patch_width,
                                                               out_features=feature_dim * patch_height * patch_width))

    def forward(self, objector_vector, p_vector):
        batch_size = p_vector.size()[0]
        objector_vector = self.object_layernorm(objector_vector)  # N, 128*8*8
        p_vector = self.p_layernorm(p_vector)  # 10, 2*28*38, 128*8*8
        object_embedding = self.object_embedding_Matrix(objector_vector)  # N, 128*8*8
        key_embedding = self.key_embedding_Matrix(p_vector)  # 10, 2*28*38, 128*8*8
        value_embedding = self.value_embedding_Matrix(p_vector)  # 10, 2*28*38, 128*8*8

        key_embedding = key_embedding.permute(0, 2, 1).contiguous()  # 10, 128*8*8, 2*28*38

        A = torch.matmul(object_embedding, key_embedding)  # 10, N, 2*28*38
        A = self.softmax_layer(A)  # 10, N, 2*28*38

        res_object_vector = torch.matmul(A, value_embedding)  # 10, N, 128*8*8
        res_object_vector = res_object_vector.sum(dim=0) / batch_size  # N, 128*8*8
        objector_vector = objector_vector + res_object_vector  # N, 128*8*8

        interaction_object_vector = self.interaction_Matrix(objector_vector)
        objector_vector = objector_vector + interaction_object_vector  # N, 128*8*8

        linear_projection = self.linear_projection_layer(objector_vector)
        objector_vector = objector_vector + linear_projection

        return objector_vector  # N, 128*8*8


class EnlargeModule(nn.Module):
    def __init__(self, height, width, h_patches, w_patches):
        super().__init__()
        patch_height = int(height / h_patches)
        patch_width = int(width / w_patches)
        self.conv1 = nn.Sequential(nn.Conv2d(in_channels=feature_dim * patch_height * patch_width * 2, out_channels=64,
                                             kernel_size=(3, 3), stride=(1, 1), padding=1),
                                   nn.BatchNorm2d(64),
                                   nn.ReLU())
        self.conv2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128,
                                             kernel_size=(3, 3), stride=(1, 1), padding=1),
                                   nn.BatchNorm2d(128),
                                   nn.ReLU())
        self.conv3 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=256,
                                             kernel_size=(3, 3), stride=(1, 1), padding=1),
                                   nn.BatchNorm2d(256),
                                   nn.ReLU())

    def forward(self, x):
        # x 10, 128*8*8*2, 28, 38
        out = self.conv1(x)  # 10, 64, 28, 38
        out = F.interpolate(out, scale_factor=2, mode='bilinear', align_corners=True)  # 10, 64, 56, 76
        out = self.conv2(out)  # 10, 128, 56, 76
        out = F.interpolate(out, scale_factor=2, mode='bilinear', align_corners=True)  # 10, 128, 112, 152
        out = self.conv3(out)  # 10, 256, 112, 152
        out = F.interpolate(out, scale_factor=2, mode='bilinear', align_corners=True)  # 10, 256, 224, 304
        return out


class Encoder_Decoder_Block(nn.Module):
    def __init__(self, height, width, h_patches, w_patches):
        super().__init__()
        self.Encoder = ObjectEncoder(height, width, h_patches, w_patches)
        self.Decoder = ObjectDecoder(height, width, h_patches, w_patches)

    def forward(self, object_vector, p_vector):
        updated_object_vector = self.Encoder(object_vector, p_vector)
        updated_p_vector = self.Decoder(p_vector, updated_object_vector)
        return updated_object_vector, updated_p_vector



class ObjectFormer(nn.Module):
    def __init__(self, height, width, h_patches, w_patches):
        super().__init__()
        print("Model:ObjectFormer")
        self.height = height
        self.width = width
        self.h_patches = h_patches
        self.w_patches = w_patches
        patch_height = int(height / h_patches)
        patch_width = int(width / w_patches)
        self.patch_height = patch_height
        self.patch_width = patch_width
        self.Gr_Extractor = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=256, kernel_size=(3, 3), stride=(1, 1),
                                                    padding=1),
                                          nn.BatchNorm2d(256),
                                          nn.ReLU(inplace=True),
                                          nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3),
                                                    stride=(1, 1),
                                                    padding=1),
                                          nn.BatchNorm2d(256),
                                          nn.ReLU(inplace=True),
                                          nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3),
                                                    stride=(1, 1),
                                                    padding=1),
                                          nn.BatchNorm2d(256),
                                          nn.ReLU(inplace=True))
        self.Gf_Extractor = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=256, kernel_size=(3, 3), stride=(1, 1),
                                                    padding=1),
                                          nn.BatchNorm2d(256),
                                          nn.ReLU(inplace=True),
                                          nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3),
                                                    stride=(1, 1),
                                                    padding=1),
                                          nn.BatchNorm2d(256),
                                          nn.ReLU(inplace=True),
                                          nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3),
                                                    stride=(1, 1),
                                                    padding=1),
                                          nn.BatchNorm2d(256),
                                          nn.ReLU(inplace=True))
        self.Gr_Splicer = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=feature_dim * patch_height * patch_width,
                      kernel_size=(patch_height, patch_width),
                      stride=(patch_height, patch_width),
                      padding=0))

        self.Gf_Splicer = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=feature_dim * patch_height * patch_width,
                      kernel_size=(patch_height, patch_width),
                      stride=(patch_height, patch_width),
                      padding=0))

        self.object_vector = nn.Parameter(
            torch.normal(mean=0, std=1e-5, size=(16, feature_dim * patch_height * patch_width)))
        self.encoder_decoder_block1 = Encoder_Decoder_Block(height, width, h_patches, w_patches)
        self.encoder_decoder_block2 = Encoder_Decoder_Block(height, width, h_patches, w_patches)

        self.localization_layer = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=256,
                                                          kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                                                nn.BatchNorm2d(256),
                                                nn.ReLU(inplace=True),
                                                nn.Conv2d(in_channels=256, out_channels=256,
                                                          kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                                                nn.BatchNorm2d(256),
                                                nn.ReLU(inplace=True),
                                                nn.Conv2d(in_channels=256, out_channels=128,
                                                          kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                                                nn.BatchNorm2d(128),
                                                nn.ReLU(inplace=True),
                                                nn.Conv2d(in_channels=128, out_channels=1,
                                                          kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
                                                nn.Sigmoid())
        self.classifier_layer = nn.Sequential(nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=1),
                                              nn.Flatten(start_dim=1),
                                              nn.Linear(in_features=int(feature_dim * 2 * self.height * self.width / 4),
                                                        out_features=1),
                                              nn.Sigmoid())
        self.enlarge_layer = EnlargeModule(height, width, h_patches, w_patches)

    def forward(self, x, y):
        batch_size = x.size()[0]
        Gr = self.Gr_Extractor(x)  # 10, 256, 224, 304
        Gf = self.Gf_Extractor(y)  # 10, 256, 224, 304

        Gr_p = self.Gr_Splicer(Gr)  # 10, 128*8*8, 28, 38
        Gf_p = self.Gf_Splicer(Gf)  # 10, 128*8*8, 28, 38

        Gr_p = Gr_p.view(batch_size, feature_dim * self.patch_height * self.patch_width,
                         self.h_patches * self.w_patches)
        # 10, 128*8*8, 28*38
        Gf_p = Gf_p.view(batch_size, feature_dim * self.patch_height * self.patch_width,
                         self.h_patches * self.w_patches)
        # 10, 128*8*8, 28*38

        p_vector = torch.cat([Gr_p, Gf_p], dim=2)  # 10, 128*8*8, 2*28*38
        p_vector = p_vector.permute(0, 2, 1).contiguous()  # 10, 2*28*38, 128*8*8

        updated_object_vector, p_vector = self.encoder_decoder_block1(self.object_vector, p_vector)
        updated_object_vector, p_vector = self.encoder_decoder_block2(updated_object_vector, p_vector)
        self.object_vector.data = updated_object_vector.data

        p_vector = p_vector.permute(0, 2, 1).contiguous()  # 10, 128*8*8, 2*28*38
        p_vector = p_vector.view(batch_size, feature_dim * self.patch_height * self.patch_width, 2, self.h_patches,
                                 self.w_patches)
        # 10, 128*8*8, 2, 28, 38
        p_vector = p_vector.view(batch_size, feature_dim * self.patch_height * self.patch_width * 2, self.h_patches,
                                 self.w_patches)
        # 10, 128*8*8*2, 28, 38
        p_vector = self.enlarge_layer(p_vector)  # 10, 256, 224, 304
        pre_mask = self.localization_layer(p_vector)
        # pre_class = self.classifier_layer(p_vector)
        # return pre_mask, pre_class
        return pre_mask

使用方法:
1.首先新建一个ObjectFormer.py的文件,将上述所有代码拷贝进去,把没有装的包安装一下;
2.在你要建立训练模型的文件(如main.py)中利用

from ObjectFormer import ObjectFormer

将ObjectFormer模块导入;
3.设置训练模型

model = ObjectFormer(height=224, width=304, h_patches=28, w_patches=38)

其中,height 和 width参数是训练输入图片的尺寸,h_patches是纵向将图片划分成多少份,w_patches是横向将图片划分成多少份。比如我这里是输入 224 × 304 224\times304 224×304的图片,把图片划分成 8 × 8 8\times8 8×8大小的patch,共 28 × 38 28\times38 28×38块。

二.对于复现代码的特别说明

1.该复现代码不包含下图所示红框中的分类任务分支,如果需要的话可自行添加,还是非常方便的,笔者的模型仅输出一个预测的Mask,和输入的图像大小一致,单通道,已 Sigmoid 到 0 0 0~ 1 1 1 区间;
ObjectFormer复现记录_第1张图片
2.在 G o u t G_{out} Gout的过程中有上采样,该上采样过程定义在

class EnlargeModule(nn.Module)

本实现中约定每个patch的大小应为2的幂次倍,且patch应该是方形的,如可以是 2 × 2 2\times2 2×2 4 × 4 4\times4 4×4 8 × 8 8\times8 8×8,…而不能是其他的倍数,如果你的patch是 2 n × 2 n 2^n\times2^n 2n×2n,则在EnlargeModule的forward的前向传播过程中,就应进行n次上采样,请自行调整;
3.笔者在实现Object Encoder 和 Patch Decoder的过程中,将这两个模块封装成了 Encoder_Decoder_Block,在ObjectFormer的前向传播过程中,笔者仅使用了2个该模块,作者原论文中使用了8个模块,如果需要修改请在ObjectFormer类中进行简单调整,不麻烦,稍微浏览一下就能知道如何调整。
4.笔者使用的学习率是1e-4。

三.实现过程经验记录

在笔者先前实现注意力有关的网络时,一直有这样一个问题,就是网络训练出来的Mask都是很奇怪的,如下面示例所示:
ObjectFormer复现记录_第2张图片
ObjectFormer复现记录_第3张图片
ObjectFormer复现记录_第4张图片
造成这种很奇怪的结果的原因其实都是一样的,就是使用reshape、view、permute等函数的方式不当,对tensor底层存储方式不清楚导致的。一定要弄清楚tensor的底层存储原理以及这些函数的具体作用后,再使用,否则就会导致张量的排列顺序其实和你想的不一样,而导致意外的错误。(想要检验你对这些基本功掌握得怎么样其实很简单,拿一张图片,看看能不能顺利把它划分成 8 × 8 8\times8 8×8个patches并正常显示出来,千万不要觉得这个事情很简单。)

你可能感兴趣的:(Pytorch,pytorch,深度学习,神经网络)